Tutorial for Cdoeforces 486D - Valid Sets

Solution #

Firstly, let’s ignore the third condition for now. Consider the tree is rooted at node 1. Let dpidp_i be the number of valid sets contain node ii and other nodes in the subtree of ii. This can be easily calculated using a dfs: dpi=jchild(i)(dpj+1)dp_i=\prod_{j\in child(i)}(dp_j+1)

Now consider the third condition. We can set each node to be the smallest value in the valid set respectively. After setting the smallest value, start dfs from node ii and only visit nodes jj such that aiajai+da_i\leq a_j \leq a_i+d. In this case, the third condition is satisfied so we can calculate the answer using the formula above. Also be careful with the duplicate counting, i.e. if aj=aia_j=a_i, only visit node jj such that j>ij>i.

Code #

#include <bits/stdc++.h>
 
#define forn(i, n) for (int i = 0; i < int(n); ++i)
#define for1(i, n) for (int i = 1; i <= int(n); ++i)
#define ms(a, x) memset(a, x, sizeof(a))
#define F first
#define S second
#define all(x) (x).begin(),(x).end()
#define pb push_back
 
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
const int INF = 0x3f3f3f3f;
mt19937 gen(chrono::steady_clock::now().time_since_epoch().count());
template<typename... T> void rd(T&... args) {((cin>>args), ...);}
template<typename... T> void wr(T... args) {((cout<<args<<" "), ...);cout<<endl;}
 
vector<int> a;
vector<vector<int>> G;
int d,n;
const int mod=1e9+7;
int dfs(int u,int root,int fa){
    int sz=1;
    for(auto to:G[u]){
        if(to==fa) continue;
        if(a[to]<a[root]||a[to]>a[root]+d) continue;
        if(a[to]==a[root]&&to<root) continue;
        sz=ll(sz)*(dfs(to,root,u)+1)%mod;
    }
    return sz;
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin>>d>>n;
    G.resize(n+1);
    a.resize(n+1);
    for1(i,n) cin>>a[i];
    forn(i,n-1){
        int x,y;
        cin>>x>>y;
        G[x].pb(y);
        G[y].pb(x);
    }
    int ans=0;
    for1(i,n){
        ans=(ans+dfs(i,i,0))%mod;
    }
    cout<<ans;
    return 0;
}