xref: /llvm-project/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp (revision 6df7cc7f47d280d550f41fc167bdd75fea726a06)
1 //===- RootOrdering.cpp - Optimal root ordering ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // An implementation of Edmonds' optimal branching algorithm. This is a
10 // directed analogue of the minimum spanning tree problem for a given root.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "RootOrdering.h"
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/DenseSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include <queue>
20 #include <utility>
21 
22 using namespace mlir;
23 using namespace mlir::pdl_to_pdl_interp;
24 
25 /// Returns the cycle implied by the specified parent relation, starting at the
26 /// given node.
27 static SmallVector<Value> getCycle(const DenseMap<Value, Value> &parents,
28                                    Value rep) {
29   SmallVector<Value> cycle;
30   Value node = rep;
31   do {
32     cycle.push_back(node);
33     node = parents.lookup(node);
34     assert(node && "got an empty value in the cycle");
35   } while (node != rep);
36   return cycle;
37 }
38 
39 /// Contracts the specified cycle in the given graph in-place.
40 /// The parentsCost map specifies, for each node in the cycle, the lowest cost
41 /// among the edges entering that node. Then, the nodes in the cycle C are
42 /// replaced with a single node v_C (the first node in the cycle). All edges
43 /// (u, v) entering the cycle, v \in C, are replaced with a single edge
44 /// (u, v_C) with an appropriately chosen cost, and the selected node v is
45 /// marked in the output map actualTarget[u]. All edges (u, v) leaving the
46 /// cycle, u \in C, are replaced with a single edge (v_C, v), and the selected
47 /// node u is marked in the ouptut map actualSource[v].
48 static void contract(RootOrderingGraph &graph, ArrayRef<Value> cycle,
49                      const DenseMap<Value, unsigned> &parentCosts,
50                      DenseMap<Value, Value> &actualSource,
51                      DenseMap<Value, Value> &actualTarget) {
52   Value rep = cycle.front();
53   DenseSet<Value> cycleSet(cycle.begin(), cycle.end());
54 
55   // Now, contract the cycle, marking the actual sources and targets.
56   DenseMap<Value, RootOrderingCost> repCosts;
57   for (auto outer = graph.begin(), e = graph.end(); outer != e; ++outer) {
58     Value target = outer->first;
59     if (cycleSet.contains(target)) {
60       // Target in the cycle => edges incoming to the cycle or within the cycle.
61       unsigned parentCost = parentCosts.lookup(target);
62       for (const auto &inner : outer->second) {
63         Value source = inner.first;
64         // Ignore edges within the cycle.
65         if (cycleSet.contains(source))
66           continue;
67 
68         // Edge incoming to the cycle.
69         std::pair<unsigned, unsigned> cost = inner.second.cost;
70         assert(parentCost <= cost.first && "invalid parent cost");
71 
72         // Subtract the cost of the parent within the cycle from the cost of
73         // the edge incoming to the cycle. This update ensures that the cost
74         // of the minimum-weight spanning arborescence of the entire graph is
75         // the cost of arborescence for the contracted graph plus the cost of
76         // the cycle, no matter which edge in the cycle we choose to drop.
77         cost.first -= parentCost;
78         auto it = repCosts.find(source);
79         if (it == repCosts.end() || it->second.cost > cost) {
80           actualTarget[source] = target;
81           // Do not bother populating the connector (the connector is only
82           // relevant for the final traversal, not for the optimal branching).
83           repCosts[source].cost = cost;
84         }
85       }
86       // Erase the node in the cycle.
87       graph.erase(outer);
88     } else {
89       // Target not in cycle => edges going away from or unrelated to the cycle.
90       DenseMap<Value, RootOrderingCost> &costs = outer->second;
91       Value bestSource;
92       std::pair<unsigned, unsigned> bestCost;
93       auto inner = costs.begin(), inner_e = costs.end();
94       while (inner != inner_e) {
95         Value source = inner->first;
96         if (cycleSet.contains(source)) {
97           // Going-away edge => get its cost and erase it.
98           if (!bestSource || bestCost > inner->second.cost) {
99             bestSource = source;
100             bestCost = inner->second.cost;
101           }
102           costs.erase(inner++);
103         } else {
104           ++inner;
105         }
106       }
107 
108       // There were going-away edges, contract them.
109       if (bestSource) {
110         costs[rep].cost = bestCost;
111         actualSource[target] = bestSource;
112       }
113     }
114   }
115 
116   // Store the edges to the representative.
117   graph[rep] = std::move(repCosts);
118 }
119 
120 OptimalBranching::OptimalBranching(RootOrderingGraph graph, Value root)
121     : graph(std::move(graph)), root(root) {}
122 
123 unsigned OptimalBranching::solve() {
124   // Initialize the parents and total cost.
125   parents.clear();
126   parents[root] = Value();
127   unsigned totalCost = 0;
128 
129   // A map that stores the cost of the optimal local choice for each node
130   // in a directed cycle. This map is cleared every time we seed the search.
131   DenseMap<Value, unsigned> parentCosts;
132   parentCosts.reserve(graph.size());
133 
134   // Determine if the optimal local choice results in an acyclic graph. This is
135   // done by computing the optimal local choice and traversing up the computed
136   // parents. On success, `parents` will contain the parent of each node.
137   for (const auto &outer : graph) {
138     Value node = outer.first;
139     if (parents.count(node)) // already visited
140       continue;
141 
142     // Follow the trail of best sources until we reach an already visited node.
143     // The code will assert if we cannot reach an already visited node, i.e.,
144     // the graph is not strongly connected.
145     parentCosts.clear();
146     do {
147       auto it = graph.find(node);
148       assert(it != graph.end() && "the graph is not strongly connected");
149 
150       Value &bestSource = parents[node];
151       unsigned &bestCost = parentCosts[node];
152       for (const auto &inner : it->second) {
153         const RootOrderingCost &cost = inner.second;
154         if (!bestSource /* initial */ || bestCost > cost.cost.first) {
155           bestSource = inner.first;
156           bestCost = cost.cost.first;
157         }
158       }
159       assert(bestSource && "the graph is not strongly connected");
160       node = bestSource;
161       totalCost += bestCost;
162     } while (!parents.count(node));
163 
164     // If we reached a non-root node, we have a cycle.
165     if (parentCosts.count(node)) {
166       // Determine the cycle starting at the representative node.
167       SmallVector<Value> cycle = getCycle(parents, node);
168 
169       // The following maps disambiguate the source / target of the edges
170       // going out of / into the cycle.
171       DenseMap<Value, Value> actualSource, actualTarget;
172 
173       // Contract the cycle and recurse.
174       contract(graph, cycle, parentCosts, actualSource, actualTarget);
175       totalCost = solve();
176 
177       // Redirect the going-away edges.
178       for (auto &p : parents)
179         if (p.second == node)
180           // The parent is the node representating the cycle; replace it
181           // with the actual (best) source in the cycle.
182           p.second = actualSource.lookup(p.first);
183 
184       // Redirect the unique incoming edge and copy the cycle.
185       Value parent = parents.lookup(node);
186       Value entry = actualTarget.lookup(parent);
187       cycle.push_back(node); // complete the cycle
188       for (size_t i = 0, e = cycle.size() - 1; i < e; ++i) {
189         totalCost += parentCosts.lookup(cycle[i]);
190         if (cycle[i] == entry)
191           parents[cycle[i]] = parent; // break the cycle
192         else
193           parents[cycle[i]] = cycle[i + 1];
194       }
195 
196       // `parents` has a complete solution.
197       break;
198     }
199   }
200 
201   return totalCost;
202 }
203 
204 OptimalBranching::EdgeList
205 OptimalBranching::preOrderTraversal(ArrayRef<Value> nodes) const {
206   // Invert the parent mapping.
207   DenseMap<Value, std::vector<Value>> children;
208   for (Value node : nodes) {
209     if (node != root) {
210       Value parent = parents.lookup(node);
211       assert(parent && "invalid parent");
212       children[parent].push_back(node);
213     }
214   }
215 
216   // The result which simultaneously acts as a queue.
217   EdgeList result;
218   result.reserve(nodes.size());
219   result.emplace_back(root, Value());
220 
221   // Perform a BFS, pushing into the queue.
222   for (size_t i = 0; i < result.size(); ++i) {
223     Value node = result[i].first;
224     for (Value child : children[node])
225       result.emplace_back(child, node);
226   }
227 
228   return result;
229 }
230