欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

相交(Access)

程序员文章站 2022-04-01 14:48:07
...

【一句话题意】有一棵n 个节点,n-1 条边的树。树上有m 条路径,定义两条路径相交仅当这两条路径经过至少一个相同的点。求这m 条路径中选择两条相交的路径的方案数。
【分析】两条树上的链相交的充要条件为:一条链上深度最小的节点(lca)被另一条链所包含。通过欧拉序预处理+树状数组,我们可以通过O(logN)的时间求出一个点到根的路径上的所有点,共有多少个lca。因此我们可以方便的求出一条链上所有点事包含了多少条深度最小的点。
由于当两条链的lca相同时答案会记录多次,所以我们要额外减去重复的情况,具体来说,若一个点是i条链的lca,那么我们会重复计算i*i-c(i,2)次。(换言之,只有c(i,2)次的运算是有用的。)
时间复杂度O(NlogN)
【code】

#include <bits/stdc++.h>
typedef long long ll;
const int N=2000010;
using namespace std;
const int inf = 0x3f3f3f3f, INF = 0x7fffffff;
const ll infll = 0x3f3f3f3f3f3f3f3fll;
const ll INFll = 0x7fffffffffffffffll;
inline int read(){
	int tmp=0,fh=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')fh=-fh;ch=getchar();}
	while(ch>='0'&&ch<='9')tmp=tmp*10+ch-'0',ch=getchar();
	return tmp*fh;
}
struct Node{int a, b, lca;}p[N];
struct Edge{int v, nxt;}edge[N * 2]; 
int n,head[N],tot,cnt[N],in[N],out[N],num[N],las[N],dep[N],m,f[N],use[N];
vector<int>qto[N],qid[N];
ll ans;
inline void add_edge(int u,int v){edge[tot].v=v,edge[tot].nxt=head[u],head[u]=tot++;}

int lowbit(int x){return x&-x;}
void modify(int x,int d){while(x<=n*2)num[x]=num[x]+d,x=x+lowbit(x);}
int query(int x){int sum=0;while(x>0)sum=sum+num[x],x=x-lowbit(x);return sum;}

int find(int x){if(f[x]!=x) f[x]=find(f[x]);return f[x];}

void dfs(int x, int fa){
	in[x]=++tot;dep[x]=dep[fa]+1;las[x]=fa;use[x]=true;
	for (int i=0;i<qto[x].size();i++)//求欧拉序&tarjan
		if(use[qto[x][i]]) p[qid[x][i]].lca=find(qto[x][i]);
	for (int i=head[x];i!=-1;i=edge[i].nxt)
		if(edge[i].v!=fa)
			dfs(edge[i].v,x);
	out[x]=++tot;
	f[x]=fa;
}
ll C(int n){if(n<2)return 0;return 1ll*n*(n-1)/2;}
int main(){
	memset(head,-1,sizeof(head));
	cin>>n>>m;
	for(int i=1;i<n;i++){
		int u=read(),v=read();
		add_edge(u,v),add_edge(v,u);
	}//build tree
	for(int i=1;i<=m;i++){
		p[i].a=read(),p[i].b=read();
		qto[p[i].a].push_back(p[i].b); qid[p[i].a].push_back(i);
		qto[p[i].b].push_back(p[i].a); qid[p[i].b].push_back(i);
	}
	for(int i=1;i<=n;i++)f[i]=i;
	tot=0;dfs(1,0);
	for(int i=1;i<=m;i++){
		modify(in[p[i].lca],1);
		modify(out[p[i].lca],-1);
		cnt[p[i].lca]++;
	}
	for(int i=1;i<=m;i++){
		ans=ans+query(in[p[i].a])+query(in[p[i].b]);
		ans=ans-query(in[p[i].lca])-query(in[las[p[i].lca]]);
	}
	for(int i=1;i<=n;i++)
		ans=ans-1ll*cnt[i]*cnt[i]+C(cnt[i]);//delete the repeating part
	printf("%lld\n",ans);
	return 0;
}

相关标签: 题解