Welcome to Subscribe On Youtube
Formatted question description: https://leetcode.ca/all/1648.html
1648. Sell DiminishingValued Colored Balls (Medium)
You have an inventory
of different colored balls, and there is a customer that wants orders
balls of any color.
The customer weirdly values the colored balls. Each colored ball's value is the number of balls of that color you currently have in your inventory
. For example, if you own 6
yellow balls, the customer would pay 6
for the first yellow ball. After the transaction, there are only 5
yellow balls left, so the next yellow ball is then valued at 5
(i.e., the value of the balls decreases as you sell more to the customer).
You are given an integer array, inventory
, where inventory[i]
represents the number of balls of the i^{th}
color that you initially own. You are also given an integer orders
, which represents the total number of balls that the customer wants. You can sell the balls in any order.
Return the maximum total value that you can attain after selling orders
colored balls. As the answer may be too large, return it modulo 10^{9 }+ 7
.
Example 1:
Input: inventory = [2,5], orders = 4 Output: 14 Explanation: Sell the 1st color 1 time (2) and the 2nd color 3 times (5 + 4 + 3). The maximum total value is 2 + 5 + 4 + 3 = 14.
Example 2:
Input: inventory = [3,5], orders = 6 Output: 19 Explanation: Sell the 1st color 2 times (3 + 2) and the 2nd color 4 times (5 + 4 + 3 + 2). The maximum total value is 3 + 2 + 5 + 4 + 3 + 2 = 19.
Example 3:
Input: inventory = [2,8,4,10,6], orders = 20 Output: 110
Example 4:
Input: inventory = [1000000000], orders = 1000000000 Output: 21 Explanation: Sell the 1st color 1000000000 times for a total value of 500000000500000000. 500000000500000000 modulo 10^{9 }+ 7 = 21.
Constraints:
1 <= inventory.length <= 10^{5}
1 <= inventory[i] <= 10^{9}
1 <= orders <= min(sum(inventory[i]), 10^{9})
Related Topics:
Math, Greedy, Sort
Solution 1. Binary Answer
Intuition
We should greedily pick the greatest number A[i]
from A
, add A[i]
to answer and then A[i]
.
The simple solution would be using max heap to simulate the above process. But it will get TLE.
To solve it more efficiently, we can break this problem into two steps:
 Find the smallest number
k
such that we make allA[i] > k
to bek
, andsum(A[i]  k)
(i.e. the number of balls we take) is smaller thanT
. We can do this using binary search.  Based on
k
, we can calculate the maximum total value.
Algorithm
For Step 1, I first store the frequencies of A[i]
in a map. The range of k
is [0, max(A)]
, so we do binary search within this range.
Let L = 0, R = max(A)
.
We define a valid(M, T)
function which returns true if sum(A[i]  M  A[i] > M) >= T
.
The smaller the M
is, the more likely valid(M, T)
returns true.
If valid(M, T)
, we should increase M
, so L = M + 1
. Otherwise, we do R = M  1
.
In the end, L
is the k
we are looking for.
For Step 2, we traverse the m
in descending order. For each pair n, cnt
that n > L
, we will take cnt * (n  L)
balls from it which has total value (n + L + 1) * (n  L) / 2 * cnt
.
If after traversing the m
, T
is still not exhausted, we should add L * T
to the answer.
// OJ: https://leetcode.com/contest/weeklycontest214/problems/selldiminishingvaluedcoloredballs/
// Time: O(Nlog(max(A)))
// Space: O(N)
class Solution {
map<int, int, greater<>> m;
bool valid(int M, int T) {
for (auto &[n, cnt] : m) {
if (n <= M) break;
T = cnt * (n  M);
if (T <= 0) return true;
}
return T <= 0;
}
public:
int maxProfit(vector<int>& A, int T) {
long ans = 0, mod = 1e9+7, L = 0, R = *max_element(begin(A), end(A));
for (int n : A) m[n]++;
while (L <= R) {
long M = (L + R) / 2;
if (valid(M, T)) L = M + 1;
else R = M  1;
}
for (auto &[n, cnt] : m) {
if (n <= L) break;
T = cnt * (n  L);
ans = (ans + (n + L + 1) * (n  L) / 2 % mod * cnt % mod) % mod;
}
if (T) ans = (ans + L * T % mod) % mod;
return ans;
}
};
About valid
function
We define a
valid(M, T)
function which returns true ifsum(A[i]  M  A[i] > M) >= T
.
The meaning of this valid(M, T)
function is that if we pick all the balls with value greater than M
, the number of balls we will get is greater than or equal to our target amount of balls T
. So a better function name could be haveEnoughBalls
.
The binary search is looking for the smallest M
which is invalid, or can’t get us enough balls. And for the remainder, we are sure those balls are of value M
.
Summary of my thinking process
 If I directly use min heap to simulate the process, the time complexity is
