xref: /llvm-project/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (revision 513cdb82223a106f183b49a40d9acb1f7efbbe7e)
1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10 
11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
14 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/OpImplementation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Support/MathExtras.h"
27 #include "mlir/Transforms/InliningUtils.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/Support/FormatVariadic.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <numeric>
33 
34 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
35 
36 using namespace mlir;
37 using namespace mlir::cf;
38 
39 //===----------------------------------------------------------------------===//
40 // ControlFlowDialect Interfaces
41 //===----------------------------------------------------------------------===//
42 namespace {
43 /// This class defines the interface for handling inlining with control flow
44 /// operations.
45 struct ControlFlowInlinerInterface : public DialectInlinerInterface {
46   using DialectInlinerInterface::DialectInlinerInterface;
47   ~ControlFlowInlinerInterface() override = default;
48 
49   /// All control flow operations can be inlined.
50   bool isLegalToInline(Operation *call, Operation *callable,
51                        bool wouldBeCloned) const final {
52     return true;
53   }
54   bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
55     return true;
56   }
57 
58   /// ControlFlow terminator operations don't really need any special handing.
59   void handleTerminator(Operation *op, Block *newDest) const final {}
60 };
61 } // namespace
62 
63 //===----------------------------------------------------------------------===//
64 // ControlFlowDialect
65 //===----------------------------------------------------------------------===//
66 
67 void ControlFlowDialect::initialize() {
68   addOperations<
69 #define GET_OP_LIST
70 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
71       >();
72   addInterfaces<ControlFlowInlinerInterface>();
73   declarePromisedInterface<ControlFlowDialect, ConvertToLLVMPatternInterface>();
74   declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
75                             CondBranchOp>();
76   declarePromisedInterface<CondBranchOp,
77                            bufferization::BufferDeallocationOpInterface>();
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // AssertOp
82 //===----------------------------------------------------------------------===//
83 
84 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
85   // Erase assertion if argument is constant true.
86   if (matchPattern(op.getArg(), m_One())) {
87     rewriter.eraseOp(op);
88     return success();
89   }
90   return failure();
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // BranchOp
95 //===----------------------------------------------------------------------===//
96 
97 /// Given a successor, try to collapse it to a new destination if it only
98 /// contains a passthrough unconditional branch. If the successor is
99 /// collapsable, `successor` and `successorOperands` are updated to reference
100 /// the new destination and values. `argStorage` is used as storage if operands
101 /// to the collapsed successor need to be remapped. It must outlive uses of
102 /// successorOperands.
103 static LogicalResult collapseBranch(Block *&successor,
104                                     ValueRange &successorOperands,
105                                     SmallVectorImpl<Value> &argStorage) {
106   // Check that the successor only contains a unconditional branch.
107   if (std::next(successor->begin()) != successor->end())
108     return failure();
109   // Check that the terminator is an unconditional branch.
110   BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
111   if (!successorBranch)
112     return failure();
113   // Check that the arguments are only used within the terminator.
114   for (BlockArgument arg : successor->getArguments()) {
115     for (Operation *user : arg.getUsers())
116       if (user != successorBranch)
117         return failure();
118   }
119   // Don't try to collapse branches to infinite loops.
120   Block *successorDest = successorBranch.getDest();
121   if (successorDest == successor)
122     return failure();
123 
124   // Update the operands to the successor. If the branch parent has no
125   // arguments, we can use the branch operands directly.
126   OperandRange operands = successorBranch.getOperands();
127   if (successor->args_empty()) {
128     successor = successorDest;
129     successorOperands = operands;
130     return success();
131   }
132 
133   // Otherwise, we need to remap any argument operands.
134   for (Value operand : operands) {
135     BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
136     if (argOperand && argOperand.getOwner() == successor)
137       argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
138     else
139       argStorage.push_back(operand);
140   }
141   successor = successorDest;
142   successorOperands = argStorage;
143   return success();
144 }
145 
146 /// Simplify a branch to a block that has a single predecessor. This effectively
147 /// merges the two blocks.
148 static LogicalResult
149 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
150   // Check that the successor block has a single predecessor.
151   Block *succ = op.getDest();
152   Block *opParent = op->getBlock();
153   if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
154     return failure();
155 
156   // Merge the successor into the current block and erase the branch.
157   SmallVector<Value> brOperands(op.getOperands());
158   rewriter.eraseOp(op);
159   rewriter.mergeBlocks(succ, opParent, brOperands);
160   return success();
161 }
162 
163 ///   br ^bb1
164 /// ^bb1
165 ///   br ^bbN(...)
166 ///
167 ///  -> br ^bbN(...)
168 ///
169 static LogicalResult simplifyPassThroughBr(BranchOp op,
170                                            PatternRewriter &rewriter) {
171   Block *dest = op.getDest();
172   ValueRange destOperands = op.getOperands();
173   SmallVector<Value, 4> destOperandStorage;
174 
175   // Try to collapse the successor if it points somewhere other than this
176   // block.
177   if (dest == op->getBlock() ||
178       failed(collapseBranch(dest, destOperands, destOperandStorage)))
179     return failure();
180 
181   // Create a new branch with the collapsed successor.
182   rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
183   return success();
184 }
185 
186 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
187   return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
188                  succeeded(simplifyPassThroughBr(op, rewriter)));
189 }
190 
191 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
192 
193 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
194 
195 SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
196   assert(index == 0 && "invalid successor index");
197   return SuccessorOperands(getDestOperandsMutable());
198 }
199 
200 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
201   return getDest();
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // CondBranchOp
206 //===----------------------------------------------------------------------===//
207 
208 namespace {
209 /// cf.cond_br true, ^bb1, ^bb2
210 ///  -> br ^bb1
211 /// cf.cond_br false, ^bb1, ^bb2
212 ///  -> br ^bb2
213 ///
214 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
215   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
216 
217   LogicalResult matchAndRewrite(CondBranchOp condbr,
218                                 PatternRewriter &rewriter) const override {
219     if (matchPattern(condbr.getCondition(), m_NonZero())) {
220       // True branch taken.
221       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
222                                             condbr.getTrueOperands());
223       return success();
224     }
225     if (matchPattern(condbr.getCondition(), m_Zero())) {
226       // False branch taken.
227       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
228                                             condbr.getFalseOperands());
229       return success();
230     }
231     return failure();
232   }
233 };
234 
235 ///   cf.cond_br %cond, ^bb1, ^bb2
236 /// ^bb1
237 ///   br ^bbN(...)
238 /// ^bb2
239 ///   br ^bbK(...)
240 ///
241 ///  -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
242 ///
243 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
244   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
245 
246   LogicalResult matchAndRewrite(CondBranchOp condbr,
247                                 PatternRewriter &rewriter) const override {
248     Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
249     ValueRange trueDestOperands = condbr.getTrueOperands();
250     ValueRange falseDestOperands = condbr.getFalseOperands();
251     SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
252 
253     // Try to collapse one of the current successors.
254     LogicalResult collapsedTrue =
255         collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
256     LogicalResult collapsedFalse =
257         collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
258     if (failed(collapsedTrue) && failed(collapsedFalse))
259       return failure();
260 
261     // Create a new branch with the collapsed successors.
262     rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
263                                               trueDest, trueDestOperands,
264                                               falseDest, falseDestOperands);
265     return success();
266   }
267 };
268 
269 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
270 ///  -> br ^bb1(A, ..., N)
271 ///
272 /// cf.cond_br %cond, ^bb1(A), ^bb1(B)
273 ///  -> %select = arith.select %cond, A, B
274 ///     br ^bb1(%select)
275 ///
276 struct SimplifyCondBranchIdenticalSuccessors
277     : public OpRewritePattern<CondBranchOp> {
278   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
279 
280   LogicalResult matchAndRewrite(CondBranchOp condbr,
281                                 PatternRewriter &rewriter) const override {
282     // Check that the true and false destinations are the same and have the same
283     // operands.
284     Block *trueDest = condbr.getTrueDest();
285     if (trueDest != condbr.getFalseDest())
286       return failure();
287 
288     // If all of the operands match, no selects need to be generated.
289     OperandRange trueOperands = condbr.getTrueOperands();
290     OperandRange falseOperands = condbr.getFalseOperands();
291     if (trueOperands == falseOperands) {
292       rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
293       return success();
294     }
295 
296     // Otherwise, if the current block is the only predecessor insert selects
297     // for any mismatched branch operands.
298     if (trueDest->getUniquePredecessor() != condbr->getBlock())
299       return failure();
300 
301     // Generate a select for any operands that differ between the two.
302     SmallVector<Value, 8> mergedOperands;
303     mergedOperands.reserve(trueOperands.size());
304     Value condition = condbr.getCondition();
305     for (auto it : llvm::zip(trueOperands, falseOperands)) {
306       if (std::get<0>(it) == std::get<1>(it))
307         mergedOperands.push_back(std::get<0>(it));
308       else
309         mergedOperands.push_back(rewriter.create<arith::SelectOp>(
310             condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
311     }
312 
313     rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
314     return success();
315   }
316 };
317 
318 ///   ...
319 ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
320 /// ...
321 /// ^bb1: // has single predecessor
322 ///   ...
323 ///   cf.cond_br %cond, ^bb3(...), ^bb4(...)
324 ///
325 /// ->
326 ///
327 ///   ...
328 ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
329 /// ...
330 /// ^bb1: // has single predecessor
331 ///   ...
332 ///   br ^bb3(...)
333 ///
334 struct SimplifyCondBranchFromCondBranchOnSameCondition
335     : public OpRewritePattern<CondBranchOp> {
336   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
337 
338   LogicalResult matchAndRewrite(CondBranchOp condbr,
339                                 PatternRewriter &rewriter) const override {
340     // Check that we have a single distinct predecessor.
341     Block *currentBlock = condbr->getBlock();
342     Block *predecessor = currentBlock->getSinglePredecessor();
343     if (!predecessor)
344       return failure();
345 
346     // Check that the predecessor terminates with a conditional branch to this
347     // block and that it branches on the same condition.
348     auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
349     if (!predBranch || condbr.getCondition() != predBranch.getCondition())
350       return failure();
351 
352     // Fold this branch to an unconditional branch.
353     if (currentBlock == predBranch.getTrueDest())
354       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
355                                             condbr.getTrueDestOperands());
356     else
357       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
358                                             condbr.getFalseDestOperands());
359     return success();
360   }
361 };
362 
363 ///   cf.cond_br %arg0, ^trueB, ^falseB
364 ///
365 /// ^trueB:
366 ///   "test.consumer1"(%arg0) : (i1) -> ()
367 ///    ...
368 ///
369 /// ^falseB:
370 ///   "test.consumer2"(%arg0) : (i1) -> ()
371 ///   ...
372 ///
373 /// ->
374 ///
375 ///   cf.cond_br %arg0, ^trueB, ^falseB
376 /// ^trueB:
377 ///   "test.consumer1"(%true) : (i1) -> ()
378 ///   ...
379 ///
380 /// ^falseB:
381 ///   "test.consumer2"(%false) : (i1) -> ()
382 ///   ...
383 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
384   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
385 
386   LogicalResult matchAndRewrite(CondBranchOp condbr,
387                                 PatternRewriter &rewriter) const override {
388     // Check that we have a single distinct predecessor.
389     bool replaced = false;
390     Type ty = rewriter.getI1Type();
391 
392     // These variables serve to prevent creating duplicate constants
393     // and hold constant true or false values.
394     Value constantTrue = nullptr;
395     Value constantFalse = nullptr;
396 
397     // TODO These checks can be expanded to encompas any use with only
398     // either the true of false edge as a predecessor. For now, we fall
399     // back to checking the single predecessor is given by the true/fasle
400     // destination, thereby ensuring that only that edge can reach the
401     // op.
402     if (condbr.getTrueDest()->getSinglePredecessor()) {
403       for (OpOperand &use :
404            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
405         if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
406           replaced = true;
407 
408           if (!constantTrue)
409             constantTrue = rewriter.create<arith::ConstantOp>(
410                 condbr.getLoc(), ty, rewriter.getBoolAttr(true));
411 
412           rewriter.modifyOpInPlace(use.getOwner(),
413                                    [&] { use.set(constantTrue); });
414         }
415       }
416     }
417     if (condbr.getFalseDest()->getSinglePredecessor()) {
418       for (OpOperand &use :
419            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
420         if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
421           replaced = true;
422 
423           if (!constantFalse)
424             constantFalse = rewriter.create<arith::ConstantOp>(
425                 condbr.getLoc(), ty, rewriter.getBoolAttr(false));
426 
427           rewriter.modifyOpInPlace(use.getOwner(),
428                                    [&] { use.set(constantFalse); });
429         }
430       }
431     }
432     return success(replaced);
433   }
434 };
435 } // namespace
436 
437 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
438                                                MLIRContext *context) {
439   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
440               SimplifyCondBranchIdenticalSuccessors,
441               SimplifyCondBranchFromCondBranchOnSameCondition,
442               CondBranchTruthPropagation>(context);
443 }
444 
445 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
446   assert(index < getNumSuccessors() && "invalid successor index");
447   return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
448                                               : getFalseDestOperandsMutable());
449 }
450 
451 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
452   if (IntegerAttr condAttr =
453           llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
454     return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
455   return nullptr;
456 }
457 
458 //===----------------------------------------------------------------------===//
459 // SwitchOp
460 //===----------------------------------------------------------------------===//
461 
462 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
463                      Block *defaultDestination, ValueRange defaultOperands,
464                      DenseIntElementsAttr caseValues,
465                      BlockRange caseDestinations,
466                      ArrayRef<ValueRange> caseOperands) {
467   build(builder, result, value, defaultOperands, caseOperands, caseValues,
468         defaultDestination, caseDestinations);
469 }
470 
471 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
472                      Block *defaultDestination, ValueRange defaultOperands,
473                      ArrayRef<APInt> caseValues, BlockRange caseDestinations,
474                      ArrayRef<ValueRange> caseOperands) {
475   DenseIntElementsAttr caseValuesAttr;
476   if (!caseValues.empty()) {
477     ShapedType caseValueType = VectorType::get(
478         static_cast<int64_t>(caseValues.size()), value.getType());
479     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
480   }
481   build(builder, result, value, defaultDestination, defaultOperands,
482         caseValuesAttr, caseDestinations, caseOperands);
483 }
484 
485 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
486                      Block *defaultDestination, ValueRange defaultOperands,
487                      ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
488                      ArrayRef<ValueRange> caseOperands) {
489   DenseIntElementsAttr caseValuesAttr;
490   if (!caseValues.empty()) {
491     ShapedType caseValueType = VectorType::get(
492         static_cast<int64_t>(caseValues.size()), value.getType());
493     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
494   }
495   build(builder, result, value, defaultDestination, defaultOperands,
496         caseValuesAttr, caseDestinations, caseOperands);
497 }
498 
499 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
500 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
501 static ParseResult parseSwitchOpCases(
502     OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
503     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
504     SmallVectorImpl<Type> &defaultOperandTypes,
505     DenseIntElementsAttr &caseValues,
506     SmallVectorImpl<Block *> &caseDestinations,
507     SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
508     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
509   if (parser.parseKeyword("default") || parser.parseColon() ||
510       parser.parseSuccessor(defaultDestination))
511     return failure();
512   if (succeeded(parser.parseOptionalLParen())) {
513     if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
514                                 /*allowResultNumber=*/false) ||
515         parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
516       return failure();
517   }
518 
519   SmallVector<APInt> values;
520   unsigned bitWidth = flagType.getIntOrFloatBitWidth();
521   while (succeeded(parser.parseOptionalComma())) {
522     int64_t value = 0;
523     if (failed(parser.parseInteger(value)))
524       return failure();
525     values.push_back(APInt(bitWidth, value));
526 
527     Block *destination;
528     SmallVector<OpAsmParser::UnresolvedOperand> operands;
529     SmallVector<Type> operandTypes;
530     if (failed(parser.parseColon()) ||
531         failed(parser.parseSuccessor(destination)))
532       return failure();
533     if (succeeded(parser.parseOptionalLParen())) {
534       if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
535                                          /*allowResultNumber=*/false)) ||
536           failed(parser.parseColonTypeList(operandTypes)) ||
537           failed(parser.parseRParen()))
538         return failure();
539     }
540     caseDestinations.push_back(destination);
541     caseOperands.emplace_back(operands);
542     caseOperandTypes.emplace_back(operandTypes);
543   }
544 
545   if (!values.empty()) {
546     ShapedType caseValueType =
547         VectorType::get(static_cast<int64_t>(values.size()), flagType);
548     caseValues = DenseIntElementsAttr::get(caseValueType, values);
549   }
550   return success();
551 }
552 
553 static void printSwitchOpCases(
554     OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
555     OperandRange defaultOperands, TypeRange defaultOperandTypes,
556     DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
557     OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
558   p << "  default: ";
559   p.printSuccessorAndUseList(defaultDestination, defaultOperands);
560 
561   if (!caseValues)
562     return;
563 
564   for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
565     p << ',';
566     p.printNewline();
567     p << "  ";
568     p << it.value().getLimitedValue();
569     p << ": ";
570     p.printSuccessorAndUseList(caseDestinations[it.index()],
571                                caseOperands[it.index()]);
572   }
573   p.printNewline();
574 }
575 
576 LogicalResult SwitchOp::verify() {
577   auto caseValues = getCaseValues();
578   auto caseDestinations = getCaseDestinations();
579 
580   if (!caseValues && caseDestinations.empty())
581     return success();
582 
583   Type flagType = getFlag().getType();
584   Type caseValueType = caseValues->getType().getElementType();
585   if (caseValueType != flagType)
586     return emitOpError() << "'flag' type (" << flagType
587                          << ") should match case value type (" << caseValueType
588                          << ")";
589 
590   if (caseValues &&
591       caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
592     return emitOpError() << "number of case values (" << caseValues->size()
593                          << ") should match number of "
594                             "case destinations ("
595                          << caseDestinations.size() << ")";
596   return success();
597 }
598 
599 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
600   assert(index < getNumSuccessors() && "invalid successor index");
601   return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
602                                       : getCaseOperandsMutable(index - 1));
603 }
604 
605 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
606   std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
607 
608   if (!caseValues)
609     return getDefaultDestination();
610 
611   SuccessorRange caseDests = getCaseDestinations();
612   if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
613     for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
614       if (it.value() == value.getValue())
615         return caseDests[it.index()];
616     return getDefaultDestination();
617   }
618   return nullptr;
619 }
620 
621 /// switch %flag : i32, [
622 ///   default:  ^bb1
623 /// ]
624 ///  -> br ^bb1
625 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
626                                                    PatternRewriter &rewriter) {
627   if (!op.getCaseDestinations().empty())
628     return failure();
629 
630   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
631                                         op.getDefaultOperands());
632   return success();
633 }
634 
635 /// switch %flag : i32, [
636 ///   default: ^bb1,
637 ///   42: ^bb1,
638 ///   43: ^bb2
639 /// ]
640 /// ->
641 /// switch %flag : i32, [
642 ///   default: ^bb1,
643 ///   43: ^bb2
644 /// ]
645 static LogicalResult
646 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
647   SmallVector<Block *> newCaseDestinations;
648   SmallVector<ValueRange> newCaseOperands;
649   SmallVector<APInt> newCaseValues;
650   bool requiresChange = false;
651   auto caseValues = op.getCaseValues();
652   auto caseDests = op.getCaseDestinations();
653 
654   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
655     if (caseDests[it.index()] == op.getDefaultDestination() &&
656         op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
657       requiresChange = true;
658       continue;
659     }
660     newCaseDestinations.push_back(caseDests[it.index()]);
661     newCaseOperands.push_back(op.getCaseOperands(it.index()));
662     newCaseValues.push_back(it.value());
663   }
664 
665   if (!requiresChange)
666     return failure();
667 
668   rewriter.replaceOpWithNewOp<SwitchOp>(
669       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
670       newCaseValues, newCaseDestinations, newCaseOperands);
671   return success();
672 }
673 
674 /// Helper for folding a switch with a constant value.
675 /// switch %c_42 : i32, [
676 ///   default: ^bb1 ,
677 ///   42: ^bb2,
678 ///   43: ^bb3
679 /// ]
680 /// -> br ^bb2
681 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
682                        const APInt &caseValue) {
683   auto caseValues = op.getCaseValues();
684   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
685     if (it.value() == caseValue) {
686       rewriter.replaceOpWithNewOp<BranchOp>(
687           op, op.getCaseDestinations()[it.index()],
688           op.getCaseOperands(it.index()));
689       return;
690     }
691   }
692   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
693                                         op.getDefaultOperands());
694 }
695 
696 /// switch %c_42 : i32, [
697 ///   default: ^bb1,
698 ///   42: ^bb2,
699 ///   43: ^bb3
700 /// ]
701 /// -> br ^bb2
702 static LogicalResult simplifyConstSwitchValue(SwitchOp op,
703                                               PatternRewriter &rewriter) {
704   APInt caseValue;
705   if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
706     return failure();
707 
708   foldSwitch(op, rewriter, caseValue);
709   return success();
710 }
711 
712 /// switch %c_42 : i32, [
713 ///   default: ^bb1,
714 ///   42: ^bb2,
715 /// ]
716 /// ^bb2:
717 ///   br ^bb3
718 /// ->
719 /// switch %c_42 : i32, [
720 ///   default: ^bb1,
721 ///   42: ^bb3,
722 /// ]
723 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
724                                                PatternRewriter &rewriter) {
725   SmallVector<Block *> newCaseDests;
726   SmallVector<ValueRange> newCaseOperands;
727   SmallVector<SmallVector<Value>> argStorage;
728   auto caseValues = op.getCaseValues();
729   argStorage.reserve(caseValues->size() + 1);
730   auto caseDests = op.getCaseDestinations();
731   bool requiresChange = false;
732   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
733     Block *caseDest = caseDests[i];
734     ValueRange caseOperands = op.getCaseOperands(i);
735     argStorage.emplace_back();
736     if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
737       requiresChange = true;
738 
739     newCaseDests.push_back(caseDest);
740     newCaseOperands.push_back(caseOperands);
741   }
742 
743   Block *defaultDest = op.getDefaultDestination();
744   ValueRange defaultOperands = op.getDefaultOperands();
745   argStorage.emplace_back();
746 
747   if (succeeded(
748           collapseBranch(defaultDest, defaultOperands, argStorage.back())))
749     requiresChange = true;
750 
751   if (!requiresChange)
752     return failure();
753 
754   rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
755                                         defaultOperands, *caseValues,
756                                         newCaseDests, newCaseOperands);
757   return success();
758 }
759 
760 /// switch %flag : i32, [
761 ///   default: ^bb1,
762 ///   42: ^bb2,
763 /// ]
764 /// ^bb2:
765 ///   switch %flag : i32, [
766 ///     default: ^bb3,
767 ///     42: ^bb4
768 ///   ]
769 /// ->
770 /// switch %flag : i32, [
771 ///   default: ^bb1,
772 ///   42: ^bb2,
773 /// ]
774 /// ^bb2:
775 ///   br ^bb4
776 ///
777 ///  and
778 ///
779 /// switch %flag : i32, [
780 ///   default: ^bb1,
781 ///   42: ^bb2,
782 /// ]
783 /// ^bb2:
784 ///   switch %flag : i32, [
785 ///     default: ^bb3,
786 ///     43: ^bb4
787 ///   ]
788 /// ->
789 /// switch %flag : i32, [
790 ///   default: ^bb1,
791 ///   42: ^bb2,
792 /// ]
793 /// ^bb2:
794 ///   br ^bb3
795 static LogicalResult
796 simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
797                                         PatternRewriter &rewriter) {
798   // Check that we have a single distinct predecessor.
799   Block *currentBlock = op->getBlock();
800   Block *predecessor = currentBlock->getSinglePredecessor();
801   if (!predecessor)
802     return failure();
803 
804   // Check that the predecessor terminates with a switch branch to this block
805   // and that it branches on the same condition and that this branch isn't the
806   // default destination.
807   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
808   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
809       predSwitch.getDefaultDestination() == currentBlock)
810     return failure();
811 
812   // Fold this switch to an unconditional branch.
813   SuccessorRange predDests = predSwitch.getCaseDestinations();
814   auto it = llvm::find(predDests, currentBlock);
815   if (it != predDests.end()) {
816     std::optional<DenseIntElementsAttr> predCaseValues =
817         predSwitch.getCaseValues();
818     foldSwitch(op, rewriter,
819                predCaseValues->getValues<APInt>()[it - predDests.begin()]);
820   } else {
821     rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
822                                           op.getDefaultOperands());
823   }
824   return success();
825 }
826 
827 /// switch %flag : i32, [
828 ///   default: ^bb1,
829 ///   42: ^bb2
830 /// ]
831 /// ^bb1:
832 ///   switch %flag : i32, [
833 ///     default: ^bb3,
834 ///     42: ^bb4,
835 ///     43: ^bb5
836 ///   ]
837 /// ->
838 /// switch %flag : i32, [
839 ///   default: ^bb1,
840 ///   42: ^bb2,
841 /// ]
842 /// ^bb1:
843 ///   switch %flag : i32, [
844 ///     default: ^bb3,
845 ///     43: ^bb5
846 ///   ]
847 static LogicalResult
848 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
849                                                PatternRewriter &rewriter) {
850   // Check that we have a single distinct predecessor.
851   Block *currentBlock = op->getBlock();
852   Block *predecessor = currentBlock->getSinglePredecessor();
853   if (!predecessor)
854     return failure();
855 
856   // Check that the predecessor terminates with a switch branch to this block
857   // and that it branches on the same condition and that this branch is the
858   // default destination.
859   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
860   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
861       predSwitch.getDefaultDestination() != currentBlock)
862     return failure();
863 
864   // Delete case values that are not possible here.
865   DenseSet<APInt> caseValuesToRemove;
866   auto predDests = predSwitch.getCaseDestinations();
867   auto predCaseValues = predSwitch.getCaseValues();
868   for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
869     if (currentBlock != predDests[i])
870       caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
871 
872   SmallVector<Block *> newCaseDestinations;
873   SmallVector<ValueRange> newCaseOperands;
874   SmallVector<APInt> newCaseValues;
875   bool requiresChange = false;
876 
877   auto caseValues = op.getCaseValues();
878   auto caseDests = op.getCaseDestinations();
879   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
880     if (caseValuesToRemove.contains(it.value())) {
881       requiresChange = true;
882       continue;
883     }
884     newCaseDestinations.push_back(caseDests[it.index()]);
885     newCaseOperands.push_back(op.getCaseOperands(it.index()));
886     newCaseValues.push_back(it.value());
887   }
888 
889   if (!requiresChange)
890     return failure();
891 
892   rewriter.replaceOpWithNewOp<SwitchOp>(
893       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
894       newCaseValues, newCaseDestinations, newCaseOperands);
895   return success();
896 }
897 
898 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
899                                            MLIRContext *context) {
900   results.add(&simplifySwitchWithOnlyDefault)
901       .add(&dropSwitchCasesThatMatchDefault)
902       .add(&simplifyConstSwitchValue)
903       .add(&simplifyPassThroughSwitch)
904       .add(&simplifySwitchFromSwitchOnSameCondition)
905       .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
906 }
907 
908 //===----------------------------------------------------------------------===//
909 // TableGen'd op method definitions
910 //===----------------------------------------------------------------------===//
911 
912 #define GET_OP_CLASSES
913 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
914