1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2436c6c9cSStella Laurenzo // 3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6436c6c9cSStella Laurenzo // 7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8436c6c9cSStella Laurenzo 9a1fe1f5fSKazu Hirata #include <optional> 1071a25454SPeter Hawkins #include <string_view> 114811270bSmax #include <utility> 121fc096afSMehdi Amini 13436c6c9cSStella Laurenzo #include "IRModule.h" 14436c6c9cSStella Laurenzo 15436c6c9cSStella Laurenzo #include "PybindUtils.h" 16436c6c9cSStella Laurenzo 1771a25454SPeter Hawkins #include "llvm/ADT/ScopeExit.h" 18*c912f0e7Spranavm-nvidia #include "llvm/Support/raw_ostream.h" 1971a25454SPeter Hawkins 20436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h" 21436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h" 22bfb1ba75Smax #include "mlir/Bindings/Python/PybindAdaptors.h" 23436c6c9cSStella Laurenzo 24436c6c9cSStella Laurenzo namespace py = pybind11; 25436c6c9cSStella Laurenzo using namespace mlir; 26436c6c9cSStella Laurenzo using namespace mlir::python; 27436c6c9cSStella Laurenzo 28436c6c9cSStella Laurenzo using llvm::SmallVector; 29436c6c9cSStella Laurenzo 305d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 315d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline). 325d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 335d6d30edSStella Laurenzo 345d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] = 355d6d30edSStella Laurenzo R"(Gets a DenseElementsAttr from a Python buffer or array. 365d6d30edSStella Laurenzo 375d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based 385d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and 395d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these 405d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer. 415d6d30edSStella Laurenzo 425d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided 435d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal 445d6d30edSStella Laurenzo representation: 455d6d30edSStella Laurenzo 465d6d30edSStella Laurenzo * Integer types (except for i1): the buffer must be byte aligned to the 475d6d30edSStella Laurenzo next byte boundary. 485d6d30edSStella Laurenzo * Floating point types: Must be bit-castable to the given floating point 495d6d30edSStella Laurenzo size. 505d6d30edSStella Laurenzo * i1 (bool): Bit packed into 8bit words where the bit pattern matches a 515d6d30edSStella Laurenzo row major ordering. An arbitrary Numpy `bool_` array can be bit packed to 525d6d30edSStella Laurenzo this specification with: `np.packbits(ary, axis=None, bitorder='little')`. 535d6d30edSStella Laurenzo 545d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0 555d6d30edSStella Laurenzo or 255), then a splat will be created. 565d6d30edSStella Laurenzo 575d6d30edSStella Laurenzo Args: 585d6d30edSStella Laurenzo array: The array or buffer to convert. 595d6d30edSStella Laurenzo signless: If inferring an appropriate MLIR type, use signless types for 605d6d30edSStella Laurenzo integers (defaults True). 615d6d30edSStella Laurenzo type: Skips inference of the MLIR element type and uses this instead. The 625d6d30edSStella Laurenzo storage size must be consistent with the actual contents of the buffer. 635d6d30edSStella Laurenzo shape: Overrides the shape of the buffer when constructing the MLIR 645d6d30edSStella Laurenzo shaped type. This is needed when the physical and logical shape differ (as 655d6d30edSStella Laurenzo for i1). 665d6d30edSStella Laurenzo context: Explicit context, if not from context manager. 675d6d30edSStella Laurenzo 685d6d30edSStella Laurenzo Returns: 695d6d30edSStella Laurenzo DenseElementsAttr on success. 705d6d30edSStella Laurenzo 715d6d30edSStella Laurenzo Raises: 725d6d30edSStella Laurenzo ValueError: If the type of the buffer or array cannot be matched to an MLIR 735d6d30edSStella Laurenzo type or if the buffer does not meet expectations. 745d6d30edSStella Laurenzo )"; 755d6d30edSStella Laurenzo 76*c912f0e7Spranavm-nvidia static const char kDenseElementsAttrGetFromListDocstring[] = 77*c912f0e7Spranavm-nvidia R"(Gets a DenseElementsAttr from a Python list of attributes. 78*c912f0e7Spranavm-nvidia 79*c912f0e7Spranavm-nvidia Note that it can be expensive to construct attributes individually. 80*c912f0e7Spranavm-nvidia For a large number of elements, consider using a Python buffer or array instead. 81*c912f0e7Spranavm-nvidia 82*c912f0e7Spranavm-nvidia Args: 83*c912f0e7Spranavm-nvidia attrs: A list of attributes. 84*c912f0e7Spranavm-nvidia type: The desired shape and type of the resulting DenseElementsAttr. 85*c912f0e7Spranavm-nvidia If not provided, the element type is determined based on the type 86*c912f0e7Spranavm-nvidia of the 0th attribute and the shape is `[len(attrs)]`. 87*c912f0e7Spranavm-nvidia context: Explicit context, if not from context manager. 88*c912f0e7Spranavm-nvidia 89*c912f0e7Spranavm-nvidia Returns: 90*c912f0e7Spranavm-nvidia DenseElementsAttr on success. 91*c912f0e7Spranavm-nvidia 92*c912f0e7Spranavm-nvidia Raises: 93*c912f0e7Spranavm-nvidia ValueError: If the type of the attributes does not match the type 94*c912f0e7Spranavm-nvidia specified by `shaped_type`. 95*c912f0e7Spranavm-nvidia )"; 96*c912f0e7Spranavm-nvidia 97f66cd9e9SStella Laurenzo static const char kDenseResourceElementsAttrGetFromBufferDocstring[] = 98f66cd9e9SStella Laurenzo R"(Gets a DenseResourceElementsAttr from a Python buffer or array. 99f66cd9e9SStella Laurenzo 100f66cd9e9SStella Laurenzo This function does minimal validation or massaging of the data, and it is 101f66cd9e9SStella Laurenzo up to the caller to ensure that the buffer meets the characteristics 102f66cd9e9SStella Laurenzo implied by the shape. 103f66cd9e9SStella Laurenzo 104f66cd9e9SStella Laurenzo The backing buffer and any user objects will be retained for the lifetime 105f66cd9e9SStella Laurenzo of the resource blob. This is typically bounded to the context but the 106f66cd9e9SStella Laurenzo resource can have a shorter lifespan depending on how it is used in 107f66cd9e9SStella Laurenzo subsequent processing. 108f66cd9e9SStella Laurenzo 109f66cd9e9SStella Laurenzo Args: 110f66cd9e9SStella Laurenzo buffer: The array or buffer to convert. 111f66cd9e9SStella Laurenzo name: Name to provide to the resource (may be changed upon collision). 112f66cd9e9SStella Laurenzo type: The explicit ShapedType to construct the attribute with. 113f66cd9e9SStella Laurenzo context: Explicit context, if not from context manager. 114f66cd9e9SStella Laurenzo 115f66cd9e9SStella Laurenzo Returns: 116f66cd9e9SStella Laurenzo DenseResourceElementsAttr on success. 117f66cd9e9SStella Laurenzo 118f66cd9e9SStella Laurenzo Raises: 119f66cd9e9SStella Laurenzo ValueError: If the type of the buffer or array cannot be matched to an MLIR 120f66cd9e9SStella Laurenzo type or if the buffer does not meet expectations. 121f66cd9e9SStella Laurenzo )"; 122f66cd9e9SStella Laurenzo 123436c6c9cSStella Laurenzo namespace { 124436c6c9cSStella Laurenzo 125436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) { 126436c6c9cSStella Laurenzo return mlirStringRefCreate(s.data(), s.size()); 127436c6c9cSStella Laurenzo } 128436c6c9cSStella Laurenzo 129436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 130436c6c9cSStella Laurenzo public: 131436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 132436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineMapAttr"; 133436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1349566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 1359566ee28Smax mlirAffineMapAttrGetTypeID; 136436c6c9cSStella Laurenzo 137436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 138436c6c9cSStella Laurenzo c.def_static( 139436c6c9cSStella Laurenzo "get", 140436c6c9cSStella Laurenzo [](PyAffineMap &affineMap) { 141436c6c9cSStella Laurenzo MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 142436c6c9cSStella Laurenzo return PyAffineMapAttribute(affineMap.getContext(), attr); 143436c6c9cSStella Laurenzo }, 144436c6c9cSStella Laurenzo py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 145436c6c9cSStella Laurenzo } 146436c6c9cSStella Laurenzo }; 147436c6c9cSStella Laurenzo 148ed9e52f3SAlex Zinenko template <typename T> 149ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) { 150ed9e52f3SAlex Zinenko try { 151ed9e52f3SAlex Zinenko return object.cast<T>(); 152ed9e52f3SAlex Zinenko } catch (py::cast_error &err) { 153ed9e52f3SAlex Zinenko std::string msg = 154ed9e52f3SAlex Zinenko std::string( 155ed9e52f3SAlex Zinenko "Invalid attribute when attempting to create an ArrayAttribute (") + 156ed9e52f3SAlex Zinenko err.what() + ")"; 157ed9e52f3SAlex Zinenko throw py::cast_error(msg); 158ed9e52f3SAlex Zinenko } catch (py::reference_cast_error &err) { 159ed9e52f3SAlex Zinenko std::string msg = std::string("Invalid attribute (None?) when attempting " 160ed9e52f3SAlex Zinenko "to create an ArrayAttribute (") + 161ed9e52f3SAlex Zinenko err.what() + ")"; 162ed9e52f3SAlex Zinenko throw py::cast_error(msg); 163ed9e52f3SAlex Zinenko } 164ed9e52f3SAlex Zinenko } 165ed9e52f3SAlex Zinenko 166619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived 167619fd8c2SJeff Niu /// implementation class. 168619fd8c2SJeff Niu template <typename EltTy, typename DerivedT> 169133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> { 170619fd8c2SJeff Niu public: 171133624acSJeff Niu using PyConcreteAttribute<DerivedT>::PyConcreteAttribute; 172619fd8c2SJeff Niu 173619fd8c2SJeff Niu /// Iterator over the integer elements of a dense array. 174619fd8c2SJeff Niu class PyDenseArrayIterator { 175619fd8c2SJeff Niu public: 1764a1b1196SMehdi Amini PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {} 177619fd8c2SJeff Niu 178619fd8c2SJeff Niu /// Return a copy of the iterator. 179619fd8c2SJeff Niu PyDenseArrayIterator dunderIter() { return *this; } 180619fd8c2SJeff Niu 181619fd8c2SJeff Niu /// Return the next element. 182619fd8c2SJeff Niu EltTy dunderNext() { 183619fd8c2SJeff Niu // Throw if the index has reached the end. 184619fd8c2SJeff Niu if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) 185619fd8c2SJeff Niu throw py::stop_iteration(); 186619fd8c2SJeff Niu return DerivedT::getElement(attr.get(), nextIndex++); 187619fd8c2SJeff Niu } 188619fd8c2SJeff Niu 189619fd8c2SJeff Niu /// Bind the iterator class. 190619fd8c2SJeff Niu static void bind(py::module &m) { 191619fd8c2SJeff Niu py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName, 192619fd8c2SJeff Niu py::module_local()) 193619fd8c2SJeff Niu .def("__iter__", &PyDenseArrayIterator::dunderIter) 194619fd8c2SJeff Niu .def("__next__", &PyDenseArrayIterator::dunderNext); 195619fd8c2SJeff Niu } 196619fd8c2SJeff Niu 197619fd8c2SJeff Niu private: 198619fd8c2SJeff Niu /// The referenced dense array attribute. 199619fd8c2SJeff Niu PyAttribute attr; 200619fd8c2SJeff Niu /// The next index to read. 201619fd8c2SJeff Niu int nextIndex = 0; 202619fd8c2SJeff Niu }; 203619fd8c2SJeff Niu 204619fd8c2SJeff Niu /// Get the element at the given index. 205619fd8c2SJeff Niu EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); } 206619fd8c2SJeff Niu 207619fd8c2SJeff Niu /// Bind the attribute class. 208133624acSJeff Niu static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) { 209619fd8c2SJeff Niu // Bind the constructor. 210619fd8c2SJeff Niu c.def_static( 211619fd8c2SJeff Niu "get", 212619fd8c2SJeff Niu [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) { 2138dcb6722SIngo Müller return getAttribute(values, ctx->getRef()); 214619fd8c2SJeff Niu }, 215619fd8c2SJeff Niu py::arg("values"), py::arg("context") = py::none(), 216619fd8c2SJeff Niu "Gets a uniqued dense array attribute"); 217619fd8c2SJeff Niu // Bind the array methods. 218133624acSJeff Niu c.def("__getitem__", [](DerivedT &arr, intptr_t i) { 219619fd8c2SJeff Niu if (i >= mlirDenseArrayGetNumElements(arr)) 220619fd8c2SJeff Niu throw py::index_error("DenseArray index out of range"); 221619fd8c2SJeff Niu return arr.getItem(i); 222619fd8c2SJeff Niu }); 223133624acSJeff Niu c.def("__len__", [](const DerivedT &arr) { 224619fd8c2SJeff Niu return mlirDenseArrayGetNumElements(arr); 225619fd8c2SJeff Niu }); 226133624acSJeff Niu c.def("__iter__", 227133624acSJeff Niu [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); 2284a1b1196SMehdi Amini c.def("__add__", [](DerivedT &arr, const py::list &extras) { 229619fd8c2SJeff Niu std::vector<EltTy> values; 230619fd8c2SJeff Niu intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); 231619fd8c2SJeff Niu values.reserve(numOldElements + py::len(extras)); 232619fd8c2SJeff Niu for (intptr_t i = 0; i < numOldElements; ++i) 233619fd8c2SJeff Niu values.push_back(arr.getItem(i)); 234619fd8c2SJeff Niu for (py::handle attr : extras) 235619fd8c2SJeff Niu values.push_back(pyTryCast<EltTy>(attr)); 2368dcb6722SIngo Müller return getAttribute(values, arr.getContext()); 237619fd8c2SJeff Niu }); 238619fd8c2SJeff Niu } 2398dcb6722SIngo Müller 2408dcb6722SIngo Müller private: 2418dcb6722SIngo Müller static DerivedT getAttribute(const std::vector<EltTy> &values, 2428dcb6722SIngo Müller PyMlirContextRef ctx) { 2438dcb6722SIngo Müller if constexpr (std::is_same_v<EltTy, bool>) { 2448dcb6722SIngo Müller std::vector<int> intValues(values.begin(), values.end()); 2458dcb6722SIngo Müller MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(), 2468dcb6722SIngo Müller intValues.data()); 2478dcb6722SIngo Müller return DerivedT(ctx, attr); 2488dcb6722SIngo Müller } else { 2498dcb6722SIngo Müller MlirAttribute attr = 2508dcb6722SIngo Müller DerivedT::getAttribute(ctx->get(), values.size(), values.data()); 2518dcb6722SIngo Müller return DerivedT(ctx, attr); 2528dcb6722SIngo Müller } 2538dcb6722SIngo Müller } 254619fd8c2SJeff Niu }; 255619fd8c2SJeff Niu 256619fd8c2SJeff Niu /// Instantiate the python dense array classes. 257619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute 2588dcb6722SIngo Müller : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> { 259619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; 260619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseBoolArrayGet; 261619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseBoolArrayGetElement; 262619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseBoolArrayAttr"; 263619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseBoolArrayIterator"; 264619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 265619fd8c2SJeff Niu }; 266619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute 267619fd8c2SJeff Niu : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> { 268619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array; 269619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI8ArrayGet; 270619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI8ArrayGetElement; 271619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI8ArrayAttr"; 272619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI8ArrayIterator"; 273619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 274619fd8c2SJeff Niu }; 275619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute 276619fd8c2SJeff Niu : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> { 277619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array; 278619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI16ArrayGet; 279619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI16ArrayGetElement; 280619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI16ArrayAttr"; 281619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI16ArrayIterator"; 282619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 283619fd8c2SJeff Niu }; 284619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute 285619fd8c2SJeff Niu : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> { 286619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array; 287619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI32ArrayGet; 288619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI32ArrayGetElement; 289619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI32ArrayAttr"; 290619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI32ArrayIterator"; 291619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 292619fd8c2SJeff Niu }; 293619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute 294619fd8c2SJeff Niu : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> { 295619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array; 296619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI64ArrayGet; 297619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI64ArrayGetElement; 298619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI64ArrayAttr"; 299619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI64ArrayIterator"; 300619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 301619fd8c2SJeff Niu }; 302619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute 303619fd8c2SJeff Niu : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> { 304619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array; 305619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseF32ArrayGet; 306619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseF32ArrayGetElement; 307619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseF32ArrayAttr"; 308619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseF32ArrayIterator"; 309619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 310619fd8c2SJeff Niu }; 311619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute 312619fd8c2SJeff Niu : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> { 313619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array; 314619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseF64ArrayGet; 315619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseF64ArrayGetElement; 316619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseF64ArrayAttr"; 317619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseF64ArrayIterator"; 318619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 319619fd8c2SJeff Niu }; 320619fd8c2SJeff Niu 321436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 322436c6c9cSStella Laurenzo public: 323436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 324436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "ArrayAttr"; 325436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 3269566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 3279566ee28Smax mlirArrayAttrGetTypeID; 328436c6c9cSStella Laurenzo 329436c6c9cSStella Laurenzo class PyArrayAttributeIterator { 330436c6c9cSStella Laurenzo public: 3311fc096afSMehdi Amini PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} 332436c6c9cSStella Laurenzo 333436c6c9cSStella Laurenzo PyArrayAttributeIterator &dunderIter() { return *this; } 334436c6c9cSStella Laurenzo 335974c1596SRahul Kayaith MlirAttribute dunderNext() { 336bca88952SJeff Niu // TODO: Throw is an inefficient way to stop iteration. 337bca88952SJeff Niu if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) 338436c6c9cSStella Laurenzo throw py::stop_iteration(); 339974c1596SRahul Kayaith return mlirArrayAttrGetElement(attr.get(), nextIndex++); 340436c6c9cSStella Laurenzo } 341436c6c9cSStella Laurenzo 342436c6c9cSStella Laurenzo static void bind(py::module &m) { 343f05ff4f7SStella Laurenzo py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator", 344f05ff4f7SStella Laurenzo py::module_local()) 345436c6c9cSStella Laurenzo .def("__iter__", &PyArrayAttributeIterator::dunderIter) 346436c6c9cSStella Laurenzo .def("__next__", &PyArrayAttributeIterator::dunderNext); 347436c6c9cSStella Laurenzo } 348436c6c9cSStella Laurenzo 349436c6c9cSStella Laurenzo private: 350436c6c9cSStella Laurenzo PyAttribute attr; 351436c6c9cSStella Laurenzo int nextIndex = 0; 352436c6c9cSStella Laurenzo }; 353436c6c9cSStella Laurenzo 354974c1596SRahul Kayaith MlirAttribute getItem(intptr_t i) { 355974c1596SRahul Kayaith return mlirArrayAttrGetElement(*this, i); 356ed9e52f3SAlex Zinenko } 357ed9e52f3SAlex Zinenko 358436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 359436c6c9cSStella Laurenzo c.def_static( 360436c6c9cSStella Laurenzo "get", 361436c6c9cSStella Laurenzo [](py::list attributes, DefaultingPyMlirContext context) { 362436c6c9cSStella Laurenzo SmallVector<MlirAttribute> mlirAttributes; 363436c6c9cSStella Laurenzo mlirAttributes.reserve(py::len(attributes)); 364436c6c9cSStella Laurenzo for (auto attribute : attributes) { 365ed9e52f3SAlex Zinenko mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); 366436c6c9cSStella Laurenzo } 367436c6c9cSStella Laurenzo MlirAttribute attr = mlirArrayAttrGet( 368436c6c9cSStella Laurenzo context->get(), mlirAttributes.size(), mlirAttributes.data()); 369436c6c9cSStella Laurenzo return PyArrayAttribute(context->getRef(), attr); 370436c6c9cSStella Laurenzo }, 371436c6c9cSStella Laurenzo py::arg("attributes"), py::arg("context") = py::none(), 372436c6c9cSStella Laurenzo "Gets a uniqued Array attribute"); 373436c6c9cSStella Laurenzo c.def("__getitem__", 374436c6c9cSStella Laurenzo [](PyArrayAttribute &arr, intptr_t i) { 375436c6c9cSStella Laurenzo if (i >= mlirArrayAttrGetNumElements(arr)) 376436c6c9cSStella Laurenzo throw py::index_error("ArrayAttribute index out of range"); 377ed9e52f3SAlex Zinenko return arr.getItem(i); 378436c6c9cSStella Laurenzo }) 379436c6c9cSStella Laurenzo .def("__len__", 380436c6c9cSStella Laurenzo [](const PyArrayAttribute &arr) { 381436c6c9cSStella Laurenzo return mlirArrayAttrGetNumElements(arr); 382436c6c9cSStella Laurenzo }) 383436c6c9cSStella Laurenzo .def("__iter__", [](const PyArrayAttribute &arr) { 384436c6c9cSStella Laurenzo return PyArrayAttributeIterator(arr); 385436c6c9cSStella Laurenzo }); 386ed9e52f3SAlex Zinenko c.def("__add__", [](PyArrayAttribute arr, py::list extras) { 387ed9e52f3SAlex Zinenko std::vector<MlirAttribute> attributes; 388ed9e52f3SAlex Zinenko intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); 389ed9e52f3SAlex Zinenko attributes.reserve(numOldElements + py::len(extras)); 390ed9e52f3SAlex Zinenko for (intptr_t i = 0; i < numOldElements; ++i) 391ed9e52f3SAlex Zinenko attributes.push_back(arr.getItem(i)); 392ed9e52f3SAlex Zinenko for (py::handle attr : extras) 393ed9e52f3SAlex Zinenko attributes.push_back(pyTryCast<PyAttribute>(attr)); 394ed9e52f3SAlex Zinenko MlirAttribute arrayAttr = mlirArrayAttrGet( 395ed9e52f3SAlex Zinenko arr.getContext()->get(), attributes.size(), attributes.data()); 396ed9e52f3SAlex Zinenko return PyArrayAttribute(arr.getContext(), arrayAttr); 397ed9e52f3SAlex Zinenko }); 398436c6c9cSStella Laurenzo } 399436c6c9cSStella Laurenzo }; 400436c6c9cSStella Laurenzo 401436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr. 402436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 403436c6c9cSStella Laurenzo public: 404436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 405436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FloatAttr"; 406436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 4079566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 4089566ee28Smax mlirFloatAttrGetTypeID; 409436c6c9cSStella Laurenzo 410436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 411436c6c9cSStella Laurenzo c.def_static( 412436c6c9cSStella Laurenzo "get", 413436c6c9cSStella Laurenzo [](PyType &type, double value, DefaultingPyLocation loc) { 4143ea4c501SRahul Kayaith PyMlirContext::ErrorCapture errors(loc->getContext()); 415436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 4163ea4c501SRahul Kayaith if (mlirAttributeIsNull(attr)) 4173ea4c501SRahul Kayaith throw MLIRError("Invalid attribute", errors.take()); 418436c6c9cSStella Laurenzo return PyFloatAttribute(type.getContext(), attr); 419436c6c9cSStella Laurenzo }, 420436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 421436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a type"); 422436c6c9cSStella Laurenzo c.def_static( 423436c6c9cSStella Laurenzo "get_f32", 424436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 425436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 426436c6c9cSStella Laurenzo context->get(), mlirF32TypeGet(context->get()), value); 427436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 428436c6c9cSStella Laurenzo }, 429436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 430436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f32 type"); 431436c6c9cSStella Laurenzo c.def_static( 432436c6c9cSStella Laurenzo "get_f64", 433436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 434436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 435436c6c9cSStella Laurenzo context->get(), mlirF64TypeGet(context->get()), value); 436436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 437436c6c9cSStella Laurenzo }, 438436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 439436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f64 type"); 4402a5d4974SIngo Müller c.def_property_readonly("value", mlirFloatAttrGetValueDouble, 4412a5d4974SIngo Müller "Returns the value of the float attribute"); 4422a5d4974SIngo Müller c.def("__float__", mlirFloatAttrGetValueDouble, 4432a5d4974SIngo Müller "Converts the value of the float attribute to a Python float"); 444436c6c9cSStella Laurenzo } 445436c6c9cSStella Laurenzo }; 446436c6c9cSStella Laurenzo 447436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr. 448436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 449436c6c9cSStella Laurenzo public: 450436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 451436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "IntegerAttr"; 452436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 453436c6c9cSStella Laurenzo 454436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 455436c6c9cSStella Laurenzo c.def_static( 456436c6c9cSStella Laurenzo "get", 457436c6c9cSStella Laurenzo [](PyType &type, int64_t value) { 458436c6c9cSStella Laurenzo MlirAttribute attr = mlirIntegerAttrGet(type, value); 459436c6c9cSStella Laurenzo return PyIntegerAttribute(type.getContext(), attr); 460436c6c9cSStella Laurenzo }, 461436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), 462436c6c9cSStella Laurenzo "Gets an uniqued integer attribute associated to a type"); 4632a5d4974SIngo Müller c.def_property_readonly("value", toPyInt, 4642a5d4974SIngo Müller "Returns the value of the integer attribute"); 4652a5d4974SIngo Müller c.def("__int__", toPyInt, 4662a5d4974SIngo Müller "Converts the value of the integer attribute to a Python int"); 4672a5d4974SIngo Müller c.def_property_readonly_static("static_typeid", 4682a5d4974SIngo Müller [](py::object & /*class*/) -> MlirTypeID { 4692a5d4974SIngo Müller return mlirIntegerAttrGetTypeID(); 4702a5d4974SIngo Müller }); 4712a5d4974SIngo Müller } 4722a5d4974SIngo Müller 4732a5d4974SIngo Müller private: 4742a5d4974SIngo Müller static py::int_ toPyInt(PyIntegerAttribute &self) { 475e9db306dSrkayaith MlirType type = mlirAttributeGetType(self); 476e9db306dSrkayaith if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) 477436c6c9cSStella Laurenzo return mlirIntegerAttrGetValueInt(self); 478e9db306dSrkayaith if (mlirIntegerTypeIsSigned(type)) 479e9db306dSrkayaith return mlirIntegerAttrGetValueSInt(self); 480e9db306dSrkayaith return mlirIntegerAttrGetValueUInt(self); 481436c6c9cSStella Laurenzo } 482436c6c9cSStella Laurenzo }; 483436c6c9cSStella Laurenzo 484436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr. 485436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 486436c6c9cSStella Laurenzo public: 487436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 488436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "BoolAttr"; 489436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 490436c6c9cSStella Laurenzo 491436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 492436c6c9cSStella Laurenzo c.def_static( 493436c6c9cSStella Laurenzo "get", 494436c6c9cSStella Laurenzo [](bool value, DefaultingPyMlirContext context) { 495436c6c9cSStella Laurenzo MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 496436c6c9cSStella Laurenzo return PyBoolAttribute(context->getRef(), attr); 497436c6c9cSStella Laurenzo }, 498436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 499436c6c9cSStella Laurenzo "Gets an uniqued bool attribute"); 5002a5d4974SIngo Müller c.def_property_readonly("value", mlirBoolAttrGetValue, 501436c6c9cSStella Laurenzo "Returns the value of the bool attribute"); 5022a5d4974SIngo Müller c.def("__bool__", mlirBoolAttrGetValue, 5032a5d4974SIngo Müller "Converts the value of the bool attribute to a Python bool"); 504436c6c9cSStella Laurenzo } 505436c6c9cSStella Laurenzo }; 506436c6c9cSStella Laurenzo 5074eee9ef9Smax class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> { 5084eee9ef9Smax public: 5094eee9ef9Smax static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef; 5104eee9ef9Smax static constexpr const char *pyClassName = "SymbolRefAttr"; 5114eee9ef9Smax using PyConcreteAttribute::PyConcreteAttribute; 5124eee9ef9Smax 5134eee9ef9Smax static MlirAttribute fromList(const std::vector<std::string> &symbols, 5144eee9ef9Smax PyMlirContext &context) { 5154eee9ef9Smax if (symbols.empty()) 5164eee9ef9Smax throw std::runtime_error("SymbolRefAttr must be composed of at least " 5174eee9ef9Smax "one symbol."); 5184eee9ef9Smax MlirStringRef rootSymbol = toMlirStringRef(symbols[0]); 5194eee9ef9Smax SmallVector<MlirAttribute, 3> referenceAttrs; 5204eee9ef9Smax for (size_t i = 1; i < symbols.size(); ++i) { 5214eee9ef9Smax referenceAttrs.push_back( 5224eee9ef9Smax mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i]))); 5234eee9ef9Smax } 5244eee9ef9Smax return mlirSymbolRefAttrGet(context.get(), rootSymbol, 5254eee9ef9Smax referenceAttrs.size(), referenceAttrs.data()); 5264eee9ef9Smax } 5274eee9ef9Smax 5284eee9ef9Smax static void bindDerived(ClassTy &c) { 5294eee9ef9Smax c.def_static( 5304eee9ef9Smax "get", 5314eee9ef9Smax [](const std::vector<std::string> &symbols, 5324eee9ef9Smax DefaultingPyMlirContext context) { 5334eee9ef9Smax return PySymbolRefAttribute::fromList(symbols, context.resolve()); 5344eee9ef9Smax }, 5354eee9ef9Smax py::arg("symbols"), py::arg("context") = py::none(), 5364eee9ef9Smax "Gets a uniqued SymbolRef attribute from a list of symbol names"); 5374eee9ef9Smax c.def_property_readonly( 5384eee9ef9Smax "value", 5394eee9ef9Smax [](PySymbolRefAttribute &self) { 5404eee9ef9Smax std::vector<std::string> symbols = { 5414eee9ef9Smax unwrap(mlirSymbolRefAttrGetRootReference(self)).str()}; 5424eee9ef9Smax for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); 5434eee9ef9Smax ++i) 5444eee9ef9Smax symbols.push_back( 5454eee9ef9Smax unwrap(mlirSymbolRefAttrGetRootReference( 5464eee9ef9Smax mlirSymbolRefAttrGetNestedReference(self, i))) 5474eee9ef9Smax .str()); 5484eee9ef9Smax return symbols; 5494eee9ef9Smax }, 5504eee9ef9Smax "Returns the value of the SymbolRef attribute as a list[str]"); 5514eee9ef9Smax } 5524eee9ef9Smax }; 5534eee9ef9Smax 554436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute 555436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 556436c6c9cSStella Laurenzo public: 557436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 558436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 559436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 560436c6c9cSStella Laurenzo 561436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 562436c6c9cSStella Laurenzo c.def_static( 563436c6c9cSStella Laurenzo "get", 564436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 565436c6c9cSStella Laurenzo MlirAttribute attr = 566436c6c9cSStella Laurenzo mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 567436c6c9cSStella Laurenzo return PyFlatSymbolRefAttribute(context->getRef(), attr); 568436c6c9cSStella Laurenzo }, 569436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 570436c6c9cSStella Laurenzo "Gets a uniqued FlatSymbolRef attribute"); 571436c6c9cSStella Laurenzo c.def_property_readonly( 572436c6c9cSStella Laurenzo "value", 573436c6c9cSStella Laurenzo [](PyFlatSymbolRefAttribute &self) { 574436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 575436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 576436c6c9cSStella Laurenzo }, 577436c6c9cSStella Laurenzo "Returns the value of the FlatSymbolRef attribute as a string"); 578436c6c9cSStella Laurenzo } 579436c6c9cSStella Laurenzo }; 580436c6c9cSStella Laurenzo 5815c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> { 5825c3861b2SYun Long public: 5835c3861b2SYun Long static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; 5845c3861b2SYun Long static constexpr const char *pyClassName = "OpaqueAttr"; 5855c3861b2SYun Long using PyConcreteAttribute::PyConcreteAttribute; 5869566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 5879566ee28Smax mlirOpaqueAttrGetTypeID; 5885c3861b2SYun Long 5895c3861b2SYun Long static void bindDerived(ClassTy &c) { 5905c3861b2SYun Long c.def_static( 5915c3861b2SYun Long "get", 5925c3861b2SYun Long [](std::string dialectNamespace, py::buffer buffer, PyType &type, 5935c3861b2SYun Long DefaultingPyMlirContext context) { 5945c3861b2SYun Long const py::buffer_info bufferInfo = buffer.request(); 5955c3861b2SYun Long intptr_t bufferSize = bufferInfo.size; 5965c3861b2SYun Long MlirAttribute attr = mlirOpaqueAttrGet( 5975c3861b2SYun Long context->get(), toMlirStringRef(dialectNamespace), bufferSize, 5985c3861b2SYun Long static_cast<char *>(bufferInfo.ptr), type); 5995c3861b2SYun Long return PyOpaqueAttribute(context->getRef(), attr); 6005c3861b2SYun Long }, 6015c3861b2SYun Long py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"), 6025c3861b2SYun Long py::arg("context") = py::none(), "Gets an Opaque attribute."); 6035c3861b2SYun Long c.def_property_readonly( 6045c3861b2SYun Long "dialect_namespace", 6055c3861b2SYun Long [](PyOpaqueAttribute &self) { 6065c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); 6075c3861b2SYun Long return py::str(stringRef.data, stringRef.length); 6085c3861b2SYun Long }, 6095c3861b2SYun Long "Returns the dialect namespace for the Opaque attribute as a string"); 6105c3861b2SYun Long c.def_property_readonly( 6115c3861b2SYun Long "data", 6125c3861b2SYun Long [](PyOpaqueAttribute &self) { 6135c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetData(self); 61462bf6c2eSChris Jones return py::bytes(stringRef.data, stringRef.length); 6155c3861b2SYun Long }, 61662bf6c2eSChris Jones "Returns the data for the Opaqued attributes as `bytes`"); 6175c3861b2SYun Long } 6185c3861b2SYun Long }; 6195c3861b2SYun Long 620436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 621436c6c9cSStella Laurenzo public: 622436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 623436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "StringAttr"; 624436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 6259566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 6269566ee28Smax mlirStringAttrGetTypeID; 627436c6c9cSStella Laurenzo 628436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 629436c6c9cSStella Laurenzo c.def_static( 630436c6c9cSStella Laurenzo "get", 631436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 632436c6c9cSStella Laurenzo MlirAttribute attr = 633436c6c9cSStella Laurenzo mlirStringAttrGet(context->get(), toMlirStringRef(value)); 634436c6c9cSStella Laurenzo return PyStringAttribute(context->getRef(), attr); 635436c6c9cSStella Laurenzo }, 636436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 637436c6c9cSStella Laurenzo "Gets a uniqued string attribute"); 638436c6c9cSStella Laurenzo c.def_static( 639436c6c9cSStella Laurenzo "get_typed", 640436c6c9cSStella Laurenzo [](PyType &type, std::string value) { 641436c6c9cSStella Laurenzo MlirAttribute attr = 642436c6c9cSStella Laurenzo mlirStringAttrTypedGet(type, toMlirStringRef(value)); 643436c6c9cSStella Laurenzo return PyStringAttribute(type.getContext(), attr); 644436c6c9cSStella Laurenzo }, 645a6e7d024SStella Laurenzo py::arg("type"), py::arg("value"), 646436c6c9cSStella Laurenzo "Gets a uniqued string attribute associated to a type"); 6479f533548SIngo Müller c.def_property_readonly( 6489f533548SIngo Müller "value", 6499f533548SIngo Müller [](PyStringAttribute &self) { 6509f533548SIngo Müller MlirStringRef stringRef = mlirStringAttrGetValue(self); 6519f533548SIngo Müller return py::str(stringRef.data, stringRef.length); 6529f533548SIngo Müller }, 653436c6c9cSStella Laurenzo "Returns the value of the string attribute"); 65462bf6c2eSChris Jones c.def_property_readonly( 65562bf6c2eSChris Jones "value_bytes", 65662bf6c2eSChris Jones [](PyStringAttribute &self) { 65762bf6c2eSChris Jones MlirStringRef stringRef = mlirStringAttrGetValue(self); 65862bf6c2eSChris Jones return py::bytes(stringRef.data, stringRef.length); 65962bf6c2eSChris Jones }, 66062bf6c2eSChris Jones "Returns the value of the string attribute as `bytes`"); 661436c6c9cSStella Laurenzo } 662436c6c9cSStella Laurenzo }; 663436c6c9cSStella Laurenzo 664436c6c9cSStella Laurenzo // TODO: Support construction of string elements. 665436c6c9cSStella Laurenzo class PyDenseElementsAttribute 666436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseElementsAttribute> { 667436c6c9cSStella Laurenzo public: 668436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 669436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseElementsAttr"; 670436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 671436c6c9cSStella Laurenzo 672436c6c9cSStella Laurenzo static PyDenseElementsAttribute 673*c912f0e7Spranavm-nvidia getFromList(py::list attributes, std::optional<PyType> explicitType, 674*c912f0e7Spranavm-nvidia DefaultingPyMlirContext contextWrapper) { 675*c912f0e7Spranavm-nvidia 676*c912f0e7Spranavm-nvidia const size_t numAttributes = py::len(attributes); 677*c912f0e7Spranavm-nvidia if (numAttributes == 0) 678*c912f0e7Spranavm-nvidia throw py::value_error("Attributes list must be non-empty."); 679*c912f0e7Spranavm-nvidia 680*c912f0e7Spranavm-nvidia MlirType shapedType; 681*c912f0e7Spranavm-nvidia if (explicitType) { 682*c912f0e7Spranavm-nvidia if ((!mlirTypeIsAShaped(*explicitType) || 683*c912f0e7Spranavm-nvidia !mlirShapedTypeHasStaticShape(*explicitType))) { 684*c912f0e7Spranavm-nvidia 685*c912f0e7Spranavm-nvidia std::string message; 686*c912f0e7Spranavm-nvidia llvm::raw_string_ostream os(message); 687*c912f0e7Spranavm-nvidia os << "Expected a static ShapedType for the shaped_type parameter: " 688*c912f0e7Spranavm-nvidia << py::repr(py::cast(*explicitType)); 689*c912f0e7Spranavm-nvidia throw py::value_error(os.str()); 690*c912f0e7Spranavm-nvidia } 691*c912f0e7Spranavm-nvidia shapedType = *explicitType; 692*c912f0e7Spranavm-nvidia } else { 693*c912f0e7Spranavm-nvidia SmallVector<int64_t> shape{static_cast<int64_t>(numAttributes)}; 694*c912f0e7Spranavm-nvidia shapedType = mlirRankedTensorTypeGet( 695*c912f0e7Spranavm-nvidia shape.size(), shape.data(), 696*c912f0e7Spranavm-nvidia mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])), 697*c912f0e7Spranavm-nvidia mlirAttributeGetNull()); 698*c912f0e7Spranavm-nvidia } 699*c912f0e7Spranavm-nvidia 700*c912f0e7Spranavm-nvidia SmallVector<MlirAttribute> mlirAttributes; 701*c912f0e7Spranavm-nvidia mlirAttributes.reserve(numAttributes); 702*c912f0e7Spranavm-nvidia for (const py::handle &attribute : attributes) { 703*c912f0e7Spranavm-nvidia MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute); 704*c912f0e7Spranavm-nvidia MlirType attrType = mlirAttributeGetType(mlirAttribute); 705*c912f0e7Spranavm-nvidia mlirAttributes.push_back(mlirAttribute); 706*c912f0e7Spranavm-nvidia 707*c912f0e7Spranavm-nvidia if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) { 708*c912f0e7Spranavm-nvidia std::string message; 709*c912f0e7Spranavm-nvidia llvm::raw_string_ostream os(message); 710*c912f0e7Spranavm-nvidia os << "All attributes must be of the same type and match " 711*c912f0e7Spranavm-nvidia << "the type parameter: expected=" << py::repr(py::cast(shapedType)) 712*c912f0e7Spranavm-nvidia << ", but got=" << py::repr(py::cast(attrType)); 713*c912f0e7Spranavm-nvidia throw py::value_error(os.str()); 714*c912f0e7Spranavm-nvidia } 715*c912f0e7Spranavm-nvidia } 716*c912f0e7Spranavm-nvidia 717*c912f0e7Spranavm-nvidia MlirAttribute elements = mlirDenseElementsAttrGet( 718*c912f0e7Spranavm-nvidia shapedType, mlirAttributes.size(), mlirAttributes.data()); 719*c912f0e7Spranavm-nvidia 720*c912f0e7Spranavm-nvidia return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 721*c912f0e7Spranavm-nvidia } 722*c912f0e7Spranavm-nvidia 723*c912f0e7Spranavm-nvidia static PyDenseElementsAttribute 7240a81ace0SKazu Hirata getFromBuffer(py::buffer array, bool signless, 7250a81ace0SKazu Hirata std::optional<PyType> explicitType, 7260a81ace0SKazu Hirata std::optional<std::vector<int64_t>> explicitShape, 727436c6c9cSStella Laurenzo DefaultingPyMlirContext contextWrapper) { 728436c6c9cSStella Laurenzo // Request a contiguous view. In exotic cases, this will cause a copy. 72971a25454SPeter Hawkins int flags = PyBUF_ND; 73071a25454SPeter Hawkins if (!explicitType) { 73171a25454SPeter Hawkins flags |= PyBUF_FORMAT; 73271a25454SPeter Hawkins } 73371a25454SPeter Hawkins Py_buffer view; 73471a25454SPeter Hawkins if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) { 735436c6c9cSStella Laurenzo throw py::error_already_set(); 736436c6c9cSStella Laurenzo } 73771a25454SPeter Hawkins auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); 7385d6d30edSStella Laurenzo SmallVector<int64_t> shape; 7395d6d30edSStella Laurenzo if (explicitShape) { 7405d6d30edSStella Laurenzo shape.append(explicitShape->begin(), explicitShape->end()); 7415d6d30edSStella Laurenzo } else { 74271a25454SPeter Hawkins shape.append(view.shape, view.shape + view.ndim); 7435d6d30edSStella Laurenzo } 744436c6c9cSStella Laurenzo 7455d6d30edSStella Laurenzo MlirAttribute encodingAttr = mlirAttributeGetNull(); 746436c6c9cSStella Laurenzo MlirContext context = contextWrapper->get(); 7475d6d30edSStella Laurenzo 7485d6d30edSStella Laurenzo // Detect format codes that are suitable for bulk loading. This includes 7495d6d30edSStella Laurenzo // all byte aligned integer and floating point types up to 8 bytes. 7505d6d30edSStella Laurenzo // Notably, this excludes, bool (which needs to be bit-packed) and 7515d6d30edSStella Laurenzo // other exotics which do not have a direct representation in the buffer 7525d6d30edSStella Laurenzo // protocol (i.e. complex, etc). 7530a81ace0SKazu Hirata std::optional<MlirType> bulkLoadElementType; 7545d6d30edSStella Laurenzo if (explicitType) { 7555d6d30edSStella Laurenzo bulkLoadElementType = *explicitType; 75671a25454SPeter Hawkins } else { 75771a25454SPeter Hawkins std::string_view format(view.format); 75871a25454SPeter Hawkins if (format == "f") { 759436c6c9cSStella Laurenzo // f32 76071a25454SPeter Hawkins assert(view.itemsize == 4 && "mismatched array itemsize"); 7615d6d30edSStella Laurenzo bulkLoadElementType = mlirF32TypeGet(context); 76271a25454SPeter Hawkins } else if (format == "d") { 763436c6c9cSStella Laurenzo // f64 76471a25454SPeter Hawkins assert(view.itemsize == 8 && "mismatched array itemsize"); 7655d6d30edSStella Laurenzo bulkLoadElementType = mlirF64TypeGet(context); 76671a25454SPeter Hawkins } else if (format == "e") { 7675d6d30edSStella Laurenzo // f16 76871a25454SPeter Hawkins assert(view.itemsize == 2 && "mismatched array itemsize"); 7695d6d30edSStella Laurenzo bulkLoadElementType = mlirF16TypeGet(context); 77071a25454SPeter Hawkins } else if (isSignedIntegerFormat(format)) { 77171a25454SPeter Hawkins if (view.itemsize == 4) { 772436c6c9cSStella Laurenzo // i32 77371a25454SPeter Hawkins bulkLoadElementType = signless 77471a25454SPeter Hawkins ? mlirIntegerTypeGet(context, 32) 775436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 32); 77671a25454SPeter Hawkins } else if (view.itemsize == 8) { 777436c6c9cSStella Laurenzo // i64 77871a25454SPeter Hawkins bulkLoadElementType = signless 77971a25454SPeter Hawkins ? mlirIntegerTypeGet(context, 64) 780436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 64); 78171a25454SPeter Hawkins } else if (view.itemsize == 1) { 7825d6d30edSStella Laurenzo // i8 7835d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 7845d6d30edSStella Laurenzo : mlirIntegerTypeSignedGet(context, 8); 78571a25454SPeter Hawkins } else if (view.itemsize == 2) { 7865d6d30edSStella Laurenzo // i16 78771a25454SPeter Hawkins bulkLoadElementType = signless 78871a25454SPeter Hawkins ? mlirIntegerTypeGet(context, 16) 7895d6d30edSStella Laurenzo : mlirIntegerTypeSignedGet(context, 16); 790436c6c9cSStella Laurenzo } 79171a25454SPeter Hawkins } else if (isUnsignedIntegerFormat(format)) { 79271a25454SPeter Hawkins if (view.itemsize == 4) { 793436c6c9cSStella Laurenzo // unsigned i32 7945d6d30edSStella Laurenzo bulkLoadElementType = signless 795436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 32) 796436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 32); 79771a25454SPeter Hawkins } else if (view.itemsize == 8) { 798436c6c9cSStella Laurenzo // unsigned i64 7995d6d30edSStella Laurenzo bulkLoadElementType = signless 800436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 64) 801436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 64); 80271a25454SPeter Hawkins } else if (view.itemsize == 1) { 8035d6d30edSStella Laurenzo // i8 80471a25454SPeter Hawkins bulkLoadElementType = signless 80571a25454SPeter Hawkins ? mlirIntegerTypeGet(context, 8) 8065d6d30edSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 8); 80771a25454SPeter Hawkins } else if (view.itemsize == 2) { 8085d6d30edSStella Laurenzo // i16 8095d6d30edSStella Laurenzo bulkLoadElementType = signless 8105d6d30edSStella Laurenzo ? mlirIntegerTypeGet(context, 16) 8115d6d30edSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 16); 812436c6c9cSStella Laurenzo } 813436c6c9cSStella Laurenzo } 81471a25454SPeter Hawkins if (!bulkLoadElementType) { 81571a25454SPeter Hawkins throw std::invalid_argument( 81671a25454SPeter Hawkins std::string("unimplemented array format conversion from format: ") + 81771a25454SPeter Hawkins std::string(format)); 81871a25454SPeter Hawkins } 81971a25454SPeter Hawkins } 82071a25454SPeter Hawkins 82199dee31eSAdam Paszke MlirType shapedType; 82299dee31eSAdam Paszke if (mlirTypeIsAShaped(*bulkLoadElementType)) { 82399dee31eSAdam Paszke if (explicitShape) { 82499dee31eSAdam Paszke throw std::invalid_argument("Shape can only be specified explicitly " 82599dee31eSAdam Paszke "when the type is not a shaped type."); 82699dee31eSAdam Paszke } 82799dee31eSAdam Paszke shapedType = *bulkLoadElementType; 82899dee31eSAdam Paszke } else { 82971a25454SPeter Hawkins shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), 83071a25454SPeter Hawkins *bulkLoadElementType, encodingAttr); 83199dee31eSAdam Paszke } 83271a25454SPeter Hawkins size_t rawBufferSize = view.len; 83371a25454SPeter Hawkins MlirAttribute attr = 83471a25454SPeter Hawkins mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf); 8355d6d30edSStella Laurenzo if (mlirAttributeIsNull(attr)) { 8365d6d30edSStella Laurenzo throw std::invalid_argument( 8375d6d30edSStella Laurenzo "DenseElementsAttr could not be constructed from the given buffer. " 8385d6d30edSStella Laurenzo "This may mean that the Python buffer layout does not match that " 8395d6d30edSStella Laurenzo "MLIR expected layout and is a bug."); 8405d6d30edSStella Laurenzo } 8415d6d30edSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), attr); 8425d6d30edSStella Laurenzo } 843436c6c9cSStella Laurenzo 8441fc096afSMehdi Amini static PyDenseElementsAttribute getSplat(const PyType &shapedType, 845436c6c9cSStella Laurenzo PyAttribute &elementAttr) { 846436c6c9cSStella Laurenzo auto contextWrapper = 847436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 848436c6c9cSStella Laurenzo if (!mlirAttributeIsAInteger(elementAttr) && 849436c6c9cSStella Laurenzo !mlirAttributeIsAFloat(elementAttr)) { 850436c6c9cSStella Laurenzo std::string message = "Illegal element type for DenseElementsAttr: "; 851436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 8524811270bSmax throw py::value_error(message); 853436c6c9cSStella Laurenzo } 854436c6c9cSStella Laurenzo if (!mlirTypeIsAShaped(shapedType) || 855436c6c9cSStella Laurenzo !mlirShapedTypeHasStaticShape(shapedType)) { 856436c6c9cSStella Laurenzo std::string message = 857436c6c9cSStella Laurenzo "Expected a static ShapedType for the shaped_type parameter: "; 858436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 8594811270bSmax throw py::value_error(message); 860436c6c9cSStella Laurenzo } 861436c6c9cSStella Laurenzo MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 862436c6c9cSStella Laurenzo MlirType attrType = mlirAttributeGetType(elementAttr); 863436c6c9cSStella Laurenzo if (!mlirTypeEqual(shapedElementType, attrType)) { 864436c6c9cSStella Laurenzo std::string message = 865436c6c9cSStella Laurenzo "Shaped element type and attribute type must be equal: shaped="; 866436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 867436c6c9cSStella Laurenzo message.append(", element="); 868436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 8694811270bSmax throw py::value_error(message); 870436c6c9cSStella Laurenzo } 871436c6c9cSStella Laurenzo 872436c6c9cSStella Laurenzo MlirAttribute elements = 873436c6c9cSStella Laurenzo mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 874436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 875436c6c9cSStella Laurenzo } 876436c6c9cSStella Laurenzo 877436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 878436c6c9cSStella Laurenzo 879436c6c9cSStella Laurenzo py::buffer_info accessBuffer() { 880436c6c9cSStella Laurenzo MlirType shapedType = mlirAttributeGetType(*this); 881436c6c9cSStella Laurenzo MlirType elementType = mlirShapedTypeGetElementType(shapedType); 8825d6d30edSStella Laurenzo std::string format; 883436c6c9cSStella Laurenzo 884436c6c9cSStella Laurenzo if (mlirTypeIsAF32(elementType)) { 885436c6c9cSStella Laurenzo // f32 8865d6d30edSStella Laurenzo return bufferInfo<float>(shapedType); 88702b6fb21SMehdi Amini } 88802b6fb21SMehdi Amini if (mlirTypeIsAF64(elementType)) { 889436c6c9cSStella Laurenzo // f64 8905d6d30edSStella Laurenzo return bufferInfo<double>(shapedType); 891bb56c2b3SMehdi Amini } 892bb56c2b3SMehdi Amini if (mlirTypeIsAF16(elementType)) { 8935d6d30edSStella Laurenzo // f16 8945d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType, "e"); 895bb56c2b3SMehdi Amini } 896ef1b735dSmax if (mlirTypeIsAIndex(elementType)) { 897ef1b735dSmax // Same as IndexType::kInternalStorageBitWidth 898ef1b735dSmax return bufferInfo<int64_t>(shapedType); 899ef1b735dSmax } 900bb56c2b3SMehdi Amini if (mlirTypeIsAInteger(elementType) && 901436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 32) { 902436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 903436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 904436c6c9cSStella Laurenzo // i32 9055d6d30edSStella Laurenzo return bufferInfo<int32_t>(shapedType); 906e5639b3fSMehdi Amini } 907e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 908436c6c9cSStella Laurenzo // unsigned i32 9095d6d30edSStella Laurenzo return bufferInfo<uint32_t>(shapedType); 910436c6c9cSStella Laurenzo } 911436c6c9cSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 912436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 64) { 913436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 914436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 915436c6c9cSStella Laurenzo // i64 9165d6d30edSStella Laurenzo return bufferInfo<int64_t>(shapedType); 917e5639b3fSMehdi Amini } 918e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 919436c6c9cSStella Laurenzo // unsigned i64 9205d6d30edSStella Laurenzo return bufferInfo<uint64_t>(shapedType); 9215d6d30edSStella Laurenzo } 9225d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 9235d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 8) { 9245d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 9255d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 9265d6d30edSStella Laurenzo // i8 9275d6d30edSStella Laurenzo return bufferInfo<int8_t>(shapedType); 928e5639b3fSMehdi Amini } 929e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 9305d6d30edSStella Laurenzo // unsigned i8 9315d6d30edSStella Laurenzo return bufferInfo<uint8_t>(shapedType); 9325d6d30edSStella Laurenzo } 9335d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 9345d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 16) { 9355d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 9365d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 9375d6d30edSStella Laurenzo // i16 9385d6d30edSStella Laurenzo return bufferInfo<int16_t>(shapedType); 939e5639b3fSMehdi Amini } 940e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 9415d6d30edSStella Laurenzo // unsigned i16 9425d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType); 943436c6c9cSStella Laurenzo } 944436c6c9cSStella Laurenzo } 945436c6c9cSStella Laurenzo 946c5f445d1SStella Laurenzo // TODO: Currently crashes the program. 9475d6d30edSStella Laurenzo // Reported as https://github.com/pybind/pybind11/issues/3336 948c5f445d1SStella Laurenzo throw std::invalid_argument( 949c5f445d1SStella Laurenzo "unsupported data type for conversion to Python buffer"); 950436c6c9cSStella Laurenzo } 951436c6c9cSStella Laurenzo 952436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 953436c6c9cSStella Laurenzo c.def("__len__", &PyDenseElementsAttribute::dunderLen) 954436c6c9cSStella Laurenzo .def_static("get", PyDenseElementsAttribute::getFromBuffer, 955436c6c9cSStella Laurenzo py::arg("array"), py::arg("signless") = true, 9565d6d30edSStella Laurenzo py::arg("type") = py::none(), py::arg("shape") = py::none(), 957436c6c9cSStella Laurenzo py::arg("context") = py::none(), 9585d6d30edSStella Laurenzo kDenseElementsAttrGetDocstring) 959*c912f0e7Spranavm-nvidia .def_static("get", PyDenseElementsAttribute::getFromList, 960*c912f0e7Spranavm-nvidia py::arg("attrs"), py::arg("type") = py::none(), 961*c912f0e7Spranavm-nvidia py::arg("context") = py::none(), 962*c912f0e7Spranavm-nvidia kDenseElementsAttrGetFromListDocstring) 963436c6c9cSStella Laurenzo .def_static("get_splat", PyDenseElementsAttribute::getSplat, 964436c6c9cSStella Laurenzo py::arg("shaped_type"), py::arg("element_attr"), 965436c6c9cSStella Laurenzo "Gets a DenseElementsAttr where all values are the same") 966436c6c9cSStella Laurenzo .def_property_readonly("is_splat", 967436c6c9cSStella Laurenzo [](PyDenseElementsAttribute &self) -> bool { 968436c6c9cSStella Laurenzo return mlirDenseElementsAttrIsSplat(self); 969436c6c9cSStella Laurenzo }) 97091259963SAdam Paszke .def("get_splat_value", 971974c1596SRahul Kayaith [](PyDenseElementsAttribute &self) { 972974c1596SRahul Kayaith if (!mlirDenseElementsAttrIsSplat(self)) 9734811270bSmax throw py::value_error( 97491259963SAdam Paszke "get_splat_value called on a non-splat attribute"); 975974c1596SRahul Kayaith return mlirDenseElementsAttrGetSplatValue(self); 97691259963SAdam Paszke }) 977436c6c9cSStella Laurenzo .def_buffer(&PyDenseElementsAttribute::accessBuffer); 978436c6c9cSStella Laurenzo } 979436c6c9cSStella Laurenzo 980436c6c9cSStella Laurenzo private: 98171a25454SPeter Hawkins static bool isUnsignedIntegerFormat(std::string_view format) { 982436c6c9cSStella Laurenzo if (format.empty()) 983436c6c9cSStella Laurenzo return false; 984436c6c9cSStella Laurenzo char code = format[0]; 985436c6c9cSStella Laurenzo return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 986436c6c9cSStella Laurenzo code == 'Q'; 987436c6c9cSStella Laurenzo } 988436c6c9cSStella Laurenzo 98971a25454SPeter Hawkins static bool isSignedIntegerFormat(std::string_view format) { 990436c6c9cSStella Laurenzo if (format.empty()) 991436c6c9cSStella Laurenzo return false; 992436c6c9cSStella Laurenzo char code = format[0]; 993436c6c9cSStella Laurenzo return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 994436c6c9cSStella Laurenzo code == 'q'; 995436c6c9cSStella Laurenzo } 996436c6c9cSStella Laurenzo 997436c6c9cSStella Laurenzo template <typename Type> 998436c6c9cSStella Laurenzo py::buffer_info bufferInfo(MlirType shapedType, 9995d6d30edSStella Laurenzo const char *explicitFormat = nullptr) { 1000436c6c9cSStella Laurenzo intptr_t rank = mlirShapedTypeGetRank(shapedType); 1001436c6c9cSStella Laurenzo // Prepare the data for the buffer_info. 1002436c6c9cSStella Laurenzo // Buffer is configured for read-only access below. 1003436c6c9cSStella Laurenzo Type *data = static_cast<Type *>( 1004436c6c9cSStella Laurenzo const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 1005436c6c9cSStella Laurenzo // Prepare the shape for the buffer_info. 1006436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> shape; 1007436c6c9cSStella Laurenzo for (intptr_t i = 0; i < rank; ++i) 1008436c6c9cSStella Laurenzo shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 1009436c6c9cSStella Laurenzo // Prepare the strides for the buffer_info. 1010436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> strides; 1011f0e847d0SRahul Kayaith if (mlirDenseElementsAttrIsSplat(*this)) { 1012f0e847d0SRahul Kayaith // Splats are special, only the single value is stored. 1013f0e847d0SRahul Kayaith strides.assign(rank, 0); 1014f0e847d0SRahul Kayaith } else { 1015436c6c9cSStella Laurenzo for (intptr_t i = 1; i < rank; ++i) { 1016f0e847d0SRahul Kayaith intptr_t strideFactor = 1; 1017f0e847d0SRahul Kayaith for (intptr_t j = i; j < rank; ++j) 1018436c6c9cSStella Laurenzo strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 1019436c6c9cSStella Laurenzo strides.push_back(sizeof(Type) * strideFactor); 1020436c6c9cSStella Laurenzo } 1021436c6c9cSStella Laurenzo strides.push_back(sizeof(Type)); 1022f0e847d0SRahul Kayaith } 10235d6d30edSStella Laurenzo std::string format; 10245d6d30edSStella Laurenzo if (explicitFormat) { 10255d6d30edSStella Laurenzo format = explicitFormat; 10265d6d30edSStella Laurenzo } else { 10275d6d30edSStella Laurenzo format = py::format_descriptor<Type>::format(); 10285d6d30edSStella Laurenzo } 10295d6d30edSStella Laurenzo return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, 10305d6d30edSStella Laurenzo /*readonly=*/true); 1031436c6c9cSStella Laurenzo } 1032436c6c9cSStella Laurenzo }; // namespace 1033436c6c9cSStella Laurenzo 1034436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer 1035436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access. 1036436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute 1037436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseIntElementsAttribute, 1038436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 1039436c6c9cSStella Laurenzo public: 1040436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 1041436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseIntElementsAttr"; 1042436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1043436c6c9cSStella Laurenzo 1044436c6c9cSStella Laurenzo /// Returns the element at the given linear position. Asserts if the index is 1045436c6c9cSStella Laurenzo /// out of range. 1046436c6c9cSStella Laurenzo py::int_ dunderGetItem(intptr_t pos) { 1047436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 10484811270bSmax throw py::index_error("attempt to access out of bounds element"); 1049436c6c9cSStella Laurenzo } 1050436c6c9cSStella Laurenzo 1051436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 1052436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 1053436c6c9cSStella Laurenzo assert(mlirTypeIsAInteger(type) && 1054436c6c9cSStella Laurenzo "expected integer element type in dense int elements attribute"); 1055436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 1056436c6c9cSStella Laurenzo // elemental type of the attribute. py::int_ is implicitly constructible 1057436c6c9cSStella Laurenzo // from any C++ integral type and handles bitwidth correctly. 1058436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 1059436c6c9cSStella Laurenzo // querying them on each element access. 1060436c6c9cSStella Laurenzo unsigned width = mlirIntegerTypeGetWidth(type); 1061436c6c9cSStella Laurenzo bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 1062436c6c9cSStella Laurenzo if (isUnsigned) { 1063436c6c9cSStella Laurenzo if (width == 1) { 1064436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 1065436c6c9cSStella Laurenzo } 1066308d8b8cSRahul Kayaith if (width == 8) { 1067308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetUInt8Value(*this, pos); 1068308d8b8cSRahul Kayaith } 1069308d8b8cSRahul Kayaith if (width == 16) { 1070308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetUInt16Value(*this, pos); 1071308d8b8cSRahul Kayaith } 1072436c6c9cSStella Laurenzo if (width == 32) { 1073436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt32Value(*this, pos); 1074436c6c9cSStella Laurenzo } 1075436c6c9cSStella Laurenzo if (width == 64) { 1076436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt64Value(*this, pos); 1077436c6c9cSStella Laurenzo } 1078436c6c9cSStella Laurenzo } else { 1079436c6c9cSStella Laurenzo if (width == 1) { 1080436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 1081436c6c9cSStella Laurenzo } 1082308d8b8cSRahul Kayaith if (width == 8) { 1083308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetInt8Value(*this, pos); 1084308d8b8cSRahul Kayaith } 1085308d8b8cSRahul Kayaith if (width == 16) { 1086308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetInt16Value(*this, pos); 1087308d8b8cSRahul Kayaith } 1088436c6c9cSStella Laurenzo if (width == 32) { 1089436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt32Value(*this, pos); 1090436c6c9cSStella Laurenzo } 1091436c6c9cSStella Laurenzo if (width == 64) { 1092436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt64Value(*this, pos); 1093436c6c9cSStella Laurenzo } 1094436c6c9cSStella Laurenzo } 10954811270bSmax throw py::type_error("Unsupported integer type"); 1096436c6c9cSStella Laurenzo } 1097436c6c9cSStella Laurenzo 1098436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1099436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 1100436c6c9cSStella Laurenzo } 1101436c6c9cSStella Laurenzo }; 1102436c6c9cSStella Laurenzo 1103f66cd9e9SStella Laurenzo class PyDenseResourceElementsAttribute 1104f66cd9e9SStella Laurenzo : public PyConcreteAttribute<PyDenseResourceElementsAttribute> { 1105f66cd9e9SStella Laurenzo public: 1106f66cd9e9SStella Laurenzo static constexpr IsAFunctionTy isaFunction = 1107f66cd9e9SStella Laurenzo mlirAttributeIsADenseResourceElements; 1108f66cd9e9SStella Laurenzo static constexpr const char *pyClassName = "DenseResourceElementsAttr"; 1109f66cd9e9SStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1110f66cd9e9SStella Laurenzo 1111f66cd9e9SStella Laurenzo static PyDenseResourceElementsAttribute 1112962bf002SMehdi Amini getFromBuffer(py::buffer buffer, const std::string &name, const PyType &type, 1113f66cd9e9SStella Laurenzo std::optional<size_t> alignment, bool isMutable, 1114f66cd9e9SStella Laurenzo DefaultingPyMlirContext contextWrapper) { 1115f66cd9e9SStella Laurenzo if (!mlirTypeIsAShaped(type)) { 1116f66cd9e9SStella Laurenzo throw std::invalid_argument( 1117f66cd9e9SStella Laurenzo "Constructing a DenseResourceElementsAttr requires a ShapedType."); 1118f66cd9e9SStella Laurenzo } 1119f66cd9e9SStella Laurenzo 1120f66cd9e9SStella Laurenzo // Do not request any conversions as we must ensure to use caller 1121f66cd9e9SStella Laurenzo // managed memory. 1122f66cd9e9SStella Laurenzo int flags = PyBUF_STRIDES; 1123f66cd9e9SStella Laurenzo std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>(); 1124f66cd9e9SStella Laurenzo if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { 1125f66cd9e9SStella Laurenzo throw py::error_already_set(); 1126f66cd9e9SStella Laurenzo } 1127f66cd9e9SStella Laurenzo 1128f66cd9e9SStella Laurenzo // This scope releaser will only release if we haven't yet transferred 1129f66cd9e9SStella Laurenzo // ownership. 1130f66cd9e9SStella Laurenzo auto freeBuffer = llvm::make_scope_exit([&]() { 1131f66cd9e9SStella Laurenzo if (view) 1132f66cd9e9SStella Laurenzo PyBuffer_Release(view.get()); 1133f66cd9e9SStella Laurenzo }); 1134f66cd9e9SStella Laurenzo 1135f66cd9e9SStella Laurenzo if (!PyBuffer_IsContiguous(view.get(), 'A')) { 1136f66cd9e9SStella Laurenzo throw std::invalid_argument("Contiguous buffer is required."); 1137f66cd9e9SStella Laurenzo } 1138f66cd9e9SStella Laurenzo 1139f66cd9e9SStella Laurenzo // Infer alignment to be the stride of one element if not explicit. 1140f66cd9e9SStella Laurenzo size_t inferredAlignment; 1141f66cd9e9SStella Laurenzo if (alignment) 1142f66cd9e9SStella Laurenzo inferredAlignment = *alignment; 1143f66cd9e9SStella Laurenzo else 1144f66cd9e9SStella Laurenzo inferredAlignment = view->strides[view->ndim - 1]; 1145f66cd9e9SStella Laurenzo 1146f66cd9e9SStella Laurenzo // The userData is a Py_buffer* that the deleter owns. 1147f66cd9e9SStella Laurenzo auto deleter = [](void *userData, const void *data, size_t size, 1148f66cd9e9SStella Laurenzo size_t align) { 1149f66cd9e9SStella Laurenzo Py_buffer *ownedView = static_cast<Py_buffer *>(userData); 1150f66cd9e9SStella Laurenzo PyBuffer_Release(ownedView); 1151f66cd9e9SStella Laurenzo delete ownedView; 1152f66cd9e9SStella Laurenzo }; 1153f66cd9e9SStella Laurenzo 1154f66cd9e9SStella Laurenzo size_t rawBufferSize = view->len; 1155f66cd9e9SStella Laurenzo MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet( 1156f66cd9e9SStella Laurenzo type, toMlirStringRef(name), view->buf, rawBufferSize, 1157f66cd9e9SStella Laurenzo inferredAlignment, isMutable, deleter, static_cast<void *>(view.get())); 1158f66cd9e9SStella Laurenzo if (mlirAttributeIsNull(attr)) { 1159f66cd9e9SStella Laurenzo throw std::invalid_argument( 1160f66cd9e9SStella Laurenzo "DenseResourceElementsAttr could not be constructed from the given " 1161f66cd9e9SStella Laurenzo "buffer. " 1162f66cd9e9SStella Laurenzo "This may mean that the Python buffer layout does not match that " 1163f66cd9e9SStella Laurenzo "MLIR expected layout and is a bug."); 1164f66cd9e9SStella Laurenzo } 1165f66cd9e9SStella Laurenzo view.release(); 1166f66cd9e9SStella Laurenzo return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr); 1167f66cd9e9SStella Laurenzo } 1168f66cd9e9SStella Laurenzo 1169f66cd9e9SStella Laurenzo static void bindDerived(ClassTy &c) { 1170f66cd9e9SStella Laurenzo c.def_static("get_from_buffer", 1171f66cd9e9SStella Laurenzo PyDenseResourceElementsAttribute::getFromBuffer, 1172f66cd9e9SStella Laurenzo py::arg("array"), py::arg("name"), py::arg("type"), 1173f66cd9e9SStella Laurenzo py::arg("alignment") = py::none(), 1174f66cd9e9SStella Laurenzo py::arg("is_mutable") = false, py::arg("context") = py::none(), 1175f66cd9e9SStella Laurenzo kDenseResourceElementsAttrGetFromBufferDocstring); 1176f66cd9e9SStella Laurenzo } 1177f66cd9e9SStella Laurenzo }; 1178f66cd9e9SStella Laurenzo 1179436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 1180436c6c9cSStella Laurenzo public: 1181436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 1182436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DictAttr"; 1183436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 11849566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 11859566ee28Smax mlirDictionaryAttrGetTypeID; 1186436c6c9cSStella Laurenzo 1187436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 1188436c6c9cSStella Laurenzo 11899fb1086bSAdrian Kuegel bool dunderContains(const std::string &name) { 11909fb1086bSAdrian Kuegel return !mlirAttributeIsNull( 11919fb1086bSAdrian Kuegel mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); 11929fb1086bSAdrian Kuegel } 11939fb1086bSAdrian Kuegel 1194436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 11959fb1086bSAdrian Kuegel c.def("__contains__", &PyDictAttribute::dunderContains); 1196436c6c9cSStella Laurenzo c.def("__len__", &PyDictAttribute::dunderLen); 1197436c6c9cSStella Laurenzo c.def_static( 1198436c6c9cSStella Laurenzo "get", 1199436c6c9cSStella Laurenzo [](py::dict attributes, DefaultingPyMlirContext context) { 1200436c6c9cSStella Laurenzo SmallVector<MlirNamedAttribute> mlirNamedAttributes; 1201436c6c9cSStella Laurenzo mlirNamedAttributes.reserve(attributes.size()); 1202436c6c9cSStella Laurenzo for (auto &it : attributes) { 120302b6fb21SMehdi Amini auto &mlirAttr = it.second.cast<PyAttribute &>(); 1204436c6c9cSStella Laurenzo auto name = it.first.cast<std::string>(); 1205436c6c9cSStella Laurenzo mlirNamedAttributes.push_back(mlirNamedAttributeGet( 120602b6fb21SMehdi Amini mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), 1207436c6c9cSStella Laurenzo toMlirStringRef(name)), 120802b6fb21SMehdi Amini mlirAttr)); 1209436c6c9cSStella Laurenzo } 1210436c6c9cSStella Laurenzo MlirAttribute attr = 1211436c6c9cSStella Laurenzo mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 1212436c6c9cSStella Laurenzo mlirNamedAttributes.data()); 1213436c6c9cSStella Laurenzo return PyDictAttribute(context->getRef(), attr); 1214436c6c9cSStella Laurenzo }, 1215ed9e52f3SAlex Zinenko py::arg("value") = py::dict(), py::arg("context") = py::none(), 1216436c6c9cSStella Laurenzo "Gets an uniqued dict attribute"); 1217436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 1218436c6c9cSStella Laurenzo MlirAttribute attr = 1219436c6c9cSStella Laurenzo mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 1220974c1596SRahul Kayaith if (mlirAttributeIsNull(attr)) 12214811270bSmax throw py::key_error("attempt to access a non-existent attribute"); 1222974c1596SRahul Kayaith return attr; 1223436c6c9cSStella Laurenzo }); 1224436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 1225436c6c9cSStella Laurenzo if (index < 0 || index >= self.dunderLen()) { 12264811270bSmax throw py::index_error("attempt to access out of bounds attribute"); 1227436c6c9cSStella Laurenzo } 1228436c6c9cSStella Laurenzo MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 1229436c6c9cSStella Laurenzo return PyNamedAttribute( 1230436c6c9cSStella Laurenzo namedAttr.attribute, 1231436c6c9cSStella Laurenzo std::string(mlirIdentifierStr(namedAttr.name).data)); 1232436c6c9cSStella Laurenzo }); 1233436c6c9cSStella Laurenzo } 1234436c6c9cSStella Laurenzo }; 1235436c6c9cSStella Laurenzo 1236436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing 1237436c6c9cSStella Laurenzo /// floating-point values. Supports element access. 1238436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute 1239436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseFPElementsAttribute, 1240436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 1241436c6c9cSStella Laurenzo public: 1242436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 1243436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseFPElementsAttr"; 1244436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1245436c6c9cSStella Laurenzo 1246436c6c9cSStella Laurenzo py::float_ dunderGetItem(intptr_t pos) { 1247436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 12484811270bSmax throw py::index_error("attempt to access out of bounds element"); 1249436c6c9cSStella Laurenzo } 1250436c6c9cSStella Laurenzo 1251436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 1252436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 1253436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 1254436c6c9cSStella Laurenzo // elemental type of the attribute. py::float_ is implicitly constructible 1255436c6c9cSStella Laurenzo // from float and double. 1256436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 1257436c6c9cSStella Laurenzo // querying them on each element access. 1258436c6c9cSStella Laurenzo if (mlirTypeIsAF32(type)) { 1259436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetFloatValue(*this, pos); 1260436c6c9cSStella Laurenzo } 1261436c6c9cSStella Laurenzo if (mlirTypeIsAF64(type)) { 1262436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetDoubleValue(*this, pos); 1263436c6c9cSStella Laurenzo } 12644811270bSmax throw py::type_error("Unsupported floating-point type"); 1265436c6c9cSStella Laurenzo } 1266436c6c9cSStella Laurenzo 1267436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1268436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 1269436c6c9cSStella Laurenzo } 1270436c6c9cSStella Laurenzo }; 1271436c6c9cSStella Laurenzo 1272436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 1273436c6c9cSStella Laurenzo public: 1274436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 1275436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "TypeAttr"; 1276436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 12779566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 12789566ee28Smax mlirTypeAttrGetTypeID; 1279436c6c9cSStella Laurenzo 1280436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1281436c6c9cSStella Laurenzo c.def_static( 1282436c6c9cSStella Laurenzo "get", 1283436c6c9cSStella Laurenzo [](PyType value, DefaultingPyMlirContext context) { 1284436c6c9cSStella Laurenzo MlirAttribute attr = mlirTypeAttrGet(value.get()); 1285436c6c9cSStella Laurenzo return PyTypeAttribute(context->getRef(), attr); 1286436c6c9cSStella Laurenzo }, 1287436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 1288436c6c9cSStella Laurenzo "Gets a uniqued Type attribute"); 1289436c6c9cSStella Laurenzo c.def_property_readonly("value", [](PyTypeAttribute &self) { 1290bfb1ba75Smax return mlirTypeAttrGetValue(self.get()); 1291436c6c9cSStella Laurenzo }); 1292436c6c9cSStella Laurenzo } 1293436c6c9cSStella Laurenzo }; 1294436c6c9cSStella Laurenzo 1295436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values. 1296436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 1297436c6c9cSStella Laurenzo public: 1298436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 1299436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "UnitAttr"; 1300436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 13019566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 13029566ee28Smax mlirUnitAttrGetTypeID; 1303436c6c9cSStella Laurenzo 1304436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1305436c6c9cSStella Laurenzo c.def_static( 1306436c6c9cSStella Laurenzo "get", 1307436c6c9cSStella Laurenzo [](DefaultingPyMlirContext context) { 1308436c6c9cSStella Laurenzo return PyUnitAttribute(context->getRef(), 1309436c6c9cSStella Laurenzo mlirUnitAttrGet(context->get())); 1310436c6c9cSStella Laurenzo }, 1311436c6c9cSStella Laurenzo py::arg("context") = py::none(), "Create a Unit attribute."); 1312436c6c9cSStella Laurenzo } 1313436c6c9cSStella Laurenzo }; 1314436c6c9cSStella Laurenzo 1315ac2e2d65SDenys Shabalin /// Strided layout attribute subclass. 1316ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute 1317ac2e2d65SDenys Shabalin : public PyConcreteAttribute<PyStridedLayoutAttribute> { 1318ac2e2d65SDenys Shabalin public: 1319ac2e2d65SDenys Shabalin static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; 1320ac2e2d65SDenys Shabalin static constexpr const char *pyClassName = "StridedLayoutAttr"; 1321ac2e2d65SDenys Shabalin using PyConcreteAttribute::PyConcreteAttribute; 13229566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 13239566ee28Smax mlirStridedLayoutAttrGetTypeID; 1324ac2e2d65SDenys Shabalin 1325ac2e2d65SDenys Shabalin static void bindDerived(ClassTy &c) { 1326ac2e2d65SDenys Shabalin c.def_static( 1327ac2e2d65SDenys Shabalin "get", 1328ac2e2d65SDenys Shabalin [](int64_t offset, const std::vector<int64_t> strides, 1329ac2e2d65SDenys Shabalin DefaultingPyMlirContext ctx) { 1330ac2e2d65SDenys Shabalin MlirAttribute attr = mlirStridedLayoutAttrGet( 1331ac2e2d65SDenys Shabalin ctx->get(), offset, strides.size(), strides.data()); 1332ac2e2d65SDenys Shabalin return PyStridedLayoutAttribute(ctx->getRef(), attr); 1333ac2e2d65SDenys Shabalin }, 1334ac2e2d65SDenys Shabalin py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(), 1335ac2e2d65SDenys Shabalin "Gets a strided layout attribute."); 1336e3fd612eSDenys Shabalin c.def_static( 1337e3fd612eSDenys Shabalin "get_fully_dynamic", 1338e3fd612eSDenys Shabalin [](int64_t rank, DefaultingPyMlirContext ctx) { 1339e3fd612eSDenys Shabalin auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); 1340e3fd612eSDenys Shabalin std::vector<int64_t> strides(rank); 1341e3fd612eSDenys Shabalin std::fill(strides.begin(), strides.end(), dynamic); 1342e3fd612eSDenys Shabalin MlirAttribute attr = mlirStridedLayoutAttrGet( 1343e3fd612eSDenys Shabalin ctx->get(), dynamic, strides.size(), strides.data()); 1344e3fd612eSDenys Shabalin return PyStridedLayoutAttribute(ctx->getRef(), attr); 1345e3fd612eSDenys Shabalin }, 1346e3fd612eSDenys Shabalin py::arg("rank"), py::arg("context") = py::none(), 1347e3fd612eSDenys Shabalin "Gets a strided layout attribute with dynamic offset and strides of a " 1348e3fd612eSDenys Shabalin "given rank."); 1349ac2e2d65SDenys Shabalin c.def_property_readonly( 1350ac2e2d65SDenys Shabalin "offset", 1351ac2e2d65SDenys Shabalin [](PyStridedLayoutAttribute &self) { 1352ac2e2d65SDenys Shabalin return mlirStridedLayoutAttrGetOffset(self); 1353ac2e2d65SDenys Shabalin }, 1354ac2e2d65SDenys Shabalin "Returns the value of the float point attribute"); 1355ac2e2d65SDenys Shabalin c.def_property_readonly( 1356ac2e2d65SDenys Shabalin "strides", 1357ac2e2d65SDenys Shabalin [](PyStridedLayoutAttribute &self) { 1358ac2e2d65SDenys Shabalin intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); 1359ac2e2d65SDenys Shabalin std::vector<int64_t> strides(size); 1360ac2e2d65SDenys Shabalin for (intptr_t i = 0; i < size; i++) { 1361ac2e2d65SDenys Shabalin strides[i] = mlirStridedLayoutAttrGetStride(self, i); 1362ac2e2d65SDenys Shabalin } 1363ac2e2d65SDenys Shabalin return strides; 1364ac2e2d65SDenys Shabalin }, 1365ac2e2d65SDenys Shabalin "Returns the value of the float point attribute"); 1366ac2e2d65SDenys Shabalin } 1367ac2e2d65SDenys Shabalin }; 1368ac2e2d65SDenys Shabalin 13699566ee28Smax py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { 13709566ee28Smax if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) 13719566ee28Smax return py::cast(PyDenseBoolArrayAttribute(pyAttribute)); 13729566ee28Smax if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) 13739566ee28Smax return py::cast(PyDenseI8ArrayAttribute(pyAttribute)); 13749566ee28Smax if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) 13759566ee28Smax return py::cast(PyDenseI16ArrayAttribute(pyAttribute)); 13769566ee28Smax if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) 13779566ee28Smax return py::cast(PyDenseI32ArrayAttribute(pyAttribute)); 13789566ee28Smax if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) 13799566ee28Smax return py::cast(PyDenseI64ArrayAttribute(pyAttribute)); 13809566ee28Smax if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) 13819566ee28Smax return py::cast(PyDenseF32ArrayAttribute(pyAttribute)); 13829566ee28Smax if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) 13839566ee28Smax return py::cast(PyDenseF64ArrayAttribute(pyAttribute)); 13849566ee28Smax std::string msg = 13859566ee28Smax std::string("Can't cast unknown element type DenseArrayAttr (") + 13869566ee28Smax std::string(py::repr(py::cast(pyAttribute))) + ")"; 13879566ee28Smax throw py::cast_error(msg); 13889566ee28Smax } 13899566ee28Smax 13909566ee28Smax py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { 13919566ee28Smax if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) 13929566ee28Smax return py::cast(PyDenseFPElementsAttribute(pyAttribute)); 13939566ee28Smax if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) 13949566ee28Smax return py::cast(PyDenseIntElementsAttribute(pyAttribute)); 13959566ee28Smax std::string msg = 13969566ee28Smax std::string( 13979566ee28Smax "Can't cast unknown element type DenseIntOrFPElementsAttr (") + 13989566ee28Smax std::string(py::repr(py::cast(pyAttribute))) + ")"; 13999566ee28Smax throw py::cast_error(msg); 14009566ee28Smax } 14019566ee28Smax 14029566ee28Smax py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { 14039566ee28Smax if (PyBoolAttribute::isaFunction(pyAttribute)) 14049566ee28Smax return py::cast(PyBoolAttribute(pyAttribute)); 14059566ee28Smax if (PyIntegerAttribute::isaFunction(pyAttribute)) 14069566ee28Smax return py::cast(PyIntegerAttribute(pyAttribute)); 14079566ee28Smax std::string msg = 14089566ee28Smax std::string("Can't cast unknown element type DenseArrayAttr (") + 14099566ee28Smax std::string(py::repr(py::cast(pyAttribute))) + ")"; 14109566ee28Smax throw py::cast_error(msg); 14119566ee28Smax } 14129566ee28Smax 14134eee9ef9Smax py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { 14144eee9ef9Smax if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute)) 14154eee9ef9Smax return py::cast(PyFlatSymbolRefAttribute(pyAttribute)); 14164eee9ef9Smax if (PySymbolRefAttribute::isaFunction(pyAttribute)) 14174eee9ef9Smax return py::cast(PySymbolRefAttribute(pyAttribute)); 14184eee9ef9Smax std::string msg = std::string("Can't cast unknown SymbolRef attribute (") + 14194eee9ef9Smax std::string(py::repr(py::cast(pyAttribute))) + ")"; 14204eee9ef9Smax throw py::cast_error(msg); 14214eee9ef9Smax } 14224eee9ef9Smax 1423436c6c9cSStella Laurenzo } // namespace 1424436c6c9cSStella Laurenzo 1425436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) { 1426436c6c9cSStella Laurenzo PyAffineMapAttribute::bind(m); 1427619fd8c2SJeff Niu 1428619fd8c2SJeff Niu PyDenseBoolArrayAttribute::bind(m); 1429619fd8c2SJeff Niu PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); 1430619fd8c2SJeff Niu PyDenseI8ArrayAttribute::bind(m); 1431619fd8c2SJeff Niu PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m); 1432619fd8c2SJeff Niu PyDenseI16ArrayAttribute::bind(m); 1433619fd8c2SJeff Niu PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m); 1434619fd8c2SJeff Niu PyDenseI32ArrayAttribute::bind(m); 1435619fd8c2SJeff Niu PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m); 1436619fd8c2SJeff Niu PyDenseI64ArrayAttribute::bind(m); 1437619fd8c2SJeff Niu PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m); 1438619fd8c2SJeff Niu PyDenseF32ArrayAttribute::bind(m); 1439619fd8c2SJeff Niu PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); 1440619fd8c2SJeff Niu PyDenseF64ArrayAttribute::bind(m); 1441619fd8c2SJeff Niu PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); 14429566ee28Smax PyGlobals::get().registerTypeCaster( 14439566ee28Smax mlirDenseArrayAttrGetTypeID(), 14449566ee28Smax pybind11::cpp_function(denseArrayAttributeCaster)); 1445619fd8c2SJeff Niu 1446436c6c9cSStella Laurenzo PyArrayAttribute::bind(m); 1447436c6c9cSStella Laurenzo PyArrayAttribute::PyArrayAttributeIterator::bind(m); 1448436c6c9cSStella Laurenzo PyBoolAttribute::bind(m); 1449436c6c9cSStella Laurenzo PyDenseElementsAttribute::bind(m); 1450436c6c9cSStella Laurenzo PyDenseFPElementsAttribute::bind(m); 1451436c6c9cSStella Laurenzo PyDenseIntElementsAttribute::bind(m); 14529566ee28Smax PyGlobals::get().registerTypeCaster( 14539566ee28Smax mlirDenseIntOrFPElementsAttrGetTypeID(), 14549566ee28Smax pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); 1455f66cd9e9SStella Laurenzo PyDenseResourceElementsAttribute::bind(m); 14569566ee28Smax 1457436c6c9cSStella Laurenzo PyDictAttribute::bind(m); 14584eee9ef9Smax PySymbolRefAttribute::bind(m); 14594eee9ef9Smax PyGlobals::get().registerTypeCaster( 14604eee9ef9Smax mlirSymbolRefAttrGetTypeID(), 14614eee9ef9Smax pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)); 14624eee9ef9Smax 1463436c6c9cSStella Laurenzo PyFlatSymbolRefAttribute::bind(m); 14645c3861b2SYun Long PyOpaqueAttribute::bind(m); 1465436c6c9cSStella Laurenzo PyFloatAttribute::bind(m); 1466436c6c9cSStella Laurenzo PyIntegerAttribute::bind(m); 1467436c6c9cSStella Laurenzo PyStringAttribute::bind(m); 1468436c6c9cSStella Laurenzo PyTypeAttribute::bind(m); 14699566ee28Smax PyGlobals::get().registerTypeCaster( 14709566ee28Smax mlirIntegerAttrGetTypeID(), 14719566ee28Smax pybind11::cpp_function(integerOrBoolAttributeCaster)); 1472436c6c9cSStella Laurenzo PyUnitAttribute::bind(m); 1473ac2e2d65SDenys Shabalin 1474ac2e2d65SDenys Shabalin PyStridedLayoutAttribute::bind(m); 1475436c6c9cSStella Laurenzo } 1476