Problem statement

https://leetcode.com/problems/minimum-cost-to-merge-stones/

Solution

Let us use dp with states: dfs(i, j, k), where i, j is ends of subarray, k is how many parts we need to split this subarray and this function will answer question: what is the minimum cost we can get on nums[i:j] (note that j is not inclusive). Let us consider the last move on given subarray, it means that on the last step we split [i:j] into K parts, that is we have [i:i_1], [i_1:i_2], ..., [i_{K-1}:j], then we need to recursively apply dfs to all this parts. In fact what we do is reverse our problem: we start from one big pile and one each step separate it into k piles. However it is not enough to pass, we need to go deeper. Let us consider only i1 = l and check how many parts we have:

  1. K >= 3: then we need to compute dfs(i, l, K): this is first part of our split. Also we have dfs(l, j, k-1) the second half of our split, which we need to split further into k-1 parts. Also we need to add sum of nums[i:l].
  2. K = 2, this means that we have only two parts to work with and logic will be the following: we evaluate dfs(i, l, K) and dfs(l, j, K) and also sum of nums[i:j]. Why we need to do here separate case, because we always split second part and take sum of elements in first part. So, in the end we will evaluate sum of all elements except the last part, and we need to add nums[l:j].

Also note, that we use l in range(i + 1, j, K - 1): so we will always have condition: i - j is divisible by K - 1.

Complexity

Now let us discuss complexity: we can choose i in O(n) ways, we can choose j in O(n/k) ways, because of divisibility condition. We can choose k in O(k) ways. So far we have O(n * n/k * k) = O(n^2) states. From each state we can make O(n/k) different steps: so overall time complexity is O(n^3/k). Space complexity is O(n^2).

Code

class Solution:
    def mergeStones(self, stones, K):
        acc = [0] + list(accumulate(stones))
        if (len(stones) - 1) % (K - 1) != 0: return -1
        
        @lru_cache(None)
        def dfs(i, j, k):
            if i + 1 == j: return 0

            cand = float("inf")
            for l in range(i+1, j, K - 1):
                k2 = K if k == 2 else k - 1
                l2 = j if k == 2 else l
                cand = min(cand, dfs(i, l, K) + dfs(l, j, k2) + acc[l2] - acc[i])
                
            return cand
        
        result = dfs(0, len(stones), K)
        return result if result != float("inf") else -1