xref: /llvm-project/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (revision e692af85966903614d470a7742ed89d124baf1a6)
1ace01605SRiver Riddle //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2ace01605SRiver Riddle //
3ace01605SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ace01605SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5ace01605SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ace01605SRiver Riddle //
7ace01605SRiver Riddle //===----------------------------------------------------------------------===//
8ace01605SRiver Riddle 
9ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10ace01605SRiver Riddle 
11b43c5049SJustin Fargnoli #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
13513cdb82SJustin Fargnoli #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
14513cdb82SJustin Fargnoli #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
15ace01605SRiver Riddle #include "mlir/IR/AffineExpr.h"
16ace01605SRiver Riddle #include "mlir/IR/AffineMap.h"
17ace01605SRiver Riddle #include "mlir/IR/Builders.h"
18ace01605SRiver Riddle #include "mlir/IR/BuiltinOps.h"
19ace01605SRiver Riddle #include "mlir/IR/BuiltinTypes.h"
204d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
21ace01605SRiver Riddle #include "mlir/IR/Matchers.h"
22ace01605SRiver Riddle #include "mlir/IR/OpImplementation.h"
23ace01605SRiver Riddle #include "mlir/IR/PatternMatch.h"
24ace01605SRiver Riddle #include "mlir/IR/TypeUtilities.h"
25ace01605SRiver Riddle #include "mlir/IR/Value.h"
26ace01605SRiver Riddle #include "mlir/Transforms/InliningUtils.h"
27ace01605SRiver Riddle #include "llvm/ADT/APFloat.h"
28ace01605SRiver Riddle #include "llvm/ADT/STLExtras.h"
29ace01605SRiver Riddle #include "llvm/Support/FormatVariadic.h"
30ace01605SRiver Riddle #include "llvm/Support/raw_ostream.h"
31ace01605SRiver Riddle #include <numeric>
32ace01605SRiver Riddle 
33ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
34ace01605SRiver Riddle 
35ace01605SRiver Riddle using namespace mlir;
36ace01605SRiver Riddle using namespace mlir::cf;
37ace01605SRiver Riddle 
38ace01605SRiver Riddle //===----------------------------------------------------------------------===//
39ace01605SRiver Riddle // ControlFlowDialect Interfaces
40ace01605SRiver Riddle //===----------------------------------------------------------------------===//
41ace01605SRiver Riddle namespace {
42ace01605SRiver Riddle /// This class defines the interface for handling inlining with control flow
43ace01605SRiver Riddle /// operations.
44ace01605SRiver Riddle struct ControlFlowInlinerInterface : public DialectInlinerInterface {
45ace01605SRiver Riddle   using DialectInlinerInterface::DialectInlinerInterface;
46ace01605SRiver Riddle   ~ControlFlowInlinerInterface() override = default;
47ace01605SRiver Riddle 
48ace01605SRiver Riddle   /// All control flow operations can be inlined.
49ace01605SRiver Riddle   bool isLegalToInline(Operation *call, Operation *callable,
50ace01605SRiver Riddle                        bool wouldBeCloned) const final {
51ace01605SRiver Riddle     return true;
52ace01605SRiver Riddle   }
534d67b278SJeff Niu   bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
54ace01605SRiver Riddle     return true;
55ace01605SRiver Riddle   }
56ace01605SRiver Riddle 
57ace01605SRiver Riddle   /// ControlFlow terminator operations don't really need any special handing.
58ace01605SRiver Riddle   void handleTerminator(Operation *op, Block *newDest) const final {}
59ace01605SRiver Riddle };
60ace01605SRiver Riddle } // namespace
61ace01605SRiver Riddle 
62ace01605SRiver Riddle //===----------------------------------------------------------------------===//
63ace01605SRiver Riddle // ControlFlowDialect
64ace01605SRiver Riddle //===----------------------------------------------------------------------===//
65ace01605SRiver Riddle 
66ace01605SRiver Riddle void ControlFlowDialect::initialize() {
67ace01605SRiver Riddle   addOperations<
68ace01605SRiver Riddle #define GET_OP_LIST
69ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
70ace01605SRiver Riddle       >();
71ace01605SRiver Riddle   addInterfaces<ControlFlowInlinerInterface>();
7235d55f28SJustin Fargnoli   declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
73513cdb82SJustin Fargnoli   declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
74513cdb82SJustin Fargnoli                             CondBranchOp>();
7535d55f28SJustin Fargnoli   declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
7635d55f28SJustin Fargnoli                            CondBranchOp>();
77ace01605SRiver Riddle }
78ace01605SRiver Riddle 
79ace01605SRiver Riddle //===----------------------------------------------------------------------===//
80ace01605SRiver Riddle // AssertOp
81ace01605SRiver Riddle //===----------------------------------------------------------------------===//
82ace01605SRiver Riddle 
83ace01605SRiver Riddle LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
84ace01605SRiver Riddle   // Erase assertion if argument is constant true.
85ace01605SRiver Riddle   if (matchPattern(op.getArg(), m_One())) {
86ace01605SRiver Riddle     rewriter.eraseOp(op);
87ace01605SRiver Riddle     return success();
88ace01605SRiver Riddle   }
89ace01605SRiver Riddle   return failure();
90ace01605SRiver Riddle }
91ace01605SRiver Riddle 
92a159b367SMcCowan Zhang // This side effect models "program termination".
93a159b367SMcCowan Zhang void AssertOp::getEffects(
94a159b367SMcCowan Zhang     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
95a159b367SMcCowan Zhang         &effects) {
96a159b367SMcCowan Zhang   effects.emplace_back(MemoryEffects::Write::get());
97a159b367SMcCowan Zhang }
98a159b367SMcCowan Zhang 
99ace01605SRiver Riddle //===----------------------------------------------------------------------===//
100ace01605SRiver Riddle // BranchOp
101ace01605SRiver Riddle //===----------------------------------------------------------------------===//
102ace01605SRiver Riddle 
103ace01605SRiver Riddle /// Given a successor, try to collapse it to a new destination if it only
104ace01605SRiver Riddle /// contains a passthrough unconditional branch. If the successor is
105ace01605SRiver Riddle /// collapsable, `successor` and `successorOperands` are updated to reference
106ace01605SRiver Riddle /// the new destination and values. `argStorage` is used as storage if operands
107ace01605SRiver Riddle /// to the collapsed successor need to be remapped. It must outlive uses of
108ace01605SRiver Riddle /// successorOperands.
109ace01605SRiver Riddle static LogicalResult collapseBranch(Block *&successor,
110ace01605SRiver Riddle                                     ValueRange &successorOperands,
111ace01605SRiver Riddle                                     SmallVectorImpl<Value> &argStorage) {
112ace01605SRiver Riddle   // Check that the successor only contains a unconditional branch.
113ace01605SRiver Riddle   if (std::next(successor->begin()) != successor->end())
114ace01605SRiver Riddle     return failure();
115ace01605SRiver Riddle   // Check that the terminator is an unconditional branch.
116ace01605SRiver Riddle   BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
117ace01605SRiver Riddle   if (!successorBranch)
118ace01605SRiver Riddle     return failure();
119ace01605SRiver Riddle   // Check that the arguments are only used within the terminator.
120ace01605SRiver Riddle   for (BlockArgument arg : successor->getArguments()) {
121ace01605SRiver Riddle     for (Operation *user : arg.getUsers())
122ace01605SRiver Riddle       if (user != successorBranch)
123ace01605SRiver Riddle         return failure();
124ace01605SRiver Riddle   }
125ace01605SRiver Riddle   // Don't try to collapse branches to infinite loops.
126ace01605SRiver Riddle   Block *successorDest = successorBranch.getDest();
127ace01605SRiver Riddle   if (successorDest == successor)
128ace01605SRiver Riddle     return failure();
129ace01605SRiver Riddle 
130ace01605SRiver Riddle   // Update the operands to the successor. If the branch parent has no
131ace01605SRiver Riddle   // arguments, we can use the branch operands directly.
132ace01605SRiver Riddle   OperandRange operands = successorBranch.getOperands();
133ace01605SRiver Riddle   if (successor->args_empty()) {
134ace01605SRiver Riddle     successor = successorDest;
135ace01605SRiver Riddle     successorOperands = operands;
136ace01605SRiver Riddle     return success();
137ace01605SRiver Riddle   }
138ace01605SRiver Riddle 
139ace01605SRiver Riddle   // Otherwise, we need to remap any argument operands.
140ace01605SRiver Riddle   for (Value operand : operands) {
141c1fa60b4STres Popp     BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
142ace01605SRiver Riddle     if (argOperand && argOperand.getOwner() == successor)
143ace01605SRiver Riddle       argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
144ace01605SRiver Riddle     else
145ace01605SRiver Riddle       argStorage.push_back(operand);
146ace01605SRiver Riddle   }
147ace01605SRiver Riddle   successor = successorDest;
148ace01605SRiver Riddle   successorOperands = argStorage;
149ace01605SRiver Riddle   return success();
150ace01605SRiver Riddle }
151ace01605SRiver Riddle 
152ace01605SRiver Riddle /// Simplify a branch to a block that has a single predecessor. This effectively
153ace01605SRiver Riddle /// merges the two blocks.
154ace01605SRiver Riddle static LogicalResult
155ace01605SRiver Riddle simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
156ace01605SRiver Riddle   // Check that the successor block has a single predecessor.
157ace01605SRiver Riddle   Block *succ = op.getDest();
158ace01605SRiver Riddle   Block *opParent = op->getBlock();
159ace01605SRiver Riddle   if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
160ace01605SRiver Riddle     return failure();
161ace01605SRiver Riddle 
162ace01605SRiver Riddle   // Merge the successor into the current block and erase the branch.
16342c31d83SMatthias Springer   SmallVector<Value> brOperands(op.getOperands());
164ace01605SRiver Riddle   rewriter.eraseOp(op);
16542c31d83SMatthias Springer   rewriter.mergeBlocks(succ, opParent, brOperands);
166ace01605SRiver Riddle   return success();
167ace01605SRiver Riddle }
168ace01605SRiver Riddle 
169ace01605SRiver Riddle ///   br ^bb1
170ace01605SRiver Riddle /// ^bb1
171ace01605SRiver Riddle ///   br ^bbN(...)
172ace01605SRiver Riddle ///
173ace01605SRiver Riddle ///  -> br ^bbN(...)
174ace01605SRiver Riddle ///
175ace01605SRiver Riddle static LogicalResult simplifyPassThroughBr(BranchOp op,
176ace01605SRiver Riddle                                            PatternRewriter &rewriter) {
177ace01605SRiver Riddle   Block *dest = op.getDest();
178ace01605SRiver Riddle   ValueRange destOperands = op.getOperands();
179ace01605SRiver Riddle   SmallVector<Value, 4> destOperandStorage;
180ace01605SRiver Riddle 
181ace01605SRiver Riddle   // Try to collapse the successor if it points somewhere other than this
182ace01605SRiver Riddle   // block.
183ace01605SRiver Riddle   if (dest == op->getBlock() ||
184ace01605SRiver Riddle       failed(collapseBranch(dest, destOperands, destOperandStorage)))
185ace01605SRiver Riddle     return failure();
186ace01605SRiver Riddle 
187ace01605SRiver Riddle   // Create a new branch with the collapsed successor.
188ace01605SRiver Riddle   rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
189ace01605SRiver Riddle   return success();
190ace01605SRiver Riddle }
191ace01605SRiver Riddle 
192ace01605SRiver Riddle LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
193ace01605SRiver Riddle   return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
194ace01605SRiver Riddle                  succeeded(simplifyPassThroughBr(op, rewriter)));
195ace01605SRiver Riddle }
196ace01605SRiver Riddle 
197ace01605SRiver Riddle void BranchOp::setDest(Block *block) { return setSuccessor(block); }
198ace01605SRiver Riddle 
199ace01605SRiver Riddle void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
200ace01605SRiver Riddle 
2010c789db5SMarkus Böck SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
202ace01605SRiver Riddle   assert(index == 0 && "invalid successor index");
2030c789db5SMarkus Böck   return SuccessorOperands(getDestOperandsMutable());
204ace01605SRiver Riddle }
205ace01605SRiver Riddle 
206ace01605SRiver Riddle Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
207ace01605SRiver Riddle   return getDest();
208ace01605SRiver Riddle }
209ace01605SRiver Riddle 
210ace01605SRiver Riddle //===----------------------------------------------------------------------===//
211ace01605SRiver Riddle // CondBranchOp
212ace01605SRiver Riddle //===----------------------------------------------------------------------===//
213ace01605SRiver Riddle 
214ace01605SRiver Riddle namespace {
215ace01605SRiver Riddle /// cf.cond_br true, ^bb1, ^bb2
216ace01605SRiver Riddle ///  -> br ^bb1
217ace01605SRiver Riddle /// cf.cond_br false, ^bb1, ^bb2
218ace01605SRiver Riddle ///  -> br ^bb2
219ace01605SRiver Riddle ///
220ace01605SRiver Riddle struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
221ace01605SRiver Riddle   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
222ace01605SRiver Riddle 
223ace01605SRiver Riddle   LogicalResult matchAndRewrite(CondBranchOp condbr,
224ace01605SRiver Riddle                                 PatternRewriter &rewriter) const override {
225ace01605SRiver Riddle     if (matchPattern(condbr.getCondition(), m_NonZero())) {
226ace01605SRiver Riddle       // True branch taken.
227ace01605SRiver Riddle       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
228ace01605SRiver Riddle                                             condbr.getTrueOperands());
229ace01605SRiver Riddle       return success();
230ace01605SRiver Riddle     }
231ace01605SRiver Riddle     if (matchPattern(condbr.getCondition(), m_Zero())) {
232ace01605SRiver Riddle       // False branch taken.
233ace01605SRiver Riddle       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
234ace01605SRiver Riddle                                             condbr.getFalseOperands());
235ace01605SRiver Riddle       return success();
236ace01605SRiver Riddle     }
237ace01605SRiver Riddle     return failure();
238ace01605SRiver Riddle   }
239ace01605SRiver Riddle };
240ace01605SRiver Riddle 
241ace01605SRiver Riddle ///   cf.cond_br %cond, ^bb1, ^bb2
242ace01605SRiver Riddle /// ^bb1
243ace01605SRiver Riddle ///   br ^bbN(...)
244ace01605SRiver Riddle /// ^bb2
245ace01605SRiver Riddle ///   br ^bbK(...)
246ace01605SRiver Riddle ///
247ace01605SRiver Riddle ///  -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
248ace01605SRiver Riddle ///
249ace01605SRiver Riddle struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
250ace01605SRiver Riddle   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
251ace01605SRiver Riddle 
252ace01605SRiver Riddle   LogicalResult matchAndRewrite(CondBranchOp condbr,
253ace01605SRiver Riddle                                 PatternRewriter &rewriter) const override {
254ace01605SRiver Riddle     Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
255ace01605SRiver Riddle     ValueRange trueDestOperands = condbr.getTrueOperands();
256ace01605SRiver Riddle     ValueRange falseDestOperands = condbr.getFalseOperands();
257ace01605SRiver Riddle     SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
258ace01605SRiver Riddle 
259ace01605SRiver Riddle     // Try to collapse one of the current successors.
260ace01605SRiver Riddle     LogicalResult collapsedTrue =
261ace01605SRiver Riddle         collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
262ace01605SRiver Riddle     LogicalResult collapsedFalse =
263ace01605SRiver Riddle         collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
264ace01605SRiver Riddle     if (failed(collapsedTrue) && failed(collapsedFalse))
265ace01605SRiver Riddle       return failure();
266ace01605SRiver Riddle 
267ace01605SRiver Riddle     // Create a new branch with the collapsed successors.
268ace01605SRiver Riddle     rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
269ace01605SRiver Riddle                                               trueDest, trueDestOperands,
270ace01605SRiver Riddle                                               falseDest, falseDestOperands);
271ace01605SRiver Riddle     return success();
272ace01605SRiver Riddle   }
273ace01605SRiver Riddle };
274ace01605SRiver Riddle 
275ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
276ace01605SRiver Riddle ///  -> br ^bb1(A, ..., N)
277ace01605SRiver Riddle ///
278ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1(A), ^bb1(B)
279ace01605SRiver Riddle ///  -> %select = arith.select %cond, A, B
280ace01605SRiver Riddle ///     br ^bb1(%select)
281ace01605SRiver Riddle ///
282ace01605SRiver Riddle struct SimplifyCondBranchIdenticalSuccessors
283ace01605SRiver Riddle     : public OpRewritePattern<CondBranchOp> {
284ace01605SRiver Riddle   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
285ace01605SRiver Riddle 
286ace01605SRiver Riddle   LogicalResult matchAndRewrite(CondBranchOp condbr,
287ace01605SRiver Riddle                                 PatternRewriter &rewriter) const override {
288ace01605SRiver Riddle     // Check that the true and false destinations are the same and have the same
289ace01605SRiver Riddle     // operands.
290ace01605SRiver Riddle     Block *trueDest = condbr.getTrueDest();
291ace01605SRiver Riddle     if (trueDest != condbr.getFalseDest())
292ace01605SRiver Riddle       return failure();
293ace01605SRiver Riddle 
294ace01605SRiver Riddle     // If all of the operands match, no selects need to be generated.
295ace01605SRiver Riddle     OperandRange trueOperands = condbr.getTrueOperands();
296ace01605SRiver Riddle     OperandRange falseOperands = condbr.getFalseOperands();
297ace01605SRiver Riddle     if (trueOperands == falseOperands) {
298ace01605SRiver Riddle       rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
299ace01605SRiver Riddle       return success();
300ace01605SRiver Riddle     }
301ace01605SRiver Riddle 
302ace01605SRiver Riddle     // Otherwise, if the current block is the only predecessor insert selects
303ace01605SRiver Riddle     // for any mismatched branch operands.
304ace01605SRiver Riddle     if (trueDest->getUniquePredecessor() != condbr->getBlock())
305ace01605SRiver Riddle       return failure();
306ace01605SRiver Riddle 
307ace01605SRiver Riddle     // Generate a select for any operands that differ between the two.
308ace01605SRiver Riddle     SmallVector<Value, 8> mergedOperands;
309ace01605SRiver Riddle     mergedOperands.reserve(trueOperands.size());
310ace01605SRiver Riddle     Value condition = condbr.getCondition();
311ace01605SRiver Riddle     for (auto it : llvm::zip(trueOperands, falseOperands)) {
312ace01605SRiver Riddle       if (std::get<0>(it) == std::get<1>(it))
313ace01605SRiver Riddle         mergedOperands.push_back(std::get<0>(it));
314ace01605SRiver Riddle       else
315ace01605SRiver Riddle         mergedOperands.push_back(rewriter.create<arith::SelectOp>(
316ace01605SRiver Riddle             condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
317ace01605SRiver Riddle     }
318ace01605SRiver Riddle 
319ace01605SRiver Riddle     rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
320ace01605SRiver Riddle     return success();
321ace01605SRiver Riddle   }
322ace01605SRiver Riddle };
323ace01605SRiver Riddle 
324ace01605SRiver Riddle ///   ...
325ace01605SRiver Riddle ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
326ace01605SRiver Riddle /// ...
327ace01605SRiver Riddle /// ^bb1: // has single predecessor
328ace01605SRiver Riddle ///   ...
329ace01605SRiver Riddle ///   cf.cond_br %cond, ^bb3(...), ^bb4(...)
330ace01605SRiver Riddle ///
331ace01605SRiver Riddle /// ->
332ace01605SRiver Riddle ///
333ace01605SRiver Riddle ///   ...
334ace01605SRiver Riddle ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
335ace01605SRiver Riddle /// ...
336ace01605SRiver Riddle /// ^bb1: // has single predecessor
337ace01605SRiver Riddle ///   ...
338ace01605SRiver Riddle ///   br ^bb3(...)
339ace01605SRiver Riddle ///
340ace01605SRiver Riddle struct SimplifyCondBranchFromCondBranchOnSameCondition
341ace01605SRiver Riddle     : public OpRewritePattern<CondBranchOp> {
342ace01605SRiver Riddle   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
343ace01605SRiver Riddle 
344ace01605SRiver Riddle   LogicalResult matchAndRewrite(CondBranchOp condbr,
345ace01605SRiver Riddle                                 PatternRewriter &rewriter) const override {
346ace01605SRiver Riddle     // Check that we have a single distinct predecessor.
347ace01605SRiver Riddle     Block *currentBlock = condbr->getBlock();
348ace01605SRiver Riddle     Block *predecessor = currentBlock->getSinglePredecessor();
349ace01605SRiver Riddle     if (!predecessor)
350ace01605SRiver Riddle       return failure();
351ace01605SRiver Riddle 
352ace01605SRiver Riddle     // Check that the predecessor terminates with a conditional branch to this
353ace01605SRiver Riddle     // block and that it branches on the same condition.
354ace01605SRiver Riddle     auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
355ace01605SRiver Riddle     if (!predBranch || condbr.getCondition() != predBranch.getCondition())
356ace01605SRiver Riddle       return failure();
357ace01605SRiver Riddle 
358ace01605SRiver Riddle     // Fold this branch to an unconditional branch.
359ace01605SRiver Riddle     if (currentBlock == predBranch.getTrueDest())
360ace01605SRiver Riddle       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
361ace01605SRiver Riddle                                             condbr.getTrueDestOperands());
362ace01605SRiver Riddle     else
363ace01605SRiver Riddle       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
364ace01605SRiver Riddle                                             condbr.getFalseDestOperands());
365ace01605SRiver Riddle     return success();
366ace01605SRiver Riddle   }
367ace01605SRiver Riddle };
368ace01605SRiver Riddle 
369ace01605SRiver Riddle ///   cf.cond_br %arg0, ^trueB, ^falseB
370ace01605SRiver Riddle ///
371ace01605SRiver Riddle /// ^trueB:
372ace01605SRiver Riddle ///   "test.consumer1"(%arg0) : (i1) -> ()
373ace01605SRiver Riddle ///    ...
374ace01605SRiver Riddle ///
375ace01605SRiver Riddle /// ^falseB:
376ace01605SRiver Riddle ///   "test.consumer2"(%arg0) : (i1) -> ()
377ace01605SRiver Riddle ///   ...
378ace01605SRiver Riddle ///
379ace01605SRiver Riddle /// ->
380ace01605SRiver Riddle ///
381ace01605SRiver Riddle ///   cf.cond_br %arg0, ^trueB, ^falseB
382ace01605SRiver Riddle /// ^trueB:
383ace01605SRiver Riddle ///   "test.consumer1"(%true) : (i1) -> ()
384ace01605SRiver Riddle ///   ...
385ace01605SRiver Riddle ///
386ace01605SRiver Riddle /// ^falseB:
387ace01605SRiver Riddle ///   "test.consumer2"(%false) : (i1) -> ()
388ace01605SRiver Riddle ///   ...
389ace01605SRiver Riddle struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
390ace01605SRiver Riddle   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
391ace01605SRiver Riddle 
392ace01605SRiver Riddle   LogicalResult matchAndRewrite(CondBranchOp condbr,
393ace01605SRiver Riddle                                 PatternRewriter &rewriter) const override {
394ace01605SRiver Riddle     // Check that we have a single distinct predecessor.
395ace01605SRiver Riddle     bool replaced = false;
396ace01605SRiver Riddle     Type ty = rewriter.getI1Type();
397ace01605SRiver Riddle 
398ace01605SRiver Riddle     // These variables serve to prevent creating duplicate constants
399ace01605SRiver Riddle     // and hold constant true or false values.
400ace01605SRiver Riddle     Value constantTrue = nullptr;
401ace01605SRiver Riddle     Value constantFalse = nullptr;
402ace01605SRiver Riddle 
403ace01605SRiver Riddle     // TODO These checks can be expanded to encompas any use with only
404ace01605SRiver Riddle     // either the true of false edge as a predecessor. For now, we fall
405ace01605SRiver Riddle     // back to checking the single predecessor is given by the true/fasle
406ace01605SRiver Riddle     // destination, thereby ensuring that only that edge can reach the
407ace01605SRiver Riddle     // op.
408ace01605SRiver Riddle     if (condbr.getTrueDest()->getSinglePredecessor()) {
409ace01605SRiver Riddle       for (OpOperand &use :
410ace01605SRiver Riddle            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
411ace01605SRiver Riddle         if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
412ace01605SRiver Riddle           replaced = true;
413ace01605SRiver Riddle 
414ace01605SRiver Riddle           if (!constantTrue)
415ace01605SRiver Riddle             constantTrue = rewriter.create<arith::ConstantOp>(
416ace01605SRiver Riddle                 condbr.getLoc(), ty, rewriter.getBoolAttr(true));
417ace01605SRiver Riddle 
4185fcf907bSMatthias Springer           rewriter.modifyOpInPlace(use.getOwner(),
419ace01605SRiver Riddle                                    [&] { use.set(constantTrue); });
420ace01605SRiver Riddle         }
421ace01605SRiver Riddle       }
422ace01605SRiver Riddle     }
423ace01605SRiver Riddle     if (condbr.getFalseDest()->getSinglePredecessor()) {
424ace01605SRiver Riddle       for (OpOperand &use :
425ace01605SRiver Riddle            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
426ace01605SRiver Riddle         if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
427ace01605SRiver Riddle           replaced = true;
428ace01605SRiver Riddle 
429ace01605SRiver Riddle           if (!constantFalse)
430ace01605SRiver Riddle             constantFalse = rewriter.create<arith::ConstantOp>(
431ace01605SRiver Riddle                 condbr.getLoc(), ty, rewriter.getBoolAttr(false));
432ace01605SRiver Riddle 
4335fcf907bSMatthias Springer           rewriter.modifyOpInPlace(use.getOwner(),
434ace01605SRiver Riddle                                    [&] { use.set(constantFalse); });
435ace01605SRiver Riddle         }
436ace01605SRiver Riddle       }
437ace01605SRiver Riddle     }
438ace01605SRiver Riddle     return success(replaced);
439ace01605SRiver Riddle   }
440ace01605SRiver Riddle };
441ace01605SRiver Riddle } // namespace
442ace01605SRiver Riddle 
443ace01605SRiver Riddle void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
444ace01605SRiver Riddle                                                MLIRContext *context) {
445ace01605SRiver Riddle   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
446ace01605SRiver Riddle               SimplifyCondBranchIdenticalSuccessors,
447ace01605SRiver Riddle               SimplifyCondBranchFromCondBranchOnSameCondition,
448ace01605SRiver Riddle               CondBranchTruthPropagation>(context);
449ace01605SRiver Riddle }
450ace01605SRiver Riddle 
4510c789db5SMarkus Böck SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
452ace01605SRiver Riddle   assert(index < getNumSuccessors() && "invalid successor index");
4530c789db5SMarkus Böck   return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
4540c789db5SMarkus Böck                                               : getFalseDestOperandsMutable());
455ace01605SRiver Riddle }
456ace01605SRiver Riddle 
457ace01605SRiver Riddle Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
458c1fa60b4STres Popp   if (IntegerAttr condAttr =
459c1fa60b4STres Popp           llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
4609e5d2495SKazu Hirata     return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
461ace01605SRiver Riddle   return nullptr;
462ace01605SRiver Riddle }
463ace01605SRiver Riddle 
464ace01605SRiver Riddle //===----------------------------------------------------------------------===//
465ace01605SRiver Riddle // SwitchOp
466ace01605SRiver Riddle //===----------------------------------------------------------------------===//
467ace01605SRiver Riddle 
468ace01605SRiver Riddle void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
469ace01605SRiver Riddle                      Block *defaultDestination, ValueRange defaultOperands,
470ace01605SRiver Riddle                      DenseIntElementsAttr caseValues,
471ace01605SRiver Riddle                      BlockRange caseDestinations,
472ace01605SRiver Riddle                      ArrayRef<ValueRange> caseOperands) {
473ace01605SRiver Riddle   build(builder, result, value, defaultOperands, caseOperands, caseValues,
474ace01605SRiver Riddle         defaultDestination, caseDestinations);
475ace01605SRiver Riddle }
476ace01605SRiver Riddle 
477ace01605SRiver Riddle void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
478ace01605SRiver Riddle                      Block *defaultDestination, ValueRange defaultOperands,
479ace01605SRiver Riddle                      ArrayRef<APInt> caseValues, BlockRange caseDestinations,
480ace01605SRiver Riddle                      ArrayRef<ValueRange> caseOperands) {
481ace01605SRiver Riddle   DenseIntElementsAttr caseValuesAttr;
482ace01605SRiver Riddle   if (!caseValues.empty()) {
483ace01605SRiver Riddle     ShapedType caseValueType = VectorType::get(
484ace01605SRiver Riddle         static_cast<int64_t>(caseValues.size()), value.getType());
485ace01605SRiver Riddle     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
486ace01605SRiver Riddle   }
487ace01605SRiver Riddle   build(builder, result, value, defaultDestination, defaultOperands,
488ace01605SRiver Riddle         caseValuesAttr, caseDestinations, caseOperands);
489ace01605SRiver Riddle }
490ace01605SRiver Riddle 
491b34fb277SAlexander Batashev void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
492b34fb277SAlexander Batashev                      Block *defaultDestination, ValueRange defaultOperands,
493b34fb277SAlexander Batashev                      ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
494b34fb277SAlexander Batashev                      ArrayRef<ValueRange> caseOperands) {
495b34fb277SAlexander Batashev   DenseIntElementsAttr caseValuesAttr;
496b34fb277SAlexander Batashev   if (!caseValues.empty()) {
497b34fb277SAlexander Batashev     ShapedType caseValueType = VectorType::get(
498b34fb277SAlexander Batashev         static_cast<int64_t>(caseValues.size()), value.getType());
499b34fb277SAlexander Batashev     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
500b34fb277SAlexander Batashev   }
501b34fb277SAlexander Batashev   build(builder, result, value, defaultDestination, defaultOperands,
502b34fb277SAlexander Batashev         caseValuesAttr, caseDestinations, caseOperands);
503b34fb277SAlexander Batashev }
504b34fb277SAlexander Batashev 
505ace01605SRiver Riddle /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
506ace01605SRiver Riddle ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
507ace01605SRiver Riddle static ParseResult parseSwitchOpCases(
508ace01605SRiver Riddle     OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
509e13d23bcSMarkus Böck     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
510ace01605SRiver Riddle     SmallVectorImpl<Type> &defaultOperandTypes,
511ace01605SRiver Riddle     DenseIntElementsAttr &caseValues,
512ace01605SRiver Riddle     SmallVectorImpl<Block *> &caseDestinations,
513e13d23bcSMarkus Böck     SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
514ace01605SRiver Riddle     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
515ace01605SRiver Riddle   if (parser.parseKeyword("default") || parser.parseColon() ||
516ace01605SRiver Riddle       parser.parseSuccessor(defaultDestination))
517ace01605SRiver Riddle     return failure();
518ace01605SRiver Riddle   if (succeeded(parser.parseOptionalLParen())) {
5195dedf911SChris Lattner     if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
5205dedf911SChris Lattner                                 /*allowResultNumber=*/false) ||
521ace01605SRiver Riddle         parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
522ace01605SRiver Riddle       return failure();
523ace01605SRiver Riddle   }
524ace01605SRiver Riddle 
525ace01605SRiver Riddle   SmallVector<APInt> values;
526ace01605SRiver Riddle   unsigned bitWidth = flagType.getIntOrFloatBitWidth();
527ace01605SRiver Riddle   while (succeeded(parser.parseOptionalComma())) {
528ace01605SRiver Riddle     int64_t value = 0;
529ace01605SRiver Riddle     if (failed(parser.parseInteger(value)))
530ace01605SRiver Riddle       return failure();
531*e692af85SNikita Popov     values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
532ace01605SRiver Riddle 
533ace01605SRiver Riddle     Block *destination;
534e13d23bcSMarkus Böck     SmallVector<OpAsmParser::UnresolvedOperand> operands;
535ace01605SRiver Riddle     SmallVector<Type> operandTypes;
536ace01605SRiver Riddle     if (failed(parser.parseColon()) ||
537ace01605SRiver Riddle         failed(parser.parseSuccessor(destination)))
538ace01605SRiver Riddle       return failure();
539ace01605SRiver Riddle     if (succeeded(parser.parseOptionalLParen())) {
5407e87d03bSKeyi Zhang       if (failed(parser.parseOperandList(operands,
5417e87d03bSKeyi Zhang                                          OpAsmParser::Delimiter::None)) ||
542ace01605SRiver Riddle           failed(parser.parseColonTypeList(operandTypes)) ||
543ace01605SRiver Riddle           failed(parser.parseRParen()))
544ace01605SRiver Riddle         return failure();
545ace01605SRiver Riddle     }
546ace01605SRiver Riddle     caseDestinations.push_back(destination);
547ace01605SRiver Riddle     caseOperands.emplace_back(operands);
548ace01605SRiver Riddle     caseOperandTypes.emplace_back(operandTypes);
549ace01605SRiver Riddle   }
550ace01605SRiver Riddle 
551ace01605SRiver Riddle   if (!values.empty()) {
552ace01605SRiver Riddle     ShapedType caseValueType =
553ace01605SRiver Riddle         VectorType::get(static_cast<int64_t>(values.size()), flagType);
554ace01605SRiver Riddle     caseValues = DenseIntElementsAttr::get(caseValueType, values);
555ace01605SRiver Riddle   }
556ace01605SRiver Riddle   return success();
557ace01605SRiver Riddle }
558ace01605SRiver Riddle 
559ace01605SRiver Riddle static void printSwitchOpCases(
560ace01605SRiver Riddle     OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
561ace01605SRiver Riddle     OperandRange defaultOperands, TypeRange defaultOperandTypes,
562ace01605SRiver Riddle     DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
563ace01605SRiver Riddle     OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
564ace01605SRiver Riddle   p << "  default: ";
565ace01605SRiver Riddle   p.printSuccessorAndUseList(defaultDestination, defaultOperands);
566ace01605SRiver Riddle 
567ace01605SRiver Riddle   if (!caseValues)
568ace01605SRiver Riddle     return;
569ace01605SRiver Riddle 
570ace01605SRiver Riddle   for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
571ace01605SRiver Riddle     p << ',';
572ace01605SRiver Riddle     p.printNewline();
573ace01605SRiver Riddle     p << "  ";
574ace01605SRiver Riddle     p << it.value().getLimitedValue();
575ace01605SRiver Riddle     p << ": ";
576ace01605SRiver Riddle     p.printSuccessorAndUseList(caseDestinations[it.index()],
577ace01605SRiver Riddle                                caseOperands[it.index()]);
578ace01605SRiver Riddle   }
579ace01605SRiver Riddle   p.printNewline();
580ace01605SRiver Riddle }
581ace01605SRiver Riddle 
582ace01605SRiver Riddle LogicalResult SwitchOp::verify() {
583ace01605SRiver Riddle   auto caseValues = getCaseValues();
584ace01605SRiver Riddle   auto caseDestinations = getCaseDestinations();
585ace01605SRiver Riddle 
586ace01605SRiver Riddle   if (!caseValues && caseDestinations.empty())
587ace01605SRiver Riddle     return success();
588ace01605SRiver Riddle 
589ace01605SRiver Riddle   Type flagType = getFlag().getType();
590ace01605SRiver Riddle   Type caseValueType = caseValues->getType().getElementType();
591ace01605SRiver Riddle   if (caseValueType != flagType)
592ace01605SRiver Riddle     return emitOpError() << "'flag' type (" << flagType
593ace01605SRiver Riddle                          << ") should match case value type (" << caseValueType
594ace01605SRiver Riddle                          << ")";
595ace01605SRiver Riddle 
596ace01605SRiver Riddle   if (caseValues &&
597ace01605SRiver Riddle       caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
598ace01605SRiver Riddle     return emitOpError() << "number of case values (" << caseValues->size()
599ace01605SRiver Riddle                          << ") should match number of "
600ace01605SRiver Riddle                             "case destinations ("
601ace01605SRiver Riddle                          << caseDestinations.size() << ")";
602ace01605SRiver Riddle   return success();
603ace01605SRiver Riddle }
604ace01605SRiver Riddle 
6050c789db5SMarkus Böck SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
606ace01605SRiver Riddle   assert(index < getNumSuccessors() && "invalid successor index");
6070c789db5SMarkus Böck   return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
6080c789db5SMarkus Böck                                       : getCaseOperandsMutable(index - 1));
609ace01605SRiver Riddle }
610ace01605SRiver Riddle 
611ace01605SRiver Riddle Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
61222426110SRamkumar Ramachandra   std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
613ace01605SRiver Riddle 
614ace01605SRiver Riddle   if (!caseValues)
615ace01605SRiver Riddle     return getDefaultDestination();
616ace01605SRiver Riddle 
617ace01605SRiver Riddle   SuccessorRange caseDests = getCaseDestinations();
618c1fa60b4STres Popp   if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
619ace01605SRiver Riddle     for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
620ace01605SRiver Riddle       if (it.value() == value.getValue())
621ace01605SRiver Riddle         return caseDests[it.index()];
622ace01605SRiver Riddle     return getDefaultDestination();
623ace01605SRiver Riddle   }
624ace01605SRiver Riddle   return nullptr;
625ace01605SRiver Riddle }
626ace01605SRiver Riddle 
627ace01605SRiver Riddle /// switch %flag : i32, [
628ace01605SRiver Riddle ///   default:  ^bb1
629ace01605SRiver Riddle /// ]
630ace01605SRiver Riddle ///  -> br ^bb1
631ace01605SRiver Riddle static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
632ace01605SRiver Riddle                                                    PatternRewriter &rewriter) {
633ace01605SRiver Riddle   if (!op.getCaseDestinations().empty())
634ace01605SRiver Riddle     return failure();
635ace01605SRiver Riddle 
636ace01605SRiver Riddle   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
637ace01605SRiver Riddle                                         op.getDefaultOperands());
638ace01605SRiver Riddle   return success();
639ace01605SRiver Riddle }
640ace01605SRiver Riddle 
641ace01605SRiver Riddle /// switch %flag : i32, [
642ace01605SRiver Riddle ///   default: ^bb1,
643ace01605SRiver Riddle ///   42: ^bb1,
644ace01605SRiver Riddle ///   43: ^bb2
645ace01605SRiver Riddle /// ]
646ace01605SRiver Riddle /// ->
647ace01605SRiver Riddle /// switch %flag : i32, [
648ace01605SRiver Riddle ///   default: ^bb1,
649ace01605SRiver Riddle ///   43: ^bb2
650ace01605SRiver Riddle /// ]
651ace01605SRiver Riddle static LogicalResult
652ace01605SRiver Riddle dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
653ace01605SRiver Riddle   SmallVector<Block *> newCaseDestinations;
654ace01605SRiver Riddle   SmallVector<ValueRange> newCaseOperands;
655ace01605SRiver Riddle   SmallVector<APInt> newCaseValues;
656ace01605SRiver Riddle   bool requiresChange = false;
657ace01605SRiver Riddle   auto caseValues = op.getCaseValues();
658ace01605SRiver Riddle   auto caseDests = op.getCaseDestinations();
659ace01605SRiver Riddle 
660ace01605SRiver Riddle   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
661ace01605SRiver Riddle     if (caseDests[it.index()] == op.getDefaultDestination() &&
662ace01605SRiver Riddle         op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
663ace01605SRiver Riddle       requiresChange = true;
664ace01605SRiver Riddle       continue;
665ace01605SRiver Riddle     }
666ace01605SRiver Riddle     newCaseDestinations.push_back(caseDests[it.index()]);
667ace01605SRiver Riddle     newCaseOperands.push_back(op.getCaseOperands(it.index()));
668ace01605SRiver Riddle     newCaseValues.push_back(it.value());
669ace01605SRiver Riddle   }
670ace01605SRiver Riddle 
671ace01605SRiver Riddle   if (!requiresChange)
672ace01605SRiver Riddle     return failure();
673ace01605SRiver Riddle 
674ace01605SRiver Riddle   rewriter.replaceOpWithNewOp<SwitchOp>(
675ace01605SRiver Riddle       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
676ace01605SRiver Riddle       newCaseValues, newCaseDestinations, newCaseOperands);
677ace01605SRiver Riddle   return success();
678ace01605SRiver Riddle }
679ace01605SRiver Riddle 
680ace01605SRiver Riddle /// Helper for folding a switch with a constant value.
681ace01605SRiver Riddle /// switch %c_42 : i32, [
682ace01605SRiver Riddle ///   default: ^bb1 ,
683ace01605SRiver Riddle ///   42: ^bb2,
684ace01605SRiver Riddle ///   43: ^bb3
685ace01605SRiver Riddle /// ]
686ace01605SRiver Riddle /// -> br ^bb2
687ace01605SRiver Riddle static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
688ace01605SRiver Riddle                        const APInt &caseValue) {
689ace01605SRiver Riddle   auto caseValues = op.getCaseValues();
690ace01605SRiver Riddle   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
691ace01605SRiver Riddle     if (it.value() == caseValue) {
692ace01605SRiver Riddle       rewriter.replaceOpWithNewOp<BranchOp>(
693ace01605SRiver Riddle           op, op.getCaseDestinations()[it.index()],
694ace01605SRiver Riddle           op.getCaseOperands(it.index()));
695ace01605SRiver Riddle       return;
696ace01605SRiver Riddle     }
697ace01605SRiver Riddle   }
698ace01605SRiver Riddle   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
699ace01605SRiver Riddle                                         op.getDefaultOperands());
700ace01605SRiver Riddle }
701ace01605SRiver Riddle 
702ace01605SRiver Riddle /// switch %c_42 : i32, [
703ace01605SRiver Riddle ///   default: ^bb1,
704ace01605SRiver Riddle ///   42: ^bb2,
705ace01605SRiver Riddle ///   43: ^bb3
706ace01605SRiver Riddle /// ]
707ace01605SRiver Riddle /// -> br ^bb2
708ace01605SRiver Riddle static LogicalResult simplifyConstSwitchValue(SwitchOp op,
709ace01605SRiver Riddle                                               PatternRewriter &rewriter) {
710ace01605SRiver Riddle   APInt caseValue;
711ace01605SRiver Riddle   if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
712ace01605SRiver Riddle     return failure();
713ace01605SRiver Riddle 
714ace01605SRiver Riddle   foldSwitch(op, rewriter, caseValue);
715ace01605SRiver Riddle   return success();
716ace01605SRiver Riddle }
717ace01605SRiver Riddle 
718ace01605SRiver Riddle /// switch %c_42 : i32, [
719ace01605SRiver Riddle ///   default: ^bb1,
720ace01605SRiver Riddle ///   42: ^bb2,
721ace01605SRiver Riddle /// ]
722ace01605SRiver Riddle /// ^bb2:
723ace01605SRiver Riddle ///   br ^bb3
724ace01605SRiver Riddle /// ->
725ace01605SRiver Riddle /// switch %c_42 : i32, [
726ace01605SRiver Riddle ///   default: ^bb1,
727ace01605SRiver Riddle ///   42: ^bb3,
728ace01605SRiver Riddle /// ]
729ace01605SRiver Riddle static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
730ace01605SRiver Riddle                                                PatternRewriter &rewriter) {
731ace01605SRiver Riddle   SmallVector<Block *> newCaseDests;
732ace01605SRiver Riddle   SmallVector<ValueRange> newCaseOperands;
733ace01605SRiver Riddle   SmallVector<SmallVector<Value>> argStorage;
734ace01605SRiver Riddle   auto caseValues = op.getCaseValues();
735f735b3a2SVitaly Buka   argStorage.reserve(caseValues->size() + 1);
736ace01605SRiver Riddle   auto caseDests = op.getCaseDestinations();
737ace01605SRiver Riddle   bool requiresChange = false;
738ace01605SRiver Riddle   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
739ace01605SRiver Riddle     Block *caseDest = caseDests[i];
740ace01605SRiver Riddle     ValueRange caseOperands = op.getCaseOperands(i);
741ace01605SRiver Riddle     argStorage.emplace_back();
742ace01605SRiver Riddle     if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
743ace01605SRiver Riddle       requiresChange = true;
744ace01605SRiver Riddle 
745ace01605SRiver Riddle     newCaseDests.push_back(caseDest);
746ace01605SRiver Riddle     newCaseOperands.push_back(caseOperands);
747ace01605SRiver Riddle   }
748ace01605SRiver Riddle 
749ace01605SRiver Riddle   Block *defaultDest = op.getDefaultDestination();
750ace01605SRiver Riddle   ValueRange defaultOperands = op.getDefaultOperands();
751ace01605SRiver Riddle   argStorage.emplace_back();
752ace01605SRiver Riddle 
753ace01605SRiver Riddle   if (succeeded(
754ace01605SRiver Riddle           collapseBranch(defaultDest, defaultOperands, argStorage.back())))
755ace01605SRiver Riddle     requiresChange = true;
756ace01605SRiver Riddle 
757ace01605SRiver Riddle   if (!requiresChange)
758ace01605SRiver Riddle     return failure();
759ace01605SRiver Riddle 
760ace01605SRiver Riddle   rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
7616d5fc1e3SKazu Hirata                                         defaultOperands, *caseValues,
762ace01605SRiver Riddle                                         newCaseDests, newCaseOperands);
763ace01605SRiver Riddle   return success();
764ace01605SRiver Riddle }
765ace01605SRiver Riddle 
766ace01605SRiver Riddle /// switch %flag : i32, [
767ace01605SRiver Riddle ///   default: ^bb1,
768ace01605SRiver Riddle ///   42: ^bb2,
769ace01605SRiver Riddle /// ]
770ace01605SRiver Riddle /// ^bb2:
771ace01605SRiver Riddle ///   switch %flag : i32, [
772ace01605SRiver Riddle ///     default: ^bb3,
773ace01605SRiver Riddle ///     42: ^bb4
774ace01605SRiver Riddle ///   ]
775ace01605SRiver Riddle /// ->
776ace01605SRiver Riddle /// switch %flag : i32, [
777ace01605SRiver Riddle ///   default: ^bb1,
778ace01605SRiver Riddle ///   42: ^bb2,
779ace01605SRiver Riddle /// ]
780ace01605SRiver Riddle /// ^bb2:
781ace01605SRiver Riddle ///   br ^bb4
782ace01605SRiver Riddle ///
783ace01605SRiver Riddle ///  and
784ace01605SRiver Riddle ///
785ace01605SRiver Riddle /// switch %flag : i32, [
786ace01605SRiver Riddle ///   default: ^bb1,
787ace01605SRiver Riddle ///   42: ^bb2,
788ace01605SRiver Riddle /// ]
789ace01605SRiver Riddle /// ^bb2:
790ace01605SRiver Riddle ///   switch %flag : i32, [
791ace01605SRiver Riddle ///     default: ^bb3,
792ace01605SRiver Riddle ///     43: ^bb4
793ace01605SRiver Riddle ///   ]
794ace01605SRiver Riddle /// ->
795ace01605SRiver Riddle /// switch %flag : i32, [
796ace01605SRiver Riddle ///   default: ^bb1,
797ace01605SRiver Riddle ///   42: ^bb2,
798ace01605SRiver Riddle /// ]
799ace01605SRiver Riddle /// ^bb2:
800ace01605SRiver Riddle ///   br ^bb3
801ace01605SRiver Riddle static LogicalResult
802ace01605SRiver Riddle simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
803ace01605SRiver Riddle                                         PatternRewriter &rewriter) {
804ace01605SRiver Riddle   // Check that we have a single distinct predecessor.
805ace01605SRiver Riddle   Block *currentBlock = op->getBlock();
806ace01605SRiver Riddle   Block *predecessor = currentBlock->getSinglePredecessor();
807ace01605SRiver Riddle   if (!predecessor)
808ace01605SRiver Riddle     return failure();
809ace01605SRiver Riddle 
810ace01605SRiver Riddle   // Check that the predecessor terminates with a switch branch to this block
811ace01605SRiver Riddle   // and that it branches on the same condition and that this branch isn't the
812ace01605SRiver Riddle   // default destination.
813ace01605SRiver Riddle   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
814ace01605SRiver Riddle   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
815ace01605SRiver Riddle       predSwitch.getDefaultDestination() == currentBlock)
816ace01605SRiver Riddle     return failure();
817ace01605SRiver Riddle 
818ace01605SRiver Riddle   // Fold this switch to an unconditional branch.
819ace01605SRiver Riddle   SuccessorRange predDests = predSwitch.getCaseDestinations();
820ace01605SRiver Riddle   auto it = llvm::find(predDests, currentBlock);
821ace01605SRiver Riddle   if (it != predDests.end()) {
82222426110SRamkumar Ramachandra     std::optional<DenseIntElementsAttr> predCaseValues =
82322426110SRamkumar Ramachandra         predSwitch.getCaseValues();
824ace01605SRiver Riddle     foldSwitch(op, rewriter,
825ace01605SRiver Riddle                predCaseValues->getValues<APInt>()[it - predDests.begin()]);
826ace01605SRiver Riddle   } else {
827ace01605SRiver Riddle     rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
828ace01605SRiver Riddle                                           op.getDefaultOperands());
829ace01605SRiver Riddle   }
830ace01605SRiver Riddle   return success();
831ace01605SRiver Riddle }
832ace01605SRiver Riddle 
833ace01605SRiver Riddle /// switch %flag : i32, [
834ace01605SRiver Riddle ///   default: ^bb1,
835ace01605SRiver Riddle ///   42: ^bb2
836ace01605SRiver Riddle /// ]
837ace01605SRiver Riddle /// ^bb1:
838ace01605SRiver Riddle ///   switch %flag : i32, [
839ace01605SRiver Riddle ///     default: ^bb3,
840ace01605SRiver Riddle ///     42: ^bb4,
841ace01605SRiver Riddle ///     43: ^bb5
842ace01605SRiver Riddle ///   ]
843ace01605SRiver Riddle /// ->
844ace01605SRiver Riddle /// switch %flag : i32, [
845ace01605SRiver Riddle ///   default: ^bb1,
846ace01605SRiver Riddle ///   42: ^bb2,
847ace01605SRiver Riddle /// ]
848ace01605SRiver Riddle /// ^bb1:
849ace01605SRiver Riddle ///   switch %flag : i32, [
850ace01605SRiver Riddle ///     default: ^bb3,
851ace01605SRiver Riddle ///     43: ^bb5
852ace01605SRiver Riddle ///   ]
853ace01605SRiver Riddle static LogicalResult
854ace01605SRiver Riddle simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
855ace01605SRiver Riddle                                                PatternRewriter &rewriter) {
856ace01605SRiver Riddle   // Check that we have a single distinct predecessor.
857ace01605SRiver Riddle   Block *currentBlock = op->getBlock();
858ace01605SRiver Riddle   Block *predecessor = currentBlock->getSinglePredecessor();
859ace01605SRiver Riddle   if (!predecessor)
860ace01605SRiver Riddle     return failure();
861ace01605SRiver Riddle 
862ace01605SRiver Riddle   // Check that the predecessor terminates with a switch branch to this block
863ace01605SRiver Riddle   // and that it branches on the same condition and that this branch is the
864ace01605SRiver Riddle   // default destination.
865ace01605SRiver Riddle   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
866ace01605SRiver Riddle   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
867ace01605SRiver Riddle       predSwitch.getDefaultDestination() != currentBlock)
868ace01605SRiver Riddle     return failure();
869ace01605SRiver Riddle 
870ace01605SRiver Riddle   // Delete case values that are not possible here.
871ace01605SRiver Riddle   DenseSet<APInt> caseValuesToRemove;
872ace01605SRiver Riddle   auto predDests = predSwitch.getCaseDestinations();
873ace01605SRiver Riddle   auto predCaseValues = predSwitch.getCaseValues();
874ace01605SRiver Riddle   for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
875ace01605SRiver Riddle     if (currentBlock != predDests[i])
876ace01605SRiver Riddle       caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
877ace01605SRiver Riddle 
878ace01605SRiver Riddle   SmallVector<Block *> newCaseDestinations;
879ace01605SRiver Riddle   SmallVector<ValueRange> newCaseOperands;
880ace01605SRiver Riddle   SmallVector<APInt> newCaseValues;
881ace01605SRiver Riddle   bool requiresChange = false;
882ace01605SRiver Riddle 
883ace01605SRiver Riddle   auto caseValues = op.getCaseValues();
884ace01605SRiver Riddle   auto caseDests = op.getCaseDestinations();
885ace01605SRiver Riddle   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
886ace01605SRiver Riddle     if (caseValuesToRemove.contains(it.value())) {
887ace01605SRiver Riddle       requiresChange = true;
888ace01605SRiver Riddle       continue;
889ace01605SRiver Riddle     }
890ace01605SRiver Riddle     newCaseDestinations.push_back(caseDests[it.index()]);
891ace01605SRiver Riddle     newCaseOperands.push_back(op.getCaseOperands(it.index()));
892ace01605SRiver Riddle     newCaseValues.push_back(it.value());
893ace01605SRiver Riddle   }
894ace01605SRiver Riddle 
895ace01605SRiver Riddle   if (!requiresChange)
896ace01605SRiver Riddle     return failure();
897ace01605SRiver Riddle 
898ace01605SRiver Riddle   rewriter.replaceOpWithNewOp<SwitchOp>(
899ace01605SRiver Riddle       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
900ace01605SRiver Riddle       newCaseValues, newCaseDestinations, newCaseOperands);
901ace01605SRiver Riddle   return success();
902ace01605SRiver Riddle }
903ace01605SRiver Riddle 
904ace01605SRiver Riddle void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
905ace01605SRiver Riddle                                            MLIRContext *context) {
906ace01605SRiver Riddle   results.add(&simplifySwitchWithOnlyDefault)
907ace01605SRiver Riddle       .add(&dropSwitchCasesThatMatchDefault)
908ace01605SRiver Riddle       .add(&simplifyConstSwitchValue)
909ace01605SRiver Riddle       .add(&simplifyPassThroughSwitch)
910ace01605SRiver Riddle       .add(&simplifySwitchFromSwitchOnSameCondition)
911ace01605SRiver Riddle       .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
912ace01605SRiver Riddle }
913ace01605SRiver Riddle 
914ace01605SRiver Riddle //===----------------------------------------------------------------------===//
915ace01605SRiver Riddle // TableGen'd op method definitions
916ace01605SRiver Riddle //===----------------------------------------------------------------------===//
917ace01605SRiver Riddle 
918ace01605SRiver Riddle #define GET_OP_CLASSES
919ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
920