O(sum(A) * logN)
which is too slow (the max heap solution can be improved as well if we don’t decrement the numbers by one). Is there a way to improve it?  We are always taking high value balls from top to bottom, like stripping the balls row by row from the top. What if I somehow know that I should strip all the balls above the
k
th row? Can I efficiently calculate the values?  Yes I can. I can use the formula of the sum of arithmetic sequence to get the values of each column. And for the remainders, they must be at the
k
th row so their values arek
.  Now the problem is how to find this
k
efficiently? The meaning of it is the smallest row number which doesn’t give me enough ball. Does it have monotonicity?  Yes! At the beginning we strip no balls so we don’t have enough balls, let’s regard this as invalid. As we strip downwards, we will get more and more balls until a breakpoint where we get enough or more balls than we need. If we continue stripping, we will keep being valid. So it has perfect monotonicity, which means that I can use binary search!
 To use binary search in range
[L, R] = [0, max(A)]
, I need to define a functionvalid(M)
meaning having enough balls. It should returntrue
ifsum( A[i]  M  A[i] > M ) >= T
. We can do thisvalid(M)
inO(N)
time.  Since the binary search takes
O(log(max(A)))
time, the overal time complexity isO(Nlog(max(A)))
. Good to go.
More Binary Answer Problems
 410. Split Array Largest Sum (Hard)
 1482. Minimum Number of Days to Make m Bouquets (Medium)
 1300. Sum of Mutated Array Closest to Target (Medium)
 1044. Longest Duplicate Substring (Hard)
 668. Kth Smallest Number in Multiplication Table (Hard)
 719. Find Kth Smallest Pair Distance (Hard)
 1283. Find the Smallest Divisor Given a Threshold (Medium)
Solution 2. Sort + Greedy
// OJ: https://leetcode.com/contest/weeklycontest214/problems/selldiminishingvaluedcoloredballs/
// Time: O(NlogN)
// Space: O(1)
// Ref: https://www.youtube.com/watch?v=kBxqAnWo9Wo
class Solution {
public:
int maxProfit(vector<int>& A, int T) {
sort(rbegin(A), rend(A)); // sort in descending order
long mod = 1e9+7, N = A.size(), cur = A[0], ans = 0, i = 0;
while (T) {
while (i < N && A[i] == cur) ++i; // handle those same numbers together
long next = i == N ? 0 : A[i], h = cur  next, r = 0, cnt = min((long)T, i * h);
if (T < i * h) { // the i * h balls are more than what we need.
h = T / i; // we just reduce the height by `T / i` instead
r = T % i;
}
long val = cur  h; // each of the remainder `r` balls has value `cur  h`.
ans = (ans + (cur + val + 1) * h / 2 * i + val * r) % mod;
T = cnt;
cur = next;
}
return ans;
}
};
Java

