xref: /llvm-project/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp (revision 0f4ba02db3985051adac07a87ca9da549c0eb8ad)
1 //===- ControlFlowInterfacesTest.cpp - Unit Tests for Control Flow Interf. ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ControlFlowInterfaces.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/IR/Dialect.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/OpDefinition.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "mlir/Parser/Parser.h"
16 
17 #include <gtest/gtest.h>
18 
19 using namespace mlir;
20 
21 /// A dummy op that is also a terminator.
22 struct DummyOp : public Op<DummyOp, OpTrait::IsTerminator> {
23   using Op::Op;
24   static ArrayRef<StringRef> getAttributeNames() { return {}; }
25 
26   static StringRef getOperationName() { return "cftest.dummy_op"; }
27 };
28 
29 /// All regions of this op are mutually exclusive.
30 struct MutuallyExclusiveRegionsOp
31     : public Op<MutuallyExclusiveRegionsOp, RegionBranchOpInterface::Trait> {
32   using Op::Op;
33   static ArrayRef<StringRef> getAttributeNames() { return {}; }
34 
35   static StringRef getOperationName() {
36     return "cftest.mutually_exclusive_regions_op";
37   }
38 
39   // Regions have no successors.
40   void getSuccessorRegions(Optional<unsigned> index,
41                            ArrayRef<Attribute> operands,
42                            SmallVectorImpl<RegionSuccessor> &regions) {}
43 };
44 
45 /// All regions of this op call each other in a large circle.
46 struct LoopRegionsOp
47     : public Op<LoopRegionsOp, RegionBranchOpInterface::Trait> {
48   using Op::Op;
49   static const unsigned kNumRegions = 3;
50 
51   static ArrayRef<StringRef> getAttributeNames() { return {}; }
52 
53   static StringRef getOperationName() { return "cftest.loop_regions_op"; }
54 
55   void getSuccessorRegions(Optional<unsigned> index,
56                            ArrayRef<Attribute> operands,
57                            SmallVectorImpl<RegionSuccessor> &regions) {
58     if (index) {
59       if (*index == 1)
60         // This region also branches back to the parent.
61         regions.push_back(RegionSuccessor());
62       regions.push_back(
63           RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions)));
64     }
65   }
66 };
67 
68 /// Regions are executed sequentially.
69 struct SequentialRegionsOp
70     : public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> {
71   using Op::Op;
72   static ArrayRef<StringRef> getAttributeNames() { return {}; }
73 
74   static StringRef getOperationName() { return "cftest.sequential_regions_op"; }
75 
76   // Region 0 has Region 1 as a successor.
77   void getSuccessorRegions(Optional<unsigned> index,
78                            ArrayRef<Attribute> operands,
79                            SmallVectorImpl<RegionSuccessor> &regions) {
80     if (index == 0u) {
81       Operation *thisOp = this->getOperation();
82       regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
83     }
84   }
85 };
86 
87 /// A dialect putting all the above together.
88 struct CFTestDialect : Dialect {
89   explicit CFTestDialect(MLIRContext *ctx)
90       : Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) {
91     addOperations<DummyOp, MutuallyExclusiveRegionsOp, LoopRegionsOp,
92                   SequentialRegionsOp>();
93   }
94   static StringRef getDialectNamespace() { return "cftest"; }
95 };
96 
97 TEST(RegionBranchOpInterface, MutuallyExclusiveOps) {
98   const char *ir = R"MLIR(
99 "cftest.mutually_exclusive_regions_op"() (
100       {"cftest.dummy_op"() : () -> ()},  // op1
101       {"cftest.dummy_op"() : () -> ()}   // op2
102   ) : () -> ()
103   )MLIR";
104 
105   DialectRegistry registry;
106   registry.insert<CFTestDialect>();
107   MLIRContext ctx(registry);
108 
109   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
110   Operation *testOp = &module->getBody()->getOperations().front();
111   Operation *op1 = &testOp->getRegion(0).front().front();
112   Operation *op2 = &testOp->getRegion(1).front().front();
113 
114   EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
115   EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
116 }
117 
118 TEST(RegionBranchOpInterface, NotMutuallyExclusiveOps) {
119   const char *ir = R"MLIR(
120 "cftest.sequential_regions_op"() (
121       {"cftest.dummy_op"() : () -> ()},  // op1
122       {"cftest.dummy_op"() : () -> ()}   // op2
123   ) : () -> ()
124   )MLIR";
125 
126   DialectRegistry registry;
127   registry.insert<CFTestDialect>();
128   MLIRContext ctx(registry);
129 
130   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
131   Operation *testOp = &module->getBody()->getOperations().front();
132   Operation *op1 = &testOp->getRegion(0).front().front();
133   Operation *op2 = &testOp->getRegion(1).front().front();
134 
135   EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op2));
136   EXPECT_FALSE(insideMutuallyExclusiveRegions(op2, op1));
137 }
138 
139 TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) {
140   const char *ir = R"MLIR(
141 "cftest.mutually_exclusive_regions_op"() (
142       {
143         "cftest.sequential_regions_op"() (
144               {"cftest.dummy_op"() : () -> ()},  // op1
145               {"cftest.dummy_op"() : () -> ()}   // op3
146           ) : () -> ()
147         "cftest.dummy_op"() : () -> ()
148       },
149       {"cftest.dummy_op"() : () -> ()}           // op2
150   ) : () -> ()
151   )MLIR";
152 
153   DialectRegistry registry;
154   registry.insert<CFTestDialect>();
155   MLIRContext ctx(registry);
156 
157   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
158   Operation *testOp = &module->getBody()->getOperations().front();
159   Operation *op1 =
160       &testOp->getRegion(0).front().front().getRegion(0).front().front();
161   Operation *op2 = &testOp->getRegion(1).front().front();
162   Operation *op3 =
163       &testOp->getRegion(0).front().front().getRegion(1).front().front();
164 
165   EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
166   EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2));
167   EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3));
168 }
169 
170 TEST(RegionBranchOpInterface, RecursiveRegions) {
171   const char *ir = R"MLIR(
172 "cftest.loop_regions_op"() (
173       {"cftest.dummy_op"() : () -> ()},  // op1
174       {"cftest.dummy_op"() : () -> ()},  // op2
175       {"cftest.dummy_op"() : () -> ()}   // op3
176   ) : () -> ()
177   )MLIR";
178 
179   DialectRegistry registry;
180   registry.insert<CFTestDialect>();
181   MLIRContext ctx(registry);
182 
183   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
184   Operation *testOp = &module->getBody()->getOperations().front();
185   auto regionOp = cast<RegionBranchOpInterface>(testOp);
186   Operation *op1 = &testOp->getRegion(0).front().front();
187   Operation *op2 = &testOp->getRegion(1).front().front();
188   Operation *op3 = &testOp->getRegion(2).front().front();
189 
190   EXPECT_TRUE(regionOp.isRepetitiveRegion(0));
191   EXPECT_TRUE(regionOp.isRepetitiveRegion(1));
192   EXPECT_TRUE(regionOp.isRepetitiveRegion(2));
193   EXPECT_NE(getEnclosingRepetitiveRegion(op1), nullptr);
194   EXPECT_NE(getEnclosingRepetitiveRegion(op2), nullptr);
195   EXPECT_NE(getEnclosingRepetitiveRegion(op3), nullptr);
196 }
197 
198 TEST(RegionBranchOpInterface, NotRecursiveRegions) {
199   const char *ir = R"MLIR(
200 "cftest.sequential_regions_op"() (
201       {"cftest.dummy_op"() : () -> ()},  // op1
202       {"cftest.dummy_op"() : () -> ()}   // op2
203   ) : () -> ()
204   )MLIR";
205 
206   DialectRegistry registry;
207   registry.insert<CFTestDialect>();
208   MLIRContext ctx(registry);
209 
210   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
211   Operation *testOp = &module->getBody()->getOperations().front();
212   Operation *op1 = &testOp->getRegion(0).front().front();
213   Operation *op2 = &testOp->getRegion(1).front().front();
214 
215   EXPECT_EQ(getEnclosingRepetitiveRegion(op1), nullptr);
216   EXPECT_EQ(getEnclosingRepetitiveRegion(op2), nullptr);
217 }
218