CodeForces 1539F - Strange Array 题解

题目链接

思路不难但线段树维护的内容需要一定的技巧。

题解 #

首先我们看如何取 l 和 r 才能使得答案最大。如果l=r=il=r=i的话,aia_i在正中间,如果加入一个小于等于aia_i的数会让aia_i往右偏,反之,加入大于等于的数会往左偏,所以设[l, r]中大于 aia_i 的数的个数为xx,小于aia_i的数的个数为yy,(等于aia_i的数可以算入xxyy其中之一)我们要调整l,rl,r找到 xyx-y的最大和最小值(分别对应aia_i在最左和最右), 此时的答案为max(xy+12,yx2)\max(\frac{x-y+1}{2}, \frac{y-x}{2})。由于l,rl, r互相独立,所以我们可以分别看[1,i],[i,n][1, i], [i, n]两个区间,找到l[1,i],r[i,n]l\in[1, i], r\in [i, n], 使得[l,i1],[i+1,r][l, i-1], [i+1, r]中的xyx-y最大或最小。

如果只找一个aia_i的答案的话,可以非常轻松的用线段树解决。但是对于整个数组的答案就行不通了,对于处理大小关系的题目一种常用的技巧是将整个数组排序,从小到大进行处理,这样就能保证之前的数都比当前数小,处理起来就会简单很多。对于本题我们需要一个数组,其中大于aia_i的位置设为11,小于 aia_i的位置设为 1-1,对于每个位置 ii,我们只要找到sum(l,i1),l[1,i]sum(l, i-1), l\in [1, i]sum(i+1,r),r[i,n]sum(i+1, r), r\in [i, n]的最大值与最小值。由于我们是从小到大处理的所以每次只改动一个位置 (将 1 变成 -1),这样数组就变得非常易于维护。

维护数组的题我们很容易想到用线段树,但这题的询问比较特别,看似是区间最值但sumsum函数对于不同的 i 也会有不同的值。这里我们用到一种在最大子段和的递归实现中用到的技巧,即对于线段树中的每个区间,维护区间和,从左/右端点开始的最大/最小子段和mnl, mnr, mxl, mxr,用数学语言描述就是:令当前维护的区间是[l,r][l, r]

mnl=minlir(sum(l,i)) mnr=minlir(sum(i,r)) mxl=maxlir(sum(l,i)) mxr=maxlir(sum(i,r))\begin{align*}mnl&=\min_{l\le i\le r}(\operatorname{sum}(l, i))\\\ mnr&=\min_{l\le i\le r}(\operatorname{sum}(i, r))\\\ mxl&=\max_{l\le i\le r}(\operatorname{sum}(l, i))\\\ mxr&=\max_{l\le i\le r}(\operatorname{sum}(i, r))\end{align*}

了解了定义之后,如何合并区间也就很容易想到了(具体看代码),此外为了方便实现,我们在代码中允许最大/最小子段不包含任何数。还有,由于相等的数可以随意排列,所以既可以算作大的数又可以算作小的数,所以要询问两遍一次当作小的数,一次当作大的数。

代码: #

#include <bits/stdc++.h>
 
using namespace std;
 
template <typename T> struct SegTree {
    int n;
    vector<T> t;
 
    SegTree(int n_) : n(n_), t(4 * n) { build(1, 0, n - 1, vector(n, T())); }
 
    template <typename U> SegTree(const vector<T> &v) : SegTree((int)v.size()) {
        build(1, 0, n - 1, v);
    }
 
    void pull(int node) { t[node] = t[node << 1] + t[node << 1 | 1]; }
 
    template <typename U>
    void build(int node, int l, int r, const vector<U> &v) {
        if (l == r) {
            t[node] = T(v[l]);
            return;
        }
        int mid = (l + r) >> 1;
        build(node << 1, l, mid, v);
        build(node << 1 | 1, mid + 1, r, v);
        pull(node);
    }
 
    void set(int node, int i, T x, int l, int r) {
        if (l == r) {
            t[node] = x;
            return;
        }
        int mid = (l + r) / 2;
        if (i <= mid) set(node << 1, i, x, l, mid);
        else
            set(node << 1 | 1, i, x, mid + 1, r);
        pull(node);
    }
 
    T get(int node, int ql, int qr, int l, int r) {
        if (ql <= l && qr >= r) return t[node];
 
        int mid = (l + r) >> 1;
        if (qr <= mid) return get(node << 1, ql, qr, l, mid);
        if (ql > mid) return get(node << 1 | 1, ql, qr, mid + 1, r);
        return get(node << 1, ql, qr, l, mid) +
               get(node << 1 | 1, ql, qr, mid + 1, r);
    }
    // wrapper
    void set(int i, T x) {
        assert(i >= 0 && i < n);
        set(1, i, x, 0, n - 1);
    }
 
    T get(int l, int r) {
        // assert(l >= 0 && l <= r && r < n);
        if (l > r) return T();
        return get(1, l, r, 0, n - 1);
    }
};
struct node {
    int sum = 0;
    int mxl = 0, mxr = 0, mnl = 0, mnr = 0;
 
    node(int x = 0)
        : sum(x), mxl(max(0, x)), mxr(mxl), mnl(min(0, x)), mnr(mnl) {}
    node(int a, int b, int c, int d, int e)
        : sum(a), mxl(b), mxr(c), mnl(d), mnr(e) {}
 
    node operator+(const node &b) const {
        return {
            sum + b.sum,
            max(mxl, sum + b.mxl),
            max(b.mxr, b.sum + mxr),
            min(mnl, sum + b.mnl),
            min(b.mnr, b.sum + mnr),
        };
    }
};
 
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n;
    cin >> n;
    vector<vector<int>> pos(n + 1);
    SegTree<node> st(n);
    for (int i = 0; i < n; i++) {
        st.set(i, node(1));
    }
    for (int i = 0; i < n; i++) {
        int x;
        cin >> x;
        pos[x].push_back(i);
    }
    vector<int> ans(n);
    for (int i = 1; i <= n; i++) {
        for (auto p : pos[i]) {
            auto r = st.get(p + 1, n - 1);
            auto l = st.get(0, p - 1);
            ans[p] = max(ans[p], (r.mxl + l.mxr + 1) / 2);
        }
        for (auto p : pos[i])
            st.set(p, node(-1));
        for (auto p : pos[i]) {
            auto r = st.get(p + 1, n - 1);
            auto l = st.get(0, p - 1);
            ans[p] = max(ans[p], (-r.mnl - l.mnr) / 2);
        }
    }
    for (auto x : ans)
        cout << x << ' ';
    return 0;
}