思路:
首先发现如果根确定的话,我们要求的答案就是根和关键点们组成的虚树的边长之和的2倍减去根到所有关键点的距离最大值.
这个应该不难理解.
然后觉得虚树很难处理,就抛弃虚树,弄了一个繁琐至极的两遍dp.(这个我就不写是怎么dp的啦,大家应该都会.)
然后看到某个题解说最上面说的那个东西bfs就行.
如何呢?
一个关键的性质:树上距离一个点最远的点必定是直径的一端!这个想想很显然.
然后就简单了.
下面是我的傻叉dp:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | #include<cstdio> #include<cstring> #include<cctype> #include<iostream> #include<algorithm> using namespace std; #define INF 1ll<<59 #define N 500010 int head[N],next[N<<1],end[N<<1],len[N<<1]; inline void addedge( int a, int b, int c){ static int q=1;end[q]=b,next[q]=head[a],head[a]=q,len[q++]=c;} inline void make( int a, int b, int c){addedge(a,b,c),addedge(b,a,c);} bool c[N]; long long max_dis[N],cnt[N],total_dis[N]; inline void dp1( int x, int fa){ for ( int j=head[x];j;j=next[j]) if (end[j]!=fa)dp1(end[j],x); cnt[x]=c[x],max_dis[x]=c[x]?0:-INF,total_dis[x]=0; for ( int j=head[x];j;j=next[j]) if (end[j]!=fa){ cnt[x]+=cnt[end[j]]; max_dis[x]=max(max_dis[x],max_dis[end[j]]+len[j]); total_dis[x]+=total_dis[end[j]]+(cnt[end[j]]?1ll:0ll)*len[j]; } } long long re[N]; long long seq[N],pre[N],suc[N],g[N]; inline void dp2( int x, int fa, int fadis){ long long real_total_dis,real_max_dis; if (x==1)real_total_dis=total_dis[x],real_max_dis=max_dis[x]; else { real_total_dis=total_dis[x]+total_dis[fa]+(cnt[fa]?1ll:0ll)*fadis; real_max_dis=max(max_dis[x],max_dis[fa]+fadis); } re[x]=real_total_dis*2-real_max_dis; int num=0; for ( int j=head[x];j;j=next[j]) if (end[j]!=fa)seq[++num]=max_dis[end[j]]+len[j]; pre[0]=suc[num+1]=-INF; for ( int i=1;i<=num;++i)pre[i]=max(seq[i],pre[i-1]); for ( int i=num;i>=1;--i)suc[i]=max(seq[i],suc[i+1]); num=0; for ( int j=head[x];j;j=next[j]) if (end[j]!=fa){ ++num; g[end[j]]=max(pre[num-1],suc[num+1]); } int real_cnt=cnt[x]; for ( int j=head[x];j;j=next[j]) if (end[j]!=fa){ total_dis[x]=real_total_dis-total_dis[end[j]]; if (cnt[end[j]])total_dis[x]-=len[j]; max_dis[x]=g[end[j]]; if (x!=1)max_dis[x]=max(max_dis[x],max_dis[fa]+fadis); if (c[x])max_dis[x]=max(max_dis[x],0ll); cnt[x]=real_cnt-cnt[end[j]]+cnt[fa]; dp2(end[j],x,len[j]); } } int main(){ int n,m; scanf ( "%d%d" ,&n,&m); register int i,j; if (m==0){ puts ( "0" ); return 0;} int a,b,x; for (i=1;i<n;++i) scanf ( "%d%d%d" ,&a,&b,&x),make(a,b,x); while (m--) scanf ( "%d" ,&x),c[x]=1; dp1(1,-1); dp2(1,-1,0); for (i=1;i<=n;++i) printf ( "%lld\n" ,re[i]); return 0; } |