Problem statement

https://leetcode.com/problems/dice-roll-simulation/

Solution 1

The idea is to use dp with states (last, number of occurences). Imagine for simplicity, that we have not 6 sides on cube, but M = 3 and rollMax = [3, 2, 4]. Then we have states:

S11, S12, S13 S21, S22 S31, S32, S33, S34

And we can recalculate

S11_new = S21 + S22 + S31 + S32 + S33 + S34 S12_new = S11 S13_new = S12 S21_new = S11 + S12 + S13 + S31 + S32 + S33 + S34 S22_new = S21 S31_new = S11 + S12 + S13 + S21 + S22 S32_new = S31 S33_new = S32 S34_new = S33

To quickly update all states we also keep sums in each line and total sum. Then we can update all states in O(m), where m is sum of rollMax.

Complexity

Time complexity is O(m*n). Space complexity is O(m).

Code

class Solution:
    def dieSimulator(self, n, rollMax):
        M, MOD = 6, 10**9 + 7
        S1, S2, S3 = [[1] + [0]*(t-1) for t in rollMax], [1]*M, M

        for i in range(n - 1):
            for i in range(M):
                for j in range(rollMax[i] - 1, 0, -1):
                    S1[i][j] = S1[i][j-1]
                S1[i][0] = (S3 - S2[i]) % MOD
                
            S2 = [sum(i) % MOD for i in S1] 
            S3 = sum(S2) % MOD
        
        return S3

Remark

Actually it can be done in O(n*6), but we need to keep not the last row but all of them (or max(rollMax) of them)

Solution 2

There is also matrix power solution where we use transaction matrix an evaluate power of it. We need to be careful with overflow here, so I use list multiplication, not numpy.

Complexity

Time complexity is O(m^3 * log n), space is O(m^2).

Code

class Solution:
    def dieSimulator(self, n, rollMax):
        def mult(X, Y, MOD):
            return [[sum(a*b for a,b in zip(x, y))%MOD for y in zip(*Y)] for x in X]
        
        def power(M, n, MOD):
            result = [[0] * m for _ in range(m)]
            for i in range(m): result[i][i] = 1
            while n > 0:
                if n%2: result = mult(M, result, MOD)
                M = mult(M, M, MOD)
                n //= 2
            return result
        
        m, MOD = sum(rollMax), 10**9 + 7
        mat, s = [[0] * m for _ in range(m)], [0] * m
        acc = [0] + list(accumulate(rollMax))

        for i, j in zip(acc, acc[1:]):
            for k in chain(range(i), range(j, m)): mat[k][i] = 1  
                
        for i in range(m-1):
            mat[i][i+1] = 1
            
        for i in acc[:-1]: s[i] = 1

        return sum(mult([s], power(mat, n-1, MOD), MOD)[0]) % MOD