当前位置: 代码迷 >> 综合 >> CSU 1811 Tree Intersection (map启发式合并+树形DP)
  详细解决方案

CSU 1811 Tree Intersection (map启发式合并+树形DP)

热度:48   发布时间:2023-11-15 12:27:11.0

题目链接:http://acm.csu.edu.cn:20080/csuoj/problemset/problem?pid=1811

题目大意:

给定一颗树,
树上每个节点都有颜色,问
对于每条边,其两端子树颜色集合交集的大小。

题目分析: 

相比于上次用线段树动态开点类似的方法
来做合并,这次用map做启发式合并,这样的
复杂度会因为每次都合并小的而降维。
这道题同时还嵌套着树形DP的思想。
对于当前节点和其子节点维护好的状态,
分为X和Y,Y为当前考虑的子树的状态,
当然可以因为X和Y的大小关系而swap一下,达到降低
复杂度的效果。
下面考虑如何合并这两个状态:
对于特定的一种颜色,如果X原来为零且X+Y集合中该颜色数小于总数
(默认Y中该颜色数不为零不然不会扫描到),那么总体计数要加一。
如果X原来不为零且X+Y中包含了该颜色的所有数量,
那么总体计数就要减一,消去这一影响。

#include<bits/stdc++.h>
using namespace std;#define debug puts("YES");
#define rep(x,y,z) for(int (x)=(y);(x)<(z);(x)++)
#define ll long long#define lrt int l,int r,int rt
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
#define root l,r,rt
#define mst(a,b) memset((a),(b),sizeof(a))
#define pii pair<int,int>
#define fi first
#define se second
#define mk(x,y) make_pair(x,y)
const int mod=1e9+7;
const int maxn=1e5+5;
const int ub=1e6;
const double inf=1e-4;
ll powmod(ll x,ll y){ll t; for(t=1;y;y>>=1,x=x*x%mod) if(y&1) t=t*x%mod; return t;}
ll gcd(ll x,ll y){return y?gcd(y,x%y):x;}
/*
相比于上次用线段树动态开点类似的方法
来做合并,这次用map做启发式合并,这样的
复杂度会因为每次都合并小的而降维。
这道题同时还嵌套着树形DP的思想。
对于当前节点和其子节点维护好的状态,
分为X和Y,Y为当前考虑的子树的状态,
当然可以因为X和Y的大小关系而swap一下,达到降低
复杂度的效果。
下面考虑如何合并这两个状态:
对于特定的一种颜色,如果X原来为零且X+Y集合中该颜色数小于总数
(默认Y中该颜色数不为零不然不会扫描到),那么总体计数要加一。
如果X原来不为零且X+Y中包含了该颜色的所有数量,
那么总体计数就要减一,消去这一影响。
*/
int n,a[maxn];
map<int,int> mp[maxn];///
map<int,int>::iterator it;
int cnt[maxn],ans[maxn],sum[maxn];
///前式链向星
struct node{int u,nxt,id;
}e[maxn<<1];
int head[maxn<<1],tot=0;
void init(){mst(head,-1),tot=0;
}
void add(int x,int y,int id){e[tot]=node{y,head[x],id};head[x]=tot++;
}
///
void dfs(int u,int fa,int id){mp[u][a[u]]++;cnt[u]=mp[u][a[u]]<sum[a[u]]?1:0;for(int i=head[u];~i;i=e[i].nxt){int v=e[i].u;if(v==fa) continue;dfs(v,u,e[i].id);if(mp[u].size()<mp[v].size()) swap(mp[u],mp[v]),swap(cnt[u],cnt[v]);for(it=mp[v].begin();it!=mp[v].end();it++){int key=it->fi,val=it->se,mv=mp[u][key]+val;if(val==0) continue;///事实上这句可以省略if(mv<sum[key]&&mp[u][key]==0) cnt[u]++;if(mv==sum[key]&&mp[u][key]>0) cnt[u]--;mp[u][key]+=val;}}if(id) ans[id]=cnt[u];
}int main(){while(scanf("%d",&n)!=EOF){init(),mst(sum,0),mst(cnt,0);rep(i,1,n+1) scanf("%d",&a[i]),sum[a[i]]++,mp[i].clear();rep(i,1,n){int x,y;scanf("%d%d",&x,&y);add(x,y,i),add(y,x,i);}dfs(1,-1,0);rep(i,1,n) printf("%d\n",ans[i]);}return 0;
}

 

  相关解决方案