bit-dp
dp
bit manipulation
]
Leetcode 1681. Minimum Incompatibility
Problem statement
https://leetcode.com/problems/minimum-incompatibility/
Solution
We need to keep state in the form: bitmask, last
, where:
bitmask
is bitmask for all numbers already taken.last
is index of last taken number.
How are we going to form our groups: sort our numbers first and imagine we have numbers [1,1,2,3,3,4]
: and k = 3
. Then we need to create first group of 2
elements, for example it can be elements with indexes 1,2
, and we have current bitmask 011000
, then we need to form another group of 2
elements, for example it can be elements with indexes 0,5
, so we have bitmask 111001
now. Finally we put last two elements in last group. We can interpret it as TSP problem: we choose path 1->2->0->5->3->4
. Note, that in each group we only choose increasing indexes, whereas between groups than can decrease (but can increase also).
Now, what we have in our algorithm:
- We iterate over all possible
mask
- Calculate places, where we have
1
in this mask. - Now, we can have two cases: first, if
len(n_z_bits)%(n//k) == 1
: it means, that it is time to start new group, so we can choose any index we want, not neccecerily bigger than previous, so we choosepermutations
here: there will be exactlyt(t-1)/2
pairs and for each of them we updatedp[mask][l]
. - In other case, it means, that we need to continue to build our group, so next index should be bigger than previous and we choose
combinations
here. Also we check that new number is not equal to last one we have and again updatedp[mask][l]
. - Finally, we return minimum from
dp[-1]
.
Complexity
Time complexity is O(n*n*2^n)
as TSP have: we have 2^n
bitmasks and we have O(n^2)
time to process each mask. In python it is working not very fast, I can say it is barely accepted, so, I think there should be O(n*2^n)
solution as well, if you know such solution (desirebly with explanations), please let me know!
Code
class Solution:
def minimumIncompatibility(self, nums, k):
n = len(nums)
if k == n: return 0
dp = [[float("inf")] * n for _ in range(1<<n)]
nums.sort()
for i in range(n): dp[1<<i][i] = 0
for mask in range(1<<n):
n_z_bits = [j for j in range(n) if mask&(1<<j)]
if len(n_z_bits)%(n//k) == 1:
for j, l in permutations(n_z_bits, 2):
dp[mask][l] = min(dp[mask][l], dp[mask^(1<<l)][j])
else:
for j, l in combinations(n_z_bits, 2):
if nums[j] != nums[l]:
dp[mask][j] = min(dp[mask][j], dp[mask^(1<<j)][l] + nums[l] - nums[j])
return min(dp[-1]) if min(dp[-1]) != float("inf") else -1