diff --git a/compiler/rustc_data_structures/src/graph/iterate/mod.rs b/compiler/rustc_data_structures/src/graph/iterate/mod.rs index 09b91083a63..a9db3497b23 100644 --- a/compiler/rustc_data_structures/src/graph/iterate/mod.rs +++ b/compiler/rustc_data_structures/src/graph/iterate/mod.rs @@ -83,8 +83,58 @@ impl DepthFirstSearch<'graph, G> where G: ?Sized + DirectedGraph + WithNumNodes + WithSuccessors, { - pub fn new(graph: &'graph G, start_node: G::Node) -> Self { - Self { graph, stack: vec![start_node], visited: BitSet::new_empty(graph.num_nodes()) } + pub fn new(graph: &'graph G) -> Self { + Self { graph, stack: vec![], visited: BitSet::new_empty(graph.num_nodes()) } + } + + /// Version of `push_start_node` that is convenient for chained + /// use. + pub fn with_start_node(mut self, start_node: G::Node) -> Self { + self.push_start_node(start_node); + self + } + + /// Pushes another start node onto the stack. If the node + /// has not already been visited, then you will be able to + /// walk its successors (and so forth) after the current + /// contents of the stack are drained. If multiple start nodes + /// are added into the walk, then their mutual successors + /// will all be walked. You can use this method once the + /// iterator has been completely drained to add additional + /// start nodes. + pub fn push_start_node(&mut self, start_node: G::Node) { + if self.visited.insert(start_node) { + self.stack.push(start_node); + } + } + + /// Searches all nodes reachable from the current start nodes. + /// This is equivalent to just invoke `next` repeatedly until + /// you get a `None` result. + pub fn complete_search(&mut self) { + while let Some(_) = self.next() {} + } + + /// Returns true if node has been visited thus far. + /// A node is considered "visited" once it is pushed + /// onto the internal stack; it may not yet have been yielded + /// from the iterator. This method is best used after + /// the iterator is completely drained. + pub fn visited(&self, node: G::Node) -> bool { + self.visited.contains(node) + } +} + +impl std::fmt::Debug for DepthFirstSearch<'_, G> +where + G: ?Sized + DirectedGraph + WithNumNodes + WithSuccessors, +{ + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut f = fmt.debug_set(); + for n in self.visited.iter() { + f.entry(&n); + } + f.finish() } } diff --git a/compiler/rustc_data_structures/src/graph/iterate/tests.rs b/compiler/rustc_data_structures/src/graph/iterate/tests.rs index 0e038e88b22..c498c289337 100644 --- a/compiler/rustc_data_structures/src/graph/iterate/tests.rs +++ b/compiler/rustc_data_structures/src/graph/iterate/tests.rs @@ -20,3 +20,19 @@ fn is_cyclic() { assert!(!is_cyclic(&diamond_acyclic)); assert!(is_cyclic(&diamond_cyclic)); } + +#[test] +fn dfs() { + let graph = TestGraph::new(0, &[(0, 1), (0, 2), (1, 3), (2, 3), (3, 0)]); + + let result: Vec = DepthFirstSearch::new(&graph).with_start_node(0).collect(); + assert_eq!(result, vec![0, 2, 3, 1]); +} + +#[test] +fn dfs_debug() { + let graph = TestGraph::new(0, &[(0, 1), (0, 2), (1, 3), (2, 3), (3, 0)]); + let mut dfs = DepthFirstSearch::new(&graph).with_start_node(0); + dfs.complete_search(); + assert_eq!(format!("{{0, 1, 2, 3}}"), format!("{:?}", dfs)); +} diff --git a/compiler/rustc_data_structures/src/graph/mod.rs b/compiler/rustc_data_structures/src/graph/mod.rs index dff22855629..3560df6e5e2 100644 --- a/compiler/rustc_data_structures/src/graph/mod.rs +++ b/compiler/rustc_data_structures/src/graph/mod.rs @@ -32,7 +32,7 @@ where where Self: WithNumNodes, { - iterate::DepthFirstSearch::new(self, from) + iterate::DepthFirstSearch::new(self).with_start_node(from) } }