string
rolling hash
suffix array
sliding window
]
Leetcode 1316. Distinct Echo Substrings
Problem statement
https://leetcode.com/problems/distinct-echo-substrings/
Solution 1
The idea is calculate hashes and then for each hash check all indexes where we meet it and check if we have difference m. In this problem we can not allow to compare strings all the time so to get AC, we need to choose some big q, better quite rare number, so there will be no test specifically to ruin it. We can call it YOLO rolling hash, just for contests. Say q = (1 << 64) - 257 or (1<<31) - 159.
Complexity
Time complexity is O(n^2), space complexity is O(n).
Code
class Solution:
def distinctEchoSubstrings(self, text):
def RabinKarp(text, M):
q = (1 << 31) - 7
h, t, d = (1<<(8*M-8))%q, 0, 256
dic = defaultdict(set)
for i in range(M):
t = (d * t + text[i])% q
dic[t].add(i-M+1)
for i in range(len(text) - M):
t = (d*(t-text[i]*h) + text[i + M])% q
dic[t].add(i+1)
return dic
n, ans = len(text), 0
nums = [ord(i) - 97 for i in text]
for m in range(1, n//2 + 1):
hashes = RabinKarp(nums, m)
for g in hashes.values():
ans += len(g & {t+m for t in g}) != 0
return ans
Solution 2
Idea is to build suffix array and then for each length compare pairs of adjacent element ins suffix array and check if we have numbers differ by m in set.
Complexity
Time complexity is O(n^2). Unfortunatelly it gives TLE sometimes, because we jump between indexes a lot.
Code
class Solution:
def distinctEchoSubstrings(self, text):
def ranks(l):
index = {v: i for i, v in enumerate(sorted(set(l)))}
return [index[v] for v in l]
def suffixArray(s):
line = ranks(s)
n, k, ans, sa = len(s), 1, [line], [0]*len(s)
while k < n - 1:
line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1)))
ans, k = ans + [line], k << 1
for i, k in enumerate(ans[-1]): sa[k] = i
return ans, sa
def compare(i, j, l, k):
a = (c[k][i], c[k][(i+l-(1<<k))%n])
b = (c[k][j], c[k][(j+l-(1<<k))%n])
return 0 if a == b else 1 if a < b else -1
c, sa = suffixArray([ord(i) - ord("a") for i in text])
n, ans = len(text), 0
for m in range(1, n//2 + 1):
ml, a = floor(log2(m)), 0
groups = defaultdict(set)
for i in range(1, n):
groups[a].add(sa[i-1])
a += compare(sa[i-1], sa[i], m, ml)
groups[a].add(sa[-1])
for g in groups.values():
ans += len(g & {t+m for t in g}) != 0
return ans