1 //===- TestOpProperties.cpp - Test all properties-related APIs ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Attributes.h" 10 #include "mlir/IR/OpDefinition.h" 11 #include "mlir/IR/OperationSupport.h" 12 #include "mlir/Parser/Parser.h" 13 #include "gtest/gtest.h" 14 #include <optional> 15 16 using namespace mlir; 17 18 namespace { 19 /// Simple structure definining a struct to define "properties" for a given 20 /// operation. Default values are honored when creating an operation. 21 struct TestProperties { 22 int a = -1; 23 float b = -1.; 24 std::vector<int64_t> array = {-33}; 25 /// A shared_ptr to a const object is safe: it is equivalent to a value-based 26 /// member. Here the label will be deallocated when the last operation 27 /// referring to it is destroyed. However there is no pool-allocation: this is 28 /// offloaded to the client. 29 std::shared_ptr<const std::string> label; 30 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestProperties) 31 }; 32 33 bool operator==(const TestProperties &lhs, TestProperties &rhs) { 34 return lhs.a == rhs.a && lhs.b == rhs.b && lhs.array == rhs.array && 35 lhs.label == rhs.label; 36 } 37 38 /// Convert a DictionaryAttr to a TestProperties struct, optionally emit errors 39 /// through the provided diagnostic if any. This is used for example during 40 /// parsing with the generic format. 41 static LogicalResult 42 setPropertiesFromAttribute(TestProperties &prop, Attribute attr, 43 function_ref<InFlightDiagnostic()> emitError) { 44 DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr); 45 if (!dict) { 46 emitError() << "expected DictionaryAttr to set TestProperties"; 47 return failure(); 48 } 49 auto aAttr = dict.getAs<IntegerAttr>("a"); 50 if (!aAttr) { 51 emitError() << "expected IntegerAttr for key `a`"; 52 return failure(); 53 } 54 auto bAttr = dict.getAs<FloatAttr>("b"); 55 if (!bAttr || 56 &bAttr.getValue().getSemantics() != &llvm::APFloatBase::IEEEsingle()) { 57 emitError() << "expected FloatAttr for key `b`"; 58 return failure(); 59 } 60 61 auto arrayAttr = dict.getAs<DenseI64ArrayAttr>("array"); 62 if (!arrayAttr) { 63 emitError() << "expected DenseI64ArrayAttr for key `array`"; 64 return failure(); 65 } 66 67 auto label = dict.getAs<mlir::StringAttr>("label"); 68 if (!label) { 69 emitError() << "expected StringAttr for key `label`"; 70 return failure(); 71 } 72 73 prop.a = aAttr.getValue().getSExtValue(); 74 prop.b = bAttr.getValue().convertToFloat(); 75 prop.array.assign(arrayAttr.asArrayRef().begin(), 76 arrayAttr.asArrayRef().end()); 77 prop.label = std::make_shared<std::string>(label.getValue()); 78 return success(); 79 } 80 81 /// Convert a TestProperties struct to a DictionaryAttr, this is used for 82 /// example during printing with the generic format. 83 static Attribute getPropertiesAsAttribute(MLIRContext *ctx, 84 const TestProperties &prop) { 85 SmallVector<NamedAttribute> attrs; 86 Builder b{ctx}; 87 attrs.push_back(b.getNamedAttr("a", b.getI32IntegerAttr(prop.a))); 88 attrs.push_back(b.getNamedAttr("b", b.getF32FloatAttr(prop.b))); 89 attrs.push_back(b.getNamedAttr("array", b.getDenseI64ArrayAttr(prop.array))); 90 attrs.push_back(b.getNamedAttr( 91 "label", b.getStringAttr(prop.label ? *prop.label : "<nullptr>"))); 92 return b.getDictionaryAttr(attrs); 93 } 94 95 inline llvm::hash_code computeHash(const TestProperties &prop) { 96 // We hash `b` which is a float using its underlying array of char: 97 unsigned char const *p = reinterpret_cast<unsigned char const *>(&prop.b); 98 ArrayRef<unsigned char> bBytes{p, sizeof(prop.b)}; 99 return llvm::hash_combine( 100 prop.a, llvm::hash_combine_range(bBytes.begin(), bBytes.end()), 101 llvm::hash_combine_range(prop.array.begin(), prop.array.end()), 102 StringRef(*prop.label)); 103 } 104 105 /// A custom operation for the purpose of showcasing how to use "properties". 106 class OpWithProperties : public Op<OpWithProperties> { 107 public: 108 // Begin boilerplate 109 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithProperties) 110 using Op::Op; 111 static ArrayRef<StringRef> getAttributeNames() { return {}; } 112 static StringRef getOperationName() { 113 return "test_op_properties.op_with_properties"; 114 } 115 // End boilerplate 116 117 // This alias is the only definition needed for enabling "properties" for this 118 // operation. 119 using Properties = TestProperties; 120 static std::optional<mlir::Attribute> getInherentAttr(MLIRContext *context, 121 const Properties &prop, 122 StringRef name) { 123 return std::nullopt; 124 } 125 static void setInherentAttr(Properties &prop, StringRef name, 126 mlir::Attribute value) {} 127 static void populateInherentAttrs(MLIRContext *context, 128 const Properties &prop, 129 NamedAttrList &attrs) {} 130 static LogicalResult 131 verifyInherentAttrs(OperationName opName, NamedAttrList &attrs, 132 function_ref<InFlightDiagnostic()> emitError) { 133 return success(); 134 } 135 }; 136 137 /// A custom operation for the purpose of showcasing how discardable attributes 138 /// are handled in absence of properties. 139 class OpWithoutProperties : public Op<OpWithoutProperties> { 140 public: 141 // Begin boilerplate. 142 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithoutProperties) 143 using Op::Op; 144 static ArrayRef<StringRef> getAttributeNames() { 145 static StringRef attributeNames[] = {StringRef("inherent_attr")}; 146 return ArrayRef(attributeNames); 147 }; 148 static StringRef getOperationName() { 149 return "test_op_properties.op_without_properties"; 150 } 151 // End boilerplate. 152 }; 153 154 // A trivial supporting dialect to register the above operation. 155 class TestOpPropertiesDialect : public Dialect { 156 public: 157 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOpPropertiesDialect) 158 static constexpr StringLiteral getDialectNamespace() { 159 return StringLiteral("test_op_properties"); 160 } 161 explicit TestOpPropertiesDialect(MLIRContext *context) 162 : Dialect(getDialectNamespace(), context, 163 TypeID::get<TestOpPropertiesDialect>()) { 164 addOperations<OpWithProperties, OpWithoutProperties>(); 165 } 166 }; 167 168 constexpr StringLiteral mlirSrc = R"mlir( 169 "test_op_properties.op_with_properties"() 170 <{a = -42 : i32, 171 b = -4.200000e+01 : f32, 172 array = array<i64: 40, 41>, 173 label = "bar foo"}> : () -> () 174 )mlir"; 175 176 TEST(OpPropertiesTest, Properties) { 177 MLIRContext context; 178 context.getOrLoadDialect<TestOpPropertiesDialect>(); 179 ParserConfig config(&context); 180 // Parse the operation with some properties. 181 OwningOpRef<Operation *> op = parseSourceString(mlirSrc, config); 182 ASSERT_TRUE(op.get() != nullptr); 183 auto opWithProp = dyn_cast<OpWithProperties>(op.get()); 184 ASSERT_TRUE(opWithProp); 185 { 186 std::string output; 187 llvm::raw_string_ostream os(output); 188 opWithProp.print(os); 189 ASSERT_STREQ("\"test_op_properties.op_with_properties\"() " 190 "<{a = -42 : i32, " 191 "array = array<i64: 40, 41>, " 192 "b = -4.200000e+01 : f32, " 193 "label = \"bar foo\"}> : () -> ()\n", 194 output.c_str()); 195 } 196 // Get a mutable reference to the properties for this operation and modify it 197 // in place one member at a time. 198 TestProperties &prop = opWithProp.getProperties(); 199 prop.a = 42; 200 { 201 std::string output; 202 llvm::raw_string_ostream os(output); 203 opWithProp.print(os); 204 StringRef view(output); 205 EXPECT_TRUE(view.contains("a = 42")); 206 EXPECT_TRUE(view.contains("b = -4.200000e+01")); 207 EXPECT_TRUE(view.contains("array = array<i64: 40, 41>")); 208 EXPECT_TRUE(view.contains("label = \"bar foo\"")); 209 } 210 prop.b = 42.; 211 { 212 std::string output; 213 llvm::raw_string_ostream os(output); 214 opWithProp.print(os); 215 StringRef view(output); 216 EXPECT_TRUE(view.contains("a = 42")); 217 EXPECT_TRUE(view.contains("b = 4.200000e+01")); 218 EXPECT_TRUE(view.contains("array = array<i64: 40, 41>")); 219 EXPECT_TRUE(view.contains("label = \"bar foo\"")); 220 } 221 prop.array.push_back(42); 222 { 223 std::string output; 224 llvm::raw_string_ostream os(output); 225 opWithProp.print(os); 226 StringRef view(output); 227 EXPECT_TRUE(view.contains("a = 42")); 228 EXPECT_TRUE(view.contains("b = 4.200000e+01")); 229 EXPECT_TRUE(view.contains("array = array<i64: 40, 41, 42>")); 230 EXPECT_TRUE(view.contains("label = \"bar foo\"")); 231 } 232 prop.label = std::make_shared<std::string>("foo bar"); 233 { 234 std::string output; 235 llvm::raw_string_ostream os(output); 236 opWithProp.print(os); 237 StringRef view(output); 238 EXPECT_TRUE(view.contains("a = 42")); 239 EXPECT_TRUE(view.contains("b = 4.200000e+01")); 240 EXPECT_TRUE(view.contains("array = array<i64: 40, 41, 42>")); 241 EXPECT_TRUE(view.contains("label = \"foo bar\"")); 242 } 243 } 244 245 // Test diagnostic emission when using invalid dictionary. 246 TEST(OpPropertiesTest, FailedProperties) { 247 MLIRContext context; 248 context.getOrLoadDialect<TestOpPropertiesDialect>(); 249 std::string diagnosticStr; 250 context.getDiagEngine().registerHandler([&](Diagnostic &diag) { 251 diagnosticStr += diag.str(); 252 return success(); 253 }); 254 255 // Parse the operation with some properties. 256 ParserConfig config(&context); 257 258 // Parse an operation with invalid (incomplete) properties. 259 OwningOpRef<Operation *> owningOp = 260 parseSourceString("\"test_op_properties.op_with_properties\"() " 261 "<{a = -42 : i32}> : () -> ()\n", 262 config); 263 ASSERT_EQ(owningOp.get(), nullptr); 264 EXPECT_STREQ( 265 "invalid properties {a = -42 : i32} for op " 266 "test_op_properties.op_with_properties: expected FloatAttr for key `b`", 267 diagnosticStr.c_str()); 268 diagnosticStr.clear(); 269 270 owningOp = parseSourceString(mlirSrc, config); 271 Operation *op = owningOp.get(); 272 ASSERT_TRUE(op != nullptr); 273 Location loc = op->getLoc(); 274 auto opWithProp = dyn_cast<OpWithProperties>(op); 275 ASSERT_TRUE(opWithProp); 276 277 OperationState state(loc, op->getName()); 278 Builder b{&context}; 279 NamedAttrList attrs; 280 attrs.push_back(b.getNamedAttr("a", b.getStringAttr("foo"))); 281 state.propertiesAttr = attrs.getDictionary(&context); 282 { 283 auto emitError = [&]() { 284 return op->emitError("setting properties failed: "); 285 }; 286 auto result = state.setProperties(op, emitError); 287 EXPECT_TRUE(result.failed()); 288 } 289 EXPECT_STREQ("setting properties failed: expected IntegerAttr for key `a`", 290 diagnosticStr.c_str()); 291 } 292 293 TEST(OpPropertiesTest, DefaultValues) { 294 MLIRContext context; 295 context.getOrLoadDialect<TestOpPropertiesDialect>(); 296 OperationState state(UnknownLoc::get(&context), 297 "test_op_properties.op_with_properties"); 298 Operation *op = Operation::create(state); 299 ASSERT_TRUE(op != nullptr); 300 { 301 std::string output; 302 llvm::raw_string_ostream os(output); 303 op->print(os); 304 StringRef view(output); 305 EXPECT_TRUE(view.contains("a = -1")); 306 EXPECT_TRUE(view.contains("b = -1")); 307 EXPECT_TRUE(view.contains("array = array<i64: -33>")); 308 } 309 op->erase(); 310 } 311 312 TEST(OpPropertiesTest, Cloning) { 313 MLIRContext context; 314 context.getOrLoadDialect<TestOpPropertiesDialect>(); 315 ParserConfig config(&context); 316 // Parse the operation with some properties. 317 OwningOpRef<Operation *> op = parseSourceString(mlirSrc, config); 318 ASSERT_TRUE(op.get() != nullptr); 319 auto opWithProp = dyn_cast<OpWithProperties>(op.get()); 320 ASSERT_TRUE(opWithProp); 321 Operation *clone = opWithProp->clone(); 322 323 // Check that op and its clone prints equally 324 std::string opStr; 325 std::string cloneStr; 326 { 327 llvm::raw_string_ostream os(opStr); 328 op.get()->print(os); 329 } 330 { 331 llvm::raw_string_ostream os(cloneStr); 332 clone->print(os); 333 } 334 clone->erase(); 335 EXPECT_STREQ(opStr.c_str(), cloneStr.c_str()); 336 } 337 338 TEST(OpPropertiesTest, Equivalence) { 339 MLIRContext context; 340 context.getOrLoadDialect<TestOpPropertiesDialect>(); 341 ParserConfig config(&context); 342 // Parse the operation with some properties. 343 OwningOpRef<Operation *> op = parseSourceString(mlirSrc, config); 344 ASSERT_TRUE(op.get() != nullptr); 345 auto opWithProp = dyn_cast<OpWithProperties>(op.get()); 346 ASSERT_TRUE(opWithProp); 347 llvm::hash_code reference = OperationEquivalence::computeHash(opWithProp); 348 TestProperties &prop = opWithProp.getProperties(); 349 prop.a = 42; 350 EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp)); 351 prop.a = -42; 352 EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp)); 353 prop.b = 42.; 354 EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp)); 355 prop.b = -42.; 356 EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp)); 357 prop.array.push_back(42); 358 EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp)); 359 prop.array.pop_back(); 360 EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp)); 361 } 362 363 TEST(OpPropertiesTest, getOrAddProperties) { 364 MLIRContext context; 365 context.getOrLoadDialect<TestOpPropertiesDialect>(); 366 OperationState state(UnknownLoc::get(&context), 367 "test_op_properties.op_with_properties"); 368 // Test `getOrAddProperties` API on OperationState. 369 TestProperties &prop = state.getOrAddProperties<TestProperties>(); 370 prop.a = 1; 371 prop.b = 2; 372 prop.array = {3, 4, 5}; 373 Operation *op = Operation::create(state); 374 ASSERT_TRUE(op != nullptr); 375 { 376 std::string output; 377 llvm::raw_string_ostream os(output); 378 op->print(os); 379 StringRef view(output); 380 EXPECT_TRUE(view.contains("a = 1")); 381 EXPECT_TRUE(view.contains("b = 2")); 382 EXPECT_TRUE(view.contains("array = array<i64: 3, 4, 5>")); 383 } 384 op->erase(); 385 } 386 387 constexpr StringLiteral withoutPropertiesAttrsSrc = R"mlir( 388 "test_op_properties.op_without_properties"() 389 {inherent_attr = 42, other_attr = 56} : () -> () 390 )mlir"; 391 392 TEST(OpPropertiesTest, withoutPropertiesDiscardableAttrs) { 393 MLIRContext context; 394 context.getOrLoadDialect<TestOpPropertiesDialect>(); 395 ParserConfig config(&context); 396 OwningOpRef<Operation *> op = 397 parseSourceString(withoutPropertiesAttrsSrc, config); 398 ASSERT_EQ(llvm::range_size(op->getDiscardableAttrs()), 1u); 399 EXPECT_EQ(op->getDiscardableAttrs().begin()->getName().getValue(), 400 "other_attr"); 401 402 EXPECT_EQ(op->getAttrs().size(), 2u); 403 EXPECT_TRUE(op->getInherentAttr("inherent_attr") != std::nullopt); 404 EXPECT_TRUE(op->getDiscardableAttr("other_attr") != Attribute()); 405 406 std::string output; 407 llvm::raw_string_ostream os(output); 408 op->print(os); 409 StringRef view(output); 410 EXPECT_TRUE(view.contains("inherent_attr = 42")); 411 EXPECT_TRUE(view.contains("other_attr = 56")); 412 413 OwningOpRef<Operation *> reparsed = parseSourceString(os.str(), config); 414 auto trivialHash = [](Value v) { return hash_value(v); }; 415 auto hash = [&](Operation *operation) { 416 return OperationEquivalence::computeHash( 417 operation, trivialHash, trivialHash, 418 OperationEquivalence::Flags::IgnoreLocations); 419 }; 420 EXPECT_TRUE(hash(op.get()) == hash(reparsed.get())); 421 } 422 423 } // namespace 424