xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1032cb165SMorten Borup Petersen //===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
2032cb165SMorten Borup Petersen //
3032cb165SMorten Borup Petersen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4032cb165SMorten Borup Petersen // See https://llvm.org/LICENSE.txt for license information.
5032cb165SMorten Borup Petersen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6032cb165SMorten Borup Petersen //
7032cb165SMorten Borup Petersen //===----------------------------------------------------------------------===//
8032cb165SMorten Borup Petersen //
9032cb165SMorten Borup Petersen // Transforms SCF.ForOp's into SCF.WhileOp's.
10032cb165SMorten Borup Petersen //
11032cb165SMorten Borup Petersen //===----------------------------------------------------------------------===//
12032cb165SMorten Borup Petersen 
1367d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h"
1467d0d7acSMichele Scuttari 
15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
168b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
178b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18032cb165SMorten Borup Petersen #include "mlir/IR/PatternMatch.h"
19032cb165SMorten Borup Petersen #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20032cb165SMorten Borup Petersen 
2167d0d7acSMichele Scuttari namespace mlir {
2267d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFFORTOWHILELOOP
2367d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
2467d0d7acSMichele Scuttari } // namespace mlir
2567d0d7acSMichele Scuttari 
26032cb165SMorten Borup Petersen using namespace llvm;
27032cb165SMorten Borup Petersen using namespace mlir;
28032cb165SMorten Borup Petersen using scf::ForOp;
29032cb165SMorten Borup Petersen using scf::WhileOp;
30032cb165SMorten Borup Petersen 
31032cb165SMorten Borup Petersen namespace {
32032cb165SMorten Borup Petersen 
33032cb165SMorten Borup Petersen struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
34032cb165SMorten Borup Petersen   using OpRewritePattern<ForOp>::OpRewritePattern;
35032cb165SMorten Borup Petersen 
36032cb165SMorten Borup Petersen   LogicalResult matchAndRewrite(ForOp forOp,
37032cb165SMorten Borup Petersen                                 PatternRewriter &rewriter) const override {
38032cb165SMorten Borup Petersen     // Generate type signature for the loop-carried values. The induction
39032cb165SMorten Borup Petersen     // variable is placed first, followed by the forOp.iterArgs.
40e084679fSRiver Riddle     SmallVector<Type> lcvTypes;
41e084679fSRiver Riddle     SmallVector<Location> lcvLocs;
42032cb165SMorten Borup Petersen     lcvTypes.push_back(forOp.getInductionVar().getType());
43e084679fSRiver Riddle     lcvLocs.push_back(forOp.getInductionVar().getLoc());
44e084679fSRiver Riddle     for (Value value : forOp.getInitArgs()) {
45e084679fSRiver Riddle       lcvTypes.push_back(value.getType());
46e084679fSRiver Riddle       lcvLocs.push_back(value.getLoc());
47e084679fSRiver Riddle     }
48032cb165SMorten Borup Petersen 
49032cb165SMorten Borup Petersen     // Build scf.WhileOp
50032cb165SMorten Borup Petersen     SmallVector<Value> initArgs;
51c0342a2dSJacques Pienaar     initArgs.push_back(forOp.getLowerBound());
52c0342a2dSJacques Pienaar     llvm::append_range(initArgs, forOp.getInitArgs());
53032cb165SMorten Borup Petersen     auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
54032cb165SMorten Borup Petersen                                             forOp->getAttrs());
55032cb165SMorten Borup Petersen 
56032cb165SMorten Borup Petersen     // 'before' region contains the loop condition and forwarding of iteration
57032cb165SMorten Borup Petersen     // arguments to the 'after' region.
58032cb165SMorten Borup Petersen     auto *beforeBlock = rewriter.createBlock(
59e084679fSRiver Riddle         &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
607c74a250SMatthias Springer     rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
61a54f4eaeSMogball     auto cmpOp = rewriter.create<arith::CmpIOp>(
62a54f4eaeSMogball         whileOp.getLoc(), arith::CmpIPredicate::slt,
63c0342a2dSJacques Pienaar         beforeBlock->getArgument(0), forOp.getUpperBound());
64032cb165SMorten Borup Petersen     rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
65032cb165SMorten Borup Petersen                                       beforeBlock->getArguments());
66032cb165SMorten Borup Petersen 
67032cb165SMorten Borup Petersen     // Inline for-loop body into an executeRegion operation in the "after"
68032cb165SMorten Borup Petersen     // region. The return type of the execRegionOp does not contain the
69032cb165SMorten Borup Petersen     // iv - yields in the source for-loop contain only iterArgs.
70032cb165SMorten Borup Petersen     auto *afterBlock = rewriter.createBlock(
71e084679fSRiver Riddle         &whileOp.getAfter(), whileOp.getAfter().begin(), lcvTypes, lcvLocs);
72032cb165SMorten Borup Petersen 
73032cb165SMorten Borup Petersen     // Add induction variable incrementation
74032cb165SMorten Borup Petersen     rewriter.setInsertionPointToEnd(afterBlock);
75a54f4eaeSMogball     auto ivIncOp = rewriter.create<arith::AddIOp>(
76c0342a2dSJacques Pienaar         whileOp.getLoc(), afterBlock->getArgument(0), forOp.getStep());
77032cb165SMorten Borup Petersen 
78032cb165SMorten Borup Petersen     // Rewrite uses of the for-loop block arguments to the new while-loop
79032cb165SMorten Borup Petersen     // "after" arguments
80e4853be2SMehdi Amini     for (const auto &barg : enumerate(forOp.getBody(0)->getArguments()))
8161f37758SMatthias Springer       rewriter.replaceAllUsesWith(barg.value(),
8261f37758SMatthias Springer                                   afterBlock->getArgument(barg.index()));
83032cb165SMorten Borup Petersen 
84032cb165SMorten Borup Petersen     // Inline for-loop body operations into 'after' region.
85032cb165SMorten Borup Petersen     for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
865cc0f76dSMatthias Springer       rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());
87032cb165SMorten Borup Petersen 
88032cb165SMorten Borup Petersen     // Add incremented IV to yield operations
89032cb165SMorten Borup Petersen     for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
90032cb165SMorten Borup Petersen       SmallVector<Value> yieldOperands = yieldOp.getOperands();
91032cb165SMorten Borup Petersen       yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
925fcf907bSMatthias Springer       rewriter.modifyOpInPlace(yieldOp,
935fcf907bSMatthias Springer                                [&]() { yieldOp->setOperands(yieldOperands); });
94032cb165SMorten Borup Petersen     }
95032cb165SMorten Borup Petersen 
96032cb165SMorten Borup Petersen     // We cannot do a direct replacement of the forOp since the while op returns
97032cb165SMorten Borup Petersen     // an extra value (the induction variable escapes the loop through being
98032cb165SMorten Borup Petersen     // carried in the set of iterargs). Instead, rewrite uses of the forOp
99032cb165SMorten Borup Petersen     // results.
100e4853be2SMehdi Amini     for (const auto &arg : llvm::enumerate(forOp.getResults()))
10161f37758SMatthias Springer       rewriter.replaceAllUsesWith(arg.value(),
10261f37758SMatthias Springer                                   whileOp.getResult(arg.index() + 1));
103032cb165SMorten Borup Petersen 
104032cb165SMorten Borup Petersen     rewriter.eraseOp(forOp);
105032cb165SMorten Borup Petersen     return success();
106032cb165SMorten Borup Petersen   }
107032cb165SMorten Borup Petersen };
108032cb165SMorten Borup Petersen 
10967d0d7acSMichele Scuttari struct ForToWhileLoop : public impl::SCFForToWhileLoopBase<ForToWhileLoop> {
11041574554SRiver Riddle   void runOnOperation() override {
11154998986SStella Laurenzo     auto *parentOp = getOperation();
11254998986SStella Laurenzo     MLIRContext *ctx = parentOp->getContext();
113032cb165SMorten Borup Petersen     RewritePatternSet patterns(ctx);
114032cb165SMorten Borup Petersen     patterns.add<ForLoopLoweringPattern>(ctx);
115*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(parentOp, std::move(patterns));
116032cb165SMorten Borup Petersen   }
117032cb165SMorten Borup Petersen };
118032cb165SMorten Borup Petersen } // namespace
119039b969bSMichele Scuttari 
120039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
121039b969bSMichele Scuttari   return std::make_unique<ForToWhileLoop>();
122039b969bSMichele Scuttari }
123