1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 10 11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/IR/AffineExpr.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/IRMapping.h" 19 #include "mlir/IR/Matchers.h" 20 #include "mlir/IR/OpImplementation.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/IR/Value.h" 24 #include "mlir/Support/MathExtras.h" 25 #include "mlir/Transforms/InliningUtils.h" 26 #include "llvm/ADT/APFloat.h" 27 #include "llvm/ADT/STLExtras.h" 28 #include "llvm/Support/FormatVariadic.h" 29 #include "llvm/Support/raw_ostream.h" 30 #include <numeric> 31 32 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc" 33 34 using namespace mlir; 35 using namespace mlir::cf; 36 37 //===----------------------------------------------------------------------===// 38 // ControlFlowDialect Interfaces 39 //===----------------------------------------------------------------------===// 40 namespace { 41 /// This class defines the interface for handling inlining with control flow 42 /// operations. 43 struct ControlFlowInlinerInterface : public DialectInlinerInterface { 44 using DialectInlinerInterface::DialectInlinerInterface; 45 ~ControlFlowInlinerInterface() override = default; 46 47 /// All control flow operations can be inlined. 48 bool isLegalToInline(Operation *call, Operation *callable, 49 bool wouldBeCloned) const final { 50 return true; 51 } 52 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { 53 return true; 54 } 55 56 /// ControlFlow terminator operations don't really need any special handing. 57 void handleTerminator(Operation *op, Block *newDest) const final {} 58 }; 59 } // namespace 60 61 //===----------------------------------------------------------------------===// 62 // ControlFlowDialect 63 //===----------------------------------------------------------------------===// 64 65 void ControlFlowDialect::initialize() { 66 addOperations< 67 #define GET_OP_LIST 68 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" 69 >(); 70 addInterfaces<ControlFlowInlinerInterface>(); 71 declarePromisedInterface<ControlFlowDialect, ConvertToLLVMPatternInterface>(); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // AssertOp 76 //===----------------------------------------------------------------------===// 77 78 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { 79 // Erase assertion if argument is constant true. 80 if (matchPattern(op.getArg(), m_One())) { 81 rewriter.eraseOp(op); 82 return success(); 83 } 84 return failure(); 85 } 86 87 //===----------------------------------------------------------------------===// 88 // BranchOp 89 //===----------------------------------------------------------------------===// 90 91 /// Given a successor, try to collapse it to a new destination if it only 92 /// contains a passthrough unconditional branch. If the successor is 93 /// collapsable, `successor` and `successorOperands` are updated to reference 94 /// the new destination and values. `argStorage` is used as storage if operands 95 /// to the collapsed successor need to be remapped. It must outlive uses of 96 /// successorOperands. 97 static LogicalResult collapseBranch(Block *&successor, 98 ValueRange &successorOperands, 99 SmallVectorImpl<Value> &argStorage) { 100 // Check that the successor only contains a unconditional branch. 101 if (std::next(successor->begin()) != successor->end()) 102 return failure(); 103 // Check that the terminator is an unconditional branch. 104 BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator()); 105 if (!successorBranch) 106 return failure(); 107 // Check that the arguments are only used within the terminator. 108 for (BlockArgument arg : successor->getArguments()) { 109 for (Operation *user : arg.getUsers()) 110 if (user != successorBranch) 111 return failure(); 112 } 113 // Don't try to collapse branches to infinite loops. 114 Block *successorDest = successorBranch.getDest(); 115 if (successorDest == successor) 116 return failure(); 117 118 // Update the operands to the successor. If the branch parent has no 119 // arguments, we can use the branch operands directly. 120 OperandRange operands = successorBranch.getOperands(); 121 if (successor->args_empty()) { 122 successor = successorDest; 123 successorOperands = operands; 124 return success(); 125 } 126 127 // Otherwise, we need to remap any argument operands. 128 for (Value operand : operands) { 129 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand); 130 if (argOperand && argOperand.getOwner() == successor) 131 argStorage.push_back(successorOperands[argOperand.getArgNumber()]); 132 else 133 argStorage.push_back(operand); 134 } 135 successor = successorDest; 136 successorOperands = argStorage; 137 return success(); 138 } 139 140 /// Simplify a branch to a block that has a single predecessor. This effectively 141 /// merges the two blocks. 142 static LogicalResult 143 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) { 144 // Check that the successor block has a single predecessor. 145 Block *succ = op.getDest(); 146 Block *opParent = op->getBlock(); 147 if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors())) 148 return failure(); 149 150 // Merge the successor into the current block and erase the branch. 151 SmallVector<Value> brOperands(op.getOperands()); 152 rewriter.eraseOp(op); 153 rewriter.mergeBlocks(succ, opParent, brOperands); 154 return success(); 155 } 156 157 /// br ^bb1 158 /// ^bb1 159 /// br ^bbN(...) 160 /// 161 /// -> br ^bbN(...) 162 /// 163 static LogicalResult simplifyPassThroughBr(BranchOp op, 164 PatternRewriter &rewriter) { 165 Block *dest = op.getDest(); 166 ValueRange destOperands = op.getOperands(); 167 SmallVector<Value, 4> destOperandStorage; 168 169 // Try to collapse the successor if it points somewhere other than this 170 // block. 171 if (dest == op->getBlock() || 172 failed(collapseBranch(dest, destOperands, destOperandStorage))) 173 return failure(); 174 175 // Create a new branch with the collapsed successor. 176 rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands); 177 return success(); 178 } 179 180 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) { 181 return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) || 182 succeeded(simplifyPassThroughBr(op, rewriter))); 183 } 184 185 void BranchOp::setDest(Block *block) { return setSuccessor(block); } 186 187 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); } 188 189 SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) { 190 assert(index == 0 && "invalid successor index"); 191 return SuccessorOperands(getDestOperandsMutable()); 192 } 193 194 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { 195 return getDest(); 196 } 197 198 //===----------------------------------------------------------------------===// 199 // CondBranchOp 200 //===----------------------------------------------------------------------===// 201 202 namespace { 203 /// cf.cond_br true, ^bb1, ^bb2 204 /// -> br ^bb1 205 /// cf.cond_br false, ^bb1, ^bb2 206 /// -> br ^bb2 207 /// 208 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { 209 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 210 211 LogicalResult matchAndRewrite(CondBranchOp condbr, 212 PatternRewriter &rewriter) const override { 213 if (matchPattern(condbr.getCondition(), m_NonZero())) { 214 // True branch taken. 215 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), 216 condbr.getTrueOperands()); 217 return success(); 218 } 219 if (matchPattern(condbr.getCondition(), m_Zero())) { 220 // False branch taken. 221 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), 222 condbr.getFalseOperands()); 223 return success(); 224 } 225 return failure(); 226 } 227 }; 228 229 /// cf.cond_br %cond, ^bb1, ^bb2 230 /// ^bb1 231 /// br ^bbN(...) 232 /// ^bb2 233 /// br ^bbK(...) 234 /// 235 /// -> cf.cond_br %cond, ^bbN(...), ^bbK(...) 236 /// 237 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> { 238 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 239 240 LogicalResult matchAndRewrite(CondBranchOp condbr, 241 PatternRewriter &rewriter) const override { 242 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest(); 243 ValueRange trueDestOperands = condbr.getTrueOperands(); 244 ValueRange falseDestOperands = condbr.getFalseOperands(); 245 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage; 246 247 // Try to collapse one of the current successors. 248 LogicalResult collapsedTrue = 249 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage); 250 LogicalResult collapsedFalse = 251 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage); 252 if (failed(collapsedTrue) && failed(collapsedFalse)) 253 return failure(); 254 255 // Create a new branch with the collapsed successors. 256 rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(), 257 trueDest, trueDestOperands, 258 falseDest, falseDestOperands); 259 return success(); 260 } 261 }; 262 263 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) 264 /// -> br ^bb1(A, ..., N) 265 /// 266 /// cf.cond_br %cond, ^bb1(A), ^bb1(B) 267 /// -> %select = arith.select %cond, A, B 268 /// br ^bb1(%select) 269 /// 270 struct SimplifyCondBranchIdenticalSuccessors 271 : public OpRewritePattern<CondBranchOp> { 272 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 273 274 LogicalResult matchAndRewrite(CondBranchOp condbr, 275 PatternRewriter &rewriter) const override { 276 // Check that the true and false destinations are the same and have the same 277 // operands. 278 Block *trueDest = condbr.getTrueDest(); 279 if (trueDest != condbr.getFalseDest()) 280 return failure(); 281 282 // If all of the operands match, no selects need to be generated. 283 OperandRange trueOperands = condbr.getTrueOperands(); 284 OperandRange falseOperands = condbr.getFalseOperands(); 285 if (trueOperands == falseOperands) { 286 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands); 287 return success(); 288 } 289 290 // Otherwise, if the current block is the only predecessor insert selects 291 // for any mismatched branch operands. 292 if (trueDest->getUniquePredecessor() != condbr->getBlock()) 293 return failure(); 294 295 // Generate a select for any operands that differ between the two. 296 SmallVector<Value, 8> mergedOperands; 297 mergedOperands.reserve(trueOperands.size()); 298 Value condition = condbr.getCondition(); 299 for (auto it : llvm::zip(trueOperands, falseOperands)) { 300 if (std::get<0>(it) == std::get<1>(it)) 301 mergedOperands.push_back(std::get<0>(it)); 302 else 303 mergedOperands.push_back(rewriter.create<arith::SelectOp>( 304 condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it))); 305 } 306 307 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands); 308 return success(); 309 } 310 }; 311 312 /// ... 313 /// cf.cond_br %cond, ^bb1(...), ^bb2(...) 314 /// ... 315 /// ^bb1: // has single predecessor 316 /// ... 317 /// cf.cond_br %cond, ^bb3(...), ^bb4(...) 318 /// 319 /// -> 320 /// 321 /// ... 322 /// cf.cond_br %cond, ^bb1(...), ^bb2(...) 323 /// ... 324 /// ^bb1: // has single predecessor 325 /// ... 326 /// br ^bb3(...) 327 /// 328 struct SimplifyCondBranchFromCondBranchOnSameCondition 329 : public OpRewritePattern<CondBranchOp> { 330 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 331 332 LogicalResult matchAndRewrite(CondBranchOp condbr, 333 PatternRewriter &rewriter) const override { 334 // Check that we have a single distinct predecessor. 335 Block *currentBlock = condbr->getBlock(); 336 Block *predecessor = currentBlock->getSinglePredecessor(); 337 if (!predecessor) 338 return failure(); 339 340 // Check that the predecessor terminates with a conditional branch to this 341 // block and that it branches on the same condition. 342 auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator()); 343 if (!predBranch || condbr.getCondition() != predBranch.getCondition()) 344 return failure(); 345 346 // Fold this branch to an unconditional branch. 347 if (currentBlock == predBranch.getTrueDest()) 348 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), 349 condbr.getTrueDestOperands()); 350 else 351 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), 352 condbr.getFalseDestOperands()); 353 return success(); 354 } 355 }; 356 357 /// cf.cond_br %arg0, ^trueB, ^falseB 358 /// 359 /// ^trueB: 360 /// "test.consumer1"(%arg0) : (i1) -> () 361 /// ... 362 /// 363 /// ^falseB: 364 /// "test.consumer2"(%arg0) : (i1) -> () 365 /// ... 366 /// 367 /// -> 368 /// 369 /// cf.cond_br %arg0, ^trueB, ^falseB 370 /// ^trueB: 371 /// "test.consumer1"(%true) : (i1) -> () 372 /// ... 373 /// 374 /// ^falseB: 375 /// "test.consumer2"(%false) : (i1) -> () 376 /// ... 377 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> { 378 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 379 380 LogicalResult matchAndRewrite(CondBranchOp condbr, 381 PatternRewriter &rewriter) const override { 382 // Check that we have a single distinct predecessor. 383 bool replaced = false; 384 Type ty = rewriter.getI1Type(); 385 386 // These variables serve to prevent creating duplicate constants 387 // and hold constant true or false values. 388 Value constantTrue = nullptr; 389 Value constantFalse = nullptr; 390 391 // TODO These checks can be expanded to encompas any use with only 392 // either the true of false edge as a predecessor. For now, we fall 393 // back to checking the single predecessor is given by the true/fasle 394 // destination, thereby ensuring that only that edge can reach the 395 // op. 396 if (condbr.getTrueDest()->getSinglePredecessor()) { 397 for (OpOperand &use : 398 llvm::make_early_inc_range(condbr.getCondition().getUses())) { 399 if (use.getOwner()->getBlock() == condbr.getTrueDest()) { 400 replaced = true; 401 402 if (!constantTrue) 403 constantTrue = rewriter.create<arith::ConstantOp>( 404 condbr.getLoc(), ty, rewriter.getBoolAttr(true)); 405 406 rewriter.updateRootInPlace(use.getOwner(), 407 [&] { use.set(constantTrue); }); 408 } 409 } 410 } 411 if (condbr.getFalseDest()->getSinglePredecessor()) { 412 for (OpOperand &use : 413 llvm::make_early_inc_range(condbr.getCondition().getUses())) { 414 if (use.getOwner()->getBlock() == condbr.getFalseDest()) { 415 replaced = true; 416 417 if (!constantFalse) 418 constantFalse = rewriter.create<arith::ConstantOp>( 419 condbr.getLoc(), ty, rewriter.getBoolAttr(false)); 420 421 rewriter.updateRootInPlace(use.getOwner(), 422 [&] { use.set(constantFalse); }); 423 } 424 } 425 } 426 return success(replaced); 427 } 428 }; 429 } // namespace 430 431 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, 432 MLIRContext *context) { 433 results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch, 434 SimplifyCondBranchIdenticalSuccessors, 435 SimplifyCondBranchFromCondBranchOnSameCondition, 436 CondBranchTruthPropagation>(context); 437 } 438 439 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { 440 assert(index < getNumSuccessors() && "invalid successor index"); 441 return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable() 442 : getFalseDestOperandsMutable()); 443 } 444 445 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { 446 if (IntegerAttr condAttr = 447 llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) 448 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest(); 449 return nullptr; 450 } 451 452 //===----------------------------------------------------------------------===// 453 // SwitchOp 454 //===----------------------------------------------------------------------===// 455 456 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 457 Block *defaultDestination, ValueRange defaultOperands, 458 DenseIntElementsAttr caseValues, 459 BlockRange caseDestinations, 460 ArrayRef<ValueRange> caseOperands) { 461 build(builder, result, value, defaultOperands, caseOperands, caseValues, 462 defaultDestination, caseDestinations); 463 } 464 465 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 466 Block *defaultDestination, ValueRange defaultOperands, 467 ArrayRef<APInt> caseValues, BlockRange caseDestinations, 468 ArrayRef<ValueRange> caseOperands) { 469 DenseIntElementsAttr caseValuesAttr; 470 if (!caseValues.empty()) { 471 ShapedType caseValueType = VectorType::get( 472 static_cast<int64_t>(caseValues.size()), value.getType()); 473 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); 474 } 475 build(builder, result, value, defaultDestination, defaultOperands, 476 caseValuesAttr, caseDestinations, caseOperands); 477 } 478 479 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 480 Block *defaultDestination, ValueRange defaultOperands, 481 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 482 ArrayRef<ValueRange> caseOperands) { 483 DenseIntElementsAttr caseValuesAttr; 484 if (!caseValues.empty()) { 485 ShapedType caseValueType = VectorType::get( 486 static_cast<int64_t>(caseValues.size()), value.getType()); 487 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); 488 } 489 build(builder, result, value, defaultDestination, defaultOperands, 490 caseValuesAttr, caseDestinations, caseOperands); 491 } 492 493 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? 494 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* 495 static ParseResult parseSwitchOpCases( 496 OpAsmParser &parser, Type &flagType, Block *&defaultDestination, 497 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands, 498 SmallVectorImpl<Type> &defaultOperandTypes, 499 DenseIntElementsAttr &caseValues, 500 SmallVectorImpl<Block *> &caseDestinations, 501 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, 502 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 503 if (parser.parseKeyword("default") || parser.parseColon() || 504 parser.parseSuccessor(defaultDestination)) 505 return failure(); 506 if (succeeded(parser.parseOptionalLParen())) { 507 if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None, 508 /*allowResultNumber=*/false) || 509 parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen()) 510 return failure(); 511 } 512 513 SmallVector<APInt> values; 514 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 515 while (succeeded(parser.parseOptionalComma())) { 516 int64_t value = 0; 517 if (failed(parser.parseInteger(value))) 518 return failure(); 519 values.push_back(APInt(bitWidth, value)); 520 521 Block *destination; 522 SmallVector<OpAsmParser::UnresolvedOperand> operands; 523 SmallVector<Type> operandTypes; 524 if (failed(parser.parseColon()) || 525 failed(parser.parseSuccessor(destination))) 526 return failure(); 527 if (succeeded(parser.parseOptionalLParen())) { 528 if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None, 529 /*allowResultNumber=*/false)) || 530 failed(parser.parseColonTypeList(operandTypes)) || 531 failed(parser.parseRParen())) 532 return failure(); 533 } 534 caseDestinations.push_back(destination); 535 caseOperands.emplace_back(operands); 536 caseOperandTypes.emplace_back(operandTypes); 537 } 538 539 if (!values.empty()) { 540 ShapedType caseValueType = 541 VectorType::get(static_cast<int64_t>(values.size()), flagType); 542 caseValues = DenseIntElementsAttr::get(caseValueType, values); 543 } 544 return success(); 545 } 546 547 static void printSwitchOpCases( 548 OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, 549 OperandRange defaultOperands, TypeRange defaultOperandTypes, 550 DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, 551 OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) { 552 p << " default: "; 553 p.printSuccessorAndUseList(defaultDestination, defaultOperands); 554 555 if (!caseValues) 556 return; 557 558 for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) { 559 p << ','; 560 p.printNewline(); 561 p << " "; 562 p << it.value().getLimitedValue(); 563 p << ": "; 564 p.printSuccessorAndUseList(caseDestinations[it.index()], 565 caseOperands[it.index()]); 566 } 567 p.printNewline(); 568 } 569 570 LogicalResult SwitchOp::verify() { 571 auto caseValues = getCaseValues(); 572 auto caseDestinations = getCaseDestinations(); 573 574 if (!caseValues && caseDestinations.empty()) 575 return success(); 576 577 Type flagType = getFlag().getType(); 578 Type caseValueType = caseValues->getType().getElementType(); 579 if (caseValueType != flagType) 580 return emitOpError() << "'flag' type (" << flagType 581 << ") should match case value type (" << caseValueType 582 << ")"; 583 584 if (caseValues && 585 caseValues->size() != static_cast<int64_t>(caseDestinations.size())) 586 return emitOpError() << "number of case values (" << caseValues->size() 587 << ") should match number of " 588 "case destinations (" 589 << caseDestinations.size() << ")"; 590 return success(); 591 } 592 593 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { 594 assert(index < getNumSuccessors() && "invalid successor index"); 595 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() 596 : getCaseOperandsMutable(index - 1)); 597 } 598 599 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { 600 std::optional<DenseIntElementsAttr> caseValues = getCaseValues(); 601 602 if (!caseValues) 603 return getDefaultDestination(); 604 605 SuccessorRange caseDests = getCaseDestinations(); 606 if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) { 607 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) 608 if (it.value() == value.getValue()) 609 return caseDests[it.index()]; 610 return getDefaultDestination(); 611 } 612 return nullptr; 613 } 614 615 /// switch %flag : i32, [ 616 /// default: ^bb1 617 /// ] 618 /// -> br ^bb1 619 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, 620 PatternRewriter &rewriter) { 621 if (!op.getCaseDestinations().empty()) 622 return failure(); 623 624 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 625 op.getDefaultOperands()); 626 return success(); 627 } 628 629 /// switch %flag : i32, [ 630 /// default: ^bb1, 631 /// 42: ^bb1, 632 /// 43: ^bb2 633 /// ] 634 /// -> 635 /// switch %flag : i32, [ 636 /// default: ^bb1, 637 /// 43: ^bb2 638 /// ] 639 static LogicalResult 640 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) { 641 SmallVector<Block *> newCaseDestinations; 642 SmallVector<ValueRange> newCaseOperands; 643 SmallVector<APInt> newCaseValues; 644 bool requiresChange = false; 645 auto caseValues = op.getCaseValues(); 646 auto caseDests = op.getCaseDestinations(); 647 648 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 649 if (caseDests[it.index()] == op.getDefaultDestination() && 650 op.getCaseOperands(it.index()) == op.getDefaultOperands()) { 651 requiresChange = true; 652 continue; 653 } 654 newCaseDestinations.push_back(caseDests[it.index()]); 655 newCaseOperands.push_back(op.getCaseOperands(it.index())); 656 newCaseValues.push_back(it.value()); 657 } 658 659 if (!requiresChange) 660 return failure(); 661 662 rewriter.replaceOpWithNewOp<SwitchOp>( 663 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), 664 newCaseValues, newCaseDestinations, newCaseOperands); 665 return success(); 666 } 667 668 /// Helper for folding a switch with a constant value. 669 /// switch %c_42 : i32, [ 670 /// default: ^bb1 , 671 /// 42: ^bb2, 672 /// 43: ^bb3 673 /// ] 674 /// -> br ^bb2 675 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, 676 const APInt &caseValue) { 677 auto caseValues = op.getCaseValues(); 678 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 679 if (it.value() == caseValue) { 680 rewriter.replaceOpWithNewOp<BranchOp>( 681 op, op.getCaseDestinations()[it.index()], 682 op.getCaseOperands(it.index())); 683 return; 684 } 685 } 686 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 687 op.getDefaultOperands()); 688 } 689 690 /// switch %c_42 : i32, [ 691 /// default: ^bb1, 692 /// 42: ^bb2, 693 /// 43: ^bb3 694 /// ] 695 /// -> br ^bb2 696 static LogicalResult simplifyConstSwitchValue(SwitchOp op, 697 PatternRewriter &rewriter) { 698 APInt caseValue; 699 if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue))) 700 return failure(); 701 702 foldSwitch(op, rewriter, caseValue); 703 return success(); 704 } 705 706 /// switch %c_42 : i32, [ 707 /// default: ^bb1, 708 /// 42: ^bb2, 709 /// ] 710 /// ^bb2: 711 /// br ^bb3 712 /// -> 713 /// switch %c_42 : i32, [ 714 /// default: ^bb1, 715 /// 42: ^bb3, 716 /// ] 717 static LogicalResult simplifyPassThroughSwitch(SwitchOp op, 718 PatternRewriter &rewriter) { 719 SmallVector<Block *> newCaseDests; 720 SmallVector<ValueRange> newCaseOperands; 721 SmallVector<SmallVector<Value>> argStorage; 722 auto caseValues = op.getCaseValues(); 723 argStorage.reserve(caseValues->size() + 1); 724 auto caseDests = op.getCaseDestinations(); 725 bool requiresChange = false; 726 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { 727 Block *caseDest = caseDests[i]; 728 ValueRange caseOperands = op.getCaseOperands(i); 729 argStorage.emplace_back(); 730 if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back()))) 731 requiresChange = true; 732 733 newCaseDests.push_back(caseDest); 734 newCaseOperands.push_back(caseOperands); 735 } 736 737 Block *defaultDest = op.getDefaultDestination(); 738 ValueRange defaultOperands = op.getDefaultOperands(); 739 argStorage.emplace_back(); 740 741 if (succeeded( 742 collapseBranch(defaultDest, defaultOperands, argStorage.back()))) 743 requiresChange = true; 744 745 if (!requiresChange) 746 return failure(); 747 748 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest, 749 defaultOperands, *caseValues, 750 newCaseDests, newCaseOperands); 751 return success(); 752 } 753 754 /// switch %flag : i32, [ 755 /// default: ^bb1, 756 /// 42: ^bb2, 757 /// ] 758 /// ^bb2: 759 /// switch %flag : i32, [ 760 /// default: ^bb3, 761 /// 42: ^bb4 762 /// ] 763 /// -> 764 /// switch %flag : i32, [ 765 /// default: ^bb1, 766 /// 42: ^bb2, 767 /// ] 768 /// ^bb2: 769 /// br ^bb4 770 /// 771 /// and 772 /// 773 /// switch %flag : i32, [ 774 /// default: ^bb1, 775 /// 42: ^bb2, 776 /// ] 777 /// ^bb2: 778 /// switch %flag : i32, [ 779 /// default: ^bb3, 780 /// 43: ^bb4 781 /// ] 782 /// -> 783 /// switch %flag : i32, [ 784 /// default: ^bb1, 785 /// 42: ^bb2, 786 /// ] 787 /// ^bb2: 788 /// br ^bb3 789 static LogicalResult 790 simplifySwitchFromSwitchOnSameCondition(SwitchOp op, 791 PatternRewriter &rewriter) { 792 // Check that we have a single distinct predecessor. 793 Block *currentBlock = op->getBlock(); 794 Block *predecessor = currentBlock->getSinglePredecessor(); 795 if (!predecessor) 796 return failure(); 797 798 // Check that the predecessor terminates with a switch branch to this block 799 // and that it branches on the same condition and that this branch isn't the 800 // default destination. 801 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); 802 if (!predSwitch || op.getFlag() != predSwitch.getFlag() || 803 predSwitch.getDefaultDestination() == currentBlock) 804 return failure(); 805 806 // Fold this switch to an unconditional branch. 807 SuccessorRange predDests = predSwitch.getCaseDestinations(); 808 auto it = llvm::find(predDests, currentBlock); 809 if (it != predDests.end()) { 810 std::optional<DenseIntElementsAttr> predCaseValues = 811 predSwitch.getCaseValues(); 812 foldSwitch(op, rewriter, 813 predCaseValues->getValues<APInt>()[it - predDests.begin()]); 814 } else { 815 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 816 op.getDefaultOperands()); 817 } 818 return success(); 819 } 820 821 /// switch %flag : i32, [ 822 /// default: ^bb1, 823 /// 42: ^bb2 824 /// ] 825 /// ^bb1: 826 /// switch %flag : i32, [ 827 /// default: ^bb3, 828 /// 42: ^bb4, 829 /// 43: ^bb5 830 /// ] 831 /// -> 832 /// switch %flag : i32, [ 833 /// default: ^bb1, 834 /// 42: ^bb2, 835 /// ] 836 /// ^bb1: 837 /// switch %flag : i32, [ 838 /// default: ^bb3, 839 /// 43: ^bb5 840 /// ] 841 static LogicalResult 842 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, 843 PatternRewriter &rewriter) { 844 // Check that we have a single distinct predecessor. 845 Block *currentBlock = op->getBlock(); 846 Block *predecessor = currentBlock->getSinglePredecessor(); 847 if (!predecessor) 848 return failure(); 849 850 // Check that the predecessor terminates with a switch branch to this block 851 // and that it branches on the same condition and that this branch is the 852 // default destination. 853 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); 854 if (!predSwitch || op.getFlag() != predSwitch.getFlag() || 855 predSwitch.getDefaultDestination() != currentBlock) 856 return failure(); 857 858 // Delete case values that are not possible here. 859 DenseSet<APInt> caseValuesToRemove; 860 auto predDests = predSwitch.getCaseDestinations(); 861 auto predCaseValues = predSwitch.getCaseValues(); 862 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) 863 if (currentBlock != predDests[i]) 864 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]); 865 866 SmallVector<Block *> newCaseDestinations; 867 SmallVector<ValueRange> newCaseOperands; 868 SmallVector<APInt> newCaseValues; 869 bool requiresChange = false; 870 871 auto caseValues = op.getCaseValues(); 872 auto caseDests = op.getCaseDestinations(); 873 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 874 if (caseValuesToRemove.contains(it.value())) { 875 requiresChange = true; 876 continue; 877 } 878 newCaseDestinations.push_back(caseDests[it.index()]); 879 newCaseOperands.push_back(op.getCaseOperands(it.index())); 880 newCaseValues.push_back(it.value()); 881 } 882 883 if (!requiresChange) 884 return failure(); 885 886 rewriter.replaceOpWithNewOp<SwitchOp>( 887 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), 888 newCaseValues, newCaseDestinations, newCaseOperands); 889 return success(); 890 } 891 892 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, 893 MLIRContext *context) { 894 results.add(&simplifySwitchWithOnlyDefault) 895 .add(&dropSwitchCasesThatMatchDefault) 896 .add(&simplifyConstSwitchValue) 897 .add(&simplifyPassThroughSwitch) 898 .add(&simplifySwitchFromSwitchOnSameCondition) 899 .add(&simplifySwitchFromDefaultSwitchOnSameCondition); 900 } 901 902 //===----------------------------------------------------------------------===// 903 // TableGen'd op method definitions 904 //===----------------------------------------------------------------------===// 905 906 #define GET_OP_CLASSES 907 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" 908