跳转至

1803. 统计异或值在范围内的数对有多少

题目描述

给你一个整数数组 nums (下标 从 0 开始 计数)以及两个整数:lowhigh ,请返回 漂亮数对 的数目。

漂亮数对 是一个形如 (i, j) 的数对,其中 0 <= i < j < nums.lengthlow <= (nums[i] XOR nums[j]) <= high

 

示例 1:

输入:nums = [1,4,2,7], low = 2, high = 6
输出:6
解释:所有漂亮数对 (i, j) 列出如下:
    - (0, 1): nums[0] XOR nums[1] = 5 
    - (0, 2): nums[0] XOR nums[2] = 3
    - (0, 3): nums[0] XOR nums[3] = 6
    - (1, 2): nums[1] XOR nums[2] = 6
    - (1, 3): nums[1] XOR nums[3] = 3
    - (2, 3): nums[2] XOR nums[3] = 5

示例 2:

输入:nums = [9,8,4,2,1], low = 5, high = 14
输出:8
解释:所有漂亮数对 (i, j) 列出如下:
​​​​​    - (0, 2): nums[0] XOR nums[2] = 13
    - (0, 3): nums[0] XOR nums[3] = 11
    - (0, 4): nums[0] XOR nums[4] = 8
    - (1, 2): nums[1] XOR nums[2] = 12
    - (1, 3): nums[1] XOR nums[3] = 10
    - (1, 4): nums[1] XOR nums[4] = 9
    - (2, 3): nums[2] XOR nums[3] = 6
    - (2, 4): nums[2] XOR nums[4] = 5

 

提示:

  • 1 <= nums.length <= 2 * 104
  • 1 <= nums[i] <= 2 * 104
  • 1 <= low <= high <= 2 * 104

解法

方法一:0-1 字典树

对于这种区间 $[low, high]$ 统计的问题,我们可以考虑将其转换为统计 $[0, high]$ 和 $[0, low - 1]$ 的问题,然后相减即可得到答案。

在这道题中,我们可以统计有多少数对的异或值小于 $high+1$,然后再统计有多少数对的异或值小于 $low$,相减的结果就是异或值在区间 $[low, high]$ 之间的数对数量。

另外,对于数组异或计数问题,我们通常可以使用“0-1 字典树”来解决。

字典树的节点定义如下:

  • children[0]children[1] 分别表示当前节点的左右子节点;
  • cnt 表示以当前节点为结尾的数的数量。

在字典树中,我们还定义了以下两个函数:

其中一个函数是 $insert(x)$,表示将数 $x$ 插入到字典树中。该函数将数字 $x$ 按照二进制位从高到低的顺序,插入到“0-1 字典树”中。如果当前二进制位为 $0$,则插入到左子节点,否则插入到右子节点。然后将节点的计数值 $cnt$ 加 $1$。

另一个函数是 $search(x, limit)$,表示在字典树中查找与 $x$ 异或值小于 $limit$ 的数量。该函数从字典树的根节点 node 开始,遍历 $x$ 的二进制位,从高到低,记当前 $x$ 的二进制位的数为 $v$。如果当前 $limit$ 的二进制位为 $1$,此时我们可以直接将答案加上与 $x$ 的当前二进制位 $v$ 相同的子节点的计数值 $cnt$,然后将当前节点移动到与 $x$ 的当前二进制位 $v$ 不同的子节点,即 node = node.children[v ^ 1]。继续遍历下一位。如果当前 $limit$ 的二进制位为 $0$,此时我们只能将当前节点移动到与 $x$ 的当前二进制位 $v$ 相同的子节点,即 node = node.children[v]。继续遍历下一位。遍历完 $x$ 的二进制位后,返回答案。

有了以上两个函数,我们就可以解决本题了。

我们遍历数组 nums,对于每个数 $x$,我们先在字典树中查找与 $x$ 异或值小于 $high+1$ 的数量,然后在字典树中查找与 $x$ 异或值小于 $low$ 的数对数量,将两者的差值加到答案中。接着将 $x$ 插入到字典树中。继续遍历下一个数 $x$,直到遍历完数组 nums。最后返回答案即可。

