Problem statement

https://leetcode.com/problems/max-stack/

Solution

This problem is extension of problem 0155 Min Stack, but here also need to remove maximum, not only peek it. One possible way is to do it is to use the same approach as in problem 0155, with complexities O(1) for push, pop, top, peekMax and O(n) for popMax.

If we want to be able to remove elements fast from the middle of stack, we need to have structure using doubly linked lists for our stack. But also we need to track maximum and not only maximum, but sorted structure: if we remove maximum, we need to know where next maximum is. So, let us have self.map: Sorted Dictionary, where keys are numbers and values are list of links to our doubly linked list elements (list, because values can be equal).

Complexity

Time complexity of push, pop, peekMax, popMax is O(log n) and complexity of top is O(1). Space complexity is O(n).

Code

from sortedcontainers import SortedDict

class ListNode:
    def __init__(self, val=0, nxt=None, prv=None):
        self.val = val
        self.next = nxt
        self.prev = prv

class DoubleLinkedList:
    def __init__(self):
        self.head = ListNode(None)
        self.tail = ListNode(None)
        self.head.next = self.tail
        self.tail.prev = self.head

    def add(self, val):
        x = ListNode(val)
        x.next = self.tail
        x.prev = self.tail.prev
        self.tail.prev.next = x
        self.tail.prev = x
        return x

    def delete(self, node):
        node.prev.next = node.next
        node.next.prev = node.prev

class MaxStack:
    def __init__(self):
        self.map = SortedDict()
        self.dll = DoubleLinkedList()

    def push(self, x):
        node = self.dll.add(x)
        if x not in self.map:
            self.map[x] = [node]
        else:
            self.map[x].append(node)

    def pop(self):
        val = self.dll.tail.prev.val
        self.dll.delete(self.dll.tail.prev)
        self.map[val].pop()
        if not self.map[val]: self.map.pop(val)
        return val

    def top(self):
        return self.dll.tail.prev.val

    def peekMax(self):
        return self.map.peekitem()[0]

    def popMax(self):
        max = self.peekMax()
        node = self.map[max][-1]
        self.dll.delete(node)
        self.map[max].pop()
        if not self.map[max]: self.map.pop(max)
        return max