LOJ6681. yww 与树上的回文串

题意:给定一棵边上带 01 权值的树,求有多少对 $(x,y)$ 满足 $x<y$ 且 $x$ 到 $y$ 路径上的边权拼起来是回文串。

$n\leq 5\times 10^4$。

tag:点分治 AC自动机 根号分治。

点分治,分治中心为 $G$ 时有三种贡献可能:

  • $G\to x$ 回文。
  • $x\to G\to y$ 回文,且 $x,y$ 属于不同子树,深度相等。
  • $x\to G\to y$ 回文,且 $x,y$ 属于不同子树,深度不相等,设 $dep_x>dep_y$。

第一种情况用 hash 判断每个点对应的串是否是回文的,第二种情况用 umap 统计,第三种先建立整个分治范围内以 $G$ 为根的 AC 自动机,fail 树上的祖先都对应自己的真后缀。

考虑 $x\to G\to y$ 实际可以划分成三段:$x\to o\to G\to y$,满足 $o\to G$ 段回文,$o\to x$ 和 $G\to y$ 段相等,也就是说 $G\to y$ 是 $G\to x$ 的后缀,点 $y$ 在 fail 树上会是点 $x$ 的祖先。

引理:一个串的前缀回文串可以划分为 $O(\log n)$ 个值域不交的等差数列。

证明见 OI-wiki

先把等差数列预处理放到每个点上。在 fail 树上 dfs,维护出若干等差数列。逐棵把子树统计贡献再加入。

设当前点是 $x$ 计算贡献,则合法的 $y$ 一定在 dfs 栈中。

假设点 $x$ 长度为 $l$ 的前缀是回文的,那么 fail 树上的祖先中如果出现了 $dep_y=dep_x-l$ 的点 $y$,则 $(x,y)$ 会统计入答案。

对于维护出的等差数列首项,末项,公差分别为 $s,e,d$,表示当前长度在 $[s,e]$ 中,且 $\bmod d$ 意义下同余 $s$ 的前缀均回文,则应该统计 fail 树的祖先中原树深度为以 $a_1=dep_x-e$ 为首项,$a_2=dep_x-s$ 为末项,公差为 $d$ 中的元素的点 $y$ 的标记之和。

统计时注意到不能和之前一样逐棵统计贡献并加入,因为 Trie 树上的一棵以根节点为根的连通子树不一定在 fail 树上也是,所以一起统计之后对于每棵子树再算一次答案容斥掉。

设置一个阈值 $B$,维护一个 $B\times B$ 大小的数组 $c_{i,j}$ 表示当前 dfs 栈中的点,原树深度 $\bmod i$ 等于 $j$ 的标记之和,再维护一个大小为 $n$ 的数组 $t_i$ 表示原树深度为 $i$ 的标记和,都容易加入和删除点的贡献。

当 $d<B$ 时可以把需要计算的贡献先挂到树上,再扫一遍贡献到答案上;当 $d\ge B$ 时直接枚举合法深度在 $t$ 中统计答案。

理论上在 $B=O(\sqrt n)$ 时取得最优复杂度 $O(n\log^2 n+n\sqrt n)$。

为了稍微好写一点,可以把差分再离线计算贡献的过程换成树状数组,且 $B$ 取到 $2$ 左右的时候跑得比较快,原因是很难构造数据来卡。

