Problem statement

https://leetcode.com/problems/design-search-autocomplete-system/

Solution

On possible solution is to use tries, where we have our Trie but also have one more field: set of all possible sentences we can get, if we start from this node, that is set of sentences starting with all upper nodes. It is not very memory efficient, so we can keep indexes of sentences instead and/or keep only 3 of them (not in the code below).

  1. init(): we have buffer, where we keep current building sentence, stimes: dictionary with frequencies of sentences and trie: our Trie. When we initialize, we go through all given sentences and their frequencies and add them to Trie, using addSentence function and also update stimes dictionary. Also we have tnode, which is the place of current node we are in now.
  2. input(c): we check if we are finishing our sentence. If not, add symbol to buffer and go the children node. Choose 3 hottest candidates among all sentences in this node. If symbol is #, we need to finish, so we update frequency for our buffer, add it to Trie, make buffer empty and return back to the root of our Trie.
  3. addSentence(sentence): we start from the root of our Trie, iterate over letters update children dictionary and also update set of sentences.

Complexity

Time complexity of addSentence now is O(m^2) to add sentence of length m: we add it to m different nodes. Time complexity of input now is O(m^2) as well: because we need to check all set and choose the hottest 3 options.

Code

class TrieNode:
    def __init__(self):
        self.children = {}
        self.sentences = set()
        
class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word):
        node = self.root
        for symbol in word:
            node = node.children.setdefault(symbol, TrieNode())
            node.sentences.add(word)


class AutocompleteSystem(object):
    def __init__(self, sentences, times):
        self.buffer = ""
        self.stimes = defaultdict(int)
        self.trie = Trie()
        for s, t in zip(sentences, times):
            self.stimes[s] = t
            self.trie.insert(s)
        self.node = self.trie.root
 
    def input(self, c):
        ans = []
        if c != "#":
            self.buffer += c
            if self.node: self.node = self.node.children.get(c)
            if self.node:
                ans = sorted(self.node.sentences, key=lambda x: (-self.stimes[x], x))[:3]
        else:
            self.stimes[self.buffer] += 1
            self.trie.insert(self.buffer)
            self.buffer = ""
            self.node = self.trie.root
        return ans