跳转至

1577. 数的平方等于两数乘积的方法数

题目描述

给你两个整数数组 nums1nums2 ,请你返回根据以下规则形成的三元组的数目(类型 1 和类型 2 ):

  • 类型 1:三元组 (i, j, k) ,如果 nums1[i]2 == nums2[j] * nums2[k] 其中 0 <= i < nums1.length0 <= j < k < nums2.length
  • 类型 2:三元组 (i, j, k) ,如果 nums2[i]2 == nums1[j] * nums1[k] 其中 0 <= i < nums2.length0 <= j < k < nums1.length

 

示例 1:

输入:nums1 = [7,4], nums2 = [5,2,8,9]
输出:1
解释:类型 1:(1,1,2), nums1[1]^2 = nums2[1] * nums2[2] (4^2 = 2 * 8)

示例 2:

输入:nums1 = [1,1], nums2 = [1,1,1]
输出:9
解释:所有三元组都符合题目要求,因为 1^2 = 1 * 1
类型 1:(0,0,1), (0,0,2), (0,1,2), (1,0,1), (1,0,2), (1,1,2), nums1[i]^2 = nums2[j] * nums2[k]
类型 2:(0,0,1), (1,0,1), (2,0,1), nums2[i]^2 = nums1[j] * nums1[k]

示例 3:

输入:nums1 = [7,7,8,3], nums2 = [1,2,9,7]
输出:2
解释:有两个符合题目要求的三元组
类型 1:(3,0,2), nums1[3]^2 = nums2[0] * nums2[2]
类型 2:(3,0,1), nums2[3]^2 = nums1[0] * nums1[1]

示例 4:

输入:nums1 = [4,7,9,11,23], nums2 = [3,5,1024,12,18]
输出:0
解释:不存在符合题目要求的三元组

 

提示:

  • 1 <= nums1.length, nums2.length <= 1000
  • 1 <= nums1[i], nums2[i] <= 10^5

解法

方法一:哈希表 + 枚举

我们用哈希表 $\textit{cnt1}$ 统计 $\textit{nums1}$ 中每个数对 $(\textit{nums}[j], \textit{nums}[k])$ 出现的次数,其中 $0 \leq j \lt k < m$,其中 $m$ 为数组 $\textit{nums1}$ 的长度。用哈希表 $\textit{cnt2}$ 统计 $\textit{nums2}$ 中每个数对 $(\textit{nums}[j], \textit{nums}[k])$ 出现的次数,其中 $0 \leq j \lt k < n$,其中 $n$ 为数组 $\textit{nums2}$ 的长度。

接下来,我们枚举数组 $\textit{nums1}$ 中的每个数 $x$,计算 $\textit{cnt2}[x^2]$ 的值,即 $\textit{nums2}$ 中有多少对数 $(\textit{nums}[j], \textit{nums}[k])$ 满足 $\textit{nums}[j] \times \textit{nums}[k] = x^2$。同理,我们枚举数组 $\textit{nums2}$ 中的每个数 $x$,计算 $\textit{cnt1}[x^2]$ 的值,即 $\textit{nums1}$ 中有多少对数 $(\textit{nums}[j], \textit{nums}[k])$ 满足 $\textit{nums}[j] \times \textit{nums}[k] = x^2$,最后将两者相加返回即可。

