【BZOJ2164】采矿

2016.03.18 16:08 Fri | 33次阅读 | 旧日oi | 固定链接 | 源码

题解

树链剖分+线段树+dfs序

对于子树部分的查询,我们可以搞出dfs序后用线段树维护,对每个节点存一个大小为m的数组,合并子节点的时候做一遍背包就好了,每次合并复杂度m^2,这部分复杂度就是logn*m^2
对于链上的询问,我们只要维护出这条链上的对于1...m的每一个点的最大值就可以了,这步用树剖解决就好了。复杂度mlogn^2.
挺无聊的数据结构题……

my code

#include <bits/stdc++.h>
using namespace std;
#define N 20005
#define X 65536
#define Y 2147483647
#define ll long long
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
int n,m,A,B,Q,C,op,p,u,v;
int h[N],tp=1;
struct edge{int to,next,val;}e[N];
void ae(int u,int v,int w=0){e[++tp].to=v;e[tp].next=h[u];e[tp].val=w;h[u]=tp;}
int rk[N],ed[N],siz[N],son[N],top[N],fa[N],dep[N],tid[N],tim;
ll st[N][55];
int work()
{
    A=((A^B)+B/X+B*X)&Y;
    B=((A^B)+A/X+A*X)&Y;
    return (A^B)%Q;
}
void dfs(int u)
{
    siz[u]=1;
    for(int i=h[u];i;i=e[i].next)
    {
        fa[e[i].to]=u;
        dep[e[i].to]=dep[u]+1;
        dfs(e[i].to);
        siz[u]+=siz[e[i].to];
        if(siz[son[u]]<siz[e[i].to]) son[u]=e[i].to;
    }
}
void dfs(int u,int tp)
{
    rk[u]=++tim;tid[tim]=u;top[u]=tp;
    if(son[u]) dfs(son[u],tp);
    for(int i=h[u];i;i=e[i].next)
    if(e[i].to!=son[u]) dfs(e[i].to,e[i].to);
    ed[u]=tim;
}
struct node{
    ll mx[55],v[55];
}a[N<<2],tmp;
void pushup(node &a,node b,node c)
{
    for(int i=1;i<=m;i++) 
    {
        a.v[i]=max(b.v[i],c.v[i]);a.mx[i]=0;
        for(int j=0;j<=i;j++) a.mx[i]=max(a.mx[i],b.mx[j]+c.mx[i-j]);
    }
}
void build(int l,int r,int rt)
{
    if(l==r)
    {
        for(int i=1;i<=m;i++) a[rt].mx[i]=a[rt].v[i]=st[tid[l]][i];
        return;
    }
    int mid=(l+r)>>1;
    build(lson);build(rson);
    pushup(a[rt],a[rt<<1],a[rt<<1|1]);
}
void update(int l,int r,int rt,int p)
{
    if(l==r)
    {
        for(int i=1;i<=m;i++) a[rt].mx[i]=a[rt].v[i]=st[tid[l]][i];
        return;
    }
    int mid=(l+r)>>1;
    if(p<=mid) update(lson,p);
    else update(rson,p);
    pushup(a[rt],a[rt<<1],a[rt<<1|1]);
}
node query(int l,int r,int rt,int L,int R)
{
    if(L==l&&R==r) 
        return a[rt];
    int mid=(l+r)>>1;
    if(R<=mid) return query(lson,L,R);
    else if(L>mid) return query(rson,L,R);
    else 
    {
        node tmp;memset(&tmp,0,sizeof(tmp));
        pushup(tmp,query(lson,L,mid),query(rson,mid+1,R));
        return tmp;
    }
}
ll get_mx(int l,int r,int rt,int L,int R,int x)
{
    if(L==l&&R==r) return a[rt].v[x];
    int mid=(l+r)>>1;
    if(R<=mid) return get_mx(lson,L,R,x);
    else if(L>mid) return get_mx(rson,L,R,x);
    else return max(get_mx(lson,L,mid,x),get_mx(rson,mid+1,R,x));
}
ll get_mx(int x,int y,int z)
{
    ll ret=0;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        ret=max(ret,get_mx(1,n,1,rk[top[x]],rk[x],z));
        x=fa[top[x]];
    }
    if(dep[x]<dep[y]) swap(x,y);
    ret=max(ret,get_mx(1,n,1,rk[y],rk[x],z));
    return ret;
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("tt.in","r",stdin);
    #endif
    cin>>n>>m>>A>>B>>Q;
    for(int f,i=2;i<=n;i++)
    {
        scanf("%d",&f);
        ae(f,i);
    }
    dfs(1);dfs(1,1);
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<=m;j++) st[i][j]=work();
        sort(st[i]+1,st[i]+m+1);
    }
    build(1,n,1);
    for(cin>>C;C--;)
    {
        scanf("%d",&op);
        if(op==0)
        {
            scanf("%d",&p);
            for(int i=1;i<=m;i++) st[p][i]=work();
            sort(st[p]+1,st[p]+m+1);
            update(1,n,1,rk[p]);
        }
        else
        {
            scanf("%d%d",&u,&v);
            tmp=query(1,n,1,rk[u],ed[u]);
            ll ans=tmp.mx[m];
            if(v!=u) for(int i=1;i<=m;i++) ans=max(ans,tmp.mx[m-i]+get_mx(fa[u],v,i));
            printf("%lld\n",ans);
        }
    }
}