13fef2d26SRiver Riddle //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===// 23fef2d26SRiver Riddle // 33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 63fef2d26SRiver Riddle // 73fef2d26SRiver Riddle //===----------------------------------------------------------------------===// 83fef2d26SRiver Riddle // 93fef2d26SRiver Riddle // This file implements a pass to test SCF dialect utils. 103fef2d26SRiver Riddle // 113fef2d26SRiver Riddle //===----------------------------------------------------------------------===// 123fef2d26SRiver Riddle 13abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 1436550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 15f5fe92f6SChristopher Bate #include "mlir/Dialect/MemRef/IR/MemRef.h" 168b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 174a6b31b8SAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Patterns.h" 18f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h" 193fef2d26SRiver Riddle #include "mlir/IR/Builders.h" 20f6f88e66Sthomasraoux #include "mlir/IR/PatternMatch.h" 213fef2d26SRiver Riddle #include "mlir/Pass/Pass.h" 22f6f88e66Sthomasraoux #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 233fef2d26SRiver Riddle 243fef2d26SRiver Riddle using namespace mlir; 253fef2d26SRiver Riddle 263fef2d26SRiver Riddle namespace { 275e50dd04SRiver Riddle struct TestSCFForUtilsPass 2858ceae95SRiver Riddle : public PassWrapper<TestSCFForUtilsPass, OperationPass<func::FuncOp>> { 295e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForUtilsPass) 305e50dd04SRiver Riddle 31b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-scf-for-utils"; } 32b5e22e6dSMehdi Amini StringRef getDescription() const final { return "test scf.for utils"; } 33b5e22e6dSMehdi Amini explicit TestSCFForUtilsPass() = default; 34567fd523SMahesh Ravishankar TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {} 35567fd523SMahesh Ravishankar 36567fd523SMahesh Ravishankar Option<bool> testReplaceWithNewYields{ 37567fd523SMahesh Ravishankar *this, "test-replace-with-new-yields", 38567fd523SMahesh Ravishankar llvm::cl::desc("Test replacing a loop with a new loop that returns new " 3947fbb247SIngo Müller "additional yield values"), 40567fd523SMahesh Ravishankar llvm::cl::init(false)}; 413fef2d26SRiver Riddle 4241574554SRiver Riddle void runOnOperation() override { 4358ceae95SRiver Riddle func::FuncOp func = getOperation(); 443fef2d26SRiver Riddle SmallVector<scf::ForOp, 4> toErase; 453fef2d26SRiver Riddle 46567fd523SMahesh Ravishankar if (testReplaceWithNewYields) { 47567fd523SMahesh Ravishankar func.walk([&](scf::ForOp forOp) { 48567fd523SMahesh Ravishankar if (forOp.getNumResults() == 0) 49567fd523SMahesh Ravishankar return; 50567fd523SMahesh Ravishankar auto newInitValues = forOp.getInitArgs(); 51567fd523SMahesh Ravishankar if (newInitValues.empty()) 52567fd523SMahesh Ravishankar return; 53ab737a86SMatthias Springer SmallVector<Value> oldYieldValues = 54ab737a86SMatthias Springer llvm::to_vector(forOp.getYieldedValues()); 5563086d6aSMatthias Springer NewYieldValuesFn fn = [&](OpBuilder &b, Location loc, 56567fd523SMahesh Ravishankar ArrayRef<BlockArgument> newBBArgs) { 57567fd523SMahesh Ravishankar SmallVector<Value> newYieldValues; 5863086d6aSMatthias Springer for (auto yieldVal : oldYieldValues) { 59567fd523SMahesh Ravishankar newYieldValues.push_back( 60567fd523SMahesh Ravishankar b.create<arith::AddFOp>(loc, yieldVal, yieldVal)); 61567fd523SMahesh Ravishankar } 62567fd523SMahesh Ravishankar return newYieldValues; 63567fd523SMahesh Ravishankar }; 6463086d6aSMatthias Springer IRRewriter rewriter(forOp.getContext()); 6563086d6aSMatthias Springer if (failed(forOp.replaceWithAdditionalYields( 6663086d6aSMatthias Springer rewriter, newInitValues, /*replaceInitOperandUsesInLoop=*/true, 6763086d6aSMatthias Springer fn))) 6863086d6aSMatthias Springer signalPassFailure(); 69567fd523SMahesh Ravishankar }); 70567fd523SMahesh Ravishankar } 71567fd523SMahesh Ravishankar } 723fef2d26SRiver Riddle }; 733fef2d26SRiver Riddle 745e50dd04SRiver Riddle struct TestSCFIfUtilsPass 7511b67aafSNicolas Vasilache : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> { 765e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFIfUtilsPass) 775e50dd04SRiver Riddle 78b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-scf-if-utils"; } 79b5e22e6dSMehdi Amini StringRef getDescription() const final { return "test scf.if utils"; } 80b5e22e6dSMehdi Amini explicit TestSCFIfUtilsPass() = default; 813fef2d26SRiver Riddle 82701f2409Sdonald chen void getDependentDialects(DialectRegistry ®istry) const override { 83701f2409Sdonald chen registry.insert<func::FuncDialect>(); 84701f2409Sdonald chen } 85701f2409Sdonald chen 8611b67aafSNicolas Vasilache void runOnOperation() override { 873fef2d26SRiver Riddle int count = 0; 8811b67aafSNicolas Vasilache getOperation().walk([&](scf::IfOp ifOp) { 893fef2d26SRiver Riddle auto strCount = std::to_string(count++); 9058ceae95SRiver Riddle func::FuncOp thenFn, elseFn; 913fef2d26SRiver Riddle OpBuilder b(ifOp); 9211b67aafSNicolas Vasilache IRRewriter rewriter(b); 9311b67aafSNicolas Vasilache if (failed(outlineIfOp(rewriter, ifOp, &thenFn, 9411b67aafSNicolas Vasilache std::string("outlined_then") + strCount, &elseFn, 9511b67aafSNicolas Vasilache std::string("outlined_else") + strCount))) { 9611b67aafSNicolas Vasilache this->signalPassFailure(); 9711b67aafSNicolas Vasilache return WalkResult::interrupt(); 9811b67aafSNicolas Vasilache } 9911b67aafSNicolas Vasilache return WalkResult::advance(); 1003fef2d26SRiver Riddle }); 1013fef2d26SRiver Riddle } 1023fef2d26SRiver Riddle }; 103f6f88e66Sthomasraoux 104f6f88e66Sthomasraoux static const StringLiteral kTestPipeliningLoopMarker = 105f6f88e66Sthomasraoux "__test_pipelining_loop__"; 106f6f88e66Sthomasraoux static const StringLiteral kTestPipeliningStageMarker = 107f6f88e66Sthomasraoux "__test_pipelining_stage__"; 108567fd523SMahesh Ravishankar /// Marker to express the order in which operations should be after 109567fd523SMahesh Ravishankar /// pipelining. 110f6f88e66Sthomasraoux static const StringLiteral kTestPipeliningOpOrderMarker = 111f6f88e66Sthomasraoux "__test_pipelining_op_order__"; 112f6f88e66Sthomasraoux 1130736bbd7SThomas Raoux static const StringLiteral kTestPipeliningAnnotationPart = 1140736bbd7SThomas Raoux "__test_pipelining_part"; 1150736bbd7SThomas Raoux static const StringLiteral kTestPipeliningAnnotationIteration = 1160736bbd7SThomas Raoux "__test_pipelining_iteration"; 1170736bbd7SThomas Raoux 1185e50dd04SRiver Riddle struct TestSCFPipeliningPass 11958ceae95SRiver Riddle : public PassWrapper<TestSCFPipeliningPass, OperationPass<func::FuncOp>> { 1205e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPipeliningPass) 1215e50dd04SRiver Riddle 1220736bbd7SThomas Raoux TestSCFPipeliningPass() = default; 1230736bbd7SThomas Raoux TestSCFPipeliningPass(const TestSCFPipeliningPass &) {} 124f6f88e66Sthomasraoux StringRef getArgument() const final { return "test-scf-pipelining"; } 125f6f88e66Sthomasraoux StringRef getDescription() const final { return "test scf.forOp pipelining"; } 1260736bbd7SThomas Raoux 1270736bbd7SThomas Raoux Option<bool> annotatePipeline{ 1280736bbd7SThomas Raoux *this, "annotate", 1290736bbd7SThomas Raoux llvm::cl::desc("Annote operations during loop pipelining transformation"), 1300736bbd7SThomas Raoux llvm::cl::init(false)}; 131f6f88e66Sthomasraoux 132205c08b5SThomas Raoux Option<bool> noEpiloguePeeling{ 133205c08b5SThomas Raoux *this, "no-epilogue-peeling", 134205c08b5SThomas Raoux llvm::cl::desc("Use predicates instead of peeling the epilogue."), 135205c08b5SThomas Raoux llvm::cl::init(false)}; 136205c08b5SThomas Raoux 137f6f88e66Sthomasraoux static void 138f6f88e66Sthomasraoux getSchedule(scf::ForOp forOp, 139f6f88e66Sthomasraoux std::vector<std::pair<Operation *, unsigned>> &schedule) { 140f6f88e66Sthomasraoux if (!forOp->hasAttr(kTestPipeliningLoopMarker)) 141f6f88e66Sthomasraoux return; 142f5fe92f6SChristopher Bate 143f6f88e66Sthomasraoux schedule.resize(forOp.getBody()->getOperations().size() - 1); 144f6f88e66Sthomasraoux forOp.walk([&schedule](Operation *op) { 145f6f88e66Sthomasraoux auto attrStage = 146f6f88e66Sthomasraoux op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker); 147f6f88e66Sthomasraoux auto attrCycle = 148f6f88e66Sthomasraoux op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker); 149f6f88e66Sthomasraoux if (attrCycle && attrStage) { 150ed9194beSMatthias Springer // TODO: Index can be out-of-bounds if ops of the loop body disappear 151ed9194beSMatthias Springer // due to folding. 152f6f88e66Sthomasraoux schedule[attrCycle.getInt()] = 153f6f88e66Sthomasraoux std::make_pair(op, unsigned(attrStage.getInt())); 154f6f88e66Sthomasraoux } 155f6f88e66Sthomasraoux }); 156f6f88e66Sthomasraoux } 157f6f88e66Sthomasraoux 158205c08b5SThomas Raoux /// Helper to generate "predicated" version of `op`. For simplicity we just 159205c08b5SThomas Raoux /// wrap the operation in a scf.ifOp operation. 1601cff4cbdSNicolas Vasilache static Operation *predicateOp(RewriterBase &rewriter, Operation *op, 1611cff4cbdSNicolas Vasilache Value pred) { 162205c08b5SThomas Raoux Location loc = op->getLoc(); 163205c08b5SThomas Raoux auto ifOp = 164205c08b5SThomas Raoux rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true); 165205c08b5SThomas Raoux // True branch. 1665cc0f76dSMatthias Springer rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(), 167f5fe92f6SChristopher Bate ifOp.getThenRegion().front().begin()); 168205c08b5SThomas Raoux rewriter.setInsertionPointAfter(op); 169f5fe92f6SChristopher Bate if (op->getNumResults() > 0) 170205c08b5SThomas Raoux rewriter.create<scf::YieldOp>(loc, op->getResults()); 171205c08b5SThomas Raoux // False branch. 172205c08b5SThomas Raoux rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 173f5fe92f6SChristopher Bate SmallVector<Value> elseYieldOperands; 174f5fe92f6SChristopher Bate elseYieldOperands.reserve(ifOp.getNumResults()); 175f5fe92f6SChristopher Bate if (auto viewOp = dyn_cast<memref::SubViewOp>(op)) { 176f5fe92f6SChristopher Bate // For sub-views, just clone the op. 177f5fe92f6SChristopher Bate // NOTE: This is okay in the test because we use dynamic memref sizes, so 178f5fe92f6SChristopher Bate // the verifier will not complain. Otherwise, we may create a logically 179f5fe92f6SChristopher Bate // out-of-bounds view and a different technique should be used. 180f5fe92f6SChristopher Bate Operation *opClone = rewriter.clone(*op); 181f5fe92f6SChristopher Bate elseYieldOperands.append(opClone->result_begin(), opClone->result_end()); 182f5fe92f6SChristopher Bate } else { 183f5fe92f6SChristopher Bate // Default to assuming constant numeric values. 184205c08b5SThomas Raoux for (Type type : op->getResultTypes()) { 185f5fe92f6SChristopher Bate elseYieldOperands.push_back(rewriter.create<arith::ConstantOp>( 186f5fe92f6SChristopher Bate loc, rewriter.getZeroAttr(type))); 187205c08b5SThomas Raoux } 188f5fe92f6SChristopher Bate } 189f5fe92f6SChristopher Bate if (op->getNumResults() > 0) 190f5fe92f6SChristopher Bate rewriter.create<scf::YieldOp>(loc, elseYieldOperands); 191205c08b5SThomas Raoux return ifOp.getOperation(); 192205c08b5SThomas Raoux } 193205c08b5SThomas Raoux 1940736bbd7SThomas Raoux static void annotate(Operation *op, 1950736bbd7SThomas Raoux mlir::scf::PipeliningOption::PipelinerPart part, 1960736bbd7SThomas Raoux unsigned iteration) { 1970736bbd7SThomas Raoux OpBuilder b(op); 1980736bbd7SThomas Raoux switch (part) { 1990736bbd7SThomas Raoux case mlir::scf::PipeliningOption::PipelinerPart::Prologue: 2000736bbd7SThomas Raoux op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue")); 2010736bbd7SThomas Raoux break; 2020736bbd7SThomas Raoux case mlir::scf::PipeliningOption::PipelinerPart::Kernel: 2030736bbd7SThomas Raoux op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel")); 2040736bbd7SThomas Raoux break; 2050736bbd7SThomas Raoux case mlir::scf::PipeliningOption::PipelinerPart::Epilogue: 2060736bbd7SThomas Raoux op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue")); 2070736bbd7SThomas Raoux break; 2080736bbd7SThomas Raoux } 2090736bbd7SThomas Raoux op->setAttr(kTestPipeliningAnnotationIteration, 2100736bbd7SThomas Raoux b.getI32IntegerAttr(iteration)); 2110736bbd7SThomas Raoux } 2120736bbd7SThomas Raoux 213a54f4eaeSMogball void getDependentDialects(DialectRegistry ®istry) const override { 214abc362a1SJakub Kuderski registry.insert<arith::ArithDialect, memref::MemRefDialect>(); 215a54f4eaeSMogball } 216a54f4eaeSMogball 21741574554SRiver Riddle void runOnOperation() override { 218f6f88e66Sthomasraoux RewritePatternSet patterns(&getContext()); 219f6f88e66Sthomasraoux mlir::scf::PipeliningOption options; 220f6f88e66Sthomasraoux options.getScheduleFn = getSchedule; 221ebf05993SSJW options.supportDynamicLoops = true; 222ebf05993SSJW options.predicateFn = predicateOp; 2230736bbd7SThomas Raoux if (annotatePipeline) 2240736bbd7SThomas Raoux options.annotateFn = annotate; 225205c08b5SThomas Raoux if (noEpiloguePeeling) { 226205c08b5SThomas Raoux options.peelEpilogue = false; 227205c08b5SThomas Raoux } 228f6f88e66Sthomasraoux scf::populateSCFLoopPipeliningPatterns(patterns, options); 229*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 23041574554SRiver Riddle getOperation().walk([](Operation *op) { 231f6f88e66Sthomasraoux // Clean up the markers. 232f6f88e66Sthomasraoux op->removeAttr(kTestPipeliningStageMarker); 233f6f88e66Sthomasraoux op->removeAttr(kTestPipeliningOpOrderMarker); 234f6f88e66Sthomasraoux }); 235f6f88e66Sthomasraoux } 236f6f88e66Sthomasraoux }; 2373fef2d26SRiver Riddle } // namespace 2383fef2d26SRiver Riddle 2393fef2d26SRiver Riddle namespace mlir { 2403fef2d26SRiver Riddle namespace test { 2413fef2d26SRiver Riddle void registerTestSCFUtilsPass() { 242b5e22e6dSMehdi Amini PassRegistration<TestSCFForUtilsPass>(); 243b5e22e6dSMehdi Amini PassRegistration<TestSCFIfUtilsPass>(); 244f6f88e66Sthomasraoux PassRegistration<TestSCFPipeliningPass>(); 2453fef2d26SRiver Riddle } 2463fef2d26SRiver Riddle } // namespace test 2473fef2d26SRiver Riddle } // namespace mlir 248