xref: /llvm-project/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp (revision 039b969b32b64b64123dce30dd28ec4e343d893f)
1 //===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.for, scf.if and loop.terminator
10 // ops into standard CFG ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18 #include "mlir/Dialect/SCF/IR/SCF.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "mlir/Transforms/Passes.h"
26 
27 using namespace mlir;
28 using namespace mlir::scf;
29 
30 namespace {
31 
32 struct SCFToControlFlowPass
33     : public SCFToControlFlowBase<SCFToControlFlowPass> {
34   void runOnOperation() override;
35 };
36 
37 // Create a CFG subgraph for the loop around its body blocks (if the body
38 // contained other loops, they have been already lowered to a flow of blocks).
39 // Maintain the invariants that a CFG subgraph created for any loop has a single
40 // entry and a single exit, and that the entry/exit blocks are respectively
41 // first/last blocks in the parent region.  The original loop operation is
42 // replaced by the initialization operations that set up the initial value of
43 // the loop induction variable (%iv) and computes the loop bounds that are loop-
44 // invariant for affine loops.  The operations following the original scf.for
45 // are split out into a separate continuation (exit) block. A condition block is
46 // created before the continuation block. It checks the exit condition of the
47 // loop and branches either to the continuation block, or to the first block of
48 // the body. The condition block takes as arguments the values of the induction
49 // variable followed by loop-carried values. Since it dominates both the body
50 // blocks and the continuation block, loop-carried values are visible in all of
51 // those blocks. Induction variable modification is appended to the last block
52 // of the body (which is the exit block from the body subgraph thanks to the
53 // invariant we maintain) along with a branch that loops back to the condition
54 // block. Loop-carried values are the loop terminator operands, which are
55 // forwarded to the branch.
56 //
57 //      +---------------------------------+
58 //      |   <code before the ForOp>       |
59 //      |   <definitions of %init...>     |
60 //      |   <compute initial %iv value>   |
61 //      |   cf.br cond(%iv, %init...)        |
62 //      +---------------------------------+
63 //             |
64 //  -------|   |
65 //  |      v   v
66 //  |   +--------------------------------+
67 //  |   | cond(%iv, %init...):           |
68 //  |   |   <compare %iv to upper bound> |
69 //  |   |   cf.cond_br %r, body, end        |
70 //  |   +--------------------------------+
71 //  |          |               |
72 //  |          |               -------------|
73 //  |          v                            |
74 //  |   +--------------------------------+  |
75 //  |   | body-first:                    |  |
76 //  |   |   <%init visible by dominance> |  |
77 //  |   |   <body contents>              |  |
78 //  |   +--------------------------------+  |
79 //  |                   |                   |
80 //  |                  ...                  |
81 //  |                   |                   |
82 //  |   +--------------------------------+  |
83 //  |   | body-last:                     |  |
84 //  |   |   <body contents>              |  |
85 //  |   |   <operands of yield = %yields>|  |
86 //  |   |   %new_iv =<add step to %iv>   |  |
87 //  |   |   cf.br cond(%new_iv, %yields)    |  |
88 //  |   +--------------------------------+  |
89 //  |          |                            |
90 //  |-----------        |--------------------
91 //                      v
92 //      +--------------------------------+
93 //      | end:                           |
94 //      |   <code after the ForOp>       |
95 //      |   <%init visible by dominance> |
96 //      +--------------------------------+
97 //
98 struct ForLowering : public OpRewritePattern<ForOp> {
99   using OpRewritePattern<ForOp>::OpRewritePattern;
100 
101   LogicalResult matchAndRewrite(ForOp forOp,
102                                 PatternRewriter &rewriter) const override;
103 };
104 
105 // Create a CFG subgraph for the scf.if operation (including its "then" and
106 // optional "else" operation blocks).  We maintain the invariants that the
107 // subgraph has a single entry and a single exit point, and that the entry/exit
108 // blocks are respectively the first/last block of the enclosing region. The
109 // operations following the scf.if are split into a continuation (subgraph
110 // exit) block. The condition is lowered to a chain of blocks that implement the
111 // short-circuit scheme. The "scf.if" operation is replaced with a conditional
112 // branch to either the first block of the "then" region, or to the first block
113 // of the "else" region. In these blocks, "scf.yield" is unconditional branches
114 // to the post-dominating block. When the "scf.if" does not return values, the
115 // post-dominating block is the same as the continuation block. When it returns
116 // values, the post-dominating block is a new block with arguments that
117 // correspond to the values returned by the "scf.if" that unconditionally
118 // branches to the continuation block. This allows block arguments to dominate
119 // any uses of the hitherto "scf.if" results that they replaced. (Inserting a
120 // new block allows us to avoid modifying the argument list of an existing
121 // block, which is illegal in a conversion pattern). When the "else" region is
122 // empty, which is only allowed for "scf.if"s that don't return values, the
123 // condition branches directly to the continuation block.
124 //
125 // CFG for a scf.if with else and without results.
126 //
127 //      +--------------------------------+
128 //      | <code before the IfOp>         |
129 //      | cf.cond_br %cond, %then, %else    |
130 //      +--------------------------------+
131 //             |              |
132 //             |              --------------|
133 //             v                            |
134 //      +--------------------------------+  |
135 //      | then:                          |  |
136 //      |   <then contents>              |  |
137 //      |   cf.br continue                  |  |
138 //      +--------------------------------+  |
139 //             |                            |
140 //   |----------               |-------------
141 //   |                         V
142 //   |  +--------------------------------+
143 //   |  | else:                          |
144 //   |  |   <else contents>              |
145 //   |  |   cf.br continue                  |
146 //   |  +--------------------------------+
147 //   |         |
148 //   ------|   |
149 //         v   v
150 //      +--------------------------------+
151 //      | continue:                      |
152 //      |   <code after the IfOp>        |
153 //      +--------------------------------+
154 //
155 // CFG for a scf.if with results.
156 //
157 //      +--------------------------------+
158 //      | <code before the IfOp>         |
159 //      | cf.cond_br %cond, %then, %else    |
160 //      +--------------------------------+
161 //             |              |
162 //             |              --------------|
163 //             v                            |
164 //      +--------------------------------+  |
165 //      | then:                          |  |
166 //      |   <then contents>              |  |
167 //      |   cf.br dom(%args...)             |  |
168 //      +--------------------------------+  |
169 //             |                            |
170 //   |----------               |-------------
171 //   |                         V
172 //   |  +--------------------------------+
173 //   |  | else:                          |
174 //   |  |   <else contents>              |
175 //   |  |   cf.br dom(%args...)             |
176 //   |  +--------------------------------+
177 //   |         |
178 //   ------|   |
179 //         v   v
180 //      +--------------------------------+
181 //      | dom(%args...):                 |
182 //      |   cf.br continue                  |
183 //      +--------------------------------+
184 //             |
185 //             v
186 //      +--------------------------------+
187 //      | continue:                      |
188 //      | <code after the IfOp>          |
189 //      +--------------------------------+
190 //
191 struct IfLowering : public OpRewritePattern<IfOp> {
192   using OpRewritePattern<IfOp>::OpRewritePattern;
193 
194   LogicalResult matchAndRewrite(IfOp ifOp,
195                                 PatternRewriter &rewriter) const override;
196 };
197 
198 struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
199   using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
200 
201   LogicalResult matchAndRewrite(ExecuteRegionOp op,
202                                 PatternRewriter &rewriter) const override;
203 };
204 
205 struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
206   using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
207 
208   LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
209                                 PatternRewriter &rewriter) const override;
210 };
211 
212 /// Create a CFG subgraph for this loop construct. The regions of the loop need
213 /// not be a single block anymore (for example, if other SCF constructs that
214 /// they contain have been already converted to CFG), but need to be single-exit
215 /// from the last block of each region. The operations following the original
216 /// WhileOp are split into a new continuation block. Both regions of the WhileOp
217 /// are inlined, and their terminators are rewritten to organize the control
218 /// flow implementing the loop as follows.
219 ///
220 ///      +---------------------------------+
221 ///      |   <code before the WhileOp>     |
222 ///      |   cf.br ^before(%operands...)      |
223 ///      +---------------------------------+
224 ///             |
225 ///  -------|   |
226 ///  |      v   v
227 ///  |   +--------------------------------+
228 ///  |   | ^before(%bargs...):            |
229 ///  |   |   %vals... = <some payload>    |
230 ///  |   +--------------------------------+
231 ///  |                   |
232 ///  |                  ...
233 ///  |                   |
234 ///  |   +--------------------------------+
235 ///  |   | ^before-last:
236 ///  |   |   %cond = <compute condition>  |
237 ///  |   |   cf.cond_br %cond,               |
238 ///  |   |        ^after(%vals...), ^cont |
239 ///  |   +--------------------------------+
240 ///  |          |               |
241 ///  |          |               -------------|
242 ///  |          v                            |
243 ///  |   +--------------------------------+  |
244 ///  |   | ^after(%aargs...):             |  |
245 ///  |   |   <body contents>              |  |
246 ///  |   +--------------------------------+  |
247 ///  |                   |                   |
248 ///  |                  ...                  |
249 ///  |                   |                   |
250 ///  |   +--------------------------------+  |
251 ///  |   | ^after-last:                   |  |
252 ///  |   |   %yields... = <some payload>  |  |
253 ///  |   |   cf.br ^before(%yields...)       |  |
254 ///  |   +--------------------------------+  |
255 ///  |          |                            |
256 ///  |-----------        |--------------------
257 ///                      v
258 ///      +--------------------------------+
259 ///      | ^cont:                         |
260 ///      |   <code after the WhileOp>     |
261 ///      |   <%vals from 'before' region  |
262 ///      |          visible by dominance> |
263 ///      +--------------------------------+
264 ///
265 /// Values are communicated between ex-regions (the groups of blocks that used
266 /// to form a region before inlining) through block arguments of their
267 /// entry blocks, which are visible in all other dominated blocks. Similarly,
268 /// the results of the WhileOp are defined in the 'before' region, which is
269 /// required to have a single existing block, and are therefore accessible in
270 /// the continuation block due to dominance.
271 struct WhileLowering : public OpRewritePattern<WhileOp> {
272   using OpRewritePattern<WhileOp>::OpRewritePattern;
273 
274   LogicalResult matchAndRewrite(WhileOp whileOp,
275                                 PatternRewriter &rewriter) const override;
276 };
277 
278 /// Optimized version of the above for the case of the "after" region merely
279 /// forwarding its arguments back to the "before" region (i.e., a "do-while"
280 /// loop). This avoid inlining the "after" region completely and branches back
281 /// to the "before" entry instead.
282 struct DoWhileLowering : public OpRewritePattern<WhileOp> {
283   using OpRewritePattern<WhileOp>::OpRewritePattern;
284 
285   LogicalResult matchAndRewrite(WhileOp whileOp,
286                                 PatternRewriter &rewriter) const override;
287 };
288 } // namespace
289 
290 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
291                                            PatternRewriter &rewriter) const {
292   Location loc = forOp.getLoc();
293 
294   // Start by splitting the block containing the 'scf.for' into two parts.
295   // The part before will get the init code, the part after will be the end
296   // point.
297   auto *initBlock = rewriter.getInsertionBlock();
298   auto initPosition = rewriter.getInsertionPoint();
299   auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
300 
301   // Use the first block of the loop body as the condition block since it is the
302   // block that has the induction variable and loop-carried values as arguments.
303   // Split out all operations from the first block into a new block. Move all
304   // body blocks from the loop body region to the region containing the loop.
305   auto *conditionBlock = &forOp.getRegion().front();
306   auto *firstBodyBlock =
307       rewriter.splitBlock(conditionBlock, conditionBlock->begin());
308   auto *lastBodyBlock = &forOp.getRegion().back();
309   rewriter.inlineRegionBefore(forOp.getRegion(), endBlock);
310   auto iv = conditionBlock->getArgument(0);
311 
312   // Append the induction variable stepping logic to the last body block and
313   // branch back to the condition block. Loop-carried values are taken from
314   // operands of the loop terminator.
315   Operation *terminator = lastBodyBlock->getTerminator();
316   rewriter.setInsertionPointToEnd(lastBodyBlock);
317   auto step = forOp.getStep();
318   auto stepped = rewriter.create<arith::AddIOp>(loc, iv, step).getResult();
319   if (!stepped)
320     return failure();
321 
322   SmallVector<Value, 8> loopCarried;
323   loopCarried.push_back(stepped);
324   loopCarried.append(terminator->operand_begin(), terminator->operand_end());
325   rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried);
326   rewriter.eraseOp(terminator);
327 
328   // Compute loop bounds before branching to the condition.
329   rewriter.setInsertionPointToEnd(initBlock);
330   Value lowerBound = forOp.getLowerBound();
331   Value upperBound = forOp.getUpperBound();
332   if (!lowerBound || !upperBound)
333     return failure();
334 
335   // The initial values of loop-carried values is obtained from the operands
336   // of the loop operation.
337   SmallVector<Value, 8> destOperands;
338   destOperands.push_back(lowerBound);
339   auto iterOperands = forOp.getIterOperands();
340   destOperands.append(iterOperands.begin(), iterOperands.end());
341   rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands);
342 
343   // With the body block done, we can fill in the condition block.
344   rewriter.setInsertionPointToEnd(conditionBlock);
345   auto comparison = rewriter.create<arith::CmpIOp>(
346       loc, arith::CmpIPredicate::slt, iv, upperBound);
347 
348   rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock,
349                                     ArrayRef<Value>(), endBlock,
350                                     ArrayRef<Value>());
351   // The result of the loop operation is the values of the condition block
352   // arguments except the induction variable on the last iteration.
353   rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
354   return success();
355 }
356 
357 LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
358                                           PatternRewriter &rewriter) const {
359   auto loc = ifOp.getLoc();
360 
361   // Start by splitting the block containing the 'scf.if' into two parts.
362   // The part before will contain the condition, the part after will be the
363   // continuation point.
364   auto *condBlock = rewriter.getInsertionBlock();
365   auto opPosition = rewriter.getInsertionPoint();
366   auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
367   Block *continueBlock;
368   if (ifOp.getNumResults() == 0) {
369     continueBlock = remainingOpsBlock;
370   } else {
371     continueBlock =
372         rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
373                              SmallVector<Location>(ifOp.getNumResults(), loc));
374     rewriter.create<cf::BranchOp>(loc, remainingOpsBlock);
375   }
376 
377   // Move blocks from the "then" region to the region containing 'scf.if',
378   // place it before the continuation block, and branch to it.
379   auto &thenRegion = ifOp.getThenRegion();
380   auto *thenBlock = &thenRegion.front();
381   Operation *thenTerminator = thenRegion.back().getTerminator();
382   ValueRange thenTerminatorOperands = thenTerminator->getOperands();
383   rewriter.setInsertionPointToEnd(&thenRegion.back());
384   rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
385   rewriter.eraseOp(thenTerminator);
386   rewriter.inlineRegionBefore(thenRegion, continueBlock);
387 
388   // Move blocks from the "else" region (if present) to the region containing
389   // 'scf.if', place it before the continuation block and branch to it.  It
390   // will be placed after the "then" regions.
391   auto *elseBlock = continueBlock;
392   auto &elseRegion = ifOp.getElseRegion();
393   if (!elseRegion.empty()) {
394     elseBlock = &elseRegion.front();
395     Operation *elseTerminator = elseRegion.back().getTerminator();
396     ValueRange elseTerminatorOperands = elseTerminator->getOperands();
397     rewriter.setInsertionPointToEnd(&elseRegion.back());
398     rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
399     rewriter.eraseOp(elseTerminator);
400     rewriter.inlineRegionBefore(elseRegion, continueBlock);
401   }
402 
403   rewriter.setInsertionPointToEnd(condBlock);
404   rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
405                                     /*trueArgs=*/ArrayRef<Value>(), elseBlock,
406                                     /*falseArgs=*/ArrayRef<Value>());
407 
408   // Ok, we're done!
409   rewriter.replaceOp(ifOp, continueBlock->getArguments());
410   return success();
411 }
412 
413 LogicalResult
414 ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
415                                        PatternRewriter &rewriter) const {
416   auto loc = op.getLoc();
417 
418   auto *condBlock = rewriter.getInsertionBlock();
419   auto opPosition = rewriter.getInsertionPoint();
420   auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
421 
422   auto &region = op.getRegion();
423   rewriter.setInsertionPointToEnd(condBlock);
424   rewriter.create<cf::BranchOp>(loc, &region.front());
425 
426   for (Block &block : region) {
427     if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
428       ValueRange terminatorOperands = terminator->getOperands();
429       rewriter.setInsertionPointToEnd(&block);
430       rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
431       rewriter.eraseOp(terminator);
432     }
433   }
434 
435   rewriter.inlineRegionBefore(region, remainingOpsBlock);
436 
437   SmallVector<Value> vals;
438   SmallVector<Location> argLocs(op.getNumResults(), op->getLoc());
439   for (auto arg :
440        remainingOpsBlock->addArguments(op->getResultTypes(), argLocs))
441     vals.push_back(arg);
442   rewriter.replaceOp(op, vals);
443   return success();
444 }
445 
446 LogicalResult
447 ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
448                                   PatternRewriter &rewriter) const {
449   Location loc = parallelOp.getLoc();
450 
451   // For a parallel loop, we essentially need to create an n-dimensional loop
452   // nest. We do this by translating to scf.for ops and have those lowered in
453   // a further rewrite. If a parallel loop contains reductions (and thus returns
454   // values), forward the initial values for the reductions down the loop
455   // hierarchy and bubble up the results by modifying the "yield" terminator.
456   SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.getInitVals());
457   SmallVector<Value, 4> ivs;
458   ivs.reserve(parallelOp.getNumLoops());
459   bool first = true;
460   SmallVector<Value, 4> loopResults(iterArgs);
461   for (auto [iv, lower, upper, step] :
462        llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
463                  parallelOp.getUpperBound(), parallelOp.getStep())) {
464     ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
465     ivs.push_back(forOp.getInductionVar());
466     auto iterRange = forOp.getRegionIterArgs();
467     iterArgs.assign(iterRange.begin(), iterRange.end());
468 
469     if (first) {
470       // Store the results of the outermost loop that will be used to replace
471       // the results of the parallel loop when it is fully rewritten.
472       loopResults.assign(forOp.result_begin(), forOp.result_end());
473       first = false;
474     } else if (!forOp.getResults().empty()) {
475       // A loop is constructed with an empty "yield" terminator if there are
476       // no results.
477       rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
478       rewriter.create<scf::YieldOp>(loc, forOp.getResults());
479     }
480 
481     rewriter.setInsertionPointToStart(forOp.getBody());
482   }
483 
484   // First, merge reduction blocks into the main region.
485   SmallVector<Value, 4> yieldOperands;
486   yieldOperands.reserve(parallelOp.getNumResults());
487   for (auto &op : *parallelOp.getBody()) {
488     auto reduce = dyn_cast<ReduceOp>(op);
489     if (!reduce)
490       continue;
491 
492     Block &reduceBlock = reduce.getReductionOperator().front();
493     Value arg = iterArgs[yieldOperands.size()];
494     yieldOperands.push_back(reduceBlock.getTerminator()->getOperand(0));
495     rewriter.eraseOp(reduceBlock.getTerminator());
496     rewriter.mergeBlockBefore(&reduceBlock, &op, {arg, reduce.getOperand()});
497     rewriter.eraseOp(reduce);
498   }
499 
500   // Then merge the loop body without the terminator.
501   rewriter.eraseOp(parallelOp.getBody()->getTerminator());
502   Block *newBody = rewriter.getInsertionBlock();
503   if (newBody->empty())
504     rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs);
505   else
506     rewriter.mergeBlockBefore(parallelOp.getBody(), newBody->getTerminator(),
507                               ivs);
508 
509   // Finally, create the terminator if required (for loops with no results, it
510   // has been already created in loop construction).
511   if (!yieldOperands.empty()) {
512     rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
513     rewriter.create<scf::YieldOp>(loc, yieldOperands);
514   }
515 
516   rewriter.replaceOp(parallelOp, loopResults);
517 
518   return success();
519 }
520 
521 LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
522                                              PatternRewriter &rewriter) const {
523   OpBuilder::InsertionGuard guard(rewriter);
524   Location loc = whileOp.getLoc();
525 
526   // Split the current block before the WhileOp to create the inlining point.
527   Block *currentBlock = rewriter.getInsertionBlock();
528   Block *continuation =
529       rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
530 
531   // Inline both regions.
532   Block *after = &whileOp.getAfter().front();
533   Block *afterLast = &whileOp.getAfter().back();
534   Block *before = &whileOp.getBefore().front();
535   Block *beforeLast = &whileOp.getBefore().back();
536   rewriter.inlineRegionBefore(whileOp.getAfter(), continuation);
537   rewriter.inlineRegionBefore(whileOp.getBefore(), after);
538 
539   // Branch to the "before" region.
540   rewriter.setInsertionPointToEnd(currentBlock);
541   rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits());
542 
543   // Replace terminators with branches. Assuming bodies are SESE, which holds
544   // given only the patterns from this file, we only need to look at the last
545   // block. This should be reconsidered if we allow break/continue in SCF.
546   rewriter.setInsertionPointToEnd(beforeLast);
547   auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
548   rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
549                                                 after, condOp.getArgs(),
550                                                 continuation, ValueRange());
551 
552   rewriter.setInsertionPointToEnd(afterLast);
553   auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator());
554   rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
555                                             yieldOp.getResults());
556 
557   // Replace the op with values "yielded" from the "before" region, which are
558   // visible by dominance.
559   rewriter.replaceOp(whileOp, condOp.getArgs());
560 
561   return success();
562 }
563 
564 LogicalResult
565 DoWhileLowering::matchAndRewrite(WhileOp whileOp,
566                                  PatternRewriter &rewriter) const {
567   if (!llvm::hasSingleElement(whileOp.getAfter()))
568     return rewriter.notifyMatchFailure(whileOp,
569                                        "do-while simplification applicable to "
570                                        "single-block 'after' region only");
571 
572   Block &afterBlock = whileOp.getAfter().front();
573   if (!llvm::hasSingleElement(afterBlock))
574     return rewriter.notifyMatchFailure(whileOp,
575                                        "do-while simplification applicable "
576                                        "only if 'after' region has no payload");
577 
578   auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front());
579   if (!yield || yield.getResults() != afterBlock.getArguments())
580     return rewriter.notifyMatchFailure(whileOp,
581                                        "do-while simplification applicable "
582                                        "only to forwarding 'after' regions");
583 
584   // Split the current block before the WhileOp to create the inlining point.
585   OpBuilder::InsertionGuard guard(rewriter);
586   Block *currentBlock = rewriter.getInsertionBlock();
587   Block *continuation =
588       rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
589 
590   // Only the "before" region should be inlined.
591   Block *before = &whileOp.getBefore().front();
592   Block *beforeLast = &whileOp.getBefore().back();
593   rewriter.inlineRegionBefore(whileOp.getBefore(), continuation);
594 
595   // Branch to the "before" region.
596   rewriter.setInsertionPointToEnd(currentBlock);
597   rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
598 
599   // Loop around the "before" region based on condition.
600   rewriter.setInsertionPointToEnd(beforeLast);
601   auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
602   rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
603                                                 before, condOp.getArgs(),
604                                                 continuation, ValueRange());
605 
606   // Replace the op with values "yielded" from the "before" region, which are
607   // visible by dominance.
608   rewriter.replaceOp(whileOp, condOp.getArgs());
609 
610   return success();
611 }
612 
613 void mlir::populateSCFToControlFlowConversionPatterns(
614     RewritePatternSet &patterns) {
615   patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
616                ExecuteRegionLowering>(patterns.getContext());
617   patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
618 }
619 
620 void SCFToControlFlowPass::runOnOperation() {
621   RewritePatternSet patterns(&getContext());
622   populateSCFToControlFlowConversionPatterns(patterns);
623 
624   // Configure conversion to lower out SCF operations.
625   ConversionTarget target(getContext());
626   target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
627                       scf::ExecuteRegionOp>();
628   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
629   if (failed(
630           applyPartialConversion(getOperation(), target, std::move(patterns))))
631     signalPassFailure();
632 }
633 
634 std::unique_ptr<Pass> mlir::createConvertSCFToCFPass() {
635   return std::make_unique<SCFToControlFlowPass>();
636 }
637