跳转至

3098. 求出所有子序列的能量和

题目描述

给你一个长度为 n 的整数数组 nums 和一个  整数 k 。

一个 子序列能量 定义为子序列中 任意 两个元素的差值绝对值的 最小值 。

请你返回 nums 中长度 等于 k 的 所有 子序列的 能量和 。

由于答案可能会很大,将答案对 109 + 7 取余 后返回。

 

示例 1:

输入:nums = [1,2,3,4], k = 3

输出:4

解释:

nums 中总共有 4 个长度为 3 的子序列:[1,2,3] ,[1,3,4] ,[1,2,4] 和 [2,3,4] 。能量和为 |2 - 3| + |3 - 4| + |2 - 1| + |3 - 4| = 4 。

示例 2:

输入:nums = [2,2], k = 2

输出:0

解释:

nums 中唯一一个长度为 2 的子序列是 [2,2] 。能量和为 |2 - 2| = 0 。

示例 3:

输入:nums = [4,3,-1], k = 2

输出:10

解释:

nums 总共有 3 个长度为 2 的子序列:[4,3] ,[4,-1] 和 [3,-1] 。能量和为 |4 - 3| + |4 - (-1)| + |3 - (-1)| = 10 。

 

提示:

  • 2 <= n == nums.length <= 50
  • -108 <= nums[i] <= 108
  • 2 <= k <= n

解法

方法一:记忆化搜索

由于题目涉及子序列元素的最小差值,我们不妨对数组 $\textit{nums}$ 进行排序,这样可以方便我们计算子序列元素的最小差值。

接下来,我们设计一个函数 $dfs(i, j, k, mi)$,表示当前处理到第 $i$ 个元素,上一个选取的是第 $j$ 个元素,还需要选取 $k$ 个元素,当前的最小差值为 $mi$ 时,能量和的值。那么答案就是 $dfs(0, n, k, +\infty)$。(若上一个选取的是第 $n$ 个元素,表示之前没有选取过元素)

函数 $dfs(i, j, k, mi)$ 的执行过程如下:

  • 如果 $i \geq n$,表示已经处理完了所有的元素,如果 $k = 0$,返回 $mi$,否则返回 $0$;
  • 如果剩余的元素个数 $n - i$ 不足 $k$ 个,返回 $0$;
  • 否则,我们可以选择不选取第 $i$ 个元素,可以获得的能量和为 $dfs(i + 1, j, k, mi)$;
  • 也可以选择选取第 $i$ 个元素。如果 $j = n$,表示之前没有选取过元素,那么可以获得的能量和为 $dfs(i + 1, i, k - 1, mi)$;否则,可以获得的能量和为 $dfs(i + 1, i, k - 1, \min(mi, \textit{nums}[i] - \textit{nums}[j]))$。
  • 我们累加上述结果,并对 $10^9 + 7$ 取模后返回。

为了避免重复计算,我们可以使用记忆化搜索的方法,将已经计算过的结果保存起来。

