Formatted question description: https://leetcode.ca/all/1439.html

1439. Find the Kth Smallest Sum of a Matrix With Sorted Rows (Hard)

You are given an m * n matrix, mat, and an integer k, which has its rows sorted in non-decreasing order.

You are allowed to choose exactly 1 element from each row to form an array. Return the Kth smallest array sum among all possible arrays.

 

Example 1:

Input: mat = [[1,3,11],[2,4,6]], k = 5
Output: 7
Explanation: Choosing one element from each row, the first k smallest sum are:
[1,2], [1,4], [3,2], [3,4], [1,6]. Where the 5th sum is 7.  

Example 2:

Input: mat = [[1,3,11],[2,4,6]], k = 9
Output: 17

Example 3:

Input: mat = [[1,10,10],[1,4,5],[2,3,6]], k = 7
Output: 9
Explanation: Choosing one element from each row, the first k smallest sum are:
[1,1,2], [1,1,3], [1,4,2], [1,4,3], [1,1,6], [1,5,2], [1,5,3]. Where the 7th sum is 9.  

Example 4:

Input: mat = [[1,1,10],[2,2,9]], k = 7
Output: 12

 

Constraints:

  • m == mat.length
  • n == mat.length[i]
  • 1 <= m, n <= 40
  • 1 <= k <= min(200, n ^ m)
  • 1 <= mat[i][j] <= 5000
  • mat[i] is a non decreasing array.

Related Topics:
Heap

Solution 1.

Ugly answer wrote during the context.

// OJ: https://leetcode.com/problems/find-the-kth-smallest-sum-of-a-matrix-with-sorted-rows/

typedef pair<int, string> Pair;
struct Cmp {
    bool operator()(const Pair &a, const Pair &b) {
        return a.first > b.first;
    }
};
class Solution {
    string hash(vector<int> &idx) {
        string s;
        for (int i = 0; i < idx.size(); ++i) {
            auto h = to_string(idx[i]);
            if (h.size() == 1) h = "0" + h;
            s += h;
        }
        return s;
    }
    vector<int> dehash(string s, int M) {
        vector<int> v(M);
        for (int i = 0; i < M; ++i) v[i] = (s[2 * i] - '0') * 10 + s[2 * i + 1] - '0';
        return v;
    }
public:
    int kthSmallest(vector<vector<int>>& A, int k) {
        int M = A.size(), N = A[0].size(), sum = 0;
        unordered_set<string> seen;
        priority_queue<Pair, vector<Pair>, Cmp> q;
        vector<int> idx(M);
        auto idxHash = hash(idx);
        seen.insert(idxHash);
        for (int i = 0; i < M; ++i) sum += A[i][0];
        q.emplace(sum, idxHash);
        int x = 1;
        while (k--) {
            auto &p = q.top();
            sum = p.first;
            auto v = dehash(p.second, M);
            for (int i = 0; i < M; ++i) {
                if (v[i] + 1 >= N) continue;
                int nsum = sum + A[i][v[i] + 1] - A[i][v[i]];
                v[i]++;
                auto h = hash(v);
                if (!seen.count(h)) {
                    q.emplace(nsum, h);
                    seen.insert(h);
                }
                v[i]--;
            }
            q.pop();
        }
        return sum;
    }
};

Java

class Solution {
    public int kthSmallest(int[][] mat, int k) {
        int rows = mat.length, columns = mat[0].length;
        int sum0 = 0;
        for (int i = 0; i < rows; i++)
            sum0 += mat[i][0];
        if (k == 1)
            return sum0;
        int[] indices = new int[rows];
        PriorityQueue<Integer> priorityQueue = new PriorityQueue<Integer>(new Comparator<Integer>() {
            public int compare(Integer num1, Integer num2) {
                return num2 - num1;
            }
        });
        boolean flag = true;
        for (int i = 0; i < k; i++) {
            int sum = 0;
            for (int row = 0; row < rows; row++)
                sum += mat[row][indices[row]];
            priorityQueue.offer(sum);
            flag = nextState(indices, columns, true);
        }
        if (!flag)
            return priorityQueue.peek();
        while (flag) {
            int curKth = priorityQueue.peek();
            int sum = 0;
            for (int row = 0; row < rows; row++)
                sum += mat[row][indices[row]];
            if (sum < curKth) {
                priorityQueue.offer(sum);
                priorityQueue.poll();
                flag = nextState(indices, columns, true);
            } else
                flag = nextState(indices, columns, false);
        }
        return priorityQueue.peek();
    }

    public boolean nextState(int[] indices, int columns, boolean flag) {
        int rows = indices.length;
        int index = 0;
        if (flag) {
            indices[index]++;
            while (index < rows - 1 && indices[index] >= columns) {
                indices[index] -= columns;
                index++;
                indices[index]++;
            }
        } else {
            while (index < rows - 1 && indices[index] == 0)
                index++;
            if (index == rows - 1)
                return false;
            indices[index] = 0;
            index++;
            indices[index]++;
            while (index < rows - 1 && indices[index] >= columns) {
                indices[index] -= columns;
                index++;
                indices[index]++;
            }
        }
        return indices[rows - 1] < columns;
    }
}

All Problems

All Solutions