xref: /llvm-project/mlir/lib/Reducer/ReductionTreePass.cpp (revision 039b969b32b64b64123dce30dd28ec4e343d893f)
1 //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Pass class. It provides a framework for
10 // the implementation of different reduction passes in the MLIR Reduce tool. It
11 // allows for custom specification of the variant generation behavior. It
12 // implements methods that define the different possible traversals of the
13 // reduction tree.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/IR/DialectInterface.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/Reducer/PassDetail.h"
20 #include "mlir/Reducer/Passes.h"
21 #include "mlir/Reducer/ReductionNode.h"
22 #include "mlir/Reducer/ReductionPatternInterface.h"
23 #include "mlir/Reducer/Tester.h"
24 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/Support/Allocator.h"
30 #include "llvm/Support/ManagedStatic.h"
31 
32 using namespace mlir;
33 
34 /// We implicitly number each operation in the region and if an operation's
35 /// number falls into rangeToKeep, we need to keep it and apply the given
36 /// rewrite patterns on it.
37 static void applyPatterns(Region &region,
38                           const FrozenRewritePatternSet &patterns,
39                           ArrayRef<ReductionNode::Range> rangeToKeep,
40                           bool eraseOpNotInRange) {
41   std::vector<Operation *> opsNotInRange;
42   std::vector<Operation *> opsInRange;
43   size_t keepIndex = 0;
44   for (const auto &op : enumerate(region.getOps())) {
45     int index = op.index();
46     if (keepIndex < rangeToKeep.size() &&
47         index == rangeToKeep[keepIndex].second)
48       ++keepIndex;
49     if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
50       opsNotInRange.push_back(&op.value());
51     else
52       opsInRange.push_back(&op.value());
53   }
54 
55   // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
56   // matching in above iteration. Besides, erase op not-in-range may end up in
57   // invalid module, so `applyOpPatternsAndFold` should come before that
58   // transform.
59   for (Operation *op : opsInRange)
60     // `applyOpPatternsAndFold` returns whether the op is convered. Omit it
61     // because we don't have expectation this reduction will be success or not.
62     (void)applyOpPatternsAndFold(op, patterns);
63 
64   if (eraseOpNotInRange)
65     for (Operation *op : opsNotInRange) {
66       op->dropAllUses();
67       op->erase();
68     }
69 }
70 
71 /// We will apply the reducer patterns to the operations in the ranges specified
72 /// by ReductionNode. Note that we are not able to remove an operation without
73 /// replacing it with another valid operation. However, The validity of module
74 /// reduction is based on the Tester provided by the user and that means certain
75 /// invalid module is still interested by the use. Thus we provide an
76 /// alternative way to remove operations, which is using `eraseOpNotInRange` to
77 /// erase the operations not in the range specified by ReductionNode.
78 template <typename IteratorType>
79 static LogicalResult findOptimal(ModuleOp module, Region &region,
80                                  const FrozenRewritePatternSet &patterns,
81                                  const Tester &test, bool eraseOpNotInRange) {
82   std::pair<Tester::Interestingness, size_t> initStatus =
83       test.isInteresting(module);
84   // While exploring the reduction tree, we always branch from an interesting
85   // node. Thus the root node must be interesting.
86   if (initStatus.first != Tester::Interestingness::True)
87     return module.emitWarning() << "uninterested module will not be reduced";
88 
89   llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
90 
91   std::vector<ReductionNode::Range> ranges{
92       {0, std::distance(region.op_begin(), region.op_end())}};
93 
94   ReductionNode *root = allocator.Allocate();
95   new (root) ReductionNode(nullptr, ranges, allocator);
96   // Duplicate the module for root node and locate the region in the copy.
97   if (failed(root->initialize(module, region)))
98     llvm_unreachable("unexpected initialization failure");
99   root->update(initStatus);
100 
101   ReductionNode *smallestNode = root;
102   IteratorType iter(root);
103 
104   while (iter != IteratorType::end()) {
105     ReductionNode &currentNode = *iter;
106     Region &curRegion = currentNode.getRegion();
107 
108     applyPatterns(curRegion, patterns, currentNode.getRanges(),
109                   eraseOpNotInRange);
110     currentNode.update(test.isInteresting(currentNode.getModule()));
111 
112     if (currentNode.isInteresting() == Tester::Interestingness::True &&
113         currentNode.getSize() < smallestNode->getSize())
114       smallestNode = &currentNode;
115 
116     ++iter;
117   }
118 
119   // At here, we have found an optimal path to reduce the given region. Retrieve
120   // the path and apply the reducer to it.
121   SmallVector<ReductionNode *> trace;
122   ReductionNode *curNode = smallestNode;
123   trace.push_back(curNode);
124   while (curNode != root) {
125     curNode = curNode->getParent();
126     trace.push_back(curNode);
127   }
128 
129   // Reduce the region through the optimal path.
130   while (!trace.empty()) {
131     ReductionNode *top = trace.pop_back_val();
132     applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
133   }
134 
135   if (test.isInteresting(module).first != Tester::Interestingness::True)
136     llvm::report_fatal_error("Reduced module is not interesting");
137   if (test.isInteresting(module).second != smallestNode->getSize())
138     llvm::report_fatal_error(
139         "Reduced module doesn't have consistent size with smallestNode");
140   return success();
141 }
142 
143 template <typename IteratorType>
144 static LogicalResult findOptimal(ModuleOp module, Region &region,
145                                  const FrozenRewritePatternSet &patterns,
146                                  const Tester &test) {
147   // We separate the reduction process into 2 steps, the first one is to erase
148   // redundant operations and the second one is to apply the reducer patterns.
149 
150   // In the first phase, we don't apply any patterns so that we only select the
151   // range of operations to keep to the module stay interesting.
152   if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
153                                        /*eraseOpNotInRange=*/true)))
154     return failure();
155   // In the second phase, we suppose that no operation is redundant, so we try
156   // to rewrite the operation into simpler form.
157   return findOptimal<IteratorType>(module, region, patterns, test,
158                                    /*eraseOpNotInRange=*/false);
159 }
160 
161 namespace {
162 
163 //===----------------------------------------------------------------------===//
164 // Reduction Pattern Interface Collection
165 //===----------------------------------------------------------------------===//
166 
167 class ReductionPatternInterfaceCollection
168     : public DialectInterfaceCollection<DialectReductionPatternInterface> {
169 public:
170   using Base::Base;
171 
172   // Collect the reduce patterns defined by each dialect.
173   void populateReductionPatterns(RewritePatternSet &pattern) const {
174     for (const DialectReductionPatternInterface &interface : *this)
175       interface.populateReductionPatterns(pattern);
176   }
177 };
178 
179 //===----------------------------------------------------------------------===//
180 // ReductionTreePass
181 //===----------------------------------------------------------------------===//
182 
183 /// This class defines the Reduction Tree Pass. It provides a framework to
184 /// to implement a reduction pass using a tree structure to keep track of the
185 /// generated reduced variants.
186 class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> {
187 public:
188   ReductionTreePass() = default;
189   ReductionTreePass(const ReductionTreePass &pass) = default;
190 
191   LogicalResult initialize(MLIRContext *context) override;
192 
193   /// Runs the pass instance in the pass pipeline.
194   void runOnOperation() override;
195 
196 private:
197   LogicalResult reduceOp(ModuleOp module, Region &region);
198 
199   FrozenRewritePatternSet reducerPatterns;
200 };
201 
202 } // namespace
203 
204 LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
205   RewritePatternSet patterns(context);
206   ReductionPatternInterfaceCollection reducePatternCollection(context);
207   reducePatternCollection.populateReductionPatterns(patterns);
208   reducerPatterns = std::move(patterns);
209   return success();
210 }
211 
212 void ReductionTreePass::runOnOperation() {
213   Operation *topOperation = getOperation();
214   while (topOperation->getParentOp() != nullptr)
215     topOperation = topOperation->getParentOp();
216   ModuleOp module = cast<ModuleOp>(topOperation);
217 
218   SmallVector<Operation *, 8> workList;
219   workList.push_back(getOperation());
220 
221   do {
222     Operation *op = workList.pop_back_val();
223 
224     for (Region &region : op->getRegions())
225       if (!region.empty())
226         if (failed(reduceOp(module, region)))
227           return signalPassFailure();
228 
229     for (Region &region : op->getRegions())
230       for (Operation &op : region.getOps())
231         if (op.getNumRegions() != 0)
232           workList.push_back(&op);
233   } while (!workList.empty());
234 }
235 
236 LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
237   Tester test(testerName, testerArgs);
238   switch (traversalModeId) {
239   case TraversalMode::SinglePath:
240     return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
241         module, region, reducerPatterns, test);
242   default:
243     return module.emitError() << "unsupported traversal mode detected";
244   }
245 }
246 
247 std::unique_ptr<Pass> mlir::createReductionTreePass() {
248   return std::make_unique<ReductionTreePass>();
249 }
250