Problem statement

https://binarysearch.com/problems/Adjacent-Square-Sums/

Solution

Equal to Leetcode 0996 Number of Squareful Arrays

Complexity

Time complexity is O(n^2 * 2^n), space complexity is O(n * 2^n).

Code

class Solution:
    def solve(self, A):
        n = len(A)
        dp = [[0] * n for _ in range(1<<n)]
        for i in range(n): dp[1<<i][i] = 1
            
        for mask in range(1<<n):
            n_z_bits = [j for j in range(n) if mask&(1<<j)]
            for j, k in permutations(n_z_bits, 2):
                if int(sqrt(A[k] + A[j]))**2 == A[k] + A[j]:
                    dp[mask][j] += dp[mask^(1<<j)][k]
                    
        return sum(dp[-1])//prod(factorial(i) for i in Counter(A).values())