Welcome to Subscribe On Youtube

Formatted question description: https://leetcode.ca/all/1648.html

1648. Sell Diminishing-Valued Colored Balls (Medium)

You have an inventory of different colored balls, and there is a customer that wants orders balls of any color.

The customer weirdly values the colored balls. Each colored ball's value is the number of balls of that color you currently have in your inventory. For example, if you own 6 yellow balls, the customer would pay 6 for the first yellow ball. After the transaction, there are only 5 yellow balls left, so the next yellow ball is then valued at 5 (i.e., the value of the balls decreases as you sell more to the customer).

You are given an integer array, inventory, where inventory[i] represents the number of balls of the ith color that you initially own. You are also given an integer orders, which represents the total number of balls that the customer wants. You can sell the balls in any order.

Return the maximum total value that you can attain after selling orders colored balls. As the answer may be too large, return it modulo 109 + 7.

 

Example 1:

Input: inventory = [2,5], orders = 4
Output: 14
Explanation: Sell the 1st color 1 time (2) and the 2nd color 3 times (5 + 4 + 3).
The maximum total value is 2 + 5 + 4 + 3 = 14.

Example 2:

Input: inventory = [3,5], orders = 6
Output: 19
Explanation: Sell the 1st color 2 times (3 + 2) and the 2nd color 4 times (5 + 4 + 3 + 2).
The maximum total value is 3 + 2 + 5 + 4 + 3 + 2 = 19.

Example 3:

Input: inventory = [2,8,4,10,6], orders = 20
Output: 110

Example 4:

Input: inventory = [1000000000], orders = 1000000000
Output: 21
Explanation: Sell the 1st color 1000000000 times for a total value of 500000000500000000. 500000000500000000 modulo 109 + 7 = 21.

 

Constraints:

  • 1 <= inventory.length <= 105
  • 1 <= inventory[i] <= 109
  • 1 <= orders <= min(sum(inventory[i]), 109)

Related Topics:
Math, Greedy, Sort

Solution 1. Binary Answer

Intuition

We should greedily pick the greatest number A[i] from A, add A[i] to answer and then A[i]--.

The simple solution would be using max heap to simulate the above process. But it will get TLE.

To solve it more efficiently, we can break this problem into two steps:

  1. Find the smallest number k such that we make all A[i] > k to be k, and sum(A[i] - k) (i.e. the number of balls we take) is smaller than T. We can do this using binary search.
  2. Based on k, we can calculate the maximum total value.

image

Algorithm

For Step 1, I first store the frequencies of A[i] in a map. The range of k is [0, max(A)], so we do binary search within this range.

Let L = 0, R = max(A).

We define a valid(M, T) function which returns true if sum(A[i] - M | A[i] > M) >= T.

The smaller the M is, the more likely valid(M, T) returns true.

If valid(M, T), we should increase M, so L = M + 1. Otherwise, we do R = M - 1.

In the end, L is the k we are looking for.

For Step 2, we traverse the m in descending order. For each pair n, cnt that n > L, we will take cnt * (n - L) balls from it which has total value (n + L + 1) * (n - L) / 2 * cnt.

image

If after traversing the m, T is still not exhausted, we should add L * T to the answer.

image

// OJ: https://leetcode.com/contest/weekly-contest-214/problems/sell-diminishing-valued-colored-balls/
// Time: O(Nlog(max(A)))
// Space: O(N)
class Solution {
    map<int, int, greater<>> m;
    bool valid(int M, int T) {
        for (auto &[n, cnt] : m) {
            if (n <= M) break;
            T -= cnt * (n - M);
            if (T <= 0) return true;
        }
        return T <= 0;
    }
public:
    int maxProfit(vector<int>& A, int T) {
        long ans = 0, mod = 1e9+7, L = 0, R = *max_element(begin(A), end(A));
        for (int n : A) m[n]++;
        while (L <= R) {
            long M = (L + R) / 2;
            if (valid(M, T)) L = M + 1;
            else R = M - 1;
        }
        for (auto &[n, cnt] : m) {
            if (n <= L) break;
            T -= cnt * (n - L);
            ans = (ans + (n + L + 1) * (n - L) / 2 % mod * cnt % mod) % mod;
        }
        if (T) ans = (ans + L * T % mod) % mod;
        return ans;
    }
};

