string
rolling hash
suffix array
sliding window
]
Leetcode 1316. Distinct Echo Substrings
Problem statement
https://leetcode.com/problems/distinct-echo-substrings/
Solution 1
The idea is calculate hashes and then for each hash check all indexes where we meet it and check if we have difference m
. In this problem we can not allow to compare strings all the time so to get AC, we need to choose some big q
, better quite rare number, so there will be no test specifically to ruin it. We can call it YOLO rolling hash, just for contests. Say q = (1 << 64) - 257
or (1<<31) - 159
.
Complexity
Time complexity is O(n^2)
, space complexity is O(n)
.
Code
class Solution:
def distinctEchoSubstrings(self, text):
def RabinKarp(text, M):
q = (1 << 31) - 7
h, t, d = (1<<(8*M-8))%q, 0, 256
dic = defaultdict(set)
for i in range(M):
t = (d * t + text[i])% q
dic[t].add(i-M+1)
for i in range(len(text) - M):
t = (d*(t-text[i]*h) + text[i + M])% q
dic[t].add(i+1)
return dic
n, ans = len(text), 0
nums = [ord(i) - 97 for i in text]
for m in range(1, n//2 + 1):
hashes = RabinKarp(nums, m)
for g in hashes.values():
ans += len(g & {t+m for t in g}) != 0
return ans
Solution 2
Idea is to build suffix array and then for each length compare pairs of adjacent element ins suffix array and check if we have numbers differ by m
in set.
Complexity
Time complexity is O(n^2)
. Unfortunatelly it gives TLE sometimes, because we jump between indexes a lot.
Code
class Solution:
def distinctEchoSubstrings(self, text):
def ranks(l):
index = {v: i for i, v in enumerate(sorted(set(l)))}
return [index[v] for v in l]
def suffixArray(s):
line = ranks(s)
n, k, ans, sa = len(s), 1, [line], [0]*len(s)
while k < n - 1:
line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1)))
ans, k = ans + [line], k << 1
for i, k in enumerate(ans[-1]): sa[k] = i
return ans, sa
def compare(i, j, l, k):
a = (c[k][i], c[k][(i+l-(1<<k))%n])
b = (c[k][j], c[k][(j+l-(1<<k))%n])
return 0 if a == b else 1 if a < b else -1
c, sa = suffixArray([ord(i) - ord("a") for i in text])
n, ans = len(text), 0
for m in range(1, n//2 + 1):
ml, a = floor(log2(m)), 0
groups = defaultdict(set)
for i in range(1, n):
groups[a].add(sa[i-1])
a += compare(sa[i-1], sa[i], m, ml)
groups[a].add(sa[-1])
for g in groups.values():
ans += len(g & {t+m for t in g}) != 0
return ans