题目描述
给你一个数组 nums
表示 1
到 n
的一个排列。我们按照元素在 nums
中的顺序依次插入一个初始为空的二叉搜索树(BST)。请你统计将 nums
重新排序后,统计满足如下条件的方案数:重排后得到的二叉搜索树与 nums
原本数字顺序得到的二叉搜索树相同。
比方说,给你 nums = [2,1,3]
,我们得到一棵 2 为根,1 为左孩子,3 为右孩子的树。数组 [2,3,1]
也能得到相同的 BST,但 [3,2,1]
会得到一棵不同的 BST 。
请你返回重排 nums
后,与原数组 nums
得到相同二叉搜索树的方案数。
由于答案可能会很大,请将结果对 10^9 + 7
取余数。
示例 1:
输入:nums = [2,1,3]
输出:1
解释:我们将 nums 重排, [2,3,1] 能得到相同的 BST 。没有其他得到相同 BST 的方案了。
示例 2:
输入:nums = [3,4,5,1,2]
输出:5
解释:下面 5 个数组会得到相同的 BST:
[3,1,2,4,5]
[3,1,4,2,5]
[3,1,4,5,2]
[3,4,1,2,5]
[3,4,1,5,2]
示例 3:
输入:nums = [1,2,3]
输出:0
解释:没有别的排列顺序能得到相同的 BST 。
提示:
1 <= nums.length <= 1000
1 <= nums[i] <= nums.length
nums
中所有数 互不相同 。
解法
方法一:组合计数 + 递归
我们设计一个函数 $dfs(nums)$,它的功能是计算以 $nums$ 为节点构成的二叉搜索树的方案数。那么答案就是 $dfs(nums)-1$,因为 $dfs(nums)$ 计算的是以 $nums$ 为节点构成的二叉搜索树的方案数,而题目要求的是重排后与原数组 $nums$ 得到相同二叉搜索树的方案数,因此答案需要减去一。
接下来,我们来看一下 $dfs(nums)$ 的计算方法。
对于一个数组 $nums$,它的第一个元素是根节点,那么它的左子树的元素都小于它,右子树的元素都大于它。因此,我们可以将数组分为三部分,第一部分是根节点,第二部分是左子树的元素,记为 $left$,第三部分是右子树的元素,记为 $right$。那么,左子树的元素个数为 $m$,右子树的元素个数为 $n$,那么 $left$ 和 $right$ 的方案数分别为 $dfs(left)$ 和 $dfs(right)$。我们可以在数组 $nums$ 的 $m + n$ 个位置中选择 $m$ 个位置放置左子树的元素,剩下的 $n$ 个位置放置右子树的元素,这样就能保证重排后与原数组 $nums$ 得到相同二叉搜索树。因此,$dfs(nums)$ 的计算方法为:
$$
dfs(nums) = C_{m+n}^m \times dfs(left) \times dfs(right)
$$
其中 $C_{m+n}^m$ 表示从 $m + n$ 个位置中选择 $m$ 个位置的方案数,我们可以通过预处理得到。
注意答案的取模运算,因为 $dfs(nums)$ 的值可能会很大,所以我们需要在计算过程中对每一步的结果取模,最后再对整个结果取模。
时间复杂度 $O(n^2)$,空间复杂度 $O(n^2)$。其中 $n$ 是数组 $nums$ 的长度。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 | class Solution:
def numOfWays(self, nums: List[int]) -> int:
def dfs(nums):
if len(nums) < 2:
return 1
left = [x for x in nums if x < nums[0]]
right = [x for x in nums if x > nums[0]]
m, n = len(left), len(right)
a, b = dfs(left), dfs(right)
return (((c[m + n][m] * a) % mod) * b) % mod
n = len(nums)
mod = 10**9 + 7
c = [[0] * n for _ in range(n)]
c[0][0] = 1
for i in range(1, n):
c[i][0] = 1
for j in range(1, i + 1):
c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod
return (dfs(nums) - 1 + mod) % mod
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39 | class Solution {
private int[][] c;
private final int mod = (int) 1e9 + 7;
public int numOfWays(int[] nums) {
int n = nums.length;
c = new int[n][n];
c[0][0] = 1;
for (int i = 1; i < n; ++i) {
c[i][0] = 1;
for (int j = 1; j <= i; ++j) {
c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;
}
}
List<Integer> list = new ArrayList<>();
for (int x : nums) {
list.add(x);
}
return (dfs(list) - 1 + mod) % mod;
}
private int dfs(List<Integer> nums) {
if (nums.size() < 2) {
return 1;
}
List<Integer> left = new ArrayList<>();
List<Integer> right = new ArrayList<>();
for (int x : nums) {
if (x < nums.get(0)) {
left.add(x);
} else if (x > nums.get(0)) {
right.add(x);
}
}
int m = left.size(), n = right.size();
int a = dfs(left), b = dfs(right);
return (int) ((long) a * b % mod * c[m + n][n] % mod);
}
}
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33 | class Solution {
public:
int numOfWays(vector<int>& nums) {
int n = nums.size();
const int mod = 1e9 + 7;
int c[n][n];
memset(c, 0, sizeof(c));
c[0][0] = 1;
for (int i = 1; i < n; ++i) {
c[i][0] = 1;
for (int j = 1; j <= i; ++j) {
c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;
}
}
function<int(vector<int>)> dfs = [&](vector<int> nums) -> int {
if (nums.size() < 2) {
return 1;
}
vector<int> left, right;
for (int& x : nums) {
if (x < nums[0]) {
left.push_back(x);
} else if (x > nums[0]) {
right.push_back(x);
}
}
int m = left.size(), n = right.size();
int a = dfs(left), b = dfs(right);
return c[m + n][m] * 1ll * a % mod * b % mod;
};
return (dfs(nums) - 1 + mod) % mod;
}
};
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33 | func numOfWays(nums []int) int {
n := len(nums)
const mod = 1e9 + 7
c := make([][]int, n)
for i := range c {
c[i] = make([]int, n)
}
c[0][0] = 1
for i := 1; i < n; i++ {
c[i][0] = 1
for j := 1; j <= i; j++ {
c[i][j] = (c[i-1][j] + c[i-1][j-1]) % mod
}
}
var dfs func(nums []int) int
dfs = func(nums []int) int {
if len(nums) < 2 {
return 1
}
var left, right []int
for _, x := range nums[1:] {
if x < nums[0] {
left = append(left, x)
} else {
right = append(right, x)
}
}
m, n := len(left), len(right)
a, b := dfs(left), dfs(right)
return c[m+n][m] * a % mod * b % mod
}
return (dfs(nums) - 1 + mod) % mod
}
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32 | function numOfWays(nums: number[]): number {
const n = nums.length;
const mod = 1e9 + 7;
const c = new Array(n).fill(0).map(() => new Array(n).fill(0));
c[0][0] = 1;
for (let i = 1; i < n; ++i) {
c[i][0] = 1;
for (let j = 1; j <= i; ++j) {
c[i][j] = (c[i - 1][j - 1] + c[i - 1][j]) % mod;
}
}
const dfs = (nums: number[]): number => {
if (nums.length < 2) {
return 1;
}
const left: number[] = [];
const right: number[] = [];
for (let i = 1; i < nums.length; ++i) {
if (nums[i] < nums[0]) {
left.push(nums[i]);
} else {
right.push(nums[i]);
}
}
const m = left.length;
const n = right.length;
const a = dfs(left);
const b = dfs(right);
return Number((BigInt(c[m + n][m]) * BigInt(a) * BigInt(b)) % BigInt(mod));
};
return (dfs(nums) - 1 + mod) % mod;
}
|