Problem statement

Solution 1

First solution is to notice that if we keep sets for unique colors for each node and when we merget two children and follow the strategy: merge smaller to bigger, than we will have good complexity: I think it can be proved that it is O(n log n). In python when we traverse tree with dfs, what is returned is pointer to array, not the copy, that is why it works.


It is O(n log n) for time and O(n) for space.


class Solution:
    def solve(self, tree, color):
        self.ans = 0

        def dfs(par, node):
            ans, cnt = set([color[node]]), 1
            for child in tree[node]:
                if child == par: continue
                x, cnt_x = dfs(node, child)
                if len(x) > len(ans):
                    x |= ans
                    ans = x
                    ans |= x
                cnt += cnt_x
            self.ans += (len(ans) == cnt)
            return ans, cnt

        dfs(-1, 0)
        return self.ans

Solution 2

There is another solution using euler traversal of graph, or pre-order traversal.

  1. col is colors of nodes in our traversal.
  2. beg is first encounter of node i.
  3. end is the last encounter of node i. Now subtree is kept in [beg, end].

Now, we have two main steps:

  1. Traverse tree to get col, beg, end.
  2. Use sliding window (two pointers approach) to answer for each start what is the rightest point we can reach such that all colors are different.

  3. f is color counter
  4. d is the look back index ranging to where the colors are unique
  5. j is the front pointer for the two pointers way of iterating the unique-color-ranges


It is O(n) for time and space.


class Solution:
    def solve(self, tree, color):
        n = len(color)
        beg, end, col, d = [0]*n, [0]*n ,[0]*n, [0]*n
        self.idx = 0
        def dfs(par, node):
            col[self.idx] = color[node]
            beg[node]= self.idx
            self.idx += 1
            for child in tree[node]:
                if child != par: dfs(node, child)
            end[node] = self.idx

        dfs(-1, 0)
        f = Counter()
        j = -1

        for i in range(n):
            while f[col[i]]:
                j += 1
                f[col[j]] -= 1
            f[col[i]] += 1
            d[i] = j

        return sum(d[end[x] - 1] < beg[x] for x in range(n))


There is also solution using union find.