gf

 •  5小时前


include <bits/stdc++.h> using namespace std; typedef long long ll; typedef pair<int, int>pi;

define fi first define se second const int N = 5e5 + 5, mod = 1e9 + 7, i2 = mod + 1 >> 1; int type, n, m, k, sn, fa[20][N], dep[N], c1[N], c2[N], L[N], R[N], b[N << 1]; ll ans, pw[N], inv[N], f[N], f1[4]; bool key[N]; vectorg[N], e[N]; pi w[N]; vectord[N]; void Mul(ll &x, ll y) {

x = x * y % mod; } void Add(ll &x, ll y) {

x = (x + y) % mod; } struct qry {

int l, r, k; }; vectorc[N]; struct sgt {

struct tree {

int l, r;
ll s, mul;

} t[N << 2]; void build(int p, int l, int r) {

t[p] = {l, r, 0, 1};

if (l == r)
    return;

int mid = l + r >> 1;
build(p << 1, l, mid), build(p << 1 | 1, mid + 1, r);

} void upd(int p, ll k) {

Mul(t[p].mul, k), Mul(t[p].s, k);

} void down(int p) {

if (t[p].mul != 1)
    upd(p << 1, t[p].mul), upd(p << 1 | 1, t[p].mul), t[p].mul = 1;

} void upd(int p) {

t[p].s = (t[p << 1].s + t[p << 1 | 1].s) % mod;

} void mdf(int p, int l, int r, int k) {

if (t[p].l > r || t[p].r < l)
    return;

if (t[p].l >= l && t[p].r <= r)
    return upd(p, k);

down(p), mdf(p << 1, l, r, k), mdf(p << 1 | 1, l, r, k), upd(p);

} void ins(int p, int x, ll k) {

if (t[p].l > x || t[p].r < x)
    return;

if (t[p].l == t[p].r)
    return t[p].s = k, void();

down(p), ins(p << 1, x, k), ins(p << 1 | 1, x, k), upd(p);

} ll ask(int p, int l, int r) {

if (t[p].l > r || t[p].r < l)
    return 0;

if (t[p].l >= l && t[p].r <= r)
    return t[p].s;

return down(p), (ask(p << 1, l, r) + ask(p << 1 | 1, l, r)) % mod;

} } T, T1; bool in(int x, int y) {

return L[x] <= L[y] && R[x] >= L[y]; } void dfs(int x, int p) {

dep[x] = dep[p] + 1, L[x] = ++sn, fa[0][x] = p;

for (int i = 1; i <= 19; i++)

fa[i][x] = fa[i - 1][fa[i - 1][x]];

for (int y : e[x])

if (y != p)
    dfs(y, x);

R[x] = sn; } int kth(int x, int k) {

for (int i = 0; i <= 19; i++)

if (k >> i & 1)
    x = fa[i][x];

return x; } int lca(int x, int y) {

if (dep[x] < dep[y])

swap(x, y);

for (int i = 19; ~i; i--)

if (dep[x] - dep[y] >= 1 << i)
    x = fa[i][x];

if (x == y)

return x;

for (int i = 19; ~i; i--)

if (fa[i][x]^fa[i][y])
    x = fa[i][x], y = fa[i][y];

return fa[0][x]; } void dfs1(int x, int p) {

for (int y : e[x])

if (y != p)
    dfs1(y, x), c1[x] += c1[y];

for (int y : g[x])

if (in(x, y))
    c1[x]++, c1[kth(y, dep[y] - dep[x] - 1)]++;

} void dfs2(int x, int p) {

for (int y : g[x])

if (!in(x, y))
    c2[x]++;

int s = 0;

for (int y : e[x])

if (y != p)
    s += c1[y];

for (int y : e[x])

if (y != p)
    c2[y] = c2[x] + s - c1[y], dfs2(y, x);

} ll calc(int x, int p) {

if (!d[x].size())

return f1[2];

int ct = 0, w = 0; ll res = 0, s; b[++ct] = L[x] + 1, b[++ct] = R[x] + 1;

for (pi i : d[x])

b[++ct] = L[i.fi], b[++ct] = R[i.fi] + 1, b[++ct] = L[i.se], b[++ct] = R[i.se] + 1;

sort(b + 1, b + ct + 1), ct = unique(b + 1, b + ct + 1) - b - 1, T1.build(1, 1, ct);

for (int i = 1; i <= ct; i++)

c[i].clear();

for (int i = 2; i <= ct; i++)

T1.ins(1, i, T.ask(1, b[i - 1], b[i] - 1));

for (pi i : d[x]) {

int l1 = lower_bound(b + 1, b + ct + 1, L[i.fi]) - b, r1 = lower_bound(b + 1, b + ct + 1, R[i.fi] + 1) - b;
int l2 = lower_bound(b + 1, b + ct + 1, L[i.se]) - b, r2 = lower_bound(b + 1, b + ct + 1, R[i.se] + 1) - b;
c[l1].push_back({l2 + 1, r2, 2}), c[r1].push_back({l2 + 1, r2, i2}), c[l2].push_back({l1 + 1, r1, 2}),
c[r2].push_back({l1 + 1, r1, i2});

}

for (int i = 1; i <= ct; i++) {

if (i > 1)
    Add(res, T.ask(1, b[i - 1], b[i] - 1)*T1.ask(1, 1, ct));

for (qry j : c[i])
    T1.mdf(1, j.l, j.r, j.k);

}

for (int y : e[x])

if (y != p)
    s = T.ask(1, L[y], R[y]), Add(res, -s * s), w += c1[y];

return res * inv[w + 1] % mod; } void solve(int x, int p) {

for (int y : e[x])

if (y != p)
    solve(y, x);

for (int y : g[x])

if (in(x, y))
    T.mdf(1, L[y], R[y], 2);

for (int i = 0; i <= 3; i++)

f1[i] = 0;

f1[0] = 1; ll f2[4] = {0}; int ct = 0;

for (int y : e[x]) {

if (y == p)
    continue;

ct += c1[y];
ll sf = T.ask(1, L[y], R[y]);

for (int i = 0; i <= 3; i++) {
    Add(f2[i], f1[i]*pw[c1[y]]);
    Add(f2[min(i + 1, 3)], f1[i]*sf);
}

for (int i = 0; i <= 3; i++)
    f1[i] = f2[i], f2[i] = 0;

}

for (int y : e[x])

if (y != p)
    T.mdf(1, L[y], R[y], pw[ct - c1[y]]);

if (!key[x])

Add(f[x], f1[2] + f1[3]), Add(ans, (calc(x, p) + f1[3])*pw[c2[x]]);

else

Add(f[x], -f1[0] - f1[1]), Add(ans, (calc(x, p) - f1[2] - f1[1] - f1[0])*pw[c2[x]]);

T.ins(1, L[x], f[x]); } int rd() {

int x = 0; char ch = getchar();

while (ch < '0' || ch > '9')

ch = getchar();

while (ch >= '0' && ch <= '9')

x = (x << 1) + (x << 3) + ch - '0', ch = getchar();

return x; } int main() {

type = rd(), n = rd(), m = rd(), k = rd(), pw[0] = inv[0] = 1;

for (int i = 1; i <= max(n, m); i++)

pw[i] = pw[i - 1] * 2 % mod, inv[i] = inv[i - 1] * i2 % mod;

for (int i = 1, x, y; i < n; i++)

x = rd(), y = rd(), e[x].push_back(y), e[y].push_back(x);

for (int i = 1, x, y; i <= m; i++)

x = rd(), y = rd(), g[x].push_back(y), g[y].push_back(x), w[i] = {x, y};

for (int i = 1; i <= k; i++)

key[rd()] = 1;

dfs(1, 0), dfs1(1, 0), dfs2(1, 0), T.build(1, 1, n);

for (int i = 1; i <= m; i++) {

int x = w[i].fi, y = w[i].se;

if (!in(x, y) && !in(y, x))
    d[lca(x, y)].push_back({x, y});

}

solve(1, 0); printf("%lld", (mod - ans) % mod); }


评论:

请先登录,才能进行评论