# 1803. Count Pairs With XOR in a Range

Hard

## Description

Given a (0-indexed) integer array nums and two integers low and high, return the number of nice pairs.

A nice pair is a pair (i, j) where 0 <= i < j < nums.length and low <= (nums[i] XOR nums[j]) <= high.

Example 1:

Input: nums = [1,4,2,7], low = 2, high = 6

Output: 6

Explanation: All nice pairs (i, j) are as follows:

- (0, 1): nums[0] XOR nums[1] = 5
- (0, 2): nums[0] XOR nums[2] = 3
- (0, 3): nums[0] XOR nums[3] = 6
- (1, 2): nums[1] XOR nums[2] = 6
- (1, 3): nums[1] XOR nums[3] = 3
- (2, 3): nums[2] XOR nums[3] = 5


Example 2:

Input: nums = [9,8,4,2,1], low = 5, high = 14

Output: 8

Explanation: All nice pairs (i, j) are as follows:

- (0, 2): nums[0] XOR nums[2] = 13
- (0, 3): nums[0] XOR nums[3] = 11
- (0, 4): nums[0] XOR nums[4] = 8
- (1, 2): nums[1] XOR nums[2] = 12
- (1, 3): nums[1] XOR nums[3] = 10
- (1, 4): nums[1] XOR nums[4] = 9
- (2, 3): nums[2] XOR nums[3] = 6
- (2, 4): nums[2] XOR nums[4] = 5


Constraints:

• 1 <= nums.length <= 2 * 10^4
• 1 <= nums[i] <= 2 * 10^4
• 1 <= low <= high <= 2 * 10^4

## Solution

Calculate the highest bit of high, which is highBit. Sort the array nums and count the pairs using nested loops with pruning.

Let 0 <= i < j < nums.length. If nums[i] < highBit * 2 && nums[j] >= highBit * 2, then do not search for more indices j for the current index i. This is the pruning operation.

Finally, return the number of nice pairs.

• class Solution {
public int countPairs(int[] nums, int low, int high) {
int highBit = 1 << ((int) (Math.log(high) / Math.log(2)));
int pairs = 0;
Arrays.sort(nums);
int length = nums.length;
for (int i = 0; i < length; i++) {
for (int j = i + 1; j < length; j++) {
int xor = nums[i] ^ nums[j];
if (xor >= low && xor <= high)
pairs++;
else if (nums[i] < highBit * 2 && nums[j] >= highBit * 2)
break;
}
}
return pairs;
}
}

class Trie {
private Trie[] children = new Trie[2];
private int cnt;

public void insert(int x) {
Trie node = this;
for (int i = 15; i >= 0; --i) {
int v = (x >> i) & 1;
if (node.children[v] == null) {
node.children[v] = new Trie();
}
node = node.children[v];
++node.cnt;
}
}

public int search(int x, int limit) {
Trie node = this;
int ans = 0;
for (int i = 15; i >= 0 && node != null; --i) {
int v = (x >> i) & 1;
if (((limit >> i) & 1) == 1) {
if (node.children[v] != null) {
ans += node.children[v].cnt;
}
node = node.children[v ^ 1];
} else {
node = node.children[v];
}
}
return ans;
}
}

class Solution {
public int countPairs(int[] nums, int low, int high) {
Trie trie = new Trie();
int ans = 0;
for (int x : nums) {
ans += trie.search(x, high + 1) - trie.search(x, low);
trie.insert(x);
}
return ans;
}
}

• // OJ: https://leetcode.com/problems/count-pairs-with-xor-in-a-range/
// Time: O(N)
// Space: O(1)
struct TrieNode {
TrieNode *next[2] = {};
int cnt = 0;
};
const int BIT = 15;
class Solution {
void addNode(TrieNode *node, int n) {
for (int i = BIT; i >= 0; --i) {
int b = n >> i & 1;
if (!node->next[b]) node->next[b] = new TrieNode();
node = node->next[b];
node->cnt++;
}
}
int count(TrieNode *node, int n, int low, int high, int i = BIT, int rangeMin = 0, int rangeMax = (1 << (BIT + 1)) - 1) {
if (rangeMin >= low && rangeMax <= high) return node->cnt;
if (rangeMax < low || rangeMin > high) return 0;
int ans = 0, b = n >> i & 1, r = 1 - b, mask = 1 << i;
if (node->next[b]) ans += count(node->next[b], n, low, high, i - 1, rangeMin & ~mask, rangeMax & ~mask);
if (node->next[r]) ans += count(node->next[r], n, low, high, i - 1, rangeMin | mask, rangeMax | mask);
return ans;
}
public:
int countPairs(vector<int>& A, int low, int high) {
TrieNode root;
int ans = 0;
for (int n : A) {
ans += count(&root, n, low, high);
}
return ans;
}
};

• class Trie:
def __init__(self):
self.children = [None] * 2
self.cnt = 0

def insert(self, x):
node = self
for i in range(15, -1, -1):
v = x >> i & 1
if node.children[v] is None:
node.children[v] = Trie()
node = node.children[v]
node.cnt += 1

def search(self, x, limit):
node = self
ans = 0
for i in range(15, -1, -1):
if node is None:
return ans
v = x >> i & 1
if limit >> i & 1:
if node.children[v]:
ans += node.children[v].cnt
node = node.children[v ^ 1]
else:
node = node.children[v]
return ans

class Solution:
def countPairs(self, nums: List[int], low: int, high: int) -> int:
ans = 0
tree = Trie()
for x in nums:
ans += tree.search(x, high + 1) - tree.search(x, low)
tree.insert(x)
return ans


• type Trie struct {
children [2]*Trie
cnt      int
}

func newTrie() *Trie {
return &Trie{}
}

func (this *Trie) insert(x int) {
node := this
for i := 15; i >= 0; i-- {
v := (x >> i) & 1
if node.children[v] == nil {
node.children[v] = newTrie()
}
node = node.children[v]
node.cnt++
}
}

func (this *Trie) search(x, limit int) (ans int) {
node := this
for i := 15; i >= 0 && node != nil; i-- {
v := (x >> i) & 1
if (limit >> i & 1) == 1 {
if node.children[v] != nil {
ans += node.children[v].cnt
}
node = node.children[v^1]
} else {
node = node.children[v]
}
}
return
}

func countPairs(nums []int, low int, high int) (ans int) {
tree := newTrie()
for _, x := range nums {
ans += tree.search(x, high+1) - tree.search(x, low)
tree.insert(x)
}
return
}