1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 10 11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" 14 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 15 #include "mlir/IR/AffineExpr.h" 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinOps.h" 19 #include "mlir/IR/BuiltinTypes.h" 20 #include "mlir/IR/IRMapping.h" 21 #include "mlir/IR/Matchers.h" 22 #include "mlir/IR/OpImplementation.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/IR/TypeUtilities.h" 25 #include "mlir/IR/Value.h" 26 #include "mlir/Transforms/InliningUtils.h" 27 #include "llvm/ADT/APFloat.h" 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/Support/FormatVariadic.h" 30 #include "llvm/Support/raw_ostream.h" 31 #include <numeric> 32 33 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc" 34 35 using namespace mlir; 36 using namespace mlir::cf; 37 38 //===----------------------------------------------------------------------===// 39 // ControlFlowDialect Interfaces 40 //===----------------------------------------------------------------------===// 41 namespace { 42 /// This class defines the interface for handling inlining with control flow 43 /// operations. 44 struct ControlFlowInlinerInterface : public DialectInlinerInterface { 45 using DialectInlinerInterface::DialectInlinerInterface; 46 ~ControlFlowInlinerInterface() override = default; 47 48 /// All control flow operations can be inlined. 49 bool isLegalToInline(Operation *call, Operation *callable, 50 bool wouldBeCloned) const final { 51 return true; 52 } 53 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { 54 return true; 55 } 56 57 /// ControlFlow terminator operations don't really need any special handing. 58 void handleTerminator(Operation *op, Block *newDest) const final {} 59 }; 60 } // namespace 61 62 //===----------------------------------------------------------------------===// 63 // ControlFlowDialect 64 //===----------------------------------------------------------------------===// 65 66 void ControlFlowDialect::initialize() { 67 addOperations< 68 #define GET_OP_LIST 69 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" 70 >(); 71 addInterfaces<ControlFlowInlinerInterface>(); 72 declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>(); 73 declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp, 74 CondBranchOp>(); 75 declarePromisedInterface<bufferization::BufferDeallocationOpInterface, 76 CondBranchOp>(); 77 } 78 79 //===----------------------------------------------------------------------===// 80 // AssertOp 81 //===----------------------------------------------------------------------===// 82 83 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { 84 // Erase assertion if argument is constant true. 85 if (matchPattern(op.getArg(), m_One())) { 86 rewriter.eraseOp(op); 87 return success(); 88 } 89 return failure(); 90 } 91 92 //===----------------------------------------------------------------------===// 93 // BranchOp 94 //===----------------------------------------------------------------------===// 95 96 /// Given a successor, try to collapse it to a new destination if it only 97 /// contains a passthrough unconditional branch. If the successor is 98 /// collapsable, `successor` and `successorOperands` are updated to reference 99 /// the new destination and values. `argStorage` is used as storage if operands 100 /// to the collapsed successor need to be remapped. It must outlive uses of 101 /// successorOperands. 102 static LogicalResult collapseBranch(Block *&successor, 103 ValueRange &successorOperands, 104 SmallVectorImpl<Value> &argStorage) { 105 // Check that the successor only contains a unconditional branch. 106 if (std::next(successor->begin()) != successor->end()) 107 return failure(); 108 // Check that the terminator is an unconditional branch. 109 BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator()); 110 if (!successorBranch) 111 return failure(); 112 // Check that the arguments are only used within the terminator. 113 for (BlockArgument arg : successor->getArguments()) { 114 for (Operation *user : arg.getUsers()) 115 if (user != successorBranch) 116 return failure(); 117 } 118 // Don't try to collapse branches to infinite loops. 119 Block *successorDest = successorBranch.getDest(); 120 if (successorDest == successor) 121 return failure(); 122 123 // Update the operands to the successor. If the branch parent has no 124 // arguments, we can use the branch operands directly. 125 OperandRange operands = successorBranch.getOperands(); 126 if (successor->args_empty()) { 127 successor = successorDest; 128 successorOperands = operands; 129 return success(); 130 } 131 132 // Otherwise, we need to remap any argument operands. 133 for (Value operand : operands) { 134 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand); 135 if (argOperand && argOperand.getOwner() == successor) 136 argStorage.push_back(successorOperands[argOperand.getArgNumber()]); 137 else 138 argStorage.push_back(operand); 139 } 140 successor = successorDest; 141 successorOperands = argStorage; 142 return success(); 143 } 144 145 /// Simplify a branch to a block that has a single predecessor. This effectively 146 /// merges the two blocks. 147 static LogicalResult 148 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) { 149 // Check that the successor block has a single predecessor. 150 Block *succ = op.getDest(); 151 Block *opParent = op->getBlock(); 152 if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors())) 153 return failure(); 154 155 // Merge the successor into the current block and erase the branch. 156 SmallVector<Value> brOperands(op.getOperands()); 157 rewriter.eraseOp(op); 158 rewriter.mergeBlocks(succ, opParent, brOperands); 159 return success(); 160 } 161 162 /// br ^bb1 163 /// ^bb1 164 /// br ^bbN(...) 165 /// 166 /// -> br ^bbN(...) 167 /// 168 static LogicalResult simplifyPassThroughBr(BranchOp op, 169 PatternRewriter &rewriter) { 170 Block *dest = op.getDest(); 171 ValueRange destOperands = op.getOperands(); 172 SmallVector<Value, 4> destOperandStorage; 173 174 // Try to collapse the successor if it points somewhere other than this 175 // block. 176 if (dest == op->getBlock() || 177 failed(collapseBranch(dest, destOperands, destOperandStorage))) 178 return failure(); 179 180 // Create a new branch with the collapsed successor. 181 rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands); 182 return success(); 183 } 184 185 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) { 186 return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) || 187 succeeded(simplifyPassThroughBr(op, rewriter))); 188 } 189 190 void BranchOp::setDest(Block *block) { return setSuccessor(block); } 191 192 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); } 193 194 SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) { 195 assert(index == 0 && "invalid successor index"); 196 return SuccessorOperands(getDestOperandsMutable()); 197 } 198 199 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { 200 return getDest(); 201 } 202 203 //===----------------------------------------------------------------------===// 204 // CondBranchOp 205 //===----------------------------------------------------------------------===// 206 207 namespace { 208 /// cf.cond_br true, ^bb1, ^bb2 209 /// -> br ^bb1 210 /// cf.cond_br false, ^bb1, ^bb2 211 /// -> br ^bb2 212 /// 213 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { 214 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 215 216 LogicalResult matchAndRewrite(CondBranchOp condbr, 217 PatternRewriter &rewriter) const override { 218 if (matchPattern(condbr.getCondition(), m_NonZero())) { 219 // True branch taken. 220 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), 221 condbr.getTrueOperands()); 222 return success(); 223 } 224 if (matchPattern(condbr.getCondition(), m_Zero())) { 225 // False branch taken. 226 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), 227 condbr.getFalseOperands()); 228 return success(); 229 } 230 return failure(); 231 } 232 }; 233 234 /// cf.cond_br %cond, ^bb1, ^bb2 235 /// ^bb1 236 /// br ^bbN(...) 237 /// ^bb2 238 /// br ^bbK(...) 239 /// 240 /// -> cf.cond_br %cond, ^bbN(...), ^bbK(...) 241 /// 242 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> { 243 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 244 245 LogicalResult matchAndRewrite(CondBranchOp condbr, 246 PatternRewriter &rewriter) const override { 247 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest(); 248 ValueRange trueDestOperands = condbr.getTrueOperands(); 249 ValueRange falseDestOperands = condbr.getFalseOperands(); 250 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage; 251 252 // Try to collapse one of the current successors. 253 LogicalResult collapsedTrue = 254 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage); 255 LogicalResult collapsedFalse = 256 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage); 257 if (failed(collapsedTrue) && failed(collapsedFalse)) 258 return failure(); 259 260 // Create a new branch with the collapsed successors. 261 rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(), 262 trueDest, trueDestOperands, 263 falseDest, falseDestOperands); 264 return success(); 265 } 266 }; 267 268 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) 269 /// -> br ^bb1(A, ..., N) 270 /// 271 /// cf.cond_br %cond, ^bb1(A), ^bb1(B) 272 /// -> %select = arith.select %cond, A, B 273 /// br ^bb1(%select) 274 /// 275 struct SimplifyCondBranchIdenticalSuccessors 276 : public OpRewritePattern<CondBranchOp> { 277 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 278 279 LogicalResult matchAndRewrite(CondBranchOp condbr, 280 PatternRewriter &rewriter) const override { 281 // Check that the true and false destinations are the same and have the same 282 // operands. 283 Block *trueDest = condbr.getTrueDest(); 284 if (trueDest != condbr.getFalseDest()) 285 return failure(); 286 287 // If all of the operands match, no selects need to be generated. 288 OperandRange trueOperands = condbr.getTrueOperands(); 289 OperandRange falseOperands = condbr.getFalseOperands(); 290 if (trueOperands == falseOperands) { 291 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands); 292 return success(); 293 } 294 295 // Otherwise, if the current block is the only predecessor insert selects 296 // for any mismatched branch operands. 297 if (trueDest->getUniquePredecessor() != condbr->getBlock()) 298 return failure(); 299 300 // Generate a select for any operands that differ between the two. 301 SmallVector<Value, 8> mergedOperands; 302 mergedOperands.reserve(trueOperands.size()); 303 Value condition = condbr.getCondition(); 304 for (auto it : llvm::zip(trueOperands, falseOperands)) { 305 if (std::get<0>(it) == std::get<1>(it)) 306 mergedOperands.push_back(std::get<0>(it)); 307 else 308 mergedOperands.push_back(rewriter.create<arith::SelectOp>( 309 condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it))); 310 } 311 312 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands); 313 return success(); 314 } 315 }; 316 317 /// ... 318 /// cf.cond_br %cond, ^bb1(...), ^bb2(...) 319 /// ... 320 /// ^bb1: // has single predecessor 321 /// ... 322 /// cf.cond_br %cond, ^bb3(...), ^bb4(...) 323 /// 324 /// -> 325 /// 326 /// ... 327 /// cf.cond_br %cond, ^bb1(...), ^bb2(...) 328 /// ... 329 /// ^bb1: // has single predecessor 330 /// ... 331 /// br ^bb3(...) 332 /// 333 struct SimplifyCondBranchFromCondBranchOnSameCondition 334 : public OpRewritePattern<CondBranchOp> { 335 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 336 337 LogicalResult matchAndRewrite(CondBranchOp condbr, 338 PatternRewriter &rewriter) const override { 339 // Check that we have a single distinct predecessor. 340 Block *currentBlock = condbr->getBlock(); 341 Block *predecessor = currentBlock->getSinglePredecessor(); 342 if (!predecessor) 343 return failure(); 344 345 // Check that the predecessor terminates with a conditional branch to this 346 // block and that it branches on the same condition. 347 auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator()); 348 if (!predBranch || condbr.getCondition() != predBranch.getCondition()) 349 return failure(); 350 351 // Fold this branch to an unconditional branch. 352 if (currentBlock == predBranch.getTrueDest()) 353 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), 354 condbr.getTrueDestOperands()); 355 else 356 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), 357 condbr.getFalseDestOperands()); 358 return success(); 359 } 360 }; 361 362 /// cf.cond_br %arg0, ^trueB, ^falseB 363 /// 364 /// ^trueB: 365 /// "test.consumer1"(%arg0) : (i1) -> () 366 /// ... 367 /// 368 /// ^falseB: 369 /// "test.consumer2"(%arg0) : (i1) -> () 370 /// ... 371 /// 372 /// -> 373 /// 374 /// cf.cond_br %arg0, ^trueB, ^falseB 375 /// ^trueB: 376 /// "test.consumer1"(%true) : (i1) -> () 377 /// ... 378 /// 379 /// ^falseB: 380 /// "test.consumer2"(%false) : (i1) -> () 381 /// ... 382 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> { 383 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 384 385 LogicalResult matchAndRewrite(CondBranchOp condbr, 386 PatternRewriter &rewriter) const override { 387 // Check that we have a single distinct predecessor. 388 bool replaced = false; 389 Type ty = rewriter.getI1Type(); 390 391 // These variables serve to prevent creating duplicate constants 392 // and hold constant true or false values. 393 Value constantTrue = nullptr; 394 Value constantFalse = nullptr; 395 396 // TODO These checks can be expanded to encompas any use with only 397 // either the true of false edge as a predecessor. For now, we fall 398 // back to checking the single predecessor is given by the true/fasle 399 // destination, thereby ensuring that only that edge can reach the 400 // op. 401 if (condbr.getTrueDest()->getSinglePredecessor()) { 402 for (OpOperand &use : 403 llvm::make_early_inc_range(condbr.getCondition().getUses())) { 404 if (use.getOwner()->getBlock() == condbr.getTrueDest()) { 405 replaced = true; 406 407 if (!constantTrue) 408 constantTrue = rewriter.create<arith::ConstantOp>( 409 condbr.getLoc(), ty, rewriter.getBoolAttr(true)); 410 411 rewriter.modifyOpInPlace(use.getOwner(), 412 [&] { use.set(constantTrue); }); 413 } 414 } 415 } 416 if (condbr.getFalseDest()->getSinglePredecessor()) { 417 for (OpOperand &use : 418 llvm::make_early_inc_range(condbr.getCondition().getUses())) { 419 if (use.getOwner()->getBlock() == condbr.getFalseDest()) { 420 replaced = true; 421 422 if (!constantFalse) 423 constantFalse = rewriter.create<arith::ConstantOp>( 424 condbr.getLoc(), ty, rewriter.getBoolAttr(false)); 425 426 rewriter.modifyOpInPlace(use.getOwner(), 427 [&] { use.set(constantFalse); }); 428 } 429 } 430 } 431 return success(replaced); 432 } 433 }; 434 } // namespace 435 436 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, 437 MLIRContext *context) { 438 results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch, 439 SimplifyCondBranchIdenticalSuccessors, 440 SimplifyCondBranchFromCondBranchOnSameCondition, 441 CondBranchTruthPropagation>(context); 442 } 443 444 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { 445 assert(index < getNumSuccessors() && "invalid successor index"); 446 return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable() 447 : getFalseDestOperandsMutable()); 448 } 449 450 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { 451 if (IntegerAttr condAttr = 452 llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) 453 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest(); 454 return nullptr; 455 } 456 457 //===----------------------------------------------------------------------===// 458 // SwitchOp 459 //===----------------------------------------------------------------------===// 460 461 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 462 Block *defaultDestination, ValueRange defaultOperands, 463 DenseIntElementsAttr caseValues, 464 BlockRange caseDestinations, 465 ArrayRef<ValueRange> caseOperands) { 466 build(builder, result, value, defaultOperands, caseOperands, caseValues, 467 defaultDestination, caseDestinations); 468 } 469 470 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 471 Block *defaultDestination, ValueRange defaultOperands, 472 ArrayRef<APInt> caseValues, BlockRange caseDestinations, 473 ArrayRef<ValueRange> caseOperands) { 474 DenseIntElementsAttr caseValuesAttr; 475 if (!caseValues.empty()) { 476 ShapedType caseValueType = VectorType::get( 477 static_cast<int64_t>(caseValues.size()), value.getType()); 478 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); 479 } 480 build(builder, result, value, defaultDestination, defaultOperands, 481 caseValuesAttr, caseDestinations, caseOperands); 482 } 483 484 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 485 Block *defaultDestination, ValueRange defaultOperands, 486 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 487 ArrayRef<ValueRange> caseOperands) { 488 DenseIntElementsAttr caseValuesAttr; 489 if (!caseValues.empty()) { 490 ShapedType caseValueType = VectorType::get( 491 static_cast<int64_t>(caseValues.size()), value.getType()); 492 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); 493 } 494 build(builder, result, value, defaultDestination, defaultOperands, 495 caseValuesAttr, caseDestinations, caseOperands); 496 } 497 498 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? 499 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* 500 static ParseResult parseSwitchOpCases( 501 OpAsmParser &parser, Type &flagType, Block *&defaultDestination, 502 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands, 503 SmallVectorImpl<Type> &defaultOperandTypes, 504 DenseIntElementsAttr &caseValues, 505 SmallVectorImpl<Block *> &caseDestinations, 506 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, 507 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 508 if (parser.parseKeyword("default") || parser.parseColon() || 509 parser.parseSuccessor(defaultDestination)) 510 return failure(); 511 if (succeeded(parser.parseOptionalLParen())) { 512 if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None, 513 /*allowResultNumber=*/false) || 514 parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen()) 515 return failure(); 516 } 517 518 SmallVector<APInt> values; 519 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 520 while (succeeded(parser.parseOptionalComma())) { 521 int64_t value = 0; 522 if (failed(parser.parseInteger(value))) 523 return failure(); 524 values.push_back(APInt(bitWidth, value)); 525 526 Block *destination; 527 SmallVector<OpAsmParser::UnresolvedOperand> operands; 528 SmallVector<Type> operandTypes; 529 if (failed(parser.parseColon()) || 530 failed(parser.parseSuccessor(destination))) 531 return failure(); 532 if (succeeded(parser.parseOptionalLParen())) { 533 if (failed(parser.parseOperandList(operands, 534 OpAsmParser::Delimiter::None)) || 535 failed(parser.parseColonTypeList(operandTypes)) || 536 failed(parser.parseRParen())) 537 return failure(); 538 } 539 caseDestinations.push_back(destination); 540 caseOperands.emplace_back(operands); 541 caseOperandTypes.emplace_back(operandTypes); 542 } 543 544 if (!values.empty()) { 545 ShapedType caseValueType = 546 VectorType::get(static_cast<int64_t>(values.size()), flagType); 547 caseValues = DenseIntElementsAttr::get(caseValueType, values); 548 } 549 return success(); 550 } 551 552 static void printSwitchOpCases( 553 OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, 554 OperandRange defaultOperands, TypeRange defaultOperandTypes, 555 DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, 556 OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) { 557 p << " default: "; 558 p.printSuccessorAndUseList(defaultDestination, defaultOperands); 559 560 if (!caseValues) 561 return; 562 563 for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) { 564 p << ','; 565 p.printNewline(); 566 p << " "; 567 p << it.value().getLimitedValue(); 568 p << ": "; 569 p.printSuccessorAndUseList(caseDestinations[it.index()], 570 caseOperands[it.index()]); 571 } 572 p.printNewline(); 573 } 574 575 LogicalResult SwitchOp::verify() { 576 auto caseValues = getCaseValues(); 577 auto caseDestinations = getCaseDestinations(); 578 579 if (!caseValues && caseDestinations.empty()) 580 return success(); 581 582 Type flagType = getFlag().getType(); 583 Type caseValueType = caseValues->getType().getElementType(); 584 if (caseValueType != flagType) 585 return emitOpError() << "'flag' type (" << flagType 586 << ") should match case value type (" << caseValueType 587 << ")"; 588 589 if (caseValues && 590 caseValues->size() != static_cast<int64_t>(caseDestinations.size())) 591 return emitOpError() << "number of case values (" << caseValues->size() 592 << ") should match number of " 593 "case destinations (" 594 << caseDestinations.size() << ")"; 595 return success(); 596 } 597 598 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { 599 assert(index < getNumSuccessors() && "invalid successor index"); 600 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() 601 : getCaseOperandsMutable(index - 1)); 602 } 603 604 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { 605 std::optional<DenseIntElementsAttr> caseValues = getCaseValues(); 606 607 if (!caseValues) 608 return getDefaultDestination(); 609 610 SuccessorRange caseDests = getCaseDestinations(); 611 if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) { 612 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) 613 if (it.value() == value.getValue()) 614 return caseDests[it.index()]; 615 return getDefaultDestination(); 616 } 617 return nullptr; 618 } 619 620 /// switch %flag : i32, [ 621 /// default: ^bb1 622 /// ] 623 /// -> br ^bb1 624 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, 625 PatternRewriter &rewriter) { 626 if (!op.getCaseDestinations().empty()) 627 return failure(); 628 629 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 630 op.getDefaultOperands()); 631 return success(); 632 } 633 634 /// switch %flag : i32, [ 635 /// default: ^bb1, 636 /// 42: ^bb1, 637 /// 43: ^bb2 638 /// ] 639 /// -> 640 /// switch %flag : i32, [ 641 /// default: ^bb1, 642 /// 43: ^bb2 643 /// ] 644 static LogicalResult 645 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) { 646 SmallVector<Block *> newCaseDestinations; 647 SmallVector<ValueRange> newCaseOperands; 648 SmallVector<APInt> newCaseValues; 649 bool requiresChange = false; 650 auto caseValues = op.getCaseValues(); 651 auto caseDests = op.getCaseDestinations(); 652 653 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 654 if (caseDests[it.index()] == op.getDefaultDestination() && 655 op.getCaseOperands(it.index()) == op.getDefaultOperands()) { 656 requiresChange = true; 657 continue; 658 } 659 newCaseDestinations.push_back(caseDests[it.index()]); 660 newCaseOperands.push_back(op.getCaseOperands(it.index())); 661 newCaseValues.push_back(it.value()); 662 } 663 664 if (!requiresChange) 665 return failure(); 666 667 rewriter.replaceOpWithNewOp<SwitchOp>( 668 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), 669 newCaseValues, newCaseDestinations, newCaseOperands); 670 return success(); 671 } 672 673 /// Helper for folding a switch with a constant value. 674 /// switch %c_42 : i32, [ 675 /// default: ^bb1 , 676 /// 42: ^bb2, 677 /// 43: ^bb3 678 /// ] 679 /// -> br ^bb2 680 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, 681 const APInt &caseValue) { 682 auto caseValues = op.getCaseValues(); 683 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 684 if (it.value() == caseValue) { 685 rewriter.replaceOpWithNewOp<BranchOp>( 686 op, op.getCaseDestinations()[it.index()], 687 op.getCaseOperands(it.index())); 688 return; 689 } 690 } 691 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 692 op.getDefaultOperands()); 693 } 694 695 /// switch %c_42 : i32, [ 696 /// default: ^bb1, 697 /// 42: ^bb2, 698 /// 43: ^bb3 699 /// ] 700 /// -> br ^bb2 701 static LogicalResult simplifyConstSwitchValue(SwitchOp op, 702 PatternRewriter &rewriter) { 703 APInt caseValue; 704 if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue))) 705 return failure(); 706 707 foldSwitch(op, rewriter, caseValue); 708 return success(); 709 } 710 711 /// switch %c_42 : i32, [ 712 /// default: ^bb1, 713 /// 42: ^bb2, 714 /// ] 715 /// ^bb2: 716 /// br ^bb3 717 /// -> 718 /// switch %c_42 : i32, [ 719 /// default: ^bb1, 720 /// 42: ^bb3, 721 /// ] 722 static LogicalResult simplifyPassThroughSwitch(SwitchOp op, 723 PatternRewriter &rewriter) { 724 SmallVector<Block *> newCaseDests; 725 SmallVector<ValueRange> newCaseOperands; 726 SmallVector<SmallVector<Value>> argStorage; 727 auto caseValues = op.getCaseValues(); 728 argStorage.reserve(caseValues->size() + 1); 729 auto caseDests = op.getCaseDestinations(); 730 bool requiresChange = false; 731 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { 732 Block *caseDest = caseDests[i]; 733 ValueRange caseOperands = op.getCaseOperands(i); 734 argStorage.emplace_back(); 735 if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back()))) 736 requiresChange = true; 737 738 newCaseDests.push_back(caseDest); 739 newCaseOperands.push_back(caseOperands); 740 } 741 742 Block *defaultDest = op.getDefaultDestination(); 743 ValueRange defaultOperands = op.getDefaultOperands(); 744 argStorage.emplace_back(); 745 746 if (succeeded( 747 collapseBranch(defaultDest, defaultOperands, argStorage.back()))) 748 requiresChange = true; 749 750 if (!requiresChange) 751 return failure(); 752 753 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest, 754 defaultOperands, *caseValues, 755 newCaseDests, newCaseOperands); 756 return success(); 757 } 758 759 /// switch %flag : i32, [ 760 /// default: ^bb1, 761 /// 42: ^bb2, 762 /// ] 763 /// ^bb2: 764 /// switch %flag : i32, [ 765 /// default: ^bb3, 766 /// 42: ^bb4 767 /// ] 768 /// -> 769 /// switch %flag : i32, [ 770 /// default: ^bb1, 771 /// 42: ^bb2, 772 /// ] 773 /// ^bb2: 774 /// br ^bb4 775 /// 776 /// and 777 /// 778 /// switch %flag : i32, [ 779 /// default: ^bb1, 780 /// 42: ^bb2, 781 /// ] 782 /// ^bb2: 783 /// switch %flag : i32, [ 784 /// default: ^bb3, 785 /// 43: ^bb4 786 /// ] 787 /// -> 788 /// switch %flag : i32, [ 789 /// default: ^bb1, 790 /// 42: ^bb2, 791 /// ] 792 /// ^bb2: 793 /// br ^bb3 794 static LogicalResult 795 simplifySwitchFromSwitchOnSameCondition(SwitchOp op, 796 PatternRewriter &rewriter) { 797 // Check that we have a single distinct predecessor. 798 Block *currentBlock = op->getBlock(); 799 Block *predecessor = currentBlock->getSinglePredecessor(); 800 if (!predecessor) 801 return failure(); 802 803 // Check that the predecessor terminates with a switch branch to this block 804 // and that it branches on the same condition and that this branch isn't the 805 // default destination. 806 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); 807 if (!predSwitch || op.getFlag() != predSwitch.getFlag() || 808 predSwitch.getDefaultDestination() == currentBlock) 809 return failure(); 810 811 // Fold this switch to an unconditional branch. 812 SuccessorRange predDests = predSwitch.getCaseDestinations(); 813 auto it = llvm::find(predDests, currentBlock); 814 if (it != predDests.end()) { 815 std::optional<DenseIntElementsAttr> predCaseValues = 816 predSwitch.getCaseValues(); 817 foldSwitch(op, rewriter, 818 predCaseValues->getValues<APInt>()[it - predDests.begin()]); 819 } else { 820 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 821 op.getDefaultOperands()); 822 } 823 return success(); 824 } 825 826 /// switch %flag : i32, [ 827 /// default: ^bb1, 828 /// 42: ^bb2 829 /// ] 830 /// ^bb1: 831 /// switch %flag : i32, [ 832 /// default: ^bb3, 833 /// 42: ^bb4, 834 /// 43: ^bb5 835 /// ] 836 /// -> 837 /// switch %flag : i32, [ 838 /// default: ^bb1, 839 /// 42: ^bb2, 840 /// ] 841 /// ^bb1: 842 /// switch %flag : i32, [ 843 /// default: ^bb3, 844 /// 43: ^bb5 845 /// ] 846 static LogicalResult 847 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, 848 PatternRewriter &rewriter) { 849 // Check that we have a single distinct predecessor. 850 Block *currentBlock = op->getBlock(); 851 Block *predecessor = currentBlock->getSinglePredecessor(); 852 if (!predecessor) 853 return failure(); 854 855 // Check that the predecessor terminates with a switch branch to this block 856 // and that it branches on the same condition and that this branch is the 857 // default destination. 858 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); 859 if (!predSwitch || op.getFlag() != predSwitch.getFlag() || 860 predSwitch.getDefaultDestination() != currentBlock) 861 return failure(); 862 863 // Delete case values that are not possible here. 864 DenseSet<APInt> caseValuesToRemove; 865 auto predDests = predSwitch.getCaseDestinations(); 866 auto predCaseValues = predSwitch.getCaseValues(); 867 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) 868 if (currentBlock != predDests[i]) 869 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]); 870 871 SmallVector<Block *> newCaseDestinations; 872 SmallVector<ValueRange> newCaseOperands; 873 SmallVector<APInt> newCaseValues; 874 bool requiresChange = false; 875 876 auto caseValues = op.getCaseValues(); 877 auto caseDests = op.getCaseDestinations(); 878 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 879 if (caseValuesToRemove.contains(it.value())) { 880 requiresChange = true; 881 continue; 882 } 883 newCaseDestinations.push_back(caseDests[it.index()]); 884 newCaseOperands.push_back(op.getCaseOperands(it.index())); 885 newCaseValues.push_back(it.value()); 886 } 887 888 if (!requiresChange) 889 return failure(); 890 891 rewriter.replaceOpWithNewOp<SwitchOp>( 892 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), 893 newCaseValues, newCaseDestinations, newCaseOperands); 894 return success(); 895 } 896 897 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, 898 MLIRContext *context) { 899 results.add(&simplifySwitchWithOnlyDefault) 900 .add(&dropSwitchCasesThatMatchDefault) 901 .add(&simplifyConstSwitchValue) 902 .add(&simplifyPassThroughSwitch) 903 .add(&simplifySwitchFromSwitchOnSameCondition) 904 .add(&simplifySwitchFromDefaultSwitchOnSameCondition); 905 } 906 907 //===----------------------------------------------------------------------===// 908 // TableGen'd op method definitions 909 //===----------------------------------------------------------------------===// 910 911 #define GET_OP_CLASSES 912 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" 913