math
dp
matrix power
]
Leetcode 1223. Dice Roll Simulation
Problem statement
https://leetcode.com/problems/dice-roll-simulation/
Solution 1
The idea is to use dp with states (last, number of occurences). Imagine for simplicity, that we have not 6
sides on cube, but M = 3
and rollMax = [3, 2, 4]
. Then we have states:
S11, S12, S13
S21, S22
S31, S32, S33, S34
And we can recalculate
S11_new = S21 + S22 + S31 + S32 + S33 + S34
S12_new = S11
S13_new = S12
S21_new = S11 + S12 + S13 + S31 + S32 + S33 + S34
S22_new = S21
S31_new = S11 + S12 + S13 + S21 + S22
S32_new = S31
S33_new = S32
S34_new = S33
To quickly update all states we also keep sums in each line and total sum. Then we can update all states in O(m)
, where m
is sum of rollMax.
Complexity
Time complexity is O(m*n)
. Space complexity is O(m)
.
Code
class Solution:
def dieSimulator(self, n, rollMax):
M, MOD = 6, 10**9 + 7
S1, S2, S3 = [[1] + [0]*(t-1) for t in rollMax], [1]*M, M
for i in range(n - 1):
for i in range(M):
for j in range(rollMax[i] - 1, 0, -1):
S1[i][j] = S1[i][j-1]
S1[i][0] = (S3 - S2[i]) % MOD
S2 = [sum(i) % MOD for i in S1]
S3 = sum(S2) % MOD
return S3
Remark
Actually it can be done in O(n*6)
, but we need to keep not the last row but all of them (or max(rollMax) of them)
Solution 2
There is also matrix power solution where we use transaction matrix an evaluate power of it. We need to be careful with overflow here, so I use list multiplication, not numpy.
Complexity
Time complexity is O(m^3 * log n)
, space is O(m^2)
.
Code
class Solution:
def dieSimulator(self, n, rollMax):
def mult(X, Y, MOD):
return [[sum(a*b for a,b in zip(x, y))%MOD for y in zip(*Y)] for x in X]
def power(M, n, MOD):
result = [[0] * m for _ in range(m)]
for i in range(m): result[i][i] = 1
while n > 0:
if n%2: result = mult(M, result, MOD)
M = mult(M, M, MOD)
n //= 2
return result
m, MOD = sum(rollMax), 10**9 + 7
mat, s = [[0] * m for _ in range(m)], [0] * m
acc = [0] + list(accumulate(rollMax))
for i, j in zip(acc, acc[1:]):
for k in chain(range(i), range(j, m)): mat[k][i] = 1
for i in range(m-1):
mat[i][i+1] = 1
for i in acc[:-1]: s[i] = 1
return sum(mult([s], power(mat, n-1, MOD), MOD)[0]) % MOD