#include<bits/stdc++.h>
#define For(i,a,b) for(int i=(a),i##END=(b);i<=i##END;i++)
#define Rof(i,b,a) for(int i=(b),i##END=(a);i>=i##END;i--)
#define go(u) for(int i=head[u];i;i=nxt[i])
#define pi pair<int,int>
#define fi first
#define se second
using namespace std;
inline int read(){
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
    return x*f;
}
const int N=5e4+10,p1=147744151,p2=666528221,base=131;
int pw1[N],pw2[N],h1[N],h2[N],H1[N],H2[N];//hash
int n,ans;vector<pi> g[N]; 
void add(int u,int v,int w){g[u].push_back(pi(v,w));}
int used[N];
int dep[N],ok[N];//深度,是否回文 
int son[N][2],idx,vis[N];//Trie
int head[N],to[N],nxt[N],cnt;
void add(int u,int v){
    to[++cnt]=v,nxt[cnt]=head[u],head[u]=cnt;
}
int q[N];int L,R; int fail[N];//ACAM
void build(){
    For(i,0,idx)fail[i]=head[i]=0;L=1,R=0;
    For(i,0,1)if(son[0][i])q[++R]=son[0][i];
    while(L<=R){
        int u=q[L++];
        For(i,0,1)if(son[u][i])fail[son[u][i]]=son[fail[u]][i],q[++R]=(son[u][i]);
        else son[u][i]=son[fail[u]][i];
    }cnt=0;For(i,1,idx)add(fail[i],i);
}
struct node{int s,e,d;}; 
vector<node> pal[N];
int Dep[N];
#define rb(x) x.back()
int ban[N];
void prework(int u,int U,int f,int op=0){
    assert(dep[u]==Dep[U]);
    if(!op){
        pal[u]=pal[f];
        ok[u]=0;if(f){
            if(h1[u]==H1[u]&&h2[u]==H2[u]){
                ok[u]=1;
                ans++;//#1
                if(!pal[u].size())pal[u].push_back((node){dep[u],dep[u],1});
                else if(rb(pal[u]).e+rb(pal[u]).d==dep[u])rb(pal[u]).e+=rb(pal[u]).d;
                else{
                    auto it=rb(pal[u]);
                    if(it.s==it.e)rb(pal[u])=(node){it.s,dep[u],dep[u]-it.s};
                    else pal[u].push_back((node){dep[u],dep[u],1});
                }
            }
        }
    }
    for(auto x:g[u]){
        int v=x.fi,w=x.se;if(v==f||used[v])continue;
        if(!op){ 
            h1[v]=(1ll*h1[u]*base+w)%p1;
            h2[v]=(1ll*h2[u]*base+w)%p2;
            H1[v]=(1ll*w*pw1[dep[u]]+H1[u])%p1;
            H2[v]=(1ll*w*pw2[dep[u]]+H2[u])%p2;
        }
        dep[v]=dep[u]+1;
        int to=son[U][w-1];
        if(!to)to=son[U][w-1]=++idx,
            vis[idx]=0,son[idx][0]=son[idx][1]=0,Dep[idx]=Dep[U]+1;
        prework(v,to,u,op);
    }
}
void calc(int u,int U,int f){
    ans+=vis[U];//#2
    assert(dep[u]==Dep[U]);
    for(auto x:g[u]){
        int v=x.fi;if(v==f||used[v])continue;
        calc(v,son[U][x.se-1],u);
    }
}
void addin(int u,int U,int f){
    if(U)vis[U]++;
    for(auto x:g[u]){
        int v=x.fi;if(v==f||used[v])continue;
        addin(v,son[U][x.se-1],u);
    }
}
vector<node> PAL[N];
void find(int u,int U,int f){
    for(auto p:pal[u])PAL[U].push_back(p);
    for(auto x:g[u]){
        int v=x.fi;if(v==f||used[v])continue;
        find(v,son[U][x.se-1],u);
    }
}
const int B=2;
int t[N];
#define lowbit(x) (x&-x)
int cc[N];
inline void add(int u,int v,int* c){for(int i=u;i<=idx;i+=lowbit(i))c[i]+=v;}
inline int ask(int u,int *c,int s=0){for(int i=u;i;i-=lowbit(i))s+=c[i];return s;}
void dfs(int u,int f,int op=1){
    for(auto p:PAL[u]){
        int l=Dep[u]-p.e,r=Dep[u]-p.s,d=p.d;
        if(d==1)ans+=op*(ask(r,cc)-(l?ask(l-1,cc):0));
        else for(int i=l;i<=r;i+=d)ans+=op*t[i];
    }
    t[Dep[u]]+=vis[u];
    if(vis[u])add(Dep[u],vis[u],cc);
    go(u)dfs(to[i],u,op);
    t[Dep[u]]-=vis[u];
    if(vis[u])add(Dep[u],-vis[u],cc);
}
int all_num,fa[N],sz[N],mx[N],rt;
void getr(int u,int f){
    sz[u]=1,mx[u]=0;for(auto x:g[u]){
        int v=x.fi;
        if(!used[v]&&v!=f)getr(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
    }mx[u]=max(mx[u],all_num-sz[u]);if(mx[u]<mx[rt])rt=u;
}
void solve(int u){
    used[u]=1;
    h1[u]=h2[u]=H1[u]=H2[u]=dep[u]=Dep[0]=0,son[0][0]=son[0][1]=0,idx=0;
    prework(u,0,0),build();
    For(i,0,idx)vis[i]=0,vector<node>().swap(PAL[i]);
    for(auto x:g[u]){
        int v=x.fi,V=son[0][x.se-1];if(used[v])continue;
        calc(v,V,u);
        addin(v,V,u);
    }
    find(u,0,0),dfs(0,0);
    for(auto x:g[u]){
        int v=x.fi,w=x.se;if(used[v])continue;
        h1[u]=h2[u]=H1[u]=H2[u]=dep[u]=Dep[0]=0,son[0][0]=son[0][1]=0,idx=0;
        int to=son[0][w-1];
        if(!to)to=son[0][w-1]=++idx,
            vis[idx]=0,son[idx][0]=son[idx][1]=0,Dep[idx]=Dep[0]+1;
        prework(v,to,u,1),build();
        For(i,0,idx)vis[i]=0,vector<node>().swap(PAL[i]);
        addin(v,to,u),find(v,to,u),dfs(0,0,-1);
    }
    for(auto x:g[u]){int v=x.fi;if(!used[v])getr(v,u),all_num=sz[v],rt=0,getr(v,u),solve(rt);}
}
signed main(){
    For(i,2,n=read()){int u=read(),v=read(),w=read()+1;add(u,v,w),add(v,u,w);}
    pw1[0]=pw2[0]=1,mx[0]=1e9;
    For(i,1,n)pw1[i]=1ll*pw1[i-1]*base%p1,pw2[i]=1ll*pw2[i-1]*base%p2;
    all_num=n,rt=0,getr(1,0),solve(rt);cout<<ans<<endl;
    return 0;
}
最后修改:2022 年 10 月 03 日
如果觉得我的文章对你有用,请随意赞赏