xref: /llvm-project/mlir/lib/Reducer/ReductionTreePass.cpp (revision c484c7dd9d2382f07216ae9142ceb76272e21dc4)
1 //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Pass class. It provides a framework for
10 // the implementation of different reduction passes in the MLIR Reduce tool. It
11 // allows for custom specification of the variant generation behavior. It
12 // implements methods that define the different possible traversals of the
13 // reduction tree.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/IR/DialectInterface.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/Reducer/PassDetail.h"
20 #include "mlir/Reducer/Passes.h"
21 #include "mlir/Reducer/ReductionNode.h"
22 #include "mlir/Reducer/ReductionPatternInterface.h"
23 #include "mlir/Reducer/Tester.h"
24 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/Support/Allocator.h"
30 #include "llvm/Support/ManagedStatic.h"
31 
32 using namespace mlir;
33 
34 /// We implicitly number each operation in the region and if an operation's
35 /// number falls into rangeToKeep, we need to keep it and apply the given
36 /// rewrite patterns on it.
37 static void applyPatterns(Region &region,
38                           const FrozenRewritePatternSet &patterns,
39                           ArrayRef<ReductionNode::Range> rangeToKeep,
40                           bool eraseOpNotInRange) {
41   std::vector<Operation *> opsNotInRange;
42   std::vector<Operation *> opsInRange;
43   size_t keepIndex = 0;
44   for (auto op : enumerate(region.getOps())) {
45     int index = op.index();
46     if (keepIndex < rangeToKeep.size() &&
47         index == rangeToKeep[keepIndex].second)
48       ++keepIndex;
49     if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
50       opsNotInRange.push_back(&op.value());
51     else
52       opsInRange.push_back(&op.value());
53   }
54 
55   // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
56   // matching in above iteration. Besides, erase op not-in-range may end up in
57   // invalid module, so `applyOpPatternsAndFold` should come before that
58   // transform.
59   for (Operation *op : opsInRange)
60     // `applyOpPatternsAndFold` returns whether the op is convered. Omit it
61     // because we don't have expectation this reduction will be success or not.
62     (void)applyOpPatternsAndFold(op, patterns);
63 
64   if (eraseOpNotInRange)
65     for (Operation *op : opsNotInRange) {
66       op->dropAllUses();
67       op->erase();
68     }
69 }
70 
71 /// We will apply the reducer patterns to the operations in the ranges specified
72 /// by ReductionNode. Note that we are not able to remove an operation without
73 /// replacing it with another valid operation. However, The validity of module
74 /// reduction is based on the Tester provided by the user and that means certain
75 /// invalid module is still interested by the use. Thus we provide an
76 /// alternative way to remove operations, which is using `eraseOpNotInRange` to
77 /// erase the operations not in the range specified by ReductionNode.
78 template <typename IteratorType>
79 static void findOptimal(ModuleOp module, Region &region,
80                         const FrozenRewritePatternSet &patterns,
81                         const Tester &test, bool eraseOpNotInRange) {
82   std::pair<Tester::Interestingness, size_t> initStatus =
83       test.isInteresting(module);
84   // While exploring the reduction tree, we always branch from an interesting
85   // node. Thus the root node must be interesting.
86   if (initStatus.first != Tester::Interestingness::True)
87     return;
88 
89   llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
90 
91   std::vector<ReductionNode::Range> ranges{
92       {0, std::distance(region.op_begin(), region.op_end())}};
93 
94   ReductionNode *root = allocator.Allocate();
95   new (root) ReductionNode(nullptr, std::move(ranges), allocator);
96   // Duplicate the module for root node and locate the region in the copy.
97   if (failed(root->initialize(module, region)))
98     llvm_unreachable("unexpected initialization failure");
99   root->update(initStatus);
100 
101   ReductionNode *smallestNode = root;
102   IteratorType iter(root);
103 
104   while (iter != IteratorType::end()) {
105     ReductionNode &currentNode = *iter;
106     Region &curRegion = currentNode.getRegion();
107 
108     applyPatterns(curRegion, patterns, currentNode.getRanges(),
109                   eraseOpNotInRange);
110     currentNode.update(test.isInteresting(currentNode.getModule()));
111 
112     if (currentNode.isInteresting() == Tester::Interestingness::True &&
113         currentNode.getSize() < smallestNode->getSize())
114       smallestNode = &currentNode;
115 
116     ++iter;
117   }
118 
119   // At here, we have found an optimal path to reduce the given region. Retrieve
120   // the path and apply the reducer to it.
121   SmallVector<ReductionNode *> trace;
122   ReductionNode *curNode = smallestNode;
123   trace.push_back(curNode);
124   while (curNode != root) {
125     curNode = curNode->getParent();
126     trace.push_back(curNode);
127   }
128 
129   // Reduce the region through the optimal path.
130   while (!trace.empty()) {
131     ReductionNode *top = trace.pop_back_val();
132     applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
133   }
134 
135   if (test.isInteresting(module).first != Tester::Interestingness::True)
136     llvm::report_fatal_error("Reduced module is not interesting");
137   if (test.isInteresting(module).second != smallestNode->getSize())
138     llvm::report_fatal_error(
139         "Reduced module doesn't have consistent size with smallestNode");
140 }
141 
142 template <typename IteratorType>
143 static void findOptimal(ModuleOp module, Region &region,
144                         const FrozenRewritePatternSet &patterns,
145                         const Tester &test) {
146   // We separate the reduction process into 2 steps, the first one is to erase
147   // redundant operations and the second one is to apply the reducer patterns.
148 
149   // In the first phase, we don't apply any patterns so that we only select the
150   // range of operations to keep to the module stay interesting.
151   findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
152                             /*eraseOpNotInRange=*/true);
153   // In the second phase, we suppose that no operation is redundant, so we try
154   // to rewrite the operation into simpler form.
155   findOptimal<IteratorType>(module, region, patterns, test,
156                             /*eraseOpNotInRange=*/false);
157 }
158 
159 namespace {
160 
161 //===----------------------------------------------------------------------===//
162 // Reduction Pattern Interface Collection
163 //===----------------------------------------------------------------------===//
164 
165 class ReductionPatternInterfaceCollection
166     : public DialectInterfaceCollection<DialectReductionPatternInterface> {
167 public:
168   using Base::Base;
169 
170   // Collect the reduce patterns defined by each dialect.
171   void populateReductionPatterns(RewritePatternSet &pattern) const {
172     for (const DialectReductionPatternInterface &interface : *this)
173       interface.populateReductionPatterns(pattern);
174   }
175 };
176 
177 //===----------------------------------------------------------------------===//
178 // ReductionTreePass
179 //===----------------------------------------------------------------------===//
180 
181 /// This class defines the Reduction Tree Pass. It provides a framework to
182 /// to implement a reduction pass using a tree structure to keep track of the
183 /// generated reduced variants.
184 class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> {
185 public:
186   ReductionTreePass() = default;
187   ReductionTreePass(const ReductionTreePass &pass) = default;
188 
189   LogicalResult initialize(MLIRContext *context) override;
190 
191   /// Runs the pass instance in the pass pipeline.
192   void runOnOperation() override;
193 
194 private:
195   void reduceOp(ModuleOp module, Region &region);
196 
197   FrozenRewritePatternSet reducerPatterns;
198 };
199 
200 } // end anonymous namespace
201 
202 LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
203   RewritePatternSet patterns(context);
204   ReductionPatternInterfaceCollection reducePatternCollection(context);
205   reducePatternCollection.populateReductionPatterns(patterns);
206   reducerPatterns = std::move(patterns);
207   return success();
208 }
209 
210 void ReductionTreePass::runOnOperation() {
211   Operation *topOperation = getOperation();
212   while (topOperation->getParentOp() != nullptr)
213     topOperation = topOperation->getParentOp();
214   ModuleOp module = cast<ModuleOp>(topOperation);
215 
216   SmallVector<Operation *, 8> workList;
217   workList.push_back(getOperation());
218 
219   do {
220     Operation *op = workList.pop_back_val();
221 
222     for (Region &region : op->getRegions())
223       if (!region.empty())
224         reduceOp(module, region);
225 
226     for (Region &region : op->getRegions())
227       for (Operation &op : region.getOps())
228         if (op.getNumRegions() != 0)
229           workList.push_back(&op);
230   } while (!workList.empty());
231 }
232 
233 void ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
234   Tester test(testerName, testerArgs);
235   switch (traversalModeId) {
236   case TraversalMode::SinglePath:
237     findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
238         module, region, reducerPatterns, test);
239     break;
240   default:
241     llvm_unreachable("Unsupported mode");
242   }
243 }
244 
245 std::unique_ptr<Pass> mlir::createReductionTreePass() {
246   return std::make_unique<ReductionTreePass>();
247 }
248