Problem statement

https://leetcode.com/problems/all-oone-data-structure/

Solution

Similar idea to problem 0460. LFU Cache.

We need to keep for each value set of keys, for example if we have original hash table: "cat": 1, "dog": 2, "bat":1, "rat":3, then our keys are 1: {"cat", "bat"}, 2: {"dog"}, 3: {"rat"}. We need to keep it as doubly-linked list: data-structure like this:

[self.val, self.keys = set(), self.prev, self.next].

So, if we remove some value completely, we can rebuild our data in O(1). It is easier to say than to implement this. When we increase value, we need to check if our next element in linked list is equal to value + 1 or not and either create new cell in linked list or add element to next cell. Do not forget to delete empty cells. Similar logic when we decreasing keys. To find Min or Max, just look at the first and last cells of our linked list. Good idea also is to use sentinels before head and after tail.

Complexity

Complexity of all operations is O(1).

Code

class ListNode:
    def __init__(self, val=0, nxt=None, prv=None):
        self.val = val
        self.keys = set()
        self.next = nxt
        self.prev = prv

class DoubleLinkedList:
    def __init__(self):
        self.head = ListNode()
        self.tail = ListNode()
        self.head.next = self.tail
        self.tail.prev = self.head

    def insert(self, Node, newNode):
        newNode.next = Node.next
        Node.next.prev = newNode
        Node.next = newNode
        newNode.prev = Node

    def delete(self, node):
        node.prev.next = node.next
        node.next.prev = node.prev

class AllOne(object):
    def __init__(self):
        self.dll = DoubleLinkedList()
        self.mapping = {}  # key to node

    def inc(self, key):
        if key not in self.mapping:  # find current node and remove key
            cur_node = self.dll.head
        else:
            cur_node = self.mapping[key]
            cur_node.keys.remove(key)

        if cur_node.val + 1 != cur_node.next.val:  # insert new node
            new_node = ListNode(cur_node.val + 1)
            self.dll.insert(cur_node, new_node)
        else:
            new_node = cur_node.next

        new_node.keys.add(key)  # update new_node
        self.mapping[key] = new_node  # ... and mapping of key to new_node

        if not cur_node.keys and cur_node.val != 0:  # delete current node if not seninel
            self.dll.delete(cur_node)

    def dec(self, key):
        if key not in self.mapping: return

        cur_node = self.mapping[key]
        self.mapping.pop(key)
        cur_node.keys.remove(key)

        if cur_node.val != 1:
            if cur_node.val - 1 != cur_node.prev.val:  # insert new node
                new_node = ListNode(cur_node.val - 1)
                self.dll.insert(cur_node.prev, new_node)
            else:
                new_node = cur_node.prev
            new_node.keys.add(key)
            self.mapping[key] = new_node

        if not cur_node.keys:  # delete current node
            self.dll.delete(cur_node)
            
    def getMaxKey(self):
        if self.dll.tail.prev.val == 0: return ""
        return next(iter(self.dll.tail.prev.keys))

    def getMinKey(self):
        if self.dll.head.next.val == 0: return ""
        return next(iter(self.dll.head.next.keys))