Problem statement

https://leetcode.com/problems/reachable-nodes-in-subdivided-graph/

Solution

Here we need shortest distances, so Dijkstra algorithm is good idea. First of all, let us forgot about split points, and deal with graph as weighted graph, where weight of each node is number of split points on it plus one. On this graph we perform Dijkstra algorithm. Now at each point we have distance for it from origin. We can find total number of reachable points as:

  1. All original nodes with distance less or equal to M.
  2. For every edge we need to check what we have as distances for it ends: if both of distances more than M, than we can not reach any split points on this node. If both of them less or equal to M, we can reach some point starting with one end and some points, starting with another end, these two sets can intersect, so we need to remove then number of points in intersection. Also if for one end distance is more than M and for another is less or equal than M, we have only one set of points that can be reached.

Complexity

Time complexity is O(E log N), space complexity is O(N) as for classical Dijkstra algorithm.

Code

class Solution:
    def reachableNodes(self, edges, M, N):
        G = defaultdict(set)
        dist = [float('inf')] * N
        dist[0] = 0
        
        for i, j, w in edges:
            G[i].add((j, w + 1))
            G[j].add((i, w + 1))
            
        heap = [(0, 0)]

        while heap:
            min_dist, idx = heappop(heap)
            for neibh, weight in G[idx]:
                cand = min_dist + weight
                if cand < dist[neibh]:
                    dist[neibh] = cand
                    heappush(heap, (cand, neibh)) 
                    
        ans = sum(dist[i] <= M for i in range(N))
 
        for i, j, w in edges:
            w1, w2 = M - dist[i], M - dist[j]
            ans += (max(0, w1) + max(0, w2))
            if w1 >= 0 and w2 >= 0: ans -= max(w1 + w2 - w, 0)
                
        return ans