Problem statement

https://binarysearch.com/problems/Counting-Maximal-Value-Roots-in-Binary-Tree/

Solution

The idea is to traverse tree and keep maximum of subtree for each node.

Complexity

It is O(n) for time and O(h) for space.

Code

class Solution:
    def solve(self, root):
        def dfs(node):
            if not node: return float("-inf")
            ans = max(node.val, dfs(node.left), dfs(node.right))
            if node.val == ans: self.ans += 1
            return ans

        self.ans = 0
        dfs(root)
        return self.ans