1 //===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert scf.for, scf.if and loop.terminator 10 // ops into standard CFG ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" 15 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/SCF/IR/SCF.h" 20 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/IRMapping.h" 24 #include "mlir/IR/MLIRContext.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/Transforms/DialectConversion.h" 27 #include "mlir/Transforms/Passes.h" 28 29 namespace mlir { 30 #define GEN_PASS_DEF_SCFTOCONTROLFLOW 31 #include "mlir/Conversion/Passes.h.inc" 32 } // namespace mlir 33 34 using namespace mlir; 35 using namespace mlir::scf; 36 37 namespace { 38 39 struct SCFToControlFlowPass 40 : public impl::SCFToControlFlowBase<SCFToControlFlowPass> { 41 void runOnOperation() override; 42 }; 43 44 // Create a CFG subgraph for the loop around its body blocks (if the body 45 // contained other loops, they have been already lowered to a flow of blocks). 46 // Maintain the invariants that a CFG subgraph created for any loop has a single 47 // entry and a single exit, and that the entry/exit blocks are respectively 48 // first/last blocks in the parent region. The original loop operation is 49 // replaced by the initialization operations that set up the initial value of 50 // the loop induction variable (%iv) and computes the loop bounds that are loop- 51 // invariant for affine loops. The operations following the original scf.for 52 // are split out into a separate continuation (exit) block. A condition block is 53 // created before the continuation block. It checks the exit condition of the 54 // loop and branches either to the continuation block, or to the first block of 55 // the body. The condition block takes as arguments the values of the induction 56 // variable followed by loop-carried values. Since it dominates both the body 57 // blocks and the continuation block, loop-carried values are visible in all of 58 // those blocks. Induction variable modification is appended to the last block 59 // of the body (which is the exit block from the body subgraph thanks to the 60 // invariant we maintain) along with a branch that loops back to the condition 61 // block. Loop-carried values are the loop terminator operands, which are 62 // forwarded to the branch. 63 // 64 // +---------------------------------+ 65 // | <code before the ForOp> | 66 // | <definitions of %init...> | 67 // | <compute initial %iv value> | 68 // | cf.br cond(%iv, %init...) | 69 // +---------------------------------+ 70 // | 71 // -------| | 72 // | v v 73 // | +--------------------------------+ 74 // | | cond(%iv, %init...): | 75 // | | <compare %iv to upper bound> | 76 // | | cf.cond_br %r, body, end | 77 // | +--------------------------------+ 78 // | | | 79 // | | -------------| 80 // | v | 81 // | +--------------------------------+ | 82 // | | body-first: | | 83 // | | <%init visible by dominance> | | 84 // | | <body contents> | | 85 // | +--------------------------------+ | 86 // | | | 87 // | ... | 88 // | | | 89 // | +--------------------------------+ | 90 // | | body-last: | | 91 // | | <body contents> | | 92 // | | <operands of yield = %yields>| | 93 // | | %new_iv =<add step to %iv> | | 94 // | | cf.br cond(%new_iv, %yields) | | 95 // | +--------------------------------+ | 96 // | | | 97 // |----------- |-------------------- 98 // v 99 // +--------------------------------+ 100 // | end: | 101 // | <code after the ForOp> | 102 // | <%init visible by dominance> | 103 // +--------------------------------+ 104 // 105 struct ForLowering : public OpRewritePattern<ForOp> { 106 using OpRewritePattern<ForOp>::OpRewritePattern; 107 108 LogicalResult matchAndRewrite(ForOp forOp, 109 PatternRewriter &rewriter) const override; 110 }; 111 112 // Create a CFG subgraph for the scf.if operation (including its "then" and 113 // optional "else" operation blocks). We maintain the invariants that the 114 // subgraph has a single entry and a single exit point, and that the entry/exit 115 // blocks are respectively the first/last block of the enclosing region. The 116 // operations following the scf.if are split into a continuation (subgraph 117 // exit) block. The condition is lowered to a chain of blocks that implement the 118 // short-circuit scheme. The "scf.if" operation is replaced with a conditional 119 // branch to either the first block of the "then" region, or to the first block 120 // of the "else" region. In these blocks, "scf.yield" is unconditional branches 121 // to the post-dominating block. When the "scf.if" does not return values, the 122 // post-dominating block is the same as the continuation block. When it returns 123 // values, the post-dominating block is a new block with arguments that 124 // correspond to the values returned by the "scf.if" that unconditionally 125 // branches to the continuation block. This allows block arguments to dominate 126 // any uses of the hitherto "scf.if" results that they replaced. (Inserting a 127 // new block allows us to avoid modifying the argument list of an existing 128 // block, which is illegal in a conversion pattern). When the "else" region is 129 // empty, which is only allowed for "scf.if"s that don't return values, the 130 // condition branches directly to the continuation block. 131 // 132 // CFG for a scf.if with else and without results. 133 // 134 // +--------------------------------+ 135 // | <code before the IfOp> | 136 // | cf.cond_br %cond, %then, %else | 137 // +--------------------------------+ 138 // | | 139 // | --------------| 140 // v | 141 // +--------------------------------+ | 142 // | then: | | 143 // | <then contents> | | 144 // | cf.br continue | | 145 // +--------------------------------+ | 146 // | | 147 // |---------- |------------- 148 // | V 149 // | +--------------------------------+ 150 // | | else: | 151 // | | <else contents> | 152 // | | cf.br continue | 153 // | +--------------------------------+ 154 // | | 155 // ------| | 156 // v v 157 // +--------------------------------+ 158 // | continue: | 159 // | <code after the IfOp> | 160 // +--------------------------------+ 161 // 162 // CFG for a scf.if with results. 163 // 164 // +--------------------------------+ 165 // | <code before the IfOp> | 166 // | cf.cond_br %cond, %then, %else | 167 // +--------------------------------+ 168 // | | 169 // | --------------| 170 // v | 171 // +--------------------------------+ | 172 // | then: | | 173 // | <then contents> | | 174 // | cf.br dom(%args...) | | 175 // +--------------------------------+ | 176 // | | 177 // |---------- |------------- 178 // | V 179 // | +--------------------------------+ 180 // | | else: | 181 // | | <else contents> | 182 // | | cf.br dom(%args...) | 183 // | +--------------------------------+ 184 // | | 185 // ------| | 186 // v v 187 // +--------------------------------+ 188 // | dom(%args...): | 189 // | cf.br continue | 190 // +--------------------------------+ 191 // | 192 // v 193 // +--------------------------------+ 194 // | continue: | 195 // | <code after the IfOp> | 196 // +--------------------------------+ 197 // 198 struct IfLowering : public OpRewritePattern<IfOp> { 199 using OpRewritePattern<IfOp>::OpRewritePattern; 200 201 LogicalResult matchAndRewrite(IfOp ifOp, 202 PatternRewriter &rewriter) const override; 203 }; 204 205 struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> { 206 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; 207 208 LogicalResult matchAndRewrite(ExecuteRegionOp op, 209 PatternRewriter &rewriter) const override; 210 }; 211 212 struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> { 213 using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern; 214 215 LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp, 216 PatternRewriter &rewriter) const override; 217 }; 218 219 /// Create a CFG subgraph for this loop construct. The regions of the loop need 220 /// not be a single block anymore (for example, if other SCF constructs that 221 /// they contain have been already converted to CFG), but need to be single-exit 222 /// from the last block of each region. The operations following the original 223 /// WhileOp are split into a new continuation block. Both regions of the WhileOp 224 /// are inlined, and their terminators are rewritten to organize the control 225 /// flow implementing the loop as follows. 226 /// 227 /// +---------------------------------+ 228 /// | <code before the WhileOp> | 229 /// | cf.br ^before(%operands...) | 230 /// +---------------------------------+ 231 /// | 232 /// -------| | 233 /// | v v 234 /// | +--------------------------------+ 235 /// | | ^before(%bargs...): | 236 /// | | %vals... = <some payload> | 237 /// | +--------------------------------+ 238 /// | | 239 /// | ... 240 /// | | 241 /// | +--------------------------------+ 242 /// | | ^before-last: 243 /// | | %cond = <compute condition> | 244 /// | | cf.cond_br %cond, | 245 /// | | ^after(%vals...), ^cont | 246 /// | +--------------------------------+ 247 /// | | | 248 /// | | -------------| 249 /// | v | 250 /// | +--------------------------------+ | 251 /// | | ^after(%aargs...): | | 252 /// | | <body contents> | | 253 /// | +--------------------------------+ | 254 /// | | | 255 /// | ... | 256 /// | | | 257 /// | +--------------------------------+ | 258 /// | | ^after-last: | | 259 /// | | %yields... = <some payload> | | 260 /// | | cf.br ^before(%yields...) | | 261 /// | +--------------------------------+ | 262 /// | | | 263 /// |----------- |-------------------- 264 /// v 265 /// +--------------------------------+ 266 /// | ^cont: | 267 /// | <code after the WhileOp> | 268 /// | <%vals from 'before' region | 269 /// | visible by dominance> | 270 /// +--------------------------------+ 271 /// 272 /// Values are communicated between ex-regions (the groups of blocks that used 273 /// to form a region before inlining) through block arguments of their 274 /// entry blocks, which are visible in all other dominated blocks. Similarly, 275 /// the results of the WhileOp are defined in the 'before' region, which is 276 /// required to have a single existing block, and are therefore accessible in 277 /// the continuation block due to dominance. 278 struct WhileLowering : public OpRewritePattern<WhileOp> { 279 using OpRewritePattern<WhileOp>::OpRewritePattern; 280 281 LogicalResult matchAndRewrite(WhileOp whileOp, 282 PatternRewriter &rewriter) const override; 283 }; 284 285 /// Optimized version of the above for the case of the "after" region merely 286 /// forwarding its arguments back to the "before" region (i.e., a "do-while" 287 /// loop). This avoid inlining the "after" region completely and branches back 288 /// to the "before" entry instead. 289 struct DoWhileLowering : public OpRewritePattern<WhileOp> { 290 using OpRewritePattern<WhileOp>::OpRewritePattern; 291 292 LogicalResult matchAndRewrite(WhileOp whileOp, 293 PatternRewriter &rewriter) const override; 294 }; 295 296 /// Lower an `scf.index_switch` operation to a `cf.switch` operation. 297 struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> { 298 using OpRewritePattern::OpRewritePattern; 299 300 LogicalResult matchAndRewrite(IndexSwitchOp op, 301 PatternRewriter &rewriter) const override; 302 }; 303 304 /// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it 305 /// has no shared outputs. Ops with shared outputs should be bufferized first. 306 /// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other 307 /// dialects/passes. 308 struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> { 309 using OpRewritePattern<mlir::scf::ForallOp>::OpRewritePattern; 310 311 LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp, 312 PatternRewriter &rewriter) const override; 313 }; 314 315 } // namespace 316 317 LogicalResult ForLowering::matchAndRewrite(ForOp forOp, 318 PatternRewriter &rewriter) const { 319 Location loc = forOp.getLoc(); 320 321 // Start by splitting the block containing the 'scf.for' into two parts. 322 // The part before will get the init code, the part after will be the end 323 // point. 324 auto *initBlock = rewriter.getInsertionBlock(); 325 auto initPosition = rewriter.getInsertionPoint(); 326 auto *endBlock = rewriter.splitBlock(initBlock, initPosition); 327 328 // Use the first block of the loop body as the condition block since it is the 329 // block that has the induction variable and loop-carried values as arguments. 330 // Split out all operations from the first block into a new block. Move all 331 // body blocks from the loop body region to the region containing the loop. 332 auto *conditionBlock = &forOp.getRegion().front(); 333 auto *firstBodyBlock = 334 rewriter.splitBlock(conditionBlock, conditionBlock->begin()); 335 auto *lastBodyBlock = &forOp.getRegion().back(); 336 rewriter.inlineRegionBefore(forOp.getRegion(), endBlock); 337 auto iv = conditionBlock->getArgument(0); 338 339 // Append the induction variable stepping logic to the last body block and 340 // branch back to the condition block. Loop-carried values are taken from 341 // operands of the loop terminator. 342 Operation *terminator = lastBodyBlock->getTerminator(); 343 rewriter.setInsertionPointToEnd(lastBodyBlock); 344 auto step = forOp.getStep(); 345 auto stepped = rewriter.create<arith::AddIOp>(loc, iv, step).getResult(); 346 if (!stepped) 347 return failure(); 348 349 SmallVector<Value, 8> loopCarried; 350 loopCarried.push_back(stepped); 351 loopCarried.append(terminator->operand_begin(), terminator->operand_end()); 352 rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried); 353 rewriter.eraseOp(terminator); 354 355 // Compute loop bounds before branching to the condition. 356 rewriter.setInsertionPointToEnd(initBlock); 357 Value lowerBound = forOp.getLowerBound(); 358 Value upperBound = forOp.getUpperBound(); 359 if (!lowerBound || !upperBound) 360 return failure(); 361 362 // The initial values of loop-carried values is obtained from the operands 363 // of the loop operation. 364 SmallVector<Value, 8> destOperands; 365 destOperands.push_back(lowerBound); 366 llvm::append_range(destOperands, forOp.getInitArgs()); 367 rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands); 368 369 // With the body block done, we can fill in the condition block. 370 rewriter.setInsertionPointToEnd(conditionBlock); 371 auto comparison = rewriter.create<arith::CmpIOp>( 372 loc, arith::CmpIPredicate::slt, iv, upperBound); 373 374 auto condBranchOp = rewriter.create<cf::CondBranchOp>( 375 loc, comparison, firstBodyBlock, ArrayRef<Value>(), endBlock, 376 ArrayRef<Value>()); 377 378 // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the 379 // llvm.loop_annotation attribute. 380 SmallVector<NamedAttribute> llvmAttrs; 381 llvm::copy_if(forOp->getAttrs(), std::back_inserter(llvmAttrs), 382 [](auto attr) { 383 return isa<LLVM::LLVMDialect>(attr.getValue().getDialect()); 384 }); 385 condBranchOp->setDiscardableAttrs(llvmAttrs); 386 // The result of the loop operation is the values of the condition block 387 // arguments except the induction variable on the last iteration. 388 rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front()); 389 return success(); 390 } 391 392 LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, 393 PatternRewriter &rewriter) const { 394 auto loc = ifOp.getLoc(); 395 396 // Start by splitting the block containing the 'scf.if' into two parts. 397 // The part before will contain the condition, the part after will be the 398 // continuation point. 399 auto *condBlock = rewriter.getInsertionBlock(); 400 auto opPosition = rewriter.getInsertionPoint(); 401 auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); 402 Block *continueBlock; 403 if (ifOp.getNumResults() == 0) { 404 continueBlock = remainingOpsBlock; 405 } else { 406 continueBlock = 407 rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(), 408 SmallVector<Location>(ifOp.getNumResults(), loc)); 409 rewriter.create<cf::BranchOp>(loc, remainingOpsBlock); 410 } 411 412 // Move blocks from the "then" region to the region containing 'scf.if', 413 // place it before the continuation block, and branch to it. 414 auto &thenRegion = ifOp.getThenRegion(); 415 auto *thenBlock = &thenRegion.front(); 416 Operation *thenTerminator = thenRegion.back().getTerminator(); 417 ValueRange thenTerminatorOperands = thenTerminator->getOperands(); 418 rewriter.setInsertionPointToEnd(&thenRegion.back()); 419 rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands); 420 rewriter.eraseOp(thenTerminator); 421 rewriter.inlineRegionBefore(thenRegion, continueBlock); 422 423 // Move blocks from the "else" region (if present) to the region containing 424 // 'scf.if', place it before the continuation block and branch to it. It 425 // will be placed after the "then" regions. 426 auto *elseBlock = continueBlock; 427 auto &elseRegion = ifOp.getElseRegion(); 428 if (!elseRegion.empty()) { 429 elseBlock = &elseRegion.front(); 430 Operation *elseTerminator = elseRegion.back().getTerminator(); 431 ValueRange elseTerminatorOperands = elseTerminator->getOperands(); 432 rewriter.setInsertionPointToEnd(&elseRegion.back()); 433 rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands); 434 rewriter.eraseOp(elseTerminator); 435 rewriter.inlineRegionBefore(elseRegion, continueBlock); 436 } 437 438 rewriter.setInsertionPointToEnd(condBlock); 439 rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock, 440 /*trueArgs=*/ArrayRef<Value>(), elseBlock, 441 /*falseArgs=*/ArrayRef<Value>()); 442 443 // Ok, we're done! 444 rewriter.replaceOp(ifOp, continueBlock->getArguments()); 445 return success(); 446 } 447 448 LogicalResult 449 ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op, 450 PatternRewriter &rewriter) const { 451 auto loc = op.getLoc(); 452 453 auto *condBlock = rewriter.getInsertionBlock(); 454 auto opPosition = rewriter.getInsertionPoint(); 455 auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); 456 457 auto ®ion = op.getRegion(); 458 rewriter.setInsertionPointToEnd(condBlock); 459 rewriter.create<cf::BranchOp>(loc, ®ion.front()); 460 461 for (Block &block : region) { 462 if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) { 463 ValueRange terminatorOperands = terminator->getOperands(); 464 rewriter.setInsertionPointToEnd(&block); 465 rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands); 466 rewriter.eraseOp(terminator); 467 } 468 } 469 470 rewriter.inlineRegionBefore(region, remainingOpsBlock); 471 472 SmallVector<Value> vals; 473 SmallVector<Location> argLocs(op.getNumResults(), op->getLoc()); 474 for (auto arg : 475 remainingOpsBlock->addArguments(op->getResultTypes(), argLocs)) 476 vals.push_back(arg); 477 rewriter.replaceOp(op, vals); 478 return success(); 479 } 480 481 LogicalResult 482 ParallelLowering::matchAndRewrite(ParallelOp parallelOp, 483 PatternRewriter &rewriter) const { 484 Location loc = parallelOp.getLoc(); 485 auto reductionOp = dyn_cast<ReduceOp>(parallelOp.getBody()->getTerminator()); 486 if (!reductionOp) { 487 return failure(); 488 } 489 490 // For a parallel loop, we essentially need to create an n-dimensional loop 491 // nest. We do this by translating to scf.for ops and have those lowered in 492 // a further rewrite. If a parallel loop contains reductions (and thus returns 493 // values), forward the initial values for the reductions down the loop 494 // hierarchy and bubble up the results by modifying the "yield" terminator. 495 SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.getInitVals()); 496 SmallVector<Value, 4> ivs; 497 ivs.reserve(parallelOp.getNumLoops()); 498 bool first = true; 499 SmallVector<Value, 4> loopResults(iterArgs); 500 for (auto [iv, lower, upper, step] : 501 llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(), 502 parallelOp.getUpperBound(), parallelOp.getStep())) { 503 ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs); 504 ivs.push_back(forOp.getInductionVar()); 505 auto iterRange = forOp.getRegionIterArgs(); 506 iterArgs.assign(iterRange.begin(), iterRange.end()); 507 508 if (first) { 509 // Store the results of the outermost loop that will be used to replace 510 // the results of the parallel loop when it is fully rewritten. 511 loopResults.assign(forOp.result_begin(), forOp.result_end()); 512 first = false; 513 } else if (!forOp.getResults().empty()) { 514 // A loop is constructed with an empty "yield" terminator if there are 515 // no results. 516 rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); 517 rewriter.create<scf::YieldOp>(loc, forOp.getResults()); 518 } 519 520 rewriter.setInsertionPointToStart(forOp.getBody()); 521 } 522 523 // First, merge reduction blocks into the main region. 524 SmallVector<Value> yieldOperands; 525 yieldOperands.reserve(parallelOp.getNumResults()); 526 for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) { 527 Block &reductionBody = reductionOp.getReductions()[i].front(); 528 Value arg = iterArgs[yieldOperands.size()]; 529 yieldOperands.push_back( 530 cast<ReduceReturnOp>(reductionBody.getTerminator()).getResult()); 531 rewriter.eraseOp(reductionBody.getTerminator()); 532 rewriter.inlineBlockBefore(&reductionBody, reductionOp, 533 {arg, reductionOp.getOperands()[i]}); 534 } 535 rewriter.eraseOp(reductionOp); 536 537 // Then merge the loop body without the terminator. 538 Block *newBody = rewriter.getInsertionBlock(); 539 if (newBody->empty()) 540 rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs); 541 else 542 rewriter.inlineBlockBefore(parallelOp.getBody(), newBody->getTerminator(), 543 ivs); 544 545 // Finally, create the terminator if required (for loops with no results, it 546 // has been already created in loop construction). 547 if (!yieldOperands.empty()) { 548 rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); 549 rewriter.create<scf::YieldOp>(loc, yieldOperands); 550 } 551 552 rewriter.replaceOp(parallelOp, loopResults); 553 554 return success(); 555 } 556 557 LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, 558 PatternRewriter &rewriter) const { 559 OpBuilder::InsertionGuard guard(rewriter); 560 Location loc = whileOp.getLoc(); 561 562 // Split the current block before the WhileOp to create the inlining point. 563 Block *currentBlock = rewriter.getInsertionBlock(); 564 Block *continuation = 565 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 566 567 // Inline both regions. 568 Block *after = whileOp.getAfterBody(); 569 Block *before = whileOp.getBeforeBody(); 570 rewriter.inlineRegionBefore(whileOp.getAfter(), continuation); 571 rewriter.inlineRegionBefore(whileOp.getBefore(), after); 572 573 // Branch to the "before" region. 574 rewriter.setInsertionPointToEnd(currentBlock); 575 rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits()); 576 577 // Replace terminators with branches. Assuming bodies are SESE, which holds 578 // given only the patterns from this file, we only need to look at the last 579 // block. This should be reconsidered if we allow break/continue in SCF. 580 rewriter.setInsertionPointToEnd(before); 581 auto condOp = cast<ConditionOp>(before->getTerminator()); 582 rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(), 583 after, condOp.getArgs(), 584 continuation, ValueRange()); 585 586 rewriter.setInsertionPointToEnd(after); 587 auto yieldOp = cast<scf::YieldOp>(after->getTerminator()); 588 rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before, 589 yieldOp.getResults()); 590 591 // Replace the op with values "yielded" from the "before" region, which are 592 // visible by dominance. 593 rewriter.replaceOp(whileOp, condOp.getArgs()); 594 595 return success(); 596 } 597 598 LogicalResult 599 DoWhileLowering::matchAndRewrite(WhileOp whileOp, 600 PatternRewriter &rewriter) const { 601 Block &afterBlock = *whileOp.getAfterBody(); 602 if (!llvm::hasSingleElement(afterBlock)) 603 return rewriter.notifyMatchFailure(whileOp, 604 "do-while simplification applicable " 605 "only if 'after' region has no payload"); 606 607 auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front()); 608 if (!yield || yield.getResults() != afterBlock.getArguments()) 609 return rewriter.notifyMatchFailure(whileOp, 610 "do-while simplification applicable " 611 "only to forwarding 'after' regions"); 612 613 // Split the current block before the WhileOp to create the inlining point. 614 OpBuilder::InsertionGuard guard(rewriter); 615 Block *currentBlock = rewriter.getInsertionBlock(); 616 Block *continuation = 617 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 618 619 // Only the "before" region should be inlined. 620 Block *before = whileOp.getBeforeBody(); 621 rewriter.inlineRegionBefore(whileOp.getBefore(), continuation); 622 623 // Branch to the "before" region. 624 rewriter.setInsertionPointToEnd(currentBlock); 625 rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits()); 626 627 // Loop around the "before" region based on condition. 628 rewriter.setInsertionPointToEnd(before); 629 auto condOp = cast<ConditionOp>(before->getTerminator()); 630 rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(), 631 before, condOp.getArgs(), 632 continuation, ValueRange()); 633 634 // Replace the op with values "yielded" from the "before" region, which are 635 // visible by dominance. 636 rewriter.replaceOp(whileOp, condOp.getArgs()); 637 638 return success(); 639 } 640 641 LogicalResult 642 IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op, 643 PatternRewriter &rewriter) const { 644 // Split the block at the op. 645 Block *condBlock = rewriter.getInsertionBlock(); 646 Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op)); 647 648 // Create the arguments on the continue block with which to replace the 649 // results of the op. 650 SmallVector<Value> results; 651 results.reserve(op.getNumResults()); 652 for (Type resultType : op.getResultTypes()) 653 results.push_back(continueBlock->addArgument(resultType, op.getLoc())); 654 655 // Handle the regions. 656 auto convertRegion = [&](Region ®ion) -> FailureOr<Block *> { 657 Block *block = ®ion.front(); 658 659 // Convert the yield terminator to a branch to the continue block. 660 auto yield = cast<scf::YieldOp>(block->getTerminator()); 661 rewriter.setInsertionPoint(yield); 662 rewriter.replaceOpWithNewOp<cf::BranchOp>(yield, continueBlock, 663 yield.getOperands()); 664 665 // Inline the region. 666 rewriter.inlineRegionBefore(region, continueBlock); 667 return block; 668 }; 669 670 // Convert the case regions. 671 SmallVector<Block *> caseSuccessors; 672 SmallVector<int32_t> caseValues; 673 caseSuccessors.reserve(op.getCases().size()); 674 caseValues.reserve(op.getCases().size()); 675 for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) { 676 FailureOr<Block *> block = convertRegion(region); 677 if (failed(block)) 678 return failure(); 679 caseSuccessors.push_back(*block); 680 caseValues.push_back(value); 681 } 682 683 // Convert the default region. 684 FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion()); 685 if (failed(defaultBlock)) 686 return failure(); 687 688 // Create the switch. 689 rewriter.setInsertionPointToEnd(condBlock); 690 SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {}); 691 692 // Cast switch index to integer case value. 693 Value caseValue = rewriter.create<arith::IndexCastOp>( 694 op.getLoc(), rewriter.getI32Type(), op.getArg()); 695 696 rewriter.create<cf::SwitchOp>( 697 op.getLoc(), caseValue, *defaultBlock, ValueRange(), 698 rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands); 699 rewriter.replaceOp(op, continueBlock->getArguments()); 700 return success(); 701 } 702 703 LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp, 704 PatternRewriter &rewriter) const { 705 return scf::forallToParallelLoop(rewriter, forallOp); 706 } 707 708 void mlir::populateSCFToControlFlowConversionPatterns( 709 RewritePatternSet &patterns) { 710 patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering, 711 WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>( 712 patterns.getContext()); 713 patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2); 714 } 715 716 void SCFToControlFlowPass::runOnOperation() { 717 RewritePatternSet patterns(&getContext()); 718 populateSCFToControlFlowConversionPatterns(patterns); 719 720 // Configure conversion to lower out SCF operations. 721 ConversionTarget target(getContext()); 722 target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp, 723 scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>(); 724 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); 725 if (failed( 726 applyPartialConversion(getOperation(), target, std::move(patterns)))) 727 signalPassFailure(); 728 } 729 730 std::unique_ptr<Pass> mlir::createConvertSCFToCFPass() { 731 return std::make_unique<SCFToControlFlowPass>(); 732 } 733