ac

哪吒  •  1天前


include <bits/stdc++.h>

using namespace std; typedef long long ll; typedef pair<int, int>pi;

define fi first

define se second

const int N = 5e5 + 5, mod = 1e9 + 7, i2 = mod + 1 >> 1; int type, n, m, k, sn, fa[20][N], dep[N], c1[N], c2[N], L[N], R[N], b[N << 1]; ll ans, pw[N], inv[N], f[N], f1[4]; bool key[N]; vectorg[N], e[N]; pi w[N]; vectord[N]; void Mul(ll &x, ll y) {

x = x * y % mod;

} void Add(ll &x, ll y) {

x = (x + y) % mod;

} struct qry {

int l, r, k;

}; vectorc[N]; struct sgt {

struct tree {
    int l, r;
    ll s, mul;
} t[N << 2];
void build(int p, int l, int r) {
    t[p] = {l, r, 0, 1};

    if (l == r)
        return;

    int mid = l + r >> 1;
    build(p << 1, l, mid), build(p << 1 | 1, mid + 1, r);
}
void upd(int p, ll k) {
    Mul(t[p].mul, k), Mul(t[p].s, k);
}
void down(int p) {
    if (t[p].mul != 1)
        upd(p << 1, t[p].mul), upd(p << 1 | 1, t[p].mul), t[p].mul = 1;
}
void upd(int p) {
    t[p].s = (t[p << 1].s + t[p << 1 | 1].s) % mod;
}
void mdf(int p, int l, int r, int k) {
    if (t[p].l > r || t[p].r < l)
        return;

    if (t[p].l >= l && t[p].r <= r)
        return upd(p, k);

    down(p), mdf(p << 1, l, r, k), mdf(p << 1 | 1, l, r, k), upd(p);
}
void ins(int p, int x, ll k) {
    if (t[p].l > x || t[p].r < x)
        return;

    if (t[p].l == t[p].r)
        return t[p].s = k, void();

    down(p), ins(p << 1, x, k), ins(p << 1 | 1, x, k), upd(p);
}
ll ask(int p, int l, int r) {
    if (t[p].l > r || t[p].r < l)
        return 0;

    if (t[p].l >= l && t[p].r <= r)
        return t[p].s;

    return down(p), (ask(p << 1, l, r) + ask(p << 1 | 1, l, r)) % mod;
}

} T, T1; bool in(int x, int y) {

return L[x] <= L[y] && R[x] >= L[y];

} void dfs(int x, int p) {

dep[x] = dep[p] + 1, L[x] = ++sn, fa[0][x] = p;

for (int i = 1; i <= 19; i++)
    fa[i][x] = fa[i - 1][fa[i - 1][x]];

for (int y : e[x])
    if (y != p)
        dfs(y, x);

R[x] = sn;

} int kth(int x, int k) {

for (int i = 0; i <= 19; i++)
    if (k >> i & 1)
        x = fa[i][x];

return x;

} int lca(int x, int y) {

if (dep[x] < dep[y])
    swap(x, y);

for (int i = 19; ~i; i--)
    if (dep[x] - dep[y] >= 1 << i)
        x = fa[i][x];

if (x == y)
    return x;

for (int i = 19; ~i; i--)
    if (fa[i][x]^fa[i][y])
        x = fa[i][x], y = fa[i][y];

return fa[0][x];

} void dfs1(int x, int p) {

for (int y : e[x])
    if (y != p)
        dfs1(y, x), c1[x] += c1[y];

for (int y : g[x])
    if (in(x, y))
        c1[x]++, c1[kth(y, dep[y] - dep[x] - 1)]++;

} void dfs2(int x, int p) {

for (int y : g[x])
    if (!in(x, y))
        c2[x]++;

int s = 0;

for (int y : e[x])
    if (y != p)
        s += c1[y];

for (int y : e[x])
    if (y != p)
        c2[y] = c2[x] + s - c1[y], dfs2(y, x);

} ll calc(int x, int p) {

if (!d[x].size())
    return f1[2];

int ct = 0, w = 0;
ll res = 0, s;
b[++ct] = L[x] + 1, b[++ct] = R[x] + 1;

for (pi i : d[x])
    b[++ct] = L[i.fi], b[++ct] = R[i.fi] + 1, b[++ct] = L[i.se], b[++ct] = R[i.se] + 1;

sort(b + 1, b + ct + 1), ct = unique(b + 1, b + ct + 1) - b - 1, T1.build(1, 1, ct);

for (int i = 1; i <= ct; i++)
    c[i].clear();

for (int i = 2; i <= ct; i++)
    T1.ins(1, i, T.ask(1, b[i - 1], b[i] - 1));

for (pi i : d[x]) {
    int l1 = lower_bound(b + 1, b + ct + 1, L[i.fi]) - b, r1 = lower_bound(b + 1, b + ct + 1, R[i.fi] + 1) - b;
    int l2 = lower_bound(b + 1, b + ct + 1, L[i.se]) - b, r2 = lower_bound(b + 1, b + ct + 1, R[i.se] + 1) - b;
    c[l1].push_back({l2 + 1, r2, 2}), c[r1].push_back({l2 + 1, r2, i2}), c[l2].push_back({l1 + 1, r1, 2}),
    c[r2].push_back({l1 + 1, r1, i2});
}

for (int i = 1; i <= ct; i++) {
    if (i > 1)
        Add(res, T.ask(1, b[i - 1], b[i] - 1)*T1.ask(1, 1, ct));

    for (qry j : c[i])
        T1.mdf(1, j.l, j.r, j.k);
}

for (int y : e[x])
    if (y != p)
        s = T.ask(1, L[y], R[y]), Add(res, -s * s), w += c1[y];

return res * inv[w + 1] % mod;

} void solve(int x, int p) {

for (int y : e[x])
    if (y != p)
        solve(y, x);

for (int y : g[x])
    if (in(x, y))
        T.mdf(1, L[y], R[y], 2);

for (int i = 0; i <= 3; i++)
    f1[i] = 0;

f1[0] = 1;
ll f2[4] = {0};
int ct = 0;

for (int y : e[x]) {
    if (y == p)
        continue;

    ct += c1[y];
    ll sf = T.ask(1, L[y], R[y]);

    for (int i = 0; i <= 3; i++) {
        Add(f2[i], f1[i]*pw[c1[y]]);
        Add(f2[min(i + 1, 3)], f1[i]*sf);
    }

    for (int i = 0; i <= 3; i++)
        f1[i] = f2[i], f2[i] = 0;
}

for (int y : e[x])
    if (y != p)
        T.mdf(1, L[y], R[y], pw[ct - c1[y]]);

if (!key[x])
    Add(f[x], f1[2] + f1[3]), Add(ans, (calc(x, p) + f1[3])*pw[c2[x]]);
else
    Add(f[x], -f1[0] - f1[1]), Add(ans, (calc(x, p) - f1[2] - f1[1] - f1[0])*pw[c2[x]]);

T.ins(1, L[x], f[x]);

} int rd() {

int x = 0;
char ch = getchar();

while (ch < '0' || ch > '9')
    ch = getchar();

while (ch >= '0' && ch <= '9')
    x = (x << 1) + (x << 3) + ch - '0', ch = getchar();

return x;

} int main() {

type = rd(), n = rd(), m = rd(), k = rd(), pw[0] = inv[0] = 1;

for (int i = 1; i <= max(n, m); i++)
    pw[i] = pw[i - 1] * 2 % mod, inv[i] = inv[i - 1] * i2 % mod;

for (int i = 1, x, y; i < n; i++)
    x = rd(), y = rd(), e[x].push_back(y), e[y].push_back(x);

for (int i = 1, x, y; i <= m; i++)
    x = rd(), y = rd(), g[x].push_back(y), g[y].push_back(x), w[i] = {x, y};

for (int i = 1; i <= k; i++)
    key[rd()] = 1;

dfs(1, 0), dfs1(1, 0), dfs2(1, 0), T.build(1, 1, n);

for (int i = 1; i <= m; i++) {
    int x = w[i].fi, y = w[i].se;

    if (!in(x, y) && !in(y, x))
        d[lca(x, y)].push_back({x, y});
}

solve(1, 0);
printf("%lld", (mod - ans) % mod);

}


评论:

请先登录,才能进行评论