xref: /llvm-project/mlir/lib/Reducer/ReductionTreePass.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Pass class. It provides a framework for
10 // the implementation of different reduction passes in the MLIR Reduce tool. It
11 // allows for custom specification of the variant generation behavior. It
12 // implements methods that define the different possible traversals of the
13 // reduction tree.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/IR/DialectInterface.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/Reducer/Passes.h"
20 #include "mlir/Reducer/ReductionNode.h"
21 #include "mlir/Reducer/ReductionPatternInterface.h"
22 #include "mlir/Reducer/Tester.h"
23 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/Allocator.h"
29 #include "llvm/Support/ManagedStatic.h"
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_REDUCTIONTREE
33 #include "mlir/Reducer/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 
38 /// We implicitly number each operation in the region and if an operation's
39 /// number falls into rangeToKeep, we need to keep it and apply the given
40 /// rewrite patterns on it.
41 static void applyPatterns(Region &region,
42                           const FrozenRewritePatternSet &patterns,
43                           ArrayRef<ReductionNode::Range> rangeToKeep,
44                           bool eraseOpNotInRange) {
45   std::vector<Operation *> opsNotInRange;
46   std::vector<Operation *> opsInRange;
47   size_t keepIndex = 0;
48   for (const auto &op : enumerate(region.getOps())) {
49     int index = op.index();
50     if (keepIndex < rangeToKeep.size() &&
51         index == rangeToKeep[keepIndex].second)
52       ++keepIndex;
53     if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
54       opsNotInRange.push_back(&op.value());
55     else
56       opsInRange.push_back(&op.value());
57   }
58 
59   // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
60   // matching in above iteration. Besides, erase op not-in-range may end up in
61   // invalid module, so `applyOpPatternsAndFold` should come before that
62   // transform.
63   for (Operation *op : opsInRange) {
64     // `applyOpPatternsAndFold` returns whether the op is convered. Omit it
65     // because we don't have expectation this reduction will be success or not.
66     GreedyRewriteConfig config;
67     config.strictMode = GreedyRewriteStrictness::ExistingOps;
68     (void)applyOpPatternsGreedily(op, patterns, config);
69   }
70 
71   if (eraseOpNotInRange)
72     for (Operation *op : opsNotInRange) {
73       op->dropAllUses();
74       op->erase();
75     }
76 }
77 
78 /// We will apply the reducer patterns to the operations in the ranges specified
79 /// by ReductionNode. Note that we are not able to remove an operation without
80 /// replacing it with another valid operation. However, The validity of module
81 /// reduction is based on the Tester provided by the user and that means certain
82 /// invalid module is still interested by the use. Thus we provide an
83 /// alternative way to remove operations, which is using `eraseOpNotInRange` to
84 /// erase the operations not in the range specified by ReductionNode.
85 template <typename IteratorType>
86 static LogicalResult findOptimal(ModuleOp module, Region &region,
87                                  const FrozenRewritePatternSet &patterns,
88                                  const Tester &test, bool eraseOpNotInRange) {
89   std::pair<Tester::Interestingness, size_t> initStatus =
90       test.isInteresting(module);
91   // While exploring the reduction tree, we always branch from an interesting
92   // node. Thus the root node must be interesting.
93   if (initStatus.first != Tester::Interestingness::True)
94     return module.emitWarning() << "uninterested module will not be reduced";
95 
96   llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
97 
98   std::vector<ReductionNode::Range> ranges{
99       {0, std::distance(region.op_begin(), region.op_end())}};
100 
101   ReductionNode *root = allocator.Allocate();
102   new (root) ReductionNode(nullptr, ranges, allocator);
103   // Duplicate the module for root node and locate the region in the copy.
104   if (failed(root->initialize(module, region)))
105     llvm_unreachable("unexpected initialization failure");
106   root->update(initStatus);
107 
108   ReductionNode *smallestNode = root;
109   IteratorType iter(root);
110 
111   while (iter != IteratorType::end()) {
112     ReductionNode &currentNode = *iter;
113     Region &curRegion = currentNode.getRegion();
114 
115     applyPatterns(curRegion, patterns, currentNode.getRanges(),
116                   eraseOpNotInRange);
117     currentNode.update(test.isInteresting(currentNode.getModule()));
118 
119     if (currentNode.isInteresting() == Tester::Interestingness::True &&
120         currentNode.getSize() < smallestNode->getSize())
121       smallestNode = &currentNode;
122 
123     ++iter;
124   }
125 
126   // At here, we have found an optimal path to reduce the given region. Retrieve
127   // the path and apply the reducer to it.
128   SmallVector<ReductionNode *> trace;
129   ReductionNode *curNode = smallestNode;
130   trace.push_back(curNode);
131   while (curNode != root) {
132     curNode = curNode->getParent();
133     trace.push_back(curNode);
134   }
135 
136   // Reduce the region through the optimal path.
137   while (!trace.empty()) {
138     ReductionNode *top = trace.pop_back_val();
139     applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
140   }
141 
142   if (test.isInteresting(module).first != Tester::Interestingness::True)
143     llvm::report_fatal_error("Reduced module is not interesting");
144   if (test.isInteresting(module).second != smallestNode->getSize())
145     llvm::report_fatal_error(
146         "Reduced module doesn't have consistent size with smallestNode");
147   return success();
148 }
149 
150 template <typename IteratorType>
151 static LogicalResult findOptimal(ModuleOp module, Region &region,
152                                  const FrozenRewritePatternSet &patterns,
153                                  const Tester &test) {
154   // We separate the reduction process into 2 steps, the first one is to erase
155   // redundant operations and the second one is to apply the reducer patterns.
156 
157   // In the first phase, we don't apply any patterns so that we only select the
158   // range of operations to keep to the module stay interesting.
159   if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
160                                        /*eraseOpNotInRange=*/true)))
161     return failure();
162   // In the second phase, we suppose that no operation is redundant, so we try
163   // to rewrite the operation into simpler form.
164   return findOptimal<IteratorType>(module, region, patterns, test,
165                                    /*eraseOpNotInRange=*/false);
166 }
167 
168 namespace {
169 
170 //===----------------------------------------------------------------------===//
171 // Reduction Pattern Interface Collection
172 //===----------------------------------------------------------------------===//
173 
174 class ReductionPatternInterfaceCollection
175     : public DialectInterfaceCollection<DialectReductionPatternInterface> {
176 public:
177   using Base::Base;
178 
179   // Collect the reduce patterns defined by each dialect.
180   void populateReductionPatterns(RewritePatternSet &pattern) const {
181     for (const DialectReductionPatternInterface &interface : *this)
182       interface.populateReductionPatterns(pattern);
183   }
184 };
185 
186 //===----------------------------------------------------------------------===//
187 // ReductionTreePass
188 //===----------------------------------------------------------------------===//
189 
190 /// This class defines the Reduction Tree Pass. It provides a framework to
191 /// to implement a reduction pass using a tree structure to keep track of the
192 /// generated reduced variants.
193 class ReductionTreePass : public impl::ReductionTreeBase<ReductionTreePass> {
194 public:
195   ReductionTreePass() = default;
196   ReductionTreePass(const ReductionTreePass &pass) = default;
197 
198   LogicalResult initialize(MLIRContext *context) override;
199 
200   /// Runs the pass instance in the pass pipeline.
201   void runOnOperation() override;
202 
203 private:
204   LogicalResult reduceOp(ModuleOp module, Region &region);
205 
206   FrozenRewritePatternSet reducerPatterns;
207 };
208 
209 } // namespace
210 
211 LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
212   RewritePatternSet patterns(context);
213   ReductionPatternInterfaceCollection reducePatternCollection(context);
214   reducePatternCollection.populateReductionPatterns(patterns);
215   reducerPatterns = std::move(patterns);
216   return success();
217 }
218 
219 void ReductionTreePass::runOnOperation() {
220   Operation *topOperation = getOperation();
221   while (topOperation->getParentOp() != nullptr)
222     topOperation = topOperation->getParentOp();
223   ModuleOp module = dyn_cast<ModuleOp>(topOperation);
224   if (!module) {
225     emitError(getOperation()->getLoc())
226         << "top-level op must be 'builtin.module'";
227     return signalPassFailure();
228   }
229 
230   SmallVector<Operation *, 8> workList;
231   workList.push_back(getOperation());
232 
233   do {
234     Operation *op = workList.pop_back_val();
235 
236     for (Region &region : op->getRegions())
237       if (!region.empty())
238         if (failed(reduceOp(module, region)))
239           return signalPassFailure();
240 
241     for (Region &region : op->getRegions())
242       for (Operation &op : region.getOps())
243         if (op.getNumRegions() != 0)
244           workList.push_back(&op);
245   } while (!workList.empty());
246 }
247 
248 LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
249   Tester test(testerName, testerArgs);
250   switch (traversalModeId) {
251   case TraversalMode::SinglePath:
252     return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
253         module, region, reducerPatterns, test);
254   default:
255     return module.emitError() << "unsupported traversal mode detected";
256   }
257 }
258 
259 std::unique_ptr<Pass> mlir::createReductionTreePass() {
260   return std::make_unique<ReductionTreePass>();
261 }
262