1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2436c6c9cSStella Laurenzo // 3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6436c6c9cSStella Laurenzo // 7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8436c6c9cSStella Laurenzo 9a1fe1f5fSKazu Hirata #include <optional> 1071a25454SPeter Hawkins #include <string_view> 114811270bSmax #include <utility> 121fc096afSMehdi Amini 13436c6c9cSStella Laurenzo #include "IRModule.h" 14436c6c9cSStella Laurenzo 15436c6c9cSStella Laurenzo #include "PybindUtils.h" 161824e45cSKasper Nielsen #include <pybind11/numpy.h> 17436c6c9cSStella Laurenzo 1871a25454SPeter Hawkins #include "llvm/ADT/ScopeExit.h" 19c912f0e7Spranavm-nvidia #include "llvm/Support/raw_ostream.h" 2071a25454SPeter Hawkins 21436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h" 22436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h" 23bfb1ba75Smax #include "mlir/Bindings/Python/PybindAdaptors.h" 24436c6c9cSStella Laurenzo 25436c6c9cSStella Laurenzo namespace py = pybind11; 26436c6c9cSStella Laurenzo using namespace mlir; 27436c6c9cSStella Laurenzo using namespace mlir::python; 28436c6c9cSStella Laurenzo 29436c6c9cSStella Laurenzo using llvm::SmallVector; 30436c6c9cSStella Laurenzo 315d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 325d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline). 335d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 345d6d30edSStella Laurenzo 355d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] = 365d6d30edSStella Laurenzo R"(Gets a DenseElementsAttr from a Python buffer or array. 375d6d30edSStella Laurenzo 385d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based 395d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and 405d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these 415d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer. 425d6d30edSStella Laurenzo 435d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided 445d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal 455d6d30edSStella Laurenzo representation: 465d6d30edSStella Laurenzo 475d6d30edSStella Laurenzo * Integer types (except for i1): the buffer must be byte aligned to the 485d6d30edSStella Laurenzo next byte boundary. 495d6d30edSStella Laurenzo * Floating point types: Must be bit-castable to the given floating point 505d6d30edSStella Laurenzo size. 515d6d30edSStella Laurenzo * i1 (bool): Bit packed into 8bit words where the bit pattern matches a 525d6d30edSStella Laurenzo row major ordering. An arbitrary Numpy `bool_` array can be bit packed to 535d6d30edSStella Laurenzo this specification with: `np.packbits(ary, axis=None, bitorder='little')`. 545d6d30edSStella Laurenzo 555d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0 565d6d30edSStella Laurenzo or 255), then a splat will be created. 575d6d30edSStella Laurenzo 585d6d30edSStella Laurenzo Args: 595d6d30edSStella Laurenzo array: The array or buffer to convert. 605d6d30edSStella Laurenzo signless: If inferring an appropriate MLIR type, use signless types for 615d6d30edSStella Laurenzo integers (defaults True). 625d6d30edSStella Laurenzo type: Skips inference of the MLIR element type and uses this instead. The 635d6d30edSStella Laurenzo storage size must be consistent with the actual contents of the buffer. 645d6d30edSStella Laurenzo shape: Overrides the shape of the buffer when constructing the MLIR 655d6d30edSStella Laurenzo shaped type. This is needed when the physical and logical shape differ (as 665d6d30edSStella Laurenzo for i1). 675d6d30edSStella Laurenzo context: Explicit context, if not from context manager. 685d6d30edSStella Laurenzo 695d6d30edSStella Laurenzo Returns: 705d6d30edSStella Laurenzo DenseElementsAttr on success. 715d6d30edSStella Laurenzo 725d6d30edSStella Laurenzo Raises: 735d6d30edSStella Laurenzo ValueError: If the type of the buffer or array cannot be matched to an MLIR 745d6d30edSStella Laurenzo type or if the buffer does not meet expectations. 755d6d30edSStella Laurenzo )"; 765d6d30edSStella Laurenzo 77c912f0e7Spranavm-nvidia static const char kDenseElementsAttrGetFromListDocstring[] = 78c912f0e7Spranavm-nvidia R"(Gets a DenseElementsAttr from a Python list of attributes. 79c912f0e7Spranavm-nvidia 80c912f0e7Spranavm-nvidia Note that it can be expensive to construct attributes individually. 81c912f0e7Spranavm-nvidia For a large number of elements, consider using a Python buffer or array instead. 82c912f0e7Spranavm-nvidia 83c912f0e7Spranavm-nvidia Args: 84c912f0e7Spranavm-nvidia attrs: A list of attributes. 85c912f0e7Spranavm-nvidia type: The desired shape and type of the resulting DenseElementsAttr. 86c912f0e7Spranavm-nvidia If not provided, the element type is determined based on the type 87c912f0e7Spranavm-nvidia of the 0th attribute and the shape is `[len(attrs)]`. 88c912f0e7Spranavm-nvidia context: Explicit context, if not from context manager. 89c912f0e7Spranavm-nvidia 90c912f0e7Spranavm-nvidia Returns: 91c912f0e7Spranavm-nvidia DenseElementsAttr on success. 92c912f0e7Spranavm-nvidia 93c912f0e7Spranavm-nvidia Raises: 94c912f0e7Spranavm-nvidia ValueError: If the type of the attributes does not match the type 95c912f0e7Spranavm-nvidia specified by `shaped_type`. 96c912f0e7Spranavm-nvidia )"; 97c912f0e7Spranavm-nvidia 98f66cd9e9SStella Laurenzo static const char kDenseResourceElementsAttrGetFromBufferDocstring[] = 99f66cd9e9SStella Laurenzo R"(Gets a DenseResourceElementsAttr from a Python buffer or array. 100f66cd9e9SStella Laurenzo 101f66cd9e9SStella Laurenzo This function does minimal validation or massaging of the data, and it is 102f66cd9e9SStella Laurenzo up to the caller to ensure that the buffer meets the characteristics 103f66cd9e9SStella Laurenzo implied by the shape. 104f66cd9e9SStella Laurenzo 105f66cd9e9SStella Laurenzo The backing buffer and any user objects will be retained for the lifetime 106f66cd9e9SStella Laurenzo of the resource blob. This is typically bounded to the context but the 107f66cd9e9SStella Laurenzo resource can have a shorter lifespan depending on how it is used in 108f66cd9e9SStella Laurenzo subsequent processing. 109f66cd9e9SStella Laurenzo 110f66cd9e9SStella Laurenzo Args: 111f66cd9e9SStella Laurenzo buffer: The array or buffer to convert. 112f66cd9e9SStella Laurenzo name: Name to provide to the resource (may be changed upon collision). 113f66cd9e9SStella Laurenzo type: The explicit ShapedType to construct the attribute with. 114f66cd9e9SStella Laurenzo context: Explicit context, if not from context manager. 115f66cd9e9SStella Laurenzo 116f66cd9e9SStella Laurenzo Returns: 117f66cd9e9SStella Laurenzo DenseResourceElementsAttr on success. 118f66cd9e9SStella Laurenzo 119f66cd9e9SStella Laurenzo Raises: 120f66cd9e9SStella Laurenzo ValueError: If the type of the buffer or array cannot be matched to an MLIR 121f66cd9e9SStella Laurenzo type or if the buffer does not meet expectations. 122f66cd9e9SStella Laurenzo )"; 123f66cd9e9SStella Laurenzo 124436c6c9cSStella Laurenzo namespace { 125436c6c9cSStella Laurenzo 126436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) { 127436c6c9cSStella Laurenzo return mlirStringRefCreate(s.data(), s.size()); 128436c6c9cSStella Laurenzo } 129436c6c9cSStella Laurenzo 130436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 131436c6c9cSStella Laurenzo public: 132436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 133436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineMapAttr"; 134436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1359566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 1369566ee28Smax mlirAffineMapAttrGetTypeID; 137436c6c9cSStella Laurenzo 138436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 139436c6c9cSStella Laurenzo c.def_static( 140436c6c9cSStella Laurenzo "get", 141436c6c9cSStella Laurenzo [](PyAffineMap &affineMap) { 142436c6c9cSStella Laurenzo MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 143436c6c9cSStella Laurenzo return PyAffineMapAttribute(affineMap.getContext(), attr); 144436c6c9cSStella Laurenzo }, 145436c6c9cSStella Laurenzo py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 146c36b4248SBimo c.def_property_readonly("value", mlirAffineMapAttrGetValue, 147c36b4248SBimo "Returns the value of the AffineMap attribute"); 148436c6c9cSStella Laurenzo } 149436c6c9cSStella Laurenzo }; 150436c6c9cSStella Laurenzo 151334873feSAmy Wang class PyIntegerSetAttribute 152334873feSAmy Wang : public PyConcreteAttribute<PyIntegerSetAttribute> { 153334873feSAmy Wang public: 154334873feSAmy Wang static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet; 155334873feSAmy Wang static constexpr const char *pyClassName = "IntegerSetAttr"; 156334873feSAmy Wang using PyConcreteAttribute::PyConcreteAttribute; 157334873feSAmy Wang static constexpr GetTypeIDFunctionTy getTypeIdFunction = 158334873feSAmy Wang mlirIntegerSetAttrGetTypeID; 159334873feSAmy Wang 160334873feSAmy Wang static void bindDerived(ClassTy &c) { 161334873feSAmy Wang c.def_static( 162334873feSAmy Wang "get", 163334873feSAmy Wang [](PyIntegerSet &integerSet) { 164334873feSAmy Wang MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get()); 165334873feSAmy Wang return PyIntegerSetAttribute(integerSet.getContext(), attr); 166334873feSAmy Wang }, 167334873feSAmy Wang py::arg("integer_set"), "Gets an attribute wrapping an IntegerSet."); 168334873feSAmy Wang } 169334873feSAmy Wang }; 170334873feSAmy Wang 171ed9e52f3SAlex Zinenko template <typename T> 172ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) { 173ed9e52f3SAlex Zinenko try { 174ed9e52f3SAlex Zinenko return object.cast<T>(); 175ed9e52f3SAlex Zinenko } catch (py::cast_error &err) { 176ed9e52f3SAlex Zinenko std::string msg = 177ed9e52f3SAlex Zinenko std::string( 178ed9e52f3SAlex Zinenko "Invalid attribute when attempting to create an ArrayAttribute (") + 179ed9e52f3SAlex Zinenko err.what() + ")"; 180ed9e52f3SAlex Zinenko throw py::cast_error(msg); 181ed9e52f3SAlex Zinenko } catch (py::reference_cast_error &err) { 182ed9e52f3SAlex Zinenko std::string msg = std::string("Invalid attribute (None?) when attempting " 183ed9e52f3SAlex Zinenko "to create an ArrayAttribute (") + 184ed9e52f3SAlex Zinenko err.what() + ")"; 185ed9e52f3SAlex Zinenko throw py::cast_error(msg); 186ed9e52f3SAlex Zinenko } 187ed9e52f3SAlex Zinenko } 188ed9e52f3SAlex Zinenko 189619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived 190619fd8c2SJeff Niu /// implementation class. 191619fd8c2SJeff Niu template <typename EltTy, typename DerivedT> 192133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> { 193619fd8c2SJeff Niu public: 194133624acSJeff Niu using PyConcreteAttribute<DerivedT>::PyConcreteAttribute; 195619fd8c2SJeff Niu 196619fd8c2SJeff Niu /// Iterator over the integer elements of a dense array. 197619fd8c2SJeff Niu class PyDenseArrayIterator { 198619fd8c2SJeff Niu public: 1994a1b1196SMehdi Amini PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {} 200619fd8c2SJeff Niu 201619fd8c2SJeff Niu /// Return a copy of the iterator. 202619fd8c2SJeff Niu PyDenseArrayIterator dunderIter() { return *this; } 203619fd8c2SJeff Niu 204619fd8c2SJeff Niu /// Return the next element. 205619fd8c2SJeff Niu EltTy dunderNext() { 206619fd8c2SJeff Niu // Throw if the index has reached the end. 207619fd8c2SJeff Niu if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) 208619fd8c2SJeff Niu throw py::stop_iteration(); 209619fd8c2SJeff Niu return DerivedT::getElement(attr.get(), nextIndex++); 210619fd8c2SJeff Niu } 211619fd8c2SJeff Niu 212619fd8c2SJeff Niu /// Bind the iterator class. 213619fd8c2SJeff Niu static void bind(py::module &m) { 214619fd8c2SJeff Niu py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName, 215619fd8c2SJeff Niu py::module_local()) 216619fd8c2SJeff Niu .def("__iter__", &PyDenseArrayIterator::dunderIter) 217619fd8c2SJeff Niu .def("__next__", &PyDenseArrayIterator::dunderNext); 218619fd8c2SJeff Niu } 219619fd8c2SJeff Niu 220619fd8c2SJeff Niu private: 221619fd8c2SJeff Niu /// The referenced dense array attribute. 222619fd8c2SJeff Niu PyAttribute attr; 223619fd8c2SJeff Niu /// The next index to read. 224619fd8c2SJeff Niu int nextIndex = 0; 225619fd8c2SJeff Niu }; 226619fd8c2SJeff Niu 227619fd8c2SJeff Niu /// Get the element at the given index. 228619fd8c2SJeff Niu EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); } 229619fd8c2SJeff Niu 230619fd8c2SJeff Niu /// Bind the attribute class. 231133624acSJeff Niu static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) { 232619fd8c2SJeff Niu // Bind the constructor. 233619fd8c2SJeff Niu c.def_static( 234619fd8c2SJeff Niu "get", 235619fd8c2SJeff Niu [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) { 2368dcb6722SIngo Müller return getAttribute(values, ctx->getRef()); 237619fd8c2SJeff Niu }, 238619fd8c2SJeff Niu py::arg("values"), py::arg("context") = py::none(), 239619fd8c2SJeff Niu "Gets a uniqued dense array attribute"); 240619fd8c2SJeff Niu // Bind the array methods. 241133624acSJeff Niu c.def("__getitem__", [](DerivedT &arr, intptr_t i) { 242619fd8c2SJeff Niu if (i >= mlirDenseArrayGetNumElements(arr)) 243619fd8c2SJeff Niu throw py::index_error("DenseArray index out of range"); 244619fd8c2SJeff Niu return arr.getItem(i); 245619fd8c2SJeff Niu }); 246133624acSJeff Niu c.def("__len__", [](const DerivedT &arr) { 247619fd8c2SJeff Niu return mlirDenseArrayGetNumElements(arr); 248619fd8c2SJeff Niu }); 249133624acSJeff Niu c.def("__iter__", 250133624acSJeff Niu [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); 2514a1b1196SMehdi Amini c.def("__add__", [](DerivedT &arr, const py::list &extras) { 252619fd8c2SJeff Niu std::vector<EltTy> values; 253619fd8c2SJeff Niu intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); 254619fd8c2SJeff Niu values.reserve(numOldElements + py::len(extras)); 255619fd8c2SJeff Niu for (intptr_t i = 0; i < numOldElements; ++i) 256619fd8c2SJeff Niu values.push_back(arr.getItem(i)); 257619fd8c2SJeff Niu for (py::handle attr : extras) 258619fd8c2SJeff Niu values.push_back(pyTryCast<EltTy>(attr)); 2598dcb6722SIngo Müller return getAttribute(values, arr.getContext()); 260619fd8c2SJeff Niu }); 261619fd8c2SJeff Niu } 2628dcb6722SIngo Müller 2638dcb6722SIngo Müller private: 2648dcb6722SIngo Müller static DerivedT getAttribute(const std::vector<EltTy> &values, 2658dcb6722SIngo Müller PyMlirContextRef ctx) { 2668dcb6722SIngo Müller if constexpr (std::is_same_v<EltTy, bool>) { 2678dcb6722SIngo Müller std::vector<int> intValues(values.begin(), values.end()); 2688dcb6722SIngo Müller MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(), 2698dcb6722SIngo Müller intValues.data()); 2708dcb6722SIngo Müller return DerivedT(ctx, attr); 2718dcb6722SIngo Müller } else { 2728dcb6722SIngo Müller MlirAttribute attr = 2738dcb6722SIngo Müller DerivedT::getAttribute(ctx->get(), values.size(), values.data()); 2748dcb6722SIngo Müller return DerivedT(ctx, attr); 2758dcb6722SIngo Müller } 2768dcb6722SIngo Müller } 277619fd8c2SJeff Niu }; 278619fd8c2SJeff Niu 279619fd8c2SJeff Niu /// Instantiate the python dense array classes. 280619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute 2818dcb6722SIngo Müller : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> { 282619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; 283619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseBoolArrayGet; 284619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseBoolArrayGetElement; 285619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseBoolArrayAttr"; 286619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseBoolArrayIterator"; 287619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 288619fd8c2SJeff Niu }; 289619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute 290619fd8c2SJeff Niu : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> { 291619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array; 292619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI8ArrayGet; 293619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI8ArrayGetElement; 294619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI8ArrayAttr"; 295619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI8ArrayIterator"; 296619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 297619fd8c2SJeff Niu }; 298619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute 299619fd8c2SJeff Niu : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> { 300619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array; 301619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI16ArrayGet; 302619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI16ArrayGetElement; 303619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI16ArrayAttr"; 304619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI16ArrayIterator"; 305619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 306619fd8c2SJeff Niu }; 307619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute 308619fd8c2SJeff Niu : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> { 309619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array; 310619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI32ArrayGet; 311619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI32ArrayGetElement; 312619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI32ArrayAttr"; 313619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI32ArrayIterator"; 314619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 315619fd8c2SJeff Niu }; 316619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute 317619fd8c2SJeff Niu : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> { 318619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array; 319619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI64ArrayGet; 320619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI64ArrayGetElement; 321619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI64ArrayAttr"; 322619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI64ArrayIterator"; 323619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 324619fd8c2SJeff Niu }; 325619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute 326619fd8c2SJeff Niu : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> { 327619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array; 328619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseF32ArrayGet; 329619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseF32ArrayGetElement; 330619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseF32ArrayAttr"; 331619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseF32ArrayIterator"; 332619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 333619fd8c2SJeff Niu }; 334619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute 335619fd8c2SJeff Niu : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> { 336619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array; 337619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseF64ArrayGet; 338619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseF64ArrayGetElement; 339619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseF64ArrayAttr"; 340619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseF64ArrayIterator"; 341619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 342619fd8c2SJeff Niu }; 343619fd8c2SJeff Niu 344436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 345436c6c9cSStella Laurenzo public: 346436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 347436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "ArrayAttr"; 348436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 3499566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 3509566ee28Smax mlirArrayAttrGetTypeID; 351436c6c9cSStella Laurenzo 352436c6c9cSStella Laurenzo class PyArrayAttributeIterator { 353436c6c9cSStella Laurenzo public: 3541fc096afSMehdi Amini PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} 355436c6c9cSStella Laurenzo 356436c6c9cSStella Laurenzo PyArrayAttributeIterator &dunderIter() { return *this; } 357436c6c9cSStella Laurenzo 358974c1596SRahul Kayaith MlirAttribute dunderNext() { 359bca88952SJeff Niu // TODO: Throw is an inefficient way to stop iteration. 360bca88952SJeff Niu if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) 361436c6c9cSStella Laurenzo throw py::stop_iteration(); 362974c1596SRahul Kayaith return mlirArrayAttrGetElement(attr.get(), nextIndex++); 363436c6c9cSStella Laurenzo } 364436c6c9cSStella Laurenzo 365436c6c9cSStella Laurenzo static void bind(py::module &m) { 366f05ff4f7SStella Laurenzo py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator", 367f05ff4f7SStella Laurenzo py::module_local()) 368436c6c9cSStella Laurenzo .def("__iter__", &PyArrayAttributeIterator::dunderIter) 369436c6c9cSStella Laurenzo .def("__next__", &PyArrayAttributeIterator::dunderNext); 370436c6c9cSStella Laurenzo } 371436c6c9cSStella Laurenzo 372436c6c9cSStella Laurenzo private: 373436c6c9cSStella Laurenzo PyAttribute attr; 374436c6c9cSStella Laurenzo int nextIndex = 0; 375436c6c9cSStella Laurenzo }; 376436c6c9cSStella Laurenzo 377974c1596SRahul Kayaith MlirAttribute getItem(intptr_t i) { 378974c1596SRahul Kayaith return mlirArrayAttrGetElement(*this, i); 379ed9e52f3SAlex Zinenko } 380ed9e52f3SAlex Zinenko 381436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 382436c6c9cSStella Laurenzo c.def_static( 383436c6c9cSStella Laurenzo "get", 384436c6c9cSStella Laurenzo [](py::list attributes, DefaultingPyMlirContext context) { 385436c6c9cSStella Laurenzo SmallVector<MlirAttribute> mlirAttributes; 386436c6c9cSStella Laurenzo mlirAttributes.reserve(py::len(attributes)); 387436c6c9cSStella Laurenzo for (auto attribute : attributes) { 388ed9e52f3SAlex Zinenko mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); 389436c6c9cSStella Laurenzo } 390436c6c9cSStella Laurenzo MlirAttribute attr = mlirArrayAttrGet( 391436c6c9cSStella Laurenzo context->get(), mlirAttributes.size(), mlirAttributes.data()); 392436c6c9cSStella Laurenzo return PyArrayAttribute(context->getRef(), attr); 393436c6c9cSStella Laurenzo }, 394436c6c9cSStella Laurenzo py::arg("attributes"), py::arg("context") = py::none(), 395436c6c9cSStella Laurenzo "Gets a uniqued Array attribute"); 396436c6c9cSStella Laurenzo c.def("__getitem__", 397436c6c9cSStella Laurenzo [](PyArrayAttribute &arr, intptr_t i) { 398436c6c9cSStella Laurenzo if (i >= mlirArrayAttrGetNumElements(arr)) 399436c6c9cSStella Laurenzo throw py::index_error("ArrayAttribute index out of range"); 400ed9e52f3SAlex Zinenko return arr.getItem(i); 401436c6c9cSStella Laurenzo }) 402436c6c9cSStella Laurenzo .def("__len__", 403436c6c9cSStella Laurenzo [](const PyArrayAttribute &arr) { 404436c6c9cSStella Laurenzo return mlirArrayAttrGetNumElements(arr); 405436c6c9cSStella Laurenzo }) 406436c6c9cSStella Laurenzo .def("__iter__", [](const PyArrayAttribute &arr) { 407436c6c9cSStella Laurenzo return PyArrayAttributeIterator(arr); 408436c6c9cSStella Laurenzo }); 409ed9e52f3SAlex Zinenko c.def("__add__", [](PyArrayAttribute arr, py::list extras) { 410ed9e52f3SAlex Zinenko std::vector<MlirAttribute> attributes; 411ed9e52f3SAlex Zinenko intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); 412ed9e52f3SAlex Zinenko attributes.reserve(numOldElements + py::len(extras)); 413ed9e52f3SAlex Zinenko for (intptr_t i = 0; i < numOldElements; ++i) 414ed9e52f3SAlex Zinenko attributes.push_back(arr.getItem(i)); 415ed9e52f3SAlex Zinenko for (py::handle attr : extras) 416ed9e52f3SAlex Zinenko attributes.push_back(pyTryCast<PyAttribute>(attr)); 417ed9e52f3SAlex Zinenko MlirAttribute arrayAttr = mlirArrayAttrGet( 418ed9e52f3SAlex Zinenko arr.getContext()->get(), attributes.size(), attributes.data()); 419ed9e52f3SAlex Zinenko return PyArrayAttribute(arr.getContext(), arrayAttr); 420ed9e52f3SAlex Zinenko }); 421436c6c9cSStella Laurenzo } 422436c6c9cSStella Laurenzo }; 423436c6c9cSStella Laurenzo 424436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr. 425436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 426436c6c9cSStella Laurenzo public: 427436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 428436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FloatAttr"; 429436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 4309566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 4319566ee28Smax mlirFloatAttrGetTypeID; 432436c6c9cSStella Laurenzo 433436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 434436c6c9cSStella Laurenzo c.def_static( 435436c6c9cSStella Laurenzo "get", 436436c6c9cSStella Laurenzo [](PyType &type, double value, DefaultingPyLocation loc) { 4373ea4c501SRahul Kayaith PyMlirContext::ErrorCapture errors(loc->getContext()); 438436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 4393ea4c501SRahul Kayaith if (mlirAttributeIsNull(attr)) 4403ea4c501SRahul Kayaith throw MLIRError("Invalid attribute", errors.take()); 441436c6c9cSStella Laurenzo return PyFloatAttribute(type.getContext(), attr); 442436c6c9cSStella Laurenzo }, 443436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 444436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a type"); 445436c6c9cSStella Laurenzo c.def_static( 446436c6c9cSStella Laurenzo "get_f32", 447436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 448436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 449436c6c9cSStella Laurenzo context->get(), mlirF32TypeGet(context->get()), value); 450436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 451436c6c9cSStella Laurenzo }, 452436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 453436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f32 type"); 454436c6c9cSStella Laurenzo c.def_static( 455436c6c9cSStella Laurenzo "get_f64", 456436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 457436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 458436c6c9cSStella Laurenzo context->get(), mlirF64TypeGet(context->get()), value); 459436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 460436c6c9cSStella Laurenzo }, 461436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 462436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f64 type"); 4632a5d4974SIngo Müller c.def_property_readonly("value", mlirFloatAttrGetValueDouble, 4642a5d4974SIngo Müller "Returns the value of the float attribute"); 4652a5d4974SIngo Müller c.def("__float__", mlirFloatAttrGetValueDouble, 4662a5d4974SIngo Müller "Converts the value of the float attribute to a Python float"); 467436c6c9cSStella Laurenzo } 468436c6c9cSStella Laurenzo }; 469436c6c9cSStella Laurenzo 470436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr. 471436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 472436c6c9cSStella Laurenzo public: 473436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 474436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "IntegerAttr"; 475436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 476436c6c9cSStella Laurenzo 477436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 478436c6c9cSStella Laurenzo c.def_static( 479436c6c9cSStella Laurenzo "get", 480436c6c9cSStella Laurenzo [](PyType &type, int64_t value) { 481436c6c9cSStella Laurenzo MlirAttribute attr = mlirIntegerAttrGet(type, value); 482436c6c9cSStella Laurenzo return PyIntegerAttribute(type.getContext(), attr); 483436c6c9cSStella Laurenzo }, 484436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), 485436c6c9cSStella Laurenzo "Gets an uniqued integer attribute associated to a type"); 4862a5d4974SIngo Müller c.def_property_readonly("value", toPyInt, 4872a5d4974SIngo Müller "Returns the value of the integer attribute"); 4882a5d4974SIngo Müller c.def("__int__", toPyInt, 4892a5d4974SIngo Müller "Converts the value of the integer attribute to a Python int"); 4902a5d4974SIngo Müller c.def_property_readonly_static("static_typeid", 4912a5d4974SIngo Müller [](py::object & /*class*/) -> MlirTypeID { 4922a5d4974SIngo Müller return mlirIntegerAttrGetTypeID(); 4932a5d4974SIngo Müller }); 4942a5d4974SIngo Müller } 4952a5d4974SIngo Müller 4962a5d4974SIngo Müller private: 4972a5d4974SIngo Müller static py::int_ toPyInt(PyIntegerAttribute &self) { 498e9db306dSrkayaith MlirType type = mlirAttributeGetType(self); 499e9db306dSrkayaith if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) 500436c6c9cSStella Laurenzo return mlirIntegerAttrGetValueInt(self); 501e9db306dSrkayaith if (mlirIntegerTypeIsSigned(type)) 502e9db306dSrkayaith return mlirIntegerAttrGetValueSInt(self); 503e9db306dSrkayaith return mlirIntegerAttrGetValueUInt(self); 504436c6c9cSStella Laurenzo } 505436c6c9cSStella Laurenzo }; 506436c6c9cSStella Laurenzo 507436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr. 508436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 509436c6c9cSStella Laurenzo public: 510436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 511436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "BoolAttr"; 512436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 513436c6c9cSStella Laurenzo 514436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 515436c6c9cSStella Laurenzo c.def_static( 516436c6c9cSStella Laurenzo "get", 517436c6c9cSStella Laurenzo [](bool value, DefaultingPyMlirContext context) { 518436c6c9cSStella Laurenzo MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 519436c6c9cSStella Laurenzo return PyBoolAttribute(context->getRef(), attr); 520436c6c9cSStella Laurenzo }, 521436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 522436c6c9cSStella Laurenzo "Gets an uniqued bool attribute"); 5232a5d4974SIngo Müller c.def_property_readonly("value", mlirBoolAttrGetValue, 524436c6c9cSStella Laurenzo "Returns the value of the bool attribute"); 5252a5d4974SIngo Müller c.def("__bool__", mlirBoolAttrGetValue, 5262a5d4974SIngo Müller "Converts the value of the bool attribute to a Python bool"); 527436c6c9cSStella Laurenzo } 528436c6c9cSStella Laurenzo }; 529436c6c9cSStella Laurenzo 5304eee9ef9Smax class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> { 5314eee9ef9Smax public: 5324eee9ef9Smax static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef; 5334eee9ef9Smax static constexpr const char *pyClassName = "SymbolRefAttr"; 5344eee9ef9Smax using PyConcreteAttribute::PyConcreteAttribute; 5354eee9ef9Smax 5364eee9ef9Smax static MlirAttribute fromList(const std::vector<std::string> &symbols, 5374eee9ef9Smax PyMlirContext &context) { 5384eee9ef9Smax if (symbols.empty()) 5394eee9ef9Smax throw std::runtime_error("SymbolRefAttr must be composed of at least " 5404eee9ef9Smax "one symbol."); 5414eee9ef9Smax MlirStringRef rootSymbol = toMlirStringRef(symbols[0]); 5424eee9ef9Smax SmallVector<MlirAttribute, 3> referenceAttrs; 5434eee9ef9Smax for (size_t i = 1; i < symbols.size(); ++i) { 5444eee9ef9Smax referenceAttrs.push_back( 5454eee9ef9Smax mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i]))); 5464eee9ef9Smax } 5474eee9ef9Smax return mlirSymbolRefAttrGet(context.get(), rootSymbol, 5484eee9ef9Smax referenceAttrs.size(), referenceAttrs.data()); 5494eee9ef9Smax } 5504eee9ef9Smax 5514eee9ef9Smax static void bindDerived(ClassTy &c) { 5524eee9ef9Smax c.def_static( 5534eee9ef9Smax "get", 5544eee9ef9Smax [](const std::vector<std::string> &symbols, 5554eee9ef9Smax DefaultingPyMlirContext context) { 5564eee9ef9Smax return PySymbolRefAttribute::fromList(symbols, context.resolve()); 5574eee9ef9Smax }, 5584eee9ef9Smax py::arg("symbols"), py::arg("context") = py::none(), 5594eee9ef9Smax "Gets a uniqued SymbolRef attribute from a list of symbol names"); 5604eee9ef9Smax c.def_property_readonly( 5614eee9ef9Smax "value", 5624eee9ef9Smax [](PySymbolRefAttribute &self) { 5634eee9ef9Smax std::vector<std::string> symbols = { 5644eee9ef9Smax unwrap(mlirSymbolRefAttrGetRootReference(self)).str()}; 5654eee9ef9Smax for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); 5664eee9ef9Smax ++i) 5674eee9ef9Smax symbols.push_back( 5684eee9ef9Smax unwrap(mlirSymbolRefAttrGetRootReference( 5694eee9ef9Smax mlirSymbolRefAttrGetNestedReference(self, i))) 5704eee9ef9Smax .str()); 5714eee9ef9Smax return symbols; 5724eee9ef9Smax }, 5734eee9ef9Smax "Returns the value of the SymbolRef attribute as a list[str]"); 5744eee9ef9Smax } 5754eee9ef9Smax }; 5764eee9ef9Smax 577436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute 578436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 579436c6c9cSStella Laurenzo public: 580436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 581436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 582436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 583436c6c9cSStella Laurenzo 584436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 585436c6c9cSStella Laurenzo c.def_static( 586436c6c9cSStella Laurenzo "get", 587436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 588436c6c9cSStella Laurenzo MlirAttribute attr = 589436c6c9cSStella Laurenzo mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 590436c6c9cSStella Laurenzo return PyFlatSymbolRefAttribute(context->getRef(), attr); 591436c6c9cSStella Laurenzo }, 592436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 593436c6c9cSStella Laurenzo "Gets a uniqued FlatSymbolRef attribute"); 594436c6c9cSStella Laurenzo c.def_property_readonly( 595436c6c9cSStella Laurenzo "value", 596436c6c9cSStella Laurenzo [](PyFlatSymbolRefAttribute &self) { 597436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 598436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 599436c6c9cSStella Laurenzo }, 600436c6c9cSStella Laurenzo "Returns the value of the FlatSymbolRef attribute as a string"); 601436c6c9cSStella Laurenzo } 602436c6c9cSStella Laurenzo }; 603436c6c9cSStella Laurenzo 6045c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> { 6055c3861b2SYun Long public: 6065c3861b2SYun Long static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; 6075c3861b2SYun Long static constexpr const char *pyClassName = "OpaqueAttr"; 6085c3861b2SYun Long using PyConcreteAttribute::PyConcreteAttribute; 6099566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 6109566ee28Smax mlirOpaqueAttrGetTypeID; 6115c3861b2SYun Long 6125c3861b2SYun Long static void bindDerived(ClassTy &c) { 6135c3861b2SYun Long c.def_static( 6145c3861b2SYun Long "get", 6155c3861b2SYun Long [](std::string dialectNamespace, py::buffer buffer, PyType &type, 6165c3861b2SYun Long DefaultingPyMlirContext context) { 6175c3861b2SYun Long const py::buffer_info bufferInfo = buffer.request(); 6185c3861b2SYun Long intptr_t bufferSize = bufferInfo.size; 6195c3861b2SYun Long MlirAttribute attr = mlirOpaqueAttrGet( 6205c3861b2SYun Long context->get(), toMlirStringRef(dialectNamespace), bufferSize, 6215c3861b2SYun Long static_cast<char *>(bufferInfo.ptr), type); 6225c3861b2SYun Long return PyOpaqueAttribute(context->getRef(), attr); 6235c3861b2SYun Long }, 6245c3861b2SYun Long py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"), 6255c3861b2SYun Long py::arg("context") = py::none(), "Gets an Opaque attribute."); 6265c3861b2SYun Long c.def_property_readonly( 6275c3861b2SYun Long "dialect_namespace", 6285c3861b2SYun Long [](PyOpaqueAttribute &self) { 6295c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); 6305c3861b2SYun Long return py::str(stringRef.data, stringRef.length); 6315c3861b2SYun Long }, 6325c3861b2SYun Long "Returns the dialect namespace for the Opaque attribute as a string"); 6335c3861b2SYun Long c.def_property_readonly( 6345c3861b2SYun Long "data", 6355c3861b2SYun Long [](PyOpaqueAttribute &self) { 6365c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetData(self); 63762bf6c2eSChris Jones return py::bytes(stringRef.data, stringRef.length); 6385c3861b2SYun Long }, 63962bf6c2eSChris Jones "Returns the data for the Opaqued attributes as `bytes`"); 6405c3861b2SYun Long } 6415c3861b2SYun Long }; 6425c3861b2SYun Long 643436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 644436c6c9cSStella Laurenzo public: 645436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 646436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "StringAttr"; 647436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 6489566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 6499566ee28Smax mlirStringAttrGetTypeID; 650436c6c9cSStella Laurenzo 651436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 652436c6c9cSStella Laurenzo c.def_static( 653436c6c9cSStella Laurenzo "get", 654436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 655436c6c9cSStella Laurenzo MlirAttribute attr = 656436c6c9cSStella Laurenzo mlirStringAttrGet(context->get(), toMlirStringRef(value)); 657436c6c9cSStella Laurenzo return PyStringAttribute(context->getRef(), attr); 658436c6c9cSStella Laurenzo }, 659436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 660436c6c9cSStella Laurenzo "Gets a uniqued string attribute"); 661436c6c9cSStella Laurenzo c.def_static( 662436c6c9cSStella Laurenzo "get_typed", 663436c6c9cSStella Laurenzo [](PyType &type, std::string value) { 664436c6c9cSStella Laurenzo MlirAttribute attr = 665436c6c9cSStella Laurenzo mlirStringAttrTypedGet(type, toMlirStringRef(value)); 666436c6c9cSStella Laurenzo return PyStringAttribute(type.getContext(), attr); 667436c6c9cSStella Laurenzo }, 668a6e7d024SStella Laurenzo py::arg("type"), py::arg("value"), 669436c6c9cSStella Laurenzo "Gets a uniqued string attribute associated to a type"); 6709f533548SIngo Müller c.def_property_readonly( 6719f533548SIngo Müller "value", 6729f533548SIngo Müller [](PyStringAttribute &self) { 6739f533548SIngo Müller MlirStringRef stringRef = mlirStringAttrGetValue(self); 6749f533548SIngo Müller return py::str(stringRef.data, stringRef.length); 6759f533548SIngo Müller }, 676436c6c9cSStella Laurenzo "Returns the value of the string attribute"); 67762bf6c2eSChris Jones c.def_property_readonly( 67862bf6c2eSChris Jones "value_bytes", 67962bf6c2eSChris Jones [](PyStringAttribute &self) { 68062bf6c2eSChris Jones MlirStringRef stringRef = mlirStringAttrGetValue(self); 68162bf6c2eSChris Jones return py::bytes(stringRef.data, stringRef.length); 68262bf6c2eSChris Jones }, 68362bf6c2eSChris Jones "Returns the value of the string attribute as `bytes`"); 684436c6c9cSStella Laurenzo } 685436c6c9cSStella Laurenzo }; 686436c6c9cSStella Laurenzo 687436c6c9cSStella Laurenzo // TODO: Support construction of string elements. 688436c6c9cSStella Laurenzo class PyDenseElementsAttribute 689436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseElementsAttribute> { 690436c6c9cSStella Laurenzo public: 691436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 692436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseElementsAttr"; 693436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 694436c6c9cSStella Laurenzo 695436c6c9cSStella Laurenzo static PyDenseElementsAttribute 696c912f0e7Spranavm-nvidia getFromList(py::list attributes, std::optional<PyType> explicitType, 697c912f0e7Spranavm-nvidia DefaultingPyMlirContext contextWrapper) { 698c912f0e7Spranavm-nvidia 699c912f0e7Spranavm-nvidia const size_t numAttributes = py::len(attributes); 700c912f0e7Spranavm-nvidia if (numAttributes == 0) 701c912f0e7Spranavm-nvidia throw py::value_error("Attributes list must be non-empty."); 702c912f0e7Spranavm-nvidia 703c912f0e7Spranavm-nvidia MlirType shapedType; 704c912f0e7Spranavm-nvidia if (explicitType) { 705c912f0e7Spranavm-nvidia if ((!mlirTypeIsAShaped(*explicitType) || 706c912f0e7Spranavm-nvidia !mlirShapedTypeHasStaticShape(*explicitType))) { 707c912f0e7Spranavm-nvidia 708c912f0e7Spranavm-nvidia std::string message; 709c912f0e7Spranavm-nvidia llvm::raw_string_ostream os(message); 710c912f0e7Spranavm-nvidia os << "Expected a static ShapedType for the shaped_type parameter: " 711c912f0e7Spranavm-nvidia << py::repr(py::cast(*explicitType)); 712095b41c6SJOE1994 throw py::value_error(message); 713c912f0e7Spranavm-nvidia } 714c912f0e7Spranavm-nvidia shapedType = *explicitType; 715c912f0e7Spranavm-nvidia } else { 716c912f0e7Spranavm-nvidia SmallVector<int64_t> shape{static_cast<int64_t>(numAttributes)}; 717c912f0e7Spranavm-nvidia shapedType = mlirRankedTensorTypeGet( 718c912f0e7Spranavm-nvidia shape.size(), shape.data(), 719c912f0e7Spranavm-nvidia mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])), 720c912f0e7Spranavm-nvidia mlirAttributeGetNull()); 721c912f0e7Spranavm-nvidia } 722c912f0e7Spranavm-nvidia 723c912f0e7Spranavm-nvidia SmallVector<MlirAttribute> mlirAttributes; 724c912f0e7Spranavm-nvidia mlirAttributes.reserve(numAttributes); 725c912f0e7Spranavm-nvidia for (const py::handle &attribute : attributes) { 726c912f0e7Spranavm-nvidia MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute); 727c912f0e7Spranavm-nvidia MlirType attrType = mlirAttributeGetType(mlirAttribute); 728c912f0e7Spranavm-nvidia mlirAttributes.push_back(mlirAttribute); 729c912f0e7Spranavm-nvidia 730c912f0e7Spranavm-nvidia if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) { 731c912f0e7Spranavm-nvidia std::string message; 732c912f0e7Spranavm-nvidia llvm::raw_string_ostream os(message); 733c912f0e7Spranavm-nvidia os << "All attributes must be of the same type and match " 734c912f0e7Spranavm-nvidia << "the type parameter: expected=" << py::repr(py::cast(shapedType)) 735c912f0e7Spranavm-nvidia << ", but got=" << py::repr(py::cast(attrType)); 736095b41c6SJOE1994 throw py::value_error(message); 737c912f0e7Spranavm-nvidia } 738c912f0e7Spranavm-nvidia } 739c912f0e7Spranavm-nvidia 740c912f0e7Spranavm-nvidia MlirAttribute elements = mlirDenseElementsAttrGet( 741c912f0e7Spranavm-nvidia shapedType, mlirAttributes.size(), mlirAttributes.data()); 742c912f0e7Spranavm-nvidia 743c912f0e7Spranavm-nvidia return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 744c912f0e7Spranavm-nvidia } 745c912f0e7Spranavm-nvidia 746c912f0e7Spranavm-nvidia static PyDenseElementsAttribute 7470a81ace0SKazu Hirata getFromBuffer(py::buffer array, bool signless, 7480a81ace0SKazu Hirata std::optional<PyType> explicitType, 7490a81ace0SKazu Hirata std::optional<std::vector<int64_t>> explicitShape, 750436c6c9cSStella Laurenzo DefaultingPyMlirContext contextWrapper) { 751436c6c9cSStella Laurenzo // Request a contiguous view. In exotic cases, this will cause a copy. 75271a25454SPeter Hawkins int flags = PyBUF_ND; 75371a25454SPeter Hawkins if (!explicitType) { 75471a25454SPeter Hawkins flags |= PyBUF_FORMAT; 75571a25454SPeter Hawkins } 75671a25454SPeter Hawkins Py_buffer view; 75771a25454SPeter Hawkins if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) { 758436c6c9cSStella Laurenzo throw py::error_already_set(); 759436c6c9cSStella Laurenzo } 76071a25454SPeter Hawkins auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); 761436c6c9cSStella Laurenzo 762436c6c9cSStella Laurenzo MlirContext context = contextWrapper->get(); 7631824e45cSKasper Nielsen MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, 7641824e45cSKasper Nielsen explicitShape, context); 7655d6d30edSStella Laurenzo if (mlirAttributeIsNull(attr)) { 7665d6d30edSStella Laurenzo throw std::invalid_argument( 7675d6d30edSStella Laurenzo "DenseElementsAttr could not be constructed from the given buffer. " 7685d6d30edSStella Laurenzo "This may mean that the Python buffer layout does not match that " 7695d6d30edSStella Laurenzo "MLIR expected layout and is a bug."); 7705d6d30edSStella Laurenzo } 7715d6d30edSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), attr); 7725d6d30edSStella Laurenzo } 773436c6c9cSStella Laurenzo 7741fc096afSMehdi Amini static PyDenseElementsAttribute getSplat(const PyType &shapedType, 775436c6c9cSStella Laurenzo PyAttribute &elementAttr) { 776436c6c9cSStella Laurenzo auto contextWrapper = 777436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 778436c6c9cSStella Laurenzo if (!mlirAttributeIsAInteger(elementAttr) && 779436c6c9cSStella Laurenzo !mlirAttributeIsAFloat(elementAttr)) { 780436c6c9cSStella Laurenzo std::string message = "Illegal element type for DenseElementsAttr: "; 781436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 7824811270bSmax throw py::value_error(message); 783436c6c9cSStella Laurenzo } 784436c6c9cSStella Laurenzo if (!mlirTypeIsAShaped(shapedType) || 785436c6c9cSStella Laurenzo !mlirShapedTypeHasStaticShape(shapedType)) { 786436c6c9cSStella Laurenzo std::string message = 787436c6c9cSStella Laurenzo "Expected a static ShapedType for the shaped_type parameter: "; 788436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 7894811270bSmax throw py::value_error(message); 790436c6c9cSStella Laurenzo } 791436c6c9cSStella Laurenzo MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 792436c6c9cSStella Laurenzo MlirType attrType = mlirAttributeGetType(elementAttr); 793436c6c9cSStella Laurenzo if (!mlirTypeEqual(shapedElementType, attrType)) { 794436c6c9cSStella Laurenzo std::string message = 795436c6c9cSStella Laurenzo "Shaped element type and attribute type must be equal: shaped="; 796436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 797436c6c9cSStella Laurenzo message.append(", element="); 798436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 7994811270bSmax throw py::value_error(message); 800436c6c9cSStella Laurenzo } 801436c6c9cSStella Laurenzo 802436c6c9cSStella Laurenzo MlirAttribute elements = 803436c6c9cSStella Laurenzo mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 804436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 805436c6c9cSStella Laurenzo } 806436c6c9cSStella Laurenzo 807436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 808436c6c9cSStella Laurenzo 809436c6c9cSStella Laurenzo py::buffer_info accessBuffer() { 810436c6c9cSStella Laurenzo MlirType shapedType = mlirAttributeGetType(*this); 811436c6c9cSStella Laurenzo MlirType elementType = mlirShapedTypeGetElementType(shapedType); 8125d6d30edSStella Laurenzo std::string format; 813436c6c9cSStella Laurenzo 814436c6c9cSStella Laurenzo if (mlirTypeIsAF32(elementType)) { 815436c6c9cSStella Laurenzo // f32 8165d6d30edSStella Laurenzo return bufferInfo<float>(shapedType); 81702b6fb21SMehdi Amini } 81802b6fb21SMehdi Amini if (mlirTypeIsAF64(elementType)) { 819436c6c9cSStella Laurenzo // f64 8205d6d30edSStella Laurenzo return bufferInfo<double>(shapedType); 821bb56c2b3SMehdi Amini } 822bb56c2b3SMehdi Amini if (mlirTypeIsAF16(elementType)) { 8235d6d30edSStella Laurenzo // f16 8245d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType, "e"); 825bb56c2b3SMehdi Amini } 826ef1b735dSmax if (mlirTypeIsAIndex(elementType)) { 827ef1b735dSmax // Same as IndexType::kInternalStorageBitWidth 828ef1b735dSmax return bufferInfo<int64_t>(shapedType); 829ef1b735dSmax } 830bb56c2b3SMehdi Amini if (mlirTypeIsAInteger(elementType) && 831436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 32) { 832436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 833436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 834436c6c9cSStella Laurenzo // i32 8355d6d30edSStella Laurenzo return bufferInfo<int32_t>(shapedType); 836e5639b3fSMehdi Amini } 837e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 838436c6c9cSStella Laurenzo // unsigned i32 8395d6d30edSStella Laurenzo return bufferInfo<uint32_t>(shapedType); 840436c6c9cSStella Laurenzo } 841436c6c9cSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 842436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 64) { 843436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 844436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 845436c6c9cSStella Laurenzo // i64 8465d6d30edSStella Laurenzo return bufferInfo<int64_t>(shapedType); 847e5639b3fSMehdi Amini } 848e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 849436c6c9cSStella Laurenzo // unsigned i64 8505d6d30edSStella Laurenzo return bufferInfo<uint64_t>(shapedType); 8515d6d30edSStella Laurenzo } 8525d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 8535d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 8) { 8545d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 8555d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 8565d6d30edSStella Laurenzo // i8 8575d6d30edSStella Laurenzo return bufferInfo<int8_t>(shapedType); 858e5639b3fSMehdi Amini } 859e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 8605d6d30edSStella Laurenzo // unsigned i8 8615d6d30edSStella Laurenzo return bufferInfo<uint8_t>(shapedType); 8625d6d30edSStella Laurenzo } 8635d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 8645d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 16) { 8655d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 8665d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 8675d6d30edSStella Laurenzo // i16 8685d6d30edSStella Laurenzo return bufferInfo<int16_t>(shapedType); 869e5639b3fSMehdi Amini } 870e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 8715d6d30edSStella Laurenzo // unsigned i16 8725d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType); 873436c6c9cSStella Laurenzo } 8741824e45cSKasper Nielsen } else if (mlirTypeIsAInteger(elementType) && 8751824e45cSKasper Nielsen mlirIntegerTypeGetWidth(elementType) == 1) { 8761824e45cSKasper Nielsen // i1 / bool 8771824e45cSKasper Nielsen // We can not send the buffer directly back to Python, because the i1 8781824e45cSKasper Nielsen // values are bitpacked within MLIR. We call numpy's unpackbits function 8791824e45cSKasper Nielsen // to convert the bytes. 8801824e45cSKasper Nielsen return getBooleanBufferFromBitpackedAttribute(); 881436c6c9cSStella Laurenzo } 882436c6c9cSStella Laurenzo 883c5f445d1SStella Laurenzo // TODO: Currently crashes the program. 8845d6d30edSStella Laurenzo // Reported as https://github.com/pybind/pybind11/issues/3336 885c5f445d1SStella Laurenzo throw std::invalid_argument( 886c5f445d1SStella Laurenzo "unsupported data type for conversion to Python buffer"); 887436c6c9cSStella Laurenzo } 888436c6c9cSStella Laurenzo 889436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 890436c6c9cSStella Laurenzo c.def("__len__", &PyDenseElementsAttribute::dunderLen) 891436c6c9cSStella Laurenzo .def_static("get", PyDenseElementsAttribute::getFromBuffer, 892436c6c9cSStella Laurenzo py::arg("array"), py::arg("signless") = true, 8935d6d30edSStella Laurenzo py::arg("type") = py::none(), py::arg("shape") = py::none(), 894436c6c9cSStella Laurenzo py::arg("context") = py::none(), 8955d6d30edSStella Laurenzo kDenseElementsAttrGetDocstring) 896c912f0e7Spranavm-nvidia .def_static("get", PyDenseElementsAttribute::getFromList, 897c912f0e7Spranavm-nvidia py::arg("attrs"), py::arg("type") = py::none(), 898c912f0e7Spranavm-nvidia py::arg("context") = py::none(), 899c912f0e7Spranavm-nvidia kDenseElementsAttrGetFromListDocstring) 900436c6c9cSStella Laurenzo .def_static("get_splat", PyDenseElementsAttribute::getSplat, 901436c6c9cSStella Laurenzo py::arg("shaped_type"), py::arg("element_attr"), 902436c6c9cSStella Laurenzo "Gets a DenseElementsAttr where all values are the same") 903436c6c9cSStella Laurenzo .def_property_readonly("is_splat", 904436c6c9cSStella Laurenzo [](PyDenseElementsAttribute &self) -> bool { 905436c6c9cSStella Laurenzo return mlirDenseElementsAttrIsSplat(self); 906436c6c9cSStella Laurenzo }) 90791259963SAdam Paszke .def("get_splat_value", 908974c1596SRahul Kayaith [](PyDenseElementsAttribute &self) { 909974c1596SRahul Kayaith if (!mlirDenseElementsAttrIsSplat(self)) 9104811270bSmax throw py::value_error( 91191259963SAdam Paszke "get_splat_value called on a non-splat attribute"); 912974c1596SRahul Kayaith return mlirDenseElementsAttrGetSplatValue(self); 91391259963SAdam Paszke }) 914436c6c9cSStella Laurenzo .def_buffer(&PyDenseElementsAttribute::accessBuffer); 915436c6c9cSStella Laurenzo } 916436c6c9cSStella Laurenzo 917436c6c9cSStella Laurenzo private: 91871a25454SPeter Hawkins static bool isUnsignedIntegerFormat(std::string_view format) { 919436c6c9cSStella Laurenzo if (format.empty()) 920436c6c9cSStella Laurenzo return false; 921436c6c9cSStella Laurenzo char code = format[0]; 922436c6c9cSStella Laurenzo return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 923436c6c9cSStella Laurenzo code == 'Q'; 924436c6c9cSStella Laurenzo } 925436c6c9cSStella Laurenzo 92671a25454SPeter Hawkins static bool isSignedIntegerFormat(std::string_view format) { 927436c6c9cSStella Laurenzo if (format.empty()) 928436c6c9cSStella Laurenzo return false; 929436c6c9cSStella Laurenzo char code = format[0]; 930436c6c9cSStella Laurenzo return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 931436c6c9cSStella Laurenzo code == 'q'; 932436c6c9cSStella Laurenzo } 933436c6c9cSStella Laurenzo 9341824e45cSKasper Nielsen static MlirType 9351824e45cSKasper Nielsen getShapedType(std::optional<MlirType> bulkLoadElementType, 9361824e45cSKasper Nielsen std::optional<std::vector<int64_t>> explicitShape, 9371824e45cSKasper Nielsen Py_buffer &view) { 9381824e45cSKasper Nielsen SmallVector<int64_t> shape; 9391824e45cSKasper Nielsen if (explicitShape) { 9401824e45cSKasper Nielsen shape.append(explicitShape->begin(), explicitShape->end()); 9411824e45cSKasper Nielsen } else { 9421824e45cSKasper Nielsen shape.append(view.shape, view.shape + view.ndim); 9431824e45cSKasper Nielsen } 9441824e45cSKasper Nielsen 9451824e45cSKasper Nielsen if (mlirTypeIsAShaped(*bulkLoadElementType)) { 9461824e45cSKasper Nielsen if (explicitShape) { 9471824e45cSKasper Nielsen throw std::invalid_argument("Shape can only be specified explicitly " 9481824e45cSKasper Nielsen "when the type is not a shaped type."); 9491824e45cSKasper Nielsen } 9501824e45cSKasper Nielsen return *bulkLoadElementType; 9511824e45cSKasper Nielsen } else { 9521824e45cSKasper Nielsen MlirAttribute encodingAttr = mlirAttributeGetNull(); 9531824e45cSKasper Nielsen return mlirRankedTensorTypeGet(shape.size(), shape.data(), 9541824e45cSKasper Nielsen *bulkLoadElementType, encodingAttr); 9551824e45cSKasper Nielsen } 9561824e45cSKasper Nielsen } 9571824e45cSKasper Nielsen 9581824e45cSKasper Nielsen static MlirAttribute getAttributeFromBuffer( 9591824e45cSKasper Nielsen Py_buffer &view, bool signless, std::optional<PyType> explicitType, 9601824e45cSKasper Nielsen std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) { 9611824e45cSKasper Nielsen // Detect format codes that are suitable for bulk loading. This includes 9621824e45cSKasper Nielsen // all byte aligned integer and floating point types up to 8 bytes. 9631824e45cSKasper Nielsen // Notably, this excludes exotics types which do not have a direct 9641824e45cSKasper Nielsen // representation in the buffer protocol (i.e. complex, etc). 9651824e45cSKasper Nielsen std::optional<MlirType> bulkLoadElementType; 9661824e45cSKasper Nielsen if (explicitType) { 9671824e45cSKasper Nielsen bulkLoadElementType = *explicitType; 9681824e45cSKasper Nielsen } else { 9691824e45cSKasper Nielsen std::string_view format(view.format); 9701824e45cSKasper Nielsen if (format == "f") { 9711824e45cSKasper Nielsen // f32 9721824e45cSKasper Nielsen assert(view.itemsize == 4 && "mismatched array itemsize"); 9731824e45cSKasper Nielsen bulkLoadElementType = mlirF32TypeGet(context); 9741824e45cSKasper Nielsen } else if (format == "d") { 9751824e45cSKasper Nielsen // f64 9761824e45cSKasper Nielsen assert(view.itemsize == 8 && "mismatched array itemsize"); 9771824e45cSKasper Nielsen bulkLoadElementType = mlirF64TypeGet(context); 9781824e45cSKasper Nielsen } else if (format == "e") { 9791824e45cSKasper Nielsen // f16 9801824e45cSKasper Nielsen assert(view.itemsize == 2 && "mismatched array itemsize"); 9811824e45cSKasper Nielsen bulkLoadElementType = mlirF16TypeGet(context); 9821824e45cSKasper Nielsen } else if (format == "?") { 9831824e45cSKasper Nielsen // i1 9841824e45cSKasper Nielsen // The i1 type needs to be bit-packed, so we will handle it seperately 9851824e45cSKasper Nielsen return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, 9861824e45cSKasper Nielsen context); 9871824e45cSKasper Nielsen } else if (isSignedIntegerFormat(format)) { 9881824e45cSKasper Nielsen if (view.itemsize == 4) { 9891824e45cSKasper Nielsen // i32 9901824e45cSKasper Nielsen bulkLoadElementType = signless 9911824e45cSKasper Nielsen ? mlirIntegerTypeGet(context, 32) 9921824e45cSKasper Nielsen : mlirIntegerTypeSignedGet(context, 32); 9931824e45cSKasper Nielsen } else if (view.itemsize == 8) { 9941824e45cSKasper Nielsen // i64 9951824e45cSKasper Nielsen bulkLoadElementType = signless 9961824e45cSKasper Nielsen ? mlirIntegerTypeGet(context, 64) 9971824e45cSKasper Nielsen : mlirIntegerTypeSignedGet(context, 64); 9981824e45cSKasper Nielsen } else if (view.itemsize == 1) { 9991824e45cSKasper Nielsen // i8 10001824e45cSKasper Nielsen bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 10011824e45cSKasper Nielsen : mlirIntegerTypeSignedGet(context, 8); 10021824e45cSKasper Nielsen } else if (view.itemsize == 2) { 10031824e45cSKasper Nielsen // i16 10041824e45cSKasper Nielsen bulkLoadElementType = signless 10051824e45cSKasper Nielsen ? mlirIntegerTypeGet(context, 16) 10061824e45cSKasper Nielsen : mlirIntegerTypeSignedGet(context, 16); 10071824e45cSKasper Nielsen } 10081824e45cSKasper Nielsen } else if (isUnsignedIntegerFormat(format)) { 10091824e45cSKasper Nielsen if (view.itemsize == 4) { 10101824e45cSKasper Nielsen // unsigned i32 10111824e45cSKasper Nielsen bulkLoadElementType = signless 10121824e45cSKasper Nielsen ? mlirIntegerTypeGet(context, 32) 10131824e45cSKasper Nielsen : mlirIntegerTypeUnsignedGet(context, 32); 10141824e45cSKasper Nielsen } else if (view.itemsize == 8) { 10151824e45cSKasper Nielsen // unsigned i64 10161824e45cSKasper Nielsen bulkLoadElementType = signless 10171824e45cSKasper Nielsen ? mlirIntegerTypeGet(context, 64) 10181824e45cSKasper Nielsen : mlirIntegerTypeUnsignedGet(context, 64); 10191824e45cSKasper Nielsen } else if (view.itemsize == 1) { 10201824e45cSKasper Nielsen // i8 10211824e45cSKasper Nielsen bulkLoadElementType = signless 10221824e45cSKasper Nielsen ? mlirIntegerTypeGet(context, 8) 10231824e45cSKasper Nielsen : mlirIntegerTypeUnsignedGet(context, 8); 10241824e45cSKasper Nielsen } else if (view.itemsize == 2) { 10251824e45cSKasper Nielsen // i16 10261824e45cSKasper Nielsen bulkLoadElementType = signless 10271824e45cSKasper Nielsen ? mlirIntegerTypeGet(context, 16) 10281824e45cSKasper Nielsen : mlirIntegerTypeUnsignedGet(context, 16); 10291824e45cSKasper Nielsen } 10301824e45cSKasper Nielsen } 10311824e45cSKasper Nielsen if (!bulkLoadElementType) { 10321824e45cSKasper Nielsen throw std::invalid_argument( 10331824e45cSKasper Nielsen std::string("unimplemented array format conversion from format: ") + 10341824e45cSKasper Nielsen std::string(format)); 10351824e45cSKasper Nielsen } 10361824e45cSKasper Nielsen } 10371824e45cSKasper Nielsen 10381824e45cSKasper Nielsen MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); 10391824e45cSKasper Nielsen return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); 10401824e45cSKasper Nielsen } 10411824e45cSKasper Nielsen 10421824e45cSKasper Nielsen // There is a complication for boolean numpy arrays, as numpy represents them 10431824e45cSKasper Nielsen // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans 10441824e45cSKasper Nielsen // per byte. 10451824e45cSKasper Nielsen static MlirAttribute getBitpackedAttributeFromBooleanBuffer( 10461824e45cSKasper Nielsen Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape, 10471824e45cSKasper Nielsen MlirContext &context) { 10481824e45cSKasper Nielsen if (llvm::endianness::native != llvm::endianness::little) { 10491824e45cSKasper Nielsen // Given we have no good way of testing the behavior on big-endian systems 10501824e45cSKasper Nielsen // we will throw 10511824e45cSKasper Nielsen throw py::type_error("Constructing a bit-packed MLIR attribute is " 10521824e45cSKasper Nielsen "unsupported on big-endian systems"); 10531824e45cSKasper Nielsen } 10541824e45cSKasper Nielsen 10551824e45cSKasper Nielsen py::array_t<uint8_t> unpackedArray(view.len, 10561824e45cSKasper Nielsen static_cast<uint8_t *>(view.buf)); 10571824e45cSKasper Nielsen 10581824e45cSKasper Nielsen py::module numpy = py::module::import("numpy"); 10591824e45cSKasper Nielsen py::object packbitsFunc = numpy.attr("packbits"); 10601824e45cSKasper Nielsen py::object packedBooleans = 10611824e45cSKasper Nielsen packbitsFunc(unpackedArray, "bitorder"_a = "little"); 10621824e45cSKasper Nielsen py::buffer_info pythonBuffer = packedBooleans.cast<py::buffer>().request(); 10631824e45cSKasper Nielsen 10641824e45cSKasper Nielsen MlirType bitpackedType = 10651824e45cSKasper Nielsen getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); 10661824e45cSKasper Nielsen assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8"); 10671824e45cSKasper Nielsen // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of 10681824e45cSKasper Nielsen // packedBooleans, hence the MlirAttribute will remain valid even when 10691824e45cSKasper Nielsen // packedBooleans get reclaimed by the end of the function. 10701824e45cSKasper Nielsen return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, 10711824e45cSKasper Nielsen pythonBuffer.ptr); 10721824e45cSKasper Nielsen } 10731824e45cSKasper Nielsen 10741824e45cSKasper Nielsen // This does the opposite transformation of 10751824e45cSKasper Nielsen // `getBitpackedAttributeFromBooleanBuffer` 10761824e45cSKasper Nielsen py::buffer_info getBooleanBufferFromBitpackedAttribute() { 10771824e45cSKasper Nielsen if (llvm::endianness::native != llvm::endianness::little) { 10781824e45cSKasper Nielsen // Given we have no good way of testing the behavior on big-endian systems 10791824e45cSKasper Nielsen // we will throw 10801824e45cSKasper Nielsen throw py::type_error("Constructing a numpy array from a MLIR attribute " 10811824e45cSKasper Nielsen "is unsupported on big-endian systems"); 10821824e45cSKasper Nielsen } 10831824e45cSKasper Nielsen 10841824e45cSKasper Nielsen int64_t numBooleans = mlirElementsAttrGetNumElements(*this); 10851824e45cSKasper Nielsen int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); 10861824e45cSKasper Nielsen uint8_t *bitpackedData = static_cast<uint8_t *>( 10871824e45cSKasper Nielsen const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 10881824e45cSKasper Nielsen py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData); 10891824e45cSKasper Nielsen 10901824e45cSKasper Nielsen py::module numpy = py::module::import("numpy"); 10911824e45cSKasper Nielsen py::object unpackbitsFunc = numpy.attr("unpackbits"); 10921824e45cSKasper Nielsen py::object equalFunc = numpy.attr("equal"); 10931824e45cSKasper Nielsen py::object reshapeFunc = numpy.attr("reshape"); 10941824e45cSKasper Nielsen py::array unpackedBooleans = 10951824e45cSKasper Nielsen unpackbitsFunc(packedArray, "bitorder"_a = "little"); 10961824e45cSKasper Nielsen 10971824e45cSKasper Nielsen // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array. 10981824e45cSKasper Nielsen // We need to: 10991824e45cSKasper Nielsen // 1. Slice away the padded bits 11001824e45cSKasper Nielsen // 2. Make the boolean array have the correct shape 11011824e45cSKasper Nielsen // 3. Convert the array to a boolean array 11021824e45cSKasper Nielsen unpackedBooleans = unpackedBooleans[py::slice(0, numBooleans, 1)]; 11031824e45cSKasper Nielsen unpackedBooleans = equalFunc(unpackedBooleans, 1); 11041824e45cSKasper Nielsen 11051824e45cSKasper Nielsen MlirType shapedType = mlirAttributeGetType(*this); 11061824e45cSKasper Nielsen intptr_t rank = mlirShapedTypeGetRank(shapedType); 1107*404d0e99SAdrian Kuegel std::vector<intptr_t> shape(rank); 11081824e45cSKasper Nielsen for (intptr_t i = 0; i < rank; ++i) { 1109*404d0e99SAdrian Kuegel shape[i] = mlirShapedTypeGetDimSize(shapedType, i); 11101824e45cSKasper Nielsen } 11111824e45cSKasper Nielsen unpackedBooleans = reshapeFunc(unpackedBooleans, shape); 11121824e45cSKasper Nielsen 11131824e45cSKasper Nielsen // Make sure the returned py::buffer_view claims ownership of the data in 11141824e45cSKasper Nielsen // `pythonBuffer` so it remains valid when Python reads it 11151824e45cSKasper Nielsen py::buffer pythonBuffer = unpackedBooleans.cast<py::buffer>(); 11161824e45cSKasper Nielsen return pythonBuffer.request(); 11171824e45cSKasper Nielsen } 11181824e45cSKasper Nielsen 1119436c6c9cSStella Laurenzo template <typename Type> 1120436c6c9cSStella Laurenzo py::buffer_info bufferInfo(MlirType shapedType, 11215d6d30edSStella Laurenzo const char *explicitFormat = nullptr) { 11220a68171bSDmitri Gribenko intptr_t rank = mlirShapedTypeGetRank(shapedType); 1123436c6c9cSStella Laurenzo // Prepare the data for the buffer_info. 11240a68171bSDmitri Gribenko // Buffer is configured for read-only access below. 1125436c6c9cSStella Laurenzo Type *data = static_cast<Type *>( 1126436c6c9cSStella Laurenzo const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 1127436c6c9cSStella Laurenzo // Prepare the shape for the buffer_info. 1128436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> shape; 1129436c6c9cSStella Laurenzo for (intptr_t i = 0; i < rank; ++i) 1130436c6c9cSStella Laurenzo shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 1131436c6c9cSStella Laurenzo // Prepare the strides for the buffer_info. 1132436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> strides; 1133f0e847d0SRahul Kayaith if (mlirDenseElementsAttrIsSplat(*this)) { 1134f0e847d0SRahul Kayaith // Splats are special, only the single value is stored. 1135f0e847d0SRahul Kayaith strides.assign(rank, 0); 1136f0e847d0SRahul Kayaith } else { 1137436c6c9cSStella Laurenzo for (intptr_t i = 1; i < rank; ++i) { 1138f0e847d0SRahul Kayaith intptr_t strideFactor = 1; 1139f0e847d0SRahul Kayaith for (intptr_t j = i; j < rank; ++j) 1140436c6c9cSStella Laurenzo strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 1141436c6c9cSStella Laurenzo strides.push_back(sizeof(Type) * strideFactor); 1142436c6c9cSStella Laurenzo } 1143436c6c9cSStella Laurenzo strides.push_back(sizeof(Type)); 1144f0e847d0SRahul Kayaith } 11455d6d30edSStella Laurenzo std::string format; 11465d6d30edSStella Laurenzo if (explicitFormat) { 11475d6d30edSStella Laurenzo format = explicitFormat; 11485d6d30edSStella Laurenzo } else { 11495d6d30edSStella Laurenzo format = py::format_descriptor<Type>::format(); 11505d6d30edSStella Laurenzo } 11515d6d30edSStella Laurenzo return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, 11525d6d30edSStella Laurenzo /*readonly=*/true); 1153436c6c9cSStella Laurenzo } 1154436c6c9cSStella Laurenzo }; // namespace 1155436c6c9cSStella Laurenzo 1156436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer 1157436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access. 1158436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute 1159436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseIntElementsAttribute, 1160436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 1161436c6c9cSStella Laurenzo public: 1162436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 1163436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseIntElementsAttr"; 1164436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1165436c6c9cSStella Laurenzo 1166436c6c9cSStella Laurenzo /// Returns the element at the given linear position. Asserts if the index is 1167436c6c9cSStella Laurenzo /// out of range. 1168436c6c9cSStella Laurenzo py::int_ dunderGetItem(intptr_t pos) { 1169436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 11704811270bSmax throw py::index_error("attempt to access out of bounds element"); 1171436c6c9cSStella Laurenzo } 1172436c6c9cSStella Laurenzo 1173436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 1174436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 1175436c6c9cSStella Laurenzo assert(mlirTypeIsAInteger(type) && 1176436c6c9cSStella Laurenzo "expected integer element type in dense int elements attribute"); 1177436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 1178436c6c9cSStella Laurenzo // elemental type of the attribute. py::int_ is implicitly constructible 1179436c6c9cSStella Laurenzo // from any C++ integral type and handles bitwidth correctly. 1180436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 1181436c6c9cSStella Laurenzo // querying them on each element access. 1182436c6c9cSStella Laurenzo unsigned width = mlirIntegerTypeGetWidth(type); 1183436c6c9cSStella Laurenzo bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 1184436c6c9cSStella Laurenzo if (isUnsigned) { 1185436c6c9cSStella Laurenzo if (width == 1) { 1186436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 1187436c6c9cSStella Laurenzo } 1188308d8b8cSRahul Kayaith if (width == 8) { 1189308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetUInt8Value(*this, pos); 1190308d8b8cSRahul Kayaith } 1191308d8b8cSRahul Kayaith if (width == 16) { 1192308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetUInt16Value(*this, pos); 1193308d8b8cSRahul Kayaith } 1194436c6c9cSStella Laurenzo if (width == 32) { 1195436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt32Value(*this, pos); 1196436c6c9cSStella Laurenzo } 1197436c6c9cSStella Laurenzo if (width == 64) { 1198436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt64Value(*this, pos); 1199436c6c9cSStella Laurenzo } 1200436c6c9cSStella Laurenzo } else { 1201436c6c9cSStella Laurenzo if (width == 1) { 1202436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 1203436c6c9cSStella Laurenzo } 1204308d8b8cSRahul Kayaith if (width == 8) { 1205308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetInt8Value(*this, pos); 1206308d8b8cSRahul Kayaith } 1207308d8b8cSRahul Kayaith if (width == 16) { 1208308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetInt16Value(*this, pos); 1209308d8b8cSRahul Kayaith } 1210436c6c9cSStella Laurenzo if (width == 32) { 1211436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt32Value(*this, pos); 1212436c6c9cSStella Laurenzo } 1213436c6c9cSStella Laurenzo if (width == 64) { 1214436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt64Value(*this, pos); 1215436c6c9cSStella Laurenzo } 1216436c6c9cSStella Laurenzo } 12174811270bSmax throw py::type_error("Unsupported integer type"); 1218436c6c9cSStella Laurenzo } 1219436c6c9cSStella Laurenzo 1220436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1221436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 1222436c6c9cSStella Laurenzo } 1223436c6c9cSStella Laurenzo }; 1224436c6c9cSStella Laurenzo 1225f66cd9e9SStella Laurenzo class PyDenseResourceElementsAttribute 1226f66cd9e9SStella Laurenzo : public PyConcreteAttribute<PyDenseResourceElementsAttribute> { 1227f66cd9e9SStella Laurenzo public: 1228f66cd9e9SStella Laurenzo static constexpr IsAFunctionTy isaFunction = 1229f66cd9e9SStella Laurenzo mlirAttributeIsADenseResourceElements; 1230f66cd9e9SStella Laurenzo static constexpr const char *pyClassName = "DenseResourceElementsAttr"; 1231f66cd9e9SStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1232f66cd9e9SStella Laurenzo 1233f66cd9e9SStella Laurenzo static PyDenseResourceElementsAttribute 1234962bf002SMehdi Amini getFromBuffer(py::buffer buffer, const std::string &name, const PyType &type, 1235f66cd9e9SStella Laurenzo std::optional<size_t> alignment, bool isMutable, 1236f66cd9e9SStella Laurenzo DefaultingPyMlirContext contextWrapper) { 1237f66cd9e9SStella Laurenzo if (!mlirTypeIsAShaped(type)) { 1238f66cd9e9SStella Laurenzo throw std::invalid_argument( 1239f66cd9e9SStella Laurenzo "Constructing a DenseResourceElementsAttr requires a ShapedType."); 1240f66cd9e9SStella Laurenzo } 1241f66cd9e9SStella Laurenzo 1242f66cd9e9SStella Laurenzo // Do not request any conversions as we must ensure to use caller 1243f66cd9e9SStella Laurenzo // managed memory. 1244f66cd9e9SStella Laurenzo int flags = PyBUF_STRIDES; 1245f66cd9e9SStella Laurenzo std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>(); 1246f66cd9e9SStella Laurenzo if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { 1247f66cd9e9SStella Laurenzo throw py::error_already_set(); 1248f66cd9e9SStella Laurenzo } 1249f66cd9e9SStella Laurenzo 1250f66cd9e9SStella Laurenzo // This scope releaser will only release if we haven't yet transferred 1251f66cd9e9SStella Laurenzo // ownership. 1252f66cd9e9SStella Laurenzo auto freeBuffer = llvm::make_scope_exit([&]() { 1253f66cd9e9SStella Laurenzo if (view) 1254f66cd9e9SStella Laurenzo PyBuffer_Release(view.get()); 1255f66cd9e9SStella Laurenzo }); 1256f66cd9e9SStella Laurenzo 1257f66cd9e9SStella Laurenzo if (!PyBuffer_IsContiguous(view.get(), 'A')) { 1258f66cd9e9SStella Laurenzo throw std::invalid_argument("Contiguous buffer is required."); 1259f66cd9e9SStella Laurenzo } 1260f66cd9e9SStella Laurenzo 1261f66cd9e9SStella Laurenzo // Infer alignment to be the stride of one element if not explicit. 1262f66cd9e9SStella Laurenzo size_t inferredAlignment; 1263f66cd9e9SStella Laurenzo if (alignment) 1264f66cd9e9SStella Laurenzo inferredAlignment = *alignment; 1265f66cd9e9SStella Laurenzo else 1266f66cd9e9SStella Laurenzo inferredAlignment = view->strides[view->ndim - 1]; 1267f66cd9e9SStella Laurenzo 1268f66cd9e9SStella Laurenzo // The userData is a Py_buffer* that the deleter owns. 1269f66cd9e9SStella Laurenzo auto deleter = [](void *userData, const void *data, size_t size, 1270f66cd9e9SStella Laurenzo size_t align) { 1271f66cd9e9SStella Laurenzo Py_buffer *ownedView = static_cast<Py_buffer *>(userData); 1272f66cd9e9SStella Laurenzo PyBuffer_Release(ownedView); 1273f66cd9e9SStella Laurenzo delete ownedView; 1274f66cd9e9SStella Laurenzo }; 1275f66cd9e9SStella Laurenzo 1276f66cd9e9SStella Laurenzo size_t rawBufferSize = view->len; 1277f66cd9e9SStella Laurenzo MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet( 1278f66cd9e9SStella Laurenzo type, toMlirStringRef(name), view->buf, rawBufferSize, 1279f66cd9e9SStella Laurenzo inferredAlignment, isMutable, deleter, static_cast<void *>(view.get())); 1280f66cd9e9SStella Laurenzo if (mlirAttributeIsNull(attr)) { 1281f66cd9e9SStella Laurenzo throw std::invalid_argument( 1282f66cd9e9SStella Laurenzo "DenseResourceElementsAttr could not be constructed from the given " 1283f66cd9e9SStella Laurenzo "buffer. " 1284f66cd9e9SStella Laurenzo "This may mean that the Python buffer layout does not match that " 1285f66cd9e9SStella Laurenzo "MLIR expected layout and is a bug."); 1286f66cd9e9SStella Laurenzo } 1287f66cd9e9SStella Laurenzo view.release(); 1288f66cd9e9SStella Laurenzo return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr); 1289f66cd9e9SStella Laurenzo } 1290f66cd9e9SStella Laurenzo 1291f66cd9e9SStella Laurenzo static void bindDerived(ClassTy &c) { 1292f66cd9e9SStella Laurenzo c.def_static("get_from_buffer", 1293f66cd9e9SStella Laurenzo PyDenseResourceElementsAttribute::getFromBuffer, 1294f66cd9e9SStella Laurenzo py::arg("array"), py::arg("name"), py::arg("type"), 1295f66cd9e9SStella Laurenzo py::arg("alignment") = py::none(), 1296f66cd9e9SStella Laurenzo py::arg("is_mutable") = false, py::arg("context") = py::none(), 1297f66cd9e9SStella Laurenzo kDenseResourceElementsAttrGetFromBufferDocstring); 1298f66cd9e9SStella Laurenzo } 1299f66cd9e9SStella Laurenzo }; 1300f66cd9e9SStella Laurenzo 1301436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 1302436c6c9cSStella Laurenzo public: 1303436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 1304436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DictAttr"; 1305436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 13069566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 13079566ee28Smax mlirDictionaryAttrGetTypeID; 1308436c6c9cSStella Laurenzo 1309436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 1310436c6c9cSStella Laurenzo 13119fb1086bSAdrian Kuegel bool dunderContains(const std::string &name) { 13129fb1086bSAdrian Kuegel return !mlirAttributeIsNull( 13139fb1086bSAdrian Kuegel mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); 13149fb1086bSAdrian Kuegel } 13159fb1086bSAdrian Kuegel 1316436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 13179fb1086bSAdrian Kuegel c.def("__contains__", &PyDictAttribute::dunderContains); 1318436c6c9cSStella Laurenzo c.def("__len__", &PyDictAttribute::dunderLen); 1319436c6c9cSStella Laurenzo c.def_static( 1320436c6c9cSStella Laurenzo "get", 1321436c6c9cSStella Laurenzo [](py::dict attributes, DefaultingPyMlirContext context) { 1322436c6c9cSStella Laurenzo SmallVector<MlirNamedAttribute> mlirNamedAttributes; 1323436c6c9cSStella Laurenzo mlirNamedAttributes.reserve(attributes.size()); 1324436c6c9cSStella Laurenzo for (auto &it : attributes) { 132502b6fb21SMehdi Amini auto &mlirAttr = it.second.cast<PyAttribute &>(); 1326436c6c9cSStella Laurenzo auto name = it.first.cast<std::string>(); 1327436c6c9cSStella Laurenzo mlirNamedAttributes.push_back(mlirNamedAttributeGet( 132802b6fb21SMehdi Amini mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), 1329436c6c9cSStella Laurenzo toMlirStringRef(name)), 133002b6fb21SMehdi Amini mlirAttr)); 1331436c6c9cSStella Laurenzo } 1332436c6c9cSStella Laurenzo MlirAttribute attr = 1333436c6c9cSStella Laurenzo mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 1334436c6c9cSStella Laurenzo mlirNamedAttributes.data()); 1335436c6c9cSStella Laurenzo return PyDictAttribute(context->getRef(), attr); 1336436c6c9cSStella Laurenzo }, 1337ed9e52f3SAlex Zinenko py::arg("value") = py::dict(), py::arg("context") = py::none(), 1338436c6c9cSStella Laurenzo "Gets an uniqued dict attribute"); 1339436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 1340436c6c9cSStella Laurenzo MlirAttribute attr = 1341436c6c9cSStella Laurenzo mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 1342974c1596SRahul Kayaith if (mlirAttributeIsNull(attr)) 13434811270bSmax throw py::key_error("attempt to access a non-existent attribute"); 1344974c1596SRahul Kayaith return attr; 1345436c6c9cSStella Laurenzo }); 1346436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 1347436c6c9cSStella Laurenzo if (index < 0 || index >= self.dunderLen()) { 13484811270bSmax throw py::index_error("attempt to access out of bounds attribute"); 1349436c6c9cSStella Laurenzo } 1350436c6c9cSStella Laurenzo MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 1351436c6c9cSStella Laurenzo return PyNamedAttribute( 1352436c6c9cSStella Laurenzo namedAttr.attribute, 1353436c6c9cSStella Laurenzo std::string(mlirIdentifierStr(namedAttr.name).data)); 1354436c6c9cSStella Laurenzo }); 1355436c6c9cSStella Laurenzo } 1356436c6c9cSStella Laurenzo }; 1357436c6c9cSStella Laurenzo 1358436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing 1359436c6c9cSStella Laurenzo /// floating-point values. Supports element access. 1360436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute 1361436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseFPElementsAttribute, 1362436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 1363436c6c9cSStella Laurenzo public: 1364436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 1365436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseFPElementsAttr"; 1366436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1367436c6c9cSStella Laurenzo 1368436c6c9cSStella Laurenzo py::float_ dunderGetItem(intptr_t pos) { 1369436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 13704811270bSmax throw py::index_error("attempt to access out of bounds element"); 1371436c6c9cSStella Laurenzo } 1372436c6c9cSStella Laurenzo 1373436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 1374436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 1375436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 1376436c6c9cSStella Laurenzo // elemental type of the attribute. py::float_ is implicitly constructible 1377436c6c9cSStella Laurenzo // from float and double. 1378436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 1379436c6c9cSStella Laurenzo // querying them on each element access. 1380436c6c9cSStella Laurenzo if (mlirTypeIsAF32(type)) { 1381436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetFloatValue(*this, pos); 1382436c6c9cSStella Laurenzo } 1383436c6c9cSStella Laurenzo if (mlirTypeIsAF64(type)) { 1384436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetDoubleValue(*this, pos); 1385436c6c9cSStella Laurenzo } 13864811270bSmax throw py::type_error("Unsupported floating-point type"); 1387436c6c9cSStella Laurenzo } 1388436c6c9cSStella Laurenzo 1389436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1390436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 1391436c6c9cSStella Laurenzo } 1392436c6c9cSStella Laurenzo }; 1393436c6c9cSStella Laurenzo 1394436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 1395436c6c9cSStella Laurenzo public: 1396436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 1397436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "TypeAttr"; 1398436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 13999566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 14009566ee28Smax mlirTypeAttrGetTypeID; 1401436c6c9cSStella Laurenzo 1402436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1403436c6c9cSStella Laurenzo c.def_static( 1404436c6c9cSStella Laurenzo "get", 1405436c6c9cSStella Laurenzo [](PyType value, DefaultingPyMlirContext context) { 1406436c6c9cSStella Laurenzo MlirAttribute attr = mlirTypeAttrGet(value.get()); 1407436c6c9cSStella Laurenzo return PyTypeAttribute(context->getRef(), attr); 1408436c6c9cSStella Laurenzo }, 1409436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 1410436c6c9cSStella Laurenzo "Gets a uniqued Type attribute"); 1411436c6c9cSStella Laurenzo c.def_property_readonly("value", [](PyTypeAttribute &self) { 1412bfb1ba75Smax return mlirTypeAttrGetValue(self.get()); 1413436c6c9cSStella Laurenzo }); 1414436c6c9cSStella Laurenzo } 1415436c6c9cSStella Laurenzo }; 1416436c6c9cSStella Laurenzo 1417436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values. 1418436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 1419436c6c9cSStella Laurenzo public: 1420436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 1421436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "UnitAttr"; 1422436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 14239566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 14249566ee28Smax mlirUnitAttrGetTypeID; 1425436c6c9cSStella Laurenzo 1426436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1427436c6c9cSStella Laurenzo c.def_static( 1428436c6c9cSStella Laurenzo "get", 1429436c6c9cSStella Laurenzo [](DefaultingPyMlirContext context) { 1430436c6c9cSStella Laurenzo return PyUnitAttribute(context->getRef(), 1431436c6c9cSStella Laurenzo mlirUnitAttrGet(context->get())); 1432436c6c9cSStella Laurenzo }, 1433436c6c9cSStella Laurenzo py::arg("context") = py::none(), "Create a Unit attribute."); 1434436c6c9cSStella Laurenzo } 1435436c6c9cSStella Laurenzo }; 1436436c6c9cSStella Laurenzo 1437ac2e2d65SDenys Shabalin /// Strided layout attribute subclass. 1438ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute 1439ac2e2d65SDenys Shabalin : public PyConcreteAttribute<PyStridedLayoutAttribute> { 1440ac2e2d65SDenys Shabalin public: 1441ac2e2d65SDenys Shabalin static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; 1442ac2e2d65SDenys Shabalin static constexpr const char *pyClassName = "StridedLayoutAttr"; 1443ac2e2d65SDenys Shabalin using PyConcreteAttribute::PyConcreteAttribute; 14449566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 14459566ee28Smax mlirStridedLayoutAttrGetTypeID; 1446ac2e2d65SDenys Shabalin 1447ac2e2d65SDenys Shabalin static void bindDerived(ClassTy &c) { 1448ac2e2d65SDenys Shabalin c.def_static( 1449ac2e2d65SDenys Shabalin "get", 1450ac2e2d65SDenys Shabalin [](int64_t offset, const std::vector<int64_t> strides, 1451ac2e2d65SDenys Shabalin DefaultingPyMlirContext ctx) { 1452ac2e2d65SDenys Shabalin MlirAttribute attr = mlirStridedLayoutAttrGet( 1453ac2e2d65SDenys Shabalin ctx->get(), offset, strides.size(), strides.data()); 1454ac2e2d65SDenys Shabalin return PyStridedLayoutAttribute(ctx->getRef(), attr); 1455ac2e2d65SDenys Shabalin }, 1456ac2e2d65SDenys Shabalin py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(), 1457ac2e2d65SDenys Shabalin "Gets a strided layout attribute."); 1458e3fd612eSDenys Shabalin c.def_static( 1459e3fd612eSDenys Shabalin "get_fully_dynamic", 1460e3fd612eSDenys Shabalin [](int64_t rank, DefaultingPyMlirContext ctx) { 1461e3fd612eSDenys Shabalin auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); 1462e3fd612eSDenys Shabalin std::vector<int64_t> strides(rank); 1463e3fd612eSDenys Shabalin std::fill(strides.begin(), strides.end(), dynamic); 1464e3fd612eSDenys Shabalin MlirAttribute attr = mlirStridedLayoutAttrGet( 1465e3fd612eSDenys Shabalin ctx->get(), dynamic, strides.size(), strides.data()); 1466e3fd612eSDenys Shabalin return PyStridedLayoutAttribute(ctx->getRef(), attr); 1467e3fd612eSDenys Shabalin }, 1468e3fd612eSDenys Shabalin py::arg("rank"), py::arg("context") = py::none(), 1469e3fd612eSDenys Shabalin "Gets a strided layout attribute with dynamic offset and strides of a " 1470e3fd612eSDenys Shabalin "given rank."); 1471ac2e2d65SDenys Shabalin c.def_property_readonly( 1472ac2e2d65SDenys Shabalin "offset", 1473ac2e2d65SDenys Shabalin [](PyStridedLayoutAttribute &self) { 1474ac2e2d65SDenys Shabalin return mlirStridedLayoutAttrGetOffset(self); 1475ac2e2d65SDenys Shabalin }, 1476ac2e2d65SDenys Shabalin "Returns the value of the float point attribute"); 1477ac2e2d65SDenys Shabalin c.def_property_readonly( 1478ac2e2d65SDenys Shabalin "strides", 1479ac2e2d65SDenys Shabalin [](PyStridedLayoutAttribute &self) { 1480ac2e2d65SDenys Shabalin intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); 1481ac2e2d65SDenys Shabalin std::vector<int64_t> strides(size); 1482ac2e2d65SDenys Shabalin for (intptr_t i = 0; i < size; i++) { 1483ac2e2d65SDenys Shabalin strides[i] = mlirStridedLayoutAttrGetStride(self, i); 1484ac2e2d65SDenys Shabalin } 1485ac2e2d65SDenys Shabalin return strides; 1486ac2e2d65SDenys Shabalin }, 1487ac2e2d65SDenys Shabalin "Returns the value of the float point attribute"); 1488ac2e2d65SDenys Shabalin } 1489ac2e2d65SDenys Shabalin }; 1490ac2e2d65SDenys Shabalin 14919566ee28Smax py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { 14929566ee28Smax if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) 14939566ee28Smax return py::cast(PyDenseBoolArrayAttribute(pyAttribute)); 14949566ee28Smax if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) 14959566ee28Smax return py::cast(PyDenseI8ArrayAttribute(pyAttribute)); 14969566ee28Smax if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) 14979566ee28Smax return py::cast(PyDenseI16ArrayAttribute(pyAttribute)); 14989566ee28Smax if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) 14999566ee28Smax return py::cast(PyDenseI32ArrayAttribute(pyAttribute)); 15009566ee28Smax if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) 15019566ee28Smax return py::cast(PyDenseI64ArrayAttribute(pyAttribute)); 15029566ee28Smax if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) 15039566ee28Smax return py::cast(PyDenseF32ArrayAttribute(pyAttribute)); 15049566ee28Smax if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) 15059566ee28Smax return py::cast(PyDenseF64ArrayAttribute(pyAttribute)); 15069566ee28Smax std::string msg = 15079566ee28Smax std::string("Can't cast unknown element type DenseArrayAttr (") + 15089566ee28Smax std::string(py::repr(py::cast(pyAttribute))) + ")"; 15099566ee28Smax throw py::cast_error(msg); 15109566ee28Smax } 15119566ee28Smax 15129566ee28Smax py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { 15139566ee28Smax if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) 15149566ee28Smax return py::cast(PyDenseFPElementsAttribute(pyAttribute)); 15159566ee28Smax if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) 15169566ee28Smax return py::cast(PyDenseIntElementsAttribute(pyAttribute)); 15179566ee28Smax std::string msg = 15189566ee28Smax std::string( 15199566ee28Smax "Can't cast unknown element type DenseIntOrFPElementsAttr (") + 15209566ee28Smax std::string(py::repr(py::cast(pyAttribute))) + ")"; 15219566ee28Smax throw py::cast_error(msg); 15229566ee28Smax } 15239566ee28Smax 15249566ee28Smax py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { 15259566ee28Smax if (PyBoolAttribute::isaFunction(pyAttribute)) 15269566ee28Smax return py::cast(PyBoolAttribute(pyAttribute)); 15279566ee28Smax if (PyIntegerAttribute::isaFunction(pyAttribute)) 15289566ee28Smax return py::cast(PyIntegerAttribute(pyAttribute)); 15299566ee28Smax std::string msg = 15309566ee28Smax std::string("Can't cast unknown element type DenseArrayAttr (") + 15319566ee28Smax std::string(py::repr(py::cast(pyAttribute))) + ")"; 15329566ee28Smax throw py::cast_error(msg); 15339566ee28Smax } 15349566ee28Smax 15354eee9ef9Smax py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { 15364eee9ef9Smax if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute)) 15374eee9ef9Smax return py::cast(PyFlatSymbolRefAttribute(pyAttribute)); 15384eee9ef9Smax if (PySymbolRefAttribute::isaFunction(pyAttribute)) 15394eee9ef9Smax return py::cast(PySymbolRefAttribute(pyAttribute)); 15404eee9ef9Smax std::string msg = std::string("Can't cast unknown SymbolRef attribute (") + 15414eee9ef9Smax std::string(py::repr(py::cast(pyAttribute))) + ")"; 15424eee9ef9Smax throw py::cast_error(msg); 15434eee9ef9Smax } 15444eee9ef9Smax 1545436c6c9cSStella Laurenzo } // namespace 1546436c6c9cSStella Laurenzo 1547436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) { 1548436c6c9cSStella Laurenzo PyAffineMapAttribute::bind(m); 1549619fd8c2SJeff Niu PyDenseBoolArrayAttribute::bind(m); 1550619fd8c2SJeff Niu PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); 1551619fd8c2SJeff Niu PyDenseI8ArrayAttribute::bind(m); 1552619fd8c2SJeff Niu PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m); 1553619fd8c2SJeff Niu PyDenseI16ArrayAttribute::bind(m); 1554619fd8c2SJeff Niu PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m); 1555619fd8c2SJeff Niu PyDenseI32ArrayAttribute::bind(m); 1556619fd8c2SJeff Niu PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m); 1557619fd8c2SJeff Niu PyDenseI64ArrayAttribute::bind(m); 1558619fd8c2SJeff Niu PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m); 1559619fd8c2SJeff Niu PyDenseF32ArrayAttribute::bind(m); 1560619fd8c2SJeff Niu PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); 1561619fd8c2SJeff Niu PyDenseF64ArrayAttribute::bind(m); 1562619fd8c2SJeff Niu PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); 15639566ee28Smax PyGlobals::get().registerTypeCaster( 15649566ee28Smax mlirDenseArrayAttrGetTypeID(), 15659566ee28Smax pybind11::cpp_function(denseArrayAttributeCaster)); 1566619fd8c2SJeff Niu 1567436c6c9cSStella Laurenzo PyArrayAttribute::bind(m); 1568436c6c9cSStella Laurenzo PyArrayAttribute::PyArrayAttributeIterator::bind(m); 1569436c6c9cSStella Laurenzo PyBoolAttribute::bind(m); 1570436c6c9cSStella Laurenzo PyDenseElementsAttribute::bind(m); 1571436c6c9cSStella Laurenzo PyDenseFPElementsAttribute::bind(m); 1572436c6c9cSStella Laurenzo PyDenseIntElementsAttribute::bind(m); 15739566ee28Smax PyGlobals::get().registerTypeCaster( 15749566ee28Smax mlirDenseIntOrFPElementsAttrGetTypeID(), 15759566ee28Smax pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); 1576f66cd9e9SStella Laurenzo PyDenseResourceElementsAttribute::bind(m); 15779566ee28Smax 1578436c6c9cSStella Laurenzo PyDictAttribute::bind(m); 15794eee9ef9Smax PySymbolRefAttribute::bind(m); 15804eee9ef9Smax PyGlobals::get().registerTypeCaster( 15814eee9ef9Smax mlirSymbolRefAttrGetTypeID(), 15824eee9ef9Smax pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)); 15834eee9ef9Smax 1584436c6c9cSStella Laurenzo PyFlatSymbolRefAttribute::bind(m); 15855c3861b2SYun Long PyOpaqueAttribute::bind(m); 1586436c6c9cSStella Laurenzo PyFloatAttribute::bind(m); 1587436c6c9cSStella Laurenzo PyIntegerAttribute::bind(m); 1588334873feSAmy Wang PyIntegerSetAttribute::bind(m); 1589436c6c9cSStella Laurenzo PyStringAttribute::bind(m); 1590436c6c9cSStella Laurenzo PyTypeAttribute::bind(m); 15919566ee28Smax PyGlobals::get().registerTypeCaster( 15929566ee28Smax mlirIntegerAttrGetTypeID(), 15939566ee28Smax pybind11::cpp_function(integerOrBoolAttributeCaster)); 1594436c6c9cSStella Laurenzo PyUnitAttribute::bind(m); 1595ac2e2d65SDenys Shabalin 1596ac2e2d65SDenys Shabalin PyStridedLayoutAttribute::bind(m); 1597436c6c9cSStella Laurenzo } 1598