1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2436c6c9cSStella Laurenzo // 3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6436c6c9cSStella Laurenzo // 7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8436c6c9cSStella Laurenzo 9436c6c9cSStella Laurenzo #include "IRModule.h" 10436c6c9cSStella Laurenzo 11436c6c9cSStella Laurenzo #include "PybindUtils.h" 12436c6c9cSStella Laurenzo 13436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h" 14436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h" 15436c6c9cSStella Laurenzo 16436c6c9cSStella Laurenzo namespace py = pybind11; 17436c6c9cSStella Laurenzo using namespace mlir; 18436c6c9cSStella Laurenzo using namespace mlir::python; 19436c6c9cSStella Laurenzo 20436c6c9cSStella Laurenzo using llvm::SmallVector; 21436c6c9cSStella Laurenzo using llvm::StringRef; 22436c6c9cSStella Laurenzo using llvm::Twine; 23436c6c9cSStella Laurenzo 24436c6c9cSStella Laurenzo namespace { 25436c6c9cSStella Laurenzo 26436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) { 27436c6c9cSStella Laurenzo return mlirStringRefCreate(s.data(), s.size()); 28436c6c9cSStella Laurenzo } 29436c6c9cSStella Laurenzo 30436c6c9cSStella Laurenzo /// CRTP base classes for Python attributes that subclass Attribute and should 31436c6c9cSStella Laurenzo /// be castable from it (i.e. via something like StringAttr(attr)). 32436c6c9cSStella Laurenzo /// By default, attribute class hierarchies are one level deep (i.e. a 33436c6c9cSStella Laurenzo /// concrete attribute class extends PyAttribute); however, intermediate 34436c6c9cSStella Laurenzo /// python-visible base classes can be modeled by specifying a BaseTy. 35436c6c9cSStella Laurenzo template <typename DerivedTy, typename BaseTy = PyAttribute> 36436c6c9cSStella Laurenzo class PyConcreteAttribute : public BaseTy { 37436c6c9cSStella Laurenzo public: 38436c6c9cSStella Laurenzo // Derived classes must define statics for: 39436c6c9cSStella Laurenzo // IsAFunctionTy isaFunction 40436c6c9cSStella Laurenzo // const char *pyClassName 41436c6c9cSStella Laurenzo using ClassTy = py::class_<DerivedTy, BaseTy>; 42436c6c9cSStella Laurenzo using IsAFunctionTy = bool (*)(MlirAttribute); 43436c6c9cSStella Laurenzo 44436c6c9cSStella Laurenzo PyConcreteAttribute() = default; 45436c6c9cSStella Laurenzo PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 46436c6c9cSStella Laurenzo : BaseTy(std::move(contextRef), attr) {} 47436c6c9cSStella Laurenzo PyConcreteAttribute(PyAttribute &orig) 48436c6c9cSStella Laurenzo : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} 49436c6c9cSStella Laurenzo 50436c6c9cSStella Laurenzo static MlirAttribute castFrom(PyAttribute &orig) { 51436c6c9cSStella Laurenzo if (!DerivedTy::isaFunction(orig)) { 52436c6c9cSStella Laurenzo auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 53436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") + 54436c6c9cSStella Laurenzo DerivedTy::pyClassName + 55436c6c9cSStella Laurenzo " (from " + origRepr + ")"); 56436c6c9cSStella Laurenzo } 57436c6c9cSStella Laurenzo return orig; 58436c6c9cSStella Laurenzo } 59436c6c9cSStella Laurenzo 60436c6c9cSStella Laurenzo static void bind(py::module &m) { 61436c6c9cSStella Laurenzo auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol()); 62436c6c9cSStella Laurenzo cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>()); 63436c6c9cSStella Laurenzo DerivedTy::bindDerived(cls); 64436c6c9cSStella Laurenzo } 65436c6c9cSStella Laurenzo 66436c6c9cSStella Laurenzo /// Implemented by derived classes to add methods to the Python subclass. 67436c6c9cSStella Laurenzo static void bindDerived(ClassTy &m) {} 68436c6c9cSStella Laurenzo }; 69436c6c9cSStella Laurenzo 70436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 71436c6c9cSStella Laurenzo public: 72436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 73436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineMapAttr"; 74436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 75436c6c9cSStella Laurenzo 76436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 77436c6c9cSStella Laurenzo c.def_static( 78436c6c9cSStella Laurenzo "get", 79436c6c9cSStella Laurenzo [](PyAffineMap &affineMap) { 80436c6c9cSStella Laurenzo MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 81436c6c9cSStella Laurenzo return PyAffineMapAttribute(affineMap.getContext(), attr); 82436c6c9cSStella Laurenzo }, 83436c6c9cSStella Laurenzo py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 84436c6c9cSStella Laurenzo } 85436c6c9cSStella Laurenzo }; 86436c6c9cSStella Laurenzo 87436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 88436c6c9cSStella Laurenzo public: 89436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 90436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "ArrayAttr"; 91436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 92436c6c9cSStella Laurenzo 93436c6c9cSStella Laurenzo class PyArrayAttributeIterator { 94436c6c9cSStella Laurenzo public: 95436c6c9cSStella Laurenzo PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {} 96436c6c9cSStella Laurenzo 97436c6c9cSStella Laurenzo PyArrayAttributeIterator &dunderIter() { return *this; } 98436c6c9cSStella Laurenzo 99436c6c9cSStella Laurenzo PyAttribute dunderNext() { 100436c6c9cSStella Laurenzo if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { 101436c6c9cSStella Laurenzo throw py::stop_iteration(); 102436c6c9cSStella Laurenzo } 103436c6c9cSStella Laurenzo return PyAttribute(attr.getContext(), 104436c6c9cSStella Laurenzo mlirArrayAttrGetElement(attr.get(), nextIndex++)); 105436c6c9cSStella Laurenzo } 106436c6c9cSStella Laurenzo 107436c6c9cSStella Laurenzo static void bind(py::module &m) { 108436c6c9cSStella Laurenzo py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator") 109436c6c9cSStella Laurenzo .def("__iter__", &PyArrayAttributeIterator::dunderIter) 110436c6c9cSStella Laurenzo .def("__next__", &PyArrayAttributeIterator::dunderNext); 111436c6c9cSStella Laurenzo } 112436c6c9cSStella Laurenzo 113436c6c9cSStella Laurenzo private: 114436c6c9cSStella Laurenzo PyAttribute attr; 115436c6c9cSStella Laurenzo int nextIndex = 0; 116436c6c9cSStella Laurenzo }; 117436c6c9cSStella Laurenzo 118436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 119436c6c9cSStella Laurenzo c.def_static( 120436c6c9cSStella Laurenzo "get", 121436c6c9cSStella Laurenzo [](py::list attributes, DefaultingPyMlirContext context) { 122436c6c9cSStella Laurenzo SmallVector<MlirAttribute> mlirAttributes; 123436c6c9cSStella Laurenzo mlirAttributes.reserve(py::len(attributes)); 124436c6c9cSStella Laurenzo for (auto attribute : attributes) { 125436c6c9cSStella Laurenzo try { 126436c6c9cSStella Laurenzo mlirAttributes.push_back(attribute.cast<PyAttribute>()); 127436c6c9cSStella Laurenzo } catch (py::cast_error &err) { 128436c6c9cSStella Laurenzo std::string msg = std::string("Invalid attribute when attempting " 129436c6c9cSStella Laurenzo "to create an ArrayAttribute (") + 130436c6c9cSStella Laurenzo err.what() + ")"; 131436c6c9cSStella Laurenzo throw py::cast_error(msg); 132436c6c9cSStella Laurenzo } catch (py::reference_cast_error &err) { 133436c6c9cSStella Laurenzo // This exception seems thrown when the value is "None". 134436c6c9cSStella Laurenzo std::string msg = 135436c6c9cSStella Laurenzo std::string("Invalid attribute (None?) when attempting to " 136436c6c9cSStella Laurenzo "create an ArrayAttribute (") + 137436c6c9cSStella Laurenzo err.what() + ")"; 138436c6c9cSStella Laurenzo throw py::cast_error(msg); 139436c6c9cSStella Laurenzo } 140436c6c9cSStella Laurenzo } 141436c6c9cSStella Laurenzo MlirAttribute attr = mlirArrayAttrGet( 142436c6c9cSStella Laurenzo context->get(), mlirAttributes.size(), mlirAttributes.data()); 143436c6c9cSStella Laurenzo return PyArrayAttribute(context->getRef(), attr); 144436c6c9cSStella Laurenzo }, 145436c6c9cSStella Laurenzo py::arg("attributes"), py::arg("context") = py::none(), 146436c6c9cSStella Laurenzo "Gets a uniqued Array attribute"); 147436c6c9cSStella Laurenzo c.def("__getitem__", 148436c6c9cSStella Laurenzo [](PyArrayAttribute &arr, intptr_t i) { 149436c6c9cSStella Laurenzo if (i >= mlirArrayAttrGetNumElements(arr)) 150436c6c9cSStella Laurenzo throw py::index_error("ArrayAttribute index out of range"); 151436c6c9cSStella Laurenzo return PyAttribute(arr.getContext(), 152436c6c9cSStella Laurenzo mlirArrayAttrGetElement(arr, i)); 153436c6c9cSStella Laurenzo }) 154436c6c9cSStella Laurenzo .def("__len__", 155436c6c9cSStella Laurenzo [](const PyArrayAttribute &arr) { 156436c6c9cSStella Laurenzo return mlirArrayAttrGetNumElements(arr); 157436c6c9cSStella Laurenzo }) 158436c6c9cSStella Laurenzo .def("__iter__", [](const PyArrayAttribute &arr) { 159436c6c9cSStella Laurenzo return PyArrayAttributeIterator(arr); 160436c6c9cSStella Laurenzo }); 161436c6c9cSStella Laurenzo } 162436c6c9cSStella Laurenzo }; 163436c6c9cSStella Laurenzo 164436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr. 165436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 166436c6c9cSStella Laurenzo public: 167436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 168436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FloatAttr"; 169436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 170436c6c9cSStella Laurenzo 171436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 172436c6c9cSStella Laurenzo c.def_static( 173436c6c9cSStella Laurenzo "get", 174436c6c9cSStella Laurenzo [](PyType &type, double value, DefaultingPyLocation loc) { 175436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 176436c6c9cSStella Laurenzo // TODO: Rework error reporting once diagnostic engine is exposed 177436c6c9cSStella Laurenzo // in C API. 178436c6c9cSStella Laurenzo if (mlirAttributeIsNull(attr)) { 179436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, 180436c6c9cSStella Laurenzo Twine("invalid '") + 181436c6c9cSStella Laurenzo py::repr(py::cast(type)).cast<std::string>() + 182436c6c9cSStella Laurenzo "' and expected floating point type."); 183436c6c9cSStella Laurenzo } 184436c6c9cSStella Laurenzo return PyFloatAttribute(type.getContext(), attr); 185436c6c9cSStella Laurenzo }, 186436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 187436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a type"); 188436c6c9cSStella Laurenzo c.def_static( 189436c6c9cSStella Laurenzo "get_f32", 190436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 191436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 192436c6c9cSStella Laurenzo context->get(), mlirF32TypeGet(context->get()), value); 193436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 194436c6c9cSStella Laurenzo }, 195436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 196436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f32 type"); 197436c6c9cSStella Laurenzo c.def_static( 198436c6c9cSStella Laurenzo "get_f64", 199436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 200436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 201436c6c9cSStella Laurenzo context->get(), mlirF64TypeGet(context->get()), value); 202436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 203436c6c9cSStella Laurenzo }, 204436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 205436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f64 type"); 206436c6c9cSStella Laurenzo c.def_property_readonly( 207436c6c9cSStella Laurenzo "value", 208436c6c9cSStella Laurenzo [](PyFloatAttribute &self) { 209436c6c9cSStella Laurenzo return mlirFloatAttrGetValueDouble(self); 210436c6c9cSStella Laurenzo }, 211436c6c9cSStella Laurenzo "Returns the value of the float point attribute"); 212436c6c9cSStella Laurenzo } 213436c6c9cSStella Laurenzo }; 214436c6c9cSStella Laurenzo 215436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr. 216436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 217436c6c9cSStella Laurenzo public: 218436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 219436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "IntegerAttr"; 220436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 221436c6c9cSStella Laurenzo 222436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 223436c6c9cSStella Laurenzo c.def_static( 224436c6c9cSStella Laurenzo "get", 225436c6c9cSStella Laurenzo [](PyType &type, int64_t value) { 226436c6c9cSStella Laurenzo MlirAttribute attr = mlirIntegerAttrGet(type, value); 227436c6c9cSStella Laurenzo return PyIntegerAttribute(type.getContext(), attr); 228436c6c9cSStella Laurenzo }, 229436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), 230436c6c9cSStella Laurenzo "Gets an uniqued integer attribute associated to a type"); 231436c6c9cSStella Laurenzo c.def_property_readonly( 232436c6c9cSStella Laurenzo "value", 233436c6c9cSStella Laurenzo [](PyIntegerAttribute &self) { 234436c6c9cSStella Laurenzo return mlirIntegerAttrGetValueInt(self); 235436c6c9cSStella Laurenzo }, 236436c6c9cSStella Laurenzo "Returns the value of the integer attribute"); 237436c6c9cSStella Laurenzo } 238436c6c9cSStella Laurenzo }; 239436c6c9cSStella Laurenzo 240436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr. 241436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 242436c6c9cSStella Laurenzo public: 243436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 244436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "BoolAttr"; 245436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 246436c6c9cSStella Laurenzo 247436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 248436c6c9cSStella Laurenzo c.def_static( 249436c6c9cSStella Laurenzo "get", 250436c6c9cSStella Laurenzo [](bool value, DefaultingPyMlirContext context) { 251436c6c9cSStella Laurenzo MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 252436c6c9cSStella Laurenzo return PyBoolAttribute(context->getRef(), attr); 253436c6c9cSStella Laurenzo }, 254436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 255436c6c9cSStella Laurenzo "Gets an uniqued bool attribute"); 256436c6c9cSStella Laurenzo c.def_property_readonly( 257436c6c9cSStella Laurenzo "value", 258436c6c9cSStella Laurenzo [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, 259436c6c9cSStella Laurenzo "Returns the value of the bool attribute"); 260436c6c9cSStella Laurenzo } 261436c6c9cSStella Laurenzo }; 262436c6c9cSStella Laurenzo 263436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute 264436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 265436c6c9cSStella Laurenzo public: 266436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 267436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 268436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 269436c6c9cSStella Laurenzo 270436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 271436c6c9cSStella Laurenzo c.def_static( 272436c6c9cSStella Laurenzo "get", 273436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 274436c6c9cSStella Laurenzo MlirAttribute attr = 275436c6c9cSStella Laurenzo mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 276436c6c9cSStella Laurenzo return PyFlatSymbolRefAttribute(context->getRef(), attr); 277436c6c9cSStella Laurenzo }, 278436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 279436c6c9cSStella Laurenzo "Gets a uniqued FlatSymbolRef attribute"); 280436c6c9cSStella Laurenzo c.def_property_readonly( 281436c6c9cSStella Laurenzo "value", 282436c6c9cSStella Laurenzo [](PyFlatSymbolRefAttribute &self) { 283436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 284436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 285436c6c9cSStella Laurenzo }, 286436c6c9cSStella Laurenzo "Returns the value of the FlatSymbolRef attribute as a string"); 287436c6c9cSStella Laurenzo } 288436c6c9cSStella Laurenzo }; 289436c6c9cSStella Laurenzo 290436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 291436c6c9cSStella Laurenzo public: 292436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 293436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "StringAttr"; 294436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 295436c6c9cSStella Laurenzo 296436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 297436c6c9cSStella Laurenzo c.def_static( 298436c6c9cSStella Laurenzo "get", 299436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 300436c6c9cSStella Laurenzo MlirAttribute attr = 301436c6c9cSStella Laurenzo mlirStringAttrGet(context->get(), toMlirStringRef(value)); 302436c6c9cSStella Laurenzo return PyStringAttribute(context->getRef(), attr); 303436c6c9cSStella Laurenzo }, 304436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 305436c6c9cSStella Laurenzo "Gets a uniqued string attribute"); 306436c6c9cSStella Laurenzo c.def_static( 307436c6c9cSStella Laurenzo "get_typed", 308436c6c9cSStella Laurenzo [](PyType &type, std::string value) { 309436c6c9cSStella Laurenzo MlirAttribute attr = 310436c6c9cSStella Laurenzo mlirStringAttrTypedGet(type, toMlirStringRef(value)); 311436c6c9cSStella Laurenzo return PyStringAttribute(type.getContext(), attr); 312436c6c9cSStella Laurenzo }, 313436c6c9cSStella Laurenzo 314436c6c9cSStella Laurenzo "Gets a uniqued string attribute associated to a type"); 315436c6c9cSStella Laurenzo c.def_property_readonly( 316436c6c9cSStella Laurenzo "value", 317436c6c9cSStella Laurenzo [](PyStringAttribute &self) { 318436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirStringAttrGetValue(self); 319436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 320436c6c9cSStella Laurenzo }, 321436c6c9cSStella Laurenzo "Returns the value of the string attribute"); 322436c6c9cSStella Laurenzo } 323436c6c9cSStella Laurenzo }; 324436c6c9cSStella Laurenzo 325436c6c9cSStella Laurenzo // TODO: Support construction of bool elements. 326436c6c9cSStella Laurenzo // TODO: Support construction of string elements. 327436c6c9cSStella Laurenzo class PyDenseElementsAttribute 328436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseElementsAttribute> { 329436c6c9cSStella Laurenzo public: 330436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 331436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseElementsAttr"; 332436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 333436c6c9cSStella Laurenzo 334436c6c9cSStella Laurenzo static PyDenseElementsAttribute 335436c6c9cSStella Laurenzo getFromBuffer(py::buffer array, bool signless, 336436c6c9cSStella Laurenzo DefaultingPyMlirContext contextWrapper) { 337436c6c9cSStella Laurenzo // Request a contiguous view. In exotic cases, this will cause a copy. 338436c6c9cSStella Laurenzo int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; 339436c6c9cSStella Laurenzo Py_buffer *view = new Py_buffer(); 340436c6c9cSStella Laurenzo if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { 341436c6c9cSStella Laurenzo delete view; 342436c6c9cSStella Laurenzo throw py::error_already_set(); 343436c6c9cSStella Laurenzo } 344436c6c9cSStella Laurenzo py::buffer_info arrayInfo(view); 345436c6c9cSStella Laurenzo 346436c6c9cSStella Laurenzo MlirContext context = contextWrapper->get(); 347436c6c9cSStella Laurenzo // Switch on the types that can be bulk loaded between the Python and 348436c6c9cSStella Laurenzo // MLIR-C APIs. 349436c6c9cSStella Laurenzo // See: https://docs.python.org/3/library/struct.html#format-characters 350436c6c9cSStella Laurenzo if (arrayInfo.format == "f") { 351436c6c9cSStella Laurenzo // f32 352436c6c9cSStella Laurenzo assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); 353436c6c9cSStella Laurenzo return PyDenseElementsAttribute( 354436c6c9cSStella Laurenzo contextWrapper->getRef(), 355436c6c9cSStella Laurenzo bulkLoad(context, mlirDenseElementsAttrFloatGet, 356436c6c9cSStella Laurenzo mlirF32TypeGet(context), arrayInfo)); 357436c6c9cSStella Laurenzo } else if (arrayInfo.format == "d") { 358436c6c9cSStella Laurenzo // f64 359436c6c9cSStella Laurenzo assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); 360436c6c9cSStella Laurenzo return PyDenseElementsAttribute( 361436c6c9cSStella Laurenzo contextWrapper->getRef(), 362436c6c9cSStella Laurenzo bulkLoad(context, mlirDenseElementsAttrDoubleGet, 363436c6c9cSStella Laurenzo mlirF64TypeGet(context), arrayInfo)); 364436c6c9cSStella Laurenzo } else if (isSignedIntegerFormat(arrayInfo.format)) { 365436c6c9cSStella Laurenzo if (arrayInfo.itemsize == 4) { 366436c6c9cSStella Laurenzo // i32 367436c6c9cSStella Laurenzo MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) 368436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 32); 369436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), 370436c6c9cSStella Laurenzo bulkLoad(context, 371436c6c9cSStella Laurenzo mlirDenseElementsAttrInt32Get, 372436c6c9cSStella Laurenzo elementType, arrayInfo)); 373436c6c9cSStella Laurenzo } else if (arrayInfo.itemsize == 8) { 374436c6c9cSStella Laurenzo // i64 375436c6c9cSStella Laurenzo MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) 376436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 64); 377436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), 378436c6c9cSStella Laurenzo bulkLoad(context, 379436c6c9cSStella Laurenzo mlirDenseElementsAttrInt64Get, 380436c6c9cSStella Laurenzo elementType, arrayInfo)); 381436c6c9cSStella Laurenzo } 382436c6c9cSStella Laurenzo } else if (isUnsignedIntegerFormat(arrayInfo.format)) { 383436c6c9cSStella Laurenzo if (arrayInfo.itemsize == 4) { 384436c6c9cSStella Laurenzo // unsigned i32 385436c6c9cSStella Laurenzo MlirType elementType = signless 386436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 32) 387436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 32); 388436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), 389436c6c9cSStella Laurenzo bulkLoad(context, 390436c6c9cSStella Laurenzo mlirDenseElementsAttrUInt32Get, 391436c6c9cSStella Laurenzo elementType, arrayInfo)); 392436c6c9cSStella Laurenzo } else if (arrayInfo.itemsize == 8) { 393436c6c9cSStella Laurenzo // unsigned i64 394436c6c9cSStella Laurenzo MlirType elementType = signless 395436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 64) 396436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 64); 397436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), 398436c6c9cSStella Laurenzo bulkLoad(context, 399436c6c9cSStella Laurenzo mlirDenseElementsAttrUInt64Get, 400436c6c9cSStella Laurenzo elementType, arrayInfo)); 401436c6c9cSStella Laurenzo } 402436c6c9cSStella Laurenzo } 403436c6c9cSStella Laurenzo 404436c6c9cSStella Laurenzo // TODO: Fall back to string-based get. 405436c6c9cSStella Laurenzo std::string message = "unimplemented array format conversion from format: "; 406436c6c9cSStella Laurenzo message.append(arrayInfo.format); 407436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 408436c6c9cSStella Laurenzo } 409436c6c9cSStella Laurenzo 410436c6c9cSStella Laurenzo static PyDenseElementsAttribute getSplat(PyType shapedType, 411436c6c9cSStella Laurenzo PyAttribute &elementAttr) { 412436c6c9cSStella Laurenzo auto contextWrapper = 413436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 414436c6c9cSStella Laurenzo if (!mlirAttributeIsAInteger(elementAttr) && 415436c6c9cSStella Laurenzo !mlirAttributeIsAFloat(elementAttr)) { 416436c6c9cSStella Laurenzo std::string message = "Illegal element type for DenseElementsAttr: "; 417436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 418436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 419436c6c9cSStella Laurenzo } 420436c6c9cSStella Laurenzo if (!mlirTypeIsAShaped(shapedType) || 421436c6c9cSStella Laurenzo !mlirShapedTypeHasStaticShape(shapedType)) { 422436c6c9cSStella Laurenzo std::string message = 423436c6c9cSStella Laurenzo "Expected a static ShapedType for the shaped_type parameter: "; 424436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 425436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 426436c6c9cSStella Laurenzo } 427436c6c9cSStella Laurenzo MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 428436c6c9cSStella Laurenzo MlirType attrType = mlirAttributeGetType(elementAttr); 429436c6c9cSStella Laurenzo if (!mlirTypeEqual(shapedElementType, attrType)) { 430436c6c9cSStella Laurenzo std::string message = 431436c6c9cSStella Laurenzo "Shaped element type and attribute type must be equal: shaped="; 432436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 433436c6c9cSStella Laurenzo message.append(", element="); 434436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 435436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 436436c6c9cSStella Laurenzo } 437436c6c9cSStella Laurenzo 438436c6c9cSStella Laurenzo MlirAttribute elements = 439436c6c9cSStella Laurenzo mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 440436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 441436c6c9cSStella Laurenzo } 442436c6c9cSStella Laurenzo 443436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 444436c6c9cSStella Laurenzo 445436c6c9cSStella Laurenzo py::buffer_info accessBuffer() { 446436c6c9cSStella Laurenzo MlirType shapedType = mlirAttributeGetType(*this); 447436c6c9cSStella Laurenzo MlirType elementType = mlirShapedTypeGetElementType(shapedType); 448436c6c9cSStella Laurenzo 449436c6c9cSStella Laurenzo if (mlirTypeIsAF32(elementType)) { 450436c6c9cSStella Laurenzo // f32 451436c6c9cSStella Laurenzo return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue); 452436c6c9cSStella Laurenzo } else if (mlirTypeIsAF64(elementType)) { 453436c6c9cSStella Laurenzo // f64 454436c6c9cSStella Laurenzo return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue); 455436c6c9cSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 456436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 32) { 457436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 458436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 459436c6c9cSStella Laurenzo // i32 460436c6c9cSStella Laurenzo return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value); 461436c6c9cSStella Laurenzo } else if (mlirIntegerTypeIsUnsigned(elementType)) { 462436c6c9cSStella Laurenzo // unsigned i32 463436c6c9cSStella Laurenzo return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value); 464436c6c9cSStella Laurenzo } 465436c6c9cSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 466436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 64) { 467436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 468436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 469436c6c9cSStella Laurenzo // i64 470436c6c9cSStella Laurenzo return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value); 471436c6c9cSStella Laurenzo } else if (mlirIntegerTypeIsUnsigned(elementType)) { 472436c6c9cSStella Laurenzo // unsigned i64 473436c6c9cSStella Laurenzo return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value); 474436c6c9cSStella Laurenzo } 475436c6c9cSStella Laurenzo } 476436c6c9cSStella Laurenzo 477436c6c9cSStella Laurenzo std::string message = "unimplemented array format."; 478436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 479436c6c9cSStella Laurenzo } 480436c6c9cSStella Laurenzo 481436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 482436c6c9cSStella Laurenzo c.def("__len__", &PyDenseElementsAttribute::dunderLen) 483436c6c9cSStella Laurenzo .def_static("get", PyDenseElementsAttribute::getFromBuffer, 484436c6c9cSStella Laurenzo py::arg("array"), py::arg("signless") = true, 485436c6c9cSStella Laurenzo py::arg("context") = py::none(), 486436c6c9cSStella Laurenzo "Gets from a buffer or ndarray") 487436c6c9cSStella Laurenzo .def_static("get_splat", PyDenseElementsAttribute::getSplat, 488436c6c9cSStella Laurenzo py::arg("shaped_type"), py::arg("element_attr"), 489436c6c9cSStella Laurenzo "Gets a DenseElementsAttr where all values are the same") 490436c6c9cSStella Laurenzo .def_property_readonly("is_splat", 491436c6c9cSStella Laurenzo [](PyDenseElementsAttribute &self) -> bool { 492436c6c9cSStella Laurenzo return mlirDenseElementsAttrIsSplat(self); 493436c6c9cSStella Laurenzo }) 494436c6c9cSStella Laurenzo .def_buffer(&PyDenseElementsAttribute::accessBuffer); 495436c6c9cSStella Laurenzo } 496436c6c9cSStella Laurenzo 497436c6c9cSStella Laurenzo private: 498436c6c9cSStella Laurenzo template <typename ElementTy> 499436c6c9cSStella Laurenzo static MlirAttribute 500436c6c9cSStella Laurenzo bulkLoad(MlirContext context, 501436c6c9cSStella Laurenzo MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *), 502436c6c9cSStella Laurenzo MlirType mlirElementType, py::buffer_info &arrayInfo) { 503436c6c9cSStella Laurenzo SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(), 504436c6c9cSStella Laurenzo arrayInfo.shape.begin() + arrayInfo.ndim); 505*7714b405SAart Bik MlirAttribute encodingAttr = mlirAttributeGetNull(); 506*7714b405SAart Bik auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), 507*7714b405SAart Bik mlirElementType, encodingAttr); 508436c6c9cSStella Laurenzo intptr_t numElements = arrayInfo.size; 509436c6c9cSStella Laurenzo const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr); 510436c6c9cSStella Laurenzo return ctor(shapedType, numElements, contents); 511436c6c9cSStella Laurenzo } 512436c6c9cSStella Laurenzo 513436c6c9cSStella Laurenzo static bool isUnsignedIntegerFormat(const std::string &format) { 514436c6c9cSStella Laurenzo if (format.empty()) 515436c6c9cSStella Laurenzo return false; 516436c6c9cSStella Laurenzo char code = format[0]; 517436c6c9cSStella Laurenzo return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 518436c6c9cSStella Laurenzo code == 'Q'; 519436c6c9cSStella Laurenzo } 520436c6c9cSStella Laurenzo 521436c6c9cSStella Laurenzo static bool isSignedIntegerFormat(const std::string &format) { 522436c6c9cSStella Laurenzo if (format.empty()) 523436c6c9cSStella Laurenzo return false; 524436c6c9cSStella Laurenzo char code = format[0]; 525436c6c9cSStella Laurenzo return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 526436c6c9cSStella Laurenzo code == 'q'; 527436c6c9cSStella Laurenzo } 528436c6c9cSStella Laurenzo 529436c6c9cSStella Laurenzo template <typename Type> 530436c6c9cSStella Laurenzo py::buffer_info bufferInfo(MlirType shapedType, 531436c6c9cSStella Laurenzo Type (*value)(MlirAttribute, intptr_t)) { 532436c6c9cSStella Laurenzo intptr_t rank = mlirShapedTypeGetRank(shapedType); 533436c6c9cSStella Laurenzo // Prepare the data for the buffer_info. 534436c6c9cSStella Laurenzo // Buffer is configured for read-only access below. 535436c6c9cSStella Laurenzo Type *data = static_cast<Type *>( 536436c6c9cSStella Laurenzo const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 537436c6c9cSStella Laurenzo // Prepare the shape for the buffer_info. 538436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> shape; 539436c6c9cSStella Laurenzo for (intptr_t i = 0; i < rank; ++i) 540436c6c9cSStella Laurenzo shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 541436c6c9cSStella Laurenzo // Prepare the strides for the buffer_info. 542436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> strides; 543436c6c9cSStella Laurenzo intptr_t strideFactor = 1; 544436c6c9cSStella Laurenzo for (intptr_t i = 1; i < rank; ++i) { 545436c6c9cSStella Laurenzo strideFactor = 1; 546436c6c9cSStella Laurenzo for (intptr_t j = i; j < rank; ++j) { 547436c6c9cSStella Laurenzo strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 548436c6c9cSStella Laurenzo } 549436c6c9cSStella Laurenzo strides.push_back(sizeof(Type) * strideFactor); 550436c6c9cSStella Laurenzo } 551436c6c9cSStella Laurenzo strides.push_back(sizeof(Type)); 552436c6c9cSStella Laurenzo return py::buffer_info(data, sizeof(Type), 553436c6c9cSStella Laurenzo py::format_descriptor<Type>::format(), rank, shape, 554436c6c9cSStella Laurenzo strides, /*readonly=*/true); 555436c6c9cSStella Laurenzo } 556436c6c9cSStella Laurenzo }; // namespace 557436c6c9cSStella Laurenzo 558436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer 559436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access. 560436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute 561436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseIntElementsAttribute, 562436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 563436c6c9cSStella Laurenzo public: 564436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 565436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseIntElementsAttr"; 566436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 567436c6c9cSStella Laurenzo 568436c6c9cSStella Laurenzo /// Returns the element at the given linear position. Asserts if the index is 569436c6c9cSStella Laurenzo /// out of range. 570436c6c9cSStella Laurenzo py::int_ dunderGetItem(intptr_t pos) { 571436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 572436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 573436c6c9cSStella Laurenzo "attempt to access out of bounds element"); 574436c6c9cSStella Laurenzo } 575436c6c9cSStella Laurenzo 576436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 577436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 578436c6c9cSStella Laurenzo assert(mlirTypeIsAInteger(type) && 579436c6c9cSStella Laurenzo "expected integer element type in dense int elements attribute"); 580436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 581436c6c9cSStella Laurenzo // elemental type of the attribute. py::int_ is implicitly constructible 582436c6c9cSStella Laurenzo // from any C++ integral type and handles bitwidth correctly. 583436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 584436c6c9cSStella Laurenzo // querying them on each element access. 585436c6c9cSStella Laurenzo unsigned width = mlirIntegerTypeGetWidth(type); 586436c6c9cSStella Laurenzo bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 587436c6c9cSStella Laurenzo if (isUnsigned) { 588436c6c9cSStella Laurenzo if (width == 1) { 589436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 590436c6c9cSStella Laurenzo } 591436c6c9cSStella Laurenzo if (width == 32) { 592436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt32Value(*this, pos); 593436c6c9cSStella Laurenzo } 594436c6c9cSStella Laurenzo if (width == 64) { 595436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt64Value(*this, pos); 596436c6c9cSStella Laurenzo } 597436c6c9cSStella Laurenzo } else { 598436c6c9cSStella Laurenzo if (width == 1) { 599436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 600436c6c9cSStella Laurenzo } 601436c6c9cSStella Laurenzo if (width == 32) { 602436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt32Value(*this, pos); 603436c6c9cSStella Laurenzo } 604436c6c9cSStella Laurenzo if (width == 64) { 605436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt64Value(*this, pos); 606436c6c9cSStella Laurenzo } 607436c6c9cSStella Laurenzo } 608436c6c9cSStella Laurenzo throw SetPyError(PyExc_TypeError, "Unsupported integer type"); 609436c6c9cSStella Laurenzo } 610436c6c9cSStella Laurenzo 611436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 612436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 613436c6c9cSStella Laurenzo } 614436c6c9cSStella Laurenzo }; 615436c6c9cSStella Laurenzo 616436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 617436c6c9cSStella Laurenzo public: 618436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 619436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DictAttr"; 620436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 621436c6c9cSStella Laurenzo 622436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 623436c6c9cSStella Laurenzo 624436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 625436c6c9cSStella Laurenzo c.def("__len__", &PyDictAttribute::dunderLen); 626436c6c9cSStella Laurenzo c.def_static( 627436c6c9cSStella Laurenzo "get", 628436c6c9cSStella Laurenzo [](py::dict attributes, DefaultingPyMlirContext context) { 629436c6c9cSStella Laurenzo SmallVector<MlirNamedAttribute> mlirNamedAttributes; 630436c6c9cSStella Laurenzo mlirNamedAttributes.reserve(attributes.size()); 631436c6c9cSStella Laurenzo for (auto &it : attributes) { 632436c6c9cSStella Laurenzo auto &mlir_attr = it.second.cast<PyAttribute &>(); 633436c6c9cSStella Laurenzo auto name = it.first.cast<std::string>(); 634436c6c9cSStella Laurenzo mlirNamedAttributes.push_back(mlirNamedAttributeGet( 635436c6c9cSStella Laurenzo mlirIdentifierGet(mlirAttributeGetContext(mlir_attr), 636436c6c9cSStella Laurenzo toMlirStringRef(name)), 637436c6c9cSStella Laurenzo mlir_attr)); 638436c6c9cSStella Laurenzo } 639436c6c9cSStella Laurenzo MlirAttribute attr = 640436c6c9cSStella Laurenzo mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 641436c6c9cSStella Laurenzo mlirNamedAttributes.data()); 642436c6c9cSStella Laurenzo return PyDictAttribute(context->getRef(), attr); 643436c6c9cSStella Laurenzo }, 644436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 645436c6c9cSStella Laurenzo "Gets an uniqued dict attribute"); 646436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 647436c6c9cSStella Laurenzo MlirAttribute attr = 648436c6c9cSStella Laurenzo mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 649436c6c9cSStella Laurenzo if (mlirAttributeIsNull(attr)) { 650436c6c9cSStella Laurenzo throw SetPyError(PyExc_KeyError, 651436c6c9cSStella Laurenzo "attempt to access a non-existent attribute"); 652436c6c9cSStella Laurenzo } 653436c6c9cSStella Laurenzo return PyAttribute(self.getContext(), attr); 654436c6c9cSStella Laurenzo }); 655436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 656436c6c9cSStella Laurenzo if (index < 0 || index >= self.dunderLen()) { 657436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 658436c6c9cSStella Laurenzo "attempt to access out of bounds attribute"); 659436c6c9cSStella Laurenzo } 660436c6c9cSStella Laurenzo MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 661436c6c9cSStella Laurenzo return PyNamedAttribute( 662436c6c9cSStella Laurenzo namedAttr.attribute, 663436c6c9cSStella Laurenzo std::string(mlirIdentifierStr(namedAttr.name).data)); 664436c6c9cSStella Laurenzo }); 665436c6c9cSStella Laurenzo } 666436c6c9cSStella Laurenzo }; 667436c6c9cSStella Laurenzo 668436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing 669436c6c9cSStella Laurenzo /// floating-point values. Supports element access. 670436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute 671436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseFPElementsAttribute, 672436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 673436c6c9cSStella Laurenzo public: 674436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 675436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseFPElementsAttr"; 676436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 677436c6c9cSStella Laurenzo 678436c6c9cSStella Laurenzo py::float_ dunderGetItem(intptr_t pos) { 679436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 680436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 681436c6c9cSStella Laurenzo "attempt to access out of bounds element"); 682436c6c9cSStella Laurenzo } 683436c6c9cSStella Laurenzo 684436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 685436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 686436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 687436c6c9cSStella Laurenzo // elemental type of the attribute. py::float_ is implicitly constructible 688436c6c9cSStella Laurenzo // from float and double. 689436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 690436c6c9cSStella Laurenzo // querying them on each element access. 691436c6c9cSStella Laurenzo if (mlirTypeIsAF32(type)) { 692436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetFloatValue(*this, pos); 693436c6c9cSStella Laurenzo } 694436c6c9cSStella Laurenzo if (mlirTypeIsAF64(type)) { 695436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetDoubleValue(*this, pos); 696436c6c9cSStella Laurenzo } 697436c6c9cSStella Laurenzo throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); 698436c6c9cSStella Laurenzo } 699436c6c9cSStella Laurenzo 700436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 701436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 702436c6c9cSStella Laurenzo } 703436c6c9cSStella Laurenzo }; 704436c6c9cSStella Laurenzo 705436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 706436c6c9cSStella Laurenzo public: 707436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 708436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "TypeAttr"; 709436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 710436c6c9cSStella Laurenzo 711436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 712436c6c9cSStella Laurenzo c.def_static( 713436c6c9cSStella Laurenzo "get", 714436c6c9cSStella Laurenzo [](PyType value, DefaultingPyMlirContext context) { 715436c6c9cSStella Laurenzo MlirAttribute attr = mlirTypeAttrGet(value.get()); 716436c6c9cSStella Laurenzo return PyTypeAttribute(context->getRef(), attr); 717436c6c9cSStella Laurenzo }, 718436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 719436c6c9cSStella Laurenzo "Gets a uniqued Type attribute"); 720436c6c9cSStella Laurenzo c.def_property_readonly("value", [](PyTypeAttribute &self) { 721436c6c9cSStella Laurenzo return PyType(self.getContext()->getRef(), 722436c6c9cSStella Laurenzo mlirTypeAttrGetValue(self.get())); 723436c6c9cSStella Laurenzo }); 724436c6c9cSStella Laurenzo } 725436c6c9cSStella Laurenzo }; 726436c6c9cSStella Laurenzo 727436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values. 728436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 729436c6c9cSStella Laurenzo public: 730436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 731436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "UnitAttr"; 732436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 733436c6c9cSStella Laurenzo 734436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 735436c6c9cSStella Laurenzo c.def_static( 736436c6c9cSStella Laurenzo "get", 737436c6c9cSStella Laurenzo [](DefaultingPyMlirContext context) { 738436c6c9cSStella Laurenzo return PyUnitAttribute(context->getRef(), 739436c6c9cSStella Laurenzo mlirUnitAttrGet(context->get())); 740436c6c9cSStella Laurenzo }, 741436c6c9cSStella Laurenzo py::arg("context") = py::none(), "Create a Unit attribute."); 742436c6c9cSStella Laurenzo } 743436c6c9cSStella Laurenzo }; 744436c6c9cSStella Laurenzo 745436c6c9cSStella Laurenzo } // namespace 746436c6c9cSStella Laurenzo 747436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) { 748436c6c9cSStella Laurenzo PyAffineMapAttribute::bind(m); 749436c6c9cSStella Laurenzo PyArrayAttribute::bind(m); 750436c6c9cSStella Laurenzo PyArrayAttribute::PyArrayAttributeIterator::bind(m); 751436c6c9cSStella Laurenzo PyBoolAttribute::bind(m); 752436c6c9cSStella Laurenzo PyDenseElementsAttribute::bind(m); 753436c6c9cSStella Laurenzo PyDenseFPElementsAttribute::bind(m); 754436c6c9cSStella Laurenzo PyDenseIntElementsAttribute::bind(m); 755436c6c9cSStella Laurenzo PyDictAttribute::bind(m); 756436c6c9cSStella Laurenzo PyFlatSymbolRefAttribute::bind(m); 757436c6c9cSStella Laurenzo PyFloatAttribute::bind(m); 758436c6c9cSStella Laurenzo PyIntegerAttribute::bind(m); 759436c6c9cSStella Laurenzo PyStringAttribute::bind(m); 760436c6c9cSStella Laurenzo PyTypeAttribute::bind(m); 761436c6c9cSStella Laurenzo PyUnitAttribute::bind(m); 762436c6c9cSStella Laurenzo } 763