xref: /llvm-project/mlir/unittests/IR/OperationSupportTest.cpp (revision e95e94adc6bb748de015ac3053e7f0786b65f351)
1 //===- OperationSupportTest.cpp - Operation support unit tests ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/OperationSupport.h"
10 #include "../../test/lib/Dialect/Test/TestDialect.h"
11 #include "../../test/lib/Dialect/Test/TestOps.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "llvm/ADT/BitVector.h"
15 #include "llvm/Support/FormatVariadic.h"
16 #include "gtest/gtest.h"
17 
18 using namespace mlir;
19 using namespace mlir::detail;
20 
createOp(MLIRContext * context,ArrayRef<Value> operands=std::nullopt,ArrayRef<Type> resultTypes=std::nullopt,unsigned int numRegions=0)21 static Operation *createOp(MLIRContext *context,
22                            ArrayRef<Value> operands = std::nullopt,
23                            ArrayRef<Type> resultTypes = std::nullopt,
24                            unsigned int numRegions = 0) {
25   context->allowUnregisteredDialects();
26   return Operation::create(
27       UnknownLoc::get(context), OperationName("foo.bar", context), resultTypes,
28       operands, std::nullopt, nullptr, std::nullopt, numRegions);
29 }
30 
31 namespace {
TEST(OperandStorageTest,NonResizable)32 TEST(OperandStorageTest, NonResizable) {
33   MLIRContext context;
34   Builder builder(&context);
35 
36   Operation *useOp =
37       createOp(&context, /*operands=*/std::nullopt, builder.getIntegerType(16));
38   Value operand = useOp->getResult(0);
39 
40   // Create a non-resizable operation with one operand.
41   Operation *user = createOp(&context, operand);
42 
43   // The same number of operands is okay.
44   user->setOperands(operand);
45   EXPECT_EQ(user->getNumOperands(), 1u);
46 
47   // Removing is okay.
48   user->setOperands(std::nullopt);
49   EXPECT_EQ(user->getNumOperands(), 0u);
50 
51   // Destroy the operations.
52   user->destroy();
53   useOp->destroy();
54 }
55 
TEST(OperandStorageTest,Resizable)56 TEST(OperandStorageTest, Resizable) {
57   MLIRContext context;
58   Builder builder(&context);
59 
60   Operation *useOp =
61       createOp(&context, /*operands=*/std::nullopt, builder.getIntegerType(16));
62   Value operand = useOp->getResult(0);
63 
64   // Create a resizable operation with one operand.
65   Operation *user = createOp(&context, operand);
66 
67   // The same number of operands is okay.
68   user->setOperands(operand);
69   EXPECT_EQ(user->getNumOperands(), 1u);
70 
71   // Removing is okay.
72   user->setOperands(std::nullopt);
73   EXPECT_EQ(user->getNumOperands(), 0u);
74 
75   // Adding more operands is okay.
76   user->setOperands({operand, operand, operand});
77   EXPECT_EQ(user->getNumOperands(), 3u);
78 
79   // Destroy the operations.
80   user->destroy();
81   useOp->destroy();
82 }
83 
TEST(OperandStorageTest,RangeReplace)84 TEST(OperandStorageTest, RangeReplace) {
85   MLIRContext context;
86   Builder builder(&context);
87 
88   Operation *useOp =
89       createOp(&context, /*operands=*/std::nullopt, builder.getIntegerType(16));
90   Value operand = useOp->getResult(0);
91 
92   // Create a resizable operation with one operand.
93   Operation *user = createOp(&context, operand);
94 
95   // Check setting with the same number of operands.
96   user->setOperands(/*start=*/0, /*length=*/1, operand);
97   EXPECT_EQ(user->getNumOperands(), 1u);
98 
99   // Check setting with more operands.
100   user->setOperands(/*start=*/0, /*length=*/1, {operand, operand, operand});
101   EXPECT_EQ(user->getNumOperands(), 3u);
102 
103   // Check setting with less operands.
104   user->setOperands(/*start=*/1, /*length=*/2, {operand});
105   EXPECT_EQ(user->getNumOperands(), 2u);
106 
107   // Check inserting without replacing operands.
108   user->setOperands(/*start=*/2, /*length=*/0, {operand});
109   EXPECT_EQ(user->getNumOperands(), 3u);
110 
111   // Check erasing operands.
112   user->setOperands(/*start=*/0, /*length=*/3, {});
113   EXPECT_EQ(user->getNumOperands(), 0u);
114 
115   // Destroy the operations.
116   user->destroy();
117   useOp->destroy();
118 }
119 
TEST(OperandStorageTest,MutableRange)120 TEST(OperandStorageTest, MutableRange) {
121   MLIRContext context;
122   Builder builder(&context);
123 
124   Operation *useOp =
125       createOp(&context, /*operands=*/std::nullopt, builder.getIntegerType(16));
126   Value operand = useOp->getResult(0);
127 
128   // Create a resizable operation with one operand.
129   Operation *user = createOp(&context, operand);
130 
131   // Check setting with the same number of operands.
132   MutableOperandRange mutableOperands(user);
133   mutableOperands.assign(operand);
134   EXPECT_EQ(mutableOperands.size(), 1u);
135   EXPECT_EQ(user->getNumOperands(), 1u);
136 
137   // Check setting with more operands.
138   mutableOperands.assign({operand, operand, operand});
139   EXPECT_EQ(mutableOperands.size(), 3u);
140   EXPECT_EQ(user->getNumOperands(), 3u);
141 
142   // Check with inserting a new operand.
143   mutableOperands.append({operand, operand});
144   EXPECT_EQ(mutableOperands.size(), 5u);
145   EXPECT_EQ(user->getNumOperands(), 5u);
146 
147   // Check erasing operands.
148   mutableOperands.clear();
149   EXPECT_EQ(mutableOperands.size(), 0u);
150   EXPECT_EQ(user->getNumOperands(), 0u);
151 
152   // Destroy the operations.
153   user->destroy();
154   useOp->destroy();
155 }
156 
TEST(OperandStorageTest,RangeErase)157 TEST(OperandStorageTest, RangeErase) {
158   MLIRContext context;
159   Builder builder(&context);
160 
161   Type type = builder.getNoneType();
162   Operation *useOp =
163       createOp(&context, /*operands=*/std::nullopt, {type, type});
164   Value operand1 = useOp->getResult(0);
165   Value operand2 = useOp->getResult(1);
166 
167   // Create an operation with operands to erase.
168   Operation *user =
169       createOp(&context, {operand2, operand1, operand2, operand1});
170   BitVector eraseIndices(user->getNumOperands());
171 
172   // Check erasing no operands.
173   user->eraseOperands(eraseIndices);
174   EXPECT_EQ(user->getNumOperands(), 4u);
175 
176   // Check erasing disjoint operands.
177   eraseIndices.set(0);
178   eraseIndices.set(3);
179   user->eraseOperands(eraseIndices);
180   EXPECT_EQ(user->getNumOperands(), 2u);
181   EXPECT_EQ(user->getOperand(0), operand1);
182   EXPECT_EQ(user->getOperand(1), operand2);
183 
184   // Destroy the operations.
185   user->destroy();
186   useOp->destroy();
187 }
188 
TEST(OperationOrderTest,OrderIsAlwaysValid)189 TEST(OperationOrderTest, OrderIsAlwaysValid) {
190   MLIRContext context;
191   Builder builder(&context);
192 
193   Operation *containerOp = createOp(&context, /*operands=*/std::nullopt,
194                                     /*resultTypes=*/std::nullopt,
195                                     /*numRegions=*/1);
196   Region &region = containerOp->getRegion(0);
197   Block *block = new Block();
198   region.push_back(block);
199 
200   // Insert two operations, then iteratively add more operations in the middle
201   // of them. Eventually we will insert more than kOrderStride operations and
202   // the block order will need to be recomputed.
203   Operation *frontOp = createOp(&context);
204   Operation *backOp = createOp(&context);
205   block->push_back(frontOp);
206   block->push_back(backOp);
207 
208   // Chosen to be larger than Operation::kOrderStride.
209   int kNumOpsToInsert = 10;
210   for (int i = 0; i < kNumOpsToInsert; ++i) {
211     Operation *op = createOp(&context);
212     block->getOperations().insert(backOp->getIterator(), op);
213     ASSERT_TRUE(op->isBeforeInBlock(backOp));
214     // Note verifyOpOrder() returns false if the order is valid.
215     ASSERT_FALSE(block->verifyOpOrder());
216   }
217 
218   containerOp->destroy();
219 }
220 
TEST(OperationFormatPrintTest,CanUseVariadicFormat)221 TEST(OperationFormatPrintTest, CanUseVariadicFormat) {
222   MLIRContext context;
223   Builder builder(&context);
224 
225   Operation *op = createOp(&context);
226 
227   std::string str = formatv("{0}", *op).str();
228   ASSERT_STREQ(str.c_str(), "\"foo.bar\"() : () -> ()");
229 
230   op->destroy();
231 }
232 
TEST(NamedAttrListTest,TestAppendAssign)233 TEST(NamedAttrListTest, TestAppendAssign) {
234   MLIRContext ctx;
235   NamedAttrList attrs;
236   Builder b(&ctx);
237 
238   attrs.append(b.getStringAttr("foo"), b.getStringAttr("bar"));
239   attrs.append("baz", b.getStringAttr("boo"));
240 
241   {
242     auto *it = attrs.begin();
243     EXPECT_EQ(it->getName(), b.getStringAttr("foo"));
244     EXPECT_EQ(it->getValue(), b.getStringAttr("bar"));
245     ++it;
246     EXPECT_EQ(it->getName(), b.getStringAttr("baz"));
247     EXPECT_EQ(it->getValue(), b.getStringAttr("boo"));
248   }
249 
250   attrs.append("foo", b.getStringAttr("zoo"));
251   {
252     auto dup = attrs.findDuplicate();
253     ASSERT_TRUE(dup.has_value());
254   }
255 
256   SmallVector<NamedAttribute> newAttrs = {
257       b.getNamedAttr("foo", b.getStringAttr("f")),
258       b.getNamedAttr("zoo", b.getStringAttr("z")),
259   };
260   attrs.assign(newAttrs);
261 
262   auto dup = attrs.findDuplicate();
263   ASSERT_FALSE(dup.has_value());
264 
265   {
266     auto *it = attrs.begin();
267     EXPECT_EQ(it->getName(), b.getStringAttr("foo"));
268     EXPECT_EQ(it->getValue(), b.getStringAttr("f"));
269     ++it;
270     EXPECT_EQ(it->getName(), b.getStringAttr("zoo"));
271     EXPECT_EQ(it->getValue(), b.getStringAttr("z"));
272   }
273 
274   attrs.assign({});
275   ASSERT_TRUE(attrs.empty());
276 }
277 
TEST(OperandStorageTest,PopulateDefaultAttrs)278 TEST(OperandStorageTest, PopulateDefaultAttrs) {
279   MLIRContext context;
280   context.getOrLoadDialect<test::TestDialect>();
281   Builder builder(&context);
282 
283   OpBuilder b(&context);
284   auto req1 = b.getI32IntegerAttr(10);
285   auto req2 = b.getI32IntegerAttr(60);
286   // Verify default attributes populated post op creation.
287   Operation *op = b.create<test::OpAttrMatch1>(b.getUnknownLoc(), req1, nullptr,
288                                                nullptr, req2);
289   auto opt = op->getInherentAttr("default_valued_attr");
290   EXPECT_NE(opt, nullptr) << *op;
291 
292   op->destroy();
293 }
294 
TEST(OperationEquivalenceTest,HashWorksWithFlags)295 TEST(OperationEquivalenceTest, HashWorksWithFlags) {
296   MLIRContext context;
297   context.getOrLoadDialect<test::TestDialect>();
298 
299   auto *op1 = createOp(&context);
300   // `op1` has an unknown loc.
301   auto *op2 = createOp(&context);
302   op2->setLoc(NameLoc::get(StringAttr::get(&context, "foo")));
303   auto getHash = [](Operation *op, OperationEquivalence::Flags flags) {
304     return OperationEquivalence::computeHash(
305         op, OperationEquivalence::ignoreHashValue,
306         OperationEquivalence::ignoreHashValue, flags);
307   };
308   EXPECT_EQ(getHash(op1, OperationEquivalence::IgnoreLocations),
309             getHash(op2, OperationEquivalence::IgnoreLocations));
310   EXPECT_NE(getHash(op1, OperationEquivalence::None),
311             getHash(op2, OperationEquivalence::None));
312   op1->destroy();
313   op2->destroy();
314 }
315 
316 } // namespace
317