到现在也只会照着std打板子..

虽然这样,树链剖分还是一个非常优雅的算法。


前置芝士:$DFS$,线段树

树链剖分可以把树上的区间操作通过把树剖成一条条链,利用线段树数据结构进行维护,从而达到$O(nlogn)$的优秀时间复杂度。

比如这样的操作:

在一棵树上,将$x$到$y$路径上点的点权加上$w$,并要求支持查询两个点$x,y$路径间的点权和。

乍一看,两个操作都很简单。修改操作可以用树上差分$O(1)$乱搞,静态查询可以用$LCA$完成。

但是合起来就没有办法了:每次查询之前都需要$O(n)$预处理,数据略大直接$T$飞。

于是树剖出场了。


区间修改&查询是线段树的强项,但是它只能对一段连续的区间进行查询。于是我们需要想办法让树上需要操作的路径变成一段连续的区间。

引入一个概念:重儿子,也就是一个节点的儿子中$size$最大的。连接到重儿子的边即为重边

重儿子组成的,就是重链

比如在这棵树中,连续的红边组成的就是一条条重链。我们用$top[u]$记录节点$u$所在重链的顶端。特别地,没有被重边连接的节点,$top[u]=u$,即它们所在重链的顶端就是自身。注意到,当$u$是一条重链的顶端($top[u]=u$)时,它的父节点一定在另一条重链上

始终记住我们的目标:把在树上区间操作转化为在一段连续的区间进行操作。

考虑如何用$DFS$给树上的每个节点在区间内找到一个合适的位置。我们发现,从根节点出发,优先走重边,这样的$dfs$序似乎有点特殊。

例如上图,优先走重边的$dfs$序为:$124798356$。很显然,这样的$dfs$序满足同一条重链上的点$dfs$序连续。所以用线段树维护的,就是重链上的信息

这样操作之后,我们可以做到的是:$O(logn)$对一条重链上的信息区间修改,区间查询。

对于两个节点$u,v$,我们可以通过不断地跳重链,直到两个节点在同一条重链上。这个是很好实现的,因为只需要跳到$fa[top[u]]$,就到了一条新的重链。

代码实现仅树剖部分是不麻烦的。我们需要维护的信息有$dep$(节点深度),$fa$(父节点),$son$(重儿子),$sz$(子树节点数,用来判重儿子),这些可以用一次$dfs$完成。

void dfs1(int u,int f,int d)//fa,dep,son,sz
{
    fa[u]=f;
    dep[u]=d;
    sz[u]=1;
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v!=f)
        {
            dfs1(v,u,d+1);
            sz[u]+=sz[v];
            if(sz[v]>sz[son[u]])son[u]=v;
        }

    }
}

接下来,就需要把这棵树每个节点压到线段树维护的序列的一个位置了。就像上文说的一样,按照优先重边的$dfs$序压入线段树即可。于是记录一个$id[i]$表示原树中节点$i$对应的线段树中的下标。$rk[i]$反过来记录线段树中下标为$i$的原数编号。

由于预处理了父节点,所以$dfs2$传参只需要$u$(当前节点)和$t$(当前重链顶端节点)。在遍历儿子之前先$dfs2(son[u],t)$,因为$u$和$u$的重儿子在同一条重链上。接下来才遍历轻(非重)儿子$v$,但是传参为$dfs2(v,v)$,因为$v$就是新的一条重链的起点。

void dfs2(int u,int t)//top,id,rk
{
    top[u]=t;
    id[u]=++tot;
    rk[tot]=u;
    if(!son[u])return;
    dfs2(son[u],t);
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v!=fa[u]&&v!=son[u])
            dfs2(v,v);
    }
}

再回到最开始的问题:

在一棵树上,将$x$到$y$路径上点的点权加上$w$,并要求支持查询两个点$x,y$路径间的点权和。

答案就显得很明了了。

如果是查询,先保证$dep[x]>dep[y]$,然后就和$LCA$类似的,利用重链加速:每次把$[top[x],x]$这条重链的和累加到答案上,再使$x$跳到另一条重链上,即$x=fa[top[x]]$,直到$x,y$在同一条重链上,再把两个点之间的信息统计累加一下即可。

int getsum(int x,int y)
{
    int res=0;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        sum=0;
        asksum(1,id[top[x]],id[x]);
        (res+=sum)%=mod;
        x=fa[top[x]];
    }
    if(id[x]>id[y])swap(x,y);
    sum=0;
    asksum(1,id[x],id[y]);
    (res+=sum)%=mod;
    return res;
}

修改同理。

于是我们发现,虽然我们采用了优先重边的$dfs$序,但它毕竟遍历的都是自己的儿子节点。所以...还可以支持子树操作。因为一棵子树在重边优先的$dfs$序中编号也是连续的。并且这个编号很容易算,因为我们维护了一个$sz$信息。所以树中$x$节点的子树对应的就是线段树维护的$[id[x],id[x]+sz[x]-1]$这个区间

于是还是板子一般的线段树区间修改&查询。


可以注意到线段树部分基本没讲,因为每个人写线段树的方法可能不太一样,蒟蒻我分享的只是树剖的思想。

