当前位置: 代码迷 >> 综合 >> [Codeforces755G][DP][NTT]PolandBall and Many Other Balls
  详细解决方案

[Codeforces755G][DP][NTT]PolandBall and Many Other Balls

热度:75   发布时间:2023-12-19 04:59:59.0

翻译

给你n个球,把他们分成K组,允许有的球没有组
每组不能为空也不能超过两个球
求方案数
n<=1e9 K<=2^15

题解

f [ i ] [ j ] f[i][j] f[i][j]表示前 i i i个球分成 j j j组的方案数
朴素DP容易想到
f [ i ] [ j ] = f [ i ? 1 ] [ j ] + f [ i ? 1 ] [ j ? 1 ] + f [ i ? 2 ] [ j ? 1 ] f[i][j]=f[i-1][j]+f[i-1][j-1]+f[i-2][j-1] f[i][j]=f[i?1][j]+f[i?1][j?1]+f[i?2][j?1]
优化不了… 没辙
换一个转移方式
对于 2 ? n 2*n 2?n,可以由 n n n得到
显然有
f [ 2 ? i ] [ K ] = ∑ j = 0 K f [ i ] [ j ] ? f [ i ] [ K ? j ] f[2*i][K]=\sum_{j=0}^{K}f[i][j]*f[i][K-j] f[2?i][K]=j=0K?f[i][j]?f[i][K?j]
发现漏算了一种情况,在分界点的两个球组成一组的情况
所以加上
f [ 2 ? i ] [ K ] + = ∑ j = 0 K ? 1 f [ i ? 1 ] [ j ] ? f [ i ? 1 ] [ K ? 1 ? j ] f[2*i][K]+=\sum_{j=0}^{K-1}f[i-1][j]*f[i-1][K-1-j] f[2?i][K]+=j=0K?1?f[i?1][j]?f[i?1][K?1?j]
这个可以优化了
上面两个是卷积形式可以NTT优化成 n log ? n n\log n nlogn
倍增DP
维护 f [ i ] , f [ i ? 1 ] , f [ i ? 2 ] f[i],f[i-1],f[i-2] f[i],f[i?1],f[i?2]的多项式,用系数当答案

f [ 2 ? i ] = f [ i ] ? f [ i ] + f [ i ? 1 ] ? f [ i ? 1 ] f[2*i]=f[i]*f[i]+f[i-1]*f[i-1] f[2?i]=f[i]?f[i]+f[i?1]?f[i?1]
f [ 2 ? i ? 1 ] = f [ i ] ? f [ i ? 1 ] + f [ i ? 1 ] ? f [ i ? 2 ] f[2*i-1]=f[i]*f[i-1]+f[i-1]*f[i-2] f[2?i?1]=f[i]?f[i?1]+f[i?1]?f[i?2]
f [ 2 ? i ? 2 ] = f [ i ? 1 ] ? f [ i ? 1 ] + f [ i ? 2 ] ? f [ i ? 2 ] f[2*i-2]=f[i-1]*f[i-1]+f[i-2]*f[i-2] f[2?i?2]=f[i?1]?f[i?1]+f[i?2]?f[i?2]
前面取第K项 后面取第K-1项
类似二进制一样维护,如果扫到这一位有1,那么暴力把这个1的贡献加上
相当于
n=1001000
1000->1001->10010…
暴力做一遍的复杂度是 O ( n ) O(n) O(n)
倍增复杂度 log ? n \log n logn
NTT的复杂度 n log ? n n\log n nlogn
所以复杂度完美的 n log ? 2 n n\log^2 n nlog2n
NTT做的有点多…常数大了有兴趣的可以帮我卡卡啊…

