xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision 5d3ae5161210c068d01ffba36c8e0761e9971179)
1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9b56d1ec6SPeter Hawkins #include <cstdint>
10a1fe1f5fSKazu Hirata #include <optional>
11b56d1ec6SPeter Hawkins #include <string>
1271a25454SPeter Hawkins #include <string_view>
134811270bSmax #include <utility>
141fc096afSMehdi Amini 
15436c6c9cSStella Laurenzo #include "IRModule.h"
16b56d1ec6SPeter Hawkins #include "NanobindUtils.h"
17b56d1ec6SPeter Hawkins #include "mlir-c/BuiltinAttributes.h"
18b56d1ec6SPeter Hawkins #include "mlir-c/BuiltinTypes.h"
19b56d1ec6SPeter Hawkins #include "mlir/Bindings/Python/NanobindAdaptors.h"
205cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h"
2171a25454SPeter Hawkins #include "llvm/ADT/ScopeExit.h"
22c912f0e7Spranavm-nvidia #include "llvm/Support/raw_ostream.h"
2371a25454SPeter Hawkins 
24b56d1ec6SPeter Hawkins namespace nb = nanobind;
25b56d1ec6SPeter Hawkins using namespace nanobind::literals;
26436c6c9cSStella Laurenzo using namespace mlir;
27436c6c9cSStella Laurenzo using namespace mlir::python;
28436c6c9cSStella Laurenzo 
29436c6c9cSStella Laurenzo using llvm::SmallVector;
30436c6c9cSStella Laurenzo 
315d6d30edSStella Laurenzo //------------------------------------------------------------------------------
325d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
335d6d30edSStella Laurenzo //------------------------------------------------------------------------------
345d6d30edSStella Laurenzo 
355d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] =
365d6d30edSStella Laurenzo     R"(Gets a DenseElementsAttr from a Python buffer or array.
375d6d30edSStella Laurenzo 
385d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based
395d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and
405d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these
415d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer.
425d6d30edSStella Laurenzo 
435d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided
445d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal
455d6d30edSStella Laurenzo representation:
465d6d30edSStella Laurenzo 
475d6d30edSStella Laurenzo   * Integer types (except for i1): the buffer must be byte aligned to the
485d6d30edSStella Laurenzo     next byte boundary.
495d6d30edSStella Laurenzo   * Floating point types: Must be bit-castable to the given floating point
505d6d30edSStella Laurenzo     size.
515d6d30edSStella Laurenzo   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
525d6d30edSStella Laurenzo     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
535d6d30edSStella Laurenzo     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
545d6d30edSStella Laurenzo 
555d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0
565d6d30edSStella Laurenzo or 255), then a splat will be created.
575d6d30edSStella Laurenzo 
585d6d30edSStella Laurenzo Args:
595d6d30edSStella Laurenzo   array: The array or buffer to convert.
605d6d30edSStella Laurenzo   signless: If inferring an appropriate MLIR type, use signless types for
615d6d30edSStella Laurenzo     integers (defaults True).
625d6d30edSStella Laurenzo   type: Skips inference of the MLIR element type and uses this instead. The
635d6d30edSStella Laurenzo     storage size must be consistent with the actual contents of the buffer.
645d6d30edSStella Laurenzo   shape: Overrides the shape of the buffer when constructing the MLIR
655d6d30edSStella Laurenzo     shaped type. This is needed when the physical and logical shape differ (as
665d6d30edSStella Laurenzo     for i1).
675d6d30edSStella Laurenzo   context: Explicit context, if not from context manager.
685d6d30edSStella Laurenzo 
695d6d30edSStella Laurenzo Returns:
705d6d30edSStella Laurenzo   DenseElementsAttr on success.
715d6d30edSStella Laurenzo 
725d6d30edSStella Laurenzo Raises:
735d6d30edSStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
745d6d30edSStella Laurenzo     type or if the buffer does not meet expectations.
755d6d30edSStella Laurenzo )";
765d6d30edSStella Laurenzo 
77c912f0e7Spranavm-nvidia static const char kDenseElementsAttrGetFromListDocstring[] =
78c912f0e7Spranavm-nvidia     R"(Gets a DenseElementsAttr from a Python list of attributes.
79c912f0e7Spranavm-nvidia 
80c912f0e7Spranavm-nvidia Note that it can be expensive to construct attributes individually.
81c912f0e7Spranavm-nvidia For a large number of elements, consider using a Python buffer or array instead.
82c912f0e7Spranavm-nvidia 
83c912f0e7Spranavm-nvidia Args:
84c912f0e7Spranavm-nvidia   attrs: A list of attributes.
85c912f0e7Spranavm-nvidia   type: The desired shape and type of the resulting DenseElementsAttr.
86c912f0e7Spranavm-nvidia     If not provided, the element type is determined based on the type
87c912f0e7Spranavm-nvidia     of the 0th attribute and the shape is `[len(attrs)]`.
88c912f0e7Spranavm-nvidia   context: Explicit context, if not from context manager.
89c912f0e7Spranavm-nvidia 
90c912f0e7Spranavm-nvidia Returns:
91c912f0e7Spranavm-nvidia   DenseElementsAttr on success.
92c912f0e7Spranavm-nvidia 
93c912f0e7Spranavm-nvidia Raises:
94c912f0e7Spranavm-nvidia   ValueError: If the type of the attributes does not match the type
95c912f0e7Spranavm-nvidia     specified by `shaped_type`.
96c912f0e7Spranavm-nvidia )";
97c912f0e7Spranavm-nvidia 
98f66cd9e9SStella Laurenzo static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
99f66cd9e9SStella Laurenzo     R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
100f66cd9e9SStella Laurenzo 
101f66cd9e9SStella Laurenzo This function does minimal validation or massaging of the data, and it is
102f66cd9e9SStella Laurenzo up to the caller to ensure that the buffer meets the characteristics
103f66cd9e9SStella Laurenzo implied by the shape.
104f66cd9e9SStella Laurenzo 
105f66cd9e9SStella Laurenzo The backing buffer and any user objects will be retained for the lifetime
106f66cd9e9SStella Laurenzo of the resource blob. This is typically bounded to the context but the
107f66cd9e9SStella Laurenzo resource can have a shorter lifespan depending on how it is used in
108f66cd9e9SStella Laurenzo subsequent processing.
109f66cd9e9SStella Laurenzo 
110f66cd9e9SStella Laurenzo Args:
111f66cd9e9SStella Laurenzo   buffer: The array or buffer to convert.
112f66cd9e9SStella Laurenzo   name: Name to provide to the resource (may be changed upon collision).
113f66cd9e9SStella Laurenzo   type: The explicit ShapedType to construct the attribute with.
114f66cd9e9SStella Laurenzo   context: Explicit context, if not from context manager.
115f66cd9e9SStella Laurenzo 
116f66cd9e9SStella Laurenzo Returns:
117f66cd9e9SStella Laurenzo   DenseResourceElementsAttr on success.
118f66cd9e9SStella Laurenzo 
119f66cd9e9SStella Laurenzo Raises:
120f66cd9e9SStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
121f66cd9e9SStella Laurenzo     type or if the buffer does not meet expectations.
122f66cd9e9SStella Laurenzo )";
123f66cd9e9SStella Laurenzo 
124436c6c9cSStella Laurenzo namespace {
125436c6c9cSStella Laurenzo 
126b56d1ec6SPeter Hawkins struct nb_buffer_info {
127b56d1ec6SPeter Hawkins   void *ptr = nullptr;
128b56d1ec6SPeter Hawkins   ssize_t itemsize = 0;
129b56d1ec6SPeter Hawkins   ssize_t size = 0;
130b56d1ec6SPeter Hawkins   const char *format = nullptr;
131b56d1ec6SPeter Hawkins   ssize_t ndim = 0;
132b56d1ec6SPeter Hawkins   SmallVector<ssize_t, 4> shape;
133b56d1ec6SPeter Hawkins   SmallVector<ssize_t, 4> strides;
134b56d1ec6SPeter Hawkins   bool readonly = false;
135b56d1ec6SPeter Hawkins 
136b56d1ec6SPeter Hawkins   nb_buffer_info(
137b56d1ec6SPeter Hawkins       void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
138b56d1ec6SPeter Hawkins       SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in,
139b56d1ec6SPeter Hawkins       bool readonly = false,
140b56d1ec6SPeter Hawkins       std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in =
141b56d1ec6SPeter Hawkins           std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr))
142b56d1ec6SPeter Hawkins       : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
143b56d1ec6SPeter Hawkins         shape(std::move(shape_in)), strides(std::move(strides_in)),
144b56d1ec6SPeter Hawkins         readonly(readonly), owned_view(std::move(owned_view_in)) {
145b56d1ec6SPeter Hawkins     size = 1;
146b56d1ec6SPeter Hawkins     for (ssize_t i = 0; i < ndim; ++i) {
147b56d1ec6SPeter Hawkins       size *= shape[i];
148b56d1ec6SPeter Hawkins     }
149b56d1ec6SPeter Hawkins   }
150b56d1ec6SPeter Hawkins 
151b56d1ec6SPeter Hawkins   explicit nb_buffer_info(Py_buffer *view)
152b56d1ec6SPeter Hawkins       : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim,
153b56d1ec6SPeter Hawkins                        {view->shape, view->shape + view->ndim},
154b56d1ec6SPeter Hawkins                        // TODO(phawkins): check for null strides
155b56d1ec6SPeter Hawkins                        {view->strides, view->strides + view->ndim},
156b56d1ec6SPeter Hawkins                        view->readonly != 0,
157b56d1ec6SPeter Hawkins                        std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(
158b56d1ec6SPeter Hawkins                            view, PyBuffer_Release)) {}
159b56d1ec6SPeter Hawkins 
160b56d1ec6SPeter Hawkins   nb_buffer_info(const nb_buffer_info &) = delete;
161b56d1ec6SPeter Hawkins   nb_buffer_info(nb_buffer_info &&) = default;
162b56d1ec6SPeter Hawkins   nb_buffer_info &operator=(const nb_buffer_info &) = delete;
163b56d1ec6SPeter Hawkins   nb_buffer_info &operator=(nb_buffer_info &&) = default;
164b56d1ec6SPeter Hawkins 
165b56d1ec6SPeter Hawkins private:
166b56d1ec6SPeter Hawkins   std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view;
167b56d1ec6SPeter Hawkins };
168b56d1ec6SPeter Hawkins 
169b56d1ec6SPeter Hawkins class nb_buffer : public nb::object {
170b56d1ec6SPeter Hawkins   NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer);
171b56d1ec6SPeter Hawkins 
172b56d1ec6SPeter Hawkins   nb_buffer_info request() const {
173b56d1ec6SPeter Hawkins     int flags = PyBUF_STRIDES | PyBUF_FORMAT;
174b56d1ec6SPeter Hawkins     auto *view = new Py_buffer();
175b56d1ec6SPeter Hawkins     if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
176b56d1ec6SPeter Hawkins       delete view;
177b56d1ec6SPeter Hawkins       throw nb::python_error();
178b56d1ec6SPeter Hawkins     }
179b56d1ec6SPeter Hawkins     return nb_buffer_info(view);
180b56d1ec6SPeter Hawkins   }
181b56d1ec6SPeter Hawkins };
182b56d1ec6SPeter Hawkins 
183b56d1ec6SPeter Hawkins template <typename T>
184b56d1ec6SPeter Hawkins struct nb_format_descriptor {};
185b56d1ec6SPeter Hawkins 
186b56d1ec6SPeter Hawkins template <>
187b56d1ec6SPeter Hawkins struct nb_format_descriptor<bool> {
188b56d1ec6SPeter Hawkins   static const char *format() { return "?"; }
189b56d1ec6SPeter Hawkins };
190b56d1ec6SPeter Hawkins template <>
191b56d1ec6SPeter Hawkins struct nb_format_descriptor<int8_t> {
192b56d1ec6SPeter Hawkins   static const char *format() { return "b"; }
193b56d1ec6SPeter Hawkins };
194b56d1ec6SPeter Hawkins template <>
195b56d1ec6SPeter Hawkins struct nb_format_descriptor<uint8_t> {
196b56d1ec6SPeter Hawkins   static const char *format() { return "B"; }
197b56d1ec6SPeter Hawkins };
198b56d1ec6SPeter Hawkins template <>
199b56d1ec6SPeter Hawkins struct nb_format_descriptor<int16_t> {
200b56d1ec6SPeter Hawkins   static const char *format() { return "h"; }
201b56d1ec6SPeter Hawkins };
202b56d1ec6SPeter Hawkins template <>
203b56d1ec6SPeter Hawkins struct nb_format_descriptor<uint16_t> {
204b56d1ec6SPeter Hawkins   static const char *format() { return "H"; }
205b56d1ec6SPeter Hawkins };
206b56d1ec6SPeter Hawkins template <>
207b56d1ec6SPeter Hawkins struct nb_format_descriptor<int32_t> {
208b56d1ec6SPeter Hawkins   static const char *format() { return "i"; }
209b56d1ec6SPeter Hawkins };
210b56d1ec6SPeter Hawkins template <>
211b56d1ec6SPeter Hawkins struct nb_format_descriptor<uint32_t> {
212b56d1ec6SPeter Hawkins   static const char *format() { return "I"; }
213b56d1ec6SPeter Hawkins };
214b56d1ec6SPeter Hawkins template <>
215b56d1ec6SPeter Hawkins struct nb_format_descriptor<int64_t> {
216b56d1ec6SPeter Hawkins   static const char *format() { return "q"; }
217b56d1ec6SPeter Hawkins };
218b56d1ec6SPeter Hawkins template <>
219b56d1ec6SPeter Hawkins struct nb_format_descriptor<uint64_t> {
220b56d1ec6SPeter Hawkins   static const char *format() { return "Q"; }
221b56d1ec6SPeter Hawkins };
222b56d1ec6SPeter Hawkins template <>
223b56d1ec6SPeter Hawkins struct nb_format_descriptor<float> {
224b56d1ec6SPeter Hawkins   static const char *format() { return "f"; }
225b56d1ec6SPeter Hawkins };
226b56d1ec6SPeter Hawkins template <>
227b56d1ec6SPeter Hawkins struct nb_format_descriptor<double> {
228b56d1ec6SPeter Hawkins   static const char *format() { return "d"; }
229b56d1ec6SPeter Hawkins };
230b56d1ec6SPeter Hawkins 
231436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
232436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
233436c6c9cSStella Laurenzo }
234436c6c9cSStella Laurenzo 
235b56d1ec6SPeter Hawkins static MlirStringRef toMlirStringRef(const nb::bytes &s) {
236b56d1ec6SPeter Hawkins   return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
237b56d1ec6SPeter Hawkins }
238b56d1ec6SPeter Hawkins 
239436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
240436c6c9cSStella Laurenzo public:
241436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
242436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
243436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
2449566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
2459566ee28Smax       mlirAffineMapAttrGetTypeID;
246436c6c9cSStella Laurenzo 
247436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
248436c6c9cSStella Laurenzo     c.def_static(
249436c6c9cSStella Laurenzo         "get",
250436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
251436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
252436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
253436c6c9cSStella Laurenzo         },
254b56d1ec6SPeter Hawkins         nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
255b56d1ec6SPeter Hawkins     c.def_prop_ro("value", mlirAffineMapAttrGetValue,
256c36b4248SBimo                   "Returns the value of the AffineMap attribute");
257436c6c9cSStella Laurenzo   }
258436c6c9cSStella Laurenzo };
259436c6c9cSStella Laurenzo 
260334873feSAmy Wang class PyIntegerSetAttribute
261334873feSAmy Wang     : public PyConcreteAttribute<PyIntegerSetAttribute> {
262334873feSAmy Wang public:
263334873feSAmy Wang   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
264334873feSAmy Wang   static constexpr const char *pyClassName = "IntegerSetAttr";
265334873feSAmy Wang   using PyConcreteAttribute::PyConcreteAttribute;
266334873feSAmy Wang   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
267334873feSAmy Wang       mlirIntegerSetAttrGetTypeID;
268334873feSAmy Wang 
269334873feSAmy Wang   static void bindDerived(ClassTy &c) {
270334873feSAmy Wang     c.def_static(
271334873feSAmy Wang         "get",
272334873feSAmy Wang         [](PyIntegerSet &integerSet) {
273334873feSAmy Wang           MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
274334873feSAmy Wang           return PyIntegerSetAttribute(integerSet.getContext(), attr);
275334873feSAmy Wang         },
276b56d1ec6SPeter Hawkins         nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
277334873feSAmy Wang   }
278334873feSAmy Wang };
279334873feSAmy Wang 
280ed9e52f3SAlex Zinenko template <typename T>
281b56d1ec6SPeter Hawkins static T pyTryCast(nb::handle object) {
282ed9e52f3SAlex Zinenko   try {
283b56d1ec6SPeter Hawkins     return nb::cast<T>(object);
284b56d1ec6SPeter Hawkins   } catch (nb::cast_error &err) {
285b56d1ec6SPeter Hawkins     std::string msg = std::string("Invalid attribute when attempting to "
286b56d1ec6SPeter Hawkins                                   "create an ArrayAttribute (") +
287ed9e52f3SAlex Zinenko                       err.what() + ")";
288b56d1ec6SPeter Hawkins     throw std::runtime_error(msg.c_str());
289b56d1ec6SPeter Hawkins   } catch (std::runtime_error &err) {
290ed9e52f3SAlex Zinenko     std::string msg = std::string("Invalid attribute (None?) when attempting "
291ed9e52f3SAlex Zinenko                                   "to create an ArrayAttribute (") +
292ed9e52f3SAlex Zinenko                       err.what() + ")";
293b56d1ec6SPeter Hawkins     throw std::runtime_error(msg.c_str());
294ed9e52f3SAlex Zinenko   }
295ed9e52f3SAlex Zinenko }
296ed9e52f3SAlex Zinenko 
297619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived
298619fd8c2SJeff Niu /// implementation class.
299619fd8c2SJeff Niu template <typename EltTy, typename DerivedT>
300133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
301619fd8c2SJeff Niu public:
302133624acSJeff Niu   using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
303619fd8c2SJeff Niu 
304619fd8c2SJeff Niu   /// Iterator over the integer elements of a dense array.
305619fd8c2SJeff Niu   class PyDenseArrayIterator {
306619fd8c2SJeff Niu   public:
3074a1b1196SMehdi Amini     PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
308619fd8c2SJeff Niu 
309619fd8c2SJeff Niu     /// Return a copy of the iterator.
310619fd8c2SJeff Niu     PyDenseArrayIterator dunderIter() { return *this; }
311619fd8c2SJeff Niu 
312619fd8c2SJeff Niu     /// Return the next element.
313619fd8c2SJeff Niu     EltTy dunderNext() {
314619fd8c2SJeff Niu       // Throw if the index has reached the end.
315619fd8c2SJeff Niu       if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
316b56d1ec6SPeter Hawkins         throw nb::stop_iteration();
317619fd8c2SJeff Niu       return DerivedT::getElement(attr.get(), nextIndex++);
318619fd8c2SJeff Niu     }
319619fd8c2SJeff Niu 
320619fd8c2SJeff Niu     /// Bind the iterator class.
321b56d1ec6SPeter Hawkins     static void bind(nb::module_ &m) {
322b56d1ec6SPeter Hawkins       nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
323619fd8c2SJeff Niu           .def("__iter__", &PyDenseArrayIterator::dunderIter)
324619fd8c2SJeff Niu           .def("__next__", &PyDenseArrayIterator::dunderNext);
325619fd8c2SJeff Niu     }
326619fd8c2SJeff Niu 
327619fd8c2SJeff Niu   private:
328619fd8c2SJeff Niu     /// The referenced dense array attribute.
329619fd8c2SJeff Niu     PyAttribute attr;
330619fd8c2SJeff Niu     /// The next index to read.
331619fd8c2SJeff Niu     int nextIndex = 0;
332619fd8c2SJeff Niu   };
333619fd8c2SJeff Niu 
334619fd8c2SJeff Niu   /// Get the element at the given index.
335619fd8c2SJeff Niu   EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
336619fd8c2SJeff Niu 
337619fd8c2SJeff Niu   /// Bind the attribute class.
338133624acSJeff Niu   static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
339619fd8c2SJeff Niu     // Bind the constructor.
340b56d1ec6SPeter Hawkins     if constexpr (std::is_same_v<EltTy, bool>) {
341b56d1ec6SPeter Hawkins       c.def_static(
342b56d1ec6SPeter Hawkins           "get",
343b56d1ec6SPeter Hawkins           [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) {
344b56d1ec6SPeter Hawkins             std::vector<bool> values;
345b56d1ec6SPeter Hawkins             for (nb::handle py_value : py_values) {
346b56d1ec6SPeter Hawkins               int is_true = PyObject_IsTrue(py_value.ptr());
347b56d1ec6SPeter Hawkins               if (is_true < 0) {
348b56d1ec6SPeter Hawkins                 throw nb::python_error();
349b56d1ec6SPeter Hawkins               }
350b56d1ec6SPeter Hawkins               values.push_back(is_true);
351b56d1ec6SPeter Hawkins             }
352b56d1ec6SPeter Hawkins             return getAttribute(values, ctx->getRef());
353b56d1ec6SPeter Hawkins           },
354b56d1ec6SPeter Hawkins           nb::arg("values"), nb::arg("context").none() = nb::none(),
355b56d1ec6SPeter Hawkins           "Gets a uniqued dense array attribute");
356b56d1ec6SPeter Hawkins     } else {
357619fd8c2SJeff Niu       c.def_static(
358619fd8c2SJeff Niu           "get",
359619fd8c2SJeff Niu           [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
3608dcb6722SIngo Müller             return getAttribute(values, ctx->getRef());
361619fd8c2SJeff Niu           },
362b56d1ec6SPeter Hawkins           nb::arg("values"), nb::arg("context").none() = nb::none(),
363619fd8c2SJeff Niu           "Gets a uniqued dense array attribute");
364b56d1ec6SPeter Hawkins     }
365619fd8c2SJeff Niu     // Bind the array methods.
366133624acSJeff Niu     c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
367619fd8c2SJeff Niu       if (i >= mlirDenseArrayGetNumElements(arr))
368b56d1ec6SPeter Hawkins         throw nb::index_error("DenseArray index out of range");
369619fd8c2SJeff Niu       return arr.getItem(i);
370619fd8c2SJeff Niu     });
371133624acSJeff Niu     c.def("__len__", [](const DerivedT &arr) {
372619fd8c2SJeff Niu       return mlirDenseArrayGetNumElements(arr);
373619fd8c2SJeff Niu     });
374133624acSJeff Niu     c.def("__iter__",
375133624acSJeff Niu           [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
376b56d1ec6SPeter Hawkins     c.def("__add__", [](DerivedT &arr, const nb::list &extras) {
377619fd8c2SJeff Niu       std::vector<EltTy> values;
378619fd8c2SJeff Niu       intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
379b56d1ec6SPeter Hawkins       values.reserve(numOldElements + nb::len(extras));
380619fd8c2SJeff Niu       for (intptr_t i = 0; i < numOldElements; ++i)
381619fd8c2SJeff Niu         values.push_back(arr.getItem(i));
382b56d1ec6SPeter Hawkins       for (nb::handle attr : extras)
383619fd8c2SJeff Niu         values.push_back(pyTryCast<EltTy>(attr));
3848dcb6722SIngo Müller       return getAttribute(values, arr.getContext());
385619fd8c2SJeff Niu     });
386619fd8c2SJeff Niu   }
3878dcb6722SIngo Müller 
3888dcb6722SIngo Müller private:
3898dcb6722SIngo Müller   static DerivedT getAttribute(const std::vector<EltTy> &values,
3908dcb6722SIngo Müller                                PyMlirContextRef ctx) {
3918dcb6722SIngo Müller     if constexpr (std::is_same_v<EltTy, bool>) {
3928dcb6722SIngo Müller       std::vector<int> intValues(values.begin(), values.end());
3938dcb6722SIngo Müller       MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
3948dcb6722SIngo Müller                                                   intValues.data());
3958dcb6722SIngo Müller       return DerivedT(ctx, attr);
3968dcb6722SIngo Müller     } else {
3978dcb6722SIngo Müller       MlirAttribute attr =
3988dcb6722SIngo Müller           DerivedT::getAttribute(ctx->get(), values.size(), values.data());
3998dcb6722SIngo Müller       return DerivedT(ctx, attr);
4008dcb6722SIngo Müller     }
4018dcb6722SIngo Müller   }
402619fd8c2SJeff Niu };
403619fd8c2SJeff Niu 
404619fd8c2SJeff Niu /// Instantiate the python dense array classes.
405619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute
4068dcb6722SIngo Müller     : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
407619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
408619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseBoolArrayGet;
409619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseBoolArrayGetElement;
410619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseBoolArrayAttr";
411619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
412619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
413619fd8c2SJeff Niu };
414619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute
415619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
416619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
417619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI8ArrayGet;
418619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI8ArrayGetElement;
419619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI8ArrayAttr";
420619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
421619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
422619fd8c2SJeff Niu };
423619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute
424619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
425619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
426619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI16ArrayGet;
427619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI16ArrayGetElement;
428619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI16ArrayAttr";
429619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
430619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
431619fd8c2SJeff Niu };
432619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute
433619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
434619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
435619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI32ArrayGet;
436619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI32ArrayGetElement;
437619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI32ArrayAttr";
438619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
439619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
440619fd8c2SJeff Niu };
441619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute
442619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
443619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
444619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI64ArrayGet;
445619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI64ArrayGetElement;
446619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI64ArrayAttr";
447619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
448619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
449619fd8c2SJeff Niu };
450619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute
451619fd8c2SJeff Niu     : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
452619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
453619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF32ArrayGet;
454619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF32ArrayGetElement;
455619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF32ArrayAttr";
456619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
457619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
458619fd8c2SJeff Niu };
459619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute
460619fd8c2SJeff Niu     : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
461619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
462619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF64ArrayGet;
463619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF64ArrayGetElement;
464619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF64ArrayAttr";
465619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
466619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
467619fd8c2SJeff Niu };
468619fd8c2SJeff Niu 
469436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
470436c6c9cSStella Laurenzo public:
471436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
472436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
473436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
4749566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
4759566ee28Smax       mlirArrayAttrGetTypeID;
476436c6c9cSStella Laurenzo 
477436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
478436c6c9cSStella Laurenzo   public:
4791fc096afSMehdi Amini     PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
480436c6c9cSStella Laurenzo 
481436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
482436c6c9cSStella Laurenzo 
483974c1596SRahul Kayaith     MlirAttribute dunderNext() {
484bca88952SJeff Niu       // TODO: Throw is an inefficient way to stop iteration.
485bca88952SJeff Niu       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
486b56d1ec6SPeter Hawkins         throw nb::stop_iteration();
487974c1596SRahul Kayaith       return mlirArrayAttrGetElement(attr.get(), nextIndex++);
488436c6c9cSStella Laurenzo     }
489436c6c9cSStella Laurenzo 
490b56d1ec6SPeter Hawkins     static void bind(nb::module_ &m) {
491b56d1ec6SPeter Hawkins       nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
492436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
493436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
494436c6c9cSStella Laurenzo     }
495436c6c9cSStella Laurenzo 
496436c6c9cSStella Laurenzo   private:
497436c6c9cSStella Laurenzo     PyAttribute attr;
498436c6c9cSStella Laurenzo     int nextIndex = 0;
499436c6c9cSStella Laurenzo   };
500436c6c9cSStella Laurenzo 
501974c1596SRahul Kayaith   MlirAttribute getItem(intptr_t i) {
502974c1596SRahul Kayaith     return mlirArrayAttrGetElement(*this, i);
503ed9e52f3SAlex Zinenko   }
504ed9e52f3SAlex Zinenko 
505436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
506436c6c9cSStella Laurenzo     c.def_static(
507436c6c9cSStella Laurenzo         "get",
508b56d1ec6SPeter Hawkins         [](nb::list attributes, DefaultingPyMlirContext context) {
509436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
510b56d1ec6SPeter Hawkins           mlirAttributes.reserve(nb::len(attributes));
511436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
512ed9e52f3SAlex Zinenko             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
513436c6c9cSStella Laurenzo           }
514436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
515436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
516436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
517436c6c9cSStella Laurenzo         },
518b56d1ec6SPeter Hawkins         nb::arg("attributes"), nb::arg("context").none() = nb::none(),
519436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
520436c6c9cSStella Laurenzo     c.def("__getitem__",
521436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
522436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
523b56d1ec6SPeter Hawkins               throw nb::index_error("ArrayAttribute index out of range");
524ed9e52f3SAlex Zinenko             return arr.getItem(i);
525436c6c9cSStella Laurenzo           })
526436c6c9cSStella Laurenzo         .def("__len__",
527436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
528436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
529436c6c9cSStella Laurenzo              })
530436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
531436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
532436c6c9cSStella Laurenzo         });
533b56d1ec6SPeter Hawkins     c.def("__add__", [](PyArrayAttribute arr, nb::list extras) {
534ed9e52f3SAlex Zinenko       std::vector<MlirAttribute> attributes;
535ed9e52f3SAlex Zinenko       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
536b56d1ec6SPeter Hawkins       attributes.reserve(numOldElements + nb::len(extras));
537ed9e52f3SAlex Zinenko       for (intptr_t i = 0; i < numOldElements; ++i)
538ed9e52f3SAlex Zinenko         attributes.push_back(arr.getItem(i));
539b56d1ec6SPeter Hawkins       for (nb::handle attr : extras)
540ed9e52f3SAlex Zinenko         attributes.push_back(pyTryCast<PyAttribute>(attr));
541ed9e52f3SAlex Zinenko       MlirAttribute arrayAttr = mlirArrayAttrGet(
542ed9e52f3SAlex Zinenko           arr.getContext()->get(), attributes.size(), attributes.data());
543ed9e52f3SAlex Zinenko       return PyArrayAttribute(arr.getContext(), arrayAttr);
544ed9e52f3SAlex Zinenko     });
545436c6c9cSStella Laurenzo   }
546436c6c9cSStella Laurenzo };
547436c6c9cSStella Laurenzo 
548436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
549436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
550436c6c9cSStella Laurenzo public:
551436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
552436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
553436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
5549566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
5559566ee28Smax       mlirFloatAttrGetTypeID;
556436c6c9cSStella Laurenzo 
557436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
558436c6c9cSStella Laurenzo     c.def_static(
559436c6c9cSStella Laurenzo         "get",
560436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
5613ea4c501SRahul Kayaith           PyMlirContext::ErrorCapture errors(loc->getContext());
562436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
5633ea4c501SRahul Kayaith           if (mlirAttributeIsNull(attr))
5643ea4c501SRahul Kayaith             throw MLIRError("Invalid attribute", errors.take());
565436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
566436c6c9cSStella Laurenzo         },
567b56d1ec6SPeter Hawkins         nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(),
568436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
569436c6c9cSStella Laurenzo     c.def_static(
570436c6c9cSStella Laurenzo         "get_f32",
571436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
572436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
573436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
574436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
575436c6c9cSStella Laurenzo         },
576b56d1ec6SPeter Hawkins         nb::arg("value"), nb::arg("context").none() = nb::none(),
577436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
578436c6c9cSStella Laurenzo     c.def_static(
579436c6c9cSStella Laurenzo         "get_f64",
580436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
581436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
582436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
583436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
584436c6c9cSStella Laurenzo         },
585b56d1ec6SPeter Hawkins         nb::arg("value"), nb::arg("context").none() = nb::none(),
586436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
587b56d1ec6SPeter Hawkins     c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
5882a5d4974SIngo Müller                   "Returns the value of the float attribute");
5892a5d4974SIngo Müller     c.def("__float__", mlirFloatAttrGetValueDouble,
5902a5d4974SIngo Müller           "Converts the value of the float attribute to a Python float");
591436c6c9cSStella Laurenzo   }
592436c6c9cSStella Laurenzo };
593436c6c9cSStella Laurenzo 
594436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
595436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
596436c6c9cSStella Laurenzo public:
597436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
598436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
599436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
600436c6c9cSStella Laurenzo 
601436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
602436c6c9cSStella Laurenzo     c.def_static(
603436c6c9cSStella Laurenzo         "get",
604436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
605436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
606436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
607436c6c9cSStella Laurenzo         },
608b56d1ec6SPeter Hawkins         nb::arg("type"), nb::arg("value"),
609436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
610b56d1ec6SPeter Hawkins     c.def_prop_ro("value", toPyInt,
6112a5d4974SIngo Müller                   "Returns the value of the integer attribute");
6122a5d4974SIngo Müller     c.def("__int__", toPyInt,
6132a5d4974SIngo Müller           "Converts the value of the integer attribute to a Python int");
614b56d1ec6SPeter Hawkins     c.def_prop_ro_static("static_typeid",
615b56d1ec6SPeter Hawkins                          [](nb::object & /*class*/) -> MlirTypeID {
6162a5d4974SIngo Müller                            return mlirIntegerAttrGetTypeID();
6172a5d4974SIngo Müller                          });
6182a5d4974SIngo Müller   }
6192a5d4974SIngo Müller 
6202a5d4974SIngo Müller private:
621b56d1ec6SPeter Hawkins   static int64_t toPyInt(PyIntegerAttribute &self) {
622e9db306dSrkayaith     MlirType type = mlirAttributeGetType(self);
623e9db306dSrkayaith     if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
624436c6c9cSStella Laurenzo       return mlirIntegerAttrGetValueInt(self);
625e9db306dSrkayaith     if (mlirIntegerTypeIsSigned(type))
626e9db306dSrkayaith       return mlirIntegerAttrGetValueSInt(self);
627e9db306dSrkayaith     return mlirIntegerAttrGetValueUInt(self);
628436c6c9cSStella Laurenzo   }
629436c6c9cSStella Laurenzo };
630436c6c9cSStella Laurenzo 
631436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
632436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
633436c6c9cSStella Laurenzo public:
634436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
635436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
636436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
637436c6c9cSStella Laurenzo 
638436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
639436c6c9cSStella Laurenzo     c.def_static(
640436c6c9cSStella Laurenzo         "get",
641436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
642436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
643436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
644436c6c9cSStella Laurenzo         },
645b56d1ec6SPeter Hawkins         nb::arg("value"), nb::arg("context").none() = nb::none(),
646436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
647b56d1ec6SPeter Hawkins     c.def_prop_ro("value", mlirBoolAttrGetValue,
648436c6c9cSStella Laurenzo                   "Returns the value of the bool attribute");
6492a5d4974SIngo Müller     c.def("__bool__", mlirBoolAttrGetValue,
6502a5d4974SIngo Müller           "Converts the value of the bool attribute to a Python bool");
651436c6c9cSStella Laurenzo   }
652436c6c9cSStella Laurenzo };
653436c6c9cSStella Laurenzo 
6544eee9ef9Smax class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
6554eee9ef9Smax public:
6564eee9ef9Smax   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
6574eee9ef9Smax   static constexpr const char *pyClassName = "SymbolRefAttr";
6584eee9ef9Smax   using PyConcreteAttribute::PyConcreteAttribute;
6594eee9ef9Smax 
6604eee9ef9Smax   static MlirAttribute fromList(const std::vector<std::string> &symbols,
6614eee9ef9Smax                                 PyMlirContext &context) {
6624eee9ef9Smax     if (symbols.empty())
6634eee9ef9Smax       throw std::runtime_error("SymbolRefAttr must be composed of at least "
6644eee9ef9Smax                                "one symbol.");
6654eee9ef9Smax     MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
6664eee9ef9Smax     SmallVector<MlirAttribute, 3> referenceAttrs;
6674eee9ef9Smax     for (size_t i = 1; i < symbols.size(); ++i) {
6684eee9ef9Smax       referenceAttrs.push_back(
6694eee9ef9Smax           mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
6704eee9ef9Smax     }
6714eee9ef9Smax     return mlirSymbolRefAttrGet(context.get(), rootSymbol,
6724eee9ef9Smax                                 referenceAttrs.size(), referenceAttrs.data());
6734eee9ef9Smax   }
6744eee9ef9Smax 
6754eee9ef9Smax   static void bindDerived(ClassTy &c) {
6764eee9ef9Smax     c.def_static(
6774eee9ef9Smax         "get",
6784eee9ef9Smax         [](const std::vector<std::string> &symbols,
6794eee9ef9Smax            DefaultingPyMlirContext context) {
6804eee9ef9Smax           return PySymbolRefAttribute::fromList(symbols, context.resolve());
6814eee9ef9Smax         },
682b56d1ec6SPeter Hawkins         nb::arg("symbols"), nb::arg("context").none() = nb::none(),
6834eee9ef9Smax         "Gets a uniqued SymbolRef attribute from a list of symbol names");
684b56d1ec6SPeter Hawkins     c.def_prop_ro(
6854eee9ef9Smax         "value",
6864eee9ef9Smax         [](PySymbolRefAttribute &self) {
6874eee9ef9Smax           std::vector<std::string> symbols = {
6884eee9ef9Smax               unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
6894eee9ef9Smax           for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
6904eee9ef9Smax                ++i)
6914eee9ef9Smax             symbols.push_back(
6924eee9ef9Smax                 unwrap(mlirSymbolRefAttrGetRootReference(
6934eee9ef9Smax                            mlirSymbolRefAttrGetNestedReference(self, i)))
6944eee9ef9Smax                     .str());
6954eee9ef9Smax           return symbols;
6964eee9ef9Smax         },
6974eee9ef9Smax         "Returns the value of the SymbolRef attribute as a list[str]");
6984eee9ef9Smax   }
6994eee9ef9Smax };
7004eee9ef9Smax 
701436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
702436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
703436c6c9cSStella Laurenzo public:
704436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
705436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
706436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
707436c6c9cSStella Laurenzo 
708436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
709436c6c9cSStella Laurenzo     c.def_static(
710436c6c9cSStella Laurenzo         "get",
711436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
712436c6c9cSStella Laurenzo           MlirAttribute attr =
713436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
714436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
715436c6c9cSStella Laurenzo         },
716b56d1ec6SPeter Hawkins         nb::arg("value"), nb::arg("context").none() = nb::none(),
717436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
718b56d1ec6SPeter Hawkins     c.def_prop_ro(
719436c6c9cSStella Laurenzo         "value",
720436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
721436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
722b56d1ec6SPeter Hawkins           return nb::str(stringRef.data, stringRef.length);
723436c6c9cSStella Laurenzo         },
724436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
725436c6c9cSStella Laurenzo   }
726436c6c9cSStella Laurenzo };
727436c6c9cSStella Laurenzo 
7285c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
7295c3861b2SYun Long public:
7305c3861b2SYun Long   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
7315c3861b2SYun Long   static constexpr const char *pyClassName = "OpaqueAttr";
7325c3861b2SYun Long   using PyConcreteAttribute::PyConcreteAttribute;
7339566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
7349566ee28Smax       mlirOpaqueAttrGetTypeID;
7355c3861b2SYun Long 
7365c3861b2SYun Long   static void bindDerived(ClassTy &c) {
7375c3861b2SYun Long     c.def_static(
7385c3861b2SYun Long         "get",
739b56d1ec6SPeter Hawkins         [](std::string dialectNamespace, nb_buffer buffer, PyType &type,
7405c3861b2SYun Long            DefaultingPyMlirContext context) {
741b56d1ec6SPeter Hawkins           const nb_buffer_info bufferInfo = buffer.request();
7425c3861b2SYun Long           intptr_t bufferSize = bufferInfo.size;
7435c3861b2SYun Long           MlirAttribute attr = mlirOpaqueAttrGet(
7445c3861b2SYun Long               context->get(), toMlirStringRef(dialectNamespace), bufferSize,
7455c3861b2SYun Long               static_cast<char *>(bufferInfo.ptr), type);
7465c3861b2SYun Long           return PyOpaqueAttribute(context->getRef(), attr);
7475c3861b2SYun Long         },
748b56d1ec6SPeter Hawkins         nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
749b56d1ec6SPeter Hawkins         nb::arg("context").none() = nb::none(), "Gets an Opaque attribute.");
750b56d1ec6SPeter Hawkins     c.def_prop_ro(
7515c3861b2SYun Long         "dialect_namespace",
7525c3861b2SYun Long         [](PyOpaqueAttribute &self) {
7535c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
754b56d1ec6SPeter Hawkins           return nb::str(stringRef.data, stringRef.length);
7555c3861b2SYun Long         },
7565c3861b2SYun Long         "Returns the dialect namespace for the Opaque attribute as a string");
757b56d1ec6SPeter Hawkins     c.def_prop_ro(
7585c3861b2SYun Long         "data",
7595c3861b2SYun Long         [](PyOpaqueAttribute &self) {
7605c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
761b56d1ec6SPeter Hawkins           return nb::bytes(stringRef.data, stringRef.length);
7625c3861b2SYun Long         },
76362bf6c2eSChris Jones         "Returns the data for the Opaqued attributes as `bytes`");
7645c3861b2SYun Long   }
7655c3861b2SYun Long };
7665c3861b2SYun Long 
767436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
768436c6c9cSStella Laurenzo public:
769436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
770436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
771436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
7729566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
7739566ee28Smax       mlirStringAttrGetTypeID;
774436c6c9cSStella Laurenzo 
775436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
776436c6c9cSStella Laurenzo     c.def_static(
777436c6c9cSStella Laurenzo         "get",
778436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
779436c6c9cSStella Laurenzo           MlirAttribute attr =
780436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
781436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
782436c6c9cSStella Laurenzo         },
783b56d1ec6SPeter Hawkins         nb::arg("value"), nb::arg("context").none() = nb::none(),
784b56d1ec6SPeter Hawkins         "Gets a uniqued string attribute");
785b56d1ec6SPeter Hawkins     c.def_static(
786b56d1ec6SPeter Hawkins         "get",
787b56d1ec6SPeter Hawkins         [](nb::bytes value, DefaultingPyMlirContext context) {
788b56d1ec6SPeter Hawkins           MlirAttribute attr =
789b56d1ec6SPeter Hawkins               mlirStringAttrGet(context->get(), toMlirStringRef(value));
790b56d1ec6SPeter Hawkins           return PyStringAttribute(context->getRef(), attr);
791b56d1ec6SPeter Hawkins         },
792b56d1ec6SPeter Hawkins         nb::arg("value"), nb::arg("context").none() = nb::none(),
793436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
794436c6c9cSStella Laurenzo     c.def_static(
795436c6c9cSStella Laurenzo         "get_typed",
796436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
797436c6c9cSStella Laurenzo           MlirAttribute attr =
798436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
799436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
800436c6c9cSStella Laurenzo         },
801b56d1ec6SPeter Hawkins         nb::arg("type"), nb::arg("value"),
802436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
803b56d1ec6SPeter Hawkins     c.def_prop_ro(
8049f533548SIngo Müller         "value",
8059f533548SIngo Müller         [](PyStringAttribute &self) {
8069f533548SIngo Müller           MlirStringRef stringRef = mlirStringAttrGetValue(self);
807b56d1ec6SPeter Hawkins           return nb::str(stringRef.data, stringRef.length);
8089f533548SIngo Müller         },
809436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
810b56d1ec6SPeter Hawkins     c.def_prop_ro(
81162bf6c2eSChris Jones         "value_bytes",
81262bf6c2eSChris Jones         [](PyStringAttribute &self) {
81362bf6c2eSChris Jones           MlirStringRef stringRef = mlirStringAttrGetValue(self);
814b56d1ec6SPeter Hawkins           return nb::bytes(stringRef.data, stringRef.length);
81562bf6c2eSChris Jones         },
81662bf6c2eSChris Jones         "Returns the value of the string attribute as `bytes`");
817436c6c9cSStella Laurenzo   }
818436c6c9cSStella Laurenzo };
819436c6c9cSStella Laurenzo 
820436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
821436c6c9cSStella Laurenzo class PyDenseElementsAttribute
822436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
823436c6c9cSStella Laurenzo public:
824436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
825436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
826436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
827436c6c9cSStella Laurenzo 
828436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
829b56d1ec6SPeter Hawkins   getFromList(nb::list attributes, std::optional<PyType> explicitType,
830c912f0e7Spranavm-nvidia               DefaultingPyMlirContext contextWrapper) {
831b56d1ec6SPeter Hawkins     const size_t numAttributes = nb::len(attributes);
832c912f0e7Spranavm-nvidia     if (numAttributes == 0)
833b56d1ec6SPeter Hawkins       throw nb::value_error("Attributes list must be non-empty.");
834c912f0e7Spranavm-nvidia 
835c912f0e7Spranavm-nvidia     MlirType shapedType;
836c912f0e7Spranavm-nvidia     if (explicitType) {
837c912f0e7Spranavm-nvidia       if ((!mlirTypeIsAShaped(*explicitType) ||
838c912f0e7Spranavm-nvidia            !mlirShapedTypeHasStaticShape(*explicitType))) {
839c912f0e7Spranavm-nvidia 
840c912f0e7Spranavm-nvidia         std::string message;
841c912f0e7Spranavm-nvidia         llvm::raw_string_ostream os(message);
842c912f0e7Spranavm-nvidia         os << "Expected a static ShapedType for the shaped_type parameter: "
843b56d1ec6SPeter Hawkins            << nb::cast<std::string>(nb::repr(nb::cast(*explicitType)));
844b56d1ec6SPeter Hawkins         throw nb::value_error(message.c_str());
845c912f0e7Spranavm-nvidia       }
846c912f0e7Spranavm-nvidia       shapedType = *explicitType;
847c912f0e7Spranavm-nvidia     } else {
8489cbc1f29SHan-Chung Wang       SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)};
849c912f0e7Spranavm-nvidia       shapedType = mlirRankedTensorTypeGet(
850c912f0e7Spranavm-nvidia           shape.size(), shape.data(),
851c912f0e7Spranavm-nvidia           mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
852c912f0e7Spranavm-nvidia           mlirAttributeGetNull());
853c912f0e7Spranavm-nvidia     }
854c912f0e7Spranavm-nvidia 
855c912f0e7Spranavm-nvidia     SmallVector<MlirAttribute> mlirAttributes;
856c912f0e7Spranavm-nvidia     mlirAttributes.reserve(numAttributes);
857b56d1ec6SPeter Hawkins     for (const nb::handle &attribute : attributes) {
858c912f0e7Spranavm-nvidia       MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
859c912f0e7Spranavm-nvidia       MlirType attrType = mlirAttributeGetType(mlirAttribute);
860c912f0e7Spranavm-nvidia       mlirAttributes.push_back(mlirAttribute);
861c912f0e7Spranavm-nvidia 
862c912f0e7Spranavm-nvidia       if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
863c912f0e7Spranavm-nvidia         std::string message;
864c912f0e7Spranavm-nvidia         llvm::raw_string_ostream os(message);
865c912f0e7Spranavm-nvidia         os << "All attributes must be of the same type and match "
866b56d1ec6SPeter Hawkins            << "the type parameter: expected="
867b56d1ec6SPeter Hawkins            << nb::cast<std::string>(nb::repr(nb::cast(shapedType)))
868b56d1ec6SPeter Hawkins            << ", but got="
869b56d1ec6SPeter Hawkins            << nb::cast<std::string>(nb::repr(nb::cast(attrType)));
870b56d1ec6SPeter Hawkins         throw nb::value_error(message.c_str());
871c912f0e7Spranavm-nvidia       }
872c912f0e7Spranavm-nvidia     }
873c912f0e7Spranavm-nvidia 
874c912f0e7Spranavm-nvidia     MlirAttribute elements = mlirDenseElementsAttrGet(
875c912f0e7Spranavm-nvidia         shapedType, mlirAttributes.size(), mlirAttributes.data());
876c912f0e7Spranavm-nvidia 
877c912f0e7Spranavm-nvidia     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
878c912f0e7Spranavm-nvidia   }
879c912f0e7Spranavm-nvidia 
880c912f0e7Spranavm-nvidia   static PyDenseElementsAttribute
881b56d1ec6SPeter Hawkins   getFromBuffer(nb_buffer array, bool signless,
8820a81ace0SKazu Hirata                 std::optional<PyType> explicitType,
8830a81ace0SKazu Hirata                 std::optional<std::vector<int64_t>> explicitShape,
884436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
885436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
88671a25454SPeter Hawkins     int flags = PyBUF_ND;
88771a25454SPeter Hawkins     if (!explicitType) {
88871a25454SPeter Hawkins       flags |= PyBUF_FORMAT;
88971a25454SPeter Hawkins     }
89071a25454SPeter Hawkins     Py_buffer view;
89171a25454SPeter Hawkins     if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
892b56d1ec6SPeter Hawkins       throw nb::python_error();
893436c6c9cSStella Laurenzo     }
89471a25454SPeter Hawkins     auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
895436c6c9cSStella Laurenzo 
896436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
8971824e45cSKasper Nielsen     MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
8981824e45cSKasper Nielsen                                                 explicitShape, context);
8995d6d30edSStella Laurenzo     if (mlirAttributeIsNull(attr)) {
9005d6d30edSStella Laurenzo       throw std::invalid_argument(
9015d6d30edSStella Laurenzo           "DenseElementsAttr could not be constructed from the given buffer. "
9025d6d30edSStella Laurenzo           "This may mean that the Python buffer layout does not match that "
9035d6d30edSStella Laurenzo           "MLIR expected layout and is a bug.");
9045d6d30edSStella Laurenzo     }
9055d6d30edSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
9065d6d30edSStella Laurenzo   }
907436c6c9cSStella Laurenzo 
9081fc096afSMehdi Amini   static PyDenseElementsAttribute getSplat(const PyType &shapedType,
909436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
910436c6c9cSStella Laurenzo     auto contextWrapper =
911436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
912436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
913436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
914436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
915b56d1ec6SPeter Hawkins       message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
916b56d1ec6SPeter Hawkins       throw nb::value_error(message.c_str());
917436c6c9cSStella Laurenzo     }
918436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
919436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
920436c6c9cSStella Laurenzo       std::string message =
921436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
922b56d1ec6SPeter Hawkins       message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
923b56d1ec6SPeter Hawkins       throw nb::value_error(message.c_str());
924436c6c9cSStella Laurenzo     }
925436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
926436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
927436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
928436c6c9cSStella Laurenzo       std::string message =
929436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
930b56d1ec6SPeter Hawkins       message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
931436c6c9cSStella Laurenzo       message.append(", element=");
932b56d1ec6SPeter Hawkins       message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
933b56d1ec6SPeter Hawkins       throw nb::value_error(message.c_str());
934436c6c9cSStella Laurenzo     }
935436c6c9cSStella Laurenzo 
936436c6c9cSStella Laurenzo     MlirAttribute elements =
937436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
938436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
939436c6c9cSStella Laurenzo   }
940436c6c9cSStella Laurenzo 
941436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
942436c6c9cSStella Laurenzo 
943b56d1ec6SPeter Hawkins   std::unique_ptr<nb_buffer_info> accessBuffer() {
944436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
945436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
9465d6d30edSStella Laurenzo     std::string format;
947436c6c9cSStella Laurenzo 
948436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
949436c6c9cSStella Laurenzo       // f32
9505d6d30edSStella Laurenzo       return bufferInfo<float>(shapedType);
95102b6fb21SMehdi Amini     }
95202b6fb21SMehdi Amini     if (mlirTypeIsAF64(elementType)) {
953436c6c9cSStella Laurenzo       // f64
9545d6d30edSStella Laurenzo       return bufferInfo<double>(shapedType);
955bb56c2b3SMehdi Amini     }
956bb56c2b3SMehdi Amini     if (mlirTypeIsAF16(elementType)) {
9575d6d30edSStella Laurenzo       // f16
9585d6d30edSStella Laurenzo       return bufferInfo<uint16_t>(shapedType, "e");
959bb56c2b3SMehdi Amini     }
960ef1b735dSmax     if (mlirTypeIsAIndex(elementType)) {
961ef1b735dSmax       // Same as IndexType::kInternalStorageBitWidth
962ef1b735dSmax       return bufferInfo<int64_t>(shapedType);
963ef1b735dSmax     }
964bb56c2b3SMehdi Amini     if (mlirTypeIsAInteger(elementType) &&
965436c6c9cSStella Laurenzo         mlirIntegerTypeGetWidth(elementType) == 32) {
966436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
967436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
968436c6c9cSStella Laurenzo         // i32
9695d6d30edSStella Laurenzo         return bufferInfo<int32_t>(shapedType);
970e5639b3fSMehdi Amini       }
971e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
972436c6c9cSStella Laurenzo         // unsigned i32
9735d6d30edSStella Laurenzo         return bufferInfo<uint32_t>(shapedType);
974436c6c9cSStella Laurenzo       }
975436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
976436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
977436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
978436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
979436c6c9cSStella Laurenzo         // i64
9805d6d30edSStella Laurenzo         return bufferInfo<int64_t>(shapedType);
981e5639b3fSMehdi Amini       }
982e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
983436c6c9cSStella Laurenzo         // unsigned i64
9845d6d30edSStella Laurenzo         return bufferInfo<uint64_t>(shapedType);
9855d6d30edSStella Laurenzo       }
9865d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
9875d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 8) {
9885d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
9895d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
9905d6d30edSStella Laurenzo         // i8
9915d6d30edSStella Laurenzo         return bufferInfo<int8_t>(shapedType);
992e5639b3fSMehdi Amini       }
993e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
9945d6d30edSStella Laurenzo         // unsigned i8
9955d6d30edSStella Laurenzo         return bufferInfo<uint8_t>(shapedType);
9965d6d30edSStella Laurenzo       }
9975d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
9985d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 16) {
9995d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
10005d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
10015d6d30edSStella Laurenzo         // i16
10025d6d30edSStella Laurenzo         return bufferInfo<int16_t>(shapedType);
1003e5639b3fSMehdi Amini       }
1004e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
10055d6d30edSStella Laurenzo         // unsigned i16
10065d6d30edSStella Laurenzo         return bufferInfo<uint16_t>(shapedType);
1007436c6c9cSStella Laurenzo       }
10081824e45cSKasper Nielsen     } else if (mlirTypeIsAInteger(elementType) &&
10091824e45cSKasper Nielsen                mlirIntegerTypeGetWidth(elementType) == 1) {
10101824e45cSKasper Nielsen       // i1 / bool
10111824e45cSKasper Nielsen       // We can not send the buffer directly back to Python, because the i1
10121824e45cSKasper Nielsen       // values are bitpacked within MLIR. We call numpy's unpackbits function
10131824e45cSKasper Nielsen       // to convert the bytes.
10141824e45cSKasper Nielsen       return getBooleanBufferFromBitpackedAttribute();
1015436c6c9cSStella Laurenzo     }
1016436c6c9cSStella Laurenzo 
1017c5f445d1SStella Laurenzo     // TODO: Currently crashes the program.
10185d6d30edSStella Laurenzo     // Reported as https://github.com/pybind/pybind11/issues/3336
1019c5f445d1SStella Laurenzo     throw std::invalid_argument(
1020c5f445d1SStella Laurenzo         "unsupported data type for conversion to Python buffer");
1021436c6c9cSStella Laurenzo   }
1022436c6c9cSStella Laurenzo 
1023436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1024b56d1ec6SPeter Hawkins #if PY_VERSION_HEX < 0x03090000
1025b56d1ec6SPeter Hawkins     PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr());
1026b56d1ec6SPeter Hawkins     tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer;
1027b56d1ec6SPeter Hawkins     tp->tp_as_buffer->bf_releasebuffer =
1028b56d1ec6SPeter Hawkins         PyDenseElementsAttribute::bf_releasebuffer;
1029b56d1ec6SPeter Hawkins #endif
1030436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
1031436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
1032b56d1ec6SPeter Hawkins                     nb::arg("array"), nb::arg("signless") = true,
1033b56d1ec6SPeter Hawkins                     nb::arg("type").none() = nb::none(),
1034b56d1ec6SPeter Hawkins                     nb::arg("shape").none() = nb::none(),
1035b56d1ec6SPeter Hawkins                     nb::arg("context").none() = nb::none(),
10365d6d30edSStella Laurenzo                     kDenseElementsAttrGetDocstring)
1037c912f0e7Spranavm-nvidia         .def_static("get", PyDenseElementsAttribute::getFromList,
1038b56d1ec6SPeter Hawkins                     nb::arg("attrs"), nb::arg("type").none() = nb::none(),
1039b56d1ec6SPeter Hawkins                     nb::arg("context").none() = nb::none(),
1040c912f0e7Spranavm-nvidia                     kDenseElementsAttrGetFromListDocstring)
1041436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
1042b56d1ec6SPeter Hawkins                     nb::arg("shaped_type"), nb::arg("element_attr"),
1043436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
1044b56d1ec6SPeter Hawkins         .def_prop_ro("is_splat",
1045436c6c9cSStella Laurenzo                      [](PyDenseElementsAttribute &self) -> bool {
1046436c6c9cSStella Laurenzo                        return mlirDenseElementsAttrIsSplat(self);
1047436c6c9cSStella Laurenzo                      })
1048b56d1ec6SPeter Hawkins         .def("get_splat_value", [](PyDenseElementsAttribute &self) {
1049974c1596SRahul Kayaith           if (!mlirDenseElementsAttrIsSplat(self))
1050b56d1ec6SPeter Hawkins             throw nb::value_error(
105191259963SAdam Paszke                 "get_splat_value called on a non-splat attribute");
1052974c1596SRahul Kayaith           return mlirDenseElementsAttrGetSplatValue(self);
1053b56d1ec6SPeter Hawkins         });
1054436c6c9cSStella Laurenzo   }
1055436c6c9cSStella Laurenzo 
1056b56d1ec6SPeter Hawkins   static PyType_Slot slots[];
1057b56d1ec6SPeter Hawkins 
1058436c6c9cSStella Laurenzo private:
1059b56d1ec6SPeter Hawkins   static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
1060b56d1ec6SPeter Hawkins   static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
1061b56d1ec6SPeter Hawkins 
106271a25454SPeter Hawkins   static bool isUnsignedIntegerFormat(std::string_view format) {
1063436c6c9cSStella Laurenzo     if (format.empty())
1064436c6c9cSStella Laurenzo       return false;
1065436c6c9cSStella Laurenzo     char code = format[0];
1066436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
1067436c6c9cSStella Laurenzo            code == 'Q';
1068436c6c9cSStella Laurenzo   }
1069436c6c9cSStella Laurenzo 
107071a25454SPeter Hawkins   static bool isSignedIntegerFormat(std::string_view format) {
1071436c6c9cSStella Laurenzo     if (format.empty())
1072436c6c9cSStella Laurenzo       return false;
1073436c6c9cSStella Laurenzo     char code = format[0];
1074436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
1075436c6c9cSStella Laurenzo            code == 'q';
1076436c6c9cSStella Laurenzo   }
1077436c6c9cSStella Laurenzo 
10781824e45cSKasper Nielsen   static MlirType
10791824e45cSKasper Nielsen   getShapedType(std::optional<MlirType> bulkLoadElementType,
10801824e45cSKasper Nielsen                 std::optional<std::vector<int64_t>> explicitShape,
10811824e45cSKasper Nielsen                 Py_buffer &view) {
10821824e45cSKasper Nielsen     SmallVector<int64_t> shape;
10831824e45cSKasper Nielsen     if (explicitShape) {
10841824e45cSKasper Nielsen       shape.append(explicitShape->begin(), explicitShape->end());
10851824e45cSKasper Nielsen     } else {
10861824e45cSKasper Nielsen       shape.append(view.shape, view.shape + view.ndim);
10871824e45cSKasper Nielsen     }
10881824e45cSKasper Nielsen 
10891824e45cSKasper Nielsen     if (mlirTypeIsAShaped(*bulkLoadElementType)) {
10901824e45cSKasper Nielsen       if (explicitShape) {
10911824e45cSKasper Nielsen         throw std::invalid_argument("Shape can only be specified explicitly "
10921824e45cSKasper Nielsen                                     "when the type is not a shaped type.");
10931824e45cSKasper Nielsen       }
10941824e45cSKasper Nielsen       return *bulkLoadElementType;
10951824e45cSKasper Nielsen     } else {
10961824e45cSKasper Nielsen       MlirAttribute encodingAttr = mlirAttributeGetNull();
10971824e45cSKasper Nielsen       return mlirRankedTensorTypeGet(shape.size(), shape.data(),
10981824e45cSKasper Nielsen                                      *bulkLoadElementType, encodingAttr);
10991824e45cSKasper Nielsen     }
11001824e45cSKasper Nielsen   }
11011824e45cSKasper Nielsen 
11021824e45cSKasper Nielsen   static MlirAttribute getAttributeFromBuffer(
11031824e45cSKasper Nielsen       Py_buffer &view, bool signless, std::optional<PyType> explicitType,
11041824e45cSKasper Nielsen       std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
11051824e45cSKasper Nielsen     // Detect format codes that are suitable for bulk loading. This includes
11061824e45cSKasper Nielsen     // all byte aligned integer and floating point types up to 8 bytes.
11071824e45cSKasper Nielsen     // Notably, this excludes exotics types which do not have a direct
11081824e45cSKasper Nielsen     // representation in the buffer protocol (i.e. complex, etc).
11091824e45cSKasper Nielsen     std::optional<MlirType> bulkLoadElementType;
11101824e45cSKasper Nielsen     if (explicitType) {
11111824e45cSKasper Nielsen       bulkLoadElementType = *explicitType;
11121824e45cSKasper Nielsen     } else {
11131824e45cSKasper Nielsen       std::string_view format(view.format);
11141824e45cSKasper Nielsen       if (format == "f") {
11151824e45cSKasper Nielsen         // f32
11161824e45cSKasper Nielsen         assert(view.itemsize == 4 && "mismatched array itemsize");
11171824e45cSKasper Nielsen         bulkLoadElementType = mlirF32TypeGet(context);
11181824e45cSKasper Nielsen       } else if (format == "d") {
11191824e45cSKasper Nielsen         // f64
11201824e45cSKasper Nielsen         assert(view.itemsize == 8 && "mismatched array itemsize");
11211824e45cSKasper Nielsen         bulkLoadElementType = mlirF64TypeGet(context);
11221824e45cSKasper Nielsen       } else if (format == "e") {
11231824e45cSKasper Nielsen         // f16
11241824e45cSKasper Nielsen         assert(view.itemsize == 2 && "mismatched array itemsize");
11251824e45cSKasper Nielsen         bulkLoadElementType = mlirF16TypeGet(context);
11261824e45cSKasper Nielsen       } else if (format == "?") {
11271824e45cSKasper Nielsen         // i1
11281824e45cSKasper Nielsen         // The i1 type needs to be bit-packed, so we will handle it seperately
11291824e45cSKasper Nielsen         return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
11301824e45cSKasper Nielsen                                                       context);
11311824e45cSKasper Nielsen       } else if (isSignedIntegerFormat(format)) {
11321824e45cSKasper Nielsen         if (view.itemsize == 4) {
11331824e45cSKasper Nielsen           // i32
11341824e45cSKasper Nielsen           bulkLoadElementType = signless
11351824e45cSKasper Nielsen                                     ? mlirIntegerTypeGet(context, 32)
11361824e45cSKasper Nielsen                                     : mlirIntegerTypeSignedGet(context, 32);
11371824e45cSKasper Nielsen         } else if (view.itemsize == 8) {
11381824e45cSKasper Nielsen           // i64
11391824e45cSKasper Nielsen           bulkLoadElementType = signless
11401824e45cSKasper Nielsen                                     ? mlirIntegerTypeGet(context, 64)
11411824e45cSKasper Nielsen                                     : mlirIntegerTypeSignedGet(context, 64);
11421824e45cSKasper Nielsen         } else if (view.itemsize == 1) {
11431824e45cSKasper Nielsen           // i8
11441824e45cSKasper Nielsen           bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
11451824e45cSKasper Nielsen                                          : mlirIntegerTypeSignedGet(context, 8);
11461824e45cSKasper Nielsen         } else if (view.itemsize == 2) {
11471824e45cSKasper Nielsen           // i16
11481824e45cSKasper Nielsen           bulkLoadElementType = signless
11491824e45cSKasper Nielsen                                     ? mlirIntegerTypeGet(context, 16)
11501824e45cSKasper Nielsen                                     : mlirIntegerTypeSignedGet(context, 16);
11511824e45cSKasper Nielsen         }
11521824e45cSKasper Nielsen       } else if (isUnsignedIntegerFormat(format)) {
11531824e45cSKasper Nielsen         if (view.itemsize == 4) {
11541824e45cSKasper Nielsen           // unsigned i32
11551824e45cSKasper Nielsen           bulkLoadElementType = signless
11561824e45cSKasper Nielsen                                     ? mlirIntegerTypeGet(context, 32)
11571824e45cSKasper Nielsen                                     : mlirIntegerTypeUnsignedGet(context, 32);
11581824e45cSKasper Nielsen         } else if (view.itemsize == 8) {
11591824e45cSKasper Nielsen           // unsigned i64
11601824e45cSKasper Nielsen           bulkLoadElementType = signless
11611824e45cSKasper Nielsen                                     ? mlirIntegerTypeGet(context, 64)
11621824e45cSKasper Nielsen                                     : mlirIntegerTypeUnsignedGet(context, 64);
11631824e45cSKasper Nielsen         } else if (view.itemsize == 1) {
11641824e45cSKasper Nielsen           // i8
11651824e45cSKasper Nielsen           bulkLoadElementType = signless
11661824e45cSKasper Nielsen                                     ? mlirIntegerTypeGet(context, 8)
11671824e45cSKasper Nielsen                                     : mlirIntegerTypeUnsignedGet(context, 8);
11681824e45cSKasper Nielsen         } else if (view.itemsize == 2) {
11691824e45cSKasper Nielsen           // i16
11701824e45cSKasper Nielsen           bulkLoadElementType = signless
11711824e45cSKasper Nielsen                                     ? mlirIntegerTypeGet(context, 16)
11721824e45cSKasper Nielsen                                     : mlirIntegerTypeUnsignedGet(context, 16);
11731824e45cSKasper Nielsen         }
11741824e45cSKasper Nielsen       }
11751824e45cSKasper Nielsen       if (!bulkLoadElementType) {
11761824e45cSKasper Nielsen         throw std::invalid_argument(
11771824e45cSKasper Nielsen             std::string("unimplemented array format conversion from format: ") +
11781824e45cSKasper Nielsen             std::string(format));
11791824e45cSKasper Nielsen       }
11801824e45cSKasper Nielsen     }
11811824e45cSKasper Nielsen 
11821824e45cSKasper Nielsen     MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
11831824e45cSKasper Nielsen     return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
11841824e45cSKasper Nielsen   }
11851824e45cSKasper Nielsen 
1186b56d1ec6SPeter Hawkins   // There is a complication for boolean numpy arrays, as numpy represents
1187b56d1ec6SPeter Hawkins   // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
1188b56d1ec6SPeter Hawkins   // booleans per byte.
11891824e45cSKasper Nielsen   static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
11901824e45cSKasper Nielsen       Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
11911824e45cSKasper Nielsen       MlirContext &context) {
11921824e45cSKasper Nielsen     if (llvm::endianness::native != llvm::endianness::little) {
1193b56d1ec6SPeter Hawkins       // Given we have no good way of testing the behavior on big-endian
1194b56d1ec6SPeter Hawkins       // systems we will throw
1195b56d1ec6SPeter Hawkins       throw nb::type_error("Constructing a bit-packed MLIR attribute is "
11961824e45cSKasper Nielsen                            "unsupported on big-endian systems");
11971824e45cSKasper Nielsen     }
1198b56d1ec6SPeter Hawkins     nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray(
1199b56d1ec6SPeter Hawkins         /*data=*/static_cast<uint8_t *>(view.buf),
1200b56d1ec6SPeter Hawkins         /*shape=*/{static_cast<size_t>(view.len)});
12011824e45cSKasper Nielsen 
1202b56d1ec6SPeter Hawkins     nb::module_ numpy = nb::module_::import_("numpy");
1203b56d1ec6SPeter Hawkins     nb::object packbitsFunc = numpy.attr("packbits");
1204b56d1ec6SPeter Hawkins     nb::object packedBooleans =
1205b56d1ec6SPeter Hawkins         packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
1206b56d1ec6SPeter Hawkins     nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
12071824e45cSKasper Nielsen 
12081824e45cSKasper Nielsen     MlirType bitpackedType =
12091824e45cSKasper Nielsen         getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
12101824e45cSKasper Nielsen     assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
12111824e45cSKasper Nielsen     // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
12121824e45cSKasper Nielsen     // packedBooleans, hence the MlirAttribute will remain valid even when
12131824e45cSKasper Nielsen     // packedBooleans get reclaimed by the end of the function.
12141824e45cSKasper Nielsen     return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
12151824e45cSKasper Nielsen                                              pythonBuffer.ptr);
12161824e45cSKasper Nielsen   }
12171824e45cSKasper Nielsen 
12181824e45cSKasper Nielsen   // This does the opposite transformation of
12191824e45cSKasper Nielsen   // `getBitpackedAttributeFromBooleanBuffer`
1220b56d1ec6SPeter Hawkins   std::unique_ptr<nb_buffer_info> getBooleanBufferFromBitpackedAttribute() {
12211824e45cSKasper Nielsen     if (llvm::endianness::native != llvm::endianness::little) {
1222b56d1ec6SPeter Hawkins       // Given we have no good way of testing the behavior on big-endian
1223b56d1ec6SPeter Hawkins       // systems we will throw
1224b56d1ec6SPeter Hawkins       throw nb::type_error("Constructing a numpy array from a MLIR attribute "
12251824e45cSKasper Nielsen                            "is unsupported on big-endian systems");
12261824e45cSKasper Nielsen     }
12271824e45cSKasper Nielsen 
12281824e45cSKasper Nielsen     int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
12291824e45cSKasper Nielsen     int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
12301824e45cSKasper Nielsen     uint8_t *bitpackedData = static_cast<uint8_t *>(
12311824e45cSKasper Nielsen         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1232b56d1ec6SPeter Hawkins     nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray(
1233b56d1ec6SPeter Hawkins         /*data=*/bitpackedData,
1234b56d1ec6SPeter Hawkins         /*shape=*/{static_cast<size_t>(numBitpackedBytes)});
12351824e45cSKasper Nielsen 
1236b56d1ec6SPeter Hawkins     nb::module_ numpy = nb::module_::import_("numpy");
1237b56d1ec6SPeter Hawkins     nb::object unpackbitsFunc = numpy.attr("unpackbits");
1238b56d1ec6SPeter Hawkins     nb::object equalFunc = numpy.attr("equal");
1239b56d1ec6SPeter Hawkins     nb::object reshapeFunc = numpy.attr("reshape");
1240b56d1ec6SPeter Hawkins     nb::object unpackedBooleans =
1241b56d1ec6SPeter Hawkins         unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little");
12421824e45cSKasper Nielsen 
12431824e45cSKasper Nielsen     // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
12441824e45cSKasper Nielsen     // We need to:
12451824e45cSKasper Nielsen     //   1. Slice away the padded bits
12461824e45cSKasper Nielsen     //   2. Make the boolean array have the correct shape
12471824e45cSKasper Nielsen     //   3. Convert the array to a boolean array
1248b56d1ec6SPeter Hawkins     unpackedBooleans = unpackedBooleans[nb::slice(
1249b56d1ec6SPeter Hawkins         nb::int_(0), nb::int_(numBooleans), nb::int_(1))];
12501824e45cSKasper Nielsen     unpackedBooleans = equalFunc(unpackedBooleans, 1);
12511824e45cSKasper Nielsen 
12521824e45cSKasper Nielsen     MlirType shapedType = mlirAttributeGetType(*this);
12531824e45cSKasper Nielsen     intptr_t rank = mlirShapedTypeGetRank(shapedType);
1254404d0e99SAdrian Kuegel     std::vector<intptr_t> shape(rank);
12551824e45cSKasper Nielsen     for (intptr_t i = 0; i < rank; ++i) {
1256404d0e99SAdrian Kuegel       shape[i] = mlirShapedTypeGetDimSize(shapedType, i);
12571824e45cSKasper Nielsen     }
12581824e45cSKasper Nielsen     unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
12591824e45cSKasper Nielsen 
1260b56d1ec6SPeter Hawkins     // Make sure the returned nb::buffer_view claims ownership of the data in
12611824e45cSKasper Nielsen     // `pythonBuffer` so it remains valid when Python reads it
1262b56d1ec6SPeter Hawkins     nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans);
1263b56d1ec6SPeter Hawkins     return std::make_unique<nb_buffer_info>(pythonBuffer.request());
12641824e45cSKasper Nielsen   }
12651824e45cSKasper Nielsen 
1266436c6c9cSStella Laurenzo   template <typename Type>
1267b56d1ec6SPeter Hawkins   std::unique_ptr<nb_buffer_info>
1268b56d1ec6SPeter Hawkins   bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) {
12690a68171bSDmitri Gribenko     intptr_t rank = mlirShapedTypeGetRank(shapedType);
1270436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
12710a68171bSDmitri Gribenko     // Buffer is configured for read-only access below.
1272436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
1273436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1274436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
1275436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
1276436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
1277436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
1278436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
1279436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
1280f0e847d0SRahul Kayaith     if (mlirDenseElementsAttrIsSplat(*this)) {
1281f0e847d0SRahul Kayaith       // Splats are special, only the single value is stored.
1282f0e847d0SRahul Kayaith       strides.assign(rank, 0);
1283f0e847d0SRahul Kayaith     } else {
1284436c6c9cSStella Laurenzo       for (intptr_t i = 1; i < rank; ++i) {
1285f0e847d0SRahul Kayaith         intptr_t strideFactor = 1;
1286f0e847d0SRahul Kayaith         for (intptr_t j = i; j < rank; ++j)
1287436c6c9cSStella Laurenzo           strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
1288436c6c9cSStella Laurenzo         strides.push_back(sizeof(Type) * strideFactor);
1289436c6c9cSStella Laurenzo       }
1290436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type));
1291f0e847d0SRahul Kayaith     }
1292b56d1ec6SPeter Hawkins     const char *format;
12935d6d30edSStella Laurenzo     if (explicitFormat) {
12945d6d30edSStella Laurenzo       format = explicitFormat;
12955d6d30edSStella Laurenzo     } else {
1296b56d1ec6SPeter Hawkins       format = nb_format_descriptor<Type>::format();
12975d6d30edSStella Laurenzo     }
1298b56d1ec6SPeter Hawkins     return std::make_unique<nb_buffer_info>(
1299b56d1ec6SPeter Hawkins         data, sizeof(Type), format, rank, std::move(shape), std::move(strides),
13005d6d30edSStella Laurenzo         /*readonly=*/true);
1301436c6c9cSStella Laurenzo   }
1302436c6c9cSStella Laurenzo }; // namespace
1303436c6c9cSStella Laurenzo 
1304b56d1ec6SPeter Hawkins PyType_Slot PyDenseElementsAttribute::slots[] = {
1305b56d1ec6SPeter Hawkins // Python 3.8 doesn't allow setting the buffer protocol slots from a type spec.
1306b56d1ec6SPeter Hawkins #if PY_VERSION_HEX >= 0x03090000
1307b56d1ec6SPeter Hawkins     {Py_bf_getbuffer,
1308b56d1ec6SPeter Hawkins      reinterpret_cast<void *>(PyDenseElementsAttribute::bf_getbuffer)},
1309b56d1ec6SPeter Hawkins     {Py_bf_releasebuffer,
1310b56d1ec6SPeter Hawkins      reinterpret_cast<void *>(PyDenseElementsAttribute::bf_releasebuffer)},
1311b56d1ec6SPeter Hawkins #endif
1312b56d1ec6SPeter Hawkins     {0, nullptr},
1313b56d1ec6SPeter Hawkins };
1314b56d1ec6SPeter Hawkins 
1315b56d1ec6SPeter Hawkins /*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj,
1316b56d1ec6SPeter Hawkins                                                       Py_buffer *view,
1317b56d1ec6SPeter Hawkins                                                       int flags) {
1318b56d1ec6SPeter Hawkins   view->obj = nullptr;
1319b56d1ec6SPeter Hawkins   std::unique_ptr<nb_buffer_info> info;
1320b56d1ec6SPeter Hawkins   try {
1321b56d1ec6SPeter Hawkins     auto *attr = nb::cast<PyDenseElementsAttribute *>(nb::handle(obj));
1322b56d1ec6SPeter Hawkins     info = attr->accessBuffer();
1323b56d1ec6SPeter Hawkins   } catch (nb::python_error &e) {
1324b56d1ec6SPeter Hawkins     e.restore();
1325b56d1ec6SPeter Hawkins     nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer");
1326b56d1ec6SPeter Hawkins     return -1;
1327b56d1ec6SPeter Hawkins   }
1328b56d1ec6SPeter Hawkins   view->obj = obj;
1329b56d1ec6SPeter Hawkins   view->ndim = 1;
1330b56d1ec6SPeter Hawkins   view->buf = info->ptr;
1331b56d1ec6SPeter Hawkins   view->itemsize = info->itemsize;
1332b56d1ec6SPeter Hawkins   view->len = info->itemsize;
1333b56d1ec6SPeter Hawkins   for (auto s : info->shape) {
1334b56d1ec6SPeter Hawkins     view->len *= s;
1335b56d1ec6SPeter Hawkins   }
1336b56d1ec6SPeter Hawkins   view->readonly = info->readonly;
1337b56d1ec6SPeter Hawkins   if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
1338b56d1ec6SPeter Hawkins     view->format = const_cast<char *>(info->format);
1339b56d1ec6SPeter Hawkins   }
1340b56d1ec6SPeter Hawkins   if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
1341b56d1ec6SPeter Hawkins     view->ndim = static_cast<int>(info->ndim);
1342b56d1ec6SPeter Hawkins     view->strides = info->strides.data();
1343b56d1ec6SPeter Hawkins     view->shape = info->shape.data();
1344b56d1ec6SPeter Hawkins   }
1345b56d1ec6SPeter Hawkins   view->suboffsets = nullptr;
1346b56d1ec6SPeter Hawkins   view->internal = info.release();
1347b56d1ec6SPeter Hawkins   Py_INCREF(obj);
1348b56d1ec6SPeter Hawkins   return 0;
1349b56d1ec6SPeter Hawkins }
1350b56d1ec6SPeter Hawkins 
1351b56d1ec6SPeter Hawkins /*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *,
1352b56d1ec6SPeter Hawkins                                                            Py_buffer *view) {
1353b56d1ec6SPeter Hawkins   delete reinterpret_cast<nb_buffer_info *>(view->internal);
1354b56d1ec6SPeter Hawkins }
1355b56d1ec6SPeter Hawkins 
1356b56d1ec6SPeter Hawkins /// Refinement of the PyDenseElementsAttribute for attributes containing
1357b56d1ec6SPeter Hawkins /// integer (and boolean) values. Supports element access.
1358436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
1359436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
1360436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
1361436c6c9cSStella Laurenzo public:
1362436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
1363436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
1364436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1365436c6c9cSStella Laurenzo 
1366b56d1ec6SPeter Hawkins   /// Returns the element at the given linear position. Asserts if the index
1367b56d1ec6SPeter Hawkins   /// is out of range.
1368b56d1ec6SPeter Hawkins   nb::object dunderGetItem(intptr_t pos) {
1369436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
1370b56d1ec6SPeter Hawkins       throw nb::index_error("attempt to access out of bounds element");
1371436c6c9cSStella Laurenzo     }
1372436c6c9cSStella Laurenzo 
1373436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
1374436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
1375*5d3ae516SMatthias Gehre     // Index type can also appear as a DenseIntElementsAttr and therefore can be
1376*5d3ae516SMatthias Gehre     // casted to integer.
1377*5d3ae516SMatthias Gehre     assert(mlirTypeIsAInteger(type) ||
1378*5d3ae516SMatthias Gehre            mlirTypeIsAIndex(type) && "expected integer/index element type in "
1379*5d3ae516SMatthias Gehre                                      "dense int elements attribute");
1380436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
1381b56d1ec6SPeter Hawkins     // elemental type of the attribute. nb::int_ is implicitly constructible
1382436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
1383436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
1384436c6c9cSStella Laurenzo     // querying them on each element access.
1385*5d3ae516SMatthias Gehre     if (mlirTypeIsAIndex(type)) {
1386*5d3ae516SMatthias Gehre       return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
1387*5d3ae516SMatthias Gehre     }
1388436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
1389436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
1390436c6c9cSStella Laurenzo     if (isUnsigned) {
1391436c6c9cSStella Laurenzo       if (width == 1) {
1392b56d1ec6SPeter Hawkins         return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
1393436c6c9cSStella Laurenzo       }
1394308d8b8cSRahul Kayaith       if (width == 8) {
1395b56d1ec6SPeter Hawkins         return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
1396308d8b8cSRahul Kayaith       }
1397308d8b8cSRahul Kayaith       if (width == 16) {
1398b56d1ec6SPeter Hawkins         return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
1399308d8b8cSRahul Kayaith       }
1400436c6c9cSStella Laurenzo       if (width == 32) {
1401b56d1ec6SPeter Hawkins         return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
1402436c6c9cSStella Laurenzo       }
1403436c6c9cSStella Laurenzo       if (width == 64) {
1404b56d1ec6SPeter Hawkins         return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
1405436c6c9cSStella Laurenzo       }
1406436c6c9cSStella Laurenzo     } else {
1407436c6c9cSStella Laurenzo       if (width == 1) {
1408b56d1ec6SPeter Hawkins         return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
1409436c6c9cSStella Laurenzo       }
1410308d8b8cSRahul Kayaith       if (width == 8) {
1411b56d1ec6SPeter Hawkins         return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
1412308d8b8cSRahul Kayaith       }
1413308d8b8cSRahul Kayaith       if (width == 16) {
1414b56d1ec6SPeter Hawkins         return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
1415308d8b8cSRahul Kayaith       }
1416436c6c9cSStella Laurenzo       if (width == 32) {
1417b56d1ec6SPeter Hawkins         return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
1418436c6c9cSStella Laurenzo       }
1419436c6c9cSStella Laurenzo       if (width == 64) {
1420b56d1ec6SPeter Hawkins         return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
1421436c6c9cSStella Laurenzo       }
1422436c6c9cSStella Laurenzo     }
1423b56d1ec6SPeter Hawkins     throw nb::type_error("Unsupported integer type");
1424436c6c9cSStella Laurenzo   }
1425436c6c9cSStella Laurenzo 
1426436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1427436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
1428436c6c9cSStella Laurenzo   }
1429436c6c9cSStella Laurenzo };
1430436c6c9cSStella Laurenzo 
1431f66cd9e9SStella Laurenzo class PyDenseResourceElementsAttribute
1432f66cd9e9SStella Laurenzo     : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
1433f66cd9e9SStella Laurenzo public:
1434f66cd9e9SStella Laurenzo   static constexpr IsAFunctionTy isaFunction =
1435f66cd9e9SStella Laurenzo       mlirAttributeIsADenseResourceElements;
1436f66cd9e9SStella Laurenzo   static constexpr const char *pyClassName = "DenseResourceElementsAttr";
1437f66cd9e9SStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1438f66cd9e9SStella Laurenzo 
1439f66cd9e9SStella Laurenzo   static PyDenseResourceElementsAttribute
1440b56d1ec6SPeter Hawkins   getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type,
1441f66cd9e9SStella Laurenzo                 std::optional<size_t> alignment, bool isMutable,
1442f66cd9e9SStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
1443f66cd9e9SStella Laurenzo     if (!mlirTypeIsAShaped(type)) {
1444f66cd9e9SStella Laurenzo       throw std::invalid_argument(
1445f66cd9e9SStella Laurenzo           "Constructing a DenseResourceElementsAttr requires a ShapedType.");
1446f66cd9e9SStella Laurenzo     }
1447f66cd9e9SStella Laurenzo 
1448f66cd9e9SStella Laurenzo     // Do not request any conversions as we must ensure to use caller
1449f66cd9e9SStella Laurenzo     // managed memory.
1450f66cd9e9SStella Laurenzo     int flags = PyBUF_STRIDES;
1451f66cd9e9SStella Laurenzo     std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
1452f66cd9e9SStella Laurenzo     if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
1453b56d1ec6SPeter Hawkins       throw nb::python_error();
1454f66cd9e9SStella Laurenzo     }
1455f66cd9e9SStella Laurenzo 
1456f66cd9e9SStella Laurenzo     // This scope releaser will only release if we haven't yet transferred
1457f66cd9e9SStella Laurenzo     // ownership.
1458f66cd9e9SStella Laurenzo     auto freeBuffer = llvm::make_scope_exit([&]() {
1459f66cd9e9SStella Laurenzo       if (view)
1460f66cd9e9SStella Laurenzo         PyBuffer_Release(view.get());
1461f66cd9e9SStella Laurenzo     });
1462f66cd9e9SStella Laurenzo 
1463f66cd9e9SStella Laurenzo     if (!PyBuffer_IsContiguous(view.get(), 'A')) {
1464f66cd9e9SStella Laurenzo       throw std::invalid_argument("Contiguous buffer is required.");
1465f66cd9e9SStella Laurenzo     }
1466f66cd9e9SStella Laurenzo 
1467f66cd9e9SStella Laurenzo     // Infer alignment to be the stride of one element if not explicit.
1468f66cd9e9SStella Laurenzo     size_t inferredAlignment;
1469f66cd9e9SStella Laurenzo     if (alignment)
1470f66cd9e9SStella Laurenzo       inferredAlignment = *alignment;
1471f66cd9e9SStella Laurenzo     else
1472f66cd9e9SStella Laurenzo       inferredAlignment = view->strides[view->ndim - 1];
1473f66cd9e9SStella Laurenzo 
1474f66cd9e9SStella Laurenzo     // The userData is a Py_buffer* that the deleter owns.
1475f66cd9e9SStella Laurenzo     auto deleter = [](void *userData, const void *data, size_t size,
1476f66cd9e9SStella Laurenzo                       size_t align) {
147728507ac6SFabian Tschopp       if (!Py_IsInitialized())
147828507ac6SFabian Tschopp         Py_Initialize();
1479f66cd9e9SStella Laurenzo       Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
148028507ac6SFabian Tschopp       nb::gil_scoped_acquire gil;
1481f66cd9e9SStella Laurenzo       PyBuffer_Release(ownedView);
1482f66cd9e9SStella Laurenzo       delete ownedView;
1483f66cd9e9SStella Laurenzo     };
1484f66cd9e9SStella Laurenzo 
1485f66cd9e9SStella Laurenzo     size_t rawBufferSize = view->len;
1486f66cd9e9SStella Laurenzo     MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
1487f66cd9e9SStella Laurenzo         type, toMlirStringRef(name), view->buf, rawBufferSize,
1488f66cd9e9SStella Laurenzo         inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
1489f66cd9e9SStella Laurenzo     if (mlirAttributeIsNull(attr)) {
1490f66cd9e9SStella Laurenzo       throw std::invalid_argument(
1491f66cd9e9SStella Laurenzo           "DenseResourceElementsAttr could not be constructed from the given "
1492f66cd9e9SStella Laurenzo           "buffer. "
1493f66cd9e9SStella Laurenzo           "This may mean that the Python buffer layout does not match that "
1494f66cd9e9SStella Laurenzo           "MLIR expected layout and is a bug.");
1495f66cd9e9SStella Laurenzo     }
1496f66cd9e9SStella Laurenzo     view.release();
1497f66cd9e9SStella Laurenzo     return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
1498f66cd9e9SStella Laurenzo   }
1499f66cd9e9SStella Laurenzo 
1500f66cd9e9SStella Laurenzo   static void bindDerived(ClassTy &c) {
1501b56d1ec6SPeter Hawkins     c.def_static(
1502b56d1ec6SPeter Hawkins         "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer,
1503b56d1ec6SPeter Hawkins         nb::arg("array"), nb::arg("name"), nb::arg("type"),
1504b56d1ec6SPeter Hawkins         nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false,
1505b56d1ec6SPeter Hawkins         nb::arg("context").none() = nb::none(),
1506f66cd9e9SStella Laurenzo         kDenseResourceElementsAttrGetFromBufferDocstring);
1507f66cd9e9SStella Laurenzo   }
1508f66cd9e9SStella Laurenzo };
1509f66cd9e9SStella Laurenzo 
1510436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
1511436c6c9cSStella Laurenzo public:
1512436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
1513436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
1514436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
15159566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
15169566ee28Smax       mlirDictionaryAttrGetTypeID;
1517436c6c9cSStella Laurenzo 
1518436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
1519436c6c9cSStella Laurenzo 
15209fb1086bSAdrian Kuegel   bool dunderContains(const std::string &name) {
15219fb1086bSAdrian Kuegel     return !mlirAttributeIsNull(
15229fb1086bSAdrian Kuegel         mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
15239fb1086bSAdrian Kuegel   }
15249fb1086bSAdrian Kuegel 
1525436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
15269fb1086bSAdrian Kuegel     c.def("__contains__", &PyDictAttribute::dunderContains);
1527436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
1528436c6c9cSStella Laurenzo     c.def_static(
1529436c6c9cSStella Laurenzo         "get",
1530b56d1ec6SPeter Hawkins         [](nb::dict attributes, DefaultingPyMlirContext context) {
1531436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
1532436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
1533b56d1ec6SPeter Hawkins           for (std::pair<nb::handle, nb::handle> it : attributes) {
1534b56d1ec6SPeter Hawkins             auto &mlirAttr = nb::cast<PyAttribute &>(it.second);
1535b56d1ec6SPeter Hawkins             auto name = nb::cast<std::string>(it.first);
1536436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
153702b6fb21SMehdi Amini                 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
1538436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
153902b6fb21SMehdi Amini                 mlirAttr));
1540436c6c9cSStella Laurenzo           }
1541436c6c9cSStella Laurenzo           MlirAttribute attr =
1542436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
1543436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
1544436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
1545436c6c9cSStella Laurenzo         },
1546b56d1ec6SPeter Hawkins         nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(),
1547436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
1548436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
1549436c6c9cSStella Laurenzo       MlirAttribute attr =
1550436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
1551974c1596SRahul Kayaith       if (mlirAttributeIsNull(attr))
1552b56d1ec6SPeter Hawkins         throw nb::key_error("attempt to access a non-existent attribute");
1553974c1596SRahul Kayaith       return attr;
1554436c6c9cSStella Laurenzo     });
1555436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
1556436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
1557b56d1ec6SPeter Hawkins         throw nb::index_error("attempt to access out of bounds attribute");
1558436c6c9cSStella Laurenzo       }
1559436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
1560436c6c9cSStella Laurenzo       return PyNamedAttribute(
1561436c6c9cSStella Laurenzo           namedAttr.attribute,
1562436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
1563436c6c9cSStella Laurenzo     });
1564436c6c9cSStella Laurenzo   }
1565436c6c9cSStella Laurenzo };
1566436c6c9cSStella Laurenzo 
1567436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
1568436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
1569436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
1570436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1571436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
1572436c6c9cSStella Laurenzo public:
1573436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1574436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
1575436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1576436c6c9cSStella Laurenzo 
1577b56d1ec6SPeter Hawkins   nb::float_ dunderGetItem(intptr_t pos) {
1578436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
1579b56d1ec6SPeter Hawkins       throw nb::index_error("attempt to access out of bounds element");
1580436c6c9cSStella Laurenzo     }
1581436c6c9cSStella Laurenzo 
1582436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
1583436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
1584436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
1585b56d1ec6SPeter Hawkins     // elemental type of the attribute. nb::float_ is implicitly constructible
1586436c6c9cSStella Laurenzo     // from float and double.
1587436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
1588436c6c9cSStella Laurenzo     // querying them on each element access.
1589436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
1590b56d1ec6SPeter Hawkins       return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
1591436c6c9cSStella Laurenzo     }
1592436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
1593b56d1ec6SPeter Hawkins       return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
1594436c6c9cSStella Laurenzo     }
1595b56d1ec6SPeter Hawkins     throw nb::type_error("Unsupported floating-point type");
1596436c6c9cSStella Laurenzo   }
1597436c6c9cSStella Laurenzo 
1598436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1599436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1600436c6c9cSStella Laurenzo   }
1601436c6c9cSStella Laurenzo };
1602436c6c9cSStella Laurenzo 
1603436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1604436c6c9cSStella Laurenzo public:
1605436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1606436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
1607436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
16089566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
16099566ee28Smax       mlirTypeAttrGetTypeID;
1610436c6c9cSStella Laurenzo 
1611436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1612436c6c9cSStella Laurenzo     c.def_static(
1613436c6c9cSStella Laurenzo         "get",
1614436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
1615436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
1616436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
1617436c6c9cSStella Laurenzo         },
1618b56d1ec6SPeter Hawkins         nb::arg("value"), nb::arg("context").none() = nb::none(),
1619436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
1620b56d1ec6SPeter Hawkins     c.def_prop_ro("value", [](PyTypeAttribute &self) {
1621bfb1ba75Smax       return mlirTypeAttrGetValue(self.get());
1622436c6c9cSStella Laurenzo     });
1623436c6c9cSStella Laurenzo   }
1624436c6c9cSStella Laurenzo };
1625436c6c9cSStella Laurenzo 
1626436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
1627436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1628436c6c9cSStella Laurenzo public:
1629436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1630436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
1631436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
16329566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
16339566ee28Smax       mlirUnitAttrGetTypeID;
1634436c6c9cSStella Laurenzo 
1635436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1636436c6c9cSStella Laurenzo     c.def_static(
1637436c6c9cSStella Laurenzo         "get",
1638436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
1639436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
1640436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
1641436c6c9cSStella Laurenzo         },
1642b56d1ec6SPeter Hawkins         nb::arg("context").none() = nb::none(), "Create a Unit attribute.");
1643436c6c9cSStella Laurenzo   }
1644436c6c9cSStella Laurenzo };
1645436c6c9cSStella Laurenzo 
1646ac2e2d65SDenys Shabalin /// Strided layout attribute subclass.
1647ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute
1648ac2e2d65SDenys Shabalin     : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1649ac2e2d65SDenys Shabalin public:
1650ac2e2d65SDenys Shabalin   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1651ac2e2d65SDenys Shabalin   static constexpr const char *pyClassName = "StridedLayoutAttr";
1652ac2e2d65SDenys Shabalin   using PyConcreteAttribute::PyConcreteAttribute;
16539566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
16549566ee28Smax       mlirStridedLayoutAttrGetTypeID;
1655ac2e2d65SDenys Shabalin 
1656ac2e2d65SDenys Shabalin   static void bindDerived(ClassTy &c) {
1657ac2e2d65SDenys Shabalin     c.def_static(
1658ac2e2d65SDenys Shabalin         "get",
1659ac2e2d65SDenys Shabalin         [](int64_t offset, const std::vector<int64_t> strides,
1660ac2e2d65SDenys Shabalin            DefaultingPyMlirContext ctx) {
1661ac2e2d65SDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1662ac2e2d65SDenys Shabalin               ctx->get(), offset, strides.size(), strides.data());
1663ac2e2d65SDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1664ac2e2d65SDenys Shabalin         },
1665b56d1ec6SPeter Hawkins         nb::arg("offset"), nb::arg("strides"),
1666b56d1ec6SPeter Hawkins         nb::arg("context").none() = nb::none(),
1667ac2e2d65SDenys Shabalin         "Gets a strided layout attribute.");
1668e3fd612eSDenys Shabalin     c.def_static(
1669e3fd612eSDenys Shabalin         "get_fully_dynamic",
1670e3fd612eSDenys Shabalin         [](int64_t rank, DefaultingPyMlirContext ctx) {
1671e3fd612eSDenys Shabalin           auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
1672e3fd612eSDenys Shabalin           std::vector<int64_t> strides(rank);
1673e3fd612eSDenys Shabalin           std::fill(strides.begin(), strides.end(), dynamic);
1674e3fd612eSDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1675e3fd612eSDenys Shabalin               ctx->get(), dynamic, strides.size(), strides.data());
1676e3fd612eSDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1677e3fd612eSDenys Shabalin         },
1678b56d1ec6SPeter Hawkins         nb::arg("rank"), nb::arg("context").none() = nb::none(),
1679b56d1ec6SPeter Hawkins         "Gets a strided layout attribute with dynamic offset and strides of "
1680b56d1ec6SPeter Hawkins         "a "
1681e3fd612eSDenys Shabalin         "given rank.");
1682b56d1ec6SPeter Hawkins     c.def_prop_ro(
1683ac2e2d65SDenys Shabalin         "offset",
1684ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1685ac2e2d65SDenys Shabalin           return mlirStridedLayoutAttrGetOffset(self);
1686ac2e2d65SDenys Shabalin         },
1687ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1688b56d1ec6SPeter Hawkins     c.def_prop_ro(
1689ac2e2d65SDenys Shabalin         "strides",
1690ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1691ac2e2d65SDenys Shabalin           intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1692ac2e2d65SDenys Shabalin           std::vector<int64_t> strides(size);
1693ac2e2d65SDenys Shabalin           for (intptr_t i = 0; i < size; i++) {
1694ac2e2d65SDenys Shabalin             strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1695ac2e2d65SDenys Shabalin           }
1696ac2e2d65SDenys Shabalin           return strides;
1697ac2e2d65SDenys Shabalin         },
1698ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1699ac2e2d65SDenys Shabalin   }
1700ac2e2d65SDenys Shabalin };
1701ac2e2d65SDenys Shabalin 
1702b56d1ec6SPeter Hawkins nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
17039566ee28Smax   if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
1704b56d1ec6SPeter Hawkins     return nb::cast(PyDenseBoolArrayAttribute(pyAttribute));
17059566ee28Smax   if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
1706b56d1ec6SPeter Hawkins     return nb::cast(PyDenseI8ArrayAttribute(pyAttribute));
17079566ee28Smax   if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
1708b56d1ec6SPeter Hawkins     return nb::cast(PyDenseI16ArrayAttribute(pyAttribute));
17099566ee28Smax   if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
1710b56d1ec6SPeter Hawkins     return nb::cast(PyDenseI32ArrayAttribute(pyAttribute));
17119566ee28Smax   if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
1712b56d1ec6SPeter Hawkins     return nb::cast(PyDenseI64ArrayAttribute(pyAttribute));
17139566ee28Smax   if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
1714b56d1ec6SPeter Hawkins     return nb::cast(PyDenseF32ArrayAttribute(pyAttribute));
17159566ee28Smax   if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
1716b56d1ec6SPeter Hawkins     return nb::cast(PyDenseF64ArrayAttribute(pyAttribute));
17179566ee28Smax   std::string msg =
17189566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
1719b56d1ec6SPeter Hawkins       nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1720b56d1ec6SPeter Hawkins   throw nb::type_error(msg.c_str());
17219566ee28Smax }
17229566ee28Smax 
1723b56d1ec6SPeter Hawkins nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
17249566ee28Smax   if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
1725b56d1ec6SPeter Hawkins     return nb::cast(PyDenseFPElementsAttribute(pyAttribute));
17269566ee28Smax   if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
1727b56d1ec6SPeter Hawkins     return nb::cast(PyDenseIntElementsAttribute(pyAttribute));
17289566ee28Smax   std::string msg =
17299566ee28Smax       std::string(
17309566ee28Smax           "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
1731b56d1ec6SPeter Hawkins       nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1732b56d1ec6SPeter Hawkins   throw nb::type_error(msg.c_str());
17339566ee28Smax }
17349566ee28Smax 
1735b56d1ec6SPeter Hawkins nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
17369566ee28Smax   if (PyBoolAttribute::isaFunction(pyAttribute))
1737b56d1ec6SPeter Hawkins     return nb::cast(PyBoolAttribute(pyAttribute));
17389566ee28Smax   if (PyIntegerAttribute::isaFunction(pyAttribute))
1739b56d1ec6SPeter Hawkins     return nb::cast(PyIntegerAttribute(pyAttribute));
17409566ee28Smax   std::string msg =
17419566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
1742b56d1ec6SPeter Hawkins       nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1743b56d1ec6SPeter Hawkins   throw nb::type_error(msg.c_str());
17449566ee28Smax }
17459566ee28Smax 
1746b56d1ec6SPeter Hawkins nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
17474eee9ef9Smax   if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
1748b56d1ec6SPeter Hawkins     return nb::cast(PyFlatSymbolRefAttribute(pyAttribute));
17494eee9ef9Smax   if (PySymbolRefAttribute::isaFunction(pyAttribute))
1750b56d1ec6SPeter Hawkins     return nb::cast(PySymbolRefAttribute(pyAttribute));
17514eee9ef9Smax   std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
1752b56d1ec6SPeter Hawkins                     nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) +
1753b56d1ec6SPeter Hawkins                     ")";
1754b56d1ec6SPeter Hawkins   throw nb::type_error(msg.c_str());
17554eee9ef9Smax }
17564eee9ef9Smax 
1757436c6c9cSStella Laurenzo } // namespace
1758436c6c9cSStella Laurenzo 
1759b56d1ec6SPeter Hawkins void mlir::python::populateIRAttributes(nb::module_ &m) {
1760436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
1761619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::bind(m);
1762619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1763619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::bind(m);
1764619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1765619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::bind(m);
1766619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1767619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::bind(m);
1768619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1769619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::bind(m);
1770619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1771619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::bind(m);
1772619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1773619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::bind(m);
1774619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
17759566ee28Smax   PyGlobals::get().registerTypeCaster(
17769566ee28Smax       mlirDenseArrayAttrGetTypeID(),
1777b56d1ec6SPeter Hawkins       nb::cast<nb::callable>(nb::cpp_function(denseArrayAttributeCaster)));
1778619fd8c2SJeff Niu 
1779436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
1780436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1781436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
1782b56d1ec6SPeter Hawkins   PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots);
1783436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
1784436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
17859566ee28Smax   PyGlobals::get().registerTypeCaster(
17869566ee28Smax       mlirDenseIntOrFPElementsAttrGetTypeID(),
1787b56d1ec6SPeter Hawkins       nb::cast<nb::callable>(
1788b56d1ec6SPeter Hawkins           nb::cpp_function(denseIntOrFPElementsAttributeCaster)));
1789f66cd9e9SStella Laurenzo   PyDenseResourceElementsAttribute::bind(m);
17909566ee28Smax 
1791436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
17924eee9ef9Smax   PySymbolRefAttribute::bind(m);
17934eee9ef9Smax   PyGlobals::get().registerTypeCaster(
17944eee9ef9Smax       mlirSymbolRefAttrGetTypeID(),
1795b56d1ec6SPeter Hawkins       nb::cast<nb::callable>(
1796b56d1ec6SPeter Hawkins           nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)));
17974eee9ef9Smax 
1798436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
17995c3861b2SYun Long   PyOpaqueAttribute::bind(m);
1800436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
1801436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
1802334873feSAmy Wang   PyIntegerSetAttribute::bind(m);
1803436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
1804436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
18059566ee28Smax   PyGlobals::get().registerTypeCaster(
18069566ee28Smax       mlirIntegerAttrGetTypeID(),
1807b56d1ec6SPeter Hawkins       nb::cast<nb::callable>(nb::cpp_function(integerOrBoolAttributeCaster)));
1808436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
1809ac2e2d65SDenys Shabalin 
1810ac2e2d65SDenys Shabalin   PyStridedLayoutAttribute::bind(m);
1811436c6c9cSStella Laurenzo }
1812