Problem statement

https://leetcode.com/problems/find-critical-and-pseudo-critical-edges-in-minimum-spanning-tree/

Solution

The idea is to use Kruskal algorithm several times. Let us use function kruskal(n, edges, taken), where:

  1. n is number of nodes, whre we have nodes (0, 1, ..., n-1).
  2. edges is edges in our graph.
  3. taken is edges we force to include in our graph.

Then to find critical connection we need to delete edge and check if weigth of MST will increase. Also we add this edge to taken edges and run algorithm once again. Let MST be original weight, MST1 is weight with deleted edge and MST2 is weight with taken edges. We can have 3 cases:

  1. If MST1 == MST2 == MST, it means that node is pseudo-critical: it appear in both trees with this edge and without it.
  2. If MST2 == MST != MST1, it means that node is critical: when we delete this node MST becames not equal to MST1` (another condition is reduncant)

Complexity

It is O(E log E) to one Kruskal pass, and we do O(E) of them with total time complexity is O(E^2 log E). Space complexity is O(E).

Code

class DSU:
    def __init__(self, N):
        self.p = list(range(N))

    def find(self, x):
        if self.p[x] != x:
            self.p[x] = self.find(self.p[x])
        return self.p[x]

    def union(self, x, y):
        xr = self.find(x)
        yr = self.find(y)
        self.p[xr] = yr

class Solution:
    def findCriticalAndPseudoCriticalEdges(self, n, edges):
        def kruskal(n, edges, taken):
            dsu, ans = DSU(n), 0
            for x, y, w in taken: 
                dsu.union(x, y)
                ans += w
            for x, y, w in sorted(edges, key = lambda x:x[2]):
                if dsu.find(x) == dsu.find(y): continue
                dsu.union(x, y)
                ans += w
            return ans
        
        ans = [[], []]
        MST = kruskal(n, edges, [])
        for i, (x, y, w) in enumerate(edges):
            edges2 = edges[:]
            edges2.remove([x, y, w])
            MST1 = kruskal(n, edges2, [])
            MST2 = kruskal(n, edges2, [[x, y, w]])
            if MST2 == MST:
                ind = 1 if MST1 == MST else 0
                ans[ind].append(i)
           
        return ans

Solution 2

In fact it can be easily reduced to O(E^2), if we sort our edges only once and use flags for deleted edges. However in practice it will work approximately the same time due to fast sort implementation.

Also there is O(E * log E) solution as well! The idea is the following: do usual Kruskal algorithm and group all edges by equal weights. Then if we have edge (x, y) and it connects already connected parts, it can not be critical or pseudo-critial. In the opposite case we
create graph with nodes being our connected components so-far and edges between these nodes. If we have more than one edge between two nodes, all of these edges are pseudo-critical. Finally, we need to find all brides in our graph: they will be critical edges and all other will be pseudo-critical.

Complexity

It is O(E log E) for time and O(E) for space

Code

class DSU:
    def __init__(self, N):
        self.p = list(range(N))

    def find(self, x):
        if self.p[x] != x:
            self.p[x] = self.find(self.p[x])
        return self.p[x]

    def union(self, x, y):
        xr = self.find(x)
        yr = self.find(y)
        self.p[xr] = yr

class Solution:
    def bridges(self, n, G):
        used, tin, fup = [0]*n, [-1]*n, [-1]*n
        self.timer, ans = 0, []

        def dfs(node, par = -1):
            used[node] = 1
            tin[node] = fup[node] = self.timer + 1
            self.timer += 1
            for child in G[node]:
                if child == par: continue
                if used[child] == 1:
                    fup[node] = min(fup[node], tin[child])
                else:
                    dfs(child, node)
                    fup[node] = min(fup[node], fup[child])
                    if fup[child] > tin[node]: ans.append((node, child))

        for i in range(n):
            if not used[i]: dfs(i)

        return ans
    
    def findCriticalAndPseudoCriticalEdges(self, n, edges):
        E = {(x, y): i for i, (x, y, _) in enumerate(edges)}
        dic = defaultdict(list)
        
        for u, v, w in edges:
            dic[w].append([u, v])
        
        dsu, ans = DSU(n), [[], []]
        
        for w in sorted(dic):
            seen = defaultdict(list)
            for u, v in dic[w]:
                pu, pv = dsu.find(u), dsu.find(v)
                if pu == pv: continue
                seen[min(pu, pv), max(pu, pv)].append((u, v))
            
            for e in seen:
                ans[1] += seen[e]
            
            graph = defaultdict(list)
            for pu, pv in seen:
                graph[pu].append(pv)
                graph[pv].append(pu)
                dsu.union(pu, pv)
                
            brid = self.bridges(n, graph)
            brid = [(min(x, y), max(x, y)) for x, y in brid]
            ans[0] += [seen[e][0] for e in brid if len(seen[e]) == 1]
        
        return [[E[i] for i in ans[0]], [E[i] for i in set(ans[1]) - set(ans[0])]]