1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 10 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/CommonFolders.h" 13 #include "mlir/IR/AffineExpr.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/IRMapping.h" 19 #include "mlir/IR/Matchers.h" 20 #include "mlir/IR/OpImplementation.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/IR/Value.h" 24 #include "mlir/Support/MathExtras.h" 25 #include "mlir/Transforms/InliningUtils.h" 26 #include "llvm/ADT/APFloat.h" 27 #include "llvm/ADT/STLExtras.h" 28 #include "llvm/Support/FormatVariadic.h" 29 #include "llvm/Support/raw_ostream.h" 30 #include <numeric> 31 32 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc" 33 34 using namespace mlir; 35 using namespace mlir::cf; 36 37 //===----------------------------------------------------------------------===// 38 // ControlFlowDialect Interfaces 39 //===----------------------------------------------------------------------===// 40 namespace { 41 /// This class defines the interface for handling inlining with control flow 42 /// operations. 43 struct ControlFlowInlinerInterface : public DialectInlinerInterface { 44 using DialectInlinerInterface::DialectInlinerInterface; 45 ~ControlFlowInlinerInterface() override = default; 46 47 /// All control flow operations can be inlined. 48 bool isLegalToInline(Operation *call, Operation *callable, 49 bool wouldBeCloned) const final { 50 return true; 51 } 52 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { 53 return true; 54 } 55 56 /// ControlFlow terminator operations don't really need any special handing. 57 void handleTerminator(Operation *op, Block *newDest) const final {} 58 }; 59 } // namespace 60 61 //===----------------------------------------------------------------------===// 62 // ControlFlowDialect 63 //===----------------------------------------------------------------------===// 64 65 void ControlFlowDialect::initialize() { 66 addOperations< 67 #define GET_OP_LIST 68 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" 69 >(); 70 addInterfaces<ControlFlowInlinerInterface>(); 71 } 72 73 //===----------------------------------------------------------------------===// 74 // AssertOp 75 //===----------------------------------------------------------------------===// 76 77 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { 78 // Erase assertion if argument is constant true. 79 if (matchPattern(op.getArg(), m_One())) { 80 rewriter.eraseOp(op); 81 return success(); 82 } 83 return failure(); 84 } 85 86 //===----------------------------------------------------------------------===// 87 // BranchOp 88 //===----------------------------------------------------------------------===// 89 90 /// Given a successor, try to collapse it to a new destination if it only 91 /// contains a passthrough unconditional branch. If the successor is 92 /// collapsable, `successor` and `successorOperands` are updated to reference 93 /// the new destination and values. `argStorage` is used as storage if operands 94 /// to the collapsed successor need to be remapped. It must outlive uses of 95 /// successorOperands. 96 static LogicalResult collapseBranch(Block *&successor, 97 ValueRange &successorOperands, 98 SmallVectorImpl<Value> &argStorage) { 99 // Check that the successor only contains a unconditional branch. 100 if (std::next(successor->begin()) != successor->end()) 101 return failure(); 102 // Check that the terminator is an unconditional branch. 103 BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator()); 104 if (!successorBranch) 105 return failure(); 106 // Check that the arguments are only used within the terminator. 107 for (BlockArgument arg : successor->getArguments()) { 108 for (Operation *user : arg.getUsers()) 109 if (user != successorBranch) 110 return failure(); 111 } 112 // Don't try to collapse branches to infinite loops. 113 Block *successorDest = successorBranch.getDest(); 114 if (successorDest == successor) 115 return failure(); 116 117 // Update the operands to the successor. If the branch parent has no 118 // arguments, we can use the branch operands directly. 119 OperandRange operands = successorBranch.getOperands(); 120 if (successor->args_empty()) { 121 successor = successorDest; 122 successorOperands = operands; 123 return success(); 124 } 125 126 // Otherwise, we need to remap any argument operands. 127 for (Value operand : operands) { 128 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand); 129 if (argOperand && argOperand.getOwner() == successor) 130 argStorage.push_back(successorOperands[argOperand.getArgNumber()]); 131 else 132 argStorage.push_back(operand); 133 } 134 successor = successorDest; 135 successorOperands = argStorage; 136 return success(); 137 } 138 139 /// Simplify a branch to a block that has a single predecessor. This effectively 140 /// merges the two blocks. 141 static LogicalResult 142 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) { 143 // Check that the successor block has a single predecessor. 144 Block *succ = op.getDest(); 145 Block *opParent = op->getBlock(); 146 if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors())) 147 return failure(); 148 149 // Merge the successor into the current block and erase the branch. 150 SmallVector<Value> brOperands(op.getOperands()); 151 rewriter.eraseOp(op); 152 rewriter.mergeBlocks(succ, opParent, brOperands); 153 return success(); 154 } 155 156 /// br ^bb1 157 /// ^bb1 158 /// br ^bbN(...) 159 /// 160 /// -> br ^bbN(...) 161 /// 162 static LogicalResult simplifyPassThroughBr(BranchOp op, 163 PatternRewriter &rewriter) { 164 Block *dest = op.getDest(); 165 ValueRange destOperands = op.getOperands(); 166 SmallVector<Value, 4> destOperandStorage; 167 168 // Try to collapse the successor if it points somewhere other than this 169 // block. 170 if (dest == op->getBlock() || 171 failed(collapseBranch(dest, destOperands, destOperandStorage))) 172 return failure(); 173 174 // Create a new branch with the collapsed successor. 175 rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands); 176 return success(); 177 } 178 179 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) { 180 return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) || 181 succeeded(simplifyPassThroughBr(op, rewriter))); 182 } 183 184 void BranchOp::setDest(Block *block) { return setSuccessor(block); } 185 186 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); } 187 188 SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) { 189 assert(index == 0 && "invalid successor index"); 190 return SuccessorOperands(getDestOperandsMutable()); 191 } 192 193 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { 194 return getDest(); 195 } 196 197 //===----------------------------------------------------------------------===// 198 // CondBranchOp 199 //===----------------------------------------------------------------------===// 200 201 namespace { 202 /// cf.cond_br true, ^bb1, ^bb2 203 /// -> br ^bb1 204 /// cf.cond_br false, ^bb1, ^bb2 205 /// -> br ^bb2 206 /// 207 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { 208 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 209 210 LogicalResult matchAndRewrite(CondBranchOp condbr, 211 PatternRewriter &rewriter) const override { 212 if (matchPattern(condbr.getCondition(), m_NonZero())) { 213 // True branch taken. 214 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), 215 condbr.getTrueOperands()); 216 return success(); 217 } 218 if (matchPattern(condbr.getCondition(), m_Zero())) { 219 // False branch taken. 220 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), 221 condbr.getFalseOperands()); 222 return success(); 223 } 224 return failure(); 225 } 226 }; 227 228 /// cf.cond_br %cond, ^bb1, ^bb2 229 /// ^bb1 230 /// br ^bbN(...) 231 /// ^bb2 232 /// br ^bbK(...) 233 /// 234 /// -> cf.cond_br %cond, ^bbN(...), ^bbK(...) 235 /// 236 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> { 237 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 238 239 LogicalResult matchAndRewrite(CondBranchOp condbr, 240 PatternRewriter &rewriter) const override { 241 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest(); 242 ValueRange trueDestOperands = condbr.getTrueOperands(); 243 ValueRange falseDestOperands = condbr.getFalseOperands(); 244 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage; 245 246 // Try to collapse one of the current successors. 247 LogicalResult collapsedTrue = 248 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage); 249 LogicalResult collapsedFalse = 250 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage); 251 if (failed(collapsedTrue) && failed(collapsedFalse)) 252 return failure(); 253 254 // Create a new branch with the collapsed successors. 255 rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(), 256 trueDest, trueDestOperands, 257 falseDest, falseDestOperands); 258 return success(); 259 } 260 }; 261 262 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) 263 /// -> br ^bb1(A, ..., N) 264 /// 265 /// cf.cond_br %cond, ^bb1(A), ^bb1(B) 266 /// -> %select = arith.select %cond, A, B 267 /// br ^bb1(%select) 268 /// 269 struct SimplifyCondBranchIdenticalSuccessors 270 : public OpRewritePattern<CondBranchOp> { 271 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 272 273 LogicalResult matchAndRewrite(CondBranchOp condbr, 274 PatternRewriter &rewriter) const override { 275 // Check that the true and false destinations are the same and have the same 276 // operands. 277 Block *trueDest = condbr.getTrueDest(); 278 if (trueDest != condbr.getFalseDest()) 279 return failure(); 280 281 // If all of the operands match, no selects need to be generated. 282 OperandRange trueOperands = condbr.getTrueOperands(); 283 OperandRange falseOperands = condbr.getFalseOperands(); 284 if (trueOperands == falseOperands) { 285 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands); 286 return success(); 287 } 288 289 // Otherwise, if the current block is the only predecessor insert selects 290 // for any mismatched branch operands. 291 if (trueDest->getUniquePredecessor() != condbr->getBlock()) 292 return failure(); 293 294 // Generate a select for any operands that differ between the two. 295 SmallVector<Value, 8> mergedOperands; 296 mergedOperands.reserve(trueOperands.size()); 297 Value condition = condbr.getCondition(); 298 for (auto it : llvm::zip(trueOperands, falseOperands)) { 299 if (std::get<0>(it) == std::get<1>(it)) 300 mergedOperands.push_back(std::get<0>(it)); 301 else 302 mergedOperands.push_back(rewriter.create<arith::SelectOp>( 303 condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it))); 304 } 305 306 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands); 307 return success(); 308 } 309 }; 310 311 /// ... 312 /// cf.cond_br %cond, ^bb1(...), ^bb2(...) 313 /// ... 314 /// ^bb1: // has single predecessor 315 /// ... 316 /// cf.cond_br %cond, ^bb3(...), ^bb4(...) 317 /// 318 /// -> 319 /// 320 /// ... 321 /// cf.cond_br %cond, ^bb1(...), ^bb2(...) 322 /// ... 323 /// ^bb1: // has single predecessor 324 /// ... 325 /// br ^bb3(...) 326 /// 327 struct SimplifyCondBranchFromCondBranchOnSameCondition 328 : public OpRewritePattern<CondBranchOp> { 329 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 330 331 LogicalResult matchAndRewrite(CondBranchOp condbr, 332 PatternRewriter &rewriter) const override { 333 // Check that we have a single distinct predecessor. 334 Block *currentBlock = condbr->getBlock(); 335 Block *predecessor = currentBlock->getSinglePredecessor(); 336 if (!predecessor) 337 return failure(); 338 339 // Check that the predecessor terminates with a conditional branch to this 340 // block and that it branches on the same condition. 341 auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator()); 342 if (!predBranch || condbr.getCondition() != predBranch.getCondition()) 343 return failure(); 344 345 // Fold this branch to an unconditional branch. 346 if (currentBlock == predBranch.getTrueDest()) 347 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), 348 condbr.getTrueDestOperands()); 349 else 350 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), 351 condbr.getFalseDestOperands()); 352 return success(); 353 } 354 }; 355 356 /// cf.cond_br %arg0, ^trueB, ^falseB 357 /// 358 /// ^trueB: 359 /// "test.consumer1"(%arg0) : (i1) -> () 360 /// ... 361 /// 362 /// ^falseB: 363 /// "test.consumer2"(%arg0) : (i1) -> () 364 /// ... 365 /// 366 /// -> 367 /// 368 /// cf.cond_br %arg0, ^trueB, ^falseB 369 /// ^trueB: 370 /// "test.consumer1"(%true) : (i1) -> () 371 /// ... 372 /// 373 /// ^falseB: 374 /// "test.consumer2"(%false) : (i1) -> () 375 /// ... 376 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> { 377 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 378 379 LogicalResult matchAndRewrite(CondBranchOp condbr, 380 PatternRewriter &rewriter) const override { 381 // Check that we have a single distinct predecessor. 382 bool replaced = false; 383 Type ty = rewriter.getI1Type(); 384 385 // These variables serve to prevent creating duplicate constants 386 // and hold constant true or false values. 387 Value constantTrue = nullptr; 388 Value constantFalse = nullptr; 389 390 // TODO These checks can be expanded to encompas any use with only 391 // either the true of false edge as a predecessor. For now, we fall 392 // back to checking the single predecessor is given by the true/fasle 393 // destination, thereby ensuring that only that edge can reach the 394 // op. 395 if (condbr.getTrueDest()->getSinglePredecessor()) { 396 for (OpOperand &use : 397 llvm::make_early_inc_range(condbr.getCondition().getUses())) { 398 if (use.getOwner()->getBlock() == condbr.getTrueDest()) { 399 replaced = true; 400 401 if (!constantTrue) 402 constantTrue = rewriter.create<arith::ConstantOp>( 403 condbr.getLoc(), ty, rewriter.getBoolAttr(true)); 404 405 rewriter.updateRootInPlace(use.getOwner(), 406 [&] { use.set(constantTrue); }); 407 } 408 } 409 } 410 if (condbr.getFalseDest()->getSinglePredecessor()) { 411 for (OpOperand &use : 412 llvm::make_early_inc_range(condbr.getCondition().getUses())) { 413 if (use.getOwner()->getBlock() == condbr.getFalseDest()) { 414 replaced = true; 415 416 if (!constantFalse) 417 constantFalse = rewriter.create<arith::ConstantOp>( 418 condbr.getLoc(), ty, rewriter.getBoolAttr(false)); 419 420 rewriter.updateRootInPlace(use.getOwner(), 421 [&] { use.set(constantFalse); }); 422 } 423 } 424 } 425 return success(replaced); 426 } 427 }; 428 } // namespace 429 430 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, 431 MLIRContext *context) { 432 results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch, 433 SimplifyCondBranchIdenticalSuccessors, 434 SimplifyCondBranchFromCondBranchOnSameCondition, 435 CondBranchTruthPropagation>(context); 436 } 437 438 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { 439 assert(index < getNumSuccessors() && "invalid successor index"); 440 return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable() 441 : getFalseDestOperandsMutable()); 442 } 443 444 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { 445 if (IntegerAttr condAttr = 446 llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) 447 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest(); 448 return nullptr; 449 } 450 451 //===----------------------------------------------------------------------===// 452 // SwitchOp 453 //===----------------------------------------------------------------------===// 454 455 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 456 Block *defaultDestination, ValueRange defaultOperands, 457 DenseIntElementsAttr caseValues, 458 BlockRange caseDestinations, 459 ArrayRef<ValueRange> caseOperands) { 460 build(builder, result, value, defaultOperands, caseOperands, caseValues, 461 defaultDestination, caseDestinations); 462 } 463 464 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 465 Block *defaultDestination, ValueRange defaultOperands, 466 ArrayRef<APInt> caseValues, BlockRange caseDestinations, 467 ArrayRef<ValueRange> caseOperands) { 468 DenseIntElementsAttr caseValuesAttr; 469 if (!caseValues.empty()) { 470 ShapedType caseValueType = VectorType::get( 471 static_cast<int64_t>(caseValues.size()), value.getType()); 472 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); 473 } 474 build(builder, result, value, defaultDestination, defaultOperands, 475 caseValuesAttr, caseDestinations, caseOperands); 476 } 477 478 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 479 Block *defaultDestination, ValueRange defaultOperands, 480 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 481 ArrayRef<ValueRange> caseOperands) { 482 DenseIntElementsAttr caseValuesAttr; 483 if (!caseValues.empty()) { 484 ShapedType caseValueType = VectorType::get( 485 static_cast<int64_t>(caseValues.size()), value.getType()); 486 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); 487 } 488 build(builder, result, value, defaultDestination, defaultOperands, 489 caseValuesAttr, caseDestinations, caseOperands); 490 } 491 492 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? 493 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* 494 static ParseResult parseSwitchOpCases( 495 OpAsmParser &parser, Type &flagType, Block *&defaultDestination, 496 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands, 497 SmallVectorImpl<Type> &defaultOperandTypes, 498 DenseIntElementsAttr &caseValues, 499 SmallVectorImpl<Block *> &caseDestinations, 500 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, 501 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 502 if (parser.parseKeyword("default") || parser.parseColon() || 503 parser.parseSuccessor(defaultDestination)) 504 return failure(); 505 if (succeeded(parser.parseOptionalLParen())) { 506 if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None, 507 /*allowResultNumber=*/false) || 508 parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen()) 509 return failure(); 510 } 511 512 SmallVector<APInt> values; 513 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 514 while (succeeded(parser.parseOptionalComma())) { 515 int64_t value = 0; 516 if (failed(parser.parseInteger(value))) 517 return failure(); 518 values.push_back(APInt(bitWidth, value)); 519 520 Block *destination; 521 SmallVector<OpAsmParser::UnresolvedOperand> operands; 522 SmallVector<Type> operandTypes; 523 if (failed(parser.parseColon()) || 524 failed(parser.parseSuccessor(destination))) 525 return failure(); 526 if (succeeded(parser.parseOptionalLParen())) { 527 if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None, 528 /*allowResultNumber=*/false)) || 529 failed(parser.parseColonTypeList(operandTypes)) || 530 failed(parser.parseRParen())) 531 return failure(); 532 } 533 caseDestinations.push_back(destination); 534 caseOperands.emplace_back(operands); 535 caseOperandTypes.emplace_back(operandTypes); 536 } 537 538 if (!values.empty()) { 539 ShapedType caseValueType = 540 VectorType::get(static_cast<int64_t>(values.size()), flagType); 541 caseValues = DenseIntElementsAttr::get(caseValueType, values); 542 } 543 return success(); 544 } 545 546 static void printSwitchOpCases( 547 OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, 548 OperandRange defaultOperands, TypeRange defaultOperandTypes, 549 DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, 550 OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) { 551 p << " default: "; 552 p.printSuccessorAndUseList(defaultDestination, defaultOperands); 553 554 if (!caseValues) 555 return; 556 557 for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) { 558 p << ','; 559 p.printNewline(); 560 p << " "; 561 p << it.value().getLimitedValue(); 562 p << ": "; 563 p.printSuccessorAndUseList(caseDestinations[it.index()], 564 caseOperands[it.index()]); 565 } 566 p.printNewline(); 567 } 568 569 LogicalResult SwitchOp::verify() { 570 auto caseValues = getCaseValues(); 571 auto caseDestinations = getCaseDestinations(); 572 573 if (!caseValues && caseDestinations.empty()) 574 return success(); 575 576 Type flagType = getFlag().getType(); 577 Type caseValueType = caseValues->getType().getElementType(); 578 if (caseValueType != flagType) 579 return emitOpError() << "'flag' type (" << flagType 580 << ") should match case value type (" << caseValueType 581 << ")"; 582 583 if (caseValues && 584 caseValues->size() != static_cast<int64_t>(caseDestinations.size())) 585 return emitOpError() << "number of case values (" << caseValues->size() 586 << ") should match number of " 587 "case destinations (" 588 << caseDestinations.size() << ")"; 589 return success(); 590 } 591 592 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { 593 assert(index < getNumSuccessors() && "invalid successor index"); 594 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() 595 : getCaseOperandsMutable(index - 1)); 596 } 597 598 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { 599 std::optional<DenseIntElementsAttr> caseValues = getCaseValues(); 600 601 if (!caseValues) 602 return getDefaultDestination(); 603 604 SuccessorRange caseDests = getCaseDestinations(); 605 if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) { 606 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) 607 if (it.value() == value.getValue()) 608 return caseDests[it.index()]; 609 return getDefaultDestination(); 610 } 611 return nullptr; 612 } 613 614 /// switch %flag : i32, [ 615 /// default: ^bb1 616 /// ] 617 /// -> br ^bb1 618 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, 619 PatternRewriter &rewriter) { 620 if (!op.getCaseDestinations().empty()) 621 return failure(); 622 623 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 624 op.getDefaultOperands()); 625 return success(); 626 } 627 628 /// switch %flag : i32, [ 629 /// default: ^bb1, 630 /// 42: ^bb1, 631 /// 43: ^bb2 632 /// ] 633 /// -> 634 /// switch %flag : i32, [ 635 /// default: ^bb1, 636 /// 43: ^bb2 637 /// ] 638 static LogicalResult 639 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) { 640 SmallVector<Block *> newCaseDestinations; 641 SmallVector<ValueRange> newCaseOperands; 642 SmallVector<APInt> newCaseValues; 643 bool requiresChange = false; 644 auto caseValues = op.getCaseValues(); 645 auto caseDests = op.getCaseDestinations(); 646 647 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 648 if (caseDests[it.index()] == op.getDefaultDestination() && 649 op.getCaseOperands(it.index()) == op.getDefaultOperands()) { 650 requiresChange = true; 651 continue; 652 } 653 newCaseDestinations.push_back(caseDests[it.index()]); 654 newCaseOperands.push_back(op.getCaseOperands(it.index())); 655 newCaseValues.push_back(it.value()); 656 } 657 658 if (!requiresChange) 659 return failure(); 660 661 rewriter.replaceOpWithNewOp<SwitchOp>( 662 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), 663 newCaseValues, newCaseDestinations, newCaseOperands); 664 return success(); 665 } 666 667 /// Helper for folding a switch with a constant value. 668 /// switch %c_42 : i32, [ 669 /// default: ^bb1 , 670 /// 42: ^bb2, 671 /// 43: ^bb3 672 /// ] 673 /// -> br ^bb2 674 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, 675 const APInt &caseValue) { 676 auto caseValues = op.getCaseValues(); 677 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 678 if (it.value() == caseValue) { 679 rewriter.replaceOpWithNewOp<BranchOp>( 680 op, op.getCaseDestinations()[it.index()], 681 op.getCaseOperands(it.index())); 682 return; 683 } 684 } 685 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 686 op.getDefaultOperands()); 687 } 688 689 /// switch %c_42 : i32, [ 690 /// default: ^bb1, 691 /// 42: ^bb2, 692 /// 43: ^bb3 693 /// ] 694 /// -> br ^bb2 695 static LogicalResult simplifyConstSwitchValue(SwitchOp op, 696 PatternRewriter &rewriter) { 697 APInt caseValue; 698 if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue))) 699 return failure(); 700 701 foldSwitch(op, rewriter, caseValue); 702 return success(); 703 } 704 705 /// switch %c_42 : i32, [ 706 /// default: ^bb1, 707 /// 42: ^bb2, 708 /// ] 709 /// ^bb2: 710 /// br ^bb3 711 /// -> 712 /// switch %c_42 : i32, [ 713 /// default: ^bb1, 714 /// 42: ^bb3, 715 /// ] 716 static LogicalResult simplifyPassThroughSwitch(SwitchOp op, 717 PatternRewriter &rewriter) { 718 SmallVector<Block *> newCaseDests; 719 SmallVector<ValueRange> newCaseOperands; 720 SmallVector<SmallVector<Value>> argStorage; 721 auto caseValues = op.getCaseValues(); 722 argStorage.reserve(caseValues->size() + 1); 723 auto caseDests = op.getCaseDestinations(); 724 bool requiresChange = false; 725 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { 726 Block *caseDest = caseDests[i]; 727 ValueRange caseOperands = op.getCaseOperands(i); 728 argStorage.emplace_back(); 729 if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back()))) 730 requiresChange = true; 731 732 newCaseDests.push_back(caseDest); 733 newCaseOperands.push_back(caseOperands); 734 } 735 736 Block *defaultDest = op.getDefaultDestination(); 737 ValueRange defaultOperands = op.getDefaultOperands(); 738 argStorage.emplace_back(); 739 740 if (succeeded( 741 collapseBranch(defaultDest, defaultOperands, argStorage.back()))) 742 requiresChange = true; 743 744 if (!requiresChange) 745 return failure(); 746 747 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest, 748 defaultOperands, *caseValues, 749 newCaseDests, newCaseOperands); 750 return success(); 751 } 752 753 /// switch %flag : i32, [ 754 /// default: ^bb1, 755 /// 42: ^bb2, 756 /// ] 757 /// ^bb2: 758 /// switch %flag : i32, [ 759 /// default: ^bb3, 760 /// 42: ^bb4 761 /// ] 762 /// -> 763 /// switch %flag : i32, [ 764 /// default: ^bb1, 765 /// 42: ^bb2, 766 /// ] 767 /// ^bb2: 768 /// br ^bb4 769 /// 770 /// and 771 /// 772 /// switch %flag : i32, [ 773 /// default: ^bb1, 774 /// 42: ^bb2, 775 /// ] 776 /// ^bb2: 777 /// switch %flag : i32, [ 778 /// default: ^bb3, 779 /// 43: ^bb4 780 /// ] 781 /// -> 782 /// switch %flag : i32, [ 783 /// default: ^bb1, 784 /// 42: ^bb2, 785 /// ] 786 /// ^bb2: 787 /// br ^bb3 788 static LogicalResult 789 simplifySwitchFromSwitchOnSameCondition(SwitchOp op, 790 PatternRewriter &rewriter) { 791 // Check that we have a single distinct predecessor. 792 Block *currentBlock = op->getBlock(); 793 Block *predecessor = currentBlock->getSinglePredecessor(); 794 if (!predecessor) 795 return failure(); 796 797 // Check that the predecessor terminates with a switch branch to this block 798 // and that it branches on the same condition and that this branch isn't the 799 // default destination. 800 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); 801 if (!predSwitch || op.getFlag() != predSwitch.getFlag() || 802 predSwitch.getDefaultDestination() == currentBlock) 803 return failure(); 804 805 // Fold this switch to an unconditional branch. 806 SuccessorRange predDests = predSwitch.getCaseDestinations(); 807 auto it = llvm::find(predDests, currentBlock); 808 if (it != predDests.end()) { 809 std::optional<DenseIntElementsAttr> predCaseValues = 810 predSwitch.getCaseValues(); 811 foldSwitch(op, rewriter, 812 predCaseValues->getValues<APInt>()[it - predDests.begin()]); 813 } else { 814 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 815 op.getDefaultOperands()); 816 } 817 return success(); 818 } 819 820 /// switch %flag : i32, [ 821 /// default: ^bb1, 822 /// 42: ^bb2 823 /// ] 824 /// ^bb1: 825 /// switch %flag : i32, [ 826 /// default: ^bb3, 827 /// 42: ^bb4, 828 /// 43: ^bb5 829 /// ] 830 /// -> 831 /// switch %flag : i32, [ 832 /// default: ^bb1, 833 /// 42: ^bb2, 834 /// ] 835 /// ^bb1: 836 /// switch %flag : i32, [ 837 /// default: ^bb3, 838 /// 43: ^bb5 839 /// ] 840 static LogicalResult 841 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, 842 PatternRewriter &rewriter) { 843 // Check that we have a single distinct predecessor. 844 Block *currentBlock = op->getBlock(); 845 Block *predecessor = currentBlock->getSinglePredecessor(); 846 if (!predecessor) 847 return failure(); 848 849 // Check that the predecessor terminates with a switch branch to this block 850 // and that it branches on the same condition and that this branch is the 851 // default destination. 852 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); 853 if (!predSwitch || op.getFlag() != predSwitch.getFlag() || 854 predSwitch.getDefaultDestination() != currentBlock) 855 return failure(); 856 857 // Delete case values that are not possible here. 858 DenseSet<APInt> caseValuesToRemove; 859 auto predDests = predSwitch.getCaseDestinations(); 860 auto predCaseValues = predSwitch.getCaseValues(); 861 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) 862 if (currentBlock != predDests[i]) 863 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]); 864 865 SmallVector<Block *> newCaseDestinations; 866 SmallVector<ValueRange> newCaseOperands; 867 SmallVector<APInt> newCaseValues; 868 bool requiresChange = false; 869 870 auto caseValues = op.getCaseValues(); 871 auto caseDests = op.getCaseDestinations(); 872 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 873 if (caseValuesToRemove.contains(it.value())) { 874 requiresChange = true; 875 continue; 876 } 877 newCaseDestinations.push_back(caseDests[it.index()]); 878 newCaseOperands.push_back(op.getCaseOperands(it.index())); 879 newCaseValues.push_back(it.value()); 880 } 881 882 if (!requiresChange) 883 return failure(); 884 885 rewriter.replaceOpWithNewOp<SwitchOp>( 886 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), 887 newCaseValues, newCaseDestinations, newCaseOperands); 888 return success(); 889 } 890 891 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, 892 MLIRContext *context) { 893 results.add(&simplifySwitchWithOnlyDefault) 894 .add(&dropSwitchCasesThatMatchDefault) 895 .add(&simplifyConstSwitchValue) 896 .add(&simplifyPassThroughSwitch) 897 .add(&simplifySwitchFromSwitchOnSameCondition) 898 .add(&simplifySwitchFromDefaultSwitchOnSameCondition); 899 } 900 901 //===----------------------------------------------------------------------===// 902 // TableGen'd op method definitions 903 //===----------------------------------------------------------------------===// 904 905 #define GET_OP_CLASSES 906 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" 907