xref: /llvm-project/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (revision 5dce74817b71a1f646fb2857c037b3a66f41c7cd)
1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10 
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/IR/AffineExpr.h"
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/IRMapping.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/OpImplementation.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/IR/Value.h"
23 #include "mlir/Support/MathExtras.h"
24 #include "mlir/Transforms/InliningUtils.h"
25 #include "llvm/ADT/APFloat.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include <numeric>
30 
31 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
32 
33 using namespace mlir;
34 using namespace mlir::cf;
35 
36 //===----------------------------------------------------------------------===//
37 // ControlFlowDialect Interfaces
38 //===----------------------------------------------------------------------===//
39 namespace {
40 /// This class defines the interface for handling inlining with control flow
41 /// operations.
42 struct ControlFlowInlinerInterface : public DialectInlinerInterface {
43   using DialectInlinerInterface::DialectInlinerInterface;
44   ~ControlFlowInlinerInterface() override = default;
45 
46   /// All control flow operations can be inlined.
47   bool isLegalToInline(Operation *call, Operation *callable,
48                        bool wouldBeCloned) const final {
49     return true;
50   }
51   bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
52     return true;
53   }
54 
55   /// ControlFlow terminator operations don't really need any special handing.
56   void handleTerminator(Operation *op, Block *newDest) const final {}
57 };
58 } // namespace
59 
60 //===----------------------------------------------------------------------===//
61 // ControlFlowDialect
62 //===----------------------------------------------------------------------===//
63 
64 void ControlFlowDialect::initialize() {
65   addOperations<
66 #define GET_OP_LIST
67 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
68       >();
69   addInterfaces<ControlFlowInlinerInterface>();
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // AssertOp
74 //===----------------------------------------------------------------------===//
75 
76 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
77   // Erase assertion if argument is constant true.
78   if (matchPattern(op.getArg(), m_One())) {
79     rewriter.eraseOp(op);
80     return success();
81   }
82   return failure();
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // BranchOp
87 //===----------------------------------------------------------------------===//
88 
89 /// Given a successor, try to collapse it to a new destination if it only
90 /// contains a passthrough unconditional branch. If the successor is
91 /// collapsable, `successor` and `successorOperands` are updated to reference
92 /// the new destination and values. `argStorage` is used as storage if operands
93 /// to the collapsed successor need to be remapped. It must outlive uses of
94 /// successorOperands.
95 static LogicalResult collapseBranch(Block *&successor,
96                                     ValueRange &successorOperands,
97                                     SmallVectorImpl<Value> &argStorage) {
98   // Check that the successor only contains a unconditional branch.
99   if (std::next(successor->begin()) != successor->end())
100     return failure();
101   // Check that the terminator is an unconditional branch.
102   BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
103   if (!successorBranch)
104     return failure();
105   // Check that the arguments are only used within the terminator.
106   for (BlockArgument arg : successor->getArguments()) {
107     for (Operation *user : arg.getUsers())
108       if (user != successorBranch)
109         return failure();
110   }
111   // Don't try to collapse branches to infinite loops.
112   Block *successorDest = successorBranch.getDest();
113   if (successorDest == successor)
114     return failure();
115 
116   // Update the operands to the successor. If the branch parent has no
117   // arguments, we can use the branch operands directly.
118   OperandRange operands = successorBranch.getOperands();
119   if (successor->args_empty()) {
120     successor = successorDest;
121     successorOperands = operands;
122     return success();
123   }
124 
125   // Otherwise, we need to remap any argument operands.
126   for (Value operand : operands) {
127     BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
128     if (argOperand && argOperand.getOwner() == successor)
129       argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
130     else
131       argStorage.push_back(operand);
132   }
133   successor = successorDest;
134   successorOperands = argStorage;
135   return success();
136 }
137 
138 /// Simplify a branch to a block that has a single predecessor. This effectively
139 /// merges the two blocks.
140 static LogicalResult
141 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
142   // Check that the successor block has a single predecessor.
143   Block *succ = op.getDest();
144   Block *opParent = op->getBlock();
145   if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
146     return failure();
147 
148   // Merge the successor into the current block and erase the branch.
149   SmallVector<Value> brOperands(op.getOperands());
150   rewriter.eraseOp(op);
151   rewriter.mergeBlocks(succ, opParent, brOperands);
152   return success();
153 }
154 
155 ///   br ^bb1
156 /// ^bb1
157 ///   br ^bbN(...)
158 ///
159 ///  -> br ^bbN(...)
160 ///
161 static LogicalResult simplifyPassThroughBr(BranchOp op,
162                                            PatternRewriter &rewriter) {
163   Block *dest = op.getDest();
164   ValueRange destOperands = op.getOperands();
165   SmallVector<Value, 4> destOperandStorage;
166 
167   // Try to collapse the successor if it points somewhere other than this
168   // block.
169   if (dest == op->getBlock() ||
170       failed(collapseBranch(dest, destOperands, destOperandStorage)))
171     return failure();
172 
173   // Create a new branch with the collapsed successor.
174   rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
175   return success();
176 }
177 
178 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
179   return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
180                  succeeded(simplifyPassThroughBr(op, rewriter)));
181 }
182 
183 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
184 
185 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
186 
187 SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
188   assert(index == 0 && "invalid successor index");
189   return SuccessorOperands(getDestOperandsMutable());
190 }
191 
192 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
193   return getDest();
194 }
195 
196 //===----------------------------------------------------------------------===//
197 // CondBranchOp
198 //===----------------------------------------------------------------------===//
199 
200 namespace {
201 /// cf.cond_br true, ^bb1, ^bb2
202 ///  -> br ^bb1
203 /// cf.cond_br false, ^bb1, ^bb2
204 ///  -> br ^bb2
205 ///
206 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
207   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
208 
209   LogicalResult matchAndRewrite(CondBranchOp condbr,
210                                 PatternRewriter &rewriter) const override {
211     if (matchPattern(condbr.getCondition(), m_NonZero())) {
212       // True branch taken.
213       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
214                                             condbr.getTrueOperands());
215       return success();
216     }
217     if (matchPattern(condbr.getCondition(), m_Zero())) {
218       // False branch taken.
219       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
220                                             condbr.getFalseOperands());
221       return success();
222     }
223     return failure();
224   }
225 };
226 
227 ///   cf.cond_br %cond, ^bb1, ^bb2
228 /// ^bb1
229 ///   br ^bbN(...)
230 /// ^bb2
231 ///   br ^bbK(...)
232 ///
233 ///  -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
234 ///
235 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
236   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
237 
238   LogicalResult matchAndRewrite(CondBranchOp condbr,
239                                 PatternRewriter &rewriter) const override {
240     Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
241     ValueRange trueDestOperands = condbr.getTrueOperands();
242     ValueRange falseDestOperands = condbr.getFalseOperands();
243     SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
244 
245     // Try to collapse one of the current successors.
246     LogicalResult collapsedTrue =
247         collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
248     LogicalResult collapsedFalse =
249         collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
250     if (failed(collapsedTrue) && failed(collapsedFalse))
251       return failure();
252 
253     // Create a new branch with the collapsed successors.
254     rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
255                                               trueDest, trueDestOperands,
256                                               falseDest, falseDestOperands);
257     return success();
258   }
259 };
260 
261 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
262 ///  -> br ^bb1(A, ..., N)
263 ///
264 /// cf.cond_br %cond, ^bb1(A), ^bb1(B)
265 ///  -> %select = arith.select %cond, A, B
266 ///     br ^bb1(%select)
267 ///
268 struct SimplifyCondBranchIdenticalSuccessors
269     : public OpRewritePattern<CondBranchOp> {
270   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
271 
272   LogicalResult matchAndRewrite(CondBranchOp condbr,
273                                 PatternRewriter &rewriter) const override {
274     // Check that the true and false destinations are the same and have the same
275     // operands.
276     Block *trueDest = condbr.getTrueDest();
277     if (trueDest != condbr.getFalseDest())
278       return failure();
279 
280     // If all of the operands match, no selects need to be generated.
281     OperandRange trueOperands = condbr.getTrueOperands();
282     OperandRange falseOperands = condbr.getFalseOperands();
283     if (trueOperands == falseOperands) {
284       rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
285       return success();
286     }
287 
288     // Otherwise, if the current block is the only predecessor insert selects
289     // for any mismatched branch operands.
290     if (trueDest->getUniquePredecessor() != condbr->getBlock())
291       return failure();
292 
293     // Generate a select for any operands that differ between the two.
294     SmallVector<Value, 8> mergedOperands;
295     mergedOperands.reserve(trueOperands.size());
296     Value condition = condbr.getCondition();
297     for (auto it : llvm::zip(trueOperands, falseOperands)) {
298       if (std::get<0>(it) == std::get<1>(it))
299         mergedOperands.push_back(std::get<0>(it));
300       else
301         mergedOperands.push_back(rewriter.create<arith::SelectOp>(
302             condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
303     }
304 
305     rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
306     return success();
307   }
308 };
309 
310 ///   ...
311 ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
312 /// ...
313 /// ^bb1: // has single predecessor
314 ///   ...
315 ///   cf.cond_br %cond, ^bb3(...), ^bb4(...)
316 ///
317 /// ->
318 ///
319 ///   ...
320 ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
321 /// ...
322 /// ^bb1: // has single predecessor
323 ///   ...
324 ///   br ^bb3(...)
325 ///
326 struct SimplifyCondBranchFromCondBranchOnSameCondition
327     : public OpRewritePattern<CondBranchOp> {
328   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
329 
330   LogicalResult matchAndRewrite(CondBranchOp condbr,
331                                 PatternRewriter &rewriter) const override {
332     // Check that we have a single distinct predecessor.
333     Block *currentBlock = condbr->getBlock();
334     Block *predecessor = currentBlock->getSinglePredecessor();
335     if (!predecessor)
336       return failure();
337 
338     // Check that the predecessor terminates with a conditional branch to this
339     // block and that it branches on the same condition.
340     auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
341     if (!predBranch || condbr.getCondition() != predBranch.getCondition())
342       return failure();
343 
344     // Fold this branch to an unconditional branch.
345     if (currentBlock == predBranch.getTrueDest())
346       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
347                                             condbr.getTrueDestOperands());
348     else
349       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
350                                             condbr.getFalseDestOperands());
351     return success();
352   }
353 };
354 
355 ///   cf.cond_br %arg0, ^trueB, ^falseB
356 ///
357 /// ^trueB:
358 ///   "test.consumer1"(%arg0) : (i1) -> ()
359 ///    ...
360 ///
361 /// ^falseB:
362 ///   "test.consumer2"(%arg0) : (i1) -> ()
363 ///   ...
364 ///
365 /// ->
366 ///
367 ///   cf.cond_br %arg0, ^trueB, ^falseB
368 /// ^trueB:
369 ///   "test.consumer1"(%true) : (i1) -> ()
370 ///   ...
371 ///
372 /// ^falseB:
373 ///   "test.consumer2"(%false) : (i1) -> ()
374 ///   ...
375 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
376   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
377 
378   LogicalResult matchAndRewrite(CondBranchOp condbr,
379                                 PatternRewriter &rewriter) const override {
380     // Check that we have a single distinct predecessor.
381     bool replaced = false;
382     Type ty = rewriter.getI1Type();
383 
384     // These variables serve to prevent creating duplicate constants
385     // and hold constant true or false values.
386     Value constantTrue = nullptr;
387     Value constantFalse = nullptr;
388 
389     // TODO These checks can be expanded to encompas any use with only
390     // either the true of false edge as a predecessor. For now, we fall
391     // back to checking the single predecessor is given by the true/fasle
392     // destination, thereby ensuring that only that edge can reach the
393     // op.
394     if (condbr.getTrueDest()->getSinglePredecessor()) {
395       for (OpOperand &use :
396            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
397         if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
398           replaced = true;
399 
400           if (!constantTrue)
401             constantTrue = rewriter.create<arith::ConstantOp>(
402                 condbr.getLoc(), ty, rewriter.getBoolAttr(true));
403 
404           rewriter.updateRootInPlace(use.getOwner(),
405                                      [&] { use.set(constantTrue); });
406         }
407       }
408     }
409     if (condbr.getFalseDest()->getSinglePredecessor()) {
410       for (OpOperand &use :
411            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
412         if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
413           replaced = true;
414 
415           if (!constantFalse)
416             constantFalse = rewriter.create<arith::ConstantOp>(
417                 condbr.getLoc(), ty, rewriter.getBoolAttr(false));
418 
419           rewriter.updateRootInPlace(use.getOwner(),
420                                      [&] { use.set(constantFalse); });
421         }
422       }
423     }
424     return success(replaced);
425   }
426 };
427 } // namespace
428 
429 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
430                                                MLIRContext *context) {
431   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
432               SimplifyCondBranchIdenticalSuccessors,
433               SimplifyCondBranchFromCondBranchOnSameCondition,
434               CondBranchTruthPropagation>(context);
435 }
436 
437 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
438   assert(index < getNumSuccessors() && "invalid successor index");
439   return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
440                                               : getFalseDestOperandsMutable());
441 }
442 
443 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
444   if (IntegerAttr condAttr =
445           llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
446     return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
447   return nullptr;
448 }
449 
450 //===----------------------------------------------------------------------===//
451 // SwitchOp
452 //===----------------------------------------------------------------------===//
453 
454 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
455                      Block *defaultDestination, ValueRange defaultOperands,
456                      DenseIntElementsAttr caseValues,
457                      BlockRange caseDestinations,
458                      ArrayRef<ValueRange> caseOperands) {
459   build(builder, result, value, defaultOperands, caseOperands, caseValues,
460         defaultDestination, caseDestinations);
461 }
462 
463 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
464                      Block *defaultDestination, ValueRange defaultOperands,
465                      ArrayRef<APInt> caseValues, BlockRange caseDestinations,
466                      ArrayRef<ValueRange> caseOperands) {
467   DenseIntElementsAttr caseValuesAttr;
468   if (!caseValues.empty()) {
469     ShapedType caseValueType = VectorType::get(
470         static_cast<int64_t>(caseValues.size()), value.getType());
471     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
472   }
473   build(builder, result, value, defaultDestination, defaultOperands,
474         caseValuesAttr, caseDestinations, caseOperands);
475 }
476 
477 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
478                      Block *defaultDestination, ValueRange defaultOperands,
479                      ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
480                      ArrayRef<ValueRange> caseOperands) {
481   DenseIntElementsAttr caseValuesAttr;
482   if (!caseValues.empty()) {
483     ShapedType caseValueType = VectorType::get(
484         static_cast<int64_t>(caseValues.size()), value.getType());
485     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
486   }
487   build(builder, result, value, defaultDestination, defaultOperands,
488         caseValuesAttr, caseDestinations, caseOperands);
489 }
490 
491 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
492 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
493 static ParseResult parseSwitchOpCases(
494     OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
495     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
496     SmallVectorImpl<Type> &defaultOperandTypes,
497     DenseIntElementsAttr &caseValues,
498     SmallVectorImpl<Block *> &caseDestinations,
499     SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
500     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
501   if (parser.parseKeyword("default") || parser.parseColon() ||
502       parser.parseSuccessor(defaultDestination))
503     return failure();
504   if (succeeded(parser.parseOptionalLParen())) {
505     if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
506                                 /*allowResultNumber=*/false) ||
507         parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
508       return failure();
509   }
510 
511   SmallVector<APInt> values;
512   unsigned bitWidth = flagType.getIntOrFloatBitWidth();
513   while (succeeded(parser.parseOptionalComma())) {
514     int64_t value = 0;
515     if (failed(parser.parseInteger(value)))
516       return failure();
517     values.push_back(APInt(bitWidth, value));
518 
519     Block *destination;
520     SmallVector<OpAsmParser::UnresolvedOperand> operands;
521     SmallVector<Type> operandTypes;
522     if (failed(parser.parseColon()) ||
523         failed(parser.parseSuccessor(destination)))
524       return failure();
525     if (succeeded(parser.parseOptionalLParen())) {
526       if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
527                                          /*allowResultNumber=*/false)) ||
528           failed(parser.parseColonTypeList(operandTypes)) ||
529           failed(parser.parseRParen()))
530         return failure();
531     }
532     caseDestinations.push_back(destination);
533     caseOperands.emplace_back(operands);
534     caseOperandTypes.emplace_back(operandTypes);
535   }
536 
537   if (!values.empty()) {
538     ShapedType caseValueType =
539         VectorType::get(static_cast<int64_t>(values.size()), flagType);
540     caseValues = DenseIntElementsAttr::get(caseValueType, values);
541   }
542   return success();
543 }
544 
545 static void printSwitchOpCases(
546     OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
547     OperandRange defaultOperands, TypeRange defaultOperandTypes,
548     DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
549     OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
550   p << "  default: ";
551   p.printSuccessorAndUseList(defaultDestination, defaultOperands);
552 
553   if (!caseValues)
554     return;
555 
556   for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
557     p << ',';
558     p.printNewline();
559     p << "  ";
560     p << it.value().getLimitedValue();
561     p << ": ";
562     p.printSuccessorAndUseList(caseDestinations[it.index()],
563                                caseOperands[it.index()]);
564   }
565   p.printNewline();
566 }
567 
568 LogicalResult SwitchOp::verify() {
569   auto caseValues = getCaseValues();
570   auto caseDestinations = getCaseDestinations();
571 
572   if (!caseValues && caseDestinations.empty())
573     return success();
574 
575   Type flagType = getFlag().getType();
576   Type caseValueType = caseValues->getType().getElementType();
577   if (caseValueType != flagType)
578     return emitOpError() << "'flag' type (" << flagType
579                          << ") should match case value type (" << caseValueType
580                          << ")";
581 
582   if (caseValues &&
583       caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
584     return emitOpError() << "number of case values (" << caseValues->size()
585                          << ") should match number of "
586                             "case destinations ("
587                          << caseDestinations.size() << ")";
588   return success();
589 }
590 
591 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
592   assert(index < getNumSuccessors() && "invalid successor index");
593   return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
594                                       : getCaseOperandsMutable(index - 1));
595 }
596 
597 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
598   std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
599 
600   if (!caseValues)
601     return getDefaultDestination();
602 
603   SuccessorRange caseDests = getCaseDestinations();
604   if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
605     for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
606       if (it.value() == value.getValue())
607         return caseDests[it.index()];
608     return getDefaultDestination();
609   }
610   return nullptr;
611 }
612 
613 /// switch %flag : i32, [
614 ///   default:  ^bb1
615 /// ]
616 ///  -> br ^bb1
617 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
618                                                    PatternRewriter &rewriter) {
619   if (!op.getCaseDestinations().empty())
620     return failure();
621 
622   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
623                                         op.getDefaultOperands());
624   return success();
625 }
626 
627 /// switch %flag : i32, [
628 ///   default: ^bb1,
629 ///   42: ^bb1,
630 ///   43: ^bb2
631 /// ]
632 /// ->
633 /// switch %flag : i32, [
634 ///   default: ^bb1,
635 ///   43: ^bb2
636 /// ]
637 static LogicalResult
638 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
639   SmallVector<Block *> newCaseDestinations;
640   SmallVector<ValueRange> newCaseOperands;
641   SmallVector<APInt> newCaseValues;
642   bool requiresChange = false;
643   auto caseValues = op.getCaseValues();
644   auto caseDests = op.getCaseDestinations();
645 
646   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
647     if (caseDests[it.index()] == op.getDefaultDestination() &&
648         op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
649       requiresChange = true;
650       continue;
651     }
652     newCaseDestinations.push_back(caseDests[it.index()]);
653     newCaseOperands.push_back(op.getCaseOperands(it.index()));
654     newCaseValues.push_back(it.value());
655   }
656 
657   if (!requiresChange)
658     return failure();
659 
660   rewriter.replaceOpWithNewOp<SwitchOp>(
661       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
662       newCaseValues, newCaseDestinations, newCaseOperands);
663   return success();
664 }
665 
666 /// Helper for folding a switch with a constant value.
667 /// switch %c_42 : i32, [
668 ///   default: ^bb1 ,
669 ///   42: ^bb2,
670 ///   43: ^bb3
671 /// ]
672 /// -> br ^bb2
673 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
674                        const APInt &caseValue) {
675   auto caseValues = op.getCaseValues();
676   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
677     if (it.value() == caseValue) {
678       rewriter.replaceOpWithNewOp<BranchOp>(
679           op, op.getCaseDestinations()[it.index()],
680           op.getCaseOperands(it.index()));
681       return;
682     }
683   }
684   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
685                                         op.getDefaultOperands());
686 }
687 
688 /// switch %c_42 : i32, [
689 ///   default: ^bb1,
690 ///   42: ^bb2,
691 ///   43: ^bb3
692 /// ]
693 /// -> br ^bb2
694 static LogicalResult simplifyConstSwitchValue(SwitchOp op,
695                                               PatternRewriter &rewriter) {
696   APInt caseValue;
697   if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
698     return failure();
699 
700   foldSwitch(op, rewriter, caseValue);
701   return success();
702 }
703 
704 /// switch %c_42 : i32, [
705 ///   default: ^bb1,
706 ///   42: ^bb2,
707 /// ]
708 /// ^bb2:
709 ///   br ^bb3
710 /// ->
711 /// switch %c_42 : i32, [
712 ///   default: ^bb1,
713 ///   42: ^bb3,
714 /// ]
715 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
716                                                PatternRewriter &rewriter) {
717   SmallVector<Block *> newCaseDests;
718   SmallVector<ValueRange> newCaseOperands;
719   SmallVector<SmallVector<Value>> argStorage;
720   auto caseValues = op.getCaseValues();
721   argStorage.reserve(caseValues->size() + 1);
722   auto caseDests = op.getCaseDestinations();
723   bool requiresChange = false;
724   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
725     Block *caseDest = caseDests[i];
726     ValueRange caseOperands = op.getCaseOperands(i);
727     argStorage.emplace_back();
728     if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
729       requiresChange = true;
730 
731     newCaseDests.push_back(caseDest);
732     newCaseOperands.push_back(caseOperands);
733   }
734 
735   Block *defaultDest = op.getDefaultDestination();
736   ValueRange defaultOperands = op.getDefaultOperands();
737   argStorage.emplace_back();
738 
739   if (succeeded(
740           collapseBranch(defaultDest, defaultOperands, argStorage.back())))
741     requiresChange = true;
742 
743   if (!requiresChange)
744     return failure();
745 
746   rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
747                                         defaultOperands, *caseValues,
748                                         newCaseDests, newCaseOperands);
749   return success();
750 }
751 
752 /// switch %flag : i32, [
753 ///   default: ^bb1,
754 ///   42: ^bb2,
755 /// ]
756 /// ^bb2:
757 ///   switch %flag : i32, [
758 ///     default: ^bb3,
759 ///     42: ^bb4
760 ///   ]
761 /// ->
762 /// switch %flag : i32, [
763 ///   default: ^bb1,
764 ///   42: ^bb2,
765 /// ]
766 /// ^bb2:
767 ///   br ^bb4
768 ///
769 ///  and
770 ///
771 /// switch %flag : i32, [
772 ///   default: ^bb1,
773 ///   42: ^bb2,
774 /// ]
775 /// ^bb2:
776 ///   switch %flag : i32, [
777 ///     default: ^bb3,
778 ///     43: ^bb4
779 ///   ]
780 /// ->
781 /// switch %flag : i32, [
782 ///   default: ^bb1,
783 ///   42: ^bb2,
784 /// ]
785 /// ^bb2:
786 ///   br ^bb3
787 static LogicalResult
788 simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
789                                         PatternRewriter &rewriter) {
790   // Check that we have a single distinct predecessor.
791   Block *currentBlock = op->getBlock();
792   Block *predecessor = currentBlock->getSinglePredecessor();
793   if (!predecessor)
794     return failure();
795 
796   // Check that the predecessor terminates with a switch branch to this block
797   // and that it branches on the same condition and that this branch isn't the
798   // default destination.
799   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
800   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
801       predSwitch.getDefaultDestination() == currentBlock)
802     return failure();
803 
804   // Fold this switch to an unconditional branch.
805   SuccessorRange predDests = predSwitch.getCaseDestinations();
806   auto it = llvm::find(predDests, currentBlock);
807   if (it != predDests.end()) {
808     std::optional<DenseIntElementsAttr> predCaseValues =
809         predSwitch.getCaseValues();
810     foldSwitch(op, rewriter,
811                predCaseValues->getValues<APInt>()[it - predDests.begin()]);
812   } else {
813     rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
814                                           op.getDefaultOperands());
815   }
816   return success();
817 }
818 
819 /// switch %flag : i32, [
820 ///   default: ^bb1,
821 ///   42: ^bb2
822 /// ]
823 /// ^bb1:
824 ///   switch %flag : i32, [
825 ///     default: ^bb3,
826 ///     42: ^bb4,
827 ///     43: ^bb5
828 ///   ]
829 /// ->
830 /// switch %flag : i32, [
831 ///   default: ^bb1,
832 ///   42: ^bb2,
833 /// ]
834 /// ^bb1:
835 ///   switch %flag : i32, [
836 ///     default: ^bb3,
837 ///     43: ^bb5
838 ///   ]
839 static LogicalResult
840 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
841                                                PatternRewriter &rewriter) {
842   // Check that we have a single distinct predecessor.
843   Block *currentBlock = op->getBlock();
844   Block *predecessor = currentBlock->getSinglePredecessor();
845   if (!predecessor)
846     return failure();
847 
848   // Check that the predecessor terminates with a switch branch to this block
849   // and that it branches on the same condition and that this branch is the
850   // default destination.
851   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
852   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
853       predSwitch.getDefaultDestination() != currentBlock)
854     return failure();
855 
856   // Delete case values that are not possible here.
857   DenseSet<APInt> caseValuesToRemove;
858   auto predDests = predSwitch.getCaseDestinations();
859   auto predCaseValues = predSwitch.getCaseValues();
860   for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
861     if (currentBlock != predDests[i])
862       caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
863 
864   SmallVector<Block *> newCaseDestinations;
865   SmallVector<ValueRange> newCaseOperands;
866   SmallVector<APInt> newCaseValues;
867   bool requiresChange = false;
868 
869   auto caseValues = op.getCaseValues();
870   auto caseDests = op.getCaseDestinations();
871   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
872     if (caseValuesToRemove.contains(it.value())) {
873       requiresChange = true;
874       continue;
875     }
876     newCaseDestinations.push_back(caseDests[it.index()]);
877     newCaseOperands.push_back(op.getCaseOperands(it.index()));
878     newCaseValues.push_back(it.value());
879   }
880 
881   if (!requiresChange)
882     return failure();
883 
884   rewriter.replaceOpWithNewOp<SwitchOp>(
885       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
886       newCaseValues, newCaseDestinations, newCaseOperands);
887   return success();
888 }
889 
890 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
891                                            MLIRContext *context) {
892   results.add(&simplifySwitchWithOnlyDefault)
893       .add(&dropSwitchCasesThatMatchDefault)
894       .add(&simplifyConstSwitchValue)
895       .add(&simplifyPassThroughSwitch)
896       .add(&simplifySwitchFromSwitchOnSameCondition)
897       .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
898 }
899 
900 //===----------------------------------------------------------------------===//
901 // TableGen'd op method definitions
902 //===----------------------------------------------------------------------===//
903 
904 #define GET_OP_CLASSES
905 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
906