guoguo • 4小时前
using namespace std;
const int N=5e5+5,mods=1e9+7; namespace tr{
struct node{
signed lson,rson,sm,laz;
}p[N*50];
int idx;
void mul(int x,int sm){
p[x].sm=1ll*p[x].sm*sm%mods;
p[x].laz=1ll*p[x].laz*sm%mods;
}
void dnset(int x){
if(p[x].laz!=1){
mul(p[x].lson,p[x].laz);
mul(p[x].rson,p[x].laz);
p[x].laz=1;
}
}
void init(int x){
if(!p[x].lson)p[x].lson=++idx,p[idx].laz=1;
if(!p[x].rson)p[x].rson=++idx,p[idx].laz=1;
}
void upset(int x){
p[x].sm=(p[p[x].lson].sm+p[p[x].rson].sm)%mods;
}
void mul(int x,int l,int r,int sm,int nl,int nr){
if(l<=nl&&r>=nr){
mul(x,sm);
return;
}
int mid=nl+nr>>1;
init(x);
dnset(x);
if(l<=mid)mul(p[x].lson,l,r,sm,nl,mid);
if(r>mid)mul(p[x].rson,l,r,sm,mid+1,nr);
upset(x);
}
void add(int x,int d,int sm,int nl,int nr){
if(nl==nr){
p[x].sm=(p[x].sm+sm)%mods;
return;
}
init(x);
dnset(x);
int mid=nl+nr>>1;
if(d<=mid)add(p[x].lson,d,sm,nl,mid);
else add(p[x].rson,d,sm,mid+1,nr);
upset(x);
}
int gets(int x,int l,int r,int nl,int nr){
if(!x)return 0;
if(l<=nl&&r>=nr)return p[x].sm;
int mid=nl+nr>>1;
dnset(x);
if(r<=mid)return gets(p[x].lson,l,r,nl,mid);
if(l>mid)return gets(p[x].rson,l,r,mid+1,nr);
return (gets(p[x].lson,l,r,nl,mid)+gets(p[x].rson,l,r,mid+1,nr))%mods;
}
int hb(int a,int b,int nl,int nr){
if(!a||!b)return a|b;
if(nl==nr){
p[a].sm=(p[a].sm+p[b].sm)%mods;
return a;
}
int mid=nl+nr>>1;
dnset(a);dnset(b);
p[a].lson=hb(p[a].lson,p[b].lson,nl,mid);
p[a].rson=hb(p[a].rson,p[b].rson,mid+1,nr);
upset(a);
return a;
}
} int pows(int a,int b){
if(b==0)return 1;
int res=pows(a,b>>1);
res=res*res%mods;
if(b&1)res=res*a%mods;
return res;
} int inv2=mods+1>>1,op,n,m,k,rt[N],dfn[N],dy[N],js[N],f1[N],f2[N],sl[N],ff[N],sz[N],res,mk[N],cf[N],pw2[N],inv[N],dp[N],eds[N],dep[N],fa[N][20],idx,bk[N]; vectorp[N],g[N],gs[N]; vector<pair<int,int> >jl[N]; map<pair<int,int>,vector<pair<int,int> > >q[N]; void dfs(int x){
for(int i=1;i<=19;i++)fa[x][i]=fa[fa[x][i-1]][i-1];
mk[x]=1;
dfn[x]=++idx;
dy[idx]=x;
for(auto c:p[x]){
if(mk[c])continue;
dep[c]=dep[x]+1;
fa[c][0]=x;
dfs(c);
}
mk[x]=0;
eds[x]=idx;
} int up(int x,int k){
while(k){
x=fa[x][__lg(k&-k)];
k^=k&-k;
}
return x;
} int lca(int a,int b){
if(dep[a]>dep[b])swap(a,b);
b=up(b,dep[b]-dep[a]);
if(a==b)return a;
for(int i=19;i>=0;i--){
if(fa[a][i]!=fa[b][i])a=fa[a][i],b=fa[b][i];
}
return fa[a][0];
} bool in(int a,int b){
return a>=dfn[b]&&a<=eds[b];
} struct msg{
int x,op,l,r;
}; void solve(int x){
rt[x]=++tr::idx;
mk[x]=1;
dp[x]=1;
if(bk[x])tr::add(rt[x],dfn[x],-1,1,n);
sz[x]=0;
for(auto c:p[x]){
if(mk[c])continue;
solve(c);
dp[x]=dp[x]*dp[c]%mods;
sz[x]+=sz[c]+g[c].size();
}
for(auto [t,c]:q[x]){
int a=t.first,b=t.second,ans=-tr::gets(rt[a],1,n,1,n)*tr::gets(rt[b],1,n,1,n)%mods;
vector<msg>jl;
for(auto [s1,s2]:c){
jl.push_back({dfn[s1],0,dfn[s2],eds[s2]});
jl.push_back({eds[s1]+1,1,dfn[s2],eds[s2]});
}
jl.push_back({n+1,2});
sort(jl.begin(),jl.end(),[&](msg a,msg b){
return a.x<b.x;
});
int lst=0;
for(auto [x,op,l,r]:jl){
if(lst<x)ans+=tr::gets(rt[a],lst,x-1,1,n)*tr::gets(rt[b],1,n,1,n)%mods;
lst=x;
if(op==0)tr::mul(rt[b],l,r,2,1,n);
if(op==1)tr::mul(rt[b],l,r,inv2,1,n);
}
ans%=mods;
res+=ans*pw2[cf[dfn[x]]-sz[x]]%mods*dp[x]%mods*inv[a]%mods*inv[b]%mods;
}
dp[x]=1;
for(auto c:p[x]){
if(mk[c])continue;
int he=tr::gets(rt[x],1,n,1,n);
int tmp=he*tr::gets(rt[c],1,n,1,n)%mods;
tr::mul(rt[c],dp[x]);
tr::mul(rt[x],dp[c]);
rt[x]=tr::hb(rt[x],rt[c],1,n);
tr::add(rt[x],dfn[x],tmp,1,n);
dp[x]=dp[x]*dp[c]%mods;
}
res+=tr::gets(rt[x],dfn[x],dfn[x],1,n)*pw2[cf[dfn[x]]-sz[x]]%mods;
res%=mods;
for(auto c:g[x]){
tr::mul(rt[x],dfn[c],eds[c],2,1,n);
dp[x]=dp[x]*2%mods;
if(c==x)js[x]++;
}
inv[x]=pows(dp[x],mods-2);
mk[x]=0;
} signed main(){
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
cin>>op>>n>>m>>k;
pw2[0]=1;
for(int i=1;i<=m;i++)pw2[i]=pw2[i-1]*2%mods;
for(int i=1;i<n;i++){
int x,y;
cin>>x>>y;
p[x].push_back(y);
p[y].push_back(x);
}
dfs(1);
for(int i=1;i<=m;i++){
int a,b;
cin>>a>>b;
if(dfn[a]>dfn[b])swap(a,b);
if(dfn[b]>=dfn[a]&&dfn[b]<=eds[a]){
g[up(b,dep[b]-dep[a]-1)].push_back(b);
cf[dfn[b]]++;
cf[eds[b]+1]--;
cf[1]++;f1[i]=a;f2[i]=b;
int tmp=up(b,dep[b]-dep[a]-1);
cf[dfn[tmp]]--;
cf[eds[tmp]+1]++;
ff[dfn[b]]++;ff[eds[b]+1]--;
}else{
int c=lca(a,b);
jl[c].push_back({a,b});
q[c][{up(a,dep[a]-dep[c]-1),up(b,dep[b]-dep[c]-1)}].push_back({a,b});
gs[up(a,dep[a]-dep[c]-1)].push_back(a);
gs[up(b,dep[b]-dep[c]-1)].push_back(b);
cf[dfn[a]]++;
cf[eds[a]+1]--;
cf[dfn[b]]++;
cf[eds[b]+1]--;
}
}
for(int i=1;i<=n;i++)cf[i]+=cf[i-1],ff[i]+=ff[i-1];
for(int i=1;i<=k;i++){
int x;
cin>>x;
bk[x]=1;
}
solve(1);
cout<<(-res%mods+mods)%mods<<"\n";
}
评论:
请先登录,才能进行评论