1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2436c6c9cSStella Laurenzo // 3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6436c6c9cSStella Laurenzo // 7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8436c6c9cSStella Laurenzo 91fc096afSMehdi Amini #include <utility> 10a1fe1f5fSKazu Hirata #include <optional> 111fc096afSMehdi Amini 12436c6c9cSStella Laurenzo #include "IRModule.h" 13436c6c9cSStella Laurenzo 14436c6c9cSStella Laurenzo #include "PybindUtils.h" 15436c6c9cSStella Laurenzo 16436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h" 17436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h" 18436c6c9cSStella Laurenzo 19436c6c9cSStella Laurenzo namespace py = pybind11; 20436c6c9cSStella Laurenzo using namespace mlir; 21436c6c9cSStella Laurenzo using namespace mlir::python; 22436c6c9cSStella Laurenzo 23436c6c9cSStella Laurenzo using llvm::SmallVector; 24436c6c9cSStella Laurenzo using llvm::Twine; 25436c6c9cSStella Laurenzo 265d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 275d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline). 285d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 295d6d30edSStella Laurenzo 305d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] = 315d6d30edSStella Laurenzo R"(Gets a DenseElementsAttr from a Python buffer or array. 325d6d30edSStella Laurenzo 335d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based 345d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and 355d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these 365d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer. 375d6d30edSStella Laurenzo 385d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided 395d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal 405d6d30edSStella Laurenzo representation: 415d6d30edSStella Laurenzo 425d6d30edSStella Laurenzo * Integer types (except for i1): the buffer must be byte aligned to the 435d6d30edSStella Laurenzo next byte boundary. 445d6d30edSStella Laurenzo * Floating point types: Must be bit-castable to the given floating point 455d6d30edSStella Laurenzo size. 465d6d30edSStella Laurenzo * i1 (bool): Bit packed into 8bit words where the bit pattern matches a 475d6d30edSStella Laurenzo row major ordering. An arbitrary Numpy `bool_` array can be bit packed to 485d6d30edSStella Laurenzo this specification with: `np.packbits(ary, axis=None, bitorder='little')`. 495d6d30edSStella Laurenzo 505d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0 515d6d30edSStella Laurenzo or 255), then a splat will be created. 525d6d30edSStella Laurenzo 535d6d30edSStella Laurenzo Args: 545d6d30edSStella Laurenzo array: The array or buffer to convert. 555d6d30edSStella Laurenzo signless: If inferring an appropriate MLIR type, use signless types for 565d6d30edSStella Laurenzo integers (defaults True). 575d6d30edSStella Laurenzo type: Skips inference of the MLIR element type and uses this instead. The 585d6d30edSStella Laurenzo storage size must be consistent with the actual contents of the buffer. 595d6d30edSStella Laurenzo shape: Overrides the shape of the buffer when constructing the MLIR 605d6d30edSStella Laurenzo shaped type. This is needed when the physical and logical shape differ (as 615d6d30edSStella Laurenzo for i1). 625d6d30edSStella Laurenzo context: Explicit context, if not from context manager. 635d6d30edSStella Laurenzo 645d6d30edSStella Laurenzo Returns: 655d6d30edSStella Laurenzo DenseElementsAttr on success. 665d6d30edSStella Laurenzo 675d6d30edSStella Laurenzo Raises: 685d6d30edSStella Laurenzo ValueError: If the type of the buffer or array cannot be matched to an MLIR 695d6d30edSStella Laurenzo type or if the buffer does not meet expectations. 705d6d30edSStella Laurenzo )"; 715d6d30edSStella Laurenzo 72436c6c9cSStella Laurenzo namespace { 73436c6c9cSStella Laurenzo 74436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) { 75436c6c9cSStella Laurenzo return mlirStringRefCreate(s.data(), s.size()); 76436c6c9cSStella Laurenzo } 77436c6c9cSStella Laurenzo 78436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 79436c6c9cSStella Laurenzo public: 80436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 81436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineMapAttr"; 82436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 83436c6c9cSStella Laurenzo 84436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 85436c6c9cSStella Laurenzo c.def_static( 86436c6c9cSStella Laurenzo "get", 87436c6c9cSStella Laurenzo [](PyAffineMap &affineMap) { 88436c6c9cSStella Laurenzo MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 89436c6c9cSStella Laurenzo return PyAffineMapAttribute(affineMap.getContext(), attr); 90436c6c9cSStella Laurenzo }, 91436c6c9cSStella Laurenzo py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 92436c6c9cSStella Laurenzo } 93436c6c9cSStella Laurenzo }; 94436c6c9cSStella Laurenzo 95ed9e52f3SAlex Zinenko template <typename T> 96ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) { 97ed9e52f3SAlex Zinenko try { 98ed9e52f3SAlex Zinenko return object.cast<T>(); 99ed9e52f3SAlex Zinenko } catch (py::cast_error &err) { 100ed9e52f3SAlex Zinenko std::string msg = 101ed9e52f3SAlex Zinenko std::string( 102ed9e52f3SAlex Zinenko "Invalid attribute when attempting to create an ArrayAttribute (") + 103ed9e52f3SAlex Zinenko err.what() + ")"; 104ed9e52f3SAlex Zinenko throw py::cast_error(msg); 105ed9e52f3SAlex Zinenko } catch (py::reference_cast_error &err) { 106ed9e52f3SAlex Zinenko std::string msg = std::string("Invalid attribute (None?) when attempting " 107ed9e52f3SAlex Zinenko "to create an ArrayAttribute (") + 108ed9e52f3SAlex Zinenko err.what() + ")"; 109ed9e52f3SAlex Zinenko throw py::cast_error(msg); 110ed9e52f3SAlex Zinenko } 111ed9e52f3SAlex Zinenko } 112ed9e52f3SAlex Zinenko 113619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived 114619fd8c2SJeff Niu /// implementation class. 115619fd8c2SJeff Niu template <typename EltTy, typename DerivedT> 116133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> { 117619fd8c2SJeff Niu public: 118133624acSJeff Niu using PyConcreteAttribute<DerivedT>::PyConcreteAttribute; 119619fd8c2SJeff Niu 120619fd8c2SJeff Niu /// Iterator over the integer elements of a dense array. 121619fd8c2SJeff Niu class PyDenseArrayIterator { 122619fd8c2SJeff Niu public: 1234a1b1196SMehdi Amini PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {} 124619fd8c2SJeff Niu 125619fd8c2SJeff Niu /// Return a copy of the iterator. 126619fd8c2SJeff Niu PyDenseArrayIterator dunderIter() { return *this; } 127619fd8c2SJeff Niu 128619fd8c2SJeff Niu /// Return the next element. 129619fd8c2SJeff Niu EltTy dunderNext() { 130619fd8c2SJeff Niu // Throw if the index has reached the end. 131619fd8c2SJeff Niu if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) 132619fd8c2SJeff Niu throw py::stop_iteration(); 133619fd8c2SJeff Niu return DerivedT::getElement(attr.get(), nextIndex++); 134619fd8c2SJeff Niu } 135619fd8c2SJeff Niu 136619fd8c2SJeff Niu /// Bind the iterator class. 137619fd8c2SJeff Niu static void bind(py::module &m) { 138619fd8c2SJeff Niu py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName, 139619fd8c2SJeff Niu py::module_local()) 140619fd8c2SJeff Niu .def("__iter__", &PyDenseArrayIterator::dunderIter) 141619fd8c2SJeff Niu .def("__next__", &PyDenseArrayIterator::dunderNext); 142619fd8c2SJeff Niu } 143619fd8c2SJeff Niu 144619fd8c2SJeff Niu private: 145619fd8c2SJeff Niu /// The referenced dense array attribute. 146619fd8c2SJeff Niu PyAttribute attr; 147619fd8c2SJeff Niu /// The next index to read. 148619fd8c2SJeff Niu int nextIndex = 0; 149619fd8c2SJeff Niu }; 150619fd8c2SJeff Niu 151619fd8c2SJeff Niu /// Get the element at the given index. 152619fd8c2SJeff Niu EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); } 153619fd8c2SJeff Niu 154619fd8c2SJeff Niu /// Bind the attribute class. 155133624acSJeff Niu static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) { 156619fd8c2SJeff Niu // Bind the constructor. 157619fd8c2SJeff Niu c.def_static( 158619fd8c2SJeff Niu "get", 159619fd8c2SJeff Niu [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) { 160619fd8c2SJeff Niu MlirAttribute attr = 161619fd8c2SJeff Niu DerivedT::getAttribute(ctx->get(), values.size(), values.data()); 162133624acSJeff Niu return DerivedT(ctx->getRef(), attr); 163619fd8c2SJeff Niu }, 164619fd8c2SJeff Niu py::arg("values"), py::arg("context") = py::none(), 165619fd8c2SJeff Niu "Gets a uniqued dense array attribute"); 166619fd8c2SJeff Niu // Bind the array methods. 167133624acSJeff Niu c.def("__getitem__", [](DerivedT &arr, intptr_t i) { 168619fd8c2SJeff Niu if (i >= mlirDenseArrayGetNumElements(arr)) 169619fd8c2SJeff Niu throw py::index_error("DenseArray index out of range"); 170619fd8c2SJeff Niu return arr.getItem(i); 171619fd8c2SJeff Niu }); 172133624acSJeff Niu c.def("__len__", [](const DerivedT &arr) { 173619fd8c2SJeff Niu return mlirDenseArrayGetNumElements(arr); 174619fd8c2SJeff Niu }); 175133624acSJeff Niu c.def("__iter__", 176133624acSJeff Niu [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); 1774a1b1196SMehdi Amini c.def("__add__", [](DerivedT &arr, const py::list &extras) { 178619fd8c2SJeff Niu std::vector<EltTy> values; 179619fd8c2SJeff Niu intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); 180619fd8c2SJeff Niu values.reserve(numOldElements + py::len(extras)); 181619fd8c2SJeff Niu for (intptr_t i = 0; i < numOldElements; ++i) 182619fd8c2SJeff Niu values.push_back(arr.getItem(i)); 183619fd8c2SJeff Niu for (py::handle attr : extras) 184619fd8c2SJeff Niu values.push_back(pyTryCast<EltTy>(attr)); 185619fd8c2SJeff Niu MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(), 186619fd8c2SJeff Niu values.size(), values.data()); 187133624acSJeff Niu return DerivedT(arr.getContext(), attr); 188619fd8c2SJeff Niu }); 189619fd8c2SJeff Niu } 190619fd8c2SJeff Niu }; 191619fd8c2SJeff Niu 192619fd8c2SJeff Niu /// Instantiate the python dense array classes. 193619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute 194619fd8c2SJeff Niu : public PyDenseArrayAttribute<int, PyDenseBoolArrayAttribute> { 195619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; 196619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseBoolArrayGet; 197619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseBoolArrayGetElement; 198619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseBoolArrayAttr"; 199619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseBoolArrayIterator"; 200619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 201619fd8c2SJeff Niu }; 202619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute 203619fd8c2SJeff Niu : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> { 204619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array; 205619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI8ArrayGet; 206619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI8ArrayGetElement; 207619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI8ArrayAttr"; 208619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI8ArrayIterator"; 209619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 210619fd8c2SJeff Niu }; 211619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute 212619fd8c2SJeff Niu : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> { 213619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array; 214619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI16ArrayGet; 215619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI16ArrayGetElement; 216619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI16ArrayAttr"; 217619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI16ArrayIterator"; 218619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 219619fd8c2SJeff Niu }; 220619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute 221619fd8c2SJeff Niu : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> { 222619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array; 223619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI32ArrayGet; 224619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI32ArrayGetElement; 225619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI32ArrayAttr"; 226619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI32ArrayIterator"; 227619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 228619fd8c2SJeff Niu }; 229619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute 230619fd8c2SJeff Niu : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> { 231619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array; 232619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI64ArrayGet; 233619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI64ArrayGetElement; 234619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI64ArrayAttr"; 235619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI64ArrayIterator"; 236619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 237619fd8c2SJeff Niu }; 238619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute 239619fd8c2SJeff Niu : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> { 240619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array; 241619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseF32ArrayGet; 242619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseF32ArrayGetElement; 243619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseF32ArrayAttr"; 244619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseF32ArrayIterator"; 245619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 246619fd8c2SJeff Niu }; 247619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute 248619fd8c2SJeff Niu : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> { 249619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array; 250619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseF64ArrayGet; 251619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseF64ArrayGetElement; 252619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseF64ArrayAttr"; 253619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseF64ArrayIterator"; 254619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 255619fd8c2SJeff Niu }; 256619fd8c2SJeff Niu 257436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 258436c6c9cSStella Laurenzo public: 259436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 260436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "ArrayAttr"; 261436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 262436c6c9cSStella Laurenzo 263436c6c9cSStella Laurenzo class PyArrayAttributeIterator { 264436c6c9cSStella Laurenzo public: 2651fc096afSMehdi Amini PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} 266436c6c9cSStella Laurenzo 267436c6c9cSStella Laurenzo PyArrayAttributeIterator &dunderIter() { return *this; } 268436c6c9cSStella Laurenzo 269436c6c9cSStella Laurenzo PyAttribute dunderNext() { 270bca88952SJeff Niu // TODO: Throw is an inefficient way to stop iteration. 271bca88952SJeff Niu if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) 272436c6c9cSStella Laurenzo throw py::stop_iteration(); 273436c6c9cSStella Laurenzo return PyAttribute(attr.getContext(), 274436c6c9cSStella Laurenzo mlirArrayAttrGetElement(attr.get(), nextIndex++)); 275436c6c9cSStella Laurenzo } 276436c6c9cSStella Laurenzo 277436c6c9cSStella Laurenzo static void bind(py::module &m) { 278f05ff4f7SStella Laurenzo py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator", 279f05ff4f7SStella Laurenzo py::module_local()) 280436c6c9cSStella Laurenzo .def("__iter__", &PyArrayAttributeIterator::dunderIter) 281436c6c9cSStella Laurenzo .def("__next__", &PyArrayAttributeIterator::dunderNext); 282436c6c9cSStella Laurenzo } 283436c6c9cSStella Laurenzo 284436c6c9cSStella Laurenzo private: 285436c6c9cSStella Laurenzo PyAttribute attr; 286436c6c9cSStella Laurenzo int nextIndex = 0; 287436c6c9cSStella Laurenzo }; 288436c6c9cSStella Laurenzo 289ed9e52f3SAlex Zinenko PyAttribute getItem(intptr_t i) { 290ed9e52f3SAlex Zinenko return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i)); 291ed9e52f3SAlex Zinenko } 292ed9e52f3SAlex Zinenko 293436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 294436c6c9cSStella Laurenzo c.def_static( 295436c6c9cSStella Laurenzo "get", 296436c6c9cSStella Laurenzo [](py::list attributes, DefaultingPyMlirContext context) { 297436c6c9cSStella Laurenzo SmallVector<MlirAttribute> mlirAttributes; 298436c6c9cSStella Laurenzo mlirAttributes.reserve(py::len(attributes)); 299436c6c9cSStella Laurenzo for (auto attribute : attributes) { 300ed9e52f3SAlex Zinenko mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); 301436c6c9cSStella Laurenzo } 302436c6c9cSStella Laurenzo MlirAttribute attr = mlirArrayAttrGet( 303436c6c9cSStella Laurenzo context->get(), mlirAttributes.size(), mlirAttributes.data()); 304436c6c9cSStella Laurenzo return PyArrayAttribute(context->getRef(), attr); 305436c6c9cSStella Laurenzo }, 306436c6c9cSStella Laurenzo py::arg("attributes"), py::arg("context") = py::none(), 307436c6c9cSStella Laurenzo "Gets a uniqued Array attribute"); 308436c6c9cSStella Laurenzo c.def("__getitem__", 309436c6c9cSStella Laurenzo [](PyArrayAttribute &arr, intptr_t i) { 310436c6c9cSStella Laurenzo if (i >= mlirArrayAttrGetNumElements(arr)) 311436c6c9cSStella Laurenzo throw py::index_error("ArrayAttribute index out of range"); 312ed9e52f3SAlex Zinenko return arr.getItem(i); 313436c6c9cSStella Laurenzo }) 314436c6c9cSStella Laurenzo .def("__len__", 315436c6c9cSStella Laurenzo [](const PyArrayAttribute &arr) { 316436c6c9cSStella Laurenzo return mlirArrayAttrGetNumElements(arr); 317436c6c9cSStella Laurenzo }) 318436c6c9cSStella Laurenzo .def("__iter__", [](const PyArrayAttribute &arr) { 319436c6c9cSStella Laurenzo return PyArrayAttributeIterator(arr); 320436c6c9cSStella Laurenzo }); 321ed9e52f3SAlex Zinenko c.def("__add__", [](PyArrayAttribute arr, py::list extras) { 322ed9e52f3SAlex Zinenko std::vector<MlirAttribute> attributes; 323ed9e52f3SAlex Zinenko intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); 324ed9e52f3SAlex Zinenko attributes.reserve(numOldElements + py::len(extras)); 325ed9e52f3SAlex Zinenko for (intptr_t i = 0; i < numOldElements; ++i) 326ed9e52f3SAlex Zinenko attributes.push_back(arr.getItem(i)); 327ed9e52f3SAlex Zinenko for (py::handle attr : extras) 328ed9e52f3SAlex Zinenko attributes.push_back(pyTryCast<PyAttribute>(attr)); 329ed9e52f3SAlex Zinenko MlirAttribute arrayAttr = mlirArrayAttrGet( 330ed9e52f3SAlex Zinenko arr.getContext()->get(), attributes.size(), attributes.data()); 331ed9e52f3SAlex Zinenko return PyArrayAttribute(arr.getContext(), arrayAttr); 332ed9e52f3SAlex Zinenko }); 333436c6c9cSStella Laurenzo } 334436c6c9cSStella Laurenzo }; 335436c6c9cSStella Laurenzo 336436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr. 337436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 338436c6c9cSStella Laurenzo public: 339436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 340436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FloatAttr"; 341436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 342436c6c9cSStella Laurenzo 343436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 344436c6c9cSStella Laurenzo c.def_static( 345436c6c9cSStella Laurenzo "get", 346436c6c9cSStella Laurenzo [](PyType &type, double value, DefaultingPyLocation loc) { 3473ea4c501SRahul Kayaith PyMlirContext::ErrorCapture errors(loc->getContext()); 348436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 3493ea4c501SRahul Kayaith if (mlirAttributeIsNull(attr)) 3503ea4c501SRahul Kayaith throw MLIRError("Invalid attribute", errors.take()); 351436c6c9cSStella Laurenzo return PyFloatAttribute(type.getContext(), attr); 352436c6c9cSStella Laurenzo }, 353436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 354436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a type"); 355436c6c9cSStella Laurenzo c.def_static( 356436c6c9cSStella Laurenzo "get_f32", 357436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 358436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 359436c6c9cSStella Laurenzo context->get(), mlirF32TypeGet(context->get()), value); 360436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 361436c6c9cSStella Laurenzo }, 362436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 363436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f32 type"); 364436c6c9cSStella Laurenzo c.def_static( 365436c6c9cSStella Laurenzo "get_f64", 366436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 367436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 368436c6c9cSStella Laurenzo context->get(), mlirF64TypeGet(context->get()), value); 369436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 370436c6c9cSStella Laurenzo }, 371436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 372436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f64 type"); 373436c6c9cSStella Laurenzo c.def_property_readonly( 374436c6c9cSStella Laurenzo "value", 375436c6c9cSStella Laurenzo [](PyFloatAttribute &self) { 376436c6c9cSStella Laurenzo return mlirFloatAttrGetValueDouble(self); 377436c6c9cSStella Laurenzo }, 378436c6c9cSStella Laurenzo "Returns the value of the float point attribute"); 379436c6c9cSStella Laurenzo } 380436c6c9cSStella Laurenzo }; 381436c6c9cSStella Laurenzo 382436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr. 383436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 384436c6c9cSStella Laurenzo public: 385436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 386436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "IntegerAttr"; 387436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 388436c6c9cSStella Laurenzo 389436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 390436c6c9cSStella Laurenzo c.def_static( 391436c6c9cSStella Laurenzo "get", 392436c6c9cSStella Laurenzo [](PyType &type, int64_t value) { 393436c6c9cSStella Laurenzo MlirAttribute attr = mlirIntegerAttrGet(type, value); 394436c6c9cSStella Laurenzo return PyIntegerAttribute(type.getContext(), attr); 395436c6c9cSStella Laurenzo }, 396436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), 397436c6c9cSStella Laurenzo "Gets an uniqued integer attribute associated to a type"); 398436c6c9cSStella Laurenzo c.def_property_readonly( 399436c6c9cSStella Laurenzo "value", 400e9db306dSrkayaith [](PyIntegerAttribute &self) -> py::int_ { 401e9db306dSrkayaith MlirType type = mlirAttributeGetType(self); 402e9db306dSrkayaith if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) 403436c6c9cSStella Laurenzo return mlirIntegerAttrGetValueInt(self); 404e9db306dSrkayaith if (mlirIntegerTypeIsSigned(type)) 405e9db306dSrkayaith return mlirIntegerAttrGetValueSInt(self); 406e9db306dSrkayaith return mlirIntegerAttrGetValueUInt(self); 407436c6c9cSStella Laurenzo }, 408436c6c9cSStella Laurenzo "Returns the value of the integer attribute"); 409436c6c9cSStella Laurenzo } 410436c6c9cSStella Laurenzo }; 411436c6c9cSStella Laurenzo 412436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr. 413436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 414436c6c9cSStella Laurenzo public: 415436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 416436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "BoolAttr"; 417436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 418436c6c9cSStella Laurenzo 419436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 420436c6c9cSStella Laurenzo c.def_static( 421436c6c9cSStella Laurenzo "get", 422436c6c9cSStella Laurenzo [](bool value, DefaultingPyMlirContext context) { 423436c6c9cSStella Laurenzo MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 424436c6c9cSStella Laurenzo return PyBoolAttribute(context->getRef(), attr); 425436c6c9cSStella Laurenzo }, 426436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 427436c6c9cSStella Laurenzo "Gets an uniqued bool attribute"); 428436c6c9cSStella Laurenzo c.def_property_readonly( 429436c6c9cSStella Laurenzo "value", 430436c6c9cSStella Laurenzo [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, 431436c6c9cSStella Laurenzo "Returns the value of the bool attribute"); 432436c6c9cSStella Laurenzo } 433436c6c9cSStella Laurenzo }; 434436c6c9cSStella Laurenzo 435436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute 436436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 437436c6c9cSStella Laurenzo public: 438436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 439436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 440436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 441436c6c9cSStella Laurenzo 442436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 443436c6c9cSStella Laurenzo c.def_static( 444436c6c9cSStella Laurenzo "get", 445436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 446436c6c9cSStella Laurenzo MlirAttribute attr = 447436c6c9cSStella Laurenzo mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 448436c6c9cSStella Laurenzo return PyFlatSymbolRefAttribute(context->getRef(), attr); 449436c6c9cSStella Laurenzo }, 450436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 451436c6c9cSStella Laurenzo "Gets a uniqued FlatSymbolRef attribute"); 452436c6c9cSStella Laurenzo c.def_property_readonly( 453436c6c9cSStella Laurenzo "value", 454436c6c9cSStella Laurenzo [](PyFlatSymbolRefAttribute &self) { 455436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 456436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 457436c6c9cSStella Laurenzo }, 458436c6c9cSStella Laurenzo "Returns the value of the FlatSymbolRef attribute as a string"); 459436c6c9cSStella Laurenzo } 460436c6c9cSStella Laurenzo }; 461436c6c9cSStella Laurenzo 4625c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> { 4635c3861b2SYun Long public: 4645c3861b2SYun Long static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; 4655c3861b2SYun Long static constexpr const char *pyClassName = "OpaqueAttr"; 4665c3861b2SYun Long using PyConcreteAttribute::PyConcreteAttribute; 4675c3861b2SYun Long 4685c3861b2SYun Long static void bindDerived(ClassTy &c) { 4695c3861b2SYun Long c.def_static( 4705c3861b2SYun Long "get", 4715c3861b2SYun Long [](std::string dialectNamespace, py::buffer buffer, PyType &type, 4725c3861b2SYun Long DefaultingPyMlirContext context) { 4735c3861b2SYun Long const py::buffer_info bufferInfo = buffer.request(); 4745c3861b2SYun Long intptr_t bufferSize = bufferInfo.size; 4755c3861b2SYun Long MlirAttribute attr = mlirOpaqueAttrGet( 4765c3861b2SYun Long context->get(), toMlirStringRef(dialectNamespace), bufferSize, 4775c3861b2SYun Long static_cast<char *>(bufferInfo.ptr), type); 4785c3861b2SYun Long return PyOpaqueAttribute(context->getRef(), attr); 4795c3861b2SYun Long }, 4805c3861b2SYun Long py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"), 4815c3861b2SYun Long py::arg("context") = py::none(), "Gets an Opaque attribute."); 4825c3861b2SYun Long c.def_property_readonly( 4835c3861b2SYun Long "dialect_namespace", 4845c3861b2SYun Long [](PyOpaqueAttribute &self) { 4855c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); 4865c3861b2SYun Long return py::str(stringRef.data, stringRef.length); 4875c3861b2SYun Long }, 4885c3861b2SYun Long "Returns the dialect namespace for the Opaque attribute as a string"); 4895c3861b2SYun Long c.def_property_readonly( 4905c3861b2SYun Long "data", 4915c3861b2SYun Long [](PyOpaqueAttribute &self) { 4925c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetData(self); 4935c3861b2SYun Long return py::str(stringRef.data, stringRef.length); 4945c3861b2SYun Long }, 4955c3861b2SYun Long "Returns the data for the Opaqued attributes as a string"); 4965c3861b2SYun Long } 4975c3861b2SYun Long }; 4985c3861b2SYun Long 499436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 500436c6c9cSStella Laurenzo public: 501436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 502436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "StringAttr"; 503436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 504436c6c9cSStella Laurenzo 505436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 506436c6c9cSStella Laurenzo c.def_static( 507436c6c9cSStella Laurenzo "get", 508436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 509436c6c9cSStella Laurenzo MlirAttribute attr = 510436c6c9cSStella Laurenzo mlirStringAttrGet(context->get(), toMlirStringRef(value)); 511436c6c9cSStella Laurenzo return PyStringAttribute(context->getRef(), attr); 512436c6c9cSStella Laurenzo }, 513436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 514436c6c9cSStella Laurenzo "Gets a uniqued string attribute"); 515436c6c9cSStella Laurenzo c.def_static( 516436c6c9cSStella Laurenzo "get_typed", 517436c6c9cSStella Laurenzo [](PyType &type, std::string value) { 518436c6c9cSStella Laurenzo MlirAttribute attr = 519436c6c9cSStella Laurenzo mlirStringAttrTypedGet(type, toMlirStringRef(value)); 520436c6c9cSStella Laurenzo return PyStringAttribute(type.getContext(), attr); 521436c6c9cSStella Laurenzo }, 522a6e7d024SStella Laurenzo py::arg("type"), py::arg("value"), 523436c6c9cSStella Laurenzo "Gets a uniqued string attribute associated to a type"); 524436c6c9cSStella Laurenzo c.def_property_readonly( 525436c6c9cSStella Laurenzo "value", 526436c6c9cSStella Laurenzo [](PyStringAttribute &self) { 527436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirStringAttrGetValue(self); 528436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 529436c6c9cSStella Laurenzo }, 530436c6c9cSStella Laurenzo "Returns the value of the string attribute"); 531436c6c9cSStella Laurenzo } 532436c6c9cSStella Laurenzo }; 533436c6c9cSStella Laurenzo 534436c6c9cSStella Laurenzo // TODO: Support construction of string elements. 535436c6c9cSStella Laurenzo class PyDenseElementsAttribute 536436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseElementsAttribute> { 537436c6c9cSStella Laurenzo public: 538436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 539436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseElementsAttr"; 540436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 541436c6c9cSStella Laurenzo 542436c6c9cSStella Laurenzo static PyDenseElementsAttribute 5430a81ace0SKazu Hirata getFromBuffer(py::buffer array, bool signless, 5440a81ace0SKazu Hirata std::optional<PyType> explicitType, 5450a81ace0SKazu Hirata std::optional<std::vector<int64_t>> explicitShape, 546436c6c9cSStella Laurenzo DefaultingPyMlirContext contextWrapper) { 547436c6c9cSStella Laurenzo // Request a contiguous view. In exotic cases, this will cause a copy. 548436c6c9cSStella Laurenzo int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; 549436c6c9cSStella Laurenzo Py_buffer *view = new Py_buffer(); 550436c6c9cSStella Laurenzo if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { 551436c6c9cSStella Laurenzo delete view; 552436c6c9cSStella Laurenzo throw py::error_already_set(); 553436c6c9cSStella Laurenzo } 554436c6c9cSStella Laurenzo py::buffer_info arrayInfo(view); 5555d6d30edSStella Laurenzo SmallVector<int64_t> shape; 5565d6d30edSStella Laurenzo if (explicitShape) { 5575d6d30edSStella Laurenzo shape.append(explicitShape->begin(), explicitShape->end()); 5585d6d30edSStella Laurenzo } else { 5595d6d30edSStella Laurenzo shape.append(arrayInfo.shape.begin(), 5605d6d30edSStella Laurenzo arrayInfo.shape.begin() + arrayInfo.ndim); 5615d6d30edSStella Laurenzo } 562436c6c9cSStella Laurenzo 5635d6d30edSStella Laurenzo MlirAttribute encodingAttr = mlirAttributeGetNull(); 564436c6c9cSStella Laurenzo MlirContext context = contextWrapper->get(); 5655d6d30edSStella Laurenzo 5665d6d30edSStella Laurenzo // Detect format codes that are suitable for bulk loading. This includes 5675d6d30edSStella Laurenzo // all byte aligned integer and floating point types up to 8 bytes. 5685d6d30edSStella Laurenzo // Notably, this excludes, bool (which needs to be bit-packed) and 5695d6d30edSStella Laurenzo // other exotics which do not have a direct representation in the buffer 5705d6d30edSStella Laurenzo // protocol (i.e. complex, etc). 5710a81ace0SKazu Hirata std::optional<MlirType> bulkLoadElementType; 5725d6d30edSStella Laurenzo if (explicitType) { 5735d6d30edSStella Laurenzo bulkLoadElementType = *explicitType; 5745d6d30edSStella Laurenzo } else if (arrayInfo.format == "f") { 575436c6c9cSStella Laurenzo // f32 576436c6c9cSStella Laurenzo assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); 5775d6d30edSStella Laurenzo bulkLoadElementType = mlirF32TypeGet(context); 578436c6c9cSStella Laurenzo } else if (arrayInfo.format == "d") { 579436c6c9cSStella Laurenzo // f64 580436c6c9cSStella Laurenzo assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); 5815d6d30edSStella Laurenzo bulkLoadElementType = mlirF64TypeGet(context); 5825d6d30edSStella Laurenzo } else if (arrayInfo.format == "e") { 5835d6d30edSStella Laurenzo // f16 5845d6d30edSStella Laurenzo assert(arrayInfo.itemsize == 2 && "mismatched array itemsize"); 5855d6d30edSStella Laurenzo bulkLoadElementType = mlirF16TypeGet(context); 586436c6c9cSStella Laurenzo } else if (isSignedIntegerFormat(arrayInfo.format)) { 587436c6c9cSStella Laurenzo if (arrayInfo.itemsize == 4) { 588436c6c9cSStella Laurenzo // i32 5895d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32) 590436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 32); 591436c6c9cSStella Laurenzo } else if (arrayInfo.itemsize == 8) { 592436c6c9cSStella Laurenzo // i64 5935d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64) 594436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 64); 5955d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 1) { 5965d6d30edSStella Laurenzo // i8 5975d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 5985d6d30edSStella Laurenzo : mlirIntegerTypeSignedGet(context, 8); 5995d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 2) { 6005d6d30edSStella Laurenzo // i16 6015d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16) 6025d6d30edSStella Laurenzo : mlirIntegerTypeSignedGet(context, 16); 603436c6c9cSStella Laurenzo } 604436c6c9cSStella Laurenzo } else if (isUnsignedIntegerFormat(arrayInfo.format)) { 605436c6c9cSStella Laurenzo if (arrayInfo.itemsize == 4) { 606436c6c9cSStella Laurenzo // unsigned i32 6075d6d30edSStella Laurenzo bulkLoadElementType = signless 608436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 32) 609436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 32); 610436c6c9cSStella Laurenzo } else if (arrayInfo.itemsize == 8) { 611436c6c9cSStella Laurenzo // unsigned i64 6125d6d30edSStella Laurenzo bulkLoadElementType = signless 613436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 64) 614436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 64); 6155d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 1) { 6165d6d30edSStella Laurenzo // i8 6175d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 6185d6d30edSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 8); 6195d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 2) { 6205d6d30edSStella Laurenzo // i16 6215d6d30edSStella Laurenzo bulkLoadElementType = signless 6225d6d30edSStella Laurenzo ? mlirIntegerTypeGet(context, 16) 6235d6d30edSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 16); 624436c6c9cSStella Laurenzo } 625436c6c9cSStella Laurenzo } 6265d6d30edSStella Laurenzo if (bulkLoadElementType) { 627*99dee31eSAdam Paszke MlirType shapedType; 628*99dee31eSAdam Paszke if (mlirTypeIsAShaped(*bulkLoadElementType)) { 629*99dee31eSAdam Paszke if (explicitShape) { 630*99dee31eSAdam Paszke throw std::invalid_argument("Shape can only be specified explicitly " 631*99dee31eSAdam Paszke "when the type is not a shaped type."); 632*99dee31eSAdam Paszke } 633*99dee31eSAdam Paszke shapedType = *bulkLoadElementType; 634*99dee31eSAdam Paszke } else { 635*99dee31eSAdam Paszke shapedType = mlirRankedTensorTypeGet( 6365d6d30edSStella Laurenzo shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); 637*99dee31eSAdam Paszke } 6385d6d30edSStella Laurenzo size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize; 6395d6d30edSStella Laurenzo MlirAttribute attr = mlirDenseElementsAttrRawBufferGet( 6405d6d30edSStella Laurenzo shapedType, rawBufferSize, arrayInfo.ptr); 6415d6d30edSStella Laurenzo if (mlirAttributeIsNull(attr)) { 6425d6d30edSStella Laurenzo throw std::invalid_argument( 6435d6d30edSStella Laurenzo "DenseElementsAttr could not be constructed from the given buffer. " 6445d6d30edSStella Laurenzo "This may mean that the Python buffer layout does not match that " 6455d6d30edSStella Laurenzo "MLIR expected layout and is a bug."); 6465d6d30edSStella Laurenzo } 6475d6d30edSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), attr); 6485d6d30edSStella Laurenzo } 649436c6c9cSStella Laurenzo 6505d6d30edSStella Laurenzo throw std::invalid_argument( 6515d6d30edSStella Laurenzo std::string("unimplemented array format conversion from format: ") + 6525d6d30edSStella Laurenzo arrayInfo.format); 653436c6c9cSStella Laurenzo } 654436c6c9cSStella Laurenzo 6551fc096afSMehdi Amini static PyDenseElementsAttribute getSplat(const PyType &shapedType, 656436c6c9cSStella Laurenzo PyAttribute &elementAttr) { 657436c6c9cSStella Laurenzo auto contextWrapper = 658436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 659436c6c9cSStella Laurenzo if (!mlirAttributeIsAInteger(elementAttr) && 660436c6c9cSStella Laurenzo !mlirAttributeIsAFloat(elementAttr)) { 661436c6c9cSStella Laurenzo std::string message = "Illegal element type for DenseElementsAttr: "; 662436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 663436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 664436c6c9cSStella Laurenzo } 665436c6c9cSStella Laurenzo if (!mlirTypeIsAShaped(shapedType) || 666436c6c9cSStella Laurenzo !mlirShapedTypeHasStaticShape(shapedType)) { 667436c6c9cSStella Laurenzo std::string message = 668436c6c9cSStella Laurenzo "Expected a static ShapedType for the shaped_type parameter: "; 669436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 670436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 671436c6c9cSStella Laurenzo } 672436c6c9cSStella Laurenzo MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 673436c6c9cSStella Laurenzo MlirType attrType = mlirAttributeGetType(elementAttr); 674436c6c9cSStella Laurenzo if (!mlirTypeEqual(shapedElementType, attrType)) { 675436c6c9cSStella Laurenzo std::string message = 676436c6c9cSStella Laurenzo "Shaped element type and attribute type must be equal: shaped="; 677436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 678436c6c9cSStella Laurenzo message.append(", element="); 679436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 680436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 681436c6c9cSStella Laurenzo } 682436c6c9cSStella Laurenzo 683436c6c9cSStella Laurenzo MlirAttribute elements = 684436c6c9cSStella Laurenzo mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 685436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 686436c6c9cSStella Laurenzo } 687436c6c9cSStella Laurenzo 688436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 689436c6c9cSStella Laurenzo 690436c6c9cSStella Laurenzo py::buffer_info accessBuffer() { 6915d6d30edSStella Laurenzo if (mlirDenseElementsAttrIsSplat(*this)) { 692c5f445d1SStella Laurenzo // TODO: Currently crashes the program. 6935d6d30edSStella Laurenzo // Reported as https://github.com/pybind/pybind11/issues/3336 694c5f445d1SStella Laurenzo throw std::invalid_argument( 695c5f445d1SStella Laurenzo "unsupported data type for conversion to Python buffer"); 6965d6d30edSStella Laurenzo } 6975d6d30edSStella Laurenzo 698436c6c9cSStella Laurenzo MlirType shapedType = mlirAttributeGetType(*this); 699436c6c9cSStella Laurenzo MlirType elementType = mlirShapedTypeGetElementType(shapedType); 7005d6d30edSStella Laurenzo std::string format; 701436c6c9cSStella Laurenzo 702436c6c9cSStella Laurenzo if (mlirTypeIsAF32(elementType)) { 703436c6c9cSStella Laurenzo // f32 7045d6d30edSStella Laurenzo return bufferInfo<float>(shapedType); 70502b6fb21SMehdi Amini } 70602b6fb21SMehdi Amini if (mlirTypeIsAF64(elementType)) { 707436c6c9cSStella Laurenzo // f64 7085d6d30edSStella Laurenzo return bufferInfo<double>(shapedType); 709bb56c2b3SMehdi Amini } 710bb56c2b3SMehdi Amini if (mlirTypeIsAF16(elementType)) { 7115d6d30edSStella Laurenzo // f16 7125d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType, "e"); 713bb56c2b3SMehdi Amini } 714bb56c2b3SMehdi Amini if (mlirTypeIsAInteger(elementType) && 715436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 32) { 716436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 717436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 718436c6c9cSStella Laurenzo // i32 7195d6d30edSStella Laurenzo return bufferInfo<int32_t>(shapedType); 720e5639b3fSMehdi Amini } 721e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 722436c6c9cSStella Laurenzo // unsigned i32 7235d6d30edSStella Laurenzo return bufferInfo<uint32_t>(shapedType); 724436c6c9cSStella Laurenzo } 725436c6c9cSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 726436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 64) { 727436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 728436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 729436c6c9cSStella Laurenzo // i64 7305d6d30edSStella Laurenzo return bufferInfo<int64_t>(shapedType); 731e5639b3fSMehdi Amini } 732e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 733436c6c9cSStella Laurenzo // unsigned i64 7345d6d30edSStella Laurenzo return bufferInfo<uint64_t>(shapedType); 7355d6d30edSStella Laurenzo } 7365d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 7375d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 8) { 7385d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 7395d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 7405d6d30edSStella Laurenzo // i8 7415d6d30edSStella Laurenzo return bufferInfo<int8_t>(shapedType); 742e5639b3fSMehdi Amini } 743e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 7445d6d30edSStella Laurenzo // unsigned i8 7455d6d30edSStella Laurenzo return bufferInfo<uint8_t>(shapedType); 7465d6d30edSStella Laurenzo } 7475d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 7485d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 16) { 7495d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 7505d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 7515d6d30edSStella Laurenzo // i16 7525d6d30edSStella Laurenzo return bufferInfo<int16_t>(shapedType); 753e5639b3fSMehdi Amini } 754e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 7555d6d30edSStella Laurenzo // unsigned i16 7565d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType); 757436c6c9cSStella Laurenzo } 758436c6c9cSStella Laurenzo } 759436c6c9cSStella Laurenzo 760c5f445d1SStella Laurenzo // TODO: Currently crashes the program. 7615d6d30edSStella Laurenzo // Reported as https://github.com/pybind/pybind11/issues/3336 762c5f445d1SStella Laurenzo throw std::invalid_argument( 763c5f445d1SStella Laurenzo "unsupported data type for conversion to Python buffer"); 764436c6c9cSStella Laurenzo } 765436c6c9cSStella Laurenzo 766436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 767436c6c9cSStella Laurenzo c.def("__len__", &PyDenseElementsAttribute::dunderLen) 768436c6c9cSStella Laurenzo .def_static("get", PyDenseElementsAttribute::getFromBuffer, 769436c6c9cSStella Laurenzo py::arg("array"), py::arg("signless") = true, 7705d6d30edSStella Laurenzo py::arg("type") = py::none(), py::arg("shape") = py::none(), 771436c6c9cSStella Laurenzo py::arg("context") = py::none(), 7725d6d30edSStella Laurenzo kDenseElementsAttrGetDocstring) 773436c6c9cSStella Laurenzo .def_static("get_splat", PyDenseElementsAttribute::getSplat, 774436c6c9cSStella Laurenzo py::arg("shaped_type"), py::arg("element_attr"), 775436c6c9cSStella Laurenzo "Gets a DenseElementsAttr where all values are the same") 776436c6c9cSStella Laurenzo .def_property_readonly("is_splat", 777436c6c9cSStella Laurenzo [](PyDenseElementsAttribute &self) -> bool { 778436c6c9cSStella Laurenzo return mlirDenseElementsAttrIsSplat(self); 779436c6c9cSStella Laurenzo }) 780436c6c9cSStella Laurenzo .def_buffer(&PyDenseElementsAttribute::accessBuffer); 781436c6c9cSStella Laurenzo } 782436c6c9cSStella Laurenzo 783436c6c9cSStella Laurenzo private: 784436c6c9cSStella Laurenzo static bool isUnsignedIntegerFormat(const std::string &format) { 785436c6c9cSStella Laurenzo if (format.empty()) 786436c6c9cSStella Laurenzo return false; 787436c6c9cSStella Laurenzo char code = format[0]; 788436c6c9cSStella Laurenzo return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 789436c6c9cSStella Laurenzo code == 'Q'; 790436c6c9cSStella Laurenzo } 791436c6c9cSStella Laurenzo 792436c6c9cSStella Laurenzo static bool isSignedIntegerFormat(const std::string &format) { 793436c6c9cSStella Laurenzo if (format.empty()) 794436c6c9cSStella Laurenzo return false; 795436c6c9cSStella Laurenzo char code = format[0]; 796436c6c9cSStella Laurenzo return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 797436c6c9cSStella Laurenzo code == 'q'; 798436c6c9cSStella Laurenzo } 799436c6c9cSStella Laurenzo 800436c6c9cSStella Laurenzo template <typename Type> 801436c6c9cSStella Laurenzo py::buffer_info bufferInfo(MlirType shapedType, 8025d6d30edSStella Laurenzo const char *explicitFormat = nullptr) { 803436c6c9cSStella Laurenzo intptr_t rank = mlirShapedTypeGetRank(shapedType); 804436c6c9cSStella Laurenzo // Prepare the data for the buffer_info. 805436c6c9cSStella Laurenzo // Buffer is configured for read-only access below. 806436c6c9cSStella Laurenzo Type *data = static_cast<Type *>( 807436c6c9cSStella Laurenzo const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 808436c6c9cSStella Laurenzo // Prepare the shape for the buffer_info. 809436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> shape; 810436c6c9cSStella Laurenzo for (intptr_t i = 0; i < rank; ++i) 811436c6c9cSStella Laurenzo shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 812436c6c9cSStella Laurenzo // Prepare the strides for the buffer_info. 813436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> strides; 814436c6c9cSStella Laurenzo intptr_t strideFactor = 1; 815436c6c9cSStella Laurenzo for (intptr_t i = 1; i < rank; ++i) { 816436c6c9cSStella Laurenzo strideFactor = 1; 817436c6c9cSStella Laurenzo for (intptr_t j = i; j < rank; ++j) { 818436c6c9cSStella Laurenzo strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 819436c6c9cSStella Laurenzo } 820436c6c9cSStella Laurenzo strides.push_back(sizeof(Type) * strideFactor); 821436c6c9cSStella Laurenzo } 822436c6c9cSStella Laurenzo strides.push_back(sizeof(Type)); 8235d6d30edSStella Laurenzo std::string format; 8245d6d30edSStella Laurenzo if (explicitFormat) { 8255d6d30edSStella Laurenzo format = explicitFormat; 8265d6d30edSStella Laurenzo } else { 8275d6d30edSStella Laurenzo format = py::format_descriptor<Type>::format(); 8285d6d30edSStella Laurenzo } 8295d6d30edSStella Laurenzo return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, 8305d6d30edSStella Laurenzo /*readonly=*/true); 831436c6c9cSStella Laurenzo } 832436c6c9cSStella Laurenzo }; // namespace 833436c6c9cSStella Laurenzo 834436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer 835436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access. 836436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute 837436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseIntElementsAttribute, 838436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 839436c6c9cSStella Laurenzo public: 840436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 841436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseIntElementsAttr"; 842436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 843436c6c9cSStella Laurenzo 844436c6c9cSStella Laurenzo /// Returns the element at the given linear position. Asserts if the index is 845436c6c9cSStella Laurenzo /// out of range. 846436c6c9cSStella Laurenzo py::int_ dunderGetItem(intptr_t pos) { 847436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 848436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 849436c6c9cSStella Laurenzo "attempt to access out of bounds element"); 850436c6c9cSStella Laurenzo } 851436c6c9cSStella Laurenzo 852436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 853436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 854436c6c9cSStella Laurenzo assert(mlirTypeIsAInteger(type) && 855436c6c9cSStella Laurenzo "expected integer element type in dense int elements attribute"); 856436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 857436c6c9cSStella Laurenzo // elemental type of the attribute. py::int_ is implicitly constructible 858436c6c9cSStella Laurenzo // from any C++ integral type and handles bitwidth correctly. 859436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 860436c6c9cSStella Laurenzo // querying them on each element access. 861436c6c9cSStella Laurenzo unsigned width = mlirIntegerTypeGetWidth(type); 862436c6c9cSStella Laurenzo bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 863436c6c9cSStella Laurenzo if (isUnsigned) { 864436c6c9cSStella Laurenzo if (width == 1) { 865436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 866436c6c9cSStella Laurenzo } 867308d8b8cSRahul Kayaith if (width == 8) { 868308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetUInt8Value(*this, pos); 869308d8b8cSRahul Kayaith } 870308d8b8cSRahul Kayaith if (width == 16) { 871308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetUInt16Value(*this, pos); 872308d8b8cSRahul Kayaith } 873436c6c9cSStella Laurenzo if (width == 32) { 874436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt32Value(*this, pos); 875436c6c9cSStella Laurenzo } 876436c6c9cSStella Laurenzo if (width == 64) { 877436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt64Value(*this, pos); 878436c6c9cSStella Laurenzo } 879436c6c9cSStella Laurenzo } else { 880436c6c9cSStella Laurenzo if (width == 1) { 881436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 882436c6c9cSStella Laurenzo } 883308d8b8cSRahul Kayaith if (width == 8) { 884308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetInt8Value(*this, pos); 885308d8b8cSRahul Kayaith } 886308d8b8cSRahul Kayaith if (width == 16) { 887308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetInt16Value(*this, pos); 888308d8b8cSRahul Kayaith } 889436c6c9cSStella Laurenzo if (width == 32) { 890436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt32Value(*this, pos); 891436c6c9cSStella Laurenzo } 892436c6c9cSStella Laurenzo if (width == 64) { 893436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt64Value(*this, pos); 894436c6c9cSStella Laurenzo } 895436c6c9cSStella Laurenzo } 896436c6c9cSStella Laurenzo throw SetPyError(PyExc_TypeError, "Unsupported integer type"); 897436c6c9cSStella Laurenzo } 898436c6c9cSStella Laurenzo 899436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 900436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 901436c6c9cSStella Laurenzo } 902436c6c9cSStella Laurenzo }; 903436c6c9cSStella Laurenzo 904436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 905436c6c9cSStella Laurenzo public: 906436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 907436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DictAttr"; 908436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 909436c6c9cSStella Laurenzo 910436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 911436c6c9cSStella Laurenzo 9129fb1086bSAdrian Kuegel bool dunderContains(const std::string &name) { 9139fb1086bSAdrian Kuegel return !mlirAttributeIsNull( 9149fb1086bSAdrian Kuegel mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); 9159fb1086bSAdrian Kuegel } 9169fb1086bSAdrian Kuegel 917436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 9189fb1086bSAdrian Kuegel c.def("__contains__", &PyDictAttribute::dunderContains); 919436c6c9cSStella Laurenzo c.def("__len__", &PyDictAttribute::dunderLen); 920436c6c9cSStella Laurenzo c.def_static( 921436c6c9cSStella Laurenzo "get", 922436c6c9cSStella Laurenzo [](py::dict attributes, DefaultingPyMlirContext context) { 923436c6c9cSStella Laurenzo SmallVector<MlirNamedAttribute> mlirNamedAttributes; 924436c6c9cSStella Laurenzo mlirNamedAttributes.reserve(attributes.size()); 925436c6c9cSStella Laurenzo for (auto &it : attributes) { 92602b6fb21SMehdi Amini auto &mlirAttr = it.second.cast<PyAttribute &>(); 927436c6c9cSStella Laurenzo auto name = it.first.cast<std::string>(); 928436c6c9cSStella Laurenzo mlirNamedAttributes.push_back(mlirNamedAttributeGet( 92902b6fb21SMehdi Amini mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), 930436c6c9cSStella Laurenzo toMlirStringRef(name)), 93102b6fb21SMehdi Amini mlirAttr)); 932436c6c9cSStella Laurenzo } 933436c6c9cSStella Laurenzo MlirAttribute attr = 934436c6c9cSStella Laurenzo mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 935436c6c9cSStella Laurenzo mlirNamedAttributes.data()); 936436c6c9cSStella Laurenzo return PyDictAttribute(context->getRef(), attr); 937436c6c9cSStella Laurenzo }, 938ed9e52f3SAlex Zinenko py::arg("value") = py::dict(), py::arg("context") = py::none(), 939436c6c9cSStella Laurenzo "Gets an uniqued dict attribute"); 940436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 941436c6c9cSStella Laurenzo MlirAttribute attr = 942436c6c9cSStella Laurenzo mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 943436c6c9cSStella Laurenzo if (mlirAttributeIsNull(attr)) { 944436c6c9cSStella Laurenzo throw SetPyError(PyExc_KeyError, 945436c6c9cSStella Laurenzo "attempt to access a non-existent attribute"); 946436c6c9cSStella Laurenzo } 947436c6c9cSStella Laurenzo return PyAttribute(self.getContext(), attr); 948436c6c9cSStella Laurenzo }); 949436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 950436c6c9cSStella Laurenzo if (index < 0 || index >= self.dunderLen()) { 951436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 952436c6c9cSStella Laurenzo "attempt to access out of bounds attribute"); 953436c6c9cSStella Laurenzo } 954436c6c9cSStella Laurenzo MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 955436c6c9cSStella Laurenzo return PyNamedAttribute( 956436c6c9cSStella Laurenzo namedAttr.attribute, 957436c6c9cSStella Laurenzo std::string(mlirIdentifierStr(namedAttr.name).data)); 958436c6c9cSStella Laurenzo }); 959436c6c9cSStella Laurenzo } 960436c6c9cSStella Laurenzo }; 961436c6c9cSStella Laurenzo 962436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing 963436c6c9cSStella Laurenzo /// floating-point values. Supports element access. 964436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute 965436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseFPElementsAttribute, 966436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 967436c6c9cSStella Laurenzo public: 968436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 969436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseFPElementsAttr"; 970436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 971436c6c9cSStella Laurenzo 972436c6c9cSStella Laurenzo py::float_ dunderGetItem(intptr_t pos) { 973436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 974436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 975436c6c9cSStella Laurenzo "attempt to access out of bounds element"); 976436c6c9cSStella Laurenzo } 977436c6c9cSStella Laurenzo 978436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 979436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 980436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 981436c6c9cSStella Laurenzo // elemental type of the attribute. py::float_ is implicitly constructible 982436c6c9cSStella Laurenzo // from float and double. 983436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 984436c6c9cSStella Laurenzo // querying them on each element access. 985436c6c9cSStella Laurenzo if (mlirTypeIsAF32(type)) { 986436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetFloatValue(*this, pos); 987436c6c9cSStella Laurenzo } 988436c6c9cSStella Laurenzo if (mlirTypeIsAF64(type)) { 989436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetDoubleValue(*this, pos); 990436c6c9cSStella Laurenzo } 991436c6c9cSStella Laurenzo throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); 992436c6c9cSStella Laurenzo } 993436c6c9cSStella Laurenzo 994436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 995436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 996436c6c9cSStella Laurenzo } 997436c6c9cSStella Laurenzo }; 998436c6c9cSStella Laurenzo 999436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 1000436c6c9cSStella Laurenzo public: 1001436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 1002436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "TypeAttr"; 1003436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1004436c6c9cSStella Laurenzo 1005436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1006436c6c9cSStella Laurenzo c.def_static( 1007436c6c9cSStella Laurenzo "get", 1008436c6c9cSStella Laurenzo [](PyType value, DefaultingPyMlirContext context) { 1009436c6c9cSStella Laurenzo MlirAttribute attr = mlirTypeAttrGet(value.get()); 1010436c6c9cSStella Laurenzo return PyTypeAttribute(context->getRef(), attr); 1011436c6c9cSStella Laurenzo }, 1012436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 1013436c6c9cSStella Laurenzo "Gets a uniqued Type attribute"); 1014436c6c9cSStella Laurenzo c.def_property_readonly("value", [](PyTypeAttribute &self) { 1015436c6c9cSStella Laurenzo return PyType(self.getContext()->getRef(), 1016436c6c9cSStella Laurenzo mlirTypeAttrGetValue(self.get())); 1017436c6c9cSStella Laurenzo }); 1018436c6c9cSStella Laurenzo } 1019436c6c9cSStella Laurenzo }; 1020436c6c9cSStella Laurenzo 1021436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values. 1022436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 1023436c6c9cSStella Laurenzo public: 1024436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 1025436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "UnitAttr"; 1026436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1027436c6c9cSStella Laurenzo 1028436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1029436c6c9cSStella Laurenzo c.def_static( 1030436c6c9cSStella Laurenzo "get", 1031436c6c9cSStella Laurenzo [](DefaultingPyMlirContext context) { 1032436c6c9cSStella Laurenzo return PyUnitAttribute(context->getRef(), 1033436c6c9cSStella Laurenzo mlirUnitAttrGet(context->get())); 1034436c6c9cSStella Laurenzo }, 1035436c6c9cSStella Laurenzo py::arg("context") = py::none(), "Create a Unit attribute."); 1036436c6c9cSStella Laurenzo } 1037436c6c9cSStella Laurenzo }; 1038436c6c9cSStella Laurenzo 1039ac2e2d65SDenys Shabalin /// Strided layout attribute subclass. 1040ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute 1041ac2e2d65SDenys Shabalin : public PyConcreteAttribute<PyStridedLayoutAttribute> { 1042ac2e2d65SDenys Shabalin public: 1043ac2e2d65SDenys Shabalin static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; 1044ac2e2d65SDenys Shabalin static constexpr const char *pyClassName = "StridedLayoutAttr"; 1045ac2e2d65SDenys Shabalin using PyConcreteAttribute::PyConcreteAttribute; 1046ac2e2d65SDenys Shabalin 1047ac2e2d65SDenys Shabalin static void bindDerived(ClassTy &c) { 1048ac2e2d65SDenys Shabalin c.def_static( 1049ac2e2d65SDenys Shabalin "get", 1050ac2e2d65SDenys Shabalin [](int64_t offset, const std::vector<int64_t> strides, 1051ac2e2d65SDenys Shabalin DefaultingPyMlirContext ctx) { 1052ac2e2d65SDenys Shabalin MlirAttribute attr = mlirStridedLayoutAttrGet( 1053ac2e2d65SDenys Shabalin ctx->get(), offset, strides.size(), strides.data()); 1054ac2e2d65SDenys Shabalin return PyStridedLayoutAttribute(ctx->getRef(), attr); 1055ac2e2d65SDenys Shabalin }, 1056ac2e2d65SDenys Shabalin py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(), 1057ac2e2d65SDenys Shabalin "Gets a strided layout attribute."); 1058e3fd612eSDenys Shabalin c.def_static( 1059e3fd612eSDenys Shabalin "get_fully_dynamic", 1060e3fd612eSDenys Shabalin [](int64_t rank, DefaultingPyMlirContext ctx) { 1061e3fd612eSDenys Shabalin auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); 1062e3fd612eSDenys Shabalin std::vector<int64_t> strides(rank); 1063e3fd612eSDenys Shabalin std::fill(strides.begin(), strides.end(), dynamic); 1064e3fd612eSDenys Shabalin MlirAttribute attr = mlirStridedLayoutAttrGet( 1065e3fd612eSDenys Shabalin ctx->get(), dynamic, strides.size(), strides.data()); 1066e3fd612eSDenys Shabalin return PyStridedLayoutAttribute(ctx->getRef(), attr); 1067e3fd612eSDenys Shabalin }, 1068e3fd612eSDenys Shabalin py::arg("rank"), py::arg("context") = py::none(), 1069e3fd612eSDenys Shabalin "Gets a strided layout attribute with dynamic offset and strides of a " 1070e3fd612eSDenys Shabalin "given rank."); 1071ac2e2d65SDenys Shabalin c.def_property_readonly( 1072ac2e2d65SDenys Shabalin "offset", 1073ac2e2d65SDenys Shabalin [](PyStridedLayoutAttribute &self) { 1074ac2e2d65SDenys Shabalin return mlirStridedLayoutAttrGetOffset(self); 1075ac2e2d65SDenys Shabalin }, 1076ac2e2d65SDenys Shabalin "Returns the value of the float point attribute"); 1077ac2e2d65SDenys Shabalin c.def_property_readonly( 1078ac2e2d65SDenys Shabalin "strides", 1079ac2e2d65SDenys Shabalin [](PyStridedLayoutAttribute &self) { 1080ac2e2d65SDenys Shabalin intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); 1081ac2e2d65SDenys Shabalin std::vector<int64_t> strides(size); 1082ac2e2d65SDenys Shabalin for (intptr_t i = 0; i < size; i++) { 1083ac2e2d65SDenys Shabalin strides[i] = mlirStridedLayoutAttrGetStride(self, i); 1084ac2e2d65SDenys Shabalin } 1085ac2e2d65SDenys Shabalin return strides; 1086ac2e2d65SDenys Shabalin }, 1087ac2e2d65SDenys Shabalin "Returns the value of the float point attribute"); 1088ac2e2d65SDenys Shabalin } 1089ac2e2d65SDenys Shabalin }; 1090ac2e2d65SDenys Shabalin 1091436c6c9cSStella Laurenzo } // namespace 1092436c6c9cSStella Laurenzo 1093436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) { 1094436c6c9cSStella Laurenzo PyAffineMapAttribute::bind(m); 1095619fd8c2SJeff Niu 1096619fd8c2SJeff Niu PyDenseBoolArrayAttribute::bind(m); 1097619fd8c2SJeff Niu PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); 1098619fd8c2SJeff Niu PyDenseI8ArrayAttribute::bind(m); 1099619fd8c2SJeff Niu PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m); 1100619fd8c2SJeff Niu PyDenseI16ArrayAttribute::bind(m); 1101619fd8c2SJeff Niu PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m); 1102619fd8c2SJeff Niu PyDenseI32ArrayAttribute::bind(m); 1103619fd8c2SJeff Niu PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m); 1104619fd8c2SJeff Niu PyDenseI64ArrayAttribute::bind(m); 1105619fd8c2SJeff Niu PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m); 1106619fd8c2SJeff Niu PyDenseF32ArrayAttribute::bind(m); 1107619fd8c2SJeff Niu PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); 1108619fd8c2SJeff Niu PyDenseF64ArrayAttribute::bind(m); 1109619fd8c2SJeff Niu PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); 1110619fd8c2SJeff Niu 1111436c6c9cSStella Laurenzo PyArrayAttribute::bind(m); 1112436c6c9cSStella Laurenzo PyArrayAttribute::PyArrayAttributeIterator::bind(m); 1113436c6c9cSStella Laurenzo PyBoolAttribute::bind(m); 1114436c6c9cSStella Laurenzo PyDenseElementsAttribute::bind(m); 1115436c6c9cSStella Laurenzo PyDenseFPElementsAttribute::bind(m); 1116436c6c9cSStella Laurenzo PyDenseIntElementsAttribute::bind(m); 1117436c6c9cSStella Laurenzo PyDictAttribute::bind(m); 1118436c6c9cSStella Laurenzo PyFlatSymbolRefAttribute::bind(m); 11195c3861b2SYun Long PyOpaqueAttribute::bind(m); 1120436c6c9cSStella Laurenzo PyFloatAttribute::bind(m); 1121436c6c9cSStella Laurenzo PyIntegerAttribute::bind(m); 1122436c6c9cSStella Laurenzo PyStringAttribute::bind(m); 1123436c6c9cSStella Laurenzo PyTypeAttribute::bind(m); 1124436c6c9cSStella Laurenzo PyUnitAttribute::bind(m); 1125ac2e2d65SDenys Shabalin 1126ac2e2d65SDenys Shabalin PyStridedLayoutAttribute::bind(m); 1127436c6c9cSStella Laurenzo } 1128