Problem statement

https://leetcode.com/problems/build-binary-expression-tree-from-infix-expression/

Solution

There is solution for this problem with stack with O(n) complexity, but it is a bit difficult to digest. I prefer alternative solution with O(n log n), which I invented and it is easier to code and digest.

Define by level of symbol is how many pairs of brackets it is inside: that is when we have open bracket, level is increased by one and when we have closing bracket, it is decreased by one. We can calculate all levels of all operations +-*/. We will keep in d[0][level] all places for +- and in d[0][level] all places for */.

Now, let us define recursive function calc(beg, end, level) which will construct tree of s[beg:end+1] and level is level of this string.

  1. If beg == end, it means that we reached leaf, return it.
  2. If beg > end, we return None, it can happen if we have case (), which is in fact empty expression.
  3. Next, we are going to find the rightest +- in our string with given level. Why? Because it will be the operation which we need to do first, that is the root of our tree. We recursively run calc for left and right parts of string and attach it to node. If we did not found +-, we will look for */ and do the same logic if we found it. Finally, if we did not found it as well, it means, that we have case (...) and we need to remove brackets and increase level.

Complexity

Time complexity to build d is O(n). However when we traverse tree, we do a lot of binary searches, in fact O(n) of them, so overall time complexity is O(n log n). Space complexity is O(n).

Code

from bisect import bisect

class Solution:
    def expTree(self, s):
        lvl, n = 0, len(s)
        d = [defaultdict(list) for _ in range(2)]
        for i, symb in enumerate(s):
            lvl += (int(symb == "(") - int(symb == ")"))
            if symb in "+-": d[0][lvl].append(i)
            if symb in "*/": d[1][lvl].append(i)

        def calc(beg, end, level):
            if beg == end: return Node(s[beg])
            if beg > end: return None
            for i in range(2):
                ind = bisect(d[i][level], end)
                if ind > 0 and d[i][level][ind-1] >= beg:
                    mid = d[i][level][ind-1]
                    node = Node(s[mid])
                    node.left = calc(beg, mid-1, level)
                    node.right = calc(mid+1, end, level)
                    return node
            return calc(beg + 1, end - 1, level + 1)

        return calc(0, n-1, 0)