1 //===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert scf.for, scf.if and loop.terminator 10 // ops into standard CFG ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" 15 #include "../PassDetail.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 18 #include "mlir/Dialect/SCF/IR/SCF.h" 19 #include "mlir/IR/BlockAndValueMapping.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/BuiltinOps.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "mlir/Transforms/Passes.h" 26 27 using namespace mlir; 28 using namespace mlir::scf; 29 30 namespace { 31 32 struct SCFToControlFlowPass 33 : public SCFToControlFlowBase<SCFToControlFlowPass> { 34 void runOnOperation() override; 35 }; 36 37 // Create a CFG subgraph for the loop around its body blocks (if the body 38 // contained other loops, they have been already lowered to a flow of blocks). 39 // Maintain the invariants that a CFG subgraph created for any loop has a single 40 // entry and a single exit, and that the entry/exit blocks are respectively 41 // first/last blocks in the parent region. The original loop operation is 42 // replaced by the initialization operations that set up the initial value of 43 // the loop induction variable (%iv) and computes the loop bounds that are loop- 44 // invariant for affine loops. The operations following the original scf.for 45 // are split out into a separate continuation (exit) block. A condition block is 46 // created before the continuation block. It checks the exit condition of the 47 // loop and branches either to the continuation block, or to the first block of 48 // the body. The condition block takes as arguments the values of the induction 49 // variable followed by loop-carried values. Since it dominates both the body 50 // blocks and the continuation block, loop-carried values are visible in all of 51 // those blocks. Induction variable modification is appended to the last block 52 // of the body (which is the exit block from the body subgraph thanks to the 53 // invariant we maintain) along with a branch that loops back to the condition 54 // block. Loop-carried values are the loop terminator operands, which are 55 // forwarded to the branch. 56 // 57 // +---------------------------------+ 58 // | <code before the ForOp> | 59 // | <definitions of %init...> | 60 // | <compute initial %iv value> | 61 // | cf.br cond(%iv, %init...) | 62 // +---------------------------------+ 63 // | 64 // -------| | 65 // | v v 66 // | +--------------------------------+ 67 // | | cond(%iv, %init...): | 68 // | | <compare %iv to upper bound> | 69 // | | cf.cond_br %r, body, end | 70 // | +--------------------------------+ 71 // | | | 72 // | | -------------| 73 // | v | 74 // | +--------------------------------+ | 75 // | | body-first: | | 76 // | | <%init visible by dominance> | | 77 // | | <body contents> | | 78 // | +--------------------------------+ | 79 // | | | 80 // | ... | 81 // | | | 82 // | +--------------------------------+ | 83 // | | body-last: | | 84 // | | <body contents> | | 85 // | | <operands of yield = %yields>| | 86 // | | %new_iv =<add step to %iv> | | 87 // | | cf.br cond(%new_iv, %yields) | | 88 // | +--------------------------------+ | 89 // | | | 90 // |----------- |-------------------- 91 // v 92 // +--------------------------------+ 93 // | end: | 94 // | <code after the ForOp> | 95 // | <%init visible by dominance> | 96 // +--------------------------------+ 97 // 98 struct ForLowering : public OpRewritePattern<ForOp> { 99 using OpRewritePattern<ForOp>::OpRewritePattern; 100 101 LogicalResult matchAndRewrite(ForOp forOp, 102 PatternRewriter &rewriter) const override; 103 }; 104 105 // Create a CFG subgraph for the scf.if operation (including its "then" and 106 // optional "else" operation blocks). We maintain the invariants that the 107 // subgraph has a single entry and a single exit point, and that the entry/exit 108 // blocks are respectively the first/last block of the enclosing region. The 109 // operations following the scf.if are split into a continuation (subgraph 110 // exit) block. The condition is lowered to a chain of blocks that implement the 111 // short-circuit scheme. The "scf.if" operation is replaced with a conditional 112 // branch to either the first block of the "then" region, or to the first block 113 // of the "else" region. In these blocks, "scf.yield" is unconditional branches 114 // to the post-dominating block. When the "scf.if" does not return values, the 115 // post-dominating block is the same as the continuation block. When it returns 116 // values, the post-dominating block is a new block with arguments that 117 // correspond to the values returned by the "scf.if" that unconditionally 118 // branches to the continuation block. This allows block arguments to dominate 119 // any uses of the hitherto "scf.if" results that they replaced. (Inserting a 120 // new block allows us to avoid modifying the argument list of an existing 121 // block, which is illegal in a conversion pattern). When the "else" region is 122 // empty, which is only allowed for "scf.if"s that don't return values, the 123 // condition branches directly to the continuation block. 124 // 125 // CFG for a scf.if with else and without results. 126 // 127 // +--------------------------------+ 128 // | <code before the IfOp> | 129 // | cf.cond_br %cond, %then, %else | 130 // +--------------------------------+ 131 // | | 132 // | --------------| 133 // v | 134 // +--------------------------------+ | 135 // | then: | | 136 // | <then contents> | | 137 // | cf.br continue | | 138 // +--------------------------------+ | 139 // | | 140 // |---------- |------------- 141 // | V 142 // | +--------------------------------+ 143 // | | else: | 144 // | | <else contents> | 145 // | | cf.br continue | 146 // | +--------------------------------+ 147 // | | 148 // ------| | 149 // v v 150 // +--------------------------------+ 151 // | continue: | 152 // | <code after the IfOp> | 153 // +--------------------------------+ 154 // 155 // CFG for a scf.if with results. 156 // 157 // +--------------------------------+ 158 // | <code before the IfOp> | 159 // | cf.cond_br %cond, %then, %else | 160 // +--------------------------------+ 161 // | | 162 // | --------------| 163 // v | 164 // +--------------------------------+ | 165 // | then: | | 166 // | <then contents> | | 167 // | cf.br dom(%args...) | | 168 // +--------------------------------+ | 169 // | | 170 // |---------- |------------- 171 // | V 172 // | +--------------------------------+ 173 // | | else: | 174 // | | <else contents> | 175 // | | cf.br dom(%args...) | 176 // | +--------------------------------+ 177 // | | 178 // ------| | 179 // v v 180 // +--------------------------------+ 181 // | dom(%args...): | 182 // | cf.br continue | 183 // +--------------------------------+ 184 // | 185 // v 186 // +--------------------------------+ 187 // | continue: | 188 // | <code after the IfOp> | 189 // +--------------------------------+ 190 // 191 struct IfLowering : public OpRewritePattern<IfOp> { 192 using OpRewritePattern<IfOp>::OpRewritePattern; 193 194 LogicalResult matchAndRewrite(IfOp ifOp, 195 PatternRewriter &rewriter) const override; 196 }; 197 198 struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> { 199 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; 200 201 LogicalResult matchAndRewrite(ExecuteRegionOp op, 202 PatternRewriter &rewriter) const override; 203 }; 204 205 struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> { 206 using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern; 207 208 LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp, 209 PatternRewriter &rewriter) const override; 210 }; 211 212 /// Create a CFG subgraph for this loop construct. The regions of the loop need 213 /// not be a single block anymore (for example, if other SCF constructs that 214 /// they contain have been already converted to CFG), but need to be single-exit 215 /// from the last block of each region. The operations following the original 216 /// WhileOp are split into a new continuation block. Both regions of the WhileOp 217 /// are inlined, and their terminators are rewritten to organize the control 218 /// flow implementing the loop as follows. 219 /// 220 /// +---------------------------------+ 221 /// | <code before the WhileOp> | 222 /// | cf.br ^before(%operands...) | 223 /// +---------------------------------+ 224 /// | 225 /// -------| | 226 /// | v v 227 /// | +--------------------------------+ 228 /// | | ^before(%bargs...): | 229 /// | | %vals... = <some payload> | 230 /// | +--------------------------------+ 231 /// | | 232 /// | ... 233 /// | | 234 /// | +--------------------------------+ 235 /// | | ^before-last: 236 /// | | %cond = <compute condition> | 237 /// | | cf.cond_br %cond, | 238 /// | | ^after(%vals...), ^cont | 239 /// | +--------------------------------+ 240 /// | | | 241 /// | | -------------| 242 /// | v | 243 /// | +--------------------------------+ | 244 /// | | ^after(%aargs...): | | 245 /// | | <body contents> | | 246 /// | +--------------------------------+ | 247 /// | | | 248 /// | ... | 249 /// | | | 250 /// | +--------------------------------+ | 251 /// | | ^after-last: | | 252 /// | | %yields... = <some payload> | | 253 /// | | cf.br ^before(%yields...) | | 254 /// | +--------------------------------+ | 255 /// | | | 256 /// |----------- |-------------------- 257 /// v 258 /// +--------------------------------+ 259 /// | ^cont: | 260 /// | <code after the WhileOp> | 261 /// | <%vals from 'before' region | 262 /// | visible by dominance> | 263 /// +--------------------------------+ 264 /// 265 /// Values are communicated between ex-regions (the groups of blocks that used 266 /// to form a region before inlining) through block arguments of their 267 /// entry blocks, which are visible in all other dominated blocks. Similarly, 268 /// the results of the WhileOp are defined in the 'before' region, which is 269 /// required to have a single existing block, and are therefore accessible in 270 /// the continuation block due to dominance. 271 struct WhileLowering : public OpRewritePattern<WhileOp> { 272 using OpRewritePattern<WhileOp>::OpRewritePattern; 273 274 LogicalResult matchAndRewrite(WhileOp whileOp, 275 PatternRewriter &rewriter) const override; 276 }; 277 278 /// Optimized version of the above for the case of the "after" region merely 279 /// forwarding its arguments back to the "before" region (i.e., a "do-while" 280 /// loop). This avoid inlining the "after" region completely and branches back 281 /// to the "before" entry instead. 282 struct DoWhileLowering : public OpRewritePattern<WhileOp> { 283 using OpRewritePattern<WhileOp>::OpRewritePattern; 284 285 LogicalResult matchAndRewrite(WhileOp whileOp, 286 PatternRewriter &rewriter) const override; 287 }; 288 } // namespace 289 290 LogicalResult ForLowering::matchAndRewrite(ForOp forOp, 291 PatternRewriter &rewriter) const { 292 Location loc = forOp.getLoc(); 293 294 // Start by splitting the block containing the 'scf.for' into two parts. 295 // The part before will get the init code, the part after will be the end 296 // point. 297 auto *initBlock = rewriter.getInsertionBlock(); 298 auto initPosition = rewriter.getInsertionPoint(); 299 auto *endBlock = rewriter.splitBlock(initBlock, initPosition); 300 301 // Use the first block of the loop body as the condition block since it is the 302 // block that has the induction variable and loop-carried values as arguments. 303 // Split out all operations from the first block into a new block. Move all 304 // body blocks from the loop body region to the region containing the loop. 305 auto *conditionBlock = &forOp.getRegion().front(); 306 auto *firstBodyBlock = 307 rewriter.splitBlock(conditionBlock, conditionBlock->begin()); 308 auto *lastBodyBlock = &forOp.getRegion().back(); 309 rewriter.inlineRegionBefore(forOp.getRegion(), endBlock); 310 auto iv = conditionBlock->getArgument(0); 311 312 // Append the induction variable stepping logic to the last body block and 313 // branch back to the condition block. Loop-carried values are taken from 314 // operands of the loop terminator. 315 Operation *terminator = lastBodyBlock->getTerminator(); 316 rewriter.setInsertionPointToEnd(lastBodyBlock); 317 auto step = forOp.getStep(); 318 auto stepped = rewriter.create<arith::AddIOp>(loc, iv, step).getResult(); 319 if (!stepped) 320 return failure(); 321 322 SmallVector<Value, 8> loopCarried; 323 loopCarried.push_back(stepped); 324 loopCarried.append(terminator->operand_begin(), terminator->operand_end()); 325 rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried); 326 rewriter.eraseOp(terminator); 327 328 // Compute loop bounds before branching to the condition. 329 rewriter.setInsertionPointToEnd(initBlock); 330 Value lowerBound = forOp.getLowerBound(); 331 Value upperBound = forOp.getUpperBound(); 332 if (!lowerBound || !upperBound) 333 return failure(); 334 335 // The initial values of loop-carried values is obtained from the operands 336 // of the loop operation. 337 SmallVector<Value, 8> destOperands; 338 destOperands.push_back(lowerBound); 339 auto iterOperands = forOp.getIterOperands(); 340 destOperands.append(iterOperands.begin(), iterOperands.end()); 341 rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands); 342 343 // With the body block done, we can fill in the condition block. 344 rewriter.setInsertionPointToEnd(conditionBlock); 345 auto comparison = rewriter.create<arith::CmpIOp>( 346 loc, arith::CmpIPredicate::slt, iv, upperBound); 347 348 rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock, 349 ArrayRef<Value>(), endBlock, 350 ArrayRef<Value>()); 351 // The result of the loop operation is the values of the condition block 352 // arguments except the induction variable on the last iteration. 353 rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front()); 354 return success(); 355 } 356 357 LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, 358 PatternRewriter &rewriter) const { 359 auto loc = ifOp.getLoc(); 360 361 // Start by splitting the block containing the 'scf.if' into two parts. 362 // The part before will contain the condition, the part after will be the 363 // continuation point. 364 auto *condBlock = rewriter.getInsertionBlock(); 365 auto opPosition = rewriter.getInsertionPoint(); 366 auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); 367 Block *continueBlock; 368 if (ifOp.getNumResults() == 0) { 369 continueBlock = remainingOpsBlock; 370 } else { 371 continueBlock = 372 rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(), 373 SmallVector<Location>(ifOp.getNumResults(), loc)); 374 rewriter.create<cf::BranchOp>(loc, remainingOpsBlock); 375 } 376 377 // Move blocks from the "then" region to the region containing 'scf.if', 378 // place it before the continuation block, and branch to it. 379 auto &thenRegion = ifOp.getThenRegion(); 380 auto *thenBlock = &thenRegion.front(); 381 Operation *thenTerminator = thenRegion.back().getTerminator(); 382 ValueRange thenTerminatorOperands = thenTerminator->getOperands(); 383 rewriter.setInsertionPointToEnd(&thenRegion.back()); 384 rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands); 385 rewriter.eraseOp(thenTerminator); 386 rewriter.inlineRegionBefore(thenRegion, continueBlock); 387 388 // Move blocks from the "else" region (if present) to the region containing 389 // 'scf.if', place it before the continuation block and branch to it. It 390 // will be placed after the "then" regions. 391 auto *elseBlock = continueBlock; 392 auto &elseRegion = ifOp.getElseRegion(); 393 if (!elseRegion.empty()) { 394 elseBlock = &elseRegion.front(); 395 Operation *elseTerminator = elseRegion.back().getTerminator(); 396 ValueRange elseTerminatorOperands = elseTerminator->getOperands(); 397 rewriter.setInsertionPointToEnd(&elseRegion.back()); 398 rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands); 399 rewriter.eraseOp(elseTerminator); 400 rewriter.inlineRegionBefore(elseRegion, continueBlock); 401 } 402 403 rewriter.setInsertionPointToEnd(condBlock); 404 rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock, 405 /*trueArgs=*/ArrayRef<Value>(), elseBlock, 406 /*falseArgs=*/ArrayRef<Value>()); 407 408 // Ok, we're done! 409 rewriter.replaceOp(ifOp, continueBlock->getArguments()); 410 return success(); 411 } 412 413 LogicalResult 414 ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op, 415 PatternRewriter &rewriter) const { 416 auto loc = op.getLoc(); 417 418 auto *condBlock = rewriter.getInsertionBlock(); 419 auto opPosition = rewriter.getInsertionPoint(); 420 auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); 421 422 auto ®ion = op.getRegion(); 423 rewriter.setInsertionPointToEnd(condBlock); 424 rewriter.create<cf::BranchOp>(loc, ®ion.front()); 425 426 for (Block &block : region) { 427 if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) { 428 ValueRange terminatorOperands = terminator->getOperands(); 429 rewriter.setInsertionPointToEnd(&block); 430 rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands); 431 rewriter.eraseOp(terminator); 432 } 433 } 434 435 rewriter.inlineRegionBefore(region, remainingOpsBlock); 436 437 SmallVector<Value> vals; 438 SmallVector<Location> argLocs(op.getNumResults(), op->getLoc()); 439 for (auto arg : 440 remainingOpsBlock->addArguments(op->getResultTypes(), argLocs)) 441 vals.push_back(arg); 442 rewriter.replaceOp(op, vals); 443 return success(); 444 } 445 446 LogicalResult 447 ParallelLowering::matchAndRewrite(ParallelOp parallelOp, 448 PatternRewriter &rewriter) const { 449 Location loc = parallelOp.getLoc(); 450 451 // For a parallel loop, we essentially need to create an n-dimensional loop 452 // nest. We do this by translating to scf.for ops and have those lowered in 453 // a further rewrite. If a parallel loop contains reductions (and thus returns 454 // values), forward the initial values for the reductions down the loop 455 // hierarchy and bubble up the results by modifying the "yield" terminator. 456 SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.getInitVals()); 457 SmallVector<Value, 4> ivs; 458 ivs.reserve(parallelOp.getNumLoops()); 459 bool first = true; 460 SmallVector<Value, 4> loopResults(iterArgs); 461 for (auto [iv, lower, upper, step] : 462 llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(), 463 parallelOp.getUpperBound(), parallelOp.getStep())) { 464 ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs); 465 ivs.push_back(forOp.getInductionVar()); 466 auto iterRange = forOp.getRegionIterArgs(); 467 iterArgs.assign(iterRange.begin(), iterRange.end()); 468 469 if (first) { 470 // Store the results of the outermost loop that will be used to replace 471 // the results of the parallel loop when it is fully rewritten. 472 loopResults.assign(forOp.result_begin(), forOp.result_end()); 473 first = false; 474 } else if (!forOp.getResults().empty()) { 475 // A loop is constructed with an empty "yield" terminator if there are 476 // no results. 477 rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); 478 rewriter.create<scf::YieldOp>(loc, forOp.getResults()); 479 } 480 481 rewriter.setInsertionPointToStart(forOp.getBody()); 482 } 483 484 // First, merge reduction blocks into the main region. 485 SmallVector<Value, 4> yieldOperands; 486 yieldOperands.reserve(parallelOp.getNumResults()); 487 for (auto &op : *parallelOp.getBody()) { 488 auto reduce = dyn_cast<ReduceOp>(op); 489 if (!reduce) 490 continue; 491 492 Block &reduceBlock = reduce.getReductionOperator().front(); 493 Value arg = iterArgs[yieldOperands.size()]; 494 yieldOperands.push_back(reduceBlock.getTerminator()->getOperand(0)); 495 rewriter.eraseOp(reduceBlock.getTerminator()); 496 rewriter.mergeBlockBefore(&reduceBlock, &op, {arg, reduce.getOperand()}); 497 rewriter.eraseOp(reduce); 498 } 499 500 // Then merge the loop body without the terminator. 501 rewriter.eraseOp(parallelOp.getBody()->getTerminator()); 502 Block *newBody = rewriter.getInsertionBlock(); 503 if (newBody->empty()) 504 rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs); 505 else 506 rewriter.mergeBlockBefore(parallelOp.getBody(), newBody->getTerminator(), 507 ivs); 508 509 // Finally, create the terminator if required (for loops with no results, it 510 // has been already created in loop construction). 511 if (!yieldOperands.empty()) { 512 rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); 513 rewriter.create<scf::YieldOp>(loc, yieldOperands); 514 } 515 516 rewriter.replaceOp(parallelOp, loopResults); 517 518 return success(); 519 } 520 521 LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, 522 PatternRewriter &rewriter) const { 523 OpBuilder::InsertionGuard guard(rewriter); 524 Location loc = whileOp.getLoc(); 525 526 // Split the current block before the WhileOp to create the inlining point. 527 Block *currentBlock = rewriter.getInsertionBlock(); 528 Block *continuation = 529 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 530 531 // Inline both regions. 532 Block *after = &whileOp.getAfter().front(); 533 Block *afterLast = &whileOp.getAfter().back(); 534 Block *before = &whileOp.getBefore().front(); 535 Block *beforeLast = &whileOp.getBefore().back(); 536 rewriter.inlineRegionBefore(whileOp.getAfter(), continuation); 537 rewriter.inlineRegionBefore(whileOp.getBefore(), after); 538 539 // Branch to the "before" region. 540 rewriter.setInsertionPointToEnd(currentBlock); 541 rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits()); 542 543 // Replace terminators with branches. Assuming bodies are SESE, which holds 544 // given only the patterns from this file, we only need to look at the last 545 // block. This should be reconsidered if we allow break/continue in SCF. 546 rewriter.setInsertionPointToEnd(beforeLast); 547 auto condOp = cast<ConditionOp>(beforeLast->getTerminator()); 548 rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(), 549 after, condOp.getArgs(), 550 continuation, ValueRange()); 551 552 rewriter.setInsertionPointToEnd(afterLast); 553 auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator()); 554 rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before, 555 yieldOp.getResults()); 556 557 // Replace the op with values "yielded" from the "before" region, which are 558 // visible by dominance. 559 rewriter.replaceOp(whileOp, condOp.getArgs()); 560 561 return success(); 562 } 563 564 LogicalResult 565 DoWhileLowering::matchAndRewrite(WhileOp whileOp, 566 PatternRewriter &rewriter) const { 567 if (!llvm::hasSingleElement(whileOp.getAfter())) 568 return rewriter.notifyMatchFailure(whileOp, 569 "do-while simplification applicable to " 570 "single-block 'after' region only"); 571 572 Block &afterBlock = whileOp.getAfter().front(); 573 if (!llvm::hasSingleElement(afterBlock)) 574 return rewriter.notifyMatchFailure(whileOp, 575 "do-while simplification applicable " 576 "only if 'after' region has no payload"); 577 578 auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front()); 579 if (!yield || yield.getResults() != afterBlock.getArguments()) 580 return rewriter.notifyMatchFailure(whileOp, 581 "do-while simplification applicable " 582 "only to forwarding 'after' regions"); 583 584 // Split the current block before the WhileOp to create the inlining point. 585 OpBuilder::InsertionGuard guard(rewriter); 586 Block *currentBlock = rewriter.getInsertionBlock(); 587 Block *continuation = 588 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 589 590 // Only the "before" region should be inlined. 591 Block *before = &whileOp.getBefore().front(); 592 Block *beforeLast = &whileOp.getBefore().back(); 593 rewriter.inlineRegionBefore(whileOp.getBefore(), continuation); 594 595 // Branch to the "before" region. 596 rewriter.setInsertionPointToEnd(currentBlock); 597 rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits()); 598 599 // Loop around the "before" region based on condition. 600 rewriter.setInsertionPointToEnd(beforeLast); 601 auto condOp = cast<ConditionOp>(beforeLast->getTerminator()); 602 rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(), 603 before, condOp.getArgs(), 604 continuation, ValueRange()); 605 606 // Replace the op with values "yielded" from the "before" region, which are 607 // visible by dominance. 608 rewriter.replaceOp(whileOp, condOp.getArgs()); 609 610 return success(); 611 } 612 613 void mlir::populateSCFToControlFlowConversionPatterns( 614 RewritePatternSet &patterns) { 615 patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering, 616 ExecuteRegionLowering>(patterns.getContext()); 617 patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2); 618 } 619 620 void SCFToControlFlowPass::runOnOperation() { 621 RewritePatternSet patterns(&getContext()); 622 populateSCFToControlFlowConversionPatterns(patterns); 623 624 // Configure conversion to lower out SCF operations. 625 ConversionTarget target(getContext()); 626 target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp, 627 scf::ExecuteRegionOp>(); 628 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); 629 if (failed( 630 applyPartialConversion(getOperation(), target, std::move(patterns)))) 631 signalPassFailure(); 632 } 633 634 std::unique_ptr<Pass> mlir::createConvertSCFToCFPass() { 635 return std::make_unique<SCFToControlFlowPass>(); 636 } 637