xref: /llvm-project/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (revision ace01605e04d094c243b0cad873e8919b80a0ced)
1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/AffineExpr.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Value.h"
24 #include "mlir/Support/MathExtras.h"
25 #include "mlir/Transforms/InliningUtils.h"
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/StringSwitch.h"
29 #include "llvm/Support/FormatVariadic.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <numeric>
32 
33 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
34 
35 using namespace mlir;
36 using namespace mlir::cf;
37 
38 //===----------------------------------------------------------------------===//
39 // ControlFlowDialect Interfaces
40 //===----------------------------------------------------------------------===//
41 namespace {
42 /// This class defines the interface for handling inlining with control flow
43 /// operations.
44 struct ControlFlowInlinerInterface : public DialectInlinerInterface {
45   using DialectInlinerInterface::DialectInlinerInterface;
46   ~ControlFlowInlinerInterface() override = default;
47 
48   /// All control flow operations can be inlined.
49   bool isLegalToInline(Operation *call, Operation *callable,
50                        bool wouldBeCloned) const final {
51     return true;
52   }
53   bool isLegalToInline(Operation *, Region *, bool,
54                        BlockAndValueMapping &) const final {
55     return true;
56   }
57 
58   /// ControlFlow terminator operations don't really need any special handing.
59   void handleTerminator(Operation *op, Block *newDest) const final {}
60 };
61 } // namespace
62 
63 //===----------------------------------------------------------------------===//
64 // ControlFlowDialect
65 //===----------------------------------------------------------------------===//
66 
67 void ControlFlowDialect::initialize() {
68   addOperations<
69 #define GET_OP_LIST
70 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
71       >();
72   addInterfaces<ControlFlowInlinerInterface>();
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // AssertOp
77 //===----------------------------------------------------------------------===//
78 
79 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
80   // Erase assertion if argument is constant true.
81   if (matchPattern(op.getArg(), m_One())) {
82     rewriter.eraseOp(op);
83     return success();
84   }
85   return failure();
86 }
87 
88 //===----------------------------------------------------------------------===//
89 // BranchOp
90 //===----------------------------------------------------------------------===//
91 
92 /// Given a successor, try to collapse it to a new destination if it only
93 /// contains a passthrough unconditional branch. If the successor is
94 /// collapsable, `successor` and `successorOperands` are updated to reference
95 /// the new destination and values. `argStorage` is used as storage if operands
96 /// to the collapsed successor need to be remapped. It must outlive uses of
97 /// successorOperands.
98 static LogicalResult collapseBranch(Block *&successor,
99                                     ValueRange &successorOperands,
100                                     SmallVectorImpl<Value> &argStorage) {
101   // Check that the successor only contains a unconditional branch.
102   if (std::next(successor->begin()) != successor->end())
103     return failure();
104   // Check that the terminator is an unconditional branch.
105   BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
106   if (!successorBranch)
107     return failure();
108   // Check that the arguments are only used within the terminator.
109   for (BlockArgument arg : successor->getArguments()) {
110     for (Operation *user : arg.getUsers())
111       if (user != successorBranch)
112         return failure();
113   }
114   // Don't try to collapse branches to infinite loops.
115   Block *successorDest = successorBranch.getDest();
116   if (successorDest == successor)
117     return failure();
118 
119   // Update the operands to the successor. If the branch parent has no
120   // arguments, we can use the branch operands directly.
121   OperandRange operands = successorBranch.getOperands();
122   if (successor->args_empty()) {
123     successor = successorDest;
124     successorOperands = operands;
125     return success();
126   }
127 
128   // Otherwise, we need to remap any argument operands.
129   for (Value operand : operands) {
130     BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
131     if (argOperand && argOperand.getOwner() == successor)
132       argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
133     else
134       argStorage.push_back(operand);
135   }
136   successor = successorDest;
137   successorOperands = argStorage;
138   return success();
139 }
140 
141 /// Simplify a branch to a block that has a single predecessor. This effectively
142 /// merges the two blocks.
143 static LogicalResult
144 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
145   // Check that the successor block has a single predecessor.
146   Block *succ = op.getDest();
147   Block *opParent = op->getBlock();
148   if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
149     return failure();
150 
151   // Merge the successor into the current block and erase the branch.
152   rewriter.mergeBlocks(succ, opParent, op.getOperands());
153   rewriter.eraseOp(op);
154   return success();
155 }
156 
157 ///   br ^bb1
158 /// ^bb1
159 ///   br ^bbN(...)
160 ///
161 ///  -> br ^bbN(...)
162 ///
163 static LogicalResult simplifyPassThroughBr(BranchOp op,
164                                            PatternRewriter &rewriter) {
165   Block *dest = op.getDest();
166   ValueRange destOperands = op.getOperands();
167   SmallVector<Value, 4> destOperandStorage;
168 
169   // Try to collapse the successor if it points somewhere other than this
170   // block.
171   if (dest == op->getBlock() ||
172       failed(collapseBranch(dest, destOperands, destOperandStorage)))
173     return failure();
174 
175   // Create a new branch with the collapsed successor.
176   rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
177   return success();
178 }
179 
180 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
181   return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
182                  succeeded(simplifyPassThroughBr(op, rewriter)));
183 }
184 
185 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
186 
187 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
188 
189 Optional<MutableOperandRange>
190 BranchOp::getMutableSuccessorOperands(unsigned index) {
191   assert(index == 0 && "invalid successor index");
192   return getDestOperandsMutable();
193 }
194 
195 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
196   return getDest();
197 }
198 
199 //===----------------------------------------------------------------------===//
200 // CondBranchOp
201 //===----------------------------------------------------------------------===//
202 
203 namespace {
204 /// cf.cond_br true, ^bb1, ^bb2
205 ///  -> br ^bb1
206 /// cf.cond_br false, ^bb1, ^bb2
207 ///  -> br ^bb2
208 ///
209 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
210   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
211 
212   LogicalResult matchAndRewrite(CondBranchOp condbr,
213                                 PatternRewriter &rewriter) const override {
214     if (matchPattern(condbr.getCondition(), m_NonZero())) {
215       // True branch taken.
216       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
217                                             condbr.getTrueOperands());
218       return success();
219     }
220     if (matchPattern(condbr.getCondition(), m_Zero())) {
221       // False branch taken.
222       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
223                                             condbr.getFalseOperands());
224       return success();
225     }
226     return failure();
227   }
228 };
229 
230 ///   cf.cond_br %cond, ^bb1, ^bb2
231 /// ^bb1
232 ///   br ^bbN(...)
233 /// ^bb2
234 ///   br ^bbK(...)
235 ///
236 ///  -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
237 ///
238 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
239   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
240 
241   LogicalResult matchAndRewrite(CondBranchOp condbr,
242                                 PatternRewriter &rewriter) const override {
243     Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
244     ValueRange trueDestOperands = condbr.getTrueOperands();
245     ValueRange falseDestOperands = condbr.getFalseOperands();
246     SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
247 
248     // Try to collapse one of the current successors.
249     LogicalResult collapsedTrue =
250         collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
251     LogicalResult collapsedFalse =
252         collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
253     if (failed(collapsedTrue) && failed(collapsedFalse))
254       return failure();
255 
256     // Create a new branch with the collapsed successors.
257     rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
258                                               trueDest, trueDestOperands,
259                                               falseDest, falseDestOperands);
260     return success();
261   }
262 };
263 
264 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
265 ///  -> br ^bb1(A, ..., N)
266 ///
267 /// cf.cond_br %cond, ^bb1(A), ^bb1(B)
268 ///  -> %select = arith.select %cond, A, B
269 ///     br ^bb1(%select)
270 ///
271 struct SimplifyCondBranchIdenticalSuccessors
272     : public OpRewritePattern<CondBranchOp> {
273   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
274 
275   LogicalResult matchAndRewrite(CondBranchOp condbr,
276                                 PatternRewriter &rewriter) const override {
277     // Check that the true and false destinations are the same and have the same
278     // operands.
279     Block *trueDest = condbr.getTrueDest();
280     if (trueDest != condbr.getFalseDest())
281       return failure();
282 
283     // If all of the operands match, no selects need to be generated.
284     OperandRange trueOperands = condbr.getTrueOperands();
285     OperandRange falseOperands = condbr.getFalseOperands();
286     if (trueOperands == falseOperands) {
287       rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
288       return success();
289     }
290 
291     // Otherwise, if the current block is the only predecessor insert selects
292     // for any mismatched branch operands.
293     if (trueDest->getUniquePredecessor() != condbr->getBlock())
294       return failure();
295 
296     // Generate a select for any operands that differ between the two.
297     SmallVector<Value, 8> mergedOperands;
298     mergedOperands.reserve(trueOperands.size());
299     Value condition = condbr.getCondition();
300     for (auto it : llvm::zip(trueOperands, falseOperands)) {
301       if (std::get<0>(it) == std::get<1>(it))
302         mergedOperands.push_back(std::get<0>(it));
303       else
304         mergedOperands.push_back(rewriter.create<arith::SelectOp>(
305             condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
306     }
307 
308     rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
309     return success();
310   }
311 };
312 
313 ///   ...
314 ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
315 /// ...
316 /// ^bb1: // has single predecessor
317 ///   ...
318 ///   cf.cond_br %cond, ^bb3(...), ^bb4(...)
319 ///
320 /// ->
321 ///
322 ///   ...
323 ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
324 /// ...
325 /// ^bb1: // has single predecessor
326 ///   ...
327 ///   br ^bb3(...)
328 ///
329 struct SimplifyCondBranchFromCondBranchOnSameCondition
330     : public OpRewritePattern<CondBranchOp> {
331   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
332 
333   LogicalResult matchAndRewrite(CondBranchOp condbr,
334                                 PatternRewriter &rewriter) const override {
335     // Check that we have a single distinct predecessor.
336     Block *currentBlock = condbr->getBlock();
337     Block *predecessor = currentBlock->getSinglePredecessor();
338     if (!predecessor)
339       return failure();
340 
341     // Check that the predecessor terminates with a conditional branch to this
342     // block and that it branches on the same condition.
343     auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
344     if (!predBranch || condbr.getCondition() != predBranch.getCondition())
345       return failure();
346 
347     // Fold this branch to an unconditional branch.
348     if (currentBlock == predBranch.getTrueDest())
349       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
350                                             condbr.getTrueDestOperands());
351     else
352       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
353                                             condbr.getFalseDestOperands());
354     return success();
355   }
356 };
357 
358 ///   cf.cond_br %arg0, ^trueB, ^falseB
359 ///
360 /// ^trueB:
361 ///   "test.consumer1"(%arg0) : (i1) -> ()
362 ///    ...
363 ///
364 /// ^falseB:
365 ///   "test.consumer2"(%arg0) : (i1) -> ()
366 ///   ...
367 ///
368 /// ->
369 ///
370 ///   cf.cond_br %arg0, ^trueB, ^falseB
371 /// ^trueB:
372 ///   "test.consumer1"(%true) : (i1) -> ()
373 ///   ...
374 ///
375 /// ^falseB:
376 ///   "test.consumer2"(%false) : (i1) -> ()
377 ///   ...
378 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
379   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
380 
381   LogicalResult matchAndRewrite(CondBranchOp condbr,
382                                 PatternRewriter &rewriter) const override {
383     // Check that we have a single distinct predecessor.
384     bool replaced = false;
385     Type ty = rewriter.getI1Type();
386 
387     // These variables serve to prevent creating duplicate constants
388     // and hold constant true or false values.
389     Value constantTrue = nullptr;
390     Value constantFalse = nullptr;
391 
392     // TODO These checks can be expanded to encompas any use with only
393     // either the true of false edge as a predecessor. For now, we fall
394     // back to checking the single predecessor is given by the true/fasle
395     // destination, thereby ensuring that only that edge can reach the
396     // op.
397     if (condbr.getTrueDest()->getSinglePredecessor()) {
398       for (OpOperand &use :
399            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
400         if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
401           replaced = true;
402 
403           if (!constantTrue)
404             constantTrue = rewriter.create<arith::ConstantOp>(
405                 condbr.getLoc(), ty, rewriter.getBoolAttr(true));
406 
407           rewriter.updateRootInPlace(use.getOwner(),
408                                      [&] { use.set(constantTrue); });
409         }
410       }
411     }
412     if (condbr.getFalseDest()->getSinglePredecessor()) {
413       for (OpOperand &use :
414            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
415         if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
416           replaced = true;
417 
418           if (!constantFalse)
419             constantFalse = rewriter.create<arith::ConstantOp>(
420                 condbr.getLoc(), ty, rewriter.getBoolAttr(false));
421 
422           rewriter.updateRootInPlace(use.getOwner(),
423                                      [&] { use.set(constantFalse); });
424         }
425       }
426     }
427     return success(replaced);
428   }
429 };
430 } // namespace
431 
432 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
433                                                MLIRContext *context) {
434   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
435               SimplifyCondBranchIdenticalSuccessors,
436               SimplifyCondBranchFromCondBranchOnSameCondition,
437               CondBranchTruthPropagation>(context);
438 }
439 
440 Optional<MutableOperandRange>
441 CondBranchOp::getMutableSuccessorOperands(unsigned index) {
442   assert(index < getNumSuccessors() && "invalid successor index");
443   return index == trueIndex ? getTrueDestOperandsMutable()
444                             : getFalseDestOperandsMutable();
445 }
446 
447 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
448   if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
449     return condAttr.getValue().isOneValue() ? getTrueDest() : getFalseDest();
450   return nullptr;
451 }
452 
453 //===----------------------------------------------------------------------===//
454 // SwitchOp
455 //===----------------------------------------------------------------------===//
456 
457 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
458                      Block *defaultDestination, ValueRange defaultOperands,
459                      DenseIntElementsAttr caseValues,
460                      BlockRange caseDestinations,
461                      ArrayRef<ValueRange> caseOperands) {
462   build(builder, result, value, defaultOperands, caseOperands, caseValues,
463         defaultDestination, caseDestinations);
464 }
465 
466 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
467                      Block *defaultDestination, ValueRange defaultOperands,
468                      ArrayRef<APInt> caseValues, BlockRange caseDestinations,
469                      ArrayRef<ValueRange> caseOperands) {
470   DenseIntElementsAttr caseValuesAttr;
471   if (!caseValues.empty()) {
472     ShapedType caseValueType = VectorType::get(
473         static_cast<int64_t>(caseValues.size()), value.getType());
474     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
475   }
476   build(builder, result, value, defaultDestination, defaultOperands,
477         caseValuesAttr, caseDestinations, caseOperands);
478 }
479 
480 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
481 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
482 static ParseResult parseSwitchOpCases(
483     OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
484     SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
485     SmallVectorImpl<Type> &defaultOperandTypes,
486     DenseIntElementsAttr &caseValues,
487     SmallVectorImpl<Block *> &caseDestinations,
488     SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
489     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
490   if (parser.parseKeyword("default") || parser.parseColon() ||
491       parser.parseSuccessor(defaultDestination))
492     return failure();
493   if (succeeded(parser.parseOptionalLParen())) {
494     if (parser.parseRegionArgumentList(defaultOperands) ||
495         parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
496       return failure();
497   }
498 
499   SmallVector<APInt> values;
500   unsigned bitWidth = flagType.getIntOrFloatBitWidth();
501   while (succeeded(parser.parseOptionalComma())) {
502     int64_t value = 0;
503     if (failed(parser.parseInteger(value)))
504       return failure();
505     values.push_back(APInt(bitWidth, value));
506 
507     Block *destination;
508     SmallVector<OpAsmParser::OperandType> operands;
509     SmallVector<Type> operandTypes;
510     if (failed(parser.parseColon()) ||
511         failed(parser.parseSuccessor(destination)))
512       return failure();
513     if (succeeded(parser.parseOptionalLParen())) {
514       if (failed(parser.parseRegionArgumentList(operands)) ||
515           failed(parser.parseColonTypeList(operandTypes)) ||
516           failed(parser.parseRParen()))
517         return failure();
518     }
519     caseDestinations.push_back(destination);
520     caseOperands.emplace_back(operands);
521     caseOperandTypes.emplace_back(operandTypes);
522   }
523 
524   if (!values.empty()) {
525     ShapedType caseValueType =
526         VectorType::get(static_cast<int64_t>(values.size()), flagType);
527     caseValues = DenseIntElementsAttr::get(caseValueType, values);
528   }
529   return success();
530 }
531 
532 static void printSwitchOpCases(
533     OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
534     OperandRange defaultOperands, TypeRange defaultOperandTypes,
535     DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
536     OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
537   p << "  default: ";
538   p.printSuccessorAndUseList(defaultDestination, defaultOperands);
539 
540   if (!caseValues)
541     return;
542 
543   for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
544     p << ',';
545     p.printNewline();
546     p << "  ";
547     p << it.value().getLimitedValue();
548     p << ": ";
549     p.printSuccessorAndUseList(caseDestinations[it.index()],
550                                caseOperands[it.index()]);
551   }
552   p.printNewline();
553 }
554 
555 LogicalResult SwitchOp::verify() {
556   auto caseValues = getCaseValues();
557   auto caseDestinations = getCaseDestinations();
558 
559   if (!caseValues && caseDestinations.empty())
560     return success();
561 
562   Type flagType = getFlag().getType();
563   Type caseValueType = caseValues->getType().getElementType();
564   if (caseValueType != flagType)
565     return emitOpError() << "'flag' type (" << flagType
566                          << ") should match case value type (" << caseValueType
567                          << ")";
568 
569   if (caseValues &&
570       caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
571     return emitOpError() << "number of case values (" << caseValues->size()
572                          << ") should match number of "
573                             "case destinations ("
574                          << caseDestinations.size() << ")";
575   return success();
576 }
577 
578 Optional<MutableOperandRange>
579 SwitchOp::getMutableSuccessorOperands(unsigned index) {
580   assert(index < getNumSuccessors() && "invalid successor index");
581   return index == 0 ? getDefaultOperandsMutable()
582                     : getCaseOperandsMutable(index - 1);
583 }
584 
585 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
586   Optional<DenseIntElementsAttr> caseValues = getCaseValues();
587 
588   if (!caseValues)
589     return getDefaultDestination();
590 
591   SuccessorRange caseDests = getCaseDestinations();
592   if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
593     for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
594       if (it.value() == value.getValue())
595         return caseDests[it.index()];
596     return getDefaultDestination();
597   }
598   return nullptr;
599 }
600 
601 /// switch %flag : i32, [
602 ///   default:  ^bb1
603 /// ]
604 ///  -> br ^bb1
605 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
606                                                    PatternRewriter &rewriter) {
607   if (!op.getCaseDestinations().empty())
608     return failure();
609 
610   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
611                                         op.getDefaultOperands());
612   return success();
613 }
614 
615 /// switch %flag : i32, [
616 ///   default: ^bb1,
617 ///   42: ^bb1,
618 ///   43: ^bb2
619 /// ]
620 /// ->
621 /// switch %flag : i32, [
622 ///   default: ^bb1,
623 ///   43: ^bb2
624 /// ]
625 static LogicalResult
626 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
627   SmallVector<Block *> newCaseDestinations;
628   SmallVector<ValueRange> newCaseOperands;
629   SmallVector<APInt> newCaseValues;
630   bool requiresChange = false;
631   auto caseValues = op.getCaseValues();
632   auto caseDests = op.getCaseDestinations();
633 
634   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
635     if (caseDests[it.index()] == op.getDefaultDestination() &&
636         op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
637       requiresChange = true;
638       continue;
639     }
640     newCaseDestinations.push_back(caseDests[it.index()]);
641     newCaseOperands.push_back(op.getCaseOperands(it.index()));
642     newCaseValues.push_back(it.value());
643   }
644 
645   if (!requiresChange)
646     return failure();
647 
648   rewriter.replaceOpWithNewOp<SwitchOp>(
649       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
650       newCaseValues, newCaseDestinations, newCaseOperands);
651   return success();
652 }
653 
654 /// Helper for folding a switch with a constant value.
655 /// switch %c_42 : i32, [
656 ///   default: ^bb1 ,
657 ///   42: ^bb2,
658 ///   43: ^bb3
659 /// ]
660 /// -> br ^bb2
661 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
662                        const APInt &caseValue) {
663   auto caseValues = op.getCaseValues();
664   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
665     if (it.value() == caseValue) {
666       rewriter.replaceOpWithNewOp<BranchOp>(
667           op, op.getCaseDestinations()[it.index()],
668           op.getCaseOperands(it.index()));
669       return;
670     }
671   }
672   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
673                                         op.getDefaultOperands());
674 }
675 
676 /// switch %c_42 : i32, [
677 ///   default: ^bb1,
678 ///   42: ^bb2,
679 ///   43: ^bb3
680 /// ]
681 /// -> br ^bb2
682 static LogicalResult simplifyConstSwitchValue(SwitchOp op,
683                                               PatternRewriter &rewriter) {
684   APInt caseValue;
685   if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
686     return failure();
687 
688   foldSwitch(op, rewriter, caseValue);
689   return success();
690 }
691 
692 /// switch %c_42 : i32, [
693 ///   default: ^bb1,
694 ///   42: ^bb2,
695 /// ]
696 /// ^bb2:
697 ///   br ^bb3
698 /// ->
699 /// switch %c_42 : i32, [
700 ///   default: ^bb1,
701 ///   42: ^bb3,
702 /// ]
703 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
704                                                PatternRewriter &rewriter) {
705   SmallVector<Block *> newCaseDests;
706   SmallVector<ValueRange> newCaseOperands;
707   SmallVector<SmallVector<Value>> argStorage;
708   auto caseValues = op.getCaseValues();
709   auto caseDests = op.getCaseDestinations();
710   bool requiresChange = false;
711   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
712     Block *caseDest = caseDests[i];
713     ValueRange caseOperands = op.getCaseOperands(i);
714     argStorage.emplace_back();
715     if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
716       requiresChange = true;
717 
718     newCaseDests.push_back(caseDest);
719     newCaseOperands.push_back(caseOperands);
720   }
721 
722   Block *defaultDest = op.getDefaultDestination();
723   ValueRange defaultOperands = op.getDefaultOperands();
724   argStorage.emplace_back();
725 
726   if (succeeded(
727           collapseBranch(defaultDest, defaultOperands, argStorage.back())))
728     requiresChange = true;
729 
730   if (!requiresChange)
731     return failure();
732 
733   rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
734                                         defaultOperands, caseValues.getValue(),
735                                         newCaseDests, newCaseOperands);
736   return success();
737 }
738 
739 /// switch %flag : i32, [
740 ///   default: ^bb1,
741 ///   42: ^bb2,
742 /// ]
743 /// ^bb2:
744 ///   switch %flag : i32, [
745 ///     default: ^bb3,
746 ///     42: ^bb4
747 ///   ]
748 /// ->
749 /// switch %flag : i32, [
750 ///   default: ^bb1,
751 ///   42: ^bb2,
752 /// ]
753 /// ^bb2:
754 ///   br ^bb4
755 ///
756 ///  and
757 ///
758 /// switch %flag : i32, [
759 ///   default: ^bb1,
760 ///   42: ^bb2,
761 /// ]
762 /// ^bb2:
763 ///   switch %flag : i32, [
764 ///     default: ^bb3,
765 ///     43: ^bb4
766 ///   ]
767 /// ->
768 /// switch %flag : i32, [
769 ///   default: ^bb1,
770 ///   42: ^bb2,
771 /// ]
772 /// ^bb2:
773 ///   br ^bb3
774 static LogicalResult
775 simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
776                                         PatternRewriter &rewriter) {
777   // Check that we have a single distinct predecessor.
778   Block *currentBlock = op->getBlock();
779   Block *predecessor = currentBlock->getSinglePredecessor();
780   if (!predecessor)
781     return failure();
782 
783   // Check that the predecessor terminates with a switch branch to this block
784   // and that it branches on the same condition and that this branch isn't the
785   // default destination.
786   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
787   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
788       predSwitch.getDefaultDestination() == currentBlock)
789     return failure();
790 
791   // Fold this switch to an unconditional branch.
792   SuccessorRange predDests = predSwitch.getCaseDestinations();
793   auto it = llvm::find(predDests, currentBlock);
794   if (it != predDests.end()) {
795     Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues();
796     foldSwitch(op, rewriter,
797                predCaseValues->getValues<APInt>()[it - predDests.begin()]);
798   } else {
799     rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
800                                           op.getDefaultOperands());
801   }
802   return success();
803 }
804 
805 /// switch %flag : i32, [
806 ///   default: ^bb1,
807 ///   42: ^bb2
808 /// ]
809 /// ^bb1:
810 ///   switch %flag : i32, [
811 ///     default: ^bb3,
812 ///     42: ^bb4,
813 ///     43: ^bb5
814 ///   ]
815 /// ->
816 /// switch %flag : i32, [
817 ///   default: ^bb1,
818 ///   42: ^bb2,
819 /// ]
820 /// ^bb1:
821 ///   switch %flag : i32, [
822 ///     default: ^bb3,
823 ///     43: ^bb5
824 ///   ]
825 static LogicalResult
826 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
827                                                PatternRewriter &rewriter) {
828   // Check that we have a single distinct predecessor.
829   Block *currentBlock = op->getBlock();
830   Block *predecessor = currentBlock->getSinglePredecessor();
831   if (!predecessor)
832     return failure();
833 
834   // Check that the predecessor terminates with a switch branch to this block
835   // and that it branches on the same condition and that this branch is the
836   // default destination.
837   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
838   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
839       predSwitch.getDefaultDestination() != currentBlock)
840     return failure();
841 
842   // Delete case values that are not possible here.
843   DenseSet<APInt> caseValuesToRemove;
844   auto predDests = predSwitch.getCaseDestinations();
845   auto predCaseValues = predSwitch.getCaseValues();
846   for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
847     if (currentBlock != predDests[i])
848       caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
849 
850   SmallVector<Block *> newCaseDestinations;
851   SmallVector<ValueRange> newCaseOperands;
852   SmallVector<APInt> newCaseValues;
853   bool requiresChange = false;
854 
855   auto caseValues = op.getCaseValues();
856   auto caseDests = op.getCaseDestinations();
857   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
858     if (caseValuesToRemove.contains(it.value())) {
859       requiresChange = true;
860       continue;
861     }
862     newCaseDestinations.push_back(caseDests[it.index()]);
863     newCaseOperands.push_back(op.getCaseOperands(it.index()));
864     newCaseValues.push_back(it.value());
865   }
866 
867   if (!requiresChange)
868     return failure();
869 
870   rewriter.replaceOpWithNewOp<SwitchOp>(
871       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
872       newCaseValues, newCaseDestinations, newCaseOperands);
873   return success();
874 }
875 
876 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
877                                            MLIRContext *context) {
878   results.add(&simplifySwitchWithOnlyDefault)
879       .add(&dropSwitchCasesThatMatchDefault)
880       .add(&simplifyConstSwitchValue)
881       .add(&simplifyPassThroughSwitch)
882       .add(&simplifySwitchFromSwitchOnSameCondition)
883       .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
884 }
885 
886 //===----------------------------------------------------------------------===//
887 // TableGen'd op method definitions
888 //===----------------------------------------------------------------------===//
889 
890 #define GET_OP_CLASSES
891 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
892