当前位置: 代码迷 >> 综合 >> 2020中国大学生程序设计竞赛(CCPC) - 网络选拔赛----HDU--6900、Residual Polynomial(分治、FFT)
  详细解决方案

2020中国大学生程序设计竞赛(CCPC) - 网络选拔赛----HDU--6900、Residual Polynomial(分治、FFT)

热度:92   发布时间:2024-02-21 11:43:24.0

题目链接

题面:
在这里插入图片描述

题意:
给定函数: f1(x)=∑i=0naixif_1(x)=\sum_{i=0}^na_ix^if1?(x)=i=0n?ai?xi

给定 b2,b3,...,bnb_2,b_3,...,b_nb2?,b3?,...,bn?c2,c3,...,cnc_2,c_3,...,c_nc2?,c3?,...,cn?

对于 i∈[2,n],fi(x)=bi(fi?1(x))′+cifi?1(x)i\in[2,n],f_i(x)=b_i(f_{i-1}(x))'+c_if_{i-1}(x)i[2,n]fi?(x)=bi?(fi?1?(x))+ci?fi?1?(x)

题解:
我们把每个 fif_ifi? 写成一列。
fi,jf_{i,j}fi,j? 表示 fif_ifi?jjj 次项的次数。其中 f1,j=ajf_{1,j}=a_jf1,j?=aj?

[f1,0f2,0f3,0?fn,0f1,1f2,1f3,1?fn,1f1,2f2,2f3,2?fn,2?????f1,nf2,nf3,n?fn,n]\begin{bmatrix}&f_{1,0}&f_{2,0}&f_{3,0}&\cdots&f_{n,0}&\\ &f_{1,1}&f_{2,1}&f_{3,1}&\cdots&f_{n,1}&\\&f_{1,2}&f_{2,2}&f_{3,2}&\cdots&f_{n,2}&\\&\vdots&\vdots&\vdots&\vdots&\vdots&\\&f_{1,n}&f_{2,n}&f_{3,n}&\cdots&f_{n,n}&\end{bmatrix}?????????f1,0?f1,1?f1,2??f1,n??f2,0?f2,1?f2,2??f2,n??f3,0?f3,1?f3,2??f3,n????????fn,0?fn,1?fn,2??fn,n???????????

考虑 fi,jf_{i,j}fi,j? 的转移,发现存在两种转移状态。

①、fi,j?ci+1?>fi+1,j,其中i<nf_{i,j}*c_{i+1}->f_{i+1,j},其中 i<nfi,j??ci+1??>fi+1,j?i<n
②、fi,j?(j?bi+1)?>fi+1,j?1,其中i<n,j>0f_{i,j}*(j*b_{i+1})->f_{i+1,j-1},其中 i<n,j>0fi,j??(j?bi+1?)?>fi+1,j?1?i<nj>0

我们发现,相当于在上方的矩阵中,每个状态 fi,j向fi+1,j和fi+1,j?1f_{i,j}向f_{i+1,j}和f_{i+1,j-1}fi,j?fi+1,j?fi+1,j?1? 连接了一条边。

我们考虑 f1,if_{1,i}f1,i? 对于 fn,jf_{n,j}fn,j? 的贡献,其中 i≥ji\ge jij

可以发现 f1,if_{1,i}f1,i? 对于 fn,jf_{n,j}fn,j? 的贡献就是从 f1,if_{1,i}f1,i?fn,jf_{n,j}fn,j? 的路径。

那么这个贡献 ans(f1,i?>fn,j)=f1,i?∑(∏path)ans(f_{1,i}->f_{n,j})=f_{1,i}*\sum(\prod path)ans(f1,i??>fn,j?)=f1,i??(path)

我们先不考虑②中 j?bi+1中的jj*b_{i+1}中的 jj?bi+1?j,那么对于转移 f1,i?>fn,jf_{1,i}->f_{n,j}f1,i??>fn,j?,就是对于 x∈[2,n]x\in[2,n]x[2,n] 选择 bxb_xbx? 或者 cxc_xcx?,其中 bxb_xbx?选择了 i?ji-ji?j 个,cxc_xcx? 选择了 n?1?(i?j)n-1-(i-j)n?1?(i?j) 个,然后把所有方案相加。

我们设 F(k)F(k)F(k) 为选 kkkbxb_xbx? ,选 n?1?kn-1-kn?1?kcxc_xcx? 的方案和。

我们设 F(l,r,k)F(l,r,k)F(l,r,k) 为在区间 [l,r][l,r][l,r] 中选 kkkbxb_xbx? 方案和。

那么 F(l,r,k)=∑i+j=kF(l,mid,i)?F(mid+1,r,j)F(l,r,k)=\sum_{i+j=k}F(l,mid,i)*F(mid+1,r,j)F(l,r,k)=i+j=k?F(l,mid,i)?F(mid+1,r,j)

分治+卷积时间复杂度为 O(nlog2n)O(nlog^2n)O(nlog2n)
 
 

