Let us use dfs(level, node) function, where:

  1. level is distance between root of our tree and current node we are in.
  2. result of this function is the distance between node and its farthest children: that is the largest numbers of steps we need to make to reach leaf.

So, how exactly this function will work:

  1. If not node, then we are in the leaf and result is 0.
  2. in other case, we run recursively dfs for node.left and node.right.
  3. What we keep in our cand: in first value sum of distances to root and to farthest leaf, second value is distance to root and final value is current node. Note, that cand[0] represent length of the longest path going from root to leaf, through our node.
  4. if cand[0] > self.ans[0]: it means that we found node with longer path going from root to leaf, and it means that we need to update self.ans.
  5. Also, if cand[0] = self.ans[0] and also lft = rgh, it means that we are now on the longest path from root to leaf and we have a fork: we can go either left or right and in both cases will have longest paths. Note, that we start from root of our tree, so it this way we will get fork which is the closest to our root, and this is exactly what we want in the end.
  6. Finally, we return cand[0] - cand[1]: distance from node to farthest children.

Complexity: time complexity is O(n): to visit all nodes in tree. Space complexity is O(h), where h is height of tree.

class Solution:
    def subtreeWithAllDeepest(self, root):
        self.ans = (0, 0, None)
        def dfs(lvl, node):
            if not node: return 0
            lft = dfs(lvl + 1, node.left)
            rgh = dfs(lvl + 1, node.right)
            cand = (max(lft, rgh) + 1 + lvl, lvl, node)

            if cand[0] > self.ans[0] or cand[0] == self.ans[0] and lft == rgh:
                self.ans = cand
            return cand[0] - cand[1]
        dfs(0, root)
        return self.ans[2]

