1 //===- RootOrdering.h - Optimal root ordering ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains definition for a cost graph over candidate roots and 10 // an implementation of an algorithm to determine the optimal ordering over 11 // these roots. Each edge in this graph indicates that the target root can be 12 // connected (via a chain of positions) to the source root, and their cost 13 // indicates the estimated cost of such traversal. The optimal root ordering 14 // is then formulated as that of finding a spanning arborescence (i.e., a 15 // directed spanning tree) of minimal weight. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_ROOTORDERING_H_ 20 #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_ROOTORDERING_H_ 21 22 #include "mlir/IR/Value.h" 23 #include "llvm/ADT/DenseMap.h" 24 #include "llvm/ADT/SmallVector.h" 25 #include <functional> 26 #include <vector> 27 28 namespace mlir { 29 namespace pdl_to_pdl_interp { 30 31 /// The information associated with an edge in the cost graph. Each node in 32 /// the cost graph corresponds to a candidate root detected in the pdl.pattern, 33 /// and each edge in the cost graph corresponds to connecting the two candidate 34 /// roots via a chain of operations. The cost of an edge is the smallest number 35 /// of upward traversals required to go from the source to the target root, and 36 /// the connector is a `Value` in the intersection of the two subtrees rooted at 37 /// the source and target root that results in that smallest number of upward 38 /// traversals. Consider the following pattern with 3 roots op3, op4, and op5: 39 /// 40 /// argA ---> op1 ---> op2 ---> op3 ---> res3 41 /// ^ ^ 42 /// | | 43 /// argB argC 44 /// | | 45 /// v v 46 /// res4 <--- op4 op5 ---> res5 47 /// ^ ^ 48 /// | | 49 /// op6 op7 50 /// 51 /// The cost of the edge op3 -> op4 is 1 (the upward traversal argB -> op4), 52 /// with argB being the connector `Value` and similarly for op3 -> op5 (cost 1, 53 /// connector argC). The cost of the edge op4 -> op3 is 3 (upward traversals 54 /// argB -> op1 -> op2 -> op3, connector argB), while the cost of edge op5 -> 55 /// op3 is 2 (uwpard traversals argC -> op2 -> op3). There are no edges between 56 /// op4 and op5 in the cost graph, because the subtrees rooted at these two 57 /// roots do not intersect. It is easy to see that the optimal root for this 58 /// pattern is op3, resulting in the spanning arborescence op3 -> {op4, op5}. 59 struct RootOrderingEntry { 60 /// The depth of the connector `Value` w.r.t. the target root. 61 /// 62 /// This is a pair where the first value is the additive cost (the depth of 63 /// the connector), and the second value is a priority for breaking ties 64 /// (with 0 being the highest). Typically, the priority is a unique edge ID. 65 std::pair<unsigned, unsigned> cost; 66 67 /// The connector value in the intersection of the two subtrees rooted at 68 /// the source and target root that results in that smallest depth w.r.t. 69 /// the target root. 70 Value connector; 71 }; 72 73 /// A directed graph representing the cost of ordering the roots in the 74 /// predicate tree. It is represented as an adjacency map, where the outer map 75 /// is indexed by the target node, and the inner map is indexed by the source 76 /// node. Each edge is associated with a cost and the underlying connector 77 /// value. 78 using RootOrderingGraph = DenseMap<Value, DenseMap<Value, RootOrderingEntry>>; 79 80 /// The optimal branching algorithm solver. This solver accepts a graph and the 81 /// root in its constructor, and is invoked via the solve() member function. 82 /// This is a direct implementation of the Edmonds' algorithm, see 83 /// https://en.wikipedia.org/wiki/Edmonds%27_algorithm. The worst-case 84 /// computational complexity of this algorithm is O(N^3), for a single root. 85 /// The PDL-to-PDLInterp lowering calls this N times (once for each candidate 86 /// root), so the overall complexity root ordering is O(N^4). If needed, this 87 /// could be reduced to O(N^3) with a more efficient algorithm. However, note 88 /// that the underlying implementation is very efficient, and N in our 89 /// instances tends to be very small (<10). 90 class OptimalBranching { 91 public: 92 /// A list of edges (child, parent). 93 using EdgeList = std::vector<std::pair<Value, Value>>; 94 95 /// Constructs the solver for the given graph and root value. 96 OptimalBranching(RootOrderingGraph graph, Value root); 97 98 /// Runs the Edmonds' algorithm for the current `graph`, returning the total 99 /// cost of the minimum-weight spanning arborescence (sum of the edge costs). 100 /// This function first determines the optimal local choice of the parents 101 /// and stores this choice in the `parents` mapping. If this choice results 102 /// in an acyclic graph, the function returns immediately. Otherwise, it 103 /// takes an arbitrary cycle, contracts it, and recurses on the new graph 104 /// (which is guaranteed to have fewer nodes than we began with). After we 105 /// return from recursion, we redirect the edges to/from the contracted node, 106 /// so the `parents` map contains a valid solution for the current graph. 107 unsigned solve(); 108 109 /// Returns the computed parent map. This is the unique predecessor for each 110 /// node (root) in the optimal branching. getRootOrderingParents()111 const DenseMap<Value, Value> &getRootOrderingParents() const { 112 return parents; 113 } 114 115 /// Returns the computed edges as visited in the preorder traversal. 116 /// The specified array determines the order for breaking any ties. 117 EdgeList preOrderTraversal(ArrayRef<Value> nodes) const; 118 119 private: 120 /// The graph whose optimal branching we wish to determine. 121 RootOrderingGraph graph; 122 123 /// The root of the optimal branching. 124 Value root; 125 126 /// The computed parent mapping. This is the unique predecessor for each node 127 /// in the optimal branching. The keys of this map correspond to the keys of 128 /// the outer map of the input graph, and each value is one of the keys of 129 /// the inner map for this node. Also used as an intermediate (possibly 130 /// cyclical) result in the optimal branching algorithm. 131 DenseMap<Value, Value> parents; 132 }; 133 134 } // namespace pdl_to_pdl_interp 135 } // namespace mlir 136 137 #endif // MLIR_CONVERSION_PDLTOPDLINTERP_ROOTORDERING_H_ 138