xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp (revision 1ca6b4475c02e5d022ec6b35dbb65d0f11409a88)
1 //===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms SCF.WhileOp's into SCF.ForOp's.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Transforms/Passes.h"
14 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
18 #include "mlir/IR/Dominance.h"
19 #include "mlir/IR/PatternMatch.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
25   using OpRewritePattern::OpRewritePattern;
26 
matchAndRewrite__anon9ab013c10111::UpliftWhileOp27   LogicalResult matchAndRewrite(scf::WhileOp loop,
28                                 PatternRewriter &rewriter) const override {
29     return upliftWhileToForLoop(rewriter, loop);
30   }
31 };
32 } // namespace
33 
upliftWhileToForLoop(RewriterBase & rewriter,scf::WhileOp loop)34 FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
35                                                       scf::WhileOp loop) {
36   Block *beforeBody = loop.getBeforeBody();
37   if (!llvm::hasSingleElement(beforeBody->without_terminator()))
38     return rewriter.notifyMatchFailure(loop, "Loop body must have single op");
39 
40   auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
41   if (!cmp)
42     return rewriter.notifyMatchFailure(loop,
43                                        "Loop body must have single cmp op");
44 
45   scf::ConditionOp beforeTerm = loop.getConditionOp();
46   if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult())
47     return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
48       diag << "Expected single condition use: " << *cmp;
49     });
50 
51   // All `before` block args must be directly forwarded to ConditionOp.
52   // They will be converted to `scf.for` `iter_vars` except induction var.
53   if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
54     return rewriter.notifyMatchFailure(loop, "Invalid args order");
55 
56   using Pred = arith::CmpIPredicate;
57   Pred predicate = cmp.getPredicate();
58   if (predicate != Pred::slt && predicate != Pred::sgt)
59     return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
60       diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
61     });
62 
63   BlockArgument inductionVar;
64   Value ub;
65   DominanceInfo dom;
66 
67   // Check if cmp has a suitable form. One of the arguments must be a `before`
68   // block arg, other must be defined outside `scf.while` and will be treated
69   // as upper bound.
70   for (bool reverse : {false, true}) {
71     auto expectedPred = reverse ? Pred::sgt : Pred::slt;
72     if (cmp.getPredicate() != expectedPred)
73       continue;
74 
75     auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
76     auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
77 
78     auto blockArg = dyn_cast<BlockArgument>(arg1);
79     if (!blockArg || blockArg.getOwner() != beforeBody)
80       continue;
81 
82     if (!dom.properlyDominates(arg2, loop))
83       continue;
84 
85     inductionVar = blockArg;
86     ub = arg2;
87     break;
88   }
89 
90   if (!inductionVar)
91     return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
92       diag << "Unrecognized cmp form: " << *cmp;
93     });
94 
95   // inductionVar must have 2 uses: one is in `cmp` and other is `condition`
96   // arg.
97   if (!llvm::hasNItems(inductionVar.getUses(), 2))
98     return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
99       diag << "Unrecognized induction var: " << inductionVar;
100     });
101 
102   Block *afterBody = loop.getAfterBody();
103   scf::YieldOp afterTerm = loop.getYieldOp();
104   unsigned argNumber = inductionVar.getArgNumber();
105   Value afterTermIndArg = afterTerm.getResults()[argNumber];
106 
107   Value inductionVarAfter = afterBody->getArgument(argNumber);
108 
109   // Find suitable `addi` op inside `after` block, one of the args must be an
110   // Induction var passed from `before` block and second arg must be defined
111   // outside of the loop and will be considered step value.
112   // TODO: Add `subi` support?
113   auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>();
114   if (!addOp)
115     return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
116 
117   Value step;
118   if (addOp.getLhs() == inductionVarAfter) {
119     step = addOp.getRhs();
120   } else if (addOp.getRhs() == inductionVarAfter) {
121     step = addOp.getLhs();
122   }
123 
124   if (!step || !dom.properlyDominates(step, loop))
125     return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");
126 
127   Value lb = loop.getInits()[argNumber];
128 
129   assert(lb.getType().isIntOrIndex());
130   assert(lb.getType() == ub.getType());
131   assert(lb.getType() == step.getType());
132 
133   llvm::SmallVector<Value> newArgs;
134 
135   // Populate inits for new `scf.for`, skip induction var.
136   newArgs.reserve(loop.getInits().size());
137   for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
138     if (i == argNumber)
139       continue;
140 
141     newArgs.emplace_back(init);
142   }
143 
144   Location loc = loop.getLoc();
145 
146   // With `builder == nullptr`, ForOp::build will try to insert terminator at
147   // the end of newly created block and we don't want it. Provide empty
148   // dummy builder instead.
149   auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
150   auto newLoop =
151       rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);
152 
153   Block *newBody = newLoop.getBody();
154 
155   // Populate block args for `scf.for` body, move induction var to the front.
156   newArgs.clear();
157   ValueRange newBodyArgs = newBody->getArguments();
158   for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
159     if (i < argNumber) {
160       newArgs.emplace_back(newBodyArgs[i + 1]);
161     } else if (i == argNumber) {
162       newArgs.emplace_back(newBodyArgs.front());
163     } else {
164       newArgs.emplace_back(newBodyArgs[i]);
165     }
166   }
167 
168   rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
169                              newArgs);
170 
171   auto term = cast<scf::YieldOp>(newBody->getTerminator());
172 
173   // Populate new yield args, skipping the induction var.
174   newArgs.clear();
175   for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
176     if (i == argNumber)
177       continue;
178 
179     newArgs.emplace_back(arg);
180   }
181 
182   OpBuilder::InsertionGuard g(rewriter);
183   rewriter.setInsertionPoint(term);
184   rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs);
185 
186   // Compute induction var value after loop execution.
187   rewriter.setInsertionPointAfter(newLoop);
188   Value one;
189   if (isa<IndexType>(step.getType())) {
190     one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
191   } else {
192     one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType());
193   }
194 
195   Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
196   Value len = rewriter.create<arith::SubIOp>(loc, ub, lb);
197   len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
198   len = rewriter.create<arith::DivSIOp>(loc, len, step);
199   len = rewriter.create<arith::SubIOp>(loc, len, one);
200   Value res = rewriter.create<arith::MulIOp>(loc, len, step);
201   res = rewriter.create<arith::AddIOp>(loc, lb, res);
202 
203   // Reconstruct `scf.while` results, inserting final induction var value
204   // into proper place.
205   newArgs.clear();
206   llvm::append_range(newArgs, newLoop.getResults());
207   newArgs.insert(newArgs.begin() + argNumber, res);
208   rewriter.replaceOp(loop, newArgs);
209   return newLoop;
210 }
211 
populateUpliftWhileToForPatterns(RewritePatternSet & patterns)212 void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
213   patterns.add<UpliftWhileOp>(patterns.getContext());
214 }
215