1 //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the Reduction Tree Pass class. It provides a framework for 10 // the implementation of different reduction passes in the MLIR Reduce tool. It 11 // allows for custom specification of the variant generation behavior. It 12 // implements methods that define the different possible traversals of the 13 // reduction tree. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/IR/DialectInterface.h" 18 #include "mlir/IR/OpDefinition.h" 19 #include "mlir/Reducer/Passes.h" 20 #include "mlir/Reducer/ReductionNode.h" 21 #include "mlir/Reducer/ReductionPatternInterface.h" 22 #include "mlir/Reducer/Tester.h" 23 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 26 #include "llvm/ADT/ArrayRef.h" 27 #include "llvm/ADT/SmallVector.h" 28 #include "llvm/Support/Allocator.h" 29 #include "llvm/Support/ManagedStatic.h" 30 31 namespace mlir { 32 #define GEN_PASS_DEF_REDUCTIONTREE 33 #include "mlir/Reducer/Passes.h.inc" 34 } // namespace mlir 35 36 using namespace mlir; 37 38 /// We implicitly number each operation in the region and if an operation's 39 /// number falls into rangeToKeep, we need to keep it and apply the given 40 /// rewrite patterns on it. 41 static void applyPatterns(Region ®ion, 42 const FrozenRewritePatternSet &patterns, 43 ArrayRef<ReductionNode::Range> rangeToKeep, 44 bool eraseOpNotInRange) { 45 std::vector<Operation *> opsNotInRange; 46 std::vector<Operation *> opsInRange; 47 size_t keepIndex = 0; 48 for (const auto &op : enumerate(region.getOps())) { 49 int index = op.index(); 50 if (keepIndex < rangeToKeep.size() && 51 index == rangeToKeep[keepIndex].second) 52 ++keepIndex; 53 if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first) 54 opsNotInRange.push_back(&op.value()); 55 else 56 opsInRange.push_back(&op.value()); 57 } 58 59 // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern 60 // matching in above iteration. Besides, erase op not-in-range may end up in 61 // invalid module, so `applyOpPatternsAndFold` should come before that 62 // transform. 63 for (Operation *op : opsInRange) { 64 // `applyOpPatternsAndFold` returns whether the op is convered. Omit it 65 // because we don't have expectation this reduction will be success or not. 66 GreedyRewriteConfig config; 67 config.strictMode = GreedyRewriteStrictness::ExistingOps; 68 (void)applyOpPatternsGreedily(op, patterns, config); 69 } 70 71 if (eraseOpNotInRange) 72 for (Operation *op : opsNotInRange) { 73 op->dropAllUses(); 74 op->erase(); 75 } 76 } 77 78 /// We will apply the reducer patterns to the operations in the ranges specified 79 /// by ReductionNode. Note that we are not able to remove an operation without 80 /// replacing it with another valid operation. However, The validity of module 81 /// reduction is based on the Tester provided by the user and that means certain 82 /// invalid module is still interested by the use. Thus we provide an 83 /// alternative way to remove operations, which is using `eraseOpNotInRange` to 84 /// erase the operations not in the range specified by ReductionNode. 85 template <typename IteratorType> 86 static LogicalResult findOptimal(ModuleOp module, Region ®ion, 87 const FrozenRewritePatternSet &patterns, 88 const Tester &test, bool eraseOpNotInRange) { 89 std::pair<Tester::Interestingness, size_t> initStatus = 90 test.isInteresting(module); 91 // While exploring the reduction tree, we always branch from an interesting 92 // node. Thus the root node must be interesting. 93 if (initStatus.first != Tester::Interestingness::True) 94 return module.emitWarning() << "uninterested module will not be reduced"; 95 96 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator; 97 98 std::vector<ReductionNode::Range> ranges{ 99 {0, std::distance(region.op_begin(), region.op_end())}}; 100 101 ReductionNode *root = allocator.Allocate(); 102 new (root) ReductionNode(nullptr, ranges, allocator); 103 // Duplicate the module for root node and locate the region in the copy. 104 if (failed(root->initialize(module, region))) 105 llvm_unreachable("unexpected initialization failure"); 106 root->update(initStatus); 107 108 ReductionNode *smallestNode = root; 109 IteratorType iter(root); 110 111 while (iter != IteratorType::end()) { 112 ReductionNode ¤tNode = *iter; 113 Region &curRegion = currentNode.getRegion(); 114 115 applyPatterns(curRegion, patterns, currentNode.getRanges(), 116 eraseOpNotInRange); 117 currentNode.update(test.isInteresting(currentNode.getModule())); 118 119 if (currentNode.isInteresting() == Tester::Interestingness::True && 120 currentNode.getSize() < smallestNode->getSize()) 121 smallestNode = ¤tNode; 122 123 ++iter; 124 } 125 126 // At here, we have found an optimal path to reduce the given region. Retrieve 127 // the path and apply the reducer to it. 128 SmallVector<ReductionNode *> trace; 129 ReductionNode *curNode = smallestNode; 130 trace.push_back(curNode); 131 while (curNode != root) { 132 curNode = curNode->getParent(); 133 trace.push_back(curNode); 134 } 135 136 // Reduce the region through the optimal path. 137 while (!trace.empty()) { 138 ReductionNode *top = trace.pop_back_val(); 139 applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange); 140 } 141 142 if (test.isInteresting(module).first != Tester::Interestingness::True) 143 llvm::report_fatal_error("Reduced module is not interesting"); 144 if (test.isInteresting(module).second != smallestNode->getSize()) 145 llvm::report_fatal_error( 146 "Reduced module doesn't have consistent size with smallestNode"); 147 return success(); 148 } 149 150 template <typename IteratorType> 151 static LogicalResult findOptimal(ModuleOp module, Region ®ion, 152 const FrozenRewritePatternSet &patterns, 153 const Tester &test) { 154 // We separate the reduction process into 2 steps, the first one is to erase 155 // redundant operations and the second one is to apply the reducer patterns. 156 157 // In the first phase, we don't apply any patterns so that we only select the 158 // range of operations to keep to the module stay interesting. 159 if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test, 160 /*eraseOpNotInRange=*/true))) 161 return failure(); 162 // In the second phase, we suppose that no operation is redundant, so we try 163 // to rewrite the operation into simpler form. 164 return findOptimal<IteratorType>(module, region, patterns, test, 165 /*eraseOpNotInRange=*/false); 166 } 167 168 namespace { 169 170 //===----------------------------------------------------------------------===// 171 // Reduction Pattern Interface Collection 172 //===----------------------------------------------------------------------===// 173 174 class ReductionPatternInterfaceCollection 175 : public DialectInterfaceCollection<DialectReductionPatternInterface> { 176 public: 177 using Base::Base; 178 179 // Collect the reduce patterns defined by each dialect. 180 void populateReductionPatterns(RewritePatternSet &pattern) const { 181 for (const DialectReductionPatternInterface &interface : *this) 182 interface.populateReductionPatterns(pattern); 183 } 184 }; 185 186 //===----------------------------------------------------------------------===// 187 // ReductionTreePass 188 //===----------------------------------------------------------------------===// 189 190 /// This class defines the Reduction Tree Pass. It provides a framework to 191 /// to implement a reduction pass using a tree structure to keep track of the 192 /// generated reduced variants. 193 class ReductionTreePass : public impl::ReductionTreeBase<ReductionTreePass> { 194 public: 195 ReductionTreePass() = default; 196 ReductionTreePass(const ReductionTreePass &pass) = default; 197 198 LogicalResult initialize(MLIRContext *context) override; 199 200 /// Runs the pass instance in the pass pipeline. 201 void runOnOperation() override; 202 203 private: 204 LogicalResult reduceOp(ModuleOp module, Region ®ion); 205 206 FrozenRewritePatternSet reducerPatterns; 207 }; 208 209 } // namespace 210 211 LogicalResult ReductionTreePass::initialize(MLIRContext *context) { 212 RewritePatternSet patterns(context); 213 ReductionPatternInterfaceCollection reducePatternCollection(context); 214 reducePatternCollection.populateReductionPatterns(patterns); 215 reducerPatterns = std::move(patterns); 216 return success(); 217 } 218 219 void ReductionTreePass::runOnOperation() { 220 Operation *topOperation = getOperation(); 221 while (topOperation->getParentOp() != nullptr) 222 topOperation = topOperation->getParentOp(); 223 ModuleOp module = dyn_cast<ModuleOp>(topOperation); 224 if (!module) { 225 emitError(getOperation()->getLoc()) 226 << "top-level op must be 'builtin.module'"; 227 return signalPassFailure(); 228 } 229 230 SmallVector<Operation *, 8> workList; 231 workList.push_back(getOperation()); 232 233 do { 234 Operation *op = workList.pop_back_val(); 235 236 for (Region ®ion : op->getRegions()) 237 if (!region.empty()) 238 if (failed(reduceOp(module, region))) 239 return signalPassFailure(); 240 241 for (Region ®ion : op->getRegions()) 242 for (Operation &op : region.getOps()) 243 if (op.getNumRegions() != 0) 244 workList.push_back(&op); 245 } while (!workList.empty()); 246 } 247 248 LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) { 249 Tester test(testerName, testerArgs); 250 switch (traversalModeId) { 251 case TraversalMode::SinglePath: 252 return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>( 253 module, region, reducerPatterns, test); 254 default: 255 return module.emitError() << "unsupported traversal mode detected"; 256 } 257 } 258 259 std::unique_ptr<Pass> mlir::createReductionTreePass() { 260 return std::make_unique<ReductionTreePass>(); 261 } 262