1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <cstdint> 10 #include <optional> 11 #include <string> 12 #include <string_view> 13 #include <utility> 14 15 #include "IRModule.h" 16 #include "NanobindUtils.h" 17 #include "mlir-c/BuiltinAttributes.h" 18 #include "mlir-c/BuiltinTypes.h" 19 #include "mlir/Bindings/Python/NanobindAdaptors.h" 20 #include "mlir/Bindings/Python/Nanobind.h" 21 #include "llvm/ADT/ScopeExit.h" 22 #include "llvm/Support/raw_ostream.h" 23 24 namespace nb = nanobind; 25 using namespace nanobind::literals; 26 using namespace mlir; 27 using namespace mlir::python; 28 29 using llvm::SmallVector; 30 31 //------------------------------------------------------------------------------ 32 // Docstrings (trivial, non-duplicated docstrings are included inline). 33 //------------------------------------------------------------------------------ 34 35 static const char kDenseElementsAttrGetDocstring[] = 36 R"(Gets a DenseElementsAttr from a Python buffer or array. 37 38 When `type` is not provided, then some limited type inferencing is done based 39 on the buffer format. Support presently exists for 8/16/32/64 signed and 40 unsigned integers and float16/float32/float64. DenseElementsAttrs of these 41 types can also be converted back to a corresponding buffer. 42 43 For conversions outside of these types, a `type=` must be explicitly provided 44 and the buffer contents must be bit-castable to the MLIR internal 45 representation: 46 47 * Integer types (except for i1): the buffer must be byte aligned to the 48 next byte boundary. 49 * Floating point types: Must be bit-castable to the given floating point 50 size. 51 * i1 (bool): Bit packed into 8bit words where the bit pattern matches a 52 row major ordering. An arbitrary Numpy `bool_` array can be bit packed to 53 this specification with: `np.packbits(ary, axis=None, bitorder='little')`. 54 55 If a single element buffer is passed (or for i1, a single byte with value 0 56 or 255), then a splat will be created. 57 58 Args: 59 array: The array or buffer to convert. 60 signless: If inferring an appropriate MLIR type, use signless types for 61 integers (defaults True). 62 type: Skips inference of the MLIR element type and uses this instead. The 63 storage size must be consistent with the actual contents of the buffer. 64 shape: Overrides the shape of the buffer when constructing the MLIR 65 shaped type. This is needed when the physical and logical shape differ (as 66 for i1). 67 context: Explicit context, if not from context manager. 68 69 Returns: 70 DenseElementsAttr on success. 71 72 Raises: 73 ValueError: If the type of the buffer or array cannot be matched to an MLIR 74 type or if the buffer does not meet expectations. 75 )"; 76 77 static const char kDenseElementsAttrGetFromListDocstring[] = 78 R"(Gets a DenseElementsAttr from a Python list of attributes. 79 80 Note that it can be expensive to construct attributes individually. 81 For a large number of elements, consider using a Python buffer or array instead. 82 83 Args: 84 attrs: A list of attributes. 85 type: The desired shape and type of the resulting DenseElementsAttr. 86 If not provided, the element type is determined based on the type 87 of the 0th attribute and the shape is `[len(attrs)]`. 88 context: Explicit context, if not from context manager. 89 90 Returns: 91 DenseElementsAttr on success. 92 93 Raises: 94 ValueError: If the type of the attributes does not match the type 95 specified by `shaped_type`. 96 )"; 97 98 static const char kDenseResourceElementsAttrGetFromBufferDocstring[] = 99 R"(Gets a DenseResourceElementsAttr from a Python buffer or array. 100 101 This function does minimal validation or massaging of the data, and it is 102 up to the caller to ensure that the buffer meets the characteristics 103 implied by the shape. 104 105 The backing buffer and any user objects will be retained for the lifetime 106 of the resource blob. This is typically bounded to the context but the 107 resource can have a shorter lifespan depending on how it is used in 108 subsequent processing. 109 110 Args: 111 buffer: The array or buffer to convert. 112 name: Name to provide to the resource (may be changed upon collision). 113 type: The explicit ShapedType to construct the attribute with. 114 context: Explicit context, if not from context manager. 115 116 Returns: 117 DenseResourceElementsAttr on success. 118 119 Raises: 120 ValueError: If the type of the buffer or array cannot be matched to an MLIR 121 type or if the buffer does not meet expectations. 122 )"; 123 124 namespace { 125 126 struct nb_buffer_info { 127 void *ptr = nullptr; 128 ssize_t itemsize = 0; 129 ssize_t size = 0; 130 const char *format = nullptr; 131 ssize_t ndim = 0; 132 SmallVector<ssize_t, 4> shape; 133 SmallVector<ssize_t, 4> strides; 134 bool readonly = false; 135 136 nb_buffer_info( 137 void *ptr, ssize_t itemsize, const char *format, ssize_t ndim, 138 SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in, 139 bool readonly = false, 140 std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in = 141 std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr)) 142 : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim), 143 shape(std::move(shape_in)), strides(std::move(strides_in)), 144 readonly(readonly), owned_view(std::move(owned_view_in)) { 145 size = 1; 146 for (ssize_t i = 0; i < ndim; ++i) { 147 size *= shape[i]; 148 } 149 } 150 151 explicit nb_buffer_info(Py_buffer *view) 152 : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim, 153 {view->shape, view->shape + view->ndim}, 154 // TODO(phawkins): check for null strides 155 {view->strides, view->strides + view->ndim}, 156 view->readonly != 0, 157 std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>( 158 view, PyBuffer_Release)) {} 159 160 nb_buffer_info(const nb_buffer_info &) = delete; 161 nb_buffer_info(nb_buffer_info &&) = default; 162 nb_buffer_info &operator=(const nb_buffer_info &) = delete; 163 nb_buffer_info &operator=(nb_buffer_info &&) = default; 164 165 private: 166 std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view; 167 }; 168 169 class nb_buffer : public nb::object { 170 NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer); 171 172 nb_buffer_info request() const { 173 int flags = PyBUF_STRIDES | PyBUF_FORMAT; 174 auto *view = new Py_buffer(); 175 if (PyObject_GetBuffer(ptr(), view, flags) != 0) { 176 delete view; 177 throw nb::python_error(); 178 } 179 return nb_buffer_info(view); 180 } 181 }; 182 183 template <typename T> 184 struct nb_format_descriptor {}; 185 186 template <> 187 struct nb_format_descriptor<bool> { 188 static const char *format() { return "?"; } 189 }; 190 template <> 191 struct nb_format_descriptor<int8_t> { 192 static const char *format() { return "b"; } 193 }; 194 template <> 195 struct nb_format_descriptor<uint8_t> { 196 static const char *format() { return "B"; } 197 }; 198 template <> 199 struct nb_format_descriptor<int16_t> { 200 static const char *format() { return "h"; } 201 }; 202 template <> 203 struct nb_format_descriptor<uint16_t> { 204 static const char *format() { return "H"; } 205 }; 206 template <> 207 struct nb_format_descriptor<int32_t> { 208 static const char *format() { return "i"; } 209 }; 210 template <> 211 struct nb_format_descriptor<uint32_t> { 212 static const char *format() { return "I"; } 213 }; 214 template <> 215 struct nb_format_descriptor<int64_t> { 216 static const char *format() { return "q"; } 217 }; 218 template <> 219 struct nb_format_descriptor<uint64_t> { 220 static const char *format() { return "Q"; } 221 }; 222 template <> 223 struct nb_format_descriptor<float> { 224 static const char *format() { return "f"; } 225 }; 226 template <> 227 struct nb_format_descriptor<double> { 228 static const char *format() { return "d"; } 229 }; 230 231 static MlirStringRef toMlirStringRef(const std::string &s) { 232 return mlirStringRefCreate(s.data(), s.size()); 233 } 234 235 static MlirStringRef toMlirStringRef(const nb::bytes &s) { 236 return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size()); 237 } 238 239 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 240 public: 241 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 242 static constexpr const char *pyClassName = "AffineMapAttr"; 243 using PyConcreteAttribute::PyConcreteAttribute; 244 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 245 mlirAffineMapAttrGetTypeID; 246 247 static void bindDerived(ClassTy &c) { 248 c.def_static( 249 "get", 250 [](PyAffineMap &affineMap) { 251 MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 252 return PyAffineMapAttribute(affineMap.getContext(), attr); 253 }, 254 nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 255 c.def_prop_ro("value", mlirAffineMapAttrGetValue, 256 "Returns the value of the AffineMap attribute"); 257 } 258 }; 259 260 class PyIntegerSetAttribute 261 : public PyConcreteAttribute<PyIntegerSetAttribute> { 262 public: 263 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet; 264 static constexpr const char *pyClassName = "IntegerSetAttr"; 265 using PyConcreteAttribute::PyConcreteAttribute; 266 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 267 mlirIntegerSetAttrGetTypeID; 268 269 static void bindDerived(ClassTy &c) { 270 c.def_static( 271 "get", 272 [](PyIntegerSet &integerSet) { 273 MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get()); 274 return PyIntegerSetAttribute(integerSet.getContext(), attr); 275 }, 276 nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet."); 277 } 278 }; 279 280 template <typename T> 281 static T pyTryCast(nb::handle object) { 282 try { 283 return nb::cast<T>(object); 284 } catch (nb::cast_error &err) { 285 std::string msg = std::string("Invalid attribute when attempting to " 286 "create an ArrayAttribute (") + 287 err.what() + ")"; 288 throw std::runtime_error(msg.c_str()); 289 } catch (std::runtime_error &err) { 290 std::string msg = std::string("Invalid attribute (None?) when attempting " 291 "to create an ArrayAttribute (") + 292 err.what() + ")"; 293 throw std::runtime_error(msg.c_str()); 294 } 295 } 296 297 /// A python-wrapped dense array attribute with an element type and a derived 298 /// implementation class. 299 template <typename EltTy, typename DerivedT> 300 class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> { 301 public: 302 using PyConcreteAttribute<DerivedT>::PyConcreteAttribute; 303 304 /// Iterator over the integer elements of a dense array. 305 class PyDenseArrayIterator { 306 public: 307 PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {} 308 309 /// Return a copy of the iterator. 310 PyDenseArrayIterator dunderIter() { return *this; } 311 312 /// Return the next element. 313 EltTy dunderNext() { 314 // Throw if the index has reached the end. 315 if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) 316 throw nb::stop_iteration(); 317 return DerivedT::getElement(attr.get(), nextIndex++); 318 } 319 320 /// Bind the iterator class. 321 static void bind(nb::module_ &m) { 322 nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName) 323 .def("__iter__", &PyDenseArrayIterator::dunderIter) 324 .def("__next__", &PyDenseArrayIterator::dunderNext); 325 } 326 327 private: 328 /// The referenced dense array attribute. 329 PyAttribute attr; 330 /// The next index to read. 331 int nextIndex = 0; 332 }; 333 334 /// Get the element at the given index. 335 EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); } 336 337 /// Bind the attribute class. 338 static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) { 339 // Bind the constructor. 340 if constexpr (std::is_same_v<EltTy, bool>) { 341 c.def_static( 342 "get", 343 [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) { 344 std::vector<bool> values; 345 for (nb::handle py_value : py_values) { 346 int is_true = PyObject_IsTrue(py_value.ptr()); 347 if (is_true < 0) { 348 throw nb::python_error(); 349 } 350 values.push_back(is_true); 351 } 352 return getAttribute(values, ctx->getRef()); 353 }, 354 nb::arg("values"), nb::arg("context").none() = nb::none(), 355 "Gets a uniqued dense array attribute"); 356 } else { 357 c.def_static( 358 "get", 359 [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) { 360 return getAttribute(values, ctx->getRef()); 361 }, 362 nb::arg("values"), nb::arg("context").none() = nb::none(), 363 "Gets a uniqued dense array attribute"); 364 } 365 // Bind the array methods. 366 c.def("__getitem__", [](DerivedT &arr, intptr_t i) { 367 if (i >= mlirDenseArrayGetNumElements(arr)) 368 throw nb::index_error("DenseArray index out of range"); 369 return arr.getItem(i); 370 }); 371 c.def("__len__", [](const DerivedT &arr) { 372 return mlirDenseArrayGetNumElements(arr); 373 }); 374 c.def("__iter__", 375 [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); 376 c.def("__add__", [](DerivedT &arr, const nb::list &extras) { 377 std::vector<EltTy> values; 378 intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); 379 values.reserve(numOldElements + nb::len(extras)); 380 for (intptr_t i = 0; i < numOldElements; ++i) 381 values.push_back(arr.getItem(i)); 382 for (nb::handle attr : extras) 383 values.push_back(pyTryCast<EltTy>(attr)); 384 return getAttribute(values, arr.getContext()); 385 }); 386 } 387 388 private: 389 static DerivedT getAttribute(const std::vector<EltTy> &values, 390 PyMlirContextRef ctx) { 391 if constexpr (std::is_same_v<EltTy, bool>) { 392 std::vector<int> intValues(values.begin(), values.end()); 393 MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(), 394 intValues.data()); 395 return DerivedT(ctx, attr); 396 } else { 397 MlirAttribute attr = 398 DerivedT::getAttribute(ctx->get(), values.size(), values.data()); 399 return DerivedT(ctx, attr); 400 } 401 } 402 }; 403 404 /// Instantiate the python dense array classes. 405 struct PyDenseBoolArrayAttribute 406 : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> { 407 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; 408 static constexpr auto getAttribute = mlirDenseBoolArrayGet; 409 static constexpr auto getElement = mlirDenseBoolArrayGetElement; 410 static constexpr const char *pyClassName = "DenseBoolArrayAttr"; 411 static constexpr const char *pyIteratorName = "DenseBoolArrayIterator"; 412 using PyDenseArrayAttribute::PyDenseArrayAttribute; 413 }; 414 struct PyDenseI8ArrayAttribute 415 : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> { 416 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array; 417 static constexpr auto getAttribute = mlirDenseI8ArrayGet; 418 static constexpr auto getElement = mlirDenseI8ArrayGetElement; 419 static constexpr const char *pyClassName = "DenseI8ArrayAttr"; 420 static constexpr const char *pyIteratorName = "DenseI8ArrayIterator"; 421 using PyDenseArrayAttribute::PyDenseArrayAttribute; 422 }; 423 struct PyDenseI16ArrayAttribute 424 : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> { 425 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array; 426 static constexpr auto getAttribute = mlirDenseI16ArrayGet; 427 static constexpr auto getElement = mlirDenseI16ArrayGetElement; 428 static constexpr const char *pyClassName = "DenseI16ArrayAttr"; 429 static constexpr const char *pyIteratorName = "DenseI16ArrayIterator"; 430 using PyDenseArrayAttribute::PyDenseArrayAttribute; 431 }; 432 struct PyDenseI32ArrayAttribute 433 : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> { 434 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array; 435 static constexpr auto getAttribute = mlirDenseI32ArrayGet; 436 static constexpr auto getElement = mlirDenseI32ArrayGetElement; 437 static constexpr const char *pyClassName = "DenseI32ArrayAttr"; 438 static constexpr const char *pyIteratorName = "DenseI32ArrayIterator"; 439 using PyDenseArrayAttribute::PyDenseArrayAttribute; 440 }; 441 struct PyDenseI64ArrayAttribute 442 : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> { 443 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array; 444 static constexpr auto getAttribute = mlirDenseI64ArrayGet; 445 static constexpr auto getElement = mlirDenseI64ArrayGetElement; 446 static constexpr const char *pyClassName = "DenseI64ArrayAttr"; 447 static constexpr const char *pyIteratorName = "DenseI64ArrayIterator"; 448 using PyDenseArrayAttribute::PyDenseArrayAttribute; 449 }; 450 struct PyDenseF32ArrayAttribute 451 : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> { 452 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array; 453 static constexpr auto getAttribute = mlirDenseF32ArrayGet; 454 static constexpr auto getElement = mlirDenseF32ArrayGetElement; 455 static constexpr const char *pyClassName = "DenseF32ArrayAttr"; 456 static constexpr const char *pyIteratorName = "DenseF32ArrayIterator"; 457 using PyDenseArrayAttribute::PyDenseArrayAttribute; 458 }; 459 struct PyDenseF64ArrayAttribute 460 : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> { 461 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array; 462 static constexpr auto getAttribute = mlirDenseF64ArrayGet; 463 static constexpr auto getElement = mlirDenseF64ArrayGetElement; 464 static constexpr const char *pyClassName = "DenseF64ArrayAttr"; 465 static constexpr const char *pyIteratorName = "DenseF64ArrayIterator"; 466 using PyDenseArrayAttribute::PyDenseArrayAttribute; 467 }; 468 469 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 470 public: 471 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 472 static constexpr const char *pyClassName = "ArrayAttr"; 473 using PyConcreteAttribute::PyConcreteAttribute; 474 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 475 mlirArrayAttrGetTypeID; 476 477 class PyArrayAttributeIterator { 478 public: 479 PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} 480 481 PyArrayAttributeIterator &dunderIter() { return *this; } 482 483 MlirAttribute dunderNext() { 484 // TODO: Throw is an inefficient way to stop iteration. 485 if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) 486 throw nb::stop_iteration(); 487 return mlirArrayAttrGetElement(attr.get(), nextIndex++); 488 } 489 490 static void bind(nb::module_ &m) { 491 nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator") 492 .def("__iter__", &PyArrayAttributeIterator::dunderIter) 493 .def("__next__", &PyArrayAttributeIterator::dunderNext); 494 } 495 496 private: 497 PyAttribute attr; 498 int nextIndex = 0; 499 }; 500 501 MlirAttribute getItem(intptr_t i) { 502 return mlirArrayAttrGetElement(*this, i); 503 } 504 505 static void bindDerived(ClassTy &c) { 506 c.def_static( 507 "get", 508 [](nb::list attributes, DefaultingPyMlirContext context) { 509 SmallVector<MlirAttribute> mlirAttributes; 510 mlirAttributes.reserve(nb::len(attributes)); 511 for (auto attribute : attributes) { 512 mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); 513 } 514 MlirAttribute attr = mlirArrayAttrGet( 515 context->get(), mlirAttributes.size(), mlirAttributes.data()); 516 return PyArrayAttribute(context->getRef(), attr); 517 }, 518 nb::arg("attributes"), nb::arg("context").none() = nb::none(), 519 "Gets a uniqued Array attribute"); 520 c.def("__getitem__", 521 [](PyArrayAttribute &arr, intptr_t i) { 522 if (i >= mlirArrayAttrGetNumElements(arr)) 523 throw nb::index_error("ArrayAttribute index out of range"); 524 return arr.getItem(i); 525 }) 526 .def("__len__", 527 [](const PyArrayAttribute &arr) { 528 return mlirArrayAttrGetNumElements(arr); 529 }) 530 .def("__iter__", [](const PyArrayAttribute &arr) { 531 return PyArrayAttributeIterator(arr); 532 }); 533 c.def("__add__", [](PyArrayAttribute arr, nb::list extras) { 534 std::vector<MlirAttribute> attributes; 535 intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); 536 attributes.reserve(numOldElements + nb::len(extras)); 537 for (intptr_t i = 0; i < numOldElements; ++i) 538 attributes.push_back(arr.getItem(i)); 539 for (nb::handle attr : extras) 540 attributes.push_back(pyTryCast<PyAttribute>(attr)); 541 MlirAttribute arrayAttr = mlirArrayAttrGet( 542 arr.getContext()->get(), attributes.size(), attributes.data()); 543 return PyArrayAttribute(arr.getContext(), arrayAttr); 544 }); 545 } 546 }; 547 548 /// Float Point Attribute subclass - FloatAttr. 549 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 550 public: 551 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 552 static constexpr const char *pyClassName = "FloatAttr"; 553 using PyConcreteAttribute::PyConcreteAttribute; 554 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 555 mlirFloatAttrGetTypeID; 556 557 static void bindDerived(ClassTy &c) { 558 c.def_static( 559 "get", 560 [](PyType &type, double value, DefaultingPyLocation loc) { 561 PyMlirContext::ErrorCapture errors(loc->getContext()); 562 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 563 if (mlirAttributeIsNull(attr)) 564 throw MLIRError("Invalid attribute", errors.take()); 565 return PyFloatAttribute(type.getContext(), attr); 566 }, 567 nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(), 568 "Gets an uniqued float point attribute associated to a type"); 569 c.def_static( 570 "get_f32", 571 [](double value, DefaultingPyMlirContext context) { 572 MlirAttribute attr = mlirFloatAttrDoubleGet( 573 context->get(), mlirF32TypeGet(context->get()), value); 574 return PyFloatAttribute(context->getRef(), attr); 575 }, 576 nb::arg("value"), nb::arg("context").none() = nb::none(), 577 "Gets an uniqued float point attribute associated to a f32 type"); 578 c.def_static( 579 "get_f64", 580 [](double value, DefaultingPyMlirContext context) { 581 MlirAttribute attr = mlirFloatAttrDoubleGet( 582 context->get(), mlirF64TypeGet(context->get()), value); 583 return PyFloatAttribute(context->getRef(), attr); 584 }, 585 nb::arg("value"), nb::arg("context").none() = nb::none(), 586 "Gets an uniqued float point attribute associated to a f64 type"); 587 c.def_prop_ro("value", mlirFloatAttrGetValueDouble, 588 "Returns the value of the float attribute"); 589 c.def("__float__", mlirFloatAttrGetValueDouble, 590 "Converts the value of the float attribute to a Python float"); 591 } 592 }; 593 594 /// Integer Attribute subclass - IntegerAttr. 595 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 596 public: 597 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 598 static constexpr const char *pyClassName = "IntegerAttr"; 599 using PyConcreteAttribute::PyConcreteAttribute; 600 601 static void bindDerived(ClassTy &c) { 602 c.def_static( 603 "get", 604 [](PyType &type, int64_t value) { 605 MlirAttribute attr = mlirIntegerAttrGet(type, value); 606 return PyIntegerAttribute(type.getContext(), attr); 607 }, 608 nb::arg("type"), nb::arg("value"), 609 "Gets an uniqued integer attribute associated to a type"); 610 c.def_prop_ro("value", toPyInt, 611 "Returns the value of the integer attribute"); 612 c.def("__int__", toPyInt, 613 "Converts the value of the integer attribute to a Python int"); 614 c.def_prop_ro_static("static_typeid", 615 [](nb::object & /*class*/) -> MlirTypeID { 616 return mlirIntegerAttrGetTypeID(); 617 }); 618 } 619 620 private: 621 static int64_t toPyInt(PyIntegerAttribute &self) { 622 MlirType type = mlirAttributeGetType(self); 623 if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) 624 return mlirIntegerAttrGetValueInt(self); 625 if (mlirIntegerTypeIsSigned(type)) 626 return mlirIntegerAttrGetValueSInt(self); 627 return mlirIntegerAttrGetValueUInt(self); 628 } 629 }; 630 631 /// Bool Attribute subclass - BoolAttr. 632 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 633 public: 634 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 635 static constexpr const char *pyClassName = "BoolAttr"; 636 using PyConcreteAttribute::PyConcreteAttribute; 637 638 static void bindDerived(ClassTy &c) { 639 c.def_static( 640 "get", 641 [](bool value, DefaultingPyMlirContext context) { 642 MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 643 return PyBoolAttribute(context->getRef(), attr); 644 }, 645 nb::arg("value"), nb::arg("context").none() = nb::none(), 646 "Gets an uniqued bool attribute"); 647 c.def_prop_ro("value", mlirBoolAttrGetValue, 648 "Returns the value of the bool attribute"); 649 c.def("__bool__", mlirBoolAttrGetValue, 650 "Converts the value of the bool attribute to a Python bool"); 651 } 652 }; 653 654 class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> { 655 public: 656 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef; 657 static constexpr const char *pyClassName = "SymbolRefAttr"; 658 using PyConcreteAttribute::PyConcreteAttribute; 659 660 static MlirAttribute fromList(const std::vector<std::string> &symbols, 661 PyMlirContext &context) { 662 if (symbols.empty()) 663 throw std::runtime_error("SymbolRefAttr must be composed of at least " 664 "one symbol."); 665 MlirStringRef rootSymbol = toMlirStringRef(symbols[0]); 666 SmallVector<MlirAttribute, 3> referenceAttrs; 667 for (size_t i = 1; i < symbols.size(); ++i) { 668 referenceAttrs.push_back( 669 mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i]))); 670 } 671 return mlirSymbolRefAttrGet(context.get(), rootSymbol, 672 referenceAttrs.size(), referenceAttrs.data()); 673 } 674 675 static void bindDerived(ClassTy &c) { 676 c.def_static( 677 "get", 678 [](const std::vector<std::string> &symbols, 679 DefaultingPyMlirContext context) { 680 return PySymbolRefAttribute::fromList(symbols, context.resolve()); 681 }, 682 nb::arg("symbols"), nb::arg("context").none() = nb::none(), 683 "Gets a uniqued SymbolRef attribute from a list of symbol names"); 684 c.def_prop_ro( 685 "value", 686 [](PySymbolRefAttribute &self) { 687 std::vector<std::string> symbols = { 688 unwrap(mlirSymbolRefAttrGetRootReference(self)).str()}; 689 for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); 690 ++i) 691 symbols.push_back( 692 unwrap(mlirSymbolRefAttrGetRootReference( 693 mlirSymbolRefAttrGetNestedReference(self, i))) 694 .str()); 695 return symbols; 696 }, 697 "Returns the value of the SymbolRef attribute as a list[str]"); 698 } 699 }; 700 701 class PyFlatSymbolRefAttribute 702 : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 703 public: 704 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 705 static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 706 using PyConcreteAttribute::PyConcreteAttribute; 707 708 static void bindDerived(ClassTy &c) { 709 c.def_static( 710 "get", 711 [](std::string value, DefaultingPyMlirContext context) { 712 MlirAttribute attr = 713 mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 714 return PyFlatSymbolRefAttribute(context->getRef(), attr); 715 }, 716 nb::arg("value"), nb::arg("context").none() = nb::none(), 717 "Gets a uniqued FlatSymbolRef attribute"); 718 c.def_prop_ro( 719 "value", 720 [](PyFlatSymbolRefAttribute &self) { 721 MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 722 return nb::str(stringRef.data, stringRef.length); 723 }, 724 "Returns the value of the FlatSymbolRef attribute as a string"); 725 } 726 }; 727 728 class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> { 729 public: 730 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; 731 static constexpr const char *pyClassName = "OpaqueAttr"; 732 using PyConcreteAttribute::PyConcreteAttribute; 733 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 734 mlirOpaqueAttrGetTypeID; 735 736 static void bindDerived(ClassTy &c) { 737 c.def_static( 738 "get", 739 [](std::string dialectNamespace, nb_buffer buffer, PyType &type, 740 DefaultingPyMlirContext context) { 741 const nb_buffer_info bufferInfo = buffer.request(); 742 intptr_t bufferSize = bufferInfo.size; 743 MlirAttribute attr = mlirOpaqueAttrGet( 744 context->get(), toMlirStringRef(dialectNamespace), bufferSize, 745 static_cast<char *>(bufferInfo.ptr), type); 746 return PyOpaqueAttribute(context->getRef(), attr); 747 }, 748 nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"), 749 nb::arg("context").none() = nb::none(), "Gets an Opaque attribute."); 750 c.def_prop_ro( 751 "dialect_namespace", 752 [](PyOpaqueAttribute &self) { 753 MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); 754 return nb::str(stringRef.data, stringRef.length); 755 }, 756 "Returns the dialect namespace for the Opaque attribute as a string"); 757 c.def_prop_ro( 758 "data", 759 [](PyOpaqueAttribute &self) { 760 MlirStringRef stringRef = mlirOpaqueAttrGetData(self); 761 return nb::bytes(stringRef.data, stringRef.length); 762 }, 763 "Returns the data for the Opaqued attributes as `bytes`"); 764 } 765 }; 766 767 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 768 public: 769 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 770 static constexpr const char *pyClassName = "StringAttr"; 771 using PyConcreteAttribute::PyConcreteAttribute; 772 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 773 mlirStringAttrGetTypeID; 774 775 static void bindDerived(ClassTy &c) { 776 c.def_static( 777 "get", 778 [](std::string value, DefaultingPyMlirContext context) { 779 MlirAttribute attr = 780 mlirStringAttrGet(context->get(), toMlirStringRef(value)); 781 return PyStringAttribute(context->getRef(), attr); 782 }, 783 nb::arg("value"), nb::arg("context").none() = nb::none(), 784 "Gets a uniqued string attribute"); 785 c.def_static( 786 "get", 787 [](nb::bytes value, DefaultingPyMlirContext context) { 788 MlirAttribute attr = 789 mlirStringAttrGet(context->get(), toMlirStringRef(value)); 790 return PyStringAttribute(context->getRef(), attr); 791 }, 792 nb::arg("value"), nb::arg("context").none() = nb::none(), 793 "Gets a uniqued string attribute"); 794 c.def_static( 795 "get_typed", 796 [](PyType &type, std::string value) { 797 MlirAttribute attr = 798 mlirStringAttrTypedGet(type, toMlirStringRef(value)); 799 return PyStringAttribute(type.getContext(), attr); 800 }, 801 nb::arg("type"), nb::arg("value"), 802 "Gets a uniqued string attribute associated to a type"); 803 c.def_prop_ro( 804 "value", 805 [](PyStringAttribute &self) { 806 MlirStringRef stringRef = mlirStringAttrGetValue(self); 807 return nb::str(stringRef.data, stringRef.length); 808 }, 809 "Returns the value of the string attribute"); 810 c.def_prop_ro( 811 "value_bytes", 812 [](PyStringAttribute &self) { 813 MlirStringRef stringRef = mlirStringAttrGetValue(self); 814 return nb::bytes(stringRef.data, stringRef.length); 815 }, 816 "Returns the value of the string attribute as `bytes`"); 817 } 818 }; 819 820 // TODO: Support construction of string elements. 821 class PyDenseElementsAttribute 822 : public PyConcreteAttribute<PyDenseElementsAttribute> { 823 public: 824 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 825 static constexpr const char *pyClassName = "DenseElementsAttr"; 826 using PyConcreteAttribute::PyConcreteAttribute; 827 828 static PyDenseElementsAttribute 829 getFromList(nb::list attributes, std::optional<PyType> explicitType, 830 DefaultingPyMlirContext contextWrapper) { 831 const size_t numAttributes = nb::len(attributes); 832 if (numAttributes == 0) 833 throw nb::value_error("Attributes list must be non-empty."); 834 835 MlirType shapedType; 836 if (explicitType) { 837 if ((!mlirTypeIsAShaped(*explicitType) || 838 !mlirShapedTypeHasStaticShape(*explicitType))) { 839 840 std::string message; 841 llvm::raw_string_ostream os(message); 842 os << "Expected a static ShapedType for the shaped_type parameter: " 843 << nb::cast<std::string>(nb::repr(nb::cast(*explicitType))); 844 throw nb::value_error(message.c_str()); 845 } 846 shapedType = *explicitType; 847 } else { 848 SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)}; 849 shapedType = mlirRankedTensorTypeGet( 850 shape.size(), shape.data(), 851 mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])), 852 mlirAttributeGetNull()); 853 } 854 855 SmallVector<MlirAttribute> mlirAttributes; 856 mlirAttributes.reserve(numAttributes); 857 for (const nb::handle &attribute : attributes) { 858 MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute); 859 MlirType attrType = mlirAttributeGetType(mlirAttribute); 860 mlirAttributes.push_back(mlirAttribute); 861 862 if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) { 863 std::string message; 864 llvm::raw_string_ostream os(message); 865 os << "All attributes must be of the same type and match " 866 << "the type parameter: expected=" 867 << nb::cast<std::string>(nb::repr(nb::cast(shapedType))) 868 << ", but got=" 869 << nb::cast<std::string>(nb::repr(nb::cast(attrType))); 870 throw nb::value_error(message.c_str()); 871 } 872 } 873 874 MlirAttribute elements = mlirDenseElementsAttrGet( 875 shapedType, mlirAttributes.size(), mlirAttributes.data()); 876 877 return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 878 } 879 880 static PyDenseElementsAttribute 881 getFromBuffer(nb_buffer array, bool signless, 882 std::optional<PyType> explicitType, 883 std::optional<std::vector<int64_t>> explicitShape, 884 DefaultingPyMlirContext contextWrapper) { 885 // Request a contiguous view. In exotic cases, this will cause a copy. 886 int flags = PyBUF_ND; 887 if (!explicitType) { 888 flags |= PyBUF_FORMAT; 889 } 890 Py_buffer view; 891 if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) { 892 throw nb::python_error(); 893 } 894 auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); 895 896 MlirContext context = contextWrapper->get(); 897 MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, 898 explicitShape, context); 899 if (mlirAttributeIsNull(attr)) { 900 throw std::invalid_argument( 901 "DenseElementsAttr could not be constructed from the given buffer. " 902 "This may mean that the Python buffer layout does not match that " 903 "MLIR expected layout and is a bug."); 904 } 905 return PyDenseElementsAttribute(contextWrapper->getRef(), attr); 906 } 907 908 static PyDenseElementsAttribute getSplat(const PyType &shapedType, 909 PyAttribute &elementAttr) { 910 auto contextWrapper = 911 PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 912 if (!mlirAttributeIsAInteger(elementAttr) && 913 !mlirAttributeIsAFloat(elementAttr)) { 914 std::string message = "Illegal element type for DenseElementsAttr: "; 915 message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr)))); 916 throw nb::value_error(message.c_str()); 917 } 918 if (!mlirTypeIsAShaped(shapedType) || 919 !mlirShapedTypeHasStaticShape(shapedType)) { 920 std::string message = 921 "Expected a static ShapedType for the shaped_type parameter: "; 922 message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType)))); 923 throw nb::value_error(message.c_str()); 924 } 925 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 926 MlirType attrType = mlirAttributeGetType(elementAttr); 927 if (!mlirTypeEqual(shapedElementType, attrType)) { 928 std::string message = 929 "Shaped element type and attribute type must be equal: shaped="; 930 message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType)))); 931 message.append(", element="); 932 message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr)))); 933 throw nb::value_error(message.c_str()); 934 } 935 936 MlirAttribute elements = 937 mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 938 return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 939 } 940 941 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 942 943 std::unique_ptr<nb_buffer_info> accessBuffer() { 944 MlirType shapedType = mlirAttributeGetType(*this); 945 MlirType elementType = mlirShapedTypeGetElementType(shapedType); 946 std::string format; 947 948 if (mlirTypeIsAF32(elementType)) { 949 // f32 950 return bufferInfo<float>(shapedType); 951 } 952 if (mlirTypeIsAF64(elementType)) { 953 // f64 954 return bufferInfo<double>(shapedType); 955 } 956 if (mlirTypeIsAF16(elementType)) { 957 // f16 958 return bufferInfo<uint16_t>(shapedType, "e"); 959 } 960 if (mlirTypeIsAIndex(elementType)) { 961 // Same as IndexType::kInternalStorageBitWidth 962 return bufferInfo<int64_t>(shapedType); 963 } 964 if (mlirTypeIsAInteger(elementType) && 965 mlirIntegerTypeGetWidth(elementType) == 32) { 966 if (mlirIntegerTypeIsSignless(elementType) || 967 mlirIntegerTypeIsSigned(elementType)) { 968 // i32 969 return bufferInfo<int32_t>(shapedType); 970 } 971 if (mlirIntegerTypeIsUnsigned(elementType)) { 972 // unsigned i32 973 return bufferInfo<uint32_t>(shapedType); 974 } 975 } else if (mlirTypeIsAInteger(elementType) && 976 mlirIntegerTypeGetWidth(elementType) == 64) { 977 if (mlirIntegerTypeIsSignless(elementType) || 978 mlirIntegerTypeIsSigned(elementType)) { 979 // i64 980 return bufferInfo<int64_t>(shapedType); 981 } 982 if (mlirIntegerTypeIsUnsigned(elementType)) { 983 // unsigned i64 984 return bufferInfo<uint64_t>(shapedType); 985 } 986 } else if (mlirTypeIsAInteger(elementType) && 987 mlirIntegerTypeGetWidth(elementType) == 8) { 988 if (mlirIntegerTypeIsSignless(elementType) || 989 mlirIntegerTypeIsSigned(elementType)) { 990 // i8 991 return bufferInfo<int8_t>(shapedType); 992 } 993 if (mlirIntegerTypeIsUnsigned(elementType)) { 994 // unsigned i8 995 return bufferInfo<uint8_t>(shapedType); 996 } 997 } else if (mlirTypeIsAInteger(elementType) && 998 mlirIntegerTypeGetWidth(elementType) == 16) { 999 if (mlirIntegerTypeIsSignless(elementType) || 1000 mlirIntegerTypeIsSigned(elementType)) { 1001 // i16 1002 return bufferInfo<int16_t>(shapedType); 1003 } 1004 if (mlirIntegerTypeIsUnsigned(elementType)) { 1005 // unsigned i16 1006 return bufferInfo<uint16_t>(shapedType); 1007 } 1008 } else if (mlirTypeIsAInteger(elementType) && 1009 mlirIntegerTypeGetWidth(elementType) == 1) { 1010 // i1 / bool 1011 // We can not send the buffer directly back to Python, because the i1 1012 // values are bitpacked within MLIR. We call numpy's unpackbits function 1013 // to convert the bytes. 1014 return getBooleanBufferFromBitpackedAttribute(); 1015 } 1016 1017 // TODO: Currently crashes the program. 1018 // Reported as https://github.com/pybind/pybind11/issues/3336 1019 throw std::invalid_argument( 1020 "unsupported data type for conversion to Python buffer"); 1021 } 1022 1023 static void bindDerived(ClassTy &c) { 1024 #if PY_VERSION_HEX < 0x03090000 1025 PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr()); 1026 tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer; 1027 tp->tp_as_buffer->bf_releasebuffer = 1028 PyDenseElementsAttribute::bf_releasebuffer; 1029 #endif 1030 c.def("__len__", &PyDenseElementsAttribute::dunderLen) 1031 .def_static("get", PyDenseElementsAttribute::getFromBuffer, 1032 nb::arg("array"), nb::arg("signless") = true, 1033 nb::arg("type").none() = nb::none(), 1034 nb::arg("shape").none() = nb::none(), 1035 nb::arg("context").none() = nb::none(), 1036 kDenseElementsAttrGetDocstring) 1037 .def_static("get", PyDenseElementsAttribute::getFromList, 1038 nb::arg("attrs"), nb::arg("type").none() = nb::none(), 1039 nb::arg("context").none() = nb::none(), 1040 kDenseElementsAttrGetFromListDocstring) 1041 .def_static("get_splat", PyDenseElementsAttribute::getSplat, 1042 nb::arg("shaped_type"), nb::arg("element_attr"), 1043 "Gets a DenseElementsAttr where all values are the same") 1044 .def_prop_ro("is_splat", 1045 [](PyDenseElementsAttribute &self) -> bool { 1046 return mlirDenseElementsAttrIsSplat(self); 1047 }) 1048 .def("get_splat_value", [](PyDenseElementsAttribute &self) { 1049 if (!mlirDenseElementsAttrIsSplat(self)) 1050 throw nb::value_error( 1051 "get_splat_value called on a non-splat attribute"); 1052 return mlirDenseElementsAttrGetSplatValue(self); 1053 }); 1054 } 1055 1056 static PyType_Slot slots[]; 1057 1058 private: 1059 static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags); 1060 static void bf_releasebuffer(PyObject *, Py_buffer *buffer); 1061 1062 static bool isUnsignedIntegerFormat(std::string_view format) { 1063 if (format.empty()) 1064 return false; 1065 char code = format[0]; 1066 return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 1067 code == 'Q'; 1068 } 1069 1070 static bool isSignedIntegerFormat(std::string_view format) { 1071 if (format.empty()) 1072 return false; 1073 char code = format[0]; 1074 return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 1075 code == 'q'; 1076 } 1077 1078 static MlirType 1079 getShapedType(std::optional<MlirType> bulkLoadElementType, 1080 std::optional<std::vector<int64_t>> explicitShape, 1081 Py_buffer &view) { 1082 SmallVector<int64_t> shape; 1083 if (explicitShape) { 1084 shape.append(explicitShape->begin(), explicitShape->end()); 1085 } else { 1086 shape.append(view.shape, view.shape + view.ndim); 1087 } 1088 1089 if (mlirTypeIsAShaped(*bulkLoadElementType)) { 1090 if (explicitShape) { 1091 throw std::invalid_argument("Shape can only be specified explicitly " 1092 "when the type is not a shaped type."); 1093 } 1094 return *bulkLoadElementType; 1095 } else { 1096 MlirAttribute encodingAttr = mlirAttributeGetNull(); 1097 return mlirRankedTensorTypeGet(shape.size(), shape.data(), 1098 *bulkLoadElementType, encodingAttr); 1099 } 1100 } 1101 1102 static MlirAttribute getAttributeFromBuffer( 1103 Py_buffer &view, bool signless, std::optional<PyType> explicitType, 1104 std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) { 1105 // Detect format codes that are suitable for bulk loading. This includes 1106 // all byte aligned integer and floating point types up to 8 bytes. 1107 // Notably, this excludes exotics types which do not have a direct 1108 // representation in the buffer protocol (i.e. complex, etc). 1109 std::optional<MlirType> bulkLoadElementType; 1110 if (explicitType) { 1111 bulkLoadElementType = *explicitType; 1112 } else { 1113 std::string_view format(view.format); 1114 if (format == "f") { 1115 // f32 1116 assert(view.itemsize == 4 && "mismatched array itemsize"); 1117 bulkLoadElementType = mlirF32TypeGet(context); 1118 } else if (format == "d") { 1119 // f64 1120 assert(view.itemsize == 8 && "mismatched array itemsize"); 1121 bulkLoadElementType = mlirF64TypeGet(context); 1122 } else if (format == "e") { 1123 // f16 1124 assert(view.itemsize == 2 && "mismatched array itemsize"); 1125 bulkLoadElementType = mlirF16TypeGet(context); 1126 } else if (format == "?") { 1127 // i1 1128 // The i1 type needs to be bit-packed, so we will handle it seperately 1129 return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, 1130 context); 1131 } else if (isSignedIntegerFormat(format)) { 1132 if (view.itemsize == 4) { 1133 // i32 1134 bulkLoadElementType = signless 1135 ? mlirIntegerTypeGet(context, 32) 1136 : mlirIntegerTypeSignedGet(context, 32); 1137 } else if (view.itemsize == 8) { 1138 // i64 1139 bulkLoadElementType = signless 1140 ? mlirIntegerTypeGet(context, 64) 1141 : mlirIntegerTypeSignedGet(context, 64); 1142 } else if (view.itemsize == 1) { 1143 // i8 1144 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 1145 : mlirIntegerTypeSignedGet(context, 8); 1146 } else if (view.itemsize == 2) { 1147 // i16 1148 bulkLoadElementType = signless 1149 ? mlirIntegerTypeGet(context, 16) 1150 : mlirIntegerTypeSignedGet(context, 16); 1151 } 1152 } else if (isUnsignedIntegerFormat(format)) { 1153 if (view.itemsize == 4) { 1154 // unsigned i32 1155 bulkLoadElementType = signless 1156 ? mlirIntegerTypeGet(context, 32) 1157 : mlirIntegerTypeUnsignedGet(context, 32); 1158 } else if (view.itemsize == 8) { 1159 // unsigned i64 1160 bulkLoadElementType = signless 1161 ? mlirIntegerTypeGet(context, 64) 1162 : mlirIntegerTypeUnsignedGet(context, 64); 1163 } else if (view.itemsize == 1) { 1164 // i8 1165 bulkLoadElementType = signless 1166 ? mlirIntegerTypeGet(context, 8) 1167 : mlirIntegerTypeUnsignedGet(context, 8); 1168 } else if (view.itemsize == 2) { 1169 // i16 1170 bulkLoadElementType = signless 1171 ? mlirIntegerTypeGet(context, 16) 1172 : mlirIntegerTypeUnsignedGet(context, 16); 1173 } 1174 } 1175 if (!bulkLoadElementType) { 1176 throw std::invalid_argument( 1177 std::string("unimplemented array format conversion from format: ") + 1178 std::string(format)); 1179 } 1180 } 1181 1182 MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); 1183 return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); 1184 } 1185 1186 // There is a complication for boolean numpy arrays, as numpy represents 1187 // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 1188 // booleans per byte. 1189 static MlirAttribute getBitpackedAttributeFromBooleanBuffer( 1190 Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape, 1191 MlirContext &context) { 1192 if (llvm::endianness::native != llvm::endianness::little) { 1193 // Given we have no good way of testing the behavior on big-endian 1194 // systems we will throw 1195 throw nb::type_error("Constructing a bit-packed MLIR attribute is " 1196 "unsupported on big-endian systems"); 1197 } 1198 nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray( 1199 /*data=*/static_cast<uint8_t *>(view.buf), 1200 /*shape=*/{static_cast<size_t>(view.len)}); 1201 1202 nb::module_ numpy = nb::module_::import_("numpy"); 1203 nb::object packbitsFunc = numpy.attr("packbits"); 1204 nb::object packedBooleans = 1205 packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little"); 1206 nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request(); 1207 1208 MlirType bitpackedType = 1209 getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); 1210 assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8"); 1211 // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of 1212 // packedBooleans, hence the MlirAttribute will remain valid even when 1213 // packedBooleans get reclaimed by the end of the function. 1214 return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, 1215 pythonBuffer.ptr); 1216 } 1217 1218 // This does the opposite transformation of 1219 // `getBitpackedAttributeFromBooleanBuffer` 1220 std::unique_ptr<nb_buffer_info> getBooleanBufferFromBitpackedAttribute() { 1221 if (llvm::endianness::native != llvm::endianness::little) { 1222 // Given we have no good way of testing the behavior on big-endian 1223 // systems we will throw 1224 throw nb::type_error("Constructing a numpy array from a MLIR attribute " 1225 "is unsupported on big-endian systems"); 1226 } 1227 1228 int64_t numBooleans = mlirElementsAttrGetNumElements(*this); 1229 int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); 1230 uint8_t *bitpackedData = static_cast<uint8_t *>( 1231 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 1232 nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray( 1233 /*data=*/bitpackedData, 1234 /*shape=*/{static_cast<size_t>(numBitpackedBytes)}); 1235 1236 nb::module_ numpy = nb::module_::import_("numpy"); 1237 nb::object unpackbitsFunc = numpy.attr("unpackbits"); 1238 nb::object equalFunc = numpy.attr("equal"); 1239 nb::object reshapeFunc = numpy.attr("reshape"); 1240 nb::object unpackedBooleans = 1241 unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little"); 1242 1243 // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array. 1244 // We need to: 1245 // 1. Slice away the padded bits 1246 // 2. Make the boolean array have the correct shape 1247 // 3. Convert the array to a boolean array 1248 unpackedBooleans = unpackedBooleans[nb::slice( 1249 nb::int_(0), nb::int_(numBooleans), nb::int_(1))]; 1250 unpackedBooleans = equalFunc(unpackedBooleans, 1); 1251 1252 MlirType shapedType = mlirAttributeGetType(*this); 1253 intptr_t rank = mlirShapedTypeGetRank(shapedType); 1254 std::vector<intptr_t> shape(rank); 1255 for (intptr_t i = 0; i < rank; ++i) { 1256 shape[i] = mlirShapedTypeGetDimSize(shapedType, i); 1257 } 1258 unpackedBooleans = reshapeFunc(unpackedBooleans, shape); 1259 1260 // Make sure the returned nb::buffer_view claims ownership of the data in 1261 // `pythonBuffer` so it remains valid when Python reads it 1262 nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans); 1263 return std::make_unique<nb_buffer_info>(pythonBuffer.request()); 1264 } 1265 1266 template <typename Type> 1267 std::unique_ptr<nb_buffer_info> 1268 bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) { 1269 intptr_t rank = mlirShapedTypeGetRank(shapedType); 1270 // Prepare the data for the buffer_info. 1271 // Buffer is configured for read-only access below. 1272 Type *data = static_cast<Type *>( 1273 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 1274 // Prepare the shape for the buffer_info. 1275 SmallVector<intptr_t, 4> shape; 1276 for (intptr_t i = 0; i < rank; ++i) 1277 shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 1278 // Prepare the strides for the buffer_info. 1279 SmallVector<intptr_t, 4> strides; 1280 if (mlirDenseElementsAttrIsSplat(*this)) { 1281 // Splats are special, only the single value is stored. 1282 strides.assign(rank, 0); 1283 } else { 1284 for (intptr_t i = 1; i < rank; ++i) { 1285 intptr_t strideFactor = 1; 1286 for (intptr_t j = i; j < rank; ++j) 1287 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 1288 strides.push_back(sizeof(Type) * strideFactor); 1289 } 1290 strides.push_back(sizeof(Type)); 1291 } 1292 const char *format; 1293 if (explicitFormat) { 1294 format = explicitFormat; 1295 } else { 1296 format = nb_format_descriptor<Type>::format(); 1297 } 1298 return std::make_unique<nb_buffer_info>( 1299 data, sizeof(Type), format, rank, std::move(shape), std::move(strides), 1300 /*readonly=*/true); 1301 } 1302 }; // namespace 1303 1304 PyType_Slot PyDenseElementsAttribute::slots[] = { 1305 // Python 3.8 doesn't allow setting the buffer protocol slots from a type spec. 1306 #if PY_VERSION_HEX >= 0x03090000 1307 {Py_bf_getbuffer, 1308 reinterpret_cast<void *>(PyDenseElementsAttribute::bf_getbuffer)}, 1309 {Py_bf_releasebuffer, 1310 reinterpret_cast<void *>(PyDenseElementsAttribute::bf_releasebuffer)}, 1311 #endif 1312 {0, nullptr}, 1313 }; 1314 1315 /*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj, 1316 Py_buffer *view, 1317 int flags) { 1318 view->obj = nullptr; 1319 std::unique_ptr<nb_buffer_info> info; 1320 try { 1321 auto *attr = nb::cast<PyDenseElementsAttribute *>(nb::handle(obj)); 1322 info = attr->accessBuffer(); 1323 } catch (nb::python_error &e) { 1324 e.restore(); 1325 nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer"); 1326 return -1; 1327 } 1328 view->obj = obj; 1329 view->ndim = 1; 1330 view->buf = info->ptr; 1331 view->itemsize = info->itemsize; 1332 view->len = info->itemsize; 1333 for (auto s : info->shape) { 1334 view->len *= s; 1335 } 1336 view->readonly = info->readonly; 1337 if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { 1338 view->format = const_cast<char *>(info->format); 1339 } 1340 if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { 1341 view->ndim = static_cast<int>(info->ndim); 1342 view->strides = info->strides.data(); 1343 view->shape = info->shape.data(); 1344 } 1345 view->suboffsets = nullptr; 1346 view->internal = info.release(); 1347 Py_INCREF(obj); 1348 return 0; 1349 } 1350 1351 /*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *, 1352 Py_buffer *view) { 1353 delete reinterpret_cast<nb_buffer_info *>(view->internal); 1354 } 1355 1356 /// Refinement of the PyDenseElementsAttribute for attributes containing 1357 /// integer (and boolean) values. Supports element access. 1358 class PyDenseIntElementsAttribute 1359 : public PyConcreteAttribute<PyDenseIntElementsAttribute, 1360 PyDenseElementsAttribute> { 1361 public: 1362 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 1363 static constexpr const char *pyClassName = "DenseIntElementsAttr"; 1364 using PyConcreteAttribute::PyConcreteAttribute; 1365 1366 /// Returns the element at the given linear position. Asserts if the index 1367 /// is out of range. 1368 nb::object dunderGetItem(intptr_t pos) { 1369 if (pos < 0 || pos >= dunderLen()) { 1370 throw nb::index_error("attempt to access out of bounds element"); 1371 } 1372 1373 MlirType type = mlirAttributeGetType(*this); 1374 type = mlirShapedTypeGetElementType(type); 1375 // Index type can also appear as a DenseIntElementsAttr and therefore can be 1376 // casted to integer. 1377 assert(mlirTypeIsAInteger(type) || 1378 mlirTypeIsAIndex(type) && "expected integer/index element type in " 1379 "dense int elements attribute"); 1380 // Dispatch element extraction to an appropriate C function based on the 1381 // elemental type of the attribute. nb::int_ is implicitly constructible 1382 // from any C++ integral type and handles bitwidth correctly. 1383 // TODO: consider caching the type properties in the constructor to avoid 1384 // querying them on each element access. 1385 if (mlirTypeIsAIndex(type)) { 1386 return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos)); 1387 } 1388 unsigned width = mlirIntegerTypeGetWidth(type); 1389 bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 1390 if (isUnsigned) { 1391 if (width == 1) { 1392 return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); 1393 } 1394 if (width == 8) { 1395 return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos)); 1396 } 1397 if (width == 16) { 1398 return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos)); 1399 } 1400 if (width == 32) { 1401 return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos)); 1402 } 1403 if (width == 64) { 1404 return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos)); 1405 } 1406 } else { 1407 if (width == 1) { 1408 return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); 1409 } 1410 if (width == 8) { 1411 return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos)); 1412 } 1413 if (width == 16) { 1414 return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos)); 1415 } 1416 if (width == 32) { 1417 return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos)); 1418 } 1419 if (width == 64) { 1420 return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos)); 1421 } 1422 } 1423 throw nb::type_error("Unsupported integer type"); 1424 } 1425 1426 static void bindDerived(ClassTy &c) { 1427 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 1428 } 1429 }; 1430 1431 class PyDenseResourceElementsAttribute 1432 : public PyConcreteAttribute<PyDenseResourceElementsAttribute> { 1433 public: 1434 static constexpr IsAFunctionTy isaFunction = 1435 mlirAttributeIsADenseResourceElements; 1436 static constexpr const char *pyClassName = "DenseResourceElementsAttr"; 1437 using PyConcreteAttribute::PyConcreteAttribute; 1438 1439 static PyDenseResourceElementsAttribute 1440 getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type, 1441 std::optional<size_t> alignment, bool isMutable, 1442 DefaultingPyMlirContext contextWrapper) { 1443 if (!mlirTypeIsAShaped(type)) { 1444 throw std::invalid_argument( 1445 "Constructing a DenseResourceElementsAttr requires a ShapedType."); 1446 } 1447 1448 // Do not request any conversions as we must ensure to use caller 1449 // managed memory. 1450 int flags = PyBUF_STRIDES; 1451 std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>(); 1452 if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { 1453 throw nb::python_error(); 1454 } 1455 1456 // This scope releaser will only release if we haven't yet transferred 1457 // ownership. 1458 auto freeBuffer = llvm::make_scope_exit([&]() { 1459 if (view) 1460 PyBuffer_Release(view.get()); 1461 }); 1462 1463 if (!PyBuffer_IsContiguous(view.get(), 'A')) { 1464 throw std::invalid_argument("Contiguous buffer is required."); 1465 } 1466 1467 // Infer alignment to be the stride of one element if not explicit. 1468 size_t inferredAlignment; 1469 if (alignment) 1470 inferredAlignment = *alignment; 1471 else 1472 inferredAlignment = view->strides[view->ndim - 1]; 1473 1474 // The userData is a Py_buffer* that the deleter owns. 1475 auto deleter = [](void *userData, const void *data, size_t size, 1476 size_t align) { 1477 if (!Py_IsInitialized()) 1478 Py_Initialize(); 1479 Py_buffer *ownedView = static_cast<Py_buffer *>(userData); 1480 nb::gil_scoped_acquire gil; 1481 PyBuffer_Release(ownedView); 1482 delete ownedView; 1483 }; 1484 1485 size_t rawBufferSize = view->len; 1486 MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet( 1487 type, toMlirStringRef(name), view->buf, rawBufferSize, 1488 inferredAlignment, isMutable, deleter, static_cast<void *>(view.get())); 1489 if (mlirAttributeIsNull(attr)) { 1490 throw std::invalid_argument( 1491 "DenseResourceElementsAttr could not be constructed from the given " 1492 "buffer. " 1493 "This may mean that the Python buffer layout does not match that " 1494 "MLIR expected layout and is a bug."); 1495 } 1496 view.release(); 1497 return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr); 1498 } 1499 1500 static void bindDerived(ClassTy &c) { 1501 c.def_static( 1502 "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer, 1503 nb::arg("array"), nb::arg("name"), nb::arg("type"), 1504 nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false, 1505 nb::arg("context").none() = nb::none(), 1506 kDenseResourceElementsAttrGetFromBufferDocstring); 1507 } 1508 }; 1509 1510 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 1511 public: 1512 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 1513 static constexpr const char *pyClassName = "DictAttr"; 1514 using PyConcreteAttribute::PyConcreteAttribute; 1515 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 1516 mlirDictionaryAttrGetTypeID; 1517 1518 intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 1519 1520 bool dunderContains(const std::string &name) { 1521 return !mlirAttributeIsNull( 1522 mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); 1523 } 1524 1525 static void bindDerived(ClassTy &c) { 1526 c.def("__contains__", &PyDictAttribute::dunderContains); 1527 c.def("__len__", &PyDictAttribute::dunderLen); 1528 c.def_static( 1529 "get", 1530 [](nb::dict attributes, DefaultingPyMlirContext context) { 1531 SmallVector<MlirNamedAttribute> mlirNamedAttributes; 1532 mlirNamedAttributes.reserve(attributes.size()); 1533 for (std::pair<nb::handle, nb::handle> it : attributes) { 1534 auto &mlirAttr = nb::cast<PyAttribute &>(it.second); 1535 auto name = nb::cast<std::string>(it.first); 1536 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1537 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), 1538 toMlirStringRef(name)), 1539 mlirAttr)); 1540 } 1541 MlirAttribute attr = 1542 mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 1543 mlirNamedAttributes.data()); 1544 return PyDictAttribute(context->getRef(), attr); 1545 }, 1546 nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(), 1547 "Gets an uniqued dict attribute"); 1548 c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 1549 MlirAttribute attr = 1550 mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 1551 if (mlirAttributeIsNull(attr)) 1552 throw nb::key_error("attempt to access a non-existent attribute"); 1553 return attr; 1554 }); 1555 c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 1556 if (index < 0 || index >= self.dunderLen()) { 1557 throw nb::index_error("attempt to access out of bounds attribute"); 1558 } 1559 MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 1560 return PyNamedAttribute( 1561 namedAttr.attribute, 1562 std::string(mlirIdentifierStr(namedAttr.name).data)); 1563 }); 1564 } 1565 }; 1566 1567 /// Refinement of PyDenseElementsAttribute for attributes containing 1568 /// floating-point values. Supports element access. 1569 class PyDenseFPElementsAttribute 1570 : public PyConcreteAttribute<PyDenseFPElementsAttribute, 1571 PyDenseElementsAttribute> { 1572 public: 1573 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 1574 static constexpr const char *pyClassName = "DenseFPElementsAttr"; 1575 using PyConcreteAttribute::PyConcreteAttribute; 1576 1577 nb::float_ dunderGetItem(intptr_t pos) { 1578 if (pos < 0 || pos >= dunderLen()) { 1579 throw nb::index_error("attempt to access out of bounds element"); 1580 } 1581 1582 MlirType type = mlirAttributeGetType(*this); 1583 type = mlirShapedTypeGetElementType(type); 1584 // Dispatch element extraction to an appropriate C function based on the 1585 // elemental type of the attribute. nb::float_ is implicitly constructible 1586 // from float and double. 1587 // TODO: consider caching the type properties in the constructor to avoid 1588 // querying them on each element access. 1589 if (mlirTypeIsAF32(type)) { 1590 return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos)); 1591 } 1592 if (mlirTypeIsAF64(type)) { 1593 return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos)); 1594 } 1595 throw nb::type_error("Unsupported floating-point type"); 1596 } 1597 1598 static void bindDerived(ClassTy &c) { 1599 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 1600 } 1601 }; 1602 1603 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 1604 public: 1605 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 1606 static constexpr const char *pyClassName = "TypeAttr"; 1607 using PyConcreteAttribute::PyConcreteAttribute; 1608 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 1609 mlirTypeAttrGetTypeID; 1610 1611 static void bindDerived(ClassTy &c) { 1612 c.def_static( 1613 "get", 1614 [](PyType value, DefaultingPyMlirContext context) { 1615 MlirAttribute attr = mlirTypeAttrGet(value.get()); 1616 return PyTypeAttribute(context->getRef(), attr); 1617 }, 1618 nb::arg("value"), nb::arg("context").none() = nb::none(), 1619 "Gets a uniqued Type attribute"); 1620 c.def_prop_ro("value", [](PyTypeAttribute &self) { 1621 return mlirTypeAttrGetValue(self.get()); 1622 }); 1623 } 1624 }; 1625 1626 /// Unit Attribute subclass. Unit attributes don't have values. 1627 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 1628 public: 1629 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 1630 static constexpr const char *pyClassName = "UnitAttr"; 1631 using PyConcreteAttribute::PyConcreteAttribute; 1632 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 1633 mlirUnitAttrGetTypeID; 1634 1635 static void bindDerived(ClassTy &c) { 1636 c.def_static( 1637 "get", 1638 [](DefaultingPyMlirContext context) { 1639 return PyUnitAttribute(context->getRef(), 1640 mlirUnitAttrGet(context->get())); 1641 }, 1642 nb::arg("context").none() = nb::none(), "Create a Unit attribute."); 1643 } 1644 }; 1645 1646 /// Strided layout attribute subclass. 1647 class PyStridedLayoutAttribute 1648 : public PyConcreteAttribute<PyStridedLayoutAttribute> { 1649 public: 1650 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; 1651 static constexpr const char *pyClassName = "StridedLayoutAttr"; 1652 using PyConcreteAttribute::PyConcreteAttribute; 1653 static constexpr GetTypeIDFunctionTy getTypeIdFunction = 1654 mlirStridedLayoutAttrGetTypeID; 1655 1656 static void bindDerived(ClassTy &c) { 1657 c.def_static( 1658 "get", 1659 [](int64_t offset, const std::vector<int64_t> strides, 1660 DefaultingPyMlirContext ctx) { 1661 MlirAttribute attr = mlirStridedLayoutAttrGet( 1662 ctx->get(), offset, strides.size(), strides.data()); 1663 return PyStridedLayoutAttribute(ctx->getRef(), attr); 1664 }, 1665 nb::arg("offset"), nb::arg("strides"), 1666 nb::arg("context").none() = nb::none(), 1667 "Gets a strided layout attribute."); 1668 c.def_static( 1669 "get_fully_dynamic", 1670 [](int64_t rank, DefaultingPyMlirContext ctx) { 1671 auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); 1672 std::vector<int64_t> strides(rank); 1673 std::fill(strides.begin(), strides.end(), dynamic); 1674 MlirAttribute attr = mlirStridedLayoutAttrGet( 1675 ctx->get(), dynamic, strides.size(), strides.data()); 1676 return PyStridedLayoutAttribute(ctx->getRef(), attr); 1677 }, 1678 nb::arg("rank"), nb::arg("context").none() = nb::none(), 1679 "Gets a strided layout attribute with dynamic offset and strides of " 1680 "a " 1681 "given rank."); 1682 c.def_prop_ro( 1683 "offset", 1684 [](PyStridedLayoutAttribute &self) { 1685 return mlirStridedLayoutAttrGetOffset(self); 1686 }, 1687 "Returns the value of the float point attribute"); 1688 c.def_prop_ro( 1689 "strides", 1690 [](PyStridedLayoutAttribute &self) { 1691 intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); 1692 std::vector<int64_t> strides(size); 1693 for (intptr_t i = 0; i < size; i++) { 1694 strides[i] = mlirStridedLayoutAttrGetStride(self, i); 1695 } 1696 return strides; 1697 }, 1698 "Returns the value of the float point attribute"); 1699 } 1700 }; 1701 1702 nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { 1703 if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) 1704 return nb::cast(PyDenseBoolArrayAttribute(pyAttribute)); 1705 if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) 1706 return nb::cast(PyDenseI8ArrayAttribute(pyAttribute)); 1707 if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) 1708 return nb::cast(PyDenseI16ArrayAttribute(pyAttribute)); 1709 if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) 1710 return nb::cast(PyDenseI32ArrayAttribute(pyAttribute)); 1711 if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) 1712 return nb::cast(PyDenseI64ArrayAttribute(pyAttribute)); 1713 if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) 1714 return nb::cast(PyDenseF32ArrayAttribute(pyAttribute)); 1715 if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) 1716 return nb::cast(PyDenseF64ArrayAttribute(pyAttribute)); 1717 std::string msg = 1718 std::string("Can't cast unknown element type DenseArrayAttr (") + 1719 nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")"; 1720 throw nb::type_error(msg.c_str()); 1721 } 1722 1723 nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { 1724 if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) 1725 return nb::cast(PyDenseFPElementsAttribute(pyAttribute)); 1726 if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) 1727 return nb::cast(PyDenseIntElementsAttribute(pyAttribute)); 1728 std::string msg = 1729 std::string( 1730 "Can't cast unknown element type DenseIntOrFPElementsAttr (") + 1731 nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")"; 1732 throw nb::type_error(msg.c_str()); 1733 } 1734 1735 nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { 1736 if (PyBoolAttribute::isaFunction(pyAttribute)) 1737 return nb::cast(PyBoolAttribute(pyAttribute)); 1738 if (PyIntegerAttribute::isaFunction(pyAttribute)) 1739 return nb::cast(PyIntegerAttribute(pyAttribute)); 1740 std::string msg = 1741 std::string("Can't cast unknown element type DenseArrayAttr (") + 1742 nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")"; 1743 throw nb::type_error(msg.c_str()); 1744 } 1745 1746 nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { 1747 if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute)) 1748 return nb::cast(PyFlatSymbolRefAttribute(pyAttribute)); 1749 if (PySymbolRefAttribute::isaFunction(pyAttribute)) 1750 return nb::cast(PySymbolRefAttribute(pyAttribute)); 1751 std::string msg = std::string("Can't cast unknown SymbolRef attribute (") + 1752 nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + 1753 ")"; 1754 throw nb::type_error(msg.c_str()); 1755 } 1756 1757 } // namespace 1758 1759 void mlir::python::populateIRAttributes(nb::module_ &m) { 1760 PyAffineMapAttribute::bind(m); 1761 PyDenseBoolArrayAttribute::bind(m); 1762 PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); 1763 PyDenseI8ArrayAttribute::bind(m); 1764 PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m); 1765 PyDenseI16ArrayAttribute::bind(m); 1766 PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m); 1767 PyDenseI32ArrayAttribute::bind(m); 1768 PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m); 1769 PyDenseI64ArrayAttribute::bind(m); 1770 PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m); 1771 PyDenseF32ArrayAttribute::bind(m); 1772 PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); 1773 PyDenseF64ArrayAttribute::bind(m); 1774 PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); 1775 PyGlobals::get().registerTypeCaster( 1776 mlirDenseArrayAttrGetTypeID(), 1777 nb::cast<nb::callable>(nb::cpp_function(denseArrayAttributeCaster))); 1778 1779 PyArrayAttribute::bind(m); 1780 PyArrayAttribute::PyArrayAttributeIterator::bind(m); 1781 PyBoolAttribute::bind(m); 1782 PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots); 1783 PyDenseFPElementsAttribute::bind(m); 1784 PyDenseIntElementsAttribute::bind(m); 1785 PyGlobals::get().registerTypeCaster( 1786 mlirDenseIntOrFPElementsAttrGetTypeID(), 1787 nb::cast<nb::callable>( 1788 nb::cpp_function(denseIntOrFPElementsAttributeCaster))); 1789 PyDenseResourceElementsAttribute::bind(m); 1790 1791 PyDictAttribute::bind(m); 1792 PySymbolRefAttribute::bind(m); 1793 PyGlobals::get().registerTypeCaster( 1794 mlirSymbolRefAttrGetTypeID(), 1795 nb::cast<nb::callable>( 1796 nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster))); 1797 1798 PyFlatSymbolRefAttribute::bind(m); 1799 PyOpaqueAttribute::bind(m); 1800 PyFloatAttribute::bind(m); 1801 PyIntegerAttribute::bind(m); 1802 PyIntegerSetAttribute::bind(m); 1803 PyStringAttribute::bind(m); 1804 PyTypeAttribute::bind(m); 1805 PyGlobals::get().registerTypeCaster( 1806 mlirIntegerAttrGetTypeID(), 1807 nb::cast<nb::callable>(nb::cpp_function(integerOrBoolAttributeCaster))); 1808 PyUnitAttribute::bind(m); 1809 1810 PyStridedLayoutAttribute::bind(m); 1811 } 1812