1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/CommonFolders.h" 13 #include "mlir/IR/AffineExpr.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/Matchers.h" 20 #include "mlir/IR/OpImplementation.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/IR/Value.h" 24 #include "mlir/Support/MathExtras.h" 25 #include "mlir/Transforms/InliningUtils.h" 26 #include "llvm/ADT/APFloat.h" 27 #include "llvm/ADT/STLExtras.h" 28 #include "llvm/ADT/StringSwitch.h" 29 #include "llvm/Support/FormatVariadic.h" 30 #include "llvm/Support/raw_ostream.h" 31 #include <numeric> 32 33 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc" 34 35 using namespace mlir; 36 using namespace mlir::cf; 37 38 //===----------------------------------------------------------------------===// 39 // ControlFlowDialect Interfaces 40 //===----------------------------------------------------------------------===// 41 namespace { 42 /// This class defines the interface for handling inlining with control flow 43 /// operations. 44 struct ControlFlowInlinerInterface : public DialectInlinerInterface { 45 using DialectInlinerInterface::DialectInlinerInterface; 46 ~ControlFlowInlinerInterface() override = default; 47 48 /// All control flow operations can be inlined. 49 bool isLegalToInline(Operation *call, Operation *callable, 50 bool wouldBeCloned) const final { 51 return true; 52 } 53 bool isLegalToInline(Operation *, Region *, bool, 54 BlockAndValueMapping &) const final { 55 return true; 56 } 57 58 /// ControlFlow terminator operations don't really need any special handing. 59 void handleTerminator(Operation *op, Block *newDest) const final {} 60 }; 61 } // namespace 62 63 //===----------------------------------------------------------------------===// 64 // ControlFlowDialect 65 //===----------------------------------------------------------------------===// 66 67 void ControlFlowDialect::initialize() { 68 addOperations< 69 #define GET_OP_LIST 70 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" 71 >(); 72 addInterfaces<ControlFlowInlinerInterface>(); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // AssertOp 77 //===----------------------------------------------------------------------===// 78 79 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { 80 // Erase assertion if argument is constant true. 81 if (matchPattern(op.getArg(), m_One())) { 82 rewriter.eraseOp(op); 83 return success(); 84 } 85 return failure(); 86 } 87 88 //===----------------------------------------------------------------------===// 89 // BranchOp 90 //===----------------------------------------------------------------------===// 91 92 /// Given a successor, try to collapse it to a new destination if it only 93 /// contains a passthrough unconditional branch. If the successor is 94 /// collapsable, `successor` and `successorOperands` are updated to reference 95 /// the new destination and values. `argStorage` is used as storage if operands 96 /// to the collapsed successor need to be remapped. It must outlive uses of 97 /// successorOperands. 98 static LogicalResult collapseBranch(Block *&successor, 99 ValueRange &successorOperands, 100 SmallVectorImpl<Value> &argStorage) { 101 // Check that the successor only contains a unconditional branch. 102 if (std::next(successor->begin()) != successor->end()) 103 return failure(); 104 // Check that the terminator is an unconditional branch. 105 BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator()); 106 if (!successorBranch) 107 return failure(); 108 // Check that the arguments are only used within the terminator. 109 for (BlockArgument arg : successor->getArguments()) { 110 for (Operation *user : arg.getUsers()) 111 if (user != successorBranch) 112 return failure(); 113 } 114 // Don't try to collapse branches to infinite loops. 115 Block *successorDest = successorBranch.getDest(); 116 if (successorDest == successor) 117 return failure(); 118 119 // Update the operands to the successor. If the branch parent has no 120 // arguments, we can use the branch operands directly. 121 OperandRange operands = successorBranch.getOperands(); 122 if (successor->args_empty()) { 123 successor = successorDest; 124 successorOperands = operands; 125 return success(); 126 } 127 128 // Otherwise, we need to remap any argument operands. 129 for (Value operand : operands) { 130 BlockArgument argOperand = operand.dyn_cast<BlockArgument>(); 131 if (argOperand && argOperand.getOwner() == successor) 132 argStorage.push_back(successorOperands[argOperand.getArgNumber()]); 133 else 134 argStorage.push_back(operand); 135 } 136 successor = successorDest; 137 successorOperands = argStorage; 138 return success(); 139 } 140 141 /// Simplify a branch to a block that has a single predecessor. This effectively 142 /// merges the two blocks. 143 static LogicalResult 144 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) { 145 // Check that the successor block has a single predecessor. 146 Block *succ = op.getDest(); 147 Block *opParent = op->getBlock(); 148 if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors())) 149 return failure(); 150 151 // Merge the successor into the current block and erase the branch. 152 rewriter.mergeBlocks(succ, opParent, op.getOperands()); 153 rewriter.eraseOp(op); 154 return success(); 155 } 156 157 /// br ^bb1 158 /// ^bb1 159 /// br ^bbN(...) 160 /// 161 /// -> br ^bbN(...) 162 /// 163 static LogicalResult simplifyPassThroughBr(BranchOp op, 164 PatternRewriter &rewriter) { 165 Block *dest = op.getDest(); 166 ValueRange destOperands = op.getOperands(); 167 SmallVector<Value, 4> destOperandStorage; 168 169 // Try to collapse the successor if it points somewhere other than this 170 // block. 171 if (dest == op->getBlock() || 172 failed(collapseBranch(dest, destOperands, destOperandStorage))) 173 return failure(); 174 175 // Create a new branch with the collapsed successor. 176 rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands); 177 return success(); 178 } 179 180 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) { 181 return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) || 182 succeeded(simplifyPassThroughBr(op, rewriter))); 183 } 184 185 void BranchOp::setDest(Block *block) { return setSuccessor(block); } 186 187 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); } 188 189 Optional<MutableOperandRange> 190 BranchOp::getMutableSuccessorOperands(unsigned index) { 191 assert(index == 0 && "invalid successor index"); 192 return getDestOperandsMutable(); 193 } 194 195 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { 196 return getDest(); 197 } 198 199 //===----------------------------------------------------------------------===// 200 // CondBranchOp 201 //===----------------------------------------------------------------------===// 202 203 namespace { 204 /// cf.cond_br true, ^bb1, ^bb2 205 /// -> br ^bb1 206 /// cf.cond_br false, ^bb1, ^bb2 207 /// -> br ^bb2 208 /// 209 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { 210 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 211 212 LogicalResult matchAndRewrite(CondBranchOp condbr, 213 PatternRewriter &rewriter) const override { 214 if (matchPattern(condbr.getCondition(), m_NonZero())) { 215 // True branch taken. 216 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), 217 condbr.getTrueOperands()); 218 return success(); 219 } 220 if (matchPattern(condbr.getCondition(), m_Zero())) { 221 // False branch taken. 222 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), 223 condbr.getFalseOperands()); 224 return success(); 225 } 226 return failure(); 227 } 228 }; 229 230 /// cf.cond_br %cond, ^bb1, ^bb2 231 /// ^bb1 232 /// br ^bbN(...) 233 /// ^bb2 234 /// br ^bbK(...) 235 /// 236 /// -> cf.cond_br %cond, ^bbN(...), ^bbK(...) 237 /// 238 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> { 239 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 240 241 LogicalResult matchAndRewrite(CondBranchOp condbr, 242 PatternRewriter &rewriter) const override { 243 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest(); 244 ValueRange trueDestOperands = condbr.getTrueOperands(); 245 ValueRange falseDestOperands = condbr.getFalseOperands(); 246 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage; 247 248 // Try to collapse one of the current successors. 249 LogicalResult collapsedTrue = 250 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage); 251 LogicalResult collapsedFalse = 252 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage); 253 if (failed(collapsedTrue) && failed(collapsedFalse)) 254 return failure(); 255 256 // Create a new branch with the collapsed successors. 257 rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(), 258 trueDest, trueDestOperands, 259 falseDest, falseDestOperands); 260 return success(); 261 } 262 }; 263 264 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) 265 /// -> br ^bb1(A, ..., N) 266 /// 267 /// cf.cond_br %cond, ^bb1(A), ^bb1(B) 268 /// -> %select = arith.select %cond, A, B 269 /// br ^bb1(%select) 270 /// 271 struct SimplifyCondBranchIdenticalSuccessors 272 : public OpRewritePattern<CondBranchOp> { 273 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 274 275 LogicalResult matchAndRewrite(CondBranchOp condbr, 276 PatternRewriter &rewriter) const override { 277 // Check that the true and false destinations are the same and have the same 278 // operands. 279 Block *trueDest = condbr.getTrueDest(); 280 if (trueDest != condbr.getFalseDest()) 281 return failure(); 282 283 // If all of the operands match, no selects need to be generated. 284 OperandRange trueOperands = condbr.getTrueOperands(); 285 OperandRange falseOperands = condbr.getFalseOperands(); 286 if (trueOperands == falseOperands) { 287 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands); 288 return success(); 289 } 290 291 // Otherwise, if the current block is the only predecessor insert selects 292 // for any mismatched branch operands. 293 if (trueDest->getUniquePredecessor() != condbr->getBlock()) 294 return failure(); 295 296 // Generate a select for any operands that differ between the two. 297 SmallVector<Value, 8> mergedOperands; 298 mergedOperands.reserve(trueOperands.size()); 299 Value condition = condbr.getCondition(); 300 for (auto it : llvm::zip(trueOperands, falseOperands)) { 301 if (std::get<0>(it) == std::get<1>(it)) 302 mergedOperands.push_back(std::get<0>(it)); 303 else 304 mergedOperands.push_back(rewriter.create<arith::SelectOp>( 305 condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it))); 306 } 307 308 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands); 309 return success(); 310 } 311 }; 312 313 /// ... 314 /// cf.cond_br %cond, ^bb1(...), ^bb2(...) 315 /// ... 316 /// ^bb1: // has single predecessor 317 /// ... 318 /// cf.cond_br %cond, ^bb3(...), ^bb4(...) 319 /// 320 /// -> 321 /// 322 /// ... 323 /// cf.cond_br %cond, ^bb1(...), ^bb2(...) 324 /// ... 325 /// ^bb1: // has single predecessor 326 /// ... 327 /// br ^bb3(...) 328 /// 329 struct SimplifyCondBranchFromCondBranchOnSameCondition 330 : public OpRewritePattern<CondBranchOp> { 331 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 332 333 LogicalResult matchAndRewrite(CondBranchOp condbr, 334 PatternRewriter &rewriter) const override { 335 // Check that we have a single distinct predecessor. 336 Block *currentBlock = condbr->getBlock(); 337 Block *predecessor = currentBlock->getSinglePredecessor(); 338 if (!predecessor) 339 return failure(); 340 341 // Check that the predecessor terminates with a conditional branch to this 342 // block and that it branches on the same condition. 343 auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator()); 344 if (!predBranch || condbr.getCondition() != predBranch.getCondition()) 345 return failure(); 346 347 // Fold this branch to an unconditional branch. 348 if (currentBlock == predBranch.getTrueDest()) 349 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), 350 condbr.getTrueDestOperands()); 351 else 352 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), 353 condbr.getFalseDestOperands()); 354 return success(); 355 } 356 }; 357 358 /// cf.cond_br %arg0, ^trueB, ^falseB 359 /// 360 /// ^trueB: 361 /// "test.consumer1"(%arg0) : (i1) -> () 362 /// ... 363 /// 364 /// ^falseB: 365 /// "test.consumer2"(%arg0) : (i1) -> () 366 /// ... 367 /// 368 /// -> 369 /// 370 /// cf.cond_br %arg0, ^trueB, ^falseB 371 /// ^trueB: 372 /// "test.consumer1"(%true) : (i1) -> () 373 /// ... 374 /// 375 /// ^falseB: 376 /// "test.consumer2"(%false) : (i1) -> () 377 /// ... 378 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> { 379 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 380 381 LogicalResult matchAndRewrite(CondBranchOp condbr, 382 PatternRewriter &rewriter) const override { 383 // Check that we have a single distinct predecessor. 384 bool replaced = false; 385 Type ty = rewriter.getI1Type(); 386 387 // These variables serve to prevent creating duplicate constants 388 // and hold constant true or false values. 389 Value constantTrue = nullptr; 390 Value constantFalse = nullptr; 391 392 // TODO These checks can be expanded to encompas any use with only 393 // either the true of false edge as a predecessor. For now, we fall 394 // back to checking the single predecessor is given by the true/fasle 395 // destination, thereby ensuring that only that edge can reach the 396 // op. 397 if (condbr.getTrueDest()->getSinglePredecessor()) { 398 for (OpOperand &use : 399 llvm::make_early_inc_range(condbr.getCondition().getUses())) { 400 if (use.getOwner()->getBlock() == condbr.getTrueDest()) { 401 replaced = true; 402 403 if (!constantTrue) 404 constantTrue = rewriter.create<arith::ConstantOp>( 405 condbr.getLoc(), ty, rewriter.getBoolAttr(true)); 406 407 rewriter.updateRootInPlace(use.getOwner(), 408 [&] { use.set(constantTrue); }); 409 } 410 } 411 } 412 if (condbr.getFalseDest()->getSinglePredecessor()) { 413 for (OpOperand &use : 414 llvm::make_early_inc_range(condbr.getCondition().getUses())) { 415 if (use.getOwner()->getBlock() == condbr.getFalseDest()) { 416 replaced = true; 417 418 if (!constantFalse) 419 constantFalse = rewriter.create<arith::ConstantOp>( 420 condbr.getLoc(), ty, rewriter.getBoolAttr(false)); 421 422 rewriter.updateRootInPlace(use.getOwner(), 423 [&] { use.set(constantFalse); }); 424 } 425 } 426 } 427 return success(replaced); 428 } 429 }; 430 } // namespace 431 432 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, 433 MLIRContext *context) { 434 results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch, 435 SimplifyCondBranchIdenticalSuccessors, 436 SimplifyCondBranchFromCondBranchOnSameCondition, 437 CondBranchTruthPropagation>(context); 438 } 439 440 Optional<MutableOperandRange> 441 CondBranchOp::getMutableSuccessorOperands(unsigned index) { 442 assert(index < getNumSuccessors() && "invalid successor index"); 443 return index == trueIndex ? getTrueDestOperandsMutable() 444 : getFalseDestOperandsMutable(); 445 } 446 447 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { 448 if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) 449 return condAttr.getValue().isOneValue() ? getTrueDest() : getFalseDest(); 450 return nullptr; 451 } 452 453 //===----------------------------------------------------------------------===// 454 // SwitchOp 455 //===----------------------------------------------------------------------===// 456 457 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 458 Block *defaultDestination, ValueRange defaultOperands, 459 DenseIntElementsAttr caseValues, 460 BlockRange caseDestinations, 461 ArrayRef<ValueRange> caseOperands) { 462 build(builder, result, value, defaultOperands, caseOperands, caseValues, 463 defaultDestination, caseDestinations); 464 } 465 466 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 467 Block *defaultDestination, ValueRange defaultOperands, 468 ArrayRef<APInt> caseValues, BlockRange caseDestinations, 469 ArrayRef<ValueRange> caseOperands) { 470 DenseIntElementsAttr caseValuesAttr; 471 if (!caseValues.empty()) { 472 ShapedType caseValueType = VectorType::get( 473 static_cast<int64_t>(caseValues.size()), value.getType()); 474 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); 475 } 476 build(builder, result, value, defaultDestination, defaultOperands, 477 caseValuesAttr, caseDestinations, caseOperands); 478 } 479 480 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? 481 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* 482 static ParseResult parseSwitchOpCases( 483 OpAsmParser &parser, Type &flagType, Block *&defaultDestination, 484 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands, 485 SmallVectorImpl<Type> &defaultOperandTypes, 486 DenseIntElementsAttr &caseValues, 487 SmallVectorImpl<Block *> &caseDestinations, 488 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, 489 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 490 if (parser.parseKeyword("default") || parser.parseColon() || 491 parser.parseSuccessor(defaultDestination)) 492 return failure(); 493 if (succeeded(parser.parseOptionalLParen())) { 494 if (parser.parseRegionArgumentList(defaultOperands) || 495 parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen()) 496 return failure(); 497 } 498 499 SmallVector<APInt> values; 500 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 501 while (succeeded(parser.parseOptionalComma())) { 502 int64_t value = 0; 503 if (failed(parser.parseInteger(value))) 504 return failure(); 505 values.push_back(APInt(bitWidth, value)); 506 507 Block *destination; 508 SmallVector<OpAsmParser::UnresolvedOperand> operands; 509 SmallVector<Type> operandTypes; 510 if (failed(parser.parseColon()) || 511 failed(parser.parseSuccessor(destination))) 512 return failure(); 513 if (succeeded(parser.parseOptionalLParen())) { 514 if (failed(parser.parseRegionArgumentList(operands)) || 515 failed(parser.parseColonTypeList(operandTypes)) || 516 failed(parser.parseRParen())) 517 return failure(); 518 } 519 caseDestinations.push_back(destination); 520 caseOperands.emplace_back(operands); 521 caseOperandTypes.emplace_back(operandTypes); 522 } 523 524 if (!values.empty()) { 525 ShapedType caseValueType = 526 VectorType::get(static_cast<int64_t>(values.size()), flagType); 527 caseValues = DenseIntElementsAttr::get(caseValueType, values); 528 } 529 return success(); 530 } 531 532 static void printSwitchOpCases( 533 OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, 534 OperandRange defaultOperands, TypeRange defaultOperandTypes, 535 DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, 536 OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) { 537 p << " default: "; 538 p.printSuccessorAndUseList(defaultDestination, defaultOperands); 539 540 if (!caseValues) 541 return; 542 543 for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) { 544 p << ','; 545 p.printNewline(); 546 p << " "; 547 p << it.value().getLimitedValue(); 548 p << ": "; 549 p.printSuccessorAndUseList(caseDestinations[it.index()], 550 caseOperands[it.index()]); 551 } 552 p.printNewline(); 553 } 554 555 LogicalResult SwitchOp::verify() { 556 auto caseValues = getCaseValues(); 557 auto caseDestinations = getCaseDestinations(); 558 559 if (!caseValues && caseDestinations.empty()) 560 return success(); 561 562 Type flagType = getFlag().getType(); 563 Type caseValueType = caseValues->getType().getElementType(); 564 if (caseValueType != flagType) 565 return emitOpError() << "'flag' type (" << flagType 566 << ") should match case value type (" << caseValueType 567 << ")"; 568 569 if (caseValues && 570 caseValues->size() != static_cast<int64_t>(caseDestinations.size())) 571 return emitOpError() << "number of case values (" << caseValues->size() 572 << ") should match number of " 573 "case destinations (" 574 << caseDestinations.size() << ")"; 575 return success(); 576 } 577 578 Optional<MutableOperandRange> 579 SwitchOp::getMutableSuccessorOperands(unsigned index) { 580 assert(index < getNumSuccessors() && "invalid successor index"); 581 return index == 0 ? getDefaultOperandsMutable() 582 : getCaseOperandsMutable(index - 1); 583 } 584 585 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { 586 Optional<DenseIntElementsAttr> caseValues = getCaseValues(); 587 588 if (!caseValues) 589 return getDefaultDestination(); 590 591 SuccessorRange caseDests = getCaseDestinations(); 592 if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) { 593 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) 594 if (it.value() == value.getValue()) 595 return caseDests[it.index()]; 596 return getDefaultDestination(); 597 } 598 return nullptr; 599 } 600 601 /// switch %flag : i32, [ 602 /// default: ^bb1 603 /// ] 604 /// -> br ^bb1 605 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, 606 PatternRewriter &rewriter) { 607 if (!op.getCaseDestinations().empty()) 608 return failure(); 609 610 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 611 op.getDefaultOperands()); 612 return success(); 613 } 614 615 /// switch %flag : i32, [ 616 /// default: ^bb1, 617 /// 42: ^bb1, 618 /// 43: ^bb2 619 /// ] 620 /// -> 621 /// switch %flag : i32, [ 622 /// default: ^bb1, 623 /// 43: ^bb2 624 /// ] 625 static LogicalResult 626 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) { 627 SmallVector<Block *> newCaseDestinations; 628 SmallVector<ValueRange> newCaseOperands; 629 SmallVector<APInt> newCaseValues; 630 bool requiresChange = false; 631 auto caseValues = op.getCaseValues(); 632 auto caseDests = op.getCaseDestinations(); 633 634 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 635 if (caseDests[it.index()] == op.getDefaultDestination() && 636 op.getCaseOperands(it.index()) == op.getDefaultOperands()) { 637 requiresChange = true; 638 continue; 639 } 640 newCaseDestinations.push_back(caseDests[it.index()]); 641 newCaseOperands.push_back(op.getCaseOperands(it.index())); 642 newCaseValues.push_back(it.value()); 643 } 644 645 if (!requiresChange) 646 return failure(); 647 648 rewriter.replaceOpWithNewOp<SwitchOp>( 649 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), 650 newCaseValues, newCaseDestinations, newCaseOperands); 651 return success(); 652 } 653 654 /// Helper for folding a switch with a constant value. 655 /// switch %c_42 : i32, [ 656 /// default: ^bb1 , 657 /// 42: ^bb2, 658 /// 43: ^bb3 659 /// ] 660 /// -> br ^bb2 661 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, 662 const APInt &caseValue) { 663 auto caseValues = op.getCaseValues(); 664 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 665 if (it.value() == caseValue) { 666 rewriter.replaceOpWithNewOp<BranchOp>( 667 op, op.getCaseDestinations()[it.index()], 668 op.getCaseOperands(it.index())); 669 return; 670 } 671 } 672 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 673 op.getDefaultOperands()); 674 } 675 676 /// switch %c_42 : i32, [ 677 /// default: ^bb1, 678 /// 42: ^bb2, 679 /// 43: ^bb3 680 /// ] 681 /// -> br ^bb2 682 static LogicalResult simplifyConstSwitchValue(SwitchOp op, 683 PatternRewriter &rewriter) { 684 APInt caseValue; 685 if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue))) 686 return failure(); 687 688 foldSwitch(op, rewriter, caseValue); 689 return success(); 690 } 691 692 /// switch %c_42 : i32, [ 693 /// default: ^bb1, 694 /// 42: ^bb2, 695 /// ] 696 /// ^bb2: 697 /// br ^bb3 698 /// -> 699 /// switch %c_42 : i32, [ 700 /// default: ^bb1, 701 /// 42: ^bb3, 702 /// ] 703 static LogicalResult simplifyPassThroughSwitch(SwitchOp op, 704 PatternRewriter &rewriter) { 705 SmallVector<Block *> newCaseDests; 706 SmallVector<ValueRange> newCaseOperands; 707 SmallVector<SmallVector<Value>> argStorage; 708 auto caseValues = op.getCaseValues(); 709 auto caseDests = op.getCaseDestinations(); 710 bool requiresChange = false; 711 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { 712 Block *caseDest = caseDests[i]; 713 ValueRange caseOperands = op.getCaseOperands(i); 714 argStorage.emplace_back(); 715 if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back()))) 716 requiresChange = true; 717 718 newCaseDests.push_back(caseDest); 719 newCaseOperands.push_back(caseOperands); 720 } 721 722 Block *defaultDest = op.getDefaultDestination(); 723 ValueRange defaultOperands = op.getDefaultOperands(); 724 argStorage.emplace_back(); 725 726 if (succeeded( 727 collapseBranch(defaultDest, defaultOperands, argStorage.back()))) 728 requiresChange = true; 729 730 if (!requiresChange) 731 return failure(); 732 733 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest, 734 defaultOperands, caseValues.getValue(), 735 newCaseDests, newCaseOperands); 736 return success(); 737 } 738 739 /// switch %flag : i32, [ 740 /// default: ^bb1, 741 /// 42: ^bb2, 742 /// ] 743 /// ^bb2: 744 /// switch %flag : i32, [ 745 /// default: ^bb3, 746 /// 42: ^bb4 747 /// ] 748 /// -> 749 /// switch %flag : i32, [ 750 /// default: ^bb1, 751 /// 42: ^bb2, 752 /// ] 753 /// ^bb2: 754 /// br ^bb4 755 /// 756 /// and 757 /// 758 /// switch %flag : i32, [ 759 /// default: ^bb1, 760 /// 42: ^bb2, 761 /// ] 762 /// ^bb2: 763 /// switch %flag : i32, [ 764 /// default: ^bb3, 765 /// 43: ^bb4 766 /// ] 767 /// -> 768 /// switch %flag : i32, [ 769 /// default: ^bb1, 770 /// 42: ^bb2, 771 /// ] 772 /// ^bb2: 773 /// br ^bb3 774 static LogicalResult 775 simplifySwitchFromSwitchOnSameCondition(SwitchOp op, 776 PatternRewriter &rewriter) { 777 // Check that we have a single distinct predecessor. 778 Block *currentBlock = op->getBlock(); 779 Block *predecessor = currentBlock->getSinglePredecessor(); 780 if (!predecessor) 781 return failure(); 782 783 // Check that the predecessor terminates with a switch branch to this block 784 // and that it branches on the same condition and that this branch isn't the 785 // default destination. 786 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); 787 if (!predSwitch || op.getFlag() != predSwitch.getFlag() || 788 predSwitch.getDefaultDestination() == currentBlock) 789 return failure(); 790 791 // Fold this switch to an unconditional branch. 792 SuccessorRange predDests = predSwitch.getCaseDestinations(); 793 auto it = llvm::find(predDests, currentBlock); 794 if (it != predDests.end()) { 795 Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues(); 796 foldSwitch(op, rewriter, 797 predCaseValues->getValues<APInt>()[it - predDests.begin()]); 798 } else { 799 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 800 op.getDefaultOperands()); 801 } 802 return success(); 803 } 804 805 /// switch %flag : i32, [ 806 /// default: ^bb1, 807 /// 42: ^bb2 808 /// ] 809 /// ^bb1: 810 /// switch %flag : i32, [ 811 /// default: ^bb3, 812 /// 42: ^bb4, 813 /// 43: ^bb5 814 /// ] 815 /// -> 816 /// switch %flag : i32, [ 817 /// default: ^bb1, 818 /// 42: ^bb2, 819 /// ] 820 /// ^bb1: 821 /// switch %flag : i32, [ 822 /// default: ^bb3, 823 /// 43: ^bb5 824 /// ] 825 static LogicalResult 826 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, 827 PatternRewriter &rewriter) { 828 // Check that we have a single distinct predecessor. 829 Block *currentBlock = op->getBlock(); 830 Block *predecessor = currentBlock->getSinglePredecessor(); 831 if (!predecessor) 832 return failure(); 833 834 // Check that the predecessor terminates with a switch branch to this block 835 // and that it branches on the same condition and that this branch is the 836 // default destination. 837 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); 838 if (!predSwitch || op.getFlag() != predSwitch.getFlag() || 839 predSwitch.getDefaultDestination() != currentBlock) 840 return failure(); 841 842 // Delete case values that are not possible here. 843 DenseSet<APInt> caseValuesToRemove; 844 auto predDests = predSwitch.getCaseDestinations(); 845 auto predCaseValues = predSwitch.getCaseValues(); 846 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) 847 if (currentBlock != predDests[i]) 848 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]); 849 850 SmallVector<Block *> newCaseDestinations; 851 SmallVector<ValueRange> newCaseOperands; 852 SmallVector<APInt> newCaseValues; 853 bool requiresChange = false; 854 855 auto caseValues = op.getCaseValues(); 856 auto caseDests = op.getCaseDestinations(); 857 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 858 if (caseValuesToRemove.contains(it.value())) { 859 requiresChange = true; 860 continue; 861 } 862 newCaseDestinations.push_back(caseDests[it.index()]); 863 newCaseOperands.push_back(op.getCaseOperands(it.index())); 864 newCaseValues.push_back(it.value()); 865 } 866 867 if (!requiresChange) 868 return failure(); 869 870 rewriter.replaceOpWithNewOp<SwitchOp>( 871 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), 872 newCaseValues, newCaseDestinations, newCaseOperands); 873 return success(); 874 } 875 876 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, 877 MLIRContext *context) { 878 results.add(&simplifySwitchWithOnlyDefault) 879 .add(&dropSwitchCasesThatMatchDefault) 880 .add(&simplifyConstSwitchValue) 881 .add(&simplifyPassThroughSwitch) 882 .add(&simplifySwitchFromSwitchOnSameCondition) 883 .add(&simplifySwitchFromDefaultSwitchOnSameCondition); 884 } 885 886 //===----------------------------------------------------------------------===// 887 // TableGen'd op method definitions 888 //===----------------------------------------------------------------------===// 889 890 #define GET_OP_CLASSES 891 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" 892