1ace01605SRiver Riddle //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===// 2ace01605SRiver Riddle // 3ace01605SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4ace01605SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5ace01605SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6ace01605SRiver Riddle // 7ace01605SRiver Riddle //===----------------------------------------------------------------------===// 8ace01605SRiver Riddle 9ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 10ace01605SRiver Riddle 11b43c5049SJustin Fargnoli #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 12abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 13513cdb82SJustin Fargnoli #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" 14513cdb82SJustin Fargnoli #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 15ace01605SRiver Riddle #include "mlir/IR/AffineExpr.h" 16ace01605SRiver Riddle #include "mlir/IR/AffineMap.h" 17ace01605SRiver Riddle #include "mlir/IR/Builders.h" 18ace01605SRiver Riddle #include "mlir/IR/BuiltinOps.h" 19ace01605SRiver Riddle #include "mlir/IR/BuiltinTypes.h" 204d67b278SJeff Niu #include "mlir/IR/IRMapping.h" 21ace01605SRiver Riddle #include "mlir/IR/Matchers.h" 22ace01605SRiver Riddle #include "mlir/IR/OpImplementation.h" 23ace01605SRiver Riddle #include "mlir/IR/PatternMatch.h" 24ace01605SRiver Riddle #include "mlir/IR/TypeUtilities.h" 25ace01605SRiver Riddle #include "mlir/IR/Value.h" 26ace01605SRiver Riddle #include "mlir/Transforms/InliningUtils.h" 27ace01605SRiver Riddle #include "llvm/ADT/APFloat.h" 28ace01605SRiver Riddle #include "llvm/ADT/STLExtras.h" 29ace01605SRiver Riddle #include "llvm/Support/FormatVariadic.h" 30ace01605SRiver Riddle #include "llvm/Support/raw_ostream.h" 31ace01605SRiver Riddle #include <numeric> 32ace01605SRiver Riddle 33ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc" 34ace01605SRiver Riddle 35ace01605SRiver Riddle using namespace mlir; 36ace01605SRiver Riddle using namespace mlir::cf; 37ace01605SRiver Riddle 38ace01605SRiver Riddle //===----------------------------------------------------------------------===// 39ace01605SRiver Riddle // ControlFlowDialect Interfaces 40ace01605SRiver Riddle //===----------------------------------------------------------------------===// 41ace01605SRiver Riddle namespace { 42ace01605SRiver Riddle /// This class defines the interface for handling inlining with control flow 43ace01605SRiver Riddle /// operations. 44ace01605SRiver Riddle struct ControlFlowInlinerInterface : public DialectInlinerInterface { 45ace01605SRiver Riddle using DialectInlinerInterface::DialectInlinerInterface; 46ace01605SRiver Riddle ~ControlFlowInlinerInterface() override = default; 47ace01605SRiver Riddle 48ace01605SRiver Riddle /// All control flow operations can be inlined. 49ace01605SRiver Riddle bool isLegalToInline(Operation *call, Operation *callable, 50ace01605SRiver Riddle bool wouldBeCloned) const final { 51ace01605SRiver Riddle return true; 52ace01605SRiver Riddle } 534d67b278SJeff Niu bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { 54ace01605SRiver Riddle return true; 55ace01605SRiver Riddle } 56ace01605SRiver Riddle 57ace01605SRiver Riddle /// ControlFlow terminator operations don't really need any special handing. 58ace01605SRiver Riddle void handleTerminator(Operation *op, Block *newDest) const final {} 59ace01605SRiver Riddle }; 60ace01605SRiver Riddle } // namespace 61ace01605SRiver Riddle 62ace01605SRiver Riddle //===----------------------------------------------------------------------===// 63ace01605SRiver Riddle // ControlFlowDialect 64ace01605SRiver Riddle //===----------------------------------------------------------------------===// 65ace01605SRiver Riddle 66ace01605SRiver Riddle void ControlFlowDialect::initialize() { 67ace01605SRiver Riddle addOperations< 68ace01605SRiver Riddle #define GET_OP_LIST 69ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" 70ace01605SRiver Riddle >(); 71ace01605SRiver Riddle addInterfaces<ControlFlowInlinerInterface>(); 7235d55f28SJustin Fargnoli declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>(); 73513cdb82SJustin Fargnoli declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp, 74513cdb82SJustin Fargnoli CondBranchOp>(); 7535d55f28SJustin Fargnoli declarePromisedInterface<bufferization::BufferDeallocationOpInterface, 7635d55f28SJustin Fargnoli CondBranchOp>(); 77ace01605SRiver Riddle } 78ace01605SRiver Riddle 79ace01605SRiver Riddle //===----------------------------------------------------------------------===// 80ace01605SRiver Riddle // AssertOp 81ace01605SRiver Riddle //===----------------------------------------------------------------------===// 82ace01605SRiver Riddle 83ace01605SRiver Riddle LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { 84ace01605SRiver Riddle // Erase assertion if argument is constant true. 85ace01605SRiver Riddle if (matchPattern(op.getArg(), m_One())) { 86ace01605SRiver Riddle rewriter.eraseOp(op); 87ace01605SRiver Riddle return success(); 88ace01605SRiver Riddle } 89ace01605SRiver Riddle return failure(); 90ace01605SRiver Riddle } 91ace01605SRiver Riddle 92a159b367SMcCowan Zhang // This side effect models "program termination". 93a159b367SMcCowan Zhang void AssertOp::getEffects( 94a159b367SMcCowan Zhang SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 95a159b367SMcCowan Zhang &effects) { 96a159b367SMcCowan Zhang effects.emplace_back(MemoryEffects::Write::get()); 97a159b367SMcCowan Zhang } 98a159b367SMcCowan Zhang 99ace01605SRiver Riddle //===----------------------------------------------------------------------===// 100ace01605SRiver Riddle // BranchOp 101ace01605SRiver Riddle //===----------------------------------------------------------------------===// 102ace01605SRiver Riddle 103ace01605SRiver Riddle /// Given a successor, try to collapse it to a new destination if it only 104ace01605SRiver Riddle /// contains a passthrough unconditional branch. If the successor is 105ace01605SRiver Riddle /// collapsable, `successor` and `successorOperands` are updated to reference 106ace01605SRiver Riddle /// the new destination and values. `argStorage` is used as storage if operands 107ace01605SRiver Riddle /// to the collapsed successor need to be remapped. It must outlive uses of 108ace01605SRiver Riddle /// successorOperands. 109ace01605SRiver Riddle static LogicalResult collapseBranch(Block *&successor, 110ace01605SRiver Riddle ValueRange &successorOperands, 111ace01605SRiver Riddle SmallVectorImpl<Value> &argStorage) { 112ace01605SRiver Riddle // Check that the successor only contains a unconditional branch. 113ace01605SRiver Riddle if (std::next(successor->begin()) != successor->end()) 114ace01605SRiver Riddle return failure(); 115ace01605SRiver Riddle // Check that the terminator is an unconditional branch. 116ace01605SRiver Riddle BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator()); 117ace01605SRiver Riddle if (!successorBranch) 118ace01605SRiver Riddle return failure(); 119ace01605SRiver Riddle // Check that the arguments are only used within the terminator. 120ace01605SRiver Riddle for (BlockArgument arg : successor->getArguments()) { 121ace01605SRiver Riddle for (Operation *user : arg.getUsers()) 122ace01605SRiver Riddle if (user != successorBranch) 123ace01605SRiver Riddle return failure(); 124ace01605SRiver Riddle } 125ace01605SRiver Riddle // Don't try to collapse branches to infinite loops. 126ace01605SRiver Riddle Block *successorDest = successorBranch.getDest(); 127ace01605SRiver Riddle if (successorDest == successor) 128ace01605SRiver Riddle return failure(); 129ace01605SRiver Riddle 130ace01605SRiver Riddle // Update the operands to the successor. If the branch parent has no 131ace01605SRiver Riddle // arguments, we can use the branch operands directly. 132ace01605SRiver Riddle OperandRange operands = successorBranch.getOperands(); 133ace01605SRiver Riddle if (successor->args_empty()) { 134ace01605SRiver Riddle successor = successorDest; 135ace01605SRiver Riddle successorOperands = operands; 136ace01605SRiver Riddle return success(); 137ace01605SRiver Riddle } 138ace01605SRiver Riddle 139ace01605SRiver Riddle // Otherwise, we need to remap any argument operands. 140ace01605SRiver Riddle for (Value operand : operands) { 141c1fa60b4STres Popp BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand); 142ace01605SRiver Riddle if (argOperand && argOperand.getOwner() == successor) 143ace01605SRiver Riddle argStorage.push_back(successorOperands[argOperand.getArgNumber()]); 144ace01605SRiver Riddle else 145ace01605SRiver Riddle argStorage.push_back(operand); 146ace01605SRiver Riddle } 147ace01605SRiver Riddle successor = successorDest; 148ace01605SRiver Riddle successorOperands = argStorage; 149ace01605SRiver Riddle return success(); 150ace01605SRiver Riddle } 151ace01605SRiver Riddle 152ace01605SRiver Riddle /// Simplify a branch to a block that has a single predecessor. This effectively 153ace01605SRiver Riddle /// merges the two blocks. 154ace01605SRiver Riddle static LogicalResult 155ace01605SRiver Riddle simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) { 156ace01605SRiver Riddle // Check that the successor block has a single predecessor. 157ace01605SRiver Riddle Block *succ = op.getDest(); 158ace01605SRiver Riddle Block *opParent = op->getBlock(); 159ace01605SRiver Riddle if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors())) 160ace01605SRiver Riddle return failure(); 161ace01605SRiver Riddle 162ace01605SRiver Riddle // Merge the successor into the current block and erase the branch. 16342c31d83SMatthias Springer SmallVector<Value> brOperands(op.getOperands()); 164ace01605SRiver Riddle rewriter.eraseOp(op); 16542c31d83SMatthias Springer rewriter.mergeBlocks(succ, opParent, brOperands); 166ace01605SRiver Riddle return success(); 167ace01605SRiver Riddle } 168ace01605SRiver Riddle 169ace01605SRiver Riddle /// br ^bb1 170ace01605SRiver Riddle /// ^bb1 171ace01605SRiver Riddle /// br ^bbN(...) 172ace01605SRiver Riddle /// 173ace01605SRiver Riddle /// -> br ^bbN(...) 174ace01605SRiver Riddle /// 175ace01605SRiver Riddle static LogicalResult simplifyPassThroughBr(BranchOp op, 176ace01605SRiver Riddle PatternRewriter &rewriter) { 177ace01605SRiver Riddle Block *dest = op.getDest(); 178ace01605SRiver Riddle ValueRange destOperands = op.getOperands(); 179ace01605SRiver Riddle SmallVector<Value, 4> destOperandStorage; 180ace01605SRiver Riddle 181ace01605SRiver Riddle // Try to collapse the successor if it points somewhere other than this 182ace01605SRiver Riddle // block. 183ace01605SRiver Riddle if (dest == op->getBlock() || 184ace01605SRiver Riddle failed(collapseBranch(dest, destOperands, destOperandStorage))) 185ace01605SRiver Riddle return failure(); 186ace01605SRiver Riddle 187ace01605SRiver Riddle // Create a new branch with the collapsed successor. 188ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands); 189ace01605SRiver Riddle return success(); 190ace01605SRiver Riddle } 191ace01605SRiver Riddle 192ace01605SRiver Riddle LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) { 193ace01605SRiver Riddle return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) || 194ace01605SRiver Riddle succeeded(simplifyPassThroughBr(op, rewriter))); 195ace01605SRiver Riddle } 196ace01605SRiver Riddle 197ace01605SRiver Riddle void BranchOp::setDest(Block *block) { return setSuccessor(block); } 198ace01605SRiver Riddle 199ace01605SRiver Riddle void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); } 200ace01605SRiver Riddle 2010c789db5SMarkus Böck SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) { 202ace01605SRiver Riddle assert(index == 0 && "invalid successor index"); 2030c789db5SMarkus Böck return SuccessorOperands(getDestOperandsMutable()); 204ace01605SRiver Riddle } 205ace01605SRiver Riddle 206ace01605SRiver Riddle Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { 207ace01605SRiver Riddle return getDest(); 208ace01605SRiver Riddle } 209ace01605SRiver Riddle 210ace01605SRiver Riddle //===----------------------------------------------------------------------===// 211ace01605SRiver Riddle // CondBranchOp 212ace01605SRiver Riddle //===----------------------------------------------------------------------===// 213ace01605SRiver Riddle 214ace01605SRiver Riddle namespace { 215ace01605SRiver Riddle /// cf.cond_br true, ^bb1, ^bb2 216ace01605SRiver Riddle /// -> br ^bb1 217ace01605SRiver Riddle /// cf.cond_br false, ^bb1, ^bb2 218ace01605SRiver Riddle /// -> br ^bb2 219ace01605SRiver Riddle /// 220ace01605SRiver Riddle struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { 221ace01605SRiver Riddle using OpRewritePattern<CondBranchOp>::OpRewritePattern; 222ace01605SRiver Riddle 223ace01605SRiver Riddle LogicalResult matchAndRewrite(CondBranchOp condbr, 224ace01605SRiver Riddle PatternRewriter &rewriter) const override { 225ace01605SRiver Riddle if (matchPattern(condbr.getCondition(), m_NonZero())) { 226ace01605SRiver Riddle // True branch taken. 227ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), 228ace01605SRiver Riddle condbr.getTrueOperands()); 229ace01605SRiver Riddle return success(); 230ace01605SRiver Riddle } 231ace01605SRiver Riddle if (matchPattern(condbr.getCondition(), m_Zero())) { 232ace01605SRiver Riddle // False branch taken. 233ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), 234ace01605SRiver Riddle condbr.getFalseOperands()); 235ace01605SRiver Riddle return success(); 236ace01605SRiver Riddle } 237ace01605SRiver Riddle return failure(); 238ace01605SRiver Riddle } 239ace01605SRiver Riddle }; 240ace01605SRiver Riddle 241ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1, ^bb2 242ace01605SRiver Riddle /// ^bb1 243ace01605SRiver Riddle /// br ^bbN(...) 244ace01605SRiver Riddle /// ^bb2 245ace01605SRiver Riddle /// br ^bbK(...) 246ace01605SRiver Riddle /// 247ace01605SRiver Riddle /// -> cf.cond_br %cond, ^bbN(...), ^bbK(...) 248ace01605SRiver Riddle /// 249ace01605SRiver Riddle struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> { 250ace01605SRiver Riddle using OpRewritePattern<CondBranchOp>::OpRewritePattern; 251ace01605SRiver Riddle 252ace01605SRiver Riddle LogicalResult matchAndRewrite(CondBranchOp condbr, 253ace01605SRiver Riddle PatternRewriter &rewriter) const override { 254ace01605SRiver Riddle Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest(); 255ace01605SRiver Riddle ValueRange trueDestOperands = condbr.getTrueOperands(); 256ace01605SRiver Riddle ValueRange falseDestOperands = condbr.getFalseOperands(); 257ace01605SRiver Riddle SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage; 258ace01605SRiver Riddle 259ace01605SRiver Riddle // Try to collapse one of the current successors. 260ace01605SRiver Riddle LogicalResult collapsedTrue = 261ace01605SRiver Riddle collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage); 262ace01605SRiver Riddle LogicalResult collapsedFalse = 263ace01605SRiver Riddle collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage); 264ace01605SRiver Riddle if (failed(collapsedTrue) && failed(collapsedFalse)) 265ace01605SRiver Riddle return failure(); 266ace01605SRiver Riddle 267ace01605SRiver Riddle // Create a new branch with the collapsed successors. 268ace01605SRiver Riddle rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(), 269ace01605SRiver Riddle trueDest, trueDestOperands, 270ace01605SRiver Riddle falseDest, falseDestOperands); 271ace01605SRiver Riddle return success(); 272ace01605SRiver Riddle } 273ace01605SRiver Riddle }; 274ace01605SRiver Riddle 275ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) 276ace01605SRiver Riddle /// -> br ^bb1(A, ..., N) 277ace01605SRiver Riddle /// 278ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1(A), ^bb1(B) 279ace01605SRiver Riddle /// -> %select = arith.select %cond, A, B 280ace01605SRiver Riddle /// br ^bb1(%select) 281ace01605SRiver Riddle /// 282ace01605SRiver Riddle struct SimplifyCondBranchIdenticalSuccessors 283ace01605SRiver Riddle : public OpRewritePattern<CondBranchOp> { 284ace01605SRiver Riddle using OpRewritePattern<CondBranchOp>::OpRewritePattern; 285ace01605SRiver Riddle 286ace01605SRiver Riddle LogicalResult matchAndRewrite(CondBranchOp condbr, 287ace01605SRiver Riddle PatternRewriter &rewriter) const override { 288ace01605SRiver Riddle // Check that the true and false destinations are the same and have the same 289ace01605SRiver Riddle // operands. 290ace01605SRiver Riddle Block *trueDest = condbr.getTrueDest(); 291ace01605SRiver Riddle if (trueDest != condbr.getFalseDest()) 292ace01605SRiver Riddle return failure(); 293ace01605SRiver Riddle 294ace01605SRiver Riddle // If all of the operands match, no selects need to be generated. 295ace01605SRiver Riddle OperandRange trueOperands = condbr.getTrueOperands(); 296ace01605SRiver Riddle OperandRange falseOperands = condbr.getFalseOperands(); 297ace01605SRiver Riddle if (trueOperands == falseOperands) { 298ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands); 299ace01605SRiver Riddle return success(); 300ace01605SRiver Riddle } 301ace01605SRiver Riddle 302ace01605SRiver Riddle // Otherwise, if the current block is the only predecessor insert selects 303ace01605SRiver Riddle // for any mismatched branch operands. 304ace01605SRiver Riddle if (trueDest->getUniquePredecessor() != condbr->getBlock()) 305ace01605SRiver Riddle return failure(); 306ace01605SRiver Riddle 307ace01605SRiver Riddle // Generate a select for any operands that differ between the two. 308ace01605SRiver Riddle SmallVector<Value, 8> mergedOperands; 309ace01605SRiver Riddle mergedOperands.reserve(trueOperands.size()); 310ace01605SRiver Riddle Value condition = condbr.getCondition(); 311ace01605SRiver Riddle for (auto it : llvm::zip(trueOperands, falseOperands)) { 312ace01605SRiver Riddle if (std::get<0>(it) == std::get<1>(it)) 313ace01605SRiver Riddle mergedOperands.push_back(std::get<0>(it)); 314ace01605SRiver Riddle else 315ace01605SRiver Riddle mergedOperands.push_back(rewriter.create<arith::SelectOp>( 316ace01605SRiver Riddle condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it))); 317ace01605SRiver Riddle } 318ace01605SRiver Riddle 319ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands); 320ace01605SRiver Riddle return success(); 321ace01605SRiver Riddle } 322ace01605SRiver Riddle }; 323ace01605SRiver Riddle 324ace01605SRiver Riddle /// ... 325ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1(...), ^bb2(...) 326ace01605SRiver Riddle /// ... 327ace01605SRiver Riddle /// ^bb1: // has single predecessor 328ace01605SRiver Riddle /// ... 329ace01605SRiver Riddle /// cf.cond_br %cond, ^bb3(...), ^bb4(...) 330ace01605SRiver Riddle /// 331ace01605SRiver Riddle /// -> 332ace01605SRiver Riddle /// 333ace01605SRiver Riddle /// ... 334ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1(...), ^bb2(...) 335ace01605SRiver Riddle /// ... 336ace01605SRiver Riddle /// ^bb1: // has single predecessor 337ace01605SRiver Riddle /// ... 338ace01605SRiver Riddle /// br ^bb3(...) 339ace01605SRiver Riddle /// 340ace01605SRiver Riddle struct SimplifyCondBranchFromCondBranchOnSameCondition 341ace01605SRiver Riddle : public OpRewritePattern<CondBranchOp> { 342ace01605SRiver Riddle using OpRewritePattern<CondBranchOp>::OpRewritePattern; 343ace01605SRiver Riddle 344ace01605SRiver Riddle LogicalResult matchAndRewrite(CondBranchOp condbr, 345ace01605SRiver Riddle PatternRewriter &rewriter) const override { 346ace01605SRiver Riddle // Check that we have a single distinct predecessor. 347ace01605SRiver Riddle Block *currentBlock = condbr->getBlock(); 348ace01605SRiver Riddle Block *predecessor = currentBlock->getSinglePredecessor(); 349ace01605SRiver Riddle if (!predecessor) 350ace01605SRiver Riddle return failure(); 351ace01605SRiver Riddle 352ace01605SRiver Riddle // Check that the predecessor terminates with a conditional branch to this 353ace01605SRiver Riddle // block and that it branches on the same condition. 354ace01605SRiver Riddle auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator()); 355ace01605SRiver Riddle if (!predBranch || condbr.getCondition() != predBranch.getCondition()) 356ace01605SRiver Riddle return failure(); 357ace01605SRiver Riddle 358ace01605SRiver Riddle // Fold this branch to an unconditional branch. 359ace01605SRiver Riddle if (currentBlock == predBranch.getTrueDest()) 360ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), 361ace01605SRiver Riddle condbr.getTrueDestOperands()); 362ace01605SRiver Riddle else 363ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), 364ace01605SRiver Riddle condbr.getFalseDestOperands()); 365ace01605SRiver Riddle return success(); 366ace01605SRiver Riddle } 367ace01605SRiver Riddle }; 368ace01605SRiver Riddle 369ace01605SRiver Riddle /// cf.cond_br %arg0, ^trueB, ^falseB 370ace01605SRiver Riddle /// 371ace01605SRiver Riddle /// ^trueB: 372ace01605SRiver Riddle /// "test.consumer1"(%arg0) : (i1) -> () 373ace01605SRiver Riddle /// ... 374ace01605SRiver Riddle /// 375ace01605SRiver Riddle /// ^falseB: 376ace01605SRiver Riddle /// "test.consumer2"(%arg0) : (i1) -> () 377ace01605SRiver Riddle /// ... 378ace01605SRiver Riddle /// 379ace01605SRiver Riddle /// -> 380ace01605SRiver Riddle /// 381ace01605SRiver Riddle /// cf.cond_br %arg0, ^trueB, ^falseB 382ace01605SRiver Riddle /// ^trueB: 383ace01605SRiver Riddle /// "test.consumer1"(%true) : (i1) -> () 384ace01605SRiver Riddle /// ... 385ace01605SRiver Riddle /// 386ace01605SRiver Riddle /// ^falseB: 387ace01605SRiver Riddle /// "test.consumer2"(%false) : (i1) -> () 388ace01605SRiver Riddle /// ... 389ace01605SRiver Riddle struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> { 390ace01605SRiver Riddle using OpRewritePattern<CondBranchOp>::OpRewritePattern; 391ace01605SRiver Riddle 392ace01605SRiver Riddle LogicalResult matchAndRewrite(CondBranchOp condbr, 393ace01605SRiver Riddle PatternRewriter &rewriter) const override { 394ace01605SRiver Riddle // Check that we have a single distinct predecessor. 395ace01605SRiver Riddle bool replaced = false; 396ace01605SRiver Riddle Type ty = rewriter.getI1Type(); 397ace01605SRiver Riddle 398ace01605SRiver Riddle // These variables serve to prevent creating duplicate constants 399ace01605SRiver Riddle // and hold constant true or false values. 400ace01605SRiver Riddle Value constantTrue = nullptr; 401ace01605SRiver Riddle Value constantFalse = nullptr; 402ace01605SRiver Riddle 403ace01605SRiver Riddle // TODO These checks can be expanded to encompas any use with only 404ace01605SRiver Riddle // either the true of false edge as a predecessor. For now, we fall 405ace01605SRiver Riddle // back to checking the single predecessor is given by the true/fasle 406ace01605SRiver Riddle // destination, thereby ensuring that only that edge can reach the 407ace01605SRiver Riddle // op. 408ace01605SRiver Riddle if (condbr.getTrueDest()->getSinglePredecessor()) { 409ace01605SRiver Riddle for (OpOperand &use : 410ace01605SRiver Riddle llvm::make_early_inc_range(condbr.getCondition().getUses())) { 411ace01605SRiver Riddle if (use.getOwner()->getBlock() == condbr.getTrueDest()) { 412ace01605SRiver Riddle replaced = true; 413ace01605SRiver Riddle 414ace01605SRiver Riddle if (!constantTrue) 415ace01605SRiver Riddle constantTrue = rewriter.create<arith::ConstantOp>( 416ace01605SRiver Riddle condbr.getLoc(), ty, rewriter.getBoolAttr(true)); 417ace01605SRiver Riddle 4185fcf907bSMatthias Springer rewriter.modifyOpInPlace(use.getOwner(), 419ace01605SRiver Riddle [&] { use.set(constantTrue); }); 420ace01605SRiver Riddle } 421ace01605SRiver Riddle } 422ace01605SRiver Riddle } 423ace01605SRiver Riddle if (condbr.getFalseDest()->getSinglePredecessor()) { 424ace01605SRiver Riddle for (OpOperand &use : 425ace01605SRiver Riddle llvm::make_early_inc_range(condbr.getCondition().getUses())) { 426ace01605SRiver Riddle if (use.getOwner()->getBlock() == condbr.getFalseDest()) { 427ace01605SRiver Riddle replaced = true; 428ace01605SRiver Riddle 429ace01605SRiver Riddle if (!constantFalse) 430ace01605SRiver Riddle constantFalse = rewriter.create<arith::ConstantOp>( 431ace01605SRiver Riddle condbr.getLoc(), ty, rewriter.getBoolAttr(false)); 432ace01605SRiver Riddle 4335fcf907bSMatthias Springer rewriter.modifyOpInPlace(use.getOwner(), 434ace01605SRiver Riddle [&] { use.set(constantFalse); }); 435ace01605SRiver Riddle } 436ace01605SRiver Riddle } 437ace01605SRiver Riddle } 438ace01605SRiver Riddle return success(replaced); 439ace01605SRiver Riddle } 440ace01605SRiver Riddle }; 441ace01605SRiver Riddle } // namespace 442ace01605SRiver Riddle 443ace01605SRiver Riddle void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, 444ace01605SRiver Riddle MLIRContext *context) { 445ace01605SRiver Riddle results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch, 446ace01605SRiver Riddle SimplifyCondBranchIdenticalSuccessors, 447ace01605SRiver Riddle SimplifyCondBranchFromCondBranchOnSameCondition, 448ace01605SRiver Riddle CondBranchTruthPropagation>(context); 449ace01605SRiver Riddle } 450ace01605SRiver Riddle 4510c789db5SMarkus Böck SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { 452ace01605SRiver Riddle assert(index < getNumSuccessors() && "invalid successor index"); 4530c789db5SMarkus Böck return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable() 4540c789db5SMarkus Böck : getFalseDestOperandsMutable()); 455ace01605SRiver Riddle } 456ace01605SRiver Riddle 457ace01605SRiver Riddle Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { 458c1fa60b4STres Popp if (IntegerAttr condAttr = 459c1fa60b4STres Popp llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) 4609e5d2495SKazu Hirata return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest(); 461ace01605SRiver Riddle return nullptr; 462ace01605SRiver Riddle } 463ace01605SRiver Riddle 464ace01605SRiver Riddle //===----------------------------------------------------------------------===// 465ace01605SRiver Riddle // SwitchOp 466ace01605SRiver Riddle //===----------------------------------------------------------------------===// 467ace01605SRiver Riddle 468ace01605SRiver Riddle void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 469ace01605SRiver Riddle Block *defaultDestination, ValueRange defaultOperands, 470ace01605SRiver Riddle DenseIntElementsAttr caseValues, 471ace01605SRiver Riddle BlockRange caseDestinations, 472ace01605SRiver Riddle ArrayRef<ValueRange> caseOperands) { 473ace01605SRiver Riddle build(builder, result, value, defaultOperands, caseOperands, caseValues, 474ace01605SRiver Riddle defaultDestination, caseDestinations); 475ace01605SRiver Riddle } 476ace01605SRiver Riddle 477ace01605SRiver Riddle void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 478ace01605SRiver Riddle Block *defaultDestination, ValueRange defaultOperands, 479ace01605SRiver Riddle ArrayRef<APInt> caseValues, BlockRange caseDestinations, 480ace01605SRiver Riddle ArrayRef<ValueRange> caseOperands) { 481ace01605SRiver Riddle DenseIntElementsAttr caseValuesAttr; 482ace01605SRiver Riddle if (!caseValues.empty()) { 483ace01605SRiver Riddle ShapedType caseValueType = VectorType::get( 484ace01605SRiver Riddle static_cast<int64_t>(caseValues.size()), value.getType()); 485ace01605SRiver Riddle caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); 486ace01605SRiver Riddle } 487ace01605SRiver Riddle build(builder, result, value, defaultDestination, defaultOperands, 488ace01605SRiver Riddle caseValuesAttr, caseDestinations, caseOperands); 489ace01605SRiver Riddle } 490ace01605SRiver Riddle 491b34fb277SAlexander Batashev void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 492b34fb277SAlexander Batashev Block *defaultDestination, ValueRange defaultOperands, 493b34fb277SAlexander Batashev ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 494b34fb277SAlexander Batashev ArrayRef<ValueRange> caseOperands) { 495b34fb277SAlexander Batashev DenseIntElementsAttr caseValuesAttr; 496b34fb277SAlexander Batashev if (!caseValues.empty()) { 497b34fb277SAlexander Batashev ShapedType caseValueType = VectorType::get( 498b34fb277SAlexander Batashev static_cast<int64_t>(caseValues.size()), value.getType()); 499b34fb277SAlexander Batashev caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); 500b34fb277SAlexander Batashev } 501b34fb277SAlexander Batashev build(builder, result, value, defaultDestination, defaultOperands, 502b34fb277SAlexander Batashev caseValuesAttr, caseDestinations, caseOperands); 503b34fb277SAlexander Batashev } 504b34fb277SAlexander Batashev 505ace01605SRiver Riddle /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? 506ace01605SRiver Riddle /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* 507ace01605SRiver Riddle static ParseResult parseSwitchOpCases( 508ace01605SRiver Riddle OpAsmParser &parser, Type &flagType, Block *&defaultDestination, 509e13d23bcSMarkus Böck SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands, 510ace01605SRiver Riddle SmallVectorImpl<Type> &defaultOperandTypes, 511ace01605SRiver Riddle DenseIntElementsAttr &caseValues, 512ace01605SRiver Riddle SmallVectorImpl<Block *> &caseDestinations, 513e13d23bcSMarkus Böck SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, 514ace01605SRiver Riddle SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 515ace01605SRiver Riddle if (parser.parseKeyword("default") || parser.parseColon() || 516ace01605SRiver Riddle parser.parseSuccessor(defaultDestination)) 517ace01605SRiver Riddle return failure(); 518ace01605SRiver Riddle if (succeeded(parser.parseOptionalLParen())) { 5195dedf911SChris Lattner if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None, 5205dedf911SChris Lattner /*allowResultNumber=*/false) || 521ace01605SRiver Riddle parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen()) 522ace01605SRiver Riddle return failure(); 523ace01605SRiver Riddle } 524ace01605SRiver Riddle 525ace01605SRiver Riddle SmallVector<APInt> values; 526ace01605SRiver Riddle unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 527ace01605SRiver Riddle while (succeeded(parser.parseOptionalComma())) { 528ace01605SRiver Riddle int64_t value = 0; 529ace01605SRiver Riddle if (failed(parser.parseInteger(value))) 530ace01605SRiver Riddle return failure(); 531*e692af85SNikita Popov values.push_back(APInt(bitWidth, value, /*isSigned=*/true)); 532ace01605SRiver Riddle 533ace01605SRiver Riddle Block *destination; 534e13d23bcSMarkus Böck SmallVector<OpAsmParser::UnresolvedOperand> operands; 535ace01605SRiver Riddle SmallVector<Type> operandTypes; 536ace01605SRiver Riddle if (failed(parser.parseColon()) || 537ace01605SRiver Riddle failed(parser.parseSuccessor(destination))) 538ace01605SRiver Riddle return failure(); 539ace01605SRiver Riddle if (succeeded(parser.parseOptionalLParen())) { 5407e87d03bSKeyi Zhang if (failed(parser.parseOperandList(operands, 5417e87d03bSKeyi Zhang OpAsmParser::Delimiter::None)) || 542ace01605SRiver Riddle failed(parser.parseColonTypeList(operandTypes)) || 543ace01605SRiver Riddle failed(parser.parseRParen())) 544ace01605SRiver Riddle return failure(); 545ace01605SRiver Riddle } 546ace01605SRiver Riddle caseDestinations.push_back(destination); 547ace01605SRiver Riddle caseOperands.emplace_back(operands); 548ace01605SRiver Riddle caseOperandTypes.emplace_back(operandTypes); 549ace01605SRiver Riddle } 550ace01605SRiver Riddle 551ace01605SRiver Riddle if (!values.empty()) { 552ace01605SRiver Riddle ShapedType caseValueType = 553ace01605SRiver Riddle VectorType::get(static_cast<int64_t>(values.size()), flagType); 554ace01605SRiver Riddle caseValues = DenseIntElementsAttr::get(caseValueType, values); 555ace01605SRiver Riddle } 556ace01605SRiver Riddle return success(); 557ace01605SRiver Riddle } 558ace01605SRiver Riddle 559ace01605SRiver Riddle static void printSwitchOpCases( 560ace01605SRiver Riddle OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, 561ace01605SRiver Riddle OperandRange defaultOperands, TypeRange defaultOperandTypes, 562ace01605SRiver Riddle DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, 563ace01605SRiver Riddle OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) { 564ace01605SRiver Riddle p << " default: "; 565ace01605SRiver Riddle p.printSuccessorAndUseList(defaultDestination, defaultOperands); 566ace01605SRiver Riddle 567ace01605SRiver Riddle if (!caseValues) 568ace01605SRiver Riddle return; 569ace01605SRiver Riddle 570ace01605SRiver Riddle for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) { 571ace01605SRiver Riddle p << ','; 572ace01605SRiver Riddle p.printNewline(); 573ace01605SRiver Riddle p << " "; 574ace01605SRiver Riddle p << it.value().getLimitedValue(); 575ace01605SRiver Riddle p << ": "; 576ace01605SRiver Riddle p.printSuccessorAndUseList(caseDestinations[it.index()], 577ace01605SRiver Riddle caseOperands[it.index()]); 578ace01605SRiver Riddle } 579ace01605SRiver Riddle p.printNewline(); 580ace01605SRiver Riddle } 581ace01605SRiver Riddle 582ace01605SRiver Riddle LogicalResult SwitchOp::verify() { 583ace01605SRiver Riddle auto caseValues = getCaseValues(); 584ace01605SRiver Riddle auto caseDestinations = getCaseDestinations(); 585ace01605SRiver Riddle 586ace01605SRiver Riddle if (!caseValues && caseDestinations.empty()) 587ace01605SRiver Riddle return success(); 588ace01605SRiver Riddle 589ace01605SRiver Riddle Type flagType = getFlag().getType(); 590ace01605SRiver Riddle Type caseValueType = caseValues->getType().getElementType(); 591ace01605SRiver Riddle if (caseValueType != flagType) 592ace01605SRiver Riddle return emitOpError() << "'flag' type (" << flagType 593ace01605SRiver Riddle << ") should match case value type (" << caseValueType 594ace01605SRiver Riddle << ")"; 595ace01605SRiver Riddle 596ace01605SRiver Riddle if (caseValues && 597ace01605SRiver Riddle caseValues->size() != static_cast<int64_t>(caseDestinations.size())) 598ace01605SRiver Riddle return emitOpError() << "number of case values (" << caseValues->size() 599ace01605SRiver Riddle << ") should match number of " 600ace01605SRiver Riddle "case destinations (" 601ace01605SRiver Riddle << caseDestinations.size() << ")"; 602ace01605SRiver Riddle return success(); 603ace01605SRiver Riddle } 604ace01605SRiver Riddle 6050c789db5SMarkus Böck SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { 606ace01605SRiver Riddle assert(index < getNumSuccessors() && "invalid successor index"); 6070c789db5SMarkus Böck return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() 6080c789db5SMarkus Böck : getCaseOperandsMutable(index - 1)); 609ace01605SRiver Riddle } 610ace01605SRiver Riddle 611ace01605SRiver Riddle Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { 61222426110SRamkumar Ramachandra std::optional<DenseIntElementsAttr> caseValues = getCaseValues(); 613ace01605SRiver Riddle 614ace01605SRiver Riddle if (!caseValues) 615ace01605SRiver Riddle return getDefaultDestination(); 616ace01605SRiver Riddle 617ace01605SRiver Riddle SuccessorRange caseDests = getCaseDestinations(); 618c1fa60b4STres Popp if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) { 619ace01605SRiver Riddle for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) 620ace01605SRiver Riddle if (it.value() == value.getValue()) 621ace01605SRiver Riddle return caseDests[it.index()]; 622ace01605SRiver Riddle return getDefaultDestination(); 623ace01605SRiver Riddle } 624ace01605SRiver Riddle return nullptr; 625ace01605SRiver Riddle } 626ace01605SRiver Riddle 627ace01605SRiver Riddle /// switch %flag : i32, [ 628ace01605SRiver Riddle /// default: ^bb1 629ace01605SRiver Riddle /// ] 630ace01605SRiver Riddle /// -> br ^bb1 631ace01605SRiver Riddle static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, 632ace01605SRiver Riddle PatternRewriter &rewriter) { 633ace01605SRiver Riddle if (!op.getCaseDestinations().empty()) 634ace01605SRiver Riddle return failure(); 635ace01605SRiver Riddle 636ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 637ace01605SRiver Riddle op.getDefaultOperands()); 638ace01605SRiver Riddle return success(); 639ace01605SRiver Riddle } 640ace01605SRiver Riddle 641ace01605SRiver Riddle /// switch %flag : i32, [ 642ace01605SRiver Riddle /// default: ^bb1, 643ace01605SRiver Riddle /// 42: ^bb1, 644ace01605SRiver Riddle /// 43: ^bb2 645ace01605SRiver Riddle /// ] 646ace01605SRiver Riddle /// -> 647ace01605SRiver Riddle /// switch %flag : i32, [ 648ace01605SRiver Riddle /// default: ^bb1, 649ace01605SRiver Riddle /// 43: ^bb2 650ace01605SRiver Riddle /// ] 651ace01605SRiver Riddle static LogicalResult 652ace01605SRiver Riddle dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) { 653ace01605SRiver Riddle SmallVector<Block *> newCaseDestinations; 654ace01605SRiver Riddle SmallVector<ValueRange> newCaseOperands; 655ace01605SRiver Riddle SmallVector<APInt> newCaseValues; 656ace01605SRiver Riddle bool requiresChange = false; 657ace01605SRiver Riddle auto caseValues = op.getCaseValues(); 658ace01605SRiver Riddle auto caseDests = op.getCaseDestinations(); 659ace01605SRiver Riddle 660ace01605SRiver Riddle for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 661ace01605SRiver Riddle if (caseDests[it.index()] == op.getDefaultDestination() && 662ace01605SRiver Riddle op.getCaseOperands(it.index()) == op.getDefaultOperands()) { 663ace01605SRiver Riddle requiresChange = true; 664ace01605SRiver Riddle continue; 665ace01605SRiver Riddle } 666ace01605SRiver Riddle newCaseDestinations.push_back(caseDests[it.index()]); 667ace01605SRiver Riddle newCaseOperands.push_back(op.getCaseOperands(it.index())); 668ace01605SRiver Riddle newCaseValues.push_back(it.value()); 669ace01605SRiver Riddle } 670ace01605SRiver Riddle 671ace01605SRiver Riddle if (!requiresChange) 672ace01605SRiver Riddle return failure(); 673ace01605SRiver Riddle 674ace01605SRiver Riddle rewriter.replaceOpWithNewOp<SwitchOp>( 675ace01605SRiver Riddle op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), 676ace01605SRiver Riddle newCaseValues, newCaseDestinations, newCaseOperands); 677ace01605SRiver Riddle return success(); 678ace01605SRiver Riddle } 679ace01605SRiver Riddle 680ace01605SRiver Riddle /// Helper for folding a switch with a constant value. 681ace01605SRiver Riddle /// switch %c_42 : i32, [ 682ace01605SRiver Riddle /// default: ^bb1 , 683ace01605SRiver Riddle /// 42: ^bb2, 684ace01605SRiver Riddle /// 43: ^bb3 685ace01605SRiver Riddle /// ] 686ace01605SRiver Riddle /// -> br ^bb2 687ace01605SRiver Riddle static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, 688ace01605SRiver Riddle const APInt &caseValue) { 689ace01605SRiver Riddle auto caseValues = op.getCaseValues(); 690ace01605SRiver Riddle for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 691ace01605SRiver Riddle if (it.value() == caseValue) { 692ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>( 693ace01605SRiver Riddle op, op.getCaseDestinations()[it.index()], 694ace01605SRiver Riddle op.getCaseOperands(it.index())); 695ace01605SRiver Riddle return; 696ace01605SRiver Riddle } 697ace01605SRiver Riddle } 698ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 699ace01605SRiver Riddle op.getDefaultOperands()); 700ace01605SRiver Riddle } 701ace01605SRiver Riddle 702ace01605SRiver Riddle /// switch %c_42 : i32, [ 703ace01605SRiver Riddle /// default: ^bb1, 704ace01605SRiver Riddle /// 42: ^bb2, 705ace01605SRiver Riddle /// 43: ^bb3 706ace01605SRiver Riddle /// ] 707ace01605SRiver Riddle /// -> br ^bb2 708ace01605SRiver Riddle static LogicalResult simplifyConstSwitchValue(SwitchOp op, 709ace01605SRiver Riddle PatternRewriter &rewriter) { 710ace01605SRiver Riddle APInt caseValue; 711ace01605SRiver Riddle if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue))) 712ace01605SRiver Riddle return failure(); 713ace01605SRiver Riddle 714ace01605SRiver Riddle foldSwitch(op, rewriter, caseValue); 715ace01605SRiver Riddle return success(); 716ace01605SRiver Riddle } 717ace01605SRiver Riddle 718ace01605SRiver Riddle /// switch %c_42 : i32, [ 719ace01605SRiver Riddle /// default: ^bb1, 720ace01605SRiver Riddle /// 42: ^bb2, 721ace01605SRiver Riddle /// ] 722ace01605SRiver Riddle /// ^bb2: 723ace01605SRiver Riddle /// br ^bb3 724ace01605SRiver Riddle /// -> 725ace01605SRiver Riddle /// switch %c_42 : i32, [ 726ace01605SRiver Riddle /// default: ^bb1, 727ace01605SRiver Riddle /// 42: ^bb3, 728ace01605SRiver Riddle /// ] 729ace01605SRiver Riddle static LogicalResult simplifyPassThroughSwitch(SwitchOp op, 730ace01605SRiver Riddle PatternRewriter &rewriter) { 731ace01605SRiver Riddle SmallVector<Block *> newCaseDests; 732ace01605SRiver Riddle SmallVector<ValueRange> newCaseOperands; 733ace01605SRiver Riddle SmallVector<SmallVector<Value>> argStorage; 734ace01605SRiver Riddle auto caseValues = op.getCaseValues(); 735f735b3a2SVitaly Buka argStorage.reserve(caseValues->size() + 1); 736ace01605SRiver Riddle auto caseDests = op.getCaseDestinations(); 737ace01605SRiver Riddle bool requiresChange = false; 738ace01605SRiver Riddle for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { 739ace01605SRiver Riddle Block *caseDest = caseDests[i]; 740ace01605SRiver Riddle ValueRange caseOperands = op.getCaseOperands(i); 741ace01605SRiver Riddle argStorage.emplace_back(); 742ace01605SRiver Riddle if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back()))) 743ace01605SRiver Riddle requiresChange = true; 744ace01605SRiver Riddle 745ace01605SRiver Riddle newCaseDests.push_back(caseDest); 746ace01605SRiver Riddle newCaseOperands.push_back(caseOperands); 747ace01605SRiver Riddle } 748ace01605SRiver Riddle 749ace01605SRiver Riddle Block *defaultDest = op.getDefaultDestination(); 750ace01605SRiver Riddle ValueRange defaultOperands = op.getDefaultOperands(); 751ace01605SRiver Riddle argStorage.emplace_back(); 752ace01605SRiver Riddle 753ace01605SRiver Riddle if (succeeded( 754ace01605SRiver Riddle collapseBranch(defaultDest, defaultOperands, argStorage.back()))) 755ace01605SRiver Riddle requiresChange = true; 756ace01605SRiver Riddle 757ace01605SRiver Riddle if (!requiresChange) 758ace01605SRiver Riddle return failure(); 759ace01605SRiver Riddle 760ace01605SRiver Riddle rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest, 7616d5fc1e3SKazu Hirata defaultOperands, *caseValues, 762ace01605SRiver Riddle newCaseDests, newCaseOperands); 763ace01605SRiver Riddle return success(); 764ace01605SRiver Riddle } 765ace01605SRiver Riddle 766ace01605SRiver Riddle /// switch %flag : i32, [ 767ace01605SRiver Riddle /// default: ^bb1, 768ace01605SRiver Riddle /// 42: ^bb2, 769ace01605SRiver Riddle /// ] 770ace01605SRiver Riddle /// ^bb2: 771ace01605SRiver Riddle /// switch %flag : i32, [ 772ace01605SRiver Riddle /// default: ^bb3, 773ace01605SRiver Riddle /// 42: ^bb4 774ace01605SRiver Riddle /// ] 775ace01605SRiver Riddle /// -> 776ace01605SRiver Riddle /// switch %flag : i32, [ 777ace01605SRiver Riddle /// default: ^bb1, 778ace01605SRiver Riddle /// 42: ^bb2, 779ace01605SRiver Riddle /// ] 780ace01605SRiver Riddle /// ^bb2: 781ace01605SRiver Riddle /// br ^bb4 782ace01605SRiver Riddle /// 783ace01605SRiver Riddle /// and 784ace01605SRiver Riddle /// 785ace01605SRiver Riddle /// switch %flag : i32, [ 786ace01605SRiver Riddle /// default: ^bb1, 787ace01605SRiver Riddle /// 42: ^bb2, 788ace01605SRiver Riddle /// ] 789ace01605SRiver Riddle /// ^bb2: 790ace01605SRiver Riddle /// switch %flag : i32, [ 791ace01605SRiver Riddle /// default: ^bb3, 792ace01605SRiver Riddle /// 43: ^bb4 793ace01605SRiver Riddle /// ] 794ace01605SRiver Riddle /// -> 795ace01605SRiver Riddle /// switch %flag : i32, [ 796ace01605SRiver Riddle /// default: ^bb1, 797ace01605SRiver Riddle /// 42: ^bb2, 798ace01605SRiver Riddle /// ] 799ace01605SRiver Riddle /// ^bb2: 800ace01605SRiver Riddle /// br ^bb3 801ace01605SRiver Riddle static LogicalResult 802ace01605SRiver Riddle simplifySwitchFromSwitchOnSameCondition(SwitchOp op, 803ace01605SRiver Riddle PatternRewriter &rewriter) { 804ace01605SRiver Riddle // Check that we have a single distinct predecessor. 805ace01605SRiver Riddle Block *currentBlock = op->getBlock(); 806ace01605SRiver Riddle Block *predecessor = currentBlock->getSinglePredecessor(); 807ace01605SRiver Riddle if (!predecessor) 808ace01605SRiver Riddle return failure(); 809ace01605SRiver Riddle 810ace01605SRiver Riddle // Check that the predecessor terminates with a switch branch to this block 811ace01605SRiver Riddle // and that it branches on the same condition and that this branch isn't the 812ace01605SRiver Riddle // default destination. 813ace01605SRiver Riddle auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); 814ace01605SRiver Riddle if (!predSwitch || op.getFlag() != predSwitch.getFlag() || 815ace01605SRiver Riddle predSwitch.getDefaultDestination() == currentBlock) 816ace01605SRiver Riddle return failure(); 817ace01605SRiver Riddle 818ace01605SRiver Riddle // Fold this switch to an unconditional branch. 819ace01605SRiver Riddle SuccessorRange predDests = predSwitch.getCaseDestinations(); 820ace01605SRiver Riddle auto it = llvm::find(predDests, currentBlock); 821ace01605SRiver Riddle if (it != predDests.end()) { 82222426110SRamkumar Ramachandra std::optional<DenseIntElementsAttr> predCaseValues = 82322426110SRamkumar Ramachandra predSwitch.getCaseValues(); 824ace01605SRiver Riddle foldSwitch(op, rewriter, 825ace01605SRiver Riddle predCaseValues->getValues<APInt>()[it - predDests.begin()]); 826ace01605SRiver Riddle } else { 827ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), 828ace01605SRiver Riddle op.getDefaultOperands()); 829ace01605SRiver Riddle } 830ace01605SRiver Riddle return success(); 831ace01605SRiver Riddle } 832ace01605SRiver Riddle 833ace01605SRiver Riddle /// switch %flag : i32, [ 834ace01605SRiver Riddle /// default: ^bb1, 835ace01605SRiver Riddle /// 42: ^bb2 836ace01605SRiver Riddle /// ] 837ace01605SRiver Riddle /// ^bb1: 838ace01605SRiver Riddle /// switch %flag : i32, [ 839ace01605SRiver Riddle /// default: ^bb3, 840ace01605SRiver Riddle /// 42: ^bb4, 841ace01605SRiver Riddle /// 43: ^bb5 842ace01605SRiver Riddle /// ] 843ace01605SRiver Riddle /// -> 844ace01605SRiver Riddle /// switch %flag : i32, [ 845ace01605SRiver Riddle /// default: ^bb1, 846ace01605SRiver Riddle /// 42: ^bb2, 847ace01605SRiver Riddle /// ] 848ace01605SRiver Riddle /// ^bb1: 849ace01605SRiver Riddle /// switch %flag : i32, [ 850ace01605SRiver Riddle /// default: ^bb3, 851ace01605SRiver Riddle /// 43: ^bb5 852ace01605SRiver Riddle /// ] 853ace01605SRiver Riddle static LogicalResult 854ace01605SRiver Riddle simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, 855ace01605SRiver Riddle PatternRewriter &rewriter) { 856ace01605SRiver Riddle // Check that we have a single distinct predecessor. 857ace01605SRiver Riddle Block *currentBlock = op->getBlock(); 858ace01605SRiver Riddle Block *predecessor = currentBlock->getSinglePredecessor(); 859ace01605SRiver Riddle if (!predecessor) 860ace01605SRiver Riddle return failure(); 861ace01605SRiver Riddle 862ace01605SRiver Riddle // Check that the predecessor terminates with a switch branch to this block 863ace01605SRiver Riddle // and that it branches on the same condition and that this branch is the 864ace01605SRiver Riddle // default destination. 865ace01605SRiver Riddle auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); 866ace01605SRiver Riddle if (!predSwitch || op.getFlag() != predSwitch.getFlag() || 867ace01605SRiver Riddle predSwitch.getDefaultDestination() != currentBlock) 868ace01605SRiver Riddle return failure(); 869ace01605SRiver Riddle 870ace01605SRiver Riddle // Delete case values that are not possible here. 871ace01605SRiver Riddle DenseSet<APInt> caseValuesToRemove; 872ace01605SRiver Riddle auto predDests = predSwitch.getCaseDestinations(); 873ace01605SRiver Riddle auto predCaseValues = predSwitch.getCaseValues(); 874ace01605SRiver Riddle for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) 875ace01605SRiver Riddle if (currentBlock != predDests[i]) 876ace01605SRiver Riddle caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]); 877ace01605SRiver Riddle 878ace01605SRiver Riddle SmallVector<Block *> newCaseDestinations; 879ace01605SRiver Riddle SmallVector<ValueRange> newCaseOperands; 880ace01605SRiver Riddle SmallVector<APInt> newCaseValues; 881ace01605SRiver Riddle bool requiresChange = false; 882ace01605SRiver Riddle 883ace01605SRiver Riddle auto caseValues = op.getCaseValues(); 884ace01605SRiver Riddle auto caseDests = op.getCaseDestinations(); 885ace01605SRiver Riddle for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { 886ace01605SRiver Riddle if (caseValuesToRemove.contains(it.value())) { 887ace01605SRiver Riddle requiresChange = true; 888ace01605SRiver Riddle continue; 889ace01605SRiver Riddle } 890ace01605SRiver Riddle newCaseDestinations.push_back(caseDests[it.index()]); 891ace01605SRiver Riddle newCaseOperands.push_back(op.getCaseOperands(it.index())); 892ace01605SRiver Riddle newCaseValues.push_back(it.value()); 893ace01605SRiver Riddle } 894ace01605SRiver Riddle 895ace01605SRiver Riddle if (!requiresChange) 896ace01605SRiver Riddle return failure(); 897ace01605SRiver Riddle 898ace01605SRiver Riddle rewriter.replaceOpWithNewOp<SwitchOp>( 899ace01605SRiver Riddle op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), 900ace01605SRiver Riddle newCaseValues, newCaseDestinations, newCaseOperands); 901ace01605SRiver Riddle return success(); 902ace01605SRiver Riddle } 903ace01605SRiver Riddle 904ace01605SRiver Riddle void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, 905ace01605SRiver Riddle MLIRContext *context) { 906ace01605SRiver Riddle results.add(&simplifySwitchWithOnlyDefault) 907ace01605SRiver Riddle .add(&dropSwitchCasesThatMatchDefault) 908ace01605SRiver Riddle .add(&simplifyConstSwitchValue) 909ace01605SRiver Riddle .add(&simplifyPassThroughSwitch) 910ace01605SRiver Riddle .add(&simplifySwitchFromSwitchOnSameCondition) 911ace01605SRiver Riddle .add(&simplifySwitchFromDefaultSwitchOnSameCondition); 912ace01605SRiver Riddle } 913ace01605SRiver Riddle 914ace01605SRiver Riddle //===----------------------------------------------------------------------===// 915ace01605SRiver Riddle // TableGen'd op method definitions 916ace01605SRiver Riddle //===----------------------------------------------------------------------===// 917ace01605SRiver Riddle 918ace01605SRiver Riddle #define GET_OP_CLASSES 919ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" 920