Welcome to Subscribe On Youtube

Formatted question description: https://leetcode.ca/all/2281.html

2281. Sum of Total Strength of Wizards

  • Difficulty: Hard.
  • Related Topics: Array, Stack, Monotonic Stack, Prefix Sum.
  • Similar Questions: Next Greater Element I, Sum of Subarray Minimums, Number of Visible People in a Queue, Sum of Subarray Ranges.

Problem

As the ruler of a kingdom, you have an army of wizards at your command.

You are given a 0-indexed integer array strength, where strength[i] denotes the strength of the ith wizard. For a contiguous group of wizards (i.e. the wizards’ strengths form a subarray of strength), the total strength is defined as the product of the following two values:

  • The strength of the weakest wizard in the group.

  • The total of all the individual strengths of the wizards in the group.

Return the **sum of the total strengths of all contiguous groups of wizards. Since the answer may be very large, return it **modulo 109 + 7.

A subarray is a contiguous non-empty sequence of elements within an array.

  Example 1:

Input: strength = [1,3,1,2]
Output: 44
Explanation: The following are all the contiguous groups of wizards:
- [1] from [1,3,1,2] has a total strength of min([1]) * sum([1]) = 1 * 1 = 1
- [3] from [1,3,1,2] has a total strength of min([3]) * sum([3]) = 3 * 3 = 9
- [1] from [1,3,1,2] has a total strength of min([1]) * sum([1]) = 1 * 1 = 1
- [2] from [1,3,1,2] has a total strength of min([2]) * sum([2]) = 2 * 2 = 4
- [1,3] from [1,3,1,2] has a total strength of min([1,3]) * sum([1,3]) = 1 * 4 = 4
- [3,1] from [1,3,1,2] has a total strength of min([3,1]) * sum([3,1]) = 1 * 4 = 4
- [1,2] from [1,3,1,2] has a total strength of min([1,2]) * sum([1,2]) = 1 * 3 = 3
- [1,3,1] from [1,3,1,2] has a total strength of min([1,3,1]) * sum([1,3,1]) = 1 * 5 = 5
- [3,1,2] from [1,3,1,2] has a total strength of min([3,1,2]) * sum([3,1,2]) = 1 * 6 = 6
- [1,3,1,2] from [1,3,1,2] has a total strength of min([1,3,1,2]) * sum([1,3,1,2]) = 1 * 7 = 7
The sum of all the total strengths is 1 + 9 + 1 + 4 + 4 + 4 + 3 + 5 + 6 + 7 = 44.

Example 2:

Input: strength = [5,4,6]
Output: 213
Explanation: The following are all the contiguous groups of wizards: 
- [5] from [5,4,6] has a total strength of min([5]) * sum([5]) = 5 * 5 = 25
- [4] from [5,4,6] has a total strength of min([4]) * sum([4]) = 4 * 4 = 16
- [6] from [5,4,6] has a total strength of min([6]) * sum([6]) = 6 * 6 = 36
- [5,4] from [5,4,6] has a total strength of min([5,4]) * sum([5,4]) = 4 * 9 = 36
- [4,6] from [5,4,6] has a total strength of min([4,6]) * sum([4,6]) = 4 * 10 = 40
- [5,4,6] from [5,4,6] has a total strength of min([5,4,6]) * sum([5,4,6]) = 4 * 15 = 60
The sum of all the total strengths is 25 + 16 + 36 + 36 + 40 + 60 = 213.

  Constraints:

  • 1 <= strength.length <= 105

  • 1 <= strength[i] <= 109

