思路:
首先发现如果根确定的话,我们要求的答案就是根和关键点们组成的虚树的边长之和的2倍减去根到所有关键点的距离最大值.
这个应该不难理解.
然后觉得虚树很难处理,就抛弃虚树,弄了一个繁琐至极的两遍dp.(这个我就不写是怎么dp的啦,大家应该都会.)
然后看到某个题解说最上面说的那个东西bfs就行.
如何呢?
一个关键的性质:树上距离一个点最远的点必定是直径的一端!这个想想很显然.
然后就简单了.
下面是我的傻叉dp:
#include<cstdio>
#include<cstring>
#include<cctype>
#include<iostream>
#include<algorithm>
using namespace std;
#define INF 1ll<<59
#define N 500010
int head[N],next[N<<1],end[N<<1],len[N<<1];
inline void addedge(int a,int b,int c){static int q=1;end[q]=b,next[q]=head[a],head[a]=q,len[q++]=c;}
inline void make(int a,int b,int c){addedge(a,b,c),addedge(b,a,c);}
bool c[N];
long long max_dis[N],cnt[N],total_dis[N];
inline void dp1(int x,int fa){
for(int j=head[x];j;j=next[j])if(end[j]!=fa)dp1(end[j],x);
cnt[x]=c[x],max_dis[x]=c[x]?0:-INF,total_dis[x]=0;
for(int j=head[x];j;j=next[j])if(end[j]!=fa){
cnt[x]+=cnt[end[j]];
max_dis[x]=max(max_dis[x],max_dis[end[j]]+len[j]);
total_dis[x]+=total_dis[end[j]]+(cnt[end[j]]?1ll:0ll)*len[j];
}
}
long long re[N];
long long seq[N],pre[N],suc[N],g[N];
inline void dp2(int x,int fa,int fadis){
long long real_total_dis,real_max_dis;
if(x==1)real_total_dis=total_dis[x],real_max_dis=max_dis[x];
else{
real_total_dis=total_dis[x]+total_dis[fa]+(cnt[fa]?1ll:0ll)*fadis;
real_max_dis=max(max_dis[x],max_dis[fa]+fadis);
}
re[x]=real_total_dis*2-real_max_dis;
int num=0;
for(int j=head[x];j;j=next[j])if(end[j]!=fa)seq[++num]=max_dis[end[j]]+len[j];
pre[0]=suc[num+1]=-INF;
for(int i=1;i<=num;++i)pre[i]=max(seq[i],pre[i-1]);
for(int i=num;i>=1;--i)suc[i]=max(seq[i],suc[i+1]);
num=0;
for(int j=head[x];j;j=next[j])if(end[j]!=fa){
++num;
g[end[j]]=max(pre[num-1],suc[num+1]);
}
int real_cnt=cnt[x];
for(int j=head[x];j;j=next[j])if(end[j]!=fa){
total_dis[x]=real_total_dis-total_dis[end[j]];
if(cnt[end[j]])total_dis[x]-=len[j];
max_dis[x]=g[end[j]];
if(x!=1)max_dis[x]=max(max_dis[x],max_dis[fa]+fadis);
if(c[x])max_dis[x]=max(max_dis[x],0ll);
cnt[x]=real_cnt-cnt[end[j]]+cnt[fa];
dp2(end[j],x,len[j]);
}
}
int main(){
int n,m;scanf("%d%d",&n,&m);register int i,j;
if(m==0){puts("0");return 0;}
int a,b,x;
for(i=1;i<n;++i)scanf("%d%d%d",&a,&b,&x),make(a,b,x);
while(m--)scanf("%d",&x),c[x]=1;
dp1(1,-1);
dp2(1,-1,0);
for(i=1;i<=n;++i)printf("%lld\n",re[i]);
return 0;
}