Problem statement

Solution 1

First of all, let us understand how to solve 1d problem: that is given list nums and number U we need to find the maximum sum of adjacent several elements such that its sum is no more than U. Note, that it is very similar to problem 327. Count of Range Sum, but there the goal was to find not the biggest sum, but number of such sums. However we can use the similar idea: let us add cumulative sums one by one, that is if we have nums = [3, 1, 4, 1, 5, 9, 2, 6], then we add elements [3, 4, 8, 9, 14, 23, 25, 31]. Each time before we add element we do binary search of element s - U: the closest element bigger than s - U. If ind != len(SList), then we update our answer.

When we found how to solve 1-d problem, it is time to work with 2-d problem. Actually we need to solve O(m^2) 1-d problems, to choose numbers i,j such that 1 <= i <=j <= m. What we can do is to calculate cumulative sums for each column and then for each pair create list of differences and apply our countRangeSum function.


Time complexity of 1-d problem is O(n log n), so time complexity of all algorithm is O(m^2*n log n). It can be make O(n^2 * m log m) if we rotate our matrix, but in practice it works similar for me. Space complexity is O(mn).


from sortedcontainers import SortedList
class Solution:
    def maxSumSubmatrix(self, M, k):
        def countRangeSum(nums, U):
            SList, ans = SortedList([0]), -float("inf")
            for s in accumulate(nums):
                idx = SList.bisect_left(s - U) 
                if idx < len(SList): ans = max(ans, s - SList[idx])        
            return ans
        m, n, ans = len(M), len(M[0]), -float("inf")
        for i, j in product(range(1, m), range(n)):
            M[i][j] += M[i-1][j]
        M = [[0]*n] + M
        for r1, r2 in combinations(range(m + 1), 2):
            row = [j - i for i, j in zip(M[r1], M[r2])]
            ans = max(ans, countRangeSum(row, k))
        return ans

Solution 2:

We can use slightly different function countRangeSum, where instead of SortedList we use usual list and insort function. Complexity is O(n^2), however because n is not very big, it works even faster than previous method, like 2-3 times!


Time complexity is O(n^2*m^2), but with very small constant. Space complexity is O(mn).


def countRangeSum(nums, U):
    SList, ans = [0], -float("inf")
    for s in accumulate(nums):
        idx = bisect_left(SList, s - U) 
        if idx < len(SList): ans = max(ans, s - SList[idx])        
        bisect.insort(SList, s)
    return ans