Problem statement

https://leetcode.com/problems/falling-squares/

Solution

There is straightforward O(n^2) solution, where we first instead of coordinates of left and right angle of squares, we go to ranks (for example [1, 100], [3, 200] will be [1, 3], [2, 4] and then we iterate over positions and update each times heights in O(n).

There is also quite difficult to code O(n*log n) solution, using segment tree. What we need here is to be able to have the following data structure. Given numbers $x_1,\dots, x_n$, we want:

  1. Find maximum (or another function query_fn, properties will be discussed later) on range [l, r], this is query(self, v, tl, tr, l, r) and we want to do it in O(log n) complexity. We need to run it with arguments (1, 0, n-1, l, r) to get answer on range [l, r]. Here v is index of segment, 1 corresponds to range [0, n-1] and each segment have two children with indexes 2*v and 2*v+1, that is why we start indexing with 1, not 0.
  2. Update values for whole range [l, r], using function update_fn. Note, that it is more difficult, that just to update one value, and we need to use so-called lazy updates. This is function update(self, v, tl, tr, l, r, h) We need to run it with arguments (1, 0, n-1, l, r, h). Time complexity is also $O(\log n)$.

Now, let us discuss possible choices of query and update functions. Denote update_fn by $\otimes$ and query_fn by $\oplus$, then they should have the following properties:

  1. We have semigroup for binary operation $\oplus$: $(a\oplus b)\oplus c = a\oplus (b\oplus c); $
  2. We have monoid for binary operation $\otimes$: $(a\otimes b)\otimes c = a\otimes (b\otimes c)$, $a\otimes e = e\otimes a = a$
  3. $\otimes$ is right-distributive over $\oplus$: $(a\otimes c)\oplus (b\otimes c) = (a\oplus b) \otimes c$

Note also, that we can easily create so-called sparse segment tree: instead of list with zeroes, we can keep dictionary. Then space complexity for whole tree will be O(k log N) instead of O(N), where k is number of elements in tree.

Complexity

It is O(n log n) for time and O(n) for space.

Code

class SegmentTree:
    def __init__(self, N, update_fn, query_fn):
        self.UF, self.QF = update_fn, query_fn
        self.T = defaultdict(int)  
        self.L = defaultdict(int)
    
    def push(self, v):
        for u in [2*v, 2*v+1]:
            self.T[u] = self.UF(self.T[u], self.L[v])
            self.L[u] = self.UF(self.L[u], self.L[v])
        self.L[v] = 0

    def update(self, v, tl, tr, l, r, h):
        if l > r: return
        if l == tl and r == tr:
            self.T[v] = self.UF(self.T[v], h)
            self.L[v] = self.UF(self.L[v], h)
        else:
            self.push(v)
            tm = (tl + tr)//2
            self.update(v*2, tl, tm, l, min(r, tm), h)
            self.update(v*2+1, tm+1, tr, max(l, tm+1), r, h)
            self.T[v] = self.QF(self.T[v*2], self.T[v*2+1])

    def query(self, v, tl, tr, l, r):
        if l > r: return -float("inf")
        if l <= tl and tr <= r: return self.T[v]
        self.push(v)
        tm = (tl + tr)//2
        return self.QF(self.query(v*2, tl, tm, l, min(r, tm)), self.query(v*2+1, tm+1, tr, max(l, tm+1), r))
    
class Solution:
    def fallingSquares(self, positions):
        V = set()
        for p, w in positions:
            V.add(p)
            V.add(p + w)
        rank = {x: i for i, x in enumerate(sorted(V))}
        
        STree = SegmentTree(len(rank), max, max)
        res, best = [], 0
        
        for p, w in positions:
            L, R = rank[p], rank[p + w] - 1
            H = STree.query(1, 0, len(rank) - 1, L, R) + w
            STree.update(1, 0, len(rank) - 1, L, R, H)
            best = max(best, H)
            res.append(best)
        return res