Solution

  • class Solution {
        private static int mod = (int) 1e9 + 7;
    
        public int totalStrength(int[] nums) {
            int n = nums.length;
            long[] forward = new long[n];
            long[] backward = new long[n];
            long[] prefix = new long[n + 1];
            long[] suffix = new long[n + 1];
            forward[0] = prefix[1] = nums[0];
            backward[n - 1] = suffix[n - 1] = nums[n - 1];
            for (int i = 1; i < n; ++i) {
                forward[i] = nums[i] + forward[i - 1];
                prefix[i + 1] = prefix[i] + forward[i];
            }
            for (int i = n - 2; 0 <= i; --i) {
                backward[i] = nums[i] + backward[i + 1];
                suffix[i] = suffix[i + 1] + backward[i];
            }
            long res = 0;
            Deque<Integer> dq = new LinkedList<>();
            for (int i = 0; i < n; ++i) {
                while (!dq.isEmpty() && nums[dq.peekLast()] >= nums[i]) {
                    int cur = dq.pollLast();
                    int prev = dq.isEmpty() ? -1 : dq.peekLast();
                    res =
                            (res
                                            + getSum(
                                                            nums, forward, prefix, backward, suffix,
                                                            prev, cur, i)
                                                    * nums[cur])
                                    % mod;
                }
                dq.add(i);
            }
            while (!dq.isEmpty()) {
                int cur = dq.pollLast();
                int prev = dq.isEmpty() ? -1 : dq.peekLast();
                res =
                        (res
                                        + getSum(nums, forward, prefix, backward, suffix, prev, cur, n)
                                                * nums[cur])
                                % mod;
            }
            return (int) res;
        }
    
        private long getSum(
                int[] nums,
                long[] forward,
                long[] prefix,
                long[] backward,
                long[] suffix,
                int prev,
                int cur,
                int next) {
            long sum = ((cur - prev) * (long) nums[cur] % mod) * (next - cur) % mod;
            long preSum = getPresum(backward, suffix, prev + 1, cur - 1, next - cur);
            long postSum = getPostsum(forward, prefix, cur + 1, next - 1, cur - prev);
            return (sum + preSum + postSum) % mod;
        }
    
        private long getPresum(long[] backward, long[] suffix, int from, int to, int m) {
            int n = backward.length;
            long cnt = to - from + 1L;
            return (suffix[from] - suffix[to + 1] - cnt * (to + 1 == n ? 0 : backward[to + 1]) % mod)
                    % mod
                    * m
                    % mod;
        }
    
        private long getPostsum(long[] forward, long[] prefix, int from, int to, int m) {
            long cnt = to - from + 1L;
            return (prefix[to + 1] - prefix[from] - cnt * (0 == from ? 0 : forward[from - 1]) % mod)
                    % mod
                    * m
                    % mod;
        }
    }
    
    ############
    
    class Solution {
        public int totalStrength(int[] strength) {
            int n = strength.length;
            int[] left = new int[n];
            int[] right = new int[n];
            Arrays.fill(left, -1);
            Arrays.fill(right, n);
            Deque<Integer> stk = new ArrayDeque<>();
            for (int i = 0; i < n; ++i) {
                while (!stk.isEmpty() && strength[stk.peek()] >= strength[i]) {
                    stk.pop();
                }
                if (!stk.isEmpty()) {
                    left[i] = stk.peek();
                }
                stk.push(i);
            }
            stk.clear();
            for (int i = n - 1; i >= 0; --i) {
                while (!stk.isEmpty() && strength[stk.peek()] > strength[i]) {
                    stk.pop();
                }
                if (!stk.isEmpty()) {
                    right[i] = stk.peek();
                }
                stk.push(i);
            }
            int mod = (int) 1e9 + 7;
            int[] s = new int[n + 1];
            for (int i = 0; i < n; ++i) {
                s[i + 1] = (s[i] + strength[i]) % mod;
            }
            int[] ss = new int[n + 2];
            for (int i = 0; i < n + 1; ++i) {
                ss[i + 1] = (ss[i] + s[i]) % mod;
            }
            long ans = 0;
            for (int i = 0; i < n; ++i) {
                int v = strength[i];
                int l = left[i] + 1, r = right[i] - 1;
                long a = (long) (i - l + 1) * (ss[r + 2] - ss[i + 1]);
                long b = (long) (r - i + 1) * (ss[i + 1] - ss[l]);
                ans = (ans + v * ((a - b) % mod)) % mod;
            }
            return (int) (ans + mod) % mod;
        }
    }
    
  • class Solution:
        def totalStrength(self, strength: List[int]) -> int:
            n = len(strength)
            left = [-1] * n
            right = [n] * n
            stk = []
            for i, v in enumerate(strength):
                while stk and strength[stk[-1]] >= v:
                    stk.pop()
                if stk:
                    left[i] = stk[-1]
                stk.append(i)
            stk = []
            for i in range(n - 1, -1, -1):
                while stk and strength[stk[-1]] > strength[i]:
                    stk.pop()
                if stk:
                    right[i] = stk[-1]
                stk.append(i)
    
            ss = list(accumulate(list(accumulate(strength, initial=0)), initial=0))
            mod = int(1e9) + 7
            ans = 0
            for i, v in enumerate(strength):
                l, r = left[i] + 1, right[i] - 1
                a = (ss[r + 2] - ss[i + 1]) * (i - l + 1)
                b = (ss[i + 1] - ss[l]) * (r - i + 1)
                ans = (ans + (a - b) * v) % mod
            return ans
    
    ############
    
    # 2281. Sum of Total Strength of Wizards
    # https://leetcode.com/problems/sum-of-total-strength-of-wizards
    
    class Solution:
        def totalStrength(self, A: List[int]) -> int:
            n = len(A)
            
            right = [n] * n
            stack = []
            
            for i in range(n):
                while stack and A[i] < A[stack[-1]]:
                    right[stack.pop()] = i
                
                stack.append(i)
            
            left = [-1] * n
            stack = []
            
            for i in range(n - 1, -1, -1):
                while stack and A[i] <= A[stack[-1]]:
                    left[stack.pop()] = i
                
                stack.append(i)
            
            res = 0
            M = 10 ** 9 + 7
            acc = list(accumulate(accumulate(A, initial = 0)))
            
            for i, x in enumerate(A):
                l, r = left[i], right[i]
                leftSum = acc[i] - acc[max(0, l)]
                rightSum = acc[r] - acc[i]
                ln, rn = i - l, r - i
                res += A[i] * (rightSum * ln - leftSum * rn)
                res %= M
            
            return res % M
    
    
  • class Solution {
    public:
        int totalStrength(vector<int>& strength) {
            int n = strength.size();
            vector<int> left(n, -1);
            vector<int> right(n, n);
            stack<int> stk;
            for (int i = 0; i < n; ++i) {
                while (!stk.empty() && strength[stk.top()] >= strength[i]) stk.pop();
                if (!stk.empty()) left[i] = stk.top();
                stk.push(i);
            }
            stk = stack<int>();
            for (int i = n - 1; i >= 0; --i) {
                while (!stk.empty() && strength[stk.top()] > strength[i]) stk.pop();
                if (!stk.empty()) right[i] = stk.top();
                stk.push(i);
            }
            int mod = 1e9 + 7;
            vector<int> s(n + 1);
            for (int i = 0; i < n; ++i) s[i + 1] = (s[i] + strength[i]) % mod;
            vector<int> ss(n + 2);
            for (int i = 0; i < n + 1; ++i) ss[i + 1] = (ss[i] + s[i]) % mod;
            int ans = 0;
            for (int i = 0; i < n; ++i) {
                int v = strength[i];
                int l = left[i] + 1, r = right[i] - 1;
                long a = (long) (i - l + 1) * (ss[r + 2] - ss[i + 1]);
                long b = (long) (r - i + 1) * (ss[i + 1] - ss[l]);
                ans = (ans + v * ((a - b) % mod)) % mod;
            }
            return (int) (ans + mod) % mod;
        }
    };
    
  • func totalStrength(strength []int) int {
    	n := len(strength)
    	left := make([]int, n)
    	right := make([]int, n)
    	for i := range left {
    		left[i] = -1
    		right[i] = n
    	}
    	stk := []int{}
    	for i, v := range strength {
    		for len(stk) > 0 && strength[stk[len(stk)-1]] >= v {
    			stk = stk[:len(stk)-1]
    		}
    		if len(stk) > 0 {
    			left[i] = stk[len(stk)-1]
    		}
    		stk = append(stk, i)
    	}
    	stk = []int{}
    	for i := n - 1; i >= 0; i-- {
    		for len(stk) > 0 && strength[stk[len(stk)-1]] > strength[i] {
    			stk = stk[:len(stk)-1]
    		}
    		if len(stk) > 0 {
    			right[i] = stk[len(stk)-1]
    		}
    		stk = append(stk, i)
    	}
    	mod := int(1e9) + 7
    	s := make([]int, n+1)
    	for i, v := range strength {
    		s[i+1] = (s[i] + v) % mod
    	}
    	ss := make([]int, n+2)
    	for i, v := range s {
    		ss[i+1] = (ss[i] + v) % mod
    	}
    	ans := 0
    	for i, v := range strength {
    		l, r := left[i]+1, right[i]-1
    		a := (ss[r+2] - ss[i+1]) * (i - l + 1)
    		b := (ss[i+1] - ss[l]) * (r - i + 1)
    		ans = (ans + v*((a-b)%mod)) % mod
    	}
    	return (ans + mod) % mod
    }
    

Explain:

nope.

Complexity:

  • Time complexity : O(n).
  • Space complexity : O(n).

All Problems

All Solutions