1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2436c6c9cSStella Laurenzo // 3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6436c6c9cSStella Laurenzo // 7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8436c6c9cSStella Laurenzo 9a1fe1f5fSKazu Hirata #include <optional> 1071a25454SPeter Hawkins #include <string_view> 114811270bSmax #include <utility> 121fc096afSMehdi Amini 13436c6c9cSStella Laurenzo #include "IRModule.h" 14436c6c9cSStella Laurenzo 15436c6c9cSStella Laurenzo #include "PybindUtils.h" 16436c6c9cSStella Laurenzo 1771a25454SPeter Hawkins #include "llvm/ADT/ScopeExit.h" 18c912f0e7Spranavm-nvidia #include "llvm/Support/raw_ostream.h" 1971a25454SPeter Hawkins 20436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h" 21436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h" 22bfb1ba75Smax #include "mlir/Bindings/Python/PybindAdaptors.h" 23436c6c9cSStella Laurenzo 24436c6c9cSStella Laurenzo namespace py = pybind11; 25436c6c9cSStella Laurenzo using namespace mlir; 26436c6c9cSStella Laurenzo using namespace mlir::python; 27436c6c9cSStella Laurenzo 28436c6c9cSStella Laurenzo using llvm::SmallVector; 29436c6c9cSStella Laurenzo 305d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 315d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline). 325d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 335d6d30edSStella Laurenzo 345d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] = 355d6d30edSStella Laurenzo R"(Gets a DenseElementsAttr from a Python buffer or array. 365d6d30edSStella Laurenzo 375d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based 385d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and 395d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these 405d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer. 415d6d30edSStella Laurenzo 425d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided 435d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal 445d6d30edSStella Laurenzo representation: 455d6d30edSStella Laurenzo 465d6d30edSStella Laurenzo * Integer types (except for i1): the buffer must be byte aligned to the 475d6d30edSStella Laurenzo next byte boundary. 485d6d30edSStella Laurenzo * Floating point types: Must be bit-castable to the given floating point 495d6d30edSStella Laurenzo size. 505d6d30edSStella Laurenzo * i1 (bool): Bit packed into 8bit words where the bit pattern matches a 515d6d30edSStella Laurenzo row major ordering. An arbitrary Numpy `bool_` array can be bit packed to 525d6d30edSStella Laurenzo this specification with: `np.packbits(ary, axis=None, bitorder='little')`. 535d6d30edSStella Laurenzo 545d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0 555d6d30edSStella Laurenzo or 255), then a splat will be created. 565d6d30edSStella Laurenzo 575d6d30edSStella Laurenzo Args: 585d6d30edSStella Laurenzo array: The array or buffer to convert. 595d6d30edSStella Laurenzo signless: If inferring an appropriate MLIR type, use signless types for 605d6d30edSStella Laurenzo integers (defaults True). 615d6d30edSStella Laurenzo type: Skips inference of the MLIR element type and uses this instead. The 625d6d30edSStella Laurenzo storage size must be consistent with the actual contents of the buffer. 635d6d30edSStella Laurenzo shape: Overrides the shape of the buffer when constructing the MLIR 645d6d30edSStella Laurenzo shaped type. This is needed when the physical and logical shape differ (as 655d6d30edSStella Laurenzo for i1). 665d6d30edSStella Laurenzo context: Explicit context, if not from context manager. 675d6d30edSStella Laurenzo 685d6d30edSStella Laurenzo Returns: 695d6d30edSStella Laurenzo DenseElementsAttr on success. 705d6d30edSStella Laurenzo 715d6d30edSStella Laurenzo Raises: 725d6d30edSStella Laurenzo ValueError: If the type of the buffer or array cannot be matched to an MLIR 735d6d30edSStella Laurenzo type or if the buffer does not meet expectations. 745d6d30edSStella Laurenzo )"; 755d6d30edSStella Laurenzo 76c912f0e7Spranavm-nvidia static const char kDenseElementsAttrGetFromListDocstring[] = 77c912f0e7Spranavm-nvidia R"(Gets a DenseElementsAttr from a Python list of attributes. 78c912f0e7Spranavm-nvidia 79c912f0e7Spranavm-nvidia Note that it can be expensive to construct attributes individually. 80c912f0e7Spranavm-nvidia For a large number of elements, consider using a Python buffer or array instead. 81c912f0e7Spranavm-nvidia 82c912f0e7Spranavm-nvidia Args: 83c912f0e7Spranavm-nvidia attrs: A list of attributes. 84c912f0e7Spranavm-nvidia type: The desired shape and type of the resulting DenseElementsAttr. 85c912f0e7Spranavm-nvidia If not provided, the element type is determined based on the type 86c912f0e7Spranavm-nvidia of the 0th attribute and the shape is `[len(attrs)]`. 87c912f0e7Spranavm-nvidia context: Explicit context, if not from context manager. 88c912f0e7Spranavm-nvidia 89c912f0e7Spranavm-nvidia Returns: 90c912f0e7Spranavm-nvidia DenseElementsAttr on success. 91c912f0e7Spranavm-nvidia 92c912f0e7Spranavm-nvidia Raises: 93c912f0e7Spranavm-nvidia ValueError: If the type of the attributes does not match the type 94c912f0e7Spranavm-nvidia specified by `shaped_type`. 95c912f0e7Spranavm-nvidia )"; 96c912f0e7Spranavm-nvidia 97f66cd9e9SStella Laurenzo static const char kDenseResourceElementsAttrGetFromBufferDocstring[] = 98f66cd9e9SStella Laurenzo R"(Gets a DenseResourceElementsAttr from a Python buffer or array. 99f66cd9e9SStella Laurenzo 100f66cd9e9SStella Laurenzo This function does minimal validation or massaging of the data, and it is 101f66cd9e9SStella Laurenzo up to the caller to ensure that the buffer meets the characteristics 102f66cd9e9SStella Laurenzo implied by the shape. 103f66cd9e9SStella Laurenzo 104f66cd9e9SStella Laurenzo The backing buffer and any user objects will be retained for the lifetime 105f66cd9e9SStella Laurenzo of the resource blob. This is typically bounded to the context but the 106f66cd9e9SStella Laurenzo resource can have a shorter lifespan depending on how it is used in 107f66cd9e9SStella Laurenzo subsequent processing. 108f66cd9e9SStella Laurenzo 109f66cd9e9SStella Laurenzo Args: 110f66cd9e9SStella Laurenzo buffer: The array or buffer to convert. 111f66cd9e9SStella Laurenzo name: Name to provide to the resource (may be changed upon collision). 112f66cd9e9SStella Laurenzo type: The explicit ShapedType to construct the attribute with. 113f66cd9e9SStella Laurenzo context: Explicit context, if not from context manager. 114f66cd9e9SStella Laurenzo 115f66cd9e9SStella Laurenzo Returns: 116f66cd9e9SStella Laurenzo DenseResourceElementsAttr on success. 117f66cd9e9SStella Laurenzo 118f66cd9e9SStella Laurenzo Raises: 119f66cd9e9SStella Laurenzo ValueError: If the type of the buffer or array cannot be matched to an MLIR 120f66cd9e9SStella Laurenzo type or if the buffer does not meet expectations. 121f66cd9e9SStella Laurenzo )"; 122f66cd9e9SStella Laurenzo 123436c6c9cSStella Laurenzo namespace { 124436c6c9cSStella Laurenzo 125436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) { 126436c6c9cSStella Laurenzo return mlirStringRefCreate(s.data(), s.size()); 127436c6c9cSStella Laurenzo } 128436c6c9cSStella Laurenzo 129436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 130436c6c9cSStella Laurenzo public: 131436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 132436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineMapAttr"; 133436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1349566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 1359566ee28Smax mlirAffineMapAttrGetTypeID; 136436c6c9cSStella Laurenzo 137436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 138436c6c9cSStella Laurenzo c.def_static( 139436c6c9cSStella Laurenzo "get", 140436c6c9cSStella Laurenzo [](PyAffineMap &affineMap) { 141436c6c9cSStella Laurenzo MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 142436c6c9cSStella Laurenzo return PyAffineMapAttribute(affineMap.getContext(), attr); 143436c6c9cSStella Laurenzo }, 144436c6c9cSStella Laurenzo py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 145*c36b4248SBimo c.def_property_readonly("value", mlirAffineMapAttrGetValue, 146*c36b4248SBimo "Returns the value of the AffineMap attribute"); 147436c6c9cSStella Laurenzo } 148436c6c9cSStella Laurenzo }; 149436c6c9cSStella Laurenzo 150ed9e52f3SAlex Zinenko template <typename T> 151ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) { 152ed9e52f3SAlex Zinenko try { 153ed9e52f3SAlex Zinenko return object.cast<T>(); 154ed9e52f3SAlex Zinenko } catch (py::cast_error &err) { 155ed9e52f3SAlex Zinenko std::string msg = 156ed9e52f3SAlex Zinenko std::string( 157ed9e52f3SAlex Zinenko "Invalid attribute when attempting to create an ArrayAttribute (") + 158ed9e52f3SAlex Zinenko err.what() + ")"; 159ed9e52f3SAlex Zinenko throw py::cast_error(msg); 160ed9e52f3SAlex Zinenko } catch (py::reference_cast_error &err) { 161ed9e52f3SAlex Zinenko std::string msg = std::string("Invalid attribute (None?) when attempting " 162ed9e52f3SAlex Zinenko "to create an ArrayAttribute (") + 163ed9e52f3SAlex Zinenko err.what() + ")"; 164ed9e52f3SAlex Zinenko throw py::cast_error(msg); 165ed9e52f3SAlex Zinenko } 166ed9e52f3SAlex Zinenko } 167ed9e52f3SAlex Zinenko 168619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived 169619fd8c2SJeff Niu /// implementation class. 170619fd8c2SJeff Niu template <typename EltTy, typename DerivedT> 171133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> { 172619fd8c2SJeff Niu public: 173133624acSJeff Niu using PyConcreteAttribute<DerivedT>::PyConcreteAttribute; 174619fd8c2SJeff Niu 175619fd8c2SJeff Niu /// Iterator over the integer elements of a dense array. 176619fd8c2SJeff Niu class PyDenseArrayIterator { 177619fd8c2SJeff Niu public: 1784a1b1196SMehdi Amini PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {} 179619fd8c2SJeff Niu 180619fd8c2SJeff Niu /// Return a copy of the iterator. 181619fd8c2SJeff Niu PyDenseArrayIterator dunderIter() { return *this; } 182619fd8c2SJeff Niu 183619fd8c2SJeff Niu /// Return the next element. 184619fd8c2SJeff Niu EltTy dunderNext() { 185619fd8c2SJeff Niu // Throw if the index has reached the end. 186619fd8c2SJeff Niu if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) 187619fd8c2SJeff Niu throw py::stop_iteration(); 188619fd8c2SJeff Niu return DerivedT::getElement(attr.get(), nextIndex++); 189619fd8c2SJeff Niu } 190619fd8c2SJeff Niu 191619fd8c2SJeff Niu /// Bind the iterator class. 192619fd8c2SJeff Niu static void bind(py::module &m) { 193619fd8c2SJeff Niu py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName, 194619fd8c2SJeff Niu py::module_local()) 195619fd8c2SJeff Niu .def("__iter__", &PyDenseArrayIterator::dunderIter) 196619fd8c2SJeff Niu .def("__next__", &PyDenseArrayIterator::dunderNext); 197619fd8c2SJeff Niu } 198619fd8c2SJeff Niu 199619fd8c2SJeff Niu private: 200619fd8c2SJeff Niu /// The referenced dense array attribute. 201619fd8c2SJeff Niu PyAttribute attr; 202619fd8c2SJeff Niu /// The next index to read. 203619fd8c2SJeff Niu int nextIndex = 0; 204619fd8c2SJeff Niu }; 205619fd8c2SJeff Niu 206619fd8c2SJeff Niu /// Get the element at the given index. 207619fd8c2SJeff Niu EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); } 208619fd8c2SJeff Niu 209619fd8c2SJeff Niu /// Bind the attribute class. 210133624acSJeff Niu static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) { 211619fd8c2SJeff Niu // Bind the constructor. 212619fd8c2SJeff Niu c.def_static( 213619fd8c2SJeff Niu "get", 214619fd8c2SJeff Niu [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) { 2158dcb6722SIngo Müller return getAttribute(values, ctx->getRef()); 216619fd8c2SJeff Niu }, 217619fd8c2SJeff Niu py::arg("values"), py::arg("context") = py::none(), 218619fd8c2SJeff Niu "Gets a uniqued dense array attribute"); 219619fd8c2SJeff Niu // Bind the array methods. 220133624acSJeff Niu c.def("__getitem__", [](DerivedT &arr, intptr_t i) { 221619fd8c2SJeff Niu if (i >= mlirDenseArrayGetNumElements(arr)) 222619fd8c2SJeff Niu throw py::index_error("DenseArray index out of range"); 223619fd8c2SJeff Niu return arr.getItem(i); 224619fd8c2SJeff Niu }); 225133624acSJeff Niu c.def("__len__", [](const DerivedT &arr) { 226619fd8c2SJeff Niu return mlirDenseArrayGetNumElements(arr); 227619fd8c2SJeff Niu }); 228133624acSJeff Niu c.def("__iter__", 229133624acSJeff Niu [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); 2304a1b1196SMehdi Amini c.def("__add__", [](DerivedT &arr, const py::list &extras) { 231619fd8c2SJeff Niu std::vector<EltTy> values; 232619fd8c2SJeff Niu intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); 233619fd8c2SJeff Niu values.reserve(numOldElements + py::len(extras)); 234619fd8c2SJeff Niu for (intptr_t i = 0; i < numOldElements; ++i) 235619fd8c2SJeff Niu values.push_back(arr.getItem(i)); 236619fd8c2SJeff Niu for (py::handle attr : extras) 237619fd8c2SJeff Niu values.push_back(pyTryCast<EltTy>(attr)); 2388dcb6722SIngo Müller return getAttribute(values, arr.getContext()); 239619fd8c2SJeff Niu }); 240619fd8c2SJeff Niu } 2418dcb6722SIngo Müller 2428dcb6722SIngo Müller private: 2438dcb6722SIngo Müller static DerivedT getAttribute(const std::vector<EltTy> &values, 2448dcb6722SIngo Müller PyMlirContextRef ctx) { 2458dcb6722SIngo Müller if constexpr (std::is_same_v<EltTy, bool>) { 2468dcb6722SIngo Müller std::vector<int> intValues(values.begin(), values.end()); 2478dcb6722SIngo Müller MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(), 2488dcb6722SIngo Müller intValues.data()); 2498dcb6722SIngo Müller return DerivedT(ctx, attr); 2508dcb6722SIngo Müller } else { 2518dcb6722SIngo Müller MlirAttribute attr = 2528dcb6722SIngo Müller DerivedT::getAttribute(ctx->get(), values.size(), values.data()); 2538dcb6722SIngo Müller return DerivedT(ctx, attr); 2548dcb6722SIngo Müller } 2558dcb6722SIngo Müller } 256619fd8c2SJeff Niu }; 257619fd8c2SJeff Niu 258619fd8c2SJeff Niu /// Instantiate the python dense array classes. 259619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute 2608dcb6722SIngo Müller : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> { 261619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; 262619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseBoolArrayGet; 263619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseBoolArrayGetElement; 264619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseBoolArrayAttr"; 265619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseBoolArrayIterator"; 266619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 267619fd8c2SJeff Niu }; 268619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute 269619fd8c2SJeff Niu : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> { 270619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array; 271619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI8ArrayGet; 272619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI8ArrayGetElement; 273619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI8ArrayAttr"; 274619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI8ArrayIterator"; 275619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 276619fd8c2SJeff Niu }; 277619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute 278619fd8c2SJeff Niu : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> { 279619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array; 280619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI16ArrayGet; 281619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI16ArrayGetElement; 282619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI16ArrayAttr"; 283619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI16ArrayIterator"; 284619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 285619fd8c2SJeff Niu }; 286619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute 287619fd8c2SJeff Niu : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> { 288619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array; 289619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI32ArrayGet; 290619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI32ArrayGetElement; 291619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI32ArrayAttr"; 292619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI32ArrayIterator"; 293619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 294619fd8c2SJeff Niu }; 295619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute 296619fd8c2SJeff Niu : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> { 297619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array; 298619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI64ArrayGet; 299619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI64ArrayGetElement; 300619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI64ArrayAttr"; 301619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI64ArrayIterator"; 302619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 303619fd8c2SJeff Niu }; 304619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute 305619fd8c2SJeff Niu : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> { 306619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array; 307619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseF32ArrayGet; 308619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseF32ArrayGetElement; 309619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseF32ArrayAttr"; 310619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseF32ArrayIterator"; 311619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 312619fd8c2SJeff Niu }; 313619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute 314619fd8c2SJeff Niu : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> { 315619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array; 316619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseF64ArrayGet; 317619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseF64ArrayGetElement; 318619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseF64ArrayAttr"; 319619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseF64ArrayIterator"; 320619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 321619fd8c2SJeff Niu }; 322619fd8c2SJeff Niu 323436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 324436c6c9cSStella Laurenzo public: 325436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 326436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "ArrayAttr"; 327436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 3289566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 3299566ee28Smax mlirArrayAttrGetTypeID; 330436c6c9cSStella Laurenzo 331436c6c9cSStella Laurenzo class PyArrayAttributeIterator { 332436c6c9cSStella Laurenzo public: 3331fc096afSMehdi Amini PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} 334436c6c9cSStella Laurenzo 335436c6c9cSStella Laurenzo PyArrayAttributeIterator &dunderIter() { return *this; } 336436c6c9cSStella Laurenzo 337974c1596SRahul Kayaith MlirAttribute dunderNext() { 338bca88952SJeff Niu // TODO: Throw is an inefficient way to stop iteration. 339bca88952SJeff Niu if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) 340436c6c9cSStella Laurenzo throw py::stop_iteration(); 341974c1596SRahul Kayaith return mlirArrayAttrGetElement(attr.get(), nextIndex++); 342436c6c9cSStella Laurenzo } 343436c6c9cSStella Laurenzo 344436c6c9cSStella Laurenzo static void bind(py::module &m) { 345f05ff4f7SStella Laurenzo py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator", 346f05ff4f7SStella Laurenzo py::module_local()) 347436c6c9cSStella Laurenzo .def("__iter__", &PyArrayAttributeIterator::dunderIter) 348436c6c9cSStella Laurenzo .def("__next__", &PyArrayAttributeIterator::dunderNext); 349436c6c9cSStella Laurenzo } 350436c6c9cSStella Laurenzo 351436c6c9cSStella Laurenzo private: 352436c6c9cSStella Laurenzo PyAttribute attr; 353436c6c9cSStella Laurenzo int nextIndex = 0; 354436c6c9cSStella Laurenzo }; 355436c6c9cSStella Laurenzo 356974c1596SRahul Kayaith MlirAttribute getItem(intptr_t i) { 357974c1596SRahul Kayaith return mlirArrayAttrGetElement(*this, i); 358ed9e52f3SAlex Zinenko } 359ed9e52f3SAlex Zinenko 360436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 361436c6c9cSStella Laurenzo c.def_static( 362436c6c9cSStella Laurenzo "get", 363436c6c9cSStella Laurenzo [](py::list attributes, DefaultingPyMlirContext context) { 364436c6c9cSStella Laurenzo SmallVector<MlirAttribute> mlirAttributes; 365436c6c9cSStella Laurenzo mlirAttributes.reserve(py::len(attributes)); 366436c6c9cSStella Laurenzo for (auto attribute : attributes) { 367ed9e52f3SAlex Zinenko mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); 368436c6c9cSStella Laurenzo } 369436c6c9cSStella Laurenzo MlirAttribute attr = mlirArrayAttrGet( 370436c6c9cSStella Laurenzo context->get(), mlirAttributes.size(), mlirAttributes.data()); 371436c6c9cSStella Laurenzo return PyArrayAttribute(context->getRef(), attr); 372436c6c9cSStella Laurenzo }, 373436c6c9cSStella Laurenzo py::arg("attributes"), py::arg("context") = py::none(), 374436c6c9cSStella Laurenzo "Gets a uniqued Array attribute"); 375436c6c9cSStella Laurenzo c.def("__getitem__", 376436c6c9cSStella Laurenzo [](PyArrayAttribute &arr, intptr_t i) { 377436c6c9cSStella Laurenzo if (i >= mlirArrayAttrGetNumElements(arr)) 378436c6c9cSStella Laurenzo throw py::index_error("ArrayAttribute index out of range"); 379ed9e52f3SAlex Zinenko return arr.getItem(i); 380436c6c9cSStella Laurenzo }) 381436c6c9cSStella Laurenzo .def("__len__", 382436c6c9cSStella Laurenzo [](const PyArrayAttribute &arr) { 383436c6c9cSStella Laurenzo return mlirArrayAttrGetNumElements(arr); 384436c6c9cSStella Laurenzo }) 385436c6c9cSStella Laurenzo .def("__iter__", [](const PyArrayAttribute &arr) { 386436c6c9cSStella Laurenzo return PyArrayAttributeIterator(arr); 387436c6c9cSStella Laurenzo }); 388ed9e52f3SAlex Zinenko c.def("__add__", [](PyArrayAttribute arr, py::list extras) { 389ed9e52f3SAlex Zinenko std::vector<MlirAttribute> attributes; 390ed9e52f3SAlex Zinenko intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); 391ed9e52f3SAlex Zinenko attributes.reserve(numOldElements + py::len(extras)); 392ed9e52f3SAlex Zinenko for (intptr_t i = 0; i < numOldElements; ++i) 393ed9e52f3SAlex Zinenko attributes.push_back(arr.getItem(i)); 394ed9e52f3SAlex Zinenko for (py::handle attr : extras) 395ed9e52f3SAlex Zinenko attributes.push_back(pyTryCast<PyAttribute>(attr)); 396ed9e52f3SAlex Zinenko MlirAttribute arrayAttr = mlirArrayAttrGet( 397ed9e52f3SAlex Zinenko arr.getContext()->get(), attributes.size(), attributes.data()); 398ed9e52f3SAlex Zinenko return PyArrayAttribute(arr.getContext(), arrayAttr); 399ed9e52f3SAlex Zinenko }); 400436c6c9cSStella Laurenzo } 401436c6c9cSStella Laurenzo }; 402436c6c9cSStella Laurenzo 403436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr. 404436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 405436c6c9cSStella Laurenzo public: 406436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 407436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FloatAttr"; 408436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 4099566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 4109566ee28Smax mlirFloatAttrGetTypeID; 411436c6c9cSStella Laurenzo 412436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 413436c6c9cSStella Laurenzo c.def_static( 414436c6c9cSStella Laurenzo "get", 415436c6c9cSStella Laurenzo [](PyType &type, double value, DefaultingPyLocation loc) { 4163ea4c501SRahul Kayaith PyMlirContext::ErrorCapture errors(loc->getContext()); 417436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 4183ea4c501SRahul Kayaith if (mlirAttributeIsNull(attr)) 4193ea4c501SRahul Kayaith throw MLIRError("Invalid attribute", errors.take()); 420436c6c9cSStella Laurenzo return PyFloatAttribute(type.getContext(), attr); 421436c6c9cSStella Laurenzo }, 422436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 423436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a type"); 424436c6c9cSStella Laurenzo c.def_static( 425436c6c9cSStella Laurenzo "get_f32", 426436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 427436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 428436c6c9cSStella Laurenzo context->get(), mlirF32TypeGet(context->get()), value); 429436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 430436c6c9cSStella Laurenzo }, 431436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 432436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f32 type"); 433436c6c9cSStella Laurenzo c.def_static( 434436c6c9cSStella Laurenzo "get_f64", 435436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 436436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 437436c6c9cSStella Laurenzo context->get(), mlirF64TypeGet(context->get()), value); 438436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 439436c6c9cSStella Laurenzo }, 440436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 441436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f64 type"); 4422a5d4974SIngo Müller c.def_property_readonly("value", mlirFloatAttrGetValueDouble, 4432a5d4974SIngo Müller "Returns the value of the float attribute"); 4442a5d4974SIngo Müller c.def("__float__", mlirFloatAttrGetValueDouble, 4452a5d4974SIngo Müller "Converts the value of the float attribute to a Python float"); 446436c6c9cSStella Laurenzo } 447436c6c9cSStella Laurenzo }; 448436c6c9cSStella Laurenzo 449436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr. 450436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 451436c6c9cSStella Laurenzo public: 452436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 453436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "IntegerAttr"; 454436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 455436c6c9cSStella Laurenzo 456436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 457436c6c9cSStella Laurenzo c.def_static( 458436c6c9cSStella Laurenzo "get", 459436c6c9cSStella Laurenzo [](PyType &type, int64_t value) { 460436c6c9cSStella Laurenzo MlirAttribute attr = mlirIntegerAttrGet(type, value); 461436c6c9cSStella Laurenzo return PyIntegerAttribute(type.getContext(), attr); 462436c6c9cSStella Laurenzo }, 463436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), 464436c6c9cSStella Laurenzo "Gets an uniqued integer attribute associated to a type"); 4652a5d4974SIngo Müller c.def_property_readonly("value", toPyInt, 4662a5d4974SIngo Müller "Returns the value of the integer attribute"); 4672a5d4974SIngo Müller c.def("__int__", toPyInt, 4682a5d4974SIngo Müller "Converts the value of the integer attribute to a Python int"); 4692a5d4974SIngo Müller c.def_property_readonly_static("static_typeid", 4702a5d4974SIngo Müller [](py::object & /*class*/) -> MlirTypeID { 4712a5d4974SIngo Müller return mlirIntegerAttrGetTypeID(); 4722a5d4974SIngo Müller }); 4732a5d4974SIngo Müller } 4742a5d4974SIngo Müller 4752a5d4974SIngo Müller private: 4762a5d4974SIngo Müller static py::int_ toPyInt(PyIntegerAttribute &self) { 477e9db306dSrkayaith MlirType type = mlirAttributeGetType(self); 478e9db306dSrkayaith if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) 479436c6c9cSStella Laurenzo return mlirIntegerAttrGetValueInt(self); 480e9db306dSrkayaith if (mlirIntegerTypeIsSigned(type)) 481e9db306dSrkayaith return mlirIntegerAttrGetValueSInt(self); 482e9db306dSrkayaith return mlirIntegerAttrGetValueUInt(self); 483436c6c9cSStella Laurenzo } 484436c6c9cSStella Laurenzo }; 485436c6c9cSStella Laurenzo 486436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr. 487436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 488436c6c9cSStella Laurenzo public: 489436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 490436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "BoolAttr"; 491436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 492436c6c9cSStella Laurenzo 493436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 494436c6c9cSStella Laurenzo c.def_static( 495436c6c9cSStella Laurenzo "get", 496436c6c9cSStella Laurenzo [](bool value, DefaultingPyMlirContext context) { 497436c6c9cSStella Laurenzo MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 498436c6c9cSStella Laurenzo return PyBoolAttribute(context->getRef(), attr); 499436c6c9cSStella Laurenzo }, 500436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 501436c6c9cSStella Laurenzo "Gets an uniqued bool attribute"); 5022a5d4974SIngo Müller c.def_property_readonly("value", mlirBoolAttrGetValue, 503436c6c9cSStella Laurenzo "Returns the value of the bool attribute"); 5042a5d4974SIngo Müller c.def("__bool__", mlirBoolAttrGetValue, 5052a5d4974SIngo Müller "Converts the value of the bool attribute to a Python bool"); 506436c6c9cSStella Laurenzo } 507436c6c9cSStella Laurenzo }; 508436c6c9cSStella Laurenzo 5094eee9ef9Smax class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> { 5104eee9ef9Smax public: 5114eee9ef9Smax static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef; 5124eee9ef9Smax static constexpr const char *pyClassName = "SymbolRefAttr"; 5134eee9ef9Smax using PyConcreteAttribute::PyConcreteAttribute; 5144eee9ef9Smax 5154eee9ef9Smax static MlirAttribute fromList(const std::vector<std::string> &symbols, 5164eee9ef9Smax PyMlirContext &context) { 5174eee9ef9Smax if (symbols.empty()) 5184eee9ef9Smax throw std::runtime_error("SymbolRefAttr must be composed of at least " 5194eee9ef9Smax "one symbol."); 5204eee9ef9Smax MlirStringRef rootSymbol = toMlirStringRef(symbols[0]); 5214eee9ef9Smax SmallVector<MlirAttribute, 3> referenceAttrs; 5224eee9ef9Smax for (size_t i = 1; i < symbols.size(); ++i) { 5234eee9ef9Smax referenceAttrs.push_back( 5244eee9ef9Smax mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i]))); 5254eee9ef9Smax } 5264eee9ef9Smax return mlirSymbolRefAttrGet(context.get(), rootSymbol, 5274eee9ef9Smax referenceAttrs.size(), referenceAttrs.data()); 5284eee9ef9Smax } 5294eee9ef9Smax 5304eee9ef9Smax static void bindDerived(ClassTy &c) { 5314eee9ef9Smax c.def_static( 5324eee9ef9Smax "get", 5334eee9ef9Smax [](const std::vector<std::string> &symbols, 5344eee9ef9Smax DefaultingPyMlirContext context) { 5354eee9ef9Smax return PySymbolRefAttribute::fromList(symbols, context.resolve()); 5364eee9ef9Smax }, 5374eee9ef9Smax py::arg("symbols"), py::arg("context") = py::none(), 5384eee9ef9Smax "Gets a uniqued SymbolRef attribute from a list of symbol names"); 5394eee9ef9Smax c.def_property_readonly( 5404eee9ef9Smax "value", 5414eee9ef9Smax [](PySymbolRefAttribute &self) { 5424eee9ef9Smax std::vector<std::string> symbols = { 5434eee9ef9Smax unwrap(mlirSymbolRefAttrGetRootReference(self)).str()}; 5444eee9ef9Smax for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); 5454eee9ef9Smax ++i) 5464eee9ef9Smax symbols.push_back( 5474eee9ef9Smax unwrap(mlirSymbolRefAttrGetRootReference( 5484eee9ef9Smax mlirSymbolRefAttrGetNestedReference(self, i))) 5494eee9ef9Smax .str()); 5504eee9ef9Smax return symbols; 5514eee9ef9Smax }, 5524eee9ef9Smax "Returns the value of the SymbolRef attribute as a list[str]"); 5534eee9ef9Smax } 5544eee9ef9Smax }; 5554eee9ef9Smax 556436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute 557436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 558436c6c9cSStella Laurenzo public: 559436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 560436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 561436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 562436c6c9cSStella Laurenzo 563436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 564436c6c9cSStella Laurenzo c.def_static( 565436c6c9cSStella Laurenzo "get", 566436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 567436c6c9cSStella Laurenzo MlirAttribute attr = 568436c6c9cSStella Laurenzo mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 569436c6c9cSStella Laurenzo return PyFlatSymbolRefAttribute(context->getRef(), attr); 570436c6c9cSStella Laurenzo }, 571436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 572436c6c9cSStella Laurenzo "Gets a uniqued FlatSymbolRef attribute"); 573436c6c9cSStella Laurenzo c.def_property_readonly( 574436c6c9cSStella Laurenzo "value", 575436c6c9cSStella Laurenzo [](PyFlatSymbolRefAttribute &self) { 576436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 577436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 578436c6c9cSStella Laurenzo }, 579436c6c9cSStella Laurenzo "Returns the value of the FlatSymbolRef attribute as a string"); 580436c6c9cSStella Laurenzo } 581436c6c9cSStella Laurenzo }; 582436c6c9cSStella Laurenzo 5835c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> { 5845c3861b2SYun Long public: 5855c3861b2SYun Long static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; 5865c3861b2SYun Long static constexpr const char *pyClassName = "OpaqueAttr"; 5875c3861b2SYun Long using PyConcreteAttribute::PyConcreteAttribute; 5889566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 5899566ee28Smax mlirOpaqueAttrGetTypeID; 5905c3861b2SYun Long 5915c3861b2SYun Long static void bindDerived(ClassTy &c) { 5925c3861b2SYun Long c.def_static( 5935c3861b2SYun Long "get", 5945c3861b2SYun Long [](std::string dialectNamespace, py::buffer buffer, PyType &type, 5955c3861b2SYun Long DefaultingPyMlirContext context) { 5965c3861b2SYun Long const py::buffer_info bufferInfo = buffer.request(); 5975c3861b2SYun Long intptr_t bufferSize = bufferInfo.size; 5985c3861b2SYun Long MlirAttribute attr = mlirOpaqueAttrGet( 5995c3861b2SYun Long context->get(), toMlirStringRef(dialectNamespace), bufferSize, 6005c3861b2SYun Long static_cast<char *>(bufferInfo.ptr), type); 6015c3861b2SYun Long return PyOpaqueAttribute(context->getRef(), attr); 6025c3861b2SYun Long }, 6035c3861b2SYun Long py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"), 6045c3861b2SYun Long py::arg("context") = py::none(), "Gets an Opaque attribute."); 6055c3861b2SYun Long c.def_property_readonly( 6065c3861b2SYun Long "dialect_namespace", 6075c3861b2SYun Long [](PyOpaqueAttribute &self) { 6085c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); 6095c3861b2SYun Long return py::str(stringRef.data, stringRef.length); 6105c3861b2SYun Long }, 6115c3861b2SYun Long "Returns the dialect namespace for the Opaque attribute as a string"); 6125c3861b2SYun Long c.def_property_readonly( 6135c3861b2SYun Long "data", 6145c3861b2SYun Long [](PyOpaqueAttribute &self) { 6155c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetData(self); 61662bf6c2eSChris Jones return py::bytes(stringRef.data, stringRef.length); 6175c3861b2SYun Long }, 61862bf6c2eSChris Jones "Returns the data for the Opaqued attributes as `bytes`"); 6195c3861b2SYun Long } 6205c3861b2SYun Long }; 6215c3861b2SYun Long 622436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 623436c6c9cSStella Laurenzo public: 624436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 625436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "StringAttr"; 626436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 6279566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 6289566ee28Smax mlirStringAttrGetTypeID; 629436c6c9cSStella Laurenzo 630436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 631436c6c9cSStella Laurenzo c.def_static( 632436c6c9cSStella Laurenzo "get", 633436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 634436c6c9cSStella Laurenzo MlirAttribute attr = 635436c6c9cSStella Laurenzo mlirStringAttrGet(context->get(), toMlirStringRef(value)); 636436c6c9cSStella Laurenzo return PyStringAttribute(context->getRef(), attr); 637436c6c9cSStella Laurenzo }, 638436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 639436c6c9cSStella Laurenzo "Gets a uniqued string attribute"); 640436c6c9cSStella Laurenzo c.def_static( 641436c6c9cSStella Laurenzo "get_typed", 642436c6c9cSStella Laurenzo [](PyType &type, std::string value) { 643436c6c9cSStella Laurenzo MlirAttribute attr = 644436c6c9cSStella Laurenzo mlirStringAttrTypedGet(type, toMlirStringRef(value)); 645436c6c9cSStella Laurenzo return PyStringAttribute(type.getContext(), attr); 646436c6c9cSStella Laurenzo }, 647a6e7d024SStella Laurenzo py::arg("type"), py::arg("value"), 648436c6c9cSStella Laurenzo "Gets a uniqued string attribute associated to a type"); 6499f533548SIngo Müller c.def_property_readonly( 6509f533548SIngo Müller "value", 6519f533548SIngo Müller [](PyStringAttribute &self) { 6529f533548SIngo Müller MlirStringRef stringRef = mlirStringAttrGetValue(self); 6539f533548SIngo Müller return py::str(stringRef.data, stringRef.length); 6549f533548SIngo Müller }, 655436c6c9cSStella Laurenzo "Returns the value of the string attribute"); 65662bf6c2eSChris Jones c.def_property_readonly( 65762bf6c2eSChris Jones "value_bytes", 65862bf6c2eSChris Jones [](PyStringAttribute &self) { 65962bf6c2eSChris Jones MlirStringRef stringRef = mlirStringAttrGetValue(self); 66062bf6c2eSChris Jones return py::bytes(stringRef.data, stringRef.length); 66162bf6c2eSChris Jones }, 66262bf6c2eSChris Jones "Returns the value of the string attribute as `bytes`"); 663436c6c9cSStella Laurenzo } 664436c6c9cSStella Laurenzo }; 665436c6c9cSStella Laurenzo 666436c6c9cSStella Laurenzo // TODO: Support construction of string elements. 667436c6c9cSStella Laurenzo class PyDenseElementsAttribute 668436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseElementsAttribute> { 669436c6c9cSStella Laurenzo public: 670436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 671436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseElementsAttr"; 672436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 673436c6c9cSStella Laurenzo 674436c6c9cSStella Laurenzo static PyDenseElementsAttribute 675c912f0e7Spranavm-nvidia getFromList(py::list attributes, std::optional<PyType> explicitType, 676c912f0e7Spranavm-nvidia DefaultingPyMlirContext contextWrapper) { 677c912f0e7Spranavm-nvidia 678c912f0e7Spranavm-nvidia const size_t numAttributes = py::len(attributes); 679c912f0e7Spranavm-nvidia if (numAttributes == 0) 680c912f0e7Spranavm-nvidia throw py::value_error("Attributes list must be non-empty."); 681c912f0e7Spranavm-nvidia 682c912f0e7Spranavm-nvidia MlirType shapedType; 683c912f0e7Spranavm-nvidia if (explicitType) { 684c912f0e7Spranavm-nvidia if ((!mlirTypeIsAShaped(*explicitType) || 685c912f0e7Spranavm-nvidia !mlirShapedTypeHasStaticShape(*explicitType))) { 686c912f0e7Spranavm-nvidia 687c912f0e7Spranavm-nvidia std::string message; 688c912f0e7Spranavm-nvidia llvm::raw_string_ostream os(message); 689c912f0e7Spranavm-nvidia os << "Expected a static ShapedType for the shaped_type parameter: " 690c912f0e7Spranavm-nvidia << py::repr(py::cast(*explicitType)); 691c912f0e7Spranavm-nvidia throw py::value_error(os.str()); 692c912f0e7Spranavm-nvidia } 693c912f0e7Spranavm-nvidia shapedType = *explicitType; 694c912f0e7Spranavm-nvidia } else { 695c912f0e7Spranavm-nvidia SmallVector<int64_t> shape{static_cast<int64_t>(numAttributes)}; 696c912f0e7Spranavm-nvidia shapedType = mlirRankedTensorTypeGet( 697c912f0e7Spranavm-nvidia shape.size(), shape.data(), 698c912f0e7Spranavm-nvidia mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])), 699c912f0e7Spranavm-nvidia mlirAttributeGetNull()); 700c912f0e7Spranavm-nvidia } 701c912f0e7Spranavm-nvidia 702c912f0e7Spranavm-nvidia SmallVector<MlirAttribute> mlirAttributes; 703c912f0e7Spranavm-nvidia mlirAttributes.reserve(numAttributes); 704c912f0e7Spranavm-nvidia for (const py::handle &attribute : attributes) { 705c912f0e7Spranavm-nvidia MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute); 706c912f0e7Spranavm-nvidia MlirType attrType = mlirAttributeGetType(mlirAttribute); 707c912f0e7Spranavm-nvidia mlirAttributes.push_back(mlirAttribute); 708c912f0e7Spranavm-nvidia 709c912f0e7Spranavm-nvidia if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) { 710c912f0e7Spranavm-nvidia std::string message; 711c912f0e7Spranavm-nvidia llvm::raw_string_ostream os(message); 712c912f0e7Spranavm-nvidia os << "All attributes must be of the same type and match " 713c912f0e7Spranavm-nvidia << "the type parameter: expected=" << py::repr(py::cast(shapedType)) 714c912f0e7Spranavm-nvidia << ", but got=" << py::repr(py::cast(attrType)); 715c912f0e7Spranavm-nvidia throw py::value_error(os.str()); 716c912f0e7Spranavm-nvidia } 717c912f0e7Spranavm-nvidia } 718c912f0e7Spranavm-nvidia 719c912f0e7Spranavm-nvidia MlirAttribute elements = mlirDenseElementsAttrGet( 720c912f0e7Spranavm-nvidia shapedType, mlirAttributes.size(), mlirAttributes.data()); 721c912f0e7Spranavm-nvidia 722c912f0e7Spranavm-nvidia return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 723c912f0e7Spranavm-nvidia } 724c912f0e7Spranavm-nvidia 725c912f0e7Spranavm-nvidia static PyDenseElementsAttribute 7260a81ace0SKazu Hirata getFromBuffer(py::buffer array, bool signless, 7270a81ace0SKazu Hirata std::optional<PyType> explicitType, 7280a81ace0SKazu Hirata std::optional<std::vector<int64_t>> explicitShape, 729436c6c9cSStella Laurenzo DefaultingPyMlirContext contextWrapper) { 730436c6c9cSStella Laurenzo // Request a contiguous view. In exotic cases, this will cause a copy. 73171a25454SPeter Hawkins int flags = PyBUF_ND; 73271a25454SPeter Hawkins if (!explicitType) { 73371a25454SPeter Hawkins flags |= PyBUF_FORMAT; 73471a25454SPeter Hawkins } 73571a25454SPeter Hawkins Py_buffer view; 73671a25454SPeter Hawkins if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) { 737436c6c9cSStella Laurenzo throw py::error_already_set(); 738436c6c9cSStella Laurenzo } 73971a25454SPeter Hawkins auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); 7405d6d30edSStella Laurenzo SmallVector<int64_t> shape; 7415d6d30edSStella Laurenzo if (explicitShape) { 7425d6d30edSStella Laurenzo shape.append(explicitShape->begin(), explicitShape->end()); 7435d6d30edSStella Laurenzo } else { 74471a25454SPeter Hawkins shape.append(view.shape, view.shape + view.ndim); 7455d6d30edSStella Laurenzo } 746436c6c9cSStella Laurenzo 7475d6d30edSStella Laurenzo MlirAttribute encodingAttr = mlirAttributeGetNull(); 748436c6c9cSStella Laurenzo MlirContext context = contextWrapper->get(); 7495d6d30edSStella Laurenzo 7505d6d30edSStella Laurenzo // Detect format codes that are suitable for bulk loading. This includes 7515d6d30edSStella Laurenzo // all byte aligned integer and floating point types up to 8 bytes. 7525d6d30edSStella Laurenzo // Notably, this excludes, bool (which needs to be bit-packed) and 7535d6d30edSStella Laurenzo // other exotics which do not have a direct representation in the buffer 7545d6d30edSStella Laurenzo // protocol (i.e. complex, etc). 7550a81ace0SKazu Hirata std::optional<MlirType> bulkLoadElementType; 7565d6d30edSStella Laurenzo if (explicitType) { 7575d6d30edSStella Laurenzo bulkLoadElementType = *explicitType; 75871a25454SPeter Hawkins } else { 75971a25454SPeter Hawkins std::string_view format(view.format); 76071a25454SPeter Hawkins if (format == "f") { 761436c6c9cSStella Laurenzo // f32 76271a25454SPeter Hawkins assert(view.itemsize == 4 && "mismatched array itemsize"); 7635d6d30edSStella Laurenzo bulkLoadElementType = mlirF32TypeGet(context); 76471a25454SPeter Hawkins } else if (format == "d") { 765436c6c9cSStella Laurenzo // f64 76671a25454SPeter Hawkins assert(view.itemsize == 8 && "mismatched array itemsize"); 7675d6d30edSStella Laurenzo bulkLoadElementType = mlirF64TypeGet(context); 76871a25454SPeter Hawkins } else if (format == "e") { 7695d6d30edSStella Laurenzo // f16 77071a25454SPeter Hawkins assert(view.itemsize == 2 && "mismatched array itemsize"); 7715d6d30edSStella Laurenzo bulkLoadElementType = mlirF16TypeGet(context); 77271a25454SPeter Hawkins } else if (isSignedIntegerFormat(format)) { 77371a25454SPeter Hawkins if (view.itemsize == 4) { 774436c6c9cSStella Laurenzo // i32 77571a25454SPeter Hawkins bulkLoadElementType = signless 77671a25454SPeter Hawkins ? mlirIntegerTypeGet(context, 32) 777436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 32); 77871a25454SPeter Hawkins } else if (view.itemsize == 8) { 779436c6c9cSStella Laurenzo // i64 78071a25454SPeter Hawkins bulkLoadElementType = signless 78171a25454SPeter Hawkins ? mlirIntegerTypeGet(context, 64) 782436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 64); 78371a25454SPeter Hawkins } else if (view.itemsize == 1) { 7845d6d30edSStella Laurenzo // i8 7855d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 7865d6d30edSStella Laurenzo : mlirIntegerTypeSignedGet(context, 8); 78771a25454SPeter Hawkins } else if (view.itemsize == 2) { 7885d6d30edSStella Laurenzo // i16 78971a25454SPeter Hawkins bulkLoadElementType = signless 79071a25454SPeter Hawkins ? mlirIntegerTypeGet(context, 16) 7915d6d30edSStella Laurenzo : mlirIntegerTypeSignedGet(context, 16); 792436c6c9cSStella Laurenzo } 79371a25454SPeter Hawkins } else if (isUnsignedIntegerFormat(format)) { 79471a25454SPeter Hawkins if (view.itemsize == 4) { 795436c6c9cSStella Laurenzo // unsigned i32 7965d6d30edSStella Laurenzo bulkLoadElementType = signless 797436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 32) 798436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 32); 79971a25454SPeter Hawkins } else if (view.itemsize == 8) { 800436c6c9cSStella Laurenzo // unsigned i64 8015d6d30edSStella Laurenzo bulkLoadElementType = signless 802436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 64) 803436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 64); 80471a25454SPeter Hawkins } else if (view.itemsize == 1) { 8055d6d30edSStella Laurenzo // i8 80671a25454SPeter Hawkins bulkLoadElementType = signless 80771a25454SPeter Hawkins ? mlirIntegerTypeGet(context, 8) 8085d6d30edSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 8); 80971a25454SPeter Hawkins } else if (view.itemsize == 2) { 8105d6d30edSStella Laurenzo // i16 8115d6d30edSStella Laurenzo bulkLoadElementType = signless 8125d6d30edSStella Laurenzo ? mlirIntegerTypeGet(context, 16) 8135d6d30edSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 16); 814436c6c9cSStella Laurenzo } 815436c6c9cSStella Laurenzo } 81671a25454SPeter Hawkins if (!bulkLoadElementType) { 81771a25454SPeter Hawkins throw std::invalid_argument( 81871a25454SPeter Hawkins std::string("unimplemented array format conversion from format: ") + 81971a25454SPeter Hawkins std::string(format)); 82071a25454SPeter Hawkins } 82171a25454SPeter Hawkins } 82271a25454SPeter Hawkins 82399dee31eSAdam Paszke MlirType shapedType; 82499dee31eSAdam Paszke if (mlirTypeIsAShaped(*bulkLoadElementType)) { 82599dee31eSAdam Paszke if (explicitShape) { 82699dee31eSAdam Paszke throw std::invalid_argument("Shape can only be specified explicitly " 82799dee31eSAdam Paszke "when the type is not a shaped type."); 82899dee31eSAdam Paszke } 82999dee31eSAdam Paszke shapedType = *bulkLoadElementType; 83099dee31eSAdam Paszke } else { 83171a25454SPeter Hawkins shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), 83271a25454SPeter Hawkins *bulkLoadElementType, encodingAttr); 83399dee31eSAdam Paszke } 83471a25454SPeter Hawkins size_t rawBufferSize = view.len; 83571a25454SPeter Hawkins MlirAttribute attr = 83671a25454SPeter Hawkins mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf); 8375d6d30edSStella Laurenzo if (mlirAttributeIsNull(attr)) { 8385d6d30edSStella Laurenzo throw std::invalid_argument( 8395d6d30edSStella Laurenzo "DenseElementsAttr could not be constructed from the given buffer. " 8405d6d30edSStella Laurenzo "This may mean that the Python buffer layout does not match that " 8415d6d30edSStella Laurenzo "MLIR expected layout and is a bug."); 8425d6d30edSStella Laurenzo } 8435d6d30edSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), attr); 8445d6d30edSStella Laurenzo } 845436c6c9cSStella Laurenzo 8461fc096afSMehdi Amini static PyDenseElementsAttribute getSplat(const PyType &shapedType, 847436c6c9cSStella Laurenzo PyAttribute &elementAttr) { 848436c6c9cSStella Laurenzo auto contextWrapper = 849436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 850436c6c9cSStella Laurenzo if (!mlirAttributeIsAInteger(elementAttr) && 851436c6c9cSStella Laurenzo !mlirAttributeIsAFloat(elementAttr)) { 852436c6c9cSStella Laurenzo std::string message = "Illegal element type for DenseElementsAttr: "; 853436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 8544811270bSmax throw py::value_error(message); 855436c6c9cSStella Laurenzo } 856436c6c9cSStella Laurenzo if (!mlirTypeIsAShaped(shapedType) || 857436c6c9cSStella Laurenzo !mlirShapedTypeHasStaticShape(shapedType)) { 858436c6c9cSStella Laurenzo std::string message = 859436c6c9cSStella Laurenzo "Expected a static ShapedType for the shaped_type parameter: "; 860436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 8614811270bSmax throw py::value_error(message); 862436c6c9cSStella Laurenzo } 863436c6c9cSStella Laurenzo MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 864436c6c9cSStella Laurenzo MlirType attrType = mlirAttributeGetType(elementAttr); 865436c6c9cSStella Laurenzo if (!mlirTypeEqual(shapedElementType, attrType)) { 866436c6c9cSStella Laurenzo std::string message = 867436c6c9cSStella Laurenzo "Shaped element type and attribute type must be equal: shaped="; 868436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 869436c6c9cSStella Laurenzo message.append(", element="); 870436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 8714811270bSmax throw py::value_error(message); 872436c6c9cSStella Laurenzo } 873436c6c9cSStella Laurenzo 874436c6c9cSStella Laurenzo MlirAttribute elements = 875436c6c9cSStella Laurenzo mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 876436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 877436c6c9cSStella Laurenzo } 878436c6c9cSStella Laurenzo 879436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 880436c6c9cSStella Laurenzo 881436c6c9cSStella Laurenzo py::buffer_info accessBuffer() { 882436c6c9cSStella Laurenzo MlirType shapedType = mlirAttributeGetType(*this); 883436c6c9cSStella Laurenzo MlirType elementType = mlirShapedTypeGetElementType(shapedType); 8845d6d30edSStella Laurenzo std::string format; 885436c6c9cSStella Laurenzo 886436c6c9cSStella Laurenzo if (mlirTypeIsAF32(elementType)) { 887436c6c9cSStella Laurenzo // f32 8885d6d30edSStella Laurenzo return bufferInfo<float>(shapedType); 88902b6fb21SMehdi Amini } 89002b6fb21SMehdi Amini if (mlirTypeIsAF64(elementType)) { 891436c6c9cSStella Laurenzo // f64 8925d6d30edSStella Laurenzo return bufferInfo<double>(shapedType); 893bb56c2b3SMehdi Amini } 894bb56c2b3SMehdi Amini if (mlirTypeIsAF16(elementType)) { 8955d6d30edSStella Laurenzo // f16 8965d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType, "e"); 897bb56c2b3SMehdi Amini } 898ef1b735dSmax if (mlirTypeIsAIndex(elementType)) { 899ef1b735dSmax // Same as IndexType::kInternalStorageBitWidth 900ef1b735dSmax return bufferInfo<int64_t>(shapedType); 901ef1b735dSmax } 902bb56c2b3SMehdi Amini if (mlirTypeIsAInteger(elementType) && 903436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 32) { 904436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 905436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 906436c6c9cSStella Laurenzo // i32 9075d6d30edSStella Laurenzo return bufferInfo<int32_t>(shapedType); 908e5639b3fSMehdi Amini } 909e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 910436c6c9cSStella Laurenzo // unsigned i32 9115d6d30edSStella Laurenzo return bufferInfo<uint32_t>(shapedType); 912436c6c9cSStella Laurenzo } 913436c6c9cSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 914436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 64) { 915436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 916436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 917436c6c9cSStella Laurenzo // i64 9185d6d30edSStella Laurenzo return bufferInfo<int64_t>(shapedType); 919e5639b3fSMehdi Amini } 920e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 921436c6c9cSStella Laurenzo // unsigned i64 9225d6d30edSStella Laurenzo return bufferInfo<uint64_t>(shapedType); 9235d6d30edSStella Laurenzo } 9245d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 9255d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 8) { 9265d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 9275d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 9285d6d30edSStella Laurenzo // i8 9295d6d30edSStella Laurenzo return bufferInfo<int8_t>(shapedType); 930e5639b3fSMehdi Amini } 931e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 9325d6d30edSStella Laurenzo // unsigned i8 9335d6d30edSStella Laurenzo return bufferInfo<uint8_t>(shapedType); 9345d6d30edSStella Laurenzo } 9355d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 9365d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 16) { 9375d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 9385d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 9395d6d30edSStella Laurenzo // i16 9405d6d30edSStella Laurenzo return bufferInfo<int16_t>(shapedType); 941e5639b3fSMehdi Amini } 942e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 9435d6d30edSStella Laurenzo // unsigned i16 9445d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType); 945436c6c9cSStella Laurenzo } 946436c6c9cSStella Laurenzo } 947436c6c9cSStella Laurenzo 948c5f445d1SStella Laurenzo // TODO: Currently crashes the program. 9495d6d30edSStella Laurenzo // Reported as https://github.com/pybind/pybind11/issues/3336 950c5f445d1SStella Laurenzo throw std::invalid_argument( 951c5f445d1SStella Laurenzo "unsupported data type for conversion to Python buffer"); 952436c6c9cSStella Laurenzo } 953436c6c9cSStella Laurenzo 954436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 955436c6c9cSStella Laurenzo c.def("__len__", &PyDenseElementsAttribute::dunderLen) 956436c6c9cSStella Laurenzo .def_static("get", PyDenseElementsAttribute::getFromBuffer, 957436c6c9cSStella Laurenzo py::arg("array"), py::arg("signless") = true, 9585d6d30edSStella Laurenzo py::arg("type") = py::none(), py::arg("shape") = py::none(), 959436c6c9cSStella Laurenzo py::arg("context") = py::none(), 9605d6d30edSStella Laurenzo kDenseElementsAttrGetDocstring) 961c912f0e7Spranavm-nvidia .def_static("get", PyDenseElementsAttribute::getFromList, 962c912f0e7Spranavm-nvidia py::arg("attrs"), py::arg("type") = py::none(), 963c912f0e7Spranavm-nvidia py::arg("context") = py::none(), 964c912f0e7Spranavm-nvidia kDenseElementsAttrGetFromListDocstring) 965436c6c9cSStella Laurenzo .def_static("get_splat", PyDenseElementsAttribute::getSplat, 966436c6c9cSStella Laurenzo py::arg("shaped_type"), py::arg("element_attr"), 967436c6c9cSStella Laurenzo "Gets a DenseElementsAttr where all values are the same") 968436c6c9cSStella Laurenzo .def_property_readonly("is_splat", 969436c6c9cSStella Laurenzo [](PyDenseElementsAttribute &self) -> bool { 970436c6c9cSStella Laurenzo return mlirDenseElementsAttrIsSplat(self); 971436c6c9cSStella Laurenzo }) 97291259963SAdam Paszke .def("get_splat_value", 973974c1596SRahul Kayaith [](PyDenseElementsAttribute &self) { 974974c1596SRahul Kayaith if (!mlirDenseElementsAttrIsSplat(self)) 9754811270bSmax throw py::value_error( 97691259963SAdam Paszke "get_splat_value called on a non-splat attribute"); 977974c1596SRahul Kayaith return mlirDenseElementsAttrGetSplatValue(self); 97891259963SAdam Paszke }) 979436c6c9cSStella Laurenzo .def_buffer(&PyDenseElementsAttribute::accessBuffer); 980436c6c9cSStella Laurenzo } 981436c6c9cSStella Laurenzo 982436c6c9cSStella Laurenzo private: 98371a25454SPeter Hawkins static bool isUnsignedIntegerFormat(std::string_view format) { 984436c6c9cSStella Laurenzo if (format.empty()) 985436c6c9cSStella Laurenzo return false; 986436c6c9cSStella Laurenzo char code = format[0]; 987436c6c9cSStella Laurenzo return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 988436c6c9cSStella Laurenzo code == 'Q'; 989436c6c9cSStella Laurenzo } 990436c6c9cSStella Laurenzo 99171a25454SPeter Hawkins static bool isSignedIntegerFormat(std::string_view format) { 992436c6c9cSStella Laurenzo if (format.empty()) 993436c6c9cSStella Laurenzo return false; 994436c6c9cSStella Laurenzo char code = format[0]; 995436c6c9cSStella Laurenzo return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 996436c6c9cSStella Laurenzo code == 'q'; 997436c6c9cSStella Laurenzo } 998436c6c9cSStella Laurenzo 999436c6c9cSStella Laurenzo template <typename Type> 1000436c6c9cSStella Laurenzo py::buffer_info bufferInfo(MlirType shapedType, 10015d6d30edSStella Laurenzo const char *explicitFormat = nullptr) { 1002436c6c9cSStella Laurenzo intptr_t rank = mlirShapedTypeGetRank(shapedType); 1003436c6c9cSStella Laurenzo // Prepare the data for the buffer_info. 1004436c6c9cSStella Laurenzo // Buffer is configured for read-only access below. 1005436c6c9cSStella Laurenzo Type *data = static_cast<Type *>( 1006436c6c9cSStella Laurenzo const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 1007436c6c9cSStella Laurenzo // Prepare the shape for the buffer_info. 1008436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> shape; 1009436c6c9cSStella Laurenzo for (intptr_t i = 0; i < rank; ++i) 1010436c6c9cSStella Laurenzo shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 1011436c6c9cSStella Laurenzo // Prepare the strides for the buffer_info. 1012436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> strides; 1013f0e847d0SRahul Kayaith if (mlirDenseElementsAttrIsSplat(*this)) { 1014f0e847d0SRahul Kayaith // Splats are special, only the single value is stored. 1015f0e847d0SRahul Kayaith strides.assign(rank, 0); 1016f0e847d0SRahul Kayaith } else { 1017436c6c9cSStella Laurenzo for (intptr_t i = 1; i < rank; ++i) { 1018f0e847d0SRahul Kayaith intptr_t strideFactor = 1; 1019f0e847d0SRahul Kayaith for (intptr_t j = i; j < rank; ++j) 1020436c6c9cSStella Laurenzo strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 1021436c6c9cSStella Laurenzo strides.push_back(sizeof(Type) * strideFactor); 1022436c6c9cSStella Laurenzo } 1023436c6c9cSStella Laurenzo strides.push_back(sizeof(Type)); 1024f0e847d0SRahul Kayaith } 10255d6d30edSStella Laurenzo std::string format; 10265d6d30edSStella Laurenzo if (explicitFormat) { 10275d6d30edSStella Laurenzo format = explicitFormat; 10285d6d30edSStella Laurenzo } else { 10295d6d30edSStella Laurenzo format = py::format_descriptor<Type>::format(); 10305d6d30edSStella Laurenzo } 10315d6d30edSStella Laurenzo return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, 10325d6d30edSStella Laurenzo /*readonly=*/true); 1033436c6c9cSStella Laurenzo } 1034436c6c9cSStella Laurenzo }; // namespace 1035436c6c9cSStella Laurenzo 1036436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer 1037436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access. 1038436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute 1039436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseIntElementsAttribute, 1040436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 1041436c6c9cSStella Laurenzo public: 1042436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 1043436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseIntElementsAttr"; 1044436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1045436c6c9cSStella Laurenzo 1046436c6c9cSStella Laurenzo /// Returns the element at the given linear position. Asserts if the index is 1047436c6c9cSStella Laurenzo /// out of range. 1048436c6c9cSStella Laurenzo py::int_ dunderGetItem(intptr_t pos) { 1049436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 10504811270bSmax throw py::index_error("attempt to access out of bounds element"); 1051436c6c9cSStella Laurenzo } 1052436c6c9cSStella Laurenzo 1053436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 1054436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 1055436c6c9cSStella Laurenzo assert(mlirTypeIsAInteger(type) && 1056436c6c9cSStella Laurenzo "expected integer element type in dense int elements attribute"); 1057436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 1058436c6c9cSStella Laurenzo // elemental type of the attribute. py::int_ is implicitly constructible 1059436c6c9cSStella Laurenzo // from any C++ integral type and handles bitwidth correctly. 1060436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 1061436c6c9cSStella Laurenzo // querying them on each element access. 1062436c6c9cSStella Laurenzo unsigned width = mlirIntegerTypeGetWidth(type); 1063436c6c9cSStella Laurenzo bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 1064436c6c9cSStella Laurenzo if (isUnsigned) { 1065436c6c9cSStella Laurenzo if (width == 1) { 1066436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 1067436c6c9cSStella Laurenzo } 1068308d8b8cSRahul Kayaith if (width == 8) { 1069308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetUInt8Value(*this, pos); 1070308d8b8cSRahul Kayaith } 1071308d8b8cSRahul Kayaith if (width == 16) { 1072308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetUInt16Value(*this, pos); 1073308d8b8cSRahul Kayaith } 1074436c6c9cSStella Laurenzo if (width == 32) { 1075436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt32Value(*this, pos); 1076436c6c9cSStella Laurenzo } 1077436c6c9cSStella Laurenzo if (width == 64) { 1078436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt64Value(*this, pos); 1079436c6c9cSStella Laurenzo } 1080436c6c9cSStella Laurenzo } else { 1081436c6c9cSStella Laurenzo if (width == 1) { 1082436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 1083436c6c9cSStella Laurenzo } 1084308d8b8cSRahul Kayaith if (width == 8) { 1085308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetInt8Value(*this, pos); 1086308d8b8cSRahul Kayaith } 1087308d8b8cSRahul Kayaith if (width == 16) { 1088308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetInt16Value(*this, pos); 1089308d8b8cSRahul Kayaith } 1090436c6c9cSStella Laurenzo if (width == 32) { 1091436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt32Value(*this, pos); 1092436c6c9cSStella Laurenzo } 1093436c6c9cSStella Laurenzo if (width == 64) { 1094436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt64Value(*this, pos); 1095436c6c9cSStella Laurenzo } 1096436c6c9cSStella Laurenzo } 10974811270bSmax throw py::type_error("Unsupported integer type"); 1098436c6c9cSStella Laurenzo } 1099436c6c9cSStella Laurenzo 1100436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1101436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 1102436c6c9cSStella Laurenzo } 1103436c6c9cSStella Laurenzo }; 1104436c6c9cSStella Laurenzo 1105f66cd9e9SStella Laurenzo class PyDenseResourceElementsAttribute 1106f66cd9e9SStella Laurenzo : public PyConcreteAttribute<PyDenseResourceElementsAttribute> { 1107f66cd9e9SStella Laurenzo public: 1108f66cd9e9SStella Laurenzo static constexpr IsAFunctionTy isaFunction = 1109f66cd9e9SStella Laurenzo mlirAttributeIsADenseResourceElements; 1110f66cd9e9SStella Laurenzo static constexpr const char *pyClassName = "DenseResourceElementsAttr"; 1111f66cd9e9SStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1112f66cd9e9SStella Laurenzo 1113f66cd9e9SStella Laurenzo static PyDenseResourceElementsAttribute 1114962bf002SMehdi Amini getFromBuffer(py::buffer buffer, const std::string &name, const PyType &type, 1115f66cd9e9SStella Laurenzo std::optional<size_t> alignment, bool isMutable, 1116f66cd9e9SStella Laurenzo DefaultingPyMlirContext contextWrapper) { 1117f66cd9e9SStella Laurenzo if (!mlirTypeIsAShaped(type)) { 1118f66cd9e9SStella Laurenzo throw std::invalid_argument( 1119f66cd9e9SStella Laurenzo "Constructing a DenseResourceElementsAttr requires a ShapedType."); 1120f66cd9e9SStella Laurenzo } 1121f66cd9e9SStella Laurenzo 1122f66cd9e9SStella Laurenzo // Do not request any conversions as we must ensure to use caller 1123f66cd9e9SStella Laurenzo // managed memory. 1124f66cd9e9SStella Laurenzo int flags = PyBUF_STRIDES; 1125f66cd9e9SStella Laurenzo std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>(); 1126f66cd9e9SStella Laurenzo if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { 1127f66cd9e9SStella Laurenzo throw py::error_already_set(); 1128f66cd9e9SStella Laurenzo } 1129f66cd9e9SStella Laurenzo 1130f66cd9e9SStella Laurenzo // This scope releaser will only release if we haven't yet transferred 1131f66cd9e9SStella Laurenzo // ownership. 1132f66cd9e9SStella Laurenzo auto freeBuffer = llvm::make_scope_exit([&]() { 1133f66cd9e9SStella Laurenzo if (view) 1134f66cd9e9SStella Laurenzo PyBuffer_Release(view.get()); 1135f66cd9e9SStella Laurenzo }); 1136f66cd9e9SStella Laurenzo 1137f66cd9e9SStella Laurenzo if (!PyBuffer_IsContiguous(view.get(), 'A')) { 1138f66cd9e9SStella Laurenzo throw std::invalid_argument("Contiguous buffer is required."); 1139f66cd9e9SStella Laurenzo } 1140f66cd9e9SStella Laurenzo 1141f66cd9e9SStella Laurenzo // Infer alignment to be the stride of one element if not explicit. 1142f66cd9e9SStella Laurenzo size_t inferredAlignment; 1143f66cd9e9SStella Laurenzo if (alignment) 1144f66cd9e9SStella Laurenzo inferredAlignment = *alignment; 1145f66cd9e9SStella Laurenzo else 1146f66cd9e9SStella Laurenzo inferredAlignment = view->strides[view->ndim - 1]; 1147f66cd9e9SStella Laurenzo 1148f66cd9e9SStella Laurenzo // The userData is a Py_buffer* that the deleter owns. 1149f66cd9e9SStella Laurenzo auto deleter = [](void *userData, const void *data, size_t size, 1150f66cd9e9SStella Laurenzo size_t align) { 1151f66cd9e9SStella Laurenzo Py_buffer *ownedView = static_cast<Py_buffer *>(userData); 1152f66cd9e9SStella Laurenzo PyBuffer_Release(ownedView); 1153f66cd9e9SStella Laurenzo delete ownedView; 1154f66cd9e9SStella Laurenzo }; 1155f66cd9e9SStella Laurenzo 1156f66cd9e9SStella Laurenzo size_t rawBufferSize = view->len; 1157f66cd9e9SStella Laurenzo MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet( 1158f66cd9e9SStella Laurenzo type, toMlirStringRef(name), view->buf, rawBufferSize, 1159f66cd9e9SStella Laurenzo inferredAlignment, isMutable, deleter, static_cast<void *>(view.get())); 1160f66cd9e9SStella Laurenzo if (mlirAttributeIsNull(attr)) { 1161f66cd9e9SStella Laurenzo throw std::invalid_argument( 1162f66cd9e9SStella Laurenzo "DenseResourceElementsAttr could not be constructed from the given " 1163f66cd9e9SStella Laurenzo "buffer. " 1164f66cd9e9SStella Laurenzo "This may mean that the Python buffer layout does not match that " 1165f66cd9e9SStella Laurenzo "MLIR expected layout and is a bug."); 1166f66cd9e9SStella Laurenzo } 1167f66cd9e9SStella Laurenzo view.release(); 1168f66cd9e9SStella Laurenzo return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr); 1169f66cd9e9SStella Laurenzo } 1170f66cd9e9SStella Laurenzo 1171f66cd9e9SStella Laurenzo static void bindDerived(ClassTy &c) { 1172f66cd9e9SStella Laurenzo c.def_static("get_from_buffer", 1173f66cd9e9SStella Laurenzo PyDenseResourceElementsAttribute::getFromBuffer, 1174f66cd9e9SStella Laurenzo py::arg("array"), py::arg("name"), py::arg("type"), 1175f66cd9e9SStella Laurenzo py::arg("alignment") = py::none(), 1176f66cd9e9SStella Laurenzo py::arg("is_mutable") = false, py::arg("context") = py::none(), 1177f66cd9e9SStella Laurenzo kDenseResourceElementsAttrGetFromBufferDocstring); 1178f66cd9e9SStella Laurenzo } 1179f66cd9e9SStella Laurenzo }; 1180f66cd9e9SStella Laurenzo 1181436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 1182436c6c9cSStella Laurenzo public: 1183436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 1184436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DictAttr"; 1185436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 11869566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 11879566ee28Smax mlirDictionaryAttrGetTypeID; 1188436c6c9cSStella Laurenzo 1189436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 1190436c6c9cSStella Laurenzo 11919fb1086bSAdrian Kuegel bool dunderContains(const std::string &name) { 11929fb1086bSAdrian Kuegel return !mlirAttributeIsNull( 11939fb1086bSAdrian Kuegel mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); 11949fb1086bSAdrian Kuegel } 11959fb1086bSAdrian Kuegel 1196436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 11979fb1086bSAdrian Kuegel c.def("__contains__", &PyDictAttribute::dunderContains); 1198436c6c9cSStella Laurenzo c.def("__len__", &PyDictAttribute::dunderLen); 1199436c6c9cSStella Laurenzo c.def_static( 1200436c6c9cSStella Laurenzo "get", 1201436c6c9cSStella Laurenzo [](py::dict attributes, DefaultingPyMlirContext context) { 1202436c6c9cSStella Laurenzo SmallVector<MlirNamedAttribute> mlirNamedAttributes; 1203436c6c9cSStella Laurenzo mlirNamedAttributes.reserve(attributes.size()); 1204436c6c9cSStella Laurenzo for (auto &it : attributes) { 120502b6fb21SMehdi Amini auto &mlirAttr = it.second.cast<PyAttribute &>(); 1206436c6c9cSStella Laurenzo auto name = it.first.cast<std::string>(); 1207436c6c9cSStella Laurenzo mlirNamedAttributes.push_back(mlirNamedAttributeGet( 120802b6fb21SMehdi Amini mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), 1209436c6c9cSStella Laurenzo toMlirStringRef(name)), 121002b6fb21SMehdi Amini mlirAttr)); 1211436c6c9cSStella Laurenzo } 1212436c6c9cSStella Laurenzo MlirAttribute attr = 1213436c6c9cSStella Laurenzo mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 1214436c6c9cSStella Laurenzo mlirNamedAttributes.data()); 1215436c6c9cSStella Laurenzo return PyDictAttribute(context->getRef(), attr); 1216436c6c9cSStella Laurenzo }, 1217ed9e52f3SAlex Zinenko py::arg("value") = py::dict(), py::arg("context") = py::none(), 1218436c6c9cSStella Laurenzo "Gets an uniqued dict attribute"); 1219436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 1220436c6c9cSStella Laurenzo MlirAttribute attr = 1221436c6c9cSStella Laurenzo mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 1222974c1596SRahul Kayaith if (mlirAttributeIsNull(attr)) 12234811270bSmax throw py::key_error("attempt to access a non-existent attribute"); 1224974c1596SRahul Kayaith return attr; 1225436c6c9cSStella Laurenzo }); 1226436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 1227436c6c9cSStella Laurenzo if (index < 0 || index >= self.dunderLen()) { 12284811270bSmax throw py::index_error("attempt to access out of bounds attribute"); 1229436c6c9cSStella Laurenzo } 1230436c6c9cSStella Laurenzo MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 1231436c6c9cSStella Laurenzo return PyNamedAttribute( 1232436c6c9cSStella Laurenzo namedAttr.attribute, 1233436c6c9cSStella Laurenzo std::string(mlirIdentifierStr(namedAttr.name).data)); 1234436c6c9cSStella Laurenzo }); 1235436c6c9cSStella Laurenzo } 1236436c6c9cSStella Laurenzo }; 1237436c6c9cSStella Laurenzo 1238436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing 1239436c6c9cSStella Laurenzo /// floating-point values. Supports element access. 1240436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute 1241436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseFPElementsAttribute, 1242436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 1243436c6c9cSStella Laurenzo public: 1244436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 1245436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseFPElementsAttr"; 1246436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1247436c6c9cSStella Laurenzo 1248436c6c9cSStella Laurenzo py::float_ dunderGetItem(intptr_t pos) { 1249436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 12504811270bSmax throw py::index_error("attempt to access out of bounds element"); 1251436c6c9cSStella Laurenzo } 1252436c6c9cSStella Laurenzo 1253436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 1254436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 1255436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 1256436c6c9cSStella Laurenzo // elemental type of the attribute. py::float_ is implicitly constructible 1257436c6c9cSStella Laurenzo // from float and double. 1258436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 1259436c6c9cSStella Laurenzo // querying them on each element access. 1260436c6c9cSStella Laurenzo if (mlirTypeIsAF32(type)) { 1261436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetFloatValue(*this, pos); 1262436c6c9cSStella Laurenzo } 1263436c6c9cSStella Laurenzo if (mlirTypeIsAF64(type)) { 1264436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetDoubleValue(*this, pos); 1265436c6c9cSStella Laurenzo } 12664811270bSmax throw py::type_error("Unsupported floating-point type"); 1267436c6c9cSStella Laurenzo } 1268436c6c9cSStella Laurenzo 1269436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1270436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 1271436c6c9cSStella Laurenzo } 1272436c6c9cSStella Laurenzo }; 1273436c6c9cSStella Laurenzo 1274436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 1275436c6c9cSStella Laurenzo public: 1276436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 1277436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "TypeAttr"; 1278436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 12799566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 12809566ee28Smax mlirTypeAttrGetTypeID; 1281436c6c9cSStella Laurenzo 1282436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1283436c6c9cSStella Laurenzo c.def_static( 1284436c6c9cSStella Laurenzo "get", 1285436c6c9cSStella Laurenzo [](PyType value, DefaultingPyMlirContext context) { 1286436c6c9cSStella Laurenzo MlirAttribute attr = mlirTypeAttrGet(value.get()); 1287436c6c9cSStella Laurenzo return PyTypeAttribute(context->getRef(), attr); 1288436c6c9cSStella Laurenzo }, 1289436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 1290436c6c9cSStella Laurenzo "Gets a uniqued Type attribute"); 1291436c6c9cSStella Laurenzo c.def_property_readonly("value", [](PyTypeAttribute &self) { 1292bfb1ba75Smax return mlirTypeAttrGetValue(self.get()); 1293436c6c9cSStella Laurenzo }); 1294436c6c9cSStella Laurenzo } 1295436c6c9cSStella Laurenzo }; 1296436c6c9cSStella Laurenzo 1297436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values. 1298436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 1299436c6c9cSStella Laurenzo public: 1300436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 1301436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "UnitAttr"; 1302436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 13039566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 13049566ee28Smax mlirUnitAttrGetTypeID; 1305436c6c9cSStella Laurenzo 1306436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1307436c6c9cSStella Laurenzo c.def_static( 1308436c6c9cSStella Laurenzo "get", 1309436c6c9cSStella Laurenzo [](DefaultingPyMlirContext context) { 1310436c6c9cSStella Laurenzo return PyUnitAttribute(context->getRef(), 1311436c6c9cSStella Laurenzo mlirUnitAttrGet(context->get())); 1312436c6c9cSStella Laurenzo }, 1313436c6c9cSStella Laurenzo py::arg("context") = py::none(), "Create a Unit attribute."); 1314436c6c9cSStella Laurenzo } 1315436c6c9cSStella Laurenzo }; 1316436c6c9cSStella Laurenzo 1317ac2e2d65SDenys Shabalin /// Strided layout attribute subclass. 1318ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute 1319ac2e2d65SDenys Shabalin : public PyConcreteAttribute<PyStridedLayoutAttribute> { 1320ac2e2d65SDenys Shabalin public: 1321ac2e2d65SDenys Shabalin static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; 1322ac2e2d65SDenys Shabalin static constexpr const char *pyClassName = "StridedLayoutAttr"; 1323ac2e2d65SDenys Shabalin using PyConcreteAttribute::PyConcreteAttribute; 13249566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 13259566ee28Smax mlirStridedLayoutAttrGetTypeID; 1326ac2e2d65SDenys Shabalin 1327ac2e2d65SDenys Shabalin static void bindDerived(ClassTy &c) { 1328ac2e2d65SDenys Shabalin c.def_static( 1329ac2e2d65SDenys Shabalin "get", 1330ac2e2d65SDenys Shabalin [](int64_t offset, const std::vector<int64_t> strides, 1331ac2e2d65SDenys Shabalin DefaultingPyMlirContext ctx) { 1332ac2e2d65SDenys Shabalin MlirAttribute attr = mlirStridedLayoutAttrGet( 1333ac2e2d65SDenys Shabalin ctx->get(), offset, strides.size(), strides.data()); 1334ac2e2d65SDenys Shabalin return PyStridedLayoutAttribute(ctx->getRef(), attr); 1335ac2e2d65SDenys Shabalin }, 1336ac2e2d65SDenys Shabalin py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(), 1337ac2e2d65SDenys Shabalin "Gets a strided layout attribute."); 1338e3fd612eSDenys Shabalin c.def_static( 1339e3fd612eSDenys Shabalin "get_fully_dynamic", 1340e3fd612eSDenys Shabalin [](int64_t rank, DefaultingPyMlirContext ctx) { 1341e3fd612eSDenys Shabalin auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); 1342e3fd612eSDenys Shabalin std::vector<int64_t> strides(rank); 1343e3fd612eSDenys Shabalin std::fill(strides.begin(), strides.end(), dynamic); 1344e3fd612eSDenys Shabalin MlirAttribute attr = mlirStridedLayoutAttrGet( 1345e3fd612eSDenys Shabalin ctx->get(), dynamic, strides.size(), strides.data()); 1346e3fd612eSDenys Shabalin return PyStridedLayoutAttribute(ctx->getRef(), attr); 1347e3fd612eSDenys Shabalin }, 1348e3fd612eSDenys Shabalin py::arg("rank"), py::arg("context") = py::none(), 1349e3fd612eSDenys Shabalin "Gets a strided layout attribute with dynamic offset and strides of a " 1350e3fd612eSDenys Shabalin "given rank."); 1351ac2e2d65SDenys Shabalin c.def_property_readonly( 1352ac2e2d65SDenys Shabalin "offset", 1353ac2e2d65SDenys Shabalin [](PyStridedLayoutAttribute &self) { 1354ac2e2d65SDenys Shabalin return mlirStridedLayoutAttrGetOffset(self); 1355ac2e2d65SDenys Shabalin }, 1356ac2e2d65SDenys Shabalin "Returns the value of the float point attribute"); 1357ac2e2d65SDenys Shabalin c.def_property_readonly( 1358ac2e2d65SDenys Shabalin "strides", 1359ac2e2d65SDenys Shabalin [](PyStridedLayoutAttribute &self) { 1360ac2e2d65SDenys Shabalin intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); 1361ac2e2d65SDenys Shabalin std::vector<int64_t> strides(size); 1362ac2e2d65SDenys Shabalin for (intptr_t i = 0; i < size; i++) { 1363ac2e2d65SDenys Shabalin strides[i] = mlirStridedLayoutAttrGetStride(self, i); 1364ac2e2d65SDenys Shabalin } 1365ac2e2d65SDenys Shabalin return strides; 1366ac2e2d65SDenys Shabalin }, 1367ac2e2d65SDenys Shabalin "Returns the value of the float point attribute"); 1368ac2e2d65SDenys Shabalin } 1369ac2e2d65SDenys Shabalin }; 1370ac2e2d65SDenys Shabalin 13719566ee28Smax py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { 13729566ee28Smax if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) 13739566ee28Smax return py::cast(PyDenseBoolArrayAttribute(pyAttribute)); 13749566ee28Smax if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) 13759566ee28Smax return py::cast(PyDenseI8ArrayAttribute(pyAttribute)); 13769566ee28Smax if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) 13779566ee28Smax return py::cast(PyDenseI16ArrayAttribute(pyAttribute)); 13789566ee28Smax if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) 13799566ee28Smax return py::cast(PyDenseI32ArrayAttribute(pyAttribute)); 13809566ee28Smax if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) 13819566ee28Smax return py::cast(PyDenseI64ArrayAttribute(pyAttribute)); 13829566ee28Smax if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) 13839566ee28Smax return py::cast(PyDenseF32ArrayAttribute(pyAttribute)); 13849566ee28Smax if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) 13859566ee28Smax return py::cast(PyDenseF64ArrayAttribute(pyAttribute)); 13869566ee28Smax std::string msg = 13879566ee28Smax std::string("Can't cast unknown element type DenseArrayAttr (") + 13889566ee28Smax std::string(py::repr(py::cast(pyAttribute))) + ")"; 13899566ee28Smax throw py::cast_error(msg); 13909566ee28Smax } 13919566ee28Smax 13929566ee28Smax py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { 13939566ee28Smax if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) 13949566ee28Smax return py::cast(PyDenseFPElementsAttribute(pyAttribute)); 13959566ee28Smax if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) 13969566ee28Smax return py::cast(PyDenseIntElementsAttribute(pyAttribute)); 13979566ee28Smax std::string msg = 13989566ee28Smax std::string( 13999566ee28Smax "Can't cast unknown element type DenseIntOrFPElementsAttr (") + 14009566ee28Smax std::string(py::repr(py::cast(pyAttribute))) + ")"; 14019566ee28Smax throw py::cast_error(msg); 14029566ee28Smax } 14039566ee28Smax 14049566ee28Smax py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { 14059566ee28Smax if (PyBoolAttribute::isaFunction(pyAttribute)) 14069566ee28Smax return py::cast(PyBoolAttribute(pyAttribute)); 14079566ee28Smax if (PyIntegerAttribute::isaFunction(pyAttribute)) 14089566ee28Smax return py::cast(PyIntegerAttribute(pyAttribute)); 14099566ee28Smax std::string msg = 14109566ee28Smax std::string("Can't cast unknown element type DenseArrayAttr (") + 14119566ee28Smax std::string(py::repr(py::cast(pyAttribute))) + ")"; 14129566ee28Smax throw py::cast_error(msg); 14139566ee28Smax } 14149566ee28Smax 14154eee9ef9Smax py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { 14164eee9ef9Smax if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute)) 14174eee9ef9Smax return py::cast(PyFlatSymbolRefAttribute(pyAttribute)); 14184eee9ef9Smax if (PySymbolRefAttribute::isaFunction(pyAttribute)) 14194eee9ef9Smax return py::cast(PySymbolRefAttribute(pyAttribute)); 14204eee9ef9Smax std::string msg = std::string("Can't cast unknown SymbolRef attribute (") + 14214eee9ef9Smax std::string(py::repr(py::cast(pyAttribute))) + ")"; 14224eee9ef9Smax throw py::cast_error(msg); 14234eee9ef9Smax } 14244eee9ef9Smax 1425436c6c9cSStella Laurenzo } // namespace 1426436c6c9cSStella Laurenzo 1427436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) { 1428436c6c9cSStella Laurenzo PyAffineMapAttribute::bind(m); 1429619fd8c2SJeff Niu 1430619fd8c2SJeff Niu PyDenseBoolArrayAttribute::bind(m); 1431619fd8c2SJeff Niu PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); 1432619fd8c2SJeff Niu PyDenseI8ArrayAttribute::bind(m); 1433619fd8c2SJeff Niu PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m); 1434619fd8c2SJeff Niu PyDenseI16ArrayAttribute::bind(m); 1435619fd8c2SJeff Niu PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m); 1436619fd8c2SJeff Niu PyDenseI32ArrayAttribute::bind(m); 1437619fd8c2SJeff Niu PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m); 1438619fd8c2SJeff Niu PyDenseI64ArrayAttribute::bind(m); 1439619fd8c2SJeff Niu PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m); 1440619fd8c2SJeff Niu PyDenseF32ArrayAttribute::bind(m); 1441619fd8c2SJeff Niu PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); 1442619fd8c2SJeff Niu PyDenseF64ArrayAttribute::bind(m); 1443619fd8c2SJeff Niu PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); 14449566ee28Smax PyGlobals::get().registerTypeCaster( 14459566ee28Smax mlirDenseArrayAttrGetTypeID(), 14469566ee28Smax pybind11::cpp_function(denseArrayAttributeCaster)); 1447619fd8c2SJeff Niu 1448436c6c9cSStella Laurenzo PyArrayAttribute::bind(m); 1449436c6c9cSStella Laurenzo PyArrayAttribute::PyArrayAttributeIterator::bind(m); 1450436c6c9cSStella Laurenzo PyBoolAttribute::bind(m); 1451436c6c9cSStella Laurenzo PyDenseElementsAttribute::bind(m); 1452436c6c9cSStella Laurenzo PyDenseFPElementsAttribute::bind(m); 1453436c6c9cSStella Laurenzo PyDenseIntElementsAttribute::bind(m); 14549566ee28Smax PyGlobals::get().registerTypeCaster( 14559566ee28Smax mlirDenseIntOrFPElementsAttrGetTypeID(), 14569566ee28Smax pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); 1457f66cd9e9SStella Laurenzo PyDenseResourceElementsAttribute::bind(m); 14589566ee28Smax 1459436c6c9cSStella Laurenzo PyDictAttribute::bind(m); 14604eee9ef9Smax PySymbolRefAttribute::bind(m); 14614eee9ef9Smax PyGlobals::get().registerTypeCaster( 14624eee9ef9Smax mlirSymbolRefAttrGetTypeID(), 14634eee9ef9Smax pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)); 14644eee9ef9Smax 1465436c6c9cSStella Laurenzo PyFlatSymbolRefAttribute::bind(m); 14665c3861b2SYun Long PyOpaqueAttribute::bind(m); 1467436c6c9cSStella Laurenzo PyFloatAttribute::bind(m); 1468436c6c9cSStella Laurenzo PyIntegerAttribute::bind(m); 1469436c6c9cSStella Laurenzo PyStringAttribute::bind(m); 1470436c6c9cSStella Laurenzo PyTypeAttribute::bind(m); 14719566ee28Smax PyGlobals::get().registerTypeCaster( 14729566ee28Smax mlirIntegerAttrGetTypeID(), 14739566ee28Smax pybind11::cpp_function(integerOrBoolAttributeCaster)); 1474436c6c9cSStella Laurenzo PyUnitAttribute::bind(m); 1475ac2e2d65SDenys Shabalin 1476ac2e2d65SDenys Shabalin PyStridedLayoutAttribute::bind(m); 1477436c6c9cSStella Laurenzo } 1478