Welcome to Subscribe On Youtube
Formatted question description: https://leetcode.ca/all/1803.html
1803. Count Pairs With XOR in a Range
Level
Hard
Description
Given a (0-indexed) integer array nums
and two integers low
and high
, return the number of nice pairs.
A nice pair is a pair (i, j)
where 0 <= i < j < nums.length
and low <= (nums[i] XOR nums[j]) <= high
.
Example 1:
Input: nums = [1,4,2,7], low = 2, high = 6
Output: 6
Explanation: All nice pairs (i, j) are as follows:
- (0, 1): nums[0] XOR nums[1] = 5
- (0, 2): nums[0] XOR nums[2] = 3
- (0, 3): nums[0] XOR nums[3] = 6
- (1, 2): nums[1] XOR nums[2] = 6
- (1, 3): nums[1] XOR nums[3] = 3
- (2, 3): nums[2] XOR nums[3] = 5
Example 2:
Input: nums = [9,8,4,2,1], low = 5, high = 14
Output: 8
Explanation: All nice pairs (i, j) are as follows:
- (0, 2): nums[0] XOR nums[2] = 13
- (0, 3): nums[0] XOR nums[3] = 11
- (0, 4): nums[0] XOR nums[4] = 8
- (1, 2): nums[1] XOR nums[2] = 12
- (1, 3): nums[1] XOR nums[3] = 10
- (1, 4): nums[1] XOR nums[4] = 9
- (2, 3): nums[2] XOR nums[3] = 6
- (2, 4): nums[2] XOR nums[4] = 5
Constraints:
1 <= nums.length <= 2 * 10^4
1 <= nums[i] <= 2 * 10^4
1 <= low <= high <= 2 * 10^4
Solution
Calculate the highest bit of high
, which is highBit
. Sort the array nums
and count the pairs using nested loops with pruning.
Let 0 <= i < j < nums.length
. If nums[i] < highBit * 2 && nums[j] >= highBit * 2
, then do not search for more indices j
for the current index i
. This is the pruning operation.
Finally, return the number of nice pairs.
-
class Solution { public int countPairs(int[] nums, int low, int high) { int highBit = 1 << ((int) (Math.log(high) / Math.log(2))); int pairs = 0; Arrays.sort(nums); int length = nums.length; for (int i = 0; i < length; i++) { for (int j = i + 1; j < length; j++) { int xor = nums[i] ^ nums[j]; if (xor >= low && xor <= high) pairs++; else if (nums[i] < highBit * 2 && nums[j] >= highBit * 2) break; } } return pairs; } }
-
// OJ: https://leetcode.com/problems/count-pairs-with-xor-in-a-range/ // Time: O(N) // Space: O(1) struct TrieNode { TrieNode *next[2] = {}; int cnt = 0; }; const int BIT = 15; class Solution { void addNode(TrieNode *node, int n) { for (int i = BIT; i >= 0; --i) { int b = n >> i & 1; if (!node->next[b]) node->next[b] = new TrieNode(); node = node->next[b]; node->cnt++; } } int count(TrieNode *node, int n, int low, int high, int i = BIT, int rangeMin = 0, int rangeMax = (1 << (BIT + 1)) - 1) { if (rangeMin >= low && rangeMax <= high) return node->cnt; if (rangeMax < low || rangeMin > high) return 0; int ans = 0, b = n >> i & 1, r = 1 - b, mask = 1 << i; if (node->next[b]) ans += count(node->next[b], n, low, high, i - 1, rangeMin & ~mask, rangeMax & ~mask); if (node->next[r]) ans += count(node->next[r], n, low, high, i - 1, rangeMin | mask, rangeMax | mask); return ans; } public: int countPairs(vector<int>& A, int low, int high) { TrieNode root; int ans = 0; for (int n : A) { ans += count(&root, n, low, high); addNode(&root, n); } return ans; } };
-
class Trie: def __init__(self): self.children = [None] * 2 self.cnt = 0 def insert(self, x): node = self for i in range(15, -1, -1): v = x >> i & 1 if node.children[v] is None: node.children[v] = Trie() node = node.children[v] node.cnt += 1 def search(self, x, limit): node = self ans = 0 for i in range(15, -1, -1): if node is None: return ans v = x >> i & 1 if limit >> i & 1: if node.children[v]: ans += node.children[v].cnt node = node.children[v ^ 1] else: node = node.children[v] return ans class Solution: def countPairs(self, nums: List[int], low: int, high: int) -> int: ans = 0 tree = Trie() for x in nums: ans += tree.search(x, high + 1) - tree.search(x, low) tree.insert(x) return ans