Skip to content

2387. Median of a Row Wise Sorted Matrix πŸ”’

Description

Given an m x n matrix grid containing an odd number of integers where each row is sorted in non-decreasing order, return the median of the matrix.

You must solve the problem in less than O(m * n) time complexity.

 

Example 1:

Input: grid = [[1,1,2],[2,3,3],[1,3,4]]
Output: 2
Explanation: The elements of the matrix in sorted order are 1,1,1,2,2,3,3,3,4. The median is 2.

Example 2:

Input: grid = [[1,1,3,3,4]]
Output: 3
Explanation: The elements of the matrix in sorted order are 1,1,3,3,4. The median is 3.

 

Constraints:

  • m == grid.length
  • n == grid[i].length
  • 1 <= m, n <= 500
  • m and n are both odd.
  • 1 <= grid[i][j] <= 106
  • grid[i] is sorted in non-decreasing order.

Solutions

Solution 1: Two Binary Searches

The median is actually the $target = \left \lceil \frac{m \times n}{2} \right \rceil$-th number after sorting.

We perform a binary search on the elements of the matrix $x$, counting the number of elements in the grid that are greater than $x$, denoted as $cnt$. If $cnt \ge target$, it means the median is on the left side of $x$ (including $x$); otherwise, it is on the right side.

The time complexity is $O(m \times \log n \times \log M)$, where $m$ and $n$ are the number of rows and columns of the grid, respectively, and $M$ is the maximum element in the grid. The space complexity is $O(1)$.

1
2
3
4
5
6
7
8
class Solution:
    def matrixMedian(self, grid: List[List[int]]) -> int:
        def count(x):
            return sum(bisect_right(row, x) for row in grid)

        m, n = len(grid), len(grid[0])
        target = (m * n + 1) >> 1
        return bisect_left(range(10**6 + 1), target, key=count)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Solution {
    private int[][] grid;

    public int matrixMedian(int[][] grid) {
        this.grid = grid;
        int m = grid.length, n = grid[0].length;
        int target = (m * n + 1) >> 1;
        int left = 0, right = 1000010;
        while (left < right) {
            int mid = (left + right) >> 1;
            if (count(mid) >= target) {
                right = mid;
            } else {
                left = mid + 1;
            }
        }
        return left;
    }

    private int count(int x) {
        int cnt = 0;
        for (var row : grid) {
            int left = 0, right = row.length;
            while (left < right) {
                int mid = (left + right) >> 1;
                if (row[mid] > x) {
                    right = mid;
                } else {
                    left = mid + 1;
                }
            }
            cnt += left;
        }
        return cnt;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution {
public:
    int matrixMedian(vector<vector<int>>& grid) {
        int m = grid.size(), n = grid[0].size();
        int left = 0, right = 1e6 + 1;
        int target = (m * n + 1) >> 1;
        auto count = [&](int x) {
            int cnt = 0;
            for (auto& row : grid) {
                cnt += (upper_bound(row.begin(), row.end(), x) - row.begin());
            }
            return cnt;
        };
        while (left < right) {
            int mid = (left + right) >> 1;
            if (count(mid) >= target) {
                right = mid;
            } else {
                left = mid + 1;
            }
        }
        return left;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
func matrixMedian(grid [][]int) int {
    m, n := len(grid), len(grid[0])

    count := func(x int) int {
        cnt := 0
        for _, row := range grid {
            left, right := 0, n
            for left < right {
                mid := (left + right) >> 1
                if row[mid] > x {
                    right = mid
                } else {
                    left = mid + 1
                }
            }
            cnt += left
        }
        return cnt
    }
    left, right := 0, 1000010
    target := (m*n + 1) >> 1
    for left < right {
        mid := (left + right) >> 1
        if count(mid) >= target {
            right = mid
        } else {
            left = mid + 1
        }
    }
    return left
}

Comments