Skip to content

3165. Maximum Sum of Subsequence With Non-adjacent Elements

Description

You are given an array nums consisting of integers. You are also given a 2D array queries, where queries[i] = [posi, xi].

For query i, we first set nums[posi] equal to xi, then we calculate the answer to query i which is the maximum sum of a subsequence of nums where no two adjacent elements are selected.

Return the sum of the answers to all queries.

Since the final answer may be very large, return it modulo 109 + 7.

A subsequence is an array that can be derived from another array by deleting some or no elements without changing the order of the remaining elements.

 

Example 1:

Input: nums = [3,5,9], queries = [[1,-2],[0,-3]]

Output: 21

Explanation:
After the 1st query, nums = [3,-2,9] and the maximum sum of a subsequence with non-adjacent elements is 3 + 9 = 12.
After the 2nd query, nums = [-3,-2,9] and the maximum sum of a subsequence with non-adjacent elements is 9.

Example 2:

Input: nums = [0,-1], queries = [[0,-5]]

Output: 0

Explanation:
After the 1st query, nums = [-5,-1] and the maximum sum of a subsequence with non-adjacent elements is 0 (choosing an empty subsequence).

 

Constraints:

  • 1 <= nums.length <= 5 * 104
  • -105 <= nums[i] <= 105
  • 1 <= queries.length <= 5 * 104
  • queries[i] == [posi, xi]
  • 0 <= posi <= nums.length - 1
  • -105 <= xi <= 105

Solutions

Solution 1: Segment Tree

According to the problem description, we need to perform multiple point updates and range queries. In this scenario, we consider using a segment tree to solve the problem.

First, we define a $\textit{Node}$ class to store the information of the segment tree nodes, including the left and right endpoints $l$ and $r$, as well as four state values $s_{00}$, $s_{01}$, $s_{10}$, and $s_{11}$. Specifically:

  • $s_{00}$ represents the maximum sum of the subsequence that does not include the left and right endpoints of the current node;
  • $s_{01}$ represents the maximum sum of the subsequence that does not include the left endpoint of the current node;
  • $s_{10}$ represents the maximum sum of the subsequence that does not include the right endpoint of the current node;
  • $s_{11}$ represents the maximum sum of the subsequence that includes the left and right endpoints of the current node.

Next, we define a $\textit{SegmentTree}$ class to construct the segment tree. During the construction of the segment tree, we need to recursively build the left and right subtrees and update the state values of the current node based on the state values of the left and right subtrees.

In the main function, we first construct the segment tree based on the given array $\textit{nums}$ and process each query. For each query, we first perform a point update, then query the state values of the entire range, and accumulate the result into the answer.

The time complexity is $O((n + q) \times \log n)$, and the space complexity is $O(n)$. Here, $n$ represents the length of the array $\textit{nums}$, and $q$ represents the number of queries.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def max(a: int, b: int) -> int:
    return a if a > b else b


class Node:
    __slots__ = "l", "r", "s00", "s01", "s10", "s11"

    def __init__(self, l: int, r: int):
        self.l = l
        self.r = r
        self.s00 = self.s01 = self.s10 = self.s11 = 0


class SegmentTree:
    __slots__ = "tr"

    def __init__(self, n: int):
        self.tr: List[Node | None] = [None] * (n << 2)
        self.build(1, 1, n)

    def build(self, u: int, l: int, r: int):
        self.tr[u] = Node(l, r)
        if l == r:
            return
        mid = (l + r) >> 1
        self.build(u << 1, l, mid)
        self.build(u << 1 | 1, mid + 1, r)

    def query(self, u: int, l: int, r: int) -> int:
        if self.tr[u].l >= l and self.tr[u].r <= r:
            return self.tr[u].s11
        mid = (self.tr[u].l + self.tr[u].r) >> 1
        ans = 0
        if r <= mid:
            ans = self.query(u << 1, l, r)
        if l > mid:
            ans = max(ans, self.query(u << 1 | 1, l, r))
        return ans

    def pushup(self, u: int):
        left, right = self.tr[u << 1], self.tr[u << 1 | 1]
        self.tr[u].s00 = max(left.s00 + right.s10, left.s01 + right.s00)
        self.tr[u].s01 = max(left.s00 + right.s11, left.s01 + right.s01)
        self.tr[u].s10 = max(left.s10 + right.s10, left.s11 + right.s00)
        self.tr[u].s11 = max(left.s10 + right.s11, left.s11 + right.s01)

    def modify(self, u: int, x: int, v: int):
        if self.tr[u].l == self.tr[u].r:
            self.tr[u].s11 = max(0, v)
            return
        mid = (self.tr[u].l + self.tr[u].r) >> 1
        if x <= mid:
            self.modify(u << 1, x, v)
        else:
            self.modify(u << 1 | 1,