Skip to content

2040. Kth Smallest Product of Two Sorted Arrays

Description

Given two sorted 0-indexed integer arrays nums1 and nums2 as well as an integer k, return the kth (1-based) smallest product of nums1[i] * nums2[j] where 0 <= i < nums1.length and 0 <= j < nums2.length.

 

Example 1:

Input: nums1 = [2,5], nums2 = [3,4], k = 2
Output: 8
Explanation: The 2 smallest products are:
- nums1[0] * nums2[0] = 2 * 3 = 6
- nums1[0] * nums2[1] = 2 * 4 = 8
The 2nd smallest product is 8.

Example 2:

Input: nums1 = [-4,-2,0,3], nums2 = [2,4], k = 6
Output: 0
Explanation: The 6 smallest products are:
- nums1[0] * nums2[1] = (-4) * 4 = -16
- nums1[0] * nums2[0] = (-4) * 2 = -8
- nums1[1] * nums2[1] = (-2) * 4 = -8
- nums1[1] * nums2[0] = (-2) * 2 = -4
- nums1[2] * nums2[0] = 0 * 2 = 0
- nums1[2] * nums2[1] = 0 * 4 = 0
The 6th smallest product is 0.

Example 3:

Input: nums1 = [-2,-1,0,1,2], nums2 = [-3,-1,2,4,5], k = 3
Output: -6
Explanation: The 3 smallest products are:
- nums1[0] * nums2[4] = (-2) * 5 = -10
- nums1[0] * nums2[3] = (-2) * 4 = -8
- nums1[4] * nums2[0] = 2 * (-3) = -6
The 3rd smallest product is -6.

 

Constraints:

  • 1 <= nums1.length, nums2.length <= 5 * 104
  • -105 <= nums1[i], nums2[j] <= 105
  • 1 <= k <= nums1.length * nums2.length
  • nums1 and nums2 are sorted.

Solutions

We can use binary search to enumerate the value of the product \(p\), defining the binary search interval as \([l, r]\), where \(l = -\textit{max}(|\textit{nums1}[0]|, |\textit{nums1}[n - 1]|) \times \textit{max}(|\textit{nums2}[0]|, |\textit{nums2}[n - 1]|)\), \(r = -l\).

For each \(p\), we calculate the number of products less than or equal to \(p\). If this number is greater than or equal to \(k\), it means the \(k\)-th smallest product must be less than or equal to \(p\), so we can reduce the right endpoint of the interval to \(p\). Otherwise, we increase the left endpoint of the interval to \(p + 1\).

The key to the problem is how to calculate the number of products less than or equal to \(p\). We can enumerate each number \(x\) in \(\textit{nums1}\) and discuss in cases:

  • If \(x > 0\), then \(x \times \textit{nums2}[i]\) is monotonically increasing as \(i\) increases. We can use binary search to find the smallest \(i\) such that \(x \times \textit{nums2}[i] > p\). Then, \(i\) is the number of products less than or equal to \(p\), which is accumulated into the count \(\textit{cnt}\);
  • If \(x < 0\), then \(x \times \textit{nums2}[i]\) is monotonically decreasing as \(i\) increases. We can use binary search to find the smallest \(i\) such that \(x \times \textit{nums2}[i] \leq p\). Then, \(n - i\) is the number of products less than or equal to \(p\), which is accumulated into the count \(\textit{cnt}\);
  • If \(x = 0\), then \(x \times \textit{nums2}[i] = 0\). If \(p \geq 0\), then \(n\) is the number of products less than or equal to \(p\), which is accumulated into the count \(\textit{cnt}\).

This way, we can find the \(k\)-th smallest product through binary search.

