xref: /llvm-project/mlir/unittests/IR/OpPropertiesTest.cpp (revision ffc80de8643969ffa0dbbd377c5b33e3a7488f5e)
1 //===- TestOpProperties.cpp - Test all properties-related APIs ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Attributes.h"
10 #include "mlir/IR/OpDefinition.h"
11 #include "mlir/IR/OperationSupport.h"
12 #include "mlir/Parser/Parser.h"
13 #include "gtest/gtest.h"
14 #include <optional>
15 
16 using namespace mlir;
17 
18 namespace {
19 /// Simple structure definining a struct to define "properties" for a given
20 /// operation. Default values are honored when creating an operation.
21 struct TestProperties {
22   int a = -1;
23   float b = -1.;
24   std::vector<int64_t> array = {-33};
25   /// A shared_ptr to a const object is safe: it is equivalent to a value-based
26   /// member. Here the label will be deallocated when the last operation
27   /// referring to it is destroyed. However there is no pool-allocation: this is
28   /// offloaded to the client.
29   std::shared_ptr<const std::string> label;
30   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestProperties)
31 };
32 
33 bool operator==(const TestProperties &lhs, TestProperties &rhs) {
34   return lhs.a == rhs.a && lhs.b == rhs.b && lhs.array == rhs.array &&
35          lhs.label == rhs.label;
36 }
37 
38 /// Convert a DictionaryAttr to a TestProperties struct, optionally emit errors
39 /// through the provided diagnostic if any. This is used for example during
40 /// parsing with the generic format.
41 static LogicalResult
42 setPropertiesFromAttribute(TestProperties &prop, Attribute attr,
43                            function_ref<InFlightDiagnostic()> emitError) {
44   DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
45   if (!dict) {
46     emitError() << "expected DictionaryAttr to set TestProperties";
47     return failure();
48   }
49   auto aAttr = dict.getAs<IntegerAttr>("a");
50   if (!aAttr) {
51     emitError() << "expected IntegerAttr for key `a`";
52     return failure();
53   }
54   auto bAttr = dict.getAs<FloatAttr>("b");
55   if (!bAttr ||
56       &bAttr.getValue().getSemantics() != &llvm::APFloatBase::IEEEsingle()) {
57     emitError() << "expected FloatAttr for key `b`";
58     return failure();
59   }
60 
61   auto arrayAttr = dict.getAs<DenseI64ArrayAttr>("array");
62   if (!arrayAttr) {
63     emitError() << "expected DenseI64ArrayAttr for key `array`";
64     return failure();
65   }
66 
67   auto label = dict.getAs<mlir::StringAttr>("label");
68   if (!label) {
69     emitError() << "expected StringAttr for key `label`";
70     return failure();
71   }
72 
73   prop.a = aAttr.getValue().getSExtValue();
74   prop.b = bAttr.getValue().convertToFloat();
75   prop.array.assign(arrayAttr.asArrayRef().begin(),
76                     arrayAttr.asArrayRef().end());
77   prop.label = std::make_shared<std::string>(label.getValue());
78   return success();
79 }
80 
81 /// Convert a TestProperties struct to a DictionaryAttr, this is used for
82 /// example during printing with the generic format.
83 static Attribute getPropertiesAsAttribute(MLIRContext *ctx,
84                                           const TestProperties &prop) {
85   SmallVector<NamedAttribute> attrs;
86   Builder b{ctx};
87   attrs.push_back(b.getNamedAttr("a", b.getI32IntegerAttr(prop.a)));
88   attrs.push_back(b.getNamedAttr("b", b.getF32FloatAttr(prop.b)));
89   attrs.push_back(b.getNamedAttr("array", b.getDenseI64ArrayAttr(prop.array)));
90   attrs.push_back(b.getNamedAttr(
91       "label", b.getStringAttr(prop.label ? *prop.label : "<nullptr>")));
92   return b.getDictionaryAttr(attrs);
93 }
94 
95 inline llvm::hash_code computeHash(const TestProperties &prop) {
96   // We hash `b` which is a float using its underlying array of char:
97   unsigned char const *p = reinterpret_cast<unsigned char const *>(&prop.b);
98   ArrayRef<unsigned char> bBytes{p, sizeof(prop.b)};
99   return llvm::hash_combine(
100       prop.a, llvm::hash_combine_range(bBytes.begin(), bBytes.end()),
101       llvm::hash_combine_range(prop.array.begin(), prop.array.end()),
102       StringRef(*prop.label));
103 }
104 
105 /// A custom operation for the purpose of showcasing how to use "properties".
106 class OpWithProperties : public Op<OpWithProperties> {
107 public:
108   // Begin boilerplate
109   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithProperties)
110   using Op::Op;
111   static ArrayRef<StringRef> getAttributeNames() { return {}; }
112   static StringRef getOperationName() {
113     return "test_op_properties.op_with_properties";
114   }
115   // End boilerplate
116 
117   // This alias is the only definition needed for enabling "properties" for this
118   // operation.
119   using Properties = TestProperties;
120   static std::optional<mlir::Attribute> getInherentAttr(MLIRContext *context,
121                                                         const Properties &prop,
122                                                         StringRef name) {
123     return std::nullopt;
124   }
125   static void setInherentAttr(Properties &prop, StringRef name,
126                               mlir::Attribute value) {}
127   static void populateInherentAttrs(MLIRContext *context,
128                                     const Properties &prop,
129                                     NamedAttrList &attrs) {}
130   static LogicalResult
131   verifyInherentAttrs(OperationName opName, NamedAttrList &attrs,
132                       function_ref<InFlightDiagnostic()> emitError) {
133     return success();
134   }
135 };
136 
137 /// A custom operation for the purpose of showcasing how discardable attributes
138 /// are handled in absence of properties.
139 class OpWithoutProperties : public Op<OpWithoutProperties> {
140 public:
141   // Begin boilerplate.
142   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithoutProperties)
143   using Op::Op;
144   static ArrayRef<StringRef> getAttributeNames() {
145     static StringRef attributeNames[] = {StringRef("inherent_attr")};
146     return ArrayRef(attributeNames);
147   };
148   static StringRef getOperationName() {
149     return "test_op_properties.op_without_properties";
150   }
151   // End boilerplate.
152 };
153 
154 // A trivial supporting dialect to register the above operation.
155 class TestOpPropertiesDialect : public Dialect {
156 public:
157   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOpPropertiesDialect)
158   static constexpr StringLiteral getDialectNamespace() {
159     return StringLiteral("test_op_properties");
160   }
161   explicit TestOpPropertiesDialect(MLIRContext *context)
162       : Dialect(getDialectNamespace(), context,
163                 TypeID::get<TestOpPropertiesDialect>()) {
164     addOperations<OpWithProperties, OpWithoutProperties>();
165   }
166 };
167 
168 constexpr StringLiteral mlirSrc = R"mlir(
169     "test_op_properties.op_with_properties"()
170       <{a = -42 : i32,
171         b = -4.200000e+01 : f32,
172         array = array<i64: 40, 41>,
173         label = "bar foo"}> : () -> ()
174 )mlir";
175 
176 TEST(OpPropertiesTest, Properties) {
177   MLIRContext context;
178   context.getOrLoadDialect<TestOpPropertiesDialect>();
179   ParserConfig config(&context);
180   // Parse the operation with some properties.
181   OwningOpRef<Operation *> op = parseSourceString(mlirSrc, config);
182   ASSERT_TRUE(op.get() != nullptr);
183   auto opWithProp = dyn_cast<OpWithProperties>(op.get());
184   ASSERT_TRUE(opWithProp);
185   {
186     std::string output;
187     llvm::raw_string_ostream os(output);
188     opWithProp.print(os);
189     ASSERT_STREQ("\"test_op_properties.op_with_properties\"() "
190                  "<{a = -42 : i32, "
191                  "array = array<i64: 40, 41>, "
192                  "b = -4.200000e+01 : f32, "
193                  "label = \"bar foo\"}> : () -> ()\n",
194                  output.c_str());
195   }
196   // Get a mutable reference to the properties for this operation and modify it
197   // in place one member at a time.
198   TestProperties &prop = opWithProp.getProperties();
199   prop.a = 42;
200   {
201     std::string output;
202     llvm::raw_string_ostream os(output);
203     opWithProp.print(os);
204     StringRef view(output);
205     EXPECT_TRUE(view.contains("a = 42"));
206     EXPECT_TRUE(view.contains("b = -4.200000e+01"));
207     EXPECT_TRUE(view.contains("array = array<i64: 40, 41>"));
208     EXPECT_TRUE(view.contains("label = \"bar foo\""));
209   }
210   prop.b = 42.;
211   {
212     std::string output;
213     llvm::raw_string_ostream os(output);
214     opWithProp.print(os);
215     StringRef view(output);
216     EXPECT_TRUE(view.contains("a = 42"));
217     EXPECT_TRUE(view.contains("b = 4.200000e+01"));
218     EXPECT_TRUE(view.contains("array = array<i64: 40, 41>"));
219     EXPECT_TRUE(view.contains("label = \"bar foo\""));
220   }
221   prop.array.push_back(42);
222   {
223     std::string output;
224     llvm::raw_string_ostream os(output);
225     opWithProp.print(os);
226     StringRef view(output);
227     EXPECT_TRUE(view.contains("a = 42"));
228     EXPECT_TRUE(view.contains("b = 4.200000e+01"));
229     EXPECT_TRUE(view.contains("array = array<i64: 40, 41, 42>"));
230     EXPECT_TRUE(view.contains("label = \"bar foo\""));
231   }
232   prop.label = std::make_shared<std::string>("foo bar");
233   {
234     std::string output;
235     llvm::raw_string_ostream os(output);
236     opWithProp.print(os);
237     StringRef view(output);
238     EXPECT_TRUE(view.contains("a = 42"));
239     EXPECT_TRUE(view.contains("b = 4.200000e+01"));
240     EXPECT_TRUE(view.contains("array = array<i64: 40, 41, 42>"));
241     EXPECT_TRUE(view.contains("label = \"foo bar\""));
242   }
243 }
244 
245 // Test diagnostic emission when using invalid dictionary.
246 TEST(OpPropertiesTest, FailedProperties) {
247   MLIRContext context;
248   context.getOrLoadDialect<TestOpPropertiesDialect>();
249   std::string diagnosticStr;
250   context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
251     diagnosticStr += diag.str();
252     return success();
253   });
254 
255   // Parse the operation with some properties.
256   ParserConfig config(&context);
257 
258   // Parse an operation with invalid (incomplete) properties.
259   OwningOpRef<Operation *> owningOp =
260       parseSourceString("\"test_op_properties.op_with_properties\"() "
261                         "<{a = -42 : i32}> : () -> ()\n",
262                         config);
263   ASSERT_EQ(owningOp.get(), nullptr);
264   EXPECT_STREQ(
265       "invalid properties {a = -42 : i32} for op "
266       "test_op_properties.op_with_properties: expected FloatAttr for key `b`",
267       diagnosticStr.c_str());
268   diagnosticStr.clear();
269 
270   owningOp = parseSourceString(mlirSrc, config);
271   Operation *op = owningOp.get();
272   ASSERT_TRUE(op != nullptr);
273   Location loc = op->getLoc();
274   auto opWithProp = dyn_cast<OpWithProperties>(op);
275   ASSERT_TRUE(opWithProp);
276 
277   OperationState state(loc, op->getName());
278   Builder b{&context};
279   NamedAttrList attrs;
280   attrs.push_back(b.getNamedAttr("a", b.getStringAttr("foo")));
281   state.propertiesAttr = attrs.getDictionary(&context);
282   {
283     auto emitError = [&]() {
284       return op->emitError("setting properties failed: ");
285     };
286     auto result = state.setProperties(op, emitError);
287     EXPECT_TRUE(result.failed());
288   }
289   EXPECT_STREQ("setting properties failed: expected IntegerAttr for key `a`",
290                diagnosticStr.c_str());
291 }
292 
293 TEST(OpPropertiesTest, DefaultValues) {
294   MLIRContext context;
295   context.getOrLoadDialect<TestOpPropertiesDialect>();
296   OperationState state(UnknownLoc::get(&context),
297                        "test_op_properties.op_with_properties");
298   Operation *op = Operation::create(state);
299   ASSERT_TRUE(op != nullptr);
300   {
301     std::string output;
302     llvm::raw_string_ostream os(output);
303     op->print(os);
304     StringRef view(output);
305     EXPECT_TRUE(view.contains("a = -1"));
306     EXPECT_TRUE(view.contains("b = -1"));
307     EXPECT_TRUE(view.contains("array = array<i64: -33>"));
308   }
309   op->erase();
310 }
311 
312 TEST(OpPropertiesTest, Cloning) {
313   MLIRContext context;
314   context.getOrLoadDialect<TestOpPropertiesDialect>();
315   ParserConfig config(&context);
316   // Parse the operation with some properties.
317   OwningOpRef<Operation *> op = parseSourceString(mlirSrc, config);
318   ASSERT_TRUE(op.get() != nullptr);
319   auto opWithProp = dyn_cast<OpWithProperties>(op.get());
320   ASSERT_TRUE(opWithProp);
321   Operation *clone = opWithProp->clone();
322 
323   // Check that op and its clone prints equally
324   std::string opStr;
325   std::string cloneStr;
326   {
327     llvm::raw_string_ostream os(opStr);
328     op.get()->print(os);
329   }
330   {
331     llvm::raw_string_ostream os(cloneStr);
332     clone->print(os);
333   }
334   clone->erase();
335   EXPECT_STREQ(opStr.c_str(), cloneStr.c_str());
336 }
337 
338 TEST(OpPropertiesTest, Equivalence) {
339   MLIRContext context;
340   context.getOrLoadDialect<TestOpPropertiesDialect>();
341   ParserConfig config(&context);
342   // Parse the operation with some properties.
343   OwningOpRef<Operation *> op = parseSourceString(mlirSrc, config);
344   ASSERT_TRUE(op.get() != nullptr);
345   auto opWithProp = dyn_cast<OpWithProperties>(op.get());
346   ASSERT_TRUE(opWithProp);
347   llvm::hash_code reference = OperationEquivalence::computeHash(opWithProp);
348   TestProperties &prop = opWithProp.getProperties();
349   prop.a = 42;
350   EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp));
351   prop.a = -42;
352   EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp));
353   prop.b = 42.;
354   EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp));
355   prop.b = -42.;
356   EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp));
357   prop.array.push_back(42);
358   EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp));
359   prop.array.pop_back();
360   EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp));
361 }
362 
363 TEST(OpPropertiesTest, getOrAddProperties) {
364   MLIRContext context;
365   context.getOrLoadDialect<TestOpPropertiesDialect>();
366   OperationState state(UnknownLoc::get(&context),
367                        "test_op_properties.op_with_properties");
368   // Test `getOrAddProperties` API on OperationState.
369   TestProperties &prop = state.getOrAddProperties<TestProperties>();
370   prop.a = 1;
371   prop.b = 2;
372   prop.array = {3, 4, 5};
373   Operation *op = Operation::create(state);
374   ASSERT_TRUE(op != nullptr);
375   {
376     std::string output;
377     llvm::raw_string_ostream os(output);
378     op->print(os);
379     StringRef view(output);
380     EXPECT_TRUE(view.contains("a = 1"));
381     EXPECT_TRUE(view.contains("b = 2"));
382     EXPECT_TRUE(view.contains("array = array<i64: 3, 4, 5>"));
383   }
384   op->erase();
385 }
386 
387 constexpr StringLiteral withoutPropertiesAttrsSrc = R"mlir(
388     "test_op_properties.op_without_properties"()
389       {inherent_attr = 42, other_attr = 56} : () -> ()
390 )mlir";
391 
392 TEST(OpPropertiesTest, withoutPropertiesDiscardableAttrs) {
393   MLIRContext context;
394   context.getOrLoadDialect<TestOpPropertiesDialect>();
395   ParserConfig config(&context);
396   OwningOpRef<Operation *> op =
397       parseSourceString(withoutPropertiesAttrsSrc, config);
398   ASSERT_EQ(llvm::range_size(op->getDiscardableAttrs()), 1u);
399   EXPECT_EQ(op->getDiscardableAttrs().begin()->getName().getValue(),
400             "other_attr");
401 
402   EXPECT_EQ(op->getAttrs().size(), 2u);
403   EXPECT_TRUE(op->getInherentAttr("inherent_attr") != std::nullopt);
404   EXPECT_TRUE(op->getDiscardableAttr("other_attr") != Attribute());
405 
406   std::string output;
407   llvm::raw_string_ostream os(output);
408   op->print(os);
409   StringRef view(output);
410   EXPECT_TRUE(view.contains("inherent_attr = 42"));
411   EXPECT_TRUE(view.contains("other_attr = 56"));
412 
413   OwningOpRef<Operation *> reparsed = parseSourceString(os.str(), config);
414   auto trivialHash = [](Value v) { return hash_value(v); };
415   auto hash = [&](Operation *operation) {
416     return OperationEquivalence::computeHash(
417         operation, trivialHash, trivialHash,
418         OperationEquivalence::Flags::IgnoreLocations);
419   };
420   EXPECT_TRUE(hash(op.get()) == hash(reparsed.get()));
421 }
422 
423 } // namespace
424