1032cb165SMorten Borup Petersen //===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===// 2032cb165SMorten Borup Petersen // 3032cb165SMorten Borup Petersen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4032cb165SMorten Borup Petersen // See https://llvm.org/LICENSE.txt for license information. 5032cb165SMorten Borup Petersen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6032cb165SMorten Borup Petersen // 7032cb165SMorten Borup Petersen //===----------------------------------------------------------------------===// 8032cb165SMorten Borup Petersen // 9032cb165SMorten Borup Petersen // Transforms SCF.ForOp's into SCF.WhileOp's. 10032cb165SMorten Borup Petersen // 11032cb165SMorten Borup Petersen //===----------------------------------------------------------------------===// 12032cb165SMorten Borup Petersen 1367d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h" 1467d0d7acSMichele Scuttari 15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 168b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 178b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h" 18032cb165SMorten Borup Petersen #include "mlir/IR/PatternMatch.h" 19032cb165SMorten Borup Petersen #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20032cb165SMorten Borup Petersen 2167d0d7acSMichele Scuttari namespace mlir { 2267d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFFORTOWHILELOOP 2367d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" 2467d0d7acSMichele Scuttari } // namespace mlir 2567d0d7acSMichele Scuttari 26032cb165SMorten Borup Petersen using namespace llvm; 27032cb165SMorten Borup Petersen using namespace mlir; 28032cb165SMorten Borup Petersen using scf::ForOp; 29032cb165SMorten Borup Petersen using scf::WhileOp; 30032cb165SMorten Borup Petersen 31032cb165SMorten Borup Petersen namespace { 32032cb165SMorten Borup Petersen 33032cb165SMorten Borup Petersen struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> { 34032cb165SMorten Borup Petersen using OpRewritePattern<ForOp>::OpRewritePattern; 35032cb165SMorten Borup Petersen 36032cb165SMorten Borup Petersen LogicalResult matchAndRewrite(ForOp forOp, 37032cb165SMorten Borup Petersen PatternRewriter &rewriter) const override { 38032cb165SMorten Borup Petersen // Generate type signature for the loop-carried values. The induction 39032cb165SMorten Borup Petersen // variable is placed first, followed by the forOp.iterArgs. 40e084679fSRiver Riddle SmallVector<Type> lcvTypes; 41e084679fSRiver Riddle SmallVector<Location> lcvLocs; 42032cb165SMorten Borup Petersen lcvTypes.push_back(forOp.getInductionVar().getType()); 43e084679fSRiver Riddle lcvLocs.push_back(forOp.getInductionVar().getLoc()); 44e084679fSRiver Riddle for (Value value : forOp.getInitArgs()) { 45e084679fSRiver Riddle lcvTypes.push_back(value.getType()); 46e084679fSRiver Riddle lcvLocs.push_back(value.getLoc()); 47e084679fSRiver Riddle } 48032cb165SMorten Borup Petersen 49032cb165SMorten Borup Petersen // Build scf.WhileOp 50032cb165SMorten Borup Petersen SmallVector<Value> initArgs; 51c0342a2dSJacques Pienaar initArgs.push_back(forOp.getLowerBound()); 52c0342a2dSJacques Pienaar llvm::append_range(initArgs, forOp.getInitArgs()); 53032cb165SMorten Borup Petersen auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs, 54032cb165SMorten Borup Petersen forOp->getAttrs()); 55032cb165SMorten Borup Petersen 56032cb165SMorten Borup Petersen // 'before' region contains the loop condition and forwarding of iteration 57032cb165SMorten Borup Petersen // arguments to the 'after' region. 58032cb165SMorten Borup Petersen auto *beforeBlock = rewriter.createBlock( 59e084679fSRiver Riddle &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs); 607c74a250SMatthias Springer rewriter.setInsertionPointToStart(whileOp.getBeforeBody()); 61a54f4eaeSMogball auto cmpOp = rewriter.create<arith::CmpIOp>( 62a54f4eaeSMogball whileOp.getLoc(), arith::CmpIPredicate::slt, 63c0342a2dSJacques Pienaar beforeBlock->getArgument(0), forOp.getUpperBound()); 64032cb165SMorten Borup Petersen rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(), 65032cb165SMorten Borup Petersen beforeBlock->getArguments()); 66032cb165SMorten Borup Petersen 67032cb165SMorten Borup Petersen // Inline for-loop body into an executeRegion operation in the "after" 68032cb165SMorten Borup Petersen // region. The return type of the execRegionOp does not contain the 69032cb165SMorten Borup Petersen // iv - yields in the source for-loop contain only iterArgs. 70032cb165SMorten Borup Petersen auto *afterBlock = rewriter.createBlock( 71e084679fSRiver Riddle &whileOp.getAfter(), whileOp.getAfter().begin(), lcvTypes, lcvLocs); 72032cb165SMorten Borup Petersen 73032cb165SMorten Borup Petersen // Add induction variable incrementation 74032cb165SMorten Borup Petersen rewriter.setInsertionPointToEnd(afterBlock); 75a54f4eaeSMogball auto ivIncOp = rewriter.create<arith::AddIOp>( 76c0342a2dSJacques Pienaar whileOp.getLoc(), afterBlock->getArgument(0), forOp.getStep()); 77032cb165SMorten Borup Petersen 78032cb165SMorten Borup Petersen // Rewrite uses of the for-loop block arguments to the new while-loop 79032cb165SMorten Borup Petersen // "after" arguments 80e4853be2SMehdi Amini for (const auto &barg : enumerate(forOp.getBody(0)->getArguments())) 8161f37758SMatthias Springer rewriter.replaceAllUsesWith(barg.value(), 8261f37758SMatthias Springer afterBlock->getArgument(barg.index())); 83032cb165SMorten Borup Petersen 84032cb165SMorten Borup Petersen // Inline for-loop body operations into 'after' region. 85032cb165SMorten Borup Petersen for (auto &arg : llvm::make_early_inc_range(*forOp.getBody())) 865cc0f76dSMatthias Springer rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end()); 87032cb165SMorten Borup Petersen 88032cb165SMorten Borup Petersen // Add incremented IV to yield operations 89032cb165SMorten Borup Petersen for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) { 90032cb165SMorten Borup Petersen SmallVector<Value> yieldOperands = yieldOp.getOperands(); 91032cb165SMorten Borup Petersen yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult()); 925fcf907bSMatthias Springer rewriter.modifyOpInPlace(yieldOp, 935fcf907bSMatthias Springer [&]() { yieldOp->setOperands(yieldOperands); }); 94032cb165SMorten Borup Petersen } 95032cb165SMorten Borup Petersen 96032cb165SMorten Borup Petersen // We cannot do a direct replacement of the forOp since the while op returns 97032cb165SMorten Borup Petersen // an extra value (the induction variable escapes the loop through being 98032cb165SMorten Borup Petersen // carried in the set of iterargs). Instead, rewrite uses of the forOp 99032cb165SMorten Borup Petersen // results. 100e4853be2SMehdi Amini for (const auto &arg : llvm::enumerate(forOp.getResults())) 10161f37758SMatthias Springer rewriter.replaceAllUsesWith(arg.value(), 10261f37758SMatthias Springer whileOp.getResult(arg.index() + 1)); 103032cb165SMorten Borup Petersen 104032cb165SMorten Borup Petersen rewriter.eraseOp(forOp); 105032cb165SMorten Borup Petersen return success(); 106032cb165SMorten Borup Petersen } 107032cb165SMorten Borup Petersen }; 108032cb165SMorten Borup Petersen 10967d0d7acSMichele Scuttari struct ForToWhileLoop : public impl::SCFForToWhileLoopBase<ForToWhileLoop> { 11041574554SRiver Riddle void runOnOperation() override { 11154998986SStella Laurenzo auto *parentOp = getOperation(); 11254998986SStella Laurenzo MLIRContext *ctx = parentOp->getContext(); 113032cb165SMorten Borup Petersen RewritePatternSet patterns(ctx); 114032cb165SMorten Borup Petersen patterns.add<ForLoopLoweringPattern>(ctx); 115*09dfc571SJacques Pienaar (void)applyPatternsGreedily(parentOp, std::move(patterns)); 116032cb165SMorten Borup Petersen } 117032cb165SMorten Borup Petersen }; 118032cb165SMorten Borup Petersen } // namespace 119039b969bSMichele Scuttari 120039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createForToWhileLoopPass() { 121039b969bSMichele Scuttari return std::make_unique<ForToWhileLoop>(); 122039b969bSMichele Scuttari } 123