1 //===- ControlFlowInterfacesTest.cpp - Unit Tests for Control Flow Interf. ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/ControlFlowInterfaces.h" 10 #include "mlir/IR/BuiltinOps.h" 11 #include "mlir/IR/Dialect.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/OpDefinition.h" 14 #include "mlir/IR/OpImplementation.h" 15 #include "mlir/Parser/Parser.h" 16 17 #include <gtest/gtest.h> 18 19 using namespace mlir; 20 21 /// A dummy op that is also a terminator. 22 struct DummyOp : public Op<DummyOp, OpTrait::IsTerminator> { 23 using Op::Op; 24 static ArrayRef<StringRef> getAttributeNames() { return {}; } 25 26 static StringRef getOperationName() { return "cftest.dummy_op"; } 27 }; 28 29 /// All regions of this op are mutually exclusive. 30 struct MutuallyExclusiveRegionsOp 31 : public Op<MutuallyExclusiveRegionsOp, RegionBranchOpInterface::Trait> { 32 using Op::Op; 33 static ArrayRef<StringRef> getAttributeNames() { return {}; } 34 35 static StringRef getOperationName() { 36 return "cftest.mutually_exclusive_regions_op"; 37 } 38 39 // Regions have no successors. 40 void getSuccessorRegions(Optional<unsigned> index, 41 ArrayRef<Attribute> operands, 42 SmallVectorImpl<RegionSuccessor> ®ions) {} 43 }; 44 45 /// All regions of this op call each other in a large circle. 46 struct LoopRegionsOp 47 : public Op<LoopRegionsOp, RegionBranchOpInterface::Trait> { 48 using Op::Op; 49 static const unsigned kNumRegions = 3; 50 51 static ArrayRef<StringRef> getAttributeNames() { return {}; } 52 53 static StringRef getOperationName() { return "cftest.loop_regions_op"; } 54 55 void getSuccessorRegions(Optional<unsigned> index, 56 ArrayRef<Attribute> operands, 57 SmallVectorImpl<RegionSuccessor> ®ions) { 58 if (index) { 59 if (*index == 1) 60 // This region also branches back to the parent. 61 regions.push_back(RegionSuccessor()); 62 regions.push_back( 63 RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions))); 64 } 65 } 66 }; 67 68 /// Regions are executed sequentially. 69 struct SequentialRegionsOp 70 : public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> { 71 using Op::Op; 72 static ArrayRef<StringRef> getAttributeNames() { return {}; } 73 74 static StringRef getOperationName() { return "cftest.sequential_regions_op"; } 75 76 // Region 0 has Region 1 as a successor. 77 void getSuccessorRegions(Optional<unsigned> index, 78 ArrayRef<Attribute> operands, 79 SmallVectorImpl<RegionSuccessor> ®ions) { 80 if (index == 0u) { 81 Operation *thisOp = this->getOperation(); 82 regions.push_back(RegionSuccessor(&thisOp->getRegion(1))); 83 } 84 } 85 }; 86 87 /// A dialect putting all the above together. 88 struct CFTestDialect : Dialect { 89 explicit CFTestDialect(MLIRContext *ctx) 90 : Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) { 91 addOperations<DummyOp, MutuallyExclusiveRegionsOp, LoopRegionsOp, 92 SequentialRegionsOp>(); 93 } 94 static StringRef getDialectNamespace() { return "cftest"; } 95 }; 96 97 TEST(RegionBranchOpInterface, MutuallyExclusiveOps) { 98 const char *ir = R"MLIR( 99 "cftest.mutually_exclusive_regions_op"() ( 100 {"cftest.dummy_op"() : () -> ()}, // op1 101 {"cftest.dummy_op"() : () -> ()} // op2 102 ) : () -> () 103 )MLIR"; 104 105 DialectRegistry registry; 106 registry.insert<CFTestDialect>(); 107 MLIRContext ctx(registry); 108 109 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx); 110 Operation *testOp = &module->getBody()->getOperations().front(); 111 Operation *op1 = &testOp->getRegion(0).front().front(); 112 Operation *op2 = &testOp->getRegion(1).front().front(); 113 114 EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2)); 115 EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1)); 116 } 117 118 TEST(RegionBranchOpInterface, NotMutuallyExclusiveOps) { 119 const char *ir = R"MLIR( 120 "cftest.sequential_regions_op"() ( 121 {"cftest.dummy_op"() : () -> ()}, // op1 122 {"cftest.dummy_op"() : () -> ()} // op2 123 ) : () -> () 124 )MLIR"; 125 126 DialectRegistry registry; 127 registry.insert<CFTestDialect>(); 128 MLIRContext ctx(registry); 129 130 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx); 131 Operation *testOp = &module->getBody()->getOperations().front(); 132 Operation *op1 = &testOp->getRegion(0).front().front(); 133 Operation *op2 = &testOp->getRegion(1).front().front(); 134 135 EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op2)); 136 EXPECT_FALSE(insideMutuallyExclusiveRegions(op2, op1)); 137 } 138 139 TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) { 140 const char *ir = R"MLIR( 141 "cftest.mutually_exclusive_regions_op"() ( 142 { 143 "cftest.sequential_regions_op"() ( 144 {"cftest.dummy_op"() : () -> ()}, // op1 145 {"cftest.dummy_op"() : () -> ()} // op3 146 ) : () -> () 147 "cftest.dummy_op"() : () -> () 148 }, 149 {"cftest.dummy_op"() : () -> ()} // op2 150 ) : () -> () 151 )MLIR"; 152 153 DialectRegistry registry; 154 registry.insert<CFTestDialect>(); 155 MLIRContext ctx(registry); 156 157 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx); 158 Operation *testOp = &module->getBody()->getOperations().front(); 159 Operation *op1 = 160 &testOp->getRegion(0).front().front().getRegion(0).front().front(); 161 Operation *op2 = &testOp->getRegion(1).front().front(); 162 Operation *op3 = 163 &testOp->getRegion(0).front().front().getRegion(1).front().front(); 164 165 EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2)); 166 EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2)); 167 EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3)); 168 } 169 170 TEST(RegionBranchOpInterface, RecursiveRegions) { 171 const char *ir = R"MLIR( 172 "cftest.loop_regions_op"() ( 173 {"cftest.dummy_op"() : () -> ()}, // op1 174 {"cftest.dummy_op"() : () -> ()}, // op2 175 {"cftest.dummy_op"() : () -> ()} // op3 176 ) : () -> () 177 )MLIR"; 178 179 DialectRegistry registry; 180 registry.insert<CFTestDialect>(); 181 MLIRContext ctx(registry); 182 183 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx); 184 Operation *testOp = &module->getBody()->getOperations().front(); 185 auto regionOp = cast<RegionBranchOpInterface>(testOp); 186 Operation *op1 = &testOp->getRegion(0).front().front(); 187 Operation *op2 = &testOp->getRegion(1).front().front(); 188 Operation *op3 = &testOp->getRegion(2).front().front(); 189 190 EXPECT_TRUE(regionOp.isRepetitiveRegion(0)); 191 EXPECT_TRUE(regionOp.isRepetitiveRegion(1)); 192 EXPECT_TRUE(regionOp.isRepetitiveRegion(2)); 193 EXPECT_NE(getEnclosingRepetitiveRegion(op1), nullptr); 194 EXPECT_NE(getEnclosingRepetitiveRegion(op2), nullptr); 195 EXPECT_NE(getEnclosingRepetitiveRegion(op3), nullptr); 196 } 197 198 TEST(RegionBranchOpInterface, NotRecursiveRegions) { 199 const char *ir = R"MLIR( 200 "cftest.sequential_regions_op"() ( 201 {"cftest.dummy_op"() : () -> ()}, // op1 202 {"cftest.dummy_op"() : () -> ()} // op2 203 ) : () -> () 204 )MLIR"; 205 206 DialectRegistry registry; 207 registry.insert<CFTestDialect>(); 208 MLIRContext ctx(registry); 209 210 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx); 211 Operation *testOp = &module->getBody()->getOperations().front(); 212 Operation *op1 = &testOp->getRegion(0).front().front(); 213 Operation *op2 = &testOp->getRegion(1).front().front(); 214 215 EXPECT_EQ(getEnclosingRepetitiveRegion(op1), nullptr); 216 EXPECT_EQ(getEnclosingRepetitiveRegion(op2), nullptr); 217 } 218