xref: /llvm-project/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (revision e692af85966903614d470a7742ed89d124baf1a6)
1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10 
11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
14 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/OpImplementation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Transforms/InliningUtils.h"
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/Support/FormatVariadic.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <numeric>
32 
33 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
34 
35 using namespace mlir;
36 using namespace mlir::cf;
37 
38 //===----------------------------------------------------------------------===//
39 // ControlFlowDialect Interfaces
40 //===----------------------------------------------------------------------===//
41 namespace {
42 /// This class defines the interface for handling inlining with control flow
43 /// operations.
44 struct ControlFlowInlinerInterface : public DialectInlinerInterface {
45   using DialectInlinerInterface::DialectInlinerInterface;
46   ~ControlFlowInlinerInterface() override = default;
47 
48   /// All control flow operations can be inlined.
49   bool isLegalToInline(Operation *call, Operation *callable,
50                        bool wouldBeCloned) const final {
51     return true;
52   }
53   bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
54     return true;
55   }
56 
57   /// ControlFlow terminator operations don't really need any special handing.
58   void handleTerminator(Operation *op, Block *newDest) const final {}
59 };
60 } // namespace
61 
62 //===----------------------------------------------------------------------===//
63 // ControlFlowDialect
64 //===----------------------------------------------------------------------===//
65 
66 void ControlFlowDialect::initialize() {
67   addOperations<
68 #define GET_OP_LIST
69 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
70       >();
71   addInterfaces<ControlFlowInlinerInterface>();
72   declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
73   declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
74                             CondBranchOp>();
75   declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
76                            CondBranchOp>();
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // AssertOp
81 //===----------------------------------------------------------------------===//
82 
83 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
84   // Erase assertion if argument is constant true.
85   if (matchPattern(op.getArg(), m_One())) {
86     rewriter.eraseOp(op);
87     return success();
88   }
89   return failure();
90 }
91 
92 // This side effect models "program termination".
93 void AssertOp::getEffects(
94     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
95         &effects) {
96   effects.emplace_back(MemoryEffects::Write::get());
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // BranchOp
101 //===----------------------------------------------------------------------===//
102 
103 /// Given a successor, try to collapse it to a new destination if it only
104 /// contains a passthrough unconditional branch. If the successor is
105 /// collapsable, `successor` and `successorOperands` are updated to reference
106 /// the new destination and values. `argStorage` is used as storage if operands
107 /// to the collapsed successor need to be remapped. It must outlive uses of
108 /// successorOperands.
109 static LogicalResult collapseBranch(Block *&successor,
110                                     ValueRange &successorOperands,
111                                     SmallVectorImpl<Value> &argStorage) {
112   // Check that the successor only contains a unconditional branch.
113   if (std::next(successor->begin()) != successor->end())
114     return failure();
115   // Check that the terminator is an unconditional branch.
116   BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
117   if (!successorBranch)
118     return failure();
119   // Check that the arguments are only used within the terminator.
120   for (BlockArgument arg : successor->getArguments()) {
121     for (Operation *user : arg.getUsers())
122       if (user != successorBranch)
123         return failure();
124   }
125   // Don't try to collapse branches to infinite loops.
126   Block *successorDest = successorBranch.getDest();
127   if (successorDest == successor)
128     return failure();
129 
130   // Update the operands to the successor. If the branch parent has no
131   // arguments, we can use the branch operands directly.
132   OperandRange operands = successorBranch.getOperands();
133   if (successor->args_empty()) {
134     successor = successorDest;
135     successorOperands = operands;
136     return success();
137   }
138 
139   // Otherwise, we need to remap any argument operands.
140   for (Value operand : operands) {
141     BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
142     if (argOperand && argOperand.getOwner() == successor)
143       argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
144     else
145       argStorage.push_back(operand);
146   }
147   successor = successorDest;
148   successorOperands = argStorage;
149   return success();
150 }
151 
152 /// Simplify a branch to a block that has a single predecessor. This effectively
153 /// merges the two blocks.
154 static LogicalResult
155 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
156   // Check that the successor block has a single predecessor.
157   Block *succ = op.getDest();
158   Block *opParent = op->getBlock();
159   if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
160     return failure();
161 
162   // Merge the successor into the current block and erase the branch.
163   SmallVector<Value> brOperands(op.getOperands());
164   rewriter.eraseOp(op);
165   rewriter.mergeBlocks(succ, opParent, brOperands);
166   return success();
167 }
168 
169 ///   br ^bb1
170 /// ^bb1
171 ///   br ^bbN(...)
172 ///
173 ///  -> br ^bbN(...)
174 ///
175 static LogicalResult simplifyPassThroughBr(BranchOp op,
176                                            PatternRewriter &rewriter) {
177   Block *dest = op.getDest();
178   ValueRange destOperands = op.getOperands();
179   SmallVector<Value, 4> destOperandStorage;
180 
181   // Try to collapse the successor if it points somewhere other than this
182   // block.
183   if (dest == op->getBlock() ||
184       failed(collapseBranch(dest, destOperands, destOperandStorage)))
185     return failure();
186 
187   // Create a new branch with the collapsed successor.
188   rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
189   return success();
190 }
191 
192 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
193   return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
194                  succeeded(simplifyPassThroughBr(op, rewriter)));
195 }
196 
197 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
198 
199 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
200 
201 SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
202   assert(index == 0 && "invalid successor index");
203   return SuccessorOperands(getDestOperandsMutable());
204 }
205 
206 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
207   return getDest();
208 }
209 
210 //===----------------------------------------------------------------------===//
211 // CondBranchOp
212 //===----------------------------------------------------------------------===//
213 
214 namespace {
215 /// cf.cond_br true, ^bb1, ^bb2
216 ///  -> br ^bb1
217 /// cf.cond_br false, ^bb1, ^bb2
218 ///  -> br ^bb2
219 ///
220 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
221   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
222 
223   LogicalResult matchAndRewrite(CondBranchOp condbr,
224                                 PatternRewriter &rewriter) const override {
225     if (matchPattern(condbr.getCondition(), m_NonZero())) {
226       // True branch taken.
227       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
228                                             condbr.getTrueOperands());
229       return success();
230     }
231     if (matchPattern(condbr.getCondition(), m_Zero())) {
232       // False branch taken.
233       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
234                                             condbr.getFalseOperands());
235       return success();
236     }
237     return failure();
238   }
239 };
240 
241 ///   cf.cond_br %cond, ^bb1, ^bb2
242 /// ^bb1
243 ///   br ^bbN(...)
244 /// ^bb2
245 ///   br ^bbK(...)
246 ///
247 ///  -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
248 ///
249 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
250   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
251 
252   LogicalResult matchAndRewrite(CondBranchOp condbr,
253                                 PatternRewriter &rewriter) const override {
254     Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
255     ValueRange trueDestOperands = condbr.getTrueOperands();
256     ValueRange falseDestOperands = condbr.getFalseOperands();
257     SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
258 
259     // Try to collapse one of the current successors.
260     LogicalResult collapsedTrue =
261         collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
262     LogicalResult collapsedFalse =
263         collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
264     if (failed(collapsedTrue) && failed(collapsedFalse))
265       return failure();
266 
267     // Create a new branch with the collapsed successors.
268     rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
269                                               trueDest, trueDestOperands,
270                                               falseDest, falseDestOperands);
271     return success();
272   }
273 };
274 
275 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
276 ///  -> br ^bb1(A, ..., N)
277 ///
278 /// cf.cond_br %cond, ^bb1(A), ^bb1(B)
279 ///  -> %select = arith.select %cond, A, B
280 ///     br ^bb1(%select)
281 ///
282 struct SimplifyCondBranchIdenticalSuccessors
283     : public OpRewritePattern<CondBranchOp> {
284   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
285 
286   LogicalResult matchAndRewrite(CondBranchOp condbr,
287                                 PatternRewriter &rewriter) const override {
288     // Check that the true and false destinations are the same and have the same
289     // operands.
290     Block *trueDest = condbr.getTrueDest();
291     if (trueDest != condbr.getFalseDest())
292       return failure();
293 
294     // If all of the operands match, no selects need to be generated.
295     OperandRange trueOperands = condbr.getTrueOperands();
296     OperandRange falseOperands = condbr.getFalseOperands();
297     if (trueOperands == falseOperands) {
298       rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
299       return success();
300     }
301 
302     // Otherwise, if the current block is the only predecessor insert selects
303     // for any mismatched branch operands.
304     if (trueDest->getUniquePredecessor() != condbr->getBlock())
305       return failure();
306 
307     // Generate a select for any operands that differ between the two.
308     SmallVector<Value, 8> mergedOperands;
309     mergedOperands.reserve(trueOperands.size());
310     Value condition = condbr.getCondition();
311     for (auto it : llvm::zip(trueOperands, falseOperands)) {
312       if (std::get<0>(it) == std::get<1>(it))
313         mergedOperands.push_back(std::get<0>(it));
314       else
315         mergedOperands.push_back(rewriter.create<arith::SelectOp>(
316             condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
317     }
318 
319     rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
320     return success();
321   }
322 };
323 
324 ///   ...
325 ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
326 /// ...
327 /// ^bb1: // has single predecessor
328 ///   ...
329 ///   cf.cond_br %cond, ^bb3(...), ^bb4(...)
330 ///
331 /// ->
332 ///
333 ///   ...
334 ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
335 /// ...
336 /// ^bb1: // has single predecessor
337 ///   ...
338 ///   br ^bb3(...)
339 ///
340 struct SimplifyCondBranchFromCondBranchOnSameCondition
341     : public OpRewritePattern<CondBranchOp> {
342   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
343 
344   LogicalResult matchAndRewrite(CondBranchOp condbr,
345                                 PatternRewriter &rewriter) const override {
346     // Check that we have a single distinct predecessor.
347     Block *currentBlock = condbr->getBlock();
348     Block *predecessor = currentBlock->getSinglePredecessor();
349     if (!predecessor)
350       return failure();
351 
352     // Check that the predecessor terminates with a conditional branch to this
353     // block and that it branches on the same condition.
354     auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
355     if (!predBranch || condbr.getCondition() != predBranch.getCondition())
356       return failure();
357 
358     // Fold this branch to an unconditional branch.
359     if (currentBlock == predBranch.getTrueDest())
360       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
361                                             condbr.getTrueDestOperands());
362     else
363       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
364                                             condbr.getFalseDestOperands());
365     return success();
366   }
367 };
368 
369 ///   cf.cond_br %arg0, ^trueB, ^falseB
370 ///
371 /// ^trueB:
372 ///   "test.consumer1"(%arg0) : (i1) -> ()
373 ///    ...
374 ///
375 /// ^falseB:
376 ///   "test.consumer2"(%arg0) : (i1) -> ()
377 ///   ...
378 ///
379 /// ->
380 ///
381 ///   cf.cond_br %arg0, ^trueB, ^falseB
382 /// ^trueB:
383 ///   "test.consumer1"(%true) : (i1) -> ()
384 ///   ...
385 ///
386 /// ^falseB:
387 ///   "test.consumer2"(%false) : (i1) -> ()
388 ///   ...
389 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
390   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
391 
392   LogicalResult matchAndRewrite(CondBranchOp condbr,
393                                 PatternRewriter &rewriter) const override {
394     // Check that we have a single distinct predecessor.
395     bool replaced = false;
396     Type ty = rewriter.getI1Type();
397 
398     // These variables serve to prevent creating duplicate constants
399     // and hold constant true or false values.
400     Value constantTrue = nullptr;
401     Value constantFalse = nullptr;
402 
403     // TODO These checks can be expanded to encompas any use with only
404     // either the true of false edge as a predecessor. For now, we fall
405     // back to checking the single predecessor is given by the true/fasle
406     // destination, thereby ensuring that only that edge can reach the
407     // op.
408     if (condbr.getTrueDest()->getSinglePredecessor()) {
409       for (OpOperand &use :
410            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
411         if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
412           replaced = true;
413 
414           if (!constantTrue)
415             constantTrue = rewriter.create<arith::ConstantOp>(
416                 condbr.getLoc(), ty, rewriter.getBoolAttr(true));
417 
418           rewriter.modifyOpInPlace(use.getOwner(),
419                                    [&] { use.set(constantTrue); });
420         }
421       }
422     }
423     if (condbr.getFalseDest()->getSinglePredecessor()) {
424       for (OpOperand &use :
425            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
426         if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
427           replaced = true;
428 
429           if (!constantFalse)
430             constantFalse = rewriter.create<arith::ConstantOp>(
431                 condbr.getLoc(), ty, rewriter.getBoolAttr(false));
432 
433           rewriter.modifyOpInPlace(use.getOwner(),
434                                    [&] { use.set(constantFalse); });
435         }
436       }
437     }
438     return success(replaced);
439   }
440 };
441 } // namespace
442 
443 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
444                                                MLIRContext *context) {
445   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
446               SimplifyCondBranchIdenticalSuccessors,
447               SimplifyCondBranchFromCondBranchOnSameCondition,
448               CondBranchTruthPropagation>(context);
449 }
450 
451 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
452   assert(index < getNumSuccessors() && "invalid successor index");
453   return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
454                                               : getFalseDestOperandsMutable());
455 }
456 
457 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
458   if (IntegerAttr condAttr =
459           llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
460     return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
461   return nullptr;
462 }
463 
464 //===----------------------------------------------------------------------===//
465 // SwitchOp
466 //===----------------------------------------------------------------------===//
467 
468 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
469                      Block *defaultDestination, ValueRange defaultOperands,
470                      DenseIntElementsAttr caseValues,
471                      BlockRange caseDestinations,
472                      ArrayRef<ValueRange> caseOperands) {
473   build(builder, result, value, defaultOperands, caseOperands, caseValues,
474         defaultDestination, caseDestinations);
475 }
476 
477 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
478                      Block *defaultDestination, ValueRange defaultOperands,
479                      ArrayRef<APInt> caseValues, BlockRange caseDestinations,
480                      ArrayRef<ValueRange> caseOperands) {
481   DenseIntElementsAttr caseValuesAttr;
482   if (!caseValues.empty()) {
483     ShapedType caseValueType = VectorType::get(
484         static_cast<int64_t>(caseValues.size()), value.getType());
485     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
486   }
487   build(builder, result, value, defaultDestination, defaultOperands,
488         caseValuesAttr, caseDestinations, caseOperands);
489 }
490 
491 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
492                      Block *defaultDestination, ValueRange defaultOperands,
493                      ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
494                      ArrayRef<ValueRange> caseOperands) {
495   DenseIntElementsAttr caseValuesAttr;
496   if (!caseValues.empty()) {
497     ShapedType caseValueType = VectorType::get(
498         static_cast<int64_t>(caseValues.size()), value.getType());
499     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
500   }
501   build(builder, result, value, defaultDestination, defaultOperands,
502         caseValuesAttr, caseDestinations, caseOperands);
503 }
504 
505 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
506 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
507 static ParseResult parseSwitchOpCases(
508     OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
509     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
510     SmallVectorImpl<Type> &defaultOperandTypes,
511     DenseIntElementsAttr &caseValues,
512     SmallVectorImpl<Block *> &caseDestinations,
513     SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
514     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
515   if (parser.parseKeyword("default") || parser.parseColon() ||
516       parser.parseSuccessor(defaultDestination))
517     return failure();
518   if (succeeded(parser.parseOptionalLParen())) {
519     if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
520                                 /*allowResultNumber=*/false) ||
521         parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
522       return failure();
523   }
524 
525   SmallVector<APInt> values;
526   unsigned bitWidth = flagType.getIntOrFloatBitWidth();
527   while (succeeded(parser.parseOptionalComma())) {
528     int64_t value = 0;
529     if (failed(parser.parseInteger(value)))
530       return failure();
531     values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
532 
533     Block *destination;
534     SmallVector<OpAsmParser::UnresolvedOperand> operands;
535     SmallVector<Type> operandTypes;
536     if (failed(parser.parseColon()) ||
537         failed(parser.parseSuccessor(destination)))
538       return failure();
539     if (succeeded(parser.parseOptionalLParen())) {
540       if (failed(parser.parseOperandList(operands,
541                                          OpAsmParser::Delimiter::None)) ||
542           failed(parser.parseColonTypeList(operandTypes)) ||
543           failed(parser.parseRParen()))
544         return failure();
545     }
546     caseDestinations.push_back(destination);
547     caseOperands.emplace_back(operands);
548     caseOperandTypes.emplace_back(operandTypes);
549   }
550 
551   if (!values.empty()) {
552     ShapedType caseValueType =
553         VectorType::get(static_cast<int64_t>(values.size()), flagType);
554     caseValues = DenseIntElementsAttr::get(caseValueType, values);
555   }
556   return success();
557 }
558 
559 static void printSwitchOpCases(
560     OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
561     OperandRange defaultOperands, TypeRange defaultOperandTypes,
562     DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
563     OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
564   p << "  default: ";
565   p.printSuccessorAndUseList(defaultDestination, defaultOperands);
566 
567   if (!caseValues)
568     return;
569 
570   for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
571     p << ',';
572     p.printNewline();
573     p << "  ";
574     p << it.value().getLimitedValue();
575     p << ": ";
576     p.printSuccessorAndUseList(caseDestinations[it.index()],
577                                caseOperands[it.index()]);
578   }
579   p.printNewline();
580 }
581 
582 LogicalResult SwitchOp::verify() {
583   auto caseValues = getCaseValues();
584   auto caseDestinations = getCaseDestinations();
585 
586   if (!caseValues && caseDestinations.empty())
587     return success();
588 
589   Type flagType = getFlag().getType();
590   Type caseValueType = caseValues->getType().getElementType();
591   if (caseValueType != flagType)
592     return emitOpError() << "'flag' type (" << flagType
593                          << ") should match case value type (" << caseValueType
594                          << ")";
595 
596   if (caseValues &&
597       caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
598     return emitOpError() << "number of case values (" << caseValues->size()
599                          << ") should match number of "
600                             "case destinations ("
601                          << caseDestinations.size() << ")";
602   return success();
603 }
604 
605 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
606   assert(index < getNumSuccessors() && "invalid successor index");
607   return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
608                                       : getCaseOperandsMutable(index - 1));
609 }
610 
611 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
612   std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
613 
614   if (!caseValues)
615     return getDefaultDestination();
616 
617   SuccessorRange caseDests = getCaseDestinations();
618   if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
619     for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
620       if (it.value() == value.getValue())
621         return caseDests[it.index()];
622     return getDefaultDestination();
623   }
624   return nullptr;
625 }
626 
627 /// switch %flag : i32, [
628 ///   default:  ^bb1
629 /// ]
630 ///  -> br ^bb1
631 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
632                                                    PatternRewriter &rewriter) {
633   if (!op.getCaseDestinations().empty())
634     return failure();
635 
636   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
637                                         op.getDefaultOperands());
638   return success();
639 }
640 
641 /// switch %flag : i32, [
642 ///   default: ^bb1,
643 ///   42: ^bb1,
644 ///   43: ^bb2
645 /// ]
646 /// ->
647 /// switch %flag : i32, [
648 ///   default: ^bb1,
649 ///   43: ^bb2
650 /// ]
651 static LogicalResult
652 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
653   SmallVector<Block *> newCaseDestinations;
654   SmallVector<ValueRange> newCaseOperands;
655   SmallVector<APInt> newCaseValues;
656   bool requiresChange = false;
657   auto caseValues = op.getCaseValues();
658   auto caseDests = op.getCaseDestinations();
659 
660   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
661     if (caseDests[it.index()] == op.getDefaultDestination() &&
662         op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
663       requiresChange = true;
664       continue;
665     }
666     newCaseDestinations.push_back(caseDests[it.index()]);
667     newCaseOperands.push_back(op.getCaseOperands(it.index()));
668     newCaseValues.push_back(it.value());
669   }
670 
671   if (!requiresChange)
672     return failure();
673 
674   rewriter.replaceOpWithNewOp<SwitchOp>(
675       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
676       newCaseValues, newCaseDestinations, newCaseOperands);
677   return success();
678 }
679 
680 /// Helper for folding a switch with a constant value.
681 /// switch %c_42 : i32, [
682 ///   default: ^bb1 ,
683 ///   42: ^bb2,
684 ///   43: ^bb3
685 /// ]
686 /// -> br ^bb2
687 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
688                        const APInt &caseValue) {
689   auto caseValues = op.getCaseValues();
690   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
691     if (it.value() == caseValue) {
692       rewriter.replaceOpWithNewOp<BranchOp>(
693           op, op.getCaseDestinations()[it.index()],
694           op.getCaseOperands(it.index()));
695       return;
696     }
697   }
698   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
699                                         op.getDefaultOperands());
700 }
701 
702 /// switch %c_42 : i32, [
703 ///   default: ^bb1,
704 ///   42: ^bb2,
705 ///   43: ^bb3
706 /// ]
707 /// -> br ^bb2
708 static LogicalResult simplifyConstSwitchValue(SwitchOp op,
709                                               PatternRewriter &rewriter) {
710   APInt caseValue;
711   if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
712     return failure();
713 
714   foldSwitch(op, rewriter, caseValue);
715   return success();
716 }
717 
718 /// switch %c_42 : i32, [
719 ///   default: ^bb1,
720 ///   42: ^bb2,
721 /// ]
722 /// ^bb2:
723 ///   br ^bb3
724 /// ->
725 /// switch %c_42 : i32, [
726 ///   default: ^bb1,
727 ///   42: ^bb3,
728 /// ]
729 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
730                                                PatternRewriter &rewriter) {
731   SmallVector<Block *> newCaseDests;
732   SmallVector<ValueRange> newCaseOperands;
733   SmallVector<SmallVector<Value>> argStorage;
734   auto caseValues = op.getCaseValues();
735   argStorage.reserve(caseValues->size() + 1);
736   auto caseDests = op.getCaseDestinations();
737   bool requiresChange = false;
738   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
739     Block *caseDest = caseDests[i];
740     ValueRange caseOperands = op.getCaseOperands(i);
741     argStorage.emplace_back();
742     if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
743       requiresChange = true;
744 
745     newCaseDests.push_back(caseDest);
746     newCaseOperands.push_back(caseOperands);
747   }
748 
749   Block *defaultDest = op.getDefaultDestination();
750   ValueRange defaultOperands = op.getDefaultOperands();
751   argStorage.emplace_back();
752 
753   if (succeeded(
754           collapseBranch(defaultDest, defaultOperands, argStorage.back())))
755     requiresChange = true;
756 
757   if (!requiresChange)
758     return failure();
759 
760   rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
761                                         defaultOperands, *caseValues,
762                                         newCaseDests, newCaseOperands);
763   return success();
764 }
765 
766 /// switch %flag : i32, [
767 ///   default: ^bb1,
768 ///   42: ^bb2,
769 /// ]
770 /// ^bb2:
771 ///   switch %flag : i32, [
772 ///     default: ^bb3,
773 ///     42: ^bb4
774 ///   ]
775 /// ->
776 /// switch %flag : i32, [
777 ///   default: ^bb1,
778 ///   42: ^bb2,
779 /// ]
780 /// ^bb2:
781 ///   br ^bb4
782 ///
783 ///  and
784 ///
785 /// switch %flag : i32, [
786 ///   default: ^bb1,
787 ///   42: ^bb2,
788 /// ]
789 /// ^bb2:
790 ///   switch %flag : i32, [
791 ///     default: ^bb3,
792 ///     43: ^bb4
793 ///   ]
794 /// ->
795 /// switch %flag : i32, [
796 ///   default: ^bb1,
797 ///   42: ^bb2,
798 /// ]
799 /// ^bb2:
800 ///   br ^bb3
801 static LogicalResult
802 simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
803                                         PatternRewriter &rewriter) {
804   // Check that we have a single distinct predecessor.
805   Block *currentBlock = op->getBlock();
806   Block *predecessor = currentBlock->getSinglePredecessor();
807   if (!predecessor)
808     return failure();
809 
810   // Check that the predecessor terminates with a switch branch to this block
811   // and that it branches on the same condition and that this branch isn't the
812   // default destination.
813   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
814   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
815       predSwitch.getDefaultDestination() == currentBlock)
816     return failure();
817 
818   // Fold this switch to an unconditional branch.
819   SuccessorRange predDests = predSwitch.getCaseDestinations();
820   auto it = llvm::find(predDests, currentBlock);
821   if (it != predDests.end()) {
822     std::optional<DenseIntElementsAttr> predCaseValues =
823         predSwitch.getCaseValues();
824     foldSwitch(op, rewriter,
825                predCaseValues->getValues<APInt>()[it - predDests.begin()]);
826   } else {
827     rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
828                                           op.getDefaultOperands());
829   }
830   return success();
831 }
832 
833 /// switch %flag : i32, [
834 ///   default: ^bb1,
835 ///   42: ^bb2
836 /// ]
837 /// ^bb1:
838 ///   switch %flag : i32, [
839 ///     default: ^bb3,
840 ///     42: ^bb4,
841 ///     43: ^bb5
842 ///   ]
843 /// ->
844 /// switch %flag : i32, [
845 ///   default: ^bb1,
846 ///   42: ^bb2,
847 /// ]
848 /// ^bb1:
849 ///   switch %flag : i32, [
850 ///     default: ^bb3,
851 ///     43: ^bb5
852 ///   ]
853 static LogicalResult
854 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
855                                                PatternRewriter &rewriter) {
856   // Check that we have a single distinct predecessor.
857   Block *currentBlock = op->getBlock();
858   Block *predecessor = currentBlock->getSinglePredecessor();
859   if (!predecessor)
860     return failure();
861 
862   // Check that the predecessor terminates with a switch branch to this block
863   // and that it branches on the same condition and that this branch is the
864   // default destination.
865   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
866   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
867       predSwitch.getDefaultDestination() != currentBlock)
868     return failure();
869 
870   // Delete case values that are not possible here.
871   DenseSet<APInt> caseValuesToRemove;
872   auto predDests = predSwitch.getCaseDestinations();
873   auto predCaseValues = predSwitch.getCaseValues();
874   for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
875     if (currentBlock != predDests[i])
876       caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
877 
878   SmallVector<Block *> newCaseDestinations;
879   SmallVector<ValueRange> newCaseOperands;
880   SmallVector<APInt> newCaseValues;
881   bool requiresChange = false;
882 
883   auto caseValues = op.getCaseValues();
884   auto caseDests = op.getCaseDestinations();
885   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
886     if (caseValuesToRemove.contains(it.value())) {
887       requiresChange = true;
888       continue;
889     }
890     newCaseDestinations.push_back(caseDests[it.index()]);
891     newCaseOperands.push_back(op.getCaseOperands(it.index()));
892     newCaseValues.push_back(it.value());
893   }
894 
895   if (!requiresChange)
896     return failure();
897 
898   rewriter.replaceOpWithNewOp<SwitchOp>(
899       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
900       newCaseValues, newCaseDestinations, newCaseOperands);
901   return success();
902 }
903 
904 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
905                                            MLIRContext *context) {
906   results.add(&simplifySwitchWithOnlyDefault)
907       .add(&dropSwitchCasesThatMatchDefault)
908       .add(&simplifyConstSwitchValue)
909       .add(&simplifyPassThroughSwitch)
910       .add(&simplifySwitchFromSwitchOnSameCondition)
911       .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
912 }
913 
914 //===----------------------------------------------------------------------===//
915 // TableGen'd op method definitions
916 //===----------------------------------------------------------------------===//
917 
918 #define GET_OP_CLASSES
919 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
920