Problem statement

https://leetcode.com/problems/bricks-falling-when-hit/

Solution

The idea of this problem is use Union Find, but here we need to do it in opposite direction. Also we need to add two additional pieces of information to our data structure: size of each disjoint set and size of the top set. We add one brick which is connected with all top layers, so we will have $mn + 1$ elements in our DSU.

Now, our algorithm will look like:

  1. First, we create array A with zeroes on places, where we need to remove bricks (we use A = [row[:] for row in grid] to make deep copy of grid).
  2. Next, we create empty DSU on mn + 1 elements: mn for cells of our grid and the last one for virtual cell which connected with top layer. Also we need to find connected components now: we iterate over A and for each non-zero element we union it with all neighbors.
  3. Final step is to traverse our hits[::-1]: first of all we find how many elements connected with our virtual top cell (we need it later). We check, if the cell was empty in original grid and if it was, just return add 0 to answer. If it was not empty, we need to connect given cell with neighbors if they are not empty. Also connect cell to the virtual top cell if it is in the top row. Finally, we again evaluate number of cells connected with our virtual top cell and return difference between this this time and previous time we counted minus one. Why? Because if we look in normal order number of dropped bricks (if we have any) will be difference between how many elements we have in top component before minus after.

Complexity

Time complexity is O((mn + Q) * log (mn)), if we use DSU without compression of path, where Q is length of hits. If we use compression of path, it will be O((mn+Q) * A(mn)), where A(.) is Inverse-Ackermann function.

Code

class DSU:
    def __init__(self, N):
        self.par = list(range(N))
        self.rnk = [0] * N
        self.sz = [1] * N
        self.N = N

    def find(self, x):
        if self.par[x] != x:
            self.par[x] = self.find(self.par[x])
        return self.par[x]

    def union(self, x, y):
        xr, yr = self.find(x), self.find(y)
        if xr == yr: return
        if self.rnk[xr] < self.rnk[yr]:
            xr, yr = yr, xr
        if self.rnk[xr] == self.rnk[yr]:
            self.rnk[xr] += 1

        self.par[yr] = xr
        self.sz[xr] += self.sz[yr]

    def size(self, x):
        return self.sz[self.find(x)]

    def top(self):
        return self.size(self.N - 1) - 1

class Solution(object):
    def hitBricks(self, grid, hits):
        m, n = len(grid), len(grid[0])
        neibs = [[1,0],[-1,0],[0,1],[0,-1]]
        def index(x, y):
            return x * n + y

        A = [row[:] for row in grid]
        for i, j in hits: A[i][j] = 0

        dsu = DSU(m*n + 1)
        for x, y in product(range(m), range(n)):
            if not A[x][y]: continue
            i = index(x, y)
            if x == 0: dsu.union(i, m*n)
            for dx, dy in neibs:
                if 0 <= x+dx < m and 0 <= y+dy < n and A[x+dx][y+dy]:
                    dsu.union(i, index(x+dx, y+dy))
        
        ans = []
        for x, y in hits[::-1]:
            pre_roof = dsu.top()
            if grid[x][y] == 0:
                ans.append(0)
            else:
                i = index(x, y)
                for dx, dy in neibs:
                    if 0 <= x+dx < m and 0 <= y+dy < n and A[x+dx][y+dy]: 
                        dsu.union(i, index(x+dx, y+dy))
                if x == 0:
                    dsu.union(i, m*n)
                A[x][y] = 1
                ans.append(max(0, dsu.top() - pre_roof - 1))
        return ans[::-1]