This solution is inspired by several other solutions I saw here in discussion. The main idea you should ask youself when you see this problem: what to do with XOR? You should look for familiar patterns and the most closest problem here is 421. Maximum XOR of Two Numbers in an Array. There are different ways to solve this problem and it is very important to understand differend ways so you can extend it for more general cases like we have here.

General idea is the following: we traverse our tree using dfs and each time we update our trie, where we keep binary representation of numbers. We need both to insert and to remove elements from trie, because once we finish traverse some node in dfs, we no longer need it. Here I spend some time in python to avoid TLE, and finally we have Trie class as following:

  1. We keep embedded dictionary of dictionaries. Key -1 corresponding for frequency: we need this when we add and delete binary numbers from trie.
  2. Also we use knowledge that all vals in queries are no bigger than 2*10^5 < 2^18. so we always will add binary numbers with leading zeroes, for example when we add 11101, we add 000000000000011101 in fact. Also we update frequencies: increase if we add and decrease if we delete.
  3. Now, for find it is classical if you already solved 421: Each time we try to find bit such that XOR of bits is equal to 1 and go to accoridng branch of our trie. If it is not possible, we go to another branch. Note, that some frequencies can be 0, so we need to make sure that we go only if it is not zero.

Now, let us talk about original funcion:

  1. First of all, we need to rearrange queries to defaultdict Q: where for each node we keep pair (index, value).
  2. Also we create graph from our parents nodes.
  3. Now, we traverse our tree, using dfs: first insert value when we first visit it, then for each pair i, val for this node update ans, run function recursively for all children and finally remove value from trie.


Let M be maximum number of bits possible in queries. Then we have time complexity O((n+q)*M), where n is size of our tree and q is number of queries. Time complexity is O(n*M). In python however it works quite slow, even though I use ints, not strings, I was a bit suprised that it did not make my code faster, anybody have any ideas?


class Trie:
    def __init__(self):
        self.root = {-1: 0}

    def insert(self, num, f):
        d = self.root
        for i in range(18, -1, -1):
            bit = (num >> i) & 1
            d = d.setdefault(bit, dict())
            d[-1] = d.get(-1, 0) + f

    def find(self, num):
        node = self.root
        res = 0
        for i in range(18, -1, -1):
            bit = (num >> i) & 1
            desired = 1-bit if 1-bit in node and node[1-bit][-1] > 0 else bit
            res += (desired ^ bit) << i
            node = node[desired]
        return res

class Solution:
    def maxGeneticDifference(self, parents, queries):
        Q, graph = defaultdict(list), defaultdict(list)
        for i, (node, val) in enumerate(queries):
            Q[node].append((i, val))
        for i, x in enumerate(parents):
        ans, trie = [-1 for _ in queries], Trie()
        def dfs(node):
            trie.insert(node, 1)
            for i, val in Q[node]: ans[i] = trie.find(val)
            for neib in graph[node]: dfs(neib)
            trie.insert(node, -1)
        return ans

Solution 2

My trie solution was working not very fast, so I decided to rewrite it, using the following idea: imagine, that we have number 1011, then we going to add all prefixes to list of Counters: 1 to prefixes of length 0, 10 to prefixes of length 2 and so on. Then it is easier to update all counters when we add or remove some number. Also the logic to find the maximum exactly the same as previous and dfs prat is almost the same. UPD If we use not Counters, but lists then we have more gain in speed, now it is around 5000ms (choose one of two lines where we defined cnts)

  1. cnts = [Counter() for _ in range(M)]
  2. cnts = [[0]*(1<<(M-i)) for i in range(M)]


It is the same as previous approach.


class Solution:
    def maxGeneticDifference(self, parents, queries):
        def query(num):
            ans = 0
            for i in range(M-1, -1, -1):
                bit = (num >> i) & 1
                c1, c2 = ans * 2 + (1-bit), ans * 2 + bit
                ans = c1 if cnts[i][c1] > 0 else c2
            return ans ^ num
        Q, graph = defaultdict(list), defaultdict(list)
        for i, (node, val) in enumerate(queries):
        for i, x in enumerate(parents):
        ans, M = [-1 for _ in queries], 18
        cnts = [[0]*(1<<(M-i)) for i in range(M)]
        def dfs(node):
            for i in range(M): cnts[i][node>>i] += 1
            for i, val in Q[node]: ans[i] = query(val)
            for neib in graph[node]: dfs(neib)
            for i in range(M): cnts[i][node>>i] -= 1
        return ans

