xref: /llvm-project/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp (revision 4dd744ac9c0f772a61dd91c84bc14d17e69aec51)
1a5c2f782SMatthias Springer //===- ControlFlowInterfacesTest.cpp - Unit Tests for Control Flow Interf. ===//
2a5c2f782SMatthias Springer //
3a5c2f782SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a5c2f782SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5a5c2f782SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a5c2f782SMatthias Springer //
7a5c2f782SMatthias Springer //===----------------------------------------------------------------------===//
8a5c2f782SMatthias Springer 
9a5c2f782SMatthias Springer #include "mlir/Interfaces/ControlFlowInterfaces.h"
10a5c2f782SMatthias Springer #include "mlir/IR/BuiltinOps.h"
11a5c2f782SMatthias Springer #include "mlir/IR/Dialect.h"
12a5c2f782SMatthias Springer #include "mlir/IR/DialectImplementation.h"
13a5c2f782SMatthias Springer #include "mlir/IR/OpDefinition.h"
14a5c2f782SMatthias Springer #include "mlir/IR/OpImplementation.h"
159eaff423SRiver Riddle #include "mlir/Parser/Parser.h"
16a5c2f782SMatthias Springer 
17a5c2f782SMatthias Springer #include <gtest/gtest.h>
18a5c2f782SMatthias Springer 
19a5c2f782SMatthias Springer using namespace mlir;
20a5c2f782SMatthias Springer 
21a5c2f782SMatthias Springer /// A dummy op that is also a terminator.
22a5c2f782SMatthias Springer struct DummyOp : public Op<DummyOp, OpTrait::IsTerminator> {
23a5c2f782SMatthias Springer   using Op::Op;
getAttributeNamesDummyOp24a5c2f782SMatthias Springer   static ArrayRef<StringRef> getAttributeNames() { return {}; }
25a5c2f782SMatthias Springer 
getOperationNameDummyOp26a5c2f782SMatthias Springer   static StringRef getOperationName() { return "cftest.dummy_op"; }
27a5c2f782SMatthias Springer };
28a5c2f782SMatthias Springer 
29a5c2f782SMatthias Springer /// All regions of this op are mutually exclusive.
30a5c2f782SMatthias Springer struct MutuallyExclusiveRegionsOp
31a5c2f782SMatthias Springer     : public Op<MutuallyExclusiveRegionsOp, RegionBranchOpInterface::Trait> {
32a5c2f782SMatthias Springer   using Op::Op;
getAttributeNamesMutuallyExclusiveRegionsOp33a5c2f782SMatthias Springer   static ArrayRef<StringRef> getAttributeNames() { return {}; }
34a5c2f782SMatthias Springer 
getOperationNameMutuallyExclusiveRegionsOp35a5c2f782SMatthias Springer   static StringRef getOperationName() {
36a5c2f782SMatthias Springer     return "cftest.mutually_exclusive_regions_op";
37a5c2f782SMatthias Springer   }
38a5c2f782SMatthias Springer 
39a5c2f782SMatthias Springer   // Regions have no successors.
getSuccessorRegionsMutuallyExclusiveRegionsOp40*4dd744acSMarkus Böck   void getSuccessorRegions(RegionBranchPoint point,
41a5c2f782SMatthias Springer                            SmallVectorImpl<RegionSuccessor> &regions) {}
42a5c2f782SMatthias Springer };
43a5c2f782SMatthias Springer 
440f4ba02dSMatthias Springer /// All regions of this op call each other in a large circle.
450f4ba02dSMatthias Springer struct LoopRegionsOp
460f4ba02dSMatthias Springer     : public Op<LoopRegionsOp, RegionBranchOpInterface::Trait> {
470f4ba02dSMatthias Springer   using Op::Op;
480f4ba02dSMatthias Springer   static const unsigned kNumRegions = 3;
490f4ba02dSMatthias Springer 
getAttributeNamesLoopRegionsOp500f4ba02dSMatthias Springer   static ArrayRef<StringRef> getAttributeNames() { return {}; }
510f4ba02dSMatthias Springer 
getOperationNameLoopRegionsOp520f4ba02dSMatthias Springer   static StringRef getOperationName() { return "cftest.loop_regions_op"; }
530f4ba02dSMatthias Springer 
getSuccessorRegionsLoopRegionsOp54*4dd744acSMarkus Böck   void getSuccessorRegions(RegionBranchPoint point,
550f4ba02dSMatthias Springer                            SmallVectorImpl<RegionSuccessor> &regions) {
56*4dd744acSMarkus Böck     if (Region *region = point.getRegionOrNull()) {
57*4dd744acSMarkus Böck       if (point == (*this)->getRegion(1))
580f4ba02dSMatthias Springer         // This region also branches back to the parent.
590f4ba02dSMatthias Springer         regions.push_back(RegionSuccessor());
60*4dd744acSMarkus Böck       regions.push_back(RegionSuccessor(region));
610f4ba02dSMatthias Springer     }
620f4ba02dSMatthias Springer   }
630f4ba02dSMatthias Springer };
640f4ba02dSMatthias Springer 
65a3005a40SMatthias Springer /// Each region branches back it itself or the parent.
66a3005a40SMatthias Springer struct DoubleLoopRegionsOp
67a3005a40SMatthias Springer     : public Op<DoubleLoopRegionsOp, RegionBranchOpInterface::Trait> {
68a3005a40SMatthias Springer   using Op::Op;
69a3005a40SMatthias Springer 
getAttributeNamesDoubleLoopRegionsOp70a3005a40SMatthias Springer   static ArrayRef<StringRef> getAttributeNames() { return {}; }
71a3005a40SMatthias Springer 
getOperationNameDoubleLoopRegionsOp72a3005a40SMatthias Springer   static StringRef getOperationName() {
73a3005a40SMatthias Springer     return "cftest.double_loop_regions_op";
74a3005a40SMatthias Springer   }
75a3005a40SMatthias Springer 
getSuccessorRegionsDoubleLoopRegionsOp76*4dd744acSMarkus Böck   void getSuccessorRegions(RegionBranchPoint point,
77a3005a40SMatthias Springer                            SmallVectorImpl<RegionSuccessor> &regions) {
78*4dd744acSMarkus Böck     if (Region *region = point.getRegionOrNull()) {
79a3005a40SMatthias Springer       regions.push_back(RegionSuccessor());
80*4dd744acSMarkus Böck       regions.push_back(RegionSuccessor(region));
81a3005a40SMatthias Springer     }
82a3005a40SMatthias Springer   }
83a3005a40SMatthias Springer };
84a3005a40SMatthias Springer 
85a5c2f782SMatthias Springer /// Regions are executed sequentially.
86a5c2f782SMatthias Springer struct SequentialRegionsOp
87a5c2f782SMatthias Springer     : public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> {
88a5c2f782SMatthias Springer   using Op::Op;
getAttributeNamesSequentialRegionsOp89a5c2f782SMatthias Springer   static ArrayRef<StringRef> getAttributeNames() { return {}; }
90a5c2f782SMatthias Springer 
getOperationNameSequentialRegionsOp91a5c2f782SMatthias Springer   static StringRef getOperationName() { return "cftest.sequential_regions_op"; }
92a5c2f782SMatthias Springer 
93a5c2f782SMatthias Springer   // Region 0 has Region 1 as a successor.
getSuccessorRegionsSequentialRegionsOp94*4dd744acSMarkus Böck   void getSuccessorRegions(RegionBranchPoint point,
95a5c2f782SMatthias Springer                            SmallVectorImpl<RegionSuccessor> &regions) {
96*4dd744acSMarkus Böck     if (point == (*this)->getRegion(0)) {
97a5c2f782SMatthias Springer       Operation *thisOp = this->getOperation();
98a5c2f782SMatthias Springer       regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
99a5c2f782SMatthias Springer     }
100a5c2f782SMatthias Springer   }
101a5c2f782SMatthias Springer };
102a5c2f782SMatthias Springer 
103a5c2f782SMatthias Springer /// A dialect putting all the above together.
104a5c2f782SMatthias Springer struct CFTestDialect : Dialect {
CFTestDialectCFTestDialect105a5c2f782SMatthias Springer   explicit CFTestDialect(MLIRContext *ctx)
106a5c2f782SMatthias Springer       : Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) {
1070f4ba02dSMatthias Springer     addOperations<DummyOp, MutuallyExclusiveRegionsOp, LoopRegionsOp,
108a3005a40SMatthias Springer                   DoubleLoopRegionsOp, SequentialRegionsOp>();
109a5c2f782SMatthias Springer   }
getDialectNamespaceCFTestDialect110a5c2f782SMatthias Springer   static StringRef getDialectNamespace() { return "cftest"; }
111a5c2f782SMatthias Springer };
112a5c2f782SMatthias Springer 
TEST(RegionBranchOpInterface,MutuallyExclusiveOps)113a5c2f782SMatthias Springer TEST(RegionBranchOpInterface, MutuallyExclusiveOps) {
114a5c2f782SMatthias Springer   const char *ir = R"MLIR(
115a5c2f782SMatthias Springer "cftest.mutually_exclusive_regions_op"() (
116a5c2f782SMatthias Springer       {"cftest.dummy_op"() : () -> ()},  // op1
117a5c2f782SMatthias Springer       {"cftest.dummy_op"() : () -> ()}   // op2
118a5c2f782SMatthias Springer   ) : () -> ()
119a5c2f782SMatthias Springer   )MLIR";
120a5c2f782SMatthias Springer 
121a5c2f782SMatthias Springer   DialectRegistry registry;
122a5c2f782SMatthias Springer   registry.insert<CFTestDialect>();
123a5c2f782SMatthias Springer   MLIRContext ctx(registry);
124a5c2f782SMatthias Springer 
125dfaadf6bSChristian Sigg   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
126a5c2f782SMatthias Springer   Operation *testOp = &module->getBody()->getOperations().front();
127a5c2f782SMatthias Springer   Operation *op1 = &testOp->getRegion(0).front().front();
128a5c2f782SMatthias Springer   Operation *op2 = &testOp->getRegion(1).front().front();
129a5c2f782SMatthias Springer 
130a5c2f782SMatthias Springer   EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
131a5c2f782SMatthias Springer   EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
132a5c2f782SMatthias Springer }
133a5c2f782SMatthias Springer 
TEST(RegionBranchOpInterface,MutuallyExclusiveOps2)134a3005a40SMatthias Springer TEST(RegionBranchOpInterface, MutuallyExclusiveOps2) {
135a3005a40SMatthias Springer   const char *ir = R"MLIR(
136a3005a40SMatthias Springer "cftest.double_loop_regions_op"() (
137a3005a40SMatthias Springer       {"cftest.dummy_op"() : () -> ()},  // op1
138a3005a40SMatthias Springer       {"cftest.dummy_op"() : () -> ()}   // op2
139a3005a40SMatthias Springer   ) : () -> ()
140a3005a40SMatthias Springer   )MLIR";
141a3005a40SMatthias Springer 
142a3005a40SMatthias Springer   DialectRegistry registry;
143a3005a40SMatthias Springer   registry.insert<CFTestDialect>();
144a3005a40SMatthias Springer   MLIRContext ctx(registry);
145a3005a40SMatthias Springer 
146a3005a40SMatthias Springer   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
147a3005a40SMatthias Springer   Operation *testOp = &module->getBody()->getOperations().front();
148a3005a40SMatthias Springer   Operation *op1 = &testOp->getRegion(0).front().front();
149a3005a40SMatthias Springer   Operation *op2 = &testOp->getRegion(1).front().front();
150a3005a40SMatthias Springer 
151a3005a40SMatthias Springer   EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
152a3005a40SMatthias Springer   EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
153a3005a40SMatthias Springer }
154a3005a40SMatthias Springer 
TEST(RegionBranchOpInterface,NotMutuallyExclusiveOps)155a5c2f782SMatthias Springer TEST(RegionBranchOpInterface, NotMutuallyExclusiveOps) {
156a5c2f782SMatthias Springer   const char *ir = R"MLIR(
157a5c2f782SMatthias Springer "cftest.sequential_regions_op"() (
158a5c2f782SMatthias Springer       {"cftest.dummy_op"() : () -> ()},  // op1
159a5c2f782SMatthias Springer       {"cftest.dummy_op"() : () -> ()}   // op2
160a5c2f782SMatthias Springer   ) : () -> ()
161a5c2f782SMatthias Springer   )MLIR";
162a5c2f782SMatthias Springer 
163a5c2f782SMatthias Springer   DialectRegistry registry;
164a5c2f782SMatthias Springer   registry.insert<CFTestDialect>();
165a5c2f782SMatthias Springer   MLIRContext ctx(registry);
166a5c2f782SMatthias Springer 
167dfaadf6bSChristian Sigg   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
168a5c2f782SMatthias Springer   Operation *testOp = &module->getBody()->getOperations().front();
169a5c2f782SMatthias Springer   Operation *op1 = &testOp->getRegion(0).front().front();
170a5c2f782SMatthias Springer   Operation *op2 = &testOp->getRegion(1).front().front();
171a5c2f782SMatthias Springer 
172a5c2f782SMatthias Springer   EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op2));
173a5c2f782SMatthias Springer   EXPECT_FALSE(insideMutuallyExclusiveRegions(op2, op1));
174a5c2f782SMatthias Springer }
175a5c2f782SMatthias Springer 
TEST(RegionBranchOpInterface,NestedMutuallyExclusiveOps)176a5c2f782SMatthias Springer TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) {
177a5c2f782SMatthias Springer   const char *ir = R"MLIR(
178a5c2f782SMatthias Springer "cftest.mutually_exclusive_regions_op"() (
179a5c2f782SMatthias Springer       {
180a5c2f782SMatthias Springer         "cftest.sequential_regions_op"() (
181a5c2f782SMatthias Springer               {"cftest.dummy_op"() : () -> ()},  // op1
182a5c2f782SMatthias Springer               {"cftest.dummy_op"() : () -> ()}   // op3
183a5c2f782SMatthias Springer           ) : () -> ()
184a5c2f782SMatthias Springer         "cftest.dummy_op"() : () -> ()
185a5c2f782SMatthias Springer       },
186a5c2f782SMatthias Springer       {"cftest.dummy_op"() : () -> ()}           // op2
187a5c2f782SMatthias Springer   ) : () -> ()
188a5c2f782SMatthias Springer   )MLIR";
189a5c2f782SMatthias Springer 
190a5c2f782SMatthias Springer   DialectRegistry registry;
191a5c2f782SMatthias Springer   registry.insert<CFTestDialect>();
192a5c2f782SMatthias Springer   MLIRContext ctx(registry);
193a5c2f782SMatthias Springer 
194dfaadf6bSChristian Sigg   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
195a5c2f782SMatthias Springer   Operation *testOp = &module->getBody()->getOperations().front();
196a5c2f782SMatthias Springer   Operation *op1 =
197a5c2f782SMatthias Springer       &testOp->getRegion(0).front().front().getRegion(0).front().front();
198a5c2f782SMatthias Springer   Operation *op2 = &testOp->getRegion(1).front().front();
199a5c2f782SMatthias Springer   Operation *op3 =
200a5c2f782SMatthias Springer       &testOp->getRegion(0).front().front().getRegion(1).front().front();
201a5c2f782SMatthias Springer 
202a5c2f782SMatthias Springer   EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
203a5c2f782SMatthias Springer   EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2));
204a5c2f782SMatthias Springer   EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3));
205a5c2f782SMatthias Springer }
2060f4ba02dSMatthias Springer 
TEST(RegionBranchOpInterface,RecursiveRegions)2070f4ba02dSMatthias Springer TEST(RegionBranchOpInterface, RecursiveRegions) {
2080f4ba02dSMatthias Springer   const char *ir = R"MLIR(
2090f4ba02dSMatthias Springer "cftest.loop_regions_op"() (
2100f4ba02dSMatthias Springer       {"cftest.dummy_op"() : () -> ()},  // op1
2110f4ba02dSMatthias Springer       {"cftest.dummy_op"() : () -> ()},  // op2
2120f4ba02dSMatthias Springer       {"cftest.dummy_op"() : () -> ()}   // op3
2130f4ba02dSMatthias Springer   ) : () -> ()
2140f4ba02dSMatthias Springer   )MLIR";
2150f4ba02dSMatthias Springer 
2160f4ba02dSMatthias Springer   DialectRegistry registry;
2170f4ba02dSMatthias Springer   registry.insert<CFTestDialect>();
2180f4ba02dSMatthias Springer   MLIRContext ctx(registry);
2190f4ba02dSMatthias Springer 
2200f4ba02dSMatthias Springer   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
2210f4ba02dSMatthias Springer   Operation *testOp = &module->getBody()->getOperations().front();
2220f4ba02dSMatthias Springer   auto regionOp = cast<RegionBranchOpInterface>(testOp);
2230f4ba02dSMatthias Springer   Operation *op1 = &testOp->getRegion(0).front().front();
2240f4ba02dSMatthias Springer   Operation *op2 = &testOp->getRegion(1).front().front();
2250f4ba02dSMatthias Springer   Operation *op3 = &testOp->getRegion(2).front().front();
2260f4ba02dSMatthias Springer 
2270f4ba02dSMatthias Springer   EXPECT_TRUE(regionOp.isRepetitiveRegion(0));
2280f4ba02dSMatthias Springer   EXPECT_TRUE(regionOp.isRepetitiveRegion(1));
2290f4ba02dSMatthias Springer   EXPECT_TRUE(regionOp.isRepetitiveRegion(2));
2300f4ba02dSMatthias Springer   EXPECT_NE(getEnclosingRepetitiveRegion(op1), nullptr);
2310f4ba02dSMatthias Springer   EXPECT_NE(getEnclosingRepetitiveRegion(op2), nullptr);
2320f4ba02dSMatthias Springer   EXPECT_NE(getEnclosingRepetitiveRegion(op3), nullptr);
2330f4ba02dSMatthias Springer }
2340f4ba02dSMatthias Springer 
TEST(RegionBranchOpInterface,NotRecursiveRegions)2350f4ba02dSMatthias Springer TEST(RegionBranchOpInterface, NotRecursiveRegions) {
2360f4ba02dSMatthias Springer   const char *ir = R"MLIR(
2370f4ba02dSMatthias Springer "cftest.sequential_regions_op"() (
2380f4ba02dSMatthias Springer       {"cftest.dummy_op"() : () -> ()},  // op1
2390f4ba02dSMatthias Springer       {"cftest.dummy_op"() : () -> ()}   // op2
2400f4ba02dSMatthias Springer   ) : () -> ()
2410f4ba02dSMatthias Springer   )MLIR";
2420f4ba02dSMatthias Springer 
2430f4ba02dSMatthias Springer   DialectRegistry registry;
2440f4ba02dSMatthias Springer   registry.insert<CFTestDialect>();
2450f4ba02dSMatthias Springer   MLIRContext ctx(registry);
2460f4ba02dSMatthias Springer 
2470f4ba02dSMatthias Springer   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
2480f4ba02dSMatthias Springer   Operation *testOp = &module->getBody()->getOperations().front();
2490f4ba02dSMatthias Springer   Operation *op1 = &testOp->getRegion(0).front().front();
2500f4ba02dSMatthias Springer   Operation *op2 = &testOp->getRegion(1).front().front();
2510f4ba02dSMatthias Springer 
2520f4ba02dSMatthias Springer   EXPECT_EQ(getEnclosingRepetitiveRegion(op1), nullptr);
2530f4ba02dSMatthias Springer   EXPECT_EQ(getEnclosingRepetitiveRegion(op2), nullptr);
2540f4ba02dSMatthias Springer }
255