About valid function

We define a valid(M, T) function which returns true if sum(A[i] - M | A[i] > M) >= T.

The meaning of this valid(M, T) function is that if we pick all the balls with value greater than M, the number of balls we will get is greater than or equal to our target amount of balls T. So a better function name could be haveEnoughBalls.

The binary search is looking for the smallest M which is invalid, or can’t get us enough balls. And for the remainder, we are sure those balls are of value M.

image

image

Summary of my thinking process

  1. If I directly use min heap to simulate the process, the time complexity is O(sum(A) * logN) which is too slow (the max heap solution can be improved as well if we don’t decrement the numbers by one). Is there a way to improve it?
  2. We are always taking high value balls from top to bottom, like stripping the balls row by row from the top. What if I somehow know that I should strip all the balls above the k-th row? Can I efficiently calculate the values?
  3. Yes I can. I can use the formula of the sum of arithmetic sequence to get the values of each column. And for the remainders, they must be at the k-th row so their values are k.
  4. Now the problem is how to find this k efficiently? The meaning of it is the smallest row number which doesn’t give me enough ball. Does it have monotonicity?
  5. Yes! At the beginning we strip no balls so we don’t have enough balls, let’s regard this as invalid. As we strip downwards, we will get more and more balls until a breakpoint where we get enough or more balls than we need. If we continue stripping, we will keep being valid. So it has perfect monotonicity, which means that I can use binary search!
  6. To use binary search in range [L, R] = [0, max(A)], I need to define a function valid(M) meaning having enough balls. It should return true if sum( A[i] - M | A[i] > M ) >= T. We can do this valid(M) in O(N) time.
  7. Since the binary search takes O(log(max(A))) time, the overal time complexity is O(Nlog(max(A))). Good to go.

More Binary Answer Problems

Solution 2. Sort + Greedy

// OJ: https://leetcode.com/contest/weekly-contest-214/problems/sell-diminishing-valued-colored-balls/
// Time: O(NlogN)
// Space: O(1)
// Ref: https://www.youtube.com/watch?v=kBxqAnWo9Wo
class Solution {
public:
    int maxProfit(vector<int>& A, int T) {
        sort(rbegin(A), rend(A)); // sort in descending order
        long mod = 1e9+7, N = A.size(), cur = A[0], ans = 0, i = 0;
        while (T) {
            while (i < N && A[i] == cur) ++i; // handle those same numbers together
            long next = i == N ? 0 : A[i], h = cur - next, r = 0, cnt = min((long)T, i * h);
            if (T < i * h) { // the i * h balls are more than what we need.
                h = T / i; // we just reduce the height by `T / i` instead
                r = T % i; 
            }
            long val = cur - h; // each of the remainder `r` balls has value `cur - h`.
            ans = (ans + (cur + val + 1) * h / 2 * i + val * r) % mod;
            T -= cnt;
            cur = next;
        }
        return ans;
    }
};

