xref: /llvm-project/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp (revision c02fd17c1e20615c9e6174a3f8ad4ef0ec5ebbec)
1 //===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.for, scf.if and loop.terminator
10 // ops into standard CFG ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
15 
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/IRMapping.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "mlir/Transforms/Passes.h"
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_SCFTOCONTROLFLOW
31 #include "mlir/Conversion/Passes.h.inc"
32 } // namespace mlir
33 
34 using namespace mlir;
35 using namespace mlir::scf;
36 
37 namespace {
38 
39 struct SCFToControlFlowPass
40     : public impl::SCFToControlFlowBase<SCFToControlFlowPass> {
41   void runOnOperation() override;
42 };
43 
44 // Create a CFG subgraph for the loop around its body blocks (if the body
45 // contained other loops, they have been already lowered to a flow of blocks).
46 // Maintain the invariants that a CFG subgraph created for any loop has a single
47 // entry and a single exit, and that the entry/exit blocks are respectively
48 // first/last blocks in the parent region.  The original loop operation is
49 // replaced by the initialization operations that set up the initial value of
50 // the loop induction variable (%iv) and computes the loop bounds that are loop-
51 // invariant for affine loops.  The operations following the original scf.for
52 // are split out into a separate continuation (exit) block. A condition block is
53 // created before the continuation block. It checks the exit condition of the
54 // loop and branches either to the continuation block, or to the first block of
55 // the body. The condition block takes as arguments the values of the induction
56 // variable followed by loop-carried values. Since it dominates both the body
57 // blocks and the continuation block, loop-carried values are visible in all of
58 // those blocks. Induction variable modification is appended to the last block
59 // of the body (which is the exit block from the body subgraph thanks to the
60 // invariant we maintain) along with a branch that loops back to the condition
61 // block. Loop-carried values are the loop terminator operands, which are
62 // forwarded to the branch.
63 //
64 //      +---------------------------------+
65 //      |   <code before the ForOp>       |
66 //      |   <definitions of %init...>     |
67 //      |   <compute initial %iv value>   |
68 //      |   cf.br cond(%iv, %init...)        |
69 //      +---------------------------------+
70 //             |
71 //  -------|   |
72 //  |      v   v
73 //  |   +--------------------------------+
74 //  |   | cond(%iv, %init...):           |
75 //  |   |   <compare %iv to upper bound> |
76 //  |   |   cf.cond_br %r, body, end        |
77 //  |   +--------------------------------+
78 //  |          |               |
79 //  |          |               -------------|
80 //  |          v                            |
81 //  |   +--------------------------------+  |
82 //  |   | body-first:                    |  |
83 //  |   |   <%init visible by dominance> |  |
84 //  |   |   <body contents>              |  |
85 //  |   +--------------------------------+  |
86 //  |                   |                   |
87 //  |                  ...                  |
88 //  |                   |                   |
89 //  |   +--------------------------------+  |
90 //  |   | body-last:                     |  |
91 //  |   |   <body contents>              |  |
92 //  |   |   <operands of yield = %yields>|  |
93 //  |   |   %new_iv =<add step to %iv>   |  |
94 //  |   |   cf.br cond(%new_iv, %yields)    |  |
95 //  |   +--------------------------------+  |
96 //  |          |                            |
97 //  |-----------        |--------------------
98 //                      v
99 //      +--------------------------------+
100 //      | end:                           |
101 //      |   <code after the ForOp>       |
102 //      |   <%init visible by dominance> |
103 //      +--------------------------------+
104 //
105 struct ForLowering : public OpRewritePattern<ForOp> {
106   using OpRewritePattern<ForOp>::OpRewritePattern;
107 
108   LogicalResult matchAndRewrite(ForOp forOp,
109                                 PatternRewriter &rewriter) const override;
110 };
111 
112 // Create a CFG subgraph for the scf.if operation (including its "then" and
113 // optional "else" operation blocks).  We maintain the invariants that the
114 // subgraph has a single entry and a single exit point, and that the entry/exit
115 // blocks are respectively the first/last block of the enclosing region. The
116 // operations following the scf.if are split into a continuation (subgraph
117 // exit) block. The condition is lowered to a chain of blocks that implement the
118 // short-circuit scheme. The "scf.if" operation is replaced with a conditional
119 // branch to either the first block of the "then" region, or to the first block
120 // of the "else" region. In these blocks, "scf.yield" is unconditional branches
121 // to the post-dominating block. When the "scf.if" does not return values, the
122 // post-dominating block is the same as the continuation block. When it returns
123 // values, the post-dominating block is a new block with arguments that
124 // correspond to the values returned by the "scf.if" that unconditionally
125 // branches to the continuation block. This allows block arguments to dominate
126 // any uses of the hitherto "scf.if" results that they replaced. (Inserting a
127 // new block allows us to avoid modifying the argument list of an existing
128 // block, which is illegal in a conversion pattern). When the "else" region is
129 // empty, which is only allowed for "scf.if"s that don't return values, the
130 // condition branches directly to the continuation block.
131 //
132 // CFG for a scf.if with else and without results.
133 //
134 //      +--------------------------------+
135 //      | <code before the IfOp>         |
136 //      | cf.cond_br %cond, %then, %else    |
137 //      +--------------------------------+
138 //             |              |
139 //             |              --------------|
140 //             v                            |
141 //      +--------------------------------+  |
142 //      | then:                          |  |
143 //      |   <then contents>              |  |
144 //      |   cf.br continue                  |  |
145 //      +--------------------------------+  |
146 //             |                            |
147 //   |----------               |-------------
148 //   |                         V
149 //   |  +--------------------------------+
150 //   |  | else:                          |
151 //   |  |   <else contents>              |
152 //   |  |   cf.br continue                  |
153 //   |  +--------------------------------+
154 //   |         |
155 //   ------|   |
156 //         v   v
157 //      +--------------------------------+
158 //      | continue:                      |
159 //      |   <code after the IfOp>        |
160 //      +--------------------------------+
161 //
162 // CFG for a scf.if with results.
163 //
164 //      +--------------------------------+
165 //      | <code before the IfOp>         |
166 //      | cf.cond_br %cond, %then, %else    |
167 //      +--------------------------------+
168 //             |              |
169 //             |              --------------|
170 //             v                            |
171 //      +--------------------------------+  |
172 //      | then:                          |  |
173 //      |   <then contents>              |  |
174 //      |   cf.br dom(%args...)             |  |
175 //      +--------------------------------+  |
176 //             |                            |
177 //   |----------               |-------------
178 //   |                         V
179 //   |  +--------------------------------+
180 //   |  | else:                          |
181 //   |  |   <else contents>              |
182 //   |  |   cf.br dom(%args...)             |
183 //   |  +--------------------------------+
184 //   |         |
185 //   ------|   |
186 //         v   v
187 //      +--------------------------------+
188 //      | dom(%args...):                 |
189 //      |   cf.br continue                  |
190 //      +--------------------------------+
191 //             |
192 //             v
193 //      +--------------------------------+
194 //      | continue:                      |
195 //      | <code after the IfOp>          |
196 //      +--------------------------------+
197 //
198 struct IfLowering : public OpRewritePattern<IfOp> {
199   using OpRewritePattern<IfOp>::OpRewritePattern;
200 
201   LogicalResult matchAndRewrite(IfOp ifOp,
202                                 PatternRewriter &rewriter) const override;
203 };
204 
205 struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
206   using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
207 
208   LogicalResult matchAndRewrite(ExecuteRegionOp op,
209                                 PatternRewriter &rewriter) const override;
210 };
211 
212 struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
213   using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
214 
215   LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
216                                 PatternRewriter &rewriter) const override;
217 };
218 
219 /// Create a CFG subgraph for this loop construct. The regions of the loop need
220 /// not be a single block anymore (for example, if other SCF constructs that
221 /// they contain have been already converted to CFG), but need to be single-exit
222 /// from the last block of each region. The operations following the original
223 /// WhileOp are split into a new continuation block. Both regions of the WhileOp
224 /// are inlined, and their terminators are rewritten to organize the control
225 /// flow implementing the loop as follows.
226 ///
227 ///      +---------------------------------+
228 ///      |   <code before the WhileOp>     |
229 ///      |   cf.br ^before(%operands...)      |
230 ///      +---------------------------------+
231 ///             |
232 ///  -------|   |
233 ///  |      v   v
234 ///  |   +--------------------------------+
235 ///  |   | ^before(%bargs...):            |
236 ///  |   |   %vals... = <some payload>    |
237 ///  |   +--------------------------------+
238 ///  |                   |
239 ///  |                  ...
240 ///  |                   |
241 ///  |   +--------------------------------+
242 ///  |   | ^before-last:
243 ///  |   |   %cond = <compute condition>  |
244 ///  |   |   cf.cond_br %cond,               |
245 ///  |   |        ^after(%vals...), ^cont |
246 ///  |   +--------------------------------+
247 ///  |          |               |
248 ///  |          |               -------------|
249 ///  |          v                            |
250 ///  |   +--------------------------------+  |
251 ///  |   | ^after(%aargs...):             |  |
252 ///  |   |   <body contents>              |  |
253 ///  |   +--------------------------------+  |
254 ///  |                   |                   |
255 ///  |                  ...                  |
256 ///  |                   |                   |
257 ///  |   +--------------------------------+  |
258 ///  |   | ^after-last:                   |  |
259 ///  |   |   %yields... = <some payload>  |  |
260 ///  |   |   cf.br ^before(%yields...)       |  |
261 ///  |   +--------------------------------+  |
262 ///  |          |                            |
263 ///  |-----------        |--------------------
264 ///                      v
265 ///      +--------------------------------+
266 ///      | ^cont:                         |
267 ///      |   <code after the WhileOp>     |
268 ///      |   <%vals from 'before' region  |
269 ///      |          visible by dominance> |
270 ///      +--------------------------------+
271 ///
272 /// Values are communicated between ex-regions (the groups of blocks that used
273 /// to form a region before inlining) through block arguments of their
274 /// entry blocks, which are visible in all other dominated blocks. Similarly,
275 /// the results of the WhileOp are defined in the 'before' region, which is
276 /// required to have a single existing block, and are therefore accessible in
277 /// the continuation block due to dominance.
278 struct WhileLowering : public OpRewritePattern<WhileOp> {
279   using OpRewritePattern<WhileOp>::OpRewritePattern;
280 
281   LogicalResult matchAndRewrite(WhileOp whileOp,
282                                 PatternRewriter &rewriter) const override;
283 };
284 
285 /// Optimized version of the above for the case of the "after" region merely
286 /// forwarding its arguments back to the "before" region (i.e., a "do-while"
287 /// loop). This avoid inlining the "after" region completely and branches back
288 /// to the "before" entry instead.
289 struct DoWhileLowering : public OpRewritePattern<WhileOp> {
290   using OpRewritePattern<WhileOp>::OpRewritePattern;
291 
292   LogicalResult matchAndRewrite(WhileOp whileOp,
293                                 PatternRewriter &rewriter) const override;
294 };
295 
296 /// Lower an `scf.index_switch` operation to a `cf.switch` operation.
297 struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
298   using OpRewritePattern::OpRewritePattern;
299 
300   LogicalResult matchAndRewrite(IndexSwitchOp op,
301                                 PatternRewriter &rewriter) const override;
302 };
303 
304 /// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it
305 /// has no shared outputs. Ops with shared outputs should be bufferized first.
306 /// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other
307 /// dialects/passes.
308 struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
309   using OpRewritePattern<mlir::scf::ForallOp>::OpRewritePattern;
310 
311   LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
312                                 PatternRewriter &rewriter) const override;
313 };
314 
315 } // namespace
316 
317 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
318                                            PatternRewriter &rewriter) const {
319   Location loc = forOp.getLoc();
320 
321   // Start by splitting the block containing the 'scf.for' into two parts.
322   // The part before will get the init code, the part after will be the end
323   // point.
324   auto *initBlock = rewriter.getInsertionBlock();
325   auto initPosition = rewriter.getInsertionPoint();
326   auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
327 
328   // Use the first block of the loop body as the condition block since it is the
329   // block that has the induction variable and loop-carried values as arguments.
330   // Split out all operations from the first block into a new block. Move all
331   // body blocks from the loop body region to the region containing the loop.
332   auto *conditionBlock = &forOp.getRegion().front();
333   auto *firstBodyBlock =
334       rewriter.splitBlock(conditionBlock, conditionBlock->begin());
335   auto *lastBodyBlock = &forOp.getRegion().back();
336   rewriter.inlineRegionBefore(forOp.getRegion(), endBlock);
337   auto iv = conditionBlock->getArgument(0);
338 
339   // Append the induction variable stepping logic to the last body block and
340   // branch back to the condition block. Loop-carried values are taken from
341   // operands of the loop terminator.
342   Operation *terminator = lastBodyBlock->getTerminator();
343   rewriter.setInsertionPointToEnd(lastBodyBlock);
344   auto step = forOp.getStep();
345   auto stepped = rewriter.create<arith::AddIOp>(loc, iv, step).getResult();
346   if (!stepped)
347     return failure();
348 
349   SmallVector<Value, 8> loopCarried;
350   loopCarried.push_back(stepped);
351   loopCarried.append(terminator->operand_begin(), terminator->operand_end());
352   rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried);
353   rewriter.eraseOp(terminator);
354 
355   // Compute loop bounds before branching to the condition.
356   rewriter.setInsertionPointToEnd(initBlock);
357   Value lowerBound = forOp.getLowerBound();
358   Value upperBound = forOp.getUpperBound();
359   if (!lowerBound || !upperBound)
360     return failure();
361 
362   // The initial values of loop-carried values is obtained from the operands
363   // of the loop operation.
364   SmallVector<Value, 8> destOperands;
365   destOperands.push_back(lowerBound);
366   llvm::append_range(destOperands, forOp.getInitArgs());
367   rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands);
368 
369   // With the body block done, we can fill in the condition block.
370   rewriter.setInsertionPointToEnd(conditionBlock);
371   auto comparison = rewriter.create<arith::CmpIOp>(
372       loc, arith::CmpIPredicate::slt, iv, upperBound);
373 
374   auto condBranchOp = rewriter.create<cf::CondBranchOp>(
375       loc, comparison, firstBodyBlock, ArrayRef<Value>(), endBlock,
376       ArrayRef<Value>());
377 
378   // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
379   // llvm.loop_annotation attribute.
380   SmallVector<NamedAttribute> llvmAttrs;
381   llvm::copy_if(forOp->getAttrs(), std::back_inserter(llvmAttrs),
382                 [](auto attr) {
383                   return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
384                 });
385   condBranchOp->setDiscardableAttrs(llvmAttrs);
386   // The result of the loop operation is the values of the condition block
387   // arguments except the induction variable on the last iteration.
388   rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
389   return success();
390 }
391 
392 LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
393                                           PatternRewriter &rewriter) const {
394   auto loc = ifOp.getLoc();
395 
396   // Start by splitting the block containing the 'scf.if' into two parts.
397   // The part before will contain the condition, the part after will be the
398   // continuation point.
399   auto *condBlock = rewriter.getInsertionBlock();
400   auto opPosition = rewriter.getInsertionPoint();
401   auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
402   Block *continueBlock;
403   if (ifOp.getNumResults() == 0) {
404     continueBlock = remainingOpsBlock;
405   } else {
406     continueBlock =
407         rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
408                              SmallVector<Location>(ifOp.getNumResults(), loc));
409     rewriter.create<cf::BranchOp>(loc, remainingOpsBlock);
410   }
411 
412   // Move blocks from the "then" region to the region containing 'scf.if',
413   // place it before the continuation block, and branch to it.
414   auto &thenRegion = ifOp.getThenRegion();
415   auto *thenBlock = &thenRegion.front();
416   Operation *thenTerminator = thenRegion.back().getTerminator();
417   ValueRange thenTerminatorOperands = thenTerminator->getOperands();
418   rewriter.setInsertionPointToEnd(&thenRegion.back());
419   rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
420   rewriter.eraseOp(thenTerminator);
421   rewriter.inlineRegionBefore(thenRegion, continueBlock);
422 
423   // Move blocks from the "else" region (if present) to the region containing
424   // 'scf.if', place it before the continuation block and branch to it.  It
425   // will be placed after the "then" regions.
426   auto *elseBlock = continueBlock;
427   auto &elseRegion = ifOp.getElseRegion();
428   if (!elseRegion.empty()) {
429     elseBlock = &elseRegion.front();
430     Operation *elseTerminator = elseRegion.back().getTerminator();
431     ValueRange elseTerminatorOperands = elseTerminator->getOperands();
432     rewriter.setInsertionPointToEnd(&elseRegion.back());
433     rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
434     rewriter.eraseOp(elseTerminator);
435     rewriter.inlineRegionBefore(elseRegion, continueBlock);
436   }
437 
438   rewriter.setInsertionPointToEnd(condBlock);
439   rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
440                                     /*trueArgs=*/ArrayRef<Value>(), elseBlock,
441                                     /*falseArgs=*/ArrayRef<Value>());
442 
443   // Ok, we're done!
444   rewriter.replaceOp(ifOp, continueBlock->getArguments());
445   return success();
446 }
447 
448 LogicalResult
449 ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
450                                        PatternRewriter &rewriter) const {
451   auto loc = op.getLoc();
452 
453   auto *condBlock = rewriter.getInsertionBlock();
454   auto opPosition = rewriter.getInsertionPoint();
455   auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
456 
457   auto &region = op.getRegion();
458   rewriter.setInsertionPointToEnd(condBlock);
459   rewriter.create<cf::BranchOp>(loc, &region.front());
460 
461   for (Block &block : region) {
462     if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
463       ValueRange terminatorOperands = terminator->getOperands();
464       rewriter.setInsertionPointToEnd(&block);
465       rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
466       rewriter.eraseOp(terminator);
467     }
468   }
469 
470   rewriter.inlineRegionBefore(region, remainingOpsBlock);
471 
472   SmallVector<Value> vals;
473   SmallVector<Location> argLocs(op.getNumResults(), op->getLoc());
474   for (auto arg :
475        remainingOpsBlock->addArguments(op->getResultTypes(), argLocs))
476     vals.push_back(arg);
477   rewriter.replaceOp(op, vals);
478   return success();
479 }
480 
481 LogicalResult
482 ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
483                                   PatternRewriter &rewriter) const {
484   Location loc = parallelOp.getLoc();
485   auto reductionOp = dyn_cast<ReduceOp>(parallelOp.getBody()->getTerminator());
486   if (!reductionOp) {
487     return failure();
488   }
489 
490   // For a parallel loop, we essentially need to create an n-dimensional loop
491   // nest. We do this by translating to scf.for ops and have those lowered in
492   // a further rewrite. If a parallel loop contains reductions (and thus returns
493   // values), forward the initial values for the reductions down the loop
494   // hierarchy and bubble up the results by modifying the "yield" terminator.
495   SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.getInitVals());
496   SmallVector<Value, 4> ivs;
497   ivs.reserve(parallelOp.getNumLoops());
498   bool first = true;
499   SmallVector<Value, 4> loopResults(iterArgs);
500   for (auto [iv, lower, upper, step] :
501        llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
502                  parallelOp.getUpperBound(), parallelOp.getStep())) {
503     ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
504     ivs.push_back(forOp.getInductionVar());
505     auto iterRange = forOp.getRegionIterArgs();
506     iterArgs.assign(iterRange.begin(), iterRange.end());
507 
508     if (first) {
509       // Store the results of the outermost loop that will be used to replace
510       // the results of the parallel loop when it is fully rewritten.
511       loopResults.assign(forOp.result_begin(), forOp.result_end());
512       first = false;
513     } else if (!forOp.getResults().empty()) {
514       // A loop is constructed with an empty "yield" terminator if there are
515       // no results.
516       rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
517       rewriter.create<scf::YieldOp>(loc, forOp.getResults());
518     }
519 
520     rewriter.setInsertionPointToStart(forOp.getBody());
521   }
522 
523   // First, merge reduction blocks into the main region.
524   SmallVector<Value> yieldOperands;
525   yieldOperands.reserve(parallelOp.getNumResults());
526   for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) {
527     Block &reductionBody = reductionOp.getReductions()[i].front();
528     Value arg = iterArgs[yieldOperands.size()];
529     yieldOperands.push_back(
530         cast<ReduceReturnOp>(reductionBody.getTerminator()).getResult());
531     rewriter.eraseOp(reductionBody.getTerminator());
532     rewriter.inlineBlockBefore(&reductionBody, reductionOp,
533                                {arg, reductionOp.getOperands()[i]});
534   }
535   rewriter.eraseOp(reductionOp);
536 
537   // Then merge the loop body without the terminator.
538   Block *newBody = rewriter.getInsertionBlock();
539   if (newBody->empty())
540     rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs);
541   else
542     rewriter.inlineBlockBefore(parallelOp.getBody(), newBody->getTerminator(),
543                                ivs);
544 
545   // Finally, create the terminator if required (for loops with no results, it
546   // has been already created in loop construction).
547   if (!yieldOperands.empty()) {
548     rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
549     rewriter.create<scf::YieldOp>(loc, yieldOperands);
550   }
551 
552   rewriter.replaceOp(parallelOp, loopResults);
553 
554   return success();
555 }
556 
557 LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
558                                              PatternRewriter &rewriter) const {
559   OpBuilder::InsertionGuard guard(rewriter);
560   Location loc = whileOp.getLoc();
561 
562   // Split the current block before the WhileOp to create the inlining point.
563   Block *currentBlock = rewriter.getInsertionBlock();
564   Block *continuation =
565       rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
566 
567   // Inline both regions.
568   Block *after = whileOp.getAfterBody();
569   Block *before = whileOp.getBeforeBody();
570   rewriter.inlineRegionBefore(whileOp.getAfter(), continuation);
571   rewriter.inlineRegionBefore(whileOp.getBefore(), after);
572 
573   // Branch to the "before" region.
574   rewriter.setInsertionPointToEnd(currentBlock);
575   rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits());
576 
577   // Replace terminators with branches. Assuming bodies are SESE, which holds
578   // given only the patterns from this file, we only need to look at the last
579   // block. This should be reconsidered if we allow break/continue in SCF.
580   rewriter.setInsertionPointToEnd(before);
581   auto condOp = cast<ConditionOp>(before->getTerminator());
582   rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
583                                                 after, condOp.getArgs(),
584                                                 continuation, ValueRange());
585 
586   rewriter.setInsertionPointToEnd(after);
587   auto yieldOp = cast<scf::YieldOp>(after->getTerminator());
588   rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
589                                             yieldOp.getResults());
590 
591   // Replace the op with values "yielded" from the "before" region, which are
592   // visible by dominance.
593   rewriter.replaceOp(whileOp, condOp.getArgs());
594 
595   return success();
596 }
597 
598 LogicalResult
599 DoWhileLowering::matchAndRewrite(WhileOp whileOp,
600                                  PatternRewriter &rewriter) const {
601   Block &afterBlock = *whileOp.getAfterBody();
602   if (!llvm::hasSingleElement(afterBlock))
603     return rewriter.notifyMatchFailure(whileOp,
604                                        "do-while simplification applicable "
605                                        "only if 'after' region has no payload");
606 
607   auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front());
608   if (!yield || yield.getResults() != afterBlock.getArguments())
609     return rewriter.notifyMatchFailure(whileOp,
610                                        "do-while simplification applicable "
611                                        "only to forwarding 'after' regions");
612 
613   // Split the current block before the WhileOp to create the inlining point.
614   OpBuilder::InsertionGuard guard(rewriter);
615   Block *currentBlock = rewriter.getInsertionBlock();
616   Block *continuation =
617       rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
618 
619   // Only the "before" region should be inlined.
620   Block *before = whileOp.getBeforeBody();
621   rewriter.inlineRegionBefore(whileOp.getBefore(), continuation);
622 
623   // Branch to the "before" region.
624   rewriter.setInsertionPointToEnd(currentBlock);
625   rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
626 
627   // Loop around the "before" region based on condition.
628   rewriter.setInsertionPointToEnd(before);
629   auto condOp = cast<ConditionOp>(before->getTerminator());
630   rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
631                                                 before, condOp.getArgs(),
632                                                 continuation, ValueRange());
633 
634   // Replace the op with values "yielded" from the "before" region, which are
635   // visible by dominance.
636   rewriter.replaceOp(whileOp, condOp.getArgs());
637 
638   return success();
639 }
640 
641 LogicalResult
642 IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
643                                      PatternRewriter &rewriter) const {
644   // Split the block at the op.
645   Block *condBlock = rewriter.getInsertionBlock();
646   Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op));
647 
648   // Create the arguments on the continue block with which to replace the
649   // results of the op.
650   SmallVector<Value> results;
651   results.reserve(op.getNumResults());
652   for (Type resultType : op.getResultTypes())
653     results.push_back(continueBlock->addArgument(resultType, op.getLoc()));
654 
655   // Handle the regions.
656   auto convertRegion = [&](Region &region) -> FailureOr<Block *> {
657     Block *block = &region.front();
658 
659     // Convert the yield terminator to a branch to the continue block.
660     auto yield = cast<scf::YieldOp>(block->getTerminator());
661     rewriter.setInsertionPoint(yield);
662     rewriter.replaceOpWithNewOp<cf::BranchOp>(yield, continueBlock,
663                                               yield.getOperands());
664 
665     // Inline the region.
666     rewriter.inlineRegionBefore(region, continueBlock);
667     return block;
668   };
669 
670   // Convert the case regions.
671   SmallVector<Block *> caseSuccessors;
672   SmallVector<int32_t> caseValues;
673   caseSuccessors.reserve(op.getCases().size());
674   caseValues.reserve(op.getCases().size());
675   for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
676     FailureOr<Block *> block = convertRegion(region);
677     if (failed(block))
678       return failure();
679     caseSuccessors.push_back(*block);
680     caseValues.push_back(value);
681   }
682 
683   // Convert the default region.
684   FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion());
685   if (failed(defaultBlock))
686     return failure();
687 
688   // Create the switch.
689   rewriter.setInsertionPointToEnd(condBlock);
690   SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {});
691 
692   // Cast switch index to integer case value.
693   Value caseValue = rewriter.create<arith::IndexCastOp>(
694       op.getLoc(), rewriter.getI32Type(), op.getArg());
695 
696   rewriter.create<cf::SwitchOp>(
697       op.getLoc(), caseValue, *defaultBlock, ValueRange(),
698       rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands);
699   rewriter.replaceOp(op, continueBlock->getArguments());
700   return success();
701 }
702 
703 LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
704                                               PatternRewriter &rewriter) const {
705   return scf::forallToParallelLoop(rewriter, forallOp);
706 }
707 
708 void mlir::populateSCFToControlFlowConversionPatterns(
709     RewritePatternSet &patterns) {
710   patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
711                WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
712       patterns.getContext());
713   patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
714 }
715 
716 void SCFToControlFlowPass::runOnOperation() {
717   RewritePatternSet patterns(&getContext());
718   populateSCFToControlFlowConversionPatterns(patterns);
719 
720   // Configure conversion to lower out SCF operations.
721   ConversionTarget target(getContext());
722   target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
723                       scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
724   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
725   if (failed(
726           applyPartialConversion(getOperation(), target, std::move(patterns))))
727     signalPassFailure();
728 }
729 
730 std::unique_ptr<Pass> mlir::createConvertSCFToCFPass() {
731   return std::make_unique<SCFToControlFlowPass>();
732 }
733