Problem statement


Basically what is asked here is to find LCA of two nodes and then find distances from both nodes to this LCA. We can do this in one pass as well.

Let dfs(node, lvl) be number of nodes from (p, q) in subtree with root node, which has lvl distance from root of wholet ree.

  1. if p == q, we return 0.
  2. We run recursively dfs to get lft and rgh.
  3. mid is 1 if one of p or q is equal to node and zero in opposite case.
  4. if mid > 0 it means, that we found one of nodes p, q, so we add lengh from root to node.
  5. if lft + rgh + mid >= 2, it means, that we found LCA, we subtract 2*lvl from answer.
  6. Finally, we return max(lft, mid, rgh) is number of nodes from p, q found in subtree.

Notice, that we can have only one node, where lft + rgh + mid >= 2 and it is our LCA.


Time complexity is O(n), space complexity is O(h).


class Solution:
    def findDistance(self, root: TreeNode, p, q):
        def dfs(node, lvl):
            if not node: return 0
            lft = dfs(node.left, lvl + 1)
            rgh = dfs(node.right, lvl + 1)
            mid = node.val in (p, q)
            if mid > 0: self.ans += lvl
            if lft + rgh + mid >= 2: self.ans -= 2*lvl
            return max(lft, mid, rgh)
        if p == q: return 0
        self.ans = 0
        dfs(root, 0)
        return self.ans