Problem statement


Let us do dfs with parameters (par, node, state), where

  1. node is current node.
  2. par is parent of current node.
  3. state is list of m = max(nums) + 1 numbers, where state[i] is closest ancestor which is coprime with i.

Each time when we traverse node, first we update ans[node] and then run dfs recursively for all children: we need to upate all places in state, which are coprime with nums[node]: for these indexes state2[i] = node.


Time complexity is O(n*m), space is O(H*m).


class Solution:
    def getCoprimes(self, nums, edges):
        n, m = len(nums), max(nums) + 1
        ans = [-1]*n
        G = defaultdict(list)
        for x, y in edges:
            G[x] += [y]
            G[y] += [x]
        def dfs(par, node, state):
            ans[node] = state[nums[node]]
            for child in G[node]:
                if child == par: continue
                state2 = state[:]
                for i in range(m):
                    if gcd(i, nums[node]) == 1:
                        state2[i] = node
                dfs(node, child, state2)
        dfs(-1, 0, [-1]*m)
        return ans

1617 Count Subtrees With Max Distance Between Cities

[graph, graph algo, dfs, bfs]


Constraints given in problem is quite small: n <= 15, so it is possible to check all 2^n possible subtrees.

  1. dfs(node, mask) is function to traverse tree, but we can visite only allowed nodes given by mask. Here self.V is bitmask of already visited nodes. Each time we add node to visited bitmask and then traverse all neighbours. Here I use condition if not ((~self.V & mask) >> neib) & 1 which check if corresponding bit from mask is turned off or if bit from self.V is turned on - in this case we can not visit.
  2. diam(node, mask) is function to find diameter of tree: it will recursively find the two longest paths starting with node, update self.res: longest path which go through node and return longest part from node to leaf of tree.
  3. Finally, we run dfs for every of 2^n possible bitmasks and if tree is connected: that is visited set is equal to mask, we calculate diamater and increase corresponding element from ans by one.


Time complexity is O(2^n * n), space complexity is O(n).


class Solution:
    def countSubgraphsForEachDiameter(self, n, edges):
        G = defaultdict(list)
        ans = [0]*(n-1)
        for x, y in edges:
            G[x-1] += [y-1]
            G[y-1] += [x-1]
        def dfs(node, mask):
            self.V |= (1<<node)
            for neib in G[node]:
                if not ((~self.V & mask) >> neib) & 1: continue
                dfs(neib, mask)
        def diam(node, mask):
            self.V |= (1<<node)
            maxs = [0, 0]
            for neib in G[node]:
                if not ((~self.V & mask) >> neib) & 1: continue
                depth = diam(neib, mask)
                maxs = sorted(maxs + [depth])[-2:]
            self.res = max(self.res, sum(maxs))
            return maxs[1] + 1
        for mask in range(1, 1<<n):
            self.V = 0
            start = mask.bit_length() - 1
            dfs(start, mask)
            if self.V == mask:
                self.V, self.res = 0, 0
                diam(start, mask)
                if self.res > 0: ans[self.res - 1] += 1
        return ans

Solution 2

It is a bit easier to do O(2^n * n^2) algorithm which should pass given constraints: we can calcualte distances between each pairs of nodes and then when we look for diameter check O(n^2) pairs.


It is O(2^n * n^2) for time and O(n^2) for space.


class Solution:
    def countSubgraphsForEachDiameter(self, n, edges):
        g = [[999 for _ in range(n)] for _ in range(n)]
        for [i, j] in edges:
            g[i - 1][j - 1], g[j - 1][i - 1] = 1, 1

        for k, i, j in permutations(range(n), 3):
            g[i][j] = min(g[i][j], g[i][k] + g[k][j])

        ans = [0 for _ in range(n - 1)]
        for k in range(2, n + 1):
            for s in combinations(range(n), k):
                e = sum(g[i][j] for i, j in combinations(s, 2) if g[i][j] == 1)
                d = max(g[i][j] for i, j in combinations(s, 2))
                if e == k - 1:
                    ans[d - 1] += 1
        return ans