xref: /llvm-project/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (revision b43c50490c5964b3b1aa1b95a9025a5b5942a46e)
1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10 
11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/IR/AffineExpr.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/IRMapping.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Value.h"
24 #include "mlir/Support/MathExtras.h"
25 #include "mlir/Transforms/InliningUtils.h"
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <numeric>
31 
32 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
33 
34 using namespace mlir;
35 using namespace mlir::cf;
36 
37 //===----------------------------------------------------------------------===//
38 // ControlFlowDialect Interfaces
39 //===----------------------------------------------------------------------===//
40 namespace {
41 /// This class defines the interface for handling inlining with control flow
42 /// operations.
43 struct ControlFlowInlinerInterface : public DialectInlinerInterface {
44   using DialectInlinerInterface::DialectInlinerInterface;
45   ~ControlFlowInlinerInterface() override = default;
46 
47   /// All control flow operations can be inlined.
48   bool isLegalToInline(Operation *call, Operation *callable,
49                        bool wouldBeCloned) const final {
50     return true;
51   }
52   bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
53     return true;
54   }
55 
56   /// ControlFlow terminator operations don't really need any special handing.
57   void handleTerminator(Operation *op, Block *newDest) const final {}
58 };
59 } // namespace
60 
61 //===----------------------------------------------------------------------===//
62 // ControlFlowDialect
63 //===----------------------------------------------------------------------===//
64 
65 void ControlFlowDialect::initialize() {
66   addOperations<
67 #define GET_OP_LIST
68 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
69       >();
70   addInterfaces<ControlFlowInlinerInterface>();
71   declarePromisedInterface<ControlFlowDialect, ConvertToLLVMPatternInterface>();
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // AssertOp
76 //===----------------------------------------------------------------------===//
77 
78 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
79   // Erase assertion if argument is constant true.
80   if (matchPattern(op.getArg(), m_One())) {
81     rewriter.eraseOp(op);
82     return success();
83   }
84   return failure();
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // BranchOp
89 //===----------------------------------------------------------------------===//
90 
91 /// Given a successor, try to collapse it to a new destination if it only
92 /// contains a passthrough unconditional branch. If the successor is
93 /// collapsable, `successor` and `successorOperands` are updated to reference
94 /// the new destination and values. `argStorage` is used as storage if operands
95 /// to the collapsed successor need to be remapped. It must outlive uses of
96 /// successorOperands.
97 static LogicalResult collapseBranch(Block *&successor,
98                                     ValueRange &successorOperands,
99                                     SmallVectorImpl<Value> &argStorage) {
100   // Check that the successor only contains a unconditional branch.
101   if (std::next(successor->begin()) != successor->end())
102     return failure();
103   // Check that the terminator is an unconditional branch.
104   BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
105   if (!successorBranch)
106     return failure();
107   // Check that the arguments are only used within the terminator.
108   for (BlockArgument arg : successor->getArguments()) {
109     for (Operation *user : arg.getUsers())
110       if (user != successorBranch)
111         return failure();
112   }
113   // Don't try to collapse branches to infinite loops.
114   Block *successorDest = successorBranch.getDest();
115   if (successorDest == successor)
116     return failure();
117 
118   // Update the operands to the successor. If the branch parent has no
119   // arguments, we can use the branch operands directly.
120   OperandRange operands = successorBranch.getOperands();
121   if (successor->args_empty()) {
122     successor = successorDest;
123     successorOperands = operands;
124     return success();
125   }
126 
127   // Otherwise, we need to remap any argument operands.
128   for (Value operand : operands) {
129     BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
130     if (argOperand && argOperand.getOwner() == successor)
131       argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
132     else
133       argStorage.push_back(operand);
134   }
135   successor = successorDest;
136   successorOperands = argStorage;
137   return success();
138 }
139 
140 /// Simplify a branch to a block that has a single predecessor. This effectively
141 /// merges the two blocks.
142 static LogicalResult
143 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
144   // Check that the successor block has a single predecessor.
145   Block *succ = op.getDest();
146   Block *opParent = op->getBlock();
147   if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
148     return failure();
149 
150   // Merge the successor into the current block and erase the branch.
151   SmallVector<Value> brOperands(op.getOperands());
152   rewriter.eraseOp(op);
153   rewriter.mergeBlocks(succ, opParent, brOperands);
154   return success();
155 }
156 
157 ///   br ^bb1
158 /// ^bb1
159 ///   br ^bbN(...)
160 ///
161 ///  -> br ^bbN(...)
162 ///
163 static LogicalResult simplifyPassThroughBr(BranchOp op,
164                                            PatternRewriter &rewriter) {
165   Block *dest = op.getDest();
166   ValueRange destOperands = op.getOperands();
167   SmallVector<Value, 4> destOperandStorage;
168 
169   // Try to collapse the successor if it points somewhere other than this
170   // block.
171   if (dest == op->getBlock() ||
172       failed(collapseBranch(dest, destOperands, destOperandStorage)))
173     return failure();
174 
175   // Create a new branch with the collapsed successor.
176   rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
177   return success();
178 }
179 
180 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
181   return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
182                  succeeded(simplifyPassThroughBr(op, rewriter)));
183 }
184 
185 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
186 
187 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
188 
189 SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
190   assert(index == 0 && "invalid successor index");
191   return SuccessorOperands(getDestOperandsMutable());
192 }
193 
194 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
195   return getDest();
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // CondBranchOp
200 //===----------------------------------------------------------------------===//
201 
202 namespace {
203 /// cf.cond_br true, ^bb1, ^bb2
204 ///  -> br ^bb1
205 /// cf.cond_br false, ^bb1, ^bb2
206 ///  -> br ^bb2
207 ///
208 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
209   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
210 
211   LogicalResult matchAndRewrite(CondBranchOp condbr,
212                                 PatternRewriter &rewriter) const override {
213     if (matchPattern(condbr.getCondition(), m_NonZero())) {
214       // True branch taken.
215       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
216                                             condbr.getTrueOperands());
217       return success();
218     }
219     if (matchPattern(condbr.getCondition(), m_Zero())) {
220       // False branch taken.
221       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
222                                             condbr.getFalseOperands());
223       return success();
224     }
225     return failure();
226   }
227 };
228 
229 ///   cf.cond_br %cond, ^bb1, ^bb2
230 /// ^bb1
231 ///   br ^bbN(...)
232 /// ^bb2
233 ///   br ^bbK(...)
234 ///
235 ///  -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
236 ///
237 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
238   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
239 
240   LogicalResult matchAndRewrite(CondBranchOp condbr,
241                                 PatternRewriter &rewriter) const override {
242     Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
243     ValueRange trueDestOperands = condbr.getTrueOperands();
244     ValueRange falseDestOperands = condbr.getFalseOperands();
245     SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
246 
247     // Try to collapse one of the current successors.
248     LogicalResult collapsedTrue =
249         collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
250     LogicalResult collapsedFalse =
251         collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
252     if (failed(collapsedTrue) && failed(collapsedFalse))
253       return failure();
254 
255     // Create a new branch with the collapsed successors.
256     rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
257                                               trueDest, trueDestOperands,
258                                               falseDest, falseDestOperands);
259     return success();
260   }
261 };
262 
263 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
264 ///  -> br ^bb1(A, ..., N)
265 ///
266 /// cf.cond_br %cond, ^bb1(A), ^bb1(B)
267 ///  -> %select = arith.select %cond, A, B
268 ///     br ^bb1(%select)
269 ///
270 struct SimplifyCondBranchIdenticalSuccessors
271     : public OpRewritePattern<CondBranchOp> {
272   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
273 
274   LogicalResult matchAndRewrite(CondBranchOp condbr,
275                                 PatternRewriter &rewriter) const override {
276     // Check that the true and false destinations are the same and have the same
277     // operands.
278     Block *trueDest = condbr.getTrueDest();
279     if (trueDest != condbr.getFalseDest())
280       return failure();
281 
282     // If all of the operands match, no selects need to be generated.
283     OperandRange trueOperands = condbr.getTrueOperands();
284     OperandRange falseOperands = condbr.getFalseOperands();
285     if (trueOperands == falseOperands) {
286       rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
287       return success();
288     }
289 
290     // Otherwise, if the current block is the only predecessor insert selects
291     // for any mismatched branch operands.
292     if (trueDest->getUniquePredecessor() != condbr->getBlock())
293       return failure();
294 
295     // Generate a select for any operands that differ between the two.
296     SmallVector<Value, 8> mergedOperands;
297     mergedOperands.reserve(trueOperands.size());
298     Value condition = condbr.getCondition();
299     for (auto it : llvm::zip(trueOperands, falseOperands)) {
300       if (std::get<0>(it) == std::get<1>(it))
301         mergedOperands.push_back(std::get<0>(it));
302       else
303         mergedOperands.push_back(rewriter.create<arith::SelectOp>(
304             condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
305     }
306 
307     rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
308     return success();
309   }
310 };
311 
312 ///   ...
313 ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
314 /// ...
315 /// ^bb1: // has single predecessor
316 ///   ...
317 ///   cf.cond_br %cond, ^bb3(...), ^bb4(...)
318 ///
319 /// ->
320 ///
321 ///   ...
322 ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
323 /// ...
324 /// ^bb1: // has single predecessor
325 ///   ...
326 ///   br ^bb3(...)
327 ///
328 struct SimplifyCondBranchFromCondBranchOnSameCondition
329     : public OpRewritePattern<CondBranchOp> {
330   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
331 
332   LogicalResult matchAndRewrite(CondBranchOp condbr,
333                                 PatternRewriter &rewriter) const override {
334     // Check that we have a single distinct predecessor.
335     Block *currentBlock = condbr->getBlock();
336     Block *predecessor = currentBlock->getSinglePredecessor();
337     if (!predecessor)
338       return failure();
339 
340     // Check that the predecessor terminates with a conditional branch to this
341     // block and that it branches on the same condition.
342     auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
343     if (!predBranch || condbr.getCondition() != predBranch.getCondition())
344       return failure();
345 
346     // Fold this branch to an unconditional branch.
347     if (currentBlock == predBranch.getTrueDest())
348       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
349                                             condbr.getTrueDestOperands());
350     else
351       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
352                                             condbr.getFalseDestOperands());
353     return success();
354   }
355 };
356 
357 ///   cf.cond_br %arg0, ^trueB, ^falseB
358 ///
359 /// ^trueB:
360 ///   "test.consumer1"(%arg0) : (i1) -> ()
361 ///    ...
362 ///
363 /// ^falseB:
364 ///   "test.consumer2"(%arg0) : (i1) -> ()
365 ///   ...
366 ///
367 /// ->
368 ///
369 ///   cf.cond_br %arg0, ^trueB, ^falseB
370 /// ^trueB:
371 ///   "test.consumer1"(%true) : (i1) -> ()
372 ///   ...
373 ///
374 /// ^falseB:
375 ///   "test.consumer2"(%false) : (i1) -> ()
376 ///   ...
377 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
378   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
379 
380   LogicalResult matchAndRewrite(CondBranchOp condbr,
381                                 PatternRewriter &rewriter) const override {
382     // Check that we have a single distinct predecessor.
383     bool replaced = false;
384     Type ty = rewriter.getI1Type();
385 
386     // These variables serve to prevent creating duplicate constants
387     // and hold constant true or false values.
388     Value constantTrue = nullptr;
389     Value constantFalse = nullptr;
390 
391     // TODO These checks can be expanded to encompas any use with only
392     // either the true of false edge as a predecessor. For now, we fall
393     // back to checking the single predecessor is given by the true/fasle
394     // destination, thereby ensuring that only that edge can reach the
395     // op.
396     if (condbr.getTrueDest()->getSinglePredecessor()) {
397       for (OpOperand &use :
398            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
399         if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
400           replaced = true;
401 
402           if (!constantTrue)
403             constantTrue = rewriter.create<arith::ConstantOp>(
404                 condbr.getLoc(), ty, rewriter.getBoolAttr(true));
405 
406           rewriter.updateRootInPlace(use.getOwner(),
407                                      [&] { use.set(constantTrue); });
408         }
409       }
410     }
411     if (condbr.getFalseDest()->getSinglePredecessor()) {
412       for (OpOperand &use :
413            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
414         if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
415           replaced = true;
416 
417           if (!constantFalse)
418             constantFalse = rewriter.create<arith::ConstantOp>(
419                 condbr.getLoc(), ty, rewriter.getBoolAttr(false));
420 
421           rewriter.updateRootInPlace(use.getOwner(),
422                                      [&] { use.set(constantFalse); });
423         }
424       }
425     }
426     return success(replaced);
427   }
428 };
429 } // namespace
430 
431 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
432                                                MLIRContext *context) {
433   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
434               SimplifyCondBranchIdenticalSuccessors,
435               SimplifyCondBranchFromCondBranchOnSameCondition,
436               CondBranchTruthPropagation>(context);
437 }
438 
439 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
440   assert(index < getNumSuccessors() && "invalid successor index");
441   return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
442                                               : getFalseDestOperandsMutable());
443 }
444 
445 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
446   if (IntegerAttr condAttr =
447           llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
448     return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
449   return nullptr;
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // SwitchOp
454 //===----------------------------------------------------------------------===//
455 
456 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
457                      Block *defaultDestination, ValueRange defaultOperands,
458                      DenseIntElementsAttr caseValues,
459                      BlockRange caseDestinations,
460                      ArrayRef<ValueRange> caseOperands) {
461   build(builder, result, value, defaultOperands, caseOperands, caseValues,
462         defaultDestination, caseDestinations);
463 }
464 
465 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
466                      Block *defaultDestination, ValueRange defaultOperands,
467                      ArrayRef<APInt> caseValues, BlockRange caseDestinations,
468                      ArrayRef<ValueRange> caseOperands) {
469   DenseIntElementsAttr caseValuesAttr;
470   if (!caseValues.empty()) {
471     ShapedType caseValueType = VectorType::get(
472         static_cast<int64_t>(caseValues.size()), value.getType());
473     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
474   }
475   build(builder, result, value, defaultDestination, defaultOperands,
476         caseValuesAttr, caseDestinations, caseOperands);
477 }
478 
479 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
480                      Block *defaultDestination, ValueRange defaultOperands,
481                      ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
482                      ArrayRef<ValueRange> caseOperands) {
483   DenseIntElementsAttr caseValuesAttr;
484   if (!caseValues.empty()) {
485     ShapedType caseValueType = VectorType::get(
486         static_cast<int64_t>(caseValues.size()), value.getType());
487     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
488   }
489   build(builder, result, value, defaultDestination, defaultOperands,
490         caseValuesAttr, caseDestinations, caseOperands);
491 }
492 
493 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
494 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
495 static ParseResult parseSwitchOpCases(
496     OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
497     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
498     SmallVectorImpl<Type> &defaultOperandTypes,
499     DenseIntElementsAttr &caseValues,
500     SmallVectorImpl<Block *> &caseDestinations,
501     SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
502     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
503   if (parser.parseKeyword("default") || parser.parseColon() ||
504       parser.parseSuccessor(defaultDestination))
505     return failure();
506   if (succeeded(parser.parseOptionalLParen())) {
507     if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
508                                 /*allowResultNumber=*/false) ||
509         parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
510       return failure();
511   }
512 
513   SmallVector<APInt> values;
514   unsigned bitWidth = flagType.getIntOrFloatBitWidth();
515   while (succeeded(parser.parseOptionalComma())) {
516     int64_t value = 0;
517     if (failed(parser.parseInteger(value)))
518       return failure();
519     values.push_back(APInt(bitWidth, value));
520 
521     Block *destination;
522     SmallVector<OpAsmParser::UnresolvedOperand> operands;
523     SmallVector<Type> operandTypes;
524     if (failed(parser.parseColon()) ||
525         failed(parser.parseSuccessor(destination)))
526       return failure();
527     if (succeeded(parser.parseOptionalLParen())) {
528       if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
529                                          /*allowResultNumber=*/false)) ||
530           failed(parser.parseColonTypeList(operandTypes)) ||
531           failed(parser.parseRParen()))
532         return failure();
533     }
534     caseDestinations.push_back(destination);
535     caseOperands.emplace_back(operands);
536     caseOperandTypes.emplace_back(operandTypes);
537   }
538 
539   if (!values.empty()) {
540     ShapedType caseValueType =
541         VectorType::get(static_cast<int64_t>(values.size()), flagType);
542     caseValues = DenseIntElementsAttr::get(caseValueType, values);
543   }
544   return success();
545 }
546 
547 static void printSwitchOpCases(
548     OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
549     OperandRange defaultOperands, TypeRange defaultOperandTypes,
550     DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
551     OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
552   p << "  default: ";
553   p.printSuccessorAndUseList(defaultDestination, defaultOperands);
554 
555   if (!caseValues)
556     return;
557 
558   for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
559     p << ',';
560     p.printNewline();
561     p << "  ";
562     p << it.value().getLimitedValue();
563     p << ": ";
564     p.printSuccessorAndUseList(caseDestinations[it.index()],
565                                caseOperands[it.index()]);
566   }
567   p.printNewline();
568 }
569 
570 LogicalResult SwitchOp::verify() {
571   auto caseValues = getCaseValues();
572   auto caseDestinations = getCaseDestinations();
573 
574   if (!caseValues && caseDestinations.empty())
575     return success();
576 
577   Type flagType = getFlag().getType();
578   Type caseValueType = caseValues->getType().getElementType();
579   if (caseValueType != flagType)
580     return emitOpError() << "'flag' type (" << flagType
581                          << ") should match case value type (" << caseValueType
582                          << ")";
583 
584   if (caseValues &&
585       caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
586     return emitOpError() << "number of case values (" << caseValues->size()
587                          << ") should match number of "
588                             "case destinations ("
589                          << caseDestinations.size() << ")";
590   return success();
591 }
592 
593 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
594   assert(index < getNumSuccessors() && "invalid successor index");
595   return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
596                                       : getCaseOperandsMutable(index - 1));
597 }
598 
599 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
600   std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
601 
602   if (!caseValues)
603     return getDefaultDestination();
604 
605   SuccessorRange caseDests = getCaseDestinations();
606   if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
607     for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
608       if (it.value() == value.getValue())
609         return caseDests[it.index()];
610     return getDefaultDestination();
611   }
612   return nullptr;
613 }
614 
615 /// switch %flag : i32, [
616 ///   default:  ^bb1
617 /// ]
618 ///  -> br ^bb1
619 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
620                                                    PatternRewriter &rewriter) {
621   if (!op.getCaseDestinations().empty())
622     return failure();
623 
624   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
625                                         op.getDefaultOperands());
626   return success();
627 }
628 
629 /// switch %flag : i32, [
630 ///   default: ^bb1,
631 ///   42: ^bb1,
632 ///   43: ^bb2
633 /// ]
634 /// ->
635 /// switch %flag : i32, [
636 ///   default: ^bb1,
637 ///   43: ^bb2
638 /// ]
639 static LogicalResult
640 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
641   SmallVector<Block *> newCaseDestinations;
642   SmallVector<ValueRange> newCaseOperands;
643   SmallVector<APInt> newCaseValues;
644   bool requiresChange = false;
645   auto caseValues = op.getCaseValues();
646   auto caseDests = op.getCaseDestinations();
647 
648   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
649     if (caseDests[it.index()] == op.getDefaultDestination() &&
650         op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
651       requiresChange = true;
652       continue;
653     }
654     newCaseDestinations.push_back(caseDests[it.index()]);
655     newCaseOperands.push_back(op.getCaseOperands(it.index()));
656     newCaseValues.push_back(it.value());
657   }
658 
659   if (!requiresChange)
660     return failure();
661 
662   rewriter.replaceOpWithNewOp<SwitchOp>(
663       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
664       newCaseValues, newCaseDestinations, newCaseOperands);
665   return success();
666 }
667 
668 /// Helper for folding a switch with a constant value.
669 /// switch %c_42 : i32, [
670 ///   default: ^bb1 ,
671 ///   42: ^bb2,
672 ///   43: ^bb3
673 /// ]
674 /// -> br ^bb2
675 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
676                        const APInt &caseValue) {
677   auto caseValues = op.getCaseValues();
678   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
679     if (it.value() == caseValue) {
680       rewriter.replaceOpWithNewOp<BranchOp>(
681           op, op.getCaseDestinations()[it.index()],
682           op.getCaseOperands(it.index()));
683       return;
684     }
685   }
686   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
687                                         op.getDefaultOperands());
688 }
689 
690 /// switch %c_42 : i32, [
691 ///   default: ^bb1,
692 ///   42: ^bb2,
693 ///   43: ^bb3
694 /// ]
695 /// -> br ^bb2
696 static LogicalResult simplifyConstSwitchValue(SwitchOp op,
697                                               PatternRewriter &rewriter) {
698   APInt caseValue;
699   if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
700     return failure();
701 
702   foldSwitch(op, rewriter, caseValue);
703   return success();
704 }
705 
706 /// switch %c_42 : i32, [
707 ///   default: ^bb1,
708 ///   42: ^bb2,
709 /// ]
710 /// ^bb2:
711 ///   br ^bb3
712 /// ->
713 /// switch %c_42 : i32, [
714 ///   default: ^bb1,
715 ///   42: ^bb3,
716 /// ]
717 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
718                                                PatternRewriter &rewriter) {
719   SmallVector<Block *> newCaseDests;
720   SmallVector<ValueRange> newCaseOperands;
721   SmallVector<SmallVector<Value>> argStorage;
722   auto caseValues = op.getCaseValues();
723   argStorage.reserve(caseValues->size() + 1);
724   auto caseDests = op.getCaseDestinations();
725   bool requiresChange = false;
726   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
727     Block *caseDest = caseDests[i];
728     ValueRange caseOperands = op.getCaseOperands(i);
729     argStorage.emplace_back();
730     if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
731       requiresChange = true;
732 
733     newCaseDests.push_back(caseDest);
734     newCaseOperands.push_back(caseOperands);
735   }
736 
737   Block *defaultDest = op.getDefaultDestination();
738   ValueRange defaultOperands = op.getDefaultOperands();
739   argStorage.emplace_back();
740 
741   if (succeeded(
742           collapseBranch(defaultDest, defaultOperands, argStorage.back())))
743     requiresChange = true;
744 
745   if (!requiresChange)
746     return failure();
747 
748   rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
749                                         defaultOperands, *caseValues,
750                                         newCaseDests, newCaseOperands);
751   return success();
752 }
753 
754 /// switch %flag : i32, [
755 ///   default: ^bb1,
756 ///   42: ^bb2,
757 /// ]
758 /// ^bb2:
759 ///   switch %flag : i32, [
760 ///     default: ^bb3,
761 ///     42: ^bb4
762 ///   ]
763 /// ->
764 /// switch %flag : i32, [
765 ///   default: ^bb1,
766 ///   42: ^bb2,
767 /// ]
768 /// ^bb2:
769 ///   br ^bb4
770 ///
771 ///  and
772 ///
773 /// switch %flag : i32, [
774 ///   default: ^bb1,
775 ///   42: ^bb2,
776 /// ]
777 /// ^bb2:
778 ///   switch %flag : i32, [
779 ///     default: ^bb3,
780 ///     43: ^bb4
781 ///   ]
782 /// ->
783 /// switch %flag : i32, [
784 ///   default: ^bb1,
785 ///   42: ^bb2,
786 /// ]
787 /// ^bb2:
788 ///   br ^bb3
789 static LogicalResult
790 simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
791                                         PatternRewriter &rewriter) {
792   // Check that we have a single distinct predecessor.
793   Block *currentBlock = op->getBlock();
794   Block *predecessor = currentBlock->getSinglePredecessor();
795   if (!predecessor)
796     return failure();
797 
798   // Check that the predecessor terminates with a switch branch to this block
799   // and that it branches on the same condition and that this branch isn't the
800   // default destination.
801   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
802   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
803       predSwitch.getDefaultDestination() == currentBlock)
804     return failure();
805 
806   // Fold this switch to an unconditional branch.
807   SuccessorRange predDests = predSwitch.getCaseDestinations();
808   auto it = llvm::find(predDests, currentBlock);
809   if (it != predDests.end()) {
810     std::optional<DenseIntElementsAttr> predCaseValues =
811         predSwitch.getCaseValues();
812     foldSwitch(op, rewriter,
813                predCaseValues->getValues<APInt>()[it - predDests.begin()]);
814   } else {
815     rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
816                                           op.getDefaultOperands());
817   }
818   return success();
819 }
820 
821 /// switch %flag : i32, [
822 ///   default: ^bb1,
823 ///   42: ^bb2
824 /// ]
825 /// ^bb1:
826 ///   switch %flag : i32, [
827 ///     default: ^bb3,
828 ///     42: ^bb4,
829 ///     43: ^bb5
830 ///   ]
831 /// ->
832 /// switch %flag : i32, [
833 ///   default: ^bb1,
834 ///   42: ^bb2,
835 /// ]
836 /// ^bb1:
837 ///   switch %flag : i32, [
838 ///     default: ^bb3,
839 ///     43: ^bb5
840 ///   ]
841 static LogicalResult
842 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
843                                                PatternRewriter &rewriter) {
844   // Check that we have a single distinct predecessor.
845   Block *currentBlock = op->getBlock();
846   Block *predecessor = currentBlock->getSinglePredecessor();
847   if (!predecessor)
848     return failure();
849 
850   // Check that the predecessor terminates with a switch branch to this block
851   // and that it branches on the same condition and that this branch is the
852   // default destination.
853   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
854   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
855       predSwitch.getDefaultDestination() != currentBlock)
856     return failure();
857 
858   // Delete case values that are not possible here.
859   DenseSet<APInt> caseValuesToRemove;
860   auto predDests = predSwitch.getCaseDestinations();
861   auto predCaseValues = predSwitch.getCaseValues();
862   for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
863     if (currentBlock != predDests[i])
864       caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
865 
866   SmallVector<Block *> newCaseDestinations;
867   SmallVector<ValueRange> newCaseOperands;
868   SmallVector<APInt> newCaseValues;
869   bool requiresChange = false;
870 
871   auto caseValues = op.getCaseValues();
872   auto caseDests = op.getCaseDestinations();
873   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
874     if (caseValuesToRemove.contains(it.value())) {
875       requiresChange = true;
876       continue;
877     }
878     newCaseDestinations.push_back(caseDests[it.index()]);
879     newCaseOperands.push_back(op.getCaseOperands(it.index()));
880     newCaseValues.push_back(it.value());
881   }
882 
883   if (!requiresChange)
884     return failure();
885 
886   rewriter.replaceOpWithNewOp<SwitchOp>(
887       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
888       newCaseValues, newCaseDestinations, newCaseOperands);
889   return success();
890 }
891 
892 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
893                                            MLIRContext *context) {
894   results.add(&simplifySwitchWithOnlyDefault)
895       .add(&dropSwitchCasesThatMatchDefault)
896       .add(&simplifyConstSwitchValue)
897       .add(&simplifyPassThroughSwitch)
898       .add(&simplifySwitchFromSwitchOnSameCondition)
899       .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
900 }
901 
902 //===----------------------------------------------------------------------===//
903 // TableGen'd op method definitions
904 //===----------------------------------------------------------------------===//
905 
906 #define GET_OP_CLASSES
907 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
908