题目描述
给你一棵 n
个节点的 无向 树,节点编号为 0
到 n - 1
,树的根节点在节点 0
处。同时给你一个长度为 n - 1
的二维整数数组 edges
,其中 edges[i] = [ai, bi]
表示树中节点 ai
和 bi
之间有一条边。
给你一个长度为 n
下标从 0 开始的整数数组 cost
,其中 cost[i]
是第 i
个节点的 开销 。
你需要在树中每个节点都放置金币,在节点 i
处的金币数目计算方法如下:
- 如果节点
i
对应的子树中的节点数目小于 3
,那么放 1
个金币。
- 否则,计算节点
i
对应的子树内 3
个不同节点的开销乘积的 最大值 ,并在节点 i
处放置对应数目的金币。如果最大乘积是 负数 ,那么放置 0
个金币。
请你返回一个长度为 n
的数组 coin
,coin[i]
是节点 i
处的金币数目。
示例 1:
输入:edges = [[0,1],[0,2],[0,3],[0,4],[0,5]], cost = [1,2,3,4,5,6]
输出:[120,1,1,1,1,1]
解释:在节点 0 处放置 6 * 5 * 4 = 120 个金币。所有其他节点都是叶子节点,子树中只有 1 个节点,所以其他每个节点都放 1 个金币。
示例 2:
输入:edges = [[0,1],[0,2],[1,3],[1,4],[1,5],[2,6],[2,7],[2,8]], cost = [1,4,2,3,5,7,8,-4,2]
输出:[280,140,32,1,1,1,1,1,1]
解释:每个节点放置的金币数分别为:
- 节点 0 处放置 8 * 7 * 5 = 280 个金币。
- 节点 1 处放置 7 * 5 * 4 = 140 个金币。
- 节点 2 处放置 8 * 2 * 2 = 32 个金币。
- 其他节点都是叶子节点,子树内节点数目为 1 ,所以其他每个节点都放 1 个金币。
示例 3:
输入:edges = [[0,1],[0,2]], cost = [1,2,-2]
输出:[0,1,1]
解释:节点 1 和 2 都是叶子节点,子树内节点数目为 1 ,各放置 1 个金币。节点 0 处唯一的开销乘积是 2 * 1 * -2 = -4 。所以在节点 0 处放置 0 个金币。
提示:
2 <= n <= 2 * 104
edges.length == n - 1
edges[i].length == 2
0 <= ai, bi < n
cost.length == n
1 <= |cost[i]| <= 104
edges
一定是一棵合法的树。
解法
方法一:DFS + 排序
根据题目描述,每个节点 $a$ 的放置的金币数有两种情况:
- 如果节点 $a$ 对应的子树中的节点数目小于 $3$,那么放 $1$ 个金币;
- 如果节点 $a$ 对应的子树中的节点数目大于等于 $3$,那么我们需要取出子树中的 $3$ 个不同节点,计算它们的开销乘积的最大值,然后在节点 $a$ 处放置对应数目的金币,如果最大乘积是负数,那么放置 $0$ 个金币。
第一种情况比较简单,我们只需要在遍历的过程中,统计每个节点的子树中的节点数目即可。
而对于第二种情况,如果开销都是正数,那么应该取开销最大的 $3$ 个节点;如果开销中有负数,那么应该取开销最小的 $2$ 个节点和开销最大的 $1$ 个节点。因此,我们需要维护每个子树最小的 $2$ 个开销和最大的 $3$ 个开销。
我们先根据题目给定的二维数组 $edges$ 构建邻接表 $g$,其中 $g[a]$ 表示节点 $a$ 的所有邻居节点。
接下来,我们设计一个函数 $dfs(a, fa)$,该函数返回一个数组 $res$,其中存储了节点 $a$ 的子树中最小的 $2$ 个开销和最大的 $3$ 个开销(可能不足 $5$ 个)。
在函数 $dfs(a, fa)$ 中,我们将节点 $a$ 的开销 $cost[a]$ 加入数组 $res$ 中,然后遍历节点 $a$ 的所有邻居节点 $b$,如果 $b$ 不是节点 $a$ 的父节点 $fa$,那么我们将 $dfs(b, a)$ 的结果加入数组 $res$ 中。
然后,我们对数组 $res$ 进行排序,然后根据数组 $res$ 的长度 $m$ 计算节点 $a$ 的放置金币数目,更新 $ans[a]$:
- 如果 $m \ge 3$,那么节点 $a$ 的放置金币数目为 $\max(0, res[m - 1] \times res[m - 2] \times res[m - 3], res[0] \times res[1] \times res[m - 1])$,否则节点 $a$ 的放置金币数目为 $1$;
- 如果 $m > 5$,那么我们只需要保留数组 $res$ 的前 $2$ 个元素和后 $3$ 个元素。
最后,我们调用函数 $dfs(0, -1)$,并且返回答案数组 $ans$ 即可。
时间复杂度 $O(n \times \log n)$,空间复杂度 $O(n)$。其中 $n$ 是节点的数目。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 | class Solution:
def placedCoins(self, edges: List[List[int]], cost: List[int]) -> List[int]:
def dfs(a: int, fa: int) -> List[int]:
res = [cost[a]]
for b in g[a]:
if b != fa:
res.extend(dfs(b, a))
res.sort()
if len(res) >= 3:
ans[a] = max(res[-3] * res[-2] * res[-1], res[0] * res[1] * res[-1], 0)
if len(res) > 5:
res = res[:2] + res[-3:]
return res
n = len(cost)
g = [[] for _ in range(n)]
for a, b in edges:
g[a].append(b)
g[b].append(a)
ans = [1] * n
dfs(0, -1)
return ans
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42 | class Solution {
private int[] cost;
private List<Integer>[] g;
private long[] ans;
public long[] placedCoins(int[][] edges, int[] cost) {
int n = cost.length;
this.cost = cost;
ans = new long[n];
g = new List[n];
Arrays.fill(ans, 1);
Arrays.setAll(g, i -> new ArrayList<>());
for (int[] e : edges) {
int a = e[0], b = e[1];
g[a].add(b);
g[b].add(a);
}
dfs(0, -1);
return ans;
}
private List<Integer> dfs(int a, int fa) {
List<Integer> res = new ArrayList<>();
res.add(cost[a]);
for (int b : g[a]) {
if (b != fa) {
res.addAll(dfs(b, a));
}
}
Collections.sort(res);
int m = res.size();
if (m >= 3) {
long x = (long) res.get(m - 1) * res.get(m - 2) * res.get(m - 3);
long y = (long) res.get(0) * res.get(1) * res.get(m - 1);
ans[a] = Math.max(0, Math.max(x, y));
}
if (m >= 5) {
res = List.of(res.get(0), res.get(1), res.get(m - 3), res.get(m - 2), res.get(m - 1));
}
return res;
}
}
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35 | class Solution {
public:
vector<long long> placedCoins(vector<vector<int>>& edges, vector<int>& cost) {
int n = cost.size();
vector<long long> ans(n, 1);
vector<int> g[n];
for (auto& e : edges) {
int a = e[0], b = e[1];
g[a].push_back(b);
g[b].push_back(a);
}
function<vector<int>(int, int)> dfs = [&](int a, int fa) -> vector<int> {
vector<int> res = {cost[a]};
for (int b : g[a]) {
if (b != fa) {
auto t = dfs(b, a);
res.insert(res.end(), t.begin(), t.end());
}
}
sort(res.begin(), res.end());
int m = res.size();
if (m >= 3) {
long long x = 1LL * res[m - 1] * res[m - 2] * res[m - 3];
long long y = 1LL * res[0] * res[1] * res[m - 1];
ans[a] = max({0LL, x, y});
}
if (m >= 5) {
res = {res[0], res[1], res[m - 1], res[m - 2], res[m - 3]};
}
return res;
};
dfs(0, -1);
return ans;
}
};
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35 | func placedCoins(edges [][]int, cost []int) []int64 {
n := len(cost)
g := make([][]int, n)
for _, e := range edges {
a, b := e[0], e[1]
g[a] = append(g[a], b)
g[b] = append(g[b], a)
}
ans := make([]int64, n)
for i := range ans {
ans[i] = int64(1)
}
var dfs func(a, fa int) []int
dfs = func(a, fa int) []int {
res := []int{cost[a]}
for _, b := range g[a] {
if b != fa {
res = append(res, dfs(b, a)...)
}
}
sort.Ints(res)
m := len(res)
if m >= 3 {
x := res[m-1] * res[m-2] * res[m-3]
y := res[0] * res[1] * res[m-1]
ans[a] = max(0, int64(x), int64(y))
}
if m >= 5 {
res = append(res[:2], res[m-3:]...)
}
return res
}
dfs(0, -1)
return ans
}
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 | function placedCoins(edges: number[][], cost: number[]): number[] {
const n = cost.length;
const ans: number[] = Array(n).fill(1);
const g: number[][] = Array.from({ length: n }, () => []);
for (const [a, b] of edges) {
g[a].push(b);
g[b].push(a);
}
const dfs = (a: number, fa: number): number[] => {
const res: number[] = [cost[a]];
for (const b of g[a]) {
if (b !== fa) {
res.push(...dfs(b, a));
}
}
res.sort((a, b) => a - b);
const m = res.length;
if (m >= 3) {
const x = res[m - 1] * res[m - 2] * res[m - 3];
const y = res[0] * res[1] * res[m - 1];
ans[a] = Math.max(0, x, y);
}
if (m > 5) {
res.splice(2, m - 5);
}
return res;
};
dfs(0, -1);
return ans;
}
|