1 //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the Reduction Tree Pass class. It provides a framework for 10 // the implementation of different reduction passes in the MLIR Reduce tool. It 11 // allows for custom specification of the variant generation behavior. It 12 // implements methods that define the different possible traversals of the 13 // reduction tree. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/IR/DialectInterface.h" 18 #include "mlir/IR/OpDefinition.h" 19 #include "mlir/Reducer/PassDetail.h" 20 #include "mlir/Reducer/Passes.h" 21 #include "mlir/Reducer/ReductionNode.h" 22 #include "mlir/Reducer/ReductionPatternInterface.h" 23 #include "mlir/Reducer/Tester.h" 24 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 #include "llvm/ADT/ArrayRef.h" 28 #include "llvm/ADT/SmallVector.h" 29 #include "llvm/Support/Allocator.h" 30 #include "llvm/Support/ManagedStatic.h" 31 32 using namespace mlir; 33 34 /// We implicitly number each operation in the region and if an operation's 35 /// number falls into rangeToKeep, we need to keep it and apply the given 36 /// rewrite patterns on it. 37 static void applyPatterns(Region ®ion, 38 const FrozenRewritePatternSet &patterns, 39 ArrayRef<ReductionNode::Range> rangeToKeep, 40 bool eraseOpNotInRange) { 41 std::vector<Operation *> opsNotInRange; 42 std::vector<Operation *> opsInRange; 43 size_t keepIndex = 0; 44 for (const auto &op : enumerate(region.getOps())) { 45 int index = op.index(); 46 if (keepIndex < rangeToKeep.size() && 47 index == rangeToKeep[keepIndex].second) 48 ++keepIndex; 49 if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first) 50 opsNotInRange.push_back(&op.value()); 51 else 52 opsInRange.push_back(&op.value()); 53 } 54 55 // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern 56 // matching in above iteration. Besides, erase op not-in-range may end up in 57 // invalid module, so `applyOpPatternsAndFold` should come before that 58 // transform. 59 for (Operation *op : opsInRange) 60 // `applyOpPatternsAndFold` returns whether the op is convered. Omit it 61 // because we don't have expectation this reduction will be success or not. 62 (void)applyOpPatternsAndFold(op, patterns); 63 64 if (eraseOpNotInRange) 65 for (Operation *op : opsNotInRange) { 66 op->dropAllUses(); 67 op->erase(); 68 } 69 } 70 71 /// We will apply the reducer patterns to the operations in the ranges specified 72 /// by ReductionNode. Note that we are not able to remove an operation without 73 /// replacing it with another valid operation. However, The validity of module 74 /// reduction is based on the Tester provided by the user and that means certain 75 /// invalid module is still interested by the use. Thus we provide an 76 /// alternative way to remove operations, which is using `eraseOpNotInRange` to 77 /// erase the operations not in the range specified by ReductionNode. 78 template <typename IteratorType> 79 static LogicalResult findOptimal(ModuleOp module, Region ®ion, 80 const FrozenRewritePatternSet &patterns, 81 const Tester &test, bool eraseOpNotInRange) { 82 std::pair<Tester::Interestingness, size_t> initStatus = 83 test.isInteresting(module); 84 // While exploring the reduction tree, we always branch from an interesting 85 // node. Thus the root node must be interesting. 86 if (initStatus.first != Tester::Interestingness::True) 87 return module.emitWarning() << "uninterested module will not be reduced"; 88 89 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator; 90 91 std::vector<ReductionNode::Range> ranges{ 92 {0, std::distance(region.op_begin(), region.op_end())}}; 93 94 ReductionNode *root = allocator.Allocate(); 95 new (root) ReductionNode(nullptr, ranges, allocator); 96 // Duplicate the module for root node and locate the region in the copy. 97 if (failed(root->initialize(module, region))) 98 llvm_unreachable("unexpected initialization failure"); 99 root->update(initStatus); 100 101 ReductionNode *smallestNode = root; 102 IteratorType iter(root); 103 104 while (iter != IteratorType::end()) { 105 ReductionNode ¤tNode = *iter; 106 Region &curRegion = currentNode.getRegion(); 107 108 applyPatterns(curRegion, patterns, currentNode.getRanges(), 109 eraseOpNotInRange); 110 currentNode.update(test.isInteresting(currentNode.getModule())); 111 112 if (currentNode.isInteresting() == Tester::Interestingness::True && 113 currentNode.getSize() < smallestNode->getSize()) 114 smallestNode = ¤tNode; 115 116 ++iter; 117 } 118 119 // At here, we have found an optimal path to reduce the given region. Retrieve 120 // the path and apply the reducer to it. 121 SmallVector<ReductionNode *> trace; 122 ReductionNode *curNode = smallestNode; 123 trace.push_back(curNode); 124 while (curNode != root) { 125 curNode = curNode->getParent(); 126 trace.push_back(curNode); 127 } 128 129 // Reduce the region through the optimal path. 130 while (!trace.empty()) { 131 ReductionNode *top = trace.pop_back_val(); 132 applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange); 133 } 134 135 if (test.isInteresting(module).first != Tester::Interestingness::True) 136 llvm::report_fatal_error("Reduced module is not interesting"); 137 if (test.isInteresting(module).second != smallestNode->getSize()) 138 llvm::report_fatal_error( 139 "Reduced module doesn't have consistent size with smallestNode"); 140 return success(); 141 } 142 143 template <typename IteratorType> 144 static LogicalResult findOptimal(ModuleOp module, Region ®ion, 145 const FrozenRewritePatternSet &patterns, 146 const Tester &test) { 147 // We separate the reduction process into 2 steps, the first one is to erase 148 // redundant operations and the second one is to apply the reducer patterns. 149 150 // In the first phase, we don't apply any patterns so that we only select the 151 // range of operations to keep to the module stay interesting. 152 if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test, 153 /*eraseOpNotInRange=*/true))) 154 return failure(); 155 // In the second phase, we suppose that no operation is redundant, so we try 156 // to rewrite the operation into simpler form. 157 return findOptimal<IteratorType>(module, region, patterns, test, 158 /*eraseOpNotInRange=*/false); 159 } 160 161 namespace { 162 163 //===----------------------------------------------------------------------===// 164 // Reduction Pattern Interface Collection 165 //===----------------------------------------------------------------------===// 166 167 class ReductionPatternInterfaceCollection 168 : public DialectInterfaceCollection<DialectReductionPatternInterface> { 169 public: 170 using Base::Base; 171 172 // Collect the reduce patterns defined by each dialect. 173 void populateReductionPatterns(RewritePatternSet &pattern) const { 174 for (const DialectReductionPatternInterface &interface : *this) 175 interface.populateReductionPatterns(pattern); 176 } 177 }; 178 179 //===----------------------------------------------------------------------===// 180 // ReductionTreePass 181 //===----------------------------------------------------------------------===// 182 183 /// This class defines the Reduction Tree Pass. It provides a framework to 184 /// to implement a reduction pass using a tree structure to keep track of the 185 /// generated reduced variants. 186 class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> { 187 public: 188 ReductionTreePass() = default; 189 ReductionTreePass(const ReductionTreePass &pass) = default; 190 191 LogicalResult initialize(MLIRContext *context) override; 192 193 /// Runs the pass instance in the pass pipeline. 194 void runOnOperation() override; 195 196 private: 197 LogicalResult reduceOp(ModuleOp module, Region ®ion); 198 199 FrozenRewritePatternSet reducerPatterns; 200 }; 201 202 } // namespace 203 204 LogicalResult ReductionTreePass::initialize(MLIRContext *context) { 205 RewritePatternSet patterns(context); 206 ReductionPatternInterfaceCollection reducePatternCollection(context); 207 reducePatternCollection.populateReductionPatterns(patterns); 208 reducerPatterns = std::move(patterns); 209 return success(); 210 } 211 212 void ReductionTreePass::runOnOperation() { 213 Operation *topOperation = getOperation(); 214 while (topOperation->getParentOp() != nullptr) 215 topOperation = topOperation->getParentOp(); 216 ModuleOp module = cast<ModuleOp>(topOperation); 217 218 SmallVector<Operation *, 8> workList; 219 workList.push_back(getOperation()); 220 221 do { 222 Operation *op = workList.pop_back_val(); 223 224 for (Region ®ion : op->getRegions()) 225 if (!region.empty()) 226 if (failed(reduceOp(module, region))) 227 return signalPassFailure(); 228 229 for (Region ®ion : op->getRegions()) 230 for (Operation &op : region.getOps()) 231 if (op.getNumRegions() != 0) 232 workList.push_back(&op); 233 } while (!workList.empty()); 234 } 235 236 LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) { 237 Tester test(testerName, testerArgs); 238 switch (traversalModeId) { 239 case TraversalMode::SinglePath: 240 return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>( 241 module, region, reducerPatterns, test); 242 default: 243 return module.emitError() << "unsupported traversal mode detected"; 244 } 245 } 246 247 std::unique_ptr<Pass> mlir::createReductionTreePass() { 248 return std::make_unique<ReductionTreePass>(); 249 } 250