Problem statement

https://leetcode.com/problems/distinct-echo-substrings/

Solution 1

The idea is calculate hashes and then for each hash check all indexes where we meet it and check if we have difference m. In this problem we can not allow to compare strings all the time so to get AC, we need to choose some big q, better quite rare number, so there will be no test specifically to ruin it. We can call it YOLO rolling hash, just for contests. Say q = (1 << 64) - 257 or (1<<31) - 159.

Complexity

Time complexity is O(n^2), space complexity is O(n).

Code

class Solution:
    def distinctEchoSubstrings(self, text):
        def RabinKarp(text, M):
            q = (1 << 31) - 7
            h, t, d = (1<<(8*M-8))%q, 0, 256
            dic = defaultdict(set)

            for i in range(M): 
                t = (d * t + text[i])% q

            dic[t].add(i-M+1)

            for i in range(len(text) - M):
                t = (d*(t-text[i]*h) + text[i + M])% q
                dic[t].add(i+1)
            return dic
        
        n, ans = len(text), 0
        nums = [ord(i) - 97 for i in text]
        for m in range(1, n//2 + 1):
            hashes = RabinKarp(nums, m)
            for g in hashes.values():
                ans += len(g & {t+m for t in g}) != 0
                
        return ans

Solution 2

Idea is to build suffix array and then for each length compare pairs of adjacent element ins suffix array and check if we have numbers differ by m in set.

Complexity

Time complexity is O(n^2). Unfortunatelly it gives TLE sometimes, because we jump between indexes a lot.

Code

class Solution:
    def distinctEchoSubstrings(self, text):
        def ranks(l):
            index = {v: i for i, v in enumerate(sorted(set(l)))}
            return [index[v] for v in l]

        def suffixArray(s):
            line = ranks(s)
            n, k, ans, sa = len(s), 1, [line], [0]*len(s)
            while k < n - 1:
                line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1)))
                ans, k = ans + [line], k << 1
            for i, k in enumerate(ans[-1]): sa[k] = i
            return ans, sa
        
        def compare(i, j, l, k):
            a = (c[k][i], c[k][(i+l-(1<<k))%n])
            b = (c[k][j], c[k][(j+l-(1<<k))%n])
            return 0 if a == b else 1 if a < b else -1
        
        c, sa = suffixArray([ord(i) - ord("a") for i in text])
        n, ans = len(text), 0

        for m in range(1, n//2 + 1):
            ml, a = floor(log2(m)), 0

            groups = defaultdict(set)
            for i in range(1, n):
                groups[a].add(sa[i-1])
                a += compare(sa[i-1], sa[i], m, ml)
            groups[a].add(sa[-1])
                
            for g in groups.values():
                ans += len(g & {t+m for t in g}) != 0
            
        return ans