1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 10 11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" 14 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 15 #include "mlir/IR/AffineExpr.h" 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinOps.h" 19 #include "mlir/IR/BuiltinTypes.h" 20 #include "mlir/IR/IRMapping.h" 21 #include "mlir/IR/Matchers.h" 22 #include "mlir/IR/OpImplementation.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/IR/TypeUtilities.h" 25 #include "mlir/IR/Value.h" 26 #include "mlir/Transforms/InliningUtils.h" 27 #include "llvm/ADT/APFloat.h" 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/Support/FormatVariadic.h" 30 #include "llvm/Support/raw_ostream.h" 31 #include <numeric> 32 33 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc" 34 35 using namespace mlir; 36 using namespace mlir::cf; 37 38 //===----------------------------------------------------------------------===// 39 // ControlFlowDialect Interfaces 40 //===----------------------------------------------------------------------===// 41 namespace { 42 /// This class defines the interface for handling inlining with control flow 43 /// operations. 44 struct ControlFlowInlinerInterface : public DialectInlinerInterface { 45 using DialectInlinerInterface::DialectInlinerInterface; 46 ~ControlFlowInlinerInterface() override = default; 47 48 /// All control flow operations can be inlined. 49 bool isLegalToInline(Operation *call, Operation *callable, 50 bool wouldBeCloned) const final { 51 return true; 52 } 53 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { 54 return true; 55 } 56 57 /// ControlFlow terminator operations don't really need any special handing. 58 void handleTerminator(Operation *op, Block *newDest) const final {} 59 }; 60 } // namespace 61 62 //===----------------------------------------------------------------------===// 63 // ControlFlowDialect 64 //===----------------------------------------------------------------------===// 65 66 void ControlFlowDialect::initialize() { 67 addOperations< 68 #define GET_OP_LIST 69 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" 70 >(); 71 addInterfaces<ControlFlowInlinerInterface>(); 72 declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>(); 73 declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp, 74 CondBranchOp>(); 75 declarePromisedInterface<bufferization::BufferDeallocationOpInterface, 76 CondBranchOp>(); 77 } 78 79 //===----------------------------------------------------------------------===// 80 // AssertOp 81 //===----------------------------------------------------------------------===// 82 83 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { 84 // Erase assertion if argument is constant true. 85 if (matchPattern(op.getArg(), m_One())) { 86 rewriter.eraseOp(op); 87 return success(); 88 } 89 return failure(); 90 } 91 92 // This side effect models "program termination". 93 void AssertOp::getEffects( 94 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 95 &effects) { 96 effects.emplace_back(MemoryEffects::Write::get()); 97 } 98 99 //===----------------------------------------------------------------------===// 100 // BranchOp 101 //===----------------------------------------------------------------------===// 102 103 /// Given a successor, try to collapse it to a new destination if it only 104 /// contains a passthrough unconditional branch. If the successor is 105 /// collapsable, `successor` and `successorOperands` are updated to reference 106 /// the new destination and values. `argStorage` is used as storage if operands 107 /// to the collapsed successor need to be remapped. It must outlive uses of 108 /// successorOperands. 109 static LogicalResult collapseBranch(Block *&successor, 110 ValueRange &successorOperands, 111 SmallVectorImpl<Value> &argStorage) { 112 // Check that the successor only contains a unconditional branch. 113 if (std::next(successor->begin()) != successor->end()) 114 return failure(); 115 // Check that the terminator is an unconditional branch. 116 BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator()); 117 if (!successorBranch) 118 return failure(); 119 // Check that the arguments are only used within the terminator. 120 for (BlockArgument arg : successor->getArguments()) { 121 for (Operation *user : arg.getUsers()) 122 if (user != successorBranch) 123 return failure(); 124 } 125 // Don't try to collapse branches to infinite loops. 126 Block *successorDest = successorBranch.getDest(); 127 if (successorDest == successor) 128 return failure(); 129 130 // Update the operands to the successor. If the branch parent has no 131 // arguments, we can use the branch operands directly. 132 OperandRange operands = successorBranch.getOperands(); 133 if (successor->args_empty()) { 134 successor = successorDest; 135 successorOperands = operands; 136 return success(); 137 } 138 139 // Otherwise, we need to remap any argument operands. 140 for (Value operand : operands) { 141 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand); 142 if (argOperand && argOperand.getOwner() == successor) 143 argStorage.push_back(successorOperands[argOperand.getArgNumber()]); 144 else 145 argStorage.push_back(operand); 146 } 147 successor = successorDest; 148 successorOperands = argStorage; 149 return success(); 150 } 151 152 /// Simplify a branch to a block that has a single predecessor. This effectively 153 /// merges the two blocks. 154 static LogicalResult 155 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) { 156 // Check that the successor block has a single predecessor. 157 Block *succ = op.getDest(); 158 Block *opParent = op->getBlock(); 159 if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors())) 160 return failure(); 161 162 // Merge the successor into the current block and erase the branch. 163 SmallVector<Value> brOperands(op.getOperands()); 164 rewriter.eraseOp(op); 165 rewriter.mergeBlocks(succ, opParent, brOperands); 166 return success(); 167 } 168 169 /// br ^bb1 170 /// ^bb1 171 /// br ^bbN(...) 172 /// 173 /// -> br ^bbN(...) 174 /// 175 static LogicalResult simplifyPassThroughBr(BranchOp op, 176 PatternRewriter &rewriter) { 177 Block *dest = op.getDest(); 178 ValueRange destOperands = op.getOperands(); 179 SmallVector<Value, 4> destOperandStorage; 180 181 // Try to collapse the successor if it points somewhere other than this 182 // block. 183 if (dest == op->getBlock() || 184 failed(collapseBranch(dest, destOperands, destOperandStorage))) 185 return failure(); 186 187 // Create a new branch with the collapsed successor. 188 rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands); 189 return success(); 190 } 191 192 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) { 193 return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) || 194 succeeded(simplifyPassThroughBr(op, rewriter))); 195 } 196 197 void BranchOp::setDest(Block *block) { return setSuccessor(block); } 198 199 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); } 200 201 SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) { 202 assert(index == 0 && "invalid successor index"); 203 return SuccessorOperands(getDestOperandsMutable()); 204 } 205 206 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { 207 return getDest(); 208 } 209 210 //===----------------------------------------------------------------------===// 211 // CondBranchOp 212 //===----------------------------------------------------------------------===// 213 214 namespace { 215 /// cf.cond_br true, ^bb1, ^bb2 216 /// -> br ^bb1 217 /// cf.cond_br false, ^bb1, ^bb2 218 /// -> br ^bb2 219 /// 220 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { 221 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 222 223 LogicalResult matchAndRewrite(CondBranchOp condbr, 224 PatternRewriter &rewriter) const override { 225 if (matchPattern(condbr.getCondition(), m_NonZero())) { 226 // True branch taken. 227 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), 228 condbr.getTrueOperands()); 229 return success(); 230 } 231 if (matchPattern(condbr.getCondition(), m_Zero())) { 232 // False branch taken. 233 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), 234 condbr.getFalseOperands()); 235 return success(); 236 } 237 return failure(); 238 } 239 }; 240 241 /// cf.cond_br %cond, ^bb1, ^bb2 242 /// ^bb1 243 /// br ^bbN(...) 244 /// ^bb2 245 /// br ^bbK(...) 246 /// 247 /// -> cf.cond_br %cond, ^bbN(...), ^bbK(...) 248 /// 249 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> { 250 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 251 252 LogicalResult matchAndRewrite(CondBranchOp condbr, 253 PatternRewriter &rewriter) const override { 254 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest(); 255 ValueRange trueDestOperands = condbr.getTrueOperands(); 256 ValueRange falseDestOperands = condbr.getFalseOperands(); 257 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage; 258 259 // Try to collapse one of the current successors. 260 LogicalResult collapsedTrue = 261 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage); 262 LogicalResult collapsedFalse = 263 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage); 264 if (failed(collapsedTrue) && failed(collapsedFalse)) 265 return failure(); 266 267 // Create a new branch with the collapsed successors. 268 rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(), 269 trueDest, trueDestOperands, 270 falseDest, falseDestOperands); 271 return success(); 272 } 273 }; 274 275 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) 276 /// -> br ^bb1(A, ..., N) 277 /// 278 /// cf.cond_br %cond, ^bb1(A), ^bb1(B) 279 /// -> %select = arith.select %cond, A, B 280 /// br ^bb1(%select) 281 /// 282 struct SimplifyCondBranchIdenticalSuccessors 283 : public OpRewritePattern<CondBranchOp> { 284 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 285 286 LogicalResult matchAndRewrite(CondBranchOp condbr, 287 PatternRewriter &rewriter) const override { 288 // Check that the true and false destinations are the same and have the same 289 // operands. 290 Block *trueDest = condbr.getTrueDest(); 291 if (trueDest != condbr.getFalseDest()) 292 return failure(); 293 294 // If all of the operands match, no selects need to be generated. 295 OperandRange trueOperands = condbr.getTrueOperands(); 296 OperandRange falseOperands = condbr.getFalseOperands(); 297 if (trueOperands == falseOperands) { 298 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands); 299 return success(); 300 } 301 302 // Otherwise, if the current block is the only predecessor insert selects 303 // for any mismatched branch operands. 304 if (trueDest->getUniquePredecessor() != condbr->getBlock()) 305 return failure(); 306 307 // Generate a select for any operands that differ between the two. 308 SmallVector<Value, 8> mergedOperands; 309 mergedOperands.reserve(trueOperands.size()); 310 Value condition = condbr.getCondition(); 311 for (auto it : llvm::zip(trueOperands, falseOperands)) { 312 if (std::get<0>(it) == std::get<1>(it)) 313 mergedOperands.push_back(std::get<0>(it)); 314 else 315 mergedOperands.push_back(rewriter.create<arith::SelectOp>( 316 condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it))); 317 } 318 319 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands); 320 return success(); 321 } 322 }; 323 324 /// ... 325 /// cf.cond_br %cond, ^bb1(...), ^bb2(...) 326 /// ... 327 /// ^bb1: // has single predecessor 328 /// ... 329 /// cf.cond_br %cond, ^bb3(...), ^bb4(...) 330 /// 331 /// -> 332 /// 333 /// ... 334 /// cf.cond_br %cond, ^bb1(...), ^bb2(...) 335 /// ... 336 /// ^bb1: // has single predecessor 337 /// ... 338 /// br ^bb3(...) 339 /// 340 struct SimplifyCondBranchFromCondBranchOnSameCondition 341 : public OpRewritePattern<CondBranchOp> { 342 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 343 344 LogicalResult matchAndRewrite(CondBranchOp condbr, 345 PatternRewriter &rewriter) const override { 346 // Check that we have a single distinct predecessor. 347 Block *currentBlock = condbr->getBlock(); 348 Block *predecessor = currentBlock->getSinglePredecessor(); 349 if (!predecessor) 350 return failure(); 351 352 // Check that the predecessor terminates with a conditional branch to this 353 // block and that it branches on the same condition. 354 auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator()); 355 if (!predBranch || condbr.getCondition() != predBranch.getCondition()) 356 return failure(); 357 358 // Fold this branch to an unconditional branch. 359 if (currentBlock == predBranch.getTrueDest()) 360 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), 361 condbr.getTrueDestOperands()); 362 else 363 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), 364 condbr.getFalseDestOperands()); 365 return success(); 366 } 367 }; 368 369 /// cf.cond_br %arg0, ^trueB, ^falseB 370 /// 371 /// ^trueB: 372 /// "test.consumer1"(%arg0) : (i1) -> () 373 /// ... 374 /// 375 /// ^falseB: 376 /// "test.consumer2"(%arg0) : (i1) -> () 377 /// ... 378 /// 379 /// -> 380 /// 381 /// cf.cond_br %arg0, ^trueB, ^falseB 382 /// ^trueB: 383 /// "test.consumer1"(%true) : (i1) -> () 384 /// ... 385 /// 386 /// ^falseB: 387 /// "test.consumer2"(%false) : (i1) -> () 388 /// ... 389 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> { 390 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 391 392 LogicalResult matchAndRewrite(CondBranchOp condbr, 393 PatternRewriter &rewriter) const override { 394 // Check that we have a single distinct predecessor. 395 bool replaced = false; 396 Type ty = rewriter.getI1Type(); 397 398 // These variables serve to prevent creating duplicate constants 399 // and hold constant true or false values. 400 Value constantTrue = nullptr; 401 Value constantFalse = nullptr; 402 403 // TODO These checks can be expanded to encompas any use with only 404 // either the true of false edge as a predecessor. For now, we fall 405 // back to checking the single predecessor is given by the true/fasle 406 // destination, thereby ensuring that only that edge can reach the 407 // op. 408 if (condbr.getTrueDest()->getSinglePredecessor()) { 409 for (OpOperand &use : 410 llvm::make_early_inc_range(condbr.getCondition().getUses())) { 411 if (use.getOwner()->getBlock() == condbr.getTrueDest()) { 412 replaced = true; 413 414 if (!constantTrue) 415 constantTrue = rewriter.create<arith::ConstantOp>( 416 condbr.getLoc(), ty, rewriter.getBoolAttr(true)); 417 418 rewriter.modifyOpInPlace(use.getOwner(), 419 [&] { use.set(constantTrue); }); 420 } 421 } 422 } 423 if (condbr.getFalseDest()->getSinglePredecessor()) { 424 for (OpOperand &use : 425 llvm::make_early_inc_range(condbr.getCondition().getUses())) { 426 if (use.getOwner()->getBlock() == condbr.getFalseDest()) { 427 replaced = true; 428 429 if (!constantFalse) 430 constantFalse = rewriter.create<arith::ConstantOp>( 431 condbr.getLoc(), ty, rewriter.getBoolAttr(false)); 432 433 rewriter.modifyOpInPlace(use.getOwner(), 434 [&] { use.set(constantFalse); }); 435 } 436 } 437 } 438 return success(replaced); 439 } 440 }; 441 } // namespace 442 443 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, 444 MLIRContext *context) { 445 results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch, 446 SimplifyCondBranchIdenticalSuccessors, 447 SimplifyCondBranchFromCondBranchOnSameCondition, 448 CondBranchTruthPropagation>(context); 449 } 450 451 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { 452 assert(index < getNumSuccessors() && "invalid successor index"); 453 return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable() 454 : getFalseDestOperandsMutable()); 455 } 456 457 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { 458 if (IntegerAttr condAttr = 459 llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) 460 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest(); 461 return nullptr; 462 } 463 464 //===----------------------------------------------------------------------===// 465 // SwitchOp 466 //===----------------------------------------------------------------------===// 467 468 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 469 Block *defaultDestination, ValueRange defaultOperands, 470 DenseIntElementsAttr caseValues, 471 BlockRange caseDestinations, 472 ArrayRef<ValueRange> caseOperands) { 473 build(builder, result, value, defaultOperands, caseOperands, caseValues, 474 defaultDestination, caseDestinations); 475 } 476 477 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 478 Block *defaultDestination, ValueRange defaultOperands, 479 ArrayRef<APInt> caseValues, BlockRange caseDestinations, 480 ArrayRef<ValueRange> caseOperands) { 481 DenseIntElementsAttr caseValuesAttr; 482 if (!caseValues.empty()) { 483 ShapedType caseValueType = VectorType::get( 484 static_cast<int64_t>(caseValues.size()), value.getType()); 485 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); 486 } 487 build(builder, result, value, defaultDestination, defaultOperands, 488 caseValuesAttr, caseDestinations, caseOperands); 489 } 490 491 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 492 Block *defaultDestination, ValueRange defaultOperands, 493 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 494 ArrayRef<ValueRange> caseOperands) { 495 DenseIntElementsAttr caseValuesAttr; 496 if (!caseValues.empty()) { 497 ShapedType caseValueType = VectorType::get( 498 static_cast<int64_t>(caseValues.size()), value.getType()); 499 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); 500 } 501 build(builder, result, value, defaultDestination, defaultOperands, 502 caseValuesAttr, caseDestinations, caseOperands); 503 } 504 505 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? 506 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* 507 static ParseResult parseSwitchOpCases( 508 OpAsmParser &parser, Type &flagType, Block *&defaultDestination, 509 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands, 510 SmallVectorImpl<Type> &defaultOperandTypes, 511 DenseIntElementsAttr &caseValues, 512 SmallVectorImpl<Block *> &caseDestinations, 513 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, 514 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 515 if (parser.parseKeyword("default") || parser.parseColon() || 516 parser.parseSuccessor(defaultDestination)) 517 return failure(); 518 if (succeeded(parser.parseOptionalLParen())) { 519 if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None, 520 /*allowResultNumber=*/false) || 521 parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen()) 522 return failure(); 523 } 524 525 SmallVector<APInt> values; 526 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 527 while (succeeded(parser.parseOptionalComma())) { 528 int64_t value = 0; 529 if (failed(parser.parseInteger(value))) 530 return failure(); 531 values.push_back(APInt(bitWidth, value, /*isSigned=*/true)); 532 533 Block *destination; 534 SmallVector<OpAsmParser::UnresolvedOperand> operands; 535 SmallVector<Type> operandTypes; 536 if (failed(parser.parseColon()) || 537 failed(parser.parseSuccessor(destination))) 538 return failure(); 539 if (succeeded(parser.parseOptionalLParen())) { 540 if (failed(parser.parseOperandList(operands, 541 OpAsmParser::Delimiter::None)) || 542 failed(parser.parseColonTypeList(operandTypes)) || 543 failed(parser.parseRParen())) 544 return failure(); 545 } 546 caseDestinations.push_back(destination); 547 caseOperands.emplace_back(operands); 548 caseOperandTypes.emplace_back(operandTypes); 549 } 550 551 if (!values.empty()) { 552 ShapedType caseValueType = 553 VectorType::get(static_cast<int64_t>(values.size()), flagType); 554 caseValues = DenseIntElementsAttr::get(caseValueType, values); 555 } 556 return success(); 557 } 558 559 static void printSwitchOpCases( 560 OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, 561 OperandRange defaultOperands, TypeRange defaultOperandTypes, 562 DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, 563 OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) { 564 p << " default: "; 565 p.printSuccessorAndUseList(defaultDestination, defaultOperands); 566 567 if (!caseValues) 568 return; 569 570 for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) { 571 p << ','; 572 p.printNewline(); 573 p << " "; 574 p << it.value().getLimitedValue(); 575 p << ": "; 576 p.printSuccessorAndUseList(caseDestinations[it.index()], 577 caseOperands[it.index()]); 578 } 579 p.printNewline(); 580 } 581 582 LogicalResult SwitchOp::verify() { 583 auto caseValues = getCaseValues(); 584 auto caseDestinations = getCaseDestinations(); 585 586 if (!caseValues && caseDestinations.empty()) 587 return success(); 588 589 Type flagType = getFlag().getType(); 590 Type caseValueType = caseValues->getType().getElementType(); 591 if (caseValueType != flagType) 592 return emitOpError() << "'flag' type (" << flagType 593 << ") should match case value type (" << caseValueType 594 << ")"; 595 596 if (caseValues && 597 caseValues->size() != static_cast<int64_t>(caseDestinations.size())) 598 return emitOpError() << "number of case values (" << caseValues->size() 599 << ") should match number of " 600 "case destinations (" 601 << caseDestinations.size() << ")"; 602 return success(); 603 } 604 605 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { 606 assert(index < getNumSuccessors() && "invalid successor index"); 607 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() 608 : getCaseOperandsMutable(index - 1)); 609 } 610 611 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { 612 std::optional<DenseIntElementsAttr> caseValues = getCaseValues(); 613 614 if (!caseValues) 615 return getDefaultDestination(); 616 617 SuccessorRange caseDests = getCaseDestinations(); 618 if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) { 619 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) 620 if (it.value() == value.getValue()) 621 return caseDests[it.index()]; 622 return getDefaultDestination(); 623 } 624 return nullptr; 625 } 626 627 /// switch %flag : i32, [ 628 /// default: ^bb1 629 /// ] 630 /// -> br ^bb1 631 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, 632 PatternRewriter &rewriter) { 633 if (!op.getCaseDestinations().empty()) 634 return failure(); 635 636 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 637 op.getDefaultOperands()); 638 return success(); 639 } 640 641 /// switch %flag : i32, [ 642 /// default: ^bb1, 643 /// 42: ^bb1, 644 /// 43: ^bb2 645 /// ] 646 /// -> 647 /// switch %flag : i32, [ 648 /// default: ^bb1, 649 /// 43: ^bb2 650 /// ] 651 static LogicalResult 652 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) { 653 SmallVector<Block *> newCaseDestinations; 654 SmallVector<ValueRange> newCaseOperands; 655 SmallVector<APInt> newCaseValues; 656 bool requiresChange = false; 657 auto caseValues = op.getCaseValues(); 658 auto caseDests = op.getCaseDestinations(); 659 660 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 661 if (caseDests[it.index()] == op.getDefaultDestination() && 662 op.getCaseOperands(it.index()) == op.getDefaultOperands()) { 663 requiresChange = true; 664 continue; 665 } 666 newCaseDestinations.push_back(caseDests[it.index()]); 667 newCaseOperands.push_back(op.getCaseOperands(it.index())); 668 newCaseValues.push_back(it.value()); 669 } 670 671 if (!requiresChange) 672 return failure(); 673 674 rewriter.replaceOpWithNewOp<SwitchOp>( 675 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), 676 newCaseValues, newCaseDestinations, newCaseOperands); 677 return success(); 678 } 679 680 /// Helper for folding a switch with a constant value. 681 /// switch %c_42 : i32, [ 682 /// default: ^bb1 , 683 /// 42: ^bb2, 684 /// 43: ^bb3 685 /// ] 686 /// -> br ^bb2 687 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, 688 const APInt &caseValue) { 689 auto caseValues = op.getCaseValues(); 690 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 691 if (it.value() == caseValue) { 692 rewriter.replaceOpWithNewOp<BranchOp>( 693 op, op.getCaseDestinations()[it.index()], 694 op.getCaseOperands(it.index())); 695 return; 696 } 697 } 698 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 699 op.getDefaultOperands()); 700 } 701 702 /// switch %c_42 : i32, [ 703 /// default: ^bb1, 704 /// 42: ^bb2, 705 /// 43: ^bb3 706 /// ] 707 /// -> br ^bb2 708 static LogicalResult simplifyConstSwitchValue(SwitchOp op, 709 PatternRewriter &rewriter) { 710 APInt caseValue; 711 if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue))) 712 return failure(); 713 714 foldSwitch(op, rewriter, caseValue); 715 return success(); 716 } 717 718 /// switch %c_42 : i32, [ 719 /// default: ^bb1, 720 /// 42: ^bb2, 721 /// ] 722 /// ^bb2: 723 /// br ^bb3 724 /// -> 725 /// switch %c_42 : i32, [ 726 /// default: ^bb1, 727 /// 42: ^bb3, 728 /// ] 729 static LogicalResult simplifyPassThroughSwitch(SwitchOp op, 730 PatternRewriter &rewriter) { 731 SmallVector<Block *> newCaseDests; 732 SmallVector<ValueRange> newCaseOperands; 733 SmallVector<SmallVector<Value>> argStorage; 734 auto caseValues = op.getCaseValues(); 735 argStorage.reserve(caseValues->size() + 1); 736 auto caseDests = op.getCaseDestinations(); 737 bool requiresChange = false; 738 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { 739 Block *caseDest = caseDests[i]; 740 ValueRange caseOperands = op.getCaseOperands(i); 741 argStorage.emplace_back(); 742 if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back()))) 743 requiresChange = true; 744 745 newCaseDests.push_back(caseDest); 746 newCaseOperands.push_back(caseOperands); 747 } 748 749 Block *defaultDest = op.getDefaultDestination(); 750 ValueRange defaultOperands = op.getDefaultOperands(); 751 argStorage.emplace_back(); 752 753 if (succeeded( 754 collapseBranch(defaultDest, defaultOperands, argStorage.back()))) 755 requiresChange = true; 756 757 if (!requiresChange) 758 return failure(); 759 760 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest, 761 defaultOperands, *caseValues, 762 newCaseDests, newCaseOperands); 763 return success(); 764 } 765 766 /// switch %flag : i32, [ 767 /// default: ^bb1, 768 /// 42: ^bb2, 769 /// ] 770 /// ^bb2: 771 /// switch %flag : i32, [ 772 /// default: ^bb3, 773 /// 42: ^bb4 774 /// ] 775 /// -> 776 /// switch %flag : i32, [ 777 /// default: ^bb1, 778 /// 42: ^bb2, 779 /// ] 780 /// ^bb2: 781 /// br ^bb4 782 /// 783 /// and 784 /// 785 /// switch %flag : i32, [ 786 /// default: ^bb1, 787 /// 42: ^bb2, 788 /// ] 789 /// ^bb2: 790 /// switch %flag : i32, [ 791 /// default: ^bb3, 792 /// 43: ^bb4 793 /// ] 794 /// -> 795 /// switch %flag : i32, [ 796 /// default: ^bb1, 797 /// 42: ^bb2, 798 /// ] 799 /// ^bb2: 800 /// br ^bb3 801 static LogicalResult 802 simplifySwitchFromSwitchOnSameCondition(SwitchOp op, 803 PatternRewriter &rewriter) { 804 // Check that we have a single distinct predecessor. 805 Block *currentBlock = op->getBlock(); 806 Block *predecessor = currentBlock->getSinglePredecessor(); 807 if (!predecessor) 808 return failure(); 809 810 // Check that the predecessor terminates with a switch branch to this block 811 // and that it branches on the same condition and that this branch isn't the 812 // default destination. 813 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); 814 if (!predSwitch || op.getFlag() != predSwitch.getFlag() || 815 predSwitch.getDefaultDestination() == currentBlock) 816 return failure(); 817 818 // Fold this switch to an unconditional branch. 819 SuccessorRange predDests = predSwitch.getCaseDestinations(); 820 auto it = llvm::find(predDests, currentBlock); 821 if (it != predDests.end()) { 822 std::optional<DenseIntElementsAttr> predCaseValues = 823 predSwitch.getCaseValues(); 824 foldSwitch(op, rewriter, 825 predCaseValues->getValues<APInt>()[it - predDests.begin()]); 826 } else { 827 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 828 op.getDefaultOperands()); 829 } 830 return success(); 831 } 832 833 /// switch %flag : i32, [ 834 /// default: ^bb1, 835 /// 42: ^bb2 836 /// ] 837 /// ^bb1: 838 /// switch %flag : i32, [ 839 /// default: ^bb3, 840 /// 42: ^bb4, 841 /// 43: ^bb5 842 /// ] 843 /// -> 844 /// switch %flag : i32, [ 845 /// default: ^bb1, 846 /// 42: ^bb2, 847 /// ] 848 /// ^bb1: 849 /// switch %flag : i32, [ 850 /// default: ^bb3, 851 /// 43: ^bb5 852 /// ] 853 static LogicalResult 854 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, 855 PatternRewriter &rewriter) { 856 // Check that we have a single distinct predecessor. 857 Block *currentBlock = op->getBlock(); 858 Block *predecessor = currentBlock->getSinglePredecessor(); 859 if (!predecessor) 860 return failure(); 861 862 // Check that the predecessor terminates with a switch branch to this block 863 // and that it branches on the same condition and that this branch is the 864 // default destination. 865 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); 866 if (!predSwitch || op.getFlag() != predSwitch.getFlag() || 867 predSwitch.getDefaultDestination() != currentBlock) 868 return failure(); 869 870 // Delete case values that are not possible here. 871 DenseSet<APInt> caseValuesToRemove; 872 auto predDests = predSwitch.getCaseDestinations(); 873 auto predCaseValues = predSwitch.getCaseValues(); 874 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) 875 if (currentBlock != predDests[i]) 876 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]); 877 878 SmallVector<Block *> newCaseDestinations; 879 SmallVector<ValueRange> newCaseOperands; 880 SmallVector<APInt> newCaseValues; 881 bool requiresChange = false; 882 883 auto caseValues = op.getCaseValues(); 884 auto caseDests = op.getCaseDestinations(); 885 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 886 if (caseValuesToRemove.contains(it.value())) { 887 requiresChange = true; 888 continue; 889 } 890 newCaseDestinations.push_back(caseDests[it.index()]); 891 newCaseOperands.push_back(op.getCaseOperands(it.index())); 892 newCaseValues.push_back(it.value()); 893 } 894 895 if (!requiresChange) 896 return failure(); 897 898 rewriter.replaceOpWithNewOp<SwitchOp>( 899 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), 900 newCaseValues, newCaseDestinations, newCaseOperands); 901 return success(); 902 } 903 904 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, 905 MLIRContext *context) { 906 results.add(&simplifySwitchWithOnlyDefault) 907 .add(&dropSwitchCasesThatMatchDefault) 908 .add(&simplifyConstSwitchValue) 909 .add(&simplifyPassThroughSwitch) 910 .add(&simplifySwitchFromSwitchOnSameCondition) 911 .add(&simplifySwitchFromDefaultSwitchOnSameCondition); 912 } 913 914 //===----------------------------------------------------------------------===// 915 // TableGen'd op method definitions 916 //===----------------------------------------------------------------------===// 917 918 #define GET_OP_CLASSES 919 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" 920