1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2436c6c9cSStella Laurenzo // 3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6436c6c9cSStella Laurenzo // 7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8436c6c9cSStella Laurenzo 9436c6c9cSStella Laurenzo #include "IRModule.h" 10436c6c9cSStella Laurenzo 11436c6c9cSStella Laurenzo #include "PybindUtils.h" 12436c6c9cSStella Laurenzo 13436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h" 14436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h" 15436c6c9cSStella Laurenzo 16436c6c9cSStella Laurenzo namespace py = pybind11; 17436c6c9cSStella Laurenzo using namespace mlir; 18436c6c9cSStella Laurenzo using namespace mlir::python; 19436c6c9cSStella Laurenzo 20*5d6d30edSStella Laurenzo using llvm::None; 21*5d6d30edSStella Laurenzo using llvm::Optional; 22436c6c9cSStella Laurenzo using llvm::SmallVector; 23436c6c9cSStella Laurenzo using llvm::Twine; 24436c6c9cSStella Laurenzo 25*5d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 26*5d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline). 27*5d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 28*5d6d30edSStella Laurenzo 29*5d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] = 30*5d6d30edSStella Laurenzo R"(Gets a DenseElementsAttr from a Python buffer or array. 31*5d6d30edSStella Laurenzo 32*5d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based 33*5d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and 34*5d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these 35*5d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer. 36*5d6d30edSStella Laurenzo 37*5d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided 38*5d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal 39*5d6d30edSStella Laurenzo representation: 40*5d6d30edSStella Laurenzo 41*5d6d30edSStella Laurenzo * Integer types (except for i1): the buffer must be byte aligned to the 42*5d6d30edSStella Laurenzo next byte boundary. 43*5d6d30edSStella Laurenzo * Floating point types: Must be bit-castable to the given floating point 44*5d6d30edSStella Laurenzo size. 45*5d6d30edSStella Laurenzo * i1 (bool): Bit packed into 8bit words where the bit pattern matches a 46*5d6d30edSStella Laurenzo row major ordering. An arbitrary Numpy `bool_` array can be bit packed to 47*5d6d30edSStella Laurenzo this specification with: `np.packbits(ary, axis=None, bitorder='little')`. 48*5d6d30edSStella Laurenzo 49*5d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0 50*5d6d30edSStella Laurenzo or 255), then a splat will be created. 51*5d6d30edSStella Laurenzo 52*5d6d30edSStella Laurenzo Args: 53*5d6d30edSStella Laurenzo array: The array or buffer to convert. 54*5d6d30edSStella Laurenzo signless: If inferring an appropriate MLIR type, use signless types for 55*5d6d30edSStella Laurenzo integers (defaults True). 56*5d6d30edSStella Laurenzo type: Skips inference of the MLIR element type and uses this instead. The 57*5d6d30edSStella Laurenzo storage size must be consistent with the actual contents of the buffer. 58*5d6d30edSStella Laurenzo shape: Overrides the shape of the buffer when constructing the MLIR 59*5d6d30edSStella Laurenzo shaped type. This is needed when the physical and logical shape differ (as 60*5d6d30edSStella Laurenzo for i1). 61*5d6d30edSStella Laurenzo context: Explicit context, if not from context manager. 62*5d6d30edSStella Laurenzo 63*5d6d30edSStella Laurenzo Returns: 64*5d6d30edSStella Laurenzo DenseElementsAttr on success. 65*5d6d30edSStella Laurenzo 66*5d6d30edSStella Laurenzo Raises: 67*5d6d30edSStella Laurenzo ValueError: If the type of the buffer or array cannot be matched to an MLIR 68*5d6d30edSStella Laurenzo type or if the buffer does not meet expectations. 69*5d6d30edSStella Laurenzo )"; 70*5d6d30edSStella Laurenzo 71436c6c9cSStella Laurenzo namespace { 72436c6c9cSStella Laurenzo 73436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) { 74436c6c9cSStella Laurenzo return mlirStringRefCreate(s.data(), s.size()); 75436c6c9cSStella Laurenzo } 76436c6c9cSStella Laurenzo 77436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 78436c6c9cSStella Laurenzo public: 79436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 80436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineMapAttr"; 81436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 82436c6c9cSStella Laurenzo 83436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 84436c6c9cSStella Laurenzo c.def_static( 85436c6c9cSStella Laurenzo "get", 86436c6c9cSStella Laurenzo [](PyAffineMap &affineMap) { 87436c6c9cSStella Laurenzo MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 88436c6c9cSStella Laurenzo return PyAffineMapAttribute(affineMap.getContext(), attr); 89436c6c9cSStella Laurenzo }, 90436c6c9cSStella Laurenzo py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 91436c6c9cSStella Laurenzo } 92436c6c9cSStella Laurenzo }; 93436c6c9cSStella Laurenzo 94ed9e52f3SAlex Zinenko template <typename T> 95ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) { 96ed9e52f3SAlex Zinenko try { 97ed9e52f3SAlex Zinenko return object.cast<T>(); 98ed9e52f3SAlex Zinenko } catch (py::cast_error &err) { 99ed9e52f3SAlex Zinenko std::string msg = 100ed9e52f3SAlex Zinenko std::string( 101ed9e52f3SAlex Zinenko "Invalid attribute when attempting to create an ArrayAttribute (") + 102ed9e52f3SAlex Zinenko err.what() + ")"; 103ed9e52f3SAlex Zinenko throw py::cast_error(msg); 104ed9e52f3SAlex Zinenko } catch (py::reference_cast_error &err) { 105ed9e52f3SAlex Zinenko std::string msg = std::string("Invalid attribute (None?) when attempting " 106ed9e52f3SAlex Zinenko "to create an ArrayAttribute (") + 107ed9e52f3SAlex Zinenko err.what() + ")"; 108ed9e52f3SAlex Zinenko throw py::cast_error(msg); 109ed9e52f3SAlex Zinenko } 110ed9e52f3SAlex Zinenko } 111ed9e52f3SAlex Zinenko 112436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 113436c6c9cSStella Laurenzo public: 114436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 115436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "ArrayAttr"; 116436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 117436c6c9cSStella Laurenzo 118436c6c9cSStella Laurenzo class PyArrayAttributeIterator { 119436c6c9cSStella Laurenzo public: 120436c6c9cSStella Laurenzo PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {} 121436c6c9cSStella Laurenzo 122436c6c9cSStella Laurenzo PyArrayAttributeIterator &dunderIter() { return *this; } 123436c6c9cSStella Laurenzo 124436c6c9cSStella Laurenzo PyAttribute dunderNext() { 125436c6c9cSStella Laurenzo if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { 126436c6c9cSStella Laurenzo throw py::stop_iteration(); 127436c6c9cSStella Laurenzo } 128436c6c9cSStella Laurenzo return PyAttribute(attr.getContext(), 129436c6c9cSStella Laurenzo mlirArrayAttrGetElement(attr.get(), nextIndex++)); 130436c6c9cSStella Laurenzo } 131436c6c9cSStella Laurenzo 132436c6c9cSStella Laurenzo static void bind(py::module &m) { 133f05ff4f7SStella Laurenzo py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator", 134f05ff4f7SStella Laurenzo py::module_local()) 135436c6c9cSStella Laurenzo .def("__iter__", &PyArrayAttributeIterator::dunderIter) 136436c6c9cSStella Laurenzo .def("__next__", &PyArrayAttributeIterator::dunderNext); 137436c6c9cSStella Laurenzo } 138436c6c9cSStella Laurenzo 139436c6c9cSStella Laurenzo private: 140436c6c9cSStella Laurenzo PyAttribute attr; 141436c6c9cSStella Laurenzo int nextIndex = 0; 142436c6c9cSStella Laurenzo }; 143436c6c9cSStella Laurenzo 144ed9e52f3SAlex Zinenko PyAttribute getItem(intptr_t i) { 145ed9e52f3SAlex Zinenko return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i)); 146ed9e52f3SAlex Zinenko } 147ed9e52f3SAlex Zinenko 148436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 149436c6c9cSStella Laurenzo c.def_static( 150436c6c9cSStella Laurenzo "get", 151436c6c9cSStella Laurenzo [](py::list attributes, DefaultingPyMlirContext context) { 152436c6c9cSStella Laurenzo SmallVector<MlirAttribute> mlirAttributes; 153436c6c9cSStella Laurenzo mlirAttributes.reserve(py::len(attributes)); 154436c6c9cSStella Laurenzo for (auto attribute : attributes) { 155ed9e52f3SAlex Zinenko mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); 156436c6c9cSStella Laurenzo } 157436c6c9cSStella Laurenzo MlirAttribute attr = mlirArrayAttrGet( 158436c6c9cSStella Laurenzo context->get(), mlirAttributes.size(), mlirAttributes.data()); 159436c6c9cSStella Laurenzo return PyArrayAttribute(context->getRef(), attr); 160436c6c9cSStella Laurenzo }, 161436c6c9cSStella Laurenzo py::arg("attributes"), py::arg("context") = py::none(), 162436c6c9cSStella Laurenzo "Gets a uniqued Array attribute"); 163436c6c9cSStella Laurenzo c.def("__getitem__", 164436c6c9cSStella Laurenzo [](PyArrayAttribute &arr, intptr_t i) { 165436c6c9cSStella Laurenzo if (i >= mlirArrayAttrGetNumElements(arr)) 166436c6c9cSStella Laurenzo throw py::index_error("ArrayAttribute index out of range"); 167ed9e52f3SAlex Zinenko return arr.getItem(i); 168436c6c9cSStella Laurenzo }) 169436c6c9cSStella Laurenzo .def("__len__", 170436c6c9cSStella Laurenzo [](const PyArrayAttribute &arr) { 171436c6c9cSStella Laurenzo return mlirArrayAttrGetNumElements(arr); 172436c6c9cSStella Laurenzo }) 173436c6c9cSStella Laurenzo .def("__iter__", [](const PyArrayAttribute &arr) { 174436c6c9cSStella Laurenzo return PyArrayAttributeIterator(arr); 175436c6c9cSStella Laurenzo }); 176ed9e52f3SAlex Zinenko c.def("__add__", [](PyArrayAttribute arr, py::list extras) { 177ed9e52f3SAlex Zinenko std::vector<MlirAttribute> attributes; 178ed9e52f3SAlex Zinenko intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); 179ed9e52f3SAlex Zinenko attributes.reserve(numOldElements + py::len(extras)); 180ed9e52f3SAlex Zinenko for (intptr_t i = 0; i < numOldElements; ++i) 181ed9e52f3SAlex Zinenko attributes.push_back(arr.getItem(i)); 182ed9e52f3SAlex Zinenko for (py::handle attr : extras) 183ed9e52f3SAlex Zinenko attributes.push_back(pyTryCast<PyAttribute>(attr)); 184ed9e52f3SAlex Zinenko MlirAttribute arrayAttr = mlirArrayAttrGet( 185ed9e52f3SAlex Zinenko arr.getContext()->get(), attributes.size(), attributes.data()); 186ed9e52f3SAlex Zinenko return PyArrayAttribute(arr.getContext(), arrayAttr); 187ed9e52f3SAlex Zinenko }); 188436c6c9cSStella Laurenzo } 189436c6c9cSStella Laurenzo }; 190436c6c9cSStella Laurenzo 191436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr. 192436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 193436c6c9cSStella Laurenzo public: 194436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 195436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FloatAttr"; 196436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 197436c6c9cSStella Laurenzo 198436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 199436c6c9cSStella Laurenzo c.def_static( 200436c6c9cSStella Laurenzo "get", 201436c6c9cSStella Laurenzo [](PyType &type, double value, DefaultingPyLocation loc) { 202436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 203436c6c9cSStella Laurenzo // TODO: Rework error reporting once diagnostic engine is exposed 204436c6c9cSStella Laurenzo // in C API. 205436c6c9cSStella Laurenzo if (mlirAttributeIsNull(attr)) { 206436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, 207436c6c9cSStella Laurenzo Twine("invalid '") + 208436c6c9cSStella Laurenzo py::repr(py::cast(type)).cast<std::string>() + 209436c6c9cSStella Laurenzo "' and expected floating point type."); 210436c6c9cSStella Laurenzo } 211436c6c9cSStella Laurenzo return PyFloatAttribute(type.getContext(), attr); 212436c6c9cSStella Laurenzo }, 213436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 214436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a type"); 215436c6c9cSStella Laurenzo c.def_static( 216436c6c9cSStella Laurenzo "get_f32", 217436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 218436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 219436c6c9cSStella Laurenzo context->get(), mlirF32TypeGet(context->get()), value); 220436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 221436c6c9cSStella Laurenzo }, 222436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 223436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f32 type"); 224436c6c9cSStella Laurenzo c.def_static( 225436c6c9cSStella Laurenzo "get_f64", 226436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 227436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 228436c6c9cSStella Laurenzo context->get(), mlirF64TypeGet(context->get()), value); 229436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 230436c6c9cSStella Laurenzo }, 231436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 232436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f64 type"); 233436c6c9cSStella Laurenzo c.def_property_readonly( 234436c6c9cSStella Laurenzo "value", 235436c6c9cSStella Laurenzo [](PyFloatAttribute &self) { 236436c6c9cSStella Laurenzo return mlirFloatAttrGetValueDouble(self); 237436c6c9cSStella Laurenzo }, 238436c6c9cSStella Laurenzo "Returns the value of the float point attribute"); 239436c6c9cSStella Laurenzo } 240436c6c9cSStella Laurenzo }; 241436c6c9cSStella Laurenzo 242436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr. 243436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 244436c6c9cSStella Laurenzo public: 245436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 246436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "IntegerAttr"; 247436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 248436c6c9cSStella Laurenzo 249436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 250436c6c9cSStella Laurenzo c.def_static( 251436c6c9cSStella Laurenzo "get", 252436c6c9cSStella Laurenzo [](PyType &type, int64_t value) { 253436c6c9cSStella Laurenzo MlirAttribute attr = mlirIntegerAttrGet(type, value); 254436c6c9cSStella Laurenzo return PyIntegerAttribute(type.getContext(), attr); 255436c6c9cSStella Laurenzo }, 256436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), 257436c6c9cSStella Laurenzo "Gets an uniqued integer attribute associated to a type"); 258436c6c9cSStella Laurenzo c.def_property_readonly( 259436c6c9cSStella Laurenzo "value", 260436c6c9cSStella Laurenzo [](PyIntegerAttribute &self) { 261436c6c9cSStella Laurenzo return mlirIntegerAttrGetValueInt(self); 262436c6c9cSStella Laurenzo }, 263436c6c9cSStella Laurenzo "Returns the value of the integer attribute"); 264436c6c9cSStella Laurenzo } 265436c6c9cSStella Laurenzo }; 266436c6c9cSStella Laurenzo 267436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr. 268436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 269436c6c9cSStella Laurenzo public: 270436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 271436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "BoolAttr"; 272436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 273436c6c9cSStella Laurenzo 274436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 275436c6c9cSStella Laurenzo c.def_static( 276436c6c9cSStella Laurenzo "get", 277436c6c9cSStella Laurenzo [](bool value, DefaultingPyMlirContext context) { 278436c6c9cSStella Laurenzo MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 279436c6c9cSStella Laurenzo return PyBoolAttribute(context->getRef(), attr); 280436c6c9cSStella Laurenzo }, 281436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 282436c6c9cSStella Laurenzo "Gets an uniqued bool attribute"); 283436c6c9cSStella Laurenzo c.def_property_readonly( 284436c6c9cSStella Laurenzo "value", 285436c6c9cSStella Laurenzo [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, 286436c6c9cSStella Laurenzo "Returns the value of the bool attribute"); 287436c6c9cSStella Laurenzo } 288436c6c9cSStella Laurenzo }; 289436c6c9cSStella Laurenzo 290436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute 291436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 292436c6c9cSStella Laurenzo public: 293436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 294436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 295436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 296436c6c9cSStella Laurenzo 297436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 298436c6c9cSStella Laurenzo c.def_static( 299436c6c9cSStella Laurenzo "get", 300436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 301436c6c9cSStella Laurenzo MlirAttribute attr = 302436c6c9cSStella Laurenzo mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 303436c6c9cSStella Laurenzo return PyFlatSymbolRefAttribute(context->getRef(), attr); 304436c6c9cSStella Laurenzo }, 305436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 306436c6c9cSStella Laurenzo "Gets a uniqued FlatSymbolRef attribute"); 307436c6c9cSStella Laurenzo c.def_property_readonly( 308436c6c9cSStella Laurenzo "value", 309436c6c9cSStella Laurenzo [](PyFlatSymbolRefAttribute &self) { 310436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 311436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 312436c6c9cSStella Laurenzo }, 313436c6c9cSStella Laurenzo "Returns the value of the FlatSymbolRef attribute as a string"); 314436c6c9cSStella Laurenzo } 315436c6c9cSStella Laurenzo }; 316436c6c9cSStella Laurenzo 317436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 318436c6c9cSStella Laurenzo public: 319436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 320436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "StringAttr"; 321436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 322436c6c9cSStella Laurenzo 323436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 324436c6c9cSStella Laurenzo c.def_static( 325436c6c9cSStella Laurenzo "get", 326436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 327436c6c9cSStella Laurenzo MlirAttribute attr = 328436c6c9cSStella Laurenzo mlirStringAttrGet(context->get(), toMlirStringRef(value)); 329436c6c9cSStella Laurenzo return PyStringAttribute(context->getRef(), attr); 330436c6c9cSStella Laurenzo }, 331436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 332436c6c9cSStella Laurenzo "Gets a uniqued string attribute"); 333436c6c9cSStella Laurenzo c.def_static( 334436c6c9cSStella Laurenzo "get_typed", 335436c6c9cSStella Laurenzo [](PyType &type, std::string value) { 336436c6c9cSStella Laurenzo MlirAttribute attr = 337436c6c9cSStella Laurenzo mlirStringAttrTypedGet(type, toMlirStringRef(value)); 338436c6c9cSStella Laurenzo return PyStringAttribute(type.getContext(), attr); 339436c6c9cSStella Laurenzo }, 340436c6c9cSStella Laurenzo 341436c6c9cSStella Laurenzo "Gets a uniqued string attribute associated to a type"); 342436c6c9cSStella Laurenzo c.def_property_readonly( 343436c6c9cSStella Laurenzo "value", 344436c6c9cSStella Laurenzo [](PyStringAttribute &self) { 345436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirStringAttrGetValue(self); 346436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 347436c6c9cSStella Laurenzo }, 348436c6c9cSStella Laurenzo "Returns the value of the string attribute"); 349436c6c9cSStella Laurenzo } 350436c6c9cSStella Laurenzo }; 351436c6c9cSStella Laurenzo 352436c6c9cSStella Laurenzo // TODO: Support construction of string elements. 353436c6c9cSStella Laurenzo class PyDenseElementsAttribute 354436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseElementsAttribute> { 355436c6c9cSStella Laurenzo public: 356436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 357436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseElementsAttr"; 358436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 359436c6c9cSStella Laurenzo 360436c6c9cSStella Laurenzo static PyDenseElementsAttribute 361*5d6d30edSStella Laurenzo getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType, 362*5d6d30edSStella Laurenzo Optional<std::vector<int64_t>> explicitShape, 363436c6c9cSStella Laurenzo DefaultingPyMlirContext contextWrapper) { 364436c6c9cSStella Laurenzo // Request a contiguous view. In exotic cases, this will cause a copy. 365436c6c9cSStella Laurenzo int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; 366436c6c9cSStella Laurenzo Py_buffer *view = new Py_buffer(); 367436c6c9cSStella Laurenzo if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { 368436c6c9cSStella Laurenzo delete view; 369436c6c9cSStella Laurenzo throw py::error_already_set(); 370436c6c9cSStella Laurenzo } 371436c6c9cSStella Laurenzo py::buffer_info arrayInfo(view); 372*5d6d30edSStella Laurenzo SmallVector<int64_t> shape; 373*5d6d30edSStella Laurenzo if (explicitShape) { 374*5d6d30edSStella Laurenzo shape.append(explicitShape->begin(), explicitShape->end()); 375*5d6d30edSStella Laurenzo } else { 376*5d6d30edSStella Laurenzo shape.append(arrayInfo.shape.begin(), 377*5d6d30edSStella Laurenzo arrayInfo.shape.begin() + arrayInfo.ndim); 378*5d6d30edSStella Laurenzo } 379436c6c9cSStella Laurenzo 380*5d6d30edSStella Laurenzo MlirAttribute encodingAttr = mlirAttributeGetNull(); 381436c6c9cSStella Laurenzo MlirContext context = contextWrapper->get(); 382*5d6d30edSStella Laurenzo 383*5d6d30edSStella Laurenzo // Detect format codes that are suitable for bulk loading. This includes 384*5d6d30edSStella Laurenzo // all byte aligned integer and floating point types up to 8 bytes. 385*5d6d30edSStella Laurenzo // Notably, this excludes, bool (which needs to be bit-packed) and 386*5d6d30edSStella Laurenzo // other exotics which do not have a direct representation in the buffer 387*5d6d30edSStella Laurenzo // protocol (i.e. complex, etc). 388*5d6d30edSStella Laurenzo Optional<MlirType> bulkLoadElementType; 389*5d6d30edSStella Laurenzo if (explicitType) { 390*5d6d30edSStella Laurenzo bulkLoadElementType = *explicitType; 391*5d6d30edSStella Laurenzo } else if (arrayInfo.format == "f") { 392436c6c9cSStella Laurenzo // f32 393436c6c9cSStella Laurenzo assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); 394*5d6d30edSStella Laurenzo bulkLoadElementType = mlirF32TypeGet(context); 395436c6c9cSStella Laurenzo } else if (arrayInfo.format == "d") { 396436c6c9cSStella Laurenzo // f64 397436c6c9cSStella Laurenzo assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); 398*5d6d30edSStella Laurenzo bulkLoadElementType = mlirF64TypeGet(context); 399*5d6d30edSStella Laurenzo } else if (arrayInfo.format == "e") { 400*5d6d30edSStella Laurenzo // f16 401*5d6d30edSStella Laurenzo assert(arrayInfo.itemsize == 2 && "mismatched array itemsize"); 402*5d6d30edSStella Laurenzo bulkLoadElementType = mlirF16TypeGet(context); 403436c6c9cSStella Laurenzo } else if (isSignedIntegerFormat(arrayInfo.format)) { 404436c6c9cSStella Laurenzo if (arrayInfo.itemsize == 4) { 405436c6c9cSStella Laurenzo // i32 406*5d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32) 407436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 32); 408436c6c9cSStella Laurenzo } else if (arrayInfo.itemsize == 8) { 409436c6c9cSStella Laurenzo // i64 410*5d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64) 411436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 64); 412*5d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 1) { 413*5d6d30edSStella Laurenzo // i8 414*5d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 415*5d6d30edSStella Laurenzo : mlirIntegerTypeSignedGet(context, 8); 416*5d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 2) { 417*5d6d30edSStella Laurenzo // i16 418*5d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16) 419*5d6d30edSStella Laurenzo : mlirIntegerTypeSignedGet(context, 16); 420436c6c9cSStella Laurenzo } 421436c6c9cSStella Laurenzo } else if (isUnsignedIntegerFormat(arrayInfo.format)) { 422436c6c9cSStella Laurenzo if (arrayInfo.itemsize == 4) { 423436c6c9cSStella Laurenzo // unsigned i32 424*5d6d30edSStella Laurenzo bulkLoadElementType = signless 425436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 32) 426436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 32); 427436c6c9cSStella Laurenzo } else if (arrayInfo.itemsize == 8) { 428436c6c9cSStella Laurenzo // unsigned i64 429*5d6d30edSStella Laurenzo bulkLoadElementType = signless 430436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 64) 431436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 64); 432*5d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 1) { 433*5d6d30edSStella Laurenzo // i8 434*5d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 435*5d6d30edSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 8); 436*5d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 2) { 437*5d6d30edSStella Laurenzo // i16 438*5d6d30edSStella Laurenzo bulkLoadElementType = signless 439*5d6d30edSStella Laurenzo ? mlirIntegerTypeGet(context, 16) 440*5d6d30edSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 16); 441436c6c9cSStella Laurenzo } 442436c6c9cSStella Laurenzo } 443*5d6d30edSStella Laurenzo if (bulkLoadElementType) { 444*5d6d30edSStella Laurenzo auto shapedType = mlirRankedTensorTypeGet( 445*5d6d30edSStella Laurenzo shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); 446*5d6d30edSStella Laurenzo size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize; 447*5d6d30edSStella Laurenzo MlirAttribute attr = mlirDenseElementsAttrRawBufferGet( 448*5d6d30edSStella Laurenzo shapedType, rawBufferSize, arrayInfo.ptr); 449*5d6d30edSStella Laurenzo if (mlirAttributeIsNull(attr)) { 450*5d6d30edSStella Laurenzo throw std::invalid_argument( 451*5d6d30edSStella Laurenzo "DenseElementsAttr could not be constructed from the given buffer. " 452*5d6d30edSStella Laurenzo "This may mean that the Python buffer layout does not match that " 453*5d6d30edSStella Laurenzo "MLIR expected layout and is a bug."); 454*5d6d30edSStella Laurenzo } 455*5d6d30edSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), attr); 456*5d6d30edSStella Laurenzo } 457436c6c9cSStella Laurenzo 458*5d6d30edSStella Laurenzo throw std::invalid_argument( 459*5d6d30edSStella Laurenzo std::string("unimplemented array format conversion from format: ") + 460*5d6d30edSStella Laurenzo arrayInfo.format); 461436c6c9cSStella Laurenzo } 462436c6c9cSStella Laurenzo 463436c6c9cSStella Laurenzo static PyDenseElementsAttribute getSplat(PyType shapedType, 464436c6c9cSStella Laurenzo PyAttribute &elementAttr) { 465436c6c9cSStella Laurenzo auto contextWrapper = 466436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 467436c6c9cSStella Laurenzo if (!mlirAttributeIsAInteger(elementAttr) && 468436c6c9cSStella Laurenzo !mlirAttributeIsAFloat(elementAttr)) { 469436c6c9cSStella Laurenzo std::string message = "Illegal element type for DenseElementsAttr: "; 470436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 471436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 472436c6c9cSStella Laurenzo } 473436c6c9cSStella Laurenzo if (!mlirTypeIsAShaped(shapedType) || 474436c6c9cSStella Laurenzo !mlirShapedTypeHasStaticShape(shapedType)) { 475436c6c9cSStella Laurenzo std::string message = 476436c6c9cSStella Laurenzo "Expected a static ShapedType for the shaped_type parameter: "; 477436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 478436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 479436c6c9cSStella Laurenzo } 480436c6c9cSStella Laurenzo MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 481436c6c9cSStella Laurenzo MlirType attrType = mlirAttributeGetType(elementAttr); 482436c6c9cSStella Laurenzo if (!mlirTypeEqual(shapedElementType, attrType)) { 483436c6c9cSStella Laurenzo std::string message = 484436c6c9cSStella Laurenzo "Shaped element type and attribute type must be equal: shaped="; 485436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 486436c6c9cSStella Laurenzo message.append(", element="); 487436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 488436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 489436c6c9cSStella Laurenzo } 490436c6c9cSStella Laurenzo 491436c6c9cSStella Laurenzo MlirAttribute elements = 492436c6c9cSStella Laurenzo mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 493436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 494436c6c9cSStella Laurenzo } 495436c6c9cSStella Laurenzo 496436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 497436c6c9cSStella Laurenzo 498436c6c9cSStella Laurenzo py::buffer_info accessBuffer() { 499*5d6d30edSStella Laurenzo if (mlirDenseElementsAttrIsSplat(*this)) { 500*5d6d30edSStella Laurenzo // TODO: Raise an exception. 501*5d6d30edSStella Laurenzo // Reported as https://github.com/pybind/pybind11/issues/3336 502*5d6d30edSStella Laurenzo return py::buffer_info(); 503*5d6d30edSStella Laurenzo } 504*5d6d30edSStella Laurenzo 505436c6c9cSStella Laurenzo MlirType shapedType = mlirAttributeGetType(*this); 506436c6c9cSStella Laurenzo MlirType elementType = mlirShapedTypeGetElementType(shapedType); 507*5d6d30edSStella Laurenzo std::string format; 508436c6c9cSStella Laurenzo 509436c6c9cSStella Laurenzo if (mlirTypeIsAF32(elementType)) { 510436c6c9cSStella Laurenzo // f32 511*5d6d30edSStella Laurenzo return bufferInfo<float>(shapedType); 512436c6c9cSStella Laurenzo } else if (mlirTypeIsAF64(elementType)) { 513436c6c9cSStella Laurenzo // f64 514*5d6d30edSStella Laurenzo return bufferInfo<double>(shapedType); 515*5d6d30edSStella Laurenzo } else if (mlirTypeIsAF16(elementType)) { 516*5d6d30edSStella Laurenzo // f16 517*5d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType, "e"); 518436c6c9cSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 519436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 32) { 520436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 521436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 522436c6c9cSStella Laurenzo // i32 523*5d6d30edSStella Laurenzo return bufferInfo<int32_t>(shapedType); 524436c6c9cSStella Laurenzo } else if (mlirIntegerTypeIsUnsigned(elementType)) { 525436c6c9cSStella Laurenzo // unsigned i32 526*5d6d30edSStella Laurenzo return bufferInfo<uint32_t>(shapedType); 527436c6c9cSStella Laurenzo } 528436c6c9cSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 529436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 64) { 530436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 531436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 532436c6c9cSStella Laurenzo // i64 533*5d6d30edSStella Laurenzo return bufferInfo<int64_t>(shapedType); 534436c6c9cSStella Laurenzo } else if (mlirIntegerTypeIsUnsigned(elementType)) { 535436c6c9cSStella Laurenzo // unsigned i64 536*5d6d30edSStella Laurenzo return bufferInfo<uint64_t>(shapedType); 537*5d6d30edSStella Laurenzo } 538*5d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 539*5d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 8) { 540*5d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 541*5d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 542*5d6d30edSStella Laurenzo // i8 543*5d6d30edSStella Laurenzo return bufferInfo<int8_t>(shapedType); 544*5d6d30edSStella Laurenzo } else if (mlirIntegerTypeIsUnsigned(elementType)) { 545*5d6d30edSStella Laurenzo // unsigned i8 546*5d6d30edSStella Laurenzo return bufferInfo<uint8_t>(shapedType); 547*5d6d30edSStella Laurenzo } 548*5d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 549*5d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 16) { 550*5d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 551*5d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 552*5d6d30edSStella Laurenzo // i16 553*5d6d30edSStella Laurenzo return bufferInfo<int16_t>(shapedType); 554*5d6d30edSStella Laurenzo } else if (mlirIntegerTypeIsUnsigned(elementType)) { 555*5d6d30edSStella Laurenzo // unsigned i16 556*5d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType); 557436c6c9cSStella Laurenzo } 558436c6c9cSStella Laurenzo } 559436c6c9cSStella Laurenzo 560*5d6d30edSStella Laurenzo // TODO: Currently crashes the program. Just returning an empty buffer 561*5d6d30edSStella Laurenzo // for now. 562*5d6d30edSStella Laurenzo // Reported as https://github.com/pybind/pybind11/issues/3336 563*5d6d30edSStella Laurenzo // throw std::invalid_argument( 564*5d6d30edSStella Laurenzo // "unsupported data type for conversion to Python buffer"); 565*5d6d30edSStella Laurenzo return py::buffer_info(); 566436c6c9cSStella Laurenzo } 567436c6c9cSStella Laurenzo 568436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 569436c6c9cSStella Laurenzo c.def("__len__", &PyDenseElementsAttribute::dunderLen) 570436c6c9cSStella Laurenzo .def_static("get", PyDenseElementsAttribute::getFromBuffer, 571436c6c9cSStella Laurenzo py::arg("array"), py::arg("signless") = true, 572*5d6d30edSStella Laurenzo py::arg("type") = py::none(), py::arg("shape") = py::none(), 573436c6c9cSStella Laurenzo py::arg("context") = py::none(), 574*5d6d30edSStella Laurenzo kDenseElementsAttrGetDocstring) 575436c6c9cSStella Laurenzo .def_static("get_splat", PyDenseElementsAttribute::getSplat, 576436c6c9cSStella Laurenzo py::arg("shaped_type"), py::arg("element_attr"), 577436c6c9cSStella Laurenzo "Gets a DenseElementsAttr where all values are the same") 578436c6c9cSStella Laurenzo .def_property_readonly("is_splat", 579436c6c9cSStella Laurenzo [](PyDenseElementsAttribute &self) -> bool { 580436c6c9cSStella Laurenzo return mlirDenseElementsAttrIsSplat(self); 581436c6c9cSStella Laurenzo }) 582436c6c9cSStella Laurenzo .def_buffer(&PyDenseElementsAttribute::accessBuffer); 583436c6c9cSStella Laurenzo } 584436c6c9cSStella Laurenzo 585436c6c9cSStella Laurenzo private: 586436c6c9cSStella Laurenzo static bool isUnsignedIntegerFormat(const std::string &format) { 587436c6c9cSStella Laurenzo if (format.empty()) 588436c6c9cSStella Laurenzo return false; 589436c6c9cSStella Laurenzo char code = format[0]; 590436c6c9cSStella Laurenzo return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 591436c6c9cSStella Laurenzo code == 'Q'; 592436c6c9cSStella Laurenzo } 593436c6c9cSStella Laurenzo 594436c6c9cSStella Laurenzo static bool isSignedIntegerFormat(const std::string &format) { 595436c6c9cSStella Laurenzo if (format.empty()) 596436c6c9cSStella Laurenzo return false; 597436c6c9cSStella Laurenzo char code = format[0]; 598436c6c9cSStella Laurenzo return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 599436c6c9cSStella Laurenzo code == 'q'; 600436c6c9cSStella Laurenzo } 601436c6c9cSStella Laurenzo 602436c6c9cSStella Laurenzo template <typename Type> 603436c6c9cSStella Laurenzo py::buffer_info bufferInfo(MlirType shapedType, 604*5d6d30edSStella Laurenzo const char *explicitFormat = nullptr) { 605436c6c9cSStella Laurenzo intptr_t rank = mlirShapedTypeGetRank(shapedType); 606436c6c9cSStella Laurenzo // Prepare the data for the buffer_info. 607436c6c9cSStella Laurenzo // Buffer is configured for read-only access below. 608436c6c9cSStella Laurenzo Type *data = static_cast<Type *>( 609436c6c9cSStella Laurenzo const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 610436c6c9cSStella Laurenzo // Prepare the shape for the buffer_info. 611436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> shape; 612436c6c9cSStella Laurenzo for (intptr_t i = 0; i < rank; ++i) 613436c6c9cSStella Laurenzo shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 614436c6c9cSStella Laurenzo // Prepare the strides for the buffer_info. 615436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> strides; 616436c6c9cSStella Laurenzo intptr_t strideFactor = 1; 617436c6c9cSStella Laurenzo for (intptr_t i = 1; i < rank; ++i) { 618436c6c9cSStella Laurenzo strideFactor = 1; 619436c6c9cSStella Laurenzo for (intptr_t j = i; j < rank; ++j) { 620436c6c9cSStella Laurenzo strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 621436c6c9cSStella Laurenzo } 622436c6c9cSStella Laurenzo strides.push_back(sizeof(Type) * strideFactor); 623436c6c9cSStella Laurenzo } 624436c6c9cSStella Laurenzo strides.push_back(sizeof(Type)); 625*5d6d30edSStella Laurenzo std::string format; 626*5d6d30edSStella Laurenzo if (explicitFormat) { 627*5d6d30edSStella Laurenzo format = explicitFormat; 628*5d6d30edSStella Laurenzo } else { 629*5d6d30edSStella Laurenzo format = py::format_descriptor<Type>::format(); 630*5d6d30edSStella Laurenzo } 631*5d6d30edSStella Laurenzo return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, 632*5d6d30edSStella Laurenzo /*readonly=*/true); 633436c6c9cSStella Laurenzo } 634436c6c9cSStella Laurenzo }; // namespace 635436c6c9cSStella Laurenzo 636436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer 637436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access. 638436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute 639436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseIntElementsAttribute, 640436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 641436c6c9cSStella Laurenzo public: 642436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 643436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseIntElementsAttr"; 644436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 645436c6c9cSStella Laurenzo 646436c6c9cSStella Laurenzo /// Returns the element at the given linear position. Asserts if the index is 647436c6c9cSStella Laurenzo /// out of range. 648436c6c9cSStella Laurenzo py::int_ dunderGetItem(intptr_t pos) { 649436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 650436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 651436c6c9cSStella Laurenzo "attempt to access out of bounds element"); 652436c6c9cSStella Laurenzo } 653436c6c9cSStella Laurenzo 654436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 655436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 656436c6c9cSStella Laurenzo assert(mlirTypeIsAInteger(type) && 657436c6c9cSStella Laurenzo "expected integer element type in dense int elements attribute"); 658436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 659436c6c9cSStella Laurenzo // elemental type of the attribute. py::int_ is implicitly constructible 660436c6c9cSStella Laurenzo // from any C++ integral type and handles bitwidth correctly. 661436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 662436c6c9cSStella Laurenzo // querying them on each element access. 663436c6c9cSStella Laurenzo unsigned width = mlirIntegerTypeGetWidth(type); 664436c6c9cSStella Laurenzo bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 665436c6c9cSStella Laurenzo if (isUnsigned) { 666436c6c9cSStella Laurenzo if (width == 1) { 667436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 668436c6c9cSStella Laurenzo } 669436c6c9cSStella Laurenzo if (width == 32) { 670436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt32Value(*this, pos); 671436c6c9cSStella Laurenzo } 672436c6c9cSStella Laurenzo if (width == 64) { 673436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt64Value(*this, pos); 674436c6c9cSStella Laurenzo } 675436c6c9cSStella Laurenzo } else { 676436c6c9cSStella Laurenzo if (width == 1) { 677436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 678436c6c9cSStella Laurenzo } 679436c6c9cSStella Laurenzo if (width == 32) { 680436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt32Value(*this, pos); 681436c6c9cSStella Laurenzo } 682436c6c9cSStella Laurenzo if (width == 64) { 683436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt64Value(*this, pos); 684436c6c9cSStella Laurenzo } 685436c6c9cSStella Laurenzo } 686436c6c9cSStella Laurenzo throw SetPyError(PyExc_TypeError, "Unsupported integer type"); 687436c6c9cSStella Laurenzo } 688436c6c9cSStella Laurenzo 689436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 690436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 691436c6c9cSStella Laurenzo } 692436c6c9cSStella Laurenzo }; 693436c6c9cSStella Laurenzo 694436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 695436c6c9cSStella Laurenzo public: 696436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 697436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DictAttr"; 698436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 699436c6c9cSStella Laurenzo 700436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 701436c6c9cSStella Laurenzo 702436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 703436c6c9cSStella Laurenzo c.def("__len__", &PyDictAttribute::dunderLen); 704436c6c9cSStella Laurenzo c.def_static( 705436c6c9cSStella Laurenzo "get", 706436c6c9cSStella Laurenzo [](py::dict attributes, DefaultingPyMlirContext context) { 707436c6c9cSStella Laurenzo SmallVector<MlirNamedAttribute> mlirNamedAttributes; 708436c6c9cSStella Laurenzo mlirNamedAttributes.reserve(attributes.size()); 709436c6c9cSStella Laurenzo for (auto &it : attributes) { 710436c6c9cSStella Laurenzo auto &mlir_attr = it.second.cast<PyAttribute &>(); 711436c6c9cSStella Laurenzo auto name = it.first.cast<std::string>(); 712436c6c9cSStella Laurenzo mlirNamedAttributes.push_back(mlirNamedAttributeGet( 713436c6c9cSStella Laurenzo mlirIdentifierGet(mlirAttributeGetContext(mlir_attr), 714436c6c9cSStella Laurenzo toMlirStringRef(name)), 715436c6c9cSStella Laurenzo mlir_attr)); 716436c6c9cSStella Laurenzo } 717436c6c9cSStella Laurenzo MlirAttribute attr = 718436c6c9cSStella Laurenzo mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 719436c6c9cSStella Laurenzo mlirNamedAttributes.data()); 720436c6c9cSStella Laurenzo return PyDictAttribute(context->getRef(), attr); 721436c6c9cSStella Laurenzo }, 722ed9e52f3SAlex Zinenko py::arg("value") = py::dict(), py::arg("context") = py::none(), 723436c6c9cSStella Laurenzo "Gets an uniqued dict attribute"); 724436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 725436c6c9cSStella Laurenzo MlirAttribute attr = 726436c6c9cSStella Laurenzo mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 727436c6c9cSStella Laurenzo if (mlirAttributeIsNull(attr)) { 728436c6c9cSStella Laurenzo throw SetPyError(PyExc_KeyError, 729436c6c9cSStella Laurenzo "attempt to access a non-existent attribute"); 730436c6c9cSStella Laurenzo } 731436c6c9cSStella Laurenzo return PyAttribute(self.getContext(), attr); 732436c6c9cSStella Laurenzo }); 733436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 734436c6c9cSStella Laurenzo if (index < 0 || index >= self.dunderLen()) { 735436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 736436c6c9cSStella Laurenzo "attempt to access out of bounds attribute"); 737436c6c9cSStella Laurenzo } 738436c6c9cSStella Laurenzo MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 739436c6c9cSStella Laurenzo return PyNamedAttribute( 740436c6c9cSStella Laurenzo namedAttr.attribute, 741436c6c9cSStella Laurenzo std::string(mlirIdentifierStr(namedAttr.name).data)); 742436c6c9cSStella Laurenzo }); 743436c6c9cSStella Laurenzo } 744436c6c9cSStella Laurenzo }; 745436c6c9cSStella Laurenzo 746436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing 747436c6c9cSStella Laurenzo /// floating-point values. Supports element access. 748436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute 749436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseFPElementsAttribute, 750436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 751436c6c9cSStella Laurenzo public: 752436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 753436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseFPElementsAttr"; 754436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 755436c6c9cSStella Laurenzo 756436c6c9cSStella Laurenzo py::float_ dunderGetItem(intptr_t pos) { 757436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 758436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 759436c6c9cSStella Laurenzo "attempt to access out of bounds element"); 760436c6c9cSStella Laurenzo } 761436c6c9cSStella Laurenzo 762436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 763436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 764436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 765436c6c9cSStella Laurenzo // elemental type of the attribute. py::float_ is implicitly constructible 766436c6c9cSStella Laurenzo // from float and double. 767436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 768436c6c9cSStella Laurenzo // querying them on each element access. 769436c6c9cSStella Laurenzo if (mlirTypeIsAF32(type)) { 770436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetFloatValue(*this, pos); 771436c6c9cSStella Laurenzo } 772436c6c9cSStella Laurenzo if (mlirTypeIsAF64(type)) { 773436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetDoubleValue(*this, pos); 774436c6c9cSStella Laurenzo } 775436c6c9cSStella Laurenzo throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); 776436c6c9cSStella Laurenzo } 777436c6c9cSStella Laurenzo 778436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 779436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 780436c6c9cSStella Laurenzo } 781436c6c9cSStella Laurenzo }; 782436c6c9cSStella Laurenzo 783436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 784436c6c9cSStella Laurenzo public: 785436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 786436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "TypeAttr"; 787436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 788436c6c9cSStella Laurenzo 789436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 790436c6c9cSStella Laurenzo c.def_static( 791436c6c9cSStella Laurenzo "get", 792436c6c9cSStella Laurenzo [](PyType value, DefaultingPyMlirContext context) { 793436c6c9cSStella Laurenzo MlirAttribute attr = mlirTypeAttrGet(value.get()); 794436c6c9cSStella Laurenzo return PyTypeAttribute(context->getRef(), attr); 795436c6c9cSStella Laurenzo }, 796436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 797436c6c9cSStella Laurenzo "Gets a uniqued Type attribute"); 798436c6c9cSStella Laurenzo c.def_property_readonly("value", [](PyTypeAttribute &self) { 799436c6c9cSStella Laurenzo return PyType(self.getContext()->getRef(), 800436c6c9cSStella Laurenzo mlirTypeAttrGetValue(self.get())); 801436c6c9cSStella Laurenzo }); 802436c6c9cSStella Laurenzo } 803436c6c9cSStella Laurenzo }; 804436c6c9cSStella Laurenzo 805436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values. 806436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 807436c6c9cSStella Laurenzo public: 808436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 809436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "UnitAttr"; 810436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 811436c6c9cSStella Laurenzo 812436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 813436c6c9cSStella Laurenzo c.def_static( 814436c6c9cSStella Laurenzo "get", 815436c6c9cSStella Laurenzo [](DefaultingPyMlirContext context) { 816436c6c9cSStella Laurenzo return PyUnitAttribute(context->getRef(), 817436c6c9cSStella Laurenzo mlirUnitAttrGet(context->get())); 818436c6c9cSStella Laurenzo }, 819436c6c9cSStella Laurenzo py::arg("context") = py::none(), "Create a Unit attribute."); 820436c6c9cSStella Laurenzo } 821436c6c9cSStella Laurenzo }; 822436c6c9cSStella Laurenzo 823436c6c9cSStella Laurenzo } // namespace 824436c6c9cSStella Laurenzo 825436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) { 826436c6c9cSStella Laurenzo PyAffineMapAttribute::bind(m); 827436c6c9cSStella Laurenzo PyArrayAttribute::bind(m); 828436c6c9cSStella Laurenzo PyArrayAttribute::PyArrayAttributeIterator::bind(m); 829436c6c9cSStella Laurenzo PyBoolAttribute::bind(m); 830436c6c9cSStella Laurenzo PyDenseElementsAttribute::bind(m); 831436c6c9cSStella Laurenzo PyDenseFPElementsAttribute::bind(m); 832436c6c9cSStella Laurenzo PyDenseIntElementsAttribute::bind(m); 833436c6c9cSStella Laurenzo PyDictAttribute::bind(m); 834436c6c9cSStella Laurenzo PyFlatSymbolRefAttribute::bind(m); 835436c6c9cSStella Laurenzo PyFloatAttribute::bind(m); 836436c6c9cSStella Laurenzo PyIntegerAttribute::bind(m); 837436c6c9cSStella Laurenzo PyStringAttribute::bind(m); 838436c6c9cSStella Laurenzo PyTypeAttribute::bind(m); 839436c6c9cSStella Laurenzo PyUnitAttribute::bind(m); 840436c6c9cSStella Laurenzo } 841