https://leetcode.com/problems/number-of-ways-to-reconstruct-a-tree

Quite painful problem, to be honest. Not a lot of people managed to solve it during contest.

First idea we need to understand, is that A pair [xi, yi] exists in pairs if and only if xi is an ancestor of yi or yi is an ancestor of xi. Let us illustrate this on example of tree:

5 | \ 4 3 | \ 1 2

We will have pairs [1, 4], [1, 5], [2,4], [2,5], [3,5], [4,5], that is all pairs with two nodes, on which is ancestor of another. What we can notice, if it is possible to create such a tree, that one node will be ancestor of all other nodes, that is in case, there will be pairs [1, 5], [2, 5], [3, 5], [4, 5].

Now, let us go to the algorithm:

  1. First, create graph from our pairs array.
  2. Function helper(nodes) have input set of nodes and try to solve subproblem with this set.
  3. We discussed, that we need to find node which is connected with all other nodes, that is its degree should be equal to m = len(nodes) - 1.
  4. If we did not found such a node, we can immedietly return 0: it is not possible to recunstruct a tree.
  5. If there is such a node, let us denote it root. Actually, it can happen that there is more than one node with this property and in this case any of them can be root and if we have say U possible solution with first of them being root, that there will be exactly U possible solutions if another will be root - we can just replace one of them with another.
  6. Now, we clean our graph g: we remove all connections between root and all its neighbours (actually it is enough to cut only one way connections)
  7. Next step is to perform dfs to find connected components.
  8. Finally, we run our helper function on each of found components: if it happens that for some subproblem we have answer 0, it means that for original problem we also need to return 0. If 0 not here and we met 2 in our cands, than we return 2: it means, that we found 2 or more solutions for some subproblems and 1 for others, like cands = [1, 2, 1, 1, 2]. In this case, we can be sure, that we have more than 2 solutions (in fact we have 1*2*1*1*2). Finally, important point: if we have len(d[m]) >= 2, we also return 2 - see step 5.

Complexity: this is the most interesting part. Imagine, that we have n nodes in our graph. Then on each step we split our data into several parts and run helper function recursively on each part, such that sum of sizes is equal to n-1 on the first step and so on. So, each element in helper(nodes) can be used no more than n times, and we can estimate it as O(n^2). Also, we need to consider for node in g[root]: g[node].remove(root) line, which for each run of helper can be estimated as O(E), where E is total number of edges in g. Finally, there is part, where we look for connected components, which can be estimated as O(E*n), because we will have no more than n levels of recursion and on each level we use each edge no more than 1 time. So, final time complexity is O(E*n). I think, this estimate can be improved to O(n^2), but I do not know at the moment, how to do it. Space complexity is O(n^2) to keep all stack of recursion.

class Solution:
    def checkWays(self, P):
        g = defaultdict(set)
        for u, v in P:
            g[u].add(v)
            g[v].add(u)

        def helper(nodes):
            d, m = defaultdict(list), len(nodes) - 1
            for node in nodes:
                d[len(g[node])].append(node)

            if len(d[m]) == 0: return 0
            root = d[m][0]
            
            for node in g[root]: g[node].remove(root)
            
            comps, seen, i = defaultdict(set), set(), 0
            def dfs(node, i):
                comps[i].add(node)
                seen.add(node)
                for neib in g[node]:
                    if neib not in seen: dfs(neib, i)
                        
            for node in nodes:
                if node != root and node not in seen:
                    dfs(node, i)
                    i += 1
                    
            cands = [helper(comps[i]) for i in comps]
            if 0 in cands: return 0
            if 2 in cands: return 2
            if len(d[m]) >= 2: return 2
            return 1
            
        return helper(set(g.keys()))

If you like the solution, you can upvote it on leetcode discussion section: Problem 1719