Problem statement

https://binarysearch.com/problems/Tree-with-Distinct-Parities/

Solution

Traverse tree and keep sum of elements in subtree.

Complexity

It is O(n) for time and O(h) for space.

Code

class Solution:
    def solve(self, root):
        def dfs(node):
            if not node: return 0
            lft = dfs(node.left)
            rgh = dfs(node.right)
            total = lft + rgh + node.val
            if node.left and node.right and (total - node.val) % 2 == 1:
                self.ans += 1
            return total

        self.ans = 0
        dfs(root)
        return self.ans