1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 10 11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" 14 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 15 #include "mlir/IR/AffineExpr.h" 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinOps.h" 19 #include "mlir/IR/BuiltinTypes.h" 20 #include "mlir/IR/IRMapping.h" 21 #include "mlir/IR/Matchers.h" 22 #include "mlir/IR/OpImplementation.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/IR/TypeUtilities.h" 25 #include "mlir/IR/Value.h" 26 #include "mlir/Support/MathExtras.h" 27 #include "mlir/Transforms/InliningUtils.h" 28 #include "llvm/ADT/APFloat.h" 29 #include "llvm/ADT/STLExtras.h" 30 #include "llvm/Support/FormatVariadic.h" 31 #include "llvm/Support/raw_ostream.h" 32 #include <numeric> 33 34 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc" 35 36 using namespace mlir; 37 using namespace mlir::cf; 38 39 //===----------------------------------------------------------------------===// 40 // ControlFlowDialect Interfaces 41 //===----------------------------------------------------------------------===// 42 namespace { 43 /// This class defines the interface for handling inlining with control flow 44 /// operations. 45 struct ControlFlowInlinerInterface : public DialectInlinerInterface { 46 using DialectInlinerInterface::DialectInlinerInterface; 47 ~ControlFlowInlinerInterface() override = default; 48 49 /// All control flow operations can be inlined. 50 bool isLegalToInline(Operation *call, Operation *callable, 51 bool wouldBeCloned) const final { 52 return true; 53 } 54 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { 55 return true; 56 } 57 58 /// ControlFlow terminator operations don't really need any special handing. 59 void handleTerminator(Operation *op, Block *newDest) const final {} 60 }; 61 } // namespace 62 63 //===----------------------------------------------------------------------===// 64 // ControlFlowDialect 65 //===----------------------------------------------------------------------===// 66 67 void ControlFlowDialect::initialize() { 68 addOperations< 69 #define GET_OP_LIST 70 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" 71 >(); 72 addInterfaces<ControlFlowInlinerInterface>(); 73 declarePromisedInterface<ControlFlowDialect, ConvertToLLVMPatternInterface>(); 74 declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp, 75 CondBranchOp>(); 76 declarePromisedInterface<CondBranchOp, 77 bufferization::BufferDeallocationOpInterface>(); 78 } 79 80 //===----------------------------------------------------------------------===// 81 // AssertOp 82 //===----------------------------------------------------------------------===// 83 84 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { 85 // Erase assertion if argument is constant true. 86 if (matchPattern(op.getArg(), m_One())) { 87 rewriter.eraseOp(op); 88 return success(); 89 } 90 return failure(); 91 } 92 93 //===----------------------------------------------------------------------===// 94 // BranchOp 95 //===----------------------------------------------------------------------===// 96 97 /// Given a successor, try to collapse it to a new destination if it only 98 /// contains a passthrough unconditional branch. If the successor is 99 /// collapsable, `successor` and `successorOperands` are updated to reference 100 /// the new destination and values. `argStorage` is used as storage if operands 101 /// to the collapsed successor need to be remapped. It must outlive uses of 102 /// successorOperands. 103 static LogicalResult collapseBranch(Block *&successor, 104 ValueRange &successorOperands, 105 SmallVectorImpl<Value> &argStorage) { 106 // Check that the successor only contains a unconditional branch. 107 if (std::next(successor->begin()) != successor->end()) 108 return failure(); 109 // Check that the terminator is an unconditional branch. 110 BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator()); 111 if (!successorBranch) 112 return failure(); 113 // Check that the arguments are only used within the terminator. 114 for (BlockArgument arg : successor->getArguments()) { 115 for (Operation *user : arg.getUsers()) 116 if (user != successorBranch) 117 return failure(); 118 } 119 // Don't try to collapse branches to infinite loops. 120 Block *successorDest = successorBranch.getDest(); 121 if (successorDest == successor) 122 return failure(); 123 124 // Update the operands to the successor. If the branch parent has no 125 // arguments, we can use the branch operands directly. 126 OperandRange operands = successorBranch.getOperands(); 127 if (successor->args_empty()) { 128 successor = successorDest; 129 successorOperands = operands; 130 return success(); 131 } 132 133 // Otherwise, we need to remap any argument operands. 134 for (Value operand : operands) { 135 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand); 136 if (argOperand && argOperand.getOwner() == successor) 137 argStorage.push_back(successorOperands[argOperand.getArgNumber()]); 138 else 139 argStorage.push_back(operand); 140 } 141 successor = successorDest; 142 successorOperands = argStorage; 143 return success(); 144 } 145 146 /// Simplify a branch to a block that has a single predecessor. This effectively 147 /// merges the two blocks. 148 static LogicalResult 149 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) { 150 // Check that the successor block has a single predecessor. 151 Block *succ = op.getDest(); 152 Block *opParent = op->getBlock(); 153 if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors())) 154 return failure(); 155 156 // Merge the successor into the current block and erase the branch. 157 SmallVector<Value> brOperands(op.getOperands()); 158 rewriter.eraseOp(op); 159 rewriter.mergeBlocks(succ, opParent, brOperands); 160 return success(); 161 } 162 163 /// br ^bb1 164 /// ^bb1 165 /// br ^bbN(...) 166 /// 167 /// -> br ^bbN(...) 168 /// 169 static LogicalResult simplifyPassThroughBr(BranchOp op, 170 PatternRewriter &rewriter) { 171 Block *dest = op.getDest(); 172 ValueRange destOperands = op.getOperands(); 173 SmallVector<Value, 4> destOperandStorage; 174 175 // Try to collapse the successor if it points somewhere other than this 176 // block. 177 if (dest == op->getBlock() || 178 failed(collapseBranch(dest, destOperands, destOperandStorage))) 179 return failure(); 180 181 // Create a new branch with the collapsed successor. 182 rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands); 183 return success(); 184 } 185 186 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) { 187 return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) || 188 succeeded(simplifyPassThroughBr(op, rewriter))); 189 } 190 191 void BranchOp::setDest(Block *block) { return setSuccessor(block); } 192 193 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); } 194 195 SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) { 196 assert(index == 0 && "invalid successor index"); 197 return SuccessorOperands(getDestOperandsMutable()); 198 } 199 200 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { 201 return getDest(); 202 } 203 204 //===----------------------------------------------------------------------===// 205 // CondBranchOp 206 //===----------------------------------------------------------------------===// 207 208 namespace { 209 /// cf.cond_br true, ^bb1, ^bb2 210 /// -> br ^bb1 211 /// cf.cond_br false, ^bb1, ^bb2 212 /// -> br ^bb2 213 /// 214 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { 215 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 216 217 LogicalResult matchAndRewrite(CondBranchOp condbr, 218 PatternRewriter &rewriter) const override { 219 if (matchPattern(condbr.getCondition(), m_NonZero())) { 220 // True branch taken. 221 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), 222 condbr.getTrueOperands()); 223 return success(); 224 } 225 if (matchPattern(condbr.getCondition(), m_Zero())) { 226 // False branch taken. 227 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), 228 condbr.getFalseOperands()); 229 return success(); 230 } 231 return failure(); 232 } 233 }; 234 235 /// cf.cond_br %cond, ^bb1, ^bb2 236 /// ^bb1 237 /// br ^bbN(...) 238 /// ^bb2 239 /// br ^bbK(...) 240 /// 241 /// -> cf.cond_br %cond, ^bbN(...), ^bbK(...) 242 /// 243 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> { 244 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 245 246 LogicalResult matchAndRewrite(CondBranchOp condbr, 247 PatternRewriter &rewriter) const override { 248 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest(); 249 ValueRange trueDestOperands = condbr.getTrueOperands(); 250 ValueRange falseDestOperands = condbr.getFalseOperands(); 251 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage; 252 253 // Try to collapse one of the current successors. 254 LogicalResult collapsedTrue = 255 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage); 256 LogicalResult collapsedFalse = 257 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage); 258 if (failed(collapsedTrue) && failed(collapsedFalse)) 259 return failure(); 260 261 // Create a new branch with the collapsed successors. 262 rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(), 263 trueDest, trueDestOperands, 264 falseDest, falseDestOperands); 265 return success(); 266 } 267 }; 268 269 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) 270 /// -> br ^bb1(A, ..., N) 271 /// 272 /// cf.cond_br %cond, ^bb1(A), ^bb1(B) 273 /// -> %select = arith.select %cond, A, B 274 /// br ^bb1(%select) 275 /// 276 struct SimplifyCondBranchIdenticalSuccessors 277 : public OpRewritePattern<CondBranchOp> { 278 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 279 280 LogicalResult matchAndRewrite(CondBranchOp condbr, 281 PatternRewriter &rewriter) const override { 282 // Check that the true and false destinations are the same and have the same 283 // operands. 284 Block *trueDest = condbr.getTrueDest(); 285 if (trueDest != condbr.getFalseDest()) 286 return failure(); 287 288 // If all of the operands match, no selects need to be generated. 289 OperandRange trueOperands = condbr.getTrueOperands(); 290 OperandRange falseOperands = condbr.getFalseOperands(); 291 if (trueOperands == falseOperands) { 292 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands); 293 return success(); 294 } 295 296 // Otherwise, if the current block is the only predecessor insert selects 297 // for any mismatched branch operands. 298 if (trueDest->getUniquePredecessor() != condbr->getBlock()) 299 return failure(); 300 301 // Generate a select for any operands that differ between the two. 302 SmallVector<Value, 8> mergedOperands; 303 mergedOperands.reserve(trueOperands.size()); 304 Value condition = condbr.getCondition(); 305 for (auto it : llvm::zip(trueOperands, falseOperands)) { 306 if (std::get<0>(it) == std::get<1>(it)) 307 mergedOperands.push_back(std::get<0>(it)); 308 else 309 mergedOperands.push_back(rewriter.create<arith::SelectOp>( 310 condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it))); 311 } 312 313 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands); 314 return success(); 315 } 316 }; 317 318 /// ... 319 /// cf.cond_br %cond, ^bb1(...), ^bb2(...) 320 /// ... 321 /// ^bb1: // has single predecessor 322 /// ... 323 /// cf.cond_br %cond, ^bb3(...), ^bb4(...) 324 /// 325 /// -> 326 /// 327 /// ... 328 /// cf.cond_br %cond, ^bb1(...), ^bb2(...) 329 /// ... 330 /// ^bb1: // has single predecessor 331 /// ... 332 /// br ^bb3(...) 333 /// 334 struct SimplifyCondBranchFromCondBranchOnSameCondition 335 : public OpRewritePattern<CondBranchOp> { 336 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 337 338 LogicalResult matchAndRewrite(CondBranchOp condbr, 339 PatternRewriter &rewriter) const override { 340 // Check that we have a single distinct predecessor. 341 Block *currentBlock = condbr->getBlock(); 342 Block *predecessor = currentBlock->getSinglePredecessor(); 343 if (!predecessor) 344 return failure(); 345 346 // Check that the predecessor terminates with a conditional branch to this 347 // block and that it branches on the same condition. 348 auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator()); 349 if (!predBranch || condbr.getCondition() != predBranch.getCondition()) 350 return failure(); 351 352 // Fold this branch to an unconditional branch. 353 if (currentBlock == predBranch.getTrueDest()) 354 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), 355 condbr.getTrueDestOperands()); 356 else 357 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), 358 condbr.getFalseDestOperands()); 359 return success(); 360 } 361 }; 362 363 /// cf.cond_br %arg0, ^trueB, ^falseB 364 /// 365 /// ^trueB: 366 /// "test.consumer1"(%arg0) : (i1) -> () 367 /// ... 368 /// 369 /// ^falseB: 370 /// "test.consumer2"(%arg0) : (i1) -> () 371 /// ... 372 /// 373 /// -> 374 /// 375 /// cf.cond_br %arg0, ^trueB, ^falseB 376 /// ^trueB: 377 /// "test.consumer1"(%true) : (i1) -> () 378 /// ... 379 /// 380 /// ^falseB: 381 /// "test.consumer2"(%false) : (i1) -> () 382 /// ... 383 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> { 384 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 385 386 LogicalResult matchAndRewrite(CondBranchOp condbr, 387 PatternRewriter &rewriter) const override { 388 // Check that we have a single distinct predecessor. 389 bool replaced = false; 390 Type ty = rewriter.getI1Type(); 391 392 // These variables serve to prevent creating duplicate constants 393 // and hold constant true or false values. 394 Value constantTrue = nullptr; 395 Value constantFalse = nullptr; 396 397 // TODO These checks can be expanded to encompas any use with only 398 // either the true of false edge as a predecessor. For now, we fall 399 // back to checking the single predecessor is given by the true/fasle 400 // destination, thereby ensuring that only that edge can reach the 401 // op. 402 if (condbr.getTrueDest()->getSinglePredecessor()) { 403 for (OpOperand &use : 404 llvm::make_early_inc_range(condbr.getCondition().getUses())) { 405 if (use.getOwner()->getBlock() == condbr.getTrueDest()) { 406 replaced = true; 407 408 if (!constantTrue) 409 constantTrue = rewriter.create<arith::ConstantOp>( 410 condbr.getLoc(), ty, rewriter.getBoolAttr(true)); 411 412 rewriter.modifyOpInPlace(use.getOwner(), 413 [&] { use.set(constantTrue); }); 414 } 415 } 416 } 417 if (condbr.getFalseDest()->getSinglePredecessor()) { 418 for (OpOperand &use : 419 llvm::make_early_inc_range(condbr.getCondition().getUses())) { 420 if (use.getOwner()->getBlock() == condbr.getFalseDest()) { 421 replaced = true; 422 423 if (!constantFalse) 424 constantFalse = rewriter.create<arith::ConstantOp>( 425 condbr.getLoc(), ty, rewriter.getBoolAttr(false)); 426 427 rewriter.modifyOpInPlace(use.getOwner(), 428 [&] { use.set(constantFalse); }); 429 } 430 } 431 } 432 return success(replaced); 433 } 434 }; 435 } // namespace 436 437 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, 438 MLIRContext *context) { 439 results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch, 440 SimplifyCondBranchIdenticalSuccessors, 441 SimplifyCondBranchFromCondBranchOnSameCondition, 442 CondBranchTruthPropagation>(context); 443 } 444 445 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { 446 assert(index < getNumSuccessors() && "invalid successor index"); 447 return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable() 448 : getFalseDestOperandsMutable()); 449 } 450 451 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { 452 if (IntegerAttr condAttr = 453 llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) 454 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest(); 455 return nullptr; 456 } 457 458 //===----------------------------------------------------------------------===// 459 // SwitchOp 460 //===----------------------------------------------------------------------===// 461 462 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 463 Block *defaultDestination, ValueRange defaultOperands, 464 DenseIntElementsAttr caseValues, 465 BlockRange caseDestinations, 466 ArrayRef<ValueRange> caseOperands) { 467 build(builder, result, value, defaultOperands, caseOperands, caseValues, 468 defaultDestination, caseDestinations); 469 } 470 471 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 472 Block *defaultDestination, ValueRange defaultOperands, 473 ArrayRef<APInt> caseValues, BlockRange caseDestinations, 474 ArrayRef<ValueRange> caseOperands) { 475 DenseIntElementsAttr caseValuesAttr; 476 if (!caseValues.empty()) { 477 ShapedType caseValueType = VectorType::get( 478 static_cast<int64_t>(caseValues.size()), value.getType()); 479 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); 480 } 481 build(builder, result, value, defaultDestination, defaultOperands, 482 caseValuesAttr, caseDestinations, caseOperands); 483 } 484 485 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 486 Block *defaultDestination, ValueRange defaultOperands, 487 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 488 ArrayRef<ValueRange> caseOperands) { 489 DenseIntElementsAttr caseValuesAttr; 490 if (!caseValues.empty()) { 491 ShapedType caseValueType = VectorType::get( 492 static_cast<int64_t>(caseValues.size()), value.getType()); 493 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); 494 } 495 build(builder, result, value, defaultDestination, defaultOperands, 496 caseValuesAttr, caseDestinations, caseOperands); 497 } 498 499 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? 500 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* 501 static ParseResult parseSwitchOpCases( 502 OpAsmParser &parser, Type &flagType, Block *&defaultDestination, 503 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands, 504 SmallVectorImpl<Type> &defaultOperandTypes, 505 DenseIntElementsAttr &caseValues, 506 SmallVectorImpl<Block *> &caseDestinations, 507 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, 508 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 509 if (parser.parseKeyword("default") || parser.parseColon() || 510 parser.parseSuccessor(defaultDestination)) 511 return failure(); 512 if (succeeded(parser.parseOptionalLParen())) { 513 if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None, 514 /*allowResultNumber=*/false) || 515 parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen()) 516 return failure(); 517 } 518 519 SmallVector<APInt> values; 520 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 521 while (succeeded(parser.parseOptionalComma())) { 522 int64_t value = 0; 523 if (failed(parser.parseInteger(value))) 524 return failure(); 525 values.push_back(APInt(bitWidth, value)); 526 527 Block *destination; 528 SmallVector<OpAsmParser::UnresolvedOperand> operands; 529 SmallVector<Type> operandTypes; 530 if (failed(parser.parseColon()) || 531 failed(parser.parseSuccessor(destination))) 532 return failure(); 533 if (succeeded(parser.parseOptionalLParen())) { 534 if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None, 535 /*allowResultNumber=*/false)) || 536 failed(parser.parseColonTypeList(operandTypes)) || 537 failed(parser.parseRParen())) 538 return failure(); 539 } 540 caseDestinations.push_back(destination); 541 caseOperands.emplace_back(operands); 542 caseOperandTypes.emplace_back(operandTypes); 543 } 544 545 if (!values.empty()) { 546 ShapedType caseValueType = 547 VectorType::get(static_cast<int64_t>(values.size()), flagType); 548 caseValues = DenseIntElementsAttr::get(caseValueType, values); 549 } 550 return success(); 551 } 552 553 static void printSwitchOpCases( 554 OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, 555 OperandRange defaultOperands, TypeRange defaultOperandTypes, 556 DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, 557 OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) { 558 p << " default: "; 559 p.printSuccessorAndUseList(defaultDestination, defaultOperands); 560 561 if (!caseValues) 562 return; 563 564 for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) { 565 p << ','; 566 p.printNewline(); 567 p << " "; 568 p << it.value().getLimitedValue(); 569 p << ": "; 570 p.printSuccessorAndUseList(caseDestinations[it.index()], 571 caseOperands[it.index()]); 572 } 573 p.printNewline(); 574 } 575 576 LogicalResult SwitchOp::verify() { 577 auto caseValues = getCaseValues(); 578 auto caseDestinations = getCaseDestinations(); 579 580 if (!caseValues && caseDestinations.empty()) 581 return success(); 582 583 Type flagType = getFlag().getType(); 584 Type caseValueType = caseValues->getType().getElementType(); 585 if (caseValueType != flagType) 586 return emitOpError() << "'flag' type (" << flagType 587 << ") should match case value type (" << caseValueType 588 << ")"; 589 590 if (caseValues && 591 caseValues->size() != static_cast<int64_t>(caseDestinations.size())) 592 return emitOpError() << "number of case values (" << caseValues->size() 593 << ") should match number of " 594 "case destinations (" 595 << caseDestinations.size() << ")"; 596 return success(); 597 } 598 599 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { 600 assert(index < getNumSuccessors() && "invalid successor index"); 601 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() 602 : getCaseOperandsMutable(index - 1)); 603 } 604 605 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { 606 std::optional<DenseIntElementsAttr> caseValues = getCaseValues(); 607 608 if (!caseValues) 609 return getDefaultDestination(); 610 611 SuccessorRange caseDests = getCaseDestinations(); 612 if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) { 613 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) 614 if (it.value() == value.getValue()) 615 return caseDests[it.index()]; 616 return getDefaultDestination(); 617 } 618 return nullptr; 619 } 620 621 /// switch %flag : i32, [ 622 /// default: ^bb1 623 /// ] 624 /// -> br ^bb1 625 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, 626 PatternRewriter &rewriter) { 627 if (!op.getCaseDestinations().empty()) 628 return failure(); 629 630 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 631 op.getDefaultOperands()); 632 return success(); 633 } 634 635 /// switch %flag : i32, [ 636 /// default: ^bb1, 637 /// 42: ^bb1, 638 /// 43: ^bb2 639 /// ] 640 /// -> 641 /// switch %flag : i32, [ 642 /// default: ^bb1, 643 /// 43: ^bb2 644 /// ] 645 static LogicalResult 646 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) { 647 SmallVector<Block *> newCaseDestinations; 648 SmallVector<ValueRange> newCaseOperands; 649 SmallVector<APInt> newCaseValues; 650 bool requiresChange = false; 651 auto caseValues = op.getCaseValues(); 652 auto caseDests = op.getCaseDestinations(); 653 654 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 655 if (caseDests[it.index()] == op.getDefaultDestination() && 656 op.getCaseOperands(it.index()) == op.getDefaultOperands()) { 657 requiresChange = true; 658 continue; 659 } 660 newCaseDestinations.push_back(caseDests[it.index()]); 661 newCaseOperands.push_back(op.getCaseOperands(it.index())); 662 newCaseValues.push_back(it.value()); 663 } 664 665 if (!requiresChange) 666 return failure(); 667 668 rewriter.replaceOpWithNewOp<SwitchOp>( 669 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), 670 newCaseValues, newCaseDestinations, newCaseOperands); 671 return success(); 672 } 673 674 /// Helper for folding a switch with a constant value. 675 /// switch %c_42 : i32, [ 676 /// default: ^bb1 , 677 /// 42: ^bb2, 678 /// 43: ^bb3 679 /// ] 680 /// -> br ^bb2 681 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, 682 const APInt &caseValue) { 683 auto caseValues = op.getCaseValues(); 684 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 685 if (it.value() == caseValue) { 686 rewriter.replaceOpWithNewOp<BranchOp>( 687 op, op.getCaseDestinations()[it.index()], 688 op.getCaseOperands(it.index())); 689 return; 690 } 691 } 692 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 693 op.getDefaultOperands()); 694 } 695 696 /// switch %c_42 : i32, [ 697 /// default: ^bb1, 698 /// 42: ^bb2, 699 /// 43: ^bb3 700 /// ] 701 /// -> br ^bb2 702 static LogicalResult simplifyConstSwitchValue(SwitchOp op, 703 PatternRewriter &rewriter) { 704 APInt caseValue; 705 if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue))) 706 return failure(); 707 708 foldSwitch(op, rewriter, caseValue); 709 return success(); 710 } 711 712 /// switch %c_42 : i32, [ 713 /// default: ^bb1, 714 /// 42: ^bb2, 715 /// ] 716 /// ^bb2: 717 /// br ^bb3 718 /// -> 719 /// switch %c_42 : i32, [ 720 /// default: ^bb1, 721 /// 42: ^bb3, 722 /// ] 723 static LogicalResult simplifyPassThroughSwitch(SwitchOp op, 724 PatternRewriter &rewriter) { 725 SmallVector<Block *> newCaseDests; 726 SmallVector<ValueRange> newCaseOperands; 727 SmallVector<SmallVector<Value>> argStorage; 728 auto caseValues = op.getCaseValues(); 729 argStorage.reserve(caseValues->size() + 1); 730 auto caseDests = op.getCaseDestinations(); 731 bool requiresChange = false; 732 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { 733 Block *caseDest = caseDests[i]; 734 ValueRange caseOperands = op.getCaseOperands(i); 735 argStorage.emplace_back(); 736 if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back()))) 737 requiresChange = true; 738 739 newCaseDests.push_back(caseDest); 740 newCaseOperands.push_back(caseOperands); 741 } 742 743 Block *defaultDest = op.getDefaultDestination(); 744 ValueRange defaultOperands = op.getDefaultOperands(); 745 argStorage.emplace_back(); 746 747 if (succeeded( 748 collapseBranch(defaultDest, defaultOperands, argStorage.back()))) 749 requiresChange = true; 750 751 if (!requiresChange) 752 return failure(); 753 754 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest, 755 defaultOperands, *caseValues, 756 newCaseDests, newCaseOperands); 757 return success(); 758 } 759 760 /// switch %flag : i32, [ 761 /// default: ^bb1, 762 /// 42: ^bb2, 763 /// ] 764 /// ^bb2: 765 /// switch %flag : i32, [ 766 /// default: ^bb3, 767 /// 42: ^bb4 768 /// ] 769 /// -> 770 /// switch %flag : i32, [ 771 /// default: ^bb1, 772 /// 42: ^bb2, 773 /// ] 774 /// ^bb2: 775 /// br ^bb4 776 /// 777 /// and 778 /// 779 /// switch %flag : i32, [ 780 /// default: ^bb1, 781 /// 42: ^bb2, 782 /// ] 783 /// ^bb2: 784 /// switch %flag : i32, [ 785 /// default: ^bb3, 786 /// 43: ^bb4 787 /// ] 788 /// -> 789 /// switch %flag : i32, [ 790 /// default: ^bb1, 791 /// 42: ^bb2, 792 /// ] 793 /// ^bb2: 794 /// br ^bb3 795 static LogicalResult 796 simplifySwitchFromSwitchOnSameCondition(SwitchOp op, 797 PatternRewriter &rewriter) { 798 // Check that we have a single distinct predecessor. 799 Block *currentBlock = op->getBlock(); 800 Block *predecessor = currentBlock->getSinglePredecessor(); 801 if (!predecessor) 802 return failure(); 803 804 // Check that the predecessor terminates with a switch branch to this block 805 // and that it branches on the same condition and that this branch isn't the 806 // default destination. 807 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); 808 if (!predSwitch || op.getFlag() != predSwitch.getFlag() || 809 predSwitch.getDefaultDestination() == currentBlock) 810 return failure(); 811 812 // Fold this switch to an unconditional branch. 813 SuccessorRange predDests = predSwitch.getCaseDestinations(); 814 auto it = llvm::find(predDests, currentBlock); 815 if (it != predDests.end()) { 816 std::optional<DenseIntElementsAttr> predCaseValues = 817 predSwitch.getCaseValues(); 818 foldSwitch(op, rewriter, 819 predCaseValues->getValues<APInt>()[it - predDests.begin()]); 820 } else { 821 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 822 op.getDefaultOperands()); 823 } 824 return success(); 825 } 826 827 /// switch %flag : i32, [ 828 /// default: ^bb1, 829 /// 42: ^bb2 830 /// ] 831 /// ^bb1: 832 /// switch %flag : i32, [ 833 /// default: ^bb3, 834 /// 42: ^bb4, 835 /// 43: ^bb5 836 /// ] 837 /// -> 838 /// switch %flag : i32, [ 839 /// default: ^bb1, 840 /// 42: ^bb2, 841 /// ] 842 /// ^bb1: 843 /// switch %flag : i32, [ 844 /// default: ^bb3, 845 /// 43: ^bb5 846 /// ] 847 static LogicalResult 848 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, 849 PatternRewriter &rewriter) { 850 // Check that we have a single distinct predecessor. 851 Block *currentBlock = op->getBlock(); 852 Block *predecessor = currentBlock->getSinglePredecessor(); 853 if (!predecessor) 854 return failure(); 855 856 // Check that the predecessor terminates with a switch branch to this block 857 // and that it branches on the same condition and that this branch is the 858 // default destination. 859 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); 860 if (!predSwitch || op.getFlag() != predSwitch.getFlag() || 861 predSwitch.getDefaultDestination() != currentBlock) 862 return failure(); 863 864 // Delete case values that are not possible here. 865 DenseSet<APInt> caseValuesToRemove; 866 auto predDests = predSwitch.getCaseDestinations(); 867 auto predCaseValues = predSwitch.getCaseValues(); 868 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) 869 if (currentBlock != predDests[i]) 870 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]); 871 872 SmallVector<Block *> newCaseDestinations; 873 SmallVector<ValueRange> newCaseOperands; 874 SmallVector<APInt> newCaseValues; 875 bool requiresChange = false; 876 877 auto caseValues = op.getCaseValues(); 878 auto caseDests = op.getCaseDestinations(); 879 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 880 if (caseValuesToRemove.contains(it.value())) { 881 requiresChange = true; 882 continue; 883 } 884 newCaseDestinations.push_back(caseDests[it.index()]); 885 newCaseOperands.push_back(op.getCaseOperands(it.index())); 886 newCaseValues.push_back(it.value()); 887 } 888 889 if (!requiresChange) 890 return failure(); 891 892 rewriter.replaceOpWithNewOp<SwitchOp>( 893 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), 894 newCaseValues, newCaseDestinations, newCaseOperands); 895 return success(); 896 } 897 898 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, 899 MLIRContext *context) { 900 results.add(&simplifySwitchWithOnlyDefault) 901 .add(&dropSwitchCasesThatMatchDefault) 902 .add(&simplifyConstSwitchValue) 903 .add(&simplifyPassThroughSwitch) 904 .add(&simplifySwitchFromSwitchOnSameCondition) 905 .add(&simplifySwitchFromDefaultSwitchOnSameCondition); 906 } 907 908 //===----------------------------------------------------------------------===// 909 // TableGen'd op method definitions 910 //===----------------------------------------------------------------------===// 911 912 #define GET_OP_CLASSES 913 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" 914