另外,为什么树剖每次操作是$O(logn)$呢?利用线段树的子树操作自然是$O(logn)$,剩下的就是那个像$LCA$一样的跳重链。

证明:从任意节点向根节点跳重链,经过的重链和轻边(非重边)都是$log$级别的。

考虑到每走一条轻边,子树大小至少翻倍,否则这就不是条轻边了。于是经过的轻边就最多为$log_2 n$条。而重链和轻边的交替出现的,所以数量也在这个级别。

于是每次操作就只有$O(logn)$的时间复杂度。

模板题

以下是代码

#include<bits/stdc++.h>
#define int long long
#define ls (k<<1)
#define rs (k<<1|1)
using namespace std;
const int N=1e5+10;
struct node
{
    int l,r,w,f;
}t[N<<2];
int a[N];
int n,m,r,mod;
int sum;
int head[N<<1],to[N<<1],nxt[N<<1],cnt;
int sz[N],fa[N],dep[N],son[N];
int top[N],id[N],rk[N],tot;
inline int read()
{
   int x=0,f=1;
   char ch=getchar();
   while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
   while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
   return x*f;
}
void add(int u,int v)
{
    cnt++;
    to[cnt]=v;
    nxt[cnt]=head[u];
    head[u]=cnt;
}
void dfs1(int u,int f)
{
    fa[u]=f;
    sz[u]=1;
    dep[u]=dep[f]+1;
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==f)continue;
        dfs1(v,u);
        sz[u]+=sz[v];
        if(sz[v]>sz[son[u]])son[u]=v;
    }
    return;
}
void dfs2(int u,int t)
{
    top[u]=t;
    id[u]=++tot;
    rk[tot]=u;
    if(!son[u])return;
    dfs2(son[u],t);
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v!=fa[u]&&v!=son[u])dfs2(v,v);//新的重链 
    }
}
void build(int k,int l,int r)
{
    t[k].l=l,t[k].r=r;
    if(l==r)
    {
        t[k].w=a[rk[l]];
        return;
    }
    int m=l+r>>1;
    build(ls,l,m);
    build(rs,m+1,r);
    t[k].w=t[ls].w+t[rs].w;
    return;
}
void down(int k)
{
    t[ls].w+=(t[ls].r-t[ls].l+1)*t[k].f;
    t[rs].w+=(t[rs].r-t[rs].l+1)*t[k].f;
    t[ls].f+=t[k].f;
    t[rs].f+=t[k].f;
    t[k].f=0;
}
void addsum(int k,int x,int y,int p)
{
    int l=t[k].l,r=t[k].r;
    if(x<=l&&r<=y)
    {
        t[k].w+=(r-l+1)*p;
        t[k].f+=p;
        return;
    }
    down(k);
    int m=l+r>>1;
    if(x<=m)addsum(ls,x,y,p);
    if(y>m)addsum(rs,x,y,p);
    t[k].w=t[ls].w+t[rs].w;
    return;
}
void asksum(int k,int x,int y)
{
    int l=t[k].l,r=t[k].r;
    if(x<=l&&r<=y)
    {
        sum+=t[k].w;
        return;
    }
    down(k);
    int m=l+r>>1;
    if(x<=m)asksum(ls,x,y);
    if(y>m)asksum(rs,x,y);
    t[k].w=t[ls].w+t[rs].w;
    return;
}
//-----------------------------
int getsum(int x,int y)
{
    int res=0;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        sum=0;
        asksum(1,id[top[x]],id[x]);
        (res+=sum)%=mod;
        x=fa[top[x]];
    }
    if(id[x]>id[y])swap(x,y);
    sum=0;
    asksum(1,id[x],id[y]);
    (res+=sum)%=mod;
    return res;
}
void update(int x,int y,int p)
{
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        addsum(1,id[top[x]],id[x],p);
        x=fa[top[x]];
    }
    if(id[x]>id[y])swap(x,y);
    addsum(1,id[x],id[y],p);
    return;
}
signed main()
{
    n=read(),m=read(),r=read(),mod=read();
    for(int i=1;i<=n;i++)a[i]=read();
    for(int i=1;i<n;i++)
    {
        int x=read(),y=read();
        add(x,y),add(y,x);
    }
    dfs1(r,0);
    dfs2(r,r);
    build(1,1,n);
    for(int i=1;i<=m;i++)
    {
        int x,y,z;
        int opt=read();
        if(opt==1)
        {
            x=read(),y=read(),z=read();
            update(x,y,z);
        }
        if(opt==2)
        {
            x=read(),y=read();
            printf("%lld\n",getsum(x,y)%mod);
        }
        if(opt==3)
        {
            x=read(),z=read();
            addsum(1,id[x],id[x]+sz[x]-1,z);
        }
        if(opt==4)
        {
            x=read();
            sum=0;asksum(1,id[x],id[x]+sz[x]-1);
            printf("%lld\n",sum%mod);
        }
    }
    return 0;
}

代码的确是长,也不算容易调,但是真正妙的是利用轻重链的思想进行的化树为链。

感谢阅读。

最后修改:2020 年 08 月 17 日 08 : 10 PM