Welcome to Subscribe On Youtube

716. Max Stack


Design a max stack data structure that supports the stack operations and supports finding the stack's maximum element.

Implement the MaxStack class:

  • MaxStack() Initializes the stack object.
  • void push(int x) Pushes element x onto the stack.
  • int pop() Removes the element on top of the stack and returns it.
  • int top() Gets the element on the top of the stack without removing it.
  • int peekMax() Retrieves the maximum element in the stack without removing it.
  • int popMax() Retrieves the maximum element in the stack and removes it. If there is more than one maximum element, only remove the top-most one.

You must come up with a solution that supports O(1) for each top call and O(logn) for each other call.


Example 1:

["MaxStack", "push", "push", "push", "top", "popMax", "top", "peekMax", "pop", "top"]
[[], [5], [1], [5], [], [], [], [], [], []]
[null, null, null, null, 5, 5, 1, 5, 1, 5]

MaxStack stk = new MaxStack();
stk.push(5);   // [5] the top of the stack and the maximum number is 5.
stk.push(1);   // [5, 1] the top of the stack is 1, but the maximum is 5.
stk.push(5);   // [5, 1, 5] the top of the stack is 5, which is also the maximum, because it is the top most one.
stk.top();     // return 5, [5, 1, 5] the stack did not change.
stk.popMax();  // return 5, [5, 1] the stack is changed now, and the top is different from the max.
stk.top();     // return 1, [5, 1] the stack did not change.
stk.peekMax(); // return 5, [5, 1] the stack did not change.
stk.pop();     // return 1, [5] the top of the stack and the max element is now 5.
stk.top();     // return 5, [5] the stack did not change.



  • -107 <= x <= 107
  • At most 105 calls will be made to push, pop, top, peekMax, and popMax.
  • There will be at least one element in the stack when pop, top, peekMax, or popMax is called.


A stack supports push, pop, and top already, so only the last two operations need to be implemented. Use two stacks.

  • One stack works as the normal stack,
  • and the other stack which is called the maximum stack stores the maximum element so far.

