xref: /llvm-project/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (revision 42c31d8302ff8f716601df9c276a5cd9ace6b158)
1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10 
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/AffineExpr.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/IRMapping.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Value.h"
24 #include "mlir/Support/MathExtras.h"
25 #include "mlir/Transforms/InliningUtils.h"
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <numeric>
31 
32 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
33 
34 using namespace mlir;
35 using namespace mlir::cf;
36 
37 //===----------------------------------------------------------------------===//
38 // ControlFlowDialect Interfaces
39 //===----------------------------------------------------------------------===//
40 namespace {
41 /// This class defines the interface for handling inlining with control flow
42 /// operations.
43 struct ControlFlowInlinerInterface : public DialectInlinerInterface {
44   using DialectInlinerInterface::DialectInlinerInterface;
45   ~ControlFlowInlinerInterface() override = default;
46 
47   /// All control flow operations can be inlined.
48   bool isLegalToInline(Operation *call, Operation *callable,
49                        bool wouldBeCloned) const final {
50     return true;
51   }
52   bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
53     return true;
54   }
55 
56   /// ControlFlow terminator operations don't really need any special handing.
57   void handleTerminator(Operation *op, Block *newDest) const final {}
58 };
59 } // namespace
60 
61 //===----------------------------------------------------------------------===//
62 // ControlFlowDialect
63 //===----------------------------------------------------------------------===//
64 
65 void ControlFlowDialect::initialize() {
66   addOperations<
67 #define GET_OP_LIST
68 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
69       >();
70   addInterfaces<ControlFlowInlinerInterface>();
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // AssertOp
75 //===----------------------------------------------------------------------===//
76 
77 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
78   // Erase assertion if argument is constant true.
79   if (matchPattern(op.getArg(), m_One())) {
80     rewriter.eraseOp(op);
81     return success();
82   }
83   return failure();
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // BranchOp
88 //===----------------------------------------------------------------------===//
89 
90 /// Given a successor, try to collapse it to a new destination if it only
91 /// contains a passthrough unconditional branch. If the successor is
92 /// collapsable, `successor` and `successorOperands` are updated to reference
93 /// the new destination and values. `argStorage` is used as storage if operands
94 /// to the collapsed successor need to be remapped. It must outlive uses of
95 /// successorOperands.
96 static LogicalResult collapseBranch(Block *&successor,
97                                     ValueRange &successorOperands,
98                                     SmallVectorImpl<Value> &argStorage) {
99   // Check that the successor only contains a unconditional branch.
100   if (std::next(successor->begin()) != successor->end())
101     return failure();
102   // Check that the terminator is an unconditional branch.
103   BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
104   if (!successorBranch)
105     return failure();
106   // Check that the arguments are only used within the terminator.
107   for (BlockArgument arg : successor->getArguments()) {
108     for (Operation *user : arg.getUsers())
109       if (user != successorBranch)
110         return failure();
111   }
112   // Don't try to collapse branches to infinite loops.
113   Block *successorDest = successorBranch.getDest();
114   if (successorDest == successor)
115     return failure();
116 
117   // Update the operands to the successor. If the branch parent has no
118   // arguments, we can use the branch operands directly.
119   OperandRange operands = successorBranch.getOperands();
120   if (successor->args_empty()) {
121     successor = successorDest;
122     successorOperands = operands;
123     return success();
124   }
125 
126   // Otherwise, we need to remap any argument operands.
127   for (Value operand : operands) {
128     BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
129     if (argOperand && argOperand.getOwner() == successor)
130       argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
131     else
132       argStorage.push_back(operand);
133   }
134   successor = successorDest;
135   successorOperands = argStorage;
136   return success();
137 }
138 
139 /// Simplify a branch to a block that has a single predecessor. This effectively
140 /// merges the two blocks.
141 static LogicalResult
142 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
143   // Check that the successor block has a single predecessor.
144   Block *succ = op.getDest();
145   Block *opParent = op->getBlock();
146   if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
147     return failure();
148 
149   // Merge the successor into the current block and erase the branch.
150   SmallVector<Value> brOperands(op.getOperands());
151   rewriter.eraseOp(op);
152   rewriter.mergeBlocks(succ, opParent, brOperands);
153   return success();
154 }
155 
156 ///   br ^bb1
157 /// ^bb1
158 ///   br ^bbN(...)
159 ///
160 ///  -> br ^bbN(...)
161 ///
162 static LogicalResult simplifyPassThroughBr(BranchOp op,
163                                            PatternRewriter &rewriter) {
164   Block *dest = op.getDest();
165   ValueRange destOperands = op.getOperands();
166   SmallVector<Value, 4> destOperandStorage;
167 
168   // Try to collapse the successor if it points somewhere other than this
169   // block.
170   if (dest == op->getBlock() ||
171       failed(collapseBranch(dest, destOperands, destOperandStorage)))
172     return failure();
173 
174   // Create a new branch with the collapsed successor.
175   rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
176   return success();
177 }
178 
179 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
180   return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
181                  succeeded(simplifyPassThroughBr(op, rewriter)));
182 }
183 
184 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
185 
186 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
187 
188 SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
189   assert(index == 0 && "invalid successor index");
190   return SuccessorOperands(getDestOperandsMutable());
191 }
192 
193 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
194   return getDest();
195 }
196 
197 //===----------------------------------------------------------------------===//
198 // CondBranchOp
199 //===----------------------------------------------------------------------===//
200 
201 namespace {
202 /// cf.cond_br true, ^bb1, ^bb2
203 ///  -> br ^bb1
204 /// cf.cond_br false, ^bb1, ^bb2
205 ///  -> br ^bb2
206 ///
207 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
208   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
209 
210   LogicalResult matchAndRewrite(CondBranchOp condbr,
211                                 PatternRewriter &rewriter) const override {
212     if (matchPattern(condbr.getCondition(), m_NonZero())) {
213       // True branch taken.
214       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
215                                             condbr.getTrueOperands());
216       return success();
217     }
218     if (matchPattern(condbr.getCondition(), m_Zero())) {
219       // False branch taken.
220       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
221                                             condbr.getFalseOperands());
222       return success();
223     }
224     return failure();
225   }
226 };
227 
228 ///   cf.cond_br %cond, ^bb1, ^bb2
229 /// ^bb1
230 ///   br ^bbN(...)
231 /// ^bb2
232 ///   br ^bbK(...)
233 ///
234 ///  -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
235 ///
236 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
237   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
238 
239   LogicalResult matchAndRewrite(CondBranchOp condbr,
240                                 PatternRewriter &rewriter) const override {
241     Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
242     ValueRange trueDestOperands = condbr.getTrueOperands();
243     ValueRange falseDestOperands = condbr.getFalseOperands();
244     SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
245 
246     // Try to collapse one of the current successors.
247     LogicalResult collapsedTrue =
248         collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
249     LogicalResult collapsedFalse =
250         collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
251     if (failed(collapsedTrue) && failed(collapsedFalse))
252       return failure();
253 
254     // Create a new branch with the collapsed successors.
255     rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
256                                               trueDest, trueDestOperands,
257                                               falseDest, falseDestOperands);
258     return success();
259   }
260 };
261 
262 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
263 ///  -> br ^bb1(A, ..., N)
264 ///
265 /// cf.cond_br %cond, ^bb1(A), ^bb1(B)
266 ///  -> %select = arith.select %cond, A, B
267 ///     br ^bb1(%select)
268 ///
269 struct SimplifyCondBranchIdenticalSuccessors
270     : public OpRewritePattern<CondBranchOp> {
271   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
272 
273   LogicalResult matchAndRewrite(CondBranchOp condbr,
274                                 PatternRewriter &rewriter) const override {
275     // Check that the true and false destinations are the same and have the same
276     // operands.
277     Block *trueDest = condbr.getTrueDest();
278     if (trueDest != condbr.getFalseDest())
279       return failure();
280 
281     // If all of the operands match, no selects need to be generated.
282     OperandRange trueOperands = condbr.getTrueOperands();
283     OperandRange falseOperands = condbr.getFalseOperands();
284     if (trueOperands == falseOperands) {
285       rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
286       return success();
287     }
288 
289     // Otherwise, if the current block is the only predecessor insert selects
290     // for any mismatched branch operands.
291     if (trueDest->getUniquePredecessor() != condbr->getBlock())
292       return failure();
293 
294     // Generate a select for any operands that differ between the two.
295     SmallVector<Value, 8> mergedOperands;
296     mergedOperands.reserve(trueOperands.size());
297     Value condition = condbr.getCondition();
298     for (auto it : llvm::zip(trueOperands, falseOperands)) {
299       if (std::get<0>(it) == std::get<1>(it))
300         mergedOperands.push_back(std::get<0>(it));
301       else
302         mergedOperands.push_back(rewriter.create<arith::SelectOp>(
303             condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
304     }
305 
306     rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
307     return success();
308   }
309 };
310 
311 ///   ...
312 ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
313 /// ...
314 /// ^bb1: // has single predecessor
315 ///   ...
316 ///   cf.cond_br %cond, ^bb3(...), ^bb4(...)
317 ///
318 /// ->
319 ///
320 ///   ...
321 ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
322 /// ...
323 /// ^bb1: // has single predecessor
324 ///   ...
325 ///   br ^bb3(...)
326 ///
327 struct SimplifyCondBranchFromCondBranchOnSameCondition
328     : public OpRewritePattern<CondBranchOp> {
329   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
330 
331   LogicalResult matchAndRewrite(CondBranchOp condbr,
332                                 PatternRewriter &rewriter) const override {
333     // Check that we have a single distinct predecessor.
334     Block *currentBlock = condbr->getBlock();
335     Block *predecessor = currentBlock->getSinglePredecessor();
336     if (!predecessor)
337       return failure();
338 
339     // Check that the predecessor terminates with a conditional branch to this
340     // block and that it branches on the same condition.
341     auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
342     if (!predBranch || condbr.getCondition() != predBranch.getCondition())
343       return failure();
344 
345     // Fold this branch to an unconditional branch.
346     if (currentBlock == predBranch.getTrueDest())
347       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
348                                             condbr.getTrueDestOperands());
349     else
350       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
351                                             condbr.getFalseDestOperands());
352     return success();
353   }
354 };
355 
356 ///   cf.cond_br %arg0, ^trueB, ^falseB
357 ///
358 /// ^trueB:
359 ///   "test.consumer1"(%arg0) : (i1) -> ()
360 ///    ...
361 ///
362 /// ^falseB:
363 ///   "test.consumer2"(%arg0) : (i1) -> ()
364 ///   ...
365 ///
366 /// ->
367 ///
368 ///   cf.cond_br %arg0, ^trueB, ^falseB
369 /// ^trueB:
370 ///   "test.consumer1"(%true) : (i1) -> ()
371 ///   ...
372 ///
373 /// ^falseB:
374 ///   "test.consumer2"(%false) : (i1) -> ()
375 ///   ...
376 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
377   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
378 
379   LogicalResult matchAndRewrite(CondBranchOp condbr,
380                                 PatternRewriter &rewriter) const override {
381     // Check that we have a single distinct predecessor.
382     bool replaced = false;
383     Type ty = rewriter.getI1Type();
384 
385     // These variables serve to prevent creating duplicate constants
386     // and hold constant true or false values.
387     Value constantTrue = nullptr;
388     Value constantFalse = nullptr;
389 
390     // TODO These checks can be expanded to encompas any use with only
391     // either the true of false edge as a predecessor. For now, we fall
392     // back to checking the single predecessor is given by the true/fasle
393     // destination, thereby ensuring that only that edge can reach the
394     // op.
395     if (condbr.getTrueDest()->getSinglePredecessor()) {
396       for (OpOperand &use :
397            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
398         if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
399           replaced = true;
400 
401           if (!constantTrue)
402             constantTrue = rewriter.create<arith::ConstantOp>(
403                 condbr.getLoc(), ty, rewriter.getBoolAttr(true));
404 
405           rewriter.updateRootInPlace(use.getOwner(),
406                                      [&] { use.set(constantTrue); });
407         }
408       }
409     }
410     if (condbr.getFalseDest()->getSinglePredecessor()) {
411       for (OpOperand &use :
412            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
413         if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
414           replaced = true;
415 
416           if (!constantFalse)
417             constantFalse = rewriter.create<arith::ConstantOp>(
418                 condbr.getLoc(), ty, rewriter.getBoolAttr(false));
419 
420           rewriter.updateRootInPlace(use.getOwner(),
421                                      [&] { use.set(constantFalse); });
422         }
423       }
424     }
425     return success(replaced);
426   }
427 };
428 } // namespace
429 
430 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
431                                                MLIRContext *context) {
432   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
433               SimplifyCondBranchIdenticalSuccessors,
434               SimplifyCondBranchFromCondBranchOnSameCondition,
435               CondBranchTruthPropagation>(context);
436 }
437 
438 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
439   assert(index < getNumSuccessors() && "invalid successor index");
440   return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
441                                               : getFalseDestOperandsMutable());
442 }
443 
444 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
445   if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
446     return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
447   return nullptr;
448 }
449 
450 //===----------------------------------------------------------------------===//
451 // SwitchOp
452 //===----------------------------------------------------------------------===//
453 
454 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
455                      Block *defaultDestination, ValueRange defaultOperands,
456                      DenseIntElementsAttr caseValues,
457                      BlockRange caseDestinations,
458                      ArrayRef<ValueRange> caseOperands) {
459   build(builder, result, value, defaultOperands, caseOperands, caseValues,
460         defaultDestination, caseDestinations);
461 }
462 
463 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
464                      Block *defaultDestination, ValueRange defaultOperands,
465                      ArrayRef<APInt> caseValues, BlockRange caseDestinations,
466                      ArrayRef<ValueRange> caseOperands) {
467   DenseIntElementsAttr caseValuesAttr;
468   if (!caseValues.empty()) {
469     ShapedType caseValueType = VectorType::get(
470         static_cast<int64_t>(caseValues.size()), value.getType());
471     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
472   }
473   build(builder, result, value, defaultDestination, defaultOperands,
474         caseValuesAttr, caseDestinations, caseOperands);
475 }
476 
477 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
478                      Block *defaultDestination, ValueRange defaultOperands,
479                      ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
480                      ArrayRef<ValueRange> caseOperands) {
481   DenseIntElementsAttr caseValuesAttr;
482   if (!caseValues.empty()) {
483     ShapedType caseValueType = VectorType::get(
484         static_cast<int64_t>(caseValues.size()), value.getType());
485     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
486   }
487   build(builder, result, value, defaultDestination, defaultOperands,
488         caseValuesAttr, caseDestinations, caseOperands);
489 }
490 
491 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
492 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
493 static ParseResult parseSwitchOpCases(
494     OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
495     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
496     SmallVectorImpl<Type> &defaultOperandTypes,
497     DenseIntElementsAttr &caseValues,
498     SmallVectorImpl<Block *> &caseDestinations,
499     SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
500     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
501   if (parser.parseKeyword("default") || parser.parseColon() ||
502       parser.parseSuccessor(defaultDestination))
503     return failure();
504   if (succeeded(parser.parseOptionalLParen())) {
505     if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
506                                 /*allowResultNumber=*/false) ||
507         parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
508       return failure();
509   }
510 
511   SmallVector<APInt> values;
512   unsigned bitWidth = flagType.getIntOrFloatBitWidth();
513   while (succeeded(parser.parseOptionalComma())) {
514     int64_t value = 0;
515     if (failed(parser.parseInteger(value)))
516       return failure();
517     values.push_back(APInt(bitWidth, value));
518 
519     Block *destination;
520     SmallVector<OpAsmParser::UnresolvedOperand> operands;
521     SmallVector<Type> operandTypes;
522     if (failed(parser.parseColon()) ||
523         failed(parser.parseSuccessor(destination)))
524       return failure();
525     if (succeeded(parser.parseOptionalLParen())) {
526       if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
527                                          /*allowResultNumber=*/false)) ||
528           failed(parser.parseColonTypeList(operandTypes)) ||
529           failed(parser.parseRParen()))
530         return failure();
531     }
532     caseDestinations.push_back(destination);
533     caseOperands.emplace_back(operands);
534     caseOperandTypes.emplace_back(operandTypes);
535   }
536 
537   if (!values.empty()) {
538     ShapedType caseValueType =
539         VectorType::get(static_cast<int64_t>(values.size()), flagType);
540     caseValues = DenseIntElementsAttr::get(caseValueType, values);
541   }
542   return success();
543 }
544 
545 static void printSwitchOpCases(
546     OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
547     OperandRange defaultOperands, TypeRange defaultOperandTypes,
548     DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
549     OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
550   p << "  default: ";
551   p.printSuccessorAndUseList(defaultDestination, defaultOperands);
552 
553   if (!caseValues)
554     return;
555 
556   for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
557     p << ',';
558     p.printNewline();
559     p << "  ";
560     p << it.value().getLimitedValue();
561     p << ": ";
562     p.printSuccessorAndUseList(caseDestinations[it.index()],
563                                caseOperands[it.index()]);
564   }
565   p.printNewline();
566 }
567 
568 LogicalResult SwitchOp::verify() {
569   auto caseValues = getCaseValues();
570   auto caseDestinations = getCaseDestinations();
571 
572   if (!caseValues && caseDestinations.empty())
573     return success();
574 
575   Type flagType = getFlag().getType();
576   Type caseValueType = caseValues->getType().getElementType();
577   if (caseValueType != flagType)
578     return emitOpError() << "'flag' type (" << flagType
579                          << ") should match case value type (" << caseValueType
580                          << ")";
581 
582   if (caseValues &&
583       caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
584     return emitOpError() << "number of case values (" << caseValues->size()
585                          << ") should match number of "
586                             "case destinations ("
587                          << caseDestinations.size() << ")";
588   return success();
589 }
590 
591 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
592   assert(index < getNumSuccessors() && "invalid successor index");
593   return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
594                                       : getCaseOperandsMutable(index - 1));
595 }
596 
597 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
598   std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
599 
600   if (!caseValues)
601     return getDefaultDestination();
602 
603   SuccessorRange caseDests = getCaseDestinations();
604   if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
605     for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
606       if (it.value() == value.getValue())
607         return caseDests[it.index()];
608     return getDefaultDestination();
609   }
610   return nullptr;
611 }
612 
613 /// switch %flag : i32, [
614 ///   default:  ^bb1
615 /// ]
616 ///  -> br ^bb1
617 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
618                                                    PatternRewriter &rewriter) {
619   if (!op.getCaseDestinations().empty())
620     return failure();
621 
622   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
623                                         op.getDefaultOperands());
624   return success();
625 }
626 
627 /// switch %flag : i32, [
628 ///   default: ^bb1,
629 ///   42: ^bb1,
630 ///   43: ^bb2
631 /// ]
632 /// ->
633 /// switch %flag : i32, [
634 ///   default: ^bb1,
635 ///   43: ^bb2
636 /// ]
637 static LogicalResult
638 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
639   SmallVector<Block *> newCaseDestinations;
640   SmallVector<ValueRange> newCaseOperands;
641   SmallVector<APInt> newCaseValues;
642   bool requiresChange = false;
643   auto caseValues = op.getCaseValues();
644   auto caseDests = op.getCaseDestinations();
645 
646   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
647     if (caseDests[it.index()] == op.getDefaultDestination() &&
648         op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
649       requiresChange = true;
650       continue;
651     }
652     newCaseDestinations.push_back(caseDests[it.index()]);
653     newCaseOperands.push_back(op.getCaseOperands(it.index()));
654     newCaseValues.push_back(it.value());
655   }
656 
657   if (!requiresChange)
658     return failure();
659 
660   rewriter.replaceOpWithNewOp<SwitchOp>(
661       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
662       newCaseValues, newCaseDestinations, newCaseOperands);
663   return success();
664 }
665 
666 /// Helper for folding a switch with a constant value.
667 /// switch %c_42 : i32, [
668 ///   default: ^bb1 ,
669 ///   42: ^bb2,
670 ///   43: ^bb3
671 /// ]
672 /// -> br ^bb2
673 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
674                        const APInt &caseValue) {
675   auto caseValues = op.getCaseValues();
676   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
677     if (it.value() == caseValue) {
678       rewriter.replaceOpWithNewOp<BranchOp>(
679           op, op.getCaseDestinations()[it.index()],
680           op.getCaseOperands(it.index()));
681       return;
682     }
683   }
684   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
685                                         op.getDefaultOperands());
686 }
687 
688 /// switch %c_42 : i32, [
689 ///   default: ^bb1,
690 ///   42: ^bb2,
691 ///   43: ^bb3
692 /// ]
693 /// -> br ^bb2
694 static LogicalResult simplifyConstSwitchValue(SwitchOp op,
695                                               PatternRewriter &rewriter) {
696   APInt caseValue;
697   if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
698     return failure();
699 
700   foldSwitch(op, rewriter, caseValue);
701   return success();
702 }
703 
704 /// switch %c_42 : i32, [
705 ///   default: ^bb1,
706 ///   42: ^bb2,
707 /// ]
708 /// ^bb2:
709 ///   br ^bb3
710 /// ->
711 /// switch %c_42 : i32, [
712 ///   default: ^bb1,
713 ///   42: ^bb3,
714 /// ]
715 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
716                                                PatternRewriter &rewriter) {
717   SmallVector<Block *> newCaseDests;
718   SmallVector<ValueRange> newCaseOperands;
719   SmallVector<SmallVector<Value>> argStorage;
720   auto caseValues = op.getCaseValues();
721   argStorage.reserve(caseValues->size() + 1);
722   auto caseDests = op.getCaseDestinations();
723   bool requiresChange = false;
724   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
725     Block *caseDest = caseDests[i];
726     ValueRange caseOperands = op.getCaseOperands(i);
727     argStorage.emplace_back();
728     if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
729       requiresChange = true;
730 
731     newCaseDests.push_back(caseDest);
732     newCaseOperands.push_back(caseOperands);
733   }
734 
735   Block *defaultDest = op.getDefaultDestination();
736   ValueRange defaultOperands = op.getDefaultOperands();
737   argStorage.emplace_back();
738 
739   if (succeeded(
740           collapseBranch(defaultDest, defaultOperands, argStorage.back())))
741     requiresChange = true;
742 
743   if (!requiresChange)
744     return failure();
745 
746   rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
747                                         defaultOperands, *caseValues,
748                                         newCaseDests, newCaseOperands);
749   return success();
750 }
751 
752 /// switch %flag : i32, [
753 ///   default: ^bb1,
754 ///   42: ^bb2,
755 /// ]
756 /// ^bb2:
757 ///   switch %flag : i32, [
758 ///     default: ^bb3,
759 ///     42: ^bb4
760 ///   ]
761 /// ->
762 /// switch %flag : i32, [
763 ///   default: ^bb1,
764 ///   42: ^bb2,
765 /// ]
766 /// ^bb2:
767 ///   br ^bb4
768 ///
769 ///  and
770 ///
771 /// switch %flag : i32, [
772 ///   default: ^bb1,
773 ///   42: ^bb2,
774 /// ]
775 /// ^bb2:
776 ///   switch %flag : i32, [
777 ///     default: ^bb3,
778 ///     43: ^bb4
779 ///   ]
780 /// ->
781 /// switch %flag : i32, [
782 ///   default: ^bb1,
783 ///   42: ^bb2,
784 /// ]
785 /// ^bb2:
786 ///   br ^bb3
787 static LogicalResult
788 simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
789                                         PatternRewriter &rewriter) {
790   // Check that we have a single distinct predecessor.
791   Block *currentBlock = op->getBlock();
792   Block *predecessor = currentBlock->getSinglePredecessor();
793   if (!predecessor)
794     return failure();
795 
796   // Check that the predecessor terminates with a switch branch to this block
797   // and that it branches on the same condition and that this branch isn't the
798   // default destination.
799   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
800   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
801       predSwitch.getDefaultDestination() == currentBlock)
802     return failure();
803 
804   // Fold this switch to an unconditional branch.
805   SuccessorRange predDests = predSwitch.getCaseDestinations();
806   auto it = llvm::find(predDests, currentBlock);
807   if (it != predDests.end()) {
808     std::optional<DenseIntElementsAttr> predCaseValues =
809         predSwitch.getCaseValues();
810     foldSwitch(op, rewriter,
811                predCaseValues->getValues<APInt>()[it - predDests.begin()]);
812   } else {
813     rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
814                                           op.getDefaultOperands());
815   }
816   return success();
817 }
818 
819 /// switch %flag : i32, [
820 ///   default: ^bb1,
821 ///   42: ^bb2
822 /// ]
823 /// ^bb1:
824 ///   switch %flag : i32, [
825 ///     default: ^bb3,
826 ///     42: ^bb4,
827 ///     43: ^bb5
828 ///   ]
829 /// ->
830 /// switch %flag : i32, [
831 ///   default: ^bb1,
832 ///   42: ^bb2,
833 /// ]
834 /// ^bb1:
835 ///   switch %flag : i32, [
836 ///     default: ^bb3,
837 ///     43: ^bb5
838 ///   ]
839 static LogicalResult
840 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
841                                                PatternRewriter &rewriter) {
842   // Check that we have a single distinct predecessor.
843   Block *currentBlock = op->getBlock();
844   Block *predecessor = currentBlock->getSinglePredecessor();
845   if (!predecessor)
846     return failure();
847 
848   // Check that the predecessor terminates with a switch branch to this block
849   // and that it branches on the same condition and that this branch is the
850   // default destination.
851   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
852   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
853       predSwitch.getDefaultDestination() != currentBlock)
854     return failure();
855 
856   // Delete case values that are not possible here.
857   DenseSet<APInt> caseValuesToRemove;
858   auto predDests = predSwitch.getCaseDestinations();
859   auto predCaseValues = predSwitch.getCaseValues();
860   for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
861     if (currentBlock != predDests[i])
862       caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
863 
864   SmallVector<Block *> newCaseDestinations;
865   SmallVector<ValueRange> newCaseOperands;
866   SmallVector<APInt> newCaseValues;
867   bool requiresChange = false;
868 
869   auto caseValues = op.getCaseValues();
870   auto caseDests = op.getCaseDestinations();
871   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
872     if (caseValuesToRemove.contains(it.value())) {
873       requiresChange = true;
874       continue;
875     }
876     newCaseDestinations.push_back(caseDests[it.index()]);
877     newCaseOperands.push_back(op.getCaseOperands(it.index()));
878     newCaseValues.push_back(it.value());
879   }
880 
881   if (!requiresChange)
882     return failure();
883 
884   rewriter.replaceOpWithNewOp<SwitchOp>(
885       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
886       newCaseValues, newCaseDestinations, newCaseOperands);
887   return success();
888 }
889 
890 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
891                                            MLIRContext *context) {
892   results.add(&simplifySwitchWithOnlyDefault)
893       .add(&dropSwitchCasesThatMatchDefault)
894       .add(&simplifyConstSwitchValue)
895       .add(&simplifyPassThroughSwitch)
896       .add(&simplifySwitchFromSwitchOnSameCondition)
897       .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
898 }
899 
900 //===----------------------------------------------------------------------===//
901 // TableGen'd op method definitions
902 //===----------------------------------------------------------------------===//
903 
904 #define GET_OP_CLASSES
905 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
906