思路:
首先发现如果根确定的话,我们要求的答案就是根和关键点们组成的虚树的边长之和的2倍减去根到所有关键点的距离最大值.
这个应该不难理解.
然后觉得虚树很难处理,就抛弃虚树,弄了一个繁琐至极的两遍dp.(这个我就不写是怎么dp的啦,大家应该都会.)
然后看到某个题解说最上面说的那个东西bfs就行.
如何呢?
一个关键的性质:树上距离一个点最远的点必定是直径的一端!这个想想很显然.
然后就简单了.
下面是我的傻叉dp:
#include<cstdio> #include<cstring> #include<cctype> #include<iostream> #include<algorithm> using namespace std; #define INF 1ll<<59 #define N 500010 int head[N],next[N<<1],end[N<<1],len[N<<1]; inline void addedge(int a,int b,int c){static int q=1;end[q]=b,next[q]=head[a],head[a]=q,len[q++]=c;} inline void make(int a,int b,int c){addedge(a,b,c),addedge(b,a,c);} bool c[N]; long long max_dis[N],cnt[N],total_dis[N]; inline void dp1(int x,int fa){ for(int j=head[x];j;j=next[j])if(end[j]!=fa)dp1(end[j],x); cnt[x]=c[x],max_dis[x]=c[x]?0:-INF,total_dis[x]=0; for(int j=head[x];j;j=next[j])if(end[j]!=fa){ cnt[x]+=cnt[end[j]]; max_dis[x]=max(max_dis[x],max_dis[end[j]]+len[j]); total_dis[x]+=total_dis[end[j]]+(cnt[end[j]]?1ll:0ll)*len[j]; } } long long re[N]; long long seq[N],pre[N],suc[N],g[N]; inline void dp2(int x,int fa,int fadis){ long long real_total_dis,real_max_dis; if(x==1)real_total_dis=total_dis[x],real_max_dis=max_dis[x]; else{ real_total_dis=total_dis[x]+total_dis[fa]+(cnt[fa]?1ll:0ll)*fadis; real_max_dis=max(max_dis[x],max_dis[fa]+fadis); } re[x]=real_total_dis*2-real_max_dis; int num=0; for(int j=head[x];j;j=next[j])if(end[j]!=fa)seq[++num]=max_dis[end[j]]+len[j]; pre[0]=suc[num+1]=-INF; for(int i=1;i<=num;++i)pre[i]=max(seq[i],pre[i-1]); for(int i=num;i>=1;--i)suc[i]=max(seq[i],suc[i+1]); num=0; for(int j=head[x];j;j=next[j])if(end[j]!=fa){ ++num; g[end[j]]=max(pre[num-1],suc[num+1]); } int real_cnt=cnt[x]; for(int j=head[x];j;j=next[j])if(end[j]!=fa){ total_dis[x]=real_total_dis-total_dis[end[j]]; if(cnt[end[j]])total_dis[x]-=len[j]; max_dis[x]=g[end[j]]; if(x!=1)max_dis[x]=max(max_dis[x],max_dis[fa]+fadis); if(c[x])max_dis[x]=max(max_dis[x],0ll); cnt[x]=real_cnt-cnt[end[j]]+cnt[fa]; dp2(end[j],x,len[j]); } } int main(){ int n,m;scanf("%d%d",&n,&m);register int i,j; if(m==0){puts("0");return 0;} int a,b,x; for(i=1;i<n;++i)scanf("%d%d%d",&a,&b,&x),make(a,b,x); while(m--)scanf("%d",&x),c[x]=1; dp1(1,-1); dp2(1,-1,0); for(i=1;i<=n;++i)printf("%lld\n",re[i]); return 0; }