Problem statement

https://leetcode.com/problems/minimum-incompatibility/

Solution

We need to keep state in the form: bitmask, last, where:

  1. bitmask is bitmask for all numbers already taken.
  2. last is index of last taken number.

How are we going to form our groups: sort our numbers first and imagine we have numbers [1,1,2,3,3,4]: and k = 3. Then we need to create first group of 2 elements, for example it can be elements with indexes 1,2, and we have current bitmask 011000, then we need to form another group of 2 elements, for example it can be elements with indexes 0,5, so we have bitmask 111001 now. Finally we put last two elements in last group. We can interpret it as TSP problem: we choose path 1->2->0->5->3->4. Note, that in each group we only choose increasing indexes, whereas between groups than can decrease (but can increase also).

Now, what we have in our algorithm:

  1. We iterate over all possible mask
  2. Calculate places, where we have 1 in this mask.
  3. Now, we can have two cases: first, if len(n_z_bits)%(n//k) == 1: it means, that it is time to start new group, so we can choose any index we want, not neccecerily bigger than previous, so we choose permutations here: there will be exactly t(t-1)/2 pairs and for each of them we update dp[mask][l].
  4. In other case, it means, that we need to continue to build our group, so next index should be bigger than previous and we choose combinations here. Also we check that new number is not equal to last one we have and again update dp[mask][l].
  5. Finally, we return minimum from dp[-1].

Complexity

Time complexity is O(n*n*2^n) as TSP have: we have 2^n bitmasks and we have O(n^2) time to process each mask. In python it is working not very fast, I can say it is barely accepted, so, I think there should be O(n*2^n) solution as well, if you know such solution (desirebly with explanations), please let me know!

Code

class Solution:
    def minimumIncompatibility(self, nums, k):
        n = len(nums)
        if k == n: return 0
        dp = [[float("inf")] * n for _ in range(1<<n)] 
        nums.sort()
        for i in range(n): dp[1<<i][i] = 0

        for mask in range(1<<n):
            n_z_bits = [j for j in range(n) if mask&(1<<j)]
            if len(n_z_bits)%(n//k) == 1:
                for j, l in permutations(n_z_bits, 2):
                    dp[mask][l] = min(dp[mask][l], dp[mask^(1<<l)][j])
            else:
                for j, l in combinations(n_z_bits, 2):
                    if nums[j] != nums[l]:
                        dp[mask][j] = min(dp[mask][j], dp[mask^(1<<j)][l] + nums[l] - nums[j])
                        
        return min(dp[-1]) if min(dp[-1]) != float("inf") else -1