当前位置: 代码迷 >> 综合 >> [Codeforces1097G] Vladislav and a Great Legend
  详细解决方案

[Codeforces1097G] Vladislav and a Great Legend

热度:67   发布时间:2023-10-29 05:10:28.0

链接:https://codeforces.com/contest/1097/problem/G
大概说一下题意吧:
一棵n个点的树,一个点集S的权值定义为把这个点集连成一个联通块的最少边数
求所有点集的f(S)kf(S)^kf(S)k的和

对于这种带次方的,一般考虑两个方法
一个是二项式展开,一个是斯特林数展开
不知道斯特林数的可以看看这个第二类斯特林数
二项式展开我不是特别会用(虽然斯特林数也不会)
一开始是想用前者做的,但是发现不是很会写,并且二项式展开的话复杂度是稳定至少O(nk2)O(nk^2)O(nk2)的。。并不可以通过
但是如果斯特林数可以和子树大小挂钩,那么复杂度就可以降为O(nk)O(nk)O(nk)

先把式子写成ans=∑i=0k{ki}i!∑S?U(f(S)i)ans=\sum_{i=0}^k \begin{Bmatrix}k\\i\end{Bmatrix} i! \sum_{S? U}\begin{pmatrix}f(S)\\i\end{pmatrix}ans=i=0k?{ ki?}i!S?U?(f(S)i?)
你会发现,右边那个组合数的意义是,我们选择iii条边,对应多少个不同的点集
于是就可以考虑DP了
fi,jf_{i,j}fi,j?表示以i为根的子树,里面选择了jjj条边,有多少种选点方案
转移的话
就是先把儿子的合并,然后加上自己的
当然,有两个条件是不合法的要减去
第一个是,子树外面没有选点,但是我们选了x到父亲这条边
第二个是,子树里面没有选点,但是我们选了x到父亲这条边
两个情况均要在DP的时候暴力去掉
但是在去掉第一个的时候,要小心,别把包括第二个也顺便去掉了

update 1.10
更新一下二项式展开的做法
本来就想用这个的。。但是我不是很会写。。忽然发现有代码可以借鉴,于是就爽快地写了一发
考虑DP,fi,jf_{i,j}fi,j?表示以iii为根,iii这个点一定在范围内的边数的jjj次方的答案
但是你发现,这个东西很不好转移,于是便有了另外一个状态
gi,jg_{i,j}gi,j?表示以iii为根,iii这个点一定不在范围内的边数的jjj次方的答案(但子树不为空)
那么我们就可以让i不管限制随便转移,随后减去jjj就好了
容易发现ggg的话,就是只选了一个儿子的和
至于有多少条边,我们可以发现边数=点数-1
因此,我们可以把父亲去掉,每一个点的贡献在合并到父亲的时候贡献就好了
具体来说,就是在合并到父亲的时候,先和一个全是1的数组二项式合并一下
然后再进行操作
和斯特林展开类似的,每次DP完都要去掉空子树的情况
然后就可以做到nk2nk^2nk2
听说可以用NTT优化做到nklogknklogknklogk。。但是由于常数巨大,且这题并不是NTT模数因此并不好实现,写个nk2nk^2nk2就溜了吧。。

CODE(斯特林的,二项式的在下面):

