xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms SCF.ForOp's into SCF.WhileOp's.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/Passes.h"
14 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_SCFFORTOWHILELOOP
23 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace llvm;
27 using namespace mlir;
28 using scf::ForOp;
29 using scf::WhileOp;
30 
31 namespace {
32 
33 struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
34   using OpRewritePattern<ForOp>::OpRewritePattern;
35 
36   LogicalResult matchAndRewrite(ForOp forOp,
37                                 PatternRewriter &rewriter) const override {
38     // Generate type signature for the loop-carried values. The induction
39     // variable is placed first, followed by the forOp.iterArgs.
40     SmallVector<Type> lcvTypes;
41     SmallVector<Location> lcvLocs;
42     lcvTypes.push_back(forOp.getInductionVar().getType());
43     lcvLocs.push_back(forOp.getInductionVar().getLoc());
44     for (Value value : forOp.getInitArgs()) {
45       lcvTypes.push_back(value.getType());
46       lcvLocs.push_back(value.getLoc());
47     }
48 
49     // Build scf.WhileOp
50     SmallVector<Value> initArgs;
51     initArgs.push_back(forOp.getLowerBound());
52     llvm::append_range(initArgs, forOp.getInitArgs());
53     auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
54                                             forOp->getAttrs());
55 
56     // 'before' region contains the loop condition and forwarding of iteration
57     // arguments to the 'after' region.
58     auto *beforeBlock = rewriter.createBlock(
59         &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
60     rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
61     auto cmpOp = rewriter.create<arith::CmpIOp>(
62         whileOp.getLoc(), arith::CmpIPredicate::slt,
63         beforeBlock->getArgument(0), forOp.getUpperBound());
64     rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
65                                       beforeBlock->getArguments());
66 
67     // Inline for-loop body into an executeRegion operation in the "after"
68     // region. The return type of the execRegionOp does not contain the
69     // iv - yields in the source for-loop contain only iterArgs.
70     auto *afterBlock = rewriter.createBlock(
71         &whileOp.getAfter(), whileOp.getAfter().begin(), lcvTypes, lcvLocs);
72 
73     // Add induction variable incrementation
74     rewriter.setInsertionPointToEnd(afterBlock);
75     auto ivIncOp = rewriter.create<arith::AddIOp>(
76         whileOp.getLoc(), afterBlock->getArgument(0), forOp.getStep());
77 
78     // Rewrite uses of the for-loop block arguments to the new while-loop
79     // "after" arguments
80     for (const auto &barg : enumerate(forOp.getBody(0)->getArguments()))
81       rewriter.replaceAllUsesWith(barg.value(),
82                                   afterBlock->getArgument(barg.index()));
83 
84     // Inline for-loop body operations into 'after' region.
85     for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
86       rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());
87 
88     // Add incremented IV to yield operations
89     for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
90       SmallVector<Value> yieldOperands = yieldOp.getOperands();
91       yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
92       rewriter.modifyOpInPlace(yieldOp,
93                                [&]() { yieldOp->setOperands(yieldOperands); });
94     }
95 
96     // We cannot do a direct replacement of the forOp since the while op returns
97     // an extra value (the induction variable escapes the loop through being
98     // carried in the set of iterargs). Instead, rewrite uses of the forOp
99     // results.
100     for (const auto &arg : llvm::enumerate(forOp.getResults()))
101       rewriter.replaceAllUsesWith(arg.value(),
102                                   whileOp.getResult(arg.index() + 1));
103 
104     rewriter.eraseOp(forOp);
105     return success();
106   }
107 };
108 
109 struct ForToWhileLoop : public impl::SCFForToWhileLoopBase<ForToWhileLoop> {
110   void runOnOperation() override {
111     auto *parentOp = getOperation();
112     MLIRContext *ctx = parentOp->getContext();
113     RewritePatternSet patterns(ctx);
114     patterns.add<ForLoopLoweringPattern>(ctx);
115     (void)applyPatternsGreedily(parentOp, std::move(patterns));
116   }
117 };
118 } // namespace
119 
120 std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
121   return std::make_unique<ForToWhileLoop>();
122 }
123