1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
| #include<bits/stdc++.h> #define int long long #define pt putchar(' ') #define nl puts("") #define pi pair<int,int> #define pb push_back #define go(it) for(auto &it:as[x]) using namespace std;
const int N=3e3+10,Q=998244353; int n,u,v,cnt,ans; int a[N],c[N],sz[N],f[N][N*3],ct[N],L[N],R[N]; vector<int> as[N];
int fr(){ int x=0,flag=1; char ch=getchar(); while(ch<'0' || ch>'9'){ if(ch=='-') flag=-1; ch=getchar(); } while(ch>='0' && ch<='9'){ x=x*10+(ch-'0'); ch=getchar(); } return x*flag; } void fw(int x){ if(x<0) putchar('-'),x=-x; if(x>9) fw(x/10); putchar(x%10+'0'); } int max(int a,int b){return a>b?a:b;} int min(int a,int b){return a<b?a:b;} void mod(int &x,int y){if((x+=y)>=Q) x-=Q;}
void dfs(int x,int rt) { sz[x]=1; f[x][c[x]+N]=1; int g[N<<1]; go(v) { if(v==rt) continue; dfs(v,x); int vr=min(sz[x],cnt),vl=max(-sz[x],-cnt); int svr=min(sz[v],cnt),svl=max(-sz[v],-cnt); for(int j=vr;j>=vl;j--) g[j+N]=f[x][j+N]; for(int j=vr;j>=vl;j--) for(int k=svr;k>=svl;k--) if(j+k>=-cnt) mod(f[x][j+k+N],g[j+N]*f[v][k+N]%Q); sz[x]+=sz[v]; } for(int j=1;j<=min(cnt,sz[x]);j++) mod(ans,f[x][j+N]); L[x]=max(-sz[x],-cnt),R[x]=min(sz[x],cnt); }
signed main() { n=fr(); for(int i=1;i<=n;i++) ct[a[i]=fr()]++; for(int i=1;i<n;i++) { u=fr(),v=fr(); as[u].pb(v),as[v].pb(u); }
for(int i=1;i<=n;i++) { cnt=ct[i]; if(!cnt) continue; else if(cnt==1) ans++; else { for(int x=1;x<=n;x++) for(int j=L[x];j<=R[x];j++) f[x][j+N]=0; for(int x=1;x<=n;x++) c[x]=(a[x]==i)?1:-1; dfs(1,-1); } } fw(ans); return 0; }
|