Formatted question description: https://leetcode.ca/all/996.html

996. Number of Squareful Arrays (Hard)

Given an array A of non-negative integers, the array is squareful if for every pair of adjacent elements, their sum is a perfect square.

Return the number of permutations of A that are squareful.  Two permutations A1 and A2 differ if and only if there is some index i such that A1[i] != A2[i].

 

Example 1:

Input: [1,17,8]
Output: 2
Explanation: 
[1,8,17] and [17,8,1] are the valid permutations.

Example 2:

Input: [2,2,2]
Output: 1

 

Note:

  1. 1 <= A.length <= 12
  2. 0 <= A[i] <= 1e9

Related Topics:
Math, Backtracking, Graph

Similar Questions:

Solution 1. Backtracking

Brute force: next_permutation + squareful check.

Issue: if the leading x elements are not squareful, no matter how the rest of the elements are permutated, the result can’t be squareful. So we need to make sure the permutation is squareful along the way.

Solution: manually implement next_permutation, check if the currently settled subarray is squareful. Skip this leading section if it’s already not squareful.

// OJ: https://leetcode.com/problems/number-of-squareful-arrays/
// Time: O(N!)
// Space: O(N^2)
class Solution {
    int ans = 0;
    inline bool isSquare(int n) {
        int r = sqrt(n);
        return r * r == n;
    }
    void dfs(vector<int> &A, int start) {
        if (start == A.size()) {
            ++ans;
            return;
        }
        unordered_set<int> s;
        for (int i = start; i < A.size(); ++i) {
            if (s.count(A[i]) || (start - 1 >= 0 && !isSquare(A[start - 1] + A[i]))) continue;
            s.insert(A[i]);
            swap(A[start], A[i]);
            dfs(A, start + 1);
            swap(A[start], A[i]);
        }
    }
public:
    int numSquarefulPerms(vector<int>& A) {
        dfs(A, 0);
        return ans;
    }
};

Java

  • class Solution {
        public int numSquarefulPerms(int[] A) {
            int length = A.length;
            Map<Integer, List<Integer>> graphMap = new HashMap<Integer, List<Integer>>();
            Integer[][] memo = new Integer[length][1 << length];
            for (int i = 0; i < length; i++) {
                for (int j = i + 1; j < length; j++) {
                    int sum = A[i] + A[j];
                    int sqrt = (int) Math.sqrt(sum);
                    if (sqrt * sqrt == sum) {
                        List<Integer> list1 = graphMap.getOrDefault(i, new ArrayList<Integer>());
                        List<Integer> list2 = graphMap.getOrDefault(j, new ArrayList<Integer>());
                        list1.add(j);
                        list2.add(i);
                        graphMap.put(i, list1);
                        graphMap.put(j, list2);
                    }
                }
            }
            int[] factorials = new int[13];
            factorials[0] = 1;
            for (int i = 1; i <= 12; i++)
                factorials[i] = factorials[i - 1] * i;
            int permutations = 0;
            for (int i = 0; i < length; i++)
                permutations += depthFirstSearch(i, 1 << i, length, graphMap, memo);
            Map<Integer, Integer> countMap = new HashMap<Integer, Integer>();
            for (int num : A) {
                int count = countMap.getOrDefault(num, 0) + 1;
                countMap.put(num, count);
            }
            List<Integer> valuesList = new ArrayList<Integer>(countMap.values());
            for (int value : valuesList)
                permutations /= factorials[value];
            return permutations;
        }
    
        public int depthFirstSearch(int node, int visited, int length, Map<Integer, List<Integer>> graphMap, Integer[][] memo) {
            if (visited == (1 << length) - 1)
                return 1;
            if (memo[node][visited] != null)
                return memo[node][visited];
            int permutations = 0;
            List<Integer> list = graphMap.getOrDefault(node, new ArrayList<Integer>());
            for (int next : list) {
                if (((visited >> next) & 1) == 0)
                    permutations += depthFirstSearch(next, visited | (1 << next), length, graphMap, memo);
            }
            memo[node][visited] = permutations;
            return permutations;
        }
    }
    
  • // OJ: https://leetcode.com/problems/number-of-squareful-arrays/
    // Time: O(N!)
    // Space: O(N^2)
    class Solution {
        int ans = 0;
        inline bool isSquare(int n) {
            int r = sqrt(n);
            return r * r == n;
        }
        void dfs(vector<int> &A, int start) {
            if (start == A.size()) {
                ++ans;
                return;
            }
            unordered_set<int> s;
            for (int i = start; i < A.size(); ++i) {
                if (s.count(A[i]) || (start - 1 >= 0 && !isSquare(A[start - 1] + A[i]))) continue;
                s.insert(A[i]);
                swap(A[start], A[i]);
                dfs(A, start + 1);
                swap(A[start], A[i]);
            }
        }
    public:
        int numSquarefulPerms(vector<int>& A) {
            dfs(A, 0);
            return ans;
        }
    };
    
  • # 996. Number of Squareful Arrays
    # https://leetcode.com/problems/number-of-squareful-arrays/
    
    class Solution:
        def numSquarefulPerms(self, nums: List[int]) -> int:
            n = len(nums)
            nums.sort()
            res = 0
            used = [False] * n
            
            @cache
            def good(m):
                sq = int(math.sqrt(m))
                
                return sq * sq == m
            
            def perm(path):
                nonlocal res
                
                if len(path) == n:
                    res += 1
                    return
                
                for i in range(n):
                    if used[i]: continue
                        
                    if i > 0 and nums[i] == nums[i - 1] and not used[i - 1]: continue
                    
                    if len(path) == 0 or (len(path) > 0 and good(path[-1] + nums[i])):
                        used[i] = True
                        perm(path + [nums[i]])
                        used[i] = False
                    
            perm([])        
    
            return res
    
    

All Problems

All Solutions