Both stacks are initialized in the constructor. The other three functions should be modified as well.

  1. The push() function. Push the element into the normal stack, and for the maximum stack, if the maximum stack is empty, then simply push the element into the maximum stack, otherwise push the maximum element between the current element and the element at the top of the maximum stack.

  2. The pop() function. Simply pop both the normal stack and the maximum stack, and return the element popped from the normal stack.

  3. The top() function. Simply return the top element of the normal stack.

  4. The peekMax() function. Simply return the top element of the maximum stack. The reason why this works is that, each time the push function is called, the element pushed into the maximum stack is guaranteed to be the maximum element so far, so at any time, the top element of the maximum stack is the maximum element among all the elements pushed.

  5. The popMax() function. This is a highly complex function that requires careful handling. Since the maximum element may not always be at the top of the normal stack, it may be necessary to first pop some elements from the normal stack in order to locate it. The top element of the normal stack is considered the maximum only if it is identical to the top element of the maximum stack. Once the maximum element has been popped, any other elements that were previously removed from the normal stack must be re-added. To accomplish this, a new stack is used to store all the popped elements until the maximum element is identified. Once the maximum element has been popped and retrieved, each element in the new stack must be pushed back onto the normal stack, to ensure that both the normal stack and the maximum stack contain the correct values.

  • class Node {
        public int val;
        public Node prev, next;
        public Node() {
        public Node(int val) {
            this.val = val;
    class DoubleLinkedList {
        private final Node head = new Node();
        private final Node tail = new Node();
        public DoubleLinkedList() {
            head.next = tail;
            tail.prev = head;
        public Node append(int val) {
            Node node = new Node(val);
            node.next = tail;
            node.prev = tail.prev;
            tail.prev = node;
            node.prev.next = node;
            return node;
        public static Node remove(Node node) {
            node.prev.next = node.next;
            node.next.prev = node.prev;
            return node;
        public Node pop() {
            return remove(tail.prev);
        public int peek() {
            return tail.prev.val;
    class MaxStack {
        private DoubleLinkedList stk = new DoubleLinkedList();
        private TreeMap<Integer, List<Node>> tm = new TreeMap<>();
        public MaxStack() {
        public void push(int x) {
            Node node = stk.append(x);
            tm.computeIfAbsent(x, k -> new ArrayList<>()).add(node);
        public int pop() {
            Node node = stk.pop();
            List<Node> nodes = tm.get(node.val);
            int x = nodes.remove(nodes.size() - 1).val;
            if (nodes.isEmpty()) {
            return x;
        public int top() {
            return stk.peek();
        public int peekMax() {
            return tm.lastKey();
        public int popMax() {
            int x = peekMax();
            List<Node> nodes = tm.get(x);
            Node node = nodes.remove(nodes.size() - 1);
            if (nodes.isEmpty()) {
            return x;
     * Your MaxStack object will be instantiated and called as such:
     * MaxStack obj = new MaxStack();
     * obj.push(x);
     * int param_2 = obj.pop();
     * int param_3 = obj.top();
     * int param_4 = obj.peekMax();
     * int param_5 = obj.popMax();
  • class MaxStack {
        MaxStack() {
        void push(int x) {
            tm.insert({x, --stk.end()});
        int pop() {
            auto it = --stk.end();
            int ans = *it;
            auto mit = --tm.upper_bound(ans);
            return ans;
        int top() {
            return stk.back();
        int peekMax() {
            return tm.rbegin()->first;
        int popMax() {
            auto mit = --tm.end();
            auto it = mit->second;
            int ans = *it;
            return ans;
        multimap<int, list<int>::iterator> tm;
        list<int> stk;
     * Your MaxStack object will be instantiated and called as such:
     * MaxStack* obj = new MaxStack();
     * obj->push(x);
     * int param_2 = obj->pop();
     * int param_3 = obj->top();
     * int param_4 = obj->peekMax();
     * int param_5 = obj->popMax();
  • class MaxStack(object):
        def __init__(self):
            initialize your data structure here.
            self.stack = []
            self.max_stack = []
            # or the same trick in min-stack
            #   self.max_stack = [-math.inf]
        def push(self, x):
            :type x: int
            :rtype: void
            self.max_stack.append(max(x, self.max_stack[-1]) if self.max_stack else x)
        def pop(self):
            :rtype: int
            if len(self.stack) != 0:
                return self.stack.pop(-1)
        def top(self):
            :rtype: int
            return self.stack[-1] if self.stack else None
        def peekMax(self):
            :rtype: int
            if len(self.max_stack) != 0:
                return self.max_stack[-1]
        def popMax(self):
            :rtype: int
            val = self.peekMax()
            buff = []
            while self.top() != val:
                # re-use pop(), ops on both max_stack and stack
            while len(buff) != 0:
                # re-use push(), no need to save buffer for max-stack
            return val
    from sortedcontainers import SortedList
    class Node:
        def __init__(self, val=0):
            self.val = val
            self.prev: Union[Node, None] = None
            self.next: Union[Node, None] = None
    class DoubleLinkedList:
        def __init__(self):
            self.head = Node()
            self.tail = Node()
            self.head.next = self.tail
            self.tail.prev = self.head
        def append(self, val) -> Node:
            node = Node(val)
            node.next = self.tail
            node.prev = self.tail.prev
            self.tail.prev = node
            node.prev.next = node
            return node
        def remove(node) -> Node:
            node.prev.next = node.next
            node.next.prev = node.prev
            return node
        def pop(self) -> Node:
            return self.remove(self.tail.prev)
        def peek(self):
            return self.tail.prev.val
    class MaxStack:
        def __init__(self):
            self.stk = DoubleLinkedList()
            self.sl = SortedList(key=lambda x: x.val)
        def push(self, x: int) -> None:
            node = self.stk.append(x)
        def pop(self) -> int:
            node = self.stk.pop()
            return node.val
        def top(self) -> int:
            return self.stk.peek()
        def peekMax(self) -> int:
            return self.sl[-1].val
        def popMax(self) -> int:
            node = self.sl.pop()
            return node.val
    # Your MaxStack object will be instantiated and called as such:
    # obj = MaxStack()
    # obj.push(x)
    # param_2 = obj.pop()
    # param_3 = obj.top()
    # param_4 = obj.peekMax()
    # param_5 = obj.popMax()

All Problems

All Solutions