xref: /llvm-project/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (revision c1fa60b4cde512964544ab66404dea79dbc5dcb4)
1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10 
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/AffineExpr.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/IRMapping.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Value.h"
24 #include "mlir/Support/MathExtras.h"
25 #include "mlir/Transforms/InliningUtils.h"
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <numeric>
31 
32 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
33 
34 using namespace mlir;
35 using namespace mlir::cf;
36 
37 //===----------------------------------------------------------------------===//
38 // ControlFlowDialect Interfaces
39 //===----------------------------------------------------------------------===//
40 namespace {
41 /// This class defines the interface for handling inlining with control flow
42 /// operations.
43 struct ControlFlowInlinerInterface : public DialectInlinerInterface {
44   using DialectInlinerInterface::DialectInlinerInterface;
45   ~ControlFlowInlinerInterface() override = default;
46 
47   /// All control flow operations can be inlined.
48   bool isLegalToInline(Operation *call, Operation *callable,
49                        bool wouldBeCloned) const final {
50     return true;
51   }
52   bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
53     return true;
54   }
55 
56   /// ControlFlow terminator operations don't really need any special handing.
57   void handleTerminator(Operation *op, Block *newDest) const final {}
58 };
59 } // namespace
60 
61 //===----------------------------------------------------------------------===//
62 // ControlFlowDialect
63 //===----------------------------------------------------------------------===//
64 
65 void ControlFlowDialect::initialize() {
66   addOperations<
67 #define GET_OP_LIST
68 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
69       >();
70   addInterfaces<ControlFlowInlinerInterface>();
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // AssertOp
75 //===----------------------------------------------------------------------===//
76 
77 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
78   // Erase assertion if argument is constant true.
79   if (matchPattern(op.getArg(), m_One())) {
80     rewriter.eraseOp(op);
81     return success();
82   }
83   return failure();
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // BranchOp
88 //===----------------------------------------------------------------------===//
89 
90 /// Given a successor, try to collapse it to a new destination if it only
91 /// contains a passthrough unconditional branch. If the successor is
92 /// collapsable, `successor` and `successorOperands` are updated to reference
93 /// the new destination and values. `argStorage` is used as storage if operands
94 /// to the collapsed successor need to be remapped. It must outlive uses of
95 /// successorOperands.
96 static LogicalResult collapseBranch(Block *&successor,
97                                     ValueRange &successorOperands,
98                                     SmallVectorImpl<Value> &argStorage) {
99   // Check that the successor only contains a unconditional branch.
100   if (std::next(successor->begin()) != successor->end())
101     return failure();
102   // Check that the terminator is an unconditional branch.
103   BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
104   if (!successorBranch)
105     return failure();
106   // Check that the arguments are only used within the terminator.
107   for (BlockArgument arg : successor->getArguments()) {
108     for (Operation *user : arg.getUsers())
109       if (user != successorBranch)
110         return failure();
111   }
112   // Don't try to collapse branches to infinite loops.
113   Block *successorDest = successorBranch.getDest();
114   if (successorDest == successor)
115     return failure();
116 
117   // Update the operands to the successor. If the branch parent has no
118   // arguments, we can use the branch operands directly.
119   OperandRange operands = successorBranch.getOperands();
120   if (successor->args_empty()) {
121     successor = successorDest;
122     successorOperands = operands;
123     return success();
124   }
125 
126   // Otherwise, we need to remap any argument operands.
127   for (Value operand : operands) {
128     BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
129     if (argOperand && argOperand.getOwner() == successor)
130       argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
131     else
132       argStorage.push_back(operand);
133   }
134   successor = successorDest;
135   successorOperands = argStorage;
136   return success();
137 }
138 
139 /// Simplify a branch to a block that has a single predecessor. This effectively
140 /// merges the two blocks.
141 static LogicalResult
142 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
143   // Check that the successor block has a single predecessor.
144   Block *succ = op.getDest();
145   Block *opParent = op->getBlock();
146   if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
147     return failure();
148 
149   // Merge the successor into the current block and erase the branch.
150   SmallVector<Value> brOperands(op.getOperands());
151   rewriter.eraseOp(op);
152   rewriter.mergeBlocks(succ, opParent, brOperands);
153   return success();
154 }
155 
156 ///   br ^bb1
157 /// ^bb1
158 ///   br ^bbN(...)
159 ///
160 ///  -> br ^bbN(...)
161 ///
162 static LogicalResult simplifyPassThroughBr(BranchOp op,
163                                            PatternRewriter &rewriter) {
164   Block *dest = op.getDest();
165   ValueRange destOperands = op.getOperands();
166   SmallVector<Value, 4> destOperandStorage;
167 
168   // Try to collapse the successor if it points somewhere other than this
169   // block.
170   if (dest == op->getBlock() ||
171       failed(collapseBranch(dest, destOperands, destOperandStorage)))
172     return failure();
173 
174   // Create a new branch with the collapsed successor.
175   rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
176   return success();
177 }
178 
179 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
180   return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
181                  succeeded(simplifyPassThroughBr(op, rewriter)));
182 }
183 
184 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
185 
186 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
187 
188 SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
189   assert(index == 0 && "invalid successor index");
190   return SuccessorOperands(getDestOperandsMutable());
191 }
192 
193 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
194   return getDest();
195 }
196 
197 //===----------------------------------------------------------------------===//
198 // CondBranchOp
199 //===----------------------------------------------------------------------===//
200 
201 namespace {
202 /// cf.cond_br true, ^bb1, ^bb2
203 ///  -> br ^bb1
204 /// cf.cond_br false, ^bb1, ^bb2
205 ///  -> br ^bb2
206 ///
207 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
208   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
209 
210   LogicalResult matchAndRewrite(CondBranchOp condbr,
211                                 PatternRewriter &rewriter) const override {
212     if (matchPattern(condbr.getCondition(), m_NonZero())) {
213       // True branch taken.
214       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
215                                             condbr.getTrueOperands());
216       return success();
217     }
218     if (matchPattern(condbr.getCondition(), m_Zero())) {
219       // False branch taken.
220       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
221                                             condbr.getFalseOperands());
222       return success();
223     }
224     return failure();
225   }
226 };
227 
228 ///   cf.cond_br %cond, ^bb1, ^bb2
229 /// ^bb1
230 ///   br ^bbN(...)
231 /// ^bb2
232 ///   br ^bbK(...)
233 ///
234 ///  -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
235 ///
236 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
237   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
238 
239   LogicalResult matchAndRewrite(CondBranchOp condbr,
240                                 PatternRewriter &rewriter) const override {
241     Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
242     ValueRange trueDestOperands = condbr.getTrueOperands();
243     ValueRange falseDestOperands = condbr.getFalseOperands();
244     SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
245 
246     // Try to collapse one of the current successors.
247     LogicalResult collapsedTrue =
248         collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
249     LogicalResult collapsedFalse =
250         collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
251     if (failed(collapsedTrue) && failed(collapsedFalse))
252       return failure();
253 
254     // Create a new branch with the collapsed successors.
255     rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
256                                               trueDest, trueDestOperands,
257                                               falseDest, falseDestOperands);
258     return success();
259   }
260 };
261 
262 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
263 ///  -> br ^bb1(A, ..., N)
264 ///
265 /// cf.cond_br %cond, ^bb1(A), ^bb1(B)
266 ///  -> %select = arith.select %cond, A, B
267 ///     br ^bb1(%select)
268 ///
269 struct SimplifyCondBranchIdenticalSuccessors
270     : public OpRewritePattern<CondBranchOp> {
271   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
272 
273   LogicalResult matchAndRewrite(CondBranchOp condbr,
274                                 PatternRewriter &rewriter) const override {
275     // Check that the true and false destinations are the same and have the same
276     // operands.
277     Block *trueDest = condbr.getTrueDest();
278     if (trueDest != condbr.getFalseDest())
279       return failure();
280 
281     // If all of the operands match, no selects need to be generated.
282     OperandRange trueOperands = condbr.getTrueOperands();
283     OperandRange falseOperands = condbr.getFalseOperands();
284     if (trueOperands == falseOperands) {
285       rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
286       return success();
287     }
288 
289     // Otherwise, if the current block is the only predecessor insert selects
290     // for any mismatched branch operands.
291     if (trueDest->getUniquePredecessor() != condbr->getBlock())
292       return failure();
293 
294     // Generate a select for any operands that differ between the two.
295     SmallVector<Value, 8> mergedOperands;
296     mergedOperands.reserve(trueOperands.size());
297     Value condition = condbr.getCondition();
298     for (auto it : llvm::zip(trueOperands, falseOperands)) {
299       if (std::get<0>(it) == std::get<1>(it))
300         mergedOperands.push_back(std::get<0>(it));
301       else
302         mergedOperands.push_back(rewriter.create<arith::SelectOp>(
303             condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
304     }
305 
306     rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
307     return success();
308   }
309 };
310 
311 ///   ...
312 ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
313 /// ...
314 /// ^bb1: // has single predecessor
315 ///   ...
316 ///   cf.cond_br %cond, ^bb3(...), ^bb4(...)
317 ///
318 /// ->
319 ///
320 ///   ...
321 ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
322 /// ...
323 /// ^bb1: // has single predecessor
324 ///   ...
325 ///   br ^bb3(...)
326 ///
327 struct SimplifyCondBranchFromCondBranchOnSameCondition
328     : public OpRewritePattern<CondBranchOp> {
329   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
330 
331   LogicalResult matchAndRewrite(CondBranchOp condbr,
332                                 PatternRewriter &rewriter) const override {
333     // Check that we have a single distinct predecessor.
334     Block *currentBlock = condbr->getBlock();
335     Block *predecessor = currentBlock->getSinglePredecessor();
336     if (!predecessor)
337       return failure();
338 
339     // Check that the predecessor terminates with a conditional branch to this
340     // block and that it branches on the same condition.
341     auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
342     if (!predBranch || condbr.getCondition() != predBranch.getCondition())
343       return failure();
344 
345     // Fold this branch to an unconditional branch.
346     if (currentBlock == predBranch.getTrueDest())
347       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
348                                             condbr.getTrueDestOperands());
349     else
350       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
351                                             condbr.getFalseDestOperands());
352     return success();
353   }
354 };
355 
356 ///   cf.cond_br %arg0, ^trueB, ^falseB
357 ///
358 /// ^trueB:
359 ///   "test.consumer1"(%arg0) : (i1) -> ()
360 ///    ...
361 ///
362 /// ^falseB:
363 ///   "test.consumer2"(%arg0) : (i1) -> ()
364 ///   ...
365 ///
366 /// ->
367 ///
368 ///   cf.cond_br %arg0, ^trueB, ^falseB
369 /// ^trueB:
370 ///   "test.consumer1"(%true) : (i1) -> ()
371 ///   ...
372 ///
373 /// ^falseB:
374 ///   "test.consumer2"(%false) : (i1) -> ()
375 ///   ...
376 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
377   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
378 
379   LogicalResult matchAndRewrite(CondBranchOp condbr,
380                                 PatternRewriter &rewriter) const override {
381     // Check that we have a single distinct predecessor.
382     bool replaced = false;
383     Type ty = rewriter.getI1Type();
384 
385     // These variables serve to prevent creating duplicate constants
386     // and hold constant true or false values.
387     Value constantTrue = nullptr;
388     Value constantFalse = nullptr;
389 
390     // TODO These checks can be expanded to encompas any use with only
391     // either the true of false edge as a predecessor. For now, we fall
392     // back to checking the single predecessor is given by the true/fasle
393     // destination, thereby ensuring that only that edge can reach the
394     // op.
395     if (condbr.getTrueDest()->getSinglePredecessor()) {
396       for (OpOperand &use :
397            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
398         if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
399           replaced = true;
400 
401           if (!constantTrue)
402             constantTrue = rewriter.create<arith::ConstantOp>(
403                 condbr.getLoc(), ty, rewriter.getBoolAttr(true));
404 
405           rewriter.updateRootInPlace(use.getOwner(),
406                                      [&] { use.set(constantTrue); });
407         }
408       }
409     }
410     if (condbr.getFalseDest()->getSinglePredecessor()) {
411       for (OpOperand &use :
412            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
413         if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
414           replaced = true;
415 
416           if (!constantFalse)
417             constantFalse = rewriter.create<arith::ConstantOp>(
418                 condbr.getLoc(), ty, rewriter.getBoolAttr(false));
419 
420           rewriter.updateRootInPlace(use.getOwner(),
421                                      [&] { use.set(constantFalse); });
422         }
423       }
424     }
425     return success(replaced);
426   }
427 };
428 } // namespace
429 
430 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
431                                                MLIRContext *context) {
432   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
433               SimplifyCondBranchIdenticalSuccessors,
434               SimplifyCondBranchFromCondBranchOnSameCondition,
435               CondBranchTruthPropagation>(context);
436 }
437 
438 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
439   assert(index < getNumSuccessors() && "invalid successor index");
440   return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
441                                               : getFalseDestOperandsMutable());
442 }
443 
444 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
445   if (IntegerAttr condAttr =
446           llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
447     return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
448   return nullptr;
449 }
450 
451 //===----------------------------------------------------------------------===//
452 // SwitchOp
453 //===----------------------------------------------------------------------===//
454 
455 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
456                      Block *defaultDestination, ValueRange defaultOperands,
457                      DenseIntElementsAttr caseValues,
458                      BlockRange caseDestinations,
459                      ArrayRef<ValueRange> caseOperands) {
460   build(builder, result, value, defaultOperands, caseOperands, caseValues,
461         defaultDestination, caseDestinations);
462 }
463 
464 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
465                      Block *defaultDestination, ValueRange defaultOperands,
466                      ArrayRef<APInt> caseValues, BlockRange caseDestinations,
467                      ArrayRef<ValueRange> caseOperands) {
468   DenseIntElementsAttr caseValuesAttr;
469   if (!caseValues.empty()) {
470     ShapedType caseValueType = VectorType::get(
471         static_cast<int64_t>(caseValues.size()), value.getType());
472     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
473   }
474   build(builder, result, value, defaultDestination, defaultOperands,
475         caseValuesAttr, caseDestinations, caseOperands);
476 }
477 
478 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
479                      Block *defaultDestination, ValueRange defaultOperands,
480                      ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
481                      ArrayRef<ValueRange> caseOperands) {
482   DenseIntElementsAttr caseValuesAttr;
483   if (!caseValues.empty()) {
484     ShapedType caseValueType = VectorType::get(
485         static_cast<int64_t>(caseValues.size()), value.getType());
486     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
487   }
488   build(builder, result, value, defaultDestination, defaultOperands,
489         caseValuesAttr, caseDestinations, caseOperands);
490 }
491 
492 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
493 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
494 static ParseResult parseSwitchOpCases(
495     OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
496     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
497     SmallVectorImpl<Type> &defaultOperandTypes,
498     DenseIntElementsAttr &caseValues,
499     SmallVectorImpl<Block *> &caseDestinations,
500     SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
501     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
502   if (parser.parseKeyword("default") || parser.parseColon() ||
503       parser.parseSuccessor(defaultDestination))
504     return failure();
505   if (succeeded(parser.parseOptionalLParen())) {
506     if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
507                                 /*allowResultNumber=*/false) ||
508         parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
509       return failure();
510   }
511 
512   SmallVector<APInt> values;
513   unsigned bitWidth = flagType.getIntOrFloatBitWidth();
514   while (succeeded(parser.parseOptionalComma())) {
515     int64_t value = 0;
516     if (failed(parser.parseInteger(value)))
517       return failure();
518     values.push_back(APInt(bitWidth, value));
519 
520     Block *destination;
521     SmallVector<OpAsmParser::UnresolvedOperand> operands;
522     SmallVector<Type> operandTypes;
523     if (failed(parser.parseColon()) ||
524         failed(parser.parseSuccessor(destination)))
525       return failure();
526     if (succeeded(parser.parseOptionalLParen())) {
527       if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
528                                          /*allowResultNumber=*/false)) ||
529           failed(parser.parseColonTypeList(operandTypes)) ||
530           failed(parser.parseRParen()))
531         return failure();
532     }
533     caseDestinations.push_back(destination);
534     caseOperands.emplace_back(operands);
535     caseOperandTypes.emplace_back(operandTypes);
536   }
537 
538   if (!values.empty()) {
539     ShapedType caseValueType =
540         VectorType::get(static_cast<int64_t>(values.size()), flagType);
541     caseValues = DenseIntElementsAttr::get(caseValueType, values);
542   }
543   return success();
544 }
545 
546 static void printSwitchOpCases(
547     OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
548     OperandRange defaultOperands, TypeRange defaultOperandTypes,
549     DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
550     OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
551   p << "  default: ";
552   p.printSuccessorAndUseList(defaultDestination, defaultOperands);
553 
554   if (!caseValues)
555     return;
556 
557   for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
558     p << ',';
559     p.printNewline();
560     p << "  ";
561     p << it.value().getLimitedValue();
562     p << ": ";
563     p.printSuccessorAndUseList(caseDestinations[it.index()],
564                                caseOperands[it.index()]);
565   }
566   p.printNewline();
567 }
568 
569 LogicalResult SwitchOp::verify() {
570   auto caseValues = getCaseValues();
571   auto caseDestinations = getCaseDestinations();
572 
573   if (!caseValues && caseDestinations.empty())
574     return success();
575 
576   Type flagType = getFlag().getType();
577   Type caseValueType = caseValues->getType().getElementType();
578   if (caseValueType != flagType)
579     return emitOpError() << "'flag' type (" << flagType
580                          << ") should match case value type (" << caseValueType
581                          << ")";
582 
583   if (caseValues &&
584       caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
585     return emitOpError() << "number of case values (" << caseValues->size()
586                          << ") should match number of "
587                             "case destinations ("
588                          << caseDestinations.size() << ")";
589   return success();
590 }
591 
592 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
593   assert(index < getNumSuccessors() && "invalid successor index");
594   return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
595                                       : getCaseOperandsMutable(index - 1));
596 }
597 
598 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
599   std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
600 
601   if (!caseValues)
602     return getDefaultDestination();
603 
604   SuccessorRange caseDests = getCaseDestinations();
605   if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
606     for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
607       if (it.value() == value.getValue())
608         return caseDests[it.index()];
609     return getDefaultDestination();
610   }
611   return nullptr;
612 }
613 
614 /// switch %flag : i32, [
615 ///   default:  ^bb1
616 /// ]
617 ///  -> br ^bb1
618 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
619                                                    PatternRewriter &rewriter) {
620   if (!op.getCaseDestinations().empty())
621     return failure();
622 
623   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
624                                         op.getDefaultOperands());
625   return success();
626 }
627 
628 /// switch %flag : i32, [
629 ///   default: ^bb1,
630 ///   42: ^bb1,
631 ///   43: ^bb2
632 /// ]
633 /// ->
634 /// switch %flag : i32, [
635 ///   default: ^bb1,
636 ///   43: ^bb2
637 /// ]
638 static LogicalResult
639 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
640   SmallVector<Block *> newCaseDestinations;
641   SmallVector<ValueRange> newCaseOperands;
642   SmallVector<APInt> newCaseValues;
643   bool requiresChange = false;
644   auto caseValues = op.getCaseValues();
645   auto caseDests = op.getCaseDestinations();
646 
647   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
648     if (caseDests[it.index()] == op.getDefaultDestination() &&
649         op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
650       requiresChange = true;
651       continue;
652     }
653     newCaseDestinations.push_back(caseDests[it.index()]);
654     newCaseOperands.push_back(op.getCaseOperands(it.index()));
655     newCaseValues.push_back(it.value());
656   }
657 
658   if (!requiresChange)
659     return failure();
660 
661   rewriter.replaceOpWithNewOp<SwitchOp>(
662       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
663       newCaseValues, newCaseDestinations, newCaseOperands);
664   return success();
665 }
666 
667 /// Helper for folding a switch with a constant value.
668 /// switch %c_42 : i32, [
669 ///   default: ^bb1 ,
670 ///   42: ^bb2,
671 ///   43: ^bb3
672 /// ]
673 /// -> br ^bb2
674 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
675                        const APInt &caseValue) {
676   auto caseValues = op.getCaseValues();
677   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
678     if (it.value() == caseValue) {
679       rewriter.replaceOpWithNewOp<BranchOp>(
680           op, op.getCaseDestinations()[it.index()],
681           op.getCaseOperands(it.index()));
682       return;
683     }
684   }
685   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
686                                         op.getDefaultOperands());
687 }
688 
689 /// switch %c_42 : i32, [
690 ///   default: ^bb1,
691 ///   42: ^bb2,
692 ///   43: ^bb3
693 /// ]
694 /// -> br ^bb2
695 static LogicalResult simplifyConstSwitchValue(SwitchOp op,
696                                               PatternRewriter &rewriter) {
697   APInt caseValue;
698   if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
699     return failure();
700 
701   foldSwitch(op, rewriter, caseValue);
702   return success();
703 }
704 
705 /// switch %c_42 : i32, [
706 ///   default: ^bb1,
707 ///   42: ^bb2,
708 /// ]
709 /// ^bb2:
710 ///   br ^bb3
711 /// ->
712 /// switch %c_42 : i32, [
713 ///   default: ^bb1,
714 ///   42: ^bb3,
715 /// ]
716 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
717                                                PatternRewriter &rewriter) {
718   SmallVector<Block *> newCaseDests;
719   SmallVector<ValueRange> newCaseOperands;
720   SmallVector<SmallVector<Value>> argStorage;
721   auto caseValues = op.getCaseValues();
722   argStorage.reserve(caseValues->size() + 1);
723   auto caseDests = op.getCaseDestinations();
724   bool requiresChange = false;
725   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
726     Block *caseDest = caseDests[i];
727     ValueRange caseOperands = op.getCaseOperands(i);
728     argStorage.emplace_back();
729     if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
730       requiresChange = true;
731 
732     newCaseDests.push_back(caseDest);
733     newCaseOperands.push_back(caseOperands);
734   }
735 
736   Block *defaultDest = op.getDefaultDestination();
737   ValueRange defaultOperands = op.getDefaultOperands();
738   argStorage.emplace_back();
739 
740   if (succeeded(
741           collapseBranch(defaultDest, defaultOperands, argStorage.back())))
742     requiresChange = true;
743 
744   if (!requiresChange)
745     return failure();
746 
747   rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
748                                         defaultOperands, *caseValues,
749                                         newCaseDests, newCaseOperands);
750   return success();
751 }
752 
753 /// switch %flag : i32, [
754 ///   default: ^bb1,
755 ///   42: ^bb2,
756 /// ]
757 /// ^bb2:
758 ///   switch %flag : i32, [
759 ///     default: ^bb3,
760 ///     42: ^bb4
761 ///   ]
762 /// ->
763 /// switch %flag : i32, [
764 ///   default: ^bb1,
765 ///   42: ^bb2,
766 /// ]
767 /// ^bb2:
768 ///   br ^bb4
769 ///
770 ///  and
771 ///
772 /// switch %flag : i32, [
773 ///   default: ^bb1,
774 ///   42: ^bb2,
775 /// ]
776 /// ^bb2:
777 ///   switch %flag : i32, [
778 ///     default: ^bb3,
779 ///     43: ^bb4
780 ///   ]
781 /// ->
782 /// switch %flag : i32, [
783 ///   default: ^bb1,
784 ///   42: ^bb2,
785 /// ]
786 /// ^bb2:
787 ///   br ^bb3
788 static LogicalResult
789 simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
790                                         PatternRewriter &rewriter) {
791   // Check that we have a single distinct predecessor.
792   Block *currentBlock = op->getBlock();
793   Block *predecessor = currentBlock->getSinglePredecessor();
794   if (!predecessor)
795     return failure();
796 
797   // Check that the predecessor terminates with a switch branch to this block
798   // and that it branches on the same condition and that this branch isn't the
799   // default destination.
800   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
801   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
802       predSwitch.getDefaultDestination() == currentBlock)
803     return failure();
804 
805   // Fold this switch to an unconditional branch.
806   SuccessorRange predDests = predSwitch.getCaseDestinations();
807   auto it = llvm::find(predDests, currentBlock);
808   if (it != predDests.end()) {
809     std::optional<DenseIntElementsAttr> predCaseValues =
810         predSwitch.getCaseValues();
811     foldSwitch(op, rewriter,
812                predCaseValues->getValues<APInt>()[it - predDests.begin()]);
813   } else {
814     rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
815                                           op.getDefaultOperands());
816   }
817   return success();
818 }
819 
820 /// switch %flag : i32, [
821 ///   default: ^bb1,
822 ///   42: ^bb2
823 /// ]
824 /// ^bb1:
825 ///   switch %flag : i32, [
826 ///     default: ^bb3,
827 ///     42: ^bb4,
828 ///     43: ^bb5
829 ///   ]
830 /// ->
831 /// switch %flag : i32, [
832 ///   default: ^bb1,
833 ///   42: ^bb2,
834 /// ]
835 /// ^bb1:
836 ///   switch %flag : i32, [
837 ///     default: ^bb3,
838 ///     43: ^bb5
839 ///   ]
840 static LogicalResult
841 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
842                                                PatternRewriter &rewriter) {
843   // Check that we have a single distinct predecessor.
844   Block *currentBlock = op->getBlock();
845   Block *predecessor = currentBlock->getSinglePredecessor();
846   if (!predecessor)
847     return failure();
848 
849   // Check that the predecessor terminates with a switch branch to this block
850   // and that it branches on the same condition and that this branch is the
851   // default destination.
852   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
853   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
854       predSwitch.getDefaultDestination() != currentBlock)
855     return failure();
856 
857   // Delete case values that are not possible here.
858   DenseSet<APInt> caseValuesToRemove;
859   auto predDests = predSwitch.getCaseDestinations();
860   auto predCaseValues = predSwitch.getCaseValues();
861   for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
862     if (currentBlock != predDests[i])
863       caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
864 
865   SmallVector<Block *> newCaseDestinations;
866   SmallVector<ValueRange> newCaseOperands;
867   SmallVector<APInt> newCaseValues;
868   bool requiresChange = false;
869 
870   auto caseValues = op.getCaseValues();
871   auto caseDests = op.getCaseDestinations();
872   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
873     if (caseValuesToRemove.contains(it.value())) {
874       requiresChange = true;
875       continue;
876     }
877     newCaseDestinations.push_back(caseDests[it.index()]);
878     newCaseOperands.push_back(op.getCaseOperands(it.index()));
879     newCaseValues.push_back(it.value());
880   }
881 
882   if (!requiresChange)
883     return failure();
884 
885   rewriter.replaceOpWithNewOp<SwitchOp>(
886       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
887       newCaseValues, newCaseDestinations, newCaseOperands);
888   return success();
889 }
890 
891 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
892                                            MLIRContext *context) {
893   results.add(&simplifySwitchWithOnlyDefault)
894       .add(&dropSwitchCasesThatMatchDefault)
895       .add(&simplifyConstSwitchValue)
896       .add(&simplifyPassThroughSwitch)
897       .add(&simplifySwitchFromSwitchOnSameCondition)
898       .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
899 }
900 
901 //===----------------------------------------------------------------------===//
902 // TableGen'd op method definitions
903 //===----------------------------------------------------------------------===//
904 
905 #define GET_OP_CLASSES
906 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
907