Skip to content

3186. Maximum Total Damage With Spell Casting

Description

A magician has various spells.

You are given an array power, where each element represents the damage of a spell. Multiple spells can have the same damage value.

It is a known fact that if a magician decides to cast a spell with a damage of power[i], they cannot cast any spell with a damage of power[i] - 2, power[i] - 1, power[i] + 1, or power[i] + 2.

Each spell can be cast only once.

Return the maximum possible total damage that a magician can cast.

 

Example 1:

Input: power = [1,1,3,4]

Output: 6

Explanation:

The maximum possible damage of 6 is produced by casting spells 0, 1, 3 with damage 1, 1, 4.

Example 2:

Input: power = [7,1,6,6]

Output: 13

Explanation:

The maximum possible damage of 13 is produced by casting spells 1, 2, 3 with damage 1, 6, 6.

 

Constraints:

  • 1 <= power.length <= 105
  • 1 <= power[i] <= 109

Solutions

Solution 1: Binary Search + Memoization

We can first sort the array $\textit{power}$, use a hash table $\textit{cnt}$ to record the occurrence count of each damage value, and then iterate through the array $\textit{power}$. For each damage value $x$, we can determine the index of the next damage value that can be used when using a spell with damage value $x$, which is the index of the first damage value greater than $x + 2$. We can use binary search to find this index and record it in the array $\textit{nxt}$.

Next, we define a function $\textit{dfs}$ to calculate the maximum damage value that can be obtained starting from the $i$-th damage value.

In the $\textit{dfs}$ function, we can choose to skip the current damage value, so we can skip all the same damage values of the current one and directly jump to $i + \textit{cnt}[x]$, obtaining a damage value of $\textit{dfs}(i + \textit{cnt}[x])$; or we can choose to use the current damage value, so we can use all the same damage values of the current one and then jump to the index of the next damage value, obtaining a damage value of $x \times \textit{cnt}[x] + \textit{dfs}(\textit{nxt}[i])$, where $\textit{nxt}[i]$ represents the index of the first damage value greater than $x + 2$. We take the maximum of these two cases as the return value of the function.

To avoid repeated calculations, we can use memoization, storing the results that have already been calculated in an array $\textit{f}$. Thus, when calculating $\textit{dfs}(i)$, if $\textit{f}[i]$ is not $0$, we directly return $\textit{f}[i]$.

The answer is $\textit{dfs}(0)$.

The time complexity is $O(n \log n)$, and the space complexity is $O(n)$. Here, $n$ is the length of the array $\textit{power}$.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
class Solution:
    def maximumTotalDamage(self, power: List[int]) -> int:
        @cache
        def dfs(i: int) -> int:
            if i >= n:
                return 0
            a = dfs(i + cnt[power[i]])
            b = power[i] * cnt[power[i]] + dfs(nxt[i])
            return max(a, b)

        n = len(power)
        cnt = Counter(power)
        power.sort()
        nxt = [bisect_right(power, x + 2, lo=i + 1) for i, x in enumerate(power)]
        return dfs(0)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class Solution {
    private Long[] f;
    private int[] power;
    private Map<Integer, Integer> cnt;
    private int[] nxt;
    private int n;

    public long maximumTotalDamage(int[] power) {
        Arrays.sort(power);
        this.power = power;
        n = power.length;
        f = new Long[n];
        cnt = new HashMap<>(n);
        nxt = new int[n];
        for (int i = 0; i < n; ++i) {
            cnt.merge(power[i], 1, Integer::sum);
            int l = Arrays.binarySearch(power, power[i] + 3);
            l = l < 0 ? -l - 1 : l;
            nxt[i] = l;
        }
        return dfs(0);
    }

    private long dfs(int i) {
        if (i >= n) {
            return 0;
        }
        if (f[i] != null) {
            return f[i];
        }
        long a = dfs(i + cnt.get(power[i]));
        long b = 1L * power[i] * cnt.get(power[i]) + dfs(nxt[i]);
        return f[i] = Math.max(a, b);
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Solution {
public:
    long long maximumTotalDamage(vector<int>& power) {
        sort(power.begin(), power.end());
        this->power = power;
        n = power.size();
        f.resize(n);
        nxt.resize(n);
        for (int i = 0; i < n; ++i) {
            cnt[power[i]]++;
            nxt[i] = upper_bound(power.begin() + i + 1, power.end(), power[i] + 2) - power.begin();
        }
        return dfs(0);
    }

private:
    unordered_map<int, int> cnt;
    vector<long long> f;
    vector<int> power;
    vector<int> nxt;
    int n;

    long long dfs(int i) {
        if (i >= n) {
            return 0;
        }
        if (f[i]) {
            return f[i];
        }
        long long a = dfs(i + cnt[power[i]]);
        long long b = 1LL * power[i] * cnt[power[i]] + dfs(nxt[i]);
        return f[i] = max(a, b);
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
func maximumTotalDamage(power []int) int64 {
    n := len(power)
    sort.Ints(power)
    cnt := map[int]int{}
    nxt := make([]int, n)
    f := make([]int64, n)
    for i, x := range power {
        cnt[x]++
        nxt[i] = sort.SearchInts(power, x+3)
    }
    var dfs func(int) int64
    dfs = func(i int) int64 {
        if i >= n {
            return 0
        }
        if f[i] != 0 {
            return f[i]
        }
        a := dfs(i + cnt[power[i]])
        b := int64(power[i]*cnt[power[i]]) + dfs(nxt[i])
        f[i] = max(a, b)
        return f[i]
    }
    return dfs(0)
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
function maximumTotalDamage(power: number[]): number {
    const n = power.length;
    power.sort((a, b) => a - b);
    const f: number[] = Array(n).fill(0);
    const cnt: Record<number, number> = {};
    const nxt: number[] = Array(n).fill(0);
    for (let i = 0; i < n; ++i) {
        cnt[power[i]] = (cnt[power[i]] || 0) + 1;
        let [l, r] = [i + 1, n];
        while (l < r) {
            const mid = (l + r) >> 1;
            if (power[mid] > power[i] + 2) {
                r = mid;
            } else {
                l = mid + 1;
            }
        }
        nxt[i] = l;
    }
    const dfs = (i: number): number => {
        if (i >= n) {
            return 0;
        }
        if (f[i]) {
            return f[i];
        }
        const a = dfs(i + cnt[power[i]]);
        const b = power[i] * cnt[power[i]] + dfs(nxt[i]);
        return (f[i] = Math.max(a, b));
    };
    return dfs(0);
}

Comments