xref: /llvm-project/mlir/lib/Reducer/ReductionTreePass.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1c484c7ddSChia-hung Duan //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2c484c7ddSChia-hung Duan //
3c484c7ddSChia-hung Duan // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c484c7ddSChia-hung Duan // See https://llvm.org/LICENSE.txt for license information.
5c484c7ddSChia-hung Duan // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c484c7ddSChia-hung Duan //
7c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
8c484c7ddSChia-hung Duan //
9c484c7ddSChia-hung Duan // This file defines the Reduction Tree Pass class. It provides a framework for
10c484c7ddSChia-hung Duan // the implementation of different reduction passes in the MLIR Reduce tool. It
11c484c7ddSChia-hung Duan // allows for custom specification of the variant generation behavior. It
12c484c7ddSChia-hung Duan // implements methods that define the different possible traversals of the
13c484c7ddSChia-hung Duan // reduction tree.
14c484c7ddSChia-hung Duan //
15c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
16c484c7ddSChia-hung Duan 
17c484c7ddSChia-hung Duan #include "mlir/IR/DialectInterface.h"
18c484c7ddSChia-hung Duan #include "mlir/IR/OpDefinition.h"
19c484c7ddSChia-hung Duan #include "mlir/Reducer/Passes.h"
20c484c7ddSChia-hung Duan #include "mlir/Reducer/ReductionNode.h"
21c484c7ddSChia-hung Duan #include "mlir/Reducer/ReductionPatternInterface.h"
22c484c7ddSChia-hung Duan #include "mlir/Reducer/Tester.h"
23c484c7ddSChia-hung Duan #include "mlir/Rewrite/FrozenRewritePatternSet.h"
24c484c7ddSChia-hung Duan #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25c484c7ddSChia-hung Duan 
26c484c7ddSChia-hung Duan #include "llvm/ADT/ArrayRef.h"
27c484c7ddSChia-hung Duan #include "llvm/ADT/SmallVector.h"
28c484c7ddSChia-hung Duan #include "llvm/Support/Allocator.h"
29c484c7ddSChia-hung Duan #include "llvm/Support/ManagedStatic.h"
30c484c7ddSChia-hung Duan 
3167d0d7acSMichele Scuttari namespace mlir {
3267d0d7acSMichele Scuttari #define GEN_PASS_DEF_REDUCTIONTREE
3367d0d7acSMichele Scuttari #include "mlir/Reducer/Passes.h.inc"
3467d0d7acSMichele Scuttari } // namespace mlir
3567d0d7acSMichele Scuttari 
36c484c7ddSChia-hung Duan using namespace mlir;
37c484c7ddSChia-hung Duan 
38c484c7ddSChia-hung Duan /// We implicitly number each operation in the region and if an operation's
39c484c7ddSChia-hung Duan /// number falls into rangeToKeep, we need to keep it and apply the given
40c484c7ddSChia-hung Duan /// rewrite patterns on it.
41c484c7ddSChia-hung Duan static void applyPatterns(Region &region,
42c484c7ddSChia-hung Duan                           const FrozenRewritePatternSet &patterns,
43c484c7ddSChia-hung Duan                           ArrayRef<ReductionNode::Range> rangeToKeep,
44c484c7ddSChia-hung Duan                           bool eraseOpNotInRange) {
45c484c7ddSChia-hung Duan   std::vector<Operation *> opsNotInRange;
46c484c7ddSChia-hung Duan   std::vector<Operation *> opsInRange;
47c484c7ddSChia-hung Duan   size_t keepIndex = 0;
48e4853be2SMehdi Amini   for (const auto &op : enumerate(region.getOps())) {
49c484c7ddSChia-hung Duan     int index = op.index();
50c484c7ddSChia-hung Duan     if (keepIndex < rangeToKeep.size() &&
51c484c7ddSChia-hung Duan         index == rangeToKeep[keepIndex].second)
52c484c7ddSChia-hung Duan       ++keepIndex;
53c484c7ddSChia-hung Duan     if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
54c484c7ddSChia-hung Duan       opsNotInRange.push_back(&op.value());
55c484c7ddSChia-hung Duan     else
56c484c7ddSChia-hung Duan       opsInRange.push_back(&op.value());
57c484c7ddSChia-hung Duan   }
58c484c7ddSChia-hung Duan 
59c484c7ddSChia-hung Duan   // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
60c484c7ddSChia-hung Duan   // matching in above iteration. Besides, erase op not-in-range may end up in
61c484c7ddSChia-hung Duan   // invalid module, so `applyOpPatternsAndFold` should come before that
62c484c7ddSChia-hung Duan   // transform.
636bdecbcbSMatthias Springer   for (Operation *op : opsInRange) {
64c484c7ddSChia-hung Duan     // `applyOpPatternsAndFold` returns whether the op is convered. Omit it
65c484c7ddSChia-hung Duan     // because we don't have expectation this reduction will be success or not.
666bdecbcbSMatthias Springer     GreedyRewriteConfig config;
676bdecbcbSMatthias Springer     config.strictMode = GreedyRewriteStrictness::ExistingOps;
68*09dfc571SJacques Pienaar     (void)applyOpPatternsGreedily(op, patterns, config);
696bdecbcbSMatthias Springer   }
70c484c7ddSChia-hung Duan 
71c484c7ddSChia-hung Duan   if (eraseOpNotInRange)
72c484c7ddSChia-hung Duan     for (Operation *op : opsNotInRange) {
73c484c7ddSChia-hung Duan       op->dropAllUses();
74c484c7ddSChia-hung Duan       op->erase();
75c484c7ddSChia-hung Duan     }
76c484c7ddSChia-hung Duan }
77c484c7ddSChia-hung Duan 
78c484c7ddSChia-hung Duan /// We will apply the reducer patterns to the operations in the ranges specified
79c484c7ddSChia-hung Duan /// by ReductionNode. Note that we are not able to remove an operation without
80c484c7ddSChia-hung Duan /// replacing it with another valid operation. However, The validity of module
81c484c7ddSChia-hung Duan /// reduction is based on the Tester provided by the user and that means certain
82c484c7ddSChia-hung Duan /// invalid module is still interested by the use. Thus we provide an
83c484c7ddSChia-hung Duan /// alternative way to remove operations, which is using `eraseOpNotInRange` to
84c484c7ddSChia-hung Duan /// erase the operations not in the range specified by ReductionNode.
85c484c7ddSChia-hung Duan template <typename IteratorType>
861a001dedSChia-hung Duan static LogicalResult findOptimal(ModuleOp module, Region &region,
87c484c7ddSChia-hung Duan                                  const FrozenRewritePatternSet &patterns,
88c484c7ddSChia-hung Duan                                  const Tester &test, bool eraseOpNotInRange) {
89c484c7ddSChia-hung Duan   std::pair<Tester::Interestingness, size_t> initStatus =
90c484c7ddSChia-hung Duan       test.isInteresting(module);
91c484c7ddSChia-hung Duan   // While exploring the reduction tree, we always branch from an interesting
92c484c7ddSChia-hung Duan   // node. Thus the root node must be interesting.
93c484c7ddSChia-hung Duan   if (initStatus.first != Tester::Interestingness::True)
941a001dedSChia-hung Duan     return module.emitWarning() << "uninterested module will not be reduced";
95c484c7ddSChia-hung Duan 
96c484c7ddSChia-hung Duan   llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
97c484c7ddSChia-hung Duan 
98c484c7ddSChia-hung Duan   std::vector<ReductionNode::Range> ranges{
99c484c7ddSChia-hung Duan       {0, std::distance(region.op_begin(), region.op_end())}};
100c484c7ddSChia-hung Duan 
101c484c7ddSChia-hung Duan   ReductionNode *root = allocator.Allocate();
102337c937dSMehdi Amini   new (root) ReductionNode(nullptr, ranges, allocator);
103c484c7ddSChia-hung Duan   // Duplicate the module for root node and locate the region in the copy.
104c484c7ddSChia-hung Duan   if (failed(root->initialize(module, region)))
105c484c7ddSChia-hung Duan     llvm_unreachable("unexpected initialization failure");
106c484c7ddSChia-hung Duan   root->update(initStatus);
107c484c7ddSChia-hung Duan 
108c484c7ddSChia-hung Duan   ReductionNode *smallestNode = root;
109c484c7ddSChia-hung Duan   IteratorType iter(root);
110c484c7ddSChia-hung Duan 
111c484c7ddSChia-hung Duan   while (iter != IteratorType::end()) {
112c484c7ddSChia-hung Duan     ReductionNode &currentNode = *iter;
113c484c7ddSChia-hung Duan     Region &curRegion = currentNode.getRegion();
114c484c7ddSChia-hung Duan 
115c484c7ddSChia-hung Duan     applyPatterns(curRegion, patterns, currentNode.getRanges(),
116c484c7ddSChia-hung Duan                   eraseOpNotInRange);
117c484c7ddSChia-hung Duan     currentNode.update(test.isInteresting(currentNode.getModule()));
118c484c7ddSChia-hung Duan 
119c484c7ddSChia-hung Duan     if (currentNode.isInteresting() == Tester::Interestingness::True &&
120c484c7ddSChia-hung Duan         currentNode.getSize() < smallestNode->getSize())
121c484c7ddSChia-hung Duan       smallestNode = &currentNode;
122c484c7ddSChia-hung Duan 
123c484c7ddSChia-hung Duan     ++iter;
124c484c7ddSChia-hung Duan   }
125c484c7ddSChia-hung Duan 
126c484c7ddSChia-hung Duan   // At here, we have found an optimal path to reduce the given region. Retrieve
127c484c7ddSChia-hung Duan   // the path and apply the reducer to it.
128c484c7ddSChia-hung Duan   SmallVector<ReductionNode *> trace;
129c484c7ddSChia-hung Duan   ReductionNode *curNode = smallestNode;
130c484c7ddSChia-hung Duan   trace.push_back(curNode);
131c484c7ddSChia-hung Duan   while (curNode != root) {
132c484c7ddSChia-hung Duan     curNode = curNode->getParent();
133c484c7ddSChia-hung Duan     trace.push_back(curNode);
134c484c7ddSChia-hung Duan   }
135c484c7ddSChia-hung Duan 
136c484c7ddSChia-hung Duan   // Reduce the region through the optimal path.
137c484c7ddSChia-hung Duan   while (!trace.empty()) {
138c484c7ddSChia-hung Duan     ReductionNode *top = trace.pop_back_val();
139c484c7ddSChia-hung Duan     applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
140c484c7ddSChia-hung Duan   }
141c484c7ddSChia-hung Duan 
142c484c7ddSChia-hung Duan   if (test.isInteresting(module).first != Tester::Interestingness::True)
143c484c7ddSChia-hung Duan     llvm::report_fatal_error("Reduced module is not interesting");
144c484c7ddSChia-hung Duan   if (test.isInteresting(module).second != smallestNode->getSize())
145c484c7ddSChia-hung Duan     llvm::report_fatal_error(
146c484c7ddSChia-hung Duan         "Reduced module doesn't have consistent size with smallestNode");
1471a001dedSChia-hung Duan   return success();
148c484c7ddSChia-hung Duan }
149c484c7ddSChia-hung Duan 
150c484c7ddSChia-hung Duan template <typename IteratorType>
1511a001dedSChia-hung Duan static LogicalResult findOptimal(ModuleOp module, Region &region,
152c484c7ddSChia-hung Duan                                  const FrozenRewritePatternSet &patterns,
153c484c7ddSChia-hung Duan                                  const Tester &test) {
154c484c7ddSChia-hung Duan   // We separate the reduction process into 2 steps, the first one is to erase
155c484c7ddSChia-hung Duan   // redundant operations and the second one is to apply the reducer patterns.
156c484c7ddSChia-hung Duan 
157c484c7ddSChia-hung Duan   // In the first phase, we don't apply any patterns so that we only select the
158c484c7ddSChia-hung Duan   // range of operations to keep to the module stay interesting.
1591a001dedSChia-hung Duan   if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
1601a001dedSChia-hung Duan                                        /*eraseOpNotInRange=*/true)))
1611a001dedSChia-hung Duan     return failure();
162c484c7ddSChia-hung Duan   // In the second phase, we suppose that no operation is redundant, so we try
163c484c7ddSChia-hung Duan   // to rewrite the operation into simpler form.
1641a001dedSChia-hung Duan   return findOptimal<IteratorType>(module, region, patterns, test,
165c484c7ddSChia-hung Duan                                    /*eraseOpNotInRange=*/false);
166c484c7ddSChia-hung Duan }
167c484c7ddSChia-hung Duan 
168c484c7ddSChia-hung Duan namespace {
169c484c7ddSChia-hung Duan 
170c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
171c484c7ddSChia-hung Duan // Reduction Pattern Interface Collection
172c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
173c484c7ddSChia-hung Duan 
174c484c7ddSChia-hung Duan class ReductionPatternInterfaceCollection
175c484c7ddSChia-hung Duan     : public DialectInterfaceCollection<DialectReductionPatternInterface> {
176c484c7ddSChia-hung Duan public:
177c484c7ddSChia-hung Duan   using Base::Base;
178c484c7ddSChia-hung Duan 
179c484c7ddSChia-hung Duan   // Collect the reduce patterns defined by each dialect.
180c484c7ddSChia-hung Duan   void populateReductionPatterns(RewritePatternSet &pattern) const {
181c484c7ddSChia-hung Duan     for (const DialectReductionPatternInterface &interface : *this)
182c484c7ddSChia-hung Duan       interface.populateReductionPatterns(pattern);
183c484c7ddSChia-hung Duan   }
184c484c7ddSChia-hung Duan };
185c484c7ddSChia-hung Duan 
186c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
187c484c7ddSChia-hung Duan // ReductionTreePass
188c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
189c484c7ddSChia-hung Duan 
190c484c7ddSChia-hung Duan /// This class defines the Reduction Tree Pass. It provides a framework to
191c484c7ddSChia-hung Duan /// to implement a reduction pass using a tree structure to keep track of the
192c484c7ddSChia-hung Duan /// generated reduced variants.
19367d0d7acSMichele Scuttari class ReductionTreePass : public impl::ReductionTreeBase<ReductionTreePass> {
194c484c7ddSChia-hung Duan public:
195c484c7ddSChia-hung Duan   ReductionTreePass() = default;
196c484c7ddSChia-hung Duan   ReductionTreePass(const ReductionTreePass &pass) = default;
197c484c7ddSChia-hung Duan 
198c484c7ddSChia-hung Duan   LogicalResult initialize(MLIRContext *context) override;
199c484c7ddSChia-hung Duan 
200c484c7ddSChia-hung Duan   /// Runs the pass instance in the pass pipeline.
201c484c7ddSChia-hung Duan   void runOnOperation() override;
202c484c7ddSChia-hung Duan 
203c484c7ddSChia-hung Duan private:
2041a001dedSChia-hung Duan   LogicalResult reduceOp(ModuleOp module, Region &region);
205c484c7ddSChia-hung Duan 
206c484c7ddSChia-hung Duan   FrozenRewritePatternSet reducerPatterns;
207c484c7ddSChia-hung Duan };
208c484c7ddSChia-hung Duan 
209be0a7e9fSMehdi Amini } // namespace
210c484c7ddSChia-hung Duan 
211c484c7ddSChia-hung Duan LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
212c484c7ddSChia-hung Duan   RewritePatternSet patterns(context);
213c484c7ddSChia-hung Duan   ReductionPatternInterfaceCollection reducePatternCollection(context);
214c484c7ddSChia-hung Duan   reducePatternCollection.populateReductionPatterns(patterns);
215c484c7ddSChia-hung Duan   reducerPatterns = std::move(patterns);
216c484c7ddSChia-hung Duan   return success();
217c484c7ddSChia-hung Duan }
218c484c7ddSChia-hung Duan 
219c484c7ddSChia-hung Duan void ReductionTreePass::runOnOperation() {
220c484c7ddSChia-hung Duan   Operation *topOperation = getOperation();
221c484c7ddSChia-hung Duan   while (topOperation->getParentOp() != nullptr)
222c484c7ddSChia-hung Duan     topOperation = topOperation->getParentOp();
22355300991Srkayaith   ModuleOp module = dyn_cast<ModuleOp>(topOperation);
22455300991Srkayaith   if (!module) {
22555300991Srkayaith     emitError(getOperation()->getLoc())
22655300991Srkayaith         << "top-level op must be 'builtin.module'";
22755300991Srkayaith     return signalPassFailure();
22855300991Srkayaith   }
229c484c7ddSChia-hung Duan 
230c484c7ddSChia-hung Duan   SmallVector<Operation *, 8> workList;
231c484c7ddSChia-hung Duan   workList.push_back(getOperation());
232c484c7ddSChia-hung Duan 
233c484c7ddSChia-hung Duan   do {
234c484c7ddSChia-hung Duan     Operation *op = workList.pop_back_val();
235c484c7ddSChia-hung Duan 
236c484c7ddSChia-hung Duan     for (Region &region : op->getRegions())
237c484c7ddSChia-hung Duan       if (!region.empty())
2381a001dedSChia-hung Duan         if (failed(reduceOp(module, region)))
2391a001dedSChia-hung Duan           return signalPassFailure();
240c484c7ddSChia-hung Duan 
241c484c7ddSChia-hung Duan     for (Region &region : op->getRegions())
242c484c7ddSChia-hung Duan       for (Operation &op : region.getOps())
243c484c7ddSChia-hung Duan         if (op.getNumRegions() != 0)
244c484c7ddSChia-hung Duan           workList.push_back(&op);
245c484c7ddSChia-hung Duan   } while (!workList.empty());
246c484c7ddSChia-hung Duan }
247c484c7ddSChia-hung Duan 
2481a001dedSChia-hung Duan LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
249c484c7ddSChia-hung Duan   Tester test(testerName, testerArgs);
250c484c7ddSChia-hung Duan   switch (traversalModeId) {
251c484c7ddSChia-hung Duan   case TraversalMode::SinglePath:
2521a001dedSChia-hung Duan     return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
253c484c7ddSChia-hung Duan         module, region, reducerPatterns, test);
254c484c7ddSChia-hung Duan   default:
2551a001dedSChia-hung Duan     return module.emitError() << "unsupported traversal mode detected";
256c484c7ddSChia-hung Duan   }
257c484c7ddSChia-hung Duan }
258c484c7ddSChia-hung Duan 
259c484c7ddSChia-hung Duan std::unique_ptr<Pass> mlir::createReductionTreePass() {
260c484c7ddSChia-hung Duan   return std::make_unique<ReductionTreePass>();
261c484c7ddSChia-hung Duan }
262