CodeForces 294E Shaass the Great:极其变态的树形 DP 与思维题

题目链接:CodeForces 294E Shaass the Great
这题真的太麻烦了……

Problem

The great Shaass is the new king of the Drakht empire. The empire has n cities which are connected by n - 1 bidirectional roads. Each road has an specific length and connects a pair of cities. There’s a unique simple path connecting each pair of cities.

His majesty the great Shaass has decided to tear down one of the roads and build another road with the same length between some pair of cities. He should build such road that it’s still possible to travel from each city to any other city. He might build the same road again.

You as his advisor should help him to find a way to make the described action. You should find the way that minimize the total sum of pairwise distances between cities after the action. So calculate the minimum sum.

Input

The first line of the input contains an integer n denoting the number of cities in the empire, (2 ≤ n ≤ 5000). The next n - 1 lines each contains three integers a_i, b_i and w_i showing that two cities a_i and b_i are connected using a road of length w_i, (1 ≤ a_i, b_i ≤ n, a_i ≠ b_i, 1 ≤ w_i ≤ 10^6).

Output

On the only line of the output print the minimum pairwise sum of distances between the cities.

Please do not use the %lld specificator to read or write 64-bit integers in C++. It is preferred to use the cin, cout streams or the %I64d specificator.

Examples

Input #1

3
1 2 2
1 3 4

Output #1

12

Input #2

6
1 2 1
2 3 1
3 4 1
4 5 1
5 6 1

Output #2

29

Input #3

6
1 3 1
2 3 1
3 4 100
4 5 2
4 6 1

Output #3

825

Translation

给你一颗树,现在你要移除其中一条边,选两个点,在其之间加上一条边权相等的边,并且不能成环,使得新的树中两两点之间距离之和最小。(也就是 \displaystyle \frac {\sum_{i=1,j=1}^{i=n,j=n,i \not = j} dist(i,j)} 2 最小)

Analysis

这题规定了数据范围:(2 ≤ n ≤ 5000)……说明必须 N^2 搞出来,不能乱搞了。首先我们要枚举删除的边,假设枚举到的删除的边是 e_k,那么删除这条边以后形成了两棵树,如何建立新的边(与原来边权相等)连接这两棵树,使得连接后两两点之间距离最小呢?

仔细思考可以发现,对于两棵分开的树中任意一棵,我们要找到点 p_i 使得树上所有点到这个点的距离之和最小。显然,p_i 是重心!说明只要连接两棵树的重心就可以了!

接下来我们需要计算出两棵树里两两点对距离之和,可以用 DFS,十分麻烦,相当于一个树形 DP 去维护……(详见代码)假设算出的两树结果是 total \_ sum \_ lefttotal\_ sum\_ right
我们还要得到这两棵树结点数量(不然重心没法求),设为:num\_ leftnum\_ right;两棵树重心设为 cen\_ leftcen\_ right,现在删除的边长度是 len
我们最后要求出左边的树上所有点到重心距离之和,右边树上所有点到重心距离之和,记为:sum\_ leftsum\_ right
左边的树记为 Tree_{left},右边的树记为 Tree_{right}

重头戏来了,那么当前答案怎么求呢?首先左边所有点要经过我们新加上的两重心之间的边到达右边的点,右边的所有点亦然;不难得出答案:

\sum_{i\in Tree_{left}}^{} (\sum_{j\in Tree_{right}} dist(i,cen\_ left)+len+dist(cen\_right,j))

化简得到:

num\_left\ast sum\_right + num\_right \ast sum\_left+num\_left \ast num\_right \ast len + total\_ sum\_left + total\_sum\_right

最后挑出最小值就是答案了!

Code

