树状数组
题意:一个树,以树枝连接两个点的形式给出,固定以1为整棵树的根。苹果长在树的节点上,节点上只可能0或1个苹果,一开始每个节点都有1个苹果
有两种操作,C表示更改某个节点的苹果数,0变1,1变0。Q表示查询,某个节点(包括)的子树上一共有多少个苹果
这题是选拔赛时候的题,一看,单点修改,区间查询?线段树?后来一直没想出来,今天看解题报告才明白真的是,不过是树状数组(不过按照理论来将,凡是树状数组能解决的问题,线段树都可以解决,反之不然)。整个问题最难的时候怎么映射成树状数组,映射后,只是树状数组的模板操作
1.首先数据太大,不用显式建树,只是用vector来保存边的信息,仅仅是利用vector来遍历二叉树,而且是后序遍历
2.
后序遍历二叉树,并按照遍历顺序重新给节点标号,那么有几个较为容易理解的结论
1.根节点的标号一定比所有子孙后代节点的标号大
2.但是标号比根节点小的节点不一定是根节点的子孙后代
3.根节点的子孙后代一定是和根节点的标号是连着的,根节点的标号为a,子孙后代节点标号范围在[b,a-1],是一段连续的区间
上面这3点比较容易想到,基于上面这3点,我们不难提出一个问题,怎么确定哪些标号是根节点的子孙后代呢?即第3条结论的b要怎么确定呢?
用时间戳来记录,或者说简单点就是访问的顺序(或说)深度来表示
起始时间为0,每访问到一个节点其实就是花费了一个时间,记录下第1次访问到该节点的时间。另外,当最后一次回到该节点时(即访问完所有的子孙后代),
记录一下时间,其实也就是后序遍历时给这个节点标号,这两步是同时进行的。
那么我们就可以知道了一个节点第一次被访问的时间,和最后离开该节点的时间,这两个时间相减得到一个时间差,在这段时间里我们只做了什么?
没错,就是访问了该节点的子孙后代!所有,标号在该差值范围内的节点就是该节点的子孙后代
所以查询的时候,我们要的不是前缀和,只是要区间和,而区间和可以由两个前缀和相减所得
sum[num[u]] - sum[num[u]-x'-1] ,x'=last[u]-first[u]
今天很累不想说了,剩下的看代码吧,应该能看懂的。。。。
#include <cstdio> #include <cstring> #include <vector> using namespace std; #define N 100010int time,ccount; bool vis[N]; //dfs的标记数组 bool app[N]; //每个点是否有苹果 typedef vector<int> INF; vector<INF> e(N); int first[N],last[N],num[N]; //第一次访问和离开一个节点的时间,后序遍历该节点的编号 int C[N]; //映射后的树状数组int lowbit(int x) {return x&(-x); }void init(int n) {for(int i=0; i<=n; i++){app[i]=true;C[i]=lowbit(i); //因为一开每个点都有1个苹果所以可以初始化不用沿路径更新 } }void build(int n) {for(int i=1; i<n; i++){int u,v;scanf("%d%d",&u,&v);e[u].push_back(v);e[v].push_back(u);} }void dfs(int u) {vis[u]=true;first[u]=++time; //记录第一次访问该节点的时间int size=e[u].size();for(int i=0; i<size; i++){int v=e[u][i];if(!vis[v]) dfs(v);}last[u]=time; //记录最后一次访问该节点的时间num[u]=++ccount; }int sum(int pos) {int ans=0;while(pos){ans += C[pos];pos -= lowbit(pos);}return ans ; }void updata(int pos , int n) {int val;if(!app[pos]) { app[pos]=true; val=1; } else { app[pos]=false; val=-1; }while(pos<=n){C[pos] += val;pos += lowbit(pos);} }int main() {int n;scanf("%d",&n);build(n); //建树存边 init(n);time=ccount=0; //时间戳和节点标号memset(vis,false,sizeof(vis));dfs(1);char str[3]; int M,m;scanf("%d",&M);for(int i=0; i<M; i++){scanf("%s%d",str,&m);if(str[0]=='C') updata(num[m],N-10);else{int s1=sum( num[m] );int s2=sum( num[m] -(last[m]-first[m]) -1 );printf("%d\n",s1-s2);}}return 0; }