Question

Formatted question description: https://leetcode.ca/all/378.html

 378	Kth Smallest Element in a Sorted Matrix

 Given a n x n matrix where each of the rows and columns are sorted in ascending order,
 find the kth smallest element in the matrix.

 Note that it is the kth smallest element in the sorted order, not the kth distinct element.

 Example:

 matrix = [
             [ 1,  5,  9],
             [10, 11, 13],
             [12, 13, 15]
             ],
 k = 8,

 return 13.

 Note:
    You may assume k is always valid, 1 ≤ k ≤ n2.

 @tag-binarysearch
 @tag-heap

Algorithm

PriorityQueue

  • similar to Merge K sorted Array/ List , use PriorityQueue。
  • Because the element of the Array cannot find the next directly, a class Node is used to store the value, x, y positions.
  • Initial O(n) time, also find k O(k), sort O(logn) => O(n + klogn)
  • variation: Kth Largest in N Arrays
  • we know where the boundary is start/end are the min/max value.
  • locate the kth smallest item (x, y) by cutt off partition in binary fasion:
  • find mid-value, and count # of items < mid-value based on the ascending matrix
  • O(nlogn)

Code

Java

  • import java.util.Comparator;
    import java.util.PriorityQueue;
    
    public class Kth_Smallest_Element_in_a_Sorted_Matrix {
    
        class Node {
            int x;
            int y;
            int val;
            public Node(int x, int y, int val) {
                this.x = x;
                this.y = y;
                this.val = val;
            }
        }
    
        class Solution {
            public int kthSmallest(int[][] matrix, int k) {
    
                if (matrix == null || matrix.length == 0 || k <= 0) {
                    return 0;
                }
    
                PriorityQueue<Node> pq = new PriorityQueue<>(new Comparator<Node>() {
                    @Override
                    public int compare(Node o1, Node o2) {
                        return o1.val - o2.val;
                    }
                });
    
                for (int i = 0; i < matrix.length; i++) {
                    pq.offer(new Node(i, 0, matrix[i][0])); // 第一列入heap
                }
    
                while (!pq.isEmpty()) {
                    Node n = pq.poll();
    
                    if (k == 1) {
                        return n.val;
                    }
    
                    if (n.y + 1 < matrix.length) {
                        pq.offer(new Node(n.x, n.y + 1, matrix[n.x][n.y+1])); // 入current的row的下一个
                    }
    
                    k--;
                }
    
                return 0;
            }
        }
    
        class Solution_optimize {
            public int kthSmallest(int[][] matrix, int k) {
    
                if (matrix == null || matrix.length == 0 || k <= 0) {
                    return 0;
                }
    
                int n = matrix.length;
                int min = matrix[0][0];
                int max = matrix[n-1][n-1];
    
                while (min < max) {
                    int mid = min + (max - min) / 2;
    
                    int smallerCount = countSmallerThanMid(matrix, mid);
    
                    if (smallerCount < k) {
                        min = mid + 1;
                    } else {
                        max = mid;
                    }
                }
    
                return min; // return max will also work
            }
    
            private int countSmallerThanMid(int[][] matrix, int mid) {
                int count = 0;
    
                // start from top-right corner, where all its left is smaller but all its below is larger
                int i = 0;
                int j = matrix[0].length - 1;
    
                while (i < matrix[0].length && j >= 0) {
    
                    if (matrix[i][j] > mid) {
                        j--;
                    } else {
                        count += j + 1;
                        i++;
                    }
                }
    
                return count;
            }
        }
    }
    
  • // OJ: https://leetcode.com/problems/kth-smallest-element-in-a-sorted-matrix/
    // Time: O(Klog(N))
    // Space: O(N)
    class Solution {
        typedef tuple<int, int, int> Item;
    public:
        int kthSmallest(vector<vector<int>>& A, int k) {
            int N = A.size(), dirs[2][2] = { {0,1},{1,0} };
            vector<vector<bool>> seen(N, vector<bool>(N));
            priority_queue<Item, vector<Item>, greater<>> pq;
            pq.emplace(A[0][0], 0, 0);
            seen[0][0] = true;
            while (--k) {
                auto [n, x, y] = pq.top();
                pq.pop();
                for (auto &[dx, dy] : dirs) {
                    int a = x + dx, b = y + dy;
                    if (a < 0 || a >= N || b < 0 || b >= N || seen[a][b]) continue;
                    seen[a][b] = true;
                    pq.emplace(A[a][b], a, b);
                }
            }
            return get<0>(pq.top());
        }
    };
    
  • import heapq
    
    
    class Solution(object):
      def kthSmallest(self, matrix, k):
        """
        :type matrix: List[List[int]]
        :type k: int
        :rtype: int
        """
        visited = {(0, 0)}
        heap = [(matrix[0][0], (0, 0))]
    
        while heap:
          val, (i, j) = heapq.heappop(heap)
          k -= 1
          if k == 0:
            return val
          if i + 1 < len(matrix) and (i + 1, j) not in visited:
            heapq.heappush(heap, (matrix[i + 1][j], (i + 1, j)))
            visited.add((i + 1, j))
          if j + 1 < len(matrix) and (i, j + 1) not in visited:
            heapq.heappush(heap, (matrix[i][j + 1], (i, j + 1)))
            visited.add((i, j + 1))
    
    

All Problems

All Solutions