时间复杂度 $O(n \times \log M)$,空间复杂度 $O(n \times \log M)$。其中 $n$ 为数组 nums 的长度,而 $M$ 为数组 nums 中的最大值。本题中我们直接取 $\log M = 16$。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class Trie:
    def __init__(self):
        self.children = [None] * 2
        self.cnt = 0

    def insert(self, x):
        node = self
        for i in range(15, -1, -1):
            v = x >> i & 1
            if node.children[v] is None:
                node.children[v] = Trie()
            node = node.children[v]
            node.cnt += 1

    def search(self, x, limit):
        node = self
        ans = 0
        for i in range(15, -1, -1):
            if node is None:
                return ans
            v = x >> i & 1
            if limit >> i & 1:
                if node.children[v]:
                    ans += node.children[v].cnt
                node = node.children[v ^ 1]
            else:
                node = node.children[v]
        return ans


class Solution:
    def countPairs(self, nums: List[int], low: int, high: int) -> int:
        ans = 0
        tree = Trie()
        for x in nums:
            ans += tree.search(x, high + 1) - tree.search(x, low)
            tree.insert(x)
        return ans
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class Trie {
    private Trie[] children = new Trie[2];
    private int cnt;

    public void insert(int x) {
        Trie node = this;
        for (int i = 15; i >= 0; --i) {
            int v = (x >> i) & 1;
            if (node.children[v] == null) {
                node.children[v] = new Trie();
            }
            node = node.children[v];
            ++node.cnt;
        }
    }

    public int search(int x, int limit) {
        Trie node = this;
        int ans = 0;
        for (int i = 15; i >= 0 && node != null; --i) {
            int v = (x >> i) & 1;
            if (((limit >> i) & 1) == 1) {
                if (node.children[v] != null) {
                    ans += node.children[v].cnt;
                }
                node = node.children[v ^ 1];
            } else {
                node = node.children[v];
            }
        }
        return ans;
    }
}

class Solution {
    public int countPairs(int[] nums, int low, int high) {
        Trie trie = new Trie();
        int ans = 0;
        for (int x : nums) {
            ans += trie.search(x, high + 1) - trie.search(x, low);
            trie.insert(x);
        }
        return ans;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class Trie {
public:
    Trie()
        : children(2)
        , cnt(0) {}

    void insert(int x) {
        Trie* node = this;
        for (int i = 15; ~i; --i) {
            int v = x >> i & 1;
            if (!node->children[v]) {
                node->children[v] = new Trie();
            }
            node = node->children[v];
            ++node->cnt;
        }
    }

    int search(int x, int limit) {
        Trie* node = this;
        int ans = 0;
        for (int i = 15; ~i && node; --i) {
            int v = x >> i & 1;
            if (limit >> i & 1) {
                if (node->children[v]) {
                    ans += node->children[v]->cnt;
                }
                node = node->children[v ^ 1];
            } else {
                node = node->children[v];
            }
        }
        return ans;
    }

private:
    vector<Trie*> children;
    int cnt;
};

class Solution {
public:
    int countPairs(vector<int>& nums, int low, int high) {
        Trie* tree = new Trie();
        int ans = 0;
        for (int& x : nums) {
            ans += tree->search(x, high + 1) - tree->search(x, low);
            tree->insert(x);
        }
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
type Trie struct {
    children [2]*Trie
    cnt      int
}

func newTrie() *Trie {
    return &Trie{}
}

func (this *Trie) insert(x int) {
    node := this
    for i := 15; i >= 0; i-- {
        v := (x >> i) & 1
        if node.children[v] == nil {
            node.children[v] = newTrie()
        }
        node = node.children[v]
        node.cnt++
    }
}

func (this *Trie) search(x, limit int) (ans int) {
    node := this
    for i := 15; i >= 0 && node != nil; i-- {
        v := (x >> i) & 1
        if (limit >> i & 1) == 1 {
            if node.children[v] != nil {
                ans += node.children[v].cnt
            }
            node = node.children[v^1]
        } else {
            node = node.children[v]
        }
    }
    return
}

func countPairs(nums []int, low int, high int) (ans int) {
    tree := newTrie()
    for _, x := range nums {
        ans += tree.search(x, high+1) - tree.search(x, low)
        tree.insert(x)
    }
    return
}

评论