xref: /llvm-project/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp (revision aa8feeefd3ac6c78ee8f67bf033976fc7d68bc6d)
1 //===- ControlFlowInterfacesTest.cpp - Unit Tests for Control Flow Interf. ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ControlFlowInterfaces.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/IR/Dialect.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/OpDefinition.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "mlir/Parser/Parser.h"
16 
17 #include <gtest/gtest.h>
18 
19 using namespace mlir;
20 
21 /// A dummy op that is also a terminator.
22 struct DummyOp : public Op<DummyOp, OpTrait::IsTerminator> {
23   using Op::Op;
24   static ArrayRef<StringRef> getAttributeNames() { return {}; }
25 
26   static StringRef getOperationName() { return "cftest.dummy_op"; }
27 };
28 
29 /// All regions of this op are mutually exclusive.
30 struct MutuallyExclusiveRegionsOp
31     : public Op<MutuallyExclusiveRegionsOp, RegionBranchOpInterface::Trait> {
32   using Op::Op;
33   static ArrayRef<StringRef> getAttributeNames() { return {}; }
34 
35   static StringRef getOperationName() {
36     return "cftest.mutually_exclusive_regions_op";
37   }
38 
39   // Regions have no successors.
40   void getSuccessorRegions(Optional<unsigned> index,
41                            ArrayRef<Attribute> operands,
42                            SmallVectorImpl<RegionSuccessor> &regions) {}
43 };
44 
45 /// All regions of this op call each other in a large circle.
46 struct LoopRegionsOp
47     : public Op<LoopRegionsOp, RegionBranchOpInterface::Trait> {
48   using Op::Op;
49   static const unsigned kNumRegions = 3;
50 
51   static ArrayRef<StringRef> getAttributeNames() { return {}; }
52 
53   static StringRef getOperationName() { return "cftest.loop_regions_op"; }
54 
55   void getSuccessorRegions(Optional<unsigned> index,
56                            ArrayRef<Attribute> operands,
57                            SmallVectorImpl<RegionSuccessor> &regions) {
58     if (index) {
59       if (*index == 1)
60         // This region also branches back to the parent.
61         regions.push_back(RegionSuccessor());
62       regions.push_back(
63           RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions)));
64     }
65   }
66 };
67 
68 /// Each region branches back it itself or the parent.
69 struct DoubleLoopRegionsOp
70     : public Op<DoubleLoopRegionsOp, RegionBranchOpInterface::Trait> {
71   using Op::Op;
72 
73   static ArrayRef<StringRef> getAttributeNames() { return {}; }
74 
75   static StringRef getOperationName() {
76     return "cftest.double_loop_regions_op";
77   }
78 
79   void getSuccessorRegions(Optional<unsigned> index,
80                            ArrayRef<Attribute> operands,
81                            SmallVectorImpl<RegionSuccessor> &regions) {
82     if (index.has_value()) {
83       regions.push_back(RegionSuccessor());
84       regions.push_back(RegionSuccessor(&getOperation()->getRegion(*index)));
85     }
86   }
87 };
88 
89 /// Regions are executed sequentially.
90 struct SequentialRegionsOp
91     : public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> {
92   using Op::Op;
93   static ArrayRef<StringRef> getAttributeNames() { return {}; }
94 
95   static StringRef getOperationName() { return "cftest.sequential_regions_op"; }
96 
97   // Region 0 has Region 1 as a successor.
98   void getSuccessorRegions(Optional<unsigned> index,
99                            ArrayRef<Attribute> operands,
100                            SmallVectorImpl<RegionSuccessor> &regions) {
101     if (index == 0u) {
102       Operation *thisOp = this->getOperation();
103       regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
104     }
105   }
106 };
107 
108 /// A dialect putting all the above together.
109 struct CFTestDialect : Dialect {
110   explicit CFTestDialect(MLIRContext *ctx)
111       : Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) {
112     addOperations<DummyOp, MutuallyExclusiveRegionsOp, LoopRegionsOp,
113                   DoubleLoopRegionsOp, SequentialRegionsOp>();
114   }
115   static StringRef getDialectNamespace() { return "cftest"; }
116 };
117 
118 TEST(RegionBranchOpInterface, MutuallyExclusiveOps) {
119   const char *ir = R"MLIR(
120 "cftest.mutually_exclusive_regions_op"() (
121       {"cftest.dummy_op"() : () -> ()},  // op1
122       {"cftest.dummy_op"() : () -> ()}   // op2
123   ) : () -> ()
124   )MLIR";
125 
126   DialectRegistry registry;
127   registry.insert<CFTestDialect>();
128   MLIRContext ctx(registry);
129 
130   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
131   Operation *testOp = &module->getBody()->getOperations().front();
132   Operation *op1 = &testOp->getRegion(0).front().front();
133   Operation *op2 = &testOp->getRegion(1).front().front();
134 
135   EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
136   EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
137 }
138 
139 TEST(RegionBranchOpInterface, MutuallyExclusiveOps2) {
140   const char *ir = R"MLIR(
141 "cftest.double_loop_regions_op"() (
142       {"cftest.dummy_op"() : () -> ()},  // op1
143       {"cftest.dummy_op"() : () -> ()}   // op2
144   ) : () -> ()
145   )MLIR";
146 
147   DialectRegistry registry;
148   registry.insert<CFTestDialect>();
149   MLIRContext ctx(registry);
150 
151   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
152   Operation *testOp = &module->getBody()->getOperations().front();
153   Operation *op1 = &testOp->getRegion(0).front().front();
154   Operation *op2 = &testOp->getRegion(1).front().front();
155 
156   EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
157   EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
158 }
159 
160 TEST(RegionBranchOpInterface, NotMutuallyExclusiveOps) {
161   const char *ir = R"MLIR(
162 "cftest.sequential_regions_op"() (
163       {"cftest.dummy_op"() : () -> ()},  // op1
164       {"cftest.dummy_op"() : () -> ()}   // op2
165   ) : () -> ()
166   )MLIR";
167 
168   DialectRegistry registry;
169   registry.insert<CFTestDialect>();
170   MLIRContext ctx(registry);
171 
172   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
173   Operation *testOp = &module->getBody()->getOperations().front();
174   Operation *op1 = &testOp->getRegion(0).front().front();
175   Operation *op2 = &testOp->getRegion(1).front().front();
176 
177   EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op2));
178   EXPECT_FALSE(insideMutuallyExclusiveRegions(op2, op1));
179 }
180 
181 TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) {
182   const char *ir = R"MLIR(
183 "cftest.mutually_exclusive_regions_op"() (
184       {
185         "cftest.sequential_regions_op"() (
186               {"cftest.dummy_op"() : () -> ()},  // op1
187               {"cftest.dummy_op"() : () -> ()}   // op3
188           ) : () -> ()
189         "cftest.dummy_op"() : () -> ()
190       },
191       {"cftest.dummy_op"() : () -> ()}           // op2
192   ) : () -> ()
193   )MLIR";
194 
195   DialectRegistry registry;
196   registry.insert<CFTestDialect>();
197   MLIRContext ctx(registry);
198 
199   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
200   Operation *testOp = &module->getBody()->getOperations().front();
201   Operation *op1 =
202       &testOp->getRegion(0).front().front().getRegion(0).front().front();
203   Operation *op2 = &testOp->getRegion(1).front().front();
204   Operation *op3 =
205       &testOp->getRegion(0).front().front().getRegion(1).front().front();
206 
207   EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
208   EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2));
209   EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3));
210 }
211 
212 TEST(RegionBranchOpInterface, RecursiveRegions) {
213   const char *ir = R"MLIR(
214 "cftest.loop_regions_op"() (
215       {"cftest.dummy_op"() : () -> ()},  // op1
216       {"cftest.dummy_op"() : () -> ()},  // op2
217       {"cftest.dummy_op"() : () -> ()}   // op3
218   ) : () -> ()
219   )MLIR";
220 
221   DialectRegistry registry;
222   registry.insert<CFTestDialect>();
223   MLIRContext ctx(registry);
224 
225   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
226   Operation *testOp = &module->getBody()->getOperations().front();
227   auto regionOp = cast<RegionBranchOpInterface>(testOp);
228   Operation *op1 = &testOp->getRegion(0).front().front();
229   Operation *op2 = &testOp->getRegion(1).front().front();
230   Operation *op3 = &testOp->getRegion(2).front().front();
231 
232   EXPECT_TRUE(regionOp.isRepetitiveRegion(0));
233   EXPECT_TRUE(regionOp.isRepetitiveRegion(1));
234   EXPECT_TRUE(regionOp.isRepetitiveRegion(2));
235   EXPECT_NE(getEnclosingRepetitiveRegion(op1), nullptr);
236   EXPECT_NE(getEnclosingRepetitiveRegion(op2), nullptr);
237   EXPECT_NE(getEnclosingRepetitiveRegion(op3), nullptr);
238 }
239 
240 TEST(RegionBranchOpInterface, NotRecursiveRegions) {
241   const char *ir = R"MLIR(
242 "cftest.sequential_regions_op"() (
243       {"cftest.dummy_op"() : () -> ()},  // op1
244       {"cftest.dummy_op"() : () -> ()}   // op2
245   ) : () -> ()
246   )MLIR";
247 
248   DialectRegistry registry;
249   registry.insert<CFTestDialect>();
250   MLIRContext ctx(registry);
251 
252   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
253   Operation *testOp = &module->getBody()->getOperations().front();
254   Operation *op1 = &testOp->getRegion(0).front().front();
255   Operation *op2 = &testOp->getRegion(1).front().front();
256 
257   EXPECT_EQ(getEnclosingRepetitiveRegion(op1), nullptr);
258   EXPECT_EQ(getEnclosingRepetitiveRegion(op2), nullptr);
259 }
260