Java

  • class Solution {
        public int maxProfit(int[] inventory, int orders) {
            final int MODULO = 1000000007;
            long profit = 0;
            Map<Integer, Integer> map = new HashMap<Integer, Integer>();
            for (int balls : inventory) {
                int count = map.getOrDefault(balls, 0) + 1;
                map.put(balls, count);
            }
            PriorityQueue<Integer> priorityQueue = new PriorityQueue<Integer>(new Comparator<Integer>() {
                public int compare(Integer balls1, Integer balls2) {
                    return balls2 - balls1;
                }
            });
            for (int balls : map.keySet())
                priorityQueue.offer(balls);
            while (orders > 0) {
                int maxBalls = priorityQueue.poll();
                int count = map.get(maxBalls);
                if (priorityQueue.isEmpty()) {
                    int quotient = orders / count;
                    int remainder = orders % count;
                    profit += (long) (maxBalls + maxBalls - quotient + 1) * quotient / 2 * count;
                    int remainMaxBalls = maxBalls - quotient;
                    profit += (long) remainMaxBalls * remainder;
                    break;
                } else {
                    int max2Balls = priorityQueue.peek();
                    long curOrders = (long) (maxBalls - max2Balls) * count;
                    if (curOrders <= orders) {
                        profit += (long) (maxBalls + max2Balls + 1) * curOrders / 2;
                        orders -= (int) curOrders;
                        map.put(max2Balls, map.get(max2Balls) + count);
                    } else {
                        int quotient = orders / count;
                        int remainder = orders % count;
                        profit += (long) (maxBalls + maxBalls - quotient + 1) * quotient / 2 * count;
                        int remainMaxBalls = maxBalls - quotient;
                        profit += (long) remainMaxBalls * remainder;
                        break;
                    }
                }
            }
            return (int) (profit % MODULO);
        }
    }
    
  • // OJ: https://leetcode.com/contest/weekly-contest-214/problems/sell-diminishing-valued-colored-balls/
    // Time: O(Nlog(max(A)))
    // Space: O(N)
    class Solution {
        map<int, int, greater<>> m;
        bool valid(int M, int T) {
            for (auto &[n, cnt] : m) {
                if (n <= M) break;
                T -= cnt * (n - M);
                if (T <= 0) return true;
            }
            return T <= 0;
        }
    public:
        int maxProfit(vector<int>& A, int T) {
            long ans = 0, mod = 1e9+7, L = 0, R = *max_element(begin(A), end(A));
            for (int n : A) m[n]++;
            while (L <= R) {
                long M = (L + R) / 2;
                if (valid(M, T)) L = M + 1;
                else R = M - 1;
            }
            for (auto &[n, cnt] : m) {
                if (n <= L) break;
                T -= cnt * (n - L);
                ans = (ans + (n + L + 1) * (n - L) / 2 % mod * cnt % mod) % mod;
            }
            if (T) ans = (ans + L * T % mod) % mod;
            return ans;
        }
    };
    
  • class Solution:
        def maxProfit(self, inventory: List[int], orders: int) -> int:
            inventory.sort(reverse=True)
            mod = 10**9 + 7
            ans = i = 0
            n = len(inventory)
            while orders > 0:
                while i < n and inventory[i] >= inventory[0]:
                    i += 1
                nxt = 0
                if i < n:
                    nxt = inventory[i]
                cnt = i
                x = inventory[0] - nxt
                tot = cnt * x
                if tot > orders:
                    decr = orders // cnt
                    a1, an = inventory[0] - decr + 1, inventory[0]
                    ans += (a1 + an) * decr // 2 * cnt
                    ans += (inventory[0] - decr) * (orders % cnt)
                else:
                    a1, an = nxt + 1, inventory[0]
                    ans += (a1 + an) * x // 2 * cnt
                    inventory[0] = nxt
                orders -= tot
                ans %= mod
            return ans
    
    ############
    
    # 1648. Sell Diminishing-Valued Colored Balls
    # https://leetcode.com/problems/sell-diminishing-valued-colored-balls/
    
    from heapq import heappop, heappush
    
    class Solution:
        def maxProfit(self, inventory: List[int], orders: int) -> int:
            M = 1000000007
            
            pq = []
            for v in inventory:
                heappush(pq, (-v, 1))
    
            res = 0
            
            while pq and orders > 0:
                top, count = pq[0]
                top = -top
                heappop(pq)
                
                while pq and pq[0][0] == -top:
                    count += pq[0][1]
                    heappop(pq)
                
                next_top = 0
                if pq:
                    next_top = -pq[0][0]
                    
                delta = top - next_top
                
                if delta * count <= orders:
                    d = (top*(top+1)) //2
                    d -= (next_top*(next_top+1)) // 2
                    d *= count
                    orders -= delta * count
                    res += d
                    res %= M
                else:
                    num = orders // count
                    d = (top*(top+1)) //2
                    next_top = top - num
                    d -= (next_top*(next_top+1)) // 2
                    d *= count
                    
                    res += d
                    orders -= num * count
                    
                    leftover = orders % count
                    res += next_top * leftover
                    res %= M
                    orders -= leftover
                    
                if next_top:
                    heappush(pq, (-next_top, count))
    
            
            return res % M
    

All Problems

All Solutions