Welcome to Subscribe On Youtube

Formatted question description: https://leetcode.ca/all/1803.html

1803. Count Pairs With XOR in a Range

Level

Hard

Description

Given a (0-indexed) integer array nums and two integers low and high, return the number of nice pairs.

A nice pair is a pair (i, j) where 0 <= i < j < nums.length and low <= (nums[i] XOR nums[j]) <= high.

Example 1:

Input: nums = [1,4,2,7], low = 2, high = 6

Output: 6

Explanation: All nice pairs (i, j) are as follows:

- (0, 1): nums[0] XOR nums[1] = 5 
- (0, 2): nums[0] XOR nums[2] = 3
- (0, 3): nums[0] XOR nums[3] = 6
- (1, 2): nums[1] XOR nums[2] = 6
- (1, 3): nums[1] XOR nums[3] = 3
- (2, 3): nums[2] XOR nums[3] = 5

Example 2:

Input: nums = [9,8,4,2,1], low = 5, high = 14

Output: 8

Explanation: All nice pairs (i, j) are as follows:

- (0, 2): nums[0] XOR nums[2] = 13
- (0, 3): nums[0] XOR nums[3] = 11
- (0, 4): nums[0] XOR nums[4] = 8
- (1, 2): nums[1] XOR nums[2] = 12
- (1, 3): nums[1] XOR nums[3] = 10
- (1, 4): nums[1] XOR nums[4] = 9
- (2, 3): nums[2] XOR nums[3] = 6
- (2, 4): nums[2] XOR nums[4] = 5

Constraints:

  • 1 <= nums.length <= 2 * 10^4
  • 1 <= nums[i] <= 2 * 10^4
  • 1 <= low <= high <= 2 * 10^4

Solution

Calculate the highest bit of high, which is highBit. Sort the array nums and count the pairs using nested loops with pruning.

Let 0 <= i < j < nums.length. If nums[i] < highBit * 2 && nums[j] >= highBit * 2, then do not search for more indices j for the current index i. This is the pruning operation.

Finally, return the number of nice pairs.

  • class Solution {
        public int countPairs(int[] nums, int low, int high) {
            int highBit = 1 << ((int) (Math.log(high) / Math.log(2)));
            int pairs = 0;
            Arrays.sort(nums);
            int length = nums.length;
            for (int i = 0; i < length; i++) {
                for (int j = i + 1; j < length; j++) {
                    int xor = nums[i] ^ nums[j];
                    if (xor >= low && xor <= high)
                        pairs++;
                    else if (nums[i] < highBit * 2 && nums[j] >= highBit * 2)
                        break;
                }
            }
            return pairs;
        }
    }
    
    ############
    
    class Trie {
        private Trie[] children = new Trie[2];
        private int cnt;
    
        public void insert(int x) {
            Trie node = this;
            for (int i = 15; i >= 0; --i) {
                int v = (x >> i) & 1;
                if (node.children[v] == null) {
                    node.children[v] = new Trie();
                }
                node = node.children[v];
                ++node.cnt;
            }
        }
    
        public int search(int x, int limit) {
            Trie node = this;
            int ans = 0;
            for (int i = 15; i >= 0 && node != null; --i) {
                int v = (x >> i) & 1;
                if (((limit >> i) & 1) == 1) {
                    if (node.children[v] != null) {
                        ans += node.children[v].cnt;
                    }
                    node = node.children[v ^ 1];
                } else {
                    node = node.children[v];
                }
            }
            return ans;
        }
    }
    
    class Solution {
        public int countPairs(int[] nums, int low, int high) {
            Trie trie = new Trie();
            int ans = 0;
            for (int x : nums) {
                ans += trie.search(x, high + 1) - trie.search(x, low);
                trie.insert(x);
            }
            return ans;
        }
    }
    
  • // OJ: https://leetcode.com/problems/count-pairs-with-xor-in-a-range/
    // Time: O(N)
    // Space: O(1)
    struct TrieNode {
        TrieNode *next[2] = {};
        int cnt = 0;
    };
    const int BIT = 15;
    class Solution {
        void addNode(TrieNode *node, int n) {
            for (int i = BIT; i >= 0; --i) {
                int b = n >> i & 1;
                if (!node->next[b]) node->next[b] = new TrieNode();
                node = node->next[b];
                node->cnt++;
            }
        }
        int count(TrieNode *node, int n, int low, int high, int i = BIT, int rangeMin = 0, int rangeMax = (1 << (BIT + 1)) - 1) {
            if (rangeMin >= low && rangeMax <= high) return node->cnt;
            if (rangeMax < low || rangeMin > high) return 0;
            int ans = 0, b = n >> i & 1, r = 1 - b, mask = 1 << i;
            if (node->next[b]) ans += count(node->next[b], n, low, high, i - 1, rangeMin & ~mask, rangeMax & ~mask);
            if (node->next[r]) ans += count(node->next[r], n, low, high, i - 1, rangeMin | mask, rangeMax | mask);
            return ans;
        }
    public:
        int countPairs(vector<int>& A, int low, int high) {
            TrieNode root;
            int ans = 0;
            for (int n : A) {
                ans += count(&root, n, low, high);
                addNode(&root, n);
            }
            return ans;
        }
    };
    
  • class Trie:
        def __init__(self):
            self.children = [None] * 2
            self.cnt = 0
    
        def insert(self, x):
            node = self
            for i in range(15, -1, -1):
                v = x >> i & 1
                if node.children[v] is None:
                    node.children[v] = Trie()
                node = node.children[v]
                node.cnt += 1
    
        def search(self, x, limit):
            node = self
            ans = 0
            for i in range(15, -1, -1):
                if node is None:
                    return ans
                v = x >> i & 1
                if limit >> i & 1:
                    if node.children[v]:
                        ans += node.children[v].cnt
                    node = node.children[v ^ 1]
                else:
                    node = node.children[v]
            return ans
    
    
    class Solution:
        def countPairs(self, nums: List[int], low: int, high: int) -> int:
            ans = 0
            tree = Trie()
            for x in nums:
                ans += tree.search(x, high + 1) - tree.search(x, low)
                tree.insert(x)
            return ans
    
    
    
  • type Trie struct {
    	children [2]*Trie
    	cnt      int
    }
    
    func newTrie() *Trie {
    	return &Trie{}
    }
    
    func (this *Trie) insert(x int) {
    	node := this
    	for i := 15; i >= 0; i-- {
    		v := (x >> i) & 1
    		if node.children[v] == nil {
    			node.children[v] = newTrie()
    		}
    		node = node.children[v]
    		node.cnt++
    	}
    }
    
    func (this *Trie) search(x, limit int) (ans int) {
    	node := this
    	for i := 15; i >= 0 && node != nil; i-- {
    		v := (x >> i) & 1
    		if (limit >> i & 1) == 1 {
    			if node.children[v] != nil {
    				ans += node.children[v].cnt
    			}
    			node = node.children[v^1]
    		} else {
    			node = node.children[v]
    		}
    	}
    	return
    }
    
    func countPairs(nums []int, low int, high int) (ans int) {
    	tree := newTrie()
    	for _, x := range nums {
    		ans += tree.search(x, high+1) - tree.search(x, low)
    		tree.insert(x)
    	}
    	return
    }
    

All Problems

All Solutions