xref: /llvm-project/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test SCF dialect utils.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
18 #include "mlir/Dialect/SCF/Utils/Utils.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 struct TestSCFForUtilsPass
28     : public PassWrapper<TestSCFForUtilsPass, OperationPass<func::FuncOp>> {
29   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForUtilsPass)
30 
31   StringRef getArgument() const final { return "test-scf-for-utils"; }
32   StringRef getDescription() const final { return "test scf.for utils"; }
33   explicit TestSCFForUtilsPass() = default;
34   TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {}
35 
36   Option<bool> testReplaceWithNewYields{
37       *this, "test-replace-with-new-yields",
38       llvm::cl::desc("Test replacing a loop with a new loop that returns new "
39                      "additional yield values"),
40       llvm::cl::init(false)};
41 
42   void runOnOperation() override {
43     func::FuncOp func = getOperation();
44     SmallVector<scf::ForOp, 4> toErase;
45 
46     if (testReplaceWithNewYields) {
47       func.walk([&](scf::ForOp forOp) {
48         if (forOp.getNumResults() == 0)
49           return;
50         auto newInitValues = forOp.getInitArgs();
51         if (newInitValues.empty())
52           return;
53         SmallVector<Value> oldYieldValues =
54             llvm::to_vector(forOp.getYieldedValues());
55         NewYieldValuesFn fn = [&](OpBuilder &b, Location loc,
56                                   ArrayRef<BlockArgument> newBBArgs) {
57           SmallVector<Value> newYieldValues;
58           for (auto yieldVal : oldYieldValues) {
59             newYieldValues.push_back(
60                 b.create<arith::AddFOp>(loc, yieldVal, yieldVal));
61           }
62           return newYieldValues;
63         };
64         IRRewriter rewriter(forOp.getContext());
65         if (failed(forOp.replaceWithAdditionalYields(
66                 rewriter, newInitValues, /*replaceInitOperandUsesInLoop=*/true,
67                 fn)))
68           signalPassFailure();
69       });
70     }
71   }
72 };
73 
74 struct TestSCFIfUtilsPass
75     : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> {
76   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFIfUtilsPass)
77 
78   StringRef getArgument() const final { return "test-scf-if-utils"; }
79   StringRef getDescription() const final { return "test scf.if utils"; }
80   explicit TestSCFIfUtilsPass() = default;
81 
82   void getDependentDialects(DialectRegistry &registry) const override {
83     registry.insert<func::FuncDialect>();
84   }
85 
86   void runOnOperation() override {
87     int count = 0;
88     getOperation().walk([&](scf::IfOp ifOp) {
89       auto strCount = std::to_string(count++);
90       func::FuncOp thenFn, elseFn;
91       OpBuilder b(ifOp);
92       IRRewriter rewriter(b);
93       if (failed(outlineIfOp(rewriter, ifOp, &thenFn,
94                              std::string("outlined_then") + strCount, &elseFn,
95                              std::string("outlined_else") + strCount))) {
96         this->signalPassFailure();
97         return WalkResult::interrupt();
98       }
99       return WalkResult::advance();
100     });
101   }
102 };
103 
104 static const StringLiteral kTestPipeliningLoopMarker =
105     "__test_pipelining_loop__";
106 static const StringLiteral kTestPipeliningStageMarker =
107     "__test_pipelining_stage__";
108 /// Marker to express the order in which operations should be after
109 /// pipelining.
110 static const StringLiteral kTestPipeliningOpOrderMarker =
111     "__test_pipelining_op_order__";
112 
113 static const StringLiteral kTestPipeliningAnnotationPart =
114     "__test_pipelining_part";
115 static const StringLiteral kTestPipeliningAnnotationIteration =
116     "__test_pipelining_iteration";
117 
118 struct TestSCFPipeliningPass
119     : public PassWrapper<TestSCFPipeliningPass, OperationPass<func::FuncOp>> {
120   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPipeliningPass)
121 
122   TestSCFPipeliningPass() = default;
123   TestSCFPipeliningPass(const TestSCFPipeliningPass &) {}
124   StringRef getArgument() const final { return "test-scf-pipelining"; }
125   StringRef getDescription() const final { return "test scf.forOp pipelining"; }
126 
127   Option<bool> annotatePipeline{
128       *this, "annotate",
129       llvm::cl::desc("Annote operations during loop pipelining transformation"),
130       llvm::cl::init(false)};
131 
132   Option<bool> noEpiloguePeeling{
133       *this, "no-epilogue-peeling",
134       llvm::cl::desc("Use predicates instead of peeling the epilogue."),
135       llvm::cl::init(false)};
136 
137   static void
138   getSchedule(scf::ForOp forOp,
139               std::vector<std::pair<Operation *, unsigned>> &schedule) {
140     if (!forOp->hasAttr(kTestPipeliningLoopMarker))
141       return;
142 
143     schedule.resize(forOp.getBody()->getOperations().size() - 1);
144     forOp.walk([&schedule](Operation *op) {
145       auto attrStage =
146           op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker);
147       auto attrCycle =
148           op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
149       if (attrCycle && attrStage) {
150         // TODO: Index can be out-of-bounds if ops of the loop body disappear
151         // due to folding.
152         schedule[attrCycle.getInt()] =
153             std::make_pair(op, unsigned(attrStage.getInt()));
154       }
155     });
156   }
157 
158   /// Helper to generate "predicated" version of `op`. For simplicity we just
159   /// wrap the operation in a scf.ifOp operation.
160   static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
161                                 Value pred) {
162     Location loc = op->getLoc();
163     auto ifOp =
164         rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
165     // True branch.
166     rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
167                           ifOp.getThenRegion().front().begin());
168     rewriter.setInsertionPointAfter(op);
169     if (op->getNumResults() > 0)
170       rewriter.create<scf::YieldOp>(loc, op->getResults());
171     // False branch.
172     rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
173     SmallVector<Value> elseYieldOperands;
174     elseYieldOperands.reserve(ifOp.getNumResults());
175     if (auto viewOp = dyn_cast<memref::SubViewOp>(op)) {
176       // For sub-views, just clone the op.
177       // NOTE: This is okay in the test because we use dynamic memref sizes, so
178       // the verifier will not complain. Otherwise, we may create a logically
179       // out-of-bounds view and a different technique should be used.
180       Operation *opClone = rewriter.clone(*op);
181       elseYieldOperands.append(opClone->result_begin(), opClone->result_end());
182     } else {
183       // Default to assuming constant numeric values.
184       for (Type type : op->getResultTypes()) {
185         elseYieldOperands.push_back(rewriter.create<arith::ConstantOp>(
186             loc, rewriter.getZeroAttr(type)));
187       }
188     }
189     if (op->getNumResults() > 0)
190       rewriter.create<scf::YieldOp>(loc, elseYieldOperands);
191     return ifOp.getOperation();
192   }
193 
194   static void annotate(Operation *op,
195                        mlir::scf::PipeliningOption::PipelinerPart part,
196                        unsigned iteration) {
197     OpBuilder b(op);
198     switch (part) {
199     case mlir::scf::PipeliningOption::PipelinerPart::Prologue:
200       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue"));
201       break;
202     case mlir::scf::PipeliningOption::PipelinerPart::Kernel:
203       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel"));
204       break;
205     case mlir::scf::PipeliningOption::PipelinerPart::Epilogue:
206       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue"));
207       break;
208     }
209     op->setAttr(kTestPipeliningAnnotationIteration,
210                 b.getI32IntegerAttr(iteration));
211   }
212 
213   void getDependentDialects(DialectRegistry &registry) const override {
214     registry.insert<arith::ArithDialect, memref::MemRefDialect>();
215   }
216 
217   void runOnOperation() override {
218     RewritePatternSet patterns(&getContext());
219     mlir::scf::PipeliningOption options;
220     options.getScheduleFn = getSchedule;
221     options.supportDynamicLoops = true;
222     options.predicateFn = predicateOp;
223     if (annotatePipeline)
224       options.annotateFn = annotate;
225     if (noEpiloguePeeling) {
226       options.peelEpilogue = false;
227     }
228     scf::populateSCFLoopPipeliningPatterns(patterns, options);
229     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
230     getOperation().walk([](Operation *op) {
231       // Clean up the markers.
232       op->removeAttr(kTestPipeliningStageMarker);
233       op->removeAttr(kTestPipeliningOpOrderMarker);
234     });
235   }
236 };
237 } // namespace
238 
239 namespace mlir {
240 namespace test {
241 void registerTestSCFUtilsPass() {
242   PassRegistration<TestSCFForUtilsPass>();
243   PassRegistration<TestSCFIfUtilsPass>();
244   PassRegistration<TestSCFPipeliningPass>();
245 }
246 } // namespace test
247 } // namespace mlir
248