1*436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2*436c6c9cSStella Laurenzo // 3*436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5*436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*436c6c9cSStella Laurenzo // 7*436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8*436c6c9cSStella Laurenzo 9*436c6c9cSStella Laurenzo #include "IRModule.h" 10*436c6c9cSStella Laurenzo 11*436c6c9cSStella Laurenzo #include "PybindUtils.h" 12*436c6c9cSStella Laurenzo 13*436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h" 14*436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h" 15*436c6c9cSStella Laurenzo 16*436c6c9cSStella Laurenzo namespace py = pybind11; 17*436c6c9cSStella Laurenzo using namespace mlir; 18*436c6c9cSStella Laurenzo using namespace mlir::python; 19*436c6c9cSStella Laurenzo 20*436c6c9cSStella Laurenzo using llvm::SmallVector; 21*436c6c9cSStella Laurenzo using llvm::StringRef; 22*436c6c9cSStella Laurenzo using llvm::Twine; 23*436c6c9cSStella Laurenzo 24*436c6c9cSStella Laurenzo namespace { 25*436c6c9cSStella Laurenzo 26*436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) { 27*436c6c9cSStella Laurenzo return mlirStringRefCreate(s.data(), s.size()); 28*436c6c9cSStella Laurenzo } 29*436c6c9cSStella Laurenzo 30*436c6c9cSStella Laurenzo /// CRTP base classes for Python attributes that subclass Attribute and should 31*436c6c9cSStella Laurenzo /// be castable from it (i.e. via something like StringAttr(attr)). 32*436c6c9cSStella Laurenzo /// By default, attribute class hierarchies are one level deep (i.e. a 33*436c6c9cSStella Laurenzo /// concrete attribute class extends PyAttribute); however, intermediate 34*436c6c9cSStella Laurenzo /// python-visible base classes can be modeled by specifying a BaseTy. 35*436c6c9cSStella Laurenzo template <typename DerivedTy, typename BaseTy = PyAttribute> 36*436c6c9cSStella Laurenzo class PyConcreteAttribute : public BaseTy { 37*436c6c9cSStella Laurenzo public: 38*436c6c9cSStella Laurenzo // Derived classes must define statics for: 39*436c6c9cSStella Laurenzo // IsAFunctionTy isaFunction 40*436c6c9cSStella Laurenzo // const char *pyClassName 41*436c6c9cSStella Laurenzo using ClassTy = py::class_<DerivedTy, BaseTy>; 42*436c6c9cSStella Laurenzo using IsAFunctionTy = bool (*)(MlirAttribute); 43*436c6c9cSStella Laurenzo 44*436c6c9cSStella Laurenzo PyConcreteAttribute() = default; 45*436c6c9cSStella Laurenzo PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 46*436c6c9cSStella Laurenzo : BaseTy(std::move(contextRef), attr) {} 47*436c6c9cSStella Laurenzo PyConcreteAttribute(PyAttribute &orig) 48*436c6c9cSStella Laurenzo : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} 49*436c6c9cSStella Laurenzo 50*436c6c9cSStella Laurenzo static MlirAttribute castFrom(PyAttribute &orig) { 51*436c6c9cSStella Laurenzo if (!DerivedTy::isaFunction(orig)) { 52*436c6c9cSStella Laurenzo auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 53*436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") + 54*436c6c9cSStella Laurenzo DerivedTy::pyClassName + 55*436c6c9cSStella Laurenzo " (from " + origRepr + ")"); 56*436c6c9cSStella Laurenzo } 57*436c6c9cSStella Laurenzo return orig; 58*436c6c9cSStella Laurenzo } 59*436c6c9cSStella Laurenzo 60*436c6c9cSStella Laurenzo static void bind(py::module &m) { 61*436c6c9cSStella Laurenzo auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol()); 62*436c6c9cSStella Laurenzo cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>()); 63*436c6c9cSStella Laurenzo DerivedTy::bindDerived(cls); 64*436c6c9cSStella Laurenzo } 65*436c6c9cSStella Laurenzo 66*436c6c9cSStella Laurenzo /// Implemented by derived classes to add methods to the Python subclass. 67*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &m) {} 68*436c6c9cSStella Laurenzo }; 69*436c6c9cSStella Laurenzo 70*436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 71*436c6c9cSStella Laurenzo public: 72*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 73*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineMapAttr"; 74*436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 75*436c6c9cSStella Laurenzo 76*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 77*436c6c9cSStella Laurenzo c.def_static( 78*436c6c9cSStella Laurenzo "get", 79*436c6c9cSStella Laurenzo [](PyAffineMap &affineMap) { 80*436c6c9cSStella Laurenzo MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 81*436c6c9cSStella Laurenzo return PyAffineMapAttribute(affineMap.getContext(), attr); 82*436c6c9cSStella Laurenzo }, 83*436c6c9cSStella Laurenzo py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 84*436c6c9cSStella Laurenzo } 85*436c6c9cSStella Laurenzo }; 86*436c6c9cSStella Laurenzo 87*436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 88*436c6c9cSStella Laurenzo public: 89*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 90*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "ArrayAttr"; 91*436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 92*436c6c9cSStella Laurenzo 93*436c6c9cSStella Laurenzo class PyArrayAttributeIterator { 94*436c6c9cSStella Laurenzo public: 95*436c6c9cSStella Laurenzo PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {} 96*436c6c9cSStella Laurenzo 97*436c6c9cSStella Laurenzo PyArrayAttributeIterator &dunderIter() { return *this; } 98*436c6c9cSStella Laurenzo 99*436c6c9cSStella Laurenzo PyAttribute dunderNext() { 100*436c6c9cSStella Laurenzo if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { 101*436c6c9cSStella Laurenzo throw py::stop_iteration(); 102*436c6c9cSStella Laurenzo } 103*436c6c9cSStella Laurenzo return PyAttribute(attr.getContext(), 104*436c6c9cSStella Laurenzo mlirArrayAttrGetElement(attr.get(), nextIndex++)); 105*436c6c9cSStella Laurenzo } 106*436c6c9cSStella Laurenzo 107*436c6c9cSStella Laurenzo static void bind(py::module &m) { 108*436c6c9cSStella Laurenzo py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator") 109*436c6c9cSStella Laurenzo .def("__iter__", &PyArrayAttributeIterator::dunderIter) 110*436c6c9cSStella Laurenzo .def("__next__", &PyArrayAttributeIterator::dunderNext); 111*436c6c9cSStella Laurenzo } 112*436c6c9cSStella Laurenzo 113*436c6c9cSStella Laurenzo private: 114*436c6c9cSStella Laurenzo PyAttribute attr; 115*436c6c9cSStella Laurenzo int nextIndex = 0; 116*436c6c9cSStella Laurenzo }; 117*436c6c9cSStella Laurenzo 118*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 119*436c6c9cSStella Laurenzo c.def_static( 120*436c6c9cSStella Laurenzo "get", 121*436c6c9cSStella Laurenzo [](py::list attributes, DefaultingPyMlirContext context) { 122*436c6c9cSStella Laurenzo SmallVector<MlirAttribute> mlirAttributes; 123*436c6c9cSStella Laurenzo mlirAttributes.reserve(py::len(attributes)); 124*436c6c9cSStella Laurenzo for (auto attribute : attributes) { 125*436c6c9cSStella Laurenzo try { 126*436c6c9cSStella Laurenzo mlirAttributes.push_back(attribute.cast<PyAttribute>()); 127*436c6c9cSStella Laurenzo } catch (py::cast_error &err) { 128*436c6c9cSStella Laurenzo std::string msg = std::string("Invalid attribute when attempting " 129*436c6c9cSStella Laurenzo "to create an ArrayAttribute (") + 130*436c6c9cSStella Laurenzo err.what() + ")"; 131*436c6c9cSStella Laurenzo throw py::cast_error(msg); 132*436c6c9cSStella Laurenzo } catch (py::reference_cast_error &err) { 133*436c6c9cSStella Laurenzo // This exception seems thrown when the value is "None". 134*436c6c9cSStella Laurenzo std::string msg = 135*436c6c9cSStella Laurenzo std::string("Invalid attribute (None?) when attempting to " 136*436c6c9cSStella Laurenzo "create an ArrayAttribute (") + 137*436c6c9cSStella Laurenzo err.what() + ")"; 138*436c6c9cSStella Laurenzo throw py::cast_error(msg); 139*436c6c9cSStella Laurenzo } 140*436c6c9cSStella Laurenzo } 141*436c6c9cSStella Laurenzo MlirAttribute attr = mlirArrayAttrGet( 142*436c6c9cSStella Laurenzo context->get(), mlirAttributes.size(), mlirAttributes.data()); 143*436c6c9cSStella Laurenzo return PyArrayAttribute(context->getRef(), attr); 144*436c6c9cSStella Laurenzo }, 145*436c6c9cSStella Laurenzo py::arg("attributes"), py::arg("context") = py::none(), 146*436c6c9cSStella Laurenzo "Gets a uniqued Array attribute"); 147*436c6c9cSStella Laurenzo c.def("__getitem__", 148*436c6c9cSStella Laurenzo [](PyArrayAttribute &arr, intptr_t i) { 149*436c6c9cSStella Laurenzo if (i >= mlirArrayAttrGetNumElements(arr)) 150*436c6c9cSStella Laurenzo throw py::index_error("ArrayAttribute index out of range"); 151*436c6c9cSStella Laurenzo return PyAttribute(arr.getContext(), 152*436c6c9cSStella Laurenzo mlirArrayAttrGetElement(arr, i)); 153*436c6c9cSStella Laurenzo }) 154*436c6c9cSStella Laurenzo .def("__len__", 155*436c6c9cSStella Laurenzo [](const PyArrayAttribute &arr) { 156*436c6c9cSStella Laurenzo return mlirArrayAttrGetNumElements(arr); 157*436c6c9cSStella Laurenzo }) 158*436c6c9cSStella Laurenzo .def("__iter__", [](const PyArrayAttribute &arr) { 159*436c6c9cSStella Laurenzo return PyArrayAttributeIterator(arr); 160*436c6c9cSStella Laurenzo }); 161*436c6c9cSStella Laurenzo } 162*436c6c9cSStella Laurenzo }; 163*436c6c9cSStella Laurenzo 164*436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr. 165*436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 166*436c6c9cSStella Laurenzo public: 167*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 168*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FloatAttr"; 169*436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 170*436c6c9cSStella Laurenzo 171*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 172*436c6c9cSStella Laurenzo c.def_static( 173*436c6c9cSStella Laurenzo "get", 174*436c6c9cSStella Laurenzo [](PyType &type, double value, DefaultingPyLocation loc) { 175*436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 176*436c6c9cSStella Laurenzo // TODO: Rework error reporting once diagnostic engine is exposed 177*436c6c9cSStella Laurenzo // in C API. 178*436c6c9cSStella Laurenzo if (mlirAttributeIsNull(attr)) { 179*436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, 180*436c6c9cSStella Laurenzo Twine("invalid '") + 181*436c6c9cSStella Laurenzo py::repr(py::cast(type)).cast<std::string>() + 182*436c6c9cSStella Laurenzo "' and expected floating point type."); 183*436c6c9cSStella Laurenzo } 184*436c6c9cSStella Laurenzo return PyFloatAttribute(type.getContext(), attr); 185*436c6c9cSStella Laurenzo }, 186*436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 187*436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a type"); 188*436c6c9cSStella Laurenzo c.def_static( 189*436c6c9cSStella Laurenzo "get_f32", 190*436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 191*436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 192*436c6c9cSStella Laurenzo context->get(), mlirF32TypeGet(context->get()), value); 193*436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 194*436c6c9cSStella Laurenzo }, 195*436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 196*436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f32 type"); 197*436c6c9cSStella Laurenzo c.def_static( 198*436c6c9cSStella Laurenzo "get_f64", 199*436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 200*436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 201*436c6c9cSStella Laurenzo context->get(), mlirF64TypeGet(context->get()), value); 202*436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 203*436c6c9cSStella Laurenzo }, 204*436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 205*436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f64 type"); 206*436c6c9cSStella Laurenzo c.def_property_readonly( 207*436c6c9cSStella Laurenzo "value", 208*436c6c9cSStella Laurenzo [](PyFloatAttribute &self) { 209*436c6c9cSStella Laurenzo return mlirFloatAttrGetValueDouble(self); 210*436c6c9cSStella Laurenzo }, 211*436c6c9cSStella Laurenzo "Returns the value of the float point attribute"); 212*436c6c9cSStella Laurenzo } 213*436c6c9cSStella Laurenzo }; 214*436c6c9cSStella Laurenzo 215*436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr. 216*436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 217*436c6c9cSStella Laurenzo public: 218*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 219*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "IntegerAttr"; 220*436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 221*436c6c9cSStella Laurenzo 222*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 223*436c6c9cSStella Laurenzo c.def_static( 224*436c6c9cSStella Laurenzo "get", 225*436c6c9cSStella Laurenzo [](PyType &type, int64_t value) { 226*436c6c9cSStella Laurenzo MlirAttribute attr = mlirIntegerAttrGet(type, value); 227*436c6c9cSStella Laurenzo return PyIntegerAttribute(type.getContext(), attr); 228*436c6c9cSStella Laurenzo }, 229*436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), 230*436c6c9cSStella Laurenzo "Gets an uniqued integer attribute associated to a type"); 231*436c6c9cSStella Laurenzo c.def_property_readonly( 232*436c6c9cSStella Laurenzo "value", 233*436c6c9cSStella Laurenzo [](PyIntegerAttribute &self) { 234*436c6c9cSStella Laurenzo return mlirIntegerAttrGetValueInt(self); 235*436c6c9cSStella Laurenzo }, 236*436c6c9cSStella Laurenzo "Returns the value of the integer attribute"); 237*436c6c9cSStella Laurenzo } 238*436c6c9cSStella Laurenzo }; 239*436c6c9cSStella Laurenzo 240*436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr. 241*436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 242*436c6c9cSStella Laurenzo public: 243*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 244*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "BoolAttr"; 245*436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 246*436c6c9cSStella Laurenzo 247*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 248*436c6c9cSStella Laurenzo c.def_static( 249*436c6c9cSStella Laurenzo "get", 250*436c6c9cSStella Laurenzo [](bool value, DefaultingPyMlirContext context) { 251*436c6c9cSStella Laurenzo MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 252*436c6c9cSStella Laurenzo return PyBoolAttribute(context->getRef(), attr); 253*436c6c9cSStella Laurenzo }, 254*436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 255*436c6c9cSStella Laurenzo "Gets an uniqued bool attribute"); 256*436c6c9cSStella Laurenzo c.def_property_readonly( 257*436c6c9cSStella Laurenzo "value", 258*436c6c9cSStella Laurenzo [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, 259*436c6c9cSStella Laurenzo "Returns the value of the bool attribute"); 260*436c6c9cSStella Laurenzo } 261*436c6c9cSStella Laurenzo }; 262*436c6c9cSStella Laurenzo 263*436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute 264*436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 265*436c6c9cSStella Laurenzo public: 266*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 267*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 268*436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 269*436c6c9cSStella Laurenzo 270*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 271*436c6c9cSStella Laurenzo c.def_static( 272*436c6c9cSStella Laurenzo "get", 273*436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 274*436c6c9cSStella Laurenzo MlirAttribute attr = 275*436c6c9cSStella Laurenzo mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 276*436c6c9cSStella Laurenzo return PyFlatSymbolRefAttribute(context->getRef(), attr); 277*436c6c9cSStella Laurenzo }, 278*436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 279*436c6c9cSStella Laurenzo "Gets a uniqued FlatSymbolRef attribute"); 280*436c6c9cSStella Laurenzo c.def_property_readonly( 281*436c6c9cSStella Laurenzo "value", 282*436c6c9cSStella Laurenzo [](PyFlatSymbolRefAttribute &self) { 283*436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 284*436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 285*436c6c9cSStella Laurenzo }, 286*436c6c9cSStella Laurenzo "Returns the value of the FlatSymbolRef attribute as a string"); 287*436c6c9cSStella Laurenzo } 288*436c6c9cSStella Laurenzo }; 289*436c6c9cSStella Laurenzo 290*436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 291*436c6c9cSStella Laurenzo public: 292*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 293*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "StringAttr"; 294*436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 295*436c6c9cSStella Laurenzo 296*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 297*436c6c9cSStella Laurenzo c.def_static( 298*436c6c9cSStella Laurenzo "get", 299*436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 300*436c6c9cSStella Laurenzo MlirAttribute attr = 301*436c6c9cSStella Laurenzo mlirStringAttrGet(context->get(), toMlirStringRef(value)); 302*436c6c9cSStella Laurenzo return PyStringAttribute(context->getRef(), attr); 303*436c6c9cSStella Laurenzo }, 304*436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 305*436c6c9cSStella Laurenzo "Gets a uniqued string attribute"); 306*436c6c9cSStella Laurenzo c.def_static( 307*436c6c9cSStella Laurenzo "get_typed", 308*436c6c9cSStella Laurenzo [](PyType &type, std::string value) { 309*436c6c9cSStella Laurenzo MlirAttribute attr = 310*436c6c9cSStella Laurenzo mlirStringAttrTypedGet(type, toMlirStringRef(value)); 311*436c6c9cSStella Laurenzo return PyStringAttribute(type.getContext(), attr); 312*436c6c9cSStella Laurenzo }, 313*436c6c9cSStella Laurenzo 314*436c6c9cSStella Laurenzo "Gets a uniqued string attribute associated to a type"); 315*436c6c9cSStella Laurenzo c.def_property_readonly( 316*436c6c9cSStella Laurenzo "value", 317*436c6c9cSStella Laurenzo [](PyStringAttribute &self) { 318*436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirStringAttrGetValue(self); 319*436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 320*436c6c9cSStella Laurenzo }, 321*436c6c9cSStella Laurenzo "Returns the value of the string attribute"); 322*436c6c9cSStella Laurenzo } 323*436c6c9cSStella Laurenzo }; 324*436c6c9cSStella Laurenzo 325*436c6c9cSStella Laurenzo // TODO: Support construction of bool elements. 326*436c6c9cSStella Laurenzo // TODO: Support construction of string elements. 327*436c6c9cSStella Laurenzo class PyDenseElementsAttribute 328*436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseElementsAttribute> { 329*436c6c9cSStella Laurenzo public: 330*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 331*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseElementsAttr"; 332*436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 333*436c6c9cSStella Laurenzo 334*436c6c9cSStella Laurenzo static PyDenseElementsAttribute 335*436c6c9cSStella Laurenzo getFromBuffer(py::buffer array, bool signless, 336*436c6c9cSStella Laurenzo DefaultingPyMlirContext contextWrapper) { 337*436c6c9cSStella Laurenzo // Request a contiguous view. In exotic cases, this will cause a copy. 338*436c6c9cSStella Laurenzo int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; 339*436c6c9cSStella Laurenzo Py_buffer *view = new Py_buffer(); 340*436c6c9cSStella Laurenzo if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { 341*436c6c9cSStella Laurenzo delete view; 342*436c6c9cSStella Laurenzo throw py::error_already_set(); 343*436c6c9cSStella Laurenzo } 344*436c6c9cSStella Laurenzo py::buffer_info arrayInfo(view); 345*436c6c9cSStella Laurenzo 346*436c6c9cSStella Laurenzo MlirContext context = contextWrapper->get(); 347*436c6c9cSStella Laurenzo // Switch on the types that can be bulk loaded between the Python and 348*436c6c9cSStella Laurenzo // MLIR-C APIs. 349*436c6c9cSStella Laurenzo // See: https://docs.python.org/3/library/struct.html#format-characters 350*436c6c9cSStella Laurenzo if (arrayInfo.format == "f") { 351*436c6c9cSStella Laurenzo // f32 352*436c6c9cSStella Laurenzo assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); 353*436c6c9cSStella Laurenzo return PyDenseElementsAttribute( 354*436c6c9cSStella Laurenzo contextWrapper->getRef(), 355*436c6c9cSStella Laurenzo bulkLoad(context, mlirDenseElementsAttrFloatGet, 356*436c6c9cSStella Laurenzo mlirF32TypeGet(context), arrayInfo)); 357*436c6c9cSStella Laurenzo } else if (arrayInfo.format == "d") { 358*436c6c9cSStella Laurenzo // f64 359*436c6c9cSStella Laurenzo assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); 360*436c6c9cSStella Laurenzo return PyDenseElementsAttribute( 361*436c6c9cSStella Laurenzo contextWrapper->getRef(), 362*436c6c9cSStella Laurenzo bulkLoad(context, mlirDenseElementsAttrDoubleGet, 363*436c6c9cSStella Laurenzo mlirF64TypeGet(context), arrayInfo)); 364*436c6c9cSStella Laurenzo } else if (isSignedIntegerFormat(arrayInfo.format)) { 365*436c6c9cSStella Laurenzo if (arrayInfo.itemsize == 4) { 366*436c6c9cSStella Laurenzo // i32 367*436c6c9cSStella Laurenzo MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) 368*436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 32); 369*436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), 370*436c6c9cSStella Laurenzo bulkLoad(context, 371*436c6c9cSStella Laurenzo mlirDenseElementsAttrInt32Get, 372*436c6c9cSStella Laurenzo elementType, arrayInfo)); 373*436c6c9cSStella Laurenzo } else if (arrayInfo.itemsize == 8) { 374*436c6c9cSStella Laurenzo // i64 375*436c6c9cSStella Laurenzo MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) 376*436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 64); 377*436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), 378*436c6c9cSStella Laurenzo bulkLoad(context, 379*436c6c9cSStella Laurenzo mlirDenseElementsAttrInt64Get, 380*436c6c9cSStella Laurenzo elementType, arrayInfo)); 381*436c6c9cSStella Laurenzo } 382*436c6c9cSStella Laurenzo } else if (isUnsignedIntegerFormat(arrayInfo.format)) { 383*436c6c9cSStella Laurenzo if (arrayInfo.itemsize == 4) { 384*436c6c9cSStella Laurenzo // unsigned i32 385*436c6c9cSStella Laurenzo MlirType elementType = signless 386*436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 32) 387*436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 32); 388*436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), 389*436c6c9cSStella Laurenzo bulkLoad(context, 390*436c6c9cSStella Laurenzo mlirDenseElementsAttrUInt32Get, 391*436c6c9cSStella Laurenzo elementType, arrayInfo)); 392*436c6c9cSStella Laurenzo } else if (arrayInfo.itemsize == 8) { 393*436c6c9cSStella Laurenzo // unsigned i64 394*436c6c9cSStella Laurenzo MlirType elementType = signless 395*436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 64) 396*436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 64); 397*436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), 398*436c6c9cSStella Laurenzo bulkLoad(context, 399*436c6c9cSStella Laurenzo mlirDenseElementsAttrUInt64Get, 400*436c6c9cSStella Laurenzo elementType, arrayInfo)); 401*436c6c9cSStella Laurenzo } 402*436c6c9cSStella Laurenzo } 403*436c6c9cSStella Laurenzo 404*436c6c9cSStella Laurenzo // TODO: Fall back to string-based get. 405*436c6c9cSStella Laurenzo std::string message = "unimplemented array format conversion from format: "; 406*436c6c9cSStella Laurenzo message.append(arrayInfo.format); 407*436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 408*436c6c9cSStella Laurenzo } 409*436c6c9cSStella Laurenzo 410*436c6c9cSStella Laurenzo static PyDenseElementsAttribute getSplat(PyType shapedType, 411*436c6c9cSStella Laurenzo PyAttribute &elementAttr) { 412*436c6c9cSStella Laurenzo auto contextWrapper = 413*436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 414*436c6c9cSStella Laurenzo if (!mlirAttributeIsAInteger(elementAttr) && 415*436c6c9cSStella Laurenzo !mlirAttributeIsAFloat(elementAttr)) { 416*436c6c9cSStella Laurenzo std::string message = "Illegal element type for DenseElementsAttr: "; 417*436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 418*436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 419*436c6c9cSStella Laurenzo } 420*436c6c9cSStella Laurenzo if (!mlirTypeIsAShaped(shapedType) || 421*436c6c9cSStella Laurenzo !mlirShapedTypeHasStaticShape(shapedType)) { 422*436c6c9cSStella Laurenzo std::string message = 423*436c6c9cSStella Laurenzo "Expected a static ShapedType for the shaped_type parameter: "; 424*436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 425*436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 426*436c6c9cSStella Laurenzo } 427*436c6c9cSStella Laurenzo MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 428*436c6c9cSStella Laurenzo MlirType attrType = mlirAttributeGetType(elementAttr); 429*436c6c9cSStella Laurenzo if (!mlirTypeEqual(shapedElementType, attrType)) { 430*436c6c9cSStella Laurenzo std::string message = 431*436c6c9cSStella Laurenzo "Shaped element type and attribute type must be equal: shaped="; 432*436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 433*436c6c9cSStella Laurenzo message.append(", element="); 434*436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 435*436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 436*436c6c9cSStella Laurenzo } 437*436c6c9cSStella Laurenzo 438*436c6c9cSStella Laurenzo MlirAttribute elements = 439*436c6c9cSStella Laurenzo mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 440*436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 441*436c6c9cSStella Laurenzo } 442*436c6c9cSStella Laurenzo 443*436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 444*436c6c9cSStella Laurenzo 445*436c6c9cSStella Laurenzo py::buffer_info accessBuffer() { 446*436c6c9cSStella Laurenzo MlirType shapedType = mlirAttributeGetType(*this); 447*436c6c9cSStella Laurenzo MlirType elementType = mlirShapedTypeGetElementType(shapedType); 448*436c6c9cSStella Laurenzo 449*436c6c9cSStella Laurenzo if (mlirTypeIsAF32(elementType)) { 450*436c6c9cSStella Laurenzo // f32 451*436c6c9cSStella Laurenzo return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue); 452*436c6c9cSStella Laurenzo } else if (mlirTypeIsAF64(elementType)) { 453*436c6c9cSStella Laurenzo // f64 454*436c6c9cSStella Laurenzo return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue); 455*436c6c9cSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 456*436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 32) { 457*436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 458*436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 459*436c6c9cSStella Laurenzo // i32 460*436c6c9cSStella Laurenzo return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value); 461*436c6c9cSStella Laurenzo } else if (mlirIntegerTypeIsUnsigned(elementType)) { 462*436c6c9cSStella Laurenzo // unsigned i32 463*436c6c9cSStella Laurenzo return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value); 464*436c6c9cSStella Laurenzo } 465*436c6c9cSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 466*436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 64) { 467*436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 468*436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 469*436c6c9cSStella Laurenzo // i64 470*436c6c9cSStella Laurenzo return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value); 471*436c6c9cSStella Laurenzo } else if (mlirIntegerTypeIsUnsigned(elementType)) { 472*436c6c9cSStella Laurenzo // unsigned i64 473*436c6c9cSStella Laurenzo return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value); 474*436c6c9cSStella Laurenzo } 475*436c6c9cSStella Laurenzo } 476*436c6c9cSStella Laurenzo 477*436c6c9cSStella Laurenzo std::string message = "unimplemented array format."; 478*436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 479*436c6c9cSStella Laurenzo } 480*436c6c9cSStella Laurenzo 481*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 482*436c6c9cSStella Laurenzo c.def("__len__", &PyDenseElementsAttribute::dunderLen) 483*436c6c9cSStella Laurenzo .def_static("get", PyDenseElementsAttribute::getFromBuffer, 484*436c6c9cSStella Laurenzo py::arg("array"), py::arg("signless") = true, 485*436c6c9cSStella Laurenzo py::arg("context") = py::none(), 486*436c6c9cSStella Laurenzo "Gets from a buffer or ndarray") 487*436c6c9cSStella Laurenzo .def_static("get_splat", PyDenseElementsAttribute::getSplat, 488*436c6c9cSStella Laurenzo py::arg("shaped_type"), py::arg("element_attr"), 489*436c6c9cSStella Laurenzo "Gets a DenseElementsAttr where all values are the same") 490*436c6c9cSStella Laurenzo .def_property_readonly("is_splat", 491*436c6c9cSStella Laurenzo [](PyDenseElementsAttribute &self) -> bool { 492*436c6c9cSStella Laurenzo return mlirDenseElementsAttrIsSplat(self); 493*436c6c9cSStella Laurenzo }) 494*436c6c9cSStella Laurenzo .def_buffer(&PyDenseElementsAttribute::accessBuffer); 495*436c6c9cSStella Laurenzo } 496*436c6c9cSStella Laurenzo 497*436c6c9cSStella Laurenzo private: 498*436c6c9cSStella Laurenzo template <typename ElementTy> 499*436c6c9cSStella Laurenzo static MlirAttribute 500*436c6c9cSStella Laurenzo bulkLoad(MlirContext context, 501*436c6c9cSStella Laurenzo MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *), 502*436c6c9cSStella Laurenzo MlirType mlirElementType, py::buffer_info &arrayInfo) { 503*436c6c9cSStella Laurenzo SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(), 504*436c6c9cSStella Laurenzo arrayInfo.shape.begin() + arrayInfo.ndim); 505*436c6c9cSStella Laurenzo auto shapedType = 506*436c6c9cSStella Laurenzo mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType); 507*436c6c9cSStella Laurenzo intptr_t numElements = arrayInfo.size; 508*436c6c9cSStella Laurenzo const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr); 509*436c6c9cSStella Laurenzo return ctor(shapedType, numElements, contents); 510*436c6c9cSStella Laurenzo } 511*436c6c9cSStella Laurenzo 512*436c6c9cSStella Laurenzo static bool isUnsignedIntegerFormat(const std::string &format) { 513*436c6c9cSStella Laurenzo if (format.empty()) 514*436c6c9cSStella Laurenzo return false; 515*436c6c9cSStella Laurenzo char code = format[0]; 516*436c6c9cSStella Laurenzo return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 517*436c6c9cSStella Laurenzo code == 'Q'; 518*436c6c9cSStella Laurenzo } 519*436c6c9cSStella Laurenzo 520*436c6c9cSStella Laurenzo static bool isSignedIntegerFormat(const std::string &format) { 521*436c6c9cSStella Laurenzo if (format.empty()) 522*436c6c9cSStella Laurenzo return false; 523*436c6c9cSStella Laurenzo char code = format[0]; 524*436c6c9cSStella Laurenzo return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 525*436c6c9cSStella Laurenzo code == 'q'; 526*436c6c9cSStella Laurenzo } 527*436c6c9cSStella Laurenzo 528*436c6c9cSStella Laurenzo template <typename Type> 529*436c6c9cSStella Laurenzo py::buffer_info bufferInfo(MlirType shapedType, 530*436c6c9cSStella Laurenzo Type (*value)(MlirAttribute, intptr_t)) { 531*436c6c9cSStella Laurenzo intptr_t rank = mlirShapedTypeGetRank(shapedType); 532*436c6c9cSStella Laurenzo // Prepare the data for the buffer_info. 533*436c6c9cSStella Laurenzo // Buffer is configured for read-only access below. 534*436c6c9cSStella Laurenzo Type *data = static_cast<Type *>( 535*436c6c9cSStella Laurenzo const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 536*436c6c9cSStella Laurenzo // Prepare the shape for the buffer_info. 537*436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> shape; 538*436c6c9cSStella Laurenzo for (intptr_t i = 0; i < rank; ++i) 539*436c6c9cSStella Laurenzo shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 540*436c6c9cSStella Laurenzo // Prepare the strides for the buffer_info. 541*436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> strides; 542*436c6c9cSStella Laurenzo intptr_t strideFactor = 1; 543*436c6c9cSStella Laurenzo for (intptr_t i = 1; i < rank; ++i) { 544*436c6c9cSStella Laurenzo strideFactor = 1; 545*436c6c9cSStella Laurenzo for (intptr_t j = i; j < rank; ++j) { 546*436c6c9cSStella Laurenzo strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 547*436c6c9cSStella Laurenzo } 548*436c6c9cSStella Laurenzo strides.push_back(sizeof(Type) * strideFactor); 549*436c6c9cSStella Laurenzo } 550*436c6c9cSStella Laurenzo strides.push_back(sizeof(Type)); 551*436c6c9cSStella Laurenzo return py::buffer_info(data, sizeof(Type), 552*436c6c9cSStella Laurenzo py::format_descriptor<Type>::format(), rank, shape, 553*436c6c9cSStella Laurenzo strides, /*readonly=*/true); 554*436c6c9cSStella Laurenzo } 555*436c6c9cSStella Laurenzo }; // namespace 556*436c6c9cSStella Laurenzo 557*436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer 558*436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access. 559*436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute 560*436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseIntElementsAttribute, 561*436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 562*436c6c9cSStella Laurenzo public: 563*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 564*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseIntElementsAttr"; 565*436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 566*436c6c9cSStella Laurenzo 567*436c6c9cSStella Laurenzo /// Returns the element at the given linear position. Asserts if the index is 568*436c6c9cSStella Laurenzo /// out of range. 569*436c6c9cSStella Laurenzo py::int_ dunderGetItem(intptr_t pos) { 570*436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 571*436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 572*436c6c9cSStella Laurenzo "attempt to access out of bounds element"); 573*436c6c9cSStella Laurenzo } 574*436c6c9cSStella Laurenzo 575*436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 576*436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 577*436c6c9cSStella Laurenzo assert(mlirTypeIsAInteger(type) && 578*436c6c9cSStella Laurenzo "expected integer element type in dense int elements attribute"); 579*436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 580*436c6c9cSStella Laurenzo // elemental type of the attribute. py::int_ is implicitly constructible 581*436c6c9cSStella Laurenzo // from any C++ integral type and handles bitwidth correctly. 582*436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 583*436c6c9cSStella Laurenzo // querying them on each element access. 584*436c6c9cSStella Laurenzo unsigned width = mlirIntegerTypeGetWidth(type); 585*436c6c9cSStella Laurenzo bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 586*436c6c9cSStella Laurenzo if (isUnsigned) { 587*436c6c9cSStella Laurenzo if (width == 1) { 588*436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 589*436c6c9cSStella Laurenzo } 590*436c6c9cSStella Laurenzo if (width == 32) { 591*436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt32Value(*this, pos); 592*436c6c9cSStella Laurenzo } 593*436c6c9cSStella Laurenzo if (width == 64) { 594*436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt64Value(*this, pos); 595*436c6c9cSStella Laurenzo } 596*436c6c9cSStella Laurenzo } else { 597*436c6c9cSStella Laurenzo if (width == 1) { 598*436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 599*436c6c9cSStella Laurenzo } 600*436c6c9cSStella Laurenzo if (width == 32) { 601*436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt32Value(*this, pos); 602*436c6c9cSStella Laurenzo } 603*436c6c9cSStella Laurenzo if (width == 64) { 604*436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt64Value(*this, pos); 605*436c6c9cSStella Laurenzo } 606*436c6c9cSStella Laurenzo } 607*436c6c9cSStella Laurenzo throw SetPyError(PyExc_TypeError, "Unsupported integer type"); 608*436c6c9cSStella Laurenzo } 609*436c6c9cSStella Laurenzo 610*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 611*436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 612*436c6c9cSStella Laurenzo } 613*436c6c9cSStella Laurenzo }; 614*436c6c9cSStella Laurenzo 615*436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 616*436c6c9cSStella Laurenzo public: 617*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 618*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DictAttr"; 619*436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 620*436c6c9cSStella Laurenzo 621*436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 622*436c6c9cSStella Laurenzo 623*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 624*436c6c9cSStella Laurenzo c.def("__len__", &PyDictAttribute::dunderLen); 625*436c6c9cSStella Laurenzo c.def_static( 626*436c6c9cSStella Laurenzo "get", 627*436c6c9cSStella Laurenzo [](py::dict attributes, DefaultingPyMlirContext context) { 628*436c6c9cSStella Laurenzo SmallVector<MlirNamedAttribute> mlirNamedAttributes; 629*436c6c9cSStella Laurenzo mlirNamedAttributes.reserve(attributes.size()); 630*436c6c9cSStella Laurenzo for (auto &it : attributes) { 631*436c6c9cSStella Laurenzo auto &mlir_attr = it.second.cast<PyAttribute &>(); 632*436c6c9cSStella Laurenzo auto name = it.first.cast<std::string>(); 633*436c6c9cSStella Laurenzo mlirNamedAttributes.push_back(mlirNamedAttributeGet( 634*436c6c9cSStella Laurenzo mlirIdentifierGet(mlirAttributeGetContext(mlir_attr), 635*436c6c9cSStella Laurenzo toMlirStringRef(name)), 636*436c6c9cSStella Laurenzo mlir_attr)); 637*436c6c9cSStella Laurenzo } 638*436c6c9cSStella Laurenzo MlirAttribute attr = 639*436c6c9cSStella Laurenzo mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 640*436c6c9cSStella Laurenzo mlirNamedAttributes.data()); 641*436c6c9cSStella Laurenzo return PyDictAttribute(context->getRef(), attr); 642*436c6c9cSStella Laurenzo }, 643*436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 644*436c6c9cSStella Laurenzo "Gets an uniqued dict attribute"); 645*436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 646*436c6c9cSStella Laurenzo MlirAttribute attr = 647*436c6c9cSStella Laurenzo mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 648*436c6c9cSStella Laurenzo if (mlirAttributeIsNull(attr)) { 649*436c6c9cSStella Laurenzo throw SetPyError(PyExc_KeyError, 650*436c6c9cSStella Laurenzo "attempt to access a non-existent attribute"); 651*436c6c9cSStella Laurenzo } 652*436c6c9cSStella Laurenzo return PyAttribute(self.getContext(), attr); 653*436c6c9cSStella Laurenzo }); 654*436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 655*436c6c9cSStella Laurenzo if (index < 0 || index >= self.dunderLen()) { 656*436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 657*436c6c9cSStella Laurenzo "attempt to access out of bounds attribute"); 658*436c6c9cSStella Laurenzo } 659*436c6c9cSStella Laurenzo MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 660*436c6c9cSStella Laurenzo return PyNamedAttribute( 661*436c6c9cSStella Laurenzo namedAttr.attribute, 662*436c6c9cSStella Laurenzo std::string(mlirIdentifierStr(namedAttr.name).data)); 663*436c6c9cSStella Laurenzo }); 664*436c6c9cSStella Laurenzo } 665*436c6c9cSStella Laurenzo }; 666*436c6c9cSStella Laurenzo 667*436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing 668*436c6c9cSStella Laurenzo /// floating-point values. Supports element access. 669*436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute 670*436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseFPElementsAttribute, 671*436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 672*436c6c9cSStella Laurenzo public: 673*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 674*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseFPElementsAttr"; 675*436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 676*436c6c9cSStella Laurenzo 677*436c6c9cSStella Laurenzo py::float_ dunderGetItem(intptr_t pos) { 678*436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 679*436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 680*436c6c9cSStella Laurenzo "attempt to access out of bounds element"); 681*436c6c9cSStella Laurenzo } 682*436c6c9cSStella Laurenzo 683*436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 684*436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 685*436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 686*436c6c9cSStella Laurenzo // elemental type of the attribute. py::float_ is implicitly constructible 687*436c6c9cSStella Laurenzo // from float and double. 688*436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 689*436c6c9cSStella Laurenzo // querying them on each element access. 690*436c6c9cSStella Laurenzo if (mlirTypeIsAF32(type)) { 691*436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetFloatValue(*this, pos); 692*436c6c9cSStella Laurenzo } 693*436c6c9cSStella Laurenzo if (mlirTypeIsAF64(type)) { 694*436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetDoubleValue(*this, pos); 695*436c6c9cSStella Laurenzo } 696*436c6c9cSStella Laurenzo throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); 697*436c6c9cSStella Laurenzo } 698*436c6c9cSStella Laurenzo 699*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 700*436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 701*436c6c9cSStella Laurenzo } 702*436c6c9cSStella Laurenzo }; 703*436c6c9cSStella Laurenzo 704*436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 705*436c6c9cSStella Laurenzo public: 706*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 707*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "TypeAttr"; 708*436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 709*436c6c9cSStella Laurenzo 710*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 711*436c6c9cSStella Laurenzo c.def_static( 712*436c6c9cSStella Laurenzo "get", 713*436c6c9cSStella Laurenzo [](PyType value, DefaultingPyMlirContext context) { 714*436c6c9cSStella Laurenzo MlirAttribute attr = mlirTypeAttrGet(value.get()); 715*436c6c9cSStella Laurenzo return PyTypeAttribute(context->getRef(), attr); 716*436c6c9cSStella Laurenzo }, 717*436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 718*436c6c9cSStella Laurenzo "Gets a uniqued Type attribute"); 719*436c6c9cSStella Laurenzo c.def_property_readonly("value", [](PyTypeAttribute &self) { 720*436c6c9cSStella Laurenzo return PyType(self.getContext()->getRef(), 721*436c6c9cSStella Laurenzo mlirTypeAttrGetValue(self.get())); 722*436c6c9cSStella Laurenzo }); 723*436c6c9cSStella Laurenzo } 724*436c6c9cSStella Laurenzo }; 725*436c6c9cSStella Laurenzo 726*436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values. 727*436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 728*436c6c9cSStella Laurenzo public: 729*436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 730*436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "UnitAttr"; 731*436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 732*436c6c9cSStella Laurenzo 733*436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 734*436c6c9cSStella Laurenzo c.def_static( 735*436c6c9cSStella Laurenzo "get", 736*436c6c9cSStella Laurenzo [](DefaultingPyMlirContext context) { 737*436c6c9cSStella Laurenzo return PyUnitAttribute(context->getRef(), 738*436c6c9cSStella Laurenzo mlirUnitAttrGet(context->get())); 739*436c6c9cSStella Laurenzo }, 740*436c6c9cSStella Laurenzo py::arg("context") = py::none(), "Create a Unit attribute."); 741*436c6c9cSStella Laurenzo } 742*436c6c9cSStella Laurenzo }; 743*436c6c9cSStella Laurenzo 744*436c6c9cSStella Laurenzo } // namespace 745*436c6c9cSStella Laurenzo 746*436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) { 747*436c6c9cSStella Laurenzo PyAffineMapAttribute::bind(m); 748*436c6c9cSStella Laurenzo PyArrayAttribute::bind(m); 749*436c6c9cSStella Laurenzo PyArrayAttribute::PyArrayAttributeIterator::bind(m); 750*436c6c9cSStella Laurenzo PyBoolAttribute::bind(m); 751*436c6c9cSStella Laurenzo PyDenseElementsAttribute::bind(m); 752*436c6c9cSStella Laurenzo PyDenseFPElementsAttribute::bind(m); 753*436c6c9cSStella Laurenzo PyDenseIntElementsAttribute::bind(m); 754*436c6c9cSStella Laurenzo PyDictAttribute::bind(m); 755*436c6c9cSStella Laurenzo PyFlatSymbolRefAttribute::bind(m); 756*436c6c9cSStella Laurenzo PyFloatAttribute::bind(m); 757*436c6c9cSStella Laurenzo PyIntegerAttribute::bind(m); 758*436c6c9cSStella Laurenzo PyStringAttribute::bind(m); 759*436c6c9cSStella Laurenzo PyTypeAttribute::bind(m); 760*436c6c9cSStella Laurenzo PyUnitAttribute::bind(m); 761*436c6c9cSStella Laurenzo } 762