1 //===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Transforms SCF.ForOp's into SCF.WhileOp's. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/Passes.h" 14 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/SCF/IR/SCF.h" 17 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 namespace mlir { 22 #define GEN_PASS_DEF_SCFFORTOWHILELOOP 23 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" 24 } // namespace mlir 25 26 using namespace llvm; 27 using namespace mlir; 28 using scf::ForOp; 29 using scf::WhileOp; 30 31 namespace { 32 33 struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> { 34 using OpRewritePattern<ForOp>::OpRewritePattern; 35 36 LogicalResult matchAndRewrite(ForOp forOp, 37 PatternRewriter &rewriter) const override { 38 // Generate type signature for the loop-carried values. The induction 39 // variable is placed first, followed by the forOp.iterArgs. 40 SmallVector<Type> lcvTypes; 41 SmallVector<Location> lcvLocs; 42 lcvTypes.push_back(forOp.getInductionVar().getType()); 43 lcvLocs.push_back(forOp.getInductionVar().getLoc()); 44 for (Value value : forOp.getInitArgs()) { 45 lcvTypes.push_back(value.getType()); 46 lcvLocs.push_back(value.getLoc()); 47 } 48 49 // Build scf.WhileOp 50 SmallVector<Value> initArgs; 51 initArgs.push_back(forOp.getLowerBound()); 52 llvm::append_range(initArgs, forOp.getInitArgs()); 53 auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs, 54 forOp->getAttrs()); 55 56 // 'before' region contains the loop condition and forwarding of iteration 57 // arguments to the 'after' region. 58 auto *beforeBlock = rewriter.createBlock( 59 &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs); 60 rewriter.setInsertionPointToStart(whileOp.getBeforeBody()); 61 auto cmpOp = rewriter.create<arith::CmpIOp>( 62 whileOp.getLoc(), arith::CmpIPredicate::slt, 63 beforeBlock->getArgument(0), forOp.getUpperBound()); 64 rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(), 65 beforeBlock->getArguments()); 66 67 // Inline for-loop body into an executeRegion operation in the "after" 68 // region. The return type of the execRegionOp does not contain the 69 // iv - yields in the source for-loop contain only iterArgs. 70 auto *afterBlock = rewriter.createBlock( 71 &whileOp.getAfter(), whileOp.getAfter().begin(), lcvTypes, lcvLocs); 72 73 // Add induction variable incrementation 74 rewriter.setInsertionPointToEnd(afterBlock); 75 auto ivIncOp = rewriter.create<arith::AddIOp>( 76 whileOp.getLoc(), afterBlock->getArgument(0), forOp.getStep()); 77 78 // Rewrite uses of the for-loop block arguments to the new while-loop 79 // "after" arguments 80 for (const auto &barg : enumerate(forOp.getBody(0)->getArguments())) 81 rewriter.replaceAllUsesWith(barg.value(), 82 afterBlock->getArgument(barg.index())); 83 84 // Inline for-loop body operations into 'after' region. 85 for (auto &arg : llvm::make_early_inc_range(*forOp.getBody())) 86 rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end()); 87 88 // Add incremented IV to yield operations 89 for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) { 90 SmallVector<Value> yieldOperands = yieldOp.getOperands(); 91 yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult()); 92 rewriter.modifyOpInPlace(yieldOp, 93 [&]() { yieldOp->setOperands(yieldOperands); }); 94 } 95 96 // We cannot do a direct replacement of the forOp since the while op returns 97 // an extra value (the induction variable escapes the loop through being 98 // carried in the set of iterargs). Instead, rewrite uses of the forOp 99 // results. 100 for (const auto &arg : llvm::enumerate(forOp.getResults())) 101 rewriter.replaceAllUsesWith(arg.value(), 102 whileOp.getResult(arg.index() + 1)); 103 104 rewriter.eraseOp(forOp); 105 return success(); 106 } 107 }; 108 109 struct ForToWhileLoop : public impl::SCFForToWhileLoopBase<ForToWhileLoop> { 110 void runOnOperation() override { 111 auto *parentOp = getOperation(); 112 MLIRContext *ctx = parentOp->getContext(); 113 RewritePatternSet patterns(ctx); 114 patterns.add<ForLoopLoweringPattern>(ctx); 115 (void)applyPatternsGreedily(parentOp, std::move(patterns)); 116 } 117 }; 118 } // namespace 119 120 std::unique_ptr<Pass> mlir::createForToWhileLoopPass() { 121 return std::make_unique<ForToWhileLoop>(); 122 } 123