1 //===- RootOrdering.cpp - Optimal root ordering ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // An implementation of Edmonds' optimal branching algorithm. This is a 10 // directed analogue of the minimum spanning tree problem for a given root. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "RootOrdering.h" 15 16 #include "llvm/ADT/DenseMap.h" 17 #include "llvm/ADT/DenseSet.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include <queue> 20 #include <utility> 21 22 using namespace mlir; 23 using namespace mlir::pdl_to_pdl_interp; 24 25 /// Returns the cycle implied by the specified parent relation, starting at the 26 /// given node. 27 static SmallVector<Value> getCycle(const DenseMap<Value, Value> &parents, 28 Value rep) { 29 SmallVector<Value> cycle; 30 Value node = rep; 31 do { 32 cycle.push_back(node); 33 node = parents.lookup(node); 34 assert(node && "got an empty value in the cycle"); 35 } while (node != rep); 36 return cycle; 37 } 38 39 /// Contracts the specified cycle in the given graph in-place. 40 /// The parentsCost map specifies, for each node in the cycle, the lowest cost 41 /// among the edges entering that node. Then, the nodes in the cycle C are 42 /// replaced with a single node v_C (the first node in the cycle). All edges 43 /// (u, v) entering the cycle, v \in C, are replaced with a single edge 44 /// (u, v_C) with an appropriately chosen cost, and the selected node v is 45 /// marked in the output map actualTarget[u]. All edges (u, v) leaving the 46 /// cycle, u \in C, are replaced with a single edge (v_C, v), and the selected 47 /// node u is marked in the ouptut map actualSource[v]. 48 static void contract(RootOrderingGraph &graph, ArrayRef<Value> cycle, 49 const DenseMap<Value, unsigned> &parentCosts, 50 DenseMap<Value, Value> &actualSource, 51 DenseMap<Value, Value> &actualTarget) { 52 Value rep = cycle.front(); 53 DenseSet<Value> cycleSet(cycle.begin(), cycle.end()); 54 55 // Now, contract the cycle, marking the actual sources and targets. 56 DenseMap<Value, RootOrderingCost> repCosts; 57 for (auto outer = graph.begin(), e = graph.end(); outer != e; ++outer) { 58 Value target = outer->first; 59 if (cycleSet.contains(target)) { 60 // Target in the cycle => edges incoming to the cycle or within the cycle. 61 unsigned parentCost = parentCosts.lookup(target); 62 for (const auto &inner : outer->second) { 63 Value source = inner.first; 64 // Ignore edges within the cycle. 65 if (cycleSet.contains(source)) 66 continue; 67 68 // Edge incoming to the cycle. 69 std::pair<unsigned, unsigned> cost = inner.second.cost; 70 assert(parentCost <= cost.first && "invalid parent cost"); 71 72 // Subtract the cost of the parent within the cycle from the cost of 73 // the edge incoming to the cycle. This update ensures that the cost 74 // of the minimum-weight spanning arborescence of the entire graph is 75 // the cost of arborescence for the contracted graph plus the cost of 76 // the cycle, no matter which edge in the cycle we choose to drop. 77 cost.first -= parentCost; 78 auto it = repCosts.find(source); 79 if (it == repCosts.end() || it->second.cost > cost) { 80 actualTarget[source] = target; 81 // Do not bother populating the connector (the connector is only 82 // relevant for the final traversal, not for the optimal branching). 83 repCosts[source].cost = cost; 84 } 85 } 86 // Erase the node in the cycle. 87 graph.erase(outer); 88 } else { 89 // Target not in cycle => edges going away from or unrelated to the cycle. 90 DenseMap<Value, RootOrderingCost> &costs = outer->second; 91 Value bestSource; 92 std::pair<unsigned, unsigned> bestCost; 93 auto inner = costs.begin(), inner_e = costs.end(); 94 while (inner != inner_e) { 95 Value source = inner->first; 96 if (cycleSet.contains(source)) { 97 // Going-away edge => get its cost and erase it. 98 if (!bestSource || bestCost > inner->second.cost) { 99 bestSource = source; 100 bestCost = inner->second.cost; 101 } 102 costs.erase(inner++); 103 } else { 104 ++inner; 105 } 106 } 107 108 // There were going-away edges, contract them. 109 if (bestSource) { 110 costs[rep].cost = bestCost; 111 actualSource[target] = bestSource; 112 } 113 } 114 } 115 116 // Store the edges to the representative. 117 graph[rep] = std::move(repCosts); 118 } 119 120 OptimalBranching::OptimalBranching(RootOrderingGraph graph, Value root) 121 : graph(std::move(graph)), root(root) {} 122 123 unsigned OptimalBranching::solve() { 124 // Initialize the parents and total cost. 125 parents.clear(); 126 parents[root] = Value(); 127 unsigned totalCost = 0; 128 129 // A map that stores the cost of the optimal local choice for each node 130 // in a directed cycle. This map is cleared every time we seed the search. 131 DenseMap<Value, unsigned> parentCosts; 132 parentCosts.reserve(graph.size()); 133 134 // Determine if the optimal local choice results in an acyclic graph. This is 135 // done by computing the optimal local choice and traversing up the computed 136 // parents. On success, `parents` will contain the parent of each node. 137 for (const auto &outer : graph) { 138 Value node = outer.first; 139 if (parents.count(node)) // already visited 140 continue; 141 142 // Follow the trail of best sources until we reach an already visited node. 143 // The code will assert if we cannot reach an already visited node, i.e., 144 // the graph is not strongly connected. 145 parentCosts.clear(); 146 do { 147 auto it = graph.find(node); 148 assert(it != graph.end() && "the graph is not strongly connected"); 149 150 Value &bestSource = parents[node]; 151 unsigned &bestCost = parentCosts[node]; 152 for (const auto &inner : it->second) { 153 const RootOrderingCost &cost = inner.second; 154 if (!bestSource /* initial */ || bestCost > cost.cost.first) { 155 bestSource = inner.first; 156 bestCost = cost.cost.first; 157 } 158 } 159 assert(bestSource && "the graph is not strongly connected"); 160 node = bestSource; 161 totalCost += bestCost; 162 } while (!parents.count(node)); 163 164 // If we reached a non-root node, we have a cycle. 165 if (parentCosts.count(node)) { 166 // Determine the cycle starting at the representative node. 167 SmallVector<Value> cycle = getCycle(parents, node); 168 169 // The following maps disambiguate the source / target of the edges 170 // going out of / into the cycle. 171 DenseMap<Value, Value> actualSource, actualTarget; 172 173 // Contract the cycle and recurse. 174 contract(graph, cycle, parentCosts, actualSource, actualTarget); 175 totalCost = solve(); 176 177 // Redirect the going-away edges. 178 for (auto &p : parents) 179 if (p.second == node) 180 // The parent is the node representating the cycle; replace it 181 // with the actual (best) source in the cycle. 182 p.second = actualSource.lookup(p.first); 183 184 // Redirect the unique incoming edge and copy the cycle. 185 Value parent = parents.lookup(node); 186 Value entry = actualTarget.lookup(parent); 187 cycle.push_back(node); // complete the cycle 188 for (size_t i = 0, e = cycle.size() - 1; i < e; ++i) { 189 totalCost += parentCosts.lookup(cycle[i]); 190 if (cycle[i] == entry) 191 parents[cycle[i]] = parent; // break the cycle 192 else 193 parents[cycle[i]] = cycle[i + 1]; 194 } 195 196 // `parents` has a complete solution. 197 break; 198 } 199 } 200 201 return totalCost; 202 } 203 204 OptimalBranching::EdgeList 205 OptimalBranching::preOrderTraversal(ArrayRef<Value> nodes) const { 206 // Invert the parent mapping. 207 DenseMap<Value, std::vector<Value>> children; 208 for (Value node : nodes) { 209 if (node != root) { 210 Value parent = parents.lookup(node); 211 assert(parent && "invalid parent"); 212 children[parent].push_back(node); 213 } 214 } 215 216 // The result which simultaneously acts as a queue. 217 EdgeList result; 218 result.reserve(nodes.size()); 219 result.emplace_back(root, Value()); 220 221 // Perform a BFS, pushing into the queue. 222 for (size_t i = 0; i < result.size(); ++i) { 223 Value node = result[i].first; 224 for (Value child : children[node]) 225 result.emplace_back(child, node); 226 } 227 228 return result; 229 } 230