Welcome to Subscribe On Youtube

Question

Formatted question description: https://leetcode.ca/all/378.html

Given an n x n matrix where each of the rows and columns is sorted in ascending order, return the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.

You must find a solution with a memory complexity better than O(n2).

 

Example 1:

Input: matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8
Output: 13
Explanation: The elements in the matrix are [1,5,9,10,11,12,13,13,15], and the 8th smallest number is 13

Example 2:

Input: matrix = [[-5]], k = 1
Output: -5

 

Constraints:

  • n == matrix.length == matrix[i].length
  • 1 <= n <= 300
  • -109 <= matrix[i][j] <= 109
  • All the rows and columns of matrix are guaranteed to be sorted in non-decreasing order.
  • 1 <= k <= n2

 

Follow up:

  • Could you solve the problem with a constant memory (i.e., O(1) memory complexity)?
  • Could you solve the problem in O(n) time complexity? The solution may be too advanced for an interview but you may find reading this paper fun.

Algorithm

PriorityQueue

  • similar to Merge K sorted Array/ List , use PriorityQueue。
  • Because the element of the Array cannot find the next directly, a class Node is used to store the value, x, y positions.
  • Initial O(n) time, also find k O(k), sort O(logn) => O(n + klogn)
  • variation: Kth Largest in N Arrays
  • we know where the boundary is start/end are the min/max value.
  • locate the kth smallest item (x, y) by cutt off partition in binary fasion:
  • find mid-value, and count # of items < mid-value based on the ascending matrix
  • O(nlogn)

