1use std::collections::BTreeMap;
4
5pub fn topo_sort<Id, NodeIds, PredsFn, PredsIter>(
15 node_ids: NodeIds,
16 mut preds_fn: PredsFn,
17) -> Result<Vec<Id>, Vec<Id>>
18where
19 Id: Copy + Eq + Ord,
20 NodeIds: IntoIterator<Item = Id>,
21 PredsFn: FnMut(Id) -> PredsIter,
22 PredsIter: IntoIterator<Item = Id>,
23{
24 let (mut marked, mut order) = Default::default();
25
26 fn pred_dfs_postorder<Id, PredsFn, PredsIter>(
27 node_id: Id,
28 preds_fn: &mut PredsFn,
29 marked: &mut BTreeMap<Id, bool>, order: &mut Vec<Id>,
31 ) -> Result<(), ()>
32 where
33 Id: Copy + Eq + Ord,
34 PredsFn: FnMut(Id) -> PredsIter,
35 PredsIter: IntoIterator<Item = Id>,
36 {
37 match marked.get(&node_id) {
38 Some(_permanent @ true) => Ok(()),
39 Some(_temporary @ false) => {
40 order.clear();
42 order.push(node_id);
43 Err(())
44 }
45 None => {
46 marked.insert(node_id, false);
47 for next_pred in (preds_fn)(node_id) {
48 pred_dfs_postorder(next_pred, preds_fn, marked, order).map_err(|()| {
49 if order.len() == 1 || order.first().unwrap() != order.last().unwrap() {
50 order.push(node_id);
51 }
52 })?;
53 }
54 order.push(node_id);
55 marked.insert(node_id, true);
56 Ok(())
57 }
58 }
59 }
60
61 for node_id in node_ids {
62 if pred_dfs_postorder(node_id, &mut preds_fn, &mut marked, &mut order).is_err() {
63 let end = order.last().unwrap();
65 let beg = order.iter().position(|n| n == end).unwrap();
66 order.drain(0..=beg);
67 return Err(order);
68 }
69 }
70
71 Ok(order)
72}
73
74#[cfg(test)]
75mod test {
76 use std::collections::{BTreeMap, BTreeSet};
77
78 use itertools::Itertools;
79
80 use super::*;
81
82 #[test]
83 pub fn test_toposort() {
84 let edges = [
85 (5, 11),
86 (11, 2),
87 (11, 9),
88 (11, 10),
89 (7, 11),
90 (7, 8),
91 (8, 9),
92 (3, 8),
93 (3, 10),
94 ];
95
96 let sort = topo_sort([2, 3, 5, 7, 8, 9, 10, 11], |v| {
98 edges
99 .iter()
100 .filter(move |&&(_, dst)| v == dst)
101 .map(|&(src, _)| src)
102 });
103 assert!(
104 sort.is_ok(),
105 "Did not expect cycle: {:?}",
106 sort.unwrap_err()
107 );
108
109 let sort = sort.unwrap();
110 println!("{:?}", sort);
111
112 let position: BTreeMap<_, _> = sort.iter().enumerate().map(|(i, &x)| (x, i)).collect();
113 for (src, dst) in edges.iter() {
114 assert!(position[src] < position[dst]);
115 }
116 }
117
118 #[test]
119 pub fn test_toposort_cycle() {
120 let edges = [
129 ('A', 'B'),
130 ('B', 'C'),
131 ('C', 'E'),
132 ('D', 'B'),
133 ('E', 'F'),
134 ('E', 'D'),
135 ];
136 let ids = edges
137 .iter()
138 .flat_map(|&(a, b)| [a, b])
139 .collect::<BTreeSet<_>>();
140 let cycle_rotations = BTreeSet::from_iter([
141 ['B', 'C', 'E', 'D'],
142 ['C', 'E', 'D', 'B'],
143 ['E', 'D', 'B', 'C'],
144 ['D', 'B', 'C', 'E'],
145 ]);
146
147 let permutations = ids.iter().copied().permutations(ids.len());
148 for permutation in permutations {
149 let result = topo_sort(permutation.iter().copied(), |v| {
150 edges
151 .iter()
152 .filter(move |&&(_, dst)| v == dst)
153 .map(|&(src, _)| src)
154 });
155 assert!(result.is_err());
156 let cycle = result.unwrap_err();
157 assert!(
158 cycle_rotations.contains(&*cycle),
159 "cycle: {:?}, vertex order: {:?}",
160 cycle,
161 permutation
162 );
163 }
164 }
165}