Problem statement

https://leetcode.com/problems/number-of-squareful-arrays/

Solution 1

Let us consider states last, BitMask, where last is the last element in built array and BitMask is bitmask for visited indexes. We have O(2^n * n) states with O(n) transitions from one state to another. Transition is possible if A[i] + A[j] is perfect square. In the end we need to take into account that some values of A can be equal, so we need to divide answer by factorials of frequencies.

Complexity

Time complexity is O(n^2 * 2^n), space complexity is O(n * 2^n).

Code

class Solution:
    def numSquarefulPerms(self, A):
        n = len(A)
        dp = [[0] * n for _ in range(1<<n)]
        for i in range(n): dp[1<<i][i] = 1
            
        for mask in range(1<<n):
            n_z_bits = [j for j in range(n) if mask&(1<<j)]
            for j, k in permutations(n_z_bits, 2):
                if int(sqrt(A[k] + A[j]))**2 == A[k] + A[j]:
                    dp[mask][j] += dp[mask^(1<<j)][k]
                    
        return sum(dp[-1])//prod(factorial(i) for i in Counter(A).values())

Solution 2

There is optimization, where we construct graph first, and then traverse only possible edges. Then the goal is to find number of Hamiltonian paths in this graph, which is very similar to TSP problem.

Complexity

Time complexity I think will be O(E * 2^n), where E is number of edges in graph.

Code

class Solution:
    def numSquarefulPerms(self, A):
        n = len(A)
        graph = defaultdict(list)
        for k, j in product(range(n), range(n)):
            if k!=j and int(sqrt(A[k] + A[j]))**2 == A[k] + A[j]:
                graph[k].append(j)
            
        @lru_cache(None)
        def dfs(last, BitMask):
            if 1<<last == BitMask: return 1

            ans = 0
            for prev in graph[last]:
                if BitMask & 1<<prev:
                    ans += dfs(prev, BitMask ^ (1 << last))
            return ans

        ans = sum(dfs(i, (1<<n) - 1) for i in range(n))
        return ans//prod(factorial(i) for i in Counter(A).values())