Problem statement

https://leetcode.com/problems/find-the-kth-smallest-sum-of-a-matrix-with-sorted-rows/

Solution 1

The idea is to keep heap, where we keep tuple (sum, i1, ... i_{m-1}): total sum and indexes of elements from each row. When we extract new element, we add m new candidates.

Complexity

We can have upto O(m*k) element in our heap in the end, because we add m elements on each step. Size of each element is O(n), so one extraction will be O(log(mk)*O(n)) and O(nk*log(mk)) for all extractions. Also we do a lot of pushes, m times bigger, so we have O(nmk*log(mk)) complexity. There is also time for working with tuples, which iteration it is O(mn), so total factor O(mnk) will be smaller. Final time complexity is O(mnk*log(mk)). Space complexity is O(mnk). Factor log(mk) can be reduced to log(k) if we keep size of heap no more than k.

Code

class Solution:
    def kthSmallest(self, mat, k):
        m, n = len(mat), len(mat[0])
        h = [(sum(r[0] for r in mat),) + (0,)*m]
        visited = set()
        
        for _ in range(k - 1):
            elem = heappop(h)
            for i in range(m):
                if elem[i + 1] + 1 >= n: continue
                tmp = list(elem)
                tmp[0] += (mat[i][elem[i+1] + 1] - mat[i][elem[i+1]])
                tmp[i + 1] += 1
                new = tuple(tmp)
                if new not in visited: 
                    heappush(h, new)
                    visited.add(new)
        
        return h[0][0]

Solution 2

See similar problem 0373 Find K pairs with smallest sums. Actually we can apply it several times.

Complexity

Time complexity will be O(k*log k*m).

Code

class Solution:
    def kthSmallest(self, mat, k):
        def kSmallestPairs(nums1, nums2, k):
            if not nums1 or not nums2: return []
            n1, n2 = len(nums1), len(nums2)
            heap = [(num + nums2[0], i, 0) for i, num in enumerate(nums1)]
            heapify(heap)
            result = []

            for _ in range(min(k, n1*n2)):
                num, ind1, ind2 = heappop(heap)
                result.append(nums1[ind1] + nums2[ind2])
                if ind2 < n2 - 1:
                    heappush(heap, (nums1[ind1] + nums2[ind2+1], ind1, ind2+1))

            return result
        
        m, n = len(mat), len(mat[0])
        res = mat[0]
        for i in range(1, m):
            res = kSmallestPairs(res, mat[i], k)
        return res[k-1]