Problem statement

https://leetcode.com/problems/correct-a-binary-tree/

Solution

The idea is to traverse graph, using bfs, each time going first to the right children. We keep in our queue pairs of nodes (parent, node). Each time when we traverse level we put node to visited and then check all children. If it happen that children is also visited, it means, that we found our bad node: we need to delete everything starting from u: so we look at parent p of u and cut connection.

Complexity

Time complexity is O(n), space complexity is O(w).

Code

class Solution:
    def correctBinaryTree(self, root):
        queue, seen = deque([(None, root)]), set()

        while queue:
            for _ in range(len(queue)):
                p, u = queue.popleft()
                seen.add(u)
                for child in filter(None, [u.right, u.left]):
                    if child in seen:
                        if p.left == u: p.left = None
                        else: p.right = None
                    else:
                        queue.append((u, child))
                        
        return root