Welcome to Subscribe On Youtube
Question
Formatted question description: https://leetcode.ca/all/327.html
327. Count of Range Sum
Given an integer array nums, return the number of range sums that lie in [lower, upper] inclusive.
Range sum S(i, j) is defined as the sum of the elements in nums between indices i and j (i ≤ j), inclusive.
Note:
A naive algorithm of O(n^2) is trivial. You MUST do better than that.
Example:
Input: nums = [-2,5,-1], lower = -2, upper = 2,
Output: 3
Explanation: The three ranges are : [0,0], [2,2], [0,2] and their respective sums are: -2, -1, 2.
@tag-interval
Algorithm
The most intuitive idea of this question is to enumerate all the sub-arrays and then check whether they are within the required value range. The time complexity of this method is O(n^2), and it will obviously time out.
What is the easiest thing to think of when seeing this kind of topic? Two Pointers! Yes, but only using Two Pointers is definitely not enough for this problem. Based on the ideas of Two Pointers, we can find a better solution by merging and sorting.
Here our sorting object is the prefix sum array. In the merge stage of merge sort, we have left and right arrays, and both the left and right arrays are sorted, so we can traverse the left array with i, j, k The two pointers are searched in the right array, so that:
sums[j]-sums[i] <upper
sums[k]-sums[i] >= lower
So at this time, we have found j-k sub-arrays that meet the requirements.
Since the left and right arrays are all sorted, when i is incremented, the positions of j and k need not be scanned from the beginning.
One last thing to note is that in order to prevent overflow, our vector contains long long elements.
Code
Java
-
import java.util.Collection; import java.util.TreeMap; public class Count_of_Range_Sum { public static void main(String[] args) { Count_of_Range_Sum out = new Count_of_Range_Sum(); Solution s = out.new Solution(); System.out.println(s.countRangeSum(new int[]{-2,5,-1}, -2, 2)); } public class Solution { public int countRangeSum(int[] nums, int lower, int upper) { if (nums == null || nums.length == 0) { return 0; } long[] sums = new long[nums.length]; sums[0] = nums[0]; for(int i = 1; i < nums.length; i++) { sums[i] = sums[i-1] + nums[i]; } int total = 0; TreeMap<Long, Integer> treemap = new TreeMap<>(); for(int i = 0; i < nums.length; i++) { if (lower <= sums[i] && sums[i] <= upper) { total ++; } Collection<Integer> search = treemap .subMap(sums[i] - upper, true, sums[i]-lower, true) .values(); for(Integer count: search) { total += count; } // update treemap Integer count = treemap.get(sums[i]); if (count == null) { count = 1; } else { count ++; } treemap.put(sums[i], count); } return total; } } }
-
Todo
-
class BinaryIndexedTree: def __init__(self, n): self.n = n self.c = [0] * (n + 1) @staticmethod def lowbit(x): return x & -x def update(self, x, delta): while x <= self.n: self.c[x] += delta x += BinaryIndexedTree.lowbit(x) def query(self, x): s = 0 while x > 0: s += self.c[x] x -= BinaryIndexedTree.lowbit(x) return s class Solution: def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int: presum = [0] for v in nums: presum.append(presum[-1] + v) alls = set() for s in presum: alls.add(s) alls.add(s - lower) alls.add(s - upper) alls = sorted(alls) m = {v: i for i, v in enumerate(alls, 1)} tree = BinaryIndexedTree(len(m)) ans = 0 for s in presum: i, j = m[s - upper], m[s - lower] ans += tree.query(j) - tree.query(i - 1) tree.update(m[s], 1) return ans ############ class Solution(object): def countRangeSum(self, nums, lower, upper): """ :type nums: List[int] :type lower: int :type upper: int :rtype: int """ def update(b, i, delta): while i < len(b): b[i] += delta i += (i & -i) def sumRange(b, i): ret = 0 while i > 0: ret += b[i] i -= (i & -i) return ret ans = 0 pres = [0] * (len(nums) + 1) b = [0] * (len(nums) + 2) for i in range(0, len(nums)): pres[i + 1] = pres[i] + nums[i] sortedPres = sorted(pres) for end in pres: count = sumRange(b, bisect.bisect_right(sortedPres, end - lower)) - sumRange(b, bisect.bisect_left(sortedPres, end - upper)) ans += count update(b, bisect.bisect_left(sortedPres, end) + 1, 1) return ans