Problem statement

https://binarysearch.com/problems/Flipped-Matrix-Prequel/

Solution

The idea is to calculate sum for each row and column and then we can calculate update in O(1) for every cell.

Complexity

It is O(mn) for time and O(m + n) for space.

Code

class Solution:
    def solve(self, M):
        m, n = len(M[0]), len(M)
        rows = [sum(row) for row in M]
        cols = [sum(col) for col in zip(*M)]
        upd = -float("inf")
        for i in range(n):
            for j in range(m):
                upd = max(upd, n + m - 2*rows[i] - 2*cols[j] + 4*M[i][j] - 2)

        return sum(rows) + upd