xref: /llvm-project/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
13fef2d26SRiver Riddle //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
23fef2d26SRiver Riddle //
33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63fef2d26SRiver Riddle //
73fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
83fef2d26SRiver Riddle //
93fef2d26SRiver Riddle // This file implements a pass to test SCF dialect utils.
103fef2d26SRiver Riddle //
113fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
123fef2d26SRiver Riddle 
13abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1436550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
15f5fe92f6SChristopher Bate #include "mlir/Dialect/MemRef/IR/MemRef.h"
168b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
174a6b31b8SAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Patterns.h"
18f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h"
193fef2d26SRiver Riddle #include "mlir/IR/Builders.h"
20f6f88e66Sthomasraoux #include "mlir/IR/PatternMatch.h"
213fef2d26SRiver Riddle #include "mlir/Pass/Pass.h"
22f6f88e66Sthomasraoux #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
233fef2d26SRiver Riddle 
243fef2d26SRiver Riddle using namespace mlir;
253fef2d26SRiver Riddle 
263fef2d26SRiver Riddle namespace {
275e50dd04SRiver Riddle struct TestSCFForUtilsPass
2858ceae95SRiver Riddle     : public PassWrapper<TestSCFForUtilsPass, OperationPass<func::FuncOp>> {
295e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForUtilsPass)
305e50dd04SRiver Riddle 
31b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-scf-for-utils"; }
32b5e22e6dSMehdi Amini   StringRef getDescription() const final { return "test scf.for utils"; }
33b5e22e6dSMehdi Amini   explicit TestSCFForUtilsPass() = default;
34567fd523SMahesh Ravishankar   TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {}
35567fd523SMahesh Ravishankar 
36567fd523SMahesh Ravishankar   Option<bool> testReplaceWithNewYields{
37567fd523SMahesh Ravishankar       *this, "test-replace-with-new-yields",
38567fd523SMahesh Ravishankar       llvm::cl::desc("Test replacing a loop with a new loop that returns new "
3947fbb247SIngo Müller                      "additional yield values"),
40567fd523SMahesh Ravishankar       llvm::cl::init(false)};
413fef2d26SRiver Riddle 
4241574554SRiver Riddle   void runOnOperation() override {
4358ceae95SRiver Riddle     func::FuncOp func = getOperation();
443fef2d26SRiver Riddle     SmallVector<scf::ForOp, 4> toErase;
453fef2d26SRiver Riddle 
46567fd523SMahesh Ravishankar     if (testReplaceWithNewYields) {
47567fd523SMahesh Ravishankar       func.walk([&](scf::ForOp forOp) {
48567fd523SMahesh Ravishankar         if (forOp.getNumResults() == 0)
49567fd523SMahesh Ravishankar           return;
50567fd523SMahesh Ravishankar         auto newInitValues = forOp.getInitArgs();
51567fd523SMahesh Ravishankar         if (newInitValues.empty())
52567fd523SMahesh Ravishankar           return;
53ab737a86SMatthias Springer         SmallVector<Value> oldYieldValues =
54ab737a86SMatthias Springer             llvm::to_vector(forOp.getYieldedValues());
5563086d6aSMatthias Springer         NewYieldValuesFn fn = [&](OpBuilder &b, Location loc,
56567fd523SMahesh Ravishankar                                   ArrayRef<BlockArgument> newBBArgs) {
57567fd523SMahesh Ravishankar           SmallVector<Value> newYieldValues;
5863086d6aSMatthias Springer           for (auto yieldVal : oldYieldValues) {
59567fd523SMahesh Ravishankar             newYieldValues.push_back(
60567fd523SMahesh Ravishankar                 b.create<arith::AddFOp>(loc, yieldVal, yieldVal));
61567fd523SMahesh Ravishankar           }
62567fd523SMahesh Ravishankar           return newYieldValues;
63567fd523SMahesh Ravishankar         };
6463086d6aSMatthias Springer         IRRewriter rewriter(forOp.getContext());
6563086d6aSMatthias Springer         if (failed(forOp.replaceWithAdditionalYields(
6663086d6aSMatthias Springer                 rewriter, newInitValues, /*replaceInitOperandUsesInLoop=*/true,
6763086d6aSMatthias Springer                 fn)))
6863086d6aSMatthias Springer           signalPassFailure();
69567fd523SMahesh Ravishankar       });
70567fd523SMahesh Ravishankar     }
71567fd523SMahesh Ravishankar   }
723fef2d26SRiver Riddle };
733fef2d26SRiver Riddle 
745e50dd04SRiver Riddle struct TestSCFIfUtilsPass
7511b67aafSNicolas Vasilache     : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> {
765e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFIfUtilsPass)
775e50dd04SRiver Riddle 
78b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-scf-if-utils"; }
79b5e22e6dSMehdi Amini   StringRef getDescription() const final { return "test scf.if utils"; }
80b5e22e6dSMehdi Amini   explicit TestSCFIfUtilsPass() = default;
813fef2d26SRiver Riddle 
82701f2409Sdonald chen   void getDependentDialects(DialectRegistry &registry) const override {
83701f2409Sdonald chen     registry.insert<func::FuncDialect>();
84701f2409Sdonald chen   }
85701f2409Sdonald chen 
8611b67aafSNicolas Vasilache   void runOnOperation() override {
873fef2d26SRiver Riddle     int count = 0;
8811b67aafSNicolas Vasilache     getOperation().walk([&](scf::IfOp ifOp) {
893fef2d26SRiver Riddle       auto strCount = std::to_string(count++);
9058ceae95SRiver Riddle       func::FuncOp thenFn, elseFn;
913fef2d26SRiver Riddle       OpBuilder b(ifOp);
9211b67aafSNicolas Vasilache       IRRewriter rewriter(b);
9311b67aafSNicolas Vasilache       if (failed(outlineIfOp(rewriter, ifOp, &thenFn,
9411b67aafSNicolas Vasilache                              std::string("outlined_then") + strCount, &elseFn,
9511b67aafSNicolas Vasilache                              std::string("outlined_else") + strCount))) {
9611b67aafSNicolas Vasilache         this->signalPassFailure();
9711b67aafSNicolas Vasilache         return WalkResult::interrupt();
9811b67aafSNicolas Vasilache       }
9911b67aafSNicolas Vasilache       return WalkResult::advance();
1003fef2d26SRiver Riddle     });
1013fef2d26SRiver Riddle   }
1023fef2d26SRiver Riddle };
103f6f88e66Sthomasraoux 
104f6f88e66Sthomasraoux static const StringLiteral kTestPipeliningLoopMarker =
105f6f88e66Sthomasraoux     "__test_pipelining_loop__";
106f6f88e66Sthomasraoux static const StringLiteral kTestPipeliningStageMarker =
107f6f88e66Sthomasraoux     "__test_pipelining_stage__";
108567fd523SMahesh Ravishankar /// Marker to express the order in which operations should be after
109567fd523SMahesh Ravishankar /// pipelining.
110f6f88e66Sthomasraoux static const StringLiteral kTestPipeliningOpOrderMarker =
111f6f88e66Sthomasraoux     "__test_pipelining_op_order__";
112f6f88e66Sthomasraoux 
1130736bbd7SThomas Raoux static const StringLiteral kTestPipeliningAnnotationPart =
1140736bbd7SThomas Raoux     "__test_pipelining_part";
1150736bbd7SThomas Raoux static const StringLiteral kTestPipeliningAnnotationIteration =
1160736bbd7SThomas Raoux     "__test_pipelining_iteration";
1170736bbd7SThomas Raoux 
1185e50dd04SRiver Riddle struct TestSCFPipeliningPass
11958ceae95SRiver Riddle     : public PassWrapper<TestSCFPipeliningPass, OperationPass<func::FuncOp>> {
1205e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPipeliningPass)
1215e50dd04SRiver Riddle 
1220736bbd7SThomas Raoux   TestSCFPipeliningPass() = default;
1230736bbd7SThomas Raoux   TestSCFPipeliningPass(const TestSCFPipeliningPass &) {}
124f6f88e66Sthomasraoux   StringRef getArgument() const final { return "test-scf-pipelining"; }
125f6f88e66Sthomasraoux   StringRef getDescription() const final { return "test scf.forOp pipelining"; }
1260736bbd7SThomas Raoux 
1270736bbd7SThomas Raoux   Option<bool> annotatePipeline{
1280736bbd7SThomas Raoux       *this, "annotate",
1290736bbd7SThomas Raoux       llvm::cl::desc("Annote operations during loop pipelining transformation"),
1300736bbd7SThomas Raoux       llvm::cl::init(false)};
131f6f88e66Sthomasraoux 
132205c08b5SThomas Raoux   Option<bool> noEpiloguePeeling{
133205c08b5SThomas Raoux       *this, "no-epilogue-peeling",
134205c08b5SThomas Raoux       llvm::cl::desc("Use predicates instead of peeling the epilogue."),
135205c08b5SThomas Raoux       llvm::cl::init(false)};
136205c08b5SThomas Raoux 
137f6f88e66Sthomasraoux   static void
138f6f88e66Sthomasraoux   getSchedule(scf::ForOp forOp,
139f6f88e66Sthomasraoux               std::vector<std::pair<Operation *, unsigned>> &schedule) {
140f6f88e66Sthomasraoux     if (!forOp->hasAttr(kTestPipeliningLoopMarker))
141f6f88e66Sthomasraoux       return;
142f5fe92f6SChristopher Bate 
143f6f88e66Sthomasraoux     schedule.resize(forOp.getBody()->getOperations().size() - 1);
144f6f88e66Sthomasraoux     forOp.walk([&schedule](Operation *op) {
145f6f88e66Sthomasraoux       auto attrStage =
146f6f88e66Sthomasraoux           op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker);
147f6f88e66Sthomasraoux       auto attrCycle =
148f6f88e66Sthomasraoux           op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
149f6f88e66Sthomasraoux       if (attrCycle && attrStage) {
150ed9194beSMatthias Springer         // TODO: Index can be out-of-bounds if ops of the loop body disappear
151ed9194beSMatthias Springer         // due to folding.
152f6f88e66Sthomasraoux         schedule[attrCycle.getInt()] =
153f6f88e66Sthomasraoux             std::make_pair(op, unsigned(attrStage.getInt()));
154f6f88e66Sthomasraoux       }
155f6f88e66Sthomasraoux     });
156f6f88e66Sthomasraoux   }
157f6f88e66Sthomasraoux 
158205c08b5SThomas Raoux   /// Helper to generate "predicated" version of `op`. For simplicity we just
159205c08b5SThomas Raoux   /// wrap the operation in a scf.ifOp operation.
1601cff4cbdSNicolas Vasilache   static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
1611cff4cbdSNicolas Vasilache                                 Value pred) {
162205c08b5SThomas Raoux     Location loc = op->getLoc();
163205c08b5SThomas Raoux     auto ifOp =
164205c08b5SThomas Raoux         rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
165205c08b5SThomas Raoux     // True branch.
1665cc0f76dSMatthias Springer     rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
167f5fe92f6SChristopher Bate                           ifOp.getThenRegion().front().begin());
168205c08b5SThomas Raoux     rewriter.setInsertionPointAfter(op);
169f5fe92f6SChristopher Bate     if (op->getNumResults() > 0)
170205c08b5SThomas Raoux       rewriter.create<scf::YieldOp>(loc, op->getResults());
171205c08b5SThomas Raoux     // False branch.
172205c08b5SThomas Raoux     rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
173f5fe92f6SChristopher Bate     SmallVector<Value> elseYieldOperands;
174f5fe92f6SChristopher Bate     elseYieldOperands.reserve(ifOp.getNumResults());
175f5fe92f6SChristopher Bate     if (auto viewOp = dyn_cast<memref::SubViewOp>(op)) {
176f5fe92f6SChristopher Bate       // For sub-views, just clone the op.
177f5fe92f6SChristopher Bate       // NOTE: This is okay in the test because we use dynamic memref sizes, so
178f5fe92f6SChristopher Bate       // the verifier will not complain. Otherwise, we may create a logically
179f5fe92f6SChristopher Bate       // out-of-bounds view and a different technique should be used.
180f5fe92f6SChristopher Bate       Operation *opClone = rewriter.clone(*op);
181f5fe92f6SChristopher Bate       elseYieldOperands.append(opClone->result_begin(), opClone->result_end());
182f5fe92f6SChristopher Bate     } else {
183f5fe92f6SChristopher Bate       // Default to assuming constant numeric values.
184205c08b5SThomas Raoux       for (Type type : op->getResultTypes()) {
185f5fe92f6SChristopher Bate         elseYieldOperands.push_back(rewriter.create<arith::ConstantOp>(
186f5fe92f6SChristopher Bate             loc, rewriter.getZeroAttr(type)));
187205c08b5SThomas Raoux       }
188f5fe92f6SChristopher Bate     }
189f5fe92f6SChristopher Bate     if (op->getNumResults() > 0)
190f5fe92f6SChristopher Bate       rewriter.create<scf::YieldOp>(loc, elseYieldOperands);
191205c08b5SThomas Raoux     return ifOp.getOperation();
192205c08b5SThomas Raoux   }
193205c08b5SThomas Raoux 
1940736bbd7SThomas Raoux   static void annotate(Operation *op,
1950736bbd7SThomas Raoux                        mlir::scf::PipeliningOption::PipelinerPart part,
1960736bbd7SThomas Raoux                        unsigned iteration) {
1970736bbd7SThomas Raoux     OpBuilder b(op);
1980736bbd7SThomas Raoux     switch (part) {
1990736bbd7SThomas Raoux     case mlir::scf::PipeliningOption::PipelinerPart::Prologue:
2000736bbd7SThomas Raoux       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue"));
2010736bbd7SThomas Raoux       break;
2020736bbd7SThomas Raoux     case mlir::scf::PipeliningOption::PipelinerPart::Kernel:
2030736bbd7SThomas Raoux       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel"));
2040736bbd7SThomas Raoux       break;
2050736bbd7SThomas Raoux     case mlir::scf::PipeliningOption::PipelinerPart::Epilogue:
2060736bbd7SThomas Raoux       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue"));
2070736bbd7SThomas Raoux       break;
2080736bbd7SThomas Raoux     }
2090736bbd7SThomas Raoux     op->setAttr(kTestPipeliningAnnotationIteration,
2100736bbd7SThomas Raoux                 b.getI32IntegerAttr(iteration));
2110736bbd7SThomas Raoux   }
2120736bbd7SThomas Raoux 
213a54f4eaeSMogball   void getDependentDialects(DialectRegistry &registry) const override {
214abc362a1SJakub Kuderski     registry.insert<arith::ArithDialect, memref::MemRefDialect>();
215a54f4eaeSMogball   }
216a54f4eaeSMogball 
21741574554SRiver Riddle   void runOnOperation() override {
218f6f88e66Sthomasraoux     RewritePatternSet patterns(&getContext());
219f6f88e66Sthomasraoux     mlir::scf::PipeliningOption options;
220f6f88e66Sthomasraoux     options.getScheduleFn = getSchedule;
221ebf05993SSJW     options.supportDynamicLoops = true;
222ebf05993SSJW     options.predicateFn = predicateOp;
2230736bbd7SThomas Raoux     if (annotatePipeline)
2240736bbd7SThomas Raoux       options.annotateFn = annotate;
225205c08b5SThomas Raoux     if (noEpiloguePeeling) {
226205c08b5SThomas Raoux       options.peelEpilogue = false;
227205c08b5SThomas Raoux     }
228f6f88e66Sthomasraoux     scf::populateSCFLoopPipeliningPatterns(patterns, options);
229*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
23041574554SRiver Riddle     getOperation().walk([](Operation *op) {
231f6f88e66Sthomasraoux       // Clean up the markers.
232f6f88e66Sthomasraoux       op->removeAttr(kTestPipeliningStageMarker);
233f6f88e66Sthomasraoux       op->removeAttr(kTestPipeliningOpOrderMarker);
234f6f88e66Sthomasraoux     });
235f6f88e66Sthomasraoux   }
236f6f88e66Sthomasraoux };
2373fef2d26SRiver Riddle } // namespace
2383fef2d26SRiver Riddle 
2393fef2d26SRiver Riddle namespace mlir {
2403fef2d26SRiver Riddle namespace test {
2413fef2d26SRiver Riddle void registerTestSCFUtilsPass() {
242b5e22e6dSMehdi Amini   PassRegistration<TestSCFForUtilsPass>();
243b5e22e6dSMehdi Amini   PassRegistration<TestSCFIfUtilsPass>();
244f6f88e66Sthomasraoux   PassRegistration<TestSCFPipeliningPass>();
2453fef2d26SRiver Riddle }
2463fef2d26SRiver Riddle } // namespace test
2473fef2d26SRiver Riddle } // namespace mlir
248