xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision 5d3ae5161210c068d01ffba36c8e0761e9971179)
1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cstdint>
10 #include <optional>
11 #include <string>
12 #include <string_view>
13 #include <utility>
14 
15 #include "IRModule.h"
16 #include "NanobindUtils.h"
17 #include "mlir-c/BuiltinAttributes.h"
18 #include "mlir-c/BuiltinTypes.h"
19 #include "mlir/Bindings/Python/NanobindAdaptors.h"
20 #include "mlir/Bindings/Python/Nanobind.h"
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/Support/raw_ostream.h"
23 
24 namespace nb = nanobind;
25 using namespace nanobind::literals;
26 using namespace mlir;
27 using namespace mlir::python;
28 
29 using llvm::SmallVector;
30 
31 //------------------------------------------------------------------------------
32 // Docstrings (trivial, non-duplicated docstrings are included inline).
33 //------------------------------------------------------------------------------
34 
35 static const char kDenseElementsAttrGetDocstring[] =
36     R"(Gets a DenseElementsAttr from a Python buffer or array.
37 
38 When `type` is not provided, then some limited type inferencing is done based
39 on the buffer format. Support presently exists for 8/16/32/64 signed and
40 unsigned integers and float16/float32/float64. DenseElementsAttrs of these
41 types can also be converted back to a corresponding buffer.
42 
43 For conversions outside of these types, a `type=` must be explicitly provided
44 and the buffer contents must be bit-castable to the MLIR internal
45 representation:
46 
47   * Integer types (except for i1): the buffer must be byte aligned to the
48     next byte boundary.
49   * Floating point types: Must be bit-castable to the given floating point
50     size.
51   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
52     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
53     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
54 
55 If a single element buffer is passed (or for i1, a single byte with value 0
56 or 255), then a splat will be created.
57 
58 Args:
59   array: The array or buffer to convert.
60   signless: If inferring an appropriate MLIR type, use signless types for
61     integers (defaults True).
62   type: Skips inference of the MLIR element type and uses this instead. The
63     storage size must be consistent with the actual contents of the buffer.
64   shape: Overrides the shape of the buffer when constructing the MLIR
65     shaped type. This is needed when the physical and logical shape differ (as
66     for i1).
67   context: Explicit context, if not from context manager.
68 
69 Returns:
70   DenseElementsAttr on success.
71 
72 Raises:
73   ValueError: If the type of the buffer or array cannot be matched to an MLIR
74     type or if the buffer does not meet expectations.
75 )";
76 
77 static const char kDenseElementsAttrGetFromListDocstring[] =
78     R"(Gets a DenseElementsAttr from a Python list of attributes.
79 
80 Note that it can be expensive to construct attributes individually.
81 For a large number of elements, consider using a Python buffer or array instead.
82 
83 Args:
84   attrs: A list of attributes.
85   type: The desired shape and type of the resulting DenseElementsAttr.
86     If not provided, the element type is determined based on the type
87     of the 0th attribute and the shape is `[len(attrs)]`.
88   context: Explicit context, if not from context manager.
89 
90 Returns:
91   DenseElementsAttr on success.
92 
93 Raises:
94   ValueError: If the type of the attributes does not match the type
95     specified by `shaped_type`.
96 )";
97 
98 static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
99     R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
100 
101 This function does minimal validation or massaging of the data, and it is
102 up to the caller to ensure that the buffer meets the characteristics
103 implied by the shape.
104 
105 The backing buffer and any user objects will be retained for the lifetime
106 of the resource blob. This is typically bounded to the context but the
107 resource can have a shorter lifespan depending on how it is used in
108 subsequent processing.
109 
110 Args:
111   buffer: The array or buffer to convert.
112   name: Name to provide to the resource (may be changed upon collision).
113   type: The explicit ShapedType to construct the attribute with.
114   context: Explicit context, if not from context manager.
115 
116 Returns:
117   DenseResourceElementsAttr on success.
118 
119 Raises:
120   ValueError: If the type of the buffer or array cannot be matched to an MLIR
121     type or if the buffer does not meet expectations.
122 )";
123 
124 namespace {
125 
126 struct nb_buffer_info {
127   void *ptr = nullptr;
128   ssize_t itemsize = 0;
129   ssize_t size = 0;
130   const char *format = nullptr;
131   ssize_t ndim = 0;
132   SmallVector<ssize_t, 4> shape;
133   SmallVector<ssize_t, 4> strides;
134   bool readonly = false;
135 
136   nb_buffer_info(
137       void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
138       SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in,
139       bool readonly = false,
140       std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in =
141           std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr))
142       : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
143         shape(std::move(shape_in)), strides(std::move(strides_in)),
144         readonly(readonly), owned_view(std::move(owned_view_in)) {
145     size = 1;
146     for (ssize_t i = 0; i < ndim; ++i) {
147       size *= shape[i];
148     }
149   }
150 
151   explicit nb_buffer_info(Py_buffer *view)
152       : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim,
153                        {view->shape, view->shape + view->ndim},
154                        // TODO(phawkins): check for null strides
155                        {view->strides, view->strides + view->ndim},
156                        view->readonly != 0,
157                        std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(
158                            view, PyBuffer_Release)) {}
159 
160   nb_buffer_info(const nb_buffer_info &) = delete;
161   nb_buffer_info(nb_buffer_info &&) = default;
162   nb_buffer_info &operator=(const nb_buffer_info &) = delete;
163   nb_buffer_info &operator=(nb_buffer_info &&) = default;
164 
165 private:
166   std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view;
167 };
168 
169 class nb_buffer : public nb::object {
170   NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer);
171 
172   nb_buffer_info request() const {
173     int flags = PyBUF_STRIDES | PyBUF_FORMAT;
174     auto *view = new Py_buffer();
175     if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
176       delete view;
177       throw nb::python_error();
178     }
179     return nb_buffer_info(view);
180   }
181 };
182 
183 template <typename T>
184 struct nb_format_descriptor {};
185 
186 template <>
187 struct nb_format_descriptor<bool> {
188   static const char *format() { return "?"; }
189 };
190 template <>
191 struct nb_format_descriptor<int8_t> {
192   static const char *format() { return "b"; }
193 };
194 template <>
195 struct nb_format_descriptor<uint8_t> {
196   static const char *format() { return "B"; }
197 };
198 template <>
199 struct nb_format_descriptor<int16_t> {
200   static const char *format() { return "h"; }
201 };
202 template <>
203 struct nb_format_descriptor<uint16_t> {
204   static const char *format() { return "H"; }
205 };
206 template <>
207 struct nb_format_descriptor<int32_t> {
208   static const char *format() { return "i"; }
209 };
210 template <>
211 struct nb_format_descriptor<uint32_t> {
212   static const char *format() { return "I"; }
213 };
214 template <>
215 struct nb_format_descriptor<int64_t> {
216   static const char *format() { return "q"; }
217 };
218 template <>
219 struct nb_format_descriptor<uint64_t> {
220   static const char *format() { return "Q"; }
221 };
222 template <>
223 struct nb_format_descriptor<float> {
224   static const char *format() { return "f"; }
225 };
226 template <>
227 struct nb_format_descriptor<double> {
228   static const char *format() { return "d"; }
229 };
230 
231 static MlirStringRef toMlirStringRef(const std::string &s) {
232   return mlirStringRefCreate(s.data(), s.size());
233 }
234 
235 static MlirStringRef toMlirStringRef(const nb::bytes &s) {
236   return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
237 }
238 
239 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
240 public:
241   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
242   static constexpr const char *pyClassName = "AffineMapAttr";
243   using PyConcreteAttribute::PyConcreteAttribute;
244   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
245       mlirAffineMapAttrGetTypeID;
246 
247   static void bindDerived(ClassTy &c) {
248     c.def_static(
249         "get",
250         [](PyAffineMap &affineMap) {
251           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
252           return PyAffineMapAttribute(affineMap.getContext(), attr);
253         },
254         nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
255     c.def_prop_ro("value", mlirAffineMapAttrGetValue,
256                   "Returns the value of the AffineMap attribute");
257   }
258 };
259 
260 class PyIntegerSetAttribute
261     : public PyConcreteAttribute<PyIntegerSetAttribute> {
262 public:
263   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
264   static constexpr const char *pyClassName = "IntegerSetAttr";
265   using PyConcreteAttribute::PyConcreteAttribute;
266   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
267       mlirIntegerSetAttrGetTypeID;
268 
269   static void bindDerived(ClassTy &c) {
270     c.def_static(
271         "get",
272         [](PyIntegerSet &integerSet) {
273           MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
274           return PyIntegerSetAttribute(integerSet.getContext(), attr);
275         },
276         nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
277   }
278 };
279 
280 template <typename T>
281 static T pyTryCast(nb::handle object) {
282   try {
283     return nb::cast<T>(object);
284   } catch (nb::cast_error &err) {
285     std::string msg = std::string("Invalid attribute when attempting to "
286                                   "create an ArrayAttribute (") +
287                       err.what() + ")";
288     throw std::runtime_error(msg.c_str());
289   } catch (std::runtime_error &err) {
290     std::string msg = std::string("Invalid attribute (None?) when attempting "
291                                   "to create an ArrayAttribute (") +
292                       err.what() + ")";
293     throw std::runtime_error(msg.c_str());
294   }
295 }
296 
297 /// A python-wrapped dense array attribute with an element type and a derived
298 /// implementation class.
299 template <typename EltTy, typename DerivedT>
300 class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
301 public:
302   using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
303 
304   /// Iterator over the integer elements of a dense array.
305   class PyDenseArrayIterator {
306   public:
307     PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
308 
309     /// Return a copy of the iterator.
310     PyDenseArrayIterator dunderIter() { return *this; }
311 
312     /// Return the next element.
313     EltTy dunderNext() {
314       // Throw if the index has reached the end.
315       if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
316         throw nb::stop_iteration();
317       return DerivedT::getElement(attr.get(), nextIndex++);
318     }
319 
320     /// Bind the iterator class.
321     static void bind(nb::module_ &m) {
322       nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
323           .def("__iter__", &PyDenseArrayIterator::dunderIter)
324           .def("__next__", &PyDenseArrayIterator::dunderNext);
325     }
326 
327   private:
328     /// The referenced dense array attribute.
329     PyAttribute attr;
330     /// The next index to read.
331     int nextIndex = 0;
332   };
333 
334   /// Get the element at the given index.
335   EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
336 
337   /// Bind the attribute class.
338   static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
339     // Bind the constructor.
340     if constexpr (std::is_same_v<EltTy, bool>) {
341       c.def_static(
342           "get",
343           [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) {
344             std::vector<bool> values;
345             for (nb::handle py_value : py_values) {
346               int is_true = PyObject_IsTrue(py_value.ptr());
347               if (is_true < 0) {
348                 throw nb::python_error();
349               }
350               values.push_back(is_true);
351             }
352             return getAttribute(values, ctx->getRef());
353           },
354           nb::arg("values"), nb::arg("context").none() = nb::none(),
355           "Gets a uniqued dense array attribute");
356     } else {
357       c.def_static(
358           "get",
359           [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
360             return getAttribute(values, ctx->getRef());
361           },
362           nb::arg("values"), nb::arg("context").none() = nb::none(),
363           "Gets a uniqued dense array attribute");
364     }
365     // Bind the array methods.
366     c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
367       if (i >= mlirDenseArrayGetNumElements(arr))
368         throw nb::index_error("DenseArray index out of range");
369       return arr.getItem(i);
370     });
371     c.def("__len__", [](const DerivedT &arr) {
372       return mlirDenseArrayGetNumElements(arr);
373     });
374     c.def("__iter__",
375           [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
376     c.def("__add__", [](DerivedT &arr, const nb::list &extras) {
377       std::vector<EltTy> values;
378       intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
379       values.reserve(numOldElements + nb::len(extras));
380       for (intptr_t i = 0; i < numOldElements; ++i)
381         values.push_back(arr.getItem(i));
382       for (nb::handle attr : extras)
383         values.push_back(pyTryCast<EltTy>(attr));
384       return getAttribute(values, arr.getContext());
385     });
386   }
387 
388 private:
389   static DerivedT getAttribute(const std::vector<EltTy> &values,
390                                PyMlirContextRef ctx) {
391     if constexpr (std::is_same_v<EltTy, bool>) {
392       std::vector<int> intValues(values.begin(), values.end());
393       MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
394                                                   intValues.data());
395       return DerivedT(ctx, attr);
396     } else {
397       MlirAttribute attr =
398           DerivedT::getAttribute(ctx->get(), values.size(), values.data());
399       return DerivedT(ctx, attr);
400     }
401   }
402 };
403 
404 /// Instantiate the python dense array classes.
405 struct PyDenseBoolArrayAttribute
406     : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
407   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
408   static constexpr auto getAttribute = mlirDenseBoolArrayGet;
409   static constexpr auto getElement = mlirDenseBoolArrayGetElement;
410   static constexpr const char *pyClassName = "DenseBoolArrayAttr";
411   static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
412   using PyDenseArrayAttribute::PyDenseArrayAttribute;
413 };
414 struct PyDenseI8ArrayAttribute
415     : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
416   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
417   static constexpr auto getAttribute = mlirDenseI8ArrayGet;
418   static constexpr auto getElement = mlirDenseI8ArrayGetElement;
419   static constexpr const char *pyClassName = "DenseI8ArrayAttr";
420   static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
421   using PyDenseArrayAttribute::PyDenseArrayAttribute;
422 };
423 struct PyDenseI16ArrayAttribute
424     : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
425   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
426   static constexpr auto getAttribute = mlirDenseI16ArrayGet;
427   static constexpr auto getElement = mlirDenseI16ArrayGetElement;
428   static constexpr const char *pyClassName = "DenseI16ArrayAttr";
429   static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
430   using PyDenseArrayAttribute::PyDenseArrayAttribute;
431 };
432 struct PyDenseI32ArrayAttribute
433     : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
434   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
435   static constexpr auto getAttribute = mlirDenseI32ArrayGet;
436   static constexpr auto getElement = mlirDenseI32ArrayGetElement;
437   static constexpr const char *pyClassName = "DenseI32ArrayAttr";
438   static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
439   using PyDenseArrayAttribute::PyDenseArrayAttribute;
440 };
441 struct PyDenseI64ArrayAttribute
442     : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
443   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
444   static constexpr auto getAttribute = mlirDenseI64ArrayGet;
445   static constexpr auto getElement = mlirDenseI64ArrayGetElement;
446   static constexpr const char *pyClassName = "DenseI64ArrayAttr";
447   static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
448   using PyDenseArrayAttribute::PyDenseArrayAttribute;
449 };
450 struct PyDenseF32ArrayAttribute
451     : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
452   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
453   static constexpr auto getAttribute = mlirDenseF32ArrayGet;
454   static constexpr auto getElement = mlirDenseF32ArrayGetElement;
455   static constexpr const char *pyClassName = "DenseF32ArrayAttr";
456   static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
457   using PyDenseArrayAttribute::PyDenseArrayAttribute;
458 };
459 struct PyDenseF64ArrayAttribute
460     : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
461   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
462   static constexpr auto getAttribute = mlirDenseF64ArrayGet;
463   static constexpr auto getElement = mlirDenseF64ArrayGetElement;
464   static constexpr const char *pyClassName = "DenseF64ArrayAttr";
465   static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
466   using PyDenseArrayAttribute::PyDenseArrayAttribute;
467 };
468 
469 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
470 public:
471   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
472   static constexpr const char *pyClassName = "ArrayAttr";
473   using PyConcreteAttribute::PyConcreteAttribute;
474   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
475       mlirArrayAttrGetTypeID;
476 
477   class PyArrayAttributeIterator {
478   public:
479     PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
480 
481     PyArrayAttributeIterator &dunderIter() { return *this; }
482 
483     MlirAttribute dunderNext() {
484       // TODO: Throw is an inefficient way to stop iteration.
485       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
486         throw nb::stop_iteration();
487       return mlirArrayAttrGetElement(attr.get(), nextIndex++);
488     }
489 
490     static void bind(nb::module_ &m) {
491       nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
492           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
493           .def("__next__", &PyArrayAttributeIterator::dunderNext);
494     }
495 
496   private:
497     PyAttribute attr;
498     int nextIndex = 0;
499   };
500 
501   MlirAttribute getItem(intptr_t i) {
502     return mlirArrayAttrGetElement(*this, i);
503   }
504 
505   static void bindDerived(ClassTy &c) {
506     c.def_static(
507         "get",
508         [](nb::list attributes, DefaultingPyMlirContext context) {
509           SmallVector<MlirAttribute> mlirAttributes;
510           mlirAttributes.reserve(nb::len(attributes));
511           for (auto attribute : attributes) {
512             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
513           }
514           MlirAttribute attr = mlirArrayAttrGet(
515               context->get(), mlirAttributes.size(), mlirAttributes.data());
516           return PyArrayAttribute(context->getRef(), attr);
517         },
518         nb::arg("attributes"), nb::arg("context").none() = nb::none(),
519         "Gets a uniqued Array attribute");
520     c.def("__getitem__",
521           [](PyArrayAttribute &arr, intptr_t i) {
522             if (i >= mlirArrayAttrGetNumElements(arr))
523               throw nb::index_error("ArrayAttribute index out of range");
524             return arr.getItem(i);
525           })
526         .def("__len__",
527              [](const PyArrayAttribute &arr) {
528                return mlirArrayAttrGetNumElements(arr);
529              })
530         .def("__iter__", [](const PyArrayAttribute &arr) {
531           return PyArrayAttributeIterator(arr);
532         });
533     c.def("__add__", [](PyArrayAttribute arr, nb::list extras) {
534       std::vector<MlirAttribute> attributes;
535       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
536       attributes.reserve(numOldElements + nb::len(extras));
537       for (intptr_t i = 0; i < numOldElements; ++i)
538         attributes.push_back(arr.getItem(i));
539       for (nb::handle attr : extras)
540         attributes.push_back(pyTryCast<PyAttribute>(attr));
541       MlirAttribute arrayAttr = mlirArrayAttrGet(
542           arr.getContext()->get(), attributes.size(), attributes.data());
543       return PyArrayAttribute(arr.getContext(), arrayAttr);
544     });
545   }
546 };
547 
548 /// Float Point Attribute subclass - FloatAttr.
549 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
550 public:
551   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
552   static constexpr const char *pyClassName = "FloatAttr";
553   using PyConcreteAttribute::PyConcreteAttribute;
554   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
555       mlirFloatAttrGetTypeID;
556 
557   static void bindDerived(ClassTy &c) {
558     c.def_static(
559         "get",
560         [](PyType &type, double value, DefaultingPyLocation loc) {
561           PyMlirContext::ErrorCapture errors(loc->getContext());
562           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
563           if (mlirAttributeIsNull(attr))
564             throw MLIRError("Invalid attribute", errors.take());
565           return PyFloatAttribute(type.getContext(), attr);
566         },
567         nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(),
568         "Gets an uniqued float point attribute associated to a type");
569     c.def_static(
570         "get_f32",
571         [](double value, DefaultingPyMlirContext context) {
572           MlirAttribute attr = mlirFloatAttrDoubleGet(
573               context->get(), mlirF32TypeGet(context->get()), value);
574           return PyFloatAttribute(context->getRef(), attr);
575         },
576         nb::arg("value"), nb::arg("context").none() = nb::none(),
577         "Gets an uniqued float point attribute associated to a f32 type");
578     c.def_static(
579         "get_f64",
580         [](double value, DefaultingPyMlirContext context) {
581           MlirAttribute attr = mlirFloatAttrDoubleGet(
582               context->get(), mlirF64TypeGet(context->get()), value);
583           return PyFloatAttribute(context->getRef(), attr);
584         },
585         nb::arg("value"), nb::arg("context").none() = nb::none(),
586         "Gets an uniqued float point attribute associated to a f64 type");
587     c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
588                   "Returns the value of the float attribute");
589     c.def("__float__", mlirFloatAttrGetValueDouble,
590           "Converts the value of the float attribute to a Python float");
591   }
592 };
593 
594 /// Integer Attribute subclass - IntegerAttr.
595 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
596 public:
597   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
598   static constexpr const char *pyClassName = "IntegerAttr";
599   using PyConcreteAttribute::PyConcreteAttribute;
600 
601   static void bindDerived(ClassTy &c) {
602     c.def_static(
603         "get",
604         [](PyType &type, int64_t value) {
605           MlirAttribute attr = mlirIntegerAttrGet(type, value);
606           return PyIntegerAttribute(type.getContext(), attr);
607         },
608         nb::arg("type"), nb::arg("value"),
609         "Gets an uniqued integer attribute associated to a type");
610     c.def_prop_ro("value", toPyInt,
611                   "Returns the value of the integer attribute");
612     c.def("__int__", toPyInt,
613           "Converts the value of the integer attribute to a Python int");
614     c.def_prop_ro_static("static_typeid",
615                          [](nb::object & /*class*/) -> MlirTypeID {
616                            return mlirIntegerAttrGetTypeID();
617                          });
618   }
619 
620 private:
621   static int64_t toPyInt(PyIntegerAttribute &self) {
622     MlirType type = mlirAttributeGetType(self);
623     if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
624       return mlirIntegerAttrGetValueInt(self);
625     if (mlirIntegerTypeIsSigned(type))
626       return mlirIntegerAttrGetValueSInt(self);
627     return mlirIntegerAttrGetValueUInt(self);
628   }
629 };
630 
631 /// Bool Attribute subclass - BoolAttr.
632 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
633 public:
634   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
635   static constexpr const char *pyClassName = "BoolAttr";
636   using PyConcreteAttribute::PyConcreteAttribute;
637 
638   static void bindDerived(ClassTy &c) {
639     c.def_static(
640         "get",
641         [](bool value, DefaultingPyMlirContext context) {
642           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
643           return PyBoolAttribute(context->getRef(), attr);
644         },
645         nb::arg("value"), nb::arg("context").none() = nb::none(),
646         "Gets an uniqued bool attribute");
647     c.def_prop_ro("value", mlirBoolAttrGetValue,
648                   "Returns the value of the bool attribute");
649     c.def("__bool__", mlirBoolAttrGetValue,
650           "Converts the value of the bool attribute to a Python bool");
651   }
652 };
653 
654 class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
655 public:
656   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
657   static constexpr const char *pyClassName = "SymbolRefAttr";
658   using PyConcreteAttribute::PyConcreteAttribute;
659 
660   static MlirAttribute fromList(const std::vector<std::string> &symbols,
661                                 PyMlirContext &context) {
662     if (symbols.empty())
663       throw std::runtime_error("SymbolRefAttr must be composed of at least "
664                                "one symbol.");
665     MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
666     SmallVector<MlirAttribute, 3> referenceAttrs;
667     for (size_t i = 1; i < symbols.size(); ++i) {
668       referenceAttrs.push_back(
669           mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
670     }
671     return mlirSymbolRefAttrGet(context.get(), rootSymbol,
672                                 referenceAttrs.size(), referenceAttrs.data());
673   }
674 
675   static void bindDerived(ClassTy &c) {
676     c.def_static(
677         "get",
678         [](const std::vector<std::string> &symbols,
679            DefaultingPyMlirContext context) {
680           return PySymbolRefAttribute::fromList(symbols, context.resolve());
681         },
682         nb::arg("symbols"), nb::arg("context").none() = nb::none(),
683         "Gets a uniqued SymbolRef attribute from a list of symbol names");
684     c.def_prop_ro(
685         "value",
686         [](PySymbolRefAttribute &self) {
687           std::vector<std::string> symbols = {
688               unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
689           for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
690                ++i)
691             symbols.push_back(
692                 unwrap(mlirSymbolRefAttrGetRootReference(
693                            mlirSymbolRefAttrGetNestedReference(self, i)))
694                     .str());
695           return symbols;
696         },
697         "Returns the value of the SymbolRef attribute as a list[str]");
698   }
699 };
700 
701 class PyFlatSymbolRefAttribute
702     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
703 public:
704   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
705   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
706   using PyConcreteAttribute::PyConcreteAttribute;
707 
708   static void bindDerived(ClassTy &c) {
709     c.def_static(
710         "get",
711         [](std::string value, DefaultingPyMlirContext context) {
712           MlirAttribute attr =
713               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
714           return PyFlatSymbolRefAttribute(context->getRef(), attr);
715         },
716         nb::arg("value"), nb::arg("context").none() = nb::none(),
717         "Gets a uniqued FlatSymbolRef attribute");
718     c.def_prop_ro(
719         "value",
720         [](PyFlatSymbolRefAttribute &self) {
721           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
722           return nb::str(stringRef.data, stringRef.length);
723         },
724         "Returns the value of the FlatSymbolRef attribute as a string");
725   }
726 };
727 
728 class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
729 public:
730   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
731   static constexpr const char *pyClassName = "OpaqueAttr";
732   using PyConcreteAttribute::PyConcreteAttribute;
733   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
734       mlirOpaqueAttrGetTypeID;
735 
736   static void bindDerived(ClassTy &c) {
737     c.def_static(
738         "get",
739         [](std::string dialectNamespace, nb_buffer buffer, PyType &type,
740            DefaultingPyMlirContext context) {
741           const nb_buffer_info bufferInfo = buffer.request();
742           intptr_t bufferSize = bufferInfo.size;
743           MlirAttribute attr = mlirOpaqueAttrGet(
744               context->get(), toMlirStringRef(dialectNamespace), bufferSize,
745               static_cast<char *>(bufferInfo.ptr), type);
746           return PyOpaqueAttribute(context->getRef(), attr);
747         },
748         nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
749         nb::arg("context").none() = nb::none(), "Gets an Opaque attribute.");
750     c.def_prop_ro(
751         "dialect_namespace",
752         [](PyOpaqueAttribute &self) {
753           MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
754           return nb::str(stringRef.data, stringRef.length);
755         },
756         "Returns the dialect namespace for the Opaque attribute as a string");
757     c.def_prop_ro(
758         "data",
759         [](PyOpaqueAttribute &self) {
760           MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
761           return nb::bytes(stringRef.data, stringRef.length);
762         },
763         "Returns the data for the Opaqued attributes as `bytes`");
764   }
765 };
766 
767 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
768 public:
769   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
770   static constexpr const char *pyClassName = "StringAttr";
771   using PyConcreteAttribute::PyConcreteAttribute;
772   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
773       mlirStringAttrGetTypeID;
774 
775   static void bindDerived(ClassTy &c) {
776     c.def_static(
777         "get",
778         [](std::string value, DefaultingPyMlirContext context) {
779           MlirAttribute attr =
780               mlirStringAttrGet(context->get(), toMlirStringRef(value));
781           return PyStringAttribute(context->getRef(), attr);
782         },
783         nb::arg("value"), nb::arg("context").none() = nb::none(),
784         "Gets a uniqued string attribute");
785     c.def_static(
786         "get",
787         [](nb::bytes value, DefaultingPyMlirContext context) {
788           MlirAttribute attr =
789               mlirStringAttrGet(context->get(), toMlirStringRef(value));
790           return PyStringAttribute(context->getRef(), attr);
791         },
792         nb::arg("value"), nb::arg("context").none() = nb::none(),
793         "Gets a uniqued string attribute");
794     c.def_static(
795         "get_typed",
796         [](PyType &type, std::string value) {
797           MlirAttribute attr =
798               mlirStringAttrTypedGet(type, toMlirStringRef(value));
799           return PyStringAttribute(type.getContext(), attr);
800         },
801         nb::arg("type"), nb::arg("value"),
802         "Gets a uniqued string attribute associated to a type");
803     c.def_prop_ro(
804         "value",
805         [](PyStringAttribute &self) {
806           MlirStringRef stringRef = mlirStringAttrGetValue(self);
807           return nb::str(stringRef.data, stringRef.length);
808         },
809         "Returns the value of the string attribute");
810     c.def_prop_ro(
811         "value_bytes",
812         [](PyStringAttribute &self) {
813           MlirStringRef stringRef = mlirStringAttrGetValue(self);
814           return nb::bytes(stringRef.data, stringRef.length);
815         },
816         "Returns the value of the string attribute as `bytes`");
817   }
818 };
819 
820 // TODO: Support construction of string elements.
821 class PyDenseElementsAttribute
822     : public PyConcreteAttribute<PyDenseElementsAttribute> {
823 public:
824   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
825   static constexpr const char *pyClassName = "DenseElementsAttr";
826   using PyConcreteAttribute::PyConcreteAttribute;
827 
828   static PyDenseElementsAttribute
829   getFromList(nb::list attributes, std::optional<PyType> explicitType,
830               DefaultingPyMlirContext contextWrapper) {
831     const size_t numAttributes = nb::len(attributes);
832     if (numAttributes == 0)
833       throw nb::value_error("Attributes list must be non-empty.");
834 
835     MlirType shapedType;
836     if (explicitType) {
837       if ((!mlirTypeIsAShaped(*explicitType) ||
838            !mlirShapedTypeHasStaticShape(*explicitType))) {
839 
840         std::string message;
841         llvm::raw_string_ostream os(message);
842         os << "Expected a static ShapedType for the shaped_type parameter: "
843            << nb::cast<std::string>(nb::repr(nb::cast(*explicitType)));
844         throw nb::value_error(message.c_str());
845       }
846       shapedType = *explicitType;
847     } else {
848       SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)};
849       shapedType = mlirRankedTensorTypeGet(
850           shape.size(), shape.data(),
851           mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
852           mlirAttributeGetNull());
853     }
854 
855     SmallVector<MlirAttribute> mlirAttributes;
856     mlirAttributes.reserve(numAttributes);
857     for (const nb::handle &attribute : attributes) {
858       MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
859       MlirType attrType = mlirAttributeGetType(mlirAttribute);
860       mlirAttributes.push_back(mlirAttribute);
861 
862       if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
863         std::string message;
864         llvm::raw_string_ostream os(message);
865         os << "All attributes must be of the same type and match "
866            << "the type parameter: expected="
867            << nb::cast<std::string>(nb::repr(nb::cast(shapedType)))
868            << ", but got="
869            << nb::cast<std::string>(nb::repr(nb::cast(attrType)));
870         throw nb::value_error(message.c_str());
871       }
872     }
873 
874     MlirAttribute elements = mlirDenseElementsAttrGet(
875         shapedType, mlirAttributes.size(), mlirAttributes.data());
876 
877     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
878   }
879 
880   static PyDenseElementsAttribute
881   getFromBuffer(nb_buffer array, bool signless,
882                 std::optional<PyType> explicitType,
883                 std::optional<std::vector<int64_t>> explicitShape,
884                 DefaultingPyMlirContext contextWrapper) {
885     // Request a contiguous view. In exotic cases, this will cause a copy.
886     int flags = PyBUF_ND;
887     if (!explicitType) {
888       flags |= PyBUF_FORMAT;
889     }
890     Py_buffer view;
891     if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
892       throw nb::python_error();
893     }
894     auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
895 
896     MlirContext context = contextWrapper->get();
897     MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
898                                                 explicitShape, context);
899     if (mlirAttributeIsNull(attr)) {
900       throw std::invalid_argument(
901           "DenseElementsAttr could not be constructed from the given buffer. "
902           "This may mean that the Python buffer layout does not match that "
903           "MLIR expected layout and is a bug.");
904     }
905     return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
906   }
907 
908   static PyDenseElementsAttribute getSplat(const PyType &shapedType,
909                                            PyAttribute &elementAttr) {
910     auto contextWrapper =
911         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
912     if (!mlirAttributeIsAInteger(elementAttr) &&
913         !mlirAttributeIsAFloat(elementAttr)) {
914       std::string message = "Illegal element type for DenseElementsAttr: ";
915       message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
916       throw nb::value_error(message.c_str());
917     }
918     if (!mlirTypeIsAShaped(shapedType) ||
919         !mlirShapedTypeHasStaticShape(shapedType)) {
920       std::string message =
921           "Expected a static ShapedType for the shaped_type parameter: ";
922       message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
923       throw nb::value_error(message.c_str());
924     }
925     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
926     MlirType attrType = mlirAttributeGetType(elementAttr);
927     if (!mlirTypeEqual(shapedElementType, attrType)) {
928       std::string message =
929           "Shaped element type and attribute type must be equal: shaped=";
930       message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
931       message.append(", element=");
932       message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
933       throw nb::value_error(message.c_str());
934     }
935 
936     MlirAttribute elements =
937         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
938     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
939   }
940 
941   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
942 
943   std::unique_ptr<nb_buffer_info> accessBuffer() {
944     MlirType shapedType = mlirAttributeGetType(*this);
945     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
946     std::string format;
947 
948     if (mlirTypeIsAF32(elementType)) {
949       // f32
950       return bufferInfo<float>(shapedType);
951     }
952     if (mlirTypeIsAF64(elementType)) {
953       // f64
954       return bufferInfo<double>(shapedType);
955     }
956     if (mlirTypeIsAF16(elementType)) {
957       // f16
958       return bufferInfo<uint16_t>(shapedType, "e");
959     }
960     if (mlirTypeIsAIndex(elementType)) {
961       // Same as IndexType::kInternalStorageBitWidth
962       return bufferInfo<int64_t>(shapedType);
963     }
964     if (mlirTypeIsAInteger(elementType) &&
965         mlirIntegerTypeGetWidth(elementType) == 32) {
966       if (mlirIntegerTypeIsSignless(elementType) ||
967           mlirIntegerTypeIsSigned(elementType)) {
968         // i32
969         return bufferInfo<int32_t>(shapedType);
970       }
971       if (mlirIntegerTypeIsUnsigned(elementType)) {
972         // unsigned i32
973         return bufferInfo<uint32_t>(shapedType);
974       }
975     } else if (mlirTypeIsAInteger(elementType) &&
976                mlirIntegerTypeGetWidth(elementType) == 64) {
977       if (mlirIntegerTypeIsSignless(elementType) ||
978           mlirIntegerTypeIsSigned(elementType)) {
979         // i64
980         return bufferInfo<int64_t>(shapedType);
981       }
982       if (mlirIntegerTypeIsUnsigned(elementType)) {
983         // unsigned i64
984         return bufferInfo<uint64_t>(shapedType);
985       }
986     } else if (mlirTypeIsAInteger(elementType) &&
987                mlirIntegerTypeGetWidth(elementType) == 8) {
988       if (mlirIntegerTypeIsSignless(elementType) ||
989           mlirIntegerTypeIsSigned(elementType)) {
990         // i8
991         return bufferInfo<int8_t>(shapedType);
992       }
993       if (mlirIntegerTypeIsUnsigned(elementType)) {
994         // unsigned i8
995         return bufferInfo<uint8_t>(shapedType);
996       }
997     } else if (mlirTypeIsAInteger(elementType) &&
998                mlirIntegerTypeGetWidth(elementType) == 16) {
999       if (mlirIntegerTypeIsSignless(elementType) ||
1000           mlirIntegerTypeIsSigned(elementType)) {
1001         // i16
1002         return bufferInfo<int16_t>(shapedType);
1003       }
1004       if (mlirIntegerTypeIsUnsigned(elementType)) {
1005         // unsigned i16
1006         return bufferInfo<uint16_t>(shapedType);
1007       }
1008     } else if (mlirTypeIsAInteger(elementType) &&
1009                mlirIntegerTypeGetWidth(elementType) == 1) {
1010       // i1 / bool
1011       // We can not send the buffer directly back to Python, because the i1
1012       // values are bitpacked within MLIR. We call numpy's unpackbits function
1013       // to convert the bytes.
1014       return getBooleanBufferFromBitpackedAttribute();
1015     }
1016 
1017     // TODO: Currently crashes the program.
1018     // Reported as https://github.com/pybind/pybind11/issues/3336
1019     throw std::invalid_argument(
1020         "unsupported data type for conversion to Python buffer");
1021   }
1022 
1023   static void bindDerived(ClassTy &c) {
1024 #if PY_VERSION_HEX < 0x03090000
1025     PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr());
1026     tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer;
1027     tp->tp_as_buffer->bf_releasebuffer =
1028         PyDenseElementsAttribute::bf_releasebuffer;
1029 #endif
1030     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
1031         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
1032                     nb::arg("array"), nb::arg("signless") = true,
1033                     nb::arg("type").none() = nb::none(),
1034                     nb::arg("shape").none() = nb::none(),
1035                     nb::arg("context").none() = nb::none(),
1036                     kDenseElementsAttrGetDocstring)
1037         .def_static("get", PyDenseElementsAttribute::getFromList,
1038                     nb::arg("attrs"), nb::arg("type").none() = nb::none(),
1039                     nb::arg("context").none() = nb::none(),
1040                     kDenseElementsAttrGetFromListDocstring)
1041         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
1042                     nb::arg("shaped_type"), nb::arg("element_attr"),
1043                     "Gets a DenseElementsAttr where all values are the same")
1044         .def_prop_ro("is_splat",
1045                      [](PyDenseElementsAttribute &self) -> bool {
1046                        return mlirDenseElementsAttrIsSplat(self);
1047                      })
1048         .def("get_splat_value", [](PyDenseElementsAttribute &self) {
1049           if (!mlirDenseElementsAttrIsSplat(self))
1050             throw nb::value_error(
1051                 "get_splat_value called on a non-splat attribute");
1052           return mlirDenseElementsAttrGetSplatValue(self);
1053         });
1054   }
1055 
1056   static PyType_Slot slots[];
1057 
1058 private:
1059   static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
1060   static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
1061 
1062   static bool isUnsignedIntegerFormat(std::string_view format) {
1063     if (format.empty())
1064       return false;
1065     char code = format[0];
1066     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
1067            code == 'Q';
1068   }
1069 
1070   static bool isSignedIntegerFormat(std::string_view format) {
1071     if (format.empty())
1072       return false;
1073     char code = format[0];
1074     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
1075            code == 'q';
1076   }
1077 
1078   static MlirType
1079   getShapedType(std::optional<MlirType> bulkLoadElementType,
1080                 std::optional<std::vector<int64_t>> explicitShape,
1081                 Py_buffer &view) {
1082     SmallVector<int64_t> shape;
1083     if (explicitShape) {
1084       shape.append(explicitShape->begin(), explicitShape->end());
1085     } else {
1086       shape.append(view.shape, view.shape + view.ndim);
1087     }
1088 
1089     if (mlirTypeIsAShaped(*bulkLoadElementType)) {
1090       if (explicitShape) {
1091         throw std::invalid_argument("Shape can only be specified explicitly "
1092                                     "when the type is not a shaped type.");
1093       }
1094       return *bulkLoadElementType;
1095     } else {
1096       MlirAttribute encodingAttr = mlirAttributeGetNull();
1097       return mlirRankedTensorTypeGet(shape.size(), shape.data(),
1098                                      *bulkLoadElementType, encodingAttr);
1099     }
1100   }
1101 
1102   static MlirAttribute getAttributeFromBuffer(
1103       Py_buffer &view, bool signless, std::optional<PyType> explicitType,
1104       std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
1105     // Detect format codes that are suitable for bulk loading. This includes
1106     // all byte aligned integer and floating point types up to 8 bytes.
1107     // Notably, this excludes exotics types which do not have a direct
1108     // representation in the buffer protocol (i.e. complex, etc).
1109     std::optional<MlirType> bulkLoadElementType;
1110     if (explicitType) {
1111       bulkLoadElementType = *explicitType;
1112     } else {
1113       std::string_view format(view.format);
1114       if (format == "f") {
1115         // f32
1116         assert(view.itemsize == 4 && "mismatched array itemsize");
1117         bulkLoadElementType = mlirF32TypeGet(context);
1118       } else if (format == "d") {
1119         // f64
1120         assert(view.itemsize == 8 && "mismatched array itemsize");
1121         bulkLoadElementType = mlirF64TypeGet(context);
1122       } else if (format == "e") {
1123         // f16
1124         assert(view.itemsize == 2 && "mismatched array itemsize");
1125         bulkLoadElementType = mlirF16TypeGet(context);
1126       } else if (format == "?") {
1127         // i1
1128         // The i1 type needs to be bit-packed, so we will handle it seperately
1129         return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
1130                                                       context);
1131       } else if (isSignedIntegerFormat(format)) {
1132         if (view.itemsize == 4) {
1133           // i32
1134           bulkLoadElementType = signless
1135                                     ? mlirIntegerTypeGet(context, 32)
1136                                     : mlirIntegerTypeSignedGet(context, 32);
1137         } else if (view.itemsize == 8) {
1138           // i64
1139           bulkLoadElementType = signless
1140                                     ? mlirIntegerTypeGet(context, 64)
1141                                     : mlirIntegerTypeSignedGet(context, 64);
1142         } else if (view.itemsize == 1) {
1143           // i8
1144           bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
1145                                          : mlirIntegerTypeSignedGet(context, 8);
1146         } else if (view.itemsize == 2) {
1147           // i16
1148           bulkLoadElementType = signless
1149                                     ? mlirIntegerTypeGet(context, 16)
1150                                     : mlirIntegerTypeSignedGet(context, 16);
1151         }
1152       } else if (isUnsignedIntegerFormat(format)) {
1153         if (view.itemsize == 4) {
1154           // unsigned i32
1155           bulkLoadElementType = signless
1156                                     ? mlirIntegerTypeGet(context, 32)
1157                                     : mlirIntegerTypeUnsignedGet(context, 32);
1158         } else if (view.itemsize == 8) {
1159           // unsigned i64
1160           bulkLoadElementType = signless
1161                                     ? mlirIntegerTypeGet(context, 64)
1162                                     : mlirIntegerTypeUnsignedGet(context, 64);
1163         } else if (view.itemsize == 1) {
1164           // i8
1165           bulkLoadElementType = signless
1166                                     ? mlirIntegerTypeGet(context, 8)
1167                                     : mlirIntegerTypeUnsignedGet(context, 8);
1168         } else if (view.itemsize == 2) {
1169           // i16
1170           bulkLoadElementType = signless
1171                                     ? mlirIntegerTypeGet(context, 16)
1172                                     : mlirIntegerTypeUnsignedGet(context, 16);
1173         }
1174       }
1175       if (!bulkLoadElementType) {
1176         throw std::invalid_argument(
1177             std::string("unimplemented array format conversion from format: ") +
1178             std::string(format));
1179       }
1180     }
1181 
1182     MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
1183     return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
1184   }
1185 
1186   // There is a complication for boolean numpy arrays, as numpy represents
1187   // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
1188   // booleans per byte.
1189   static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
1190       Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
1191       MlirContext &context) {
1192     if (llvm::endianness::native != llvm::endianness::little) {
1193       // Given we have no good way of testing the behavior on big-endian
1194       // systems we will throw
1195       throw nb::type_error("Constructing a bit-packed MLIR attribute is "
1196                            "unsupported on big-endian systems");
1197     }
1198     nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray(
1199         /*data=*/static_cast<uint8_t *>(view.buf),
1200         /*shape=*/{static_cast<size_t>(view.len)});
1201 
1202     nb::module_ numpy = nb::module_::import_("numpy");
1203     nb::object packbitsFunc = numpy.attr("packbits");
1204     nb::object packedBooleans =
1205         packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
1206     nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
1207 
1208     MlirType bitpackedType =
1209         getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
1210     assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
1211     // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
1212     // packedBooleans, hence the MlirAttribute will remain valid even when
1213     // packedBooleans get reclaimed by the end of the function.
1214     return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
1215                                              pythonBuffer.ptr);
1216   }
1217 
1218   // This does the opposite transformation of
1219   // `getBitpackedAttributeFromBooleanBuffer`
1220   std::unique_ptr<nb_buffer_info> getBooleanBufferFromBitpackedAttribute() {
1221     if (llvm::endianness::native != llvm::endianness::little) {
1222       // Given we have no good way of testing the behavior on big-endian
1223       // systems we will throw
1224       throw nb::type_error("Constructing a numpy array from a MLIR attribute "
1225                            "is unsupported on big-endian systems");
1226     }
1227 
1228     int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
1229     int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
1230     uint8_t *bitpackedData = static_cast<uint8_t *>(
1231         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1232     nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray(
1233         /*data=*/bitpackedData,
1234         /*shape=*/{static_cast<size_t>(numBitpackedBytes)});
1235 
1236     nb::module_ numpy = nb::module_::import_("numpy");
1237     nb::object unpackbitsFunc = numpy.attr("unpackbits");
1238     nb::object equalFunc = numpy.attr("equal");
1239     nb::object reshapeFunc = numpy.attr("reshape");
1240     nb::object unpackedBooleans =
1241         unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little");
1242 
1243     // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
1244     // We need to:
1245     //   1. Slice away the padded bits
1246     //   2. Make the boolean array have the correct shape
1247     //   3. Convert the array to a boolean array
1248     unpackedBooleans = unpackedBooleans[nb::slice(
1249         nb::int_(0), nb::int_(numBooleans), nb::int_(1))];
1250     unpackedBooleans = equalFunc(unpackedBooleans, 1);
1251 
1252     MlirType shapedType = mlirAttributeGetType(*this);
1253     intptr_t rank = mlirShapedTypeGetRank(shapedType);
1254     std::vector<intptr_t> shape(rank);
1255     for (intptr_t i = 0; i < rank; ++i) {
1256       shape[i] = mlirShapedTypeGetDimSize(shapedType, i);
1257     }
1258     unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
1259 
1260     // Make sure the returned nb::buffer_view claims ownership of the data in
1261     // `pythonBuffer` so it remains valid when Python reads it
1262     nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans);
1263     return std::make_unique<nb_buffer_info>(pythonBuffer.request());
1264   }
1265 
1266   template <typename Type>
1267   std::unique_ptr<nb_buffer_info>
1268   bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) {
1269     intptr_t rank = mlirShapedTypeGetRank(shapedType);
1270     // Prepare the data for the buffer_info.
1271     // Buffer is configured for read-only access below.
1272     Type *data = static_cast<Type *>(
1273         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1274     // Prepare the shape for the buffer_info.
1275     SmallVector<intptr_t, 4> shape;
1276     for (intptr_t i = 0; i < rank; ++i)
1277       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
1278     // Prepare the strides for the buffer_info.
1279     SmallVector<intptr_t, 4> strides;
1280     if (mlirDenseElementsAttrIsSplat(*this)) {
1281       // Splats are special, only the single value is stored.
1282       strides.assign(rank, 0);
1283     } else {
1284       for (intptr_t i = 1; i < rank; ++i) {
1285         intptr_t strideFactor = 1;
1286         for (intptr_t j = i; j < rank; ++j)
1287           strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
1288         strides.push_back(sizeof(Type) * strideFactor);
1289       }
1290       strides.push_back(sizeof(Type));
1291     }
1292     const char *format;
1293     if (explicitFormat) {
1294       format = explicitFormat;
1295     } else {
1296       format = nb_format_descriptor<Type>::format();
1297     }
1298     return std::make_unique<nb_buffer_info>(
1299         data, sizeof(Type), format, rank, std::move(shape), std::move(strides),
1300         /*readonly=*/true);
1301   }
1302 }; // namespace
1303 
1304 PyType_Slot PyDenseElementsAttribute::slots[] = {
1305 // Python 3.8 doesn't allow setting the buffer protocol slots from a type spec.
1306 #if PY_VERSION_HEX >= 0x03090000
1307     {Py_bf_getbuffer,
1308      reinterpret_cast<void *>(PyDenseElementsAttribute::bf_getbuffer)},
1309     {Py_bf_releasebuffer,
1310      reinterpret_cast<void *>(PyDenseElementsAttribute::bf_releasebuffer)},
1311 #endif
1312     {0, nullptr},
1313 };
1314 
1315 /*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj,
1316                                                       Py_buffer *view,
1317                                                       int flags) {
1318   view->obj = nullptr;
1319   std::unique_ptr<nb_buffer_info> info;
1320   try {
1321     auto *attr = nb::cast<PyDenseElementsAttribute *>(nb::handle(obj));
1322     info = attr->accessBuffer();
1323   } catch (nb::python_error &e) {
1324     e.restore();
1325     nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer");
1326     return -1;
1327   }
1328   view->obj = obj;
1329   view->ndim = 1;
1330   view->buf = info->ptr;
1331   view->itemsize = info->itemsize;
1332   view->len = info->itemsize;
1333   for (auto s : info->shape) {
1334     view->len *= s;
1335   }
1336   view->readonly = info->readonly;
1337   if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
1338     view->format = const_cast<char *>(info->format);
1339   }
1340   if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
1341     view->ndim = static_cast<int>(info->ndim);
1342     view->strides = info->strides.data();
1343     view->shape = info->shape.data();
1344   }
1345   view->suboffsets = nullptr;
1346   view->internal = info.release();
1347   Py_INCREF(obj);
1348   return 0;
1349 }
1350 
1351 /*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *,
1352                                                            Py_buffer *view) {
1353   delete reinterpret_cast<nb_buffer_info *>(view->internal);
1354 }
1355 
1356 /// Refinement of the PyDenseElementsAttribute for attributes containing
1357 /// integer (and boolean) values. Supports element access.
1358 class PyDenseIntElementsAttribute
1359     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
1360                                  PyDenseElementsAttribute> {
1361 public:
1362   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
1363   static constexpr const char *pyClassName = "DenseIntElementsAttr";
1364   using PyConcreteAttribute::PyConcreteAttribute;
1365 
1366   /// Returns the element at the given linear position. Asserts if the index
1367   /// is out of range.
1368   nb::object dunderGetItem(intptr_t pos) {
1369     if (pos < 0 || pos >= dunderLen()) {
1370       throw nb::index_error("attempt to access out of bounds element");
1371     }
1372 
1373     MlirType type = mlirAttributeGetType(*this);
1374     type = mlirShapedTypeGetElementType(type);
1375     // Index type can also appear as a DenseIntElementsAttr and therefore can be
1376     // casted to integer.
1377     assert(mlirTypeIsAInteger(type) ||
1378            mlirTypeIsAIndex(type) && "expected integer/index element type in "
1379                                      "dense int elements attribute");
1380     // Dispatch element extraction to an appropriate C function based on the
1381     // elemental type of the attribute. nb::int_ is implicitly constructible
1382     // from any C++ integral type and handles bitwidth correctly.
1383     // TODO: consider caching the type properties in the constructor to avoid
1384     // querying them on each element access.
1385     if (mlirTypeIsAIndex(type)) {
1386       return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
1387     }
1388     unsigned width = mlirIntegerTypeGetWidth(type);
1389     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
1390     if (isUnsigned) {
1391       if (width == 1) {
1392         return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
1393       }
1394       if (width == 8) {
1395         return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
1396       }
1397       if (width == 16) {
1398         return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
1399       }
1400       if (width == 32) {
1401         return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
1402       }
1403       if (width == 64) {
1404         return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
1405       }
1406     } else {
1407       if (width == 1) {
1408         return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
1409       }
1410       if (width == 8) {
1411         return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
1412       }
1413       if (width == 16) {
1414         return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
1415       }
1416       if (width == 32) {
1417         return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
1418       }
1419       if (width == 64) {
1420         return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
1421       }
1422     }
1423     throw nb::type_error("Unsupported integer type");
1424   }
1425 
1426   static void bindDerived(ClassTy &c) {
1427     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
1428   }
1429 };
1430 
1431 class PyDenseResourceElementsAttribute
1432     : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
1433 public:
1434   static constexpr IsAFunctionTy isaFunction =
1435       mlirAttributeIsADenseResourceElements;
1436   static constexpr const char *pyClassName = "DenseResourceElementsAttr";
1437   using PyConcreteAttribute::PyConcreteAttribute;
1438 
1439   static PyDenseResourceElementsAttribute
1440   getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type,
1441                 std::optional<size_t> alignment, bool isMutable,
1442                 DefaultingPyMlirContext contextWrapper) {
1443     if (!mlirTypeIsAShaped(type)) {
1444       throw std::invalid_argument(
1445           "Constructing a DenseResourceElementsAttr requires a ShapedType.");
1446     }
1447 
1448     // Do not request any conversions as we must ensure to use caller
1449     // managed memory.
1450     int flags = PyBUF_STRIDES;
1451     std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
1452     if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
1453       throw nb::python_error();
1454     }
1455 
1456     // This scope releaser will only release if we haven't yet transferred
1457     // ownership.
1458     auto freeBuffer = llvm::make_scope_exit([&]() {
1459       if (view)
1460         PyBuffer_Release(view.get());
1461     });
1462 
1463     if (!PyBuffer_IsContiguous(view.get(), 'A')) {
1464       throw std::invalid_argument("Contiguous buffer is required.");
1465     }
1466 
1467     // Infer alignment to be the stride of one element if not explicit.
1468     size_t inferredAlignment;
1469     if (alignment)
1470       inferredAlignment = *alignment;
1471     else
1472       inferredAlignment = view->strides[view->ndim - 1];
1473 
1474     // The userData is a Py_buffer* that the deleter owns.
1475     auto deleter = [](void *userData, const void *data, size_t size,
1476                       size_t align) {
1477       if (!Py_IsInitialized())
1478         Py_Initialize();
1479       Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
1480       nb::gil_scoped_acquire gil;
1481       PyBuffer_Release(ownedView);
1482       delete ownedView;
1483     };
1484 
1485     size_t rawBufferSize = view->len;
1486     MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
1487         type, toMlirStringRef(name), view->buf, rawBufferSize,
1488         inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
1489     if (mlirAttributeIsNull(attr)) {
1490       throw std::invalid_argument(
1491           "DenseResourceElementsAttr could not be constructed from the given "
1492           "buffer. "
1493           "This may mean that the Python buffer layout does not match that "
1494           "MLIR expected layout and is a bug.");
1495     }
1496     view.release();
1497     return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
1498   }
1499 
1500   static void bindDerived(ClassTy &c) {
1501     c.def_static(
1502         "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer,
1503         nb::arg("array"), nb::arg("name"), nb::arg("type"),
1504         nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false,
1505         nb::arg("context").none() = nb::none(),
1506         kDenseResourceElementsAttrGetFromBufferDocstring);
1507   }
1508 };
1509 
1510 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
1511 public:
1512   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
1513   static constexpr const char *pyClassName = "DictAttr";
1514   using PyConcreteAttribute::PyConcreteAttribute;
1515   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1516       mlirDictionaryAttrGetTypeID;
1517 
1518   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
1519 
1520   bool dunderContains(const std::string &name) {
1521     return !mlirAttributeIsNull(
1522         mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
1523   }
1524 
1525   static void bindDerived(ClassTy &c) {
1526     c.def("__contains__", &PyDictAttribute::dunderContains);
1527     c.def("__len__", &PyDictAttribute::dunderLen);
1528     c.def_static(
1529         "get",
1530         [](nb::dict attributes, DefaultingPyMlirContext context) {
1531           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
1532           mlirNamedAttributes.reserve(attributes.size());
1533           for (std::pair<nb::handle, nb::handle> it : attributes) {
1534             auto &mlirAttr = nb::cast<PyAttribute &>(it.second);
1535             auto name = nb::cast<std::string>(it.first);
1536             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1537                 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
1538                                   toMlirStringRef(name)),
1539                 mlirAttr));
1540           }
1541           MlirAttribute attr =
1542               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
1543                                     mlirNamedAttributes.data());
1544           return PyDictAttribute(context->getRef(), attr);
1545         },
1546         nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(),
1547         "Gets an uniqued dict attribute");
1548     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
1549       MlirAttribute attr =
1550           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
1551       if (mlirAttributeIsNull(attr))
1552         throw nb::key_error("attempt to access a non-existent attribute");
1553       return attr;
1554     });
1555     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
1556       if (index < 0 || index >= self.dunderLen()) {
1557         throw nb::index_error("attempt to access out of bounds attribute");
1558       }
1559       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
1560       return PyNamedAttribute(
1561           namedAttr.attribute,
1562           std::string(mlirIdentifierStr(namedAttr.name).data));
1563     });
1564   }
1565 };
1566 
1567 /// Refinement of PyDenseElementsAttribute for attributes containing
1568 /// floating-point values. Supports element access.
1569 class PyDenseFPElementsAttribute
1570     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1571                                  PyDenseElementsAttribute> {
1572 public:
1573   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1574   static constexpr const char *pyClassName = "DenseFPElementsAttr";
1575   using PyConcreteAttribute::PyConcreteAttribute;
1576 
1577   nb::float_ dunderGetItem(intptr_t pos) {
1578     if (pos < 0 || pos >= dunderLen()) {
1579       throw nb::index_error("attempt to access out of bounds element");
1580     }
1581 
1582     MlirType type = mlirAttributeGetType(*this);
1583     type = mlirShapedTypeGetElementType(type);
1584     // Dispatch element extraction to an appropriate C function based on the
1585     // elemental type of the attribute. nb::float_ is implicitly constructible
1586     // from float and double.
1587     // TODO: consider caching the type properties in the constructor to avoid
1588     // querying them on each element access.
1589     if (mlirTypeIsAF32(type)) {
1590       return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
1591     }
1592     if (mlirTypeIsAF64(type)) {
1593       return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
1594     }
1595     throw nb::type_error("Unsupported floating-point type");
1596   }
1597 
1598   static void bindDerived(ClassTy &c) {
1599     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1600   }
1601 };
1602 
1603 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1604 public:
1605   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1606   static constexpr const char *pyClassName = "TypeAttr";
1607   using PyConcreteAttribute::PyConcreteAttribute;
1608   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1609       mlirTypeAttrGetTypeID;
1610 
1611   static void bindDerived(ClassTy &c) {
1612     c.def_static(
1613         "get",
1614         [](PyType value, DefaultingPyMlirContext context) {
1615           MlirAttribute attr = mlirTypeAttrGet(value.get());
1616           return PyTypeAttribute(context->getRef(), attr);
1617         },
1618         nb::arg("value"), nb::arg("context").none() = nb::none(),
1619         "Gets a uniqued Type attribute");
1620     c.def_prop_ro("value", [](PyTypeAttribute &self) {
1621       return mlirTypeAttrGetValue(self.get());
1622     });
1623   }
1624 };
1625 
1626 /// Unit Attribute subclass. Unit attributes don't have values.
1627 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1628 public:
1629   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1630   static constexpr const char *pyClassName = "UnitAttr";
1631   using PyConcreteAttribute::PyConcreteAttribute;
1632   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1633       mlirUnitAttrGetTypeID;
1634 
1635   static void bindDerived(ClassTy &c) {
1636     c.def_static(
1637         "get",
1638         [](DefaultingPyMlirContext context) {
1639           return PyUnitAttribute(context->getRef(),
1640                                  mlirUnitAttrGet(context->get()));
1641         },
1642         nb::arg("context").none() = nb::none(), "Create a Unit attribute.");
1643   }
1644 };
1645 
1646 /// Strided layout attribute subclass.
1647 class PyStridedLayoutAttribute
1648     : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1649 public:
1650   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1651   static constexpr const char *pyClassName = "StridedLayoutAttr";
1652   using PyConcreteAttribute::PyConcreteAttribute;
1653   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1654       mlirStridedLayoutAttrGetTypeID;
1655 
1656   static void bindDerived(ClassTy &c) {
1657     c.def_static(
1658         "get",
1659         [](int64_t offset, const std::vector<int64_t> strides,
1660            DefaultingPyMlirContext ctx) {
1661           MlirAttribute attr = mlirStridedLayoutAttrGet(
1662               ctx->get(), offset, strides.size(), strides.data());
1663           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1664         },
1665         nb::arg("offset"), nb::arg("strides"),
1666         nb::arg("context").none() = nb::none(),
1667         "Gets a strided layout attribute.");
1668     c.def_static(
1669         "get_fully_dynamic",
1670         [](int64_t rank, DefaultingPyMlirContext ctx) {
1671           auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
1672           std::vector<int64_t> strides(rank);
1673           std::fill(strides.begin(), strides.end(), dynamic);
1674           MlirAttribute attr = mlirStridedLayoutAttrGet(
1675               ctx->get(), dynamic, strides.size(), strides.data());
1676           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1677         },
1678         nb::arg("rank"), nb::arg("context").none() = nb::none(),
1679         "Gets a strided layout attribute with dynamic offset and strides of "
1680         "a "
1681         "given rank.");
1682     c.def_prop_ro(
1683         "offset",
1684         [](PyStridedLayoutAttribute &self) {
1685           return mlirStridedLayoutAttrGetOffset(self);
1686         },
1687         "Returns the value of the float point attribute");
1688     c.def_prop_ro(
1689         "strides",
1690         [](PyStridedLayoutAttribute &self) {
1691           intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1692           std::vector<int64_t> strides(size);
1693           for (intptr_t i = 0; i < size; i++) {
1694             strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1695           }
1696           return strides;
1697         },
1698         "Returns the value of the float point attribute");
1699   }
1700 };
1701 
1702 nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
1703   if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
1704     return nb::cast(PyDenseBoolArrayAttribute(pyAttribute));
1705   if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
1706     return nb::cast(PyDenseI8ArrayAttribute(pyAttribute));
1707   if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
1708     return nb::cast(PyDenseI16ArrayAttribute(pyAttribute));
1709   if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
1710     return nb::cast(PyDenseI32ArrayAttribute(pyAttribute));
1711   if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
1712     return nb::cast(PyDenseI64ArrayAttribute(pyAttribute));
1713   if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
1714     return nb::cast(PyDenseF32ArrayAttribute(pyAttribute));
1715   if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
1716     return nb::cast(PyDenseF64ArrayAttribute(pyAttribute));
1717   std::string msg =
1718       std::string("Can't cast unknown element type DenseArrayAttr (") +
1719       nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1720   throw nb::type_error(msg.c_str());
1721 }
1722 
1723 nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
1724   if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
1725     return nb::cast(PyDenseFPElementsAttribute(pyAttribute));
1726   if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
1727     return nb::cast(PyDenseIntElementsAttribute(pyAttribute));
1728   std::string msg =
1729       std::string(
1730           "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
1731       nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1732   throw nb::type_error(msg.c_str());
1733 }
1734 
1735 nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
1736   if (PyBoolAttribute::isaFunction(pyAttribute))
1737     return nb::cast(PyBoolAttribute(pyAttribute));
1738   if (PyIntegerAttribute::isaFunction(pyAttribute))
1739     return nb::cast(PyIntegerAttribute(pyAttribute));
1740   std::string msg =
1741       std::string("Can't cast unknown element type DenseArrayAttr (") +
1742       nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1743   throw nb::type_error(msg.c_str());
1744 }
1745 
1746 nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
1747   if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
1748     return nb::cast(PyFlatSymbolRefAttribute(pyAttribute));
1749   if (PySymbolRefAttribute::isaFunction(pyAttribute))
1750     return nb::cast(PySymbolRefAttribute(pyAttribute));
1751   std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
1752                     nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) +
1753                     ")";
1754   throw nb::type_error(msg.c_str());
1755 }
1756 
1757 } // namespace
1758 
1759 void mlir::python::populateIRAttributes(nb::module_ &m) {
1760   PyAffineMapAttribute::bind(m);
1761   PyDenseBoolArrayAttribute::bind(m);
1762   PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1763   PyDenseI8ArrayAttribute::bind(m);
1764   PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1765   PyDenseI16ArrayAttribute::bind(m);
1766   PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1767   PyDenseI32ArrayAttribute::bind(m);
1768   PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1769   PyDenseI64ArrayAttribute::bind(m);
1770   PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1771   PyDenseF32ArrayAttribute::bind(m);
1772   PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1773   PyDenseF64ArrayAttribute::bind(m);
1774   PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
1775   PyGlobals::get().registerTypeCaster(
1776       mlirDenseArrayAttrGetTypeID(),
1777       nb::cast<nb::callable>(nb::cpp_function(denseArrayAttributeCaster)));
1778 
1779   PyArrayAttribute::bind(m);
1780   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1781   PyBoolAttribute::bind(m);
1782   PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots);
1783   PyDenseFPElementsAttribute::bind(m);
1784   PyDenseIntElementsAttribute::bind(m);
1785   PyGlobals::get().registerTypeCaster(
1786       mlirDenseIntOrFPElementsAttrGetTypeID(),
1787       nb::cast<nb::callable>(
1788           nb::cpp_function(denseIntOrFPElementsAttributeCaster)));
1789   PyDenseResourceElementsAttribute::bind(m);
1790 
1791   PyDictAttribute::bind(m);
1792   PySymbolRefAttribute::bind(m);
1793   PyGlobals::get().registerTypeCaster(
1794       mlirSymbolRefAttrGetTypeID(),
1795       nb::cast<nb::callable>(
1796           nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)));
1797 
1798   PyFlatSymbolRefAttribute::bind(m);
1799   PyOpaqueAttribute::bind(m);
1800   PyFloatAttribute::bind(m);
1801   PyIntegerAttribute::bind(m);
1802   PyIntegerSetAttribute::bind(m);
1803   PyStringAttribute::bind(m);
1804   PyTypeAttribute::bind(m);
1805   PyGlobals::get().registerTypeCaster(
1806       mlirIntegerAttrGetTypeID(),
1807       nb::cast<nb::callable>(nb::cpp_function(integerOrBoolAttributeCaster)));
1808   PyUnitAttribute::bind(m);
1809 
1810   PyStridedLayoutAttribute::bind(m);
1811 }
1812