哪吒 • 1天前
using namespace std; typedef long long ll; typedef pair<int, int>pi;
const int N = 5e5 + 5, mod = 1e9 + 7, i2 = mod + 1 >> 1; int type, n, m, k, sn, fa[20][N], dep[N], c1[N], c2[N], L[N], R[N], b[N << 1]; ll ans, pw[N], inv[N], f[N], f1[4]; bool key[N]; vectorg[N], e[N]; pi w[N]; vectord[N]; void Mul(ll &x, ll y) {
x = x * y % mod;
} void Add(ll &x, ll y) {
x = (x + y) % mod;
} struct qry {
int l, r, k;
}; vectorc[N]; struct sgt {
struct tree {
int l, r;
ll s, mul;
} t[N << 2];
void build(int p, int l, int r) {
t[p] = {l, r, 0, 1};
if (l == r)
return;
int mid = l + r >> 1;
build(p << 1, l, mid), build(p << 1 | 1, mid + 1, r);
}
void upd(int p, ll k) {
Mul(t[p].mul, k), Mul(t[p].s, k);
}
void down(int p) {
if (t[p].mul != 1)
upd(p << 1, t[p].mul), upd(p << 1 | 1, t[p].mul), t[p].mul = 1;
}
void upd(int p) {
t[p].s = (t[p << 1].s + t[p << 1 | 1].s) % mod;
}
void mdf(int p, int l, int r, int k) {
if (t[p].l > r || t[p].r < l)
return;
if (t[p].l >= l && t[p].r <= r)
return upd(p, k);
down(p), mdf(p << 1, l, r, k), mdf(p << 1 | 1, l, r, k), upd(p);
}
void ins(int p, int x, ll k) {
if (t[p].l > x || t[p].r < x)
return;
if (t[p].l == t[p].r)
return t[p].s = k, void();
down(p), ins(p << 1, x, k), ins(p << 1 | 1, x, k), upd(p);
}
ll ask(int p, int l, int r) {
if (t[p].l > r || t[p].r < l)
return 0;
if (t[p].l >= l && t[p].r <= r)
return t[p].s;
return down(p), (ask(p << 1, l, r) + ask(p << 1 | 1, l, r)) % mod;
}
} T, T1; bool in(int x, int y) {
return L[x] <= L[y] && R[x] >= L[y];
} void dfs(int x, int p) {
dep[x] = dep[p] + 1, L[x] = ++sn, fa[0][x] = p;
for (int i = 1; i <= 19; i++)
fa[i][x] = fa[i - 1][fa[i - 1][x]];
for (int y : e[x])
if (y != p)
dfs(y, x);
R[x] = sn;
} int kth(int x, int k) {
for (int i = 0; i <= 19; i++)
if (k >> i & 1)
x = fa[i][x];
return x;
} int lca(int x, int y) {
if (dep[x] < dep[y])
swap(x, y);
for (int i = 19; ~i; i--)
if (dep[x] - dep[y] >= 1 << i)
x = fa[i][x];
if (x == y)
return x;
for (int i = 19; ~i; i--)
if (fa[i][x]^fa[i][y])
x = fa[i][x], y = fa[i][y];
return fa[0][x];
} void dfs1(int x, int p) {
for (int y : e[x])
if (y != p)
dfs1(y, x), c1[x] += c1[y];
for (int y : g[x])
if (in(x, y))
c1[x]++, c1[kth(y, dep[y] - dep[x] - 1)]++;
} void dfs2(int x, int p) {
for (int y : g[x])
if (!in(x, y))
c2[x]++;
int s = 0;
for (int y : e[x])
if (y != p)
s += c1[y];
for (int y : e[x])
if (y != p)
c2[y] = c2[x] + s - c1[y], dfs2(y, x);
} ll calc(int x, int p) {
if (!d[x].size())
return f1[2];
int ct = 0, w = 0;
ll res = 0, s;
b[++ct] = L[x] + 1, b[++ct] = R[x] + 1;
for (pi i : d[x])
b[++ct] = L[i.fi], b[++ct] = R[i.fi] + 1, b[++ct] = L[i.se], b[++ct] = R[i.se] + 1;
sort(b + 1, b + ct + 1), ct = unique(b + 1, b + ct + 1) - b - 1, T1.build(1, 1, ct);
for (int i = 1; i <= ct; i++)
c[i].clear();
for (int i = 2; i <= ct; i++)
T1.ins(1, i, T.ask(1, b[i - 1], b[i] - 1));
for (pi i : d[x]) {
int l1 = lower_bound(b + 1, b + ct + 1, L[i.fi]) - b, r1 = lower_bound(b + 1, b + ct + 1, R[i.fi] + 1) - b;
int l2 = lower_bound(b + 1, b + ct + 1, L[i.se]) - b, r2 = lower_bound(b + 1, b + ct + 1, R[i.se] + 1) - b;
c[l1].push_back({l2 + 1, r2, 2}), c[r1].push_back({l2 + 1, r2, i2}), c[l2].push_back({l1 + 1, r1, 2}),
c[r2].push_back({l1 + 1, r1, i2});
}
for (int i = 1; i <= ct; i++) {
if (i > 1)
Add(res, T.ask(1, b[i - 1], b[i] - 1)*T1.ask(1, 1, ct));
for (qry j : c[i])
T1.mdf(1, j.l, j.r, j.k);
}
for (int y : e[x])
if (y != p)
s = T.ask(1, L[y], R[y]), Add(res, -s * s), w += c1[y];
return res * inv[w + 1] % mod;
} void solve(int x, int p) {
for (int y : e[x])
if (y != p)
solve(y, x);
for (int y : g[x])
if (in(x, y))
T.mdf(1, L[y], R[y], 2);
for (int i = 0; i <= 3; i++)
f1[i] = 0;
f1[0] = 1;
ll f2[4] = {0};
int ct = 0;
for (int y : e[x]) {
if (y == p)
continue;
ct += c1[y];
ll sf = T.ask(1, L[y], R[y]);
for (int i = 0; i <= 3; i++) {
Add(f2[i], f1[i]*pw[c1[y]]);
Add(f2[min(i + 1, 3)], f1[i]*sf);
}
for (int i = 0; i <= 3; i++)
f1[i] = f2[i], f2[i] = 0;
}
for (int y : e[x])
if (y != p)
T.mdf(1, L[y], R[y], pw[ct - c1[y]]);
if (!key[x])
Add(f[x], f1[2] + f1[3]), Add(ans, (calc(x, p) + f1[3])*pw[c2[x]]);
else
Add(f[x], -f1[0] - f1[1]), Add(ans, (calc(x, p) - f1[2] - f1[1] - f1[0])*pw[c2[x]]);
T.ins(1, L[x], f[x]);
} int rd() {
int x = 0;
char ch = getchar();
while (ch < '0' || ch > '9')
ch = getchar();
while (ch >= '0' && ch <= '9')
x = (x << 1) + (x << 3) + ch - '0', ch = getchar();
return x;
} int main() {
type = rd(), n = rd(), m = rd(), k = rd(), pw[0] = inv[0] = 1;
for (int i = 1; i <= max(n, m); i++)
pw[i] = pw[i - 1] * 2 % mod, inv[i] = inv[i - 1] * i2 % mod;
for (int i = 1, x, y; i < n; i++)
x = rd(), y = rd(), e[x].push_back(y), e[y].push_back(x);
for (int i = 1, x, y; i <= m; i++)
x = rd(), y = rd(), g[x].push_back(y), g[y].push_back(x), w[i] = {x, y};
for (int i = 1; i <= k; i++)
key[rd()] = 1;
dfs(1, 0), dfs1(1, 0), dfs2(1, 0), T.build(1, 1, n);
for (int i = 1; i <= m; i++) {
int x = w[i].fi, y = w[i].se;
if (!in(x, y) && !in(y, x))
d[lca(x, y)].push_back({x, y});
}
solve(1, 0);
printf("%lld", (mod - ans) % mod);
}
评论:
请先登录,才能进行评论