Problem statement

https://leetcode.com/problems/the-number-of-good-subsets/

Solution

The key idea of this problem is to notice that values are not very big, in fact they are in range [1, 30]. Also, notice, that if we take for example number x, then we can not take another x, because product will be divisible by x. Finally, if we decided for example to take x and we have k occurences of x in our nums, then we have k options to take it to get different subsequences.

Let us create list of all small prime numbers P = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] and for each number from 2 to 30, create bitmask corresponding to this number, for example bm[2] = 0000000001, bm[3] = 0000000010, bm[5] = 0000000100, bm[6] = 0000000011 and so on. Why we need these bitmasks? To quickly check if we can take number or not.

Let us use function dp(mask, num) with the following information:

  1. mask is mask for already booked primes, for example 01011 00010 means that 3, 13, 17, 29 are booked.
  2. num is number we are currently processing.

How we can calculate dp(mask, num)?

  1. First option is to take dp(mask, num - 1), in this case we just do not take num and continue to the smaller one.
  2. What if we want to take num? We need to make sure it is not in bad list: numbers, for which we know that they divisible by square. Also we need to be sure, that mask | bm[num] == mask, this condition means, that the places we need to book are still avaliable. Finally, how many options we have it this case? It is dp(mask ^ bm[num], num - 1) * cnt[num], where we can take num in any place we meet it in our nums.

Also we need to deal with 1. As I said previously all number can be taken only once. 1 can be taken any number of times, that is 2^cnt[1].

Complexity

We have 2^10 possible masks and 30 possible num, so we have time complexity O(2^10 * 30), because we have this number of states and only 2 transitions from one state to others. Space complexity is O(2^10 * 30) as well.

Code

class Solution:
    def numberOfGoodSubsets(self, nums):
        P = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
        cnt = Counter(nums)
        bm = [sum(1<<i for i, p in enumerate(P) if x % p == 0) for x in range(31)]
        bad = set([4, 8, 9, 12, 16, 18, 20, 24, 25, 27, 28])
        M = 10**9 + 7
        
        @lru_cache(None)
        def dp(mask, num):
            if num == 1: return 1
            ans = dp(mask, num - 1)
            if num not in bad and mask | bm[num] == mask:
                ans += dp(mask ^ bm[num], num - 1) * cnt[num]
            return ans % M

        return ((dp(1023, 30) - 1) * pow(2, cnt[1], M)) % M