题意:给一棵 n n n 个点的树,每条边需要染成黑白两种颜色中的一种。给出 m m m 个条件,每个条件给出 u , v u,v u,v,其中 u u u 是 v v v 的祖先,要求 u u u 到 v v v 的链上至少一条黑边。求方案数 模 998244353 998244353 998244353。
n , m ≤ 5 × 1 0 5 n,m\leq 5\times 10^5 n,m≤5×105
这个dp想了一上午
对于树上的一个点,考虑子树内有关的所有限制,唯一不好处理的是超出子树的部分,而这部分只需要考虑超出最短的。
定义: d p ( u , k ) dp(u,k) dp(u,k) 表示有多少种确定 u u u 子树内的边的颜色的方案,使得所有下端点在 u u u 子树内并且尚未满足的条件 的上端点的深度最大值恰好为 k k k。如果所有上述条件都满足, k = 0 k=0 k=0。
人话翻译:
考虑 u u u 子树内能影响到的条件,分为下列两种:
上端点在子树内(显然下端点就在子树内了)。如果这种条件没有满足,就永远不可能满足了,这个时候上面的定义表现为 k ≥ d e p u k\geq dep_u k≥depu,后面可以看到这部分状态是无用的。上端点是 u u u 的严格祖先,下端点在 u u u 子树内,且 u u u 到下端点这段没有黑边。此时就需要上端点到 u u u 有黑边。如果这样的条件的上端点的最大深度为 k k k,那么所有条件成立当且仅当 u u u 深度为 k k k 的祖先到 u u u 有一条黑边,处理方式后述。进行一次 dfs,每个点 u u u 先假设它没有儿子,即让 d p ( u , x ) = 1 dp(u,x)=1 dp(u,x)=1,其中 x x x 为所有下端点为 u u u 的条件的上端点的最大深度。
然后依次突然加入每个儿子,设儿子为 v v v,得到新的 dp 数组为 d p ′ dp' dp′
考虑连接儿子的这条边是黑边还是白边。
如果是黑边,对于 u u u 来说,从 v v v 子树内来的条件就全部满足了(当然要原来有机会满足),但 u u u 原来不满足的还是不满足。即
d p ( u , k ) ∑ i = 0 d e p u d p ( v , i ) dp(u,k)\sum_{i=0}^{dep_u}dp(v,i) dp(u,k)i=0∑depudp(v,i)
如果是白边,那么要同时满足两边的深度限制,即
∑ max ( i , j ) = k d p ( u , i ) d p ( v , j ) \sum_{\max(i,j)=k}dp(u,i)dp(v,j) max(i,j)=k∑dp(u,i)dp(v,j)
整理一下
d p ′ ( u , k ) = d p ( u , k ) ∑ i = 0 d e p u d p ( v , i ) + d p ( u , k ) ∑ i = 0 k d p ( v , i ) + d p ( v , k ) ∑ i = 0 k − 1 d p ( u , i ) dp'(u,k)=dp(u,k)\sum_{i=0}^{dep_u}dp(v,i)+dp(u,k)\sum_{i=0}^kdp(v,i)+dp(v,k)\sum_{i=0}^{k-1}dp(u,i) dp′(u,k)=dp(u,k)i=0∑depudp(v,i)+dp(u,k)i=0∑kdp(v,i)+dp(v,k)i=0∑k−1dp(u,i)
长这样子的式子都可以考虑线段树合并。
∑ i = 0 d e p u d p ( v , i ) \sum_{i=0}^{dep_u}dp(v,i) ∑i=0depudp(v,i) 是个常数,先算出来。
合并的时候顺便维护左边遍历过的结点的和,如果一边的结点为空,用维护的和给另一边的结点打乘法标记。递归到叶结点了再处理求和符号的边界情况。
注意维护的这个和是合并前的,要先维护再打标记。可以在递归的时候传引用。
复杂度 O ( n log n + m ) O(n\log n+m) O(nlogn+m)
#include <iostream> #include <cstdio> #include <cstring> #include <cctype> #include <vector> #define MAXN 500005 using namespace std; inline int read() { int ans=0; char c=getchar(); while (!isdigit(c)) c=getchar(); while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar(); return ans; } typedef long long ll; const int MOD=998244353; inline int add(const int& x,const int& y){return x+y>=MOD? x+y-MOD:x+y;} int n; int ch[MAXN<<5][2],sum[MAXN<<5],mul[MAXN<<5],cnt; inline void update(int x){sum[x]=add(sum[ch[x][0]],sum[ch[x][1]]);} inline void pushlzy(int x,ll v){sum[x]=sum[x]*v%MOD,mul[x]=mul[x]*v%MOD;} inline void pushdown(int x) { if (mul[x]!=1) { if (ch[x][0]) pushlzy(ch[x][0],mul[x]); if (ch[x][1]) pushlzy(ch[x][1],mul[x]); mul[x]=1; } } void modify(int& x,int l,int r,int k) { if (!x) mul[x=++cnt]=1; if (l==r) return (void)(++sum[x]); int mid=(l+r)>>1; if (k<=mid) modify(ch[x][0],l,mid,k); else modify(ch[x][1],mid+1,r,k); update(x); } int query(int x,int l,int r,int ql,int qr) { if (!x) return 0; if (ql<=l&&r<=qr) return sum[x]; if (qr<l||r<ql) return 0; pushdown(x); int mid=(l+r)>>1; return add(query(ch[x][0],l,mid,ql,qr),query(ch[x][1],mid+1,r,ql,qr)); } int merge(int x,int y,int l,int r,int& xsum,int& ysum) { if (!x&&!y) return 0; if (!x) return ysum=add(ysum,sum[y]),pushlzy(y,xsum),y; if (!y) return xsum=add(xsum,sum[x]),pushlzy(x,ysum),x; if (l==r) { ysum=add(ysum,sum[y]); int t=sum[x]; sum[x]=((ll)sum[x]*ysum+(ll)xsum*sum[y])%MOD; xsum=add(xsum,t); return x; } pushdown(x),pushdown(y); int mid=(l+r)>>1; ch[x][0]=merge(ch[x][0],ch[y][0],l,mid,xsum,ysum); ch[x][1]=merge(ch[x][1],ch[y][1],mid+1,r,xsum,ysum); update(x); return x; } vector<int> e[MAXN],lis[MAXN]; int dep[MAXN],rt[MAXN]; void dfs(int u) { int mx=0; for (int i=0;i<(int)lis[u].size();i++) mx=max(mx,dep[lis[u][i]]); modify(rt[u],0,n,mx); for (int i=0;i<(int)e[u].size();i++) if (!dep[e[u][i]]) { dep[e[u][i]]=dep[u]+1; dfs(e[u][i]); int xsum=0,ysum=query(rt[e[u][i]],0,n,0,dep[u]); rt[u]=merge(rt[u],rt[e[u][i]],0,n,xsum,ysum); } } int main() { n=read(); for (int i=1;i<n;i++) { int u,v; u=read(),v=read(); e[u].push_back(v),e[v].push_back(u); } int m=read(); while (m--) { int u,v; u=read(),v=read(); lis[v].push_back(u); } dfs(dep[1]=1); printf("%d\n",query(rt[1],0,n,0,0)); return 0; }