graph
graph algo
spanning tree
dsu
sort
]
Leetcode 1489. Find Critical and Pseudo-Critical Edges in Minimum Spanning
Problem statement
https://leetcode.com/problems/find-critical-and-pseudo-critical-edges-in-minimum-spanning-tree/
Solution
The idea is to use Kruskal algorithm several times. Let us use function kruskal(n, edges, taken)
, where:
n
is number of nodes, whre we have nodes(0, 1, ..., n-1)
.edges
is edges in our graph.taken
is edges we force to include in our graph.
Then to find critical connection we need to delete edge and check if weigth of MST will increase. Also we add this edge to taken
edges and run algorithm once again. Let MST
be original weight, MST1
is weight with deleted edge and MST2
is weight with taken
edges. We can have 3
cases:
- If
MST1 == MST2 == MST
, it means that node is pseudo-critical: it appear in both trees with this edge and without it. - If
MST2 == MST != MST1, it means that node is critical: when we delete this node
MSTbecames not equal to
MST1` (another condition is reduncant)
Complexity
It is O(E log E)
to one Kruskal pass, and we do O(E)
of them with total time complexity is O(E^2 log E)
. Space complexity is O(E)
.
Code
class DSU:
def __init__(self, N):
self.p = list(range(N))
def find(self, x):
if self.p[x] != x:
self.p[x] = self.find(self.p[x])
return self.p[x]
def union(self, x, y):
xr = self.find(x)
yr = self.find(y)
self.p[xr] = yr
class Solution:
def findCriticalAndPseudoCriticalEdges(self, n, edges):
def kruskal(n, edges, taken):
dsu, ans = DSU(n), 0
for x, y, w in taken:
dsu.union(x, y)
ans += w
for x, y, w in sorted(edges, key = lambda x:x[2]):
if dsu.find(x) == dsu.find(y): continue
dsu.union(x, y)
ans += w
return ans
ans = [[], []]
MST = kruskal(n, edges, [])
for i, (x, y, w) in enumerate(edges):
edges2 = edges[:]
edges2.remove([x, y, w])
MST1 = kruskal(n, edges2, [])
MST2 = kruskal(n, edges2, [[x, y, w]])
if MST2 == MST:
ind = 1 if MST1 == MST else 0
ans[ind].append(i)
return ans
Solution 2
In fact it can be easily reduced to O(E^2)
, if we sort our edges only once and use flags for deleted edges. However in practice it will work approximately the same time due to fast sort implementation.
Also there is O(E * log E)
solution as well! The idea is the following: do usual Kruskal algorithm and group all edges by equal weights. Then if we have edge (x, y)
and it connects already connected parts, it can not be critical or pseudo-critial. In the opposite case we
create graph with nodes being our connected components so-far and edges between these nodes. If we have more than one edge between two nodes, all of these edges are pseudo-critical. Finally, we need to find all brides in our graph: they will be critical edges and all other will be pseudo-critical.
Complexity
It is O(E log E)
for time and O(E)
for space
Code
class DSU:
def __init__(self, N):
self.p = list(range(N))
def find(self, x):
if self.p[x] != x:
self.p[x] = self.find(self.p[x])
return self.p[x]
def union(self, x, y):
xr = self.find(x)
yr = self.find(y)
self.p[xr] = yr
class Solution:
def bridges(self, n, G):
used, tin, fup = [0]*n, [-1]*n, [-1]*n
self.timer, ans = 0, []
def dfs(node, par = -1):
used[node] = 1
tin[node] = fup[node] = self.timer + 1
self.timer += 1
for child in G[node]:
if child == par: continue
if used[child] == 1:
fup[node] = min(fup[node], tin[child])
else:
dfs(child, node)
fup[node] = min(fup[node], fup[child])
if fup[child] > tin[node]: ans.append((node, child))
for i in range(n):
if not used[i]: dfs(i)
return ans
def findCriticalAndPseudoCriticalEdges(self, n, edges):
E = {(x, y): i for i, (x, y, _) in enumerate(edges)}
dic = defaultdict(list)
for u, v, w in edges:
dic[w].append([u, v])
dsu, ans = DSU(n), [[], []]
for w in sorted(dic):
seen = defaultdict(list)
for u, v in dic[w]:
pu, pv = dsu.find(u), dsu.find(v)
if pu == pv: continue
seen[min(pu, pv), max(pu, pv)].append((u, v))
for e in seen:
ans[1] += seen[e]
graph = defaultdict(list)
for pu, pv in seen:
graph[pu].append(pv)
graph[pv].append(pu)
dsu.union(pu, pv)
brid = self.bridges(n, graph)
brid = [(min(x, y), max(x, y)) for x, y in brid]
ans[0] += [seen[e][0] for e in brid if len(seen[e]) == 1]
return [[E[i] for i in ans[0]], [E[i] for i in set(ans[1]) - set(ans[0])]]