Problem statement

https://binarysearch.com/problems/Strings-Down-Under/

Solution

This problem is similar to Leetcode Problem 1397. Find All Good Strings. The idea is to keep dp with states (i, last, q, border), where we have:

  1. i is index we reached in string s.
  2. last is the last symbol.
  3. q is how many times this symbols i repeated.
  4. border means if we on the border or not: it means for example if we have azc and we have so-far az, we are on the border, if we have ay or smaller, we are not on the border anymore, that is we can use any next symbol.
  5. Answer is how many strings which given parameters we have modulo 10**9 + 7.

If practice we will keep in dp all items for level i, from which we reconstruct level i + 1. Here we can have different options:

  1. The case when we are not on border, then we can add any symbol from alphabet. If we are on border, then we can add any symbol smaller than next element s[i]. Also if we are on border, there is unique case we can have to add element and to be still on board. Also we need to check cases than number of repetitions is not big.

Complexity

In total we have O(n*26*n*2) states and from each state we can have at most 26 transitions. So, in total complexity is O(n^2s^2), where s is the size of alphabet. Given constraint

Code

class Solution:
    def solve(self, s, k):
        n, M = len(s), 10**9 + 7
        alphabet = "abcdefghijklmnopqrstuvwxyz"
        dp = {("", 0, True): 1}

        for i in range(n):
            dp2 = Counter()
            for (last, q, border), val in dp.items():
                end = ord(s[i]) - ord("a") if border else 26
                
                for l in alphabet[:end]:
                    last2 = 1 + q*(l == last)
                    if last2 <= k: 
                        dp2[l, last2, False] = (dp2[l, last2, False] + val) % M

                if border:
                    last2 = 1 + q*(last == s[i])
                    if last2 <= k: 
                        dp2[s[i], last2, True] = (dp2[s[i], last2, True] + val) % M
            dp = dp2

        return sum(i for i in dp.values()) % M

Remark

There is also math solution with complexity O(n)!