#include<cstdio> #include<cstring> #include<cstdlib> #include<algorithm> #include<cmath> #include<queue> #include<vector> #include<ctime> #include<map> #define LL long long #define mp(x,y) make_pair(x,y) #define mod 998244353 #define MAXN 50005 using namespace std; inline int read() {
     int f=1,x=0;char ch=getchar();while(ch<'0'||ch>'9'){
     if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){
     x=x*10+ch-'0';ch=getchar();}return x*f; } inline void write(int x) {
     if(x<0)putchar('-'),x=-x;if(x>9)write(x/10);putchar(x%10+'0'); } inline void print(int x){
     write(x);printf(" ");} LL pow_mod(LL a,LL b) {
     LL ret=1;while(b){
     if(b&1)ret=ret*a%mod;a=a*a%mod;b>>=1;}return ret; } int R[MAXN*4],L; void NTT(LL *y,int len,int on) {
     for(int i=0;i<len;i++)if(i<R[i])swap(y[i],y[R[i]]);for(int i=1;i<len;i<<=1){
     LL wn=pow_mod(3,(mod-1)/(i<<1));if(on==-1)wn=pow_mod(wn,mod-2);for(int j=0;j<len;j+=(i<<1)){
     LL w=1;for(int k=0;k<i;k++){
     LL u=y[j+k];LL v=y[j+k+i]*w%mod;y[j+k]=(u+v)%mod;y[j+k+i]=(u-v+mod)%mod;w=w*wn%mod;}}}if(on==-1){
     LL tmp=pow_mod(len,mod-2);for(int i=0;i<len;i++)y[i]=y[i]*tmp%mod;} } LL A[MAXN*4],B[MAXN*4],C[MAXN*4]; LL n1[MAXN*4],n2[MAXN*4],n3[MAXN*4]; LL s1[MAXN*4],s2[MAXN*4],s3[MAXN*4],s4[MAXN*4],s5[MAXN*4]; int n,K; void update(int len) {
     memcpy(n1,A,sizeof(n1));memcpy(n2,B,sizeof(n2));memcpy(n3,C,sizeof(n3));NTT(n1,len,1);NTT(n2,len,1);NTT(n3,len,1);for(int i=0;i<len;i++)s1[i]=n1[i]*n1[i]%mod;for(int i=0;i<len;i++)s2[i]=n2[i]*n2[i]%mod;for(int i=0;i<len;i++)s3[i]=n1[i]*n2[i]%mod;for(int i=0;i<len;i++)s4[i]=n2[i]*n3[i]%mod;for(int i=0;i<len;i++)s5[i]=n3[i]*n3[i]%mod;NTT(s1,len,-1);NTT(s2,len,-1);NTT(s3,len,-1);NTT(s4,len,-1);NTT(s5,len,-1);for(int i=1;i<=K;i++)A[i]=(s1[i]+s2[i-1])%mod;A[0]=s1[0];for(int i=1;i<=K;i++)B[i]=(s3[i]+s4[i-1])%mod;B[0]=s3[0];for(int i=1;i<=K;i++)C[i]=(s2[i]+s5[i-1])%mod;C[0]=s2[0]; } LL tmp[MAXN*4],t1[MAXN*4]; void vio(int ok) {
     memcpy(tmp,A,sizeof(tmp));for(int i=1;i<=min(K,ok);i++)A[i]=(A[i]+tmp[i-1]+B[i-1])%mod;memcpy(t1,B,sizeof(t1));memcpy(B,tmp,sizeof(B));memcpy(C,t1,sizeof(C)); } int gets(int u){
     int ret=0;for(;u;u>>=1)ret++;return ret;} int main() {
     //freopen("a.in","r",stdin);//freopen("b.out","w",stdout);n=read();K=read();int ln=1;for(ln=1;ln<=2*K;ln<<=1)L++;for(int i=0;i<ln;i++)R[i]=(R[i>>1]>>1)|(i&1)<<(L-1);A[0]=1;int lg=gets(n);int sum=0;for(int i=lg-1;i>=0;i--){
     update(ln);sum<<=1;if(sum==2)C[1]=0;if(n&(1<<i))sum|=1,vio(sum);}for(int i=1;i<=K;i++)printf("%lld ",A[i]);puts("");return 0; }