The time complexity is \(O(m \times \log n \times \log M)\), where \(m\) and \(n\) are the lengths of \(\textit{nums1}\) and \(\textit{nums2}\), respectively, and \(M\) is the maximum absolute value in \(\textit{nums1}\) and \(\textit{nums2}\).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class Solution:
    def kthSmallestProduct(self, nums1: List[int], nums2: List[int], k: int) -> int:
        def count(p: int) -> int:
            cnt = 0
            n = len(nums2)
            for x in nums1:
                if x > 0:
                    cnt += bisect_right(nums2, p / x)
                elif x < 0:
                    cnt += n - bisect_left(nums2, p / x)
                else:
                    cnt += n * int(p >= 0)
            return cnt

        mx = max(abs(nums1[0]), abs(nums1[-1])) * max(abs(nums2[0]), abs(nums2[-1]))
        return bisect_left(range(-mx, mx + 1), k, key=count) - mx
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class Solution {
    private int[] nums1;
    private int[] nums2;

    public long kthSmallestProduct(int[] nums1, int[] nums2, long k) {
        this.nums1 = nums1;
        this.nums2 = nums2;
        int m = nums1.length;
        int n = nums2.length;
        int a = Math.max(Math.abs(nums1[0]), Math.abs(nums1[m - 1]));
        int b = Math.max(Math.abs(nums2[0]), Math.abs(nums2[n - 1]));
        long r = (long) a * b;
        long l = (long) -a * b;
        while (l < r) {
            long mid = (l + r) >> 1;
            if (count(mid) >= k) {
                r = mid;
            } else {
                l = mid + 1;
            }
        }
        return l;
    }

    private long count(long p) {
        long cnt = 0;
        int n = nums2.length;
        for (int x : nums1) {
            if (x > 0) {
                int l = 0, r = n;
                while (l < r) {
                    int mid = (l + r) >> 1;
                    if ((long) x * nums2[mid] > p) {
                        r = mid;
                    } else {
                        l = mid + 1;
                    }
                }
                cnt += l;
            } else if (x < 0) {
                int l = 0, r = n;
                while (l < r) {
                    int mid = (l + r) >> 1;
                    if ((long) x * nums2[mid] <= p) {
                        r = mid;
                    } else {
                        l = mid + 1;
                    }
                }
                cnt += n - l;
            } else if (p >= 0) {
                cnt += n;
            }
        }
        return cnt;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class Solution {
public:
    long long kthSmallestProduct(vector<int>& nums1, vector<int>& nums2, long long k) {
        int m = nums1.size(), n = nums2.size();
        int a = max(abs(nums1[0]), abs(nums1[m - 1]));
        int b = max(abs(nums2[0]), abs(nums2[n - 1]));
        long long r = 1LL * a * b;
        long long l = -r;
        auto count = [&](long long p) {
            long long cnt = 0;
            for (int x : nums1) {
                if (x > 0) {
                    int l = 0, r = n;
                    while (l < r) {
                        int mid = (l + r) >> 1;
                        if (1LL * x * nums2[mid] > p) {
                            r = mid;
                        } else {
                            l = mid + 1;
                        }
                    }
                    cnt += l;
                } else if (x < 0) {
                    int l = 0, r = n;
                    while (l < r) {
                        int mid = (l + r) >> 1;
                        if (1LL * x * nums2[mid] <= p) {
                            r = mid;
                        } else {
                            l = mid + 1;
                        }
                    }
                    cnt += n - l;
                } else if (p >= 0) {
                    cnt += n;
                }
            }
            return cnt;
        };
        while (l < r) {
            long long mid = (l + r) >> 1;
            if (count(mid) >= k) {
                r = mid;
            } else {
                l = mid + 1;
            }
        }
        return l;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
func kthSmallestProduct(nums1 []int, nums2 []int, k int64) int64 {
    m := len(nums1)
    n := len(nums2)
    a := max(abs(nums1[0]), abs(nums1[m-1]))
    b := max(abs(nums2[0]), abs(nums2[n-1]))
    r := int64(a) * int64(b)
    l := -r

    count := func(p int64) int64 {
        var cnt int64
        for _, x := range nums1 {
            if x > 0 {
                l, r := 0, n
                for l < r {
                    mid := (l + r) >> 1
                    if int64(x)*int64(nums2[mid]) > p {
                        r = mid
                    } else {
                        l = mid + 1
                    }
                }
                cnt += int64(l)
            } else if x < 0 {
                l, r := 0, n
                for l < r {
                    mid := (l + r) >> 1
                    if int64(x)*int64(nums2[mid]) <= p {
                        r = mid
                    } else {
                        l = mid + 1
                    }
                }
                cnt += int64(n - l)
            } else if p >= 0 {
                cnt += int64(n)
            }
        }
        return cnt
    }

    for l < r {
        mid := (l + r) >> 1
        if count(mid) >= k {
            r = mid
        } else {
            l = mid + 1
        }
    }
    return l
}

func abs(x int) int {
    if x < 0 {
        return -x
    }
    return x
}

Comments