客官您好,答案一份

guoguo  •  4小时前


include <bits/stdc++.h>

using namespace std;

define int long long

const int N=5e5+5,mods=1e9+7; namespace tr{

struct node{
	signed lson,rson,sm,laz;
}p[N*50];
int idx;
void mul(int x,int sm){
	p[x].sm=1ll*p[x].sm*sm%mods;
	p[x].laz=1ll*p[x].laz*sm%mods;
}
void dnset(int x){
	if(p[x].laz!=1){
		mul(p[x].lson,p[x].laz);
		mul(p[x].rson,p[x].laz);
		p[x].laz=1;
	}
}
void init(int x){
	if(!p[x].lson)p[x].lson=++idx,p[idx].laz=1;
	if(!p[x].rson)p[x].rson=++idx,p[idx].laz=1;
}
void upset(int x){
	p[x].sm=(p[p[x].lson].sm+p[p[x].rson].sm)%mods;
}
void mul(int x,int l,int r,int sm,int nl,int nr){
	if(l<=nl&&r>=nr){
		mul(x,sm);
		return;
	}
	int mid=nl+nr>>1;
	init(x);
	dnset(x);
	if(l<=mid)mul(p[x].lson,l,r,sm,nl,mid);
	if(r>mid)mul(p[x].rson,l,r,sm,mid+1,nr);
	upset(x);
}
void add(int x,int d,int sm,int nl,int nr){
	if(nl==nr){
		p[x].sm=(p[x].sm+sm)%mods;
		return;
	}
	init(x);
	dnset(x);
	int mid=nl+nr>>1;
	if(d<=mid)add(p[x].lson,d,sm,nl,mid);
	else add(p[x].rson,d,sm,mid+1,nr);
	upset(x);
}
int gets(int x,int l,int r,int nl,int nr){
	if(!x)return 0;
	if(l<=nl&&r>=nr)return p[x].sm;
	int mid=nl+nr>>1;
	dnset(x);
	if(r<=mid)return gets(p[x].lson,l,r,nl,mid);
	if(l>mid)return gets(p[x].rson,l,r,mid+1,nr);
	return (gets(p[x].lson,l,r,nl,mid)+gets(p[x].rson,l,r,mid+1,nr))%mods;
}
int hb(int a,int b,int nl,int nr){
	if(!a||!b)return a|b;
	if(nl==nr){
		p[a].sm=(p[a].sm+p[b].sm)%mods;
		return a;
	}
	int mid=nl+nr>>1;
	dnset(a);dnset(b);
	p[a].lson=hb(p[a].lson,p[b].lson,nl,mid);
	p[a].rson=hb(p[a].rson,p[b].rson,mid+1,nr);
	upset(a);
	return a;
}

} int pows(int a,int b){

if(b==0)return 1;
int res=pows(a,b>>1);
res=res*res%mods;
if(b&1)res=res*a%mods;
return res;

} int inv2=mods+1>>1,op,n,m,k,rt[N],dfn[N],dy[N],js[N],f1[N],f2[N],sl[N],ff[N],sz[N],res,mk[N],cf[N],pw2[N],inv[N],dp[N],eds[N],dep[N],fa[N][20],idx,bk[N]; vectorp[N],g[N],gs[N]; vector<pair<int,int> >jl[N]; map<pair<int,int>,vector<pair<int,int> > >q[N]; void dfs(int x){

for(int i=1;i<=19;i++)fa[x][i]=fa[fa[x][i-1]][i-1];
mk[x]=1;
dfn[x]=++idx;
dy[idx]=x;
for(auto c:p[x]){
	if(mk[c])continue;
	dep[c]=dep[x]+1;
	fa[c][0]=x;
	dfs(c);
}
mk[x]=0;
eds[x]=idx;

} int up(int x,int k){

while(k){
	x=fa[x][__lg(k&-k)];
	k^=k&-k;
}
return x;

} int lca(int a,int b){

if(dep[a]>dep[b])swap(a,b);
b=up(b,dep[b]-dep[a]);
if(a==b)return a;
for(int i=19;i>=0;i--){
	if(fa[a][i]!=fa[b][i])a=fa[a][i],b=fa[b][i];
}
return fa[a][0];

} bool in(int a,int b){

return a>=dfn[b]&&a<=eds[b];

} struct msg{

int x,op,l,r;

}; void solve(int x){

rt[x]=++tr::idx;
mk[x]=1;
dp[x]=1;
if(bk[x])tr::add(rt[x],dfn[x],-1,1,n);
sz[x]=0;
for(auto c:p[x]){
	if(mk[c])continue;
	solve(c);
	dp[x]=dp[x]*dp[c]%mods;
	sz[x]+=sz[c]+g[c].size();
}
for(auto [t,c]:q[x]){
	int a=t.first,b=t.second,ans=-tr::gets(rt[a],1,n,1,n)*tr::gets(rt[b],1,n,1,n)%mods;
	vector<msg>jl;
	for(auto [s1,s2]:c){
		jl.push_back({dfn[s1],0,dfn[s2],eds[s2]});
		jl.push_back({eds[s1]+1,1,dfn[s2],eds[s2]});
	}
	jl.push_back({n+1,2});
	sort(jl.begin(),jl.end(),[&](msg a,msg b){
	    return a.x<b.x;
	});
	int lst=0;
	for(auto [x,op,l,r]:jl){
		if(lst<x)ans+=tr::gets(rt[a],lst,x-1,1,n)*tr::gets(rt[b],1,n,1,n)%mods;
		lst=x;
		if(op==0)tr::mul(rt[b],l,r,2,1,n);
		if(op==1)tr::mul(rt[b],l,r,inv2,1,n);
	}
	ans%=mods;
	res+=ans*pw2[cf[dfn[x]]-sz[x]]%mods*dp[x]%mods*inv[a]%mods*inv[b]%mods;
}
dp[x]=1;
for(auto c:p[x]){
	if(mk[c])continue;
	int he=tr::gets(rt[x],1,n,1,n);
	int tmp=he*tr::gets(rt[c],1,n,1,n)%mods;
	tr::mul(rt[c],dp[x]);
	tr::mul(rt[x],dp[c]);
	rt[x]=tr::hb(rt[x],rt[c],1,n);
	tr::add(rt[x],dfn[x],tmp,1,n);
	dp[x]=dp[x]*dp[c]%mods;
}
res+=tr::gets(rt[x],dfn[x],dfn[x],1,n)*pw2[cf[dfn[x]]-sz[x]]%mods;
res%=mods;
for(auto c:g[x]){
	tr::mul(rt[x],dfn[c],eds[c],2,1,n);
    dp[x]=dp[x]*2%mods;
    if(c==x)js[x]++;
}
inv[x]=pows(dp[x],mods-2);
mk[x]=0;

} signed main(){

ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
cin>>op>>n>>m>>k;
pw2[0]=1;
for(int i=1;i<=m;i++)pw2[i]=pw2[i-1]*2%mods;
for(int i=1;i<n;i++){
	int x,y;
	cin>>x>>y;
	p[x].push_back(y);
	p[y].push_back(x);
}
dfs(1);
for(int i=1;i<=m;i++){
	int a,b;
	cin>>a>>b;
	if(dfn[a]>dfn[b])swap(a,b);
	if(dfn[b]>=dfn[a]&&dfn[b]<=eds[a]){
		g[up(b,dep[b]-dep[a]-1)].push_back(b);
		cf[dfn[b]]++;
		cf[eds[b]+1]--;
		cf[1]++;f1[i]=a;f2[i]=b;
		int tmp=up(b,dep[b]-dep[a]-1);
		cf[dfn[tmp]]--;
		cf[eds[tmp]+1]++;
		ff[dfn[b]]++;ff[eds[b]+1]--;
	}else{
		int c=lca(a,b);
		jl[c].push_back({a,b});
		q[c][{up(a,dep[a]-dep[c]-1),up(b,dep[b]-dep[c]-1)}].push_back({a,b});
		gs[up(a,dep[a]-dep[c]-1)].push_back(a);
		gs[up(b,dep[b]-dep[c]-1)].push_back(b);
		cf[dfn[a]]++;
		cf[eds[a]+1]--;
		cf[dfn[b]]++;
		cf[eds[b]+1]--;
	}
}
for(int i=1;i<=n;i++)cf[i]+=cf[i-1],ff[i]+=ff[i-1];
for(int i=1;i<=k;i++){
	int x;
	cin>>x;
	bk[x]=1;
}
solve(1);
cout<<(-res%mods+mods)%mods<<"\n";

}


评论:

请先登录,才能进行评论