class Solution { public int maxProfit(int[] inventory, int orders) { final int MODULO = 1000000007; long profit = 0; Map<Integer, Integer> map = new HashMap<Integer, Integer>(); for (int balls : inventory) { int count = map.getOrDefault(balls, 0) + 1; map.put(balls, count); } PriorityQueue<Integer> priorityQueue = new PriorityQueue<Integer>(new Comparator<Integer>() { public int compare(Integer balls1, Integer balls2) { return balls2  balls1; } }); for (int balls : map.keySet()) priorityQueue.offer(balls); while (orders > 0) { int maxBalls = priorityQueue.poll(); int count = map.get(maxBalls); if (priorityQueue.isEmpty()) { int quotient = orders / count; int remainder = orders % count; profit += (long) (maxBalls + maxBalls  quotient + 1) * quotient / 2 * count; int remainMaxBalls = maxBalls  quotient; profit += (long) remainMaxBalls * remainder; break; } else { int max2Balls = priorityQueue.peek(); long curOrders = (long) (maxBalls  max2Balls) * count; if (curOrders <= orders) { profit += (long) (maxBalls + max2Balls + 1) * curOrders / 2; orders = (int) curOrders; map.put(max2Balls, map.get(max2Balls) + count); } else { int quotient = orders / count; int remainder = orders % count; profit += (long) (maxBalls + maxBalls  quotient + 1) * quotient / 2 * count; int remainMaxBalls = maxBalls  quotient; profit += (long) remainMaxBalls * remainder; break; } } } return (int) (profit % MODULO); } }

// OJ: https://leetcode.com/contest/weeklycontest214/problems/selldiminishingvaluedcoloredballs/ // Time: O(Nlog(max(A))) // Space: O(N) class Solution { map<int, int, greater<>> m; bool valid(int M, int T) { for (auto &[n, cnt] : m) { if (n <= M) break; T = cnt * (n  M); if (T <= 0) return true; } return T <= 0; } public: int maxProfit(vector<int>& A, int T) { long ans = 0, mod = 1e9+7, L = 0, R = *max_element(begin(A), end(A)); for (int n : A) m[n]++; while (L <= R) { long M = (L + R) / 2; if (valid(M, T)) L = M + 1; else R = M  1; } for (auto &[n, cnt] : m) { if (n <= L) break; T = cnt * (n  L); ans = (ans + (n + L + 1) * (n  L) / 2 % mod * cnt % mod) % mod; } if (T) ans = (ans + L * T % mod) % mod; return ans; } };

class Solution: def maxProfit(self, inventory: List[int], orders: int) > int: inventory.sort(reverse=True) mod = 10**9 + 7 ans = i = 0 n = len(inventory) while orders > 0: while i < n and inventory[i] >= inventory[0]: i += 1 nxt = 0 if i < n: nxt = inventory[i] cnt = i x = inventory[0]  nxt tot = cnt * x if tot > orders: decr = orders // cnt a1, an = inventory[0]  decr + 1, inventory[0] ans += (a1 + an) * decr // 2 * cnt ans += (inventory[0]  decr) * (orders % cnt) else: a1, an = nxt + 1, inventory[0] ans += (a1 + an) * x // 2 * cnt inventory[0] = nxt orders = tot ans %= mod return ans ############ # 1648. Sell DiminishingValued Colored Balls # https://leetcode.com/problems/selldiminishingvaluedcoloredballs/ from heapq import heappop, heappush class Solution: def maxProfit(self, inventory: List[int], orders: int) > int: M = 1000000007 pq = [] for v in inventory: heappush(pq, (v, 1)) res = 0 while pq and orders > 0: top, count = pq[0] top = top heappop(pq) while pq and pq[0][0] == top: count += pq[0][1] heappop(pq) next_top = 0 if pq: next_top = pq[0][0] delta = top  next_top if delta * count <= orders: d = (top*(top+1)) //2 d = (next_top*(next_top+1)) // 2 d *= count orders = delta * count res += d res %= M else: num = orders // count d = (top*(top+1)) //2 next_top = top  num d = (next_top*(next_top+1)) // 2 d *= count res += d orders = num * count leftover = orders % count res += next_top * leftover res %= M orders = leftover if next_top: heappush(pq, (next_top, count)) return res % M