前置知识:


来看这样一道题目:

对于长度为$n\leq 2*10^5$的数列a,要求进行$m\leq 10^5$次询问,给定$l,r,x$,求$i∈[l,r],max\{a_i⊕x\}$,$a_i\leq 2^{32}$

如果是初学01Trie,可能一眼就觉得这是个裸题,但是并没有那么简单,回忆我们01Trie的操作,每次查询时是在一棵完整的Trie树上进行查找,这里的完整指的是,如果我们把$a_1$加进Trie,我们就无法做到一定不走$a_1$。

于是我们有一个朴素的想法:每次插入一个值前备份当前Trie树上的信息。这里的信息即val,val[u]记录结点u被经过的次数,那么就用valu记录$a_x$插入后$u$结点的$val$。

每次查询[l,r]的信息,我们依然是从根节点出发,当我们想要走到结点u时,需要先判断 val[u][r]-val[u][l-1]是否大于0,如果是,那么才能走,因为这说明在[l,r]区间内有这样一个数$a_i$支持我们走向结点u,这是显然的。

但是Trie的空间消耗本来就有点大,1e5的数据为了保险要开到3e6左右的结点。对于每个结点我们还要记录n个版本的信息,空间直接爆炸掉。

那么我们考虑利用可持久化的思想来优化这个过程。

我们回忆起,每次插入一个数,最多新增的结点数目是log级别的,正如我们的线段树。

那么也就是说,每次插入会引起val改变的结点数目只有log个,其它结点的信息是没有必要再进行一次备份的。

考虑插入每个新的结点时加入一个新的root,即对于每个数$a_i$都有不同的$rt_i$。

假设插入数字的这一位是v,那么我们总是新建v分支的结点,并把!v分支的结点连向上一个版本的儿子。

假设我们要查询区间[l,r],我们就同时从$rt_r,rt_{l-1}$出发,按照x的每一位依次递归,判断一个结点u是否合法还是和上面一样,用第r棵的减去第l-1棵的即可。

并且由于我们的val也是从上一棵trie继承过来的,每次插入依然是log(n)的,新增的结点也是log(n)级别的。信息的继承也不会出现问题,为什么呢?朴素的Trie其实也是这样的,只是$rt_{1...n}$都是同一个位置,上一个版本的继承过来时也就不需要两个结点一起递归,所以继承信息这一块并没有本质上的区别,自然可以用上述方法进行计算。

代码:

#include<bits/stdc++.h>
using namespace std;
inline int read(){
    int x(0),f(1);char ch(getchar());
    while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
    while(ch<='9'&&ch>='0')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
    return x*f;
}
int idx=1;
int c[7000010][2];
int val[7000010];
int rt[200010];
void insert(int x,int b){
    int u=rt[b],u2=rt[b-1];
    for(int i=31;~i;i--){
        int v=(x>>i)&1;
        c[u][v]=++idx;
        if(u2)
            val[c[u][v]]=val[c[u2][v]]+1,
            c[u][!v]=c[u2][!v],val[c[u][!v]]=val[c[u2][!v]];
        else val[c[u][v]]=1,c[u][!v]=0;
        u=c[u][v];
        u2=c[u2][v];
    }
}
int ask(int l,int r,int x){
    int ans=0;
    int u1=rt[r],u2=rt[l-1];
    for(int i=31;~i;i--){
        int v=(x>>i)&1;
        if(val[c[u1][!v]]>val[c[u2][!v]]){
            u1=c[u1][!v],u2=c[u2][!v];
            ans+=(1<<i);
        }
        else u1=c[u1][v],u2=c[u2][v];
    }
    return ans;
}
int main(){
    freopen("hugclose.in","r",stdin);
    freopen("hugclose.out","w",stdout);
    rt[0]=1;
    int n(read()),q(read());
    for(int i=1;i<=n;i++){
        int a(read());
        rt[i]=++idx;
        insert(a,i);
    }
    while(q--){
        int x(read()),l(read()),r(read());
        printf("%d\n",ask(l+1,r+1,x));
    }
    return 0;
}
最后修改:2020 年 10 月 31 日 09 : 23 PM