#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cstring>
using namespace std;
typedef long long LL;
const LL MOD=1e9+7;
const LL N=100005;
const LL K=205;
LL n,k;
struct qq
{
    LL x,y,last;
}e[N*2];LL num,last[N];
void init (LL x,LL y)
{
    num++;e[num].x=x;e[num].y=y;e[num].last=last[x];last[x]=num;
}
LL S[K][K];
LL JC[K];
LL h[K];
LL siz[N];
LL f[N][K];
LL g[N];
void dfs (LL x,LL fa)
{
    siz[x]=1;f[x][0]=2;for (LL xx=last[x];xx!=-1;xx=e[xx].last){
    LL y=e[xx].y;if (y==fa) continue;dfs(y,x);for (LL u=0;u<=min(siz[x]+siz[y]-1,k);u++) g[u]=0;for (LL u=0;u<siz[x]&&u<=k;u++)for (LL i=0;i<=siz[y]&&(u+i)<=k;i++)g[u+i]=(g[u+i]+f[x][u]*f[y][i]%MOD)%MOD;siz[x]=siz[x]+siz[y];for (LL u=0;u<=min(siz[x]-1,k);u++) f[x][u]=g[u];}if (x==1){
    for (LL u=0;u<=k;u++) h[u]=h[u]+f[x][u];}else{
    for (LL u=1;u<=k;u++) h[u]=(h[u]-f[x][u-1])%MOD;h[1]=(h[1]+1)%MOD;}for (LL u=k;u>=1;u--) f[x][u]=(f[x][u]+f[x][u-1])%MOD;f[x][1]=(f[x][1]-1+MOD)%MOD;
}
int main()
{
    num=0;memset(last,-1,sizeof(last));scanf("%lld%lld",&n,&k);JC[0]=1;for (LL u=1;u<=k;u++) JC[u]=JC[u-1]*u%MOD;S[0][0]=1;for (LL u=1;u<=k;u++)for (LL i=1;i<=u;i++)S[u][i]=(S[u-1][i-1]+S[u-1][i]*i%MOD)%MOD;for (LL u=1;u<n;u++){
    LL x,y;scanf("%lld%lld",&x,&y);init(x,y);init(y,x);}dfs(1,0);LL ans=0;for (LL u=0;u<=k;u++) ans=(ans+S[k][u]*JC[u]%MOD*h[u]%MOD)%MOD;ans=(ans+MOD)%MOD;printf("%lld\n",ans);return 0;
}
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long LL;
const LL MOD=1e9+7;
const LL N=100005;
const LL K=205;
struct qq
{
    LL x,y,last;
}e[N];LL num,last[N];
LL n,k;
LL Pow[N];
void init (LL x,LL y)
{
    e[++num].x=x;e[num].y=y;e[num].last=last[x];last[x]=num;
}
LL f[N][K],g[N][K];
LL tmp[K];
LL siz[N];
LL C[K][K];
LL add (LL x,LL y)	{
    x=x+y;return x>=MOD?x-MOD:x;}
LL dec (LL x,LL y)	{
    x=x-y;return x<0?x+MOD:x;}
void dfs (LL x,LL fa)
{
    f[x][0]=2;siz[x]=1;for (LL u=last[x];u!=-1;u=e[u].last){
    LL y=e[u].y;if (y==fa) continue;dfs(y,x);for (LL i=0;i<=k;i++){
    tmp[i]=0;for (LL j=0;j<=i;j++)tmp[i]=add(tmp[i],C[i][j]*add(f[y][j],g[y][j])%MOD);}for (LL i=0;i<=k;i++)	g[x][i]=add(g[x][i],tmp[i]);for (LL i=k;i>=0;i--)for (LL j=i;j>=0;j--)f[x][i]=add(f[x][i],C[i][j]*f[x][j]%MOD*tmp[i-j]%MOD);siz[x]=siz[x]+siz[y];f[x][0]=Pow[siz[x]];}for (LL u=0;u<=k;u++) f[x][u]=dec(f[x][u],g[x][u]);f[x][0]=dec(f[x][0],1);
}
int main()
{
    num=0;memset(last,-1,sizeof(last));scanf("%lld%lld",&n,&k);Pow[0]=1;for (LL u=1;u<=n;u++) Pow[u]=Pow[u-1]*2%MOD;C[0][0]=1;for (LL u=1;u<=k;u++){
    C[u][0]=1;for (LL i=1;i<=u;i++)	C[u][i]=add(C[u-1][i-1],C[u-1][i]);}for (LL u=1;u<n;u++){
    LL x,y;scanf("%lld%lld",&x,&y);init(x,y);init(y,x);}dfs(1,0);LL ans=0;for (LL u=1;u<=n;u++) ans=add(ans,f[u][k]);printf("%lld\n",ans);return 0;
}
  相关解决方案