1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 10 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/CommonFolders.h" 13 #include "mlir/IR/AffineExpr.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/IRMapping.h" 19 #include "mlir/IR/Matchers.h" 20 #include "mlir/IR/OpImplementation.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/IR/Value.h" 24 #include "mlir/Support/MathExtras.h" 25 #include "mlir/Transforms/InliningUtils.h" 26 #include "llvm/ADT/APFloat.h" 27 #include "llvm/ADT/STLExtras.h" 28 #include "llvm/Support/FormatVariadic.h" 29 #include "llvm/Support/raw_ostream.h" 30 #include <numeric> 31 32 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc" 33 34 using namespace mlir; 35 using namespace mlir::cf; 36 37 //===----------------------------------------------------------------------===// 38 // ControlFlowDialect Interfaces 39 //===----------------------------------------------------------------------===// 40 namespace { 41 /// This class defines the interface for handling inlining with control flow 42 /// operations. 43 struct ControlFlowInlinerInterface : public DialectInlinerInterface { 44 using DialectInlinerInterface::DialectInlinerInterface; 45 ~ControlFlowInlinerInterface() override = default; 46 47 /// All control flow operations can be inlined. 48 bool isLegalToInline(Operation *call, Operation *callable, 49 bool wouldBeCloned) const final { 50 return true; 51 } 52 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { 53 return true; 54 } 55 56 /// ControlFlow terminator operations don't really need any special handing. 57 void handleTerminator(Operation *op, Block *newDest) const final {} 58 }; 59 } // namespace 60 61 //===----------------------------------------------------------------------===// 62 // ControlFlowDialect 63 //===----------------------------------------------------------------------===// 64 65 void ControlFlowDialect::initialize() { 66 addOperations< 67 #define GET_OP_LIST 68 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" 69 >(); 70 addInterfaces<ControlFlowInlinerInterface>(); 71 } 72 73 //===----------------------------------------------------------------------===// 74 // AssertOp 75 //===----------------------------------------------------------------------===// 76 77 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { 78 // Erase assertion if argument is constant true. 79 if (matchPattern(op.getArg(), m_One())) { 80 rewriter.eraseOp(op); 81 return success(); 82 } 83 return failure(); 84 } 85 86 //===----------------------------------------------------------------------===// 87 // BranchOp 88 //===----------------------------------------------------------------------===// 89 90 /// Given a successor, try to collapse it to a new destination if it only 91 /// contains a passthrough unconditional branch. If the successor is 92 /// collapsable, `successor` and `successorOperands` are updated to reference 93 /// the new destination and values. `argStorage` is used as storage if operands 94 /// to the collapsed successor need to be remapped. It must outlive uses of 95 /// successorOperands. 96 static LogicalResult collapseBranch(Block *&successor, 97 ValueRange &successorOperands, 98 SmallVectorImpl<Value> &argStorage) { 99 // Check that the successor only contains a unconditional branch. 100 if (std::next(successor->begin()) != successor->end()) 101 return failure(); 102 // Check that the terminator is an unconditional branch. 103 BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator()); 104 if (!successorBranch) 105 return failure(); 106 // Check that the arguments are only used within the terminator. 107 for (BlockArgument arg : successor->getArguments()) { 108 for (Operation *user : arg.getUsers()) 109 if (user != successorBranch) 110 return failure(); 111 } 112 // Don't try to collapse branches to infinite loops. 113 Block *successorDest = successorBranch.getDest(); 114 if (successorDest == successor) 115 return failure(); 116 117 // Update the operands to the successor. If the branch parent has no 118 // arguments, we can use the branch operands directly. 119 OperandRange operands = successorBranch.getOperands(); 120 if (successor->args_empty()) { 121 successor = successorDest; 122 successorOperands = operands; 123 return success(); 124 } 125 126 // Otherwise, we need to remap any argument operands. 127 for (Value operand : operands) { 128 BlockArgument argOperand = operand.dyn_cast<BlockArgument>(); 129 if (argOperand && argOperand.getOwner() == successor) 130 argStorage.push_back(successorOperands[argOperand.getArgNumber()]); 131 else 132 argStorage.push_back(operand); 133 } 134 successor = successorDest; 135 successorOperands = argStorage; 136 return success(); 137 } 138 139 /// Simplify a branch to a block that has a single predecessor. This effectively 140 /// merges the two blocks. 141 static LogicalResult 142 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) { 143 // Check that the successor block has a single predecessor. 144 Block *succ = op.getDest(); 145 Block *opParent = op->getBlock(); 146 if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors())) 147 return failure(); 148 149 // Merge the successor into the current block and erase the branch. 150 rewriter.mergeBlocks(succ, opParent, op.getOperands()); 151 rewriter.eraseOp(op); 152 return success(); 153 } 154 155 /// br ^bb1 156 /// ^bb1 157 /// br ^bbN(...) 158 /// 159 /// -> br ^bbN(...) 160 /// 161 static LogicalResult simplifyPassThroughBr(BranchOp op, 162 PatternRewriter &rewriter) { 163 Block *dest = op.getDest(); 164 ValueRange destOperands = op.getOperands(); 165 SmallVector<Value, 4> destOperandStorage; 166 167 // Try to collapse the successor if it points somewhere other than this 168 // block. 169 if (dest == op->getBlock() || 170 failed(collapseBranch(dest, destOperands, destOperandStorage))) 171 return failure(); 172 173 // Create a new branch with the collapsed successor. 174 rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands); 175 return success(); 176 } 177 178 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) { 179 return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) || 180 succeeded(simplifyPassThroughBr(op, rewriter))); 181 } 182 183 void BranchOp::setDest(Block *block) { return setSuccessor(block); } 184 185 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); } 186 187 SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) { 188 assert(index == 0 && "invalid successor index"); 189 return SuccessorOperands(getDestOperandsMutable()); 190 } 191 192 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { 193 return getDest(); 194 } 195 196 //===----------------------------------------------------------------------===// 197 // CondBranchOp 198 //===----------------------------------------------------------------------===// 199 200 namespace { 201 /// cf.cond_br true, ^bb1, ^bb2 202 /// -> br ^bb1 203 /// cf.cond_br false, ^bb1, ^bb2 204 /// -> br ^bb2 205 /// 206 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { 207 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 208 209 LogicalResult matchAndRewrite(CondBranchOp condbr, 210 PatternRewriter &rewriter) const override { 211 if (matchPattern(condbr.getCondition(), m_NonZero())) { 212 // True branch taken. 213 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), 214 condbr.getTrueOperands()); 215 return success(); 216 } 217 if (matchPattern(condbr.getCondition(), m_Zero())) { 218 // False branch taken. 219 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), 220 condbr.getFalseOperands()); 221 return success(); 222 } 223 return failure(); 224 } 225 }; 226 227 /// cf.cond_br %cond, ^bb1, ^bb2 228 /// ^bb1 229 /// br ^bbN(...) 230 /// ^bb2 231 /// br ^bbK(...) 232 /// 233 /// -> cf.cond_br %cond, ^bbN(...), ^bbK(...) 234 /// 235 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> { 236 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 237 238 LogicalResult matchAndRewrite(CondBranchOp condbr, 239 PatternRewriter &rewriter) const override { 240 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest(); 241 ValueRange trueDestOperands = condbr.getTrueOperands(); 242 ValueRange falseDestOperands = condbr.getFalseOperands(); 243 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage; 244 245 // Try to collapse one of the current successors. 246 LogicalResult collapsedTrue = 247 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage); 248 LogicalResult collapsedFalse = 249 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage); 250 if (failed(collapsedTrue) && failed(collapsedFalse)) 251 return failure(); 252 253 // Create a new branch with the collapsed successors. 254 rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(), 255 trueDest, trueDestOperands, 256 falseDest, falseDestOperands); 257 return success(); 258 } 259 }; 260 261 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) 262 /// -> br ^bb1(A, ..., N) 263 /// 264 /// cf.cond_br %cond, ^bb1(A), ^bb1(B) 265 /// -> %select = arith.select %cond, A, B 266 /// br ^bb1(%select) 267 /// 268 struct SimplifyCondBranchIdenticalSuccessors 269 : public OpRewritePattern<CondBranchOp> { 270 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 271 272 LogicalResult matchAndRewrite(CondBranchOp condbr, 273 PatternRewriter &rewriter) const override { 274 // Check that the true and false destinations are the same and have the same 275 // operands. 276 Block *trueDest = condbr.getTrueDest(); 277 if (trueDest != condbr.getFalseDest()) 278 return failure(); 279 280 // If all of the operands match, no selects need to be generated. 281 OperandRange trueOperands = condbr.getTrueOperands(); 282 OperandRange falseOperands = condbr.getFalseOperands(); 283 if (trueOperands == falseOperands) { 284 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands); 285 return success(); 286 } 287 288 // Otherwise, if the current block is the only predecessor insert selects 289 // for any mismatched branch operands. 290 if (trueDest->getUniquePredecessor() != condbr->getBlock()) 291 return failure(); 292 293 // Generate a select for any operands that differ between the two. 294 SmallVector<Value, 8> mergedOperands; 295 mergedOperands.reserve(trueOperands.size()); 296 Value condition = condbr.getCondition(); 297 for (auto it : llvm::zip(trueOperands, falseOperands)) { 298 if (std::get<0>(it) == std::get<1>(it)) 299 mergedOperands.push_back(std::get<0>(it)); 300 else 301 mergedOperands.push_back(rewriter.create<arith::SelectOp>( 302 condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it))); 303 } 304 305 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands); 306 return success(); 307 } 308 }; 309 310 /// ... 311 /// cf.cond_br %cond, ^bb1(...), ^bb2(...) 312 /// ... 313 /// ^bb1: // has single predecessor 314 /// ... 315 /// cf.cond_br %cond, ^bb3(...), ^bb4(...) 316 /// 317 /// -> 318 /// 319 /// ... 320 /// cf.cond_br %cond, ^bb1(...), ^bb2(...) 321 /// ... 322 /// ^bb1: // has single predecessor 323 /// ... 324 /// br ^bb3(...) 325 /// 326 struct SimplifyCondBranchFromCondBranchOnSameCondition 327 : public OpRewritePattern<CondBranchOp> { 328 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 329 330 LogicalResult matchAndRewrite(CondBranchOp condbr, 331 PatternRewriter &rewriter) const override { 332 // Check that we have a single distinct predecessor. 333 Block *currentBlock = condbr->getBlock(); 334 Block *predecessor = currentBlock->getSinglePredecessor(); 335 if (!predecessor) 336 return failure(); 337 338 // Check that the predecessor terminates with a conditional branch to this 339 // block and that it branches on the same condition. 340 auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator()); 341 if (!predBranch || condbr.getCondition() != predBranch.getCondition()) 342 return failure(); 343 344 // Fold this branch to an unconditional branch. 345 if (currentBlock == predBranch.getTrueDest()) 346 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), 347 condbr.getTrueDestOperands()); 348 else 349 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), 350 condbr.getFalseDestOperands()); 351 return success(); 352 } 353 }; 354 355 /// cf.cond_br %arg0, ^trueB, ^falseB 356 /// 357 /// ^trueB: 358 /// "test.consumer1"(%arg0) : (i1) -> () 359 /// ... 360 /// 361 /// ^falseB: 362 /// "test.consumer2"(%arg0) : (i1) -> () 363 /// ... 364 /// 365 /// -> 366 /// 367 /// cf.cond_br %arg0, ^trueB, ^falseB 368 /// ^trueB: 369 /// "test.consumer1"(%true) : (i1) -> () 370 /// ... 371 /// 372 /// ^falseB: 373 /// "test.consumer2"(%false) : (i1) -> () 374 /// ... 375 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> { 376 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 377 378 LogicalResult matchAndRewrite(CondBranchOp condbr, 379 PatternRewriter &rewriter) const override { 380 // Check that we have a single distinct predecessor. 381 bool replaced = false; 382 Type ty = rewriter.getI1Type(); 383 384 // These variables serve to prevent creating duplicate constants 385 // and hold constant true or false values. 386 Value constantTrue = nullptr; 387 Value constantFalse = nullptr; 388 389 // TODO These checks can be expanded to encompas any use with only 390 // either the true of false edge as a predecessor. For now, we fall 391 // back to checking the single predecessor is given by the true/fasle 392 // destination, thereby ensuring that only that edge can reach the 393 // op. 394 if (condbr.getTrueDest()->getSinglePredecessor()) { 395 for (OpOperand &use : 396 llvm::make_early_inc_range(condbr.getCondition().getUses())) { 397 if (use.getOwner()->getBlock() == condbr.getTrueDest()) { 398 replaced = true; 399 400 if (!constantTrue) 401 constantTrue = rewriter.create<arith::ConstantOp>( 402 condbr.getLoc(), ty, rewriter.getBoolAttr(true)); 403 404 rewriter.updateRootInPlace(use.getOwner(), 405 [&] { use.set(constantTrue); }); 406 } 407 } 408 } 409 if (condbr.getFalseDest()->getSinglePredecessor()) { 410 for (OpOperand &use : 411 llvm::make_early_inc_range(condbr.getCondition().getUses())) { 412 if (use.getOwner()->getBlock() == condbr.getFalseDest()) { 413 replaced = true; 414 415 if (!constantFalse) 416 constantFalse = rewriter.create<arith::ConstantOp>( 417 condbr.getLoc(), ty, rewriter.getBoolAttr(false)); 418 419 rewriter.updateRootInPlace(use.getOwner(), 420 [&] { use.set(constantFalse); }); 421 } 422 } 423 } 424 return success(replaced); 425 } 426 }; 427 } // namespace 428 429 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, 430 MLIRContext *context) { 431 results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch, 432 SimplifyCondBranchIdenticalSuccessors, 433 SimplifyCondBranchFromCondBranchOnSameCondition, 434 CondBranchTruthPropagation>(context); 435 } 436 437 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { 438 assert(index < getNumSuccessors() && "invalid successor index"); 439 return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable() 440 : getFalseDestOperandsMutable()); 441 } 442 443 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { 444 if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) 445 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest(); 446 return nullptr; 447 } 448 449 //===----------------------------------------------------------------------===// 450 // SwitchOp 451 //===----------------------------------------------------------------------===// 452 453 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 454 Block *defaultDestination, ValueRange defaultOperands, 455 DenseIntElementsAttr caseValues, 456 BlockRange caseDestinations, 457 ArrayRef<ValueRange> caseOperands) { 458 build(builder, result, value, defaultOperands, caseOperands, caseValues, 459 defaultDestination, caseDestinations); 460 } 461 462 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 463 Block *defaultDestination, ValueRange defaultOperands, 464 ArrayRef<APInt> caseValues, BlockRange caseDestinations, 465 ArrayRef<ValueRange> caseOperands) { 466 DenseIntElementsAttr caseValuesAttr; 467 if (!caseValues.empty()) { 468 ShapedType caseValueType = VectorType::get( 469 static_cast<int64_t>(caseValues.size()), value.getType()); 470 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); 471 } 472 build(builder, result, value, defaultDestination, defaultOperands, 473 caseValuesAttr, caseDestinations, caseOperands); 474 } 475 476 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 477 Block *defaultDestination, ValueRange defaultOperands, 478 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 479 ArrayRef<ValueRange> caseOperands) { 480 DenseIntElementsAttr caseValuesAttr; 481 if (!caseValues.empty()) { 482 ShapedType caseValueType = VectorType::get( 483 static_cast<int64_t>(caseValues.size()), value.getType()); 484 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); 485 } 486 build(builder, result, value, defaultDestination, defaultOperands, 487 caseValuesAttr, caseDestinations, caseOperands); 488 } 489 490 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? 491 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* 492 static ParseResult parseSwitchOpCases( 493 OpAsmParser &parser, Type &flagType, Block *&defaultDestination, 494 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands, 495 SmallVectorImpl<Type> &defaultOperandTypes, 496 DenseIntElementsAttr &caseValues, 497 SmallVectorImpl<Block *> &caseDestinations, 498 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, 499 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 500 if (parser.parseKeyword("default") || parser.parseColon() || 501 parser.parseSuccessor(defaultDestination)) 502 return failure(); 503 if (succeeded(parser.parseOptionalLParen())) { 504 if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None, 505 /*allowResultNumber=*/false) || 506 parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen()) 507 return failure(); 508 } 509 510 SmallVector<APInt> values; 511 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 512 while (succeeded(parser.parseOptionalComma())) { 513 int64_t value = 0; 514 if (failed(parser.parseInteger(value))) 515 return failure(); 516 values.push_back(APInt(bitWidth, value)); 517 518 Block *destination; 519 SmallVector<OpAsmParser::UnresolvedOperand> operands; 520 SmallVector<Type> operandTypes; 521 if (failed(parser.parseColon()) || 522 failed(parser.parseSuccessor(destination))) 523 return failure(); 524 if (succeeded(parser.parseOptionalLParen())) { 525 if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None, 526 /*allowResultNumber=*/false)) || 527 failed(parser.parseColonTypeList(operandTypes)) || 528 failed(parser.parseRParen())) 529 return failure(); 530 } 531 caseDestinations.push_back(destination); 532 caseOperands.emplace_back(operands); 533 caseOperandTypes.emplace_back(operandTypes); 534 } 535 536 if (!values.empty()) { 537 ShapedType caseValueType = 538 VectorType::get(static_cast<int64_t>(values.size()), flagType); 539 caseValues = DenseIntElementsAttr::get(caseValueType, values); 540 } 541 return success(); 542 } 543 544 static void printSwitchOpCases( 545 OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, 546 OperandRange defaultOperands, TypeRange defaultOperandTypes, 547 DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, 548 OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) { 549 p << " default: "; 550 p.printSuccessorAndUseList(defaultDestination, defaultOperands); 551 552 if (!caseValues) 553 return; 554 555 for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) { 556 p << ','; 557 p.printNewline(); 558 p << " "; 559 p << it.value().getLimitedValue(); 560 p << ": "; 561 p.printSuccessorAndUseList(caseDestinations[it.index()], 562 caseOperands[it.index()]); 563 } 564 p.printNewline(); 565 } 566 567 LogicalResult SwitchOp::verify() { 568 auto caseValues = getCaseValues(); 569 auto caseDestinations = getCaseDestinations(); 570 571 if (!caseValues && caseDestinations.empty()) 572 return success(); 573 574 Type flagType = getFlag().getType(); 575 Type caseValueType = caseValues->getType().getElementType(); 576 if (caseValueType != flagType) 577 return emitOpError() << "'flag' type (" << flagType 578 << ") should match case value type (" << caseValueType 579 << ")"; 580 581 if (caseValues && 582 caseValues->size() != static_cast<int64_t>(caseDestinations.size())) 583 return emitOpError() << "number of case values (" << caseValues->size() 584 << ") should match number of " 585 "case destinations (" 586 << caseDestinations.size() << ")"; 587 return success(); 588 } 589 590 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { 591 assert(index < getNumSuccessors() && "invalid successor index"); 592 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() 593 : getCaseOperandsMutable(index - 1)); 594 } 595 596 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { 597 std::optional<DenseIntElementsAttr> caseValues = getCaseValues(); 598 599 if (!caseValues) 600 return getDefaultDestination(); 601 602 SuccessorRange caseDests = getCaseDestinations(); 603 if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) { 604 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) 605 if (it.value() == value.getValue()) 606 return caseDests[it.index()]; 607 return getDefaultDestination(); 608 } 609 return nullptr; 610 } 611 612 /// switch %flag : i32, [ 613 /// default: ^bb1 614 /// ] 615 /// -> br ^bb1 616 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, 617 PatternRewriter &rewriter) { 618 if (!op.getCaseDestinations().empty()) 619 return failure(); 620 621 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 622 op.getDefaultOperands()); 623 return success(); 624 } 625 626 /// switch %flag : i32, [ 627 /// default: ^bb1, 628 /// 42: ^bb1, 629 /// 43: ^bb2 630 /// ] 631 /// -> 632 /// switch %flag : i32, [ 633 /// default: ^bb1, 634 /// 43: ^bb2 635 /// ] 636 static LogicalResult 637 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) { 638 SmallVector<Block *> newCaseDestinations; 639 SmallVector<ValueRange> newCaseOperands; 640 SmallVector<APInt> newCaseValues; 641 bool requiresChange = false; 642 auto caseValues = op.getCaseValues(); 643 auto caseDests = op.getCaseDestinations(); 644 645 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 646 if (caseDests[it.index()] == op.getDefaultDestination() && 647 op.getCaseOperands(it.index()) == op.getDefaultOperands()) { 648 requiresChange = true; 649 continue; 650 } 651 newCaseDestinations.push_back(caseDests[it.index()]); 652 newCaseOperands.push_back(op.getCaseOperands(it.index())); 653 newCaseValues.push_back(it.value()); 654 } 655 656 if (!requiresChange) 657 return failure(); 658 659 rewriter.replaceOpWithNewOp<SwitchOp>( 660 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), 661 newCaseValues, newCaseDestinations, newCaseOperands); 662 return success(); 663 } 664 665 /// Helper for folding a switch with a constant value. 666 /// switch %c_42 : i32, [ 667 /// default: ^bb1 , 668 /// 42: ^bb2, 669 /// 43: ^bb3 670 /// ] 671 /// -> br ^bb2 672 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, 673 const APInt &caseValue) { 674 auto caseValues = op.getCaseValues(); 675 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 676 if (it.value() == caseValue) { 677 rewriter.replaceOpWithNewOp<BranchOp>( 678 op, op.getCaseDestinations()[it.index()], 679 op.getCaseOperands(it.index())); 680 return; 681 } 682 } 683 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 684 op.getDefaultOperands()); 685 } 686 687 /// switch %c_42 : i32, [ 688 /// default: ^bb1, 689 /// 42: ^bb2, 690 /// 43: ^bb3 691 /// ] 692 /// -> br ^bb2 693 static LogicalResult simplifyConstSwitchValue(SwitchOp op, 694 PatternRewriter &rewriter) { 695 APInt caseValue; 696 if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue))) 697 return failure(); 698 699 foldSwitch(op, rewriter, caseValue); 700 return success(); 701 } 702 703 /// switch %c_42 : i32, [ 704 /// default: ^bb1, 705 /// 42: ^bb2, 706 /// ] 707 /// ^bb2: 708 /// br ^bb3 709 /// -> 710 /// switch %c_42 : i32, [ 711 /// default: ^bb1, 712 /// 42: ^bb3, 713 /// ] 714 static LogicalResult simplifyPassThroughSwitch(SwitchOp op, 715 PatternRewriter &rewriter) { 716 SmallVector<Block *> newCaseDests; 717 SmallVector<ValueRange> newCaseOperands; 718 SmallVector<SmallVector<Value>> argStorage; 719 auto caseValues = op.getCaseValues(); 720 argStorage.reserve(caseValues->size() + 1); 721 auto caseDests = op.getCaseDestinations(); 722 bool requiresChange = false; 723 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { 724 Block *caseDest = caseDests[i]; 725 ValueRange caseOperands = op.getCaseOperands(i); 726 argStorage.emplace_back(); 727 if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back()))) 728 requiresChange = true; 729 730 newCaseDests.push_back(caseDest); 731 newCaseOperands.push_back(caseOperands); 732 } 733 734 Block *defaultDest = op.getDefaultDestination(); 735 ValueRange defaultOperands = op.getDefaultOperands(); 736 argStorage.emplace_back(); 737 738 if (succeeded( 739 collapseBranch(defaultDest, defaultOperands, argStorage.back()))) 740 requiresChange = true; 741 742 if (!requiresChange) 743 return failure(); 744 745 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest, 746 defaultOperands, *caseValues, 747 newCaseDests, newCaseOperands); 748 return success(); 749 } 750 751 /// switch %flag : i32, [ 752 /// default: ^bb1, 753 /// 42: ^bb2, 754 /// ] 755 /// ^bb2: 756 /// switch %flag : i32, [ 757 /// default: ^bb3, 758 /// 42: ^bb4 759 /// ] 760 /// -> 761 /// switch %flag : i32, [ 762 /// default: ^bb1, 763 /// 42: ^bb2, 764 /// ] 765 /// ^bb2: 766 /// br ^bb4 767 /// 768 /// and 769 /// 770 /// switch %flag : i32, [ 771 /// default: ^bb1, 772 /// 42: ^bb2, 773 /// ] 774 /// ^bb2: 775 /// switch %flag : i32, [ 776 /// default: ^bb3, 777 /// 43: ^bb4 778 /// ] 779 /// -> 780 /// switch %flag : i32, [ 781 /// default: ^bb1, 782 /// 42: ^bb2, 783 /// ] 784 /// ^bb2: 785 /// br ^bb3 786 static LogicalResult 787 simplifySwitchFromSwitchOnSameCondition(SwitchOp op, 788 PatternRewriter &rewriter) { 789 // Check that we have a single distinct predecessor. 790 Block *currentBlock = op->getBlock(); 791 Block *predecessor = currentBlock->getSinglePredecessor(); 792 if (!predecessor) 793 return failure(); 794 795 // Check that the predecessor terminates with a switch branch to this block 796 // and that it branches on the same condition and that this branch isn't the 797 // default destination. 798 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); 799 if (!predSwitch || op.getFlag() != predSwitch.getFlag() || 800 predSwitch.getDefaultDestination() == currentBlock) 801 return failure(); 802 803 // Fold this switch to an unconditional branch. 804 SuccessorRange predDests = predSwitch.getCaseDestinations(); 805 auto it = llvm::find(predDests, currentBlock); 806 if (it != predDests.end()) { 807 std::optional<DenseIntElementsAttr> predCaseValues = 808 predSwitch.getCaseValues(); 809 foldSwitch(op, rewriter, 810 predCaseValues->getValues<APInt>()[it - predDests.begin()]); 811 } else { 812 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 813 op.getDefaultOperands()); 814 } 815 return success(); 816 } 817 818 /// switch %flag : i32, [ 819 /// default: ^bb1, 820 /// 42: ^bb2 821 /// ] 822 /// ^bb1: 823 /// switch %flag : i32, [ 824 /// default: ^bb3, 825 /// 42: ^bb4, 826 /// 43: ^bb5 827 /// ] 828 /// -> 829 /// switch %flag : i32, [ 830 /// default: ^bb1, 831 /// 42: ^bb2, 832 /// ] 833 /// ^bb1: 834 /// switch %flag : i32, [ 835 /// default: ^bb3, 836 /// 43: ^bb5 837 /// ] 838 static LogicalResult 839 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, 840 PatternRewriter &rewriter) { 841 // Check that we have a single distinct predecessor. 842 Block *currentBlock = op->getBlock(); 843 Block *predecessor = currentBlock->getSinglePredecessor(); 844 if (!predecessor) 845 return failure(); 846 847 // Check that the predecessor terminates with a switch branch to this block 848 // and that it branches on the same condition and that this branch is the 849 // default destination. 850 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); 851 if (!predSwitch || op.getFlag() != predSwitch.getFlag() || 852 predSwitch.getDefaultDestination() != currentBlock) 853 return failure(); 854 855 // Delete case values that are not possible here. 856 DenseSet<APInt> caseValuesToRemove; 857 auto predDests = predSwitch.getCaseDestinations(); 858 auto predCaseValues = predSwitch.getCaseValues(); 859 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) 860 if (currentBlock != predDests[i]) 861 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]); 862 863 SmallVector<Block *> newCaseDestinations; 864 SmallVector<ValueRange> newCaseOperands; 865 SmallVector<APInt> newCaseValues; 866 bool requiresChange = false; 867 868 auto caseValues = op.getCaseValues(); 869 auto caseDests = op.getCaseDestinations(); 870 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 871 if (caseValuesToRemove.contains(it.value())) { 872 requiresChange = true; 873 continue; 874 } 875 newCaseDestinations.push_back(caseDests[it.index()]); 876 newCaseOperands.push_back(op.getCaseOperands(it.index())); 877 newCaseValues.push_back(it.value()); 878 } 879 880 if (!requiresChange) 881 return failure(); 882 883 rewriter.replaceOpWithNewOp<SwitchOp>( 884 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), 885 newCaseValues, newCaseDestinations, newCaseOperands); 886 return success(); 887 } 888 889 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, 890 MLIRContext *context) { 891 results.add(&simplifySwitchWithOnlyDefault) 892 .add(&dropSwitchCasesThatMatchDefault) 893 .add(&simplifyConstSwitchValue) 894 .add(&simplifyPassThroughSwitch) 895 .add(&simplifySwitchFromSwitchOnSameCondition) 896 .add(&simplifySwitchFromDefaultSwitchOnSameCondition); 897 } 898 899 //===----------------------------------------------------------------------===// 900 // TableGen'd op method definitions 901 //===----------------------------------------------------------------------===// 902 903 #define GET_OP_CLASSES 904 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" 905