1c484c7ddSChia-hung Duan //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===// 2c484c7ddSChia-hung Duan // 3c484c7ddSChia-hung Duan // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4c484c7ddSChia-hung Duan // See https://llvm.org/LICENSE.txt for license information. 5c484c7ddSChia-hung Duan // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6c484c7ddSChia-hung Duan // 7c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===// 8c484c7ddSChia-hung Duan // 9c484c7ddSChia-hung Duan // This file defines the Reduction Tree Pass class. It provides a framework for 10c484c7ddSChia-hung Duan // the implementation of different reduction passes in the MLIR Reduce tool. It 11c484c7ddSChia-hung Duan // allows for custom specification of the variant generation behavior. It 12c484c7ddSChia-hung Duan // implements methods that define the different possible traversals of the 13c484c7ddSChia-hung Duan // reduction tree. 14c484c7ddSChia-hung Duan // 15c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===// 16c484c7ddSChia-hung Duan 17c484c7ddSChia-hung Duan #include "mlir/IR/DialectInterface.h" 18c484c7ddSChia-hung Duan #include "mlir/IR/OpDefinition.h" 19c484c7ddSChia-hung Duan #include "mlir/Reducer/Passes.h" 20c484c7ddSChia-hung Duan #include "mlir/Reducer/ReductionNode.h" 21c484c7ddSChia-hung Duan #include "mlir/Reducer/ReductionPatternInterface.h" 22c484c7ddSChia-hung Duan #include "mlir/Reducer/Tester.h" 23c484c7ddSChia-hung Duan #include "mlir/Rewrite/FrozenRewritePatternSet.h" 24c484c7ddSChia-hung Duan #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25c484c7ddSChia-hung Duan 26c484c7ddSChia-hung Duan #include "llvm/ADT/ArrayRef.h" 27c484c7ddSChia-hung Duan #include "llvm/ADT/SmallVector.h" 28c484c7ddSChia-hung Duan #include "llvm/Support/Allocator.h" 29c484c7ddSChia-hung Duan #include "llvm/Support/ManagedStatic.h" 30c484c7ddSChia-hung Duan 3167d0d7acSMichele Scuttari namespace mlir { 3267d0d7acSMichele Scuttari #define GEN_PASS_DEF_REDUCTIONTREE 3367d0d7acSMichele Scuttari #include "mlir/Reducer/Passes.h.inc" 3467d0d7acSMichele Scuttari } // namespace mlir 3567d0d7acSMichele Scuttari 36c484c7ddSChia-hung Duan using namespace mlir; 37c484c7ddSChia-hung Duan 38c484c7ddSChia-hung Duan /// We implicitly number each operation in the region and if an operation's 39c484c7ddSChia-hung Duan /// number falls into rangeToKeep, we need to keep it and apply the given 40c484c7ddSChia-hung Duan /// rewrite patterns on it. 41c484c7ddSChia-hung Duan static void applyPatterns(Region ®ion, 42c484c7ddSChia-hung Duan const FrozenRewritePatternSet &patterns, 43c484c7ddSChia-hung Duan ArrayRef<ReductionNode::Range> rangeToKeep, 44c484c7ddSChia-hung Duan bool eraseOpNotInRange) { 45c484c7ddSChia-hung Duan std::vector<Operation *> opsNotInRange; 46c484c7ddSChia-hung Duan std::vector<Operation *> opsInRange; 47c484c7ddSChia-hung Duan size_t keepIndex = 0; 48e4853be2SMehdi Amini for (const auto &op : enumerate(region.getOps())) { 49c484c7ddSChia-hung Duan int index = op.index(); 50c484c7ddSChia-hung Duan if (keepIndex < rangeToKeep.size() && 51c484c7ddSChia-hung Duan index == rangeToKeep[keepIndex].second) 52c484c7ddSChia-hung Duan ++keepIndex; 53c484c7ddSChia-hung Duan if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first) 54c484c7ddSChia-hung Duan opsNotInRange.push_back(&op.value()); 55c484c7ddSChia-hung Duan else 56c484c7ddSChia-hung Duan opsInRange.push_back(&op.value()); 57c484c7ddSChia-hung Duan } 58c484c7ddSChia-hung Duan 59c484c7ddSChia-hung Duan // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern 60c484c7ddSChia-hung Duan // matching in above iteration. Besides, erase op not-in-range may end up in 61c484c7ddSChia-hung Duan // invalid module, so `applyOpPatternsAndFold` should come before that 62c484c7ddSChia-hung Duan // transform. 636bdecbcbSMatthias Springer for (Operation *op : opsInRange) { 64c484c7ddSChia-hung Duan // `applyOpPatternsAndFold` returns whether the op is convered. Omit it 65c484c7ddSChia-hung Duan // because we don't have expectation this reduction will be success or not. 666bdecbcbSMatthias Springer GreedyRewriteConfig config; 676bdecbcbSMatthias Springer config.strictMode = GreedyRewriteStrictness::ExistingOps; 68*09dfc571SJacques Pienaar (void)applyOpPatternsGreedily(op, patterns, config); 696bdecbcbSMatthias Springer } 70c484c7ddSChia-hung Duan 71c484c7ddSChia-hung Duan if (eraseOpNotInRange) 72c484c7ddSChia-hung Duan for (Operation *op : opsNotInRange) { 73c484c7ddSChia-hung Duan op->dropAllUses(); 74c484c7ddSChia-hung Duan op->erase(); 75c484c7ddSChia-hung Duan } 76c484c7ddSChia-hung Duan } 77c484c7ddSChia-hung Duan 78c484c7ddSChia-hung Duan /// We will apply the reducer patterns to the operations in the ranges specified 79c484c7ddSChia-hung Duan /// by ReductionNode. Note that we are not able to remove an operation without 80c484c7ddSChia-hung Duan /// replacing it with another valid operation. However, The validity of module 81c484c7ddSChia-hung Duan /// reduction is based on the Tester provided by the user and that means certain 82c484c7ddSChia-hung Duan /// invalid module is still interested by the use. Thus we provide an 83c484c7ddSChia-hung Duan /// alternative way to remove operations, which is using `eraseOpNotInRange` to 84c484c7ddSChia-hung Duan /// erase the operations not in the range specified by ReductionNode. 85c484c7ddSChia-hung Duan template <typename IteratorType> 861a001dedSChia-hung Duan static LogicalResult findOptimal(ModuleOp module, Region ®ion, 87c484c7ddSChia-hung Duan const FrozenRewritePatternSet &patterns, 88c484c7ddSChia-hung Duan const Tester &test, bool eraseOpNotInRange) { 89c484c7ddSChia-hung Duan std::pair<Tester::Interestingness, size_t> initStatus = 90c484c7ddSChia-hung Duan test.isInteresting(module); 91c484c7ddSChia-hung Duan // While exploring the reduction tree, we always branch from an interesting 92c484c7ddSChia-hung Duan // node. Thus the root node must be interesting. 93c484c7ddSChia-hung Duan if (initStatus.first != Tester::Interestingness::True) 941a001dedSChia-hung Duan return module.emitWarning() << "uninterested module will not be reduced"; 95c484c7ddSChia-hung Duan 96c484c7ddSChia-hung Duan llvm::SpecificBumpPtrAllocator<ReductionNode> allocator; 97c484c7ddSChia-hung Duan 98c484c7ddSChia-hung Duan std::vector<ReductionNode::Range> ranges{ 99c484c7ddSChia-hung Duan {0, std::distance(region.op_begin(), region.op_end())}}; 100c484c7ddSChia-hung Duan 101c484c7ddSChia-hung Duan ReductionNode *root = allocator.Allocate(); 102337c937dSMehdi Amini new (root) ReductionNode(nullptr, ranges, allocator); 103c484c7ddSChia-hung Duan // Duplicate the module for root node and locate the region in the copy. 104c484c7ddSChia-hung Duan if (failed(root->initialize(module, region))) 105c484c7ddSChia-hung Duan llvm_unreachable("unexpected initialization failure"); 106c484c7ddSChia-hung Duan root->update(initStatus); 107c484c7ddSChia-hung Duan 108c484c7ddSChia-hung Duan ReductionNode *smallestNode = root; 109c484c7ddSChia-hung Duan IteratorType iter(root); 110c484c7ddSChia-hung Duan 111c484c7ddSChia-hung Duan while (iter != IteratorType::end()) { 112c484c7ddSChia-hung Duan ReductionNode ¤tNode = *iter; 113c484c7ddSChia-hung Duan Region &curRegion = currentNode.getRegion(); 114c484c7ddSChia-hung Duan 115c484c7ddSChia-hung Duan applyPatterns(curRegion, patterns, currentNode.getRanges(), 116c484c7ddSChia-hung Duan eraseOpNotInRange); 117c484c7ddSChia-hung Duan currentNode.update(test.isInteresting(currentNode.getModule())); 118c484c7ddSChia-hung Duan 119c484c7ddSChia-hung Duan if (currentNode.isInteresting() == Tester::Interestingness::True && 120c484c7ddSChia-hung Duan currentNode.getSize() < smallestNode->getSize()) 121c484c7ddSChia-hung Duan smallestNode = ¤tNode; 122c484c7ddSChia-hung Duan 123c484c7ddSChia-hung Duan ++iter; 124c484c7ddSChia-hung Duan } 125c484c7ddSChia-hung Duan 126c484c7ddSChia-hung Duan // At here, we have found an optimal path to reduce the given region. Retrieve 127c484c7ddSChia-hung Duan // the path and apply the reducer to it. 128c484c7ddSChia-hung Duan SmallVector<ReductionNode *> trace; 129c484c7ddSChia-hung Duan ReductionNode *curNode = smallestNode; 130c484c7ddSChia-hung Duan trace.push_back(curNode); 131c484c7ddSChia-hung Duan while (curNode != root) { 132c484c7ddSChia-hung Duan curNode = curNode->getParent(); 133c484c7ddSChia-hung Duan trace.push_back(curNode); 134c484c7ddSChia-hung Duan } 135c484c7ddSChia-hung Duan 136c484c7ddSChia-hung Duan // Reduce the region through the optimal path. 137c484c7ddSChia-hung Duan while (!trace.empty()) { 138c484c7ddSChia-hung Duan ReductionNode *top = trace.pop_back_val(); 139c484c7ddSChia-hung Duan applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange); 140c484c7ddSChia-hung Duan } 141c484c7ddSChia-hung Duan 142c484c7ddSChia-hung Duan if (test.isInteresting(module).first != Tester::Interestingness::True) 143c484c7ddSChia-hung Duan llvm::report_fatal_error("Reduced module is not interesting"); 144c484c7ddSChia-hung Duan if (test.isInteresting(module).second != smallestNode->getSize()) 145c484c7ddSChia-hung Duan llvm::report_fatal_error( 146c484c7ddSChia-hung Duan "Reduced module doesn't have consistent size with smallestNode"); 1471a001dedSChia-hung Duan return success(); 148c484c7ddSChia-hung Duan } 149c484c7ddSChia-hung Duan 150c484c7ddSChia-hung Duan template <typename IteratorType> 1511a001dedSChia-hung Duan static LogicalResult findOptimal(ModuleOp module, Region ®ion, 152c484c7ddSChia-hung Duan const FrozenRewritePatternSet &patterns, 153c484c7ddSChia-hung Duan const Tester &test) { 154c484c7ddSChia-hung Duan // We separate the reduction process into 2 steps, the first one is to erase 155c484c7ddSChia-hung Duan // redundant operations and the second one is to apply the reducer patterns. 156c484c7ddSChia-hung Duan 157c484c7ddSChia-hung Duan // In the first phase, we don't apply any patterns so that we only select the 158c484c7ddSChia-hung Duan // range of operations to keep to the module stay interesting. 1591a001dedSChia-hung Duan if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test, 1601a001dedSChia-hung Duan /*eraseOpNotInRange=*/true))) 1611a001dedSChia-hung Duan return failure(); 162c484c7ddSChia-hung Duan // In the second phase, we suppose that no operation is redundant, so we try 163c484c7ddSChia-hung Duan // to rewrite the operation into simpler form. 1641a001dedSChia-hung Duan return findOptimal<IteratorType>(module, region, patterns, test, 165c484c7ddSChia-hung Duan /*eraseOpNotInRange=*/false); 166c484c7ddSChia-hung Duan } 167c484c7ddSChia-hung Duan 168c484c7ddSChia-hung Duan namespace { 169c484c7ddSChia-hung Duan 170c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===// 171c484c7ddSChia-hung Duan // Reduction Pattern Interface Collection 172c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===// 173c484c7ddSChia-hung Duan 174c484c7ddSChia-hung Duan class ReductionPatternInterfaceCollection 175c484c7ddSChia-hung Duan : public DialectInterfaceCollection<DialectReductionPatternInterface> { 176c484c7ddSChia-hung Duan public: 177c484c7ddSChia-hung Duan using Base::Base; 178c484c7ddSChia-hung Duan 179c484c7ddSChia-hung Duan // Collect the reduce patterns defined by each dialect. 180c484c7ddSChia-hung Duan void populateReductionPatterns(RewritePatternSet &pattern) const { 181c484c7ddSChia-hung Duan for (const DialectReductionPatternInterface &interface : *this) 182c484c7ddSChia-hung Duan interface.populateReductionPatterns(pattern); 183c484c7ddSChia-hung Duan } 184c484c7ddSChia-hung Duan }; 185c484c7ddSChia-hung Duan 186c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===// 187c484c7ddSChia-hung Duan // ReductionTreePass 188c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===// 189c484c7ddSChia-hung Duan 190c484c7ddSChia-hung Duan /// This class defines the Reduction Tree Pass. It provides a framework to 191c484c7ddSChia-hung Duan /// to implement a reduction pass using a tree structure to keep track of the 192c484c7ddSChia-hung Duan /// generated reduced variants. 19367d0d7acSMichele Scuttari class ReductionTreePass : public impl::ReductionTreeBase<ReductionTreePass> { 194c484c7ddSChia-hung Duan public: 195c484c7ddSChia-hung Duan ReductionTreePass() = default; 196c484c7ddSChia-hung Duan ReductionTreePass(const ReductionTreePass &pass) = default; 197c484c7ddSChia-hung Duan 198c484c7ddSChia-hung Duan LogicalResult initialize(MLIRContext *context) override; 199c484c7ddSChia-hung Duan 200c484c7ddSChia-hung Duan /// Runs the pass instance in the pass pipeline. 201c484c7ddSChia-hung Duan void runOnOperation() override; 202c484c7ddSChia-hung Duan 203c484c7ddSChia-hung Duan private: 2041a001dedSChia-hung Duan LogicalResult reduceOp(ModuleOp module, Region ®ion); 205c484c7ddSChia-hung Duan 206c484c7ddSChia-hung Duan FrozenRewritePatternSet reducerPatterns; 207c484c7ddSChia-hung Duan }; 208c484c7ddSChia-hung Duan 209be0a7e9fSMehdi Amini } // namespace 210c484c7ddSChia-hung Duan 211c484c7ddSChia-hung Duan LogicalResult ReductionTreePass::initialize(MLIRContext *context) { 212c484c7ddSChia-hung Duan RewritePatternSet patterns(context); 213c484c7ddSChia-hung Duan ReductionPatternInterfaceCollection reducePatternCollection(context); 214c484c7ddSChia-hung Duan reducePatternCollection.populateReductionPatterns(patterns); 215c484c7ddSChia-hung Duan reducerPatterns = std::move(patterns); 216c484c7ddSChia-hung Duan return success(); 217c484c7ddSChia-hung Duan } 218c484c7ddSChia-hung Duan 219c484c7ddSChia-hung Duan void ReductionTreePass::runOnOperation() { 220c484c7ddSChia-hung Duan Operation *topOperation = getOperation(); 221c484c7ddSChia-hung Duan while (topOperation->getParentOp() != nullptr) 222c484c7ddSChia-hung Duan topOperation = topOperation->getParentOp(); 22355300991Srkayaith ModuleOp module = dyn_cast<ModuleOp>(topOperation); 22455300991Srkayaith if (!module) { 22555300991Srkayaith emitError(getOperation()->getLoc()) 22655300991Srkayaith << "top-level op must be 'builtin.module'"; 22755300991Srkayaith return signalPassFailure(); 22855300991Srkayaith } 229c484c7ddSChia-hung Duan 230c484c7ddSChia-hung Duan SmallVector<Operation *, 8> workList; 231c484c7ddSChia-hung Duan workList.push_back(getOperation()); 232c484c7ddSChia-hung Duan 233c484c7ddSChia-hung Duan do { 234c484c7ddSChia-hung Duan Operation *op = workList.pop_back_val(); 235c484c7ddSChia-hung Duan 236c484c7ddSChia-hung Duan for (Region ®ion : op->getRegions()) 237c484c7ddSChia-hung Duan if (!region.empty()) 2381a001dedSChia-hung Duan if (failed(reduceOp(module, region))) 2391a001dedSChia-hung Duan return signalPassFailure(); 240c484c7ddSChia-hung Duan 241c484c7ddSChia-hung Duan for (Region ®ion : op->getRegions()) 242c484c7ddSChia-hung Duan for (Operation &op : region.getOps()) 243c484c7ddSChia-hung Duan if (op.getNumRegions() != 0) 244c484c7ddSChia-hung Duan workList.push_back(&op); 245c484c7ddSChia-hung Duan } while (!workList.empty()); 246c484c7ddSChia-hung Duan } 247c484c7ddSChia-hung Duan 2481a001dedSChia-hung Duan LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) { 249c484c7ddSChia-hung Duan Tester test(testerName, testerArgs); 250c484c7ddSChia-hung Duan switch (traversalModeId) { 251c484c7ddSChia-hung Duan case TraversalMode::SinglePath: 2521a001dedSChia-hung Duan return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>( 253c484c7ddSChia-hung Duan module, region, reducerPatterns, test); 254c484c7ddSChia-hung Duan default: 2551a001dedSChia-hung Duan return module.emitError() << "unsupported traversal mode detected"; 256c484c7ddSChia-hung Duan } 257c484c7ddSChia-hung Duan } 258c484c7ddSChia-hung Duan 259c484c7ddSChia-hung Duan std::unique_ptr<Pass> mlir::createReductionTreePass() { 260c484c7ddSChia-hung Duan return std::make_unique<ReductionTreePass>(); 261c484c7ddSChia-hung Duan } 262