Solution for CodeForces 813E - Army Creation

Interesting technique.

Solution #

We create an auxiliary array bb where bib_i is the index of the next kk-th occurrence of aia_i, or n+1n+1 if such occurrence doesn’t exist. For example, the auxiliary array of the example input should be [3, 7, 7, 6, 7, 7].

Consider query (l,r)(l, r), for i[l,r]i\in [l, r], if birb_i\le r, this means that there are more than kk occurrences of aia_i after ii so ii should not be in the army. Thus the answer to the query is rl+1{bibir,i[l,r]}r-l+1-|\{b_i|b_i\le r, i\in[l, r]\}|. Finding the number of elements in a range that are smaller than xx is a classic problem that can be solved with persistent segment tree or wavelet tree.

Code #

#include <bits/stdc++.h>
using namespace std;
struct PST {
    int n, tot=0;
    vector<int> lc, rc, sum, roots; // left child, right child
    PST(int n_) : n(n_), lc(n<<5), rc(n<<5), sum(n<<5), roots(1) {
        build(0, n-1, roots[0]);
    void pushup(int rt) {
        sum[rt] = sum[lc[rt]] + sum[rc[rt]];
    void build(int l, int r, int& rt) {
        rt = ++tot;
        if (l == r) return;
        int mid = (l + r) >> 1;
        build(l, mid, lc[rt]);
        build(mid + 1, r, rc[rt]);
    void update(int pos, int val, int l, int r, int old, int& rt) {
        rt = ++tot;
        lc[rt] = lc[old];
        rc[rt] = rc[old];
        if (l == r) {
            sum[rt] = sum[old] + val;
        int mid = (l + r) >> 1;
        if (pos <= mid) update(pos, val, l, mid, lc[old], lc[rt]);
        else update(pos, val, mid + 1, r, rc[old], rc[rt]);
    int update(int pos, int val) { // return the root of the new version
        int new_root;
        update(pos, val, 0, n-1, roots.back(), new_root);
        return new_root;
    int query(int u, int v, int l, int r, int k) {
        if (l==r) return sum[v]-sum[u];
        int mid=(l+r)/2, x=sum[lc[v]]-sum[lc[u]];
        if (mid<k) return x+query(rc[u], rc[v], mid+1, r, k);
        return query(lc[u], lc[v], l, mid, k);
    int query(int u, int v, int k) {
        return query(u, v, 0, n-1, k);
int main() {
    int n, k;
    constexpr int M=1e5;
    vector<vector<int>> pos(M);
    vector<int> a(n, n);
    for (int i=0; i<n; i++) {
        int x;
        if (pos[x].size()>k) {
    int last=0;
    vector<int> roots(n+1);
    PST tr(n+1);
    for (int i=0; i<n; i++) {
        roots[i+1]=tr.update(a[i], 1);
    int q;
    while (q--) {
        int x, y;
        int l=(x+last)%n, r=(y+last)%n;
        if (l>r) swap(l, r);
        last=(r-l+1)-tr.query(roots[l], roots[r+1], r);
    return 0;