Tutorial 2018-2019 ACM-ICPC, Asia Nanjing Regional Contest M - Mediocre String Problem

Solution #

First, count how many palindromes begin with sis_i and let the number be fif_i.

Then, find the maximum length dd such that sik=tks_{i - k} = t_k for each k = 1, 2, …, d and let the length be gig_i.

The answer should be i=1Sfigi\sum_{i=1}^{\lvert S\rvert}f_i\cdot g_i.

The first part can be solved using manacher algorithm, and the second part is equivalent to calculate the longest common prefix (LCP) for every suffix of the reversed string ss with tt, which can be solved using Z algorithm.

Code #

#include <bits/stdc++.h>
 
#define forn(i, n) for (int i = 0; i < int(n); ++i)
#define for1(i, n) for (int i = 1; i <= int(n); ++i)
#define F first
#define S second
#define all(x) (x).begin(),(x).end()
#define sz(x) int(x.size())
#define pb push_back
 
using namespace std;
using ll=long long;
 
vector<int> manacher(const string ss){
    string s;
    for(auto ch:ss) s+="#",s+=ch;
    s+="#";
    int n=(int)s.size();
    vector<int> d1(n);
    for (int i = 0, l = 0, r = -1; i < n; i++) {
        int k = (i > r) ? 1 : min(d1[l + r - i], r - i);
        while (0 <= i - k && i + k < n && s[i - k] == s[i + k]) {
            k++;
        }
        d1[i] = k--;
        if (i + k > r) {
            l = i - k;
            r = i + k;
        }
    }
    return d1;
}
 
vector<int> z_function(const string s) {
    int n = (int)s.size();
    vector<int> z(n);
    for (int i = 1, l = 0, r = 0; i < n; ++i) {
        if (i <= r) z[i] = min(r - i + 1, z[i - l]);
        while (i + z[i] < n && s[z[i]] == s[i + z[i]]) ++z[i];
        if (i + z[i] - 1 > r) l = i, r = i + z[i] - 1;
    }
    return z;
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    string s,t;
    cin>>s>>t;
    auto man=manacher(s);
    vector<int> f(sz(s));
    for(int i=1;i<sz(man)-1;i++){
        int l=(i-man[i]+1)/2,r=(i-1)/2;
        f[l]++;
        if(r<sz(s)-1) f[r+1]--;
    }
    partial_sum(all(f),f.begin());
    reverse(all(s));
    auto g=z_function(t+"#"+s);
    g.erase(g.begin(),g.begin()+sz(t)+1);
    reverse(all(g));
    g.erase(g.end()-1);
    g.insert(g.begin(),0);
    ll ans=0;
    forn(i,sz(s)) ans+=ll(f[i])*g[i];
    cout<<ans;
    return 0;
}