diff --git a/src/taskgraph/graph.py b/src/taskgraph/graph.py index 573275087..a521cdfc4 100644 --- a/src/taskgraph/graph.py +++ b/src/taskgraph/graph.py @@ -61,21 +61,29 @@ def transitive_closure(self, nodes, reverse=False): f"Unknown nodes in transitive closure: {nodes - self.nodes}" ) - # generate a new graph by expanding along edges until reaching a fixed - # point - new_nodes, new_edges = nodes, set() - nodes, edges = set(), set() - while (new_nodes, new_edges) != (nodes, edges): - nodes, edges = new_nodes, new_edges - add_edges = { - (left, right, name) - for (left, right, name) in self.edges - if (right if reverse else left) in nodes - } - add_nodes = {(left if reverse else right) for (left, right, _) in add_edges} - new_nodes = nodes | add_nodes - new_edges = edges | add_edges - return Graph(new_nodes, new_edges) + # Build an adjacency list once, then BFS — O(Vertices + Edges) + adjacency = collections.defaultdict(set) + for left, right, _name in self.edges: + if reverse: + adjacency[right].add(left) + else: + adjacency[left].add(right) + + result_nodes = set(nodes) + queue = collections.deque(nodes) + while queue: + node = queue.popleft() + for neighbor in adjacency.get(node, ()): + if neighbor not in result_nodes: + result_nodes.add(neighbor) + queue.append(neighbor) + + result_edges = frozenset( + (left, right, name) + for left, right, name in self.edges + if left in result_nodes and right in result_nodes + ) + return Graph(frozenset(result_nodes), result_edges) def _visit(self, reverse): forward_links, reverse_links = self.links_and_reverse_links_dict()