1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2436c6c9cSStella Laurenzo // 3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6436c6c9cSStella Laurenzo // 7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8436c6c9cSStella Laurenzo 91fc096afSMehdi Amini #include <utility> 10a1fe1f5fSKazu Hirata #include <optional> 111fc096afSMehdi Amini 12436c6c9cSStella Laurenzo #include "IRModule.h" 13436c6c9cSStella Laurenzo 14436c6c9cSStella Laurenzo #include "PybindUtils.h" 15436c6c9cSStella Laurenzo 16436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h" 17436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h" 18436c6c9cSStella Laurenzo 19436c6c9cSStella Laurenzo namespace py = pybind11; 20436c6c9cSStella Laurenzo using namespace mlir; 21436c6c9cSStella Laurenzo using namespace mlir::python; 22436c6c9cSStella Laurenzo 235d6d30edSStella Laurenzo using llvm::Optional; 24436c6c9cSStella Laurenzo using llvm::SmallVector; 25436c6c9cSStella Laurenzo using llvm::Twine; 26436c6c9cSStella Laurenzo 275d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 285d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline). 295d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 305d6d30edSStella Laurenzo 315d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] = 325d6d30edSStella Laurenzo R"(Gets a DenseElementsAttr from a Python buffer or array. 335d6d30edSStella Laurenzo 345d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based 355d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and 365d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these 375d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer. 385d6d30edSStella Laurenzo 395d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided 405d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal 415d6d30edSStella Laurenzo representation: 425d6d30edSStella Laurenzo 435d6d30edSStella Laurenzo * Integer types (except for i1): the buffer must be byte aligned to the 445d6d30edSStella Laurenzo next byte boundary. 455d6d30edSStella Laurenzo * Floating point types: Must be bit-castable to the given floating point 465d6d30edSStella Laurenzo size. 475d6d30edSStella Laurenzo * i1 (bool): Bit packed into 8bit words where the bit pattern matches a 485d6d30edSStella Laurenzo row major ordering. An arbitrary Numpy `bool_` array can be bit packed to 495d6d30edSStella Laurenzo this specification with: `np.packbits(ary, axis=None, bitorder='little')`. 505d6d30edSStella Laurenzo 515d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0 525d6d30edSStella Laurenzo or 255), then a splat will be created. 535d6d30edSStella Laurenzo 545d6d30edSStella Laurenzo Args: 555d6d30edSStella Laurenzo array: The array or buffer to convert. 565d6d30edSStella Laurenzo signless: If inferring an appropriate MLIR type, use signless types for 575d6d30edSStella Laurenzo integers (defaults True). 585d6d30edSStella Laurenzo type: Skips inference of the MLIR element type and uses this instead. The 595d6d30edSStella Laurenzo storage size must be consistent with the actual contents of the buffer. 605d6d30edSStella Laurenzo shape: Overrides the shape of the buffer when constructing the MLIR 615d6d30edSStella Laurenzo shaped type. This is needed when the physical and logical shape differ (as 625d6d30edSStella Laurenzo for i1). 635d6d30edSStella Laurenzo context: Explicit context, if not from context manager. 645d6d30edSStella Laurenzo 655d6d30edSStella Laurenzo Returns: 665d6d30edSStella Laurenzo DenseElementsAttr on success. 675d6d30edSStella Laurenzo 685d6d30edSStella Laurenzo Raises: 695d6d30edSStella Laurenzo ValueError: If the type of the buffer or array cannot be matched to an MLIR 705d6d30edSStella Laurenzo type or if the buffer does not meet expectations. 715d6d30edSStella Laurenzo )"; 725d6d30edSStella Laurenzo 73436c6c9cSStella Laurenzo namespace { 74436c6c9cSStella Laurenzo 75436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) { 76436c6c9cSStella Laurenzo return mlirStringRefCreate(s.data(), s.size()); 77436c6c9cSStella Laurenzo } 78436c6c9cSStella Laurenzo 79436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 80436c6c9cSStella Laurenzo public: 81436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 82436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineMapAttr"; 83436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 84436c6c9cSStella Laurenzo 85436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 86436c6c9cSStella Laurenzo c.def_static( 87436c6c9cSStella Laurenzo "get", 88436c6c9cSStella Laurenzo [](PyAffineMap &affineMap) { 89436c6c9cSStella Laurenzo MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 90436c6c9cSStella Laurenzo return PyAffineMapAttribute(affineMap.getContext(), attr); 91436c6c9cSStella Laurenzo }, 92436c6c9cSStella Laurenzo py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 93436c6c9cSStella Laurenzo } 94436c6c9cSStella Laurenzo }; 95436c6c9cSStella Laurenzo 96ed9e52f3SAlex Zinenko template <typename T> 97ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) { 98ed9e52f3SAlex Zinenko try { 99ed9e52f3SAlex Zinenko return object.cast<T>(); 100ed9e52f3SAlex Zinenko } catch (py::cast_error &err) { 101ed9e52f3SAlex Zinenko std::string msg = 102ed9e52f3SAlex Zinenko std::string( 103ed9e52f3SAlex Zinenko "Invalid attribute when attempting to create an ArrayAttribute (") + 104ed9e52f3SAlex Zinenko err.what() + ")"; 105ed9e52f3SAlex Zinenko throw py::cast_error(msg); 106ed9e52f3SAlex Zinenko } catch (py::reference_cast_error &err) { 107ed9e52f3SAlex Zinenko std::string msg = std::string("Invalid attribute (None?) when attempting " 108ed9e52f3SAlex Zinenko "to create an ArrayAttribute (") + 109ed9e52f3SAlex Zinenko err.what() + ")"; 110ed9e52f3SAlex Zinenko throw py::cast_error(msg); 111ed9e52f3SAlex Zinenko } 112ed9e52f3SAlex Zinenko } 113ed9e52f3SAlex Zinenko 114619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived 115619fd8c2SJeff Niu /// implementation class. 116619fd8c2SJeff Niu template <typename EltTy, typename DerivedT> 117133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> { 118619fd8c2SJeff Niu public: 119133624acSJeff Niu using PyConcreteAttribute<DerivedT>::PyConcreteAttribute; 120619fd8c2SJeff Niu 121619fd8c2SJeff Niu /// Iterator over the integer elements of a dense array. 122619fd8c2SJeff Niu class PyDenseArrayIterator { 123619fd8c2SJeff Niu public: 1244a1b1196SMehdi Amini PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {} 125619fd8c2SJeff Niu 126619fd8c2SJeff Niu /// Return a copy of the iterator. 127619fd8c2SJeff Niu PyDenseArrayIterator dunderIter() { return *this; } 128619fd8c2SJeff Niu 129619fd8c2SJeff Niu /// Return the next element. 130619fd8c2SJeff Niu EltTy dunderNext() { 131619fd8c2SJeff Niu // Throw if the index has reached the end. 132619fd8c2SJeff Niu if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) 133619fd8c2SJeff Niu throw py::stop_iteration(); 134619fd8c2SJeff Niu return DerivedT::getElement(attr.get(), nextIndex++); 135619fd8c2SJeff Niu } 136619fd8c2SJeff Niu 137619fd8c2SJeff Niu /// Bind the iterator class. 138619fd8c2SJeff Niu static void bind(py::module &m) { 139619fd8c2SJeff Niu py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName, 140619fd8c2SJeff Niu py::module_local()) 141619fd8c2SJeff Niu .def("__iter__", &PyDenseArrayIterator::dunderIter) 142619fd8c2SJeff Niu .def("__next__", &PyDenseArrayIterator::dunderNext); 143619fd8c2SJeff Niu } 144619fd8c2SJeff Niu 145619fd8c2SJeff Niu private: 146619fd8c2SJeff Niu /// The referenced dense array attribute. 147619fd8c2SJeff Niu PyAttribute attr; 148619fd8c2SJeff Niu /// The next index to read. 149619fd8c2SJeff Niu int nextIndex = 0; 150619fd8c2SJeff Niu }; 151619fd8c2SJeff Niu 152619fd8c2SJeff Niu /// Get the element at the given index. 153619fd8c2SJeff Niu EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); } 154619fd8c2SJeff Niu 155619fd8c2SJeff Niu /// Bind the attribute class. 156133624acSJeff Niu static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) { 157619fd8c2SJeff Niu // Bind the constructor. 158619fd8c2SJeff Niu c.def_static( 159619fd8c2SJeff Niu "get", 160619fd8c2SJeff Niu [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) { 161619fd8c2SJeff Niu MlirAttribute attr = 162619fd8c2SJeff Niu DerivedT::getAttribute(ctx->get(), values.size(), values.data()); 163133624acSJeff Niu return DerivedT(ctx->getRef(), attr); 164619fd8c2SJeff Niu }, 165619fd8c2SJeff Niu py::arg("values"), py::arg("context") = py::none(), 166619fd8c2SJeff Niu "Gets a uniqued dense array attribute"); 167619fd8c2SJeff Niu // Bind the array methods. 168133624acSJeff Niu c.def("__getitem__", [](DerivedT &arr, intptr_t i) { 169619fd8c2SJeff Niu if (i >= mlirDenseArrayGetNumElements(arr)) 170619fd8c2SJeff Niu throw py::index_error("DenseArray index out of range"); 171619fd8c2SJeff Niu return arr.getItem(i); 172619fd8c2SJeff Niu }); 173133624acSJeff Niu c.def("__len__", [](const DerivedT &arr) { 174619fd8c2SJeff Niu return mlirDenseArrayGetNumElements(arr); 175619fd8c2SJeff Niu }); 176133624acSJeff Niu c.def("__iter__", 177133624acSJeff Niu [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); 1784a1b1196SMehdi Amini c.def("__add__", [](DerivedT &arr, const py::list &extras) { 179619fd8c2SJeff Niu std::vector<EltTy> values; 180619fd8c2SJeff Niu intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); 181619fd8c2SJeff Niu values.reserve(numOldElements + py::len(extras)); 182619fd8c2SJeff Niu for (intptr_t i = 0; i < numOldElements; ++i) 183619fd8c2SJeff Niu values.push_back(arr.getItem(i)); 184619fd8c2SJeff Niu for (py::handle attr : extras) 185619fd8c2SJeff Niu values.push_back(pyTryCast<EltTy>(attr)); 186619fd8c2SJeff Niu MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(), 187619fd8c2SJeff Niu values.size(), values.data()); 188133624acSJeff Niu return DerivedT(arr.getContext(), attr); 189619fd8c2SJeff Niu }); 190619fd8c2SJeff Niu } 191619fd8c2SJeff Niu }; 192619fd8c2SJeff Niu 193619fd8c2SJeff Niu /// Instantiate the python dense array classes. 194619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute 195619fd8c2SJeff Niu : public PyDenseArrayAttribute<int, PyDenseBoolArrayAttribute> { 196619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; 197619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseBoolArrayGet; 198619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseBoolArrayGetElement; 199619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseBoolArrayAttr"; 200619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseBoolArrayIterator"; 201619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 202619fd8c2SJeff Niu }; 203619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute 204619fd8c2SJeff Niu : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> { 205619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array; 206619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI8ArrayGet; 207619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI8ArrayGetElement; 208619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI8ArrayAttr"; 209619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI8ArrayIterator"; 210619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 211619fd8c2SJeff Niu }; 212619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute 213619fd8c2SJeff Niu : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> { 214619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array; 215619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI16ArrayGet; 216619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI16ArrayGetElement; 217619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI16ArrayAttr"; 218619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI16ArrayIterator"; 219619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 220619fd8c2SJeff Niu }; 221619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute 222619fd8c2SJeff Niu : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> { 223619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array; 224619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI32ArrayGet; 225619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI32ArrayGetElement; 226619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI32ArrayAttr"; 227619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI32ArrayIterator"; 228619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 229619fd8c2SJeff Niu }; 230619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute 231619fd8c2SJeff Niu : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> { 232619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array; 233619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI64ArrayGet; 234619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI64ArrayGetElement; 235619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI64ArrayAttr"; 236619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI64ArrayIterator"; 237619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 238619fd8c2SJeff Niu }; 239619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute 240619fd8c2SJeff Niu : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> { 241619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array; 242619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseF32ArrayGet; 243619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseF32ArrayGetElement; 244619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseF32ArrayAttr"; 245619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseF32ArrayIterator"; 246619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 247619fd8c2SJeff Niu }; 248619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute 249619fd8c2SJeff Niu : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> { 250619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array; 251619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseF64ArrayGet; 252619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseF64ArrayGetElement; 253619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseF64ArrayAttr"; 254619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseF64ArrayIterator"; 255619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 256619fd8c2SJeff Niu }; 257619fd8c2SJeff Niu 258436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 259436c6c9cSStella Laurenzo public: 260436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 261436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "ArrayAttr"; 262436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 263436c6c9cSStella Laurenzo 264436c6c9cSStella Laurenzo class PyArrayAttributeIterator { 265436c6c9cSStella Laurenzo public: 2661fc096afSMehdi Amini PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} 267436c6c9cSStella Laurenzo 268436c6c9cSStella Laurenzo PyArrayAttributeIterator &dunderIter() { return *this; } 269436c6c9cSStella Laurenzo 270436c6c9cSStella Laurenzo PyAttribute dunderNext() { 271bca88952SJeff Niu // TODO: Throw is an inefficient way to stop iteration. 272bca88952SJeff Niu if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) 273436c6c9cSStella Laurenzo throw py::stop_iteration(); 274436c6c9cSStella Laurenzo return PyAttribute(attr.getContext(), 275436c6c9cSStella Laurenzo mlirArrayAttrGetElement(attr.get(), nextIndex++)); 276436c6c9cSStella Laurenzo } 277436c6c9cSStella Laurenzo 278436c6c9cSStella Laurenzo static void bind(py::module &m) { 279f05ff4f7SStella Laurenzo py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator", 280f05ff4f7SStella Laurenzo py::module_local()) 281436c6c9cSStella Laurenzo .def("__iter__", &PyArrayAttributeIterator::dunderIter) 282436c6c9cSStella Laurenzo .def("__next__", &PyArrayAttributeIterator::dunderNext); 283436c6c9cSStella Laurenzo } 284436c6c9cSStella Laurenzo 285436c6c9cSStella Laurenzo private: 286436c6c9cSStella Laurenzo PyAttribute attr; 287436c6c9cSStella Laurenzo int nextIndex = 0; 288436c6c9cSStella Laurenzo }; 289436c6c9cSStella Laurenzo 290ed9e52f3SAlex Zinenko PyAttribute getItem(intptr_t i) { 291ed9e52f3SAlex Zinenko return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i)); 292ed9e52f3SAlex Zinenko } 293ed9e52f3SAlex Zinenko 294436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 295436c6c9cSStella Laurenzo c.def_static( 296436c6c9cSStella Laurenzo "get", 297436c6c9cSStella Laurenzo [](py::list attributes, DefaultingPyMlirContext context) { 298436c6c9cSStella Laurenzo SmallVector<MlirAttribute> mlirAttributes; 299436c6c9cSStella Laurenzo mlirAttributes.reserve(py::len(attributes)); 300436c6c9cSStella Laurenzo for (auto attribute : attributes) { 301ed9e52f3SAlex Zinenko mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); 302436c6c9cSStella Laurenzo } 303436c6c9cSStella Laurenzo MlirAttribute attr = mlirArrayAttrGet( 304436c6c9cSStella Laurenzo context->get(), mlirAttributes.size(), mlirAttributes.data()); 305436c6c9cSStella Laurenzo return PyArrayAttribute(context->getRef(), attr); 306436c6c9cSStella Laurenzo }, 307436c6c9cSStella Laurenzo py::arg("attributes"), py::arg("context") = py::none(), 308436c6c9cSStella Laurenzo "Gets a uniqued Array attribute"); 309436c6c9cSStella Laurenzo c.def("__getitem__", 310436c6c9cSStella Laurenzo [](PyArrayAttribute &arr, intptr_t i) { 311436c6c9cSStella Laurenzo if (i >= mlirArrayAttrGetNumElements(arr)) 312436c6c9cSStella Laurenzo throw py::index_error("ArrayAttribute index out of range"); 313ed9e52f3SAlex Zinenko return arr.getItem(i); 314436c6c9cSStella Laurenzo }) 315436c6c9cSStella Laurenzo .def("__len__", 316436c6c9cSStella Laurenzo [](const PyArrayAttribute &arr) { 317436c6c9cSStella Laurenzo return mlirArrayAttrGetNumElements(arr); 318436c6c9cSStella Laurenzo }) 319436c6c9cSStella Laurenzo .def("__iter__", [](const PyArrayAttribute &arr) { 320436c6c9cSStella Laurenzo return PyArrayAttributeIterator(arr); 321436c6c9cSStella Laurenzo }); 322ed9e52f3SAlex Zinenko c.def("__add__", [](PyArrayAttribute arr, py::list extras) { 323ed9e52f3SAlex Zinenko std::vector<MlirAttribute> attributes; 324ed9e52f3SAlex Zinenko intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); 325ed9e52f3SAlex Zinenko attributes.reserve(numOldElements + py::len(extras)); 326ed9e52f3SAlex Zinenko for (intptr_t i = 0; i < numOldElements; ++i) 327ed9e52f3SAlex Zinenko attributes.push_back(arr.getItem(i)); 328ed9e52f3SAlex Zinenko for (py::handle attr : extras) 329ed9e52f3SAlex Zinenko attributes.push_back(pyTryCast<PyAttribute>(attr)); 330ed9e52f3SAlex Zinenko MlirAttribute arrayAttr = mlirArrayAttrGet( 331ed9e52f3SAlex Zinenko arr.getContext()->get(), attributes.size(), attributes.data()); 332ed9e52f3SAlex Zinenko return PyArrayAttribute(arr.getContext(), arrayAttr); 333ed9e52f3SAlex Zinenko }); 334436c6c9cSStella Laurenzo } 335436c6c9cSStella Laurenzo }; 336436c6c9cSStella Laurenzo 337436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr. 338436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 339436c6c9cSStella Laurenzo public: 340436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 341436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FloatAttr"; 342436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 343436c6c9cSStella Laurenzo 344436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 345436c6c9cSStella Laurenzo c.def_static( 346436c6c9cSStella Laurenzo "get", 347436c6c9cSStella Laurenzo [](PyType &type, double value, DefaultingPyLocation loc) { 348436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 349436c6c9cSStella Laurenzo // TODO: Rework error reporting once diagnostic engine is exposed 350436c6c9cSStella Laurenzo // in C API. 351436c6c9cSStella Laurenzo if (mlirAttributeIsNull(attr)) { 352436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, 353436c6c9cSStella Laurenzo Twine("invalid '") + 354436c6c9cSStella Laurenzo py::repr(py::cast(type)).cast<std::string>() + 355436c6c9cSStella Laurenzo "' and expected floating point type."); 356436c6c9cSStella Laurenzo } 357436c6c9cSStella Laurenzo return PyFloatAttribute(type.getContext(), attr); 358436c6c9cSStella Laurenzo }, 359436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 360436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a type"); 361436c6c9cSStella Laurenzo c.def_static( 362436c6c9cSStella Laurenzo "get_f32", 363436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 364436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 365436c6c9cSStella Laurenzo context->get(), mlirF32TypeGet(context->get()), value); 366436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 367436c6c9cSStella Laurenzo }, 368436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 369436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f32 type"); 370436c6c9cSStella Laurenzo c.def_static( 371436c6c9cSStella Laurenzo "get_f64", 372436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 373436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 374436c6c9cSStella Laurenzo context->get(), mlirF64TypeGet(context->get()), value); 375436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 376436c6c9cSStella Laurenzo }, 377436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 378436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f64 type"); 379436c6c9cSStella Laurenzo c.def_property_readonly( 380436c6c9cSStella Laurenzo "value", 381436c6c9cSStella Laurenzo [](PyFloatAttribute &self) { 382436c6c9cSStella Laurenzo return mlirFloatAttrGetValueDouble(self); 383436c6c9cSStella Laurenzo }, 384436c6c9cSStella Laurenzo "Returns the value of the float point attribute"); 385436c6c9cSStella Laurenzo } 386436c6c9cSStella Laurenzo }; 387436c6c9cSStella Laurenzo 388436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr. 389436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 390436c6c9cSStella Laurenzo public: 391436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 392436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "IntegerAttr"; 393436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 394436c6c9cSStella Laurenzo 395436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 396436c6c9cSStella Laurenzo c.def_static( 397436c6c9cSStella Laurenzo "get", 398436c6c9cSStella Laurenzo [](PyType &type, int64_t value) { 399436c6c9cSStella Laurenzo MlirAttribute attr = mlirIntegerAttrGet(type, value); 400436c6c9cSStella Laurenzo return PyIntegerAttribute(type.getContext(), attr); 401436c6c9cSStella Laurenzo }, 402436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), 403436c6c9cSStella Laurenzo "Gets an uniqued integer attribute associated to a type"); 404436c6c9cSStella Laurenzo c.def_property_readonly( 405436c6c9cSStella Laurenzo "value", 406e9db306dSrkayaith [](PyIntegerAttribute &self) -> py::int_ { 407e9db306dSrkayaith MlirType type = mlirAttributeGetType(self); 408e9db306dSrkayaith if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) 409436c6c9cSStella Laurenzo return mlirIntegerAttrGetValueInt(self); 410e9db306dSrkayaith if (mlirIntegerTypeIsSigned(type)) 411e9db306dSrkayaith return mlirIntegerAttrGetValueSInt(self); 412e9db306dSrkayaith return mlirIntegerAttrGetValueUInt(self); 413436c6c9cSStella Laurenzo }, 414436c6c9cSStella Laurenzo "Returns the value of the integer attribute"); 415436c6c9cSStella Laurenzo } 416436c6c9cSStella Laurenzo }; 417436c6c9cSStella Laurenzo 418436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr. 419436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 420436c6c9cSStella Laurenzo public: 421436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 422436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "BoolAttr"; 423436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 424436c6c9cSStella Laurenzo 425436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 426436c6c9cSStella Laurenzo c.def_static( 427436c6c9cSStella Laurenzo "get", 428436c6c9cSStella Laurenzo [](bool value, DefaultingPyMlirContext context) { 429436c6c9cSStella Laurenzo MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 430436c6c9cSStella Laurenzo return PyBoolAttribute(context->getRef(), attr); 431436c6c9cSStella Laurenzo }, 432436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 433436c6c9cSStella Laurenzo "Gets an uniqued bool attribute"); 434436c6c9cSStella Laurenzo c.def_property_readonly( 435436c6c9cSStella Laurenzo "value", 436436c6c9cSStella Laurenzo [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, 437436c6c9cSStella Laurenzo "Returns the value of the bool attribute"); 438436c6c9cSStella Laurenzo } 439436c6c9cSStella Laurenzo }; 440436c6c9cSStella Laurenzo 441436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute 442436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 443436c6c9cSStella Laurenzo public: 444436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 445436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 446436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 447436c6c9cSStella Laurenzo 448436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 449436c6c9cSStella Laurenzo c.def_static( 450436c6c9cSStella Laurenzo "get", 451436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 452436c6c9cSStella Laurenzo MlirAttribute attr = 453436c6c9cSStella Laurenzo mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 454436c6c9cSStella Laurenzo return PyFlatSymbolRefAttribute(context->getRef(), attr); 455436c6c9cSStella Laurenzo }, 456436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 457436c6c9cSStella Laurenzo "Gets a uniqued FlatSymbolRef attribute"); 458436c6c9cSStella Laurenzo c.def_property_readonly( 459436c6c9cSStella Laurenzo "value", 460436c6c9cSStella Laurenzo [](PyFlatSymbolRefAttribute &self) { 461436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 462436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 463436c6c9cSStella Laurenzo }, 464436c6c9cSStella Laurenzo "Returns the value of the FlatSymbolRef attribute as a string"); 465436c6c9cSStella Laurenzo } 466436c6c9cSStella Laurenzo }; 467436c6c9cSStella Laurenzo 4685c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> { 4695c3861b2SYun Long public: 4705c3861b2SYun Long static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; 4715c3861b2SYun Long static constexpr const char *pyClassName = "OpaqueAttr"; 4725c3861b2SYun Long using PyConcreteAttribute::PyConcreteAttribute; 4735c3861b2SYun Long 4745c3861b2SYun Long static void bindDerived(ClassTy &c) { 4755c3861b2SYun Long c.def_static( 4765c3861b2SYun Long "get", 4775c3861b2SYun Long [](std::string dialectNamespace, py::buffer buffer, PyType &type, 4785c3861b2SYun Long DefaultingPyMlirContext context) { 4795c3861b2SYun Long const py::buffer_info bufferInfo = buffer.request(); 4805c3861b2SYun Long intptr_t bufferSize = bufferInfo.size; 4815c3861b2SYun Long MlirAttribute attr = mlirOpaqueAttrGet( 4825c3861b2SYun Long context->get(), toMlirStringRef(dialectNamespace), bufferSize, 4835c3861b2SYun Long static_cast<char *>(bufferInfo.ptr), type); 4845c3861b2SYun Long return PyOpaqueAttribute(context->getRef(), attr); 4855c3861b2SYun Long }, 4865c3861b2SYun Long py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"), 4875c3861b2SYun Long py::arg("context") = py::none(), "Gets an Opaque attribute."); 4885c3861b2SYun Long c.def_property_readonly( 4895c3861b2SYun Long "dialect_namespace", 4905c3861b2SYun Long [](PyOpaqueAttribute &self) { 4915c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); 4925c3861b2SYun Long return py::str(stringRef.data, stringRef.length); 4935c3861b2SYun Long }, 4945c3861b2SYun Long "Returns the dialect namespace for the Opaque attribute as a string"); 4955c3861b2SYun Long c.def_property_readonly( 4965c3861b2SYun Long "data", 4975c3861b2SYun Long [](PyOpaqueAttribute &self) { 4985c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetData(self); 4995c3861b2SYun Long return py::str(stringRef.data, stringRef.length); 5005c3861b2SYun Long }, 5015c3861b2SYun Long "Returns the data for the Opaqued attributes as a string"); 5025c3861b2SYun Long } 5035c3861b2SYun Long }; 5045c3861b2SYun Long 505436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 506436c6c9cSStella Laurenzo public: 507436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 508436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "StringAttr"; 509436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 510436c6c9cSStella Laurenzo 511436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 512436c6c9cSStella Laurenzo c.def_static( 513436c6c9cSStella Laurenzo "get", 514436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 515436c6c9cSStella Laurenzo MlirAttribute attr = 516436c6c9cSStella Laurenzo mlirStringAttrGet(context->get(), toMlirStringRef(value)); 517436c6c9cSStella Laurenzo return PyStringAttribute(context->getRef(), attr); 518436c6c9cSStella Laurenzo }, 519436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 520436c6c9cSStella Laurenzo "Gets a uniqued string attribute"); 521436c6c9cSStella Laurenzo c.def_static( 522436c6c9cSStella Laurenzo "get_typed", 523436c6c9cSStella Laurenzo [](PyType &type, std::string value) { 524436c6c9cSStella Laurenzo MlirAttribute attr = 525436c6c9cSStella Laurenzo mlirStringAttrTypedGet(type, toMlirStringRef(value)); 526436c6c9cSStella Laurenzo return PyStringAttribute(type.getContext(), attr); 527436c6c9cSStella Laurenzo }, 528a6e7d024SStella Laurenzo py::arg("type"), py::arg("value"), 529436c6c9cSStella Laurenzo "Gets a uniqued string attribute associated to a type"); 530436c6c9cSStella Laurenzo c.def_property_readonly( 531436c6c9cSStella Laurenzo "value", 532436c6c9cSStella Laurenzo [](PyStringAttribute &self) { 533436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirStringAttrGetValue(self); 534436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 535436c6c9cSStella Laurenzo }, 536436c6c9cSStella Laurenzo "Returns the value of the string attribute"); 537436c6c9cSStella Laurenzo } 538436c6c9cSStella Laurenzo }; 539436c6c9cSStella Laurenzo 540436c6c9cSStella Laurenzo // TODO: Support construction of string elements. 541436c6c9cSStella Laurenzo class PyDenseElementsAttribute 542436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseElementsAttribute> { 543436c6c9cSStella Laurenzo public: 544436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 545436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseElementsAttr"; 546436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 547436c6c9cSStella Laurenzo 548436c6c9cSStella Laurenzo static PyDenseElementsAttribute 549*0a81ace0SKazu Hirata getFromBuffer(py::buffer array, bool signless, 550*0a81ace0SKazu Hirata std::optional<PyType> explicitType, 551*0a81ace0SKazu Hirata std::optional<std::vector<int64_t>> explicitShape, 552436c6c9cSStella Laurenzo DefaultingPyMlirContext contextWrapper) { 553436c6c9cSStella Laurenzo // Request a contiguous view. In exotic cases, this will cause a copy. 554436c6c9cSStella Laurenzo int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; 555436c6c9cSStella Laurenzo Py_buffer *view = new Py_buffer(); 556436c6c9cSStella Laurenzo if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { 557436c6c9cSStella Laurenzo delete view; 558436c6c9cSStella Laurenzo throw py::error_already_set(); 559436c6c9cSStella Laurenzo } 560436c6c9cSStella Laurenzo py::buffer_info arrayInfo(view); 5615d6d30edSStella Laurenzo SmallVector<int64_t> shape; 5625d6d30edSStella Laurenzo if (explicitShape) { 5635d6d30edSStella Laurenzo shape.append(explicitShape->begin(), explicitShape->end()); 5645d6d30edSStella Laurenzo } else { 5655d6d30edSStella Laurenzo shape.append(arrayInfo.shape.begin(), 5665d6d30edSStella Laurenzo arrayInfo.shape.begin() + arrayInfo.ndim); 5675d6d30edSStella Laurenzo } 568436c6c9cSStella Laurenzo 5695d6d30edSStella Laurenzo MlirAttribute encodingAttr = mlirAttributeGetNull(); 570436c6c9cSStella Laurenzo MlirContext context = contextWrapper->get(); 5715d6d30edSStella Laurenzo 5725d6d30edSStella Laurenzo // Detect format codes that are suitable for bulk loading. This includes 5735d6d30edSStella Laurenzo // all byte aligned integer and floating point types up to 8 bytes. 5745d6d30edSStella Laurenzo // Notably, this excludes, bool (which needs to be bit-packed) and 5755d6d30edSStella Laurenzo // other exotics which do not have a direct representation in the buffer 5765d6d30edSStella Laurenzo // protocol (i.e. complex, etc). 577*0a81ace0SKazu Hirata std::optional<MlirType> bulkLoadElementType; 5785d6d30edSStella Laurenzo if (explicitType) { 5795d6d30edSStella Laurenzo bulkLoadElementType = *explicitType; 5805d6d30edSStella Laurenzo } else if (arrayInfo.format == "f") { 581436c6c9cSStella Laurenzo // f32 582436c6c9cSStella Laurenzo assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); 5835d6d30edSStella Laurenzo bulkLoadElementType = mlirF32TypeGet(context); 584436c6c9cSStella Laurenzo } else if (arrayInfo.format == "d") { 585436c6c9cSStella Laurenzo // f64 586436c6c9cSStella Laurenzo assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); 5875d6d30edSStella Laurenzo bulkLoadElementType = mlirF64TypeGet(context); 5885d6d30edSStella Laurenzo } else if (arrayInfo.format == "e") { 5895d6d30edSStella Laurenzo // f16 5905d6d30edSStella Laurenzo assert(arrayInfo.itemsize == 2 && "mismatched array itemsize"); 5915d6d30edSStella Laurenzo bulkLoadElementType = mlirF16TypeGet(context); 592436c6c9cSStella Laurenzo } else if (isSignedIntegerFormat(arrayInfo.format)) { 593436c6c9cSStella Laurenzo if (arrayInfo.itemsize == 4) { 594436c6c9cSStella Laurenzo // i32 5955d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32) 596436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 32); 597436c6c9cSStella Laurenzo } else if (arrayInfo.itemsize == 8) { 598436c6c9cSStella Laurenzo // i64 5995d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64) 600436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 64); 6015d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 1) { 6025d6d30edSStella Laurenzo // i8 6035d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 6045d6d30edSStella Laurenzo : mlirIntegerTypeSignedGet(context, 8); 6055d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 2) { 6065d6d30edSStella Laurenzo // i16 6075d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16) 6085d6d30edSStella Laurenzo : mlirIntegerTypeSignedGet(context, 16); 609436c6c9cSStella Laurenzo } 610436c6c9cSStella Laurenzo } else if (isUnsignedIntegerFormat(arrayInfo.format)) { 611436c6c9cSStella Laurenzo if (arrayInfo.itemsize == 4) { 612436c6c9cSStella Laurenzo // unsigned i32 6135d6d30edSStella Laurenzo bulkLoadElementType = signless 614436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 32) 615436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 32); 616436c6c9cSStella Laurenzo } else if (arrayInfo.itemsize == 8) { 617436c6c9cSStella Laurenzo // unsigned i64 6185d6d30edSStella Laurenzo bulkLoadElementType = signless 619436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 64) 620436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 64); 6215d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 1) { 6225d6d30edSStella Laurenzo // i8 6235d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 6245d6d30edSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 8); 6255d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 2) { 6265d6d30edSStella Laurenzo // i16 6275d6d30edSStella Laurenzo bulkLoadElementType = signless 6285d6d30edSStella Laurenzo ? mlirIntegerTypeGet(context, 16) 6295d6d30edSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 16); 630436c6c9cSStella Laurenzo } 631436c6c9cSStella Laurenzo } 6325d6d30edSStella Laurenzo if (bulkLoadElementType) { 6335d6d30edSStella Laurenzo auto shapedType = mlirRankedTensorTypeGet( 6345d6d30edSStella Laurenzo shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); 6355d6d30edSStella Laurenzo size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize; 6365d6d30edSStella Laurenzo MlirAttribute attr = mlirDenseElementsAttrRawBufferGet( 6375d6d30edSStella Laurenzo shapedType, rawBufferSize, arrayInfo.ptr); 6385d6d30edSStella Laurenzo if (mlirAttributeIsNull(attr)) { 6395d6d30edSStella Laurenzo throw std::invalid_argument( 6405d6d30edSStella Laurenzo "DenseElementsAttr could not be constructed from the given buffer. " 6415d6d30edSStella Laurenzo "This may mean that the Python buffer layout does not match that " 6425d6d30edSStella Laurenzo "MLIR expected layout and is a bug."); 6435d6d30edSStella Laurenzo } 6445d6d30edSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), attr); 6455d6d30edSStella Laurenzo } 646436c6c9cSStella Laurenzo 6475d6d30edSStella Laurenzo throw std::invalid_argument( 6485d6d30edSStella Laurenzo std::string("unimplemented array format conversion from format: ") + 6495d6d30edSStella Laurenzo arrayInfo.format); 650436c6c9cSStella Laurenzo } 651436c6c9cSStella Laurenzo 6521fc096afSMehdi Amini static PyDenseElementsAttribute getSplat(const PyType &shapedType, 653436c6c9cSStella Laurenzo PyAttribute &elementAttr) { 654436c6c9cSStella Laurenzo auto contextWrapper = 655436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 656436c6c9cSStella Laurenzo if (!mlirAttributeIsAInteger(elementAttr) && 657436c6c9cSStella Laurenzo !mlirAttributeIsAFloat(elementAttr)) { 658436c6c9cSStella Laurenzo std::string message = "Illegal element type for DenseElementsAttr: "; 659436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 660436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 661436c6c9cSStella Laurenzo } 662436c6c9cSStella Laurenzo if (!mlirTypeIsAShaped(shapedType) || 663436c6c9cSStella Laurenzo !mlirShapedTypeHasStaticShape(shapedType)) { 664436c6c9cSStella Laurenzo std::string message = 665436c6c9cSStella Laurenzo "Expected a static ShapedType for the shaped_type parameter: "; 666436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 667436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 668436c6c9cSStella Laurenzo } 669436c6c9cSStella Laurenzo MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 670436c6c9cSStella Laurenzo MlirType attrType = mlirAttributeGetType(elementAttr); 671436c6c9cSStella Laurenzo if (!mlirTypeEqual(shapedElementType, attrType)) { 672436c6c9cSStella Laurenzo std::string message = 673436c6c9cSStella Laurenzo "Shaped element type and attribute type must be equal: shaped="; 674436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 675436c6c9cSStella Laurenzo message.append(", element="); 676436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 677436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 678436c6c9cSStella Laurenzo } 679436c6c9cSStella Laurenzo 680436c6c9cSStella Laurenzo MlirAttribute elements = 681436c6c9cSStella Laurenzo mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 682436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 683436c6c9cSStella Laurenzo } 684436c6c9cSStella Laurenzo 685436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 686436c6c9cSStella Laurenzo 687436c6c9cSStella Laurenzo py::buffer_info accessBuffer() { 6885d6d30edSStella Laurenzo if (mlirDenseElementsAttrIsSplat(*this)) { 689c5f445d1SStella Laurenzo // TODO: Currently crashes the program. 6905d6d30edSStella Laurenzo // Reported as https://github.com/pybind/pybind11/issues/3336 691c5f445d1SStella Laurenzo throw std::invalid_argument( 692c5f445d1SStella Laurenzo "unsupported data type for conversion to Python buffer"); 6935d6d30edSStella Laurenzo } 6945d6d30edSStella Laurenzo 695436c6c9cSStella Laurenzo MlirType shapedType = mlirAttributeGetType(*this); 696436c6c9cSStella Laurenzo MlirType elementType = mlirShapedTypeGetElementType(shapedType); 6975d6d30edSStella Laurenzo std::string format; 698436c6c9cSStella Laurenzo 699436c6c9cSStella Laurenzo if (mlirTypeIsAF32(elementType)) { 700436c6c9cSStella Laurenzo // f32 7015d6d30edSStella Laurenzo return bufferInfo<float>(shapedType); 70202b6fb21SMehdi Amini } 70302b6fb21SMehdi Amini if (mlirTypeIsAF64(elementType)) { 704436c6c9cSStella Laurenzo // f64 7055d6d30edSStella Laurenzo return bufferInfo<double>(shapedType); 706bb56c2b3SMehdi Amini } 707bb56c2b3SMehdi Amini if (mlirTypeIsAF16(elementType)) { 7085d6d30edSStella Laurenzo // f16 7095d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType, "e"); 710bb56c2b3SMehdi Amini } 711bb56c2b3SMehdi Amini if (mlirTypeIsAInteger(elementType) && 712436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 32) { 713436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 714436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 715436c6c9cSStella Laurenzo // i32 7165d6d30edSStella Laurenzo return bufferInfo<int32_t>(shapedType); 717e5639b3fSMehdi Amini } 718e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 719436c6c9cSStella Laurenzo // unsigned i32 7205d6d30edSStella Laurenzo return bufferInfo<uint32_t>(shapedType); 721436c6c9cSStella Laurenzo } 722436c6c9cSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 723436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 64) { 724436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 725436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 726436c6c9cSStella Laurenzo // i64 7275d6d30edSStella Laurenzo return bufferInfo<int64_t>(shapedType); 728e5639b3fSMehdi Amini } 729e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 730436c6c9cSStella Laurenzo // unsigned i64 7315d6d30edSStella Laurenzo return bufferInfo<uint64_t>(shapedType); 7325d6d30edSStella Laurenzo } 7335d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 7345d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 8) { 7355d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 7365d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 7375d6d30edSStella Laurenzo // i8 7385d6d30edSStella Laurenzo return bufferInfo<int8_t>(shapedType); 739e5639b3fSMehdi Amini } 740e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 7415d6d30edSStella Laurenzo // unsigned i8 7425d6d30edSStella Laurenzo return bufferInfo<uint8_t>(shapedType); 7435d6d30edSStella Laurenzo } 7445d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 7455d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 16) { 7465d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 7475d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 7485d6d30edSStella Laurenzo // i16 7495d6d30edSStella Laurenzo return bufferInfo<int16_t>(shapedType); 750e5639b3fSMehdi Amini } 751e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 7525d6d30edSStella Laurenzo // unsigned i16 7535d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType); 754436c6c9cSStella Laurenzo } 755436c6c9cSStella Laurenzo } 756436c6c9cSStella Laurenzo 757c5f445d1SStella Laurenzo // TODO: Currently crashes the program. 7585d6d30edSStella Laurenzo // Reported as https://github.com/pybind/pybind11/issues/3336 759c5f445d1SStella Laurenzo throw std::invalid_argument( 760c5f445d1SStella Laurenzo "unsupported data type for conversion to Python buffer"); 761436c6c9cSStella Laurenzo } 762436c6c9cSStella Laurenzo 763436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 764436c6c9cSStella Laurenzo c.def("__len__", &PyDenseElementsAttribute::dunderLen) 765436c6c9cSStella Laurenzo .def_static("get", PyDenseElementsAttribute::getFromBuffer, 766436c6c9cSStella Laurenzo py::arg("array"), py::arg("signless") = true, 7675d6d30edSStella Laurenzo py::arg("type") = py::none(), py::arg("shape") = py::none(), 768436c6c9cSStella Laurenzo py::arg("context") = py::none(), 7695d6d30edSStella Laurenzo kDenseElementsAttrGetDocstring) 770436c6c9cSStella Laurenzo .def_static("get_splat", PyDenseElementsAttribute::getSplat, 771436c6c9cSStella Laurenzo py::arg("shaped_type"), py::arg("element_attr"), 772436c6c9cSStella Laurenzo "Gets a DenseElementsAttr where all values are the same") 773436c6c9cSStella Laurenzo .def_property_readonly("is_splat", 774436c6c9cSStella Laurenzo [](PyDenseElementsAttribute &self) -> bool { 775436c6c9cSStella Laurenzo return mlirDenseElementsAttrIsSplat(self); 776436c6c9cSStella Laurenzo }) 777436c6c9cSStella Laurenzo .def_buffer(&PyDenseElementsAttribute::accessBuffer); 778436c6c9cSStella Laurenzo } 779436c6c9cSStella Laurenzo 780436c6c9cSStella Laurenzo private: 781436c6c9cSStella Laurenzo static bool isUnsignedIntegerFormat(const std::string &format) { 782436c6c9cSStella Laurenzo if (format.empty()) 783436c6c9cSStella Laurenzo return false; 784436c6c9cSStella Laurenzo char code = format[0]; 785436c6c9cSStella Laurenzo return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 786436c6c9cSStella Laurenzo code == 'Q'; 787436c6c9cSStella Laurenzo } 788436c6c9cSStella Laurenzo 789436c6c9cSStella Laurenzo static bool isSignedIntegerFormat(const std::string &format) { 790436c6c9cSStella Laurenzo if (format.empty()) 791436c6c9cSStella Laurenzo return false; 792436c6c9cSStella Laurenzo char code = format[0]; 793436c6c9cSStella Laurenzo return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 794436c6c9cSStella Laurenzo code == 'q'; 795436c6c9cSStella Laurenzo } 796436c6c9cSStella Laurenzo 797436c6c9cSStella Laurenzo template <typename Type> 798436c6c9cSStella Laurenzo py::buffer_info bufferInfo(MlirType shapedType, 7995d6d30edSStella Laurenzo const char *explicitFormat = nullptr) { 800436c6c9cSStella Laurenzo intptr_t rank = mlirShapedTypeGetRank(shapedType); 801436c6c9cSStella Laurenzo // Prepare the data for the buffer_info. 802436c6c9cSStella Laurenzo // Buffer is configured for read-only access below. 803436c6c9cSStella Laurenzo Type *data = static_cast<Type *>( 804436c6c9cSStella Laurenzo const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 805436c6c9cSStella Laurenzo // Prepare the shape for the buffer_info. 806436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> shape; 807436c6c9cSStella Laurenzo for (intptr_t i = 0; i < rank; ++i) 808436c6c9cSStella Laurenzo shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 809436c6c9cSStella Laurenzo // Prepare the strides for the buffer_info. 810436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> strides; 811436c6c9cSStella Laurenzo intptr_t strideFactor = 1; 812436c6c9cSStella Laurenzo for (intptr_t i = 1; i < rank; ++i) { 813436c6c9cSStella Laurenzo strideFactor = 1; 814436c6c9cSStella Laurenzo for (intptr_t j = i; j < rank; ++j) { 815436c6c9cSStella Laurenzo strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 816436c6c9cSStella Laurenzo } 817436c6c9cSStella Laurenzo strides.push_back(sizeof(Type) * strideFactor); 818436c6c9cSStella Laurenzo } 819436c6c9cSStella Laurenzo strides.push_back(sizeof(Type)); 8205d6d30edSStella Laurenzo std::string format; 8215d6d30edSStella Laurenzo if (explicitFormat) { 8225d6d30edSStella Laurenzo format = explicitFormat; 8235d6d30edSStella Laurenzo } else { 8245d6d30edSStella Laurenzo format = py::format_descriptor<Type>::format(); 8255d6d30edSStella Laurenzo } 8265d6d30edSStella Laurenzo return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, 8275d6d30edSStella Laurenzo /*readonly=*/true); 828436c6c9cSStella Laurenzo } 829436c6c9cSStella Laurenzo }; // namespace 830436c6c9cSStella Laurenzo 831436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer 832436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access. 833436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute 834436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseIntElementsAttribute, 835436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 836436c6c9cSStella Laurenzo public: 837436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 838436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseIntElementsAttr"; 839436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 840436c6c9cSStella Laurenzo 841436c6c9cSStella Laurenzo /// Returns the element at the given linear position. Asserts if the index is 842436c6c9cSStella Laurenzo /// out of range. 843436c6c9cSStella Laurenzo py::int_ dunderGetItem(intptr_t pos) { 844436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 845436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 846436c6c9cSStella Laurenzo "attempt to access out of bounds element"); 847436c6c9cSStella Laurenzo } 848436c6c9cSStella Laurenzo 849436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 850436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 851436c6c9cSStella Laurenzo assert(mlirTypeIsAInteger(type) && 852436c6c9cSStella Laurenzo "expected integer element type in dense int elements attribute"); 853436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 854436c6c9cSStella Laurenzo // elemental type of the attribute. py::int_ is implicitly constructible 855436c6c9cSStella Laurenzo // from any C++ integral type and handles bitwidth correctly. 856436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 857436c6c9cSStella Laurenzo // querying them on each element access. 858436c6c9cSStella Laurenzo unsigned width = mlirIntegerTypeGetWidth(type); 859436c6c9cSStella Laurenzo bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 860436c6c9cSStella Laurenzo if (isUnsigned) { 861436c6c9cSStella Laurenzo if (width == 1) { 862436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 863436c6c9cSStella Laurenzo } 864308d8b8cSRahul Kayaith if (width == 8) { 865308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetUInt8Value(*this, pos); 866308d8b8cSRahul Kayaith } 867308d8b8cSRahul Kayaith if (width == 16) { 868308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetUInt16Value(*this, pos); 869308d8b8cSRahul Kayaith } 870436c6c9cSStella Laurenzo if (width == 32) { 871436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt32Value(*this, pos); 872436c6c9cSStella Laurenzo } 873436c6c9cSStella Laurenzo if (width == 64) { 874436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt64Value(*this, pos); 875436c6c9cSStella Laurenzo } 876436c6c9cSStella Laurenzo } else { 877436c6c9cSStella Laurenzo if (width == 1) { 878436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 879436c6c9cSStella Laurenzo } 880308d8b8cSRahul Kayaith if (width == 8) { 881308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetInt8Value(*this, pos); 882308d8b8cSRahul Kayaith } 883308d8b8cSRahul Kayaith if (width == 16) { 884308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetInt16Value(*this, pos); 885308d8b8cSRahul Kayaith } 886436c6c9cSStella Laurenzo if (width == 32) { 887436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt32Value(*this, pos); 888436c6c9cSStella Laurenzo } 889436c6c9cSStella Laurenzo if (width == 64) { 890436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt64Value(*this, pos); 891436c6c9cSStella Laurenzo } 892436c6c9cSStella Laurenzo } 893436c6c9cSStella Laurenzo throw SetPyError(PyExc_TypeError, "Unsupported integer type"); 894436c6c9cSStella Laurenzo } 895436c6c9cSStella Laurenzo 896436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 897436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 898436c6c9cSStella Laurenzo } 899436c6c9cSStella Laurenzo }; 900436c6c9cSStella Laurenzo 901436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 902436c6c9cSStella Laurenzo public: 903436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 904436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DictAttr"; 905436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 906436c6c9cSStella Laurenzo 907436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 908436c6c9cSStella Laurenzo 9099fb1086bSAdrian Kuegel bool dunderContains(const std::string &name) { 9109fb1086bSAdrian Kuegel return !mlirAttributeIsNull( 9119fb1086bSAdrian Kuegel mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); 9129fb1086bSAdrian Kuegel } 9139fb1086bSAdrian Kuegel 914436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 9159fb1086bSAdrian Kuegel c.def("__contains__", &PyDictAttribute::dunderContains); 916436c6c9cSStella Laurenzo c.def("__len__", &PyDictAttribute::dunderLen); 917436c6c9cSStella Laurenzo c.def_static( 918436c6c9cSStella Laurenzo "get", 919436c6c9cSStella Laurenzo [](py::dict attributes, DefaultingPyMlirContext context) { 920436c6c9cSStella Laurenzo SmallVector<MlirNamedAttribute> mlirNamedAttributes; 921436c6c9cSStella Laurenzo mlirNamedAttributes.reserve(attributes.size()); 922436c6c9cSStella Laurenzo for (auto &it : attributes) { 92302b6fb21SMehdi Amini auto &mlirAttr = it.second.cast<PyAttribute &>(); 924436c6c9cSStella Laurenzo auto name = it.first.cast<std::string>(); 925436c6c9cSStella Laurenzo mlirNamedAttributes.push_back(mlirNamedAttributeGet( 92602b6fb21SMehdi Amini mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), 927436c6c9cSStella Laurenzo toMlirStringRef(name)), 92802b6fb21SMehdi Amini mlirAttr)); 929436c6c9cSStella Laurenzo } 930436c6c9cSStella Laurenzo MlirAttribute attr = 931436c6c9cSStella Laurenzo mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 932436c6c9cSStella Laurenzo mlirNamedAttributes.data()); 933436c6c9cSStella Laurenzo return PyDictAttribute(context->getRef(), attr); 934436c6c9cSStella Laurenzo }, 935ed9e52f3SAlex Zinenko py::arg("value") = py::dict(), py::arg("context") = py::none(), 936436c6c9cSStella Laurenzo "Gets an uniqued dict attribute"); 937436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 938436c6c9cSStella Laurenzo MlirAttribute attr = 939436c6c9cSStella Laurenzo mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 940436c6c9cSStella Laurenzo if (mlirAttributeIsNull(attr)) { 941436c6c9cSStella Laurenzo throw SetPyError(PyExc_KeyError, 942436c6c9cSStella Laurenzo "attempt to access a non-existent attribute"); 943436c6c9cSStella Laurenzo } 944436c6c9cSStella Laurenzo return PyAttribute(self.getContext(), attr); 945436c6c9cSStella Laurenzo }); 946436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 947436c6c9cSStella Laurenzo if (index < 0 || index >= self.dunderLen()) { 948436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 949436c6c9cSStella Laurenzo "attempt to access out of bounds attribute"); 950436c6c9cSStella Laurenzo } 951436c6c9cSStella Laurenzo MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 952436c6c9cSStella Laurenzo return PyNamedAttribute( 953436c6c9cSStella Laurenzo namedAttr.attribute, 954436c6c9cSStella Laurenzo std::string(mlirIdentifierStr(namedAttr.name).data)); 955436c6c9cSStella Laurenzo }); 956436c6c9cSStella Laurenzo } 957436c6c9cSStella Laurenzo }; 958436c6c9cSStella Laurenzo 959436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing 960436c6c9cSStella Laurenzo /// floating-point values. Supports element access. 961436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute 962436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseFPElementsAttribute, 963436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 964436c6c9cSStella Laurenzo public: 965436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 966436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseFPElementsAttr"; 967436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 968436c6c9cSStella Laurenzo 969436c6c9cSStella Laurenzo py::float_ dunderGetItem(intptr_t pos) { 970436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 971436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 972436c6c9cSStella Laurenzo "attempt to access out of bounds element"); 973436c6c9cSStella Laurenzo } 974436c6c9cSStella Laurenzo 975436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 976436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 977436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 978436c6c9cSStella Laurenzo // elemental type of the attribute. py::float_ is implicitly constructible 979436c6c9cSStella Laurenzo // from float and double. 980436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 981436c6c9cSStella Laurenzo // querying them on each element access. 982436c6c9cSStella Laurenzo if (mlirTypeIsAF32(type)) { 983436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetFloatValue(*this, pos); 984436c6c9cSStella Laurenzo } 985436c6c9cSStella Laurenzo if (mlirTypeIsAF64(type)) { 986436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetDoubleValue(*this, pos); 987436c6c9cSStella Laurenzo } 988436c6c9cSStella Laurenzo throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); 989436c6c9cSStella Laurenzo } 990436c6c9cSStella Laurenzo 991436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 992436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 993436c6c9cSStella Laurenzo } 994436c6c9cSStella Laurenzo }; 995436c6c9cSStella Laurenzo 996436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 997436c6c9cSStella Laurenzo public: 998436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 999436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "TypeAttr"; 1000436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1001436c6c9cSStella Laurenzo 1002436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1003436c6c9cSStella Laurenzo c.def_static( 1004436c6c9cSStella Laurenzo "get", 1005436c6c9cSStella Laurenzo [](PyType value, DefaultingPyMlirContext context) { 1006436c6c9cSStella Laurenzo MlirAttribute attr = mlirTypeAttrGet(value.get()); 1007436c6c9cSStella Laurenzo return PyTypeAttribute(context->getRef(), attr); 1008436c6c9cSStella Laurenzo }, 1009436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 1010436c6c9cSStella Laurenzo "Gets a uniqued Type attribute"); 1011436c6c9cSStella Laurenzo c.def_property_readonly("value", [](PyTypeAttribute &self) { 1012436c6c9cSStella Laurenzo return PyType(self.getContext()->getRef(), 1013436c6c9cSStella Laurenzo mlirTypeAttrGetValue(self.get())); 1014436c6c9cSStella Laurenzo }); 1015436c6c9cSStella Laurenzo } 1016436c6c9cSStella Laurenzo }; 1017436c6c9cSStella Laurenzo 1018436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values. 1019436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 1020436c6c9cSStella Laurenzo public: 1021436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 1022436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "UnitAttr"; 1023436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1024436c6c9cSStella Laurenzo 1025436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1026436c6c9cSStella Laurenzo c.def_static( 1027436c6c9cSStella Laurenzo "get", 1028436c6c9cSStella Laurenzo [](DefaultingPyMlirContext context) { 1029436c6c9cSStella Laurenzo return PyUnitAttribute(context->getRef(), 1030436c6c9cSStella Laurenzo mlirUnitAttrGet(context->get())); 1031436c6c9cSStella Laurenzo }, 1032436c6c9cSStella Laurenzo py::arg("context") = py::none(), "Create a Unit attribute."); 1033436c6c9cSStella Laurenzo } 1034436c6c9cSStella Laurenzo }; 1035436c6c9cSStella Laurenzo 1036ac2e2d65SDenys Shabalin /// Strided layout attribute subclass. 1037ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute 1038ac2e2d65SDenys Shabalin : public PyConcreteAttribute<PyStridedLayoutAttribute> { 1039ac2e2d65SDenys Shabalin public: 1040ac2e2d65SDenys Shabalin static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; 1041ac2e2d65SDenys Shabalin static constexpr const char *pyClassName = "StridedLayoutAttr"; 1042ac2e2d65SDenys Shabalin using PyConcreteAttribute::PyConcreteAttribute; 1043ac2e2d65SDenys Shabalin 1044ac2e2d65SDenys Shabalin static void bindDerived(ClassTy &c) { 1045ac2e2d65SDenys Shabalin c.def_static( 1046ac2e2d65SDenys Shabalin "get", 1047ac2e2d65SDenys Shabalin [](int64_t offset, const std::vector<int64_t> strides, 1048ac2e2d65SDenys Shabalin DefaultingPyMlirContext ctx) { 1049ac2e2d65SDenys Shabalin MlirAttribute attr = mlirStridedLayoutAttrGet( 1050ac2e2d65SDenys Shabalin ctx->get(), offset, strides.size(), strides.data()); 1051ac2e2d65SDenys Shabalin return PyStridedLayoutAttribute(ctx->getRef(), attr); 1052ac2e2d65SDenys Shabalin }, 1053ac2e2d65SDenys Shabalin py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(), 1054ac2e2d65SDenys Shabalin "Gets a strided layout attribute."); 1055e3fd612eSDenys Shabalin c.def_static( 1056e3fd612eSDenys Shabalin "get_fully_dynamic", 1057e3fd612eSDenys Shabalin [](int64_t rank, DefaultingPyMlirContext ctx) { 1058e3fd612eSDenys Shabalin auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); 1059e3fd612eSDenys Shabalin std::vector<int64_t> strides(rank); 1060e3fd612eSDenys Shabalin std::fill(strides.begin(), strides.end(), dynamic); 1061e3fd612eSDenys Shabalin MlirAttribute attr = mlirStridedLayoutAttrGet( 1062e3fd612eSDenys Shabalin ctx->get(), dynamic, strides.size(), strides.data()); 1063e3fd612eSDenys Shabalin return PyStridedLayoutAttribute(ctx->getRef(), attr); 1064e3fd612eSDenys Shabalin }, 1065e3fd612eSDenys Shabalin py::arg("rank"), py::arg("context") = py::none(), 1066e3fd612eSDenys Shabalin "Gets a strided layout attribute with dynamic offset and strides of a " 1067e3fd612eSDenys Shabalin "given rank."); 1068ac2e2d65SDenys Shabalin c.def_property_readonly( 1069ac2e2d65SDenys Shabalin "offset", 1070ac2e2d65SDenys Shabalin [](PyStridedLayoutAttribute &self) { 1071ac2e2d65SDenys Shabalin return mlirStridedLayoutAttrGetOffset(self); 1072ac2e2d65SDenys Shabalin }, 1073ac2e2d65SDenys Shabalin "Returns the value of the float point attribute"); 1074ac2e2d65SDenys Shabalin c.def_property_readonly( 1075ac2e2d65SDenys Shabalin "strides", 1076ac2e2d65SDenys Shabalin [](PyStridedLayoutAttribute &self) { 1077ac2e2d65SDenys Shabalin intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); 1078ac2e2d65SDenys Shabalin std::vector<int64_t> strides(size); 1079ac2e2d65SDenys Shabalin for (intptr_t i = 0; i < size; i++) { 1080ac2e2d65SDenys Shabalin strides[i] = mlirStridedLayoutAttrGetStride(self, i); 1081ac2e2d65SDenys Shabalin } 1082ac2e2d65SDenys Shabalin return strides; 1083ac2e2d65SDenys Shabalin }, 1084ac2e2d65SDenys Shabalin "Returns the value of the float point attribute"); 1085ac2e2d65SDenys Shabalin } 1086ac2e2d65SDenys Shabalin }; 1087ac2e2d65SDenys Shabalin 1088436c6c9cSStella Laurenzo } // namespace 1089436c6c9cSStella Laurenzo 1090436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) { 1091436c6c9cSStella Laurenzo PyAffineMapAttribute::bind(m); 1092619fd8c2SJeff Niu 1093619fd8c2SJeff Niu PyDenseBoolArrayAttribute::bind(m); 1094619fd8c2SJeff Niu PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); 1095619fd8c2SJeff Niu PyDenseI8ArrayAttribute::bind(m); 1096619fd8c2SJeff Niu PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m); 1097619fd8c2SJeff Niu PyDenseI16ArrayAttribute::bind(m); 1098619fd8c2SJeff Niu PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m); 1099619fd8c2SJeff Niu PyDenseI32ArrayAttribute::bind(m); 1100619fd8c2SJeff Niu PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m); 1101619fd8c2SJeff Niu PyDenseI64ArrayAttribute::bind(m); 1102619fd8c2SJeff Niu PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m); 1103619fd8c2SJeff Niu PyDenseF32ArrayAttribute::bind(m); 1104619fd8c2SJeff Niu PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); 1105619fd8c2SJeff Niu PyDenseF64ArrayAttribute::bind(m); 1106619fd8c2SJeff Niu PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); 1107619fd8c2SJeff Niu 1108436c6c9cSStella Laurenzo PyArrayAttribute::bind(m); 1109436c6c9cSStella Laurenzo PyArrayAttribute::PyArrayAttributeIterator::bind(m); 1110436c6c9cSStella Laurenzo PyBoolAttribute::bind(m); 1111436c6c9cSStella Laurenzo PyDenseElementsAttribute::bind(m); 1112436c6c9cSStella Laurenzo PyDenseFPElementsAttribute::bind(m); 1113436c6c9cSStella Laurenzo PyDenseIntElementsAttribute::bind(m); 1114436c6c9cSStella Laurenzo PyDictAttribute::bind(m); 1115436c6c9cSStella Laurenzo PyFlatSymbolRefAttribute::bind(m); 11165c3861b2SYun Long PyOpaqueAttribute::bind(m); 1117436c6c9cSStella Laurenzo PyFloatAttribute::bind(m); 1118436c6c9cSStella Laurenzo PyIntegerAttribute::bind(m); 1119436c6c9cSStella Laurenzo PyStringAttribute::bind(m); 1120436c6c9cSStella Laurenzo PyTypeAttribute::bind(m); 1121436c6c9cSStella Laurenzo PyUnitAttribute::bind(m); 1122ac2e2d65SDenys Shabalin 1123ac2e2d65SDenys Shabalin PyStridedLayoutAttribute::bind(m); 1124436c6c9cSStella Laurenzo } 1125