时间复杂度 $O(n^4 \times k)$,空间复杂度 $O(n^4 \times k)$。其中 $n$ 为数组的长度。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class Solution:
    def sumOfPowers(self, nums: List[int], k: int) -> int:
        @cache
        def dfs(i: int, j: int, k: int, mi: int) -> int:
            if i >= n:
                return mi if k == 0 else 0
            if n - i < k:
                return 0
            ans = dfs(i + 1, j, k, mi)
            if j == n:
                ans += dfs(i + 1, i, k - 1, mi)
            else:
                ans += dfs(i + 1, i, k - 1, min(mi, nums[i] - nums[j]))
            ans %= mod
            return ans

        mod = 10**9 + 7
        n = len(nums)
        nums.sort()
        return dfs(0, n, k, inf)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Solution {
    private Map<Long, Integer> f = new HashMap<>();
    private final int mod = (int) 1e9 + 7;
    private int[] nums;

    public int sumOfPowers(int[] nums, int k) {
        Arrays.sort(nums);
        this.nums = nums;
        return dfs(0, nums.length, k, Integer.MAX_VALUE);
    }

    private int dfs(int i, int j, int k, int mi) {
        if (i >= nums.length) {
            return k == 0 ? mi : 0;
        }
        if (nums.length - i < k) {
            return 0;
        }
        long key = (1L * mi) << 18 | (i << 12) | (j << 6) | k;
        if (f.containsKey(key)) {
            return f.get(key);
        }
        int ans = dfs(i + 1, j, k, mi);
        if (j == nums.length) {
            ans += dfs(i + 1, i, k - 1, mi);
        } else {
            ans += dfs(i + 1, i, k - 1, Math.min(mi, nums[i] - nums[j]));
        }
        ans %= mod;
        f.put(key, ans);
        return ans;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Solution {
public:
    int sumOfPowers(vector<int>& nums, int k) {
        unordered_map<long long, int> f;
        const int mod = 1e9 + 7;
        int n = nums.size();
        sort(nums.begin(), nums.end());
        auto dfs = [&](auto&& dfs, int i, int j, int k, int mi) -> int {
            if (i >= n) {
                return k == 0 ? mi : 0;
            }
            if (n - i < k) {
                return 0;
            }
            long long key = (1LL * mi) << 18 | (i << 12) | (j << 6) | k;
            if (f.contains(key)) {
                return f[key];
            }
            long long ans = dfs(dfs, i + 1, j, k, mi);
            if (j == n) {
                ans += dfs(dfs, i + 1, i, k - 1, mi);
            } else {
                ans += dfs(dfs, i + 1, i, k - 1, min(mi, nums[i] - nums[j]));
            }
            ans %= mod;
            f[key] = ans;
            return f[key];
        };
        return dfs(dfs, 0, n, k, INT_MAX);
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
func sumOfPowers(nums []int, k int) int {
    const mod int = 1e9 + 7
    sort.Ints(nums)
    n := len(nums)
    f := map[int]int{}
    var dfs func(i, j, k, mi int) int
    dfs = func(i, j, k, mi int) int {
        if i >= n {
            if k == 0 {
                return mi
            }
            return 0
        }
        if n-i < k {
            return 0
        }
        key := mi<<18 | (i << 12) | (j << 6) | k
        if v, ok := f[key]; ok {
            return v
        }
        ans := dfs(i+1, j, k, mi)
        if j == n {
            ans += dfs(i+1, i, k-1, mi)
        } else {
            ans += dfs(i+1, i, k-1, min(mi, nums[i]-nums[j]))
        }
        ans %= mod
        f[key] = ans
        return ans
    }
    return dfs(0, n, k, math.MaxInt)
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
function sumOfPowers(nums: number[], k: number): number {
    const mod = BigInt(1e9 + 7);
    nums.sort((a, b) => a - b);
    const n = nums.length;
    const f: Map<bigint, bigint> = new Map();
    function dfs(i: number, j: number, k: number, mi: number): bigint {
        if (i >= n) {
            if (k === 0) {
                return BigInt(mi);
            }
            return BigInt(0);
        }
        if (n - i < k) {
            return BigInt(0);
        }
        const key =
            (BigInt(mi) << BigInt(18)) |
            (BigInt(i) << BigInt(12)) |
            (BigInt(j) << BigInt(6)) |
            BigInt(k);
        if (f.has(key)) {
            return f.get(key)!;
        }
        let ans = dfs(i + 1, j, k, mi);
        if (j === n) {
            ans += dfs(i + 1, i, k - 1, mi);
        } else {
            ans += dfs(i + 1, i, k - 1, Math.min(mi, nums[i] - nums[j]));
        }
        ans %= mod;
        f.set(key, ans);
        return ans;
    }

    return Number(dfs(0, n, k, Number.MAX_SAFE_INTEGER));
}

评论