题意
题解
如果直接考虑DP
f[i][j]f[i][j]表示当前构造了ii位,余数有多少
种方案
然后再构造g[i][j]g[i][j]表示构造了至少ii位,这个的话当一个累加器就好了
但是这样太慢了,是
的
于是要考虑优化
我们考虑到,这种东西应该是可以合并的
于是我们就考虑使用一个类似快速幂的方式来求出这两个东西
如果是奇数的话,那么就先变为偶数,然后剩下一个暴力合并
偶数的话,就拆成i/2i/2
然后一个数o=basei2o=basei2
然后容易得到
f[j?o+k]=∑f[j]?f[k]f[j?o+k]=∑f[j]?f[k]
然后这个显然是一个卷积的形式
你只需要把第一个f[j]f[j]放在f[j?o]f[j?o]就可以了
然后gg的转移就是
然后操作是一样的
然后就可以了
时间复杂度O(nlognlogp)O(nlognlogp)
因为是求稳先写暴力的,所以可能代码比较长。。
但是你们也可以参考一下暴力
一开始因为pow的模数不同没有意识到而调了很久
CODE:
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<iostream>
#include<cstring>
using namespace std;
typedef long long LL;
const LL MOD=998244353;
const LL gi=3,ggi=332748118;
const LL N=50005;
LL base,p,x;
LL pow (LL x,LL y,LL p)
{if (y==1) return x;LL lalal=pow(x,y>>1,p);lalal=lalal*lalal%p; //printf("%I64d %I64d\n",x,y);if (y&1) lalal=lalal*x%p;return lalal;
}
LL F[N],G[N];
LL f[N],g[N];
LL bin[N];
void ntt (LL *a,LL n,LL o)
{for (LL u=0;u<n;u++) bin[u]=((bin[u>>1]>>1)|((u&1)*(n>>1)));for (LL u=0;u<n;u++) if (u<bin[u]) swap(a[u],a[bin[u]]);for (int u=1;u<n;u<<=1){LL wn=pow(o==1?gi:ggi,(MOD-1)/(u<<1),MOD),w,t;for (int i=0;i<n;i=i+(u<<1)){w=1;for (int k=0;k<u;k++){t=w*a[u+i+k]%MOD;a[u+i+k]=(a[i+k]-t+MOD)%MOD;a[i+k]=(a[i+k]+t)%MOD;w=w*wn%MOD;}}}if (o==-1){LL Inv=pow(n,MOD-2,MOD);for (int u=0;u<n;u++) a[u]=a[u]*Inv%MOD;}
}
LL a[N],b[N],now;
void solve (LL n)
{if (n==1){for (LL u='a';u<='z';u++) {f[u%p]++;g[u%p]++;}return ;}if (n&1){solve(n-1);for (LL u=0;u<p;u++) {F[u]=f[u];f[u]=0;}for (LL u='a';u<='z';u++)//枚举加上一个什么 for (LL i=0;i<p;i++){LL h=(i*base+u)%p;f[h]+=F[i];if (f[h]>=MOD) f[h]-=MOD;}for (LL u=0;u<p;u++) {g[u]=g[u]+f[u];if (g[u]>=MOD) g[u]-=MOD;}}else{solve(n/2);/*printf("%I64d\n",n);printf("f:");for (LL u=0;u<p;u++) printf("%I64d ",f[u]);printf("\n");printf("g:");for (LL u=0;u<p;u++) printf("%I64d ",g[u]);printf("\n");system("pause");*/LL o=pow(base,n/2,p);for (LL u=0;u<p;u++) {F[u]=f[u];G[u]=g[u];f[u]=0;}for (LL u=0;u<p;u++) a[u]=b[u]=0;for (LL u=0;u<p;u++) a[(u*o)%p]+=F[u];for (LL u=0;u<p;u++) b[u]=F[u];for (LL u=p;u<now;u++) a[u]=b[u]=0;/*for (LL u=0;u<p;u++)for (LL i=0;i<p;i++)f[(u+i)%p]=(f[(u+i)%p]+a[u]*b[i])%MOD;*/ntt(a,now,1);ntt(b,now,1);for (int u=0;u<now;u++) a[u]=a[u]*b[u]%MOD;ntt(a,now,-1);for (int u=0;u<now;u++) f[u%p]=(f[u%p]+a[u])%MOD;/*for (LL u=0;u<p;u++)for (LL i=0;i<p;i++){LL h=(u*o+i)%p;f[h]=(f[h]+F[u]*F[i]%MOD)%MOD;}*/for (LL u=0;u<p;u++) a[u]=b[u]=0;for (LL u=0;u<p;u++) a[(u*o)%p]+=G[u];for (LL u=0;u<p;u++) b[u]=F[u];for (LL u=p;u<now;u++) a[u]=b[u]=0;/*for (LL u=0;u<p;u++)for (LL i=0;i<p;i++)f[(u+i)%p]=(f[(u+i)%p]+a[u]*b[i])%MOD;*/ntt(a,now,1);ntt(b,now,1);for (int u=0;u<now;u++) a[u]=a[u]*b[u]%MOD;ntt(a,now,-1);for (int u=0;u<now;u++) g[u%p]=(g[u%p]+a[u])%MOD;/*for (LL u=0;u<p;u++)for (LL i=0;i<p;i++){LL h=(u*o+i)%p;g[h]=(g[h]+G[u]*F[i]%MOD)%MOD;}*/}/* printf("%I64d\n",n); printf("f:");for (LL u=0;u<p;u++) printf("%I64d ",f[u]);printf("\n");printf("g:");for (LL u=0;u<p;u++) printf("%I64d ",g[u]);printf("\n");system("pause");*/
}
int main()
{memset(f,0,sizeof(f));memset(g,0,sizeof(g));LL n;scanf("%I64d%I64d%I64d%I64d",&n,&base,&p,&x);now=1;while (now<p) now<<=1;now<<=1;solve(n);/*for (int u=0;u<p;u++) printf("%I64d ",f[u]);printf("\n");*///for (int u=0;u<p;u++) printf("%I64d ",g[u]);printf("%I64d\n",g[x]);return 0;
}