Formatted question description: https://leetcode.ca/all/996.html

# 996. Number of Squareful Arrays (Hard)

Given an array A of non-negative integers, the array is squareful if for every pair of adjacent elements, their sum is a perfect square.

Return the number of permutations of A that are squareful.  Two permutations A1 and A2 differ if and only if there is some index i such that A1[i] != A2[i].

Example 1:

Input: [1,17,8]
Output: 2
Explanation:
[1,8,17] and [17,8,1] are the valid permutations.


Example 2:

Input: [2,2,2]
Output: 1


Note:

1. 1 <= A.length <= 12
2. 0 <= A[i] <= 1e9

Related Topics:
Math, Backtracking, Graph

Similar Questions:

## Solution 1. Backtracking

Brute force: next_permutation + squareful check.

Issue: if the leading x elements are not squareful, no matter how the rest of the elements are permutated, the result can’t be squareful. So we need to make sure the permutation is squareful along the way.

Solution: manually implement next_permutation, check if the currently settled subarray is squareful. Skip this leading section if it’s already not squareful.

// OJ: https://leetcode.com/problems/number-of-squareful-arrays/

// Time: O(N!)
// Space: O(N^2)
class Solution {
int ans = 0;
inline bool isSquare(int n) {
int r = sqrt(n);
return r * r == n;
}
void dfs(vector<int> &A, int start) {
if (start == A.size()) {
++ans;
return;
}
unordered_set<int> s;
for (int i = start; i < A.size(); ++i) {
if (s.count(A[i]) || (start - 1 >= 0 && !isSquare(A[start - 1] + A[i]))) continue;
s.insert(A[i]);
swap(A[start], A[i]);
dfs(A, start + 1);
swap(A[start], A[i]);
}
}
public:
int numSquarefulPerms(vector<int>& A) {
dfs(A, 0);
return ans;
}
};


Java

class Solution {
public int numSquarefulPerms(int[] A) {
int length = A.length;
Map<Integer, List<Integer>> graphMap = new HashMap<Integer, List<Integer>>();
Integer[][] memo = new Integer[length][1 << length];
for (int i = 0; i < length; i++) {
for (int j = i + 1; j < length; j++) {
int sum = A[i] + A[j];
int sqrt = (int) Math.sqrt(sum);
if (sqrt * sqrt == sum) {
List<Integer> list1 = graphMap.getOrDefault(i, new ArrayList<Integer>());
List<Integer> list2 = graphMap.getOrDefault(j, new ArrayList<Integer>());
graphMap.put(i, list1);
graphMap.put(j, list2);
}
}
}
int[] factorials = new int;
factorials = 1;
for (int i = 1; i <= 12; i++)
factorials[i] = factorials[i - 1] * i;
int permutations = 0;
for (int i = 0; i < length; i++)
permutations += depthFirstSearch(i, 1 << i, length, graphMap, memo);
Map<Integer, Integer> countMap = new HashMap<Integer, Integer>();
for (int num : A) {
int count = countMap.getOrDefault(num, 0) + 1;
countMap.put(num, count);
}
List<Integer> valuesList = new ArrayList<Integer>(countMap.values());
for (int value : valuesList)
permutations /= factorials[value];
return permutations;
}

public int depthFirstSearch(int node, int visited, int length, Map<Integer, List<Integer>> graphMap, Integer[][] memo) {
if (visited == (1 << length) - 1)
return 1;
if (memo[node][visited] != null)
return memo[node][visited];
int permutations = 0;
List<Integer> list = graphMap.getOrDefault(node, new ArrayList<Integer>());
for (int next : list) {
if (((visited >> next) & 1) == 0)
permutations += depthFirstSearch(next, visited | (1 << next), length, graphMap, memo);
}
memo[node][visited] = permutations;
return permutations;
}
}