Problem statement


The idea is to traverse graph, using bfs, each time going first to the right children. We keep in our queue pairs of nodes (parent, node). Each time when we traverse level we put node to visited and then check all children. If it happen that children is also visited, it means, that we found our bad node: we need to delete everything starting from u: so we look at parent p of u and cut connection.


Time complexity is O(n), space complexity is O(w).


class Solution:
    def correctBinaryTree(self, root):
        queue, seen = deque([(None, root)]), set()

        while queue:
            for _ in range(len(queue)):
                p, u = queue.popleft()
                for child in filter(None, [u.right, u.left]):
                    if child in seen:
                        if p.left == u: p.left = None
                        else: p.right = None
                        queue.append((u, child))
        return root