1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// 2436c6c9cSStella Laurenzo // 3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6436c6c9cSStella Laurenzo // 7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8436c6c9cSStella Laurenzo 9b56d1ec6SPeter Hawkins #include <cstdint> 10a1fe1f5fSKazu Hirata #include <optional> 11b56d1ec6SPeter Hawkins #include <string> 1271a25454SPeter Hawkins #include <string_view> 134811270bSmax #include <utility> 141fc096afSMehdi Amini 15436c6c9cSStella Laurenzo #include "IRModule.h" 16b56d1ec6SPeter Hawkins #include "NanobindUtils.h" 17b56d1ec6SPeter Hawkins #include "mlir-c/BuiltinAttributes.h" 18b56d1ec6SPeter Hawkins #include "mlir-c/BuiltinTypes.h" 19b56d1ec6SPeter Hawkins #include "mlir/Bindings/Python/NanobindAdaptors.h" 205cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h" 2171a25454SPeter Hawkins #include "llvm/ADT/ScopeExit.h" 22c912f0e7Spranavm-nvidia #include "llvm/Support/raw_ostream.h" 2371a25454SPeter Hawkins 24b56d1ec6SPeter Hawkins namespace nb = nanobind; 25b56d1ec6SPeter Hawkins using namespace nanobind::literals; 26436c6c9cSStella Laurenzo using namespace mlir; 27436c6c9cSStella Laurenzo using namespace mlir::python; 28436c6c9cSStella Laurenzo 29436c6c9cSStella Laurenzo using llvm::SmallVector; 30436c6c9cSStella Laurenzo 315d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 325d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline). 335d6d30edSStella Laurenzo //------------------------------------------------------------------------------ 345d6d30edSStella Laurenzo 355d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] = 365d6d30edSStella Laurenzo R"(Gets a DenseElementsAttr from a Python buffer or array. 375d6d30edSStella Laurenzo 385d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based 395d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and 405d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these 415d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer. 425d6d30edSStella Laurenzo 435d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided 445d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal 455d6d30edSStella Laurenzo representation: 465d6d30edSStella Laurenzo 475d6d30edSStella Laurenzo * Integer types (except for i1): the buffer must be byte aligned to the 485d6d30edSStella Laurenzo next byte boundary. 495d6d30edSStella Laurenzo * Floating point types: Must be bit-castable to the given floating point 505d6d30edSStella Laurenzo size. 515d6d30edSStella Laurenzo * i1 (bool): Bit packed into 8bit words where the bit pattern matches a 525d6d30edSStella Laurenzo row major ordering. An arbitrary Numpy `bool_` array can be bit packed to 535d6d30edSStella Laurenzo this specification with: `np.packbits(ary, axis=None, bitorder='little')`. 545d6d30edSStella Laurenzo 555d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0 565d6d30edSStella Laurenzo or 255), then a splat will be created. 575d6d30edSStella Laurenzo 585d6d30edSStella Laurenzo Args: 595d6d30edSStella Laurenzo array: The array or buffer to convert. 605d6d30edSStella Laurenzo signless: If inferring an appropriate MLIR type, use signless types for 615d6d30edSStella Laurenzo integers (defaults True). 625d6d30edSStella Laurenzo type: Skips inference of the MLIR element type and uses this instead. The 635d6d30edSStella Laurenzo storage size must be consistent with the actual contents of the buffer. 645d6d30edSStella Laurenzo shape: Overrides the shape of the buffer when constructing the MLIR 655d6d30edSStella Laurenzo shaped type. This is needed when the physical and logical shape differ (as 665d6d30edSStella Laurenzo for i1). 675d6d30edSStella Laurenzo context: Explicit context, if not from context manager. 685d6d30edSStella Laurenzo 695d6d30edSStella Laurenzo Returns: 705d6d30edSStella Laurenzo DenseElementsAttr on success. 715d6d30edSStella Laurenzo 725d6d30edSStella Laurenzo Raises: 735d6d30edSStella Laurenzo ValueError: If the type of the buffer or array cannot be matched to an MLIR 745d6d30edSStella Laurenzo type or if the buffer does not meet expectations. 755d6d30edSStella Laurenzo )"; 765d6d30edSStella Laurenzo 77c912f0e7Spranavm-nvidia static const char kDenseElementsAttrGetFromListDocstring[] = 78c912f0e7Spranavm-nvidia R"(Gets a DenseElementsAttr from a Python list of attributes. 79c912f0e7Spranavm-nvidia 80c912f0e7Spranavm-nvidia Note that it can be expensive to construct attributes individually. 81c912f0e7Spranavm-nvidia For a large number of elements, consider using a Python buffer or array instead. 82c912f0e7Spranavm-nvidia 83c912f0e7Spranavm-nvidia Args: 84c912f0e7Spranavm-nvidia attrs: A list of attributes. 85c912f0e7Spranavm-nvidia type: The desired shape and type of the resulting DenseElementsAttr. 86c912f0e7Spranavm-nvidia If not provided, the element type is determined based on the type 87c912f0e7Spranavm-nvidia of the 0th attribute and the shape is `[len(attrs)]`. 88c912f0e7Spranavm-nvidia context: Explicit context, if not from context manager. 89c912f0e7Spranavm-nvidia 90c912f0e7Spranavm-nvidia Returns: 91c912f0e7Spranavm-nvidia DenseElementsAttr on success. 92c912f0e7Spranavm-nvidia 93c912f0e7Spranavm-nvidia Raises: 94c912f0e7Spranavm-nvidia ValueError: If the type of the attributes does not match the type 95c912f0e7Spranavm-nvidia specified by `shaped_type`. 96c912f0e7Spranavm-nvidia )"; 97c912f0e7Spranavm-nvidia 98f66cd9e9SStella Laurenzo static const char kDenseResourceElementsAttrGetFromBufferDocstring[] = 99f66cd9e9SStella Laurenzo R"(Gets a DenseResourceElementsAttr from a Python buffer or array. 100f66cd9e9SStella Laurenzo 101f66cd9e9SStella Laurenzo This function does minimal validation or massaging of the data, and it is 102f66cd9e9SStella Laurenzo up to the caller to ensure that the buffer meets the characteristics 103f66cd9e9SStella Laurenzo implied by the shape. 104f66cd9e9SStella Laurenzo 105f66cd9e9SStella Laurenzo The backing buffer and any user objects will be retained for the lifetime 106f66cd9e9SStella Laurenzo of the resource blob. This is typically bounded to the context but the 107f66cd9e9SStella Laurenzo resource can have a shorter lifespan depending on how it is used in 108f66cd9e9SStella Laurenzo subsequent processing. 109f66cd9e9SStella Laurenzo 110f66cd9e9SStella Laurenzo Args: 111f66cd9e9SStella Laurenzo buffer: The array or buffer to convert. 112f66cd9e9SStella Laurenzo name: Name to provide to the resource (may be changed upon collision). 113f66cd9e9SStella Laurenzo type: The explicit ShapedType to construct the attribute with. 114f66cd9e9SStella Laurenzo context: Explicit context, if not from context manager. 115f66cd9e9SStella Laurenzo 116f66cd9e9SStella Laurenzo Returns: 117f66cd9e9SStella Laurenzo DenseResourceElementsAttr on success. 118f66cd9e9SStella Laurenzo 119f66cd9e9SStella Laurenzo Raises: 120f66cd9e9SStella Laurenzo ValueError: If the type of the buffer or array cannot be matched to an MLIR 121f66cd9e9SStella Laurenzo type or if the buffer does not meet expectations. 122f66cd9e9SStella Laurenzo )"; 123f66cd9e9SStella Laurenzo 124436c6c9cSStella Laurenzo namespace { 125436c6c9cSStella Laurenzo 126b56d1ec6SPeter Hawkins struct nb_buffer_info { 127b56d1ec6SPeter Hawkins void *ptr = nullptr; 128b56d1ec6SPeter Hawkins ssize_t itemsize = 0; 129b56d1ec6SPeter Hawkins ssize_t size = 0; 130b56d1ec6SPeter Hawkins const char *format = nullptr; 131b56d1ec6SPeter Hawkins ssize_t ndim = 0; 132b56d1ec6SPeter Hawkins SmallVector<ssize_t, 4> shape; 133b56d1ec6SPeter Hawkins SmallVector<ssize_t, 4> strides; 134b56d1ec6SPeter Hawkins bool readonly = false; 135b56d1ec6SPeter Hawkins 136b56d1ec6SPeter Hawkins nb_buffer_info( 137b56d1ec6SPeter Hawkins void *ptr, ssize_t itemsize, const char *format, ssize_t ndim, 138b56d1ec6SPeter Hawkins SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in, 139b56d1ec6SPeter Hawkins bool readonly = false, 140b56d1ec6SPeter Hawkins std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in = 141b56d1ec6SPeter Hawkins std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr)) 142b56d1ec6SPeter Hawkins : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim), 143b56d1ec6SPeter Hawkins shape(std::move(shape_in)), strides(std::move(strides_in)), 144b56d1ec6SPeter Hawkins readonly(readonly), owned_view(std::move(owned_view_in)) { 145b56d1ec6SPeter Hawkins size = 1; 146b56d1ec6SPeter Hawkins for (ssize_t i = 0; i < ndim; ++i) { 147b56d1ec6SPeter Hawkins size *= shape[i]; 148b56d1ec6SPeter Hawkins } 149b56d1ec6SPeter Hawkins } 150b56d1ec6SPeter Hawkins 151b56d1ec6SPeter Hawkins explicit nb_buffer_info(Py_buffer *view) 152b56d1ec6SPeter Hawkins : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim, 153b56d1ec6SPeter Hawkins {view->shape, view->shape + view->ndim}, 154b56d1ec6SPeter Hawkins // TODO(phawkins): check for null strides 155b56d1ec6SPeter Hawkins {view->strides, view->strides + view->ndim}, 156b56d1ec6SPeter Hawkins view->readonly != 0, 157b56d1ec6SPeter Hawkins std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>( 158b56d1ec6SPeter Hawkins view, PyBuffer_Release)) {} 159b56d1ec6SPeter Hawkins 160b56d1ec6SPeter Hawkins nb_buffer_info(const nb_buffer_info &) = delete; 161b56d1ec6SPeter Hawkins nb_buffer_info(nb_buffer_info &&) = default; 162b56d1ec6SPeter Hawkins nb_buffer_info &operator=(const nb_buffer_info &) = delete; 163b56d1ec6SPeter Hawkins nb_buffer_info &operator=(nb_buffer_info &&) = default; 164b56d1ec6SPeter Hawkins 165b56d1ec6SPeter Hawkins private: 166b56d1ec6SPeter Hawkins std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view; 167b56d1ec6SPeter Hawkins }; 168b56d1ec6SPeter Hawkins 169b56d1ec6SPeter Hawkins class nb_buffer : public nb::object { 170b56d1ec6SPeter Hawkins NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer); 171b56d1ec6SPeter Hawkins 172b56d1ec6SPeter Hawkins nb_buffer_info request() const { 173b56d1ec6SPeter Hawkins int flags = PyBUF_STRIDES | PyBUF_FORMAT; 174b56d1ec6SPeter Hawkins auto *view = new Py_buffer(); 175b56d1ec6SPeter Hawkins if (PyObject_GetBuffer(ptr(), view, flags) != 0) { 176b56d1ec6SPeter Hawkins delete view; 177b56d1ec6SPeter Hawkins throw nb::python_error(); 178b56d1ec6SPeter Hawkins } 179b56d1ec6SPeter Hawkins return nb_buffer_info(view); 180b56d1ec6SPeter Hawkins } 181b56d1ec6SPeter Hawkins }; 182b56d1ec6SPeter Hawkins 183b56d1ec6SPeter Hawkins template <typename T> 184b56d1ec6SPeter Hawkins struct nb_format_descriptor {}; 185b56d1ec6SPeter Hawkins 186b56d1ec6SPeter Hawkins template <> 187b56d1ec6SPeter Hawkins struct nb_format_descriptor<bool> { 188b56d1ec6SPeter Hawkins static const char *format() { return "?"; } 189b56d1ec6SPeter Hawkins }; 190b56d1ec6SPeter Hawkins template <> 191b56d1ec6SPeter Hawkins struct nb_format_descriptor<int8_t> { 192b56d1ec6SPeter Hawkins static const char *format() { return "b"; } 193b56d1ec6SPeter Hawkins }; 194b56d1ec6SPeter Hawkins template <> 195b56d1ec6SPeter Hawkins struct nb_format_descriptor<uint8_t> { 196b56d1ec6SPeter Hawkins static const char *format() { return "B"; } 197b56d1ec6SPeter Hawkins }; 198b56d1ec6SPeter Hawkins template <> 199b56d1ec6SPeter Hawkins struct nb_format_descriptor<int16_t> { 200b56d1ec6SPeter Hawkins static const char *format() { return "h"; } 201b56d1ec6SPeter Hawkins }; 202b56d1ec6SPeter Hawkins template <> 203b56d1ec6SPeter Hawkins struct nb_format_descriptor<uint16_t> { 204b56d1ec6SPeter Hawkins static const char *format() { return "H"; } 205b56d1ec6SPeter Hawkins }; 206b56d1ec6SPeter Hawkins template <> 207b56d1ec6SPeter Hawkins struct nb_format_descriptor<int32_t> { 208b56d1ec6SPeter Hawkins static const char *format() { return "i"; } 209b56d1ec6SPeter Hawkins }; 210b56d1ec6SPeter Hawkins template <> 211b56d1ec6SPeter Hawkins struct nb_format_descriptor<uint32_t> { 212b56d1ec6SPeter Hawkins static const char *format() { return "I"; } 213b56d1ec6SPeter Hawkins }; 214b56d1ec6SPeter Hawkins template <> 215b56d1ec6SPeter Hawkins struct nb_format_descriptor<int64_t> { 216b56d1ec6SPeter Hawkins static const char *format() { return "q"; } 217b56d1ec6SPeter Hawkins }; 218b56d1ec6SPeter Hawkins template <> 219b56d1ec6SPeter Hawkins struct nb_format_descriptor<uint64_t> { 220b56d1ec6SPeter Hawkins static const char *format() { return "Q"; } 221b56d1ec6SPeter Hawkins }; 222b56d1ec6SPeter Hawkins template <> 223b56d1ec6SPeter Hawkins struct nb_format_descriptor<float> { 224b56d1ec6SPeter Hawkins static const char *format() { return "f"; } 225b56d1ec6SPeter Hawkins }; 226b56d1ec6SPeter Hawkins template <> 227b56d1ec6SPeter Hawkins struct nb_format_descriptor<double> { 228b56d1ec6SPeter Hawkins static const char *format() { return "d"; } 229b56d1ec6SPeter Hawkins }; 230b56d1ec6SPeter Hawkins 231436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) { 232436c6c9cSStella Laurenzo return mlirStringRefCreate(s.data(), s.size()); 233436c6c9cSStella Laurenzo } 234436c6c9cSStella Laurenzo 235b56d1ec6SPeter Hawkins static MlirStringRef toMlirStringRef(const nb::bytes &s) { 236b56d1ec6SPeter Hawkins return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size()); 237b56d1ec6SPeter Hawkins } 238b56d1ec6SPeter Hawkins 239436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { 240436c6c9cSStella Laurenzo public: 241436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; 242436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "AffineMapAttr"; 243436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 2449566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 2459566ee28Smax mlirAffineMapAttrGetTypeID; 246436c6c9cSStella Laurenzo 247436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 248436c6c9cSStella Laurenzo c.def_static( 249436c6c9cSStella Laurenzo "get", 250436c6c9cSStella Laurenzo [](PyAffineMap &affineMap) { 251436c6c9cSStella Laurenzo MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); 252436c6c9cSStella Laurenzo return PyAffineMapAttribute(affineMap.getContext(), attr); 253436c6c9cSStella Laurenzo }, 254b56d1ec6SPeter Hawkins nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); 255b56d1ec6SPeter Hawkins c.def_prop_ro("value", mlirAffineMapAttrGetValue, 256c36b4248SBimo "Returns the value of the AffineMap attribute"); 257436c6c9cSStella Laurenzo } 258436c6c9cSStella Laurenzo }; 259436c6c9cSStella Laurenzo 260334873feSAmy Wang class PyIntegerSetAttribute 261334873feSAmy Wang : public PyConcreteAttribute<PyIntegerSetAttribute> { 262334873feSAmy Wang public: 263334873feSAmy Wang static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet; 264334873feSAmy Wang static constexpr const char *pyClassName = "IntegerSetAttr"; 265334873feSAmy Wang using PyConcreteAttribute::PyConcreteAttribute; 266334873feSAmy Wang static constexpr GetTypeIDFunctionTy getTypeIdFunction = 267334873feSAmy Wang mlirIntegerSetAttrGetTypeID; 268334873feSAmy Wang 269334873feSAmy Wang static void bindDerived(ClassTy &c) { 270334873feSAmy Wang c.def_static( 271334873feSAmy Wang "get", 272334873feSAmy Wang [](PyIntegerSet &integerSet) { 273334873feSAmy Wang MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get()); 274334873feSAmy Wang return PyIntegerSetAttribute(integerSet.getContext(), attr); 275334873feSAmy Wang }, 276b56d1ec6SPeter Hawkins nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet."); 277334873feSAmy Wang } 278334873feSAmy Wang }; 279334873feSAmy Wang 280ed9e52f3SAlex Zinenko template <typename T> 281b56d1ec6SPeter Hawkins static T pyTryCast(nb::handle object) { 282ed9e52f3SAlex Zinenko try { 283b56d1ec6SPeter Hawkins return nb::cast<T>(object); 284b56d1ec6SPeter Hawkins } catch (nb::cast_error &err) { 285b56d1ec6SPeter Hawkins std::string msg = std::string("Invalid attribute when attempting to " 286b56d1ec6SPeter Hawkins "create an ArrayAttribute (") + 287ed9e52f3SAlex Zinenko err.what() + ")"; 288b56d1ec6SPeter Hawkins throw std::runtime_error(msg.c_str()); 289b56d1ec6SPeter Hawkins } catch (std::runtime_error &err) { 290ed9e52f3SAlex Zinenko std::string msg = std::string("Invalid attribute (None?) when attempting " 291ed9e52f3SAlex Zinenko "to create an ArrayAttribute (") + 292ed9e52f3SAlex Zinenko err.what() + ")"; 293b56d1ec6SPeter Hawkins throw std::runtime_error(msg.c_str()); 294ed9e52f3SAlex Zinenko } 295ed9e52f3SAlex Zinenko } 296ed9e52f3SAlex Zinenko 297619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived 298619fd8c2SJeff Niu /// implementation class. 299619fd8c2SJeff Niu template <typename EltTy, typename DerivedT> 300133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> { 301619fd8c2SJeff Niu public: 302133624acSJeff Niu using PyConcreteAttribute<DerivedT>::PyConcreteAttribute; 303619fd8c2SJeff Niu 304619fd8c2SJeff Niu /// Iterator over the integer elements of a dense array. 305619fd8c2SJeff Niu class PyDenseArrayIterator { 306619fd8c2SJeff Niu public: 3074a1b1196SMehdi Amini PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {} 308619fd8c2SJeff Niu 309619fd8c2SJeff Niu /// Return a copy of the iterator. 310619fd8c2SJeff Niu PyDenseArrayIterator dunderIter() { return *this; } 311619fd8c2SJeff Niu 312619fd8c2SJeff Niu /// Return the next element. 313619fd8c2SJeff Niu EltTy dunderNext() { 314619fd8c2SJeff Niu // Throw if the index has reached the end. 315619fd8c2SJeff Niu if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) 316b56d1ec6SPeter Hawkins throw nb::stop_iteration(); 317619fd8c2SJeff Niu return DerivedT::getElement(attr.get(), nextIndex++); 318619fd8c2SJeff Niu } 319619fd8c2SJeff Niu 320619fd8c2SJeff Niu /// Bind the iterator class. 321b56d1ec6SPeter Hawkins static void bind(nb::module_ &m) { 322b56d1ec6SPeter Hawkins nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName) 323619fd8c2SJeff Niu .def("__iter__", &PyDenseArrayIterator::dunderIter) 324619fd8c2SJeff Niu .def("__next__", &PyDenseArrayIterator::dunderNext); 325619fd8c2SJeff Niu } 326619fd8c2SJeff Niu 327619fd8c2SJeff Niu private: 328619fd8c2SJeff Niu /// The referenced dense array attribute. 329619fd8c2SJeff Niu PyAttribute attr; 330619fd8c2SJeff Niu /// The next index to read. 331619fd8c2SJeff Niu int nextIndex = 0; 332619fd8c2SJeff Niu }; 333619fd8c2SJeff Niu 334619fd8c2SJeff Niu /// Get the element at the given index. 335619fd8c2SJeff Niu EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); } 336619fd8c2SJeff Niu 337619fd8c2SJeff Niu /// Bind the attribute class. 338133624acSJeff Niu static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) { 339619fd8c2SJeff Niu // Bind the constructor. 340b56d1ec6SPeter Hawkins if constexpr (std::is_same_v<EltTy, bool>) { 341b56d1ec6SPeter Hawkins c.def_static( 342b56d1ec6SPeter Hawkins "get", 343b56d1ec6SPeter Hawkins [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) { 344b56d1ec6SPeter Hawkins std::vector<bool> values; 345b56d1ec6SPeter Hawkins for (nb::handle py_value : py_values) { 346b56d1ec6SPeter Hawkins int is_true = PyObject_IsTrue(py_value.ptr()); 347b56d1ec6SPeter Hawkins if (is_true < 0) { 348b56d1ec6SPeter Hawkins throw nb::python_error(); 349b56d1ec6SPeter Hawkins } 350b56d1ec6SPeter Hawkins values.push_back(is_true); 351b56d1ec6SPeter Hawkins } 352b56d1ec6SPeter Hawkins return getAttribute(values, ctx->getRef()); 353b56d1ec6SPeter Hawkins }, 354b56d1ec6SPeter Hawkins nb::arg("values"), nb::arg("context").none() = nb::none(), 355b56d1ec6SPeter Hawkins "Gets a uniqued dense array attribute"); 356b56d1ec6SPeter Hawkins } else { 357619fd8c2SJeff Niu c.def_static( 358619fd8c2SJeff Niu "get", 359619fd8c2SJeff Niu [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) { 3608dcb6722SIngo Müller return getAttribute(values, ctx->getRef()); 361619fd8c2SJeff Niu }, 362b56d1ec6SPeter Hawkins nb::arg("values"), nb::arg("context").none() = nb::none(), 363619fd8c2SJeff Niu "Gets a uniqued dense array attribute"); 364b56d1ec6SPeter Hawkins } 365619fd8c2SJeff Niu // Bind the array methods. 366133624acSJeff Niu c.def("__getitem__", [](DerivedT &arr, intptr_t i) { 367619fd8c2SJeff Niu if (i >= mlirDenseArrayGetNumElements(arr)) 368b56d1ec6SPeter Hawkins throw nb::index_error("DenseArray index out of range"); 369619fd8c2SJeff Niu return arr.getItem(i); 370619fd8c2SJeff Niu }); 371133624acSJeff Niu c.def("__len__", [](const DerivedT &arr) { 372619fd8c2SJeff Niu return mlirDenseArrayGetNumElements(arr); 373619fd8c2SJeff Niu }); 374133624acSJeff Niu c.def("__iter__", 375133624acSJeff Niu [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); 376b56d1ec6SPeter Hawkins c.def("__add__", [](DerivedT &arr, const nb::list &extras) { 377619fd8c2SJeff Niu std::vector<EltTy> values; 378619fd8c2SJeff Niu intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); 379b56d1ec6SPeter Hawkins values.reserve(numOldElements + nb::len(extras)); 380619fd8c2SJeff Niu for (intptr_t i = 0; i < numOldElements; ++i) 381619fd8c2SJeff Niu values.push_back(arr.getItem(i)); 382b56d1ec6SPeter Hawkins for (nb::handle attr : extras) 383619fd8c2SJeff Niu values.push_back(pyTryCast<EltTy>(attr)); 3848dcb6722SIngo Müller return getAttribute(values, arr.getContext()); 385619fd8c2SJeff Niu }); 386619fd8c2SJeff Niu } 3878dcb6722SIngo Müller 3888dcb6722SIngo Müller private: 3898dcb6722SIngo Müller static DerivedT getAttribute(const std::vector<EltTy> &values, 3908dcb6722SIngo Müller PyMlirContextRef ctx) { 3918dcb6722SIngo Müller if constexpr (std::is_same_v<EltTy, bool>) { 3928dcb6722SIngo Müller std::vector<int> intValues(values.begin(), values.end()); 3938dcb6722SIngo Müller MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(), 3948dcb6722SIngo Müller intValues.data()); 3958dcb6722SIngo Müller return DerivedT(ctx, attr); 3968dcb6722SIngo Müller } else { 3978dcb6722SIngo Müller MlirAttribute attr = 3988dcb6722SIngo Müller DerivedT::getAttribute(ctx->get(), values.size(), values.data()); 3998dcb6722SIngo Müller return DerivedT(ctx, attr); 4008dcb6722SIngo Müller } 4018dcb6722SIngo Müller } 402619fd8c2SJeff Niu }; 403619fd8c2SJeff Niu 404619fd8c2SJeff Niu /// Instantiate the python dense array classes. 405619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute 4068dcb6722SIngo Müller : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> { 407619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; 408619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseBoolArrayGet; 409619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseBoolArrayGetElement; 410619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseBoolArrayAttr"; 411619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseBoolArrayIterator"; 412619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 413619fd8c2SJeff Niu }; 414619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute 415619fd8c2SJeff Niu : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> { 416619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array; 417619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI8ArrayGet; 418619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI8ArrayGetElement; 419619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI8ArrayAttr"; 420619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI8ArrayIterator"; 421619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 422619fd8c2SJeff Niu }; 423619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute 424619fd8c2SJeff Niu : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> { 425619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array; 426619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI16ArrayGet; 427619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI16ArrayGetElement; 428619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI16ArrayAttr"; 429619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI16ArrayIterator"; 430619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 431619fd8c2SJeff Niu }; 432619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute 433619fd8c2SJeff Niu : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> { 434619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array; 435619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI32ArrayGet; 436619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI32ArrayGetElement; 437619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI32ArrayAttr"; 438619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI32ArrayIterator"; 439619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 440619fd8c2SJeff Niu }; 441619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute 442619fd8c2SJeff Niu : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> { 443619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array; 444619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseI64ArrayGet; 445619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseI64ArrayGetElement; 446619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseI64ArrayAttr"; 447619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseI64ArrayIterator"; 448619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 449619fd8c2SJeff Niu }; 450619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute 451619fd8c2SJeff Niu : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> { 452619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array; 453619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseF32ArrayGet; 454619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseF32ArrayGetElement; 455619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseF32ArrayAttr"; 456619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseF32ArrayIterator"; 457619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 458619fd8c2SJeff Niu }; 459619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute 460619fd8c2SJeff Niu : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> { 461619fd8c2SJeff Niu static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array; 462619fd8c2SJeff Niu static constexpr auto getAttribute = mlirDenseF64ArrayGet; 463619fd8c2SJeff Niu static constexpr auto getElement = mlirDenseF64ArrayGetElement; 464619fd8c2SJeff Niu static constexpr const char *pyClassName = "DenseF64ArrayAttr"; 465619fd8c2SJeff Niu static constexpr const char *pyIteratorName = "DenseF64ArrayIterator"; 466619fd8c2SJeff Niu using PyDenseArrayAttribute::PyDenseArrayAttribute; 467619fd8c2SJeff Niu }; 468619fd8c2SJeff Niu 469436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { 470436c6c9cSStella Laurenzo public: 471436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; 472436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "ArrayAttr"; 473436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 4749566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 4759566ee28Smax mlirArrayAttrGetTypeID; 476436c6c9cSStella Laurenzo 477436c6c9cSStella Laurenzo class PyArrayAttributeIterator { 478436c6c9cSStella Laurenzo public: 4791fc096afSMehdi Amini PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} 480436c6c9cSStella Laurenzo 481436c6c9cSStella Laurenzo PyArrayAttributeIterator &dunderIter() { return *this; } 482436c6c9cSStella Laurenzo 483974c1596SRahul Kayaith MlirAttribute dunderNext() { 484bca88952SJeff Niu // TODO: Throw is an inefficient way to stop iteration. 485bca88952SJeff Niu if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) 486b56d1ec6SPeter Hawkins throw nb::stop_iteration(); 487974c1596SRahul Kayaith return mlirArrayAttrGetElement(attr.get(), nextIndex++); 488436c6c9cSStella Laurenzo } 489436c6c9cSStella Laurenzo 490b56d1ec6SPeter Hawkins static void bind(nb::module_ &m) { 491b56d1ec6SPeter Hawkins nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator") 492436c6c9cSStella Laurenzo .def("__iter__", &PyArrayAttributeIterator::dunderIter) 493436c6c9cSStella Laurenzo .def("__next__", &PyArrayAttributeIterator::dunderNext); 494436c6c9cSStella Laurenzo } 495436c6c9cSStella Laurenzo 496436c6c9cSStella Laurenzo private: 497436c6c9cSStella Laurenzo PyAttribute attr; 498436c6c9cSStella Laurenzo int nextIndex = 0; 499436c6c9cSStella Laurenzo }; 500436c6c9cSStella Laurenzo 501974c1596SRahul Kayaith MlirAttribute getItem(intptr_t i) { 502974c1596SRahul Kayaith return mlirArrayAttrGetElement(*this, i); 503ed9e52f3SAlex Zinenko } 504ed9e52f3SAlex Zinenko 505436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 506436c6c9cSStella Laurenzo c.def_static( 507436c6c9cSStella Laurenzo "get", 508b56d1ec6SPeter Hawkins [](nb::list attributes, DefaultingPyMlirContext context) { 509436c6c9cSStella Laurenzo SmallVector<MlirAttribute> mlirAttributes; 510b56d1ec6SPeter Hawkins mlirAttributes.reserve(nb::len(attributes)); 511436c6c9cSStella Laurenzo for (auto attribute : attributes) { 512ed9e52f3SAlex Zinenko mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); 513436c6c9cSStella Laurenzo } 514436c6c9cSStella Laurenzo MlirAttribute attr = mlirArrayAttrGet( 515436c6c9cSStella Laurenzo context->get(), mlirAttributes.size(), mlirAttributes.data()); 516436c6c9cSStella Laurenzo return PyArrayAttribute(context->getRef(), attr); 517436c6c9cSStella Laurenzo }, 518b56d1ec6SPeter Hawkins nb::arg("attributes"), nb::arg("context").none() = nb::none(), 519436c6c9cSStella Laurenzo "Gets a uniqued Array attribute"); 520436c6c9cSStella Laurenzo c.def("__getitem__", 521436c6c9cSStella Laurenzo [](PyArrayAttribute &arr, intptr_t i) { 522436c6c9cSStella Laurenzo if (i >= mlirArrayAttrGetNumElements(arr)) 523b56d1ec6SPeter Hawkins throw nb::index_error("ArrayAttribute index out of range"); 524ed9e52f3SAlex Zinenko return arr.getItem(i); 525436c6c9cSStella Laurenzo }) 526436c6c9cSStella Laurenzo .def("__len__", 527436c6c9cSStella Laurenzo [](const PyArrayAttribute &arr) { 528436c6c9cSStella Laurenzo return mlirArrayAttrGetNumElements(arr); 529436c6c9cSStella Laurenzo }) 530436c6c9cSStella Laurenzo .def("__iter__", [](const PyArrayAttribute &arr) { 531436c6c9cSStella Laurenzo return PyArrayAttributeIterator(arr); 532436c6c9cSStella Laurenzo }); 533b56d1ec6SPeter Hawkins c.def("__add__", [](PyArrayAttribute arr, nb::list extras) { 534ed9e52f3SAlex Zinenko std::vector<MlirAttribute> attributes; 535ed9e52f3SAlex Zinenko intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); 536b56d1ec6SPeter Hawkins attributes.reserve(numOldElements + nb::len(extras)); 537ed9e52f3SAlex Zinenko for (intptr_t i = 0; i < numOldElements; ++i) 538ed9e52f3SAlex Zinenko attributes.push_back(arr.getItem(i)); 539b56d1ec6SPeter Hawkins for (nb::handle attr : extras) 540ed9e52f3SAlex Zinenko attributes.push_back(pyTryCast<PyAttribute>(attr)); 541ed9e52f3SAlex Zinenko MlirAttribute arrayAttr = mlirArrayAttrGet( 542ed9e52f3SAlex Zinenko arr.getContext()->get(), attributes.size(), attributes.data()); 543ed9e52f3SAlex Zinenko return PyArrayAttribute(arr.getContext(), arrayAttr); 544ed9e52f3SAlex Zinenko }); 545436c6c9cSStella Laurenzo } 546436c6c9cSStella Laurenzo }; 547436c6c9cSStella Laurenzo 548436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr. 549436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { 550436c6c9cSStella Laurenzo public: 551436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; 552436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FloatAttr"; 553436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 5549566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 5559566ee28Smax mlirFloatAttrGetTypeID; 556436c6c9cSStella Laurenzo 557436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 558436c6c9cSStella Laurenzo c.def_static( 559436c6c9cSStella Laurenzo "get", 560436c6c9cSStella Laurenzo [](PyType &type, double value, DefaultingPyLocation loc) { 5613ea4c501SRahul Kayaith PyMlirContext::ErrorCapture errors(loc->getContext()); 562436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); 5633ea4c501SRahul Kayaith if (mlirAttributeIsNull(attr)) 5643ea4c501SRahul Kayaith throw MLIRError("Invalid attribute", errors.take()); 565436c6c9cSStella Laurenzo return PyFloatAttribute(type.getContext(), attr); 566436c6c9cSStella Laurenzo }, 567b56d1ec6SPeter Hawkins nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(), 568436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a type"); 569436c6c9cSStella Laurenzo c.def_static( 570436c6c9cSStella Laurenzo "get_f32", 571436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 572436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 573436c6c9cSStella Laurenzo context->get(), mlirF32TypeGet(context->get()), value); 574436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 575436c6c9cSStella Laurenzo }, 576b56d1ec6SPeter Hawkins nb::arg("value"), nb::arg("context").none() = nb::none(), 577436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f32 type"); 578436c6c9cSStella Laurenzo c.def_static( 579436c6c9cSStella Laurenzo "get_f64", 580436c6c9cSStella Laurenzo [](double value, DefaultingPyMlirContext context) { 581436c6c9cSStella Laurenzo MlirAttribute attr = mlirFloatAttrDoubleGet( 582436c6c9cSStella Laurenzo context->get(), mlirF64TypeGet(context->get()), value); 583436c6c9cSStella Laurenzo return PyFloatAttribute(context->getRef(), attr); 584436c6c9cSStella Laurenzo }, 585b56d1ec6SPeter Hawkins nb::arg("value"), nb::arg("context").none() = nb::none(), 586436c6c9cSStella Laurenzo "Gets an uniqued float point attribute associated to a f64 type"); 587b56d1ec6SPeter Hawkins c.def_prop_ro("value", mlirFloatAttrGetValueDouble, 5882a5d4974SIngo Müller "Returns the value of the float attribute"); 5892a5d4974SIngo Müller c.def("__float__", mlirFloatAttrGetValueDouble, 5902a5d4974SIngo Müller "Converts the value of the float attribute to a Python float"); 591436c6c9cSStella Laurenzo } 592436c6c9cSStella Laurenzo }; 593436c6c9cSStella Laurenzo 594436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr. 595436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { 596436c6c9cSStella Laurenzo public: 597436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; 598436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "IntegerAttr"; 599436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 600436c6c9cSStella Laurenzo 601436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 602436c6c9cSStella Laurenzo c.def_static( 603436c6c9cSStella Laurenzo "get", 604436c6c9cSStella Laurenzo [](PyType &type, int64_t value) { 605436c6c9cSStella Laurenzo MlirAttribute attr = mlirIntegerAttrGet(type, value); 606436c6c9cSStella Laurenzo return PyIntegerAttribute(type.getContext(), attr); 607436c6c9cSStella Laurenzo }, 608b56d1ec6SPeter Hawkins nb::arg("type"), nb::arg("value"), 609436c6c9cSStella Laurenzo "Gets an uniqued integer attribute associated to a type"); 610b56d1ec6SPeter Hawkins c.def_prop_ro("value", toPyInt, 6112a5d4974SIngo Müller "Returns the value of the integer attribute"); 6122a5d4974SIngo Müller c.def("__int__", toPyInt, 6132a5d4974SIngo Müller "Converts the value of the integer attribute to a Python int"); 614b56d1ec6SPeter Hawkins c.def_prop_ro_static("static_typeid", 615b56d1ec6SPeter Hawkins [](nb::object & /*class*/) -> MlirTypeID { 6162a5d4974SIngo Müller return mlirIntegerAttrGetTypeID(); 6172a5d4974SIngo Müller }); 6182a5d4974SIngo Müller } 6192a5d4974SIngo Müller 6202a5d4974SIngo Müller private: 621b56d1ec6SPeter Hawkins static int64_t toPyInt(PyIntegerAttribute &self) { 622e9db306dSrkayaith MlirType type = mlirAttributeGetType(self); 623e9db306dSrkayaith if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) 624436c6c9cSStella Laurenzo return mlirIntegerAttrGetValueInt(self); 625e9db306dSrkayaith if (mlirIntegerTypeIsSigned(type)) 626e9db306dSrkayaith return mlirIntegerAttrGetValueSInt(self); 627e9db306dSrkayaith return mlirIntegerAttrGetValueUInt(self); 628436c6c9cSStella Laurenzo } 629436c6c9cSStella Laurenzo }; 630436c6c9cSStella Laurenzo 631436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr. 632436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { 633436c6c9cSStella Laurenzo public: 634436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; 635436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "BoolAttr"; 636436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 637436c6c9cSStella Laurenzo 638436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 639436c6c9cSStella Laurenzo c.def_static( 640436c6c9cSStella Laurenzo "get", 641436c6c9cSStella Laurenzo [](bool value, DefaultingPyMlirContext context) { 642436c6c9cSStella Laurenzo MlirAttribute attr = mlirBoolAttrGet(context->get(), value); 643436c6c9cSStella Laurenzo return PyBoolAttribute(context->getRef(), attr); 644436c6c9cSStella Laurenzo }, 645b56d1ec6SPeter Hawkins nb::arg("value"), nb::arg("context").none() = nb::none(), 646436c6c9cSStella Laurenzo "Gets an uniqued bool attribute"); 647b56d1ec6SPeter Hawkins c.def_prop_ro("value", mlirBoolAttrGetValue, 648436c6c9cSStella Laurenzo "Returns the value of the bool attribute"); 6492a5d4974SIngo Müller c.def("__bool__", mlirBoolAttrGetValue, 6502a5d4974SIngo Müller "Converts the value of the bool attribute to a Python bool"); 651436c6c9cSStella Laurenzo } 652436c6c9cSStella Laurenzo }; 653436c6c9cSStella Laurenzo 6544eee9ef9Smax class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> { 6554eee9ef9Smax public: 6564eee9ef9Smax static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef; 6574eee9ef9Smax static constexpr const char *pyClassName = "SymbolRefAttr"; 6584eee9ef9Smax using PyConcreteAttribute::PyConcreteAttribute; 6594eee9ef9Smax 6604eee9ef9Smax static MlirAttribute fromList(const std::vector<std::string> &symbols, 6614eee9ef9Smax PyMlirContext &context) { 6624eee9ef9Smax if (symbols.empty()) 6634eee9ef9Smax throw std::runtime_error("SymbolRefAttr must be composed of at least " 6644eee9ef9Smax "one symbol."); 6654eee9ef9Smax MlirStringRef rootSymbol = toMlirStringRef(symbols[0]); 6664eee9ef9Smax SmallVector<MlirAttribute, 3> referenceAttrs; 6674eee9ef9Smax for (size_t i = 1; i < symbols.size(); ++i) { 6684eee9ef9Smax referenceAttrs.push_back( 6694eee9ef9Smax mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i]))); 6704eee9ef9Smax } 6714eee9ef9Smax return mlirSymbolRefAttrGet(context.get(), rootSymbol, 6724eee9ef9Smax referenceAttrs.size(), referenceAttrs.data()); 6734eee9ef9Smax } 6744eee9ef9Smax 6754eee9ef9Smax static void bindDerived(ClassTy &c) { 6764eee9ef9Smax c.def_static( 6774eee9ef9Smax "get", 6784eee9ef9Smax [](const std::vector<std::string> &symbols, 6794eee9ef9Smax DefaultingPyMlirContext context) { 6804eee9ef9Smax return PySymbolRefAttribute::fromList(symbols, context.resolve()); 6814eee9ef9Smax }, 682b56d1ec6SPeter Hawkins nb::arg("symbols"), nb::arg("context").none() = nb::none(), 6834eee9ef9Smax "Gets a uniqued SymbolRef attribute from a list of symbol names"); 684b56d1ec6SPeter Hawkins c.def_prop_ro( 6854eee9ef9Smax "value", 6864eee9ef9Smax [](PySymbolRefAttribute &self) { 6874eee9ef9Smax std::vector<std::string> symbols = { 6884eee9ef9Smax unwrap(mlirSymbolRefAttrGetRootReference(self)).str()}; 6894eee9ef9Smax for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); 6904eee9ef9Smax ++i) 6914eee9ef9Smax symbols.push_back( 6924eee9ef9Smax unwrap(mlirSymbolRefAttrGetRootReference( 6934eee9ef9Smax mlirSymbolRefAttrGetNestedReference(self, i))) 6944eee9ef9Smax .str()); 6954eee9ef9Smax return symbols; 6964eee9ef9Smax }, 6974eee9ef9Smax "Returns the value of the SymbolRef attribute as a list[str]"); 6984eee9ef9Smax } 6994eee9ef9Smax }; 7004eee9ef9Smax 701436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute 702436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { 703436c6c9cSStella Laurenzo public: 704436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; 705436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "FlatSymbolRefAttr"; 706436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 707436c6c9cSStella Laurenzo 708436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 709436c6c9cSStella Laurenzo c.def_static( 710436c6c9cSStella Laurenzo "get", 711436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 712436c6c9cSStella Laurenzo MlirAttribute attr = 713436c6c9cSStella Laurenzo mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); 714436c6c9cSStella Laurenzo return PyFlatSymbolRefAttribute(context->getRef(), attr); 715436c6c9cSStella Laurenzo }, 716b56d1ec6SPeter Hawkins nb::arg("value"), nb::arg("context").none() = nb::none(), 717436c6c9cSStella Laurenzo "Gets a uniqued FlatSymbolRef attribute"); 718b56d1ec6SPeter Hawkins c.def_prop_ro( 719436c6c9cSStella Laurenzo "value", 720436c6c9cSStella Laurenzo [](PyFlatSymbolRefAttribute &self) { 721436c6c9cSStella Laurenzo MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); 722b56d1ec6SPeter Hawkins return nb::str(stringRef.data, stringRef.length); 723436c6c9cSStella Laurenzo }, 724436c6c9cSStella Laurenzo "Returns the value of the FlatSymbolRef attribute as a string"); 725436c6c9cSStella Laurenzo } 726436c6c9cSStella Laurenzo }; 727436c6c9cSStella Laurenzo 7285c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> { 7295c3861b2SYun Long public: 7305c3861b2SYun Long static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; 7315c3861b2SYun Long static constexpr const char *pyClassName = "OpaqueAttr"; 7325c3861b2SYun Long using PyConcreteAttribute::PyConcreteAttribute; 7339566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 7349566ee28Smax mlirOpaqueAttrGetTypeID; 7355c3861b2SYun Long 7365c3861b2SYun Long static void bindDerived(ClassTy &c) { 7375c3861b2SYun Long c.def_static( 7385c3861b2SYun Long "get", 739b56d1ec6SPeter Hawkins [](std::string dialectNamespace, nb_buffer buffer, PyType &type, 7405c3861b2SYun Long DefaultingPyMlirContext context) { 741b56d1ec6SPeter Hawkins const nb_buffer_info bufferInfo = buffer.request(); 7425c3861b2SYun Long intptr_t bufferSize = bufferInfo.size; 7435c3861b2SYun Long MlirAttribute attr = mlirOpaqueAttrGet( 7445c3861b2SYun Long context->get(), toMlirStringRef(dialectNamespace), bufferSize, 7455c3861b2SYun Long static_cast<char *>(bufferInfo.ptr), type); 7465c3861b2SYun Long return PyOpaqueAttribute(context->getRef(), attr); 7475c3861b2SYun Long }, 748b56d1ec6SPeter Hawkins nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"), 749b56d1ec6SPeter Hawkins nb::arg("context").none() = nb::none(), "Gets an Opaque attribute."); 750b56d1ec6SPeter Hawkins c.def_prop_ro( 7515c3861b2SYun Long "dialect_namespace", 7525c3861b2SYun Long [](PyOpaqueAttribute &self) { 7535c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); 754b56d1ec6SPeter Hawkins return nb::str(stringRef.data, stringRef.length); 7555c3861b2SYun Long }, 7565c3861b2SYun Long "Returns the dialect namespace for the Opaque attribute as a string"); 757b56d1ec6SPeter Hawkins c.def_prop_ro( 7585c3861b2SYun Long "data", 7595c3861b2SYun Long [](PyOpaqueAttribute &self) { 7605c3861b2SYun Long MlirStringRef stringRef = mlirOpaqueAttrGetData(self); 761b56d1ec6SPeter Hawkins return nb::bytes(stringRef.data, stringRef.length); 7625c3861b2SYun Long }, 76362bf6c2eSChris Jones "Returns the data for the Opaqued attributes as `bytes`"); 7645c3861b2SYun Long } 7655c3861b2SYun Long }; 7665c3861b2SYun Long 767436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { 768436c6c9cSStella Laurenzo public: 769436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; 770436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "StringAttr"; 771436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 7729566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 7739566ee28Smax mlirStringAttrGetTypeID; 774436c6c9cSStella Laurenzo 775436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 776436c6c9cSStella Laurenzo c.def_static( 777436c6c9cSStella Laurenzo "get", 778436c6c9cSStella Laurenzo [](std::string value, DefaultingPyMlirContext context) { 779436c6c9cSStella Laurenzo MlirAttribute attr = 780436c6c9cSStella Laurenzo mlirStringAttrGet(context->get(), toMlirStringRef(value)); 781436c6c9cSStella Laurenzo return PyStringAttribute(context->getRef(), attr); 782436c6c9cSStella Laurenzo }, 783b56d1ec6SPeter Hawkins nb::arg("value"), nb::arg("context").none() = nb::none(), 784b56d1ec6SPeter Hawkins "Gets a uniqued string attribute"); 785b56d1ec6SPeter Hawkins c.def_static( 786b56d1ec6SPeter Hawkins "get", 787b56d1ec6SPeter Hawkins [](nb::bytes value, DefaultingPyMlirContext context) { 788b56d1ec6SPeter Hawkins MlirAttribute attr = 789b56d1ec6SPeter Hawkins mlirStringAttrGet(context->get(), toMlirStringRef(value)); 790b56d1ec6SPeter Hawkins return PyStringAttribute(context->getRef(), attr); 791b56d1ec6SPeter Hawkins }, 792b56d1ec6SPeter Hawkins nb::arg("value"), nb::arg("context").none() = nb::none(), 793436c6c9cSStella Laurenzo "Gets a uniqued string attribute"); 794436c6c9cSStella Laurenzo c.def_static( 795436c6c9cSStella Laurenzo "get_typed", 796436c6c9cSStella Laurenzo [](PyType &type, std::string value) { 797436c6c9cSStella Laurenzo MlirAttribute attr = 798436c6c9cSStella Laurenzo mlirStringAttrTypedGet(type, toMlirStringRef(value)); 799436c6c9cSStella Laurenzo return PyStringAttribute(type.getContext(), attr); 800436c6c9cSStella Laurenzo }, 801b56d1ec6SPeter Hawkins nb::arg("type"), nb::arg("value"), 802436c6c9cSStella Laurenzo "Gets a uniqued string attribute associated to a type"); 803b56d1ec6SPeter Hawkins c.def_prop_ro( 8049f533548SIngo Müller "value", 8059f533548SIngo Müller [](PyStringAttribute &self) { 8069f533548SIngo Müller MlirStringRef stringRef = mlirStringAttrGetValue(self); 807b56d1ec6SPeter Hawkins return nb::str(stringRef.data, stringRef.length); 8089f533548SIngo Müller }, 809436c6c9cSStella Laurenzo "Returns the value of the string attribute"); 810b56d1ec6SPeter Hawkins c.def_prop_ro( 81162bf6c2eSChris Jones "value_bytes", 81262bf6c2eSChris Jones [](PyStringAttribute &self) { 81362bf6c2eSChris Jones MlirStringRef stringRef = mlirStringAttrGetValue(self); 814b56d1ec6SPeter Hawkins return nb::bytes(stringRef.data, stringRef.length); 81562bf6c2eSChris Jones }, 81662bf6c2eSChris Jones "Returns the value of the string attribute as `bytes`"); 817436c6c9cSStella Laurenzo } 818436c6c9cSStella Laurenzo }; 819436c6c9cSStella Laurenzo 820436c6c9cSStella Laurenzo // TODO: Support construction of string elements. 821436c6c9cSStella Laurenzo class PyDenseElementsAttribute 822436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseElementsAttribute> { 823436c6c9cSStella Laurenzo public: 824436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; 825436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseElementsAttr"; 826436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 827436c6c9cSStella Laurenzo 828436c6c9cSStella Laurenzo static PyDenseElementsAttribute 829b56d1ec6SPeter Hawkins getFromList(nb::list attributes, std::optional<PyType> explicitType, 830c912f0e7Spranavm-nvidia DefaultingPyMlirContext contextWrapper) { 831b56d1ec6SPeter Hawkins const size_t numAttributes = nb::len(attributes); 832c912f0e7Spranavm-nvidia if (numAttributes == 0) 833b56d1ec6SPeter Hawkins throw nb::value_error("Attributes list must be non-empty."); 834c912f0e7Spranavm-nvidia 835c912f0e7Spranavm-nvidia MlirType shapedType; 836c912f0e7Spranavm-nvidia if (explicitType) { 837c912f0e7Spranavm-nvidia if ((!mlirTypeIsAShaped(*explicitType) || 838c912f0e7Spranavm-nvidia !mlirShapedTypeHasStaticShape(*explicitType))) { 839c912f0e7Spranavm-nvidia 840c912f0e7Spranavm-nvidia std::string message; 841c912f0e7Spranavm-nvidia llvm::raw_string_ostream os(message); 842c912f0e7Spranavm-nvidia os << "Expected a static ShapedType for the shaped_type parameter: " 843b56d1ec6SPeter Hawkins << nb::cast<std::string>(nb::repr(nb::cast(*explicitType))); 844b56d1ec6SPeter Hawkins throw nb::value_error(message.c_str()); 845c912f0e7Spranavm-nvidia } 846c912f0e7Spranavm-nvidia shapedType = *explicitType; 847c912f0e7Spranavm-nvidia } else { 8489cbc1f29SHan-Chung Wang SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)}; 849c912f0e7Spranavm-nvidia shapedType = mlirRankedTensorTypeGet( 850c912f0e7Spranavm-nvidia shape.size(), shape.data(), 851c912f0e7Spranavm-nvidia mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])), 852c912f0e7Spranavm-nvidia mlirAttributeGetNull()); 853c912f0e7Spranavm-nvidia } 854c912f0e7Spranavm-nvidia 855c912f0e7Spranavm-nvidia SmallVector<MlirAttribute> mlirAttributes; 856c912f0e7Spranavm-nvidia mlirAttributes.reserve(numAttributes); 857b56d1ec6SPeter Hawkins for (const nb::handle &attribute : attributes) { 858c912f0e7Spranavm-nvidia MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute); 859c912f0e7Spranavm-nvidia MlirType attrType = mlirAttributeGetType(mlirAttribute); 860c912f0e7Spranavm-nvidia mlirAttributes.push_back(mlirAttribute); 861c912f0e7Spranavm-nvidia 862c912f0e7Spranavm-nvidia if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) { 863c912f0e7Spranavm-nvidia std::string message; 864c912f0e7Spranavm-nvidia llvm::raw_string_ostream os(message); 865c912f0e7Spranavm-nvidia os << "All attributes must be of the same type and match " 866b56d1ec6SPeter Hawkins << "the type parameter: expected=" 867b56d1ec6SPeter Hawkins << nb::cast<std::string>(nb::repr(nb::cast(shapedType))) 868b56d1ec6SPeter Hawkins << ", but got=" 869b56d1ec6SPeter Hawkins << nb::cast<std::string>(nb::repr(nb::cast(attrType))); 870b56d1ec6SPeter Hawkins throw nb::value_error(message.c_str()); 871c912f0e7Spranavm-nvidia } 872c912f0e7Spranavm-nvidia } 873c912f0e7Spranavm-nvidia 874c912f0e7Spranavm-nvidia MlirAttribute elements = mlirDenseElementsAttrGet( 875c912f0e7Spranavm-nvidia shapedType, mlirAttributes.size(), mlirAttributes.data()); 876c912f0e7Spranavm-nvidia 877c912f0e7Spranavm-nvidia return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 878c912f0e7Spranavm-nvidia } 879c912f0e7Spranavm-nvidia 880c912f0e7Spranavm-nvidia static PyDenseElementsAttribute 881b56d1ec6SPeter Hawkins getFromBuffer(nb_buffer array, bool signless, 8820a81ace0SKazu Hirata std::optional<PyType> explicitType, 8830a81ace0SKazu Hirata std::optional<std::vector<int64_t>> explicitShape, 884436c6c9cSStella Laurenzo DefaultingPyMlirContext contextWrapper) { 885436c6c9cSStella Laurenzo // Request a contiguous view. In exotic cases, this will cause a copy. 88671a25454SPeter Hawkins int flags = PyBUF_ND; 88771a25454SPeter Hawkins if (!explicitType) { 88871a25454SPeter Hawkins flags |= PyBUF_FORMAT; 88971a25454SPeter Hawkins } 89071a25454SPeter Hawkins Py_buffer view; 89171a25454SPeter Hawkins if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) { 892b56d1ec6SPeter Hawkins throw nb::python_error(); 893436c6c9cSStella Laurenzo } 89471a25454SPeter Hawkins auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); 895436c6c9cSStella Laurenzo 896436c6c9cSStella Laurenzo MlirContext context = contextWrapper->get(); 8971824e45cSKasper Nielsen MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, 8981824e45cSKasper Nielsen explicitShape, context); 8995d6d30edSStella Laurenzo if (mlirAttributeIsNull(attr)) { 9005d6d30edSStella Laurenzo throw std::invalid_argument( 9015d6d30edSStella Laurenzo "DenseElementsAttr could not be constructed from the given buffer. " 9025d6d30edSStella Laurenzo "This may mean that the Python buffer layout does not match that " 9035d6d30edSStella Laurenzo "MLIR expected layout and is a bug."); 9045d6d30edSStella Laurenzo } 9055d6d30edSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), attr); 9065d6d30edSStella Laurenzo } 907436c6c9cSStella Laurenzo 9081fc096afSMehdi Amini static PyDenseElementsAttribute getSplat(const PyType &shapedType, 909436c6c9cSStella Laurenzo PyAttribute &elementAttr) { 910436c6c9cSStella Laurenzo auto contextWrapper = 911436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirTypeGetContext(shapedType)); 912436c6c9cSStella Laurenzo if (!mlirAttributeIsAInteger(elementAttr) && 913436c6c9cSStella Laurenzo !mlirAttributeIsAFloat(elementAttr)) { 914436c6c9cSStella Laurenzo std::string message = "Illegal element type for DenseElementsAttr: "; 915b56d1ec6SPeter Hawkins message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr)))); 916b56d1ec6SPeter Hawkins throw nb::value_error(message.c_str()); 917436c6c9cSStella Laurenzo } 918436c6c9cSStella Laurenzo if (!mlirTypeIsAShaped(shapedType) || 919436c6c9cSStella Laurenzo !mlirShapedTypeHasStaticShape(shapedType)) { 920436c6c9cSStella Laurenzo std::string message = 921436c6c9cSStella Laurenzo "Expected a static ShapedType for the shaped_type parameter: "; 922b56d1ec6SPeter Hawkins message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType)))); 923b56d1ec6SPeter Hawkins throw nb::value_error(message.c_str()); 924436c6c9cSStella Laurenzo } 925436c6c9cSStella Laurenzo MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); 926436c6c9cSStella Laurenzo MlirType attrType = mlirAttributeGetType(elementAttr); 927436c6c9cSStella Laurenzo if (!mlirTypeEqual(shapedElementType, attrType)) { 928436c6c9cSStella Laurenzo std::string message = 929436c6c9cSStella Laurenzo "Shaped element type and attribute type must be equal: shaped="; 930b56d1ec6SPeter Hawkins message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType)))); 931436c6c9cSStella Laurenzo message.append(", element="); 932b56d1ec6SPeter Hawkins message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr)))); 933b56d1ec6SPeter Hawkins throw nb::value_error(message.c_str()); 934436c6c9cSStella Laurenzo } 935436c6c9cSStella Laurenzo 936436c6c9cSStella Laurenzo MlirAttribute elements = 937436c6c9cSStella Laurenzo mlirDenseElementsAttrSplatGet(shapedType, elementAttr); 938436c6c9cSStella Laurenzo return PyDenseElementsAttribute(contextWrapper->getRef(), elements); 939436c6c9cSStella Laurenzo } 940436c6c9cSStella Laurenzo 941436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } 942436c6c9cSStella Laurenzo 943b56d1ec6SPeter Hawkins std::unique_ptr<nb_buffer_info> accessBuffer() { 944436c6c9cSStella Laurenzo MlirType shapedType = mlirAttributeGetType(*this); 945436c6c9cSStella Laurenzo MlirType elementType = mlirShapedTypeGetElementType(shapedType); 9465d6d30edSStella Laurenzo std::string format; 947436c6c9cSStella Laurenzo 948436c6c9cSStella Laurenzo if (mlirTypeIsAF32(elementType)) { 949436c6c9cSStella Laurenzo // f32 9505d6d30edSStella Laurenzo return bufferInfo<float>(shapedType); 95102b6fb21SMehdi Amini } 95202b6fb21SMehdi Amini if (mlirTypeIsAF64(elementType)) { 953436c6c9cSStella Laurenzo // f64 9545d6d30edSStella Laurenzo return bufferInfo<double>(shapedType); 955bb56c2b3SMehdi Amini } 956bb56c2b3SMehdi Amini if (mlirTypeIsAF16(elementType)) { 9575d6d30edSStella Laurenzo // f16 9585d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType, "e"); 959bb56c2b3SMehdi Amini } 960ef1b735dSmax if (mlirTypeIsAIndex(elementType)) { 961ef1b735dSmax // Same as IndexType::kInternalStorageBitWidth 962ef1b735dSmax return bufferInfo<int64_t>(shapedType); 963ef1b735dSmax } 964bb56c2b3SMehdi Amini if (mlirTypeIsAInteger(elementType) && 965436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 32) { 966436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 967436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 968436c6c9cSStella Laurenzo // i32 9695d6d30edSStella Laurenzo return bufferInfo<int32_t>(shapedType); 970e5639b3fSMehdi Amini } 971e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 972436c6c9cSStella Laurenzo // unsigned i32 9735d6d30edSStella Laurenzo return bufferInfo<uint32_t>(shapedType); 974436c6c9cSStella Laurenzo } 975436c6c9cSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 976436c6c9cSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 64) { 977436c6c9cSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 978436c6c9cSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 979436c6c9cSStella Laurenzo // i64 9805d6d30edSStella Laurenzo return bufferInfo<int64_t>(shapedType); 981e5639b3fSMehdi Amini } 982e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 983436c6c9cSStella Laurenzo // unsigned i64 9845d6d30edSStella Laurenzo return bufferInfo<uint64_t>(shapedType); 9855d6d30edSStella Laurenzo } 9865d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 9875d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 8) { 9885d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 9895d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 9905d6d30edSStella Laurenzo // i8 9915d6d30edSStella Laurenzo return bufferInfo<int8_t>(shapedType); 992e5639b3fSMehdi Amini } 993e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 9945d6d30edSStella Laurenzo // unsigned i8 9955d6d30edSStella Laurenzo return bufferInfo<uint8_t>(shapedType); 9965d6d30edSStella Laurenzo } 9975d6d30edSStella Laurenzo } else if (mlirTypeIsAInteger(elementType) && 9985d6d30edSStella Laurenzo mlirIntegerTypeGetWidth(elementType) == 16) { 9995d6d30edSStella Laurenzo if (mlirIntegerTypeIsSignless(elementType) || 10005d6d30edSStella Laurenzo mlirIntegerTypeIsSigned(elementType)) { 10015d6d30edSStella Laurenzo // i16 10025d6d30edSStella Laurenzo return bufferInfo<int16_t>(shapedType); 1003e5639b3fSMehdi Amini } 1004e5639b3fSMehdi Amini if (mlirIntegerTypeIsUnsigned(elementType)) { 10055d6d30edSStella Laurenzo // unsigned i16 10065d6d30edSStella Laurenzo return bufferInfo<uint16_t>(shapedType); 1007436c6c9cSStella Laurenzo } 10081824e45cSKasper Nielsen } else if (mlirTypeIsAInteger(elementType) && 10091824e45cSKasper Nielsen mlirIntegerTypeGetWidth(elementType) == 1) { 10101824e45cSKasper Nielsen // i1 / bool 10111824e45cSKasper Nielsen // We can not send the buffer directly back to Python, because the i1 10121824e45cSKasper Nielsen // values are bitpacked within MLIR. We call numpy's unpackbits function 10131824e45cSKasper Nielsen // to convert the bytes. 10141824e45cSKasper Nielsen return getBooleanBufferFromBitpackedAttribute(); 1015436c6c9cSStella Laurenzo } 1016436c6c9cSStella Laurenzo 1017c5f445d1SStella Laurenzo // TODO: Currently crashes the program. 10185d6d30edSStella Laurenzo // Reported as https://github.com/pybind/pybind11/issues/3336 1019c5f445d1SStella Laurenzo throw std::invalid_argument( 1020c5f445d1SStella Laurenzo "unsupported data type for conversion to Python buffer"); 1021436c6c9cSStella Laurenzo } 1022436c6c9cSStella Laurenzo 1023436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1024b56d1ec6SPeter Hawkins #if PY_VERSION_HEX < 0x03090000 1025b56d1ec6SPeter Hawkins PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr()); 1026b56d1ec6SPeter Hawkins tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer; 1027b56d1ec6SPeter Hawkins tp->tp_as_buffer->bf_releasebuffer = 1028b56d1ec6SPeter Hawkins PyDenseElementsAttribute::bf_releasebuffer; 1029b56d1ec6SPeter Hawkins #endif 1030436c6c9cSStella Laurenzo c.def("__len__", &PyDenseElementsAttribute::dunderLen) 1031436c6c9cSStella Laurenzo .def_static("get", PyDenseElementsAttribute::getFromBuffer, 1032b56d1ec6SPeter Hawkins nb::arg("array"), nb::arg("signless") = true, 1033b56d1ec6SPeter Hawkins nb::arg("type").none() = nb::none(), 1034b56d1ec6SPeter Hawkins nb::arg("shape").none() = nb::none(), 1035b56d1ec6SPeter Hawkins nb::arg("context").none() = nb::none(), 10365d6d30edSStella Laurenzo kDenseElementsAttrGetDocstring) 1037c912f0e7Spranavm-nvidia .def_static("get", PyDenseElementsAttribute::getFromList, 1038b56d1ec6SPeter Hawkins nb::arg("attrs"), nb::arg("type").none() = nb::none(), 1039b56d1ec6SPeter Hawkins nb::arg("context").none() = nb::none(), 1040c912f0e7Spranavm-nvidia kDenseElementsAttrGetFromListDocstring) 1041436c6c9cSStella Laurenzo .def_static("get_splat", PyDenseElementsAttribute::getSplat, 1042b56d1ec6SPeter Hawkins nb::arg("shaped_type"), nb::arg("element_attr"), 1043436c6c9cSStella Laurenzo "Gets a DenseElementsAttr where all values are the same") 1044b56d1ec6SPeter Hawkins .def_prop_ro("is_splat", 1045436c6c9cSStella Laurenzo [](PyDenseElementsAttribute &self) -> bool { 1046436c6c9cSStella Laurenzo return mlirDenseElementsAttrIsSplat(self); 1047436c6c9cSStella Laurenzo }) 1048b56d1ec6SPeter Hawkins .def("get_splat_value", [](PyDenseElementsAttribute &self) { 1049974c1596SRahul Kayaith if (!mlirDenseElementsAttrIsSplat(self)) 1050b56d1ec6SPeter Hawkins throw nb::value_error( 105191259963SAdam Paszke "get_splat_value called on a non-splat attribute"); 1052974c1596SRahul Kayaith return mlirDenseElementsAttrGetSplatValue(self); 1053b56d1ec6SPeter Hawkins }); 1054436c6c9cSStella Laurenzo } 1055436c6c9cSStella Laurenzo 1056b56d1ec6SPeter Hawkins static PyType_Slot slots[]; 1057b56d1ec6SPeter Hawkins 1058436c6c9cSStella Laurenzo private: 1059b56d1ec6SPeter Hawkins static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags); 1060b56d1ec6SPeter Hawkins static void bf_releasebuffer(PyObject *, Py_buffer *buffer); 1061b56d1ec6SPeter Hawkins 106271a25454SPeter Hawkins static bool isUnsignedIntegerFormat(std::string_view format) { 1063436c6c9cSStella Laurenzo if (format.empty()) 1064436c6c9cSStella Laurenzo return false; 1065436c6c9cSStella Laurenzo char code = format[0]; 1066436c6c9cSStella Laurenzo return code == 'I' || code == 'B' || code == 'H' || code == 'L' || 1067436c6c9cSStella Laurenzo code == 'Q'; 1068436c6c9cSStella Laurenzo } 1069436c6c9cSStella Laurenzo 107071a25454SPeter Hawkins static bool isSignedIntegerFormat(std::string_view format) { 1071436c6c9cSStella Laurenzo if (format.empty()) 1072436c6c9cSStella Laurenzo return false; 1073436c6c9cSStella Laurenzo char code = format[0]; 1074436c6c9cSStella Laurenzo return code == 'i' || code == 'b' || code == 'h' || code == 'l' || 1075436c6c9cSStella Laurenzo code == 'q'; 1076436c6c9cSStella Laurenzo } 1077436c6c9cSStella Laurenzo 10781824e45cSKasper Nielsen static MlirType 10791824e45cSKasper Nielsen getShapedType(std::optional<MlirType> bulkLoadElementType, 10801824e45cSKasper Nielsen std::optional<std::vector<int64_t>> explicitShape, 10811824e45cSKasper Nielsen Py_buffer &view) { 10821824e45cSKasper Nielsen SmallVector<int64_t> shape; 10831824e45cSKasper Nielsen if (explicitShape) { 10841824e45cSKasper Nielsen shape.append(explicitShape->begin(), explicitShape->end()); 10851824e45cSKasper Nielsen } else { 10861824e45cSKasper Nielsen shape.append(view.shape, view.shape + view.ndim); 10871824e45cSKasper Nielsen } 10881824e45cSKasper Nielsen 10891824e45cSKasper Nielsen if (mlirTypeIsAShaped(*bulkLoadElementType)) { 10901824e45cSKasper Nielsen if (explicitShape) { 10911824e45cSKasper Nielsen throw std::invalid_argument("Shape can only be specified explicitly " 10921824e45cSKasper Nielsen "when the type is not a shaped type."); 10931824e45cSKasper Nielsen } 10941824e45cSKasper Nielsen return *bulkLoadElementType; 10951824e45cSKasper Nielsen } else { 10961824e45cSKasper Nielsen MlirAttribute encodingAttr = mlirAttributeGetNull(); 10971824e45cSKasper Nielsen return mlirRankedTensorTypeGet(shape.size(), shape.data(), 10981824e45cSKasper Nielsen *bulkLoadElementType, encodingAttr); 10991824e45cSKasper Nielsen } 11001824e45cSKasper Nielsen } 11011824e45cSKasper Nielsen 11021824e45cSKasper Nielsen static MlirAttribute getAttributeFromBuffer( 11031824e45cSKasper Nielsen Py_buffer &view, bool signless, std::optional<PyType> explicitType, 11041824e45cSKasper Nielsen std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) { 11051824e45cSKasper Nielsen // Detect format codes that are suitable for bulk loading. This includes 11061824e45cSKasper Nielsen // all byte aligned integer and floating point types up to 8 bytes. 11071824e45cSKasper Nielsen // Notably, this excludes exotics types which do not have a direct 11081824e45cSKasper Nielsen // representation in the buffer protocol (i.e. complex, etc). 11091824e45cSKasper Nielsen std::optional<MlirType> bulkLoadElementType; 11101824e45cSKasper Nielsen if (explicitType) { 11111824e45cSKasper Nielsen bulkLoadElementType = *explicitType; 11121824e45cSKasper Nielsen } else { 11131824e45cSKasper Nielsen std::string_view format(view.format); 11141824e45cSKasper Nielsen if (format == "f") { 11151824e45cSKasper Nielsen // f32 11161824e45cSKasper Nielsen assert(view.itemsize == 4 && "mismatched array itemsize"); 11171824e45cSKasper Nielsen bulkLoadElementType = mlirF32TypeGet(context); 11181824e45cSKasper Nielsen } else if (format == "d") { 11191824e45cSKasper Nielsen // f64 11201824e45cSKasper Nielsen assert(view.itemsize == 8 && "mismatched array itemsize"); 11211824e45cSKasper Nielsen bulkLoadElementType = mlirF64TypeGet(context); 11221824e45cSKasper Nielsen } else if (format == "e") { 11231824e45cSKasper Nielsen // f16 11241824e45cSKasper Nielsen assert(view.itemsize == 2 && "mismatched array itemsize"); 11251824e45cSKasper Nielsen bulkLoadElementType = mlirF16TypeGet(context); 11261824e45cSKasper Nielsen } else if (format == "?") { 11271824e45cSKasper Nielsen // i1 11281824e45cSKasper Nielsen // The i1 type needs to be bit-packed, so we will handle it seperately 11291824e45cSKasper Nielsen return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, 11301824e45cSKasper Nielsen context); 11311824e45cSKasper Nielsen } else if (isSignedIntegerFormat(format)) { 11321824e45cSKasper Nielsen if (view.itemsize == 4) { 11331824e45cSKasper Nielsen // i32 11341824e45cSKasper Nielsen bulkLoadElementType = signless 11351824e45cSKasper Nielsen ? mlirIntegerTypeGet(context, 32) 11361824e45cSKasper Nielsen : mlirIntegerTypeSignedGet(context, 32); 11371824e45cSKasper Nielsen } else if (view.itemsize == 8) { 11381824e45cSKasper Nielsen // i64 11391824e45cSKasper Nielsen bulkLoadElementType = signless 11401824e45cSKasper Nielsen ? mlirIntegerTypeGet(context, 64) 11411824e45cSKasper Nielsen : mlirIntegerTypeSignedGet(context, 64); 11421824e45cSKasper Nielsen } else if (view.itemsize == 1) { 11431824e45cSKasper Nielsen // i8 11441824e45cSKasper Nielsen bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) 11451824e45cSKasper Nielsen : mlirIntegerTypeSignedGet(context, 8); 11461824e45cSKasper Nielsen } else if (view.itemsize == 2) { 11471824e45cSKasper Nielsen // i16 11481824e45cSKasper Nielsen bulkLoadElementType = signless 11491824e45cSKasper Nielsen ? mlirIntegerTypeGet(context, 16) 11501824e45cSKasper Nielsen : mlirIntegerTypeSignedGet(context, 16); 11511824e45cSKasper Nielsen } 11521824e45cSKasper Nielsen } else if (isUnsignedIntegerFormat(format)) { 11531824e45cSKasper Nielsen if (view.itemsize == 4) { 11541824e45cSKasper Nielsen // unsigned i32 11551824e45cSKasper Nielsen bulkLoadElementType = signless 11561824e45cSKasper Nielsen ? mlirIntegerTypeGet(context, 32) 11571824e45cSKasper Nielsen : mlirIntegerTypeUnsignedGet(context, 32); 11581824e45cSKasper Nielsen } else if (view.itemsize == 8) { 11591824e45cSKasper Nielsen // unsigned i64 11601824e45cSKasper Nielsen bulkLoadElementType = signless 11611824e45cSKasper Nielsen ? mlirIntegerTypeGet(context, 64) 11621824e45cSKasper Nielsen : mlirIntegerTypeUnsignedGet(context, 64); 11631824e45cSKasper Nielsen } else if (view.itemsize == 1) { 11641824e45cSKasper Nielsen // i8 11651824e45cSKasper Nielsen bulkLoadElementType = signless 11661824e45cSKasper Nielsen ? mlirIntegerTypeGet(context, 8) 11671824e45cSKasper Nielsen : mlirIntegerTypeUnsignedGet(context, 8); 11681824e45cSKasper Nielsen } else if (view.itemsize == 2) { 11691824e45cSKasper Nielsen // i16 11701824e45cSKasper Nielsen bulkLoadElementType = signless 11711824e45cSKasper Nielsen ? mlirIntegerTypeGet(context, 16) 11721824e45cSKasper Nielsen : mlirIntegerTypeUnsignedGet(context, 16); 11731824e45cSKasper Nielsen } 11741824e45cSKasper Nielsen } 11751824e45cSKasper Nielsen if (!bulkLoadElementType) { 11761824e45cSKasper Nielsen throw std::invalid_argument( 11771824e45cSKasper Nielsen std::string("unimplemented array format conversion from format: ") + 11781824e45cSKasper Nielsen std::string(format)); 11791824e45cSKasper Nielsen } 11801824e45cSKasper Nielsen } 11811824e45cSKasper Nielsen 11821824e45cSKasper Nielsen MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); 11831824e45cSKasper Nielsen return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); 11841824e45cSKasper Nielsen } 11851824e45cSKasper Nielsen 1186b56d1ec6SPeter Hawkins // There is a complication for boolean numpy arrays, as numpy represents 1187b56d1ec6SPeter Hawkins // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 1188b56d1ec6SPeter Hawkins // booleans per byte. 11891824e45cSKasper Nielsen static MlirAttribute getBitpackedAttributeFromBooleanBuffer( 11901824e45cSKasper Nielsen Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape, 11911824e45cSKasper Nielsen MlirContext &context) { 11921824e45cSKasper Nielsen if (llvm::endianness::native != llvm::endianness::little) { 1193b56d1ec6SPeter Hawkins // Given we have no good way of testing the behavior on big-endian 1194b56d1ec6SPeter Hawkins // systems we will throw 1195b56d1ec6SPeter Hawkins throw nb::type_error("Constructing a bit-packed MLIR attribute is " 11961824e45cSKasper Nielsen "unsupported on big-endian systems"); 11971824e45cSKasper Nielsen } 1198b56d1ec6SPeter Hawkins nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray( 1199b56d1ec6SPeter Hawkins /*data=*/static_cast<uint8_t *>(view.buf), 1200b56d1ec6SPeter Hawkins /*shape=*/{static_cast<size_t>(view.len)}); 12011824e45cSKasper Nielsen 1202b56d1ec6SPeter Hawkins nb::module_ numpy = nb::module_::import_("numpy"); 1203b56d1ec6SPeter Hawkins nb::object packbitsFunc = numpy.attr("packbits"); 1204b56d1ec6SPeter Hawkins nb::object packedBooleans = 1205b56d1ec6SPeter Hawkins packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little"); 1206b56d1ec6SPeter Hawkins nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request(); 12071824e45cSKasper Nielsen 12081824e45cSKasper Nielsen MlirType bitpackedType = 12091824e45cSKasper Nielsen getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); 12101824e45cSKasper Nielsen assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8"); 12111824e45cSKasper Nielsen // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of 12121824e45cSKasper Nielsen // packedBooleans, hence the MlirAttribute will remain valid even when 12131824e45cSKasper Nielsen // packedBooleans get reclaimed by the end of the function. 12141824e45cSKasper Nielsen return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, 12151824e45cSKasper Nielsen pythonBuffer.ptr); 12161824e45cSKasper Nielsen } 12171824e45cSKasper Nielsen 12181824e45cSKasper Nielsen // This does the opposite transformation of 12191824e45cSKasper Nielsen // `getBitpackedAttributeFromBooleanBuffer` 1220b56d1ec6SPeter Hawkins std::unique_ptr<nb_buffer_info> getBooleanBufferFromBitpackedAttribute() { 12211824e45cSKasper Nielsen if (llvm::endianness::native != llvm::endianness::little) { 1222b56d1ec6SPeter Hawkins // Given we have no good way of testing the behavior on big-endian 1223b56d1ec6SPeter Hawkins // systems we will throw 1224b56d1ec6SPeter Hawkins throw nb::type_error("Constructing a numpy array from a MLIR attribute " 12251824e45cSKasper Nielsen "is unsupported on big-endian systems"); 12261824e45cSKasper Nielsen } 12271824e45cSKasper Nielsen 12281824e45cSKasper Nielsen int64_t numBooleans = mlirElementsAttrGetNumElements(*this); 12291824e45cSKasper Nielsen int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); 12301824e45cSKasper Nielsen uint8_t *bitpackedData = static_cast<uint8_t *>( 12311824e45cSKasper Nielsen const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 1232b56d1ec6SPeter Hawkins nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray( 1233b56d1ec6SPeter Hawkins /*data=*/bitpackedData, 1234b56d1ec6SPeter Hawkins /*shape=*/{static_cast<size_t>(numBitpackedBytes)}); 12351824e45cSKasper Nielsen 1236b56d1ec6SPeter Hawkins nb::module_ numpy = nb::module_::import_("numpy"); 1237b56d1ec6SPeter Hawkins nb::object unpackbitsFunc = numpy.attr("unpackbits"); 1238b56d1ec6SPeter Hawkins nb::object equalFunc = numpy.attr("equal"); 1239b56d1ec6SPeter Hawkins nb::object reshapeFunc = numpy.attr("reshape"); 1240b56d1ec6SPeter Hawkins nb::object unpackedBooleans = 1241b56d1ec6SPeter Hawkins unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little"); 12421824e45cSKasper Nielsen 12431824e45cSKasper Nielsen // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array. 12441824e45cSKasper Nielsen // We need to: 12451824e45cSKasper Nielsen // 1. Slice away the padded bits 12461824e45cSKasper Nielsen // 2. Make the boolean array have the correct shape 12471824e45cSKasper Nielsen // 3. Convert the array to a boolean array 1248b56d1ec6SPeter Hawkins unpackedBooleans = unpackedBooleans[nb::slice( 1249b56d1ec6SPeter Hawkins nb::int_(0), nb::int_(numBooleans), nb::int_(1))]; 12501824e45cSKasper Nielsen unpackedBooleans = equalFunc(unpackedBooleans, 1); 12511824e45cSKasper Nielsen 12521824e45cSKasper Nielsen MlirType shapedType = mlirAttributeGetType(*this); 12531824e45cSKasper Nielsen intptr_t rank = mlirShapedTypeGetRank(shapedType); 1254404d0e99SAdrian Kuegel std::vector<intptr_t> shape(rank); 12551824e45cSKasper Nielsen for (intptr_t i = 0; i < rank; ++i) { 1256404d0e99SAdrian Kuegel shape[i] = mlirShapedTypeGetDimSize(shapedType, i); 12571824e45cSKasper Nielsen } 12581824e45cSKasper Nielsen unpackedBooleans = reshapeFunc(unpackedBooleans, shape); 12591824e45cSKasper Nielsen 1260b56d1ec6SPeter Hawkins // Make sure the returned nb::buffer_view claims ownership of the data in 12611824e45cSKasper Nielsen // `pythonBuffer` so it remains valid when Python reads it 1262b56d1ec6SPeter Hawkins nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans); 1263b56d1ec6SPeter Hawkins return std::make_unique<nb_buffer_info>(pythonBuffer.request()); 12641824e45cSKasper Nielsen } 12651824e45cSKasper Nielsen 1266436c6c9cSStella Laurenzo template <typename Type> 1267b56d1ec6SPeter Hawkins std::unique_ptr<nb_buffer_info> 1268b56d1ec6SPeter Hawkins bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) { 12690a68171bSDmitri Gribenko intptr_t rank = mlirShapedTypeGetRank(shapedType); 1270436c6c9cSStella Laurenzo // Prepare the data for the buffer_info. 12710a68171bSDmitri Gribenko // Buffer is configured for read-only access below. 1272436c6c9cSStella Laurenzo Type *data = static_cast<Type *>( 1273436c6c9cSStella Laurenzo const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); 1274436c6c9cSStella Laurenzo // Prepare the shape for the buffer_info. 1275436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> shape; 1276436c6c9cSStella Laurenzo for (intptr_t i = 0; i < rank; ++i) 1277436c6c9cSStella Laurenzo shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); 1278436c6c9cSStella Laurenzo // Prepare the strides for the buffer_info. 1279436c6c9cSStella Laurenzo SmallVector<intptr_t, 4> strides; 1280f0e847d0SRahul Kayaith if (mlirDenseElementsAttrIsSplat(*this)) { 1281f0e847d0SRahul Kayaith // Splats are special, only the single value is stored. 1282f0e847d0SRahul Kayaith strides.assign(rank, 0); 1283f0e847d0SRahul Kayaith } else { 1284436c6c9cSStella Laurenzo for (intptr_t i = 1; i < rank; ++i) { 1285f0e847d0SRahul Kayaith intptr_t strideFactor = 1; 1286f0e847d0SRahul Kayaith for (intptr_t j = i; j < rank; ++j) 1287436c6c9cSStella Laurenzo strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); 1288436c6c9cSStella Laurenzo strides.push_back(sizeof(Type) * strideFactor); 1289436c6c9cSStella Laurenzo } 1290436c6c9cSStella Laurenzo strides.push_back(sizeof(Type)); 1291f0e847d0SRahul Kayaith } 1292b56d1ec6SPeter Hawkins const char *format; 12935d6d30edSStella Laurenzo if (explicitFormat) { 12945d6d30edSStella Laurenzo format = explicitFormat; 12955d6d30edSStella Laurenzo } else { 1296b56d1ec6SPeter Hawkins format = nb_format_descriptor<Type>::format(); 12975d6d30edSStella Laurenzo } 1298b56d1ec6SPeter Hawkins return std::make_unique<nb_buffer_info>( 1299b56d1ec6SPeter Hawkins data, sizeof(Type), format, rank, std::move(shape), std::move(strides), 13005d6d30edSStella Laurenzo /*readonly=*/true); 1301436c6c9cSStella Laurenzo } 1302436c6c9cSStella Laurenzo }; // namespace 1303436c6c9cSStella Laurenzo 1304b56d1ec6SPeter Hawkins PyType_Slot PyDenseElementsAttribute::slots[] = { 1305b56d1ec6SPeter Hawkins // Python 3.8 doesn't allow setting the buffer protocol slots from a type spec. 1306b56d1ec6SPeter Hawkins #if PY_VERSION_HEX >= 0x03090000 1307b56d1ec6SPeter Hawkins {Py_bf_getbuffer, 1308b56d1ec6SPeter Hawkins reinterpret_cast<void *>(PyDenseElementsAttribute::bf_getbuffer)}, 1309b56d1ec6SPeter Hawkins {Py_bf_releasebuffer, 1310b56d1ec6SPeter Hawkins reinterpret_cast<void *>(PyDenseElementsAttribute::bf_releasebuffer)}, 1311b56d1ec6SPeter Hawkins #endif 1312b56d1ec6SPeter Hawkins {0, nullptr}, 1313b56d1ec6SPeter Hawkins }; 1314b56d1ec6SPeter Hawkins 1315b56d1ec6SPeter Hawkins /*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj, 1316b56d1ec6SPeter Hawkins Py_buffer *view, 1317b56d1ec6SPeter Hawkins int flags) { 1318b56d1ec6SPeter Hawkins view->obj = nullptr; 1319b56d1ec6SPeter Hawkins std::unique_ptr<nb_buffer_info> info; 1320b56d1ec6SPeter Hawkins try { 1321b56d1ec6SPeter Hawkins auto *attr = nb::cast<PyDenseElementsAttribute *>(nb::handle(obj)); 1322b56d1ec6SPeter Hawkins info = attr->accessBuffer(); 1323b56d1ec6SPeter Hawkins } catch (nb::python_error &e) { 1324b56d1ec6SPeter Hawkins e.restore(); 1325b56d1ec6SPeter Hawkins nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer"); 1326b56d1ec6SPeter Hawkins return -1; 1327b56d1ec6SPeter Hawkins } 1328b56d1ec6SPeter Hawkins view->obj = obj; 1329b56d1ec6SPeter Hawkins view->ndim = 1; 1330b56d1ec6SPeter Hawkins view->buf = info->ptr; 1331b56d1ec6SPeter Hawkins view->itemsize = info->itemsize; 1332b56d1ec6SPeter Hawkins view->len = info->itemsize; 1333b56d1ec6SPeter Hawkins for (auto s : info->shape) { 1334b56d1ec6SPeter Hawkins view->len *= s; 1335b56d1ec6SPeter Hawkins } 1336b56d1ec6SPeter Hawkins view->readonly = info->readonly; 1337b56d1ec6SPeter Hawkins if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { 1338b56d1ec6SPeter Hawkins view->format = const_cast<char *>(info->format); 1339b56d1ec6SPeter Hawkins } 1340b56d1ec6SPeter Hawkins if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { 1341b56d1ec6SPeter Hawkins view->ndim = static_cast<int>(info->ndim); 1342b56d1ec6SPeter Hawkins view->strides = info->strides.data(); 1343b56d1ec6SPeter Hawkins view->shape = info->shape.data(); 1344b56d1ec6SPeter Hawkins } 1345b56d1ec6SPeter Hawkins view->suboffsets = nullptr; 1346b56d1ec6SPeter Hawkins view->internal = info.release(); 1347b56d1ec6SPeter Hawkins Py_INCREF(obj); 1348b56d1ec6SPeter Hawkins return 0; 1349b56d1ec6SPeter Hawkins } 1350b56d1ec6SPeter Hawkins 1351b56d1ec6SPeter Hawkins /*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *, 1352b56d1ec6SPeter Hawkins Py_buffer *view) { 1353b56d1ec6SPeter Hawkins delete reinterpret_cast<nb_buffer_info *>(view->internal); 1354b56d1ec6SPeter Hawkins } 1355b56d1ec6SPeter Hawkins 1356b56d1ec6SPeter Hawkins /// Refinement of the PyDenseElementsAttribute for attributes containing 1357b56d1ec6SPeter Hawkins /// integer (and boolean) values. Supports element access. 1358436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute 1359436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseIntElementsAttribute, 1360436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 1361436c6c9cSStella Laurenzo public: 1362436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; 1363436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseIntElementsAttr"; 1364436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1365436c6c9cSStella Laurenzo 1366b56d1ec6SPeter Hawkins /// Returns the element at the given linear position. Asserts if the index 1367b56d1ec6SPeter Hawkins /// is out of range. 1368b56d1ec6SPeter Hawkins nb::object dunderGetItem(intptr_t pos) { 1369436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 1370b56d1ec6SPeter Hawkins throw nb::index_error("attempt to access out of bounds element"); 1371436c6c9cSStella Laurenzo } 1372436c6c9cSStella Laurenzo 1373436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 1374436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 1375*5d3ae516SMatthias Gehre // Index type can also appear as a DenseIntElementsAttr and therefore can be 1376*5d3ae516SMatthias Gehre // casted to integer. 1377*5d3ae516SMatthias Gehre assert(mlirTypeIsAInteger(type) || 1378*5d3ae516SMatthias Gehre mlirTypeIsAIndex(type) && "expected integer/index element type in " 1379*5d3ae516SMatthias Gehre "dense int elements attribute"); 1380436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 1381b56d1ec6SPeter Hawkins // elemental type of the attribute. nb::int_ is implicitly constructible 1382436c6c9cSStella Laurenzo // from any C++ integral type and handles bitwidth correctly. 1383436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 1384436c6c9cSStella Laurenzo // querying them on each element access. 1385*5d3ae516SMatthias Gehre if (mlirTypeIsAIndex(type)) { 1386*5d3ae516SMatthias Gehre return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos)); 1387*5d3ae516SMatthias Gehre } 1388436c6c9cSStella Laurenzo unsigned width = mlirIntegerTypeGetWidth(type); 1389436c6c9cSStella Laurenzo bool isUnsigned = mlirIntegerTypeIsUnsigned(type); 1390436c6c9cSStella Laurenzo if (isUnsigned) { 1391436c6c9cSStella Laurenzo if (width == 1) { 1392b56d1ec6SPeter Hawkins return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); 1393436c6c9cSStella Laurenzo } 1394308d8b8cSRahul Kayaith if (width == 8) { 1395b56d1ec6SPeter Hawkins return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos)); 1396308d8b8cSRahul Kayaith } 1397308d8b8cSRahul Kayaith if (width == 16) { 1398b56d1ec6SPeter Hawkins return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos)); 1399308d8b8cSRahul Kayaith } 1400436c6c9cSStella Laurenzo if (width == 32) { 1401b56d1ec6SPeter Hawkins return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos)); 1402436c6c9cSStella Laurenzo } 1403436c6c9cSStella Laurenzo if (width == 64) { 1404b56d1ec6SPeter Hawkins return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos)); 1405436c6c9cSStella Laurenzo } 1406436c6c9cSStella Laurenzo } else { 1407436c6c9cSStella Laurenzo if (width == 1) { 1408b56d1ec6SPeter Hawkins return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); 1409436c6c9cSStella Laurenzo } 1410308d8b8cSRahul Kayaith if (width == 8) { 1411b56d1ec6SPeter Hawkins return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos)); 1412308d8b8cSRahul Kayaith } 1413308d8b8cSRahul Kayaith if (width == 16) { 1414b56d1ec6SPeter Hawkins return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos)); 1415308d8b8cSRahul Kayaith } 1416436c6c9cSStella Laurenzo if (width == 32) { 1417b56d1ec6SPeter Hawkins return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos)); 1418436c6c9cSStella Laurenzo } 1419436c6c9cSStella Laurenzo if (width == 64) { 1420b56d1ec6SPeter Hawkins return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos)); 1421436c6c9cSStella Laurenzo } 1422436c6c9cSStella Laurenzo } 1423b56d1ec6SPeter Hawkins throw nb::type_error("Unsupported integer type"); 1424436c6c9cSStella Laurenzo } 1425436c6c9cSStella Laurenzo 1426436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1427436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); 1428436c6c9cSStella Laurenzo } 1429436c6c9cSStella Laurenzo }; 1430436c6c9cSStella Laurenzo 1431f66cd9e9SStella Laurenzo class PyDenseResourceElementsAttribute 1432f66cd9e9SStella Laurenzo : public PyConcreteAttribute<PyDenseResourceElementsAttribute> { 1433f66cd9e9SStella Laurenzo public: 1434f66cd9e9SStella Laurenzo static constexpr IsAFunctionTy isaFunction = 1435f66cd9e9SStella Laurenzo mlirAttributeIsADenseResourceElements; 1436f66cd9e9SStella Laurenzo static constexpr const char *pyClassName = "DenseResourceElementsAttr"; 1437f66cd9e9SStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1438f66cd9e9SStella Laurenzo 1439f66cd9e9SStella Laurenzo static PyDenseResourceElementsAttribute 1440b56d1ec6SPeter Hawkins getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type, 1441f66cd9e9SStella Laurenzo std::optional<size_t> alignment, bool isMutable, 1442f66cd9e9SStella Laurenzo DefaultingPyMlirContext contextWrapper) { 1443f66cd9e9SStella Laurenzo if (!mlirTypeIsAShaped(type)) { 1444f66cd9e9SStella Laurenzo throw std::invalid_argument( 1445f66cd9e9SStella Laurenzo "Constructing a DenseResourceElementsAttr requires a ShapedType."); 1446f66cd9e9SStella Laurenzo } 1447f66cd9e9SStella Laurenzo 1448f66cd9e9SStella Laurenzo // Do not request any conversions as we must ensure to use caller 1449f66cd9e9SStella Laurenzo // managed memory. 1450f66cd9e9SStella Laurenzo int flags = PyBUF_STRIDES; 1451f66cd9e9SStella Laurenzo std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>(); 1452f66cd9e9SStella Laurenzo if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { 1453b56d1ec6SPeter Hawkins throw nb::python_error(); 1454f66cd9e9SStella Laurenzo } 1455f66cd9e9SStella Laurenzo 1456f66cd9e9SStella Laurenzo // This scope releaser will only release if we haven't yet transferred 1457f66cd9e9SStella Laurenzo // ownership. 1458f66cd9e9SStella Laurenzo auto freeBuffer = llvm::make_scope_exit([&]() { 1459f66cd9e9SStella Laurenzo if (view) 1460f66cd9e9SStella Laurenzo PyBuffer_Release(view.get()); 1461f66cd9e9SStella Laurenzo }); 1462f66cd9e9SStella Laurenzo 1463f66cd9e9SStella Laurenzo if (!PyBuffer_IsContiguous(view.get(), 'A')) { 1464f66cd9e9SStella Laurenzo throw std::invalid_argument("Contiguous buffer is required."); 1465f66cd9e9SStella Laurenzo } 1466f66cd9e9SStella Laurenzo 1467f66cd9e9SStella Laurenzo // Infer alignment to be the stride of one element if not explicit. 1468f66cd9e9SStella Laurenzo size_t inferredAlignment; 1469f66cd9e9SStella Laurenzo if (alignment) 1470f66cd9e9SStella Laurenzo inferredAlignment = *alignment; 1471f66cd9e9SStella Laurenzo else 1472f66cd9e9SStella Laurenzo inferredAlignment = view->strides[view->ndim - 1]; 1473f66cd9e9SStella Laurenzo 1474f66cd9e9SStella Laurenzo // The userData is a Py_buffer* that the deleter owns. 1475f66cd9e9SStella Laurenzo auto deleter = [](void *userData, const void *data, size_t size, 1476f66cd9e9SStella Laurenzo size_t align) { 147728507ac6SFabian Tschopp if (!Py_IsInitialized()) 147828507ac6SFabian Tschopp Py_Initialize(); 1479f66cd9e9SStella Laurenzo Py_buffer *ownedView = static_cast<Py_buffer *>(userData); 148028507ac6SFabian Tschopp nb::gil_scoped_acquire gil; 1481f66cd9e9SStella Laurenzo PyBuffer_Release(ownedView); 1482f66cd9e9SStella Laurenzo delete ownedView; 1483f66cd9e9SStella Laurenzo }; 1484f66cd9e9SStella Laurenzo 1485f66cd9e9SStella Laurenzo size_t rawBufferSize = view->len; 1486f66cd9e9SStella Laurenzo MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet( 1487f66cd9e9SStella Laurenzo type, toMlirStringRef(name), view->buf, rawBufferSize, 1488f66cd9e9SStella Laurenzo inferredAlignment, isMutable, deleter, static_cast<void *>(view.get())); 1489f66cd9e9SStella Laurenzo if (mlirAttributeIsNull(attr)) { 1490f66cd9e9SStella Laurenzo throw std::invalid_argument( 1491f66cd9e9SStella Laurenzo "DenseResourceElementsAttr could not be constructed from the given " 1492f66cd9e9SStella Laurenzo "buffer. " 1493f66cd9e9SStella Laurenzo "This may mean that the Python buffer layout does not match that " 1494f66cd9e9SStella Laurenzo "MLIR expected layout and is a bug."); 1495f66cd9e9SStella Laurenzo } 1496f66cd9e9SStella Laurenzo view.release(); 1497f66cd9e9SStella Laurenzo return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr); 1498f66cd9e9SStella Laurenzo } 1499f66cd9e9SStella Laurenzo 1500f66cd9e9SStella Laurenzo static void bindDerived(ClassTy &c) { 1501b56d1ec6SPeter Hawkins c.def_static( 1502b56d1ec6SPeter Hawkins "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer, 1503b56d1ec6SPeter Hawkins nb::arg("array"), nb::arg("name"), nb::arg("type"), 1504b56d1ec6SPeter Hawkins nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false, 1505b56d1ec6SPeter Hawkins nb::arg("context").none() = nb::none(), 1506f66cd9e9SStella Laurenzo kDenseResourceElementsAttrGetFromBufferDocstring); 1507f66cd9e9SStella Laurenzo } 1508f66cd9e9SStella Laurenzo }; 1509f66cd9e9SStella Laurenzo 1510436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { 1511436c6c9cSStella Laurenzo public: 1512436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; 1513436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DictAttr"; 1514436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 15159566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 15169566ee28Smax mlirDictionaryAttrGetTypeID; 1517436c6c9cSStella Laurenzo 1518436c6c9cSStella Laurenzo intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } 1519436c6c9cSStella Laurenzo 15209fb1086bSAdrian Kuegel bool dunderContains(const std::string &name) { 15219fb1086bSAdrian Kuegel return !mlirAttributeIsNull( 15229fb1086bSAdrian Kuegel mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); 15239fb1086bSAdrian Kuegel } 15249fb1086bSAdrian Kuegel 1525436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 15269fb1086bSAdrian Kuegel c.def("__contains__", &PyDictAttribute::dunderContains); 1527436c6c9cSStella Laurenzo c.def("__len__", &PyDictAttribute::dunderLen); 1528436c6c9cSStella Laurenzo c.def_static( 1529436c6c9cSStella Laurenzo "get", 1530b56d1ec6SPeter Hawkins [](nb::dict attributes, DefaultingPyMlirContext context) { 1531436c6c9cSStella Laurenzo SmallVector<MlirNamedAttribute> mlirNamedAttributes; 1532436c6c9cSStella Laurenzo mlirNamedAttributes.reserve(attributes.size()); 1533b56d1ec6SPeter Hawkins for (std::pair<nb::handle, nb::handle> it : attributes) { 1534b56d1ec6SPeter Hawkins auto &mlirAttr = nb::cast<PyAttribute &>(it.second); 1535b56d1ec6SPeter Hawkins auto name = nb::cast<std::string>(it.first); 1536436c6c9cSStella Laurenzo mlirNamedAttributes.push_back(mlirNamedAttributeGet( 153702b6fb21SMehdi Amini mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), 1538436c6c9cSStella Laurenzo toMlirStringRef(name)), 153902b6fb21SMehdi Amini mlirAttr)); 1540436c6c9cSStella Laurenzo } 1541436c6c9cSStella Laurenzo MlirAttribute attr = 1542436c6c9cSStella Laurenzo mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), 1543436c6c9cSStella Laurenzo mlirNamedAttributes.data()); 1544436c6c9cSStella Laurenzo return PyDictAttribute(context->getRef(), attr); 1545436c6c9cSStella Laurenzo }, 1546b56d1ec6SPeter Hawkins nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(), 1547436c6c9cSStella Laurenzo "Gets an uniqued dict attribute"); 1548436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { 1549436c6c9cSStella Laurenzo MlirAttribute attr = 1550436c6c9cSStella Laurenzo mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); 1551974c1596SRahul Kayaith if (mlirAttributeIsNull(attr)) 1552b56d1ec6SPeter Hawkins throw nb::key_error("attempt to access a non-existent attribute"); 1553974c1596SRahul Kayaith return attr; 1554436c6c9cSStella Laurenzo }); 1555436c6c9cSStella Laurenzo c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { 1556436c6c9cSStella Laurenzo if (index < 0 || index >= self.dunderLen()) { 1557b56d1ec6SPeter Hawkins throw nb::index_error("attempt to access out of bounds attribute"); 1558436c6c9cSStella Laurenzo } 1559436c6c9cSStella Laurenzo MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); 1560436c6c9cSStella Laurenzo return PyNamedAttribute( 1561436c6c9cSStella Laurenzo namedAttr.attribute, 1562436c6c9cSStella Laurenzo std::string(mlirIdentifierStr(namedAttr.name).data)); 1563436c6c9cSStella Laurenzo }); 1564436c6c9cSStella Laurenzo } 1565436c6c9cSStella Laurenzo }; 1566436c6c9cSStella Laurenzo 1567436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing 1568436c6c9cSStella Laurenzo /// floating-point values. Supports element access. 1569436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute 1570436c6c9cSStella Laurenzo : public PyConcreteAttribute<PyDenseFPElementsAttribute, 1571436c6c9cSStella Laurenzo PyDenseElementsAttribute> { 1572436c6c9cSStella Laurenzo public: 1573436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; 1574436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "DenseFPElementsAttr"; 1575436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 1576436c6c9cSStella Laurenzo 1577b56d1ec6SPeter Hawkins nb::float_ dunderGetItem(intptr_t pos) { 1578436c6c9cSStella Laurenzo if (pos < 0 || pos >= dunderLen()) { 1579b56d1ec6SPeter Hawkins throw nb::index_error("attempt to access out of bounds element"); 1580436c6c9cSStella Laurenzo } 1581436c6c9cSStella Laurenzo 1582436c6c9cSStella Laurenzo MlirType type = mlirAttributeGetType(*this); 1583436c6c9cSStella Laurenzo type = mlirShapedTypeGetElementType(type); 1584436c6c9cSStella Laurenzo // Dispatch element extraction to an appropriate C function based on the 1585b56d1ec6SPeter Hawkins // elemental type of the attribute. nb::float_ is implicitly constructible 1586436c6c9cSStella Laurenzo // from float and double. 1587436c6c9cSStella Laurenzo // TODO: consider caching the type properties in the constructor to avoid 1588436c6c9cSStella Laurenzo // querying them on each element access. 1589436c6c9cSStella Laurenzo if (mlirTypeIsAF32(type)) { 1590b56d1ec6SPeter Hawkins return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos)); 1591436c6c9cSStella Laurenzo } 1592436c6c9cSStella Laurenzo if (mlirTypeIsAF64(type)) { 1593b56d1ec6SPeter Hawkins return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos)); 1594436c6c9cSStella Laurenzo } 1595b56d1ec6SPeter Hawkins throw nb::type_error("Unsupported floating-point type"); 1596436c6c9cSStella Laurenzo } 1597436c6c9cSStella Laurenzo 1598436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1599436c6c9cSStella Laurenzo c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); 1600436c6c9cSStella Laurenzo } 1601436c6c9cSStella Laurenzo }; 1602436c6c9cSStella Laurenzo 1603436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { 1604436c6c9cSStella Laurenzo public: 1605436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; 1606436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "TypeAttr"; 1607436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 16089566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 16099566ee28Smax mlirTypeAttrGetTypeID; 1610436c6c9cSStella Laurenzo 1611436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1612436c6c9cSStella Laurenzo c.def_static( 1613436c6c9cSStella Laurenzo "get", 1614436c6c9cSStella Laurenzo [](PyType value, DefaultingPyMlirContext context) { 1615436c6c9cSStella Laurenzo MlirAttribute attr = mlirTypeAttrGet(value.get()); 1616436c6c9cSStella Laurenzo return PyTypeAttribute(context->getRef(), attr); 1617436c6c9cSStella Laurenzo }, 1618b56d1ec6SPeter Hawkins nb::arg("value"), nb::arg("context").none() = nb::none(), 1619436c6c9cSStella Laurenzo "Gets a uniqued Type attribute"); 1620b56d1ec6SPeter Hawkins c.def_prop_ro("value", [](PyTypeAttribute &self) { 1621bfb1ba75Smax return mlirTypeAttrGetValue(self.get()); 1622436c6c9cSStella Laurenzo }); 1623436c6c9cSStella Laurenzo } 1624436c6c9cSStella Laurenzo }; 1625436c6c9cSStella Laurenzo 1626436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values. 1627436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { 1628436c6c9cSStella Laurenzo public: 1629436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; 1630436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "UnitAttr"; 1631436c6c9cSStella Laurenzo using PyConcreteAttribute::PyConcreteAttribute; 16329566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 16339566ee28Smax mlirUnitAttrGetTypeID; 1634436c6c9cSStella Laurenzo 1635436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1636436c6c9cSStella Laurenzo c.def_static( 1637436c6c9cSStella Laurenzo "get", 1638436c6c9cSStella Laurenzo [](DefaultingPyMlirContext context) { 1639436c6c9cSStella Laurenzo return PyUnitAttribute(context->getRef(), 1640436c6c9cSStella Laurenzo mlirUnitAttrGet(context->get())); 1641436c6c9cSStella Laurenzo }, 1642b56d1ec6SPeter Hawkins nb::arg("context").none() = nb::none(), "Create a Unit attribute."); 1643436c6c9cSStella Laurenzo } 1644436c6c9cSStella Laurenzo }; 1645436c6c9cSStella Laurenzo 1646ac2e2d65SDenys Shabalin /// Strided layout attribute subclass. 1647ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute 1648ac2e2d65SDenys Shabalin : public PyConcreteAttribute<PyStridedLayoutAttribute> { 1649ac2e2d65SDenys Shabalin public: 1650ac2e2d65SDenys Shabalin static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; 1651ac2e2d65SDenys Shabalin static constexpr const char *pyClassName = "StridedLayoutAttr"; 1652ac2e2d65SDenys Shabalin using PyConcreteAttribute::PyConcreteAttribute; 16539566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = 16549566ee28Smax mlirStridedLayoutAttrGetTypeID; 1655ac2e2d65SDenys Shabalin 1656ac2e2d65SDenys Shabalin static void bindDerived(ClassTy &c) { 1657ac2e2d65SDenys Shabalin c.def_static( 1658ac2e2d65SDenys Shabalin "get", 1659ac2e2d65SDenys Shabalin [](int64_t offset, const std::vector<int64_t> strides, 1660ac2e2d65SDenys Shabalin DefaultingPyMlirContext ctx) { 1661ac2e2d65SDenys Shabalin MlirAttribute attr = mlirStridedLayoutAttrGet( 1662ac2e2d65SDenys Shabalin ctx->get(), offset, strides.size(), strides.data()); 1663ac2e2d65SDenys Shabalin return PyStridedLayoutAttribute(ctx->getRef(), attr); 1664ac2e2d65SDenys Shabalin }, 1665b56d1ec6SPeter Hawkins nb::arg("offset"), nb::arg("strides"), 1666b56d1ec6SPeter Hawkins nb::arg("context").none() = nb::none(), 1667ac2e2d65SDenys Shabalin "Gets a strided layout attribute."); 1668e3fd612eSDenys Shabalin c.def_static( 1669e3fd612eSDenys Shabalin "get_fully_dynamic", 1670e3fd612eSDenys Shabalin [](int64_t rank, DefaultingPyMlirContext ctx) { 1671e3fd612eSDenys Shabalin auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); 1672e3fd612eSDenys Shabalin std::vector<int64_t> strides(rank); 1673e3fd612eSDenys Shabalin std::fill(strides.begin(), strides.end(), dynamic); 1674e3fd612eSDenys Shabalin MlirAttribute attr = mlirStridedLayoutAttrGet( 1675e3fd612eSDenys Shabalin ctx->get(), dynamic, strides.size(), strides.data()); 1676e3fd612eSDenys Shabalin return PyStridedLayoutAttribute(ctx->getRef(), attr); 1677e3fd612eSDenys Shabalin }, 1678b56d1ec6SPeter Hawkins nb::arg("rank"), nb::arg("context").none() = nb::none(), 1679b56d1ec6SPeter Hawkins "Gets a strided layout attribute with dynamic offset and strides of " 1680b56d1ec6SPeter Hawkins "a " 1681e3fd612eSDenys Shabalin "given rank."); 1682b56d1ec6SPeter Hawkins c.def_prop_ro( 1683ac2e2d65SDenys Shabalin "offset", 1684ac2e2d65SDenys Shabalin [](PyStridedLayoutAttribute &self) { 1685ac2e2d65SDenys Shabalin return mlirStridedLayoutAttrGetOffset(self); 1686ac2e2d65SDenys Shabalin }, 1687ac2e2d65SDenys Shabalin "Returns the value of the float point attribute"); 1688b56d1ec6SPeter Hawkins c.def_prop_ro( 1689ac2e2d65SDenys Shabalin "strides", 1690ac2e2d65SDenys Shabalin [](PyStridedLayoutAttribute &self) { 1691ac2e2d65SDenys Shabalin intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); 1692ac2e2d65SDenys Shabalin std::vector<int64_t> strides(size); 1693ac2e2d65SDenys Shabalin for (intptr_t i = 0; i < size; i++) { 1694ac2e2d65SDenys Shabalin strides[i] = mlirStridedLayoutAttrGetStride(self, i); 1695ac2e2d65SDenys Shabalin } 1696ac2e2d65SDenys Shabalin return strides; 1697ac2e2d65SDenys Shabalin }, 1698ac2e2d65SDenys Shabalin "Returns the value of the float point attribute"); 1699ac2e2d65SDenys Shabalin } 1700ac2e2d65SDenys Shabalin }; 1701ac2e2d65SDenys Shabalin 1702b56d1ec6SPeter Hawkins nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { 17039566ee28Smax if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) 1704b56d1ec6SPeter Hawkins return nb::cast(PyDenseBoolArrayAttribute(pyAttribute)); 17059566ee28Smax if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) 1706b56d1ec6SPeter Hawkins return nb::cast(PyDenseI8ArrayAttribute(pyAttribute)); 17079566ee28Smax if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) 1708b56d1ec6SPeter Hawkins return nb::cast(PyDenseI16ArrayAttribute(pyAttribute)); 17099566ee28Smax if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) 1710b56d1ec6SPeter Hawkins return nb::cast(PyDenseI32ArrayAttribute(pyAttribute)); 17119566ee28Smax if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) 1712b56d1ec6SPeter Hawkins return nb::cast(PyDenseI64ArrayAttribute(pyAttribute)); 17139566ee28Smax if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) 1714b56d1ec6SPeter Hawkins return nb::cast(PyDenseF32ArrayAttribute(pyAttribute)); 17159566ee28Smax if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) 1716b56d1ec6SPeter Hawkins return nb::cast(PyDenseF64ArrayAttribute(pyAttribute)); 17179566ee28Smax std::string msg = 17189566ee28Smax std::string("Can't cast unknown element type DenseArrayAttr (") + 1719b56d1ec6SPeter Hawkins nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")"; 1720b56d1ec6SPeter Hawkins throw nb::type_error(msg.c_str()); 17219566ee28Smax } 17229566ee28Smax 1723b56d1ec6SPeter Hawkins nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { 17249566ee28Smax if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) 1725b56d1ec6SPeter Hawkins return nb::cast(PyDenseFPElementsAttribute(pyAttribute)); 17269566ee28Smax if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) 1727b56d1ec6SPeter Hawkins return nb::cast(PyDenseIntElementsAttribute(pyAttribute)); 17289566ee28Smax std::string msg = 17299566ee28Smax std::string( 17309566ee28Smax "Can't cast unknown element type DenseIntOrFPElementsAttr (") + 1731b56d1ec6SPeter Hawkins nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")"; 1732b56d1ec6SPeter Hawkins throw nb::type_error(msg.c_str()); 17339566ee28Smax } 17349566ee28Smax 1735b56d1ec6SPeter Hawkins nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { 17369566ee28Smax if (PyBoolAttribute::isaFunction(pyAttribute)) 1737b56d1ec6SPeter Hawkins return nb::cast(PyBoolAttribute(pyAttribute)); 17389566ee28Smax if (PyIntegerAttribute::isaFunction(pyAttribute)) 1739b56d1ec6SPeter Hawkins return nb::cast(PyIntegerAttribute(pyAttribute)); 17409566ee28Smax std::string msg = 17419566ee28Smax std::string("Can't cast unknown element type DenseArrayAttr (") + 1742b56d1ec6SPeter Hawkins nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")"; 1743b56d1ec6SPeter Hawkins throw nb::type_error(msg.c_str()); 17449566ee28Smax } 17459566ee28Smax 1746b56d1ec6SPeter Hawkins nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { 17474eee9ef9Smax if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute)) 1748b56d1ec6SPeter Hawkins return nb::cast(PyFlatSymbolRefAttribute(pyAttribute)); 17494eee9ef9Smax if (PySymbolRefAttribute::isaFunction(pyAttribute)) 1750b56d1ec6SPeter Hawkins return nb::cast(PySymbolRefAttribute(pyAttribute)); 17514eee9ef9Smax std::string msg = std::string("Can't cast unknown SymbolRef attribute (") + 1752b56d1ec6SPeter Hawkins nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + 1753b56d1ec6SPeter Hawkins ")"; 1754b56d1ec6SPeter Hawkins throw nb::type_error(msg.c_str()); 17554eee9ef9Smax } 17564eee9ef9Smax 1757436c6c9cSStella Laurenzo } // namespace 1758436c6c9cSStella Laurenzo 1759b56d1ec6SPeter Hawkins void mlir::python::populateIRAttributes(nb::module_ &m) { 1760436c6c9cSStella Laurenzo PyAffineMapAttribute::bind(m); 1761619fd8c2SJeff Niu PyDenseBoolArrayAttribute::bind(m); 1762619fd8c2SJeff Niu PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); 1763619fd8c2SJeff Niu PyDenseI8ArrayAttribute::bind(m); 1764619fd8c2SJeff Niu PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m); 1765619fd8c2SJeff Niu PyDenseI16ArrayAttribute::bind(m); 1766619fd8c2SJeff Niu PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m); 1767619fd8c2SJeff Niu PyDenseI32ArrayAttribute::bind(m); 1768619fd8c2SJeff Niu PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m); 1769619fd8c2SJeff Niu PyDenseI64ArrayAttribute::bind(m); 1770619fd8c2SJeff Niu PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m); 1771619fd8c2SJeff Niu PyDenseF32ArrayAttribute::bind(m); 1772619fd8c2SJeff Niu PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); 1773619fd8c2SJeff Niu PyDenseF64ArrayAttribute::bind(m); 1774619fd8c2SJeff Niu PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); 17759566ee28Smax PyGlobals::get().registerTypeCaster( 17769566ee28Smax mlirDenseArrayAttrGetTypeID(), 1777b56d1ec6SPeter Hawkins nb::cast<nb::callable>(nb::cpp_function(denseArrayAttributeCaster))); 1778619fd8c2SJeff Niu 1779436c6c9cSStella Laurenzo PyArrayAttribute::bind(m); 1780436c6c9cSStella Laurenzo PyArrayAttribute::PyArrayAttributeIterator::bind(m); 1781436c6c9cSStella Laurenzo PyBoolAttribute::bind(m); 1782b56d1ec6SPeter Hawkins PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots); 1783436c6c9cSStella Laurenzo PyDenseFPElementsAttribute::bind(m); 1784436c6c9cSStella Laurenzo PyDenseIntElementsAttribute::bind(m); 17859566ee28Smax PyGlobals::get().registerTypeCaster( 17869566ee28Smax mlirDenseIntOrFPElementsAttrGetTypeID(), 1787b56d1ec6SPeter Hawkins nb::cast<nb::callable>( 1788b56d1ec6SPeter Hawkins nb::cpp_function(denseIntOrFPElementsAttributeCaster))); 1789f66cd9e9SStella Laurenzo PyDenseResourceElementsAttribute::bind(m); 17909566ee28Smax 1791436c6c9cSStella Laurenzo PyDictAttribute::bind(m); 17924eee9ef9Smax PySymbolRefAttribute::bind(m); 17934eee9ef9Smax PyGlobals::get().registerTypeCaster( 17944eee9ef9Smax mlirSymbolRefAttrGetTypeID(), 1795b56d1ec6SPeter Hawkins nb::cast<nb::callable>( 1796b56d1ec6SPeter Hawkins nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster))); 17974eee9ef9Smax 1798436c6c9cSStella Laurenzo PyFlatSymbolRefAttribute::bind(m); 17995c3861b2SYun Long PyOpaqueAttribute::bind(m); 1800436c6c9cSStella Laurenzo PyFloatAttribute::bind(m); 1801436c6c9cSStella Laurenzo PyIntegerAttribute::bind(m); 1802334873feSAmy Wang PyIntegerSetAttribute::bind(m); 1803436c6c9cSStella Laurenzo PyStringAttribute::bind(m); 1804436c6c9cSStella Laurenzo PyTypeAttribute::bind(m); 18059566ee28Smax PyGlobals::get().registerTypeCaster( 18069566ee28Smax mlirIntegerAttrGetTypeID(), 1807b56d1ec6SPeter Hawkins nb::cast<nb::callable>(nb::cpp_function(integerOrBoolAttributeCaster))); 1808436c6c9cSStella Laurenzo PyUnitAttribute::bind(m); 1809ac2e2d65SDenys Shabalin 1810ac2e2d65SDenys Shabalin PyStridedLayoutAttribute::bind(m); 1811436c6c9cSStella Laurenzo } 1812