CodeForces 813E - Army Creation 题解

很神奇的技巧

题解 #

我们创建一个辅助数组bb其中bib_iaia_i后第kkaia_i的下标,或者是n+1n+1如果后面没有kkaia_i了。比如说,样例的辅助数组是[3, 7, 7, 6, 7, 7].

考虑询问(l,r)(l, r),对于i[l,r]i\in [l, r],如果birb_i\le r,说明iiaia_i出现了多于kk次,所以ii不应该在军队里。所以答案是rl+1{bibir,i[l,r]}r-l+1-|\{b_i|b_i\le r, i\in[l, r]\}|。找区间里小于xx的数的个数可以用主席树或者 wavelet 树解决。

Code #

#include <bits/stdc++.h>
 
using namespace std;
 
struct PST {
    int n, tot=0;
    vector<int> lc, rc, sum, roots; // left child, right child
    PST(int n_) : n(n_), lc(n<<5), rc(n<<5), sum(n<<5), roots(1) {
        build(0, n-1, roots[0]);
    }
    void pushup(int rt) {
        sum[rt] = sum[lc[rt]] + sum[rc[rt]];
    }
    void build(int l, int r, int& rt) {
        rt = ++tot;
        if (l == r) return;
        int mid = (l + r) >> 1;
        build(l, mid, lc[rt]);
        build(mid + 1, r, rc[rt]);
        pushup(rt);
    }
    void update(int pos, int val, int l, int r, int old, int& rt) {
        rt = ++tot;
        lc[rt] = lc[old];
        rc[rt] = rc[old];
        if (l == r) {
            sum[rt] = sum[old] + val;
            return;
        }
        int mid = (l + r) >> 1;
        if (pos <= mid) update(pos, val, l, mid, lc[old], lc[rt]);
        else update(pos, val, mid + 1, r, rc[old], rc[rt]);
        pushup(rt);
    }
    int update(int pos, int val) { // return the root of the new version
        int new_root;
        update(pos, val, 0, n-1, roots.back(), new_root);
        roots.push_back(new_root);
        return new_root;
    }
    int query(int u, int v, int l, int r, int k) {
        if (l==r) return sum[v]-sum[u];
        int mid=(l+r)/2, x=sum[lc[v]]-sum[lc[u]];
        if (mid<k) return x+query(rc[u], rc[v], mid+1, r, k);
        return query(lc[u], lc[v], l, mid, k);
    }
    int query(int u, int v, int k) {
        return query(u, v, 0, n-1, k);
    }
};
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, k;
    cin>>n>>k;
    constexpr int M=1e5;
    vector<vector<int>> pos(M);
    vector<int> a(n, n);
    for (int i=0; i<n; i++) {
        int x;
        cin>>x;
        pos[x].push_back(i);
        if (pos[x].size()>k) {
            a[*(pos[x].rbegin()+k)]=i;
        }
    }
    int last=0;
    vector<int> roots(n+1);
    roots[0]=1;
    PST tr(n+1);
    for (int i=0; i<n; i++) {
        roots[i+1]=tr.update(a[i], 1);
    }
    int q;
    cin>>q;
    while (q--) {
        int x, y;
        cin>>x>>y;
        int l=(x+last)%n, r=(y+last)%n;
        if (l>r) swap(l, r);
        last=(r-l+1)-tr.query(roots[l], roots[r+1], r);
        cout<<last<<'\n';
    }
    return 0;
}