1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
| #include<bits/stdc++.h> using namespace std; const int N = 1e6+6; const int M = 2e6+6;
struct Edge{ int to, nxt; }edge[M]; int h[N], cnt; void _add(int u, int v) {edge[++cnt] = {v, h[u]}; h[u] = cnt; } void add(int u, int v) {_add(u, v), _add(v, u);} int dep[N], dfn[N], dfnidx, anc[N][21], keyp[M]; int lim, m, len; void dfs(int x, int fa){ dfn[x] = ++dfnidx; dep[x] = dep[fa] + 1; anc[x][0] = fa; for(int i=1;i<=lim;i++) anc[x][i] = anc[anc[x][i-1]][i-1]; for(int i=h[x];i;i=edge[i].nxt){ int v = edge[i].to; if(v == fa) continue; dfs(v, x); } } int getlca(int x, int y){ if(dep[x] < dep[y]) swap(x, y); int sa = dep[x] - dep[y]; for(int i=lim;i>=0&&sa;i--){ if(sa&(1<<i)) sa -= (1<<i), x = anc[x][i]; } if(x==y) return x; for(int i=lim;i>=0;i--){ if(anc[x][i]^anc[y][i]) x = anc[x][i], y = anc[y][i]; } return anc[x][0]; } bool cmp(int x, int y){ return dfn[x] < dfn[y]; } void init_imagtree(){ for(int i=1;i<=m;i++) cin>>keyp[i]; sort(keyp+1, keyp+1+m, cmp); len = m; for(int i=1;i<m;i++) keyp[++len] = getlca(keyp[i], keyp[i+1]); sort(keyp+1, keyp+1+len); len = unique(keyp+1, keyp+1+len) - keyp - 1; }
typedef long long ll; const ll mod = 998244353; ll fact[M], invfact[M];
ll ksm(ll bas, ll x){ ll ans = 1; while(x){ if(x&1) ans = ans * bas % mod; bas = bas * bas % mod; x >>= 1; } return ans; }
ll C(ll nn, ll mm){ return fact[nn]*invfact[mm]%mod*invfact[nn-mm]%mod; } ll calc(int n, int k){ ll ans = 0; for(int i=0;i<=n-k;i++){ ans = (ans + C(n-k, i) * fact[k+i-2] % mod * invfact[k-2] % mod * ((i == n-k)? 1: ksm(n, n-k-i-1) * (k+i) % mod) % mod) % mod; } return ans; }
int n; void init_fact(){ fact[0] = invfact[0] = 1; for(int i=1;i<=n;i++) fact[i] = fact[i-1] * i % mod; invfact[n] = ksm(fact[n], mod-2); for(int i=n-1;i>=1;i--) invfact[i] = invfact[i+1] * (i+1) % mod; } int main(){ ios::sync_with_stdio(0); cin.tie(0), cout.tie(0); cin>>n>>m; lim = __lg(n) + (__builtin_popcount(n) != 1); int u; for(int i=2;i<=n;i++){ cin>>u; add(u, i); } dfs(1, 0); init_imagtree(); init_fact(); ll ans = 0; ans += calc(n, len); if(len < n) ans = (ans + (n-len) * calc(n, len+1) % mod) % mod; cout<<ans; return 0; }
|