树链剖分
树链剖分的核心思想是把一棵树变为一个序列,树中的路径全部转化为logn段连续的区间,接下只要去用线段树或分块去维护即可。剖分有几个定义:轻重儿子:对子树的点数进行排序,最多点的即为重儿子,其它全为轻儿子;轻重边:父节点向轻重儿子的边;重链:顺着重儿子开始一直遍历到底部。我们在第一次遍历时候记录重儿子,深度,父节点和大小。第二次dfs去完成最祖先节点的标记和序列的初始化,区间操作只要查询id即可直接进行修改。对于树上的路径我们用爬山法一步一步向上,把重链全部给一个一个的遍历。
给定一棵树,树中包含 n 个节点(编号 1∼n),其中第 i 个节点的权值为 ai。
初始时,1 号节点为树的根节点。
现在要对该树进行 m 次操作,操作分为以下 4 种类型:
1 u v k
,修改路径上节点权值,将节点 u 和节点 v 之间路径上的所有节点(包括这两个节点)的权值增加 k。2 u k
,修改子树上节点权值,将以节点 u 为根的子树上的所有节点的权值增加 k。3 u v
,询问路径,询问节点 u 和节点 v 之间路径上的所有节点(包括这两个节点)的权值和。4 u
,询问子树,询问以节点 u 为根的子树上的所有节点的权值和。
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long LL;
const int N = 1e5+10,M = 2e5+10;
int n,m;
int id[N],nw[N],cnt;
int dep[N],sz[N],top[N],fa[N],son[N];
int w[N],head[N],idx;
struct Edge{
int next,to;
}edge[M];
struct Tree{
int l,r;
LL add,sum;
}tr[N<<2];
void add(int a,int b) {
edge[idx]=(Edge){head[a],b},head[a]=idx++;
}
void dfs1(int u,int father,int depth)
{
dep[u]=depth,fa[u]=father,sz[u]=1;
for(int i=head[u];~i;i=edge[i].next)
{
int j=edge[i].to;
if(j==father) continue;
dfs1(j,u,depth+1);
sz[u]+=sz[j];
if(sz[j]>sz[son[u]]) son[u]=j;
}
}
void dfs2(int u,int t)
{
id[u]=++cnt,nw[cnt]=w[u],top[u]=t;
if(!son[u]) return;
dfs2(son[u],t);
for(int i=head[u];~i;i=edge[i].next)
{
int j=edge[i].to;
if(j==fa[u]||j==son[u]) continue;
dfs2(j,j);
}
}
void pushup(int u) {
tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum;
}
void pushdown(int u)
{
auto &root=tr[u],&left=tr[u<<1],&right=tr[u<<1|1];
if(root.add)
{
left.add+=root.add,left.sum+=root.add*(left.r-left.l+1);
right.add+=root.add,right.sum+=root.add*(right.r-right.l+1);
root.add=0;
}
}
void build(int u,int l,int r)
{
tr[u]=(Tree){l,r,0,nw[r]};
if(l==r) return;
int mid=l+r>>1;
build(u<<1,l,mid),build(u<<1|1,mid+1,r);
pushup(u);
}
void update(int u,int l,int r,int k)
{
if(l<=tr[u].l&&tr[u].r<=r)
{
tr[u].add+=k;
tr[u].sum+=k*(tr[u].r-tr[u].l+1);
return;
}
pushdown(u);
int mid=tr[u].l+tr[u].r>>1;
if(l<=mid) update(u<<1,l,r,k);
if(r>mid) update(u<<1|1,l,r,k);
pushup(u);
}
LL query(int u,int l,int r)
{
if(l<=tr[u].l&&r>=tr[u].r) return tr[u].sum;
pushdown(u);
int mid=tr[u].l+tr[u].r>>1;
LL res=0;
if(l<=mid) res+=query(u<<1,l,r);
if(r>mid) res+=query(u<<1|1,l,r);
return res;
}
void update_path(int u,int v,int k)
{
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]]) swap(u,v);
update(1,id[top[u]],id[u],k);
u=fa[top[u]];
}
if(dep[u]<dep[v]) swap(u,v);
update(1,id[v],id[u],k);
}
LL query_path(int u,int v)
{
LL res=0;
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]]) swap(u,v);
res+=query(1,id[top[u]],id[u]);
u=fa[top[u]];
}
if(dep[u]<dep[v]) swap(u,v);
res+=query(1,id[v],id[u]);
return res;
}
void update_tree(int u,int k)
{
update(1,id[u],id[u]+sz[u]-1,k);
}
LL query_tree(int u)
{
return query(1,id[u],id[u]+sz[u]-1);
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",&w[i]);
memset(head,-1,sizeof head);
for(int i=1;i<n;i++)
{
int a,b;
scanf("%d%d",&a,&b);
add(a,b),add(b,a);
}
dfs1(1,-1,1);
dfs2(1,1);
build(1,1,n);
scanf("%d",&m);
while(m--)
{
int t,u,v,k;
scanf("%d%d",&t,&u);
if(t==1)
{
scanf("%d%d",&v,&k);
update_path(u,v,k);
}
else if(t==2)
{
scanf("%d",&k);
update_tree(u,k);
}
else if(t==3)
{
scanf("%d",&v);
printf("%lld\n",query_path(u,v));
}
else printf("%lld\n",query_tree(u));
}
return 0;
}