社区讨论

关于本题(可能是)边界处理的问题

P14254分割(divide)参与者 4已保存回复 4

讨论操作

快速查看讨论及其快照的属性,并进行相关操作。

当前回复
4 条
当前快照
1 份
快照标识符
@mhj0tgdl
此快照首次捕获于
2025/11/03 18:52
4 个月前
此快照最后确认于
2025/11/03 18:52
4 个月前
查看原帖
我在一份代码里判断了组合数 AnmA_n ^ m 可能出现 n<mn < m 的情况,但另一份没有判断,结果都过了。然后我就以为这种情况不可能出现,于是我为了测试,把特判 nm<0n - m < 0 的部分改了一个错误的特判,然后错了。但我认为如果不特判直接访问 invnminv_{n - m} 可能会出现负数下标,那为什么都能过呢?
具体地,三份代码仅有一行不同,在 long long A(long long n , long long m) 函数计算 invnm 的一行(下面这份代码的通过记录见此),将这行变成 long long invnm = gksm(fac[n - m] , mod - 2); 也可以通过(见此)。但我为了测试是否会越界改为 long long invnm = (n - m < 0 ? 1 : gksm(fac[n - m] , mod - 2)); 就没法过样例二了(这种改法明显把特判部分改错了),说明存在越界情况,但为什么不会 RE?
另外我的 long long C(long long n , long long m) 函数的特判也写错了,但因为没用到此函数,所以代码里没改。
代码:
CPP
vector < int > v[1000005];
long long n , k;
int fa[1000005];
int depth[1000005];
int maxdep[1000005];
void dfs(int now)
{
    maxdep[now] = depth[now] = depth[fa[now]] + 1;
    for(int i : v[now])
    {
        dfs(i);
        maxdep[now] = max(maxdep[now] , maxdep[i]);
    }
}
long long fac[1000005];
const long long mod = 998244353;
long long gksm(long long a , long long b , long long ans = 1) { while(b) { if(b & 1) (ans *= a) %= mod; (a *= a) %= mod; b >>= 1; } return ans; }
void init()
{
    fac[0] = 1;
    for(int i = 1 ; i <= 1000000 ; i++)
    {
        fac[i] = fac[i - 1] * i % mod;
    }
}
long long C(long long n , long long m) // n 选 m
{
    if(m < 0)
    {
        return 0;
    }
    long long invnm = (n - m < 0 ? 1 : gksm(fac[n - m] , mod - 2));
    long long invm = gksm(fac[m] , mod - 2);
    return fac[n] * invnm % mod * invm % mod;
}
long long A(long long n , long long m) // n 选 m
{
    if(m < 0)
    {
        return 0;
    }
    long long invnm = (n - m < 0 ? 0 : gksm(fac[n - m] , mod - 2));
    return fac[n] * invnm % mod;
}
vector < int > depsort[1000005];
vector < pair < int , int > > depcnt[1000005];
signed main()
{
    init();
    read(n , k);
    readarray(fa , 2 , n);
    for(int i = 2 ; i <= n ; i++)
    {
        v[fa[i]].push_back(i);
    }
    dfs(1);
    for(int i = 1 ; i <= n ; i++)
    {
        depsort[depth[i]].push_back(maxdep[i]);
    }
    for(int i = 1 ; i <= n ; i++)
    {
        sort(depsort[i].begin() , depsort[i].end());
    }
    for(int i = 1 ; i <= n ; i++)
    {
        if(!depsort[i].size())
        {
            continue;
        }
        depcnt[i].push_back(make_pair(depsort[i][0] , 1));
        for(int j = 1 ; j < depsort[i].size() ; j++)
        {
            if(depsort[i][j] == depsort[i][j - 1])
            {
                depcnt[i][depcnt[i].size() - 1].second++;
            }
            else
            {
                depcnt[i].push_back(make_pair(depsort[i][j] , 1));
            }
        }
    }
    long long ans = 0;
    for(int i = 2 ; i <= n ; i++)
    {
        if(!depcnt[i].size())
        {
            continue;
        }
        reverse(depcnt[i].begin() , depcnt[i].end());
        long long bigger = 0;
        for(pair < int , int > j : depcnt[i])
        {
            long long equ = j.second;
            if(bigger + equ >= k + 1 && equ >= 2)
            {
                ans += ((A(bigger + equ - 1 , k - 1) - A(bigger , k - 1)) % mod + mod) * equ % mod;
                ans %= mod;
            }
            if(bigger == k - 1 && equ >= 2)
            {
                ans += equ * fac[k - 1] % mod;
                ans %= mod;
            }
            bigger += equ;
        }
    }
    printnl((ans + mod) % mod);
	return 0;
}

回复

4 条回复,欢迎继续交流。

正在加载回复...