SPOJ Combat on a tree

2016.03.12 14:10 Sat | 35次阅读 | 旧日oi | 固定链接 | 源码

题目大意

给出一颗n($n\le 10^5$)个点的树,每个点是黑色或者白色,每次可以选择把一个白点到1的路径上的所有白点都染黑,问先手能否必胜,能的话可以选择哪些点。

解题思路

sg函数+Trie树的合并

首先可以看出一个sg函数的模型,每次选一个白点,把树分成几个部分,变成子问题。那么问题是如何算出以某个点为根的子树的sg值。
设当前点为u,sg[u]表示u点的sg值,f[u][v]表示在u的子树中选择v并删掉u到v的路径上的白点后其他子树的sg值的和(和指异或和)。
以下sigma,+均表示异或。
则sg[u]=mex(f[u][w]),w是u子树中的白点。
那么如果u点是白色,有f[u][u]=sigma(sg[v]),v是u的儿子。
对于一个点u的儿子v中的某个白点w,有
f[u][w]=sigma(sg[son[u]]) + sg[v] + f[v][w],
暴力求这个东西的话是n^2的。
观察到对于u的一个儿子v,v的子树中的白点都会被异或上同一个值。
我们可以用数据结构去维护,用什么能支持打异或标记呢?
Trie树!只要记录下标记的值,看某个节点的儿子是否需要交换就行了。
那问题还剩一个,就是如何求这个mex。
我们考虑把Trie树合并。
启发式的话是log^2n,而事实上可以做到logn,方法和线段树合并是相同的。
然而我不会证……

my code

#include <bits/stdc++.h>
using namespace std;
#define D 20
#define N 100005
#define M 20000005
int h[N],tp=1;
struct edge{int to,next,val;}e[N<<1];
void ae(int u,int v,int w=0){e[++tp].to=v;e[tp].next=h[u];e[tp].val=w;h[u]=tp;}
int n,col[N],use[N],root[N],dp[N];
int ans[N],cnt;
int tot,son[M][2],siz[M],tag[M];
void insert(int &x,int y)
{
    if(!x) x=++tot;
    int p=x;siz[p]++;
    for(int now,i=D-1;i>=0;i--)
    {
        now=y>>i&1;
        if(!son[p][now]) son[p][now]=++tot;
        p=son[p][now];siz[p]++;
    }
}
void work(int x,int v,int d)
{
    if(!x) return;
    if(v>>d&1) swap(son[x][0],son[x][1]);
    tag[x]^=v;
}
void pushdown(int x,int d)
{
    if(d==0) return;
    if(tag[x])
    {
        work(son[x][0],tag[x],d-1);
        work(son[x][1],tag[x],d-1);
        tag[x]=0;
    }
}
void merge(int &x,int y,int d)
{
    if(!x||!y) {x=x+y;return;}
    if(d==0) {siz[x]=min(1,siz[x]+siz[y]);return;}
    pushdown(x,d-1);
    pushdown(y,d-1);
    merge(son[x][0],son[y][0],d-1);
    merge(son[x][1],son[y][1],d-1);
    siz[x]=siz[son[x][0]]+siz[son[x][1]];
}
int find(int x)
{
    int ans=0;
    for(int i=D-1;i>=0;i--)
    {
        pushdown(x,i);
        if(siz[son[x][0]]==(1<<i)) ans+=(1<<i),x=son[x][1];
        else x=son[x][0];
    }
    return ans;
}
int dfs(int u,int f)
{
    for(int i=h[u];i;i=e[i].next)
    if(e[i].to!=f) use[u]^=dfs(e[i].to,u);
    if(!col[u]) insert(root[u],use[u]);
    for(int v,i=h[u];i;i=e[i].next)
    if((v=e[i].to)!=f)
    {
        work(root[v],use[u]^dp[v],D-1);
        merge(root[u],root[v],D);
    }
    return dp[u]=find(root[u]);
}
void dfs2(int u,int f,int d)
{
    if(!col[u]&&((use[u]^d)==0)) ans[++cnt]=u;
    for(int v,i=h[u];i;i=e[i].next)
    if((v=e[i].to)!=f) dfs2(v,u,d^use[u]^dp[v]);
}
int main()
{
    cin>>n;
    for(int i=1;i<=n;i++) scanf("%d",&col[i]);
    for(int u,v,i=1;i<n;i++)
    {
        scanf("%d%d",&u,&v);
        ae(u,v);ae(v,u);
    }
    dfs(1,0);
    dfs2(1,0,0);
    sort(ans+1,ans+cnt+1);
    for(int i=1;i<=cnt;i++) printf("%d\n",ans[i]);
}