1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2436c6c9cSStella Laurenzo // 3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6436c6c9cSStella Laurenzo // 7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8436c6c9cSStella Laurenzo 91fc096afSMehdi Amini #include <utility> 101fc096afSMehdi Amini 11436c6c9cSStella Laurenzo #include "IRModule.h" 12436c6c9cSStella Laurenzo 13436c6c9cSStella Laurenzo #include "PybindUtils.h" 14436c6c9cSStella Laurenzo 15436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h" 16436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h" 17436c6c9cSStella Laurenzo 18436c6c9cSStella Laurenzo namespace py = pybind11; 19436c6c9cSStella Laurenzo using namespace mlir; 20436c6c9cSStella Laurenzo using namespace mlir::python; 21436c6c9cSStella Laurenzo 225d6d30edSStella Laurenzo using llvm::Optional; 23436c6c9cSStella Laurenzo using llvm::SmallVector; 24436c6c9cSStella Laurenzo using llvm::Twine; 25436c6c9cSStella Laurenzo 265d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 275d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline). 285d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 295d6d30edSStella Laurenzo 305d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] = 315d6d30edSStella Laurenzo R"(Gets a DenseElementsAttr from a Python buffer or array. 325d6d30edSStella Laurenzo 335d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based 345d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and 355d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these 365d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer. 375d6d30edSStella Laurenzo 385d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided 395d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal 405d6d30edSStella Laurenzo representation: 415d6d30edSStella Laurenzo 425d6d30edSStella Laurenzo * Integer types (except for i1): the buffer must be byte aligned to the 435d6d30edSStella Laurenzo next byte boundary. 445d6d30edSStella Laurenzo * Floating point types: Must be bit-castable to the given floating point 455d6d30edSStella Laurenzo size. 465d6d30edSStella Laurenzo * i1 (bool): Bit packed into 8bit words where the bit pattern matches a 475d6d30edSStella Laurenzo row major ordering. An arbitrary Numpy `bool_` array can be bit packed to 485d6d30edSStella Laurenzo this specification with: `np.packbits(ary, axis=None, bitorder='little')`. 495d6d30edSStella Laurenzo 505d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0 515d6d30edSStella Laurenzo or 255), then a splat will be created. 525d6d30edSStella Laurenzo 535d6d30edSStella Laurenzo Args: 545d6d30edSStella Laurenzo array: The array or buffer to convert. 555d6d30edSStella Laurenzo signless: If inferring an appropriate MLIR type, use signless types for 565d6d30edSStella Laurenzo integers (defaults True). 575d6d30edSStella Laurenzo type: Skips inference of the MLIR element type and uses this instead. The 585d6d30edSStella Laurenzo storage size must be consistent with the actual contents of the buffer. 595d6d30edSStella Laurenzo shape: Overrides the shape of the buffer when constructing the MLIR 605d6d30edSStella Laurenzo shaped type. This is needed when the physical and logical shape differ (as 615d6d30edSStella Laurenzo for i1). 625d6d30edSStella Laurenzo context: Explicit context, if not from context manager. 635d6d30edSStella Laurenzo 645d6d30edSStella Laurenzo Returns: 655d6d30edSStella Laurenzo DenseElementsAttr on success. 665d6d30edSStella Laurenzo 675d6d30edSStella Laurenzo Raises: 685d6d30edSStella Laurenzo ValueError: If the type of the buffer or array cannot be matched to an MLIR 695d6d30edSStella Laurenzo type or if the buffer does not meet expectations. 705d6d30edSStella Laurenzo )"; 715d6d30edSStella Laurenzo 72436c6c9cSStella Laurenzo namespace { 73436c6c9cSStella Laurenzo 74436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) { 75436c6c9cSStella Laurenzo return mlirStringRefCreate(s.data(), s.size()); 76436c6c9cSStella Laurenzo } 77436c6c9cSStella Laurenzo 78436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 79436c6c9cSStella Laurenzo public: 80436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 81436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineMapAttr"; 82436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 83436c6c9cSStella Laurenzo 84436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 85436c6c9cSStella Laurenzo c.def_static( 86436c6c9cSStella Laurenzo "get", 87436c6c9cSStella Laurenzo [](PyAffineMap &affineMap) { 88436c6c9cSStella Laurenzo MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 89436c6c9cSStella Laurenzo return PyAffineMapAttribute(affineMap.getContext(), attr); 90436c6c9cSStella Laurenzo }, 91436c6c9cSStella Laurenzo py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 92436c6c9cSStella Laurenzo } 93436c6c9cSStella Laurenzo }; 94436c6c9cSStella Laurenzo 95ed9e52f3SAlex Zinenko template <typename T> 96ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) { 97ed9e52f3SAlex Zinenko try { 98ed9e52f3SAlex Zinenko return object.cast<T>(); 99ed9e52f3SAlex Zinenko } catch (py::cast_error &err) { 100ed9e52f3SAlex Zinenko std::string msg = 101ed9e52f3SAlex Zinenko std::string( 102ed9e52f3SAlex Zinenko "Invalid attribute when attempting to create an ArrayAttribute (") + 103ed9e52f3SAlex Zinenko err.what() + ")"; 104ed9e52f3SAlex Zinenko throw py::cast_error(msg); 105ed9e52f3SAlex Zinenko } catch (py::reference_cast_error &err) { 106ed9e52f3SAlex Zinenko std::string msg = std::string("Invalid attribute (None?) when attempting " 107ed9e52f3SAlex Zinenko "to create an ArrayAttribute (") + 108ed9e52f3SAlex Zinenko err.what() + ")"; 109ed9e52f3SAlex Zinenko throw py::cast_error(msg); 110ed9e52f3SAlex Zinenko } 111ed9e52f3SAlex Zinenko } 112ed9e52f3SAlex Zinenko 113*619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived 114*619fd8c2SJeff Niu /// implementation class. 115*619fd8c2SJeff Niu template <typename EltTy, typename DerivedT> 116*619fd8c2SJeff Niu class PyDenseArrayAttribute 117*619fd8c2SJeff Niu : public PyConcreteAttribute<PyDenseArrayAttribute<EltTy, DerivedT>> { 118*619fd8c2SJeff Niu public: 119*619fd8c2SJeff Niu static constexpr typename PyConcreteAttribute< 120*619fd8c2SJeff Niu PyDenseArrayAttribute<EltTy, DerivedT>>::IsAFunctionTy isaFunction = 121*619fd8c2SJeff Niu DerivedT::isaFunction; 122*619fd8c2SJeff Niu static constexpr const char *pyClassName = DerivedT::pyClassName; 123*619fd8c2SJeff Niu using PyConcreteAttribute< 124*619fd8c2SJeff Niu PyDenseArrayAttribute<EltTy, DerivedT>>::PyConcreteAttribute; 125*619fd8c2SJeff Niu 126*619fd8c2SJeff Niu /// Iterator over the integer elements of a dense array. 127*619fd8c2SJeff Niu class PyDenseArrayIterator { 128*619fd8c2SJeff Niu public: 129*619fd8c2SJeff Niu PyDenseArrayIterator(PyAttribute attr) : attr(attr) {} 130*619fd8c2SJeff Niu 131*619fd8c2SJeff Niu /// Return a copy of the iterator. 132*619fd8c2SJeff Niu PyDenseArrayIterator dunderIter() { return *this; } 133*619fd8c2SJeff Niu 134*619fd8c2SJeff Niu /// Return the next element. 135*619fd8c2SJeff Niu EltTy dunderNext() { 136*619fd8c2SJeff Niu // Throw if the index has reached the end. 137*619fd8c2SJeff Niu if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) 138*619fd8c2SJeff Niu throw py::stop_iteration(); 139*619fd8c2SJeff Niu return DerivedT::getElement(attr.get(), nextIndex++); 140*619fd8c2SJeff Niu } 141*619fd8c2SJeff Niu 142*619fd8c2SJeff Niu /// Bind the iterator class. 143*619fd8c2SJeff Niu static void bind(py::module &m) { 144*619fd8c2SJeff Niu py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName, 145*619fd8c2SJeff Niu py::module_local()) 146*619fd8c2SJeff Niu .def("__iter__", &PyDenseArrayIterator::dunderIter) 147*619fd8c2SJeff Niu .def("__next__", &PyDenseArrayIterator::dunderNext); 148*619fd8c2SJeff Niu } 149*619fd8c2SJeff Niu 150*619fd8c2SJeff Niu private: 151*619fd8c2SJeff Niu /// The referenced dense array attribute. 152*619fd8c2SJeff Niu PyAttribute attr; 153*619fd8c2SJeff Niu /// The next index to read. 154*619fd8c2SJeff Niu int nextIndex = 0; 155*619fd8c2SJeff Niu }; 156*619fd8c2SJeff Niu 157*619fd8c2SJeff Niu /// Get the element at the given index. 158*619fd8c2SJeff Niu EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); } 159*619fd8c2SJeff Niu 160*619fd8c2SJeff Niu /// Bind the attribute class. 161*619fd8c2SJeff Niu static void bindDerived(typename PyConcreteAttribute< 162*619fd8c2SJeff Niu PyDenseArrayAttribute<EltTy, DerivedT>>::ClassTy &c) { 163*619fd8c2SJeff Niu // Bind the constructor. 164*619fd8c2SJeff Niu c.def_static( 165*619fd8c2SJeff Niu "get", 166*619fd8c2SJeff Niu [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) { 167*619fd8c2SJeff Niu MlirAttribute attr = 168*619fd8c2SJeff Niu DerivedT::getAttribute(ctx->get(), values.size(), values.data()); 169*619fd8c2SJeff Niu return PyDenseArrayAttribute<EltTy, DerivedT>(ctx->getRef(), attr); 170*619fd8c2SJeff Niu }, 171*619fd8c2SJeff Niu py::arg("values"), py::arg("context") = py::none(), 172*619fd8c2SJeff Niu "Gets a uniqued dense array attribute"); 173*619fd8c2SJeff Niu // Bind the array methods. 174*619fd8c2SJeff Niu c.def("__getitem__", 175*619fd8c2SJeff Niu [](PyDenseArrayAttribute<EltTy, DerivedT> &arr, intptr_t i) { 176*619fd8c2SJeff Niu if (i >= mlirDenseArrayGetNumElements(arr)) 177*619fd8c2SJeff Niu throw py::index_error("DenseArray index out of range"); 178*619fd8c2SJeff Niu return arr.getItem(i); 179*619fd8c2SJeff Niu }); 180*619fd8c2SJeff Niu c.def("__len__", [](const PyDenseArrayAttribute<EltTy, DerivedT> &arr) { 181*619fd8c2SJeff Niu return mlirDenseArrayGetNumElements(arr); 182*619fd8c2SJeff Niu }); 183*619fd8c2SJeff Niu c.def("__iter__", [](const PyDenseArrayAttribute<EltTy, DerivedT> &arr) { 184*619fd8c2SJeff Niu return PyDenseArrayIterator(arr); 185*619fd8c2SJeff Niu }); 186*619fd8c2SJeff Niu // Bind a concat. 187*619fd8c2SJeff Niu c.def("__add__", [](PyDenseArrayAttribute<EltTy, DerivedT> &arr, 188*619fd8c2SJeff Niu py::list extras) { 189*619fd8c2SJeff Niu std::vector<EltTy> values; 190*619fd8c2SJeff Niu intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); 191*619fd8c2SJeff Niu values.reserve(numOldElements + py::len(extras)); 192*619fd8c2SJeff Niu for (intptr_t i = 0; i < numOldElements; ++i) 193*619fd8c2SJeff Niu values.push_back(arr.getItem(i)); 194*619fd8c2SJeff Niu for (py::handle attr : extras) 195*619fd8c2SJeff Niu values.push_back(pyTryCast<EltTy>(attr)); 196*619fd8c2SJeff Niu MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(), 197*619fd8c2SJeff Niu values.size(), values.data()); 198*619fd8c2SJeff Niu return PyDenseArrayAttribute<EltTy, DerivedT>(arr.getContext(), attr); 199*619fd8c2SJeff Niu }); 200*619fd8c2SJeff Niu } 201*619fd8c2SJeff Niu }; 202*619fd8c2SJeff Niu 203*619fd8c2SJeff Niu /// Instantiate the python dense array classes. 204*619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute 205*619fd8c2SJeff Niu : public PyDenseArrayAttribute<int, PyDenseBoolArrayAttribute> { 206*619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; 207*619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseBoolArrayGet; 208*619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseBoolArrayGetElement; 209*619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseBoolArrayAttr"; 210*619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseBoolArrayIterator"; 211*619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 212*619fd8c2SJeff Niu }; 213*619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute 214*619fd8c2SJeff Niu : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> { 215*619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array; 216*619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI8ArrayGet; 217*619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI8ArrayGetElement; 218*619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI8ArrayAttr"; 219*619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI8ArrayIterator"; 220*619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 221*619fd8c2SJeff Niu }; 222*619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute 223*619fd8c2SJeff Niu : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> { 224*619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array; 225*619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI16ArrayGet; 226*619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI16ArrayGetElement; 227*619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI16ArrayAttr"; 228*619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI16ArrayIterator"; 229*619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 230*619fd8c2SJeff Niu }; 231*619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute 232*619fd8c2SJeff Niu : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> { 233*619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array; 234*619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI32ArrayGet; 235*619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI32ArrayGetElement; 236*619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI32ArrayAttr"; 237*619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI32ArrayIterator"; 238*619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 239*619fd8c2SJeff Niu }; 240*619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute 241*619fd8c2SJeff Niu : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> { 242*619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array; 243*619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI64ArrayGet; 244*619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI64ArrayGetElement; 245*619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI64ArrayAttr"; 246*619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI64ArrayIterator"; 247*619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 248*619fd8c2SJeff Niu }; 249*619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute 250*619fd8c2SJeff Niu : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> { 251*619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array; 252*619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseF32ArrayGet; 253*619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseF32ArrayGetElement; 254*619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseF32ArrayAttr"; 255*619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseF32ArrayIterator"; 256*619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 257*619fd8c2SJeff Niu }; 258*619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute 259*619fd8c2SJeff Niu : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> { 260*619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array; 261*619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseF64ArrayGet; 262*619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseF64ArrayGetElement; 263*619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseF64ArrayAttr"; 264*619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseF64ArrayIterator"; 265*619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 266*619fd8c2SJeff Niu }; 267*619fd8c2SJeff Niu 268436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 269436c6c9cSStella Laurenzo public: 270436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 271436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "ArrayAttr"; 272436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 273436c6c9cSStella Laurenzo 274436c6c9cSStella Laurenzo class PyArrayAttributeIterator { 275436c6c9cSStella Laurenzo public: 2761fc096afSMehdi Amini PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} 277436c6c9cSStella Laurenzo 278436c6c9cSStella Laurenzo PyArrayAttributeIterator &dunderIter() { return *this; } 279436c6c9cSStella Laurenzo 280436c6c9cSStella Laurenzo PyAttribute dunderNext() { 281436c6c9cSStella Laurenzo if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { 282436c6c9cSStella Laurenzo throw py::stop_iteration(); 283436c6c9cSStella Laurenzo } 284436c6c9cSStella Laurenzo return PyAttribute(attr.getContext(), 285436c6c9cSStella Laurenzo mlirArrayAttrGetElement(attr.get(), nextIndex++)); 286436c6c9cSStella Laurenzo } 287436c6c9cSStella Laurenzo 288436c6c9cSStella Laurenzo static void bind(py::module &m) { 289f05ff4f7SStella Laurenzo py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator", 290f05ff4f7SStella Laurenzo py::module_local()) 291436c6c9cSStella Laurenzo .def("__iter__", &PyArrayAttributeIterator::dunderIter) 292436c6c9cSStella Laurenzo .def("__next__", &PyArrayAttributeIterator::dunderNext); 293436c6c9cSStella Laurenzo } 294436c6c9cSStella Laurenzo 295436c6c9cSStella Laurenzo private: 296436c6c9cSStella Laurenzo PyAttribute attr; 297436c6c9cSStella Laurenzo int nextIndex = 0; 298436c6c9cSStella Laurenzo }; 299436c6c9cSStella Laurenzo 300ed9e52f3SAlex Zinenko PyAttribute getItem(intptr_t i) { 301ed9e52f3SAlex Zinenko return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i)); 302ed9e52f3SAlex Zinenko } 303ed9e52f3SAlex Zinenko 304436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 305436c6c9cSStella Laurenzo c.def_static( 306436c6c9cSStella Laurenzo "get", 307436c6c9cSStella Laurenzo [](py::list attributes, DefaultingPyMlirContext context) { 308436c6c9cSStella Laurenzo SmallVector<MlirAttribute> mlirAttributes; 309436c6c9cSStella Laurenzo mlirAttributes.reserve(py::len(attributes)); 310436c6c9cSStella Laurenzo for (auto attribute : attributes) { 311ed9e52f3SAlex Zinenko mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); 312436c6c9cSStella Laurenzo } 313436c6c9cSStella Laurenzo MlirAttribute attr = mlirArrayAttrGet( 314436c6c9cSStella Laurenzo context->get(), mlirAttributes.size(), mlirAttributes.data()); 315436c6c9cSStella Laurenzo return PyArrayAttribute(context->getRef(), attr); 316436c6c9cSStella Laurenzo }, 317436c6c9cSStella Laurenzo py::arg("attributes"), py::arg("context") = py::none(), 318436c6c9cSStella Laurenzo "Gets a uniqued Array attribute"); 319436c6c9cSStella Laurenzo c.def("__getitem__", 320436c6c9cSStella Laurenzo [](PyArrayAttribute &arr, intptr_t i) { 321436c6c9cSStella Laurenzo if (i >= mlirArrayAttrGetNumElements(arr)) 322436c6c9cSStella Laurenzo throw py::index_error("ArrayAttribute index out of range"); 323ed9e52f3SAlex Zinenko return arr.getItem(i); 324436c6c9cSStella Laurenzo }) 325436c6c9cSStella Laurenzo .def("__len__", 326436c6c9cSStella Laurenzo [](const PyArrayAttribute &arr) { 327436c6c9cSStella Laurenzo return mlirArrayAttrGetNumElements(arr); 328436c6c9cSStella Laurenzo }) 329436c6c9cSStella Laurenzo .def("__iter__", [](const PyArrayAttribute &arr) { 330436c6c9cSStella Laurenzo return PyArrayAttributeIterator(arr); 331436c6c9cSStella Laurenzo }); 332ed9e52f3SAlex Zinenko c.def("__add__", [](PyArrayAttribute arr, py::list extras) { 333ed9e52f3SAlex Zinenko std::vector<MlirAttribute> attributes; 334ed9e52f3SAlex Zinenko intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); 335ed9e52f3SAlex Zinenko attributes.reserve(numOldElements + py::len(extras)); 336ed9e52f3SAlex Zinenko for (intptr_t i = 0; i < numOldElements; ++i) 337ed9e52f3SAlex Zinenko attributes.push_back(arr.getItem(i)); 338ed9e52f3SAlex Zinenko for (py::handle attr : extras) 339ed9e52f3SAlex Zinenko attributes.push_back(pyTryCast<PyAttribute>(attr)); 340ed9e52f3SAlex Zinenko MlirAttribute arrayAttr = mlirArrayAttrGet( 341ed9e52f3SAlex Zinenko arr.getContext()->get(), attributes.size(), attributes.data()); 342ed9e52f3SAlex Zinenko return PyArrayAttribute(arr.getContext(), arrayAttr); 343ed9e52f3SAlex Zinenko }); 344436c6c9cSStella Laurenzo } 345436c6c9cSStella Laurenzo }; 346436c6c9cSStella Laurenzo 347436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr. 348436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 349436c6c9cSStella Laurenzo public: 350436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 351436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FloatAttr"; 352436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 353436c6c9cSStella Laurenzo 354436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 355436c6c9cSStella Laurenzo c.def_static( 356436c6c9cSStella Laurenzo "get", 357436c6c9cSStella Laurenzo [](PyType &type, double value, DefaultingPyLocation loc) { 358436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 359436c6c9cSStella Laurenzo // TODO: Rework error reporting once diagnostic engine is exposed 360436c6c9cSStella Laurenzo // in C API. 361436c6c9cSStella Laurenzo if (mlirAttributeIsNull(attr)) { 362436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, 363436c6c9cSStella Laurenzo Twine("invalid '") + 364436c6c9cSStella Laurenzo py::repr(py::cast(type)).cast<std::string>() + 365436c6c9cSStella Laurenzo "' and expected floating point type."); 366436c6c9cSStella Laurenzo } 367436c6c9cSStella Laurenzo return PyFloatAttribute(type.getContext(), attr); 368436c6c9cSStella Laurenzo }, 369436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 370436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a type"); 371436c6c9cSStella Laurenzo c.def_static( 372436c6c9cSStella Laurenzo "get_f32", 373436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 374436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 375436c6c9cSStella Laurenzo context->get(), mlirF32TypeGet(context->get()), value); 376436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 377436c6c9cSStella Laurenzo }, 378436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 379436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f32 type"); 380436c6c9cSStella Laurenzo c.def_static( 381436c6c9cSStella Laurenzo "get_f64", 382436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 383436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 384436c6c9cSStella Laurenzo context->get(), mlirF64TypeGet(context->get()), value); 385436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 386436c6c9cSStella Laurenzo }, 387436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 388436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f64 type"); 389436c6c9cSStella Laurenzo c.def_property_readonly( 390436c6c9cSStella Laurenzo "value", 391436c6c9cSStella Laurenzo [](PyFloatAttribute &self) { 392436c6c9cSStella Laurenzo return mlirFloatAttrGetValueDouble(self); 393436c6c9cSStella Laurenzo }, 394436c6c9cSStella Laurenzo "Returns the value of the float point attribute"); 395436c6c9cSStella Laurenzo } 396436c6c9cSStella Laurenzo }; 397436c6c9cSStella Laurenzo 398436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr. 399436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 400436c6c9cSStella Laurenzo public: 401436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 402436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "IntegerAttr"; 403436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 404436c6c9cSStella Laurenzo 405436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 406436c6c9cSStella Laurenzo c.def_static( 407436c6c9cSStella Laurenzo "get", 408436c6c9cSStella Laurenzo [](PyType &type, int64_t value) { 409436c6c9cSStella Laurenzo MlirAttribute attr = mlirIntegerAttrGet(type, value); 410436c6c9cSStella Laurenzo return PyIntegerAttribute(type.getContext(), attr); 411436c6c9cSStella Laurenzo }, 412436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), 413436c6c9cSStella Laurenzo "Gets an uniqued integer attribute associated to a type"); 414436c6c9cSStella Laurenzo c.def_property_readonly( 415436c6c9cSStella Laurenzo "value", 416e9db306dSrkayaith [](PyIntegerAttribute &self) -> py::int_ { 417e9db306dSrkayaith MlirType type = mlirAttributeGetType(self); 418e9db306dSrkayaith if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) 419436c6c9cSStella Laurenzo return mlirIntegerAttrGetValueInt(self); 420e9db306dSrkayaith if (mlirIntegerTypeIsSigned(type)) 421e9db306dSrkayaith return mlirIntegerAttrGetValueSInt(self); 422e9db306dSrkayaith return mlirIntegerAttrGetValueUInt(self); 423436c6c9cSStella Laurenzo }, 424436c6c9cSStella Laurenzo "Returns the value of the integer attribute"); 425436c6c9cSStella Laurenzo } 426436c6c9cSStella Laurenzo }; 427436c6c9cSStella Laurenzo 428436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr. 429436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 430436c6c9cSStella Laurenzo public: 431436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 432436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "BoolAttr"; 433436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 434436c6c9cSStella Laurenzo 435436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 436436c6c9cSStella Laurenzo c.def_static( 437436c6c9cSStella Laurenzo "get", 438436c6c9cSStella Laurenzo [](bool value, DefaultingPyMlirContext context) { 439436c6c9cSStella Laurenzo MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 440436c6c9cSStella Laurenzo return PyBoolAttribute(context->getRef(), attr); 441436c6c9cSStella Laurenzo }, 442436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 443436c6c9cSStella Laurenzo "Gets an uniqued bool attribute"); 444436c6c9cSStella Laurenzo c.def_property_readonly( 445436c6c9cSStella Laurenzo "value", 446436c6c9cSStella Laurenzo [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, 447436c6c9cSStella Laurenzo "Returns the value of the bool attribute"); 448436c6c9cSStella Laurenzo } 449436c6c9cSStella Laurenzo }; 450436c6c9cSStella Laurenzo 451436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute 452436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 453436c6c9cSStella Laurenzo public: 454436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 455436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 456436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 457436c6c9cSStella Laurenzo 458436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 459436c6c9cSStella Laurenzo c.def_static( 460436c6c9cSStella Laurenzo "get", 461436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 462436c6c9cSStella Laurenzo MlirAttribute attr = 463436c6c9cSStella Laurenzo mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 464436c6c9cSStella Laurenzo return PyFlatSymbolRefAttribute(context->getRef(), attr); 465436c6c9cSStella Laurenzo }, 466436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 467436c6c9cSStella Laurenzo "Gets a uniqued FlatSymbolRef attribute"); 468436c6c9cSStella Laurenzo c.def_property_readonly( 469436c6c9cSStella Laurenzo "value", 470436c6c9cSStella Laurenzo [](PyFlatSymbolRefAttribute &self) { 471436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 472436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 473436c6c9cSStella Laurenzo }, 474436c6c9cSStella Laurenzo "Returns the value of the FlatSymbolRef attribute as a string"); 475436c6c9cSStella Laurenzo } 476436c6c9cSStella Laurenzo }; 477436c6c9cSStella Laurenzo 4785c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> { 4795c3861b2SYun Long public: 4805c3861b2SYun Long static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; 4815c3861b2SYun Long static constexpr const char *pyClassName = "OpaqueAttr"; 4825c3861b2SYun Long using PyConcreteAttribute::PyConcreteAttribute; 4835c3861b2SYun Long 4845c3861b2SYun Long static void bindDerived(ClassTy &c) { 4855c3861b2SYun Long c.def_static( 4865c3861b2SYun Long "get", 4875c3861b2SYun Long [](std::string dialectNamespace, py::buffer buffer, PyType &type, 4885c3861b2SYun Long DefaultingPyMlirContext context) { 4895c3861b2SYun Long const py::buffer_info bufferInfo = buffer.request(); 4905c3861b2SYun Long intptr_t bufferSize = bufferInfo.size; 4915c3861b2SYun Long MlirAttribute attr = mlirOpaqueAttrGet( 4925c3861b2SYun Long context->get(), toMlirStringRef(dialectNamespace), bufferSize, 4935c3861b2SYun Long static_cast<char *>(bufferInfo.ptr), type); 4945c3861b2SYun Long return PyOpaqueAttribute(context->getRef(), attr); 4955c3861b2SYun Long }, 4965c3861b2SYun Long py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"), 4975c3861b2SYun Long py::arg("context") = py::none(), "Gets an Opaque attribute."); 4985c3861b2SYun Long c.def_property_readonly( 4995c3861b2SYun Long "dialect_namespace", 5005c3861b2SYun Long [](PyOpaqueAttribute &self) { 5015c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); 5025c3861b2SYun Long return py::str(stringRef.data, stringRef.length); 5035c3861b2SYun Long }, 5045c3861b2SYun Long "Returns the dialect namespace for the Opaque attribute as a string"); 5055c3861b2SYun Long c.def_property_readonly( 5065c3861b2SYun Long "data", 5075c3861b2SYun Long [](PyOpaqueAttribute &self) { 5085c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetData(self); 5095c3861b2SYun Long return py::str(stringRef.data, stringRef.length); 5105c3861b2SYun Long }, 5115c3861b2SYun Long "Returns the data for the Opaqued attributes as a string"); 5125c3861b2SYun Long } 5135c3861b2SYun Long }; 5145c3861b2SYun Long 515436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 516436c6c9cSStella Laurenzo public: 517436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 518436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "StringAttr"; 519436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 520436c6c9cSStella Laurenzo 521436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 522436c6c9cSStella Laurenzo c.def_static( 523436c6c9cSStella Laurenzo "get", 524436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 525436c6c9cSStella Laurenzo MlirAttribute attr = 526436c6c9cSStella Laurenzo mlirStringAttrGet(context->get(), toMlirStringRef(value)); 527436c6c9cSStella Laurenzo return PyStringAttribute(context->getRef(), attr); 528436c6c9cSStella Laurenzo }, 529436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 530436c6c9cSStella Laurenzo "Gets a uniqued string attribute"); 531436c6c9cSStella Laurenzo c.def_static( 532436c6c9cSStella Laurenzo "get_typed", 533436c6c9cSStella Laurenzo [](PyType &type, std::string value) { 534436c6c9cSStella Laurenzo MlirAttribute attr = 535436c6c9cSStella Laurenzo mlirStringAttrTypedGet(type, toMlirStringRef(value)); 536436c6c9cSStella Laurenzo return PyStringAttribute(type.getContext(), attr); 537436c6c9cSStella Laurenzo }, 538a6e7d024SStella Laurenzo py::arg("type"), py::arg("value"), 539436c6c9cSStella Laurenzo "Gets a uniqued string attribute associated to a type"); 540436c6c9cSStella Laurenzo c.def_property_readonly( 541436c6c9cSStella Laurenzo "value", 542436c6c9cSStella Laurenzo [](PyStringAttribute &self) { 543436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirStringAttrGetValue(self); 544436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 545436c6c9cSStella Laurenzo }, 546436c6c9cSStella Laurenzo "Returns the value of the string attribute"); 547436c6c9cSStella Laurenzo } 548436c6c9cSStella Laurenzo }; 549436c6c9cSStella Laurenzo 550436c6c9cSStella Laurenzo // TODO: Support construction of string elements. 551436c6c9cSStella Laurenzo class PyDenseElementsAttribute 552436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseElementsAttribute> { 553436c6c9cSStella Laurenzo public: 554436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 555436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseElementsAttr"; 556436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 557436c6c9cSStella Laurenzo 558436c6c9cSStella Laurenzo static PyDenseElementsAttribute 5595d6d30edSStella Laurenzo getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType, 5605d6d30edSStella Laurenzo Optional<std::vector<int64_t>> explicitShape, 561436c6c9cSStella Laurenzo DefaultingPyMlirContext contextWrapper) { 562436c6c9cSStella Laurenzo // Request a contiguous view. In exotic cases, this will cause a copy. 563436c6c9cSStella Laurenzo int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; 564436c6c9cSStella Laurenzo Py_buffer *view = new Py_buffer(); 565436c6c9cSStella Laurenzo if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { 566436c6c9cSStella Laurenzo delete view; 567436c6c9cSStella Laurenzo throw py::error_already_set(); 568436c6c9cSStella Laurenzo } 569436c6c9cSStella Laurenzo py::buffer_info arrayInfo(view); 5705d6d30edSStella Laurenzo SmallVector<int64_t> shape; 5715d6d30edSStella Laurenzo if (explicitShape) { 5725d6d30edSStella Laurenzo shape.append(explicitShape->begin(), explicitShape->end()); 5735d6d30edSStella Laurenzo } else { 5745d6d30edSStella Laurenzo shape.append(arrayInfo.shape.begin(), 5755d6d30edSStella Laurenzo arrayInfo.shape.begin() + arrayInfo.ndim); 5765d6d30edSStella Laurenzo } 577436c6c9cSStella Laurenzo 5785d6d30edSStella Laurenzo MlirAttribute encodingAttr = mlirAttributeGetNull(); 579436c6c9cSStella Laurenzo MlirContext context = contextWrapper->get(); 5805d6d30edSStella Laurenzo 5815d6d30edSStella Laurenzo // Detect format codes that are suitable for bulk loading. This includes 5825d6d30edSStella Laurenzo // all byte aligned integer and floating point types up to 8 bytes. 5835d6d30edSStella Laurenzo // Notably, this excludes, bool (which needs to be bit-packed) and 5845d6d30edSStella Laurenzo // other exotics which do not have a direct representation in the buffer 5855d6d30edSStella Laurenzo // protocol (i.e. complex, etc). 5865d6d30edSStella Laurenzo Optional<MlirType> bulkLoadElementType; 5875d6d30edSStella Laurenzo if (explicitType) { 5885d6d30edSStella Laurenzo bulkLoadElementType = *explicitType; 5895d6d30edSStella Laurenzo } else if (arrayInfo.format == "f") { 590436c6c9cSStella Laurenzo // f32 591436c6c9cSStella Laurenzo assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); 5925d6d30edSStella Laurenzo bulkLoadElementType = mlirF32TypeGet(context); 593436c6c9cSStella Laurenzo } else if (arrayInfo.format == "d") { 594436c6c9cSStella Laurenzo // f64 595436c6c9cSStella Laurenzo assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); 5965d6d30edSStella Laurenzo bulkLoadElementType = mlirF64TypeGet(context); 5975d6d30edSStella Laurenzo } else if (arrayInfo.format == "e") { 5985d6d30edSStella Laurenzo // f16 5995d6d30edSStella Laurenzo assert(arrayInfo.itemsize == 2 && "mismatched array itemsize"); 6005d6d30edSStella Laurenzo bulkLoadElementType = mlirF16TypeGet(context); 601436c6c9cSStella Laurenzo } else if (isSignedIntegerFormat(arrayInfo.format)) { 602436c6c9cSStella Laurenzo if (arrayInfo.itemsize == 4) { 603436c6c9cSStella Laurenzo // i32 6045d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32) 605436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 32); 606436c6c9cSStella Laurenzo } else if (arrayInfo.itemsize == 8) { 607436c6c9cSStella Laurenzo // i64 6085d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64) 609436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 64); 6105d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 1) { 6115d6d30edSStella Laurenzo // i8 6125d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 6135d6d30edSStella Laurenzo : mlirIntegerTypeSignedGet(context, 8); 6145d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 2) { 6155d6d30edSStella Laurenzo // i16 6165d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16) 6175d6d30edSStella Laurenzo : mlirIntegerTypeSignedGet(context, 16); 618436c6c9cSStella Laurenzo } 619436c6c9cSStella Laurenzo } else if (isUnsignedIntegerFormat(arrayInfo.format)) { 620436c6c9cSStella Laurenzo if (arrayInfo.itemsize == 4) { 621436c6c9cSStella Laurenzo // unsigned i32 6225d6d30edSStella Laurenzo bulkLoadElementType = signless 623436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 32) 624436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 32); 625436c6c9cSStella Laurenzo } else if (arrayInfo.itemsize == 8) { 626436c6c9cSStella Laurenzo // unsigned i64 6275d6d30edSStella Laurenzo bulkLoadElementType = signless 628436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 64) 629436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 64); 6305d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 1) { 6315d6d30edSStella Laurenzo // i8 6325d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 6335d6d30edSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 8); 6345d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 2) { 6355d6d30edSStella Laurenzo // i16 6365d6d30edSStella Laurenzo bulkLoadElementType = signless 6375d6d30edSStella Laurenzo ? mlirIntegerTypeGet(context, 16) 6385d6d30edSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 16); 639436c6c9cSStella Laurenzo } 640436c6c9cSStella Laurenzo } 6415d6d30edSStella Laurenzo if (bulkLoadElementType) { 6425d6d30edSStella Laurenzo auto shapedType = mlirRankedTensorTypeGet( 6435d6d30edSStella Laurenzo shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); 6445d6d30edSStella Laurenzo size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize; 6455d6d30edSStella Laurenzo MlirAttribute attr = mlirDenseElementsAttrRawBufferGet( 6465d6d30edSStella Laurenzo shapedType, rawBufferSize, arrayInfo.ptr); 6475d6d30edSStella Laurenzo if (mlirAttributeIsNull(attr)) { 6485d6d30edSStella Laurenzo throw std::invalid_argument( 6495d6d30edSStella Laurenzo "DenseElementsAttr could not be constructed from the given buffer. " 6505d6d30edSStella Laurenzo "This may mean that the Python buffer layout does not match that " 6515d6d30edSStella Laurenzo "MLIR expected layout and is a bug."); 6525d6d30edSStella Laurenzo } 6535d6d30edSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), attr); 6545d6d30edSStella Laurenzo } 655436c6c9cSStella Laurenzo 6565d6d30edSStella Laurenzo throw std::invalid_argument( 6575d6d30edSStella Laurenzo std::string("unimplemented array format conversion from format: ") + 6585d6d30edSStella Laurenzo arrayInfo.format); 659436c6c9cSStella Laurenzo } 660436c6c9cSStella Laurenzo 6611fc096afSMehdi Amini static PyDenseElementsAttribute getSplat(const PyType &shapedType, 662436c6c9cSStella Laurenzo PyAttribute &elementAttr) { 663436c6c9cSStella Laurenzo auto contextWrapper = 664436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 665436c6c9cSStella Laurenzo if (!mlirAttributeIsAInteger(elementAttr) && 666436c6c9cSStella Laurenzo !mlirAttributeIsAFloat(elementAttr)) { 667436c6c9cSStella Laurenzo std::string message = "Illegal element type for DenseElementsAttr: "; 668436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 669436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 670436c6c9cSStella Laurenzo } 671436c6c9cSStella Laurenzo if (!mlirTypeIsAShaped(shapedType) || 672436c6c9cSStella Laurenzo !mlirShapedTypeHasStaticShape(shapedType)) { 673436c6c9cSStella Laurenzo std::string message = 674436c6c9cSStella Laurenzo "Expected a static ShapedType for the shaped_type parameter: "; 675436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 676436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 677436c6c9cSStella Laurenzo } 678436c6c9cSStella Laurenzo MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 679436c6c9cSStella Laurenzo MlirType attrType = mlirAttributeGetType(elementAttr); 680436c6c9cSStella Laurenzo if (!mlirTypeEqual(shapedElementType, attrType)) { 681436c6c9cSStella Laurenzo std::string message = 682436c6c9cSStella Laurenzo "Shaped element type and attribute type must be equal: shaped="; 683436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 684436c6c9cSStella Laurenzo message.append(", element="); 685436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 686436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 687436c6c9cSStella Laurenzo } 688436c6c9cSStella Laurenzo 689436c6c9cSStella Laurenzo MlirAttribute elements = 690436c6c9cSStella Laurenzo mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 691436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 692436c6c9cSStella Laurenzo } 693436c6c9cSStella Laurenzo 694436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 695436c6c9cSStella Laurenzo 696436c6c9cSStella Laurenzo py::buffer_info accessBuffer() { 6975d6d30edSStella Laurenzo if (mlirDenseElementsAttrIsSplat(*this)) { 698c5f445d1SStella Laurenzo // TODO: Currently crashes the program. 6995d6d30edSStella Laurenzo // Reported as https://github.com/pybind/pybind11/issues/3336 700c5f445d1SStella Laurenzo throw std::invalid_argument( 701c5f445d1SStella Laurenzo "unsupported data type for conversion to Python buffer"); 7025d6d30edSStella Laurenzo } 7035d6d30edSStella Laurenzo 704436c6c9cSStella Laurenzo MlirType shapedType = mlirAttributeGetType(*this); 705436c6c9cSStella Laurenzo MlirType elementType = mlirShapedTypeGetElementType(shapedType); 7065d6d30edSStella Laurenzo std::string format; 707436c6c9cSStella Laurenzo 708436c6c9cSStella Laurenzo if (mlirTypeIsAF32(elementType)) { 709436c6c9cSStella Laurenzo // f32 7105d6d30edSStella Laurenzo return bufferInfo<float>(shapedType); 71102b6fb21SMehdi Amini } 71202b6fb21SMehdi Amini if (mlirTypeIsAF64(elementType)) { 713436c6c9cSStella Laurenzo // f64 7145d6d30edSStella Laurenzo return bufferInfo<double>(shapedType); 715bb56c2b3SMehdi Amini } 716bb56c2b3SMehdi Amini if (mlirTypeIsAF16(elementType)) { 7175d6d30edSStella Laurenzo // f16 7185d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType, "e"); 719bb56c2b3SMehdi Amini } 720bb56c2b3SMehdi Amini if (mlirTypeIsAInteger(elementType) && 721436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 32) { 722436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 723436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 724436c6c9cSStella Laurenzo // i32 7255d6d30edSStella Laurenzo return bufferInfo<int32_t>(shapedType); 726e5639b3fSMehdi Amini } 727e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 728436c6c9cSStella Laurenzo // unsigned i32 7295d6d30edSStella Laurenzo return bufferInfo<uint32_t>(shapedType); 730436c6c9cSStella Laurenzo } 731436c6c9cSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 732436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 64) { 733436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 734436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 735436c6c9cSStella Laurenzo // i64 7365d6d30edSStella Laurenzo return bufferInfo<int64_t>(shapedType); 737e5639b3fSMehdi Amini } 738e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 739436c6c9cSStella Laurenzo // unsigned i64 7405d6d30edSStella Laurenzo return bufferInfo<uint64_t>(shapedType); 7415d6d30edSStella Laurenzo } 7425d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 7435d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 8) { 7445d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 7455d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 7465d6d30edSStella Laurenzo // i8 7475d6d30edSStella Laurenzo return bufferInfo<int8_t>(shapedType); 748e5639b3fSMehdi Amini } 749e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 7505d6d30edSStella Laurenzo // unsigned i8 7515d6d30edSStella Laurenzo return bufferInfo<uint8_t>(shapedType); 7525d6d30edSStella Laurenzo } 7535d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 7545d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 16) { 7555d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 7565d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 7575d6d30edSStella Laurenzo // i16 7585d6d30edSStella Laurenzo return bufferInfo<int16_t>(shapedType); 759e5639b3fSMehdi Amini } 760e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 7615d6d30edSStella Laurenzo // unsigned i16 7625d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType); 763436c6c9cSStella Laurenzo } 764436c6c9cSStella Laurenzo } 765436c6c9cSStella Laurenzo 766c5f445d1SStella Laurenzo // TODO: Currently crashes the program. 7675d6d30edSStella Laurenzo // Reported as https://github.com/pybind/pybind11/issues/3336 768c5f445d1SStella Laurenzo throw std::invalid_argument( 769c5f445d1SStella Laurenzo "unsupported data type for conversion to Python buffer"); 770436c6c9cSStella Laurenzo } 771436c6c9cSStella Laurenzo 772436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 773436c6c9cSStella Laurenzo c.def("__len__", &PyDenseElementsAttribute::dunderLen) 774436c6c9cSStella Laurenzo .def_static("get", PyDenseElementsAttribute::getFromBuffer, 775436c6c9cSStella Laurenzo py::arg("array"), py::arg("signless") = true, 7765d6d30edSStella Laurenzo py::arg("type") = py::none(), py::arg("shape") = py::none(), 777436c6c9cSStella Laurenzo py::arg("context") = py::none(), 7785d6d30edSStella Laurenzo kDenseElementsAttrGetDocstring) 779436c6c9cSStella Laurenzo .def_static("get_splat", PyDenseElementsAttribute::getSplat, 780436c6c9cSStella Laurenzo py::arg("shaped_type"), py::arg("element_attr"), 781436c6c9cSStella Laurenzo "Gets a DenseElementsAttr where all values are the same") 782436c6c9cSStella Laurenzo .def_property_readonly("is_splat", 783436c6c9cSStella Laurenzo [](PyDenseElementsAttribute &self) -> bool { 784436c6c9cSStella Laurenzo return mlirDenseElementsAttrIsSplat(self); 785436c6c9cSStella Laurenzo }) 786436c6c9cSStella Laurenzo .def_buffer(&PyDenseElementsAttribute::accessBuffer); 787436c6c9cSStella Laurenzo } 788436c6c9cSStella Laurenzo 789436c6c9cSStella Laurenzo private: 790436c6c9cSStella Laurenzo static bool isUnsignedIntegerFormat(const std::string &format) { 791436c6c9cSStella Laurenzo if (format.empty()) 792436c6c9cSStella Laurenzo return false; 793436c6c9cSStella Laurenzo char code = format[0]; 794436c6c9cSStella Laurenzo return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 795436c6c9cSStella Laurenzo code == 'Q'; 796436c6c9cSStella Laurenzo } 797436c6c9cSStella Laurenzo 798436c6c9cSStella Laurenzo static bool isSignedIntegerFormat(const std::string &format) { 799436c6c9cSStella Laurenzo if (format.empty()) 800436c6c9cSStella Laurenzo return false; 801436c6c9cSStella Laurenzo char code = format[0]; 802436c6c9cSStella Laurenzo return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 803436c6c9cSStella Laurenzo code == 'q'; 804436c6c9cSStella Laurenzo } 805436c6c9cSStella Laurenzo 806436c6c9cSStella Laurenzo template <typename Type> 807436c6c9cSStella Laurenzo py::buffer_info bufferInfo(MlirType shapedType, 8085d6d30edSStella Laurenzo const char *explicitFormat = nullptr) { 809436c6c9cSStella Laurenzo intptr_t rank = mlirShapedTypeGetRank(shapedType); 810436c6c9cSStella Laurenzo // Prepare the data for the buffer_info. 811436c6c9cSStella Laurenzo // Buffer is configured for read-only access below. 812436c6c9cSStella Laurenzo Type *data = static_cast<Type *>( 813436c6c9cSStella Laurenzo const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 814436c6c9cSStella Laurenzo // Prepare the shape for the buffer_info. 815436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> shape; 816436c6c9cSStella Laurenzo for (intptr_t i = 0; i < rank; ++i) 817436c6c9cSStella Laurenzo shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 818436c6c9cSStella Laurenzo // Prepare the strides for the buffer_info. 819436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> strides; 820436c6c9cSStella Laurenzo intptr_t strideFactor = 1; 821436c6c9cSStella Laurenzo for (intptr_t i = 1; i < rank; ++i) { 822436c6c9cSStella Laurenzo strideFactor = 1; 823436c6c9cSStella Laurenzo for (intptr_t j = i; j < rank; ++j) { 824436c6c9cSStella Laurenzo strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 825436c6c9cSStella Laurenzo } 826436c6c9cSStella Laurenzo strides.push_back(sizeof(Type) * strideFactor); 827436c6c9cSStella Laurenzo } 828436c6c9cSStella Laurenzo strides.push_back(sizeof(Type)); 8295d6d30edSStella Laurenzo std::string format; 8305d6d30edSStella Laurenzo if (explicitFormat) { 8315d6d30edSStella Laurenzo format = explicitFormat; 8325d6d30edSStella Laurenzo } else { 8335d6d30edSStella Laurenzo format = py::format_descriptor<Type>::format(); 8345d6d30edSStella Laurenzo } 8355d6d30edSStella Laurenzo return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, 8365d6d30edSStella Laurenzo /*readonly=*/true); 837436c6c9cSStella Laurenzo } 838436c6c9cSStella Laurenzo }; // namespace 839436c6c9cSStella Laurenzo 840436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer 841436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access. 842436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute 843436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseIntElementsAttribute, 844436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 845436c6c9cSStella Laurenzo public: 846436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 847436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseIntElementsAttr"; 848436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 849436c6c9cSStella Laurenzo 850436c6c9cSStella Laurenzo /// Returns the element at the given linear position. Asserts if the index is 851436c6c9cSStella Laurenzo /// out of range. 852436c6c9cSStella Laurenzo py::int_ dunderGetItem(intptr_t pos) { 853436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 854436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 855436c6c9cSStella Laurenzo "attempt to access out of bounds element"); 856436c6c9cSStella Laurenzo } 857436c6c9cSStella Laurenzo 858436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 859436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 860436c6c9cSStella Laurenzo assert(mlirTypeIsAInteger(type) && 861436c6c9cSStella Laurenzo "expected integer element type in dense int elements attribute"); 862436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 863436c6c9cSStella Laurenzo // elemental type of the attribute. py::int_ is implicitly constructible 864436c6c9cSStella Laurenzo // from any C++ integral type and handles bitwidth correctly. 865436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 866436c6c9cSStella Laurenzo // querying them on each element access. 867436c6c9cSStella Laurenzo unsigned width = mlirIntegerTypeGetWidth(type); 868436c6c9cSStella Laurenzo bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 869436c6c9cSStella Laurenzo if (isUnsigned) { 870436c6c9cSStella Laurenzo if (width == 1) { 871436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 872436c6c9cSStella Laurenzo } 873308d8b8cSRahul Kayaith if (width == 8) { 874308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetUInt8Value(*this, pos); 875308d8b8cSRahul Kayaith } 876308d8b8cSRahul Kayaith if (width == 16) { 877308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetUInt16Value(*this, pos); 878308d8b8cSRahul Kayaith } 879436c6c9cSStella Laurenzo if (width == 32) { 880436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt32Value(*this, pos); 881436c6c9cSStella Laurenzo } 882436c6c9cSStella Laurenzo if (width == 64) { 883436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt64Value(*this, pos); 884436c6c9cSStella Laurenzo } 885436c6c9cSStella Laurenzo } else { 886436c6c9cSStella Laurenzo if (width == 1) { 887436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 888436c6c9cSStella Laurenzo } 889308d8b8cSRahul Kayaith if (width == 8) { 890308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetInt8Value(*this, pos); 891308d8b8cSRahul Kayaith } 892308d8b8cSRahul Kayaith if (width == 16) { 893308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetInt16Value(*this, pos); 894308d8b8cSRahul Kayaith } 895436c6c9cSStella Laurenzo if (width == 32) { 896436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt32Value(*this, pos); 897436c6c9cSStella Laurenzo } 898436c6c9cSStella Laurenzo if (width == 64) { 899436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt64Value(*this, pos); 900436c6c9cSStella Laurenzo } 901436c6c9cSStella Laurenzo } 902436c6c9cSStella Laurenzo throw SetPyError(PyExc_TypeError, "Unsupported integer type"); 903436c6c9cSStella Laurenzo } 904436c6c9cSStella Laurenzo 905436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 906436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 907436c6c9cSStella Laurenzo } 908436c6c9cSStella Laurenzo }; 909436c6c9cSStella Laurenzo 910436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 911436c6c9cSStella Laurenzo public: 912436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 913436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DictAttr"; 914436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 915436c6c9cSStella Laurenzo 916436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 917436c6c9cSStella Laurenzo 9189fb1086bSAdrian Kuegel bool dunderContains(const std::string &name) { 9199fb1086bSAdrian Kuegel return !mlirAttributeIsNull( 9209fb1086bSAdrian Kuegel mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); 9219fb1086bSAdrian Kuegel } 9229fb1086bSAdrian Kuegel 923436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 9249fb1086bSAdrian Kuegel c.def("__contains__", &PyDictAttribute::dunderContains); 925436c6c9cSStella Laurenzo c.def("__len__", &PyDictAttribute::dunderLen); 926436c6c9cSStella Laurenzo c.def_static( 927436c6c9cSStella Laurenzo "get", 928436c6c9cSStella Laurenzo [](py::dict attributes, DefaultingPyMlirContext context) { 929436c6c9cSStella Laurenzo SmallVector<MlirNamedAttribute> mlirNamedAttributes; 930436c6c9cSStella Laurenzo mlirNamedAttributes.reserve(attributes.size()); 931436c6c9cSStella Laurenzo for (auto &it : attributes) { 93202b6fb21SMehdi Amini auto &mlirAttr = it.second.cast<PyAttribute &>(); 933436c6c9cSStella Laurenzo auto name = it.first.cast<std::string>(); 934436c6c9cSStella Laurenzo mlirNamedAttributes.push_back(mlirNamedAttributeGet( 93502b6fb21SMehdi Amini mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), 936436c6c9cSStella Laurenzo toMlirStringRef(name)), 93702b6fb21SMehdi Amini mlirAttr)); 938436c6c9cSStella Laurenzo } 939436c6c9cSStella Laurenzo MlirAttribute attr = 940436c6c9cSStella Laurenzo mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 941436c6c9cSStella Laurenzo mlirNamedAttributes.data()); 942436c6c9cSStella Laurenzo return PyDictAttribute(context->getRef(), attr); 943436c6c9cSStella Laurenzo }, 944ed9e52f3SAlex Zinenko py::arg("value") = py::dict(), py::arg("context") = py::none(), 945436c6c9cSStella Laurenzo "Gets an uniqued dict attribute"); 946436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 947436c6c9cSStella Laurenzo MlirAttribute attr = 948436c6c9cSStella Laurenzo mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 949436c6c9cSStella Laurenzo if (mlirAttributeIsNull(attr)) { 950436c6c9cSStella Laurenzo throw SetPyError(PyExc_KeyError, 951436c6c9cSStella Laurenzo "attempt to access a non-existent attribute"); 952436c6c9cSStella Laurenzo } 953436c6c9cSStella Laurenzo return PyAttribute(self.getContext(), attr); 954436c6c9cSStella Laurenzo }); 955436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 956436c6c9cSStella Laurenzo if (index < 0 || index >= self.dunderLen()) { 957436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 958436c6c9cSStella Laurenzo "attempt to access out of bounds attribute"); 959436c6c9cSStella Laurenzo } 960436c6c9cSStella Laurenzo MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 961436c6c9cSStella Laurenzo return PyNamedAttribute( 962436c6c9cSStella Laurenzo namedAttr.attribute, 963436c6c9cSStella Laurenzo std::string(mlirIdentifierStr(namedAttr.name).data)); 964436c6c9cSStella Laurenzo }); 965436c6c9cSStella Laurenzo } 966436c6c9cSStella Laurenzo }; 967436c6c9cSStella Laurenzo 968436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing 969436c6c9cSStella Laurenzo /// floating-point values. Supports element access. 970436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute 971436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseFPElementsAttribute, 972436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 973436c6c9cSStella Laurenzo public: 974436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 975436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseFPElementsAttr"; 976436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 977436c6c9cSStella Laurenzo 978436c6c9cSStella Laurenzo py::float_ dunderGetItem(intptr_t pos) { 979436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 980436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 981436c6c9cSStella Laurenzo "attempt to access out of bounds element"); 982436c6c9cSStella Laurenzo } 983436c6c9cSStella Laurenzo 984436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 985436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 986436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 987436c6c9cSStella Laurenzo // elemental type of the attribute. py::float_ is implicitly constructible 988436c6c9cSStella Laurenzo // from float and double. 989436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 990436c6c9cSStella Laurenzo // querying them on each element access. 991436c6c9cSStella Laurenzo if (mlirTypeIsAF32(type)) { 992436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetFloatValue(*this, pos); 993436c6c9cSStella Laurenzo } 994436c6c9cSStella Laurenzo if (mlirTypeIsAF64(type)) { 995436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetDoubleValue(*this, pos); 996436c6c9cSStella Laurenzo } 997436c6c9cSStella Laurenzo throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); 998436c6c9cSStella Laurenzo } 999436c6c9cSStella Laurenzo 1000436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1001436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 1002436c6c9cSStella Laurenzo } 1003436c6c9cSStella Laurenzo }; 1004436c6c9cSStella Laurenzo 1005436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 1006436c6c9cSStella Laurenzo public: 1007436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 1008436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "TypeAttr"; 1009436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1010436c6c9cSStella Laurenzo 1011436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1012436c6c9cSStella Laurenzo c.def_static( 1013436c6c9cSStella Laurenzo "get", 1014436c6c9cSStella Laurenzo [](PyType value, DefaultingPyMlirContext context) { 1015436c6c9cSStella Laurenzo MlirAttribute attr = mlirTypeAttrGet(value.get()); 1016436c6c9cSStella Laurenzo return PyTypeAttribute(context->getRef(), attr); 1017436c6c9cSStella Laurenzo }, 1018436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 1019436c6c9cSStella Laurenzo "Gets a uniqued Type attribute"); 1020436c6c9cSStella Laurenzo c.def_property_readonly("value", [](PyTypeAttribute &self) { 1021436c6c9cSStella Laurenzo return PyType(self.getContext()->getRef(), 1022436c6c9cSStella Laurenzo mlirTypeAttrGetValue(self.get())); 1023436c6c9cSStella Laurenzo }); 1024436c6c9cSStella Laurenzo } 1025436c6c9cSStella Laurenzo }; 1026436c6c9cSStella Laurenzo 1027436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values. 1028436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 1029436c6c9cSStella Laurenzo public: 1030436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 1031436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "UnitAttr"; 1032436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1033436c6c9cSStella Laurenzo 1034436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1035436c6c9cSStella Laurenzo c.def_static( 1036436c6c9cSStella Laurenzo "get", 1037436c6c9cSStella Laurenzo [](DefaultingPyMlirContext context) { 1038436c6c9cSStella Laurenzo return PyUnitAttribute(context->getRef(), 1039436c6c9cSStella Laurenzo mlirUnitAttrGet(context->get())); 1040436c6c9cSStella Laurenzo }, 1041436c6c9cSStella Laurenzo py::arg("context") = py::none(), "Create a Unit attribute."); 1042436c6c9cSStella Laurenzo } 1043436c6c9cSStella Laurenzo }; 1044436c6c9cSStella Laurenzo 1045436c6c9cSStella Laurenzo } // namespace 1046436c6c9cSStella Laurenzo 1047436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) { 1048436c6c9cSStella Laurenzo PyAffineMapAttribute::bind(m); 1049*619fd8c2SJeff Niu 1050*619fd8c2SJeff Niu PyDenseBoolArrayAttribute::bind(m); 1051*619fd8c2SJeff Niu PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); 1052*619fd8c2SJeff Niu PyDenseI8ArrayAttribute::bind(m); 1053*619fd8c2SJeff Niu PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m); 1054*619fd8c2SJeff Niu PyDenseI16ArrayAttribute::bind(m); 1055*619fd8c2SJeff Niu PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m); 1056*619fd8c2SJeff Niu PyDenseI32ArrayAttribute::bind(m); 1057*619fd8c2SJeff Niu PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m); 1058*619fd8c2SJeff Niu PyDenseI64ArrayAttribute::bind(m); 1059*619fd8c2SJeff Niu PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m); 1060*619fd8c2SJeff Niu PyDenseF32ArrayAttribute::bind(m); 1061*619fd8c2SJeff Niu PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); 1062*619fd8c2SJeff Niu PyDenseF64ArrayAttribute::bind(m); 1063*619fd8c2SJeff Niu PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); 1064*619fd8c2SJeff Niu 1065436c6c9cSStella Laurenzo PyArrayAttribute::bind(m); 1066436c6c9cSStella Laurenzo PyArrayAttribute::PyArrayAttributeIterator::bind(m); 1067436c6c9cSStella Laurenzo PyBoolAttribute::bind(m); 1068436c6c9cSStella Laurenzo PyDenseElementsAttribute::bind(m); 1069436c6c9cSStella Laurenzo PyDenseFPElementsAttribute::bind(m); 1070436c6c9cSStella Laurenzo PyDenseIntElementsAttribute::bind(m); 1071436c6c9cSStella Laurenzo PyDictAttribute::bind(m); 1072436c6c9cSStella Laurenzo PyFlatSymbolRefAttribute::bind(m); 10735c3861b2SYun Long PyOpaqueAttribute::bind(m); 1074436c6c9cSStella Laurenzo PyFloatAttribute::bind(m); 1075436c6c9cSStella Laurenzo PyIntegerAttribute::bind(m); 1076436c6c9cSStella Laurenzo PyStringAttribute::bind(m); 1077436c6c9cSStella Laurenzo PyTypeAttribute::bind(m); 1078436c6c9cSStella Laurenzo PyUnitAttribute::bind(m); 1079436c6c9cSStella Laurenzo } 1080