时间复杂度 $O(m^2 + n^2 + m + n)$,空间复杂度 $O(m^2 + n^2)$。其中 $m$ 和 $n$ 分别为数组 $\textit{nums1}$ 和 $\textit{nums2}$ 的长度。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
class Solution:
    def numTriplets(self, nums1: List[int], nums2: List[int]) -> int:
        def count(nums: List[int]) -> Counter:
            cnt = Counter()
            for j in range(len(nums)):
                for k in range(j + 1, len(nums)):
                    cnt[nums[j] * nums[k]] += 1
            return cnt

        def cal(nums: List[int], cnt: Counter) -> int:
            return sum(cnt[x * x] for x in nums)

        cnt1 = count(nums1)
        cnt2 = count(nums2)
        return cal(nums1, cnt2) + cal(nums2, cnt1)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Solution {
    public int numTriplets(int[] nums1, int[] nums2) {
        var cnt1 = count(nums1);
        var cnt2 = count(nums2);
        return cal(cnt1, nums2) + cal(cnt2, nums1);
    }

    private Map<Long, Integer> count(int[] nums) {
        Map<Long, Integer> cnt = new HashMap<>();
        int n = nums.length;
        for (int j = 0; j < n; ++j) {
            for (int k = j + 1; k < n; ++k) {
                long x = (long) nums[j] * nums[k];
                cnt.merge(x, 1, Integer::sum);
            }
        }
        return cnt;
    }

    private int cal(Map<Long, Integer> cnt, int[] nums) {
        int ans = 0;
        for (int x : nums) {
            long y = (long) x * x;
            ans += cnt.getOrDefault(y, 0);
        }
        return ans;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution {
public:
    int numTriplets(vector<int>& nums1, vector<int>& nums2) {
        auto cnt1 = count(nums1);
        auto cnt2 = count(nums2);
        return cal(cnt1, nums2) + cal(cnt2, nums1);
    }

    unordered_map<long long, int> count(vector<int>& nums) {
        unordered_map<long long, int> cnt;
        for (int i = 0; i < nums.size(); i++) {
            for (int j = i + 1; j < nums.size(); j++) {
                cnt[(long long) nums[i] * nums[j]]++;
            }
        }
        return cnt;
    }

    int cal(unordered_map<long long, int>& cnt, vector<int>& nums) {
        int ans = 0;
        for (int x : nums) {
            ans += cnt[(long long) x * x];
        }
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
func numTriplets(nums1 []int, nums2 []int) int {
    cnt1 := count(nums1)
    cnt2 := count(nums2)
    return cal(cnt1, nums2) + cal(cnt2, nums1)
}

func count(nums []int) map[int]int {
    cnt := map[int]int{}
    for j, x := range nums {
        for _, y := range nums[j+1:] {
            cnt[x*y]++
        }
    }
    return cnt
}

func cal(cnt map[int]int, nums []int) (ans int) {
    for _, x := range nums {
        ans += cnt[x*x]
    }
    return
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
function numTriplets(nums1: number[], nums2: number[]): number {
    const cnt1 = count(nums1);
    const cnt2 = count(nums2);
    return cal(cnt1, nums2) + cal(cnt2, nums1);
}

function count(nums: number[]): Map<number, number> {
    const cnt: Map<number, number> = new Map();
    for (let j = 0; j < nums.length; ++j) {
        for (let k = j + 1; k < nums.length; ++k) {
            const x = nums[j] * nums[k];
            cnt.set(x, (cnt.get(x) || 0) + 1);
        }
    }
    return cnt;
}

function cal(cnt: Map<number, number>, nums: number[]): number {
    return nums.reduce((acc, x) => acc + (cnt.get(x * x) || 0), 0);
}

方法二:哈希表 + 枚举优化

我们用哈希表 $\textit{cnt1}$ 统计 $\textit{nums1}$ 中每个数出现的次数,用哈希表 $\textit{cnt2}$ 统计 $\textit{nums2}$ 中每个数出现的次数。

接下来,我们枚举数组 $\textit{nums1}$ 中的每个数 $x$,然后枚举 $\textit{cnt2}$ 中的每个数对 $(y, v1)$,其中 $y$ 为 $\textit{cnt2}$ 的键,$v1$ 为 $\textit{cnt2}$ 的值。我们计算 $z = x^2 / y$,如果 $y \times z = x^2$,此时如果 $y = z$,说明 $y$ 和 $z$ 是同一个数,那么 $v1 = v2$,从 $v1$ 个数中任选两个数的方案数为 $v1 \times (v1 - 1) = v1 \times (v2 - 1)$;如果 $y \neq z$,那么 $v1$ 个数中任选两个数的方案数为 $v1 \times v2$。最后将所有方案数相加并除以 $2$ 即可。这里除以 $2$ 是因为我们统计的是对数对 $(j, k)$ 的方案数,而实际上 $(j, k)$ 和 $(k, j)$ 是同一种方案。

时间复杂度 $O(m \times n)$,空间复杂度 $O(m + n)$。其中 $m$ 和 $n$ 分别为数组 $\textit{nums1}$ 和 $\textit{nums2}$ 的长度。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
class Solution:
    def numTriplets(self, nums1: List[int], nums2: List[int]) -> int:
        def cal(nums: List[int], cnt: Counter) -> int:
            ans = 0
            for x in nums:
                for y, v1 in cnt.items():
                    z = x * x // y
                    if y * z == x * x:
                        v2 = cnt[z]
                        ans += v1 * (v2 - int(y == z))
            return ans // 2

        cnt1 = Counter(nums1)
        cnt2 = Counter(nums2)
        return cal(nums1, cnt2) + cal(nums2, cnt1)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Solution {
    public int numTriplets(int[] nums1, int[] nums2) {
        var cnt1 = count(nums1);
        var cnt2 = count(nums2);
        return cal(cnt1, nums2) + cal(cnt2, nums1);
    }

    private Map<Integer, Integer> count(int[] nums) {
        Map<Integer, Integer> cnt = new HashMap<>();
        for (int x : nums) {
            cnt.merge(x, 1, Integer::sum);
        }
        return cnt;
    }

    private int cal(Map<Integer, Integer> cnt, int[] nums) {
        long ans = 0;
        for (int x : nums) {
            for (var e : cnt.entrySet()) {
                int y = e.getKey(), v1 = e.getValue();
                int z = (int) (1L * x * x / y);
                if (y * z == x * x) {
                    int v2 = cnt.getOrDefault(z, 0);
                    ans += v1 * (y == z ? v2 - 1 : v2);
                }
            }
        }
        return (int) (ans / 2);
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
func numTriplets(nums1 []int, nums2 []int) int {
    cnt1 := count(nums1)
    cnt2 := count(nums2)
    return cal(cnt1, nums2) + cal(cnt2, nums1)
}

func count(nums []int) map[int]int {
    cnt := map[int]int{}
    for _, x := range nums {
        cnt[x]++
    }
    return cnt
}

func cal(cnt map[int]int, nums []int) (ans int) {
    for _, x := range nums {
        for y, v1 := range cnt {
            z := x * x / y
            if y*z == x*x {
                if v2, ok := cnt[z]; ok {
                    if y == z {
                        v2--
                    }
                    ans += v1 * v2
                }
            }
        }
    }
    ans /= 2
    return
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
function numTriplets(nums1: number[], nums2: number[]): number {
    const cnt1 = count(nums1);
    const cnt2 = count(nums2);
    return cal(cnt1, nums2) + cal(cnt2, nums1);
}

function count(nums: number[]): Map<number, number> {
    const cnt: Map<number, number> = new Map();
    for (const x of nums) {
        cnt.set(x, (cnt.get(x) || 0) + 1);
    }
    return cnt;
}

function cal(cnt: Map<number, number>, nums: number[]): number {
    let ans: number = 0;
    for (const x of nums) {
        for (const [y, v1] of cnt) {
            const z = Math.floor((x * x) / y);
            if (y * z == x * x) {
                const v2 = cnt.get(z) || 0;
                ans += v1 * (y === z ? v2 - 1 : v2);
            }
        }
    }
    return ans / 2;
}

评论