xref: /llvm-project/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (revision 0fb216fb2fbb49c1fe90c1c3267873a100b1c356)
1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10 
11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
14 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/OpImplementation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Transforms/InliningUtils.h"
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/Support/FormatVariadic.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <numeric>
32 
33 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
34 
35 using namespace mlir;
36 using namespace mlir::cf;
37 
38 //===----------------------------------------------------------------------===//
39 // ControlFlowDialect Interfaces
40 //===----------------------------------------------------------------------===//
41 namespace {
42 /// This class defines the interface for handling inlining with control flow
43 /// operations.
44 struct ControlFlowInlinerInterface : public DialectInlinerInterface {
45   using DialectInlinerInterface::DialectInlinerInterface;
46   ~ControlFlowInlinerInterface() override = default;
47 
48   /// All control flow operations can be inlined.
49   bool isLegalToInline(Operation *call, Operation *callable,
50                        bool wouldBeCloned) const final {
51     return true;
52   }
53   bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
54     return true;
55   }
56 
57   /// ControlFlow terminator operations don't really need any special handing.
58   void handleTerminator(Operation *op, Block *newDest) const final {}
59 };
60 } // namespace
61 
62 //===----------------------------------------------------------------------===//
63 // ControlFlowDialect
64 //===----------------------------------------------------------------------===//
65 
66 void ControlFlowDialect::initialize() {
67   addOperations<
68 #define GET_OP_LIST
69 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
70       >();
71   addInterfaces<ControlFlowInlinerInterface>();
72   declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
73   declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
74                             CondBranchOp>();
75   declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
76                            CondBranchOp>();
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // AssertOp
81 //===----------------------------------------------------------------------===//
82 
83 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
84   // Erase assertion if argument is constant true.
85   if (matchPattern(op.getArg(), m_One())) {
86     rewriter.eraseOp(op);
87     return success();
88   }
89   return failure();
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // BranchOp
94 //===----------------------------------------------------------------------===//
95 
96 /// Given a successor, try to collapse it to a new destination if it only
97 /// contains a passthrough unconditional branch. If the successor is
98 /// collapsable, `successor` and `successorOperands` are updated to reference
99 /// the new destination and values. `argStorage` is used as storage if operands
100 /// to the collapsed successor need to be remapped. It must outlive uses of
101 /// successorOperands.
102 static LogicalResult collapseBranch(Block *&successor,
103                                     ValueRange &successorOperands,
104                                     SmallVectorImpl<Value> &argStorage) {
105   // Check that the successor only contains a unconditional branch.
106   if (std::next(successor->begin()) != successor->end())
107     return failure();
108   // Check that the terminator is an unconditional branch.
109   BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
110   if (!successorBranch)
111     return failure();
112   // Check that the arguments are only used within the terminator.
113   for (BlockArgument arg : successor->getArguments()) {
114     for (Operation *user : arg.getUsers())
115       if (user != successorBranch)
116         return failure();
117   }
118   // Don't try to collapse branches to infinite loops.
119   Block *successorDest = successorBranch.getDest();
120   if (successorDest == successor)
121     return failure();
122 
123   // Update the operands to the successor. If the branch parent has no
124   // arguments, we can use the branch operands directly.
125   OperandRange operands = successorBranch.getOperands();
126   if (successor->args_empty()) {
127     successor = successorDest;
128     successorOperands = operands;
129     return success();
130   }
131 
132   // Otherwise, we need to remap any argument operands.
133   for (Value operand : operands) {
134     BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
135     if (argOperand && argOperand.getOwner() == successor)
136       argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
137     else
138       argStorage.push_back(operand);
139   }
140   successor = successorDest;
141   successorOperands = argStorage;
142   return success();
143 }
144 
145 /// Simplify a branch to a block that has a single predecessor. This effectively
146 /// merges the two blocks.
147 static LogicalResult
148 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
149   // Check that the successor block has a single predecessor.
150   Block *succ = op.getDest();
151   Block *opParent = op->getBlock();
152   if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
153     return failure();
154 
155   // Merge the successor into the current block and erase the branch.
156   SmallVector<Value> brOperands(op.getOperands());
157   rewriter.eraseOp(op);
158   rewriter.mergeBlocks(succ, opParent, brOperands);
159   return success();
160 }
161 
162 ///   br ^bb1
163 /// ^bb1
164 ///   br ^bbN(...)
165 ///
166 ///  -> br ^bbN(...)
167 ///
168 static LogicalResult simplifyPassThroughBr(BranchOp op,
169                                            PatternRewriter &rewriter) {
170   Block *dest = op.getDest();
171   ValueRange destOperands = op.getOperands();
172   SmallVector<Value, 4> destOperandStorage;
173 
174   // Try to collapse the successor if it points somewhere other than this
175   // block.
176   if (dest == op->getBlock() ||
177       failed(collapseBranch(dest, destOperands, destOperandStorage)))
178     return failure();
179 
180   // Create a new branch with the collapsed successor.
181   rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
182   return success();
183 }
184 
185 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
186   return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
187                  succeeded(simplifyPassThroughBr(op, rewriter)));
188 }
189 
190 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
191 
192 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
193 
194 SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
195   assert(index == 0 && "invalid successor index");
196   return SuccessorOperands(getDestOperandsMutable());
197 }
198 
199 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
200   return getDest();
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // CondBranchOp
205 //===----------------------------------------------------------------------===//
206 
207 namespace {
208 /// cf.cond_br true, ^bb1, ^bb2
209 ///  -> br ^bb1
210 /// cf.cond_br false, ^bb1, ^bb2
211 ///  -> br ^bb2
212 ///
213 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
214   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
215 
216   LogicalResult matchAndRewrite(CondBranchOp condbr,
217                                 PatternRewriter &rewriter) const override {
218     if (matchPattern(condbr.getCondition(), m_NonZero())) {
219       // True branch taken.
220       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
221                                             condbr.getTrueOperands());
222       return success();
223     }
224     if (matchPattern(condbr.getCondition(), m_Zero())) {
225       // False branch taken.
226       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
227                                             condbr.getFalseOperands());
228       return success();
229     }
230     return failure();
231   }
232 };
233 
234 ///   cf.cond_br %cond, ^bb1, ^bb2
235 /// ^bb1
236 ///   br ^bbN(...)
237 /// ^bb2
238 ///   br ^bbK(...)
239 ///
240 ///  -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
241 ///
242 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
243   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
244 
245   LogicalResult matchAndRewrite(CondBranchOp condbr,
246                                 PatternRewriter &rewriter) const override {
247     Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
248     ValueRange trueDestOperands = condbr.getTrueOperands();
249     ValueRange falseDestOperands = condbr.getFalseOperands();
250     SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
251 
252     // Try to collapse one of the current successors.
253     LogicalResult collapsedTrue =
254         collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
255     LogicalResult collapsedFalse =
256         collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
257     if (failed(collapsedTrue) && failed(collapsedFalse))
258       return failure();
259 
260     // Create a new branch with the collapsed successors.
261     rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
262                                               trueDest, trueDestOperands,
263                                               falseDest, falseDestOperands);
264     return success();
265   }
266 };
267 
268 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
269 ///  -> br ^bb1(A, ..., N)
270 ///
271 /// cf.cond_br %cond, ^bb1(A), ^bb1(B)
272 ///  -> %select = arith.select %cond, A, B
273 ///     br ^bb1(%select)
274 ///
275 struct SimplifyCondBranchIdenticalSuccessors
276     : public OpRewritePattern<CondBranchOp> {
277   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
278 
279   LogicalResult matchAndRewrite(CondBranchOp condbr,
280                                 PatternRewriter &rewriter) const override {
281     // Check that the true and false destinations are the same and have the same
282     // operands.
283     Block *trueDest = condbr.getTrueDest();
284     if (trueDest != condbr.getFalseDest())
285       return failure();
286 
287     // If all of the operands match, no selects need to be generated.
288     OperandRange trueOperands = condbr.getTrueOperands();
289     OperandRange falseOperands = condbr.getFalseOperands();
290     if (trueOperands == falseOperands) {
291       rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
292       return success();
293     }
294 
295     // Otherwise, if the current block is the only predecessor insert selects
296     // for any mismatched branch operands.
297     if (trueDest->getUniquePredecessor() != condbr->getBlock())
298       return failure();
299 
300     // Generate a select for any operands that differ between the two.
301     SmallVector<Value, 8> mergedOperands;
302     mergedOperands.reserve(trueOperands.size());
303     Value condition = condbr.getCondition();
304     for (auto it : llvm::zip(trueOperands, falseOperands)) {
305       if (std::get<0>(it) == std::get<1>(it))
306         mergedOperands.push_back(std::get<0>(it));
307       else
308         mergedOperands.push_back(rewriter.create<arith::SelectOp>(
309             condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
310     }
311 
312     rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
313     return success();
314   }
315 };
316 
317 ///   ...
318 ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
319 /// ...
320 /// ^bb1: // has single predecessor
321 ///   ...
322 ///   cf.cond_br %cond, ^bb3(...), ^bb4(...)
323 ///
324 /// ->
325 ///
326 ///   ...
327 ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
328 /// ...
329 /// ^bb1: // has single predecessor
330 ///   ...
331 ///   br ^bb3(...)
332 ///
333 struct SimplifyCondBranchFromCondBranchOnSameCondition
334     : public OpRewritePattern<CondBranchOp> {
335   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
336 
337   LogicalResult matchAndRewrite(CondBranchOp condbr,
338                                 PatternRewriter &rewriter) const override {
339     // Check that we have a single distinct predecessor.
340     Block *currentBlock = condbr->getBlock();
341     Block *predecessor = currentBlock->getSinglePredecessor();
342     if (!predecessor)
343       return failure();
344 
345     // Check that the predecessor terminates with a conditional branch to this
346     // block and that it branches on the same condition.
347     auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
348     if (!predBranch || condbr.getCondition() != predBranch.getCondition())
349       return failure();
350 
351     // Fold this branch to an unconditional branch.
352     if (currentBlock == predBranch.getTrueDest())
353       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
354                                             condbr.getTrueDestOperands());
355     else
356       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
357                                             condbr.getFalseDestOperands());
358     return success();
359   }
360 };
361 
362 ///   cf.cond_br %arg0, ^trueB, ^falseB
363 ///
364 /// ^trueB:
365 ///   "test.consumer1"(%arg0) : (i1) -> ()
366 ///    ...
367 ///
368 /// ^falseB:
369 ///   "test.consumer2"(%arg0) : (i1) -> ()
370 ///   ...
371 ///
372 /// ->
373 ///
374 ///   cf.cond_br %arg0, ^trueB, ^falseB
375 /// ^trueB:
376 ///   "test.consumer1"(%true) : (i1) -> ()
377 ///   ...
378 ///
379 /// ^falseB:
380 ///   "test.consumer2"(%false) : (i1) -> ()
381 ///   ...
382 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
383   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
384 
385   LogicalResult matchAndRewrite(CondBranchOp condbr,
386                                 PatternRewriter &rewriter) const override {
387     // Check that we have a single distinct predecessor.
388     bool replaced = false;
389     Type ty = rewriter.getI1Type();
390 
391     // These variables serve to prevent creating duplicate constants
392     // and hold constant true or false values.
393     Value constantTrue = nullptr;
394     Value constantFalse = nullptr;
395 
396     // TODO These checks can be expanded to encompas any use with only
397     // either the true of false edge as a predecessor. For now, we fall
398     // back to checking the single predecessor is given by the true/fasle
399     // destination, thereby ensuring that only that edge can reach the
400     // op.
401     if (condbr.getTrueDest()->getSinglePredecessor()) {
402       for (OpOperand &use :
403            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
404         if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
405           replaced = true;
406 
407           if (!constantTrue)
408             constantTrue = rewriter.create<arith::ConstantOp>(
409                 condbr.getLoc(), ty, rewriter.getBoolAttr(true));
410 
411           rewriter.modifyOpInPlace(use.getOwner(),
412                                    [&] { use.set(constantTrue); });
413         }
414       }
415     }
416     if (condbr.getFalseDest()->getSinglePredecessor()) {
417       for (OpOperand &use :
418            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
419         if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
420           replaced = true;
421 
422           if (!constantFalse)
423             constantFalse = rewriter.create<arith::ConstantOp>(
424                 condbr.getLoc(), ty, rewriter.getBoolAttr(false));
425 
426           rewriter.modifyOpInPlace(use.getOwner(),
427                                    [&] { use.set(constantFalse); });
428         }
429       }
430     }
431     return success(replaced);
432   }
433 };
434 } // namespace
435 
436 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
437                                                MLIRContext *context) {
438   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
439               SimplifyCondBranchIdenticalSuccessors,
440               SimplifyCondBranchFromCondBranchOnSameCondition,
441               CondBranchTruthPropagation>(context);
442 }
443 
444 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
445   assert(index < getNumSuccessors() && "invalid successor index");
446   return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
447                                               : getFalseDestOperandsMutable());
448 }
449 
450 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
451   if (IntegerAttr condAttr =
452           llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
453     return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
454   return nullptr;
455 }
456 
457 //===----------------------------------------------------------------------===//
458 // SwitchOp
459 //===----------------------------------------------------------------------===//
460 
461 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
462                      Block *defaultDestination, ValueRange defaultOperands,
463                      DenseIntElementsAttr caseValues,
464                      BlockRange caseDestinations,
465                      ArrayRef<ValueRange> caseOperands) {
466   build(builder, result, value, defaultOperands, caseOperands, caseValues,
467         defaultDestination, caseDestinations);
468 }
469 
470 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
471                      Block *defaultDestination, ValueRange defaultOperands,
472                      ArrayRef<APInt> caseValues, BlockRange caseDestinations,
473                      ArrayRef<ValueRange> caseOperands) {
474   DenseIntElementsAttr caseValuesAttr;
475   if (!caseValues.empty()) {
476     ShapedType caseValueType = VectorType::get(
477         static_cast<int64_t>(caseValues.size()), value.getType());
478     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
479   }
480   build(builder, result, value, defaultDestination, defaultOperands,
481         caseValuesAttr, caseDestinations, caseOperands);
482 }
483 
484 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
485                      Block *defaultDestination, ValueRange defaultOperands,
486                      ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
487                      ArrayRef<ValueRange> caseOperands) {
488   DenseIntElementsAttr caseValuesAttr;
489   if (!caseValues.empty()) {
490     ShapedType caseValueType = VectorType::get(
491         static_cast<int64_t>(caseValues.size()), value.getType());
492     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
493   }
494   build(builder, result, value, defaultDestination, defaultOperands,
495         caseValuesAttr, caseDestinations, caseOperands);
496 }
497 
498 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
499 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
500 static ParseResult parseSwitchOpCases(
501     OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
502     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
503     SmallVectorImpl<Type> &defaultOperandTypes,
504     DenseIntElementsAttr &caseValues,
505     SmallVectorImpl<Block *> &caseDestinations,
506     SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
507     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
508   if (parser.parseKeyword("default") || parser.parseColon() ||
509       parser.parseSuccessor(defaultDestination))
510     return failure();
511   if (succeeded(parser.parseOptionalLParen())) {
512     if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
513                                 /*allowResultNumber=*/false) ||
514         parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
515       return failure();
516   }
517 
518   SmallVector<APInt> values;
519   unsigned bitWidth = flagType.getIntOrFloatBitWidth();
520   while (succeeded(parser.parseOptionalComma())) {
521     int64_t value = 0;
522     if (failed(parser.parseInteger(value)))
523       return failure();
524     values.push_back(APInt(bitWidth, value));
525 
526     Block *destination;
527     SmallVector<OpAsmParser::UnresolvedOperand> operands;
528     SmallVector<Type> operandTypes;
529     if (failed(parser.parseColon()) ||
530         failed(parser.parseSuccessor(destination)))
531       return failure();
532     if (succeeded(parser.parseOptionalLParen())) {
533       if (failed(parser.parseOperandList(operands,
534                                          OpAsmParser::Delimiter::None)) ||
535           failed(parser.parseColonTypeList(operandTypes)) ||
536           failed(parser.parseRParen()))
537         return failure();
538     }
539     caseDestinations.push_back(destination);
540     caseOperands.emplace_back(operands);
541     caseOperandTypes.emplace_back(operandTypes);
542   }
543 
544   if (!values.empty()) {
545     ShapedType caseValueType =
546         VectorType::get(static_cast<int64_t>(values.size()), flagType);
547     caseValues = DenseIntElementsAttr::get(caseValueType, values);
548   }
549   return success();
550 }
551 
552 static void printSwitchOpCases(
553     OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
554     OperandRange defaultOperands, TypeRange defaultOperandTypes,
555     DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
556     OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
557   p << "  default: ";
558   p.printSuccessorAndUseList(defaultDestination, defaultOperands);
559 
560   if (!caseValues)
561     return;
562 
563   for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
564     p << ',';
565     p.printNewline();
566     p << "  ";
567     p << it.value().getLimitedValue();
568     p << ": ";
569     p.printSuccessorAndUseList(caseDestinations[it.index()],
570                                caseOperands[it.index()]);
571   }
572   p.printNewline();
573 }
574 
575 LogicalResult SwitchOp::verify() {
576   auto caseValues = getCaseValues();
577   auto caseDestinations = getCaseDestinations();
578 
579   if (!caseValues && caseDestinations.empty())
580     return success();
581 
582   Type flagType = getFlag().getType();
583   Type caseValueType = caseValues->getType().getElementType();
584   if (caseValueType != flagType)
585     return emitOpError() << "'flag' type (" << flagType
586                          << ") should match case value type (" << caseValueType
587                          << ")";
588 
589   if (caseValues &&
590       caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
591     return emitOpError() << "number of case values (" << caseValues->size()
592                          << ") should match number of "
593                             "case destinations ("
594                          << caseDestinations.size() << ")";
595   return success();
596 }
597 
598 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
599   assert(index < getNumSuccessors() && "invalid successor index");
600   return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
601                                       : getCaseOperandsMutable(index - 1));
602 }
603 
604 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
605   std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
606 
607   if (!caseValues)
608     return getDefaultDestination();
609 
610   SuccessorRange caseDests = getCaseDestinations();
611   if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
612     for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
613       if (it.value() == value.getValue())
614         return caseDests[it.index()];
615     return getDefaultDestination();
616   }
617   return nullptr;
618 }
619 
620 /// switch %flag : i32, [
621 ///   default:  ^bb1
622 /// ]
623 ///  -> br ^bb1
624 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
625                                                    PatternRewriter &rewriter) {
626   if (!op.getCaseDestinations().empty())
627     return failure();
628 
629   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
630                                         op.getDefaultOperands());
631   return success();
632 }
633 
634 /// switch %flag : i32, [
635 ///   default: ^bb1,
636 ///   42: ^bb1,
637 ///   43: ^bb2
638 /// ]
639 /// ->
640 /// switch %flag : i32, [
641 ///   default: ^bb1,
642 ///   43: ^bb2
643 /// ]
644 static LogicalResult
645 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
646   SmallVector<Block *> newCaseDestinations;
647   SmallVector<ValueRange> newCaseOperands;
648   SmallVector<APInt> newCaseValues;
649   bool requiresChange = false;
650   auto caseValues = op.getCaseValues();
651   auto caseDests = op.getCaseDestinations();
652 
653   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
654     if (caseDests[it.index()] == op.getDefaultDestination() &&
655         op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
656       requiresChange = true;
657       continue;
658     }
659     newCaseDestinations.push_back(caseDests[it.index()]);
660     newCaseOperands.push_back(op.getCaseOperands(it.index()));
661     newCaseValues.push_back(it.value());
662   }
663 
664   if (!requiresChange)
665     return failure();
666 
667   rewriter.replaceOpWithNewOp<SwitchOp>(
668       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
669       newCaseValues, newCaseDestinations, newCaseOperands);
670   return success();
671 }
672 
673 /// Helper for folding a switch with a constant value.
674 /// switch %c_42 : i32, [
675 ///   default: ^bb1 ,
676 ///   42: ^bb2,
677 ///   43: ^bb3
678 /// ]
679 /// -> br ^bb2
680 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
681                        const APInt &caseValue) {
682   auto caseValues = op.getCaseValues();
683   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
684     if (it.value() == caseValue) {
685       rewriter.replaceOpWithNewOp<BranchOp>(
686           op, op.getCaseDestinations()[it.index()],
687           op.getCaseOperands(it.index()));
688       return;
689     }
690   }
691   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
692                                         op.getDefaultOperands());
693 }
694 
695 /// switch %c_42 : i32, [
696 ///   default: ^bb1,
697 ///   42: ^bb2,
698 ///   43: ^bb3
699 /// ]
700 /// -> br ^bb2
701 static LogicalResult simplifyConstSwitchValue(SwitchOp op,
702                                               PatternRewriter &rewriter) {
703   APInt caseValue;
704   if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
705     return failure();
706 
707   foldSwitch(op, rewriter, caseValue);
708   return success();
709 }
710 
711 /// switch %c_42 : i32, [
712 ///   default: ^bb1,
713 ///   42: ^bb2,
714 /// ]
715 /// ^bb2:
716 ///   br ^bb3
717 /// ->
718 /// switch %c_42 : i32, [
719 ///   default: ^bb1,
720 ///   42: ^bb3,
721 /// ]
722 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
723                                                PatternRewriter &rewriter) {
724   SmallVector<Block *> newCaseDests;
725   SmallVector<ValueRange> newCaseOperands;
726   SmallVector<SmallVector<Value>> argStorage;
727   auto caseValues = op.getCaseValues();
728   argStorage.reserve(caseValues->size() + 1);
729   auto caseDests = op.getCaseDestinations();
730   bool requiresChange = false;
731   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
732     Block *caseDest = caseDests[i];
733     ValueRange caseOperands = op.getCaseOperands(i);
734     argStorage.emplace_back();
735     if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
736       requiresChange = true;
737 
738     newCaseDests.push_back(caseDest);
739     newCaseOperands.push_back(caseOperands);
740   }
741 
742   Block *defaultDest = op.getDefaultDestination();
743   ValueRange defaultOperands = op.getDefaultOperands();
744   argStorage.emplace_back();
745 
746   if (succeeded(
747           collapseBranch(defaultDest, defaultOperands, argStorage.back())))
748     requiresChange = true;
749 
750   if (!requiresChange)
751     return failure();
752 
753   rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
754                                         defaultOperands, *caseValues,
755                                         newCaseDests, newCaseOperands);
756   return success();
757 }
758 
759 /// switch %flag : i32, [
760 ///   default: ^bb1,
761 ///   42: ^bb2,
762 /// ]
763 /// ^bb2:
764 ///   switch %flag : i32, [
765 ///     default: ^bb3,
766 ///     42: ^bb4
767 ///   ]
768 /// ->
769 /// switch %flag : i32, [
770 ///   default: ^bb1,
771 ///   42: ^bb2,
772 /// ]
773 /// ^bb2:
774 ///   br ^bb4
775 ///
776 ///  and
777 ///
778 /// switch %flag : i32, [
779 ///   default: ^bb1,
780 ///   42: ^bb2,
781 /// ]
782 /// ^bb2:
783 ///   switch %flag : i32, [
784 ///     default: ^bb3,
785 ///     43: ^bb4
786 ///   ]
787 /// ->
788 /// switch %flag : i32, [
789 ///   default: ^bb1,
790 ///   42: ^bb2,
791 /// ]
792 /// ^bb2:
793 ///   br ^bb3
794 static LogicalResult
795 simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
796                                         PatternRewriter &rewriter) {
797   // Check that we have a single distinct predecessor.
798   Block *currentBlock = op->getBlock();
799   Block *predecessor = currentBlock->getSinglePredecessor();
800   if (!predecessor)
801     return failure();
802 
803   // Check that the predecessor terminates with a switch branch to this block
804   // and that it branches on the same condition and that this branch isn't the
805   // default destination.
806   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
807   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
808       predSwitch.getDefaultDestination() == currentBlock)
809     return failure();
810 
811   // Fold this switch to an unconditional branch.
812   SuccessorRange predDests = predSwitch.getCaseDestinations();
813   auto it = llvm::find(predDests, currentBlock);
814   if (it != predDests.end()) {
815     std::optional<DenseIntElementsAttr> predCaseValues =
816         predSwitch.getCaseValues();
817     foldSwitch(op, rewriter,
818                predCaseValues->getValues<APInt>()[it - predDests.begin()]);
819   } else {
820     rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
821                                           op.getDefaultOperands());
822   }
823   return success();
824 }
825 
826 /// switch %flag : i32, [
827 ///   default: ^bb1,
828 ///   42: ^bb2
829 /// ]
830 /// ^bb1:
831 ///   switch %flag : i32, [
832 ///     default: ^bb3,
833 ///     42: ^bb4,
834 ///     43: ^bb5
835 ///   ]
836 /// ->
837 /// switch %flag : i32, [
838 ///   default: ^bb1,
839 ///   42: ^bb2,
840 /// ]
841 /// ^bb1:
842 ///   switch %flag : i32, [
843 ///     default: ^bb3,
844 ///     43: ^bb5
845 ///   ]
846 static LogicalResult
847 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
848                                                PatternRewriter &rewriter) {
849   // Check that we have a single distinct predecessor.
850   Block *currentBlock = op->getBlock();
851   Block *predecessor = currentBlock->getSinglePredecessor();
852   if (!predecessor)
853     return failure();
854 
855   // Check that the predecessor terminates with a switch branch to this block
856   // and that it branches on the same condition and that this branch is the
857   // default destination.
858   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
859   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
860       predSwitch.getDefaultDestination() != currentBlock)
861     return failure();
862 
863   // Delete case values that are not possible here.
864   DenseSet<APInt> caseValuesToRemove;
865   auto predDests = predSwitch.getCaseDestinations();
866   auto predCaseValues = predSwitch.getCaseValues();
867   for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
868     if (currentBlock != predDests[i])
869       caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
870 
871   SmallVector<Block *> newCaseDestinations;
872   SmallVector<ValueRange> newCaseOperands;
873   SmallVector<APInt> newCaseValues;
874   bool requiresChange = false;
875 
876   auto caseValues = op.getCaseValues();
877   auto caseDests = op.getCaseDestinations();
878   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
879     if (caseValuesToRemove.contains(it.value())) {
880       requiresChange = true;
881       continue;
882     }
883     newCaseDestinations.push_back(caseDests[it.index()]);
884     newCaseOperands.push_back(op.getCaseOperands(it.index()));
885     newCaseValues.push_back(it.value());
886   }
887 
888   if (!requiresChange)
889     return failure();
890 
891   rewriter.replaceOpWithNewOp<SwitchOp>(
892       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
893       newCaseValues, newCaseDestinations, newCaseOperands);
894   return success();
895 }
896 
897 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
898                                            MLIRContext *context) {
899   results.add(&simplifySwitchWithOnlyDefault)
900       .add(&dropSwitchCasesThatMatchDefault)
901       .add(&simplifyConstSwitchValue)
902       .add(&simplifyPassThroughSwitch)
903       .add(&simplifySwitchFromSwitchOnSameCondition)
904       .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
905 }
906 
907 //===----------------------------------------------------------------------===//
908 // TableGen'd op method definitions
909 //===----------------------------------------------------------------------===//
910 
911 #define GET_OP_CLASSES
912 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
913