Code

  • import java.util.Comparator;
    import java.util.PriorityQueue;
    
    public class Kth_Smallest_Element_in_a_Sorted_Matrix {
    
        class Node {
            int x;
            int y;
            int val;
            public Node(int x, int y, int val) {
                this.x = x;
                this.y = y;
                this.val = val;
            }
        }
    
        class Solution {
            public int kthSmallest(int[][] matrix, int k) {
    
                if (matrix == null || matrix.length == 0 || k <= 0) {
                    return 0;
                }
    
                PriorityQueue<Node> pq = new PriorityQueue<>(new Comparator<Node>() {
                    @Override
                    public int compare(Node o1, Node o2) {
                        return o1.val - o2.val;
                    }
                });
    
                for (int i = 0; i < matrix.length; i++) {
                    pq.offer(new Node(i, 0, matrix[i][0])); // 第一列入heap
                }
    
                while (!pq.isEmpty()) {
                    Node n = pq.poll();
    
                    if (k == 1) {
                        return n.val;
                    }
    
                    if (n.y + 1 < matrix.length) {
                        pq.offer(new Node(n.x, n.y + 1, matrix[n.x][n.y+1])); // 入current的row的下一个
                    }
    
                    k--;
                }
    
                return 0;
            }
        }
    
        class Solution_optimize {
            public int kthSmallest(int[][] matrix, int k) {
    
                if (matrix == null || matrix.length == 0 || k <= 0) {
                    return 0;
                }
    
                int n = matrix.length;
                int min = matrix[0][0];
                int max = matrix[n-1][n-1];
    
                while (min < max) {
                    int mid = min + (max - min) / 2;
    
                    int smallerCount = countSmallerThanMid(matrix, mid);
    
                    if (smallerCount < k) {
                        min = mid + 1;
                    } else {
                        max = mid;
                    }
                }
    
                return min; // return max will also work
            }
    
            private int countSmallerThanMid(int[][] matrix, int mid) {
                int count = 0;
    
                // start from top-right corner, where all its left is smaller but all its below is larger
                int i = 0;
                int j = matrix[0].length - 1;
    
                while (i < matrix[0].length && j >= 0) {
    
                    if (matrix[i][j] > mid) {
                        j--;
                    } else {
                        count += j + 1;
                        i++;
                    }
                }
    
                return count;
            }
        }
    }
    
    //////
    
    class Solution {
        public int kthSmallest(int[][] matrix, int k) {
            int n = matrix.length;
            int left = matrix[0][0], right = matrix[n - 1][n - 1];
            while (left < right) {
                int mid = (left + right) >>> 1;
                if (check(matrix, mid, k, n)) {
                    right = mid;
                } else {
                    left = mid + 1;
                }
            }
            return left;
        }
    
        private boolean check(int[][] matrix, int mid, int k, int n) {
            int count = 0;
            int i = n - 1, j = 0;
            while (i >= 0 && j < n) {
                if (matrix[i][j] <= mid) {
                    count += (i + 1);
                    ++j;
                } else {
                    --i;
                }
            }
            return count >= k;
        }
    }
    
  • // OJ: https://leetcode.com/problems/kth-smallest-element-in-a-sorted-matrix/
    // Time: O(Klog(N))
    // Space: O(N)
    class Solution {
        typedef tuple<int, int, int> Item;
    public:
        int kthSmallest(vector<vector<int>>& A, int k) {
            int N = A.size(), dirs[2][2] = { {0,1},{1,0} };
            vector<vector<bool>> seen(N, vector<bool>(N));
            priority_queue<Item, vector<Item>, greater<>> pq;
            pq.emplace(A[0][0], 0, 0);
            seen[0][0] = true;
            while (--k) {
                auto [n, x, y] = pq.top();
                pq.pop();
                for (auto &[dx, dy] : dirs) {
                    int a = x + dx, b = y + dy;
                    if (a < 0 || a >= N || b < 0 || b >= N || seen[a][b]) continue;
                    seen[a][b] = true;
                    pq.emplace(A[a][b], a, b);
                }
            }
            return get<0>(pq.top());
        }
    };
    
  • '''
    time complexity of the countSmallerThanMid(mid) function is O(n), 
    as it iterates through at most n rows and columns of the matrix.
    
    number of iterations in the binary search `while left < right:` 
    is O(log(maxVal - minVal)), where maxVal - minVal represents the range of values in the matrix.
    
    overall time complexity of the code is `O(n * log(maxVal - minVal))`
    
    no extra space
    '''
    class Solution:
        def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
            def countSmallerThanMid(mid):
                count = 0
                # start from top-right corner, where all its left is smaller but all its below is larger
                i, j = 0, n - 1 
                while i < n and j >= 0:
                    if matrix[i][j] > mid:
                        j -= 1
                    else:
                        count += j + 1 # j is index, j+1 is count for that row
                        i += 1
                return count
    
            n = len(matrix)
            left, right = matrix[0][0], matrix[n - 1][n - 1]
            while left < right:
                mid = (left + right) >> 1
                if countSmallerThanMid(mid) >= k:
                    # The kth smallest element is smaller than or equal to mid, so update the right boundary
                    right = mid
                else:
                    # The kth smallest element is larger than mid, so update the left boundary
                    left = mid + 1
            # If the loop terminates, left and right will be equal
            # Return left (or right), which represents the kth smallest element
            return left
            # also passing OJ: return right
    
    
    ############
    
    '''
    This algorithm runs in O(n) time complexity 
    because each every element is added to the heap only once and is visited only once
    
    The space complexity is also O(n) 
    because we store at most n elements in the heap and at most n elements in the visited set
    '''
    import heapq
    
    class Solution(object):
      def kthSmallest(self, matrix, k):
        """
        :type matrix: List[List[int]]
        :type k: int
        :rtype: int
        """
        visited = {(0, 0)}
        heap = [(matrix[0][0], (0, 0))]
    
        while heap: # checking every element
          val, (i, j) = heapq.heappop(heap)
          k -= 1
          if k == 0:
            return val
          if i + 1 < len(matrix) and (i + 1, j) not in visited:
            heapq.heappush(heap, (matrix[i + 1][j], (i + 1, j)))
            visited.add((i + 1, j))
          if j + 1 < len(matrix[0]) and (i, j + 1) not in visited:
            heapq.heappush(heap, (matrix[i][j + 1], (i, j + 1)))
            visited.add((i, j + 1))
    
    
  • func kthSmallest(matrix [][]int, k int) int {
    	n := len(matrix)
    	left, right := matrix[0][0], matrix[n-1][n-1]
    	for left < right {
    		mid := (left + right) >> 1
    		if check(matrix, mid, k, n) {
    			right = mid
    		} else {
    			left = mid + 1
    		}
    	}
    	return left
    }
    
    func check(matrix [][]int, mid, k, n int) bool {
    	count := 0
    	i, j := n-1, 0
    	for i >= 0 && j < n {
    		if matrix[i][j] <= mid {
    			count += (i + 1)
    			j++
    		} else {
    			i--
    		}
    	}
    	return count >= k
    }
    

All Problems

All Solutions