猜猜这题我改了多久?_(´ཀ`」 ∠)_

#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
const int maxn=5005,maxe=10005;
int n,sum_son[maxn],max_sub[maxn],c[2],tmp[2],flag[maxn],num[2],sum_dfs=0;
long long ans=(long long)1<<60,total_sum[2],lst[maxn],lstw[maxn];
int tot=0,lnk[maxn],nxt[maxe],son[maxe],w[maxe],dst[2][maxn];
bool can_use[maxe],vis[maxn];
struct EdgeData{
    int x,y,w;
}e[maxn];
inline void add(int x,int y,int z){
    tot++;son[tot]=y;w[tot]=z;nxt[tot]=lnk[x];lnk[x]=tot;
}
inline int read(){
    int ret=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9') ret=ret*10+ch-'0',ch=getchar();
    return ret*f;
}
inline int BuildNumber(int x){ // 先要求出两个联通块点的个数,不然后面求重心没法做……
    // 我是拒绝的 qwq
    vis[x]=true;
    int ret=1;
    for (int i=lnk[x];i;i=nxt[i]) if (!vis[son[i]]&&can_use[i]) ret+=BuildNumber(son[i]);
    return ret;
}
inline void Build(int x,int k){ // 求重心……
    flag[x]=k;
    vis[x]=true;
    max_sub[x]=0;sum_son[x]=0;
    for (int i=lnk[x];i;i=nxt[i]) if (!vis[son[i]]&&can_use[i]){
        Build(son[i],k);
        max_sub[x]=max(max_sub[x],max(max_sub[son[i]],sum_son[son[i]]+1));
        sum_son[x]+=sum_son[son[i]]+1;
    }
    if (max(num[k]-sum_son[x],max_sub[x])<tmp[k]) tmp[k]=max(num[k]-sum_son[x],max_sub[x]),c[k]=x;
}
inline void GetDist(int x,int k){
    vis[x]=true; 
    for (int i=lnk[x];i;i=nxt[i]) if (!vis[son[i]]&&flag[son[i]]==k){
        dst[k][son[i]]=dst[k][x]+w[i];
        GetDist(son[i],k);
    }
}
inline int GetSum(int x,int k){ // 还要写个构造联通块内两两节点距离之和的函数 (吐血)
    sum_dfs++;
    vis[x]=1;
    total_sum[k]+=lst[x];
    int van=0;
    for (int i=lnk[x];i;i=nxt[i]) if (!vis[son[i]]&&can_use[i]){
        lst[son[i]]=lst[x]+(long long)sum_dfs*w[i];
        int nowvan=GetSum(son[i],k);
        lst[x]+=lstw[son[i]]+(long long)(nowvan)*w[i];     // !!!!!!!!!!
        lstw[x]+=lstw[son[i]]+(long long)(nowvan)*w[i];
        van+=nowvan;
    }
    van++;
    return van;
}
int main(){
    n=read();
    for (int i=1;i<n;i++){
        int x=read(),y=read(),z=read();
        add(x,y,z);add(y,x,z);
        e[i].x=x;e[i].y=y;e[i].w=z;
    }
    memset(can_use,true,sizeof(can_use));
    for (int i=1;i<n;i++){
        memset(c,0,sizeof(c));
        memset(tmp,63,sizeof(tmp));
        memset(dst,0,sizeof(dst));
        memset(flag,255,sizeof(flag));
        memset(sum_son,0,sizeof(sum_son));
        memset(max_sub,0,sizeof(max_sub));

        long long cen_left=0,cen_right=0,sum_left=0,sum_right=0,num_left=0,num_right=0;
        can_use[i*2]=can_use[i*2-1]=false;

        memset(vis,0,sizeof(vis));
        num_left=num[0]=BuildNumber(e[i].x);num_right=num[1]=BuildNumber(e[i].y);
        memset(vis,0,sizeof(vis));
        Build(e[i].x,0);
        Build(e[i].y,1);
        cen_left=c[0];cen_right=c[1];
        //printf("Cen_Left: %d  Cen_Right: %d  Num_Left: %d  Num_Right: %d\n",cen_left,cen_right,num_left,num_right);

        memset(vis,0,sizeof(vis));
        GetDist(cen_left,0);
        GetDist(cen_right,1);
        for (int j=1;j<=n;j++){
            if (flag[j]==0) sum_left +=dst[0][j]; else
            if (flag[j]==1) sum_right+=dst[1][j];
        }

        memset(total_sum,0,sizeof(total_sum));
        memset(vis,0,sizeof(vis));
        memset(lst,0,sizeof(lst));
        memset(lstw,0,sizeof(lstw));
        sum_dfs=0;GetSum(e[i].x,0);
        sum_dfs=0;GetSum(e[i].y,1);
        //printf("TOTSUM0: %d  TOTSUM1: %d\n",total_sum[0],total_sum[1]);
        long long now=(long long)num_right*sum_left+(long long)num_left*sum_right+(long long)num_left*num_right*e[i].w+total_sum[0]+total_sum[1];
        //printf("Result: %d\n",now);
        if (now<ans) ans=now;
        can_use[i*2]=can_use[i*2-1]=true;
    }
    printf("%lld\n",ans);
    return 0;
}

本文采用 BY-NC-SA 4.0 协议,欢迎转载。如有错误欢迎指出。
本文链接: https://skywt.cn/posts/cf294e/

发表评论

电子邮件地址不会被公开。 必填项已用*标注

相关文章

开始在上面输入您的搜索词,然后按回车进行搜索。按ESC取消。