Problem statement

https://leetcode.com/problems/find-k-pairs-with-smallest-sums/

Solution

There is bruteforce solution with O(mn * log(mn)) complexity.

We can do better: let us denote $x_{i,j} = nums1[i] + nums2[j]$ - we are not going to evaluate all sums, we do it just for simpler notations. On the first step put numbers $x_{0,0}, \dots, x_{n-1,0}$ into heap: first elements in each row. Then at each iteration we find the smallest element, remove it, and put the next element in given row, for example remove $x_{3,0}$ and put $x_{3,1}$. At any moment of time we have elements $x_{0,i_0}, \dots, x_{n-1,i_{n-1}}$ in our heap, that is one element from each row (actually $\leqslant 1$ element, if row is finished). Also, if $k<n$, we can consider only first $k$ rows.

Complexity

Time complexity is $O(n + k\log n)$, for $k > n$ and $O(k\log k)$ for $k<n$. We can keep only the second complexity, because $\log n$ and $\log n^2$ differ only $2$ times.

Code

class Solution:
    def kSmallestPairs(self, nums1, nums2, k):
        n1, n2 = len(nums1), len(nums2)
        heap = [(num + nums2[0], i, 0) for i, num in enumerate(nums1)]
        heapify(heap)
        result = []
        
        for _ in range(min(k, n1*n2)):
            num, ind1, ind2 = heappop(heap)
            result.append([nums1[ind1], nums2[ind2]])
            if ind2 < n2 - 1:
                heappush(heap, (nums1[ind1] + nums2[ind2+1], ind1, ind2+1))
        
        return result