Skip to content

2387. Median of a Row Wise Sorted Matrix πŸ”’

Description

Given an m x n matrix grid containing an odd number of integers where each row is sorted in non-decreasing order, return the median of the matrix.

You must solve the problem in less than O(m * n) time complexity.

 

Example 1:

Input: grid = [[1,1,2],[2,3,3],[1,3,4]]
Output: 2
Explanation: The elements of the matrix in sorted order are 1,1,1,2,2,3,3,3,4. The median is 2.

Example 2:

Input: grid = [[1,1,3,3,4]]
Output: 3
Explanation: The elements of the matrix in sorted order are 1,1,3,3,4. The median is 3.

 

Constraints:

  • m == grid.length
  • n == grid[i].length
  • 1 <= m, n <= 500
  • m and n are both odd.
  • 1 <= grid[i][j] <= 106
  • grid[i] is sorted in non-decreasing order.

Solutions

Solution 1: Two Binary Searches

The median is actually the \(target = \left \lceil \frac{m \times n}{2} \right \rceil\)-th number after sorting.

We perform a binary search on the elements of the matrix \(x\), counting the number of elements in the grid that are greater than \(x\), denoted as \(cnt\). If \(cnt \ge target\), it means the median is on the left side of \(x\) (including \(x\)); otherwise, it is on the right side.

The time complexity is \(O(m \times \log n \times \log M)\), where \(m\) and \(n\) are the number of rows and columns of the grid, respectively, and \(M\) is the maximum element in the grid. The space complexity is \(O(1)\).

1
2
3
4
5
6
7
8
class Solution:
    def matrixMedian(self, grid: List[List[int]]) -> int:
        def count(x):
            return sum(bisect_right(row, x) for row in grid)

        m, n = len(grid), len(grid[0])
        target = (m * n + 1) >> 1
        return bisect_left(range(10**6 + 1), target, key=count)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Solution {
    private int[][] grid;

    public int matrixMedian(int[][] grid) {
        this.grid = grid;
        int m = grid.length, n = grid[0].length;
        int target = (m * n + 1) >> 1;
        int left = 0, right = 1000010;
        while (left < right) {
            int mid = (left + right) >> 1;
            if (count(mid) >= target) {
                right = mid;
            } else {
                left = mid + 1;
            }
        }
        return left;
    }

    private int count(int x) {
        int cnt = 0;
        for (var row : grid) {
            int left = 0, right = row.length;
            while (left < right) {
                int mid = (left + right) >> 1;
                if (row[mid] > x) {
                    right = mid;
                } else {
                    left = mid + 1;
                }
            }
            cnt += left;
        }
        return cnt;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution {
public:
    int matrixMedian(vector<vector<int>>& grid) {
        int m = grid.size(), n = grid[0].size();
        int left = 0, right = 1e6 + 1;
        int target = (m * n + 1) >> 1;
        auto count = [&](int x) {
            int cnt = 0;
            for (auto& row : grid) {
                cnt += (upper_bound(row.begin(), row.end(), x) - row.begin());
            }
            return cnt;
        };
        while (left < right) {
            int mid = (left + right) >> 1;
            if (count(mid) >= target) {
                right = mid;
            } else {
                left = mid + 1;
            }
        }
        return left;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
func matrixMedian(grid [][]int) int {
    m, n := len(grid), len(grid[0])

    count := func(x int) int {
        cnt := 0
        for _, row := range grid {
            left, right := 0, n
            for left < right {
                mid := (left + right) >> 1
                if row[mid] > x {
                    right = mid
                } else {
                    left = mid + 1
                }
            }
            cnt += left
        }
        return cnt
    }
    left, right := 0, 1000010
    target := (m*n + 1) >> 1
    for left < right {
        mid := (left + right) >> 1
        if count(mid) >= target {
            right = mid
        } else {
            left = mid + 1
        }
    }
    return left
}

Comments