现在考虑②中的 jjj 的贡献,我们令 gj=fn,jg_j=f_{n,j}gj?=fn,j?fi=f1,if_i=f_{1,i}fi?=f1,i?

那么有 gj=fn,j=∑i?k=jF(k)?fi?i!j!g_j=f_{n,j}=\sum_{i-k=j}F(k)*f_{i}*\frac{i!}{j!}gj?=fn,j?=i?k=j?F(k)?fi??j!i!?

现在考虑怎么求解 gj=inv(j!)?∑i?k=jF(k)?fi?i!g_j=inv(j!)*\sum_{i-k=j}F(k)*f_{i}*i!gj?=inv(j!)?i?k=j?F(k)?fi??i!

我们设 h(n?i)=fi?i!h(n-i)=f_{i}*i!h(n?i)=fi??i!

那么有:gj=inv(j!)?∑i?k=jF(k)?hn?ig_j=inv(j!)*\sum_{i-k=j}F(k)*h_{n-i}gj?=inv(j!)?i?k=j?F(k)?hn?i?

我们令 i=n?ii=n-ii=n?i 则有:

gj=inv(j!)?∑i+k=n?jF(k)?hig_j=inv(j!)*\sum_{i+k=n-j}F(k)*h_{i}gj?=inv(j!)?i+k=n?j?F(k)?hi?

我们设 p(n?j)=∑i+k=n?jF(k)?hip(n-j)=\sum_{i+k=n-j}F(k)*h_{i}p(n?j)=i+k=n?j?F(k)?hi?

那么 gj=inv(j!)?p(n?j)g_j=inv(j!)*p(n-j)gj?=inv(j!)?p(n?j)

即做一次 FFTFFTFFT 即可。

时间复杂度 O(nlog2n+nlogn)=O(nlog2n)O(nlog^2n+nlogn)=O(nlog^2n)O(nlog2n+nlogn)=O(nlog2n)

