这道题是Sums of Distances in Tree的带权版本,那么做法仍然是类似的。我们选择一个节点root当根,算出其他节点到root的距离,和每个以节点i为根的子树的大小size[i]。我们用dp[i]表示其他所有节点到i的距离和。那么我们有递推公式:
- dp[i] = dp[j] - size[i] * w(i, j) + (n - size[i]) * w(i, j), where j is i's parent
- 以i为根的子树的节点和i的距离比和j的距离减小了size[i] * w(i, j)
- 除了在i的子树中的节点和i的距离比和j的距离增加了(n - size[i]) * w(i, j)
其中n为所有节点的数量,w(i, j)表示连接i和j的边的权值。然我们dp是树形从上到下的dfs顺序。时间复杂度和空间复杂度均为O(n),代码如下:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Solution { | |
public: | |
/** | |
* @param x: The end points set of edges | |
* @param y: The end points set of edges | |
* @param d: The length of edges | |
* @return: Return the index of the fermat point | |
*/ | |
int getFermatPoint(vector<int> &x, vector<int> &y, vector<int> &d) | |
{ | |
//acyclic connected graph, V = E + 1 | |
int n = x.size() + 1; | |
vector<vector<pair<int, int>>> g(n + 1); | |
for (int i = 0; i < x.size(); ++i) | |
{ | |
int u = x[i], v = y[i], w = d[i]; | |
g[u].emplace_back(v, w); | |
g[v].emplace_back(u, w); | |
} | |
vector<long long> dp(n + 1, 0); | |
vector<int> sz(n + 1, 0); | |
countSubTree(0, 1, g, dp, sz); | |
long long minDist = dp[1]; | |
int idx = 1; | |
findPoint(0, 1, g, sz, dp[1], idx, minDist); | |
return idx; | |
} | |
private: | |
void countSubTree(int from, int to, vector<vector<pair<int, int>>>& g, vector<long long>& dp, vector<int>& sz) | |
{ | |
sz[to] = 1; | |
for (const auto& adj : g[to]) | |
{ | |
int v = adj.first; | |
long long w = adj.second; | |
if (v != from) | |
{ | |
countSubTree(to, v, g, dp, sz); | |
sz[to] += sz[v]; | |
dp[to] += dp[v] + sz[v] * w; | |
} | |
} | |
} | |
void findPoint(int from, int to, vector<vector<pair<int, int>>>& g, vector<int>& sz, long long dist, int& res, long long& minDist) | |
{ | |
int n = sz.size() - 1; | |
for (const auto& adj : g[to]) | |
{ | |
int v = adj.first; | |
long long w = adj.second; | |
if (v != from) | |
{ | |
long long currDist = dist - sz[v] * w + (n - sz[v]) * w; | |
if (currDist < minDist) | |
{ | |
minDist = currDist; | |
res = v; | |
} | |
findPoint(to, v, g, sz, currDist, res, minDist); | |
} | |
} | |
} | |
}; |
No comments:
Post a Comment