1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to test SCF dialect utils. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/SCF/IR/SCF.h" 17 #include "mlir/Dialect/SCF/Transforms/Patterns.h" 18 #include "mlir/Dialect/SCF/Utils/Utils.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 24 using namespace mlir; 25 26 namespace { 27 struct TestSCFForUtilsPass 28 : public PassWrapper<TestSCFForUtilsPass, OperationPass<func::FuncOp>> { 29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForUtilsPass) 30 31 StringRef getArgument() const final { return "test-scf-for-utils"; } 32 StringRef getDescription() const final { return "test scf.for utils"; } 33 explicit TestSCFForUtilsPass() = default; 34 TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {} 35 36 Option<bool> testReplaceWithNewYields{ 37 *this, "test-replace-with-new-yields", 38 llvm::cl::desc("Test replacing a loop with a new loop that returns new " 39 "additional yield values"), 40 llvm::cl::init(false)}; 41 42 void runOnOperation() override { 43 func::FuncOp func = getOperation(); 44 SmallVector<scf::ForOp, 4> toErase; 45 46 if (testReplaceWithNewYields) { 47 func.walk([&](scf::ForOp forOp) { 48 if (forOp.getNumResults() == 0) 49 return; 50 auto newInitValues = forOp.getInitArgs(); 51 if (newInitValues.empty()) 52 return; 53 SmallVector<Value> oldYieldValues = 54 llvm::to_vector(forOp.getYieldedValues()); 55 NewYieldValuesFn fn = [&](OpBuilder &b, Location loc, 56 ArrayRef<BlockArgument> newBBArgs) { 57 SmallVector<Value> newYieldValues; 58 for (auto yieldVal : oldYieldValues) { 59 newYieldValues.push_back( 60 b.create<arith::AddFOp>(loc, yieldVal, yieldVal)); 61 } 62 return newYieldValues; 63 }; 64 IRRewriter rewriter(forOp.getContext()); 65 if (failed(forOp.replaceWithAdditionalYields( 66 rewriter, newInitValues, /*replaceInitOperandUsesInLoop=*/true, 67 fn))) 68 signalPassFailure(); 69 }); 70 } 71 } 72 }; 73 74 struct TestSCFIfUtilsPass 75 : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> { 76 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFIfUtilsPass) 77 78 StringRef getArgument() const final { return "test-scf-if-utils"; } 79 StringRef getDescription() const final { return "test scf.if utils"; } 80 explicit TestSCFIfUtilsPass() = default; 81 82 void getDependentDialects(DialectRegistry ®istry) const override { 83 registry.insert<func::FuncDialect>(); 84 } 85 86 void runOnOperation() override { 87 int count = 0; 88 getOperation().walk([&](scf::IfOp ifOp) { 89 auto strCount = std::to_string(count++); 90 func::FuncOp thenFn, elseFn; 91 OpBuilder b(ifOp); 92 IRRewriter rewriter(b); 93 if (failed(outlineIfOp(rewriter, ifOp, &thenFn, 94 std::string("outlined_then") + strCount, &elseFn, 95 std::string("outlined_else") + strCount))) { 96 this->signalPassFailure(); 97 return WalkResult::interrupt(); 98 } 99 return WalkResult::advance(); 100 }); 101 } 102 }; 103 104 static const StringLiteral kTestPipeliningLoopMarker = 105 "__test_pipelining_loop__"; 106 static const StringLiteral kTestPipeliningStageMarker = 107 "__test_pipelining_stage__"; 108 /// Marker to express the order in which operations should be after 109 /// pipelining. 110 static const StringLiteral kTestPipeliningOpOrderMarker = 111 "__test_pipelining_op_order__"; 112 113 static const StringLiteral kTestPipeliningAnnotationPart = 114 "__test_pipelining_part"; 115 static const StringLiteral kTestPipeliningAnnotationIteration = 116 "__test_pipelining_iteration"; 117 118 struct TestSCFPipeliningPass 119 : public PassWrapper<TestSCFPipeliningPass, OperationPass<func::FuncOp>> { 120 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPipeliningPass) 121 122 TestSCFPipeliningPass() = default; 123 TestSCFPipeliningPass(const TestSCFPipeliningPass &) {} 124 StringRef getArgument() const final { return "test-scf-pipelining"; } 125 StringRef getDescription() const final { return "test scf.forOp pipelining"; } 126 127 Option<bool> annotatePipeline{ 128 *this, "annotate", 129 llvm::cl::desc("Annote operations during loop pipelining transformation"), 130 llvm::cl::init(false)}; 131 132 Option<bool> noEpiloguePeeling{ 133 *this, "no-epilogue-peeling", 134 llvm::cl::desc("Use predicates instead of peeling the epilogue."), 135 llvm::cl::init(false)}; 136 137 static void 138 getSchedule(scf::ForOp forOp, 139 std::vector<std::pair<Operation *, unsigned>> &schedule) { 140 if (!forOp->hasAttr(kTestPipeliningLoopMarker)) 141 return; 142 143 schedule.resize(forOp.getBody()->getOperations().size() - 1); 144 forOp.walk([&schedule](Operation *op) { 145 auto attrStage = 146 op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker); 147 auto attrCycle = 148 op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker); 149 if (attrCycle && attrStage) { 150 // TODO: Index can be out-of-bounds if ops of the loop body disappear 151 // due to folding. 152 schedule[attrCycle.getInt()] = 153 std::make_pair(op, unsigned(attrStage.getInt())); 154 } 155 }); 156 } 157 158 /// Helper to generate "predicated" version of `op`. For simplicity we just 159 /// wrap the operation in a scf.ifOp operation. 160 static Operation *predicateOp(RewriterBase &rewriter, Operation *op, 161 Value pred) { 162 Location loc = op->getLoc(); 163 auto ifOp = 164 rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true); 165 // True branch. 166 rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(), 167 ifOp.getThenRegion().front().begin()); 168 rewriter.setInsertionPointAfter(op); 169 if (op->getNumResults() > 0) 170 rewriter.create<scf::YieldOp>(loc, op->getResults()); 171 // False branch. 172 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 173 SmallVector<Value> elseYieldOperands; 174 elseYieldOperands.reserve(ifOp.getNumResults()); 175 if (auto viewOp = dyn_cast<memref::SubViewOp>(op)) { 176 // For sub-views, just clone the op. 177 // NOTE: This is okay in the test because we use dynamic memref sizes, so 178 // the verifier will not complain. Otherwise, we may create a logically 179 // out-of-bounds view and a different technique should be used. 180 Operation *opClone = rewriter.clone(*op); 181 elseYieldOperands.append(opClone->result_begin(), opClone->result_end()); 182 } else { 183 // Default to assuming constant numeric values. 184 for (Type type : op->getResultTypes()) { 185 elseYieldOperands.push_back(rewriter.create<arith::ConstantOp>( 186 loc, rewriter.getZeroAttr(type))); 187 } 188 } 189 if (op->getNumResults() > 0) 190 rewriter.create<scf::YieldOp>(loc, elseYieldOperands); 191 return ifOp.getOperation(); 192 } 193 194 static void annotate(Operation *op, 195 mlir::scf::PipeliningOption::PipelinerPart part, 196 unsigned iteration) { 197 OpBuilder b(op); 198 switch (part) { 199 case mlir::scf::PipeliningOption::PipelinerPart::Prologue: 200 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue")); 201 break; 202 case mlir::scf::PipeliningOption::PipelinerPart::Kernel: 203 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel")); 204 break; 205 case mlir::scf::PipeliningOption::PipelinerPart::Epilogue: 206 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue")); 207 break; 208 } 209 op->setAttr(kTestPipeliningAnnotationIteration, 210 b.getI32IntegerAttr(iteration)); 211 } 212 213 void getDependentDialects(DialectRegistry ®istry) const override { 214 registry.insert<arith::ArithDialect, memref::MemRefDialect>(); 215 } 216 217 void runOnOperation() override { 218 RewritePatternSet patterns(&getContext()); 219 mlir::scf::PipeliningOption options; 220 options.getScheduleFn = getSchedule; 221 options.supportDynamicLoops = true; 222 options.predicateFn = predicateOp; 223 if (annotatePipeline) 224 options.annotateFn = annotate; 225 if (noEpiloguePeeling) { 226 options.peelEpilogue = false; 227 } 228 scf::populateSCFLoopPipeliningPatterns(patterns, options); 229 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 230 getOperation().walk([](Operation *op) { 231 // Clean up the markers. 232 op->removeAttr(kTestPipeliningStageMarker); 233 op->removeAttr(kTestPipeliningOpOrderMarker); 234 }); 235 } 236 }; 237 } // namespace 238 239 namespace mlir { 240 namespace test { 241 void registerTestSCFUtilsPass() { 242 PassRegistration<TestSCFForUtilsPass>(); 243 PassRegistration<TestSCFIfUtilsPass>(); 244 PassRegistration<TestSCFPipeliningPass>(); 245 } 246 } // namespace test 247 } // namespace mlir 248