1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2436c6c9cSStella Laurenzo // 3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6436c6c9cSStella Laurenzo // 7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8436c6c9cSStella Laurenzo 9a1fe1f5fSKazu Hirata #include <optional> 104811270bSmax #include <utility> 111fc096afSMehdi Amini 12436c6c9cSStella Laurenzo #include "IRModule.h" 13436c6c9cSStella Laurenzo 14436c6c9cSStella Laurenzo #include "PybindUtils.h" 15436c6c9cSStella Laurenzo 16436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h" 17436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h" 18bfb1ba75Smax #include "mlir/Bindings/Python/PybindAdaptors.h" 19436c6c9cSStella Laurenzo 20436c6c9cSStella Laurenzo namespace py = pybind11; 21436c6c9cSStella Laurenzo using namespace mlir; 22436c6c9cSStella Laurenzo using namespace mlir::python; 23436c6c9cSStella Laurenzo 24436c6c9cSStella Laurenzo using llvm::SmallVector; 25436c6c9cSStella Laurenzo 265d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 275d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline). 285d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 295d6d30edSStella Laurenzo 305d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] = 315d6d30edSStella Laurenzo R"(Gets a DenseElementsAttr from a Python buffer or array. 325d6d30edSStella Laurenzo 335d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based 345d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and 355d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these 365d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer. 375d6d30edSStella Laurenzo 385d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided 395d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal 405d6d30edSStella Laurenzo representation: 415d6d30edSStella Laurenzo 425d6d30edSStella Laurenzo * Integer types (except for i1): the buffer must be byte aligned to the 435d6d30edSStella Laurenzo next byte boundary. 445d6d30edSStella Laurenzo * Floating point types: Must be bit-castable to the given floating point 455d6d30edSStella Laurenzo size. 465d6d30edSStella Laurenzo * i1 (bool): Bit packed into 8bit words where the bit pattern matches a 475d6d30edSStella Laurenzo row major ordering. An arbitrary Numpy `bool_` array can be bit packed to 485d6d30edSStella Laurenzo this specification with: `np.packbits(ary, axis=None, bitorder='little')`. 495d6d30edSStella Laurenzo 505d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0 515d6d30edSStella Laurenzo or 255), then a splat will be created. 525d6d30edSStella Laurenzo 535d6d30edSStella Laurenzo Args: 545d6d30edSStella Laurenzo array: The array or buffer to convert. 555d6d30edSStella Laurenzo signless: If inferring an appropriate MLIR type, use signless types for 565d6d30edSStella Laurenzo integers (defaults True). 575d6d30edSStella Laurenzo type: Skips inference of the MLIR element type and uses this instead. The 585d6d30edSStella Laurenzo storage size must be consistent with the actual contents of the buffer. 595d6d30edSStella Laurenzo shape: Overrides the shape of the buffer when constructing the MLIR 605d6d30edSStella Laurenzo shaped type. This is needed when the physical and logical shape differ (as 615d6d30edSStella Laurenzo for i1). 625d6d30edSStella Laurenzo context: Explicit context, if not from context manager. 635d6d30edSStella Laurenzo 645d6d30edSStella Laurenzo Returns: 655d6d30edSStella Laurenzo DenseElementsAttr on success. 665d6d30edSStella Laurenzo 675d6d30edSStella Laurenzo Raises: 685d6d30edSStella Laurenzo ValueError: If the type of the buffer or array cannot be matched to an MLIR 695d6d30edSStella Laurenzo type or if the buffer does not meet expectations. 705d6d30edSStella Laurenzo )"; 715d6d30edSStella Laurenzo 72436c6c9cSStella Laurenzo namespace { 73436c6c9cSStella Laurenzo 74436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) { 75436c6c9cSStella Laurenzo return mlirStringRefCreate(s.data(), s.size()); 76436c6c9cSStella Laurenzo } 77436c6c9cSStella Laurenzo 78436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 79436c6c9cSStella Laurenzo public: 80436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 81436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineMapAttr"; 82436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 83*9566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 84*9566ee28Smax mlirAffineMapAttrGetTypeID; 85436c6c9cSStella Laurenzo 86436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 87436c6c9cSStella Laurenzo c.def_static( 88436c6c9cSStella Laurenzo "get", 89436c6c9cSStella Laurenzo [](PyAffineMap &affineMap) { 90436c6c9cSStella Laurenzo MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 91436c6c9cSStella Laurenzo return PyAffineMapAttribute(affineMap.getContext(), attr); 92436c6c9cSStella Laurenzo }, 93436c6c9cSStella Laurenzo py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 94436c6c9cSStella Laurenzo } 95436c6c9cSStella Laurenzo }; 96436c6c9cSStella Laurenzo 97ed9e52f3SAlex Zinenko template <typename T> 98ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) { 99ed9e52f3SAlex Zinenko try { 100ed9e52f3SAlex Zinenko return object.cast<T>(); 101ed9e52f3SAlex Zinenko } catch (py::cast_error &err) { 102ed9e52f3SAlex Zinenko std::string msg = 103ed9e52f3SAlex Zinenko std::string( 104ed9e52f3SAlex Zinenko "Invalid attribute when attempting to create an ArrayAttribute (") + 105ed9e52f3SAlex Zinenko err.what() + ")"; 106ed9e52f3SAlex Zinenko throw py::cast_error(msg); 107ed9e52f3SAlex Zinenko } catch (py::reference_cast_error &err) { 108ed9e52f3SAlex Zinenko std::string msg = std::string("Invalid attribute (None?) when attempting " 109ed9e52f3SAlex Zinenko "to create an ArrayAttribute (") + 110ed9e52f3SAlex Zinenko err.what() + ")"; 111ed9e52f3SAlex Zinenko throw py::cast_error(msg); 112ed9e52f3SAlex Zinenko } 113ed9e52f3SAlex Zinenko } 114ed9e52f3SAlex Zinenko 115619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived 116619fd8c2SJeff Niu /// implementation class. 117619fd8c2SJeff Niu template <typename EltTy, typename DerivedT> 118133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> { 119619fd8c2SJeff Niu public: 120133624acSJeff Niu using PyConcreteAttribute<DerivedT>::PyConcreteAttribute; 121619fd8c2SJeff Niu 122619fd8c2SJeff Niu /// Iterator over the integer elements of a dense array. 123619fd8c2SJeff Niu class PyDenseArrayIterator { 124619fd8c2SJeff Niu public: 1254a1b1196SMehdi Amini PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {} 126619fd8c2SJeff Niu 127619fd8c2SJeff Niu /// Return a copy of the iterator. 128619fd8c2SJeff Niu PyDenseArrayIterator dunderIter() { return *this; } 129619fd8c2SJeff Niu 130619fd8c2SJeff Niu /// Return the next element. 131619fd8c2SJeff Niu EltTy dunderNext() { 132619fd8c2SJeff Niu // Throw if the index has reached the end. 133619fd8c2SJeff Niu if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) 134619fd8c2SJeff Niu throw py::stop_iteration(); 135619fd8c2SJeff Niu return DerivedT::getElement(attr.get(), nextIndex++); 136619fd8c2SJeff Niu } 137619fd8c2SJeff Niu 138619fd8c2SJeff Niu /// Bind the iterator class. 139619fd8c2SJeff Niu static void bind(py::module &m) { 140619fd8c2SJeff Niu py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName, 141619fd8c2SJeff Niu py::module_local()) 142619fd8c2SJeff Niu .def("__iter__", &PyDenseArrayIterator::dunderIter) 143619fd8c2SJeff Niu .def("__next__", &PyDenseArrayIterator::dunderNext); 144619fd8c2SJeff Niu } 145619fd8c2SJeff Niu 146619fd8c2SJeff Niu private: 147619fd8c2SJeff Niu /// The referenced dense array attribute. 148619fd8c2SJeff Niu PyAttribute attr; 149619fd8c2SJeff Niu /// The next index to read. 150619fd8c2SJeff Niu int nextIndex = 0; 151619fd8c2SJeff Niu }; 152619fd8c2SJeff Niu 153619fd8c2SJeff Niu /// Get the element at the given index. 154619fd8c2SJeff Niu EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); } 155619fd8c2SJeff Niu 156619fd8c2SJeff Niu /// Bind the attribute class. 157133624acSJeff Niu static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) { 158619fd8c2SJeff Niu // Bind the constructor. 159619fd8c2SJeff Niu c.def_static( 160619fd8c2SJeff Niu "get", 161619fd8c2SJeff Niu [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) { 162619fd8c2SJeff Niu MlirAttribute attr = 163619fd8c2SJeff Niu DerivedT::getAttribute(ctx->get(), values.size(), values.data()); 164133624acSJeff Niu return DerivedT(ctx->getRef(), attr); 165619fd8c2SJeff Niu }, 166619fd8c2SJeff Niu py::arg("values"), py::arg("context") = py::none(), 167619fd8c2SJeff Niu "Gets a uniqued dense array attribute"); 168619fd8c2SJeff Niu // Bind the array methods. 169133624acSJeff Niu c.def("__getitem__", [](DerivedT &arr, intptr_t i) { 170619fd8c2SJeff Niu if (i >= mlirDenseArrayGetNumElements(arr)) 171619fd8c2SJeff Niu throw py::index_error("DenseArray index out of range"); 172619fd8c2SJeff Niu return arr.getItem(i); 173619fd8c2SJeff Niu }); 174133624acSJeff Niu c.def("__len__", [](const DerivedT &arr) { 175619fd8c2SJeff Niu return mlirDenseArrayGetNumElements(arr); 176619fd8c2SJeff Niu }); 177133624acSJeff Niu c.def("__iter__", 178133624acSJeff Niu [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); 1794a1b1196SMehdi Amini c.def("__add__", [](DerivedT &arr, const py::list &extras) { 180619fd8c2SJeff Niu std::vector<EltTy> values; 181619fd8c2SJeff Niu intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); 182619fd8c2SJeff Niu values.reserve(numOldElements + py::len(extras)); 183619fd8c2SJeff Niu for (intptr_t i = 0; i < numOldElements; ++i) 184619fd8c2SJeff Niu values.push_back(arr.getItem(i)); 185619fd8c2SJeff Niu for (py::handle attr : extras) 186619fd8c2SJeff Niu values.push_back(pyTryCast<EltTy>(attr)); 187619fd8c2SJeff Niu MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(), 188619fd8c2SJeff Niu values.size(), values.data()); 189133624acSJeff Niu return DerivedT(arr.getContext(), attr); 190619fd8c2SJeff Niu }); 191619fd8c2SJeff Niu } 192619fd8c2SJeff Niu }; 193619fd8c2SJeff Niu 194619fd8c2SJeff Niu /// Instantiate the python dense array classes. 195619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute 196619fd8c2SJeff Niu : public PyDenseArrayAttribute<int, PyDenseBoolArrayAttribute> { 197619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; 198619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseBoolArrayGet; 199619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseBoolArrayGetElement; 200619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseBoolArrayAttr"; 201619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseBoolArrayIterator"; 202619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 203619fd8c2SJeff Niu }; 204619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute 205619fd8c2SJeff Niu : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> { 206619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array; 207619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI8ArrayGet; 208619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI8ArrayGetElement; 209619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI8ArrayAttr"; 210619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI8ArrayIterator"; 211619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 212619fd8c2SJeff Niu }; 213619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute 214619fd8c2SJeff Niu : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> { 215619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array; 216619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI16ArrayGet; 217619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI16ArrayGetElement; 218619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI16ArrayAttr"; 219619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI16ArrayIterator"; 220619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 221619fd8c2SJeff Niu }; 222619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute 223619fd8c2SJeff Niu : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> { 224619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array; 225619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI32ArrayGet; 226619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI32ArrayGetElement; 227619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI32ArrayAttr"; 228619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI32ArrayIterator"; 229619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 230619fd8c2SJeff Niu }; 231619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute 232619fd8c2SJeff Niu : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> { 233619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array; 234619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI64ArrayGet; 235619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI64ArrayGetElement; 236619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI64ArrayAttr"; 237619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI64ArrayIterator"; 238619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 239619fd8c2SJeff Niu }; 240619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute 241619fd8c2SJeff Niu : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> { 242619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array; 243619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseF32ArrayGet; 244619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseF32ArrayGetElement; 245619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseF32ArrayAttr"; 246619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseF32ArrayIterator"; 247619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 248619fd8c2SJeff Niu }; 249619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute 250619fd8c2SJeff Niu : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> { 251619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array; 252619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseF64ArrayGet; 253619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseF64ArrayGetElement; 254619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseF64ArrayAttr"; 255619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseF64ArrayIterator"; 256619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 257619fd8c2SJeff Niu }; 258619fd8c2SJeff Niu 259436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 260436c6c9cSStella Laurenzo public: 261436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 262436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "ArrayAttr"; 263436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 264*9566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 265*9566ee28Smax mlirArrayAttrGetTypeID; 266436c6c9cSStella Laurenzo 267436c6c9cSStella Laurenzo class PyArrayAttributeIterator { 268436c6c9cSStella Laurenzo public: 2691fc096afSMehdi Amini PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} 270436c6c9cSStella Laurenzo 271436c6c9cSStella Laurenzo PyArrayAttributeIterator &dunderIter() { return *this; } 272436c6c9cSStella Laurenzo 273436c6c9cSStella Laurenzo PyAttribute dunderNext() { 274bca88952SJeff Niu // TODO: Throw is an inefficient way to stop iteration. 275bca88952SJeff Niu if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) 276436c6c9cSStella Laurenzo throw py::stop_iteration(); 277436c6c9cSStella Laurenzo return PyAttribute(attr.getContext(), 278436c6c9cSStella Laurenzo mlirArrayAttrGetElement(attr.get(), nextIndex++)); 279436c6c9cSStella Laurenzo } 280436c6c9cSStella Laurenzo 281436c6c9cSStella Laurenzo static void bind(py::module &m) { 282f05ff4f7SStella Laurenzo py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator", 283f05ff4f7SStella Laurenzo py::module_local()) 284436c6c9cSStella Laurenzo .def("__iter__", &PyArrayAttributeIterator::dunderIter) 285436c6c9cSStella Laurenzo .def("__next__", &PyArrayAttributeIterator::dunderNext); 286436c6c9cSStella Laurenzo } 287436c6c9cSStella Laurenzo 288436c6c9cSStella Laurenzo private: 289436c6c9cSStella Laurenzo PyAttribute attr; 290436c6c9cSStella Laurenzo int nextIndex = 0; 291436c6c9cSStella Laurenzo }; 292436c6c9cSStella Laurenzo 293ed9e52f3SAlex Zinenko PyAttribute getItem(intptr_t i) { 294ed9e52f3SAlex Zinenko return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i)); 295ed9e52f3SAlex Zinenko } 296ed9e52f3SAlex Zinenko 297436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 298436c6c9cSStella Laurenzo c.def_static( 299436c6c9cSStella Laurenzo "get", 300436c6c9cSStella Laurenzo [](py::list attributes, DefaultingPyMlirContext context) { 301436c6c9cSStella Laurenzo SmallVector<MlirAttribute> mlirAttributes; 302436c6c9cSStella Laurenzo mlirAttributes.reserve(py::len(attributes)); 303436c6c9cSStella Laurenzo for (auto attribute : attributes) { 304ed9e52f3SAlex Zinenko mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); 305436c6c9cSStella Laurenzo } 306436c6c9cSStella Laurenzo MlirAttribute attr = mlirArrayAttrGet( 307436c6c9cSStella Laurenzo context->get(), mlirAttributes.size(), mlirAttributes.data()); 308436c6c9cSStella Laurenzo return PyArrayAttribute(context->getRef(), attr); 309436c6c9cSStella Laurenzo }, 310436c6c9cSStella Laurenzo py::arg("attributes"), py::arg("context") = py::none(), 311436c6c9cSStella Laurenzo "Gets a uniqued Array attribute"); 312436c6c9cSStella Laurenzo c.def("__getitem__", 313436c6c9cSStella Laurenzo [](PyArrayAttribute &arr, intptr_t i) { 314436c6c9cSStella Laurenzo if (i >= mlirArrayAttrGetNumElements(arr)) 315436c6c9cSStella Laurenzo throw py::index_error("ArrayAttribute index out of range"); 316ed9e52f3SAlex Zinenko return arr.getItem(i); 317436c6c9cSStella Laurenzo }) 318436c6c9cSStella Laurenzo .def("__len__", 319436c6c9cSStella Laurenzo [](const PyArrayAttribute &arr) { 320436c6c9cSStella Laurenzo return mlirArrayAttrGetNumElements(arr); 321436c6c9cSStella Laurenzo }) 322436c6c9cSStella Laurenzo .def("__iter__", [](const PyArrayAttribute &arr) { 323436c6c9cSStella Laurenzo return PyArrayAttributeIterator(arr); 324436c6c9cSStella Laurenzo }); 325ed9e52f3SAlex Zinenko c.def("__add__", [](PyArrayAttribute arr, py::list extras) { 326ed9e52f3SAlex Zinenko std::vector<MlirAttribute> attributes; 327ed9e52f3SAlex Zinenko intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); 328ed9e52f3SAlex Zinenko attributes.reserve(numOldElements + py::len(extras)); 329ed9e52f3SAlex Zinenko for (intptr_t i = 0; i < numOldElements; ++i) 330ed9e52f3SAlex Zinenko attributes.push_back(arr.getItem(i)); 331ed9e52f3SAlex Zinenko for (py::handle attr : extras) 332ed9e52f3SAlex Zinenko attributes.push_back(pyTryCast<PyAttribute>(attr)); 333ed9e52f3SAlex Zinenko MlirAttribute arrayAttr = mlirArrayAttrGet( 334ed9e52f3SAlex Zinenko arr.getContext()->get(), attributes.size(), attributes.data()); 335ed9e52f3SAlex Zinenko return PyArrayAttribute(arr.getContext(), arrayAttr); 336ed9e52f3SAlex Zinenko }); 337436c6c9cSStella Laurenzo } 338436c6c9cSStella Laurenzo }; 339436c6c9cSStella Laurenzo 340436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr. 341436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 342436c6c9cSStella Laurenzo public: 343436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 344436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FloatAttr"; 345436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 346*9566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 347*9566ee28Smax mlirFloatAttrGetTypeID; 348436c6c9cSStella Laurenzo 349436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 350436c6c9cSStella Laurenzo c.def_static( 351436c6c9cSStella Laurenzo "get", 352436c6c9cSStella Laurenzo [](PyType &type, double value, DefaultingPyLocation loc) { 3533ea4c501SRahul Kayaith PyMlirContext::ErrorCapture errors(loc->getContext()); 354436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 3553ea4c501SRahul Kayaith if (mlirAttributeIsNull(attr)) 3563ea4c501SRahul Kayaith throw MLIRError("Invalid attribute", errors.take()); 357436c6c9cSStella Laurenzo return PyFloatAttribute(type.getContext(), attr); 358436c6c9cSStella Laurenzo }, 359436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 360436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a type"); 361436c6c9cSStella Laurenzo c.def_static( 362436c6c9cSStella Laurenzo "get_f32", 363436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 364436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 365436c6c9cSStella Laurenzo context->get(), mlirF32TypeGet(context->get()), value); 366436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 367436c6c9cSStella Laurenzo }, 368436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 369436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f32 type"); 370436c6c9cSStella Laurenzo c.def_static( 371436c6c9cSStella Laurenzo "get_f64", 372436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 373436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 374436c6c9cSStella Laurenzo context->get(), mlirF64TypeGet(context->get()), value); 375436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 376436c6c9cSStella Laurenzo }, 377436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 378436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f64 type"); 379436c6c9cSStella Laurenzo c.def_property_readonly( 380436c6c9cSStella Laurenzo "value", 381436c6c9cSStella Laurenzo [](PyFloatAttribute &self) { 382436c6c9cSStella Laurenzo return mlirFloatAttrGetValueDouble(self); 383436c6c9cSStella Laurenzo }, 384436c6c9cSStella Laurenzo "Returns the value of the float point attribute"); 385436c6c9cSStella Laurenzo } 386436c6c9cSStella Laurenzo }; 387436c6c9cSStella Laurenzo 388436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr. 389436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 390436c6c9cSStella Laurenzo public: 391436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 392436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "IntegerAttr"; 393436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 394436c6c9cSStella Laurenzo 395436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 396436c6c9cSStella Laurenzo c.def_static( 397436c6c9cSStella Laurenzo "get", 398436c6c9cSStella Laurenzo [](PyType &type, int64_t value) { 399436c6c9cSStella Laurenzo MlirAttribute attr = mlirIntegerAttrGet(type, value); 400436c6c9cSStella Laurenzo return PyIntegerAttribute(type.getContext(), attr); 401436c6c9cSStella Laurenzo }, 402436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), 403436c6c9cSStella Laurenzo "Gets an uniqued integer attribute associated to a type"); 404436c6c9cSStella Laurenzo c.def_property_readonly( 405436c6c9cSStella Laurenzo "value", 406e9db306dSrkayaith [](PyIntegerAttribute &self) -> py::int_ { 407e9db306dSrkayaith MlirType type = mlirAttributeGetType(self); 408e9db306dSrkayaith if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) 409436c6c9cSStella Laurenzo return mlirIntegerAttrGetValueInt(self); 410e9db306dSrkayaith if (mlirIntegerTypeIsSigned(type)) 411e9db306dSrkayaith return mlirIntegerAttrGetValueSInt(self); 412e9db306dSrkayaith return mlirIntegerAttrGetValueUInt(self); 413436c6c9cSStella Laurenzo }, 414436c6c9cSStella Laurenzo "Returns the value of the integer attribute"); 415*9566ee28Smax c.def_property_readonly_static("static_typeid", 416*9566ee28Smax [](py::object & /*class*/) -> MlirTypeID { 417*9566ee28Smax return mlirIntegerAttrGetTypeID(); 418*9566ee28Smax }); 419436c6c9cSStella Laurenzo } 420436c6c9cSStella Laurenzo }; 421436c6c9cSStella Laurenzo 422436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr. 423436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 424436c6c9cSStella Laurenzo public: 425436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 426436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "BoolAttr"; 427436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 428436c6c9cSStella Laurenzo 429436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 430436c6c9cSStella Laurenzo c.def_static( 431436c6c9cSStella Laurenzo "get", 432436c6c9cSStella Laurenzo [](bool value, DefaultingPyMlirContext context) { 433436c6c9cSStella Laurenzo MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 434436c6c9cSStella Laurenzo return PyBoolAttribute(context->getRef(), attr); 435436c6c9cSStella Laurenzo }, 436436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 437436c6c9cSStella Laurenzo "Gets an uniqued bool attribute"); 438436c6c9cSStella Laurenzo c.def_property_readonly( 439436c6c9cSStella Laurenzo "value", 440436c6c9cSStella Laurenzo [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, 441436c6c9cSStella Laurenzo "Returns the value of the bool attribute"); 442436c6c9cSStella Laurenzo } 443436c6c9cSStella Laurenzo }; 444436c6c9cSStella Laurenzo 445436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute 446436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 447436c6c9cSStella Laurenzo public: 448436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 449436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 450436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 451*9566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 452*9566ee28Smax mlirFlatSymbolRefAttrGetTypeID; 453436c6c9cSStella Laurenzo 454436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 455436c6c9cSStella Laurenzo c.def_static( 456436c6c9cSStella Laurenzo "get", 457436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 458436c6c9cSStella Laurenzo MlirAttribute attr = 459436c6c9cSStella Laurenzo mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 460436c6c9cSStella Laurenzo return PyFlatSymbolRefAttribute(context->getRef(), attr); 461436c6c9cSStella Laurenzo }, 462436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 463436c6c9cSStella Laurenzo "Gets a uniqued FlatSymbolRef attribute"); 464436c6c9cSStella Laurenzo c.def_property_readonly( 465436c6c9cSStella Laurenzo "value", 466436c6c9cSStella Laurenzo [](PyFlatSymbolRefAttribute &self) { 467436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 468436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 469436c6c9cSStella Laurenzo }, 470436c6c9cSStella Laurenzo "Returns the value of the FlatSymbolRef attribute as a string"); 471436c6c9cSStella Laurenzo } 472436c6c9cSStella Laurenzo }; 473436c6c9cSStella Laurenzo 4745c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> { 4755c3861b2SYun Long public: 4765c3861b2SYun Long static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; 4775c3861b2SYun Long static constexpr const char *pyClassName = "OpaqueAttr"; 4785c3861b2SYun Long using PyConcreteAttribute::PyConcreteAttribute; 479*9566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 480*9566ee28Smax mlirOpaqueAttrGetTypeID; 4815c3861b2SYun Long 4825c3861b2SYun Long static void bindDerived(ClassTy &c) { 4835c3861b2SYun Long c.def_static( 4845c3861b2SYun Long "get", 4855c3861b2SYun Long [](std::string dialectNamespace, py::buffer buffer, PyType &type, 4865c3861b2SYun Long DefaultingPyMlirContext context) { 4875c3861b2SYun Long const py::buffer_info bufferInfo = buffer.request(); 4885c3861b2SYun Long intptr_t bufferSize = bufferInfo.size; 4895c3861b2SYun Long MlirAttribute attr = mlirOpaqueAttrGet( 4905c3861b2SYun Long context->get(), toMlirStringRef(dialectNamespace), bufferSize, 4915c3861b2SYun Long static_cast<char *>(bufferInfo.ptr), type); 4925c3861b2SYun Long return PyOpaqueAttribute(context->getRef(), attr); 4935c3861b2SYun Long }, 4945c3861b2SYun Long py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"), 4955c3861b2SYun Long py::arg("context") = py::none(), "Gets an Opaque attribute."); 4965c3861b2SYun Long c.def_property_readonly( 4975c3861b2SYun Long "dialect_namespace", 4985c3861b2SYun Long [](PyOpaqueAttribute &self) { 4995c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); 5005c3861b2SYun Long return py::str(stringRef.data, stringRef.length); 5015c3861b2SYun Long }, 5025c3861b2SYun Long "Returns the dialect namespace for the Opaque attribute as a string"); 5035c3861b2SYun Long c.def_property_readonly( 5045c3861b2SYun Long "data", 5055c3861b2SYun Long [](PyOpaqueAttribute &self) { 5065c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetData(self); 50762bf6c2eSChris Jones return py::bytes(stringRef.data, stringRef.length); 5085c3861b2SYun Long }, 50962bf6c2eSChris Jones "Returns the data for the Opaqued attributes as `bytes`"); 5105c3861b2SYun Long } 5115c3861b2SYun Long }; 5125c3861b2SYun Long 513436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 514436c6c9cSStella Laurenzo public: 515436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 516436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "StringAttr"; 517436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 518*9566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 519*9566ee28Smax mlirStringAttrGetTypeID; 520436c6c9cSStella Laurenzo 521436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 522436c6c9cSStella Laurenzo c.def_static( 523436c6c9cSStella Laurenzo "get", 524436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 525436c6c9cSStella Laurenzo MlirAttribute attr = 526436c6c9cSStella Laurenzo mlirStringAttrGet(context->get(), toMlirStringRef(value)); 527436c6c9cSStella Laurenzo return PyStringAttribute(context->getRef(), attr); 528436c6c9cSStella Laurenzo }, 529436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 530436c6c9cSStella Laurenzo "Gets a uniqued string attribute"); 531436c6c9cSStella Laurenzo c.def_static( 532436c6c9cSStella Laurenzo "get_typed", 533436c6c9cSStella Laurenzo [](PyType &type, std::string value) { 534436c6c9cSStella Laurenzo MlirAttribute attr = 535436c6c9cSStella Laurenzo mlirStringAttrTypedGet(type, toMlirStringRef(value)); 536436c6c9cSStella Laurenzo return PyStringAttribute(type.getContext(), attr); 537436c6c9cSStella Laurenzo }, 538a6e7d024SStella Laurenzo py::arg("type"), py::arg("value"), 539436c6c9cSStella Laurenzo "Gets a uniqued string attribute associated to a type"); 540436c6c9cSStella Laurenzo c.def_property_readonly( 541436c6c9cSStella Laurenzo "value", 542436c6c9cSStella Laurenzo [](PyStringAttribute &self) { 543436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirStringAttrGetValue(self); 544436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 545436c6c9cSStella Laurenzo }, 546436c6c9cSStella Laurenzo "Returns the value of the string attribute"); 54762bf6c2eSChris Jones c.def_property_readonly( 54862bf6c2eSChris Jones "value_bytes", 54962bf6c2eSChris Jones [](PyStringAttribute &self) { 55062bf6c2eSChris Jones MlirStringRef stringRef = mlirStringAttrGetValue(self); 55162bf6c2eSChris Jones return py::bytes(stringRef.data, stringRef.length); 55262bf6c2eSChris Jones }, 55362bf6c2eSChris Jones "Returns the value of the string attribute as `bytes`"); 554436c6c9cSStella Laurenzo } 555436c6c9cSStella Laurenzo }; 556436c6c9cSStella Laurenzo 557436c6c9cSStella Laurenzo // TODO: Support construction of string elements. 558436c6c9cSStella Laurenzo class PyDenseElementsAttribute 559436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseElementsAttribute> { 560436c6c9cSStella Laurenzo public: 561436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 562436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseElementsAttr"; 563436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 564436c6c9cSStella Laurenzo 565436c6c9cSStella Laurenzo static PyDenseElementsAttribute 5660a81ace0SKazu Hirata getFromBuffer(py::buffer array, bool signless, 5670a81ace0SKazu Hirata std::optional<PyType> explicitType, 5680a81ace0SKazu Hirata std::optional<std::vector<int64_t>> explicitShape, 569436c6c9cSStella Laurenzo DefaultingPyMlirContext contextWrapper) { 570436c6c9cSStella Laurenzo // Request a contiguous view. In exotic cases, this will cause a copy. 571436c6c9cSStella Laurenzo int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; 572436c6c9cSStella Laurenzo Py_buffer *view = new Py_buffer(); 573436c6c9cSStella Laurenzo if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { 574436c6c9cSStella Laurenzo delete view; 575436c6c9cSStella Laurenzo throw py::error_already_set(); 576436c6c9cSStella Laurenzo } 577436c6c9cSStella Laurenzo py::buffer_info arrayInfo(view); 5785d6d30edSStella Laurenzo SmallVector<int64_t> shape; 5795d6d30edSStella Laurenzo if (explicitShape) { 5805d6d30edSStella Laurenzo shape.append(explicitShape->begin(), explicitShape->end()); 5815d6d30edSStella Laurenzo } else { 5825d6d30edSStella Laurenzo shape.append(arrayInfo.shape.begin(), 5835d6d30edSStella Laurenzo arrayInfo.shape.begin() + arrayInfo.ndim); 5845d6d30edSStella Laurenzo } 585436c6c9cSStella Laurenzo 5865d6d30edSStella Laurenzo MlirAttribute encodingAttr = mlirAttributeGetNull(); 587436c6c9cSStella Laurenzo MlirContext context = contextWrapper->get(); 5885d6d30edSStella Laurenzo 5895d6d30edSStella Laurenzo // Detect format codes that are suitable for bulk loading. This includes 5905d6d30edSStella Laurenzo // all byte aligned integer and floating point types up to 8 bytes. 5915d6d30edSStella Laurenzo // Notably, this excludes, bool (which needs to be bit-packed) and 5925d6d30edSStella Laurenzo // other exotics which do not have a direct representation in the buffer 5935d6d30edSStella Laurenzo // protocol (i.e. complex, etc). 5940a81ace0SKazu Hirata std::optional<MlirType> bulkLoadElementType; 5955d6d30edSStella Laurenzo if (explicitType) { 5965d6d30edSStella Laurenzo bulkLoadElementType = *explicitType; 5975d6d30edSStella Laurenzo } else if (arrayInfo.format == "f") { 598436c6c9cSStella Laurenzo // f32 599436c6c9cSStella Laurenzo assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); 6005d6d30edSStella Laurenzo bulkLoadElementType = mlirF32TypeGet(context); 601436c6c9cSStella Laurenzo } else if (arrayInfo.format == "d") { 602436c6c9cSStella Laurenzo // f64 603436c6c9cSStella Laurenzo assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); 6045d6d30edSStella Laurenzo bulkLoadElementType = mlirF64TypeGet(context); 6055d6d30edSStella Laurenzo } else if (arrayInfo.format == "e") { 6065d6d30edSStella Laurenzo // f16 6075d6d30edSStella Laurenzo assert(arrayInfo.itemsize == 2 && "mismatched array itemsize"); 6085d6d30edSStella Laurenzo bulkLoadElementType = mlirF16TypeGet(context); 609436c6c9cSStella Laurenzo } else if (isSignedIntegerFormat(arrayInfo.format)) { 610436c6c9cSStella Laurenzo if (arrayInfo.itemsize == 4) { 611436c6c9cSStella Laurenzo // i32 6125d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32) 613436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 32); 614436c6c9cSStella Laurenzo } else if (arrayInfo.itemsize == 8) { 615436c6c9cSStella Laurenzo // i64 6165d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64) 617436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 64); 6185d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 1) { 6195d6d30edSStella Laurenzo // i8 6205d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 6215d6d30edSStella Laurenzo : mlirIntegerTypeSignedGet(context, 8); 6225d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 2) { 6235d6d30edSStella Laurenzo // i16 6245d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16) 6255d6d30edSStella Laurenzo : mlirIntegerTypeSignedGet(context, 16); 626436c6c9cSStella Laurenzo } 627436c6c9cSStella Laurenzo } else if (isUnsignedIntegerFormat(arrayInfo.format)) { 628436c6c9cSStella Laurenzo if (arrayInfo.itemsize == 4) { 629436c6c9cSStella Laurenzo // unsigned i32 6305d6d30edSStella Laurenzo bulkLoadElementType = signless 631436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 32) 632436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 32); 633436c6c9cSStella Laurenzo } else if (arrayInfo.itemsize == 8) { 634436c6c9cSStella Laurenzo // unsigned i64 6355d6d30edSStella Laurenzo bulkLoadElementType = signless 636436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 64) 637436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 64); 6385d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 1) { 6395d6d30edSStella Laurenzo // i8 6405d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 6415d6d30edSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 8); 6425d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 2) { 6435d6d30edSStella Laurenzo // i16 6445d6d30edSStella Laurenzo bulkLoadElementType = signless 6455d6d30edSStella Laurenzo ? mlirIntegerTypeGet(context, 16) 6465d6d30edSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 16); 647436c6c9cSStella Laurenzo } 648436c6c9cSStella Laurenzo } 6495d6d30edSStella Laurenzo if (bulkLoadElementType) { 65099dee31eSAdam Paszke MlirType shapedType; 65199dee31eSAdam Paszke if (mlirTypeIsAShaped(*bulkLoadElementType)) { 65299dee31eSAdam Paszke if (explicitShape) { 65399dee31eSAdam Paszke throw std::invalid_argument("Shape can only be specified explicitly " 65499dee31eSAdam Paszke "when the type is not a shaped type."); 65599dee31eSAdam Paszke } 65699dee31eSAdam Paszke shapedType = *bulkLoadElementType; 65799dee31eSAdam Paszke } else { 65899dee31eSAdam Paszke shapedType = mlirRankedTensorTypeGet( 6595d6d30edSStella Laurenzo shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); 66099dee31eSAdam Paszke } 6615d6d30edSStella Laurenzo size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize; 6625d6d30edSStella Laurenzo MlirAttribute attr = mlirDenseElementsAttrRawBufferGet( 6635d6d30edSStella Laurenzo shapedType, rawBufferSize, arrayInfo.ptr); 6645d6d30edSStella Laurenzo if (mlirAttributeIsNull(attr)) { 6655d6d30edSStella Laurenzo throw std::invalid_argument( 6665d6d30edSStella Laurenzo "DenseElementsAttr could not be constructed from the given buffer. " 6675d6d30edSStella Laurenzo "This may mean that the Python buffer layout does not match that " 6685d6d30edSStella Laurenzo "MLIR expected layout and is a bug."); 6695d6d30edSStella Laurenzo } 6705d6d30edSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), attr); 6715d6d30edSStella Laurenzo } 672436c6c9cSStella Laurenzo 6735d6d30edSStella Laurenzo throw std::invalid_argument( 6745d6d30edSStella Laurenzo std::string("unimplemented array format conversion from format: ") + 6755d6d30edSStella Laurenzo arrayInfo.format); 676436c6c9cSStella Laurenzo } 677436c6c9cSStella Laurenzo 6781fc096afSMehdi Amini static PyDenseElementsAttribute getSplat(const PyType &shapedType, 679436c6c9cSStella Laurenzo PyAttribute &elementAttr) { 680436c6c9cSStella Laurenzo auto contextWrapper = 681436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 682436c6c9cSStella Laurenzo if (!mlirAttributeIsAInteger(elementAttr) && 683436c6c9cSStella Laurenzo !mlirAttributeIsAFloat(elementAttr)) { 684436c6c9cSStella Laurenzo std::string message = "Illegal element type for DenseElementsAttr: "; 685436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 6864811270bSmax throw py::value_error(message); 687436c6c9cSStella Laurenzo } 688436c6c9cSStella Laurenzo if (!mlirTypeIsAShaped(shapedType) || 689436c6c9cSStella Laurenzo !mlirShapedTypeHasStaticShape(shapedType)) { 690436c6c9cSStella Laurenzo std::string message = 691436c6c9cSStella Laurenzo "Expected a static ShapedType for the shaped_type parameter: "; 692436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 6934811270bSmax throw py::value_error(message); 694436c6c9cSStella Laurenzo } 695436c6c9cSStella Laurenzo MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 696436c6c9cSStella Laurenzo MlirType attrType = mlirAttributeGetType(elementAttr); 697436c6c9cSStella Laurenzo if (!mlirTypeEqual(shapedElementType, attrType)) { 698436c6c9cSStella Laurenzo std::string message = 699436c6c9cSStella Laurenzo "Shaped element type and attribute type must be equal: shaped="; 700436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 701436c6c9cSStella Laurenzo message.append(", element="); 702436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 7034811270bSmax throw py::value_error(message); 704436c6c9cSStella Laurenzo } 705436c6c9cSStella Laurenzo 706436c6c9cSStella Laurenzo MlirAttribute elements = 707436c6c9cSStella Laurenzo mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 708436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 709436c6c9cSStella Laurenzo } 710436c6c9cSStella Laurenzo 711436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 712436c6c9cSStella Laurenzo 713436c6c9cSStella Laurenzo py::buffer_info accessBuffer() { 714436c6c9cSStella Laurenzo MlirType shapedType = mlirAttributeGetType(*this); 715436c6c9cSStella Laurenzo MlirType elementType = mlirShapedTypeGetElementType(shapedType); 7165d6d30edSStella Laurenzo std::string format; 717436c6c9cSStella Laurenzo 718436c6c9cSStella Laurenzo if (mlirTypeIsAF32(elementType)) { 719436c6c9cSStella Laurenzo // f32 7205d6d30edSStella Laurenzo return bufferInfo<float>(shapedType); 72102b6fb21SMehdi Amini } 72202b6fb21SMehdi Amini if (mlirTypeIsAF64(elementType)) { 723436c6c9cSStella Laurenzo // f64 7245d6d30edSStella Laurenzo return bufferInfo<double>(shapedType); 725bb56c2b3SMehdi Amini } 726bb56c2b3SMehdi Amini if (mlirTypeIsAF16(elementType)) { 7275d6d30edSStella Laurenzo // f16 7285d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType, "e"); 729bb56c2b3SMehdi Amini } 730ef1b735dSmax if (mlirTypeIsAIndex(elementType)) { 731ef1b735dSmax // Same as IndexType::kInternalStorageBitWidth 732ef1b735dSmax return bufferInfo<int64_t>(shapedType); 733ef1b735dSmax } 734bb56c2b3SMehdi Amini if (mlirTypeIsAInteger(elementType) && 735436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 32) { 736436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 737436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 738436c6c9cSStella Laurenzo // i32 7395d6d30edSStella Laurenzo return bufferInfo<int32_t>(shapedType); 740e5639b3fSMehdi Amini } 741e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 742436c6c9cSStella Laurenzo // unsigned i32 7435d6d30edSStella Laurenzo return bufferInfo<uint32_t>(shapedType); 744436c6c9cSStella Laurenzo } 745436c6c9cSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 746436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 64) { 747436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 748436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 749436c6c9cSStella Laurenzo // i64 7505d6d30edSStella Laurenzo return bufferInfo<int64_t>(shapedType); 751e5639b3fSMehdi Amini } 752e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 753436c6c9cSStella Laurenzo // unsigned i64 7545d6d30edSStella Laurenzo return bufferInfo<uint64_t>(shapedType); 7555d6d30edSStella Laurenzo } 7565d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 7575d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 8) { 7585d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 7595d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 7605d6d30edSStella Laurenzo // i8 7615d6d30edSStella Laurenzo return bufferInfo<int8_t>(shapedType); 762e5639b3fSMehdi Amini } 763e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 7645d6d30edSStella Laurenzo // unsigned i8 7655d6d30edSStella Laurenzo return bufferInfo<uint8_t>(shapedType); 7665d6d30edSStella Laurenzo } 7675d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 7685d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 16) { 7695d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 7705d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 7715d6d30edSStella Laurenzo // i16 7725d6d30edSStella Laurenzo return bufferInfo<int16_t>(shapedType); 773e5639b3fSMehdi Amini } 774e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 7755d6d30edSStella Laurenzo // unsigned i16 7765d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType); 777436c6c9cSStella Laurenzo } 778436c6c9cSStella Laurenzo } 779436c6c9cSStella Laurenzo 780c5f445d1SStella Laurenzo // TODO: Currently crashes the program. 7815d6d30edSStella Laurenzo // Reported as https://github.com/pybind/pybind11/issues/3336 782c5f445d1SStella Laurenzo throw std::invalid_argument( 783c5f445d1SStella Laurenzo "unsupported data type for conversion to Python buffer"); 784436c6c9cSStella Laurenzo } 785436c6c9cSStella Laurenzo 786436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 787436c6c9cSStella Laurenzo c.def("__len__", &PyDenseElementsAttribute::dunderLen) 788436c6c9cSStella Laurenzo .def_static("get", PyDenseElementsAttribute::getFromBuffer, 789436c6c9cSStella Laurenzo py::arg("array"), py::arg("signless") = true, 7905d6d30edSStella Laurenzo py::arg("type") = py::none(), py::arg("shape") = py::none(), 791436c6c9cSStella Laurenzo py::arg("context") = py::none(), 7925d6d30edSStella Laurenzo kDenseElementsAttrGetDocstring) 793436c6c9cSStella Laurenzo .def_static("get_splat", PyDenseElementsAttribute::getSplat, 794436c6c9cSStella Laurenzo py::arg("shaped_type"), py::arg("element_attr"), 795436c6c9cSStella Laurenzo "Gets a DenseElementsAttr where all values are the same") 796436c6c9cSStella Laurenzo .def_property_readonly("is_splat", 797436c6c9cSStella Laurenzo [](PyDenseElementsAttribute &self) -> bool { 798436c6c9cSStella Laurenzo return mlirDenseElementsAttrIsSplat(self); 799436c6c9cSStella Laurenzo }) 80091259963SAdam Paszke .def("get_splat_value", 80191259963SAdam Paszke [](PyDenseElementsAttribute &self) -> PyAttribute { 80291259963SAdam Paszke if (!mlirDenseElementsAttrIsSplat(self)) { 8034811270bSmax throw py::value_error( 80491259963SAdam Paszke "get_splat_value called on a non-splat attribute"); 80591259963SAdam Paszke } 80691259963SAdam Paszke return PyAttribute(self.getContext(), 80791259963SAdam Paszke mlirDenseElementsAttrGetSplatValue(self)); 80891259963SAdam Paszke }) 809436c6c9cSStella Laurenzo .def_buffer(&PyDenseElementsAttribute::accessBuffer); 810436c6c9cSStella Laurenzo } 811436c6c9cSStella Laurenzo 812436c6c9cSStella Laurenzo private: 813436c6c9cSStella Laurenzo static bool isUnsignedIntegerFormat(const std::string &format) { 814436c6c9cSStella Laurenzo if (format.empty()) 815436c6c9cSStella Laurenzo return false; 816436c6c9cSStella Laurenzo char code = format[0]; 817436c6c9cSStella Laurenzo return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 818436c6c9cSStella Laurenzo code == 'Q'; 819436c6c9cSStella Laurenzo } 820436c6c9cSStella Laurenzo 821436c6c9cSStella Laurenzo static bool isSignedIntegerFormat(const std::string &format) { 822436c6c9cSStella Laurenzo if (format.empty()) 823436c6c9cSStella Laurenzo return false; 824436c6c9cSStella Laurenzo char code = format[0]; 825436c6c9cSStella Laurenzo return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 826436c6c9cSStella Laurenzo code == 'q'; 827436c6c9cSStella Laurenzo } 828436c6c9cSStella Laurenzo 829436c6c9cSStella Laurenzo template <typename Type> 830436c6c9cSStella Laurenzo py::buffer_info bufferInfo(MlirType shapedType, 8315d6d30edSStella Laurenzo const char *explicitFormat = nullptr) { 832436c6c9cSStella Laurenzo intptr_t rank = mlirShapedTypeGetRank(shapedType); 833436c6c9cSStella Laurenzo // Prepare the data for the buffer_info. 834436c6c9cSStella Laurenzo // Buffer is configured for read-only access below. 835436c6c9cSStella Laurenzo Type *data = static_cast<Type *>( 836436c6c9cSStella Laurenzo const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 837436c6c9cSStella Laurenzo // Prepare the shape for the buffer_info. 838436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> shape; 839436c6c9cSStella Laurenzo for (intptr_t i = 0; i < rank; ++i) 840436c6c9cSStella Laurenzo shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 841436c6c9cSStella Laurenzo // Prepare the strides for the buffer_info. 842436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> strides; 843f0e847d0SRahul Kayaith if (mlirDenseElementsAttrIsSplat(*this)) { 844f0e847d0SRahul Kayaith // Splats are special, only the single value is stored. 845f0e847d0SRahul Kayaith strides.assign(rank, 0); 846f0e847d0SRahul Kayaith } else { 847436c6c9cSStella Laurenzo for (intptr_t i = 1; i < rank; ++i) { 848f0e847d0SRahul Kayaith intptr_t strideFactor = 1; 849f0e847d0SRahul Kayaith for (intptr_t j = i; j < rank; ++j) 850436c6c9cSStella Laurenzo strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 851436c6c9cSStella Laurenzo strides.push_back(sizeof(Type) * strideFactor); 852436c6c9cSStella Laurenzo } 853436c6c9cSStella Laurenzo strides.push_back(sizeof(Type)); 854f0e847d0SRahul Kayaith } 8555d6d30edSStella Laurenzo std::string format; 8565d6d30edSStella Laurenzo if (explicitFormat) { 8575d6d30edSStella Laurenzo format = explicitFormat; 8585d6d30edSStella Laurenzo } else { 8595d6d30edSStella Laurenzo format = py::format_descriptor<Type>::format(); 8605d6d30edSStella Laurenzo } 8615d6d30edSStella Laurenzo return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, 8625d6d30edSStella Laurenzo /*readonly=*/true); 863436c6c9cSStella Laurenzo } 864436c6c9cSStella Laurenzo }; // namespace 865436c6c9cSStella Laurenzo 866436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer 867436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access. 868436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute 869436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseIntElementsAttribute, 870436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 871436c6c9cSStella Laurenzo public: 872436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 873436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseIntElementsAttr"; 874436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 875436c6c9cSStella Laurenzo 876436c6c9cSStella Laurenzo /// Returns the element at the given linear position. Asserts if the index is 877436c6c9cSStella Laurenzo /// out of range. 878436c6c9cSStella Laurenzo py::int_ dunderGetItem(intptr_t pos) { 879436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 8804811270bSmax throw py::index_error("attempt to access out of bounds element"); 881436c6c9cSStella Laurenzo } 882436c6c9cSStella Laurenzo 883436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 884436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 885436c6c9cSStella Laurenzo assert(mlirTypeIsAInteger(type) && 886436c6c9cSStella Laurenzo "expected integer element type in dense int elements attribute"); 887436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 888436c6c9cSStella Laurenzo // elemental type of the attribute. py::int_ is implicitly constructible 889436c6c9cSStella Laurenzo // from any C++ integral type and handles bitwidth correctly. 890436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 891436c6c9cSStella Laurenzo // querying them on each element access. 892436c6c9cSStella Laurenzo unsigned width = mlirIntegerTypeGetWidth(type); 893436c6c9cSStella Laurenzo bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 894436c6c9cSStella Laurenzo if (isUnsigned) { 895436c6c9cSStella Laurenzo if (width == 1) { 896436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 897436c6c9cSStella Laurenzo } 898308d8b8cSRahul Kayaith if (width == 8) { 899308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetUInt8Value(*this, pos); 900308d8b8cSRahul Kayaith } 901308d8b8cSRahul Kayaith if (width == 16) { 902308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetUInt16Value(*this, pos); 903308d8b8cSRahul Kayaith } 904436c6c9cSStella Laurenzo if (width == 32) { 905436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt32Value(*this, pos); 906436c6c9cSStella Laurenzo } 907436c6c9cSStella Laurenzo if (width == 64) { 908436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt64Value(*this, pos); 909436c6c9cSStella Laurenzo } 910436c6c9cSStella Laurenzo } else { 911436c6c9cSStella Laurenzo if (width == 1) { 912436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 913436c6c9cSStella Laurenzo } 914308d8b8cSRahul Kayaith if (width == 8) { 915308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetInt8Value(*this, pos); 916308d8b8cSRahul Kayaith } 917308d8b8cSRahul Kayaith if (width == 16) { 918308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetInt16Value(*this, pos); 919308d8b8cSRahul Kayaith } 920436c6c9cSStella Laurenzo if (width == 32) { 921436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt32Value(*this, pos); 922436c6c9cSStella Laurenzo } 923436c6c9cSStella Laurenzo if (width == 64) { 924436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt64Value(*this, pos); 925436c6c9cSStella Laurenzo } 926436c6c9cSStella Laurenzo } 9274811270bSmax throw py::type_error("Unsupported integer type"); 928436c6c9cSStella Laurenzo } 929436c6c9cSStella Laurenzo 930436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 931436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 932436c6c9cSStella Laurenzo } 933436c6c9cSStella Laurenzo }; 934436c6c9cSStella Laurenzo 935436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 936436c6c9cSStella Laurenzo public: 937436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 938436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DictAttr"; 939436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 940*9566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 941*9566ee28Smax mlirDictionaryAttrGetTypeID; 942436c6c9cSStella Laurenzo 943436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 944436c6c9cSStella Laurenzo 9459fb1086bSAdrian Kuegel bool dunderContains(const std::string &name) { 9469fb1086bSAdrian Kuegel return !mlirAttributeIsNull( 9479fb1086bSAdrian Kuegel mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); 9489fb1086bSAdrian Kuegel } 9499fb1086bSAdrian Kuegel 950436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 9519fb1086bSAdrian Kuegel c.def("__contains__", &PyDictAttribute::dunderContains); 952436c6c9cSStella Laurenzo c.def("__len__", &PyDictAttribute::dunderLen); 953436c6c9cSStella Laurenzo c.def_static( 954436c6c9cSStella Laurenzo "get", 955436c6c9cSStella Laurenzo [](py::dict attributes, DefaultingPyMlirContext context) { 956436c6c9cSStella Laurenzo SmallVector<MlirNamedAttribute> mlirNamedAttributes; 957436c6c9cSStella Laurenzo mlirNamedAttributes.reserve(attributes.size()); 958436c6c9cSStella Laurenzo for (auto &it : attributes) { 95902b6fb21SMehdi Amini auto &mlirAttr = it.second.cast<PyAttribute &>(); 960436c6c9cSStella Laurenzo auto name = it.first.cast<std::string>(); 961436c6c9cSStella Laurenzo mlirNamedAttributes.push_back(mlirNamedAttributeGet( 96202b6fb21SMehdi Amini mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), 963436c6c9cSStella Laurenzo toMlirStringRef(name)), 96402b6fb21SMehdi Amini mlirAttr)); 965436c6c9cSStella Laurenzo } 966436c6c9cSStella Laurenzo MlirAttribute attr = 967436c6c9cSStella Laurenzo mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 968436c6c9cSStella Laurenzo mlirNamedAttributes.data()); 969436c6c9cSStella Laurenzo return PyDictAttribute(context->getRef(), attr); 970436c6c9cSStella Laurenzo }, 971ed9e52f3SAlex Zinenko py::arg("value") = py::dict(), py::arg("context") = py::none(), 972436c6c9cSStella Laurenzo "Gets an uniqued dict attribute"); 973436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 974436c6c9cSStella Laurenzo MlirAttribute attr = 975436c6c9cSStella Laurenzo mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 976436c6c9cSStella Laurenzo if (mlirAttributeIsNull(attr)) { 9774811270bSmax throw py::key_error("attempt to access a non-existent attribute"); 978436c6c9cSStella Laurenzo } 979436c6c9cSStella Laurenzo return PyAttribute(self.getContext(), attr); 980436c6c9cSStella Laurenzo }); 981436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 982436c6c9cSStella Laurenzo if (index < 0 || index >= self.dunderLen()) { 9834811270bSmax throw py::index_error("attempt to access out of bounds attribute"); 984436c6c9cSStella Laurenzo } 985436c6c9cSStella Laurenzo MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 986436c6c9cSStella Laurenzo return PyNamedAttribute( 987436c6c9cSStella Laurenzo namedAttr.attribute, 988436c6c9cSStella Laurenzo std::string(mlirIdentifierStr(namedAttr.name).data)); 989436c6c9cSStella Laurenzo }); 990436c6c9cSStella Laurenzo } 991436c6c9cSStella Laurenzo }; 992436c6c9cSStella Laurenzo 993436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing 994436c6c9cSStella Laurenzo /// floating-point values. Supports element access. 995436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute 996436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseFPElementsAttribute, 997436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 998436c6c9cSStella Laurenzo public: 999436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 1000436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseFPElementsAttr"; 1001436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1002436c6c9cSStella Laurenzo 1003436c6c9cSStella Laurenzo py::float_ dunderGetItem(intptr_t pos) { 1004436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 10054811270bSmax throw py::index_error("attempt to access out of bounds element"); 1006436c6c9cSStella Laurenzo } 1007436c6c9cSStella Laurenzo 1008436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 1009436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 1010436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 1011436c6c9cSStella Laurenzo // elemental type of the attribute. py::float_ is implicitly constructible 1012436c6c9cSStella Laurenzo // from float and double. 1013436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 1014436c6c9cSStella Laurenzo // querying them on each element access. 1015436c6c9cSStella Laurenzo if (mlirTypeIsAF32(type)) { 1016436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetFloatValue(*this, pos); 1017436c6c9cSStella Laurenzo } 1018436c6c9cSStella Laurenzo if (mlirTypeIsAF64(type)) { 1019436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetDoubleValue(*this, pos); 1020436c6c9cSStella Laurenzo } 10214811270bSmax throw py::type_error("Unsupported floating-point type"); 1022436c6c9cSStella Laurenzo } 1023436c6c9cSStella Laurenzo 1024436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1025436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 1026436c6c9cSStella Laurenzo } 1027436c6c9cSStella Laurenzo }; 1028436c6c9cSStella Laurenzo 1029436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 1030436c6c9cSStella Laurenzo public: 1031436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 1032436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "TypeAttr"; 1033436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1034*9566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 1035*9566ee28Smax mlirTypeAttrGetTypeID; 1036436c6c9cSStella Laurenzo 1037436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1038436c6c9cSStella Laurenzo c.def_static( 1039436c6c9cSStella Laurenzo "get", 1040436c6c9cSStella Laurenzo [](PyType value, DefaultingPyMlirContext context) { 1041436c6c9cSStella Laurenzo MlirAttribute attr = mlirTypeAttrGet(value.get()); 1042436c6c9cSStella Laurenzo return PyTypeAttribute(context->getRef(), attr); 1043436c6c9cSStella Laurenzo }, 1044436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 1045436c6c9cSStella Laurenzo "Gets a uniqued Type attribute"); 1046436c6c9cSStella Laurenzo c.def_property_readonly("value", [](PyTypeAttribute &self) { 1047bfb1ba75Smax return mlirTypeAttrGetValue(self.get()); 1048436c6c9cSStella Laurenzo }); 1049436c6c9cSStella Laurenzo } 1050436c6c9cSStella Laurenzo }; 1051436c6c9cSStella Laurenzo 1052436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values. 1053436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 1054436c6c9cSStella Laurenzo public: 1055436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 1056436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "UnitAttr"; 1057436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1058*9566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 1059*9566ee28Smax mlirUnitAttrGetTypeID; 1060436c6c9cSStella Laurenzo 1061436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1062436c6c9cSStella Laurenzo c.def_static( 1063436c6c9cSStella Laurenzo "get", 1064436c6c9cSStella Laurenzo [](DefaultingPyMlirContext context) { 1065436c6c9cSStella Laurenzo return PyUnitAttribute(context->getRef(), 1066436c6c9cSStella Laurenzo mlirUnitAttrGet(context->get())); 1067436c6c9cSStella Laurenzo }, 1068436c6c9cSStella Laurenzo py::arg("context") = py::none(), "Create a Unit attribute."); 1069436c6c9cSStella Laurenzo } 1070436c6c9cSStella Laurenzo }; 1071436c6c9cSStella Laurenzo 1072ac2e2d65SDenys Shabalin /// Strided layout attribute subclass. 1073ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute 1074ac2e2d65SDenys Shabalin : public PyConcreteAttribute<PyStridedLayoutAttribute> { 1075ac2e2d65SDenys Shabalin public: 1076ac2e2d65SDenys Shabalin static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; 1077ac2e2d65SDenys Shabalin static constexpr const char *pyClassName = "StridedLayoutAttr"; 1078ac2e2d65SDenys Shabalin using PyConcreteAttribute::PyConcreteAttribute; 1079*9566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 1080*9566ee28Smax mlirStridedLayoutAttrGetTypeID; 1081ac2e2d65SDenys Shabalin 1082ac2e2d65SDenys Shabalin static void bindDerived(ClassTy &c) { 1083ac2e2d65SDenys Shabalin c.def_static( 1084ac2e2d65SDenys Shabalin "get", 1085ac2e2d65SDenys Shabalin [](int64_t offset, const std::vector<int64_t> strides, 1086ac2e2d65SDenys Shabalin DefaultingPyMlirContext ctx) { 1087ac2e2d65SDenys Shabalin MlirAttribute attr = mlirStridedLayoutAttrGet( 1088ac2e2d65SDenys Shabalin ctx->get(), offset, strides.size(), strides.data()); 1089ac2e2d65SDenys Shabalin return PyStridedLayoutAttribute(ctx->getRef(), attr); 1090ac2e2d65SDenys Shabalin }, 1091ac2e2d65SDenys Shabalin py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(), 1092ac2e2d65SDenys Shabalin "Gets a strided layout attribute."); 1093e3fd612eSDenys Shabalin c.def_static( 1094e3fd612eSDenys Shabalin "get_fully_dynamic", 1095e3fd612eSDenys Shabalin [](int64_t rank, DefaultingPyMlirContext ctx) { 1096e3fd612eSDenys Shabalin auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); 1097e3fd612eSDenys Shabalin std::vector<int64_t> strides(rank); 1098e3fd612eSDenys Shabalin std::fill(strides.begin(), strides.end(), dynamic); 1099e3fd612eSDenys Shabalin MlirAttribute attr = mlirStridedLayoutAttrGet( 1100e3fd612eSDenys Shabalin ctx->get(), dynamic, strides.size(), strides.data()); 1101e3fd612eSDenys Shabalin return PyStridedLayoutAttribute(ctx->getRef(), attr); 1102e3fd612eSDenys Shabalin }, 1103e3fd612eSDenys Shabalin py::arg("rank"), py::arg("context") = py::none(), 1104e3fd612eSDenys Shabalin "Gets a strided layout attribute with dynamic offset and strides of a " 1105e3fd612eSDenys Shabalin "given rank."); 1106ac2e2d65SDenys Shabalin c.def_property_readonly( 1107ac2e2d65SDenys Shabalin "offset", 1108ac2e2d65SDenys Shabalin [](PyStridedLayoutAttribute &self) { 1109ac2e2d65SDenys Shabalin return mlirStridedLayoutAttrGetOffset(self); 1110ac2e2d65SDenys Shabalin }, 1111ac2e2d65SDenys Shabalin "Returns the value of the float point attribute"); 1112ac2e2d65SDenys Shabalin c.def_property_readonly( 1113ac2e2d65SDenys Shabalin "strides", 1114ac2e2d65SDenys Shabalin [](PyStridedLayoutAttribute &self) { 1115ac2e2d65SDenys Shabalin intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); 1116ac2e2d65SDenys Shabalin std::vector<int64_t> strides(size); 1117ac2e2d65SDenys Shabalin for (intptr_t i = 0; i < size; i++) { 1118ac2e2d65SDenys Shabalin strides[i] = mlirStridedLayoutAttrGetStride(self, i); 1119ac2e2d65SDenys Shabalin } 1120ac2e2d65SDenys Shabalin return strides; 1121ac2e2d65SDenys Shabalin }, 1122ac2e2d65SDenys Shabalin "Returns the value of the float point attribute"); 1123ac2e2d65SDenys Shabalin } 1124ac2e2d65SDenys Shabalin }; 1125ac2e2d65SDenys Shabalin 1126*9566ee28Smax py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { 1127*9566ee28Smax if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) 1128*9566ee28Smax return py::cast(PyDenseBoolArrayAttribute(pyAttribute)); 1129*9566ee28Smax if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) 1130*9566ee28Smax return py::cast(PyDenseI8ArrayAttribute(pyAttribute)); 1131*9566ee28Smax if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) 1132*9566ee28Smax return py::cast(PyDenseI16ArrayAttribute(pyAttribute)); 1133*9566ee28Smax if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) 1134*9566ee28Smax return py::cast(PyDenseI32ArrayAttribute(pyAttribute)); 1135*9566ee28Smax if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) 1136*9566ee28Smax return py::cast(PyDenseI64ArrayAttribute(pyAttribute)); 1137*9566ee28Smax if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) 1138*9566ee28Smax return py::cast(PyDenseF32ArrayAttribute(pyAttribute)); 1139*9566ee28Smax if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) 1140*9566ee28Smax return py::cast(PyDenseF64ArrayAttribute(pyAttribute)); 1141*9566ee28Smax std::string msg = 1142*9566ee28Smax std::string("Can't cast unknown element type DenseArrayAttr (") + 1143*9566ee28Smax std::string(py::repr(py::cast(pyAttribute))) + ")"; 1144*9566ee28Smax throw py::cast_error(msg); 1145*9566ee28Smax } 1146*9566ee28Smax 1147*9566ee28Smax py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { 1148*9566ee28Smax if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) 1149*9566ee28Smax return py::cast(PyDenseFPElementsAttribute(pyAttribute)); 1150*9566ee28Smax if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) 1151*9566ee28Smax return py::cast(PyDenseIntElementsAttribute(pyAttribute)); 1152*9566ee28Smax std::string msg = 1153*9566ee28Smax std::string( 1154*9566ee28Smax "Can't cast unknown element type DenseIntOrFPElementsAttr (") + 1155*9566ee28Smax std::string(py::repr(py::cast(pyAttribute))) + ")"; 1156*9566ee28Smax throw py::cast_error(msg); 1157*9566ee28Smax } 1158*9566ee28Smax 1159*9566ee28Smax py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { 1160*9566ee28Smax if (PyBoolAttribute::isaFunction(pyAttribute)) 1161*9566ee28Smax return py::cast(PyBoolAttribute(pyAttribute)); 1162*9566ee28Smax if (PyIntegerAttribute::isaFunction(pyAttribute)) 1163*9566ee28Smax return py::cast(PyIntegerAttribute(pyAttribute)); 1164*9566ee28Smax std::string msg = 1165*9566ee28Smax std::string("Can't cast unknown element type DenseArrayAttr (") + 1166*9566ee28Smax std::string(py::repr(py::cast(pyAttribute))) + ")"; 1167*9566ee28Smax throw py::cast_error(msg); 1168*9566ee28Smax } 1169*9566ee28Smax 1170436c6c9cSStella Laurenzo } // namespace 1171436c6c9cSStella Laurenzo 1172436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) { 1173436c6c9cSStella Laurenzo PyAffineMapAttribute::bind(m); 1174619fd8c2SJeff Niu 1175619fd8c2SJeff Niu PyDenseBoolArrayAttribute::bind(m); 1176619fd8c2SJeff Niu PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); 1177619fd8c2SJeff Niu PyDenseI8ArrayAttribute::bind(m); 1178619fd8c2SJeff Niu PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m); 1179619fd8c2SJeff Niu PyDenseI16ArrayAttribute::bind(m); 1180619fd8c2SJeff Niu PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m); 1181619fd8c2SJeff Niu PyDenseI32ArrayAttribute::bind(m); 1182619fd8c2SJeff Niu PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m); 1183619fd8c2SJeff Niu PyDenseI64ArrayAttribute::bind(m); 1184619fd8c2SJeff Niu PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m); 1185619fd8c2SJeff Niu PyDenseF32ArrayAttribute::bind(m); 1186619fd8c2SJeff Niu PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); 1187619fd8c2SJeff Niu PyDenseF64ArrayAttribute::bind(m); 1188619fd8c2SJeff Niu PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); 1189*9566ee28Smax PyGlobals::get().registerTypeCaster( 1190*9566ee28Smax mlirDenseArrayAttrGetTypeID(), 1191*9566ee28Smax pybind11::cpp_function(denseArrayAttributeCaster)); 1192619fd8c2SJeff Niu 1193436c6c9cSStella Laurenzo PyArrayAttribute::bind(m); 1194436c6c9cSStella Laurenzo PyArrayAttribute::PyArrayAttributeIterator::bind(m); 1195436c6c9cSStella Laurenzo PyBoolAttribute::bind(m); 1196436c6c9cSStella Laurenzo PyDenseElementsAttribute::bind(m); 1197436c6c9cSStella Laurenzo PyDenseFPElementsAttribute::bind(m); 1198436c6c9cSStella Laurenzo PyDenseIntElementsAttribute::bind(m); 1199*9566ee28Smax PyGlobals::get().registerTypeCaster( 1200*9566ee28Smax mlirDenseIntOrFPElementsAttrGetTypeID(), 1201*9566ee28Smax pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); 1202*9566ee28Smax 1203436c6c9cSStella Laurenzo PyDictAttribute::bind(m); 1204436c6c9cSStella Laurenzo PyFlatSymbolRefAttribute::bind(m); 12055c3861b2SYun Long PyOpaqueAttribute::bind(m); 1206436c6c9cSStella Laurenzo PyFloatAttribute::bind(m); 1207436c6c9cSStella Laurenzo PyIntegerAttribute::bind(m); 1208436c6c9cSStella Laurenzo PyStringAttribute::bind(m); 1209436c6c9cSStella Laurenzo PyTypeAttribute::bind(m); 1210*9566ee28Smax PyGlobals::get().registerTypeCaster( 1211*9566ee28Smax mlirIntegerAttrGetTypeID(), 1212*9566ee28Smax pybind11::cpp_function(integerOrBoolAttributeCaster)); 1213436c6c9cSStella Laurenzo PyUnitAttribute::bind(m); 1214ac2e2d65SDenys Shabalin 1215ac2e2d65SDenys Shabalin PyStridedLayoutAttribute::bind(m); 1216436c6c9cSStella Laurenzo } 1217