1 //===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/BuiltinAttributes.h" 10 #include "mlir-c/Support.h" 11 #include "mlir/CAPI/AffineMap.h" 12 #include "mlir/CAPI/IR.h" 13 #include "mlir/CAPI/IntegerSet.h" 14 #include "mlir/CAPI/Support.h" 15 #include "mlir/IR/AsmState.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/BuiltinAttributes.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 20 using namespace mlir; 21 22 MlirAttribute mlirAttributeGetNull() { return {nullptr}; } 23 24 //===----------------------------------------------------------------------===// 25 // Location attribute. 26 //===----------------------------------------------------------------------===// 27 28 bool mlirAttributeIsALocation(MlirAttribute attr) { 29 return llvm::isa<LocationAttr>(unwrap(attr)); 30 } 31 32 //===----------------------------------------------------------------------===// 33 // Affine map attribute. 34 //===----------------------------------------------------------------------===// 35 36 bool mlirAttributeIsAAffineMap(MlirAttribute attr) { 37 return llvm::isa<AffineMapAttr>(unwrap(attr)); 38 } 39 40 MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) { 41 return wrap(AffineMapAttr::get(unwrap(map))); 42 } 43 44 MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) { 45 return wrap(llvm::cast<AffineMapAttr>(unwrap(attr)).getValue()); 46 } 47 48 MlirTypeID mlirAffineMapAttrGetTypeID(void) { 49 return wrap(AffineMapAttr::getTypeID()); 50 } 51 52 //===----------------------------------------------------------------------===// 53 // Array attribute. 54 //===----------------------------------------------------------------------===// 55 56 bool mlirAttributeIsAArray(MlirAttribute attr) { 57 return llvm::isa<ArrayAttr>(unwrap(attr)); 58 } 59 60 MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, 61 MlirAttribute const *elements) { 62 SmallVector<Attribute, 8> attrs; 63 return wrap( 64 ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements), 65 elements, attrs))); 66 } 67 68 intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) { 69 return static_cast<intptr_t>(llvm::cast<ArrayAttr>(unwrap(attr)).size()); 70 } 71 72 MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) { 73 return wrap(llvm::cast<ArrayAttr>(unwrap(attr)).getValue()[pos]); 74 } 75 76 MlirTypeID mlirArrayAttrGetTypeID(void) { return wrap(ArrayAttr::getTypeID()); } 77 78 //===----------------------------------------------------------------------===// 79 // Dictionary attribute. 80 //===----------------------------------------------------------------------===// 81 82 bool mlirAttributeIsADictionary(MlirAttribute attr) { 83 return llvm::isa<DictionaryAttr>(unwrap(attr)); 84 } 85 86 MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, 87 MlirNamedAttribute const *elements) { 88 SmallVector<NamedAttribute, 8> attributes; 89 attributes.reserve(numElements); 90 for (intptr_t i = 0; i < numElements; ++i) 91 attributes.emplace_back(unwrap(elements[i].name), 92 unwrap(elements[i].attribute)); 93 return wrap(DictionaryAttr::get(unwrap(ctx), attributes)); 94 } 95 96 intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) { 97 return static_cast<intptr_t>(llvm::cast<DictionaryAttr>(unwrap(attr)).size()); 98 } 99 100 MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, 101 intptr_t pos) { 102 NamedAttribute attribute = 103 llvm::cast<DictionaryAttr>(unwrap(attr)).getValue()[pos]; 104 return {wrap(attribute.getName()), wrap(attribute.getValue())}; 105 } 106 107 MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, 108 MlirStringRef name) { 109 return wrap(llvm::cast<DictionaryAttr>(unwrap(attr)).get(unwrap(name))); 110 } 111 112 MlirTypeID mlirDictionaryAttrGetTypeID(void) { 113 return wrap(DictionaryAttr::getTypeID()); 114 } 115 116 //===----------------------------------------------------------------------===// 117 // Floating point attribute. 118 //===----------------------------------------------------------------------===// 119 120 bool mlirAttributeIsAFloat(MlirAttribute attr) { 121 return llvm::isa<FloatAttr>(unwrap(attr)); 122 } 123 124 MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, 125 double value) { 126 return wrap(FloatAttr::get(unwrap(type), value)); 127 } 128 129 MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type, 130 double value) { 131 return wrap(FloatAttr::getChecked(unwrap(loc), unwrap(type), value)); 132 } 133 134 double mlirFloatAttrGetValueDouble(MlirAttribute attr) { 135 return llvm::cast<FloatAttr>(unwrap(attr)).getValueAsDouble(); 136 } 137 138 MlirTypeID mlirFloatAttrGetTypeID(void) { return wrap(FloatAttr::getTypeID()); } 139 140 //===----------------------------------------------------------------------===// 141 // Integer attribute. 142 //===----------------------------------------------------------------------===// 143 144 bool mlirAttributeIsAInteger(MlirAttribute attr) { 145 return llvm::isa<IntegerAttr>(unwrap(attr)); 146 } 147 148 MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) { 149 return wrap(IntegerAttr::get(unwrap(type), value)); 150 } 151 152 int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) { 153 return llvm::cast<IntegerAttr>(unwrap(attr)).getInt(); 154 } 155 156 int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) { 157 return llvm::cast<IntegerAttr>(unwrap(attr)).getSInt(); 158 } 159 160 uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) { 161 return llvm::cast<IntegerAttr>(unwrap(attr)).getUInt(); 162 } 163 164 MlirTypeID mlirIntegerAttrGetTypeID(void) { 165 return wrap(IntegerAttr::getTypeID()); 166 } 167 168 //===----------------------------------------------------------------------===// 169 // Bool attribute. 170 //===----------------------------------------------------------------------===// 171 172 bool mlirAttributeIsABool(MlirAttribute attr) { 173 return llvm::isa<BoolAttr>(unwrap(attr)); 174 } 175 176 MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { 177 return wrap(BoolAttr::get(unwrap(ctx), value)); 178 } 179 180 bool mlirBoolAttrGetValue(MlirAttribute attr) { 181 return llvm::cast<BoolAttr>(unwrap(attr)).getValue(); 182 } 183 184 //===----------------------------------------------------------------------===// 185 // Integer set attribute. 186 //===----------------------------------------------------------------------===// 187 188 bool mlirAttributeIsAIntegerSet(MlirAttribute attr) { 189 return llvm::isa<IntegerSetAttr>(unwrap(attr)); 190 } 191 192 MlirTypeID mlirIntegerSetAttrGetTypeID(void) { 193 return wrap(IntegerSetAttr::getTypeID()); 194 } 195 196 MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set) { 197 return wrap(IntegerSetAttr::get(unwrap(set))); 198 } 199 200 MlirIntegerSet mlirIntegerSetAttrGetValue(MlirAttribute attr) { 201 return wrap(llvm::cast<IntegerSetAttr>(unwrap(attr)).getValue()); 202 } 203 204 //===----------------------------------------------------------------------===// 205 // Opaque attribute. 206 //===----------------------------------------------------------------------===// 207 208 bool mlirAttributeIsAOpaque(MlirAttribute attr) { 209 return llvm::isa<OpaqueAttr>(unwrap(attr)); 210 } 211 212 MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, 213 intptr_t dataLength, const char *data, 214 MlirType type) { 215 return wrap( 216 OpaqueAttr::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)), 217 StringRef(data, dataLength), unwrap(type))); 218 } 219 220 MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { 221 return wrap( 222 llvm::cast<OpaqueAttr>(unwrap(attr)).getDialectNamespace().strref()); 223 } 224 225 MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) { 226 return wrap(llvm::cast<OpaqueAttr>(unwrap(attr)).getAttrData()); 227 } 228 229 MlirTypeID mlirOpaqueAttrGetTypeID(void) { 230 return wrap(OpaqueAttr::getTypeID()); 231 } 232 233 //===----------------------------------------------------------------------===// 234 // String attribute. 235 //===----------------------------------------------------------------------===// 236 237 bool mlirAttributeIsAString(MlirAttribute attr) { 238 return llvm::isa<StringAttr>(unwrap(attr)); 239 } 240 241 MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) { 242 return wrap((Attribute)StringAttr::get(unwrap(ctx), unwrap(str))); 243 } 244 245 MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) { 246 return wrap((Attribute)StringAttr::get(unwrap(str), unwrap(type))); 247 } 248 249 MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) { 250 return wrap(llvm::cast<StringAttr>(unwrap(attr)).getValue()); 251 } 252 253 MlirTypeID mlirStringAttrGetTypeID(void) { 254 return wrap(StringAttr::getTypeID()); 255 } 256 257 //===----------------------------------------------------------------------===// 258 // SymbolRef attribute. 259 //===----------------------------------------------------------------------===// 260 261 bool mlirAttributeIsASymbolRef(MlirAttribute attr) { 262 return llvm::isa<SymbolRefAttr>(unwrap(attr)); 263 } 264 265 MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, 266 intptr_t numReferences, 267 MlirAttribute const *references) { 268 SmallVector<FlatSymbolRefAttr, 4> refs; 269 refs.reserve(numReferences); 270 for (intptr_t i = 0; i < numReferences; ++i) 271 refs.push_back(llvm::cast<FlatSymbolRefAttr>(unwrap(references[i]))); 272 auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol)); 273 return wrap(SymbolRefAttr::get(symbolAttr, refs)); 274 } 275 276 MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) { 277 return wrap( 278 llvm::cast<SymbolRefAttr>(unwrap(attr)).getRootReference().getValue()); 279 } 280 281 MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) { 282 return wrap( 283 llvm::cast<SymbolRefAttr>(unwrap(attr)).getLeafReference().getValue()); 284 } 285 286 intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) { 287 return static_cast<intptr_t>( 288 llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences().size()); 289 } 290 291 MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, 292 intptr_t pos) { 293 return wrap( 294 llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences()[pos]); 295 } 296 297 MlirTypeID mlirSymbolRefAttrGetTypeID(void) { 298 return wrap(SymbolRefAttr::getTypeID()); 299 } 300 301 MlirAttribute mlirDisctinctAttrCreate(MlirAttribute referencedAttr) { 302 return wrap(mlir::DistinctAttr::create(unwrap(referencedAttr))); 303 } 304 305 //===----------------------------------------------------------------------===// 306 // Flat SymbolRef attribute. 307 //===----------------------------------------------------------------------===// 308 309 bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) { 310 return llvm::isa<FlatSymbolRefAttr>(unwrap(attr)); 311 } 312 313 MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) { 314 return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol))); 315 } 316 317 MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { 318 return wrap(llvm::cast<FlatSymbolRefAttr>(unwrap(attr)).getValue()); 319 } 320 321 //===----------------------------------------------------------------------===// 322 // Type attribute. 323 //===----------------------------------------------------------------------===// 324 325 bool mlirAttributeIsAType(MlirAttribute attr) { 326 return llvm::isa<TypeAttr>(unwrap(attr)); 327 } 328 329 MlirAttribute mlirTypeAttrGet(MlirType type) { 330 return wrap(TypeAttr::get(unwrap(type))); 331 } 332 333 MlirType mlirTypeAttrGetValue(MlirAttribute attr) { 334 return wrap(llvm::cast<TypeAttr>(unwrap(attr)).getValue()); 335 } 336 337 MlirTypeID mlirTypeAttrGetTypeID(void) { return wrap(TypeAttr::getTypeID()); } 338 339 //===----------------------------------------------------------------------===// 340 // Unit attribute. 341 //===----------------------------------------------------------------------===// 342 343 bool mlirAttributeIsAUnit(MlirAttribute attr) { 344 return llvm::isa<UnitAttr>(unwrap(attr)); 345 } 346 347 MlirAttribute mlirUnitAttrGet(MlirContext ctx) { 348 return wrap(UnitAttr::get(unwrap(ctx))); 349 } 350 351 MlirTypeID mlirUnitAttrGetTypeID(void) { return wrap(UnitAttr::getTypeID()); } 352 353 //===----------------------------------------------------------------------===// 354 // Elements attributes. 355 //===----------------------------------------------------------------------===// 356 357 bool mlirAttributeIsAElements(MlirAttribute attr) { 358 return llvm::isa<ElementsAttr>(unwrap(attr)); 359 } 360 361 MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, 362 uint64_t *idxs) { 363 return wrap(llvm::cast<ElementsAttr>(unwrap(attr)) 364 .getValues<Attribute>()[llvm::ArrayRef(idxs, rank)]); 365 } 366 367 bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, 368 uint64_t *idxs) { 369 return llvm::cast<ElementsAttr>(unwrap(attr)) 370 .isValidIndex(llvm::ArrayRef(idxs, rank)); 371 } 372 373 int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { 374 return llvm::cast<ElementsAttr>(unwrap(attr)).getNumElements(); 375 } 376 377 //===----------------------------------------------------------------------===// 378 // Dense array attribute. 379 //===----------------------------------------------------------------------===// 380 381 MlirTypeID mlirDenseArrayAttrGetTypeID() { 382 return wrap(DenseArrayAttr::getTypeID()); 383 } 384 385 //===----------------------------------------------------------------------===// 386 // IsA support. 387 //===----------------------------------------------------------------------===// 388 389 bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) { 390 return llvm::isa<DenseBoolArrayAttr>(unwrap(attr)); 391 } 392 bool mlirAttributeIsADenseI8Array(MlirAttribute attr) { 393 return llvm::isa<DenseI8ArrayAttr>(unwrap(attr)); 394 } 395 bool mlirAttributeIsADenseI16Array(MlirAttribute attr) { 396 return llvm::isa<DenseI16ArrayAttr>(unwrap(attr)); 397 } 398 bool mlirAttributeIsADenseI32Array(MlirAttribute attr) { 399 return llvm::isa<DenseI32ArrayAttr>(unwrap(attr)); 400 } 401 bool mlirAttributeIsADenseI64Array(MlirAttribute attr) { 402 return llvm::isa<DenseI64ArrayAttr>(unwrap(attr)); 403 } 404 bool mlirAttributeIsADenseF32Array(MlirAttribute attr) { 405 return llvm::isa<DenseF32ArrayAttr>(unwrap(attr)); 406 } 407 bool mlirAttributeIsADenseF64Array(MlirAttribute attr) { 408 return llvm::isa<DenseF64ArrayAttr>(unwrap(attr)); 409 } 410 411 //===----------------------------------------------------------------------===// 412 // Constructors. 413 //===----------------------------------------------------------------------===// 414 415 MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size, 416 int const *values) { 417 SmallVector<bool, 4> elements(values, values + size); 418 return wrap(DenseBoolArrayAttr::get(unwrap(ctx), elements)); 419 } 420 MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, intptr_t size, 421 int8_t const *values) { 422 return wrap( 423 DenseI8ArrayAttr::get(unwrap(ctx), ArrayRef<int8_t>(values, size))); 424 } 425 MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, intptr_t size, 426 int16_t const *values) { 427 return wrap( 428 DenseI16ArrayAttr::get(unwrap(ctx), ArrayRef<int16_t>(values, size))); 429 } 430 MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, 431 int32_t const *values) { 432 return wrap( 433 DenseI32ArrayAttr::get(unwrap(ctx), ArrayRef<int32_t>(values, size))); 434 } 435 MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, intptr_t size, 436 int64_t const *values) { 437 return wrap( 438 DenseI64ArrayAttr::get(unwrap(ctx), ArrayRef<int64_t>(values, size))); 439 } 440 MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, intptr_t size, 441 float const *values) { 442 return wrap( 443 DenseF32ArrayAttr::get(unwrap(ctx), ArrayRef<float>(values, size))); 444 } 445 MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size, 446 double const *values) { 447 return wrap( 448 DenseF64ArrayAttr::get(unwrap(ctx), ArrayRef<double>(values, size))); 449 } 450 451 //===----------------------------------------------------------------------===// 452 // Accessors. 453 //===----------------------------------------------------------------------===// 454 455 intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) { 456 return llvm::cast<DenseArrayAttr>(unwrap(attr)).size(); 457 } 458 459 //===----------------------------------------------------------------------===// 460 // Indexed accessors. 461 //===----------------------------------------------------------------------===// 462 463 bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) { 464 return llvm::cast<DenseBoolArrayAttr>(unwrap(attr))[pos]; 465 } 466 int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos) { 467 return llvm::cast<DenseI8ArrayAttr>(unwrap(attr))[pos]; 468 } 469 int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos) { 470 return llvm::cast<DenseI16ArrayAttr>(unwrap(attr))[pos]; 471 } 472 int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos) { 473 return llvm::cast<DenseI32ArrayAttr>(unwrap(attr))[pos]; 474 } 475 int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos) { 476 return llvm::cast<DenseI64ArrayAttr>(unwrap(attr))[pos]; 477 } 478 float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos) { 479 return llvm::cast<DenseF32ArrayAttr>(unwrap(attr))[pos]; 480 } 481 double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) { 482 return llvm::cast<DenseF64ArrayAttr>(unwrap(attr))[pos]; 483 } 484 485 //===----------------------------------------------------------------------===// 486 // Dense elements attribute. 487 //===----------------------------------------------------------------------===// 488 489 //===----------------------------------------------------------------------===// 490 // IsA support. 491 //===----------------------------------------------------------------------===// 492 493 bool mlirAttributeIsADenseElements(MlirAttribute attr) { 494 return llvm::isa<DenseElementsAttr>(unwrap(attr)); 495 } 496 497 bool mlirAttributeIsADenseIntElements(MlirAttribute attr) { 498 return llvm::isa<DenseIntElementsAttr>(unwrap(attr)); 499 } 500 501 bool mlirAttributeIsADenseFPElements(MlirAttribute attr) { 502 return llvm::isa<DenseFPElementsAttr>(unwrap(attr)); 503 } 504 505 MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void) { 506 return wrap(DenseIntOrFPElementsAttr::getTypeID()); 507 } 508 509 //===----------------------------------------------------------------------===// 510 // Constructors. 511 //===----------------------------------------------------------------------===// 512 513 MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, 514 intptr_t numElements, 515 MlirAttribute const *elements) { 516 SmallVector<Attribute, 8> attributes; 517 return wrap( 518 DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), 519 unwrapList(numElements, elements, attributes))); 520 } 521 522 MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType, 523 size_t rawBufferSize, 524 const void *rawBuffer) { 525 auto shapedTypeCpp = llvm::cast<ShapedType>(unwrap(shapedType)); 526 ArrayRef<char> rawBufferCpp(static_cast<const char *>(rawBuffer), 527 rawBufferSize); 528 bool isSplat = false; 529 if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp, 530 isSplat)) 531 return mlirAttributeGetNull(); 532 return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp)); 533 } 534 535 MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, 536 MlirAttribute element) { 537 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), 538 unwrap(element))); 539 } 540 MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, 541 bool element) { 542 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), 543 element)); 544 } 545 MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType, 546 uint8_t element) { 547 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), 548 element)); 549 } 550 MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType, 551 int8_t element) { 552 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), 553 element)); 554 } 555 MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, 556 uint32_t element) { 557 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), 558 element)); 559 } 560 MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, 561 int32_t element) { 562 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), 563 element)); 564 } 565 MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType, 566 uint64_t element) { 567 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), 568 element)); 569 } 570 MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType, 571 int64_t element) { 572 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), 573 element)); 574 } 575 MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType, 576 float element) { 577 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), 578 element)); 579 } 580 MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, 581 double element) { 582 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), 583 element)); 584 } 585 586 MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType, 587 intptr_t numElements, 588 const int *elements) { 589 SmallVector<bool, 8> values(elements, elements + numElements); 590 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), 591 values)); 592 } 593 594 /// Creates a dense attribute with elements of the type deduced by templates. 595 template <typename T> 596 static MlirAttribute getDenseAttribute(MlirType shapedType, 597 intptr_t numElements, 598 const T *elements) { 599 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), 600 llvm::ArrayRef(elements, numElements))); 601 } 602 603 MlirAttribute mlirDenseElementsAttrUInt8Get(MlirType shapedType, 604 intptr_t numElements, 605 const uint8_t *elements) { 606 return getDenseAttribute(shapedType, numElements, elements); 607 } 608 MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType, 609 intptr_t numElements, 610 const int8_t *elements) { 611 return getDenseAttribute(shapedType, numElements, elements); 612 } 613 MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType, 614 intptr_t numElements, 615 const uint16_t *elements) { 616 return getDenseAttribute(shapedType, numElements, elements); 617 } 618 MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType, 619 intptr_t numElements, 620 const int16_t *elements) { 621 return getDenseAttribute(shapedType, numElements, elements); 622 } 623 MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType, 624 intptr_t numElements, 625 const uint32_t *elements) { 626 return getDenseAttribute(shapedType, numElements, elements); 627 } 628 MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType, 629 intptr_t numElements, 630 const int32_t *elements) { 631 return getDenseAttribute(shapedType, numElements, elements); 632 } 633 MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType, 634 intptr_t numElements, 635 const uint64_t *elements) { 636 return getDenseAttribute(shapedType, numElements, elements); 637 } 638 MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType, 639 intptr_t numElements, 640 const int64_t *elements) { 641 return getDenseAttribute(shapedType, numElements, elements); 642 } 643 MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType, 644 intptr_t numElements, 645 const float *elements) { 646 return getDenseAttribute(shapedType, numElements, elements); 647 } 648 MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType, 649 intptr_t numElements, 650 const double *elements) { 651 return getDenseAttribute(shapedType, numElements, elements); 652 } 653 MlirAttribute mlirDenseElementsAttrBFloat16Get(MlirType shapedType, 654 intptr_t numElements, 655 const uint16_t *elements) { 656 size_t bufferSize = numElements * 2; 657 const void *buffer = static_cast<const void *>(elements); 658 return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer); 659 } 660 MlirAttribute mlirDenseElementsAttrFloat16Get(MlirType shapedType, 661 intptr_t numElements, 662 const uint16_t *elements) { 663 size_t bufferSize = numElements * 2; 664 const void *buffer = static_cast<const void *>(elements); 665 return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer); 666 } 667 668 MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, 669 intptr_t numElements, 670 MlirStringRef *strs) { 671 SmallVector<StringRef, 8> values; 672 values.reserve(numElements); 673 for (intptr_t i = 0; i < numElements; ++i) 674 values.push_back(unwrap(strs[i])); 675 676 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), 677 values)); 678 } 679 680 MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, 681 MlirType shapedType) { 682 return wrap(llvm::cast<DenseElementsAttr>(unwrap(attr)) 683 .reshape(llvm::cast<ShapedType>(unwrap(shapedType)))); 684 } 685 686 //===----------------------------------------------------------------------===// 687 // Splat accessors. 688 //===----------------------------------------------------------------------===// 689 690 bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { 691 return llvm::cast<DenseElementsAttr>(unwrap(attr)).isSplat(); 692 } 693 694 MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { 695 return wrap( 696 llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<Attribute>()); 697 } 698 int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { 699 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<bool>(); 700 } 701 int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) { 702 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<int8_t>(); 703 } 704 uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) { 705 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<uint8_t>(); 706 } 707 int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) { 708 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<int32_t>(); 709 } 710 uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) { 711 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<uint32_t>(); 712 } 713 int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) { 714 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<int64_t>(); 715 } 716 uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) { 717 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<uint64_t>(); 718 } 719 float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) { 720 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<float>(); 721 } 722 double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) { 723 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<double>(); 724 } 725 MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) { 726 return wrap( 727 llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<StringRef>()); 728 } 729 730 //===----------------------------------------------------------------------===// 731 // Indexed accessors. 732 //===----------------------------------------------------------------------===// 733 734 bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { 735 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<bool>()[pos]; 736 } 737 int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) { 738 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<int8_t>()[pos]; 739 } 740 uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { 741 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint8_t>()[pos]; 742 } 743 int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) { 744 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<int16_t>()[pos]; 745 } 746 uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) { 747 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint16_t>()[pos]; 748 } 749 int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { 750 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<int32_t>()[pos]; 751 } 752 uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { 753 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint32_t>()[pos]; 754 } 755 int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { 756 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<int64_t>()[pos]; 757 } 758 uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { 759 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint64_t>()[pos]; 760 } 761 uint64_t mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos) { 762 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint64_t>()[pos]; 763 } 764 float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { 765 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<float>()[pos]; 766 } 767 double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { 768 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<double>()[pos]; 769 } 770 MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, 771 intptr_t pos) { 772 return wrap( 773 llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<StringRef>()[pos]); 774 } 775 776 //===----------------------------------------------------------------------===// 777 // Raw data accessors. 778 //===----------------------------------------------------------------------===// 779 780 const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { 781 return static_cast<const void *>( 782 llvm::cast<DenseElementsAttr>(unwrap(attr)).getRawData().data()); 783 } 784 785 //===----------------------------------------------------------------------===// 786 // Resource blob attributes. 787 //===----------------------------------------------------------------------===// 788 789 bool mlirAttributeIsADenseResourceElements(MlirAttribute attr) { 790 return llvm::isa<DenseResourceElementsAttr>(unwrap(attr)); 791 } 792 793 MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet( 794 MlirType shapedType, MlirStringRef name, void *data, size_t dataLength, 795 size_t dataAlignment, bool dataIsMutable, 796 void (*deleter)(void *userData, const void *data, size_t size, 797 size_t align), 798 void *userData) { 799 AsmResourceBlob::DeleterFn cppDeleter = {}; 800 if (deleter) { 801 cppDeleter = [deleter, userData](void *data, size_t size, size_t align) { 802 deleter(userData, data, size, align); 803 }; 804 } 805 AsmResourceBlob blob( 806 llvm::ArrayRef(static_cast<const char *>(data), dataLength), 807 dataAlignment, std::move(cppDeleter), dataIsMutable); 808 return wrap( 809 DenseResourceElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), 810 unwrap(name), std::move(blob))); 811 } 812 813 template <typename U, typename T> 814 static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name, 815 intptr_t numElements, const T *elements) { 816 return wrap(U::get(llvm::cast<ShapedType>(unwrap(shapedType)), unwrap(name), 817 UnmanagedAsmResourceBlob::allocateInferAlign( 818 llvm::ArrayRef(elements, numElements)))); 819 } 820 821 MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet( 822 MlirType shapedType, MlirStringRef name, intptr_t numElements, 823 const int *elements) { 824 return getDenseResource<DenseBoolResourceElementsAttr>(shapedType, name, 825 numElements, elements); 826 } 827 MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet( 828 MlirType shapedType, MlirStringRef name, intptr_t numElements, 829 const uint8_t *elements) { 830 return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name, 831 numElements, elements); 832 } 833 MlirAttribute mlirUnmanagedDenseUInt16ResourceElementsAttrGet( 834 MlirType shapedType, MlirStringRef name, intptr_t numElements, 835 const uint16_t *elements) { 836 return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name, 837 numElements, elements); 838 } 839 MlirAttribute mlirUnmanagedDenseUInt32ResourceElementsAttrGet( 840 MlirType shapedType, MlirStringRef name, intptr_t numElements, 841 const uint32_t *elements) { 842 return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name, 843 numElements, elements); 844 } 845 MlirAttribute mlirUnmanagedDenseUInt64ResourceElementsAttrGet( 846 MlirType shapedType, MlirStringRef name, intptr_t numElements, 847 const uint64_t *elements) { 848 return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name, 849 numElements, elements); 850 } 851 MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet( 852 MlirType shapedType, MlirStringRef name, intptr_t numElements, 853 const int8_t *elements) { 854 return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name, 855 numElements, elements); 856 } 857 MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet( 858 MlirType shapedType, MlirStringRef name, intptr_t numElements, 859 const int16_t *elements) { 860 return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name, 861 numElements, elements); 862 } 863 MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet( 864 MlirType shapedType, MlirStringRef name, intptr_t numElements, 865 const int32_t *elements) { 866 return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name, 867 numElements, elements); 868 } 869 MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet( 870 MlirType shapedType, MlirStringRef name, intptr_t numElements, 871 const int64_t *elements) { 872 return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name, 873 numElements, elements); 874 } 875 MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet( 876 MlirType shapedType, MlirStringRef name, intptr_t numElements, 877 const float *elements) { 878 return getDenseResource<DenseF32ResourceElementsAttr>(shapedType, name, 879 numElements, elements); 880 } 881 MlirAttribute mlirUnmanagedDenseDoubleResourceElementsAttrGet( 882 MlirType shapedType, MlirStringRef name, intptr_t numElements, 883 const double *elements) { 884 return getDenseResource<DenseF64ResourceElementsAttr>(shapedType, name, 885 numElements, elements); 886 } 887 template <typename U, typename T> 888 static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) { 889 return (*llvm::cast<U>(unwrap(attr)).tryGetAsArrayRef())[pos]; 890 } 891 892 bool mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr, 893 intptr_t pos) { 894 return getDenseResourceVal<DenseBoolResourceElementsAttr, uint8_t>(attr, pos); 895 } 896 uint8_t mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr, 897 intptr_t pos) { 898 return getDenseResourceVal<DenseUI8ResourceElementsAttr, uint8_t>(attr, pos); 899 } 900 uint16_t mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr, 901 intptr_t pos) { 902 return getDenseResourceVal<DenseUI16ResourceElementsAttr, uint16_t>(attr, 903 pos); 904 } 905 uint32_t mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr, 906 intptr_t pos) { 907 return getDenseResourceVal<DenseUI32ResourceElementsAttr, uint32_t>(attr, 908 pos); 909 } 910 uint64_t mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr, 911 intptr_t pos) { 912 return getDenseResourceVal<DenseUI64ResourceElementsAttr, uint64_t>(attr, 913 pos); 914 } 915 int8_t mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr, 916 intptr_t pos) { 917 return getDenseResourceVal<DenseUI8ResourceElementsAttr, int8_t>(attr, pos); 918 } 919 int16_t mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr, 920 intptr_t pos) { 921 return getDenseResourceVal<DenseUI16ResourceElementsAttr, int16_t>(attr, pos); 922 } 923 int32_t mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr, 924 intptr_t pos) { 925 return getDenseResourceVal<DenseUI32ResourceElementsAttr, int32_t>(attr, pos); 926 } 927 int64_t mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr, 928 intptr_t pos) { 929 return getDenseResourceVal<DenseUI64ResourceElementsAttr, int64_t>(attr, pos); 930 } 931 float mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr, 932 intptr_t pos) { 933 return getDenseResourceVal<DenseF32ResourceElementsAttr, float>(attr, pos); 934 } 935 double mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr, 936 intptr_t pos) { 937 return getDenseResourceVal<DenseF64ResourceElementsAttr, double>(attr, pos); 938 } 939 940 //===----------------------------------------------------------------------===// 941 // Sparse elements attribute. 942 //===----------------------------------------------------------------------===// 943 944 bool mlirAttributeIsASparseElements(MlirAttribute attr) { 945 return llvm::isa<SparseElementsAttr>(unwrap(attr)); 946 } 947 948 MlirAttribute mlirSparseElementsAttribute(MlirType shapedType, 949 MlirAttribute denseIndices, 950 MlirAttribute denseValues) { 951 return wrap(SparseElementsAttr::get( 952 llvm::cast<ShapedType>(unwrap(shapedType)), 953 llvm::cast<DenseElementsAttr>(unwrap(denseIndices)), 954 llvm::cast<DenseElementsAttr>(unwrap(denseValues)))); 955 } 956 957 MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) { 958 return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getIndices()); 959 } 960 961 MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) { 962 return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getValues()); 963 } 964 965 MlirTypeID mlirSparseElementsAttrGetTypeID(void) { 966 return wrap(SparseElementsAttr::getTypeID()); 967 } 968 969 //===----------------------------------------------------------------------===// 970 // Strided layout attribute. 971 //===----------------------------------------------------------------------===// 972 973 bool mlirAttributeIsAStridedLayout(MlirAttribute attr) { 974 return llvm::isa<StridedLayoutAttr>(unwrap(attr)); 975 } 976 977 MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, 978 intptr_t numStrides, 979 const int64_t *strides) { 980 return wrap(StridedLayoutAttr::get(unwrap(ctx), offset, 981 ArrayRef<int64_t>(strides, numStrides))); 982 } 983 984 int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr) { 985 return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getOffset(); 986 } 987 988 intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) { 989 return static_cast<intptr_t>( 990 llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides().size()); 991 } 992 993 int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) { 994 return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides()[pos]; 995 } 996 997 MlirTypeID mlirStridedLayoutAttrGetTypeID(void) { 998 return wrap(StridedLayoutAttr::getTypeID()); 999 } 1000