xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp (revision 1ca6b4475c02e5d022ec6b35dbb65d0f11409a88)
1b153c05cSIvan Butygin //===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
2b153c05cSIvan Butygin //
3b153c05cSIvan Butygin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b153c05cSIvan Butygin // See https://llvm.org/LICENSE.txt for license information.
5b153c05cSIvan Butygin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b153c05cSIvan Butygin //
7b153c05cSIvan Butygin //===----------------------------------------------------------------------===//
8b153c05cSIvan Butygin //
9b153c05cSIvan Butygin // Transforms SCF.WhileOp's into SCF.ForOp's.
10b153c05cSIvan Butygin //
11b153c05cSIvan Butygin //===----------------------------------------------------------------------===//
12b153c05cSIvan Butygin 
13b153c05cSIvan Butygin #include "mlir/Dialect/SCF/Transforms/Passes.h"
14b153c05cSIvan Butygin 
15b153c05cSIvan Butygin #include "mlir/Dialect/Arith/IR/Arith.h"
16b153c05cSIvan Butygin #include "mlir/Dialect/SCF/IR/SCF.h"
17b153c05cSIvan Butygin #include "mlir/Dialect/SCF/Transforms/Patterns.h"
18b153c05cSIvan Butygin #include "mlir/IR/Dominance.h"
19b153c05cSIvan Butygin #include "mlir/IR/PatternMatch.h"
20b153c05cSIvan Butygin 
21b153c05cSIvan Butygin using namespace mlir;
22b153c05cSIvan Butygin 
23b153c05cSIvan Butygin namespace {
24b153c05cSIvan Butygin struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
25b153c05cSIvan Butygin   using OpRewritePattern::OpRewritePattern;
26b153c05cSIvan Butygin 
matchAndRewrite__anon9ab013c10111::UpliftWhileOp27b153c05cSIvan Butygin   LogicalResult matchAndRewrite(scf::WhileOp loop,
28b153c05cSIvan Butygin                                 PatternRewriter &rewriter) const override {
29b153c05cSIvan Butygin     return upliftWhileToForLoop(rewriter, loop);
30b153c05cSIvan Butygin   }
31b153c05cSIvan Butygin };
32b153c05cSIvan Butygin } // namespace
33b153c05cSIvan Butygin 
upliftWhileToForLoop(RewriterBase & rewriter,scf::WhileOp loop)34b153c05cSIvan Butygin FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
35b153c05cSIvan Butygin                                                       scf::WhileOp loop) {
36b153c05cSIvan Butygin   Block *beforeBody = loop.getBeforeBody();
37b153c05cSIvan Butygin   if (!llvm::hasSingleElement(beforeBody->without_terminator()))
38b153c05cSIvan Butygin     return rewriter.notifyMatchFailure(loop, "Loop body must have single op");
39b153c05cSIvan Butygin 
40b153c05cSIvan Butygin   auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
41b153c05cSIvan Butygin   if (!cmp)
42b153c05cSIvan Butygin     return rewriter.notifyMatchFailure(loop,
43b153c05cSIvan Butygin                                        "Loop body must have single cmp op");
44b153c05cSIvan Butygin 
45b153c05cSIvan Butygin   scf::ConditionOp beforeTerm = loop.getConditionOp();
46b153c05cSIvan Butygin   if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult())
47b153c05cSIvan Butygin     return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
48b153c05cSIvan Butygin       diag << "Expected single condition use: " << *cmp;
49b153c05cSIvan Butygin     });
50b153c05cSIvan Butygin 
51b153c05cSIvan Butygin   // All `before` block args must be directly forwarded to ConditionOp.
52b153c05cSIvan Butygin   // They will be converted to `scf.for` `iter_vars` except induction var.
53b153c05cSIvan Butygin   if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
54b153c05cSIvan Butygin     return rewriter.notifyMatchFailure(loop, "Invalid args order");
55b153c05cSIvan Butygin 
56b153c05cSIvan Butygin   using Pred = arith::CmpIPredicate;
57b153c05cSIvan Butygin   Pred predicate = cmp.getPredicate();
58b153c05cSIvan Butygin   if (predicate != Pred::slt && predicate != Pred::sgt)
59b153c05cSIvan Butygin     return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
60b153c05cSIvan Butygin       diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
61b153c05cSIvan Butygin     });
62b153c05cSIvan Butygin 
63b153c05cSIvan Butygin   BlockArgument inductionVar;
64b153c05cSIvan Butygin   Value ub;
65b153c05cSIvan Butygin   DominanceInfo dom;
66b153c05cSIvan Butygin 
67b153c05cSIvan Butygin   // Check if cmp has a suitable form. One of the arguments must be a `before`
68b153c05cSIvan Butygin   // block arg, other must be defined outside `scf.while` and will be treated
69b153c05cSIvan Butygin   // as upper bound.
70b153c05cSIvan Butygin   for (bool reverse : {false, true}) {
71b153c05cSIvan Butygin     auto expectedPred = reverse ? Pred::sgt : Pred::slt;
72b153c05cSIvan Butygin     if (cmp.getPredicate() != expectedPred)
73b153c05cSIvan Butygin       continue;
74b153c05cSIvan Butygin 
75b153c05cSIvan Butygin     auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
76b153c05cSIvan Butygin     auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
77b153c05cSIvan Butygin 
78b153c05cSIvan Butygin     auto blockArg = dyn_cast<BlockArgument>(arg1);
79b153c05cSIvan Butygin     if (!blockArg || blockArg.getOwner() != beforeBody)
80b153c05cSIvan Butygin       continue;
81b153c05cSIvan Butygin 
82b153c05cSIvan Butygin     if (!dom.properlyDominates(arg2, loop))
83b153c05cSIvan Butygin       continue;
84b153c05cSIvan Butygin 
85b153c05cSIvan Butygin     inductionVar = blockArg;
86b153c05cSIvan Butygin     ub = arg2;
87b153c05cSIvan Butygin     break;
88b153c05cSIvan Butygin   }
89b153c05cSIvan Butygin 
90b153c05cSIvan Butygin   if (!inductionVar)
91b153c05cSIvan Butygin     return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
92b153c05cSIvan Butygin       diag << "Unrecognized cmp form: " << *cmp;
93b153c05cSIvan Butygin     });
94b153c05cSIvan Butygin 
95b153c05cSIvan Butygin   // inductionVar must have 2 uses: one is in `cmp` and other is `condition`
96b153c05cSIvan Butygin   // arg.
97b153c05cSIvan Butygin   if (!llvm::hasNItems(inductionVar.getUses(), 2))
98b153c05cSIvan Butygin     return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
99b153c05cSIvan Butygin       diag << "Unrecognized induction var: " << inductionVar;
100b153c05cSIvan Butygin     });
101b153c05cSIvan Butygin 
102b153c05cSIvan Butygin   Block *afterBody = loop.getAfterBody();
103b153c05cSIvan Butygin   scf::YieldOp afterTerm = loop.getYieldOp();
104*1ca6b447SIvan Butygin   unsigned argNumber = inductionVar.getArgNumber();
105*1ca6b447SIvan Butygin   Value afterTermIndArg = afterTerm.getResults()[argNumber];
106b153c05cSIvan Butygin 
107*1ca6b447SIvan Butygin   Value inductionVarAfter = afterBody->getArgument(argNumber);
108b153c05cSIvan Butygin 
109b153c05cSIvan Butygin   // Find suitable `addi` op inside `after` block, one of the args must be an
110b153c05cSIvan Butygin   // Induction var passed from `before` block and second arg must be defined
111b153c05cSIvan Butygin   // outside of the loop and will be considered step value.
112b153c05cSIvan Butygin   // TODO: Add `subi` support?
113*1ca6b447SIvan Butygin   auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>();
114*1ca6b447SIvan Butygin   if (!addOp)
115b153c05cSIvan Butygin     return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
116b153c05cSIvan Butygin 
117*1ca6b447SIvan Butygin   Value step;
118*1ca6b447SIvan Butygin   if (addOp.getLhs() == inductionVarAfter) {
119*1ca6b447SIvan Butygin     step = addOp.getRhs();
120*1ca6b447SIvan Butygin   } else if (addOp.getRhs() == inductionVarAfter) {
121*1ca6b447SIvan Butygin     step = addOp.getLhs();
122*1ca6b447SIvan Butygin   }
123*1ca6b447SIvan Butygin 
124*1ca6b447SIvan Butygin   if (!step || !dom.properlyDominates(step, loop))
125*1ca6b447SIvan Butygin     return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");
126*1ca6b447SIvan Butygin 
127*1ca6b447SIvan Butygin   Value lb = loop.getInits()[argNumber];
128b153c05cSIvan Butygin 
129b153c05cSIvan Butygin   assert(lb.getType().isIntOrIndex());
130b153c05cSIvan Butygin   assert(lb.getType() == ub.getType());
131b153c05cSIvan Butygin   assert(lb.getType() == step.getType());
132b153c05cSIvan Butygin 
133b153c05cSIvan Butygin   llvm::SmallVector<Value> newArgs;
134b153c05cSIvan Butygin 
135b153c05cSIvan Butygin   // Populate inits for new `scf.for`, skip induction var.
136b153c05cSIvan Butygin   newArgs.reserve(loop.getInits().size());
137b153c05cSIvan Butygin   for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
138b153c05cSIvan Butygin     if (i == argNumber)
139b153c05cSIvan Butygin       continue;
140b153c05cSIvan Butygin 
141b153c05cSIvan Butygin     newArgs.emplace_back(init);
142b153c05cSIvan Butygin   }
143b153c05cSIvan Butygin 
144b153c05cSIvan Butygin   Location loc = loop.getLoc();
145b153c05cSIvan Butygin 
146b153c05cSIvan Butygin   // With `builder == nullptr`, ForOp::build will try to insert terminator at
147b153c05cSIvan Butygin   // the end of newly created block and we don't want it. Provide empty
148b153c05cSIvan Butygin   // dummy builder instead.
149b153c05cSIvan Butygin   auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
150b153c05cSIvan Butygin   auto newLoop =
151b153c05cSIvan Butygin       rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);
152b153c05cSIvan Butygin 
153b153c05cSIvan Butygin   Block *newBody = newLoop.getBody();
154b153c05cSIvan Butygin 
155b153c05cSIvan Butygin   // Populate block args for `scf.for` body, move induction var to the front.
156b153c05cSIvan Butygin   newArgs.clear();
157b153c05cSIvan Butygin   ValueRange newBodyArgs = newBody->getArguments();
158b153c05cSIvan Butygin   for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
159b153c05cSIvan Butygin     if (i < argNumber) {
160b153c05cSIvan Butygin       newArgs.emplace_back(newBodyArgs[i + 1]);
161b153c05cSIvan Butygin     } else if (i == argNumber) {
162b153c05cSIvan Butygin       newArgs.emplace_back(newBodyArgs.front());
163b153c05cSIvan Butygin     } else {
164b153c05cSIvan Butygin       newArgs.emplace_back(newBodyArgs[i]);
165b153c05cSIvan Butygin     }
166b153c05cSIvan Butygin   }
167b153c05cSIvan Butygin 
168b153c05cSIvan Butygin   rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
169b153c05cSIvan Butygin                              newArgs);
170b153c05cSIvan Butygin 
171b153c05cSIvan Butygin   auto term = cast<scf::YieldOp>(newBody->getTerminator());
172b153c05cSIvan Butygin 
173b153c05cSIvan Butygin   // Populate new yield args, skipping the induction var.
174b153c05cSIvan Butygin   newArgs.clear();
175b153c05cSIvan Butygin   for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
176b153c05cSIvan Butygin     if (i == argNumber)
177b153c05cSIvan Butygin       continue;
178b153c05cSIvan Butygin 
179b153c05cSIvan Butygin     newArgs.emplace_back(arg);
180b153c05cSIvan Butygin   }
181b153c05cSIvan Butygin 
182b153c05cSIvan Butygin   OpBuilder::InsertionGuard g(rewriter);
183b153c05cSIvan Butygin   rewriter.setInsertionPoint(term);
184b153c05cSIvan Butygin   rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs);
185b153c05cSIvan Butygin 
186b153c05cSIvan Butygin   // Compute induction var value after loop execution.
187b153c05cSIvan Butygin   rewriter.setInsertionPointAfter(newLoop);
188b153c05cSIvan Butygin   Value one;
189b153c05cSIvan Butygin   if (isa<IndexType>(step.getType())) {
190b153c05cSIvan Butygin     one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
191b153c05cSIvan Butygin   } else {
192b153c05cSIvan Butygin     one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType());
193b153c05cSIvan Butygin   }
194b153c05cSIvan Butygin 
195b153c05cSIvan Butygin   Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
196b153c05cSIvan Butygin   Value len = rewriter.create<arith::SubIOp>(loc, ub, lb);
197b153c05cSIvan Butygin   len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
198b153c05cSIvan Butygin   len = rewriter.create<arith::DivSIOp>(loc, len, step);
199b153c05cSIvan Butygin   len = rewriter.create<arith::SubIOp>(loc, len, one);
200b153c05cSIvan Butygin   Value res = rewriter.create<arith::MulIOp>(loc, len, step);
201b153c05cSIvan Butygin   res = rewriter.create<arith::AddIOp>(loc, lb, res);
202b153c05cSIvan Butygin 
203b153c05cSIvan Butygin   // Reconstruct `scf.while` results, inserting final induction var value
204b153c05cSIvan Butygin   // into proper place.
205b153c05cSIvan Butygin   newArgs.clear();
206b153c05cSIvan Butygin   llvm::append_range(newArgs, newLoop.getResults());
207b153c05cSIvan Butygin   newArgs.insert(newArgs.begin() + argNumber, res);
208b153c05cSIvan Butygin   rewriter.replaceOp(loop, newArgs);
209b153c05cSIvan Butygin   return newLoop;
210b153c05cSIvan Butygin }
211b153c05cSIvan Butygin 
populateUpliftWhileToForPatterns(RewritePatternSet & patterns)212b153c05cSIvan Butygin void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
213b153c05cSIvan Butygin   patterns.add<UpliftWhileOp>(patterns.getContext());
214b153c05cSIvan Butygin }
215