#pragma GCC optimize(2)
#pragma GCC optimize("Ofast","inline","-ffast-math")
#pragma GCC target("avx,sse2,sse3,sse4,mmx")
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<unordered_map>
#include<set>
namespace onlyzhao
{
    #define ui unsigned int#define ll long long#define llu unsigned ll#define ld long double#define pr make_pair#define pb push_back#define lc (cnt<<1)#define rc (cnt<<1|1)#define len(x) (t[(x)].r-t[(x)].l+1)#define tmid ((l+r)>>1)#define fhead(x) for(int i=head[(x)];i;i=nt[i])#define max(x,y) ((x)>(y)?(x):(y))#define min(x,y) ((x)>(y)?(y):(x))#define one(n) for(int i=1;i<=(n);i++)#define rone(n) for(int i=(n);i>=1;i--)#define fone(i,x,n) for(int i=(x);i<=(n);i++)#define frone(i,n,x) for(int i=(n);i>=(x);i--)#define fonk(i,x,n,k) for(int i=(x);i<=(n);i+=(k))#define fronk(i,n,x,k) for(int i=(n);i>=(x);i-=(k))#define two(n,m) for(int i=1;i<=(n);i++) for(int j=1;j<=(m);j++)#define ftwo(i,n,j,m) for(int i=1;i<=(n);i++) for(int j=1;j<=(m);j++)#define fvc(vc) for(int i=0;i<vc.size();i++)#define frvc(vc) for(int i=vc.size()-1;i>=0;i--)#define forvc(i,vc) for(int i=0;i<vc.size();i++)#define forrvc(i,vc) for(int i=vc.size()-1;i>=0;i--)#define cls(a) memset(a,0,sizeof(a))#define cls1(a) memset(a,-1,sizeof(a))#define clsmax(a) memset(a,0x3f,sizeof(a))#define clsmin(a) memset(a,0x80,sizeof(a))#define cln(a,num) memset(a,0,sizeof(a[0])*num)#define cln1(a,num) memset(a,-1,sizeof(a[0])*num)#define clnmax(a,num) memset(a,0x3f,sizeof(a[0])*num)#define clnmin(a,num) memset(a,0x80,sizeof(a[0])*num)#define sc(x) scanf("%d",&x)#define sc2(x,y) scanf("%d%d",&x,&y)#define sc3(x,y,z) scanf("%d%d%d",&x,&y,&z)#define scl(x) scanf("%lld",&x)#define scl2(x,y) scanf("%lld%lld",&x,&y)#define scl3(x,y,z) scanf("%lld%lld%lld",&x,&y,&z)#define scf(x) scanf("%lf",&x)#define scf2(x,y) scanf("%lf%lf",&x,&y)#define scf3(x,y,z) scanf("%lf%lf%lf",&x,&y,&z)#define scs(x) scanf("%s",x+1)#define scs0(x) scanf("%s",x)#define scline(x) scanf("%[^\n]%*c",x+1)#define scline0(x) scanf("%[^\n]%*c",x)#define pcc(x) putchar(x)#define pc(x) printf("%d\n",x)#define pc2(x,y) printf("%d %d\n",x,y)#define pc3(x,y,z) printf("%d %d %d\n",x,y,z)#define pck(x) printf("%d ",x)#define pcl(x) printf("%lld\n",x)#define pcl2(x,y) printf("%lld %lld\n",x,y)#define pcl3(x,y,z) printf("%lld %lld %d\n",x,y,z)#define pclk(x) printf("%lld ",x)#define pcf2(x) printf("%.2f\n",x)#define pcf6(x) printf("%.6f\n",x)#define pcf8(x) printf("%.8f\n",x)#define pcs(x) printf("%s\n",x+1)#define pcs0(x) printf("%s\n",x)#define pcline(x) printf("%d**********\n",x)#define casett int tt;sc(tt);int pp=0;while(tt--)char buffer[100001],*S,*T;inline char Get_Char(){
    if (S==T){
    T=(S=buffer)+fread(buffer,1,100001,stdin);if (S==T) return EOF;}return *S++;}inline int read(){
    char c;int re=0;for(c=Get_Char();c<'0'||c>'9';c=Get_Char());while(c>='0'&&c<='9') re=re*10+(c-'0'),c=Get_Char();return re;}
};
using namespace onlyzhao;
using namespace std;const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const int mod=998244353;
const int p=mod;
const double eps=1e-8;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=400100;
const int maxm=100100;
const int maxp=100100;
const int up=100100;
const int g=3;int A[maxn],B[maxn];
vector<int>a,b,c;
int fac[maxn],inv[maxn];
int fi[maxn];
int mypow(int a,int b)
{
    if(b<0) return mypow(mypow(a,p-2),-b);int ans=1;while(b){
    if(b&1) ans=1ll*ans*a%p;a=1ll*a*a%p;b>>=1;}return ans%p;
}void init(void)
{
    fac[0]=1;for(int i=1;i<maxn;i++)fac[i]=1ll*fac[i-1]*i%p;inv[maxn-1]=mypow(fac[maxn-1],p-2);for(int i=maxn-2;i>=0;i--)inv[i]=1ll*inv[i+1]*(i+1)%p;
}int init(int n,int m)
{
    int len=1,cnt=0;while(len<=n+m) len<<=1,cnt++;for(int i=0;i<len;i++)fi[i]=((fi[i>>1]>>1)|((i&1)<<(cnt-1)));return len;
}void ntt(int *x,int len,int f)
{
    for(int i=0;i<len;i++)if(i<fi[i]) swap(x[i],x[fi[i]]);for(int i=1;i<len;i<<=1){
    int r=i<<1;int wn=mypow(g,f*(p-1)/r);for(int j=0;j<len;j+=r){
    int w=1;for(int k=0;k<i;k++){
    int xx=x[j+k],yy=1ll*w*x[j+i+k]%p;x[j+k]=(xx+yy)%p;x[j+i+k]=(xx-yy+p)%p;w=1ll*w*wn%p;}}}if(f==-1){
    int invn=mypow(len,p-2);for(int i=0;i<len;i++)x[i]=1ll*x[i]*invn%p;}
}vector<int> dontt(const vector<int>&a,const vector<int>&b)
{
    //小范围暴力,大范围卷积会快很多int n=a.size()-1,m=b.size()-1;vector<int>vc(n+m+1);if(n<=50&&m<=50){
    for(int i=0;i<=n;i++){
    for(int j=0;j<=m;j++)vc[i+j]=(vc[i+j]+1ll*a[i]*b[j])%p;}return vc;}int len=init(n,m);for(int i=0;i<len;i++){
    if(i<=n) A[i]=a[i];else A[i]=0;;if(i<=m) B[i]=b[i];else B[i]=0;}ntt(A,len,1);ntt(B,len,1);for(int i=0;i<len;i++)A[i]=1ll*A[i]*B[i]%p;ntt(A,len,-1);for(int i=0;i<=n+m;i++)vc[i]=A[i];return vc;
}vector<int> sol(int l,int r)
{
    if(l==r){
    vector<int>vc;vc.pb(c[l]);vc.pb(b[l]);return vc;}return dontt(sol(l,tmid),sol(tmid+1,r));
}int main(void)
{
    init();int tt;scanf("%d",&tt);while(tt--){
    int n;scanf("%d",&n);a=vector<int>(n+1),b=vector<int>(n-1),c=vector<int>(n-1);for(int i=0;i<=n;i++)scanf("%d",&a[i]);for(int i=0;i<=n-2;i++)scanf("%d",&b[i]);for(int i=0;i<=n-2;i++)scanf("%d",&c[i]);vector<int>vc=sol(0,n-2);for(int i=0;i<a.size();i++)a[i]=1ll*a[i]*fac[i]%p;reverse(a.begin(),a.end());vector<int>ans=dontt(vc,a);for(int i=0;i<=n;i++){
    if(i!=0) putchar(' ');printf("%lld",1ll*ans[n-i]*inv[i]%p);}putchar('\n');}return 0;}
  相关解决方案