1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2436c6c9cSStella Laurenzo // 3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6436c6c9cSStella Laurenzo // 7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8436c6c9cSStella Laurenzo 91fc096afSMehdi Amini #include <utility> 101fc096afSMehdi Amini 11436c6c9cSStella Laurenzo #include "IRModule.h" 12436c6c9cSStella Laurenzo 13436c6c9cSStella Laurenzo #include "PybindUtils.h" 14436c6c9cSStella Laurenzo 15436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h" 16436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h" 17436c6c9cSStella Laurenzo 18436c6c9cSStella Laurenzo namespace py = pybind11; 19436c6c9cSStella Laurenzo using namespace mlir; 20436c6c9cSStella Laurenzo using namespace mlir::python; 21436c6c9cSStella Laurenzo 225d6d30edSStella Laurenzo using llvm::Optional; 23436c6c9cSStella Laurenzo using llvm::SmallVector; 24436c6c9cSStella Laurenzo using llvm::Twine; 25436c6c9cSStella Laurenzo 265d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 275d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline). 285d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 295d6d30edSStella Laurenzo 305d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] = 315d6d30edSStella Laurenzo R"(Gets a DenseElementsAttr from a Python buffer or array. 325d6d30edSStella Laurenzo 335d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based 345d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and 355d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these 365d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer. 375d6d30edSStella Laurenzo 385d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided 395d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal 405d6d30edSStella Laurenzo representation: 415d6d30edSStella Laurenzo 425d6d30edSStella Laurenzo * Integer types (except for i1): the buffer must be byte aligned to the 435d6d30edSStella Laurenzo next byte boundary. 445d6d30edSStella Laurenzo * Floating point types: Must be bit-castable to the given floating point 455d6d30edSStella Laurenzo size. 465d6d30edSStella Laurenzo * i1 (bool): Bit packed into 8bit words where the bit pattern matches a 475d6d30edSStella Laurenzo row major ordering. An arbitrary Numpy `bool_` array can be bit packed to 485d6d30edSStella Laurenzo this specification with: `np.packbits(ary, axis=None, bitorder='little')`. 495d6d30edSStella Laurenzo 505d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0 515d6d30edSStella Laurenzo or 255), then a splat will be created. 525d6d30edSStella Laurenzo 535d6d30edSStella Laurenzo Args: 545d6d30edSStella Laurenzo array: The array or buffer to convert. 555d6d30edSStella Laurenzo signless: If inferring an appropriate MLIR type, use signless types for 565d6d30edSStella Laurenzo integers (defaults True). 575d6d30edSStella Laurenzo type: Skips inference of the MLIR element type and uses this instead. The 585d6d30edSStella Laurenzo storage size must be consistent with the actual contents of the buffer. 595d6d30edSStella Laurenzo shape: Overrides the shape of the buffer when constructing the MLIR 605d6d30edSStella Laurenzo shaped type. This is needed when the physical and logical shape differ (as 615d6d30edSStella Laurenzo for i1). 625d6d30edSStella Laurenzo context: Explicit context, if not from context manager. 635d6d30edSStella Laurenzo 645d6d30edSStella Laurenzo Returns: 655d6d30edSStella Laurenzo DenseElementsAttr on success. 665d6d30edSStella Laurenzo 675d6d30edSStella Laurenzo Raises: 685d6d30edSStella Laurenzo ValueError: If the type of the buffer or array cannot be matched to an MLIR 695d6d30edSStella Laurenzo type or if the buffer does not meet expectations. 705d6d30edSStella Laurenzo )"; 715d6d30edSStella Laurenzo 72436c6c9cSStella Laurenzo namespace { 73436c6c9cSStella Laurenzo 74436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) { 75436c6c9cSStella Laurenzo return mlirStringRefCreate(s.data(), s.size()); 76436c6c9cSStella Laurenzo } 77436c6c9cSStella Laurenzo 78436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 79436c6c9cSStella Laurenzo public: 80436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 81436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineMapAttr"; 82436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 83436c6c9cSStella Laurenzo 84436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 85436c6c9cSStella Laurenzo c.def_static( 86436c6c9cSStella Laurenzo "get", 87436c6c9cSStella Laurenzo [](PyAffineMap &affineMap) { 88436c6c9cSStella Laurenzo MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 89436c6c9cSStella Laurenzo return PyAffineMapAttribute(affineMap.getContext(), attr); 90436c6c9cSStella Laurenzo }, 91436c6c9cSStella Laurenzo py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 92436c6c9cSStella Laurenzo } 93436c6c9cSStella Laurenzo }; 94436c6c9cSStella Laurenzo 95ed9e52f3SAlex Zinenko template <typename T> 96ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) { 97ed9e52f3SAlex Zinenko try { 98ed9e52f3SAlex Zinenko return object.cast<T>(); 99ed9e52f3SAlex Zinenko } catch (py::cast_error &err) { 100ed9e52f3SAlex Zinenko std::string msg = 101ed9e52f3SAlex Zinenko std::string( 102ed9e52f3SAlex Zinenko "Invalid attribute when attempting to create an ArrayAttribute (") + 103ed9e52f3SAlex Zinenko err.what() + ")"; 104ed9e52f3SAlex Zinenko throw py::cast_error(msg); 105ed9e52f3SAlex Zinenko } catch (py::reference_cast_error &err) { 106ed9e52f3SAlex Zinenko std::string msg = std::string("Invalid attribute (None?) when attempting " 107ed9e52f3SAlex Zinenko "to create an ArrayAttribute (") + 108ed9e52f3SAlex Zinenko err.what() + ")"; 109ed9e52f3SAlex Zinenko throw py::cast_error(msg); 110ed9e52f3SAlex Zinenko } 111ed9e52f3SAlex Zinenko } 112ed9e52f3SAlex Zinenko 113436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 114436c6c9cSStella Laurenzo public: 115436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 116436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "ArrayAttr"; 117436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 118436c6c9cSStella Laurenzo 119436c6c9cSStella Laurenzo class PyArrayAttributeIterator { 120436c6c9cSStella Laurenzo public: 1211fc096afSMehdi Amini PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} 122436c6c9cSStella Laurenzo 123436c6c9cSStella Laurenzo PyArrayAttributeIterator &dunderIter() { return *this; } 124436c6c9cSStella Laurenzo 125436c6c9cSStella Laurenzo PyAttribute dunderNext() { 126436c6c9cSStella Laurenzo if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { 127436c6c9cSStella Laurenzo throw py::stop_iteration(); 128436c6c9cSStella Laurenzo } 129436c6c9cSStella Laurenzo return PyAttribute(attr.getContext(), 130436c6c9cSStella Laurenzo mlirArrayAttrGetElement(attr.get(), nextIndex++)); 131436c6c9cSStella Laurenzo } 132436c6c9cSStella Laurenzo 133436c6c9cSStella Laurenzo static void bind(py::module &m) { 134f05ff4f7SStella Laurenzo py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator", 135f05ff4f7SStella Laurenzo py::module_local()) 136436c6c9cSStella Laurenzo .def("__iter__", &PyArrayAttributeIterator::dunderIter) 137436c6c9cSStella Laurenzo .def("__next__", &PyArrayAttributeIterator::dunderNext); 138436c6c9cSStella Laurenzo } 139436c6c9cSStella Laurenzo 140436c6c9cSStella Laurenzo private: 141436c6c9cSStella Laurenzo PyAttribute attr; 142436c6c9cSStella Laurenzo int nextIndex = 0; 143436c6c9cSStella Laurenzo }; 144436c6c9cSStella Laurenzo 145ed9e52f3SAlex Zinenko PyAttribute getItem(intptr_t i) { 146ed9e52f3SAlex Zinenko return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i)); 147ed9e52f3SAlex Zinenko } 148ed9e52f3SAlex Zinenko 149436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 150436c6c9cSStella Laurenzo c.def_static( 151436c6c9cSStella Laurenzo "get", 152436c6c9cSStella Laurenzo [](py::list attributes, DefaultingPyMlirContext context) { 153436c6c9cSStella Laurenzo SmallVector<MlirAttribute> mlirAttributes; 154436c6c9cSStella Laurenzo mlirAttributes.reserve(py::len(attributes)); 155436c6c9cSStella Laurenzo for (auto attribute : attributes) { 156ed9e52f3SAlex Zinenko mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); 157436c6c9cSStella Laurenzo } 158436c6c9cSStella Laurenzo MlirAttribute attr = mlirArrayAttrGet( 159436c6c9cSStella Laurenzo context->get(), mlirAttributes.size(), mlirAttributes.data()); 160436c6c9cSStella Laurenzo return PyArrayAttribute(context->getRef(), attr); 161436c6c9cSStella Laurenzo }, 162436c6c9cSStella Laurenzo py::arg("attributes"), py::arg("context") = py::none(), 163436c6c9cSStella Laurenzo "Gets a uniqued Array attribute"); 164436c6c9cSStella Laurenzo c.def("__getitem__", 165436c6c9cSStella Laurenzo [](PyArrayAttribute &arr, intptr_t i) { 166436c6c9cSStella Laurenzo if (i >= mlirArrayAttrGetNumElements(arr)) 167436c6c9cSStella Laurenzo throw py::index_error("ArrayAttribute index out of range"); 168ed9e52f3SAlex Zinenko return arr.getItem(i); 169436c6c9cSStella Laurenzo }) 170436c6c9cSStella Laurenzo .def("__len__", 171436c6c9cSStella Laurenzo [](const PyArrayAttribute &arr) { 172436c6c9cSStella Laurenzo return mlirArrayAttrGetNumElements(arr); 173436c6c9cSStella Laurenzo }) 174436c6c9cSStella Laurenzo .def("__iter__", [](const PyArrayAttribute &arr) { 175436c6c9cSStella Laurenzo return PyArrayAttributeIterator(arr); 176436c6c9cSStella Laurenzo }); 177ed9e52f3SAlex Zinenko c.def("__add__", [](PyArrayAttribute arr, py::list extras) { 178ed9e52f3SAlex Zinenko std::vector<MlirAttribute> attributes; 179ed9e52f3SAlex Zinenko intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); 180ed9e52f3SAlex Zinenko attributes.reserve(numOldElements + py::len(extras)); 181ed9e52f3SAlex Zinenko for (intptr_t i = 0; i < numOldElements; ++i) 182ed9e52f3SAlex Zinenko attributes.push_back(arr.getItem(i)); 183ed9e52f3SAlex Zinenko for (py::handle attr : extras) 184ed9e52f3SAlex Zinenko attributes.push_back(pyTryCast<PyAttribute>(attr)); 185ed9e52f3SAlex Zinenko MlirAttribute arrayAttr = mlirArrayAttrGet( 186ed9e52f3SAlex Zinenko arr.getContext()->get(), attributes.size(), attributes.data()); 187ed9e52f3SAlex Zinenko return PyArrayAttribute(arr.getContext(), arrayAttr); 188ed9e52f3SAlex Zinenko }); 189436c6c9cSStella Laurenzo } 190436c6c9cSStella Laurenzo }; 191436c6c9cSStella Laurenzo 192436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr. 193436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 194436c6c9cSStella Laurenzo public: 195436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 196436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FloatAttr"; 197436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 198436c6c9cSStella Laurenzo 199436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 200436c6c9cSStella Laurenzo c.def_static( 201436c6c9cSStella Laurenzo "get", 202436c6c9cSStella Laurenzo [](PyType &type, double value, DefaultingPyLocation loc) { 203436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 204436c6c9cSStella Laurenzo // TODO: Rework error reporting once diagnostic engine is exposed 205436c6c9cSStella Laurenzo // in C API. 206436c6c9cSStella Laurenzo if (mlirAttributeIsNull(attr)) { 207436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, 208436c6c9cSStella Laurenzo Twine("invalid '") + 209436c6c9cSStella Laurenzo py::repr(py::cast(type)).cast<std::string>() + 210436c6c9cSStella Laurenzo "' and expected floating point type."); 211436c6c9cSStella Laurenzo } 212436c6c9cSStella Laurenzo return PyFloatAttribute(type.getContext(), attr); 213436c6c9cSStella Laurenzo }, 214436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), 215436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a type"); 216436c6c9cSStella Laurenzo c.def_static( 217436c6c9cSStella Laurenzo "get_f32", 218436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 219436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 220436c6c9cSStella Laurenzo context->get(), mlirF32TypeGet(context->get()), value); 221436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 222436c6c9cSStella Laurenzo }, 223436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 224436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f32 type"); 225436c6c9cSStella Laurenzo c.def_static( 226436c6c9cSStella Laurenzo "get_f64", 227436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 228436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 229436c6c9cSStella Laurenzo context->get(), mlirF64TypeGet(context->get()), value); 230436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 231436c6c9cSStella Laurenzo }, 232436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 233436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f64 type"); 234436c6c9cSStella Laurenzo c.def_property_readonly( 235436c6c9cSStella Laurenzo "value", 236436c6c9cSStella Laurenzo [](PyFloatAttribute &self) { 237436c6c9cSStella Laurenzo return mlirFloatAttrGetValueDouble(self); 238436c6c9cSStella Laurenzo }, 239436c6c9cSStella Laurenzo "Returns the value of the float point attribute"); 240436c6c9cSStella Laurenzo } 241436c6c9cSStella Laurenzo }; 242436c6c9cSStella Laurenzo 243436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr. 244436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 245436c6c9cSStella Laurenzo public: 246436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 247436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "IntegerAttr"; 248436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 249436c6c9cSStella Laurenzo 250436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 251436c6c9cSStella Laurenzo c.def_static( 252436c6c9cSStella Laurenzo "get", 253436c6c9cSStella Laurenzo [](PyType &type, int64_t value) { 254436c6c9cSStella Laurenzo MlirAttribute attr = mlirIntegerAttrGet(type, value); 255436c6c9cSStella Laurenzo return PyIntegerAttribute(type.getContext(), attr); 256436c6c9cSStella Laurenzo }, 257436c6c9cSStella Laurenzo py::arg("type"), py::arg("value"), 258436c6c9cSStella Laurenzo "Gets an uniqued integer attribute associated to a type"); 259436c6c9cSStella Laurenzo c.def_property_readonly( 260436c6c9cSStella Laurenzo "value", 261e9db306dSrkayaith [](PyIntegerAttribute &self) -> py::int_ { 262e9db306dSrkayaith MlirType type = mlirAttributeGetType(self); 263e9db306dSrkayaith if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) 264436c6c9cSStella Laurenzo return mlirIntegerAttrGetValueInt(self); 265e9db306dSrkayaith if (mlirIntegerTypeIsSigned(type)) 266e9db306dSrkayaith return mlirIntegerAttrGetValueSInt(self); 267e9db306dSrkayaith return mlirIntegerAttrGetValueUInt(self); 268436c6c9cSStella Laurenzo }, 269436c6c9cSStella Laurenzo "Returns the value of the integer attribute"); 270436c6c9cSStella Laurenzo } 271436c6c9cSStella Laurenzo }; 272436c6c9cSStella Laurenzo 273436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr. 274436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 275436c6c9cSStella Laurenzo public: 276436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 277436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "BoolAttr"; 278436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 279436c6c9cSStella Laurenzo 280436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 281436c6c9cSStella Laurenzo c.def_static( 282436c6c9cSStella Laurenzo "get", 283436c6c9cSStella Laurenzo [](bool value, DefaultingPyMlirContext context) { 284436c6c9cSStella Laurenzo MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 285436c6c9cSStella Laurenzo return PyBoolAttribute(context->getRef(), attr); 286436c6c9cSStella Laurenzo }, 287436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 288436c6c9cSStella Laurenzo "Gets an uniqued bool attribute"); 289436c6c9cSStella Laurenzo c.def_property_readonly( 290436c6c9cSStella Laurenzo "value", 291436c6c9cSStella Laurenzo [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, 292436c6c9cSStella Laurenzo "Returns the value of the bool attribute"); 293436c6c9cSStella Laurenzo } 294436c6c9cSStella Laurenzo }; 295436c6c9cSStella Laurenzo 296436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute 297436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 298436c6c9cSStella Laurenzo public: 299436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 300436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 301436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 302436c6c9cSStella Laurenzo 303436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 304436c6c9cSStella Laurenzo c.def_static( 305436c6c9cSStella Laurenzo "get", 306436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 307436c6c9cSStella Laurenzo MlirAttribute attr = 308436c6c9cSStella Laurenzo mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 309436c6c9cSStella Laurenzo return PyFlatSymbolRefAttribute(context->getRef(), attr); 310436c6c9cSStella Laurenzo }, 311436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 312436c6c9cSStella Laurenzo "Gets a uniqued FlatSymbolRef attribute"); 313436c6c9cSStella Laurenzo c.def_property_readonly( 314436c6c9cSStella Laurenzo "value", 315436c6c9cSStella Laurenzo [](PyFlatSymbolRefAttribute &self) { 316436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 317436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 318436c6c9cSStella Laurenzo }, 319436c6c9cSStella Laurenzo "Returns the value of the FlatSymbolRef attribute as a string"); 320436c6c9cSStella Laurenzo } 321436c6c9cSStella Laurenzo }; 322436c6c9cSStella Laurenzo 323*5c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> { 324*5c3861b2SYun Long public: 325*5c3861b2SYun Long static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; 326*5c3861b2SYun Long static constexpr const char *pyClassName = "OpaqueAttr"; 327*5c3861b2SYun Long using PyConcreteAttribute::PyConcreteAttribute; 328*5c3861b2SYun Long 329*5c3861b2SYun Long static void bindDerived(ClassTy &c) { 330*5c3861b2SYun Long c.def_static( 331*5c3861b2SYun Long "get", 332*5c3861b2SYun Long [](std::string dialectNamespace, py::buffer buffer, PyType &type, 333*5c3861b2SYun Long DefaultingPyMlirContext context) { 334*5c3861b2SYun Long const py::buffer_info bufferInfo = buffer.request(); 335*5c3861b2SYun Long intptr_t bufferSize = bufferInfo.size; 336*5c3861b2SYun Long MlirAttribute attr = mlirOpaqueAttrGet( 337*5c3861b2SYun Long context->get(), toMlirStringRef(dialectNamespace), bufferSize, 338*5c3861b2SYun Long static_cast<char *>(bufferInfo.ptr), type); 339*5c3861b2SYun Long return PyOpaqueAttribute(context->getRef(), attr); 340*5c3861b2SYun Long }, 341*5c3861b2SYun Long py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"), 342*5c3861b2SYun Long py::arg("context") = py::none(), "Gets an Opaque attribute."); 343*5c3861b2SYun Long c.def_property_readonly( 344*5c3861b2SYun Long "dialect_namespace", 345*5c3861b2SYun Long [](PyOpaqueAttribute &self) { 346*5c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); 347*5c3861b2SYun Long return py::str(stringRef.data, stringRef.length); 348*5c3861b2SYun Long }, 349*5c3861b2SYun Long "Returns the dialect namespace for the Opaque attribute as a string"); 350*5c3861b2SYun Long c.def_property_readonly( 351*5c3861b2SYun Long "data", 352*5c3861b2SYun Long [](PyOpaqueAttribute &self) { 353*5c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetData(self); 354*5c3861b2SYun Long return py::str(stringRef.data, stringRef.length); 355*5c3861b2SYun Long }, 356*5c3861b2SYun Long "Returns the data for the Opaqued attributes as a string"); 357*5c3861b2SYun Long } 358*5c3861b2SYun Long }; 359*5c3861b2SYun Long 360436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 361436c6c9cSStella Laurenzo public: 362436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 363436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "StringAttr"; 364436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 365436c6c9cSStella Laurenzo 366436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 367436c6c9cSStella Laurenzo c.def_static( 368436c6c9cSStella Laurenzo "get", 369436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 370436c6c9cSStella Laurenzo MlirAttribute attr = 371436c6c9cSStella Laurenzo mlirStringAttrGet(context->get(), toMlirStringRef(value)); 372436c6c9cSStella Laurenzo return PyStringAttribute(context->getRef(), attr); 373436c6c9cSStella Laurenzo }, 374436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 375436c6c9cSStella Laurenzo "Gets a uniqued string attribute"); 376436c6c9cSStella Laurenzo c.def_static( 377436c6c9cSStella Laurenzo "get_typed", 378436c6c9cSStella Laurenzo [](PyType &type, std::string value) { 379436c6c9cSStella Laurenzo MlirAttribute attr = 380436c6c9cSStella Laurenzo mlirStringAttrTypedGet(type, toMlirStringRef(value)); 381436c6c9cSStella Laurenzo return PyStringAttribute(type.getContext(), attr); 382436c6c9cSStella Laurenzo }, 383a6e7d024SStella Laurenzo py::arg("type"), py::arg("value"), 384436c6c9cSStella Laurenzo "Gets a uniqued string attribute associated to a type"); 385436c6c9cSStella Laurenzo c.def_property_readonly( 386436c6c9cSStella Laurenzo "value", 387436c6c9cSStella Laurenzo [](PyStringAttribute &self) { 388436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirStringAttrGetValue(self); 389436c6c9cSStella Laurenzo return py::str(stringRef.data, stringRef.length); 390436c6c9cSStella Laurenzo }, 391436c6c9cSStella Laurenzo "Returns the value of the string attribute"); 392436c6c9cSStella Laurenzo } 393436c6c9cSStella Laurenzo }; 394436c6c9cSStella Laurenzo 395436c6c9cSStella Laurenzo // TODO: Support construction of string elements. 396436c6c9cSStella Laurenzo class PyDenseElementsAttribute 397436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseElementsAttribute> { 398436c6c9cSStella Laurenzo public: 399436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 400436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseElementsAttr"; 401436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 402436c6c9cSStella Laurenzo 403436c6c9cSStella Laurenzo static PyDenseElementsAttribute 4045d6d30edSStella Laurenzo getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType, 4055d6d30edSStella Laurenzo Optional<std::vector<int64_t>> explicitShape, 406436c6c9cSStella Laurenzo DefaultingPyMlirContext contextWrapper) { 407436c6c9cSStella Laurenzo // Request a contiguous view. In exotic cases, this will cause a copy. 408436c6c9cSStella Laurenzo int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; 409436c6c9cSStella Laurenzo Py_buffer *view = new Py_buffer(); 410436c6c9cSStella Laurenzo if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { 411436c6c9cSStella Laurenzo delete view; 412436c6c9cSStella Laurenzo throw py::error_already_set(); 413436c6c9cSStella Laurenzo } 414436c6c9cSStella Laurenzo py::buffer_info arrayInfo(view); 4155d6d30edSStella Laurenzo SmallVector<int64_t> shape; 4165d6d30edSStella Laurenzo if (explicitShape) { 4175d6d30edSStella Laurenzo shape.append(explicitShape->begin(), explicitShape->end()); 4185d6d30edSStella Laurenzo } else { 4195d6d30edSStella Laurenzo shape.append(arrayInfo.shape.begin(), 4205d6d30edSStella Laurenzo arrayInfo.shape.begin() + arrayInfo.ndim); 4215d6d30edSStella Laurenzo } 422436c6c9cSStella Laurenzo 4235d6d30edSStella Laurenzo MlirAttribute encodingAttr = mlirAttributeGetNull(); 424436c6c9cSStella Laurenzo MlirContext context = contextWrapper->get(); 4255d6d30edSStella Laurenzo 4265d6d30edSStella Laurenzo // Detect format codes that are suitable for bulk loading. This includes 4275d6d30edSStella Laurenzo // all byte aligned integer and floating point types up to 8 bytes. 4285d6d30edSStella Laurenzo // Notably, this excludes, bool (which needs to be bit-packed) and 4295d6d30edSStella Laurenzo // other exotics which do not have a direct representation in the buffer 4305d6d30edSStella Laurenzo // protocol (i.e. complex, etc). 4315d6d30edSStella Laurenzo Optional<MlirType> bulkLoadElementType; 4325d6d30edSStella Laurenzo if (explicitType) { 4335d6d30edSStella Laurenzo bulkLoadElementType = *explicitType; 4345d6d30edSStella Laurenzo } else if (arrayInfo.format == "f") { 435436c6c9cSStella Laurenzo // f32 436436c6c9cSStella Laurenzo assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); 4375d6d30edSStella Laurenzo bulkLoadElementType = mlirF32TypeGet(context); 438436c6c9cSStella Laurenzo } else if (arrayInfo.format == "d") { 439436c6c9cSStella Laurenzo // f64 440436c6c9cSStella Laurenzo assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); 4415d6d30edSStella Laurenzo bulkLoadElementType = mlirF64TypeGet(context); 4425d6d30edSStella Laurenzo } else if (arrayInfo.format == "e") { 4435d6d30edSStella Laurenzo // f16 4445d6d30edSStella Laurenzo assert(arrayInfo.itemsize == 2 && "mismatched array itemsize"); 4455d6d30edSStella Laurenzo bulkLoadElementType = mlirF16TypeGet(context); 446436c6c9cSStella Laurenzo } else if (isSignedIntegerFormat(arrayInfo.format)) { 447436c6c9cSStella Laurenzo if (arrayInfo.itemsize == 4) { 448436c6c9cSStella Laurenzo // i32 4495d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32) 450436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 32); 451436c6c9cSStella Laurenzo } else if (arrayInfo.itemsize == 8) { 452436c6c9cSStella Laurenzo // i64 4535d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64) 454436c6c9cSStella Laurenzo : mlirIntegerTypeSignedGet(context, 64); 4555d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 1) { 4565d6d30edSStella Laurenzo // i8 4575d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 4585d6d30edSStella Laurenzo : mlirIntegerTypeSignedGet(context, 8); 4595d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 2) { 4605d6d30edSStella Laurenzo // i16 4615d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16) 4625d6d30edSStella Laurenzo : mlirIntegerTypeSignedGet(context, 16); 463436c6c9cSStella Laurenzo } 464436c6c9cSStella Laurenzo } else if (isUnsignedIntegerFormat(arrayInfo.format)) { 465436c6c9cSStella Laurenzo if (arrayInfo.itemsize == 4) { 466436c6c9cSStella Laurenzo // unsigned i32 4675d6d30edSStella Laurenzo bulkLoadElementType = signless 468436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 32) 469436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 32); 470436c6c9cSStella Laurenzo } else if (arrayInfo.itemsize == 8) { 471436c6c9cSStella Laurenzo // unsigned i64 4725d6d30edSStella Laurenzo bulkLoadElementType = signless 473436c6c9cSStella Laurenzo ? mlirIntegerTypeGet(context, 64) 474436c6c9cSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 64); 4755d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 1) { 4765d6d30edSStella Laurenzo // i8 4775d6d30edSStella Laurenzo bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 4785d6d30edSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 8); 4795d6d30edSStella Laurenzo } else if (arrayInfo.itemsize == 2) { 4805d6d30edSStella Laurenzo // i16 4815d6d30edSStella Laurenzo bulkLoadElementType = signless 4825d6d30edSStella Laurenzo ? mlirIntegerTypeGet(context, 16) 4835d6d30edSStella Laurenzo : mlirIntegerTypeUnsignedGet(context, 16); 484436c6c9cSStella Laurenzo } 485436c6c9cSStella Laurenzo } 4865d6d30edSStella Laurenzo if (bulkLoadElementType) { 4875d6d30edSStella Laurenzo auto shapedType = mlirRankedTensorTypeGet( 4885d6d30edSStella Laurenzo shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); 4895d6d30edSStella Laurenzo size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize; 4905d6d30edSStella Laurenzo MlirAttribute attr = mlirDenseElementsAttrRawBufferGet( 4915d6d30edSStella Laurenzo shapedType, rawBufferSize, arrayInfo.ptr); 4925d6d30edSStella Laurenzo if (mlirAttributeIsNull(attr)) { 4935d6d30edSStella Laurenzo throw std::invalid_argument( 4945d6d30edSStella Laurenzo "DenseElementsAttr could not be constructed from the given buffer. " 4955d6d30edSStella Laurenzo "This may mean that the Python buffer layout does not match that " 4965d6d30edSStella Laurenzo "MLIR expected layout and is a bug."); 4975d6d30edSStella Laurenzo } 4985d6d30edSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), attr); 4995d6d30edSStella Laurenzo } 500436c6c9cSStella Laurenzo 5015d6d30edSStella Laurenzo throw std::invalid_argument( 5025d6d30edSStella Laurenzo std::string("unimplemented array format conversion from format: ") + 5035d6d30edSStella Laurenzo arrayInfo.format); 504436c6c9cSStella Laurenzo } 505436c6c9cSStella Laurenzo 5061fc096afSMehdi Amini static PyDenseElementsAttribute getSplat(const PyType &shapedType, 507436c6c9cSStella Laurenzo PyAttribute &elementAttr) { 508436c6c9cSStella Laurenzo auto contextWrapper = 509436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 510436c6c9cSStella Laurenzo if (!mlirAttributeIsAInteger(elementAttr) && 511436c6c9cSStella Laurenzo !mlirAttributeIsAFloat(elementAttr)) { 512436c6c9cSStella Laurenzo std::string message = "Illegal element type for DenseElementsAttr: "; 513436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 514436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 515436c6c9cSStella Laurenzo } 516436c6c9cSStella Laurenzo if (!mlirTypeIsAShaped(shapedType) || 517436c6c9cSStella Laurenzo !mlirShapedTypeHasStaticShape(shapedType)) { 518436c6c9cSStella Laurenzo std::string message = 519436c6c9cSStella Laurenzo "Expected a static ShapedType for the shaped_type parameter: "; 520436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 521436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 522436c6c9cSStella Laurenzo } 523436c6c9cSStella Laurenzo MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 524436c6c9cSStella Laurenzo MlirType attrType = mlirAttributeGetType(elementAttr); 525436c6c9cSStella Laurenzo if (!mlirTypeEqual(shapedElementType, attrType)) { 526436c6c9cSStella Laurenzo std::string message = 527436c6c9cSStella Laurenzo "Shaped element type and attribute type must be equal: shaped="; 528436c6c9cSStella Laurenzo message.append(py::repr(py::cast(shapedType))); 529436c6c9cSStella Laurenzo message.append(", element="); 530436c6c9cSStella Laurenzo message.append(py::repr(py::cast(elementAttr))); 531436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, message); 532436c6c9cSStella Laurenzo } 533436c6c9cSStella Laurenzo 534436c6c9cSStella Laurenzo MlirAttribute elements = 535436c6c9cSStella Laurenzo mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 536436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 537436c6c9cSStella Laurenzo } 538436c6c9cSStella Laurenzo 539436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 540436c6c9cSStella Laurenzo 541436c6c9cSStella Laurenzo py::buffer_info accessBuffer() { 5425d6d30edSStella Laurenzo if (mlirDenseElementsAttrIsSplat(*this)) { 543c5f445d1SStella Laurenzo // TODO: Currently crashes the program. 5445d6d30edSStella Laurenzo // Reported as https://github.com/pybind/pybind11/issues/3336 545c5f445d1SStella Laurenzo throw std::invalid_argument( 546c5f445d1SStella Laurenzo "unsupported data type for conversion to Python buffer"); 5475d6d30edSStella Laurenzo } 5485d6d30edSStella Laurenzo 549436c6c9cSStella Laurenzo MlirType shapedType = mlirAttributeGetType(*this); 550436c6c9cSStella Laurenzo MlirType elementType = mlirShapedTypeGetElementType(shapedType); 5515d6d30edSStella Laurenzo std::string format; 552436c6c9cSStella Laurenzo 553436c6c9cSStella Laurenzo if (mlirTypeIsAF32(elementType)) { 554436c6c9cSStella Laurenzo // f32 5555d6d30edSStella Laurenzo return bufferInfo<float>(shapedType); 55602b6fb21SMehdi Amini } 55702b6fb21SMehdi Amini if (mlirTypeIsAF64(elementType)) { 558436c6c9cSStella Laurenzo // f64 5595d6d30edSStella Laurenzo return bufferInfo<double>(shapedType); 560bb56c2b3SMehdi Amini } 561bb56c2b3SMehdi Amini if (mlirTypeIsAF16(elementType)) { 5625d6d30edSStella Laurenzo // f16 5635d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType, "e"); 564bb56c2b3SMehdi Amini } 565bb56c2b3SMehdi Amini if (mlirTypeIsAInteger(elementType) && 566436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 32) { 567436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 568436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 569436c6c9cSStella Laurenzo // i32 5705d6d30edSStella Laurenzo return bufferInfo<int32_t>(shapedType); 571e5639b3fSMehdi Amini } 572e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 573436c6c9cSStella Laurenzo // unsigned i32 5745d6d30edSStella Laurenzo return bufferInfo<uint32_t>(shapedType); 575436c6c9cSStella Laurenzo } 576436c6c9cSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 577436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 64) { 578436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 579436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 580436c6c9cSStella Laurenzo // i64 5815d6d30edSStella Laurenzo return bufferInfo<int64_t>(shapedType); 582e5639b3fSMehdi Amini } 583e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 584436c6c9cSStella Laurenzo // unsigned i64 5855d6d30edSStella Laurenzo return bufferInfo<uint64_t>(shapedType); 5865d6d30edSStella Laurenzo } 5875d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 5885d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 8) { 5895d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 5905d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 5915d6d30edSStella Laurenzo // i8 5925d6d30edSStella Laurenzo return bufferInfo<int8_t>(shapedType); 593e5639b3fSMehdi Amini } 594e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 5955d6d30edSStella Laurenzo // unsigned i8 5965d6d30edSStella Laurenzo return bufferInfo<uint8_t>(shapedType); 5975d6d30edSStella Laurenzo } 5985d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 5995d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 16) { 6005d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 6015d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 6025d6d30edSStella Laurenzo // i16 6035d6d30edSStella Laurenzo return bufferInfo<int16_t>(shapedType); 604e5639b3fSMehdi Amini } 605e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 6065d6d30edSStella Laurenzo // unsigned i16 6075d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType); 608436c6c9cSStella Laurenzo } 609436c6c9cSStella Laurenzo } 610436c6c9cSStella Laurenzo 611c5f445d1SStella Laurenzo // TODO: Currently crashes the program. 6125d6d30edSStella Laurenzo // Reported as https://github.com/pybind/pybind11/issues/3336 613c5f445d1SStella Laurenzo throw std::invalid_argument( 614c5f445d1SStella Laurenzo "unsupported data type for conversion to Python buffer"); 615436c6c9cSStella Laurenzo } 616436c6c9cSStella Laurenzo 617436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 618436c6c9cSStella Laurenzo c.def("__len__", &PyDenseElementsAttribute::dunderLen) 619436c6c9cSStella Laurenzo .def_static("get", PyDenseElementsAttribute::getFromBuffer, 620436c6c9cSStella Laurenzo py::arg("array"), py::arg("signless") = true, 6215d6d30edSStella Laurenzo py::arg("type") = py::none(), py::arg("shape") = py::none(), 622436c6c9cSStella Laurenzo py::arg("context") = py::none(), 6235d6d30edSStella Laurenzo kDenseElementsAttrGetDocstring) 624436c6c9cSStella Laurenzo .def_static("get_splat", PyDenseElementsAttribute::getSplat, 625436c6c9cSStella Laurenzo py::arg("shaped_type"), py::arg("element_attr"), 626436c6c9cSStella Laurenzo "Gets a DenseElementsAttr where all values are the same") 627436c6c9cSStella Laurenzo .def_property_readonly("is_splat", 628436c6c9cSStella Laurenzo [](PyDenseElementsAttribute &self) -> bool { 629436c6c9cSStella Laurenzo return mlirDenseElementsAttrIsSplat(self); 630436c6c9cSStella Laurenzo }) 631436c6c9cSStella Laurenzo .def_buffer(&PyDenseElementsAttribute::accessBuffer); 632436c6c9cSStella Laurenzo } 633436c6c9cSStella Laurenzo 634436c6c9cSStella Laurenzo private: 635436c6c9cSStella Laurenzo static bool isUnsignedIntegerFormat(const std::string &format) { 636436c6c9cSStella Laurenzo if (format.empty()) 637436c6c9cSStella Laurenzo return false; 638436c6c9cSStella Laurenzo char code = format[0]; 639436c6c9cSStella Laurenzo return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 640436c6c9cSStella Laurenzo code == 'Q'; 641436c6c9cSStella Laurenzo } 642436c6c9cSStella Laurenzo 643436c6c9cSStella Laurenzo static bool isSignedIntegerFormat(const std::string &format) { 644436c6c9cSStella Laurenzo if (format.empty()) 645436c6c9cSStella Laurenzo return false; 646436c6c9cSStella Laurenzo char code = format[0]; 647436c6c9cSStella Laurenzo return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 648436c6c9cSStella Laurenzo code == 'q'; 649436c6c9cSStella Laurenzo } 650436c6c9cSStella Laurenzo 651436c6c9cSStella Laurenzo template <typename Type> 652436c6c9cSStella Laurenzo py::buffer_info bufferInfo(MlirType shapedType, 6535d6d30edSStella Laurenzo const char *explicitFormat = nullptr) { 654436c6c9cSStella Laurenzo intptr_t rank = mlirShapedTypeGetRank(shapedType); 655436c6c9cSStella Laurenzo // Prepare the data for the buffer_info. 656436c6c9cSStella Laurenzo // Buffer is configured for read-only access below. 657436c6c9cSStella Laurenzo Type *data = static_cast<Type *>( 658436c6c9cSStella Laurenzo const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 659436c6c9cSStella Laurenzo // Prepare the shape for the buffer_info. 660436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> shape; 661436c6c9cSStella Laurenzo for (intptr_t i = 0; i < rank; ++i) 662436c6c9cSStella Laurenzo shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 663436c6c9cSStella Laurenzo // Prepare the strides for the buffer_info. 664436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> strides; 665436c6c9cSStella Laurenzo intptr_t strideFactor = 1; 666436c6c9cSStella Laurenzo for (intptr_t i = 1; i < rank; ++i) { 667436c6c9cSStella Laurenzo strideFactor = 1; 668436c6c9cSStella Laurenzo for (intptr_t j = i; j < rank; ++j) { 669436c6c9cSStella Laurenzo strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 670436c6c9cSStella Laurenzo } 671436c6c9cSStella Laurenzo strides.push_back(sizeof(Type) * strideFactor); 672436c6c9cSStella Laurenzo } 673436c6c9cSStella Laurenzo strides.push_back(sizeof(Type)); 6745d6d30edSStella Laurenzo std::string format; 6755d6d30edSStella Laurenzo if (explicitFormat) { 6765d6d30edSStella Laurenzo format = explicitFormat; 6775d6d30edSStella Laurenzo } else { 6785d6d30edSStella Laurenzo format = py::format_descriptor<Type>::format(); 6795d6d30edSStella Laurenzo } 6805d6d30edSStella Laurenzo return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, 6815d6d30edSStella Laurenzo /*readonly=*/true); 682436c6c9cSStella Laurenzo } 683436c6c9cSStella Laurenzo }; // namespace 684436c6c9cSStella Laurenzo 685436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer 686436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access. 687436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute 688436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseIntElementsAttribute, 689436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 690436c6c9cSStella Laurenzo public: 691436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 692436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseIntElementsAttr"; 693436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 694436c6c9cSStella Laurenzo 695436c6c9cSStella Laurenzo /// Returns the element at the given linear position. Asserts if the index is 696436c6c9cSStella Laurenzo /// out of range. 697436c6c9cSStella Laurenzo py::int_ dunderGetItem(intptr_t pos) { 698436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 699436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 700436c6c9cSStella Laurenzo "attempt to access out of bounds element"); 701436c6c9cSStella Laurenzo } 702436c6c9cSStella Laurenzo 703436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 704436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 705436c6c9cSStella Laurenzo assert(mlirTypeIsAInteger(type) && 706436c6c9cSStella Laurenzo "expected integer element type in dense int elements attribute"); 707436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 708436c6c9cSStella Laurenzo // elemental type of the attribute. py::int_ is implicitly constructible 709436c6c9cSStella Laurenzo // from any C++ integral type and handles bitwidth correctly. 710436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 711436c6c9cSStella Laurenzo // querying them on each element access. 712436c6c9cSStella Laurenzo unsigned width = mlirIntegerTypeGetWidth(type); 713436c6c9cSStella Laurenzo bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 714436c6c9cSStella Laurenzo if (isUnsigned) { 715436c6c9cSStella Laurenzo if (width == 1) { 716436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 717436c6c9cSStella Laurenzo } 718308d8b8cSRahul Kayaith if (width == 8) { 719308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetUInt8Value(*this, pos); 720308d8b8cSRahul Kayaith } 721308d8b8cSRahul Kayaith if (width == 16) { 722308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetUInt16Value(*this, pos); 723308d8b8cSRahul Kayaith } 724436c6c9cSStella Laurenzo if (width == 32) { 725436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt32Value(*this, pos); 726436c6c9cSStella Laurenzo } 727436c6c9cSStella Laurenzo if (width == 64) { 728436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetUInt64Value(*this, pos); 729436c6c9cSStella Laurenzo } 730436c6c9cSStella Laurenzo } else { 731436c6c9cSStella Laurenzo if (width == 1) { 732436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetBoolValue(*this, pos); 733436c6c9cSStella Laurenzo } 734308d8b8cSRahul Kayaith if (width == 8) { 735308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetInt8Value(*this, pos); 736308d8b8cSRahul Kayaith } 737308d8b8cSRahul Kayaith if (width == 16) { 738308d8b8cSRahul Kayaith return mlirDenseElementsAttrGetInt16Value(*this, pos); 739308d8b8cSRahul Kayaith } 740436c6c9cSStella Laurenzo if (width == 32) { 741436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt32Value(*this, pos); 742436c6c9cSStella Laurenzo } 743436c6c9cSStella Laurenzo if (width == 64) { 744436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetInt64Value(*this, pos); 745436c6c9cSStella Laurenzo } 746436c6c9cSStella Laurenzo } 747436c6c9cSStella Laurenzo throw SetPyError(PyExc_TypeError, "Unsupported integer type"); 748436c6c9cSStella Laurenzo } 749436c6c9cSStella Laurenzo 750436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 751436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 752436c6c9cSStella Laurenzo } 753436c6c9cSStella Laurenzo }; 754436c6c9cSStella Laurenzo 755436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 756436c6c9cSStella Laurenzo public: 757436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 758436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DictAttr"; 759436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 760436c6c9cSStella Laurenzo 761436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 762436c6c9cSStella Laurenzo 7639fb1086bSAdrian Kuegel bool dunderContains(const std::string &name) { 7649fb1086bSAdrian Kuegel return !mlirAttributeIsNull( 7659fb1086bSAdrian Kuegel mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); 7669fb1086bSAdrian Kuegel } 7679fb1086bSAdrian Kuegel 768436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 7699fb1086bSAdrian Kuegel c.def("__contains__", &PyDictAttribute::dunderContains); 770436c6c9cSStella Laurenzo c.def("__len__", &PyDictAttribute::dunderLen); 771436c6c9cSStella Laurenzo c.def_static( 772436c6c9cSStella Laurenzo "get", 773436c6c9cSStella Laurenzo [](py::dict attributes, DefaultingPyMlirContext context) { 774436c6c9cSStella Laurenzo SmallVector<MlirNamedAttribute> mlirNamedAttributes; 775436c6c9cSStella Laurenzo mlirNamedAttributes.reserve(attributes.size()); 776436c6c9cSStella Laurenzo for (auto &it : attributes) { 77702b6fb21SMehdi Amini auto &mlirAttr = it.second.cast<PyAttribute &>(); 778436c6c9cSStella Laurenzo auto name = it.first.cast<std::string>(); 779436c6c9cSStella Laurenzo mlirNamedAttributes.push_back(mlirNamedAttributeGet( 78002b6fb21SMehdi Amini mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), 781436c6c9cSStella Laurenzo toMlirStringRef(name)), 78202b6fb21SMehdi Amini mlirAttr)); 783436c6c9cSStella Laurenzo } 784436c6c9cSStella Laurenzo MlirAttribute attr = 785436c6c9cSStella Laurenzo mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 786436c6c9cSStella Laurenzo mlirNamedAttributes.data()); 787436c6c9cSStella Laurenzo return PyDictAttribute(context->getRef(), attr); 788436c6c9cSStella Laurenzo }, 789ed9e52f3SAlex Zinenko py::arg("value") = py::dict(), py::arg("context") = py::none(), 790436c6c9cSStella Laurenzo "Gets an uniqued dict attribute"); 791436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 792436c6c9cSStella Laurenzo MlirAttribute attr = 793436c6c9cSStella Laurenzo mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 794436c6c9cSStella Laurenzo if (mlirAttributeIsNull(attr)) { 795436c6c9cSStella Laurenzo throw SetPyError(PyExc_KeyError, 796436c6c9cSStella Laurenzo "attempt to access a non-existent attribute"); 797436c6c9cSStella Laurenzo } 798436c6c9cSStella Laurenzo return PyAttribute(self.getContext(), attr); 799436c6c9cSStella Laurenzo }); 800436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 801436c6c9cSStella Laurenzo if (index < 0 || index >= self.dunderLen()) { 802436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 803436c6c9cSStella Laurenzo "attempt to access out of bounds attribute"); 804436c6c9cSStella Laurenzo } 805436c6c9cSStella Laurenzo MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 806436c6c9cSStella Laurenzo return PyNamedAttribute( 807436c6c9cSStella Laurenzo namedAttr.attribute, 808436c6c9cSStella Laurenzo std::string(mlirIdentifierStr(namedAttr.name).data)); 809436c6c9cSStella Laurenzo }); 810436c6c9cSStella Laurenzo } 811436c6c9cSStella Laurenzo }; 812436c6c9cSStella Laurenzo 813436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing 814436c6c9cSStella Laurenzo /// floating-point values. Supports element access. 815436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute 816436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseFPElementsAttribute, 817436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 818436c6c9cSStella Laurenzo public: 819436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 820436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseFPElementsAttr"; 821436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 822436c6c9cSStella Laurenzo 823436c6c9cSStella Laurenzo py::float_ dunderGetItem(intptr_t pos) { 824436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 825436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 826436c6c9cSStella Laurenzo "attempt to access out of bounds element"); 827436c6c9cSStella Laurenzo } 828436c6c9cSStella Laurenzo 829436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 830436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 831436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 832436c6c9cSStella Laurenzo // elemental type of the attribute. py::float_ is implicitly constructible 833436c6c9cSStella Laurenzo // from float and double. 834436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 835436c6c9cSStella Laurenzo // querying them on each element access. 836436c6c9cSStella Laurenzo if (mlirTypeIsAF32(type)) { 837436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetFloatValue(*this, pos); 838436c6c9cSStella Laurenzo } 839436c6c9cSStella Laurenzo if (mlirTypeIsAF64(type)) { 840436c6c9cSStella Laurenzo return mlirDenseElementsAttrGetDoubleValue(*this, pos); 841436c6c9cSStella Laurenzo } 842436c6c9cSStella Laurenzo throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); 843436c6c9cSStella Laurenzo } 844436c6c9cSStella Laurenzo 845436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 846436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 847436c6c9cSStella Laurenzo } 848436c6c9cSStella Laurenzo }; 849436c6c9cSStella Laurenzo 850436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 851436c6c9cSStella Laurenzo public: 852436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 853436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "TypeAttr"; 854436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 855436c6c9cSStella Laurenzo 856436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 857436c6c9cSStella Laurenzo c.def_static( 858436c6c9cSStella Laurenzo "get", 859436c6c9cSStella Laurenzo [](PyType value, DefaultingPyMlirContext context) { 860436c6c9cSStella Laurenzo MlirAttribute attr = mlirTypeAttrGet(value.get()); 861436c6c9cSStella Laurenzo return PyTypeAttribute(context->getRef(), attr); 862436c6c9cSStella Laurenzo }, 863436c6c9cSStella Laurenzo py::arg("value"), py::arg("context") = py::none(), 864436c6c9cSStella Laurenzo "Gets a uniqued Type attribute"); 865436c6c9cSStella Laurenzo c.def_property_readonly("value", [](PyTypeAttribute &self) { 866436c6c9cSStella Laurenzo return PyType(self.getContext()->getRef(), 867436c6c9cSStella Laurenzo mlirTypeAttrGetValue(self.get())); 868436c6c9cSStella Laurenzo }); 869436c6c9cSStella Laurenzo } 870436c6c9cSStella Laurenzo }; 871436c6c9cSStella Laurenzo 872436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values. 873436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 874436c6c9cSStella Laurenzo public: 875436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 876436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "UnitAttr"; 877436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 878436c6c9cSStella Laurenzo 879436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 880436c6c9cSStella Laurenzo c.def_static( 881436c6c9cSStella Laurenzo "get", 882436c6c9cSStella Laurenzo [](DefaultingPyMlirContext context) { 883436c6c9cSStella Laurenzo return PyUnitAttribute(context->getRef(), 884436c6c9cSStella Laurenzo mlirUnitAttrGet(context->get())); 885436c6c9cSStella Laurenzo }, 886436c6c9cSStella Laurenzo py::arg("context") = py::none(), "Create a Unit attribute."); 887436c6c9cSStella Laurenzo } 888436c6c9cSStella Laurenzo }; 889436c6c9cSStella Laurenzo 890436c6c9cSStella Laurenzo } // namespace 891436c6c9cSStella Laurenzo 892436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) { 893436c6c9cSStella Laurenzo PyAffineMapAttribute::bind(m); 894436c6c9cSStella Laurenzo PyArrayAttribute::bind(m); 895436c6c9cSStella Laurenzo PyArrayAttribute::PyArrayAttributeIterator::bind(m); 896436c6c9cSStella Laurenzo PyBoolAttribute::bind(m); 897436c6c9cSStella Laurenzo PyDenseElementsAttribute::bind(m); 898436c6c9cSStella Laurenzo PyDenseFPElementsAttribute::bind(m); 899436c6c9cSStella Laurenzo PyDenseIntElementsAttribute::bind(m); 900436c6c9cSStella Laurenzo PyDictAttribute::bind(m); 901436c6c9cSStella Laurenzo PyFlatSymbolRefAttribute::bind(m); 902*5c3861b2SYun Long PyOpaqueAttribute::bind(m); 903436c6c9cSStella Laurenzo PyFloatAttribute::bind(m); 904436c6c9cSStella Laurenzo PyIntegerAttribute::bind(m); 905436c6c9cSStella Laurenzo PyStringAttribute::bind(m); 906436c6c9cSStella Laurenzo PyTypeAttribute::bind(m); 907436c6c9cSStella Laurenzo PyUnitAttribute::bind(m); 908436c6c9cSStella Laurenzo } 909