xref: /llvm-project/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp (revision 4dd744ac9c0f772a61dd91c84bc14d17e69aec51)
1 //===- ControlFlowInterfacesTest.cpp - Unit Tests for Control Flow Interf. ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ControlFlowInterfaces.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/IR/Dialect.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/OpDefinition.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "mlir/Parser/Parser.h"
16 
17 #include <gtest/gtest.h>
18 
19 using namespace mlir;
20 
21 /// A dummy op that is also a terminator.
22 struct DummyOp : public Op<DummyOp, OpTrait::IsTerminator> {
23   using Op::Op;
getAttributeNamesDummyOp24   static ArrayRef<StringRef> getAttributeNames() { return {}; }
25 
getOperationNameDummyOp26   static StringRef getOperationName() { return "cftest.dummy_op"; }
27 };
28 
29 /// All regions of this op are mutually exclusive.
30 struct MutuallyExclusiveRegionsOp
31     : public Op<MutuallyExclusiveRegionsOp, RegionBranchOpInterface::Trait> {
32   using Op::Op;
getAttributeNamesMutuallyExclusiveRegionsOp33   static ArrayRef<StringRef> getAttributeNames() { return {}; }
34 
getOperationNameMutuallyExclusiveRegionsOp35   static StringRef getOperationName() {
36     return "cftest.mutually_exclusive_regions_op";
37   }
38 
39   // Regions have no successors.
getSuccessorRegionsMutuallyExclusiveRegionsOp40   void getSuccessorRegions(RegionBranchPoint point,
41                            SmallVectorImpl<RegionSuccessor> &regions) {}
42 };
43 
44 /// All regions of this op call each other in a large circle.
45 struct LoopRegionsOp
46     : public Op<LoopRegionsOp, RegionBranchOpInterface::Trait> {
47   using Op::Op;
48   static const unsigned kNumRegions = 3;
49 
getAttributeNamesLoopRegionsOp50   static ArrayRef<StringRef> getAttributeNames() { return {}; }
51 
getOperationNameLoopRegionsOp52   static StringRef getOperationName() { return "cftest.loop_regions_op"; }
53 
getSuccessorRegionsLoopRegionsOp54   void getSuccessorRegions(RegionBranchPoint point,
55                            SmallVectorImpl<RegionSuccessor> &regions) {
56     if (Region *region = point.getRegionOrNull()) {
57       if (point == (*this)->getRegion(1))
58         // This region also branches back to the parent.
59         regions.push_back(RegionSuccessor());
60       regions.push_back(RegionSuccessor(region));
61     }
62   }
63 };
64 
65 /// Each region branches back it itself or the parent.
66 struct DoubleLoopRegionsOp
67     : public Op<DoubleLoopRegionsOp, RegionBranchOpInterface::Trait> {
68   using Op::Op;
69 
getAttributeNamesDoubleLoopRegionsOp70   static ArrayRef<StringRef> getAttributeNames() { return {}; }
71 
getOperationNameDoubleLoopRegionsOp72   static StringRef getOperationName() {
73     return "cftest.double_loop_regions_op";
74   }
75 
getSuccessorRegionsDoubleLoopRegionsOp76   void getSuccessorRegions(RegionBranchPoint point,
77                            SmallVectorImpl<RegionSuccessor> &regions) {
78     if (Region *region = point.getRegionOrNull()) {
79       regions.push_back(RegionSuccessor());
80       regions.push_back(RegionSuccessor(region));
81     }
82   }
83 };
84 
85 /// Regions are executed sequentially.
86 struct SequentialRegionsOp
87     : public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> {
88   using Op::Op;
getAttributeNamesSequentialRegionsOp89   static ArrayRef<StringRef> getAttributeNames() { return {}; }
90 
getOperationNameSequentialRegionsOp91   static StringRef getOperationName() { return "cftest.sequential_regions_op"; }
92 
93   // Region 0 has Region 1 as a successor.
getSuccessorRegionsSequentialRegionsOp94   void getSuccessorRegions(RegionBranchPoint point,
95                            SmallVectorImpl<RegionSuccessor> &regions) {
96     if (point == (*this)->getRegion(0)) {
97       Operation *thisOp = this->getOperation();
98       regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
99     }
100   }
101 };
102 
103 /// A dialect putting all the above together.
104 struct CFTestDialect : Dialect {
CFTestDialectCFTestDialect105   explicit CFTestDialect(MLIRContext *ctx)
106       : Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) {
107     addOperations<DummyOp, MutuallyExclusiveRegionsOp, LoopRegionsOp,
108                   DoubleLoopRegionsOp, SequentialRegionsOp>();
109   }
getDialectNamespaceCFTestDialect110   static StringRef getDialectNamespace() { return "cftest"; }
111 };
112 
TEST(RegionBranchOpInterface,MutuallyExclusiveOps)113 TEST(RegionBranchOpInterface, MutuallyExclusiveOps) {
114   const char *ir = R"MLIR(
115 "cftest.mutually_exclusive_regions_op"() (
116       {"cftest.dummy_op"() : () -> ()},  // op1
117       {"cftest.dummy_op"() : () -> ()}   // op2
118   ) : () -> ()
119   )MLIR";
120 
121   DialectRegistry registry;
122   registry.insert<CFTestDialect>();
123   MLIRContext ctx(registry);
124 
125   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
126   Operation *testOp = &module->getBody()->getOperations().front();
127   Operation *op1 = &testOp->getRegion(0).front().front();
128   Operation *op2 = &testOp->getRegion(1).front().front();
129 
130   EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
131   EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
132 }
133 
TEST(RegionBranchOpInterface,MutuallyExclusiveOps2)134 TEST(RegionBranchOpInterface, MutuallyExclusiveOps2) {
135   const char *ir = R"MLIR(
136 "cftest.double_loop_regions_op"() (
137       {"cftest.dummy_op"() : () -> ()},  // op1
138       {"cftest.dummy_op"() : () -> ()}   // op2
139   ) : () -> ()
140   )MLIR";
141 
142   DialectRegistry registry;
143   registry.insert<CFTestDialect>();
144   MLIRContext ctx(registry);
145 
146   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
147   Operation *testOp = &module->getBody()->getOperations().front();
148   Operation *op1 = &testOp->getRegion(0).front().front();
149   Operation *op2 = &testOp->getRegion(1).front().front();
150 
151   EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
152   EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
153 }
154 
TEST(RegionBranchOpInterface,NotMutuallyExclusiveOps)155 TEST(RegionBranchOpInterface, NotMutuallyExclusiveOps) {
156   const char *ir = R"MLIR(
157 "cftest.sequential_regions_op"() (
158       {"cftest.dummy_op"() : () -> ()},  // op1
159       {"cftest.dummy_op"() : () -> ()}   // op2
160   ) : () -> ()
161   )MLIR";
162 
163   DialectRegistry registry;
164   registry.insert<CFTestDialect>();
165   MLIRContext ctx(registry);
166 
167   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
168   Operation *testOp = &module->getBody()->getOperations().front();
169   Operation *op1 = &testOp->getRegion(0).front().front();
170   Operation *op2 = &testOp->getRegion(1).front().front();
171 
172   EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op2));
173   EXPECT_FALSE(insideMutuallyExclusiveRegions(op2, op1));
174 }
175 
TEST(RegionBranchOpInterface,NestedMutuallyExclusiveOps)176 TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) {
177   const char *ir = R"MLIR(
178 "cftest.mutually_exclusive_regions_op"() (
179       {
180         "cftest.sequential_regions_op"() (
181               {"cftest.dummy_op"() : () -> ()},  // op1
182               {"cftest.dummy_op"() : () -> ()}   // op3
183           ) : () -> ()
184         "cftest.dummy_op"() : () -> ()
185       },
186       {"cftest.dummy_op"() : () -> ()}           // op2
187   ) : () -> ()
188   )MLIR";
189 
190   DialectRegistry registry;
191   registry.insert<CFTestDialect>();
192   MLIRContext ctx(registry);
193 
194   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
195   Operation *testOp = &module->getBody()->getOperations().front();
196   Operation *op1 =
197       &testOp->getRegion(0).front().front().getRegion(0).front().front();
198   Operation *op2 = &testOp->getRegion(1).front().front();
199   Operation *op3 =
200       &testOp->getRegion(0).front().front().getRegion(1).front().front();
201 
202   EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
203   EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2));
204   EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3));
205 }
206 
TEST(RegionBranchOpInterface,RecursiveRegions)207 TEST(RegionBranchOpInterface, RecursiveRegions) {
208   const char *ir = R"MLIR(
209 "cftest.loop_regions_op"() (
210       {"cftest.dummy_op"() : () -> ()},  // op1
211       {"cftest.dummy_op"() : () -> ()},  // op2
212       {"cftest.dummy_op"() : () -> ()}   // op3
213   ) : () -> ()
214   )MLIR";
215 
216   DialectRegistry registry;
217   registry.insert<CFTestDialect>();
218   MLIRContext ctx(registry);
219 
220   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
221   Operation *testOp = &module->getBody()->getOperations().front();
222   auto regionOp = cast<RegionBranchOpInterface>(testOp);
223   Operation *op1 = &testOp->getRegion(0).front().front();
224   Operation *op2 = &testOp->getRegion(1).front().front();
225   Operation *op3 = &testOp->getRegion(2).front().front();
226 
227   EXPECT_TRUE(regionOp.isRepetitiveRegion(0));
228   EXPECT_TRUE(regionOp.isRepetitiveRegion(1));
229   EXPECT_TRUE(regionOp.isRepetitiveRegion(2));
230   EXPECT_NE(getEnclosingRepetitiveRegion(op1), nullptr);
231   EXPECT_NE(getEnclosingRepetitiveRegion(op2), nullptr);
232   EXPECT_NE(getEnclosingRepetitiveRegion(op3), nullptr);
233 }
234 
TEST(RegionBranchOpInterface,NotRecursiveRegions)235 TEST(RegionBranchOpInterface, NotRecursiveRegions) {
236   const char *ir = R"MLIR(
237 "cftest.sequential_regions_op"() (
238       {"cftest.dummy_op"() : () -> ()},  // op1
239       {"cftest.dummy_op"() : () -> ()}   // op2
240   ) : () -> ()
241   )MLIR";
242 
243   DialectRegistry registry;
244   registry.insert<CFTestDialect>();
245   MLIRContext ctx(registry);
246 
247   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
248   Operation *testOp = &module->getBody()->getOperations().front();
249   Operation *op1 = &testOp->getRegion(0).front().front();
250   Operation *op2 = &testOp->getRegion(1).front().front();
251 
252   EXPECT_EQ(getEnclosingRepetitiveRegion(op1), nullptr);
253   EXPECT_EQ(getEnclosingRepetitiveRegion(op2), nullptr);
254 }
255