[BZOJ 4196]软件包管理器

danihao123 posted @ 2016年1月03日 13:14 in 题解 with tags bzoj 树链剖分 NOI , 505 阅读
转载请注明出处:http://danihao123.is-programmer.com/

终于A了!

在CodeVS,洛谷甚至UOJ上各种A

但是在BZOJ上各种TLE。BZOJ评测姬自带10倍常数?

这题处理安装很简单,一直溯到根。

删除……注意一下树剖的一些神奇性质。

我们都知道同一重链的点在线段树中是连续的,但是DFS在访问完重链之后访问的是什么?

当然是上一个结点的轻儿子了。

所以说一个子树的所有节点在线段树中也是连续的!

所以问题也就迎刃而解了~

下面是喜闻乐见的代码:

/**************************************************************
    Problem: 4196
    User: danihao123
    Language: C++
    Result: Accepted
    Time:9628 ms
    Memory:10884 kb
****************************************************************/
 
#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
#include <cctype>
using namespace std;
int n;
const int maxn=100001;
vector<int> G[maxn];
inline void Add_Edge(int x,int y){
    G[x].push_back(y);
    return;
}
// 线段树
int sumv[maxn*4],unsumv[maxn*4],setv[maxn*4];
int _sum,ql,qr;
bool d;
void _query(int o,int L,int R){
    if(setv[o]>=0){
        _sum+=(d?setv[o]:(setv[o]^1))*(min(qr,R)-max(ql,L)+1);
    }else{
        if(ql<=L && R<=qr){
            _sum+=d?sumv[o]:unsumv[o];
        }else{
            int M=L+(R-L)/2;
            if(ql<=M)
                _query(o*2,L,M);
            if(qr>M)
                _query(o*2+1,M+1,R);
        }
    }
    return;
}
void maintain(int o,int L,int R){
    if(R>L){
        sumv[o]=sumv[o*2]+sumv[o*2+1];
        unsumv[o]=unsumv[o*2]+unsumv[o*2+1];
    }
    if(setv[o]>=0){
        sumv[o]=setv[o]*(R-L+1);
        unsumv[o]=(setv[o]^1)*(R-L+1);
    }
    return;
}
void pushdown(int o){
    if(setv[o]>=0){
        setv[o*2]=setv[o*2+1]=setv[o];
        setv[o]=-1;
    }
    return;
}
int yl,yr,v;
void _change(int o,int L,int R){
    int lc=o*2,rc=o*2+1;
    if(yl<=L && R<=yr){
        setv[o]=v;
    }else{
        pushdown(o);
        int M=L+(R-L)/2;
        if(yl<=M)
            _change(lc,L,M);
        else
            maintain(lc,L,M);
        if(yr>M)
            _change(rc,M+1,R);
        else
            maintain(rc,M+1,R);
    }
    maintain(o,L,R);
    return;
}
inline int query(int l,int r,bool done){
    _sum=0;
    ql=l;
    qr=r;
    d=done;
    _query(1,1,n);
    return _sum;
}
inline void change(int l,int r,int value){
    yl=l;
    yr=r;
    v=value;
    _change(1,1,n);
    return;
}
bool vis[maxn];
int son[maxn],fa[maxn],siz[maxn];
void dfs_1(int x,int father){
    fa[x]=father;
    siz[x]=1;
    int max_siz=0,temp;
    for(int i=0;i<G[x].size();i++){
        temp=G[x][i];
        dfs_1(temp,x);
        siz[x]+=siz[temp];
        if(siz[temp]>max_siz){
            max_siz=siz[temp];
            son[x]=temp;
        }
    }
    return;
}
int lable=0;
int tid[maxn],top[maxn];
void dfs_2(int x,int a){
    vis[x]=true;
    tid[x]=++lable;
    top[x]=a;
    int temp;
    if(son[x]>=0)
        dfs_2(son[x],a);
    else
        return;
    for(int i=0;i<G[x].size();i++){
        temp=G[x][i];
        if(!vis[temp]){
            dfs_2(temp,temp);
        }
    }
    return;
}
int install(int x){
    register int ans=0;
    do{
        ans+=query(tid[top[x]],tid[x],false);
        change(tid[top[x]],tid[x],1);
        x=fa[top[x]];
    }while(x>=0);
    return ans;
}
int uninstall(int x){
    register int ans,k=tid[x]+siz[x]-1;
    ans=query(tid[x],k,true);
    change(tid[x],k,0);
    return ans;
}
inline int readint(){
    char c=getchar();
    register int x=0;
    while(!isdigit(c))
        c=getchar();
    while(isdigit(c)){
        x=x*10+c-'0';
        c=getchar();
    }
    return x;
}
int bf[10];
inline void writeint(int x){
    register int p=0;
    if(x==0){
        bf[p++]=0;
    }else{
        while(x){
            bf[p++]=x%10;
            x/=10;
        }
    }
    for(register int i=p-1;i>=0;i--)
        putchar('0'+bf[i]);
}
int main(){
    register int i;
    int m,x;
    char buf[10];
    n=readint();
    for(i=1;i<n;i++){
        x=readint();
        Add_Edge(x,i);
    }
    memset(sumv,0,sizeof(sumv));
    memset(setv,-1,sizeof(setv));
    setv[1]=0;
    memset(son,-1,sizeof(son));
    memset(vis,0,sizeof(vis));
    dfs_1(0,-1);
    dfs_2(0,0);
    scanf("%d",&m);
    for(i=0;i<m;i++){
        scanf("\n%s",buf);
        x=readint();
        if(buf[0]=='i')
            writeint(install(x));
        else
            writeint(uninstall(x));
        putchar('\n');
    }
    return 0;
}

登录 *


loading captcha image...
(输入验证码)
or Ctrl+Enter