Welcome to Subscribe On Youtube
Question
Formatted question description: https://leetcode.ca/all/378.html
Given an n x n
matrix
where each of the rows and columns is sorted in ascending order, return the kth
smallest element in the matrix.
Note that it is the kth
smallest element in the sorted order, not the kth
distinct element.
You must find a solution with a memory complexity better than O(n2)
.
Example 1:
Input: matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8 Output: 13 Explanation: The elements in the matrix are [1,5,9,10,11,12,13,13,15], and the 8th smallest number is 13
Example 2:
Input: matrix = [[-5]], k = 1 Output: -5
Constraints:
n == matrix.length == matrix[i].length
1 <= n <= 300
-109 <= matrix[i][j] <= 109
- All the rows and columns of
matrix
are guaranteed to be sorted in non-decreasing order. 1 <= k <= n2
Follow up:
- Could you solve the problem with a constant memory (i.e.,
O(1)
memory complexity)? - Could you solve the problem in
O(n)
time complexity? The solution may be too advanced for an interview but you may find reading this paper fun.
Algorithm
PriorityQueue
- similar to
Merge K sorted Array/ List
, use PriorityQueue。 - Because the element of the Array cannot find the next directly, a class Node is used to store the value, x, y positions.
- Initial
O(n)
time, also find kO(k)
, sortO(logn)
=>O(n + klogn)
- variation: Kth Largest in N Arrays
Binary Search
- we know where the boundary is start/end are the min/max value.
- locate the kth smallest item
(x, y)
by cutt off partition in binary fasion: - find mid-value, and count # of items < mid-value based on the ascending matrix
O(nlogn)
Code
-
import java.util.Comparator; import java.util.PriorityQueue; public class Kth_Smallest_Element_in_a_Sorted_Matrix { class Node { int x; int y; int val; public Node(int x, int y, int val) { this.x = x; this.y = y; this.val = val; } } class Solution { public int kthSmallest(int[][] matrix, int k) { if (matrix == null || matrix.length == 0 || k <= 0) { return 0; } PriorityQueue<Node> pq = new PriorityQueue<>(new Comparator<Node>() { @Override public int compare(Node o1, Node o2) { return o1.val - o2.val; } }); for (int i = 0; i < matrix.length; i++) { pq.offer(new Node(i, 0, matrix[i][0])); // 第一列入heap } while (!pq.isEmpty()) { Node n = pq.poll(); if (k == 1) { return n.val; } if (n.y + 1 < matrix.length) { pq.offer(new Node(n.x, n.y + 1, matrix[n.x][n.y+1])); // 入current的row的下一个 } k--; } return 0; } } class Solution_optimize { public int kthSmallest(int[][] matrix, int k) { if (matrix == null || matrix.length == 0 || k <= 0) { return 0; } int n = matrix.length; int min = matrix[0][0]; int max = matrix[n-1][n-1]; while (min < max) { int mid = min + (max - min) / 2; int smallerCount = countSmallerThanMid(matrix, mid); if (smallerCount < k) { min = mid + 1; } else { max = mid; } } return min; // return max will also work } private int countSmallerThanMid(int[][] matrix, int mid) { int count = 0; // start from top-right corner, where all its left is smaller but all its below is larger int i = 0; int j = matrix[0].length - 1; while (i < matrix[0].length && j >= 0) { if (matrix[i][j] > mid) { j--; } else { count += j + 1; i++; } } return count; } } } ////// class Solution { public int kthSmallest(int[][] matrix, int k) { int n = matrix.length; int left = matrix[0][0], right = matrix[n - 1][n - 1]; while (left < right) { int mid = (left + right) >>> 1; if (check(matrix, mid, k, n)) { right = mid; } else { left = mid + 1; } } return left; } private boolean check(int[][] matrix, int mid, int k, int n) { int count = 0; int i = n - 1, j = 0; while (i >= 0 && j < n) { if (matrix[i][j] <= mid) { count += (i + 1); ++j; } else { --i; } } return count >= k; } }
-
// OJ: https://leetcode.com/problems/kth-smallest-element-in-a-sorted-matrix/ // Time: O(Klog(N)) // Space: O(N) class Solution { typedef tuple<int, int, int> Item; public: int kthSmallest(vector<vector<int>>& A, int k) { int N = A.size(), dirs[2][2] = { {0,1},{1,0} }; vector<vector<bool>> seen(N, vector<bool>(N)); priority_queue<Item, vector<Item>, greater<>> pq; pq.emplace(A[0][0], 0, 0); seen[0][0] = true; while (--k) { auto [n, x, y] = pq.top(); pq.pop(); for (auto &[dx, dy] : dirs) { int a = x + dx, b = y + dy; if (a < 0 || a >= N || b < 0 || b >= N || seen[a][b]) continue; seen[a][b] = true; pq.emplace(A[a][b], a, b); } } return get<0>(pq.top()); } };
-
''' time complexity of the countSmallerThanMid(mid) function is O(n), as it iterates through at most n rows and columns of the matrix. number of iterations in the binary search `while left < right:` is O(log(maxVal - minVal)), where maxVal - minVal represents the range of values in the matrix. overall time complexity of the code is `O(n * log(maxVal - minVal))` no extra space ''' class Solution: def kthSmallest(self, matrix: List[List[int]], k: int) -> int: def countSmallerThanMid(mid): count = 0 # start from top-right corner, where all its left is smaller but all its below is larger i, j = 0, n - 1 while i < n and j >= 0: if matrix[i][j] > mid: j -= 1 else: count += j + 1 # j is index, j+1 is count for that row i += 1 return count n = len(matrix) left, right = matrix[0][0], matrix[n - 1][n - 1] while left < right: mid = (left + right) >> 1 if countSmallerThanMid(mid) >= k: # The kth smallest element is smaller than or equal to mid, so update the right boundary right = mid else: # The kth smallest element is larger than mid, so update the left boundary left = mid + 1 # If the loop terminates, left and right will be equal # Return left (or right), which represents the kth smallest element return left # also passing OJ: return right ############ ''' This algorithm runs in O(n) time complexity because each every element is added to the heap only once and is visited only once The space complexity is also O(n) because we store at most n elements in the heap and at most n elements in the visited set ''' import heapq class Solution(object): def kthSmallest(self, matrix, k): """ :type matrix: List[List[int]] :type k: int :rtype: int """ visited = {(0, 0)} heap = [(matrix[0][0], (0, 0))] while heap: # checking every element val, (i, j) = heapq.heappop(heap) k -= 1 if k == 0: return val if i + 1 < len(matrix) and (i + 1, j) not in visited: heapq.heappush(heap, (matrix[i + 1][j], (i + 1, j))) visited.add((i + 1, j)) if j + 1 < len(matrix[0]) and (i, j + 1) not in visited: heapq.heappush(heap, (matrix[i][j + 1], (i, j + 1))) visited.add((i, j + 1))
-
func kthSmallest(matrix [][]int, k int) int { n := len(matrix) left, right := matrix[0][0], matrix[n-1][n-1] for left < right { mid := (left + right) >> 1 if check(matrix, mid, k, n) { right = mid } else { left = mid + 1 } } return left } func check(matrix [][]int, mid, k, n int) bool { count := 0 i, j := n-1, 0 for i >= 0 && j < n { if matrix[i][j] <= mid { count += (i + 1) j++ } else { i-- } } return count >= k }