社区讨论

AI 已经拿下本题

P14364[CSP-S 2025] 员工招聘参与者 156已保存回复 203

讨论操作

快速查看讨论及其快照的属性,并进行相关操作。

当前回复
200 条
当前快照
1 份
快照标识符
@mi5d9f9y
此快照首次捕获于
2025/11/19 10:11
3 个月前
此快照最后确认于
2025/11/20 09:15
3 个月前
查看原帖
今天 0 时 Gemini3Pro 刚刚发布,我在 AIStudio 上测试了本题,被一发秒了:
CPP
#include <iostream>
#include <vector>
#include <string>
#include <algorithm>

using namespace std;

// 常量定义
const int MOD = 998244353;
const int MAXN = 505;

// 全局变量
int n, m;
string s;
int cnt[MAXN]; // cnt[v] 表示耐心值 c 为 v 的人数
long long dp[MAXN][MAXN]; // dp[j][x]
long long next_dp[MAXN][MAXN];
long long inter[MAXN][MAXN]; // 中间状态,用于处理 lost 增加时的重分布
long long C[MAXN][MAXN]; // 组合数
int small_total[MAXN]; // 前缀和,small_total[j] 表示 c <= j 的总人数

// 快速幂计算模逆元
long long power(long long base, long long exp) {
    long long res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

long long modInverse(long long n) {
    return power(n, MOD - 2);
}

// 初始化组合数
void init_comb() {
    for (int i = 0; i < MAXN; ++i) {
        C[i][0] = 1;
        for (int j = 1; j <= i; ++j) {
            C[i][j] = (C[i - 1][j - 1] + C[i - 1][j]) % MOD;
        }
    }
}

// 初始化 small_total
void init_small_total() {
    small_total[0] = cnt[0];
    for (int i = 1; i < MAXN; ++i) {
        small_total[i] = small_total[i-1] + cnt[i];
    }
}

int get_small_total(int j) {
    if (j < 0) return 0;
    if (j >= MAXN) return small_total[MAXN-1];
    return small_total[j];
}

int main() {
    // 优化 I/O
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    if (!(cin >> n >> m)) return 0;
    cin >> s;
    for (int i = 0; i < n; ++i) {
        int val;
        cin >> val;
        cnt[val]++;
    }

    init_comb();
    init_small_total();

    // 初始状态:第 0 天前,lost=0,使用了 0 个 Small (c<=0) 的人
    // 注意:如果 cnt[0]>0,这些人在 j=0 时即为 Small,但尚未被使用,所以 x=0 是正确的。
    dp[0][0] = 1;

    for (int i = 0; i < n; ++i) { // 遍历每一天
        int diff = s[i] - '0';

        // 清空下一轮的 DP 数组
        for(int j=0; j<=n+1; ++j) {
            for(int x=0; x<=n+1; ++x) {
                next_dp[j][x] = 0;
                inter[j][x] = 0;
            }
        }

        // 第一阶段:根据当前 diff 分配人员
        // j: 当前 lost 人数
        // x: 当前已使用的 Small 人数 (c <= j)
        for (int j = 0; j <= i && j <= n - m + 1; ++j) {
            if (j > n - m) continue; // 剪枝:如果已经损失超过 n-m,无法满足条件

            int S_total = get_small_total(j); // 当前 Small 总数
            int L_total = n - S_total;        // 当前 Large 总数
            
            for (int x = 0; x <= S_total; ++x) {
                if (dp[j][x] == 0) continue;

                int used_l = i - x; // 已使用的 Large 人数
                if (used_l < 0 || used_l > L_total) continue;

                int avail_s = S_total - x;     // 可用的 Small
                int avail_l = L_total - used_l; // 可用的 Large

                if (diff == 1) {
                    // 选项 1: 录用 (Hire)。需要消耗一个 Large。
                    // 状态转移:(j, x) -> (j, x) (进入下一轮的 next_dp)
                    if (avail_l > 0) {
                        next_dp[j][x] = (next_dp[j][x] + dp[j][x] * avail_l) % MOD;
                    }
                    // 选项 2: 失败 (Fail)。需要消耗一个 Small。
                    // 状态转移:(j, x) -> (j+1, x+1) (进入 inter 等待重分布)
                    if (avail_s > 0) {
                        inter[j + 1][x + 1] = (inter[j + 1][x + 1] + dp[j][x] * avail_s) % MOD;
                    }
                } else { // diff == 0
                    // 选项 1: 选 Small。
                    // 状态转移:(j, x) -> (j+1, x+1) (进入 inter)
                    if (avail_s > 0) {
                        inter[j + 1][x + 1] = (inter[j + 1][x + 1] + dp[j][x] * avail_s) % MOD;
                    }
                    // 选项 2: 选 Large。
                    // 状态转移:(j, x) -> (j+1, x) (进入 inter,注意 x 暂时不变)
                    if (avail_l > 0) {
                        inter[j + 1][x] = (inter[j + 1][x] + dp[j][x] * avail_l) % MOD;
                    }
                }
            }
        }

        // 第二阶段:重分布 (Redistribution)
        // 处理 lost 从 j-1 增加到 j 的情况
        for (int j = 1; j <= i + 1 && j <= n; ++j) {
             if (j > n - m) continue; // 剪枝

             int c_val = cnt[j]; // 耐心值为 j 的人数(这些人从 Large 变成了 Small)
             int prev_S_total = get_small_total(j - 1);
             int prev_L_total = n - prev_S_total; // 上一状态下的 Large 总数

             // 计算分母的逆元:binom(prev_L_total, c_val)
             long long denom = C[prev_L_total][c_val];
             long long invDenom = modInverse(denom);

             // x_in: 在 j-1 状态下已使用的 Small 人数
             for (int x_in = 0; x_in <= prev_S_total; ++x_in) {
                 if (inter[j][x_in] == 0) continue;

                 // 在 j-1 状态下已使用的 Large 人数
                 int used_l = (i + 1) - x_in;
                 // 在 j-1 状态下可用的 Large 人数
                 int avail_l = prev_L_total - used_l;

                 // 我们需要确定在 used_l 中有多少人恰好是 c_val (即 c=j)
                 // 设这个数量为 y
                 int min_y = max(0, c_val - avail_l);
                 int max_y = min(used_l, c_val);

                 for (int y = min_y; y <= max_y; ++y) {
                     // 组合数:从 used_l 中选 y 个,从 avail_l 中选 c_val - y 个
                     long long ways = (C[used_l][y] * C[avail_l][c_val - y]) % MOD;
                     // 概率/比例
                     long long prob = (ways * invDenom) % MOD;
                     
                     long long added = (inter[j][x_in] * prob) % MOD;
                     
                     // 更新状态:新的 Small 使用数 = 原 Small 使用数 + 新变成 Small 的使用数
                     next_dp[j][x_in + y] = (next_dp[j][x_in + y] + added) % MOD;
                 }
             }
        }

        // 将 next_dp 更新到 dp
        for(int j=0; j<=n; ++j) {
            for(int x=0; x<=n; ++x) {
                dp[j][x] = next_dp[j][x];
            }
        }
    }

    long long ans = 0;
    // 统计所有录用至少 m 人的方案(即 lost <= n - m)
    for (int j = 0; j <= n - m; ++j) {
        for (int x = 0; x <= n; ++x) {
            ans = (ans + dp[j][x]) % MOD;
        }
    }

    cout << ans << endl;

    return 0;
}
因为当时误开了联网,尽管 Gemini 并没有搜索到相关的试题资料,但是依然感觉说服力不强。因此刚刚又测试了一次,虽然第一次提交没有通过该题,但是告知代码错误后,第二次提交成功通过了本题。
CPP
#include <iostream>
#include <vector>
#include <string>
#include <algorithm>

using namespace std;

const int MOD = 998244353;
const int MAXN = 505;

long long C[MAXN][MAXN];
long long fact[MAXN], invFact[MAXN];

int n, m;
string s;
int c[MAXN];
int cnt[MAXN]; 
int Suf[MAXN]; 
int Pre[MAXN]; 

// dp[j][u]
// j: number of rejects so far
// u: number of used candidates with c > j
long long dp[MAXN][MAXN];
long long next_dp[MAXN][MAXN];

// 快速幂计算模逆元
long long power(long long base, long long exp) {
    long long res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

long long modInverse(long long n) {
    return power(n, MOD - 2);
}

// 初始化组合数和阶乘
void init() {
    // 组合数 C[n][k]
    for (int i = 0; i < MAXN; i++) {
        C[i][0] = 1;
        for (int j = 1; j <= i; j++) {
            C[i][j] = (C[i - 1][j - 1] + C[i - 1][j]) % MOD;
        }
    }
    
    // 阶乘及其逆元,用于计算 C(n, k) 的逆元
    fact[0] = 1;
    invFact[0] = 1;
    for(int i=1; i<MAXN; ++i) {
        fact[i] = (fact[i-1] * i) % MOD;
        invFact[i] = modInverse(fact[i]);
    }
}

// 计算 C(n, k) 的模逆元: 1 / C(n, k) % MOD
long long get_invC(int n, int k) {
    if (k < 0 || k > n) return 0; 
    // inv(C(n, k)) = k! * (n-k)! * inv(n!)
    long long res = fact[k];
    res = (res * fact[n-k]) % MOD;
    res = (res * invFact[n]) % MOD;
    return res;
}

int main() {
    // 优化 I/O
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    init();

    if (!(cin >> n >> m)) return 0;
    cin >> s;
    for (int i = 0; i < n; i++) {
        cin >> c[i];
        if (c[i] > n) c[i] = n; // 耐心值上限处理
        cnt[c[i]]++;
    }

    // 预处理 Suf 和 Pre 数组
    // Suf[j] 表示 c >= j 的人数
    // Pre[j] 表示 c < j 的人数
    for (int j = 0; j <= n + 1; j++) {
        int s_count = 0;
        int p_count = 0;
        for (int i = 0; i < n; i++) {
            if (c[i] >= j) s_count++;
            if (c[i] < j) p_count++;
        }
        Suf[j] = s_count;
        Pre[j] = p_count;
    }

    // 初始状态:0 天,0 拒绝,0 个强候选人被使用
    dp[0][0] = 1;

    // 遍历每一天
    for (int i = 0; i < n; i++) {
        // 清空下一天的 DP 表
        for(int j=0; j<=n; ++j) {
            for(int u=0; u<=n; ++u) {
                next_dp[j][u] = 0;
            }
        }

        int current_s = s[i] - '0';

        // 遍历当前状态
        for (int j = 0; j <= i; j++) {
            // u 不可能超过总天数 i,也不可能超过当前强候选人总数
            int max_u = min(i, Suf[j+1]);
            for (int u = 0; u <= max_u; u++) {
                if (dp[j][u] == 0) continue;

                int S_curr = Suf[j+1]; // 当前强候选人总数 (c > j)
                int S_le = Pre[j+1];   // 当前弱候选人总数 (c <= j)
                int A_S = S_curr - u;  // 可用的强候选人
                int A_W = S_le - (i - u); // 可用的弱候选人 (总已用 i - 强已用 u)

                // 1. 录用 (Hire)
                // 条件:题目难度为 1 且有可用的强候选人
                if (current_s == 1 && A_S > 0) {
                    long long ways = (dp[j][u] * A_S) % MOD;
                    next_dp[j][u+1] = (next_dp[j][u+1] + ways) % MOD;
                }

                // 2. 拒绝 (Reject)
                // 无论如何 j 都会变成 j+1。
                // 此时,耐心值恰好为 j+1 的人从强变弱。
                // 我们需要将已使用的强候选人根据超几何分布拆分。
                
                int C_val = cnt[j+1]; // 耐心值恰好为 j+1 的人数
                int S_next = Suf[j+2]; // 新的强候选人总数 (c > j+1)
                
                // 2a. 使用弱候选人导致拒绝
                if (A_W > 0) {
                    long long ways = (dp[j][u] * A_W) % MOD;
                    
                    // 我们从 S_curr 中选了 u 个,现在 S_curr 分裂为 C_val 和 S_next。
                    // 我们需要知道这 u 个中有多少个 (y) 落在 C_val 中。
                    // 系数 = ways / C(S_curr, u)
                    long long factor = (ways * get_invC(S_curr, u)) % MOD;
                    
                    int min_y = 0;
                    if (u > S_next) min_y = u - S_next;
                    int max_y = min(u, C_val);
                    
                    for (int y = min_y; y <= max_y; y++) {
                        // 组合数:从 C_val 中选 y 个,从 S_next 中选 u-y 个
                        long long term = (C[C_val][y] * C[S_next][u - y]) % MOD;
                        term = (term * factor) % MOD;
                        // 新状态下,使用的强候选人数量为 u - y
                        next_dp[j+1][u-y] = (next_dp[j+1][u-y] + term) % MOD;
                    }
                }

                // 2b. 使用强候选人导致拒绝 (仅当 s=0)
                if (current_s == 0 && A_S > 0) {
                    long long ways = (dp[j][u] * A_S) % MOD;
                    
                    // 使用了一个强候选人,故总使用数为 u+1
                    int u_new = u + 1;
                    long long factor = (ways * get_invC(S_curr, u_new)) % MOD;
                    
                    int min_y = 0;
                    if (u_new > S_next) min_y = u_new - S_next;
                    int max_y = min(u_new, C_val);
                    
                    for (int y = min_y; y <= max_y; y++) {
                        long long term = (C[C_val][y] * C[S_next][u_new - y]) % MOD;
                        term = (term * factor) % MOD;
                        // 新状态下,使用的强候选人数量为 u_new - y
                        next_dp[j+1][u_new-y] = (next_dp[j+1][u_new-y] + term) % MOD;
                    }
                }
            }
        }
        
        // 滚动数组更新
        for(int j=0; j<=n; ++j) {
            for(int u=0; u<=n; ++u) {
                dp[j][u] = next_dp[j][u];
            }
        }
    }

    long long ans = 0;
    // 统计答案:录用人数 >= m
    // 总人数 n,拒绝人数 j,录用人数 = n - j
    // n - j >= m  =>  j <= n - m
    for (int j = 0; j <= n - m; j++) {
        for (int u = 0; u <= n; u++) {
            ans = (ans + dp[j][u]) % MOD;
        }
    }

    cout << ans << endl;

    return 0;
}

回复

203 条回复,欢迎继续交流。

正在加载回复...