xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision 8dcb67225b2ce871b54f7a0f172b58f15f05f7fa)
1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9a1fe1f5fSKazu Hirata #include <optional>
1071a25454SPeter Hawkins #include <string_view>
114811270bSmax #include <utility>
121fc096afSMehdi Amini 
13436c6c9cSStella Laurenzo #include "IRModule.h"
14436c6c9cSStella Laurenzo 
15436c6c9cSStella Laurenzo #include "PybindUtils.h"
16436c6c9cSStella Laurenzo 
1771a25454SPeter Hawkins #include "llvm/ADT/ScopeExit.h"
1871a25454SPeter Hawkins 
19436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
20436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
21bfb1ba75Smax #include "mlir/Bindings/Python/PybindAdaptors.h"
22436c6c9cSStella Laurenzo 
23436c6c9cSStella Laurenzo namespace py = pybind11;
24436c6c9cSStella Laurenzo using namespace mlir;
25436c6c9cSStella Laurenzo using namespace mlir::python;
26436c6c9cSStella Laurenzo 
27436c6c9cSStella Laurenzo using llvm::SmallVector;
28436c6c9cSStella Laurenzo 
295d6d30edSStella Laurenzo //------------------------------------------------------------------------------
305d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
315d6d30edSStella Laurenzo //------------------------------------------------------------------------------
325d6d30edSStella Laurenzo 
335d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] =
345d6d30edSStella Laurenzo     R"(Gets a DenseElementsAttr from a Python buffer or array.
355d6d30edSStella Laurenzo 
365d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based
375d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and
385d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these
395d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer.
405d6d30edSStella Laurenzo 
415d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided
425d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal
435d6d30edSStella Laurenzo representation:
445d6d30edSStella Laurenzo 
455d6d30edSStella Laurenzo   * Integer types (except for i1): the buffer must be byte aligned to the
465d6d30edSStella Laurenzo     next byte boundary.
475d6d30edSStella Laurenzo   * Floating point types: Must be bit-castable to the given floating point
485d6d30edSStella Laurenzo     size.
495d6d30edSStella Laurenzo   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
505d6d30edSStella Laurenzo     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
515d6d30edSStella Laurenzo     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
525d6d30edSStella Laurenzo 
535d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0
545d6d30edSStella Laurenzo or 255), then a splat will be created.
555d6d30edSStella Laurenzo 
565d6d30edSStella Laurenzo Args:
575d6d30edSStella Laurenzo   array: The array or buffer to convert.
585d6d30edSStella Laurenzo   signless: If inferring an appropriate MLIR type, use signless types for
595d6d30edSStella Laurenzo     integers (defaults True).
605d6d30edSStella Laurenzo   type: Skips inference of the MLIR element type and uses this instead. The
615d6d30edSStella Laurenzo     storage size must be consistent with the actual contents of the buffer.
625d6d30edSStella Laurenzo   shape: Overrides the shape of the buffer when constructing the MLIR
635d6d30edSStella Laurenzo     shaped type. This is needed when the physical and logical shape differ (as
645d6d30edSStella Laurenzo     for i1).
655d6d30edSStella Laurenzo   context: Explicit context, if not from context manager.
665d6d30edSStella Laurenzo 
675d6d30edSStella Laurenzo Returns:
685d6d30edSStella Laurenzo   DenseElementsAttr on success.
695d6d30edSStella Laurenzo 
705d6d30edSStella Laurenzo Raises:
715d6d30edSStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
725d6d30edSStella Laurenzo     type or if the buffer does not meet expectations.
735d6d30edSStella Laurenzo )";
745d6d30edSStella Laurenzo 
75436c6c9cSStella Laurenzo namespace {
76436c6c9cSStella Laurenzo 
77436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
78436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
79436c6c9cSStella Laurenzo }
80436c6c9cSStella Laurenzo 
81436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
82436c6c9cSStella Laurenzo public:
83436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
84436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
85436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
869566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
879566ee28Smax       mlirAffineMapAttrGetTypeID;
88436c6c9cSStella Laurenzo 
89436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
90436c6c9cSStella Laurenzo     c.def_static(
91436c6c9cSStella Laurenzo         "get",
92436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
93436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
94436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
95436c6c9cSStella Laurenzo         },
96436c6c9cSStella Laurenzo         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
97436c6c9cSStella Laurenzo   }
98436c6c9cSStella Laurenzo };
99436c6c9cSStella Laurenzo 
100ed9e52f3SAlex Zinenko template <typename T>
101ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) {
102ed9e52f3SAlex Zinenko   try {
103ed9e52f3SAlex Zinenko     return object.cast<T>();
104ed9e52f3SAlex Zinenko   } catch (py::cast_error &err) {
105ed9e52f3SAlex Zinenko     std::string msg =
106ed9e52f3SAlex Zinenko         std::string(
107ed9e52f3SAlex Zinenko             "Invalid attribute when attempting to create an ArrayAttribute (") +
108ed9e52f3SAlex Zinenko         err.what() + ")";
109ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
110ed9e52f3SAlex Zinenko   } catch (py::reference_cast_error &err) {
111ed9e52f3SAlex Zinenko     std::string msg = std::string("Invalid attribute (None?) when attempting "
112ed9e52f3SAlex Zinenko                                   "to create an ArrayAttribute (") +
113ed9e52f3SAlex Zinenko                       err.what() + ")";
114ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
115ed9e52f3SAlex Zinenko   }
116ed9e52f3SAlex Zinenko }
117ed9e52f3SAlex Zinenko 
118619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived
119619fd8c2SJeff Niu /// implementation class.
120619fd8c2SJeff Niu template <typename EltTy, typename DerivedT>
121133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
122619fd8c2SJeff Niu public:
123133624acSJeff Niu   using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
124619fd8c2SJeff Niu 
125619fd8c2SJeff Niu   /// Iterator over the integer elements of a dense array.
126619fd8c2SJeff Niu   class PyDenseArrayIterator {
127619fd8c2SJeff Niu   public:
1284a1b1196SMehdi Amini     PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
129619fd8c2SJeff Niu 
130619fd8c2SJeff Niu     /// Return a copy of the iterator.
131619fd8c2SJeff Niu     PyDenseArrayIterator dunderIter() { return *this; }
132619fd8c2SJeff Niu 
133619fd8c2SJeff Niu     /// Return the next element.
134619fd8c2SJeff Niu     EltTy dunderNext() {
135619fd8c2SJeff Niu       // Throw if the index has reached the end.
136619fd8c2SJeff Niu       if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
137619fd8c2SJeff Niu         throw py::stop_iteration();
138619fd8c2SJeff Niu       return DerivedT::getElement(attr.get(), nextIndex++);
139619fd8c2SJeff Niu     }
140619fd8c2SJeff Niu 
141619fd8c2SJeff Niu     /// Bind the iterator class.
142619fd8c2SJeff Niu     static void bind(py::module &m) {
143619fd8c2SJeff Niu       py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName,
144619fd8c2SJeff Niu                                        py::module_local())
145619fd8c2SJeff Niu           .def("__iter__", &PyDenseArrayIterator::dunderIter)
146619fd8c2SJeff Niu           .def("__next__", &PyDenseArrayIterator::dunderNext);
147619fd8c2SJeff Niu     }
148619fd8c2SJeff Niu 
149619fd8c2SJeff Niu   private:
150619fd8c2SJeff Niu     /// The referenced dense array attribute.
151619fd8c2SJeff Niu     PyAttribute attr;
152619fd8c2SJeff Niu     /// The next index to read.
153619fd8c2SJeff Niu     int nextIndex = 0;
154619fd8c2SJeff Niu   };
155619fd8c2SJeff Niu 
156619fd8c2SJeff Niu   /// Get the element at the given index.
157619fd8c2SJeff Niu   EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
158619fd8c2SJeff Niu 
159619fd8c2SJeff Niu   /// Bind the attribute class.
160133624acSJeff Niu   static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
161619fd8c2SJeff Niu     // Bind the constructor.
162619fd8c2SJeff Niu     c.def_static(
163619fd8c2SJeff Niu         "get",
164619fd8c2SJeff Niu         [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
165*8dcb6722SIngo Müller           return getAttribute(values, ctx->getRef());
166619fd8c2SJeff Niu         },
167619fd8c2SJeff Niu         py::arg("values"), py::arg("context") = py::none(),
168619fd8c2SJeff Niu         "Gets a uniqued dense array attribute");
169619fd8c2SJeff Niu     // Bind the array methods.
170133624acSJeff Niu     c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
171619fd8c2SJeff Niu       if (i >= mlirDenseArrayGetNumElements(arr))
172619fd8c2SJeff Niu         throw py::index_error("DenseArray index out of range");
173619fd8c2SJeff Niu       return arr.getItem(i);
174619fd8c2SJeff Niu     });
175133624acSJeff Niu     c.def("__len__", [](const DerivedT &arr) {
176619fd8c2SJeff Niu       return mlirDenseArrayGetNumElements(arr);
177619fd8c2SJeff Niu     });
178133624acSJeff Niu     c.def("__iter__",
179133624acSJeff Niu           [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
1804a1b1196SMehdi Amini     c.def("__add__", [](DerivedT &arr, const py::list &extras) {
181619fd8c2SJeff Niu       std::vector<EltTy> values;
182619fd8c2SJeff Niu       intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
183619fd8c2SJeff Niu       values.reserve(numOldElements + py::len(extras));
184619fd8c2SJeff Niu       for (intptr_t i = 0; i < numOldElements; ++i)
185619fd8c2SJeff Niu         values.push_back(arr.getItem(i));
186619fd8c2SJeff Niu       for (py::handle attr : extras)
187619fd8c2SJeff Niu         values.push_back(pyTryCast<EltTy>(attr));
188*8dcb6722SIngo Müller       return getAttribute(values, arr.getContext());
189619fd8c2SJeff Niu     });
190619fd8c2SJeff Niu   }
191*8dcb6722SIngo Müller 
192*8dcb6722SIngo Müller private:
193*8dcb6722SIngo Müller   static DerivedT getAttribute(const std::vector<EltTy> &values,
194*8dcb6722SIngo Müller                                PyMlirContextRef ctx) {
195*8dcb6722SIngo Müller     if constexpr (std::is_same_v<EltTy, bool>) {
196*8dcb6722SIngo Müller       std::vector<int> intValues(values.begin(), values.end());
197*8dcb6722SIngo Müller       MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
198*8dcb6722SIngo Müller                                                   intValues.data());
199*8dcb6722SIngo Müller       return DerivedT(ctx, attr);
200*8dcb6722SIngo Müller     } else {
201*8dcb6722SIngo Müller       MlirAttribute attr =
202*8dcb6722SIngo Müller           DerivedT::getAttribute(ctx->get(), values.size(), values.data());
203*8dcb6722SIngo Müller       return DerivedT(ctx, attr);
204*8dcb6722SIngo Müller     }
205*8dcb6722SIngo Müller   }
206619fd8c2SJeff Niu };
207619fd8c2SJeff Niu 
208619fd8c2SJeff Niu /// Instantiate the python dense array classes.
209619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute
210*8dcb6722SIngo Müller     : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
211619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
212619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseBoolArrayGet;
213619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseBoolArrayGetElement;
214619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseBoolArrayAttr";
215619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
216619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
217619fd8c2SJeff Niu };
218619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute
219619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
220619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
221619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI8ArrayGet;
222619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI8ArrayGetElement;
223619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI8ArrayAttr";
224619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
225619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
226619fd8c2SJeff Niu };
227619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute
228619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
229619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
230619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI16ArrayGet;
231619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI16ArrayGetElement;
232619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI16ArrayAttr";
233619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
234619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
235619fd8c2SJeff Niu };
236619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute
237619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
238619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
239619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI32ArrayGet;
240619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI32ArrayGetElement;
241619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI32ArrayAttr";
242619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
243619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
244619fd8c2SJeff Niu };
245619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute
246619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
247619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
248619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI64ArrayGet;
249619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI64ArrayGetElement;
250619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI64ArrayAttr";
251619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
252619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
253619fd8c2SJeff Niu };
254619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute
255619fd8c2SJeff Niu     : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
256619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
257619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF32ArrayGet;
258619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF32ArrayGetElement;
259619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF32ArrayAttr";
260619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
261619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
262619fd8c2SJeff Niu };
263619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute
264619fd8c2SJeff Niu     : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
265619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
266619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF64ArrayGet;
267619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF64ArrayGetElement;
268619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF64ArrayAttr";
269619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
270619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
271619fd8c2SJeff Niu };
272619fd8c2SJeff Niu 
273436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
274436c6c9cSStella Laurenzo public:
275436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
276436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
277436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
2789566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
2799566ee28Smax       mlirArrayAttrGetTypeID;
280436c6c9cSStella Laurenzo 
281436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
282436c6c9cSStella Laurenzo   public:
2831fc096afSMehdi Amini     PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
284436c6c9cSStella Laurenzo 
285436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
286436c6c9cSStella Laurenzo 
287974c1596SRahul Kayaith     MlirAttribute dunderNext() {
288bca88952SJeff Niu       // TODO: Throw is an inefficient way to stop iteration.
289bca88952SJeff Niu       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
290436c6c9cSStella Laurenzo         throw py::stop_iteration();
291974c1596SRahul Kayaith       return mlirArrayAttrGetElement(attr.get(), nextIndex++);
292436c6c9cSStella Laurenzo     }
293436c6c9cSStella Laurenzo 
294436c6c9cSStella Laurenzo     static void bind(py::module &m) {
295f05ff4f7SStella Laurenzo       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
296f05ff4f7SStella Laurenzo                                            py::module_local())
297436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
298436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
299436c6c9cSStella Laurenzo     }
300436c6c9cSStella Laurenzo 
301436c6c9cSStella Laurenzo   private:
302436c6c9cSStella Laurenzo     PyAttribute attr;
303436c6c9cSStella Laurenzo     int nextIndex = 0;
304436c6c9cSStella Laurenzo   };
305436c6c9cSStella Laurenzo 
306974c1596SRahul Kayaith   MlirAttribute getItem(intptr_t i) {
307974c1596SRahul Kayaith     return mlirArrayAttrGetElement(*this, i);
308ed9e52f3SAlex Zinenko   }
309ed9e52f3SAlex Zinenko 
310436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
311436c6c9cSStella Laurenzo     c.def_static(
312436c6c9cSStella Laurenzo         "get",
313436c6c9cSStella Laurenzo         [](py::list attributes, DefaultingPyMlirContext context) {
314436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
315436c6c9cSStella Laurenzo           mlirAttributes.reserve(py::len(attributes));
316436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
317ed9e52f3SAlex Zinenko             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
318436c6c9cSStella Laurenzo           }
319436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
320436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
321436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
322436c6c9cSStella Laurenzo         },
323436c6c9cSStella Laurenzo         py::arg("attributes"), py::arg("context") = py::none(),
324436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
325436c6c9cSStella Laurenzo     c.def("__getitem__",
326436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
327436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
328436c6c9cSStella Laurenzo               throw py::index_error("ArrayAttribute index out of range");
329ed9e52f3SAlex Zinenko             return arr.getItem(i);
330436c6c9cSStella Laurenzo           })
331436c6c9cSStella Laurenzo         .def("__len__",
332436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
333436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
334436c6c9cSStella Laurenzo              })
335436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
336436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
337436c6c9cSStella Laurenzo         });
338ed9e52f3SAlex Zinenko     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
339ed9e52f3SAlex Zinenko       std::vector<MlirAttribute> attributes;
340ed9e52f3SAlex Zinenko       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
341ed9e52f3SAlex Zinenko       attributes.reserve(numOldElements + py::len(extras));
342ed9e52f3SAlex Zinenko       for (intptr_t i = 0; i < numOldElements; ++i)
343ed9e52f3SAlex Zinenko         attributes.push_back(arr.getItem(i));
344ed9e52f3SAlex Zinenko       for (py::handle attr : extras)
345ed9e52f3SAlex Zinenko         attributes.push_back(pyTryCast<PyAttribute>(attr));
346ed9e52f3SAlex Zinenko       MlirAttribute arrayAttr = mlirArrayAttrGet(
347ed9e52f3SAlex Zinenko           arr.getContext()->get(), attributes.size(), attributes.data());
348ed9e52f3SAlex Zinenko       return PyArrayAttribute(arr.getContext(), arrayAttr);
349ed9e52f3SAlex Zinenko     });
350436c6c9cSStella Laurenzo   }
351436c6c9cSStella Laurenzo };
352436c6c9cSStella Laurenzo 
353436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
354436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
355436c6c9cSStella Laurenzo public:
356436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
357436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
358436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
3599566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
3609566ee28Smax       mlirFloatAttrGetTypeID;
361436c6c9cSStella Laurenzo 
362436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
363436c6c9cSStella Laurenzo     c.def_static(
364436c6c9cSStella Laurenzo         "get",
365436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
3663ea4c501SRahul Kayaith           PyMlirContext::ErrorCapture errors(loc->getContext());
367436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
3683ea4c501SRahul Kayaith           if (mlirAttributeIsNull(attr))
3693ea4c501SRahul Kayaith             throw MLIRError("Invalid attribute", errors.take());
370436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
371436c6c9cSStella Laurenzo         },
372436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
373436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
374436c6c9cSStella Laurenzo     c.def_static(
375436c6c9cSStella Laurenzo         "get_f32",
376436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
377436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
378436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
379436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
380436c6c9cSStella Laurenzo         },
381436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
382436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
383436c6c9cSStella Laurenzo     c.def_static(
384436c6c9cSStella Laurenzo         "get_f64",
385436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
386436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
387436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
388436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
389436c6c9cSStella Laurenzo         },
390436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
391436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
392436c6c9cSStella Laurenzo     c.def_property_readonly(
393436c6c9cSStella Laurenzo         "value",
394436c6c9cSStella Laurenzo         [](PyFloatAttribute &self) {
395436c6c9cSStella Laurenzo           return mlirFloatAttrGetValueDouble(self);
396436c6c9cSStella Laurenzo         },
397436c6c9cSStella Laurenzo         "Returns the value of the float point attribute");
398436c6c9cSStella Laurenzo   }
399436c6c9cSStella Laurenzo };
400436c6c9cSStella Laurenzo 
401436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
402436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
403436c6c9cSStella Laurenzo public:
404436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
405436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
406436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
407436c6c9cSStella Laurenzo 
408436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
409436c6c9cSStella Laurenzo     c.def_static(
410436c6c9cSStella Laurenzo         "get",
411436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
412436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
413436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
414436c6c9cSStella Laurenzo         },
415436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"),
416436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
417436c6c9cSStella Laurenzo     c.def_property_readonly(
418436c6c9cSStella Laurenzo         "value",
419e9db306dSrkayaith         [](PyIntegerAttribute &self) -> py::int_ {
420e9db306dSrkayaith           MlirType type = mlirAttributeGetType(self);
421e9db306dSrkayaith           if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
422436c6c9cSStella Laurenzo             return mlirIntegerAttrGetValueInt(self);
423e9db306dSrkayaith           if (mlirIntegerTypeIsSigned(type))
424e9db306dSrkayaith             return mlirIntegerAttrGetValueSInt(self);
425e9db306dSrkayaith           return mlirIntegerAttrGetValueUInt(self);
426436c6c9cSStella Laurenzo         },
427436c6c9cSStella Laurenzo         "Returns the value of the integer attribute");
4289566ee28Smax     c.def_property_readonly_static("static_typeid",
4299566ee28Smax                                    [](py::object & /*class*/) -> MlirTypeID {
4309566ee28Smax                                      return mlirIntegerAttrGetTypeID();
4319566ee28Smax                                    });
432436c6c9cSStella Laurenzo   }
433436c6c9cSStella Laurenzo };
434436c6c9cSStella Laurenzo 
435436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
436436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
437436c6c9cSStella Laurenzo public:
438436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
439436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
440436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
441436c6c9cSStella Laurenzo 
442436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
443436c6c9cSStella Laurenzo     c.def_static(
444436c6c9cSStella Laurenzo         "get",
445436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
446436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
447436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
448436c6c9cSStella Laurenzo         },
449436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
450436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
451436c6c9cSStella Laurenzo     c.def_property_readonly(
452436c6c9cSStella Laurenzo         "value",
453436c6c9cSStella Laurenzo         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
454436c6c9cSStella Laurenzo         "Returns the value of the bool attribute");
455436c6c9cSStella Laurenzo   }
456436c6c9cSStella Laurenzo };
457436c6c9cSStella Laurenzo 
4584eee9ef9Smax class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
4594eee9ef9Smax public:
4604eee9ef9Smax   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
4614eee9ef9Smax   static constexpr const char *pyClassName = "SymbolRefAttr";
4624eee9ef9Smax   using PyConcreteAttribute::PyConcreteAttribute;
4634eee9ef9Smax 
4644eee9ef9Smax   static MlirAttribute fromList(const std::vector<std::string> &symbols,
4654eee9ef9Smax                                 PyMlirContext &context) {
4664eee9ef9Smax     if (symbols.empty())
4674eee9ef9Smax       throw std::runtime_error("SymbolRefAttr must be composed of at least "
4684eee9ef9Smax                                "one symbol.");
4694eee9ef9Smax     MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
4704eee9ef9Smax     SmallVector<MlirAttribute, 3> referenceAttrs;
4714eee9ef9Smax     for (size_t i = 1; i < symbols.size(); ++i) {
4724eee9ef9Smax       referenceAttrs.push_back(
4734eee9ef9Smax           mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
4744eee9ef9Smax     }
4754eee9ef9Smax     return mlirSymbolRefAttrGet(context.get(), rootSymbol,
4764eee9ef9Smax                                 referenceAttrs.size(), referenceAttrs.data());
4774eee9ef9Smax   }
4784eee9ef9Smax 
4794eee9ef9Smax   static void bindDerived(ClassTy &c) {
4804eee9ef9Smax     c.def_static(
4814eee9ef9Smax         "get",
4824eee9ef9Smax         [](const std::vector<std::string> &symbols,
4834eee9ef9Smax            DefaultingPyMlirContext context) {
4844eee9ef9Smax           return PySymbolRefAttribute::fromList(symbols, context.resolve());
4854eee9ef9Smax         },
4864eee9ef9Smax         py::arg("symbols"), py::arg("context") = py::none(),
4874eee9ef9Smax         "Gets a uniqued SymbolRef attribute from a list of symbol names");
4884eee9ef9Smax     c.def_property_readonly(
4894eee9ef9Smax         "value",
4904eee9ef9Smax         [](PySymbolRefAttribute &self) {
4914eee9ef9Smax           std::vector<std::string> symbols = {
4924eee9ef9Smax               unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
4934eee9ef9Smax           for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
4944eee9ef9Smax                ++i)
4954eee9ef9Smax             symbols.push_back(
4964eee9ef9Smax                 unwrap(mlirSymbolRefAttrGetRootReference(
4974eee9ef9Smax                            mlirSymbolRefAttrGetNestedReference(self, i)))
4984eee9ef9Smax                     .str());
4994eee9ef9Smax           return symbols;
5004eee9ef9Smax         },
5014eee9ef9Smax         "Returns the value of the SymbolRef attribute as a list[str]");
5024eee9ef9Smax   }
5034eee9ef9Smax };
5044eee9ef9Smax 
505436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
506436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
507436c6c9cSStella Laurenzo public:
508436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
509436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
510436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
511436c6c9cSStella Laurenzo 
512436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
513436c6c9cSStella Laurenzo     c.def_static(
514436c6c9cSStella Laurenzo         "get",
515436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
516436c6c9cSStella Laurenzo           MlirAttribute attr =
517436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
518436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
519436c6c9cSStella Laurenzo         },
520436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
521436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
522436c6c9cSStella Laurenzo     c.def_property_readonly(
523436c6c9cSStella Laurenzo         "value",
524436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
525436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
526436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
527436c6c9cSStella Laurenzo         },
528436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
529436c6c9cSStella Laurenzo   }
530436c6c9cSStella Laurenzo };
531436c6c9cSStella Laurenzo 
5325c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
5335c3861b2SYun Long public:
5345c3861b2SYun Long   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
5355c3861b2SYun Long   static constexpr const char *pyClassName = "OpaqueAttr";
5365c3861b2SYun Long   using PyConcreteAttribute::PyConcreteAttribute;
5379566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
5389566ee28Smax       mlirOpaqueAttrGetTypeID;
5395c3861b2SYun Long 
5405c3861b2SYun Long   static void bindDerived(ClassTy &c) {
5415c3861b2SYun Long     c.def_static(
5425c3861b2SYun Long         "get",
5435c3861b2SYun Long         [](std::string dialectNamespace, py::buffer buffer, PyType &type,
5445c3861b2SYun Long            DefaultingPyMlirContext context) {
5455c3861b2SYun Long           const py::buffer_info bufferInfo = buffer.request();
5465c3861b2SYun Long           intptr_t bufferSize = bufferInfo.size;
5475c3861b2SYun Long           MlirAttribute attr = mlirOpaqueAttrGet(
5485c3861b2SYun Long               context->get(), toMlirStringRef(dialectNamespace), bufferSize,
5495c3861b2SYun Long               static_cast<char *>(bufferInfo.ptr), type);
5505c3861b2SYun Long           return PyOpaqueAttribute(context->getRef(), attr);
5515c3861b2SYun Long         },
5525c3861b2SYun Long         py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"),
5535c3861b2SYun Long         py::arg("context") = py::none(), "Gets an Opaque attribute.");
5545c3861b2SYun Long     c.def_property_readonly(
5555c3861b2SYun Long         "dialect_namespace",
5565c3861b2SYun Long         [](PyOpaqueAttribute &self) {
5575c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
5585c3861b2SYun Long           return py::str(stringRef.data, stringRef.length);
5595c3861b2SYun Long         },
5605c3861b2SYun Long         "Returns the dialect namespace for the Opaque attribute as a string");
5615c3861b2SYun Long     c.def_property_readonly(
5625c3861b2SYun Long         "data",
5635c3861b2SYun Long         [](PyOpaqueAttribute &self) {
5645c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
56562bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
5665c3861b2SYun Long         },
56762bf6c2eSChris Jones         "Returns the data for the Opaqued attributes as `bytes`");
5685c3861b2SYun Long   }
5695c3861b2SYun Long };
5705c3861b2SYun Long 
571436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
572436c6c9cSStella Laurenzo public:
573436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
574436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
575436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
5769566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
5779566ee28Smax       mlirStringAttrGetTypeID;
578436c6c9cSStella Laurenzo 
579436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
580436c6c9cSStella Laurenzo     c.def_static(
581436c6c9cSStella Laurenzo         "get",
582436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
583436c6c9cSStella Laurenzo           MlirAttribute attr =
584436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
585436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
586436c6c9cSStella Laurenzo         },
587436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
588436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
589436c6c9cSStella Laurenzo     c.def_static(
590436c6c9cSStella Laurenzo         "get_typed",
591436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
592436c6c9cSStella Laurenzo           MlirAttribute attr =
593436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
594436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
595436c6c9cSStella Laurenzo         },
596a6e7d024SStella Laurenzo         py::arg("type"), py::arg("value"),
597436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
598436c6c9cSStella Laurenzo     c.def_property_readonly(
599436c6c9cSStella Laurenzo         "value",
600436c6c9cSStella Laurenzo         [](PyStringAttribute &self) {
601436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirStringAttrGetValue(self);
602436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
603436c6c9cSStella Laurenzo         },
604436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
60562bf6c2eSChris Jones     c.def_property_readonly(
60662bf6c2eSChris Jones         "value_bytes",
60762bf6c2eSChris Jones         [](PyStringAttribute &self) {
60862bf6c2eSChris Jones           MlirStringRef stringRef = mlirStringAttrGetValue(self);
60962bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
61062bf6c2eSChris Jones         },
61162bf6c2eSChris Jones         "Returns the value of the string attribute as `bytes`");
612436c6c9cSStella Laurenzo   }
613436c6c9cSStella Laurenzo };
614436c6c9cSStella Laurenzo 
615436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
616436c6c9cSStella Laurenzo class PyDenseElementsAttribute
617436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
618436c6c9cSStella Laurenzo public:
619436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
620436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
621436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
622436c6c9cSStella Laurenzo 
623436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
6240a81ace0SKazu Hirata   getFromBuffer(py::buffer array, bool signless,
6250a81ace0SKazu Hirata                 std::optional<PyType> explicitType,
6260a81ace0SKazu Hirata                 std::optional<std::vector<int64_t>> explicitShape,
627436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
628436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
62971a25454SPeter Hawkins     int flags = PyBUF_ND;
63071a25454SPeter Hawkins     if (!explicitType) {
63171a25454SPeter Hawkins       flags |= PyBUF_FORMAT;
63271a25454SPeter Hawkins     }
63371a25454SPeter Hawkins     Py_buffer view;
63471a25454SPeter Hawkins     if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
635436c6c9cSStella Laurenzo       throw py::error_already_set();
636436c6c9cSStella Laurenzo     }
63771a25454SPeter Hawkins     auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
6385d6d30edSStella Laurenzo     SmallVector<int64_t> shape;
6395d6d30edSStella Laurenzo     if (explicitShape) {
6405d6d30edSStella Laurenzo       shape.append(explicitShape->begin(), explicitShape->end());
6415d6d30edSStella Laurenzo     } else {
64271a25454SPeter Hawkins       shape.append(view.shape, view.shape + view.ndim);
6435d6d30edSStella Laurenzo     }
644436c6c9cSStella Laurenzo 
6455d6d30edSStella Laurenzo     MlirAttribute encodingAttr = mlirAttributeGetNull();
646436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
6475d6d30edSStella Laurenzo 
6485d6d30edSStella Laurenzo     // Detect format codes that are suitable for bulk loading. This includes
6495d6d30edSStella Laurenzo     // all byte aligned integer and floating point types up to 8 bytes.
6505d6d30edSStella Laurenzo     // Notably, this excludes, bool (which needs to be bit-packed) and
6515d6d30edSStella Laurenzo     // other exotics which do not have a direct representation in the buffer
6525d6d30edSStella Laurenzo     // protocol (i.e. complex, etc).
6530a81ace0SKazu Hirata     std::optional<MlirType> bulkLoadElementType;
6545d6d30edSStella Laurenzo     if (explicitType) {
6555d6d30edSStella Laurenzo       bulkLoadElementType = *explicitType;
65671a25454SPeter Hawkins     } else {
65771a25454SPeter Hawkins       std::string_view format(view.format);
65871a25454SPeter Hawkins       if (format == "f") {
659436c6c9cSStella Laurenzo         // f32
66071a25454SPeter Hawkins         assert(view.itemsize == 4 && "mismatched array itemsize");
6615d6d30edSStella Laurenzo         bulkLoadElementType = mlirF32TypeGet(context);
66271a25454SPeter Hawkins       } else if (format == "d") {
663436c6c9cSStella Laurenzo         // f64
66471a25454SPeter Hawkins         assert(view.itemsize == 8 && "mismatched array itemsize");
6655d6d30edSStella Laurenzo         bulkLoadElementType = mlirF64TypeGet(context);
66671a25454SPeter Hawkins       } else if (format == "e") {
6675d6d30edSStella Laurenzo         // f16
66871a25454SPeter Hawkins         assert(view.itemsize == 2 && "mismatched array itemsize");
6695d6d30edSStella Laurenzo         bulkLoadElementType = mlirF16TypeGet(context);
67071a25454SPeter Hawkins       } else if (isSignedIntegerFormat(format)) {
67171a25454SPeter Hawkins         if (view.itemsize == 4) {
672436c6c9cSStella Laurenzo           // i32
67371a25454SPeter Hawkins           bulkLoadElementType = signless
67471a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 32)
675436c6c9cSStella Laurenzo                                     : mlirIntegerTypeSignedGet(context, 32);
67671a25454SPeter Hawkins         } else if (view.itemsize == 8) {
677436c6c9cSStella Laurenzo           // i64
67871a25454SPeter Hawkins           bulkLoadElementType = signless
67971a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 64)
680436c6c9cSStella Laurenzo                                     : mlirIntegerTypeSignedGet(context, 64);
68171a25454SPeter Hawkins         } else if (view.itemsize == 1) {
6825d6d30edSStella Laurenzo           // i8
6835d6d30edSStella Laurenzo           bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
6845d6d30edSStella Laurenzo                                          : mlirIntegerTypeSignedGet(context, 8);
68571a25454SPeter Hawkins         } else if (view.itemsize == 2) {
6865d6d30edSStella Laurenzo           // i16
68771a25454SPeter Hawkins           bulkLoadElementType = signless
68871a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 16)
6895d6d30edSStella Laurenzo                                     : mlirIntegerTypeSignedGet(context, 16);
690436c6c9cSStella Laurenzo         }
69171a25454SPeter Hawkins       } else if (isUnsignedIntegerFormat(format)) {
69271a25454SPeter Hawkins         if (view.itemsize == 4) {
693436c6c9cSStella Laurenzo           // unsigned i32
6945d6d30edSStella Laurenzo           bulkLoadElementType = signless
695436c6c9cSStella Laurenzo                                     ? mlirIntegerTypeGet(context, 32)
696436c6c9cSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 32);
69771a25454SPeter Hawkins         } else if (view.itemsize == 8) {
698436c6c9cSStella Laurenzo           // unsigned i64
6995d6d30edSStella Laurenzo           bulkLoadElementType = signless
700436c6c9cSStella Laurenzo                                     ? mlirIntegerTypeGet(context, 64)
701436c6c9cSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 64);
70271a25454SPeter Hawkins         } else if (view.itemsize == 1) {
7035d6d30edSStella Laurenzo           // i8
70471a25454SPeter Hawkins           bulkLoadElementType = signless
70571a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 8)
7065d6d30edSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 8);
70771a25454SPeter Hawkins         } else if (view.itemsize == 2) {
7085d6d30edSStella Laurenzo           // i16
7095d6d30edSStella Laurenzo           bulkLoadElementType = signless
7105d6d30edSStella Laurenzo                                     ? mlirIntegerTypeGet(context, 16)
7115d6d30edSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 16);
712436c6c9cSStella Laurenzo         }
713436c6c9cSStella Laurenzo       }
71471a25454SPeter Hawkins       if (!bulkLoadElementType) {
71571a25454SPeter Hawkins         throw std::invalid_argument(
71671a25454SPeter Hawkins             std::string("unimplemented array format conversion from format: ") +
71771a25454SPeter Hawkins             std::string(format));
71871a25454SPeter Hawkins       }
71971a25454SPeter Hawkins     }
72071a25454SPeter Hawkins 
72199dee31eSAdam Paszke     MlirType shapedType;
72299dee31eSAdam Paszke     if (mlirTypeIsAShaped(*bulkLoadElementType)) {
72399dee31eSAdam Paszke       if (explicitShape) {
72499dee31eSAdam Paszke         throw std::invalid_argument("Shape can only be specified explicitly "
72599dee31eSAdam Paszke                                     "when the type is not a shaped type.");
72699dee31eSAdam Paszke       }
72799dee31eSAdam Paszke       shapedType = *bulkLoadElementType;
72899dee31eSAdam Paszke     } else {
72971a25454SPeter Hawkins       shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
73071a25454SPeter Hawkins                                            *bulkLoadElementType, encodingAttr);
73199dee31eSAdam Paszke     }
73271a25454SPeter Hawkins     size_t rawBufferSize = view.len;
73371a25454SPeter Hawkins     MlirAttribute attr =
73471a25454SPeter Hawkins         mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
7355d6d30edSStella Laurenzo     if (mlirAttributeIsNull(attr)) {
7365d6d30edSStella Laurenzo       throw std::invalid_argument(
7375d6d30edSStella Laurenzo           "DenseElementsAttr could not be constructed from the given buffer. "
7385d6d30edSStella Laurenzo           "This may mean that the Python buffer layout does not match that "
7395d6d30edSStella Laurenzo           "MLIR expected layout and is a bug.");
7405d6d30edSStella Laurenzo     }
7415d6d30edSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
7425d6d30edSStella Laurenzo   }
743436c6c9cSStella Laurenzo 
7441fc096afSMehdi Amini   static PyDenseElementsAttribute getSplat(const PyType &shapedType,
745436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
746436c6c9cSStella Laurenzo     auto contextWrapper =
747436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
748436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
749436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
750436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
751436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
7524811270bSmax       throw py::value_error(message);
753436c6c9cSStella Laurenzo     }
754436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
755436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
756436c6c9cSStella Laurenzo       std::string message =
757436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
758436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
7594811270bSmax       throw py::value_error(message);
760436c6c9cSStella Laurenzo     }
761436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
762436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
763436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
764436c6c9cSStella Laurenzo       std::string message =
765436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
766436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
767436c6c9cSStella Laurenzo       message.append(", element=");
768436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
7694811270bSmax       throw py::value_error(message);
770436c6c9cSStella Laurenzo     }
771436c6c9cSStella Laurenzo 
772436c6c9cSStella Laurenzo     MlirAttribute elements =
773436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
774436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
775436c6c9cSStella Laurenzo   }
776436c6c9cSStella Laurenzo 
777436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
778436c6c9cSStella Laurenzo 
779436c6c9cSStella Laurenzo   py::buffer_info accessBuffer() {
780436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
781436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
7825d6d30edSStella Laurenzo     std::string format;
783436c6c9cSStella Laurenzo 
784436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
785436c6c9cSStella Laurenzo       // f32
7865d6d30edSStella Laurenzo       return bufferInfo<float>(shapedType);
78702b6fb21SMehdi Amini     }
78802b6fb21SMehdi Amini     if (mlirTypeIsAF64(elementType)) {
789436c6c9cSStella Laurenzo       // f64
7905d6d30edSStella Laurenzo       return bufferInfo<double>(shapedType);
791bb56c2b3SMehdi Amini     }
792bb56c2b3SMehdi Amini     if (mlirTypeIsAF16(elementType)) {
7935d6d30edSStella Laurenzo       // f16
7945d6d30edSStella Laurenzo       return bufferInfo<uint16_t>(shapedType, "e");
795bb56c2b3SMehdi Amini     }
796ef1b735dSmax     if (mlirTypeIsAIndex(elementType)) {
797ef1b735dSmax       // Same as IndexType::kInternalStorageBitWidth
798ef1b735dSmax       return bufferInfo<int64_t>(shapedType);
799ef1b735dSmax     }
800bb56c2b3SMehdi Amini     if (mlirTypeIsAInteger(elementType) &&
801436c6c9cSStella Laurenzo         mlirIntegerTypeGetWidth(elementType) == 32) {
802436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
803436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
804436c6c9cSStella Laurenzo         // i32
8055d6d30edSStella Laurenzo         return bufferInfo<int32_t>(shapedType);
806e5639b3fSMehdi Amini       }
807e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
808436c6c9cSStella Laurenzo         // unsigned i32
8095d6d30edSStella Laurenzo         return bufferInfo<uint32_t>(shapedType);
810436c6c9cSStella Laurenzo       }
811436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
812436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
813436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
814436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
815436c6c9cSStella Laurenzo         // i64
8165d6d30edSStella Laurenzo         return bufferInfo<int64_t>(shapedType);
817e5639b3fSMehdi Amini       }
818e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
819436c6c9cSStella Laurenzo         // unsigned i64
8205d6d30edSStella Laurenzo         return bufferInfo<uint64_t>(shapedType);
8215d6d30edSStella Laurenzo       }
8225d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
8235d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 8) {
8245d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
8255d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
8265d6d30edSStella Laurenzo         // i8
8275d6d30edSStella Laurenzo         return bufferInfo<int8_t>(shapedType);
828e5639b3fSMehdi Amini       }
829e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
8305d6d30edSStella Laurenzo         // unsigned i8
8315d6d30edSStella Laurenzo         return bufferInfo<uint8_t>(shapedType);
8325d6d30edSStella Laurenzo       }
8335d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
8345d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 16) {
8355d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
8365d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
8375d6d30edSStella Laurenzo         // i16
8385d6d30edSStella Laurenzo         return bufferInfo<int16_t>(shapedType);
839e5639b3fSMehdi Amini       }
840e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
8415d6d30edSStella Laurenzo         // unsigned i16
8425d6d30edSStella Laurenzo         return bufferInfo<uint16_t>(shapedType);
843436c6c9cSStella Laurenzo       }
844436c6c9cSStella Laurenzo     }
845436c6c9cSStella Laurenzo 
846c5f445d1SStella Laurenzo     // TODO: Currently crashes the program.
8475d6d30edSStella Laurenzo     // Reported as https://github.com/pybind/pybind11/issues/3336
848c5f445d1SStella Laurenzo     throw std::invalid_argument(
849c5f445d1SStella Laurenzo         "unsupported data type for conversion to Python buffer");
850436c6c9cSStella Laurenzo   }
851436c6c9cSStella Laurenzo 
852436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
853436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
854436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
855436c6c9cSStella Laurenzo                     py::arg("array"), py::arg("signless") = true,
8565d6d30edSStella Laurenzo                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
857436c6c9cSStella Laurenzo                     py::arg("context") = py::none(),
8585d6d30edSStella Laurenzo                     kDenseElementsAttrGetDocstring)
859436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
860436c6c9cSStella Laurenzo                     py::arg("shaped_type"), py::arg("element_attr"),
861436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
862436c6c9cSStella Laurenzo         .def_property_readonly("is_splat",
863436c6c9cSStella Laurenzo                                [](PyDenseElementsAttribute &self) -> bool {
864436c6c9cSStella Laurenzo                                  return mlirDenseElementsAttrIsSplat(self);
865436c6c9cSStella Laurenzo                                })
86691259963SAdam Paszke         .def("get_splat_value",
867974c1596SRahul Kayaith              [](PyDenseElementsAttribute &self) {
868974c1596SRahul Kayaith                if (!mlirDenseElementsAttrIsSplat(self))
8694811270bSmax                  throw py::value_error(
87091259963SAdam Paszke                      "get_splat_value called on a non-splat attribute");
871974c1596SRahul Kayaith                return mlirDenseElementsAttrGetSplatValue(self);
87291259963SAdam Paszke              })
873436c6c9cSStella Laurenzo         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
874436c6c9cSStella Laurenzo   }
875436c6c9cSStella Laurenzo 
876436c6c9cSStella Laurenzo private:
87771a25454SPeter Hawkins   static bool isUnsignedIntegerFormat(std::string_view format) {
878436c6c9cSStella Laurenzo     if (format.empty())
879436c6c9cSStella Laurenzo       return false;
880436c6c9cSStella Laurenzo     char code = format[0];
881436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
882436c6c9cSStella Laurenzo            code == 'Q';
883436c6c9cSStella Laurenzo   }
884436c6c9cSStella Laurenzo 
88571a25454SPeter Hawkins   static bool isSignedIntegerFormat(std::string_view format) {
886436c6c9cSStella Laurenzo     if (format.empty())
887436c6c9cSStella Laurenzo       return false;
888436c6c9cSStella Laurenzo     char code = format[0];
889436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
890436c6c9cSStella Laurenzo            code == 'q';
891436c6c9cSStella Laurenzo   }
892436c6c9cSStella Laurenzo 
893436c6c9cSStella Laurenzo   template <typename Type>
894436c6c9cSStella Laurenzo   py::buffer_info bufferInfo(MlirType shapedType,
8955d6d30edSStella Laurenzo                              const char *explicitFormat = nullptr) {
896436c6c9cSStella Laurenzo     intptr_t rank = mlirShapedTypeGetRank(shapedType);
897436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
898436c6c9cSStella Laurenzo     // Buffer is configured for read-only access below.
899436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
900436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
901436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
902436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
903436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
904436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
905436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
906436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
907f0e847d0SRahul Kayaith     if (mlirDenseElementsAttrIsSplat(*this)) {
908f0e847d0SRahul Kayaith       // Splats are special, only the single value is stored.
909f0e847d0SRahul Kayaith       strides.assign(rank, 0);
910f0e847d0SRahul Kayaith     } else {
911436c6c9cSStella Laurenzo       for (intptr_t i = 1; i < rank; ++i) {
912f0e847d0SRahul Kayaith         intptr_t strideFactor = 1;
913f0e847d0SRahul Kayaith         for (intptr_t j = i; j < rank; ++j)
914436c6c9cSStella Laurenzo           strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
915436c6c9cSStella Laurenzo         strides.push_back(sizeof(Type) * strideFactor);
916436c6c9cSStella Laurenzo       }
917436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type));
918f0e847d0SRahul Kayaith     }
9195d6d30edSStella Laurenzo     std::string format;
9205d6d30edSStella Laurenzo     if (explicitFormat) {
9215d6d30edSStella Laurenzo       format = explicitFormat;
9225d6d30edSStella Laurenzo     } else {
9235d6d30edSStella Laurenzo       format = py::format_descriptor<Type>::format();
9245d6d30edSStella Laurenzo     }
9255d6d30edSStella Laurenzo     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
9265d6d30edSStella Laurenzo                            /*readonly=*/true);
927436c6c9cSStella Laurenzo   }
928436c6c9cSStella Laurenzo }; // namespace
929436c6c9cSStella Laurenzo 
930436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer
931436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access.
932436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
933436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
934436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
935436c6c9cSStella Laurenzo public:
936436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
937436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
938436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
939436c6c9cSStella Laurenzo 
940436c6c9cSStella Laurenzo   /// Returns the element at the given linear position. Asserts if the index is
941436c6c9cSStella Laurenzo   /// out of range.
942436c6c9cSStella Laurenzo   py::int_ dunderGetItem(intptr_t pos) {
943436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
9444811270bSmax       throw py::index_error("attempt to access out of bounds element");
945436c6c9cSStella Laurenzo     }
946436c6c9cSStella Laurenzo 
947436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
948436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
949436c6c9cSStella Laurenzo     assert(mlirTypeIsAInteger(type) &&
950436c6c9cSStella Laurenzo            "expected integer element type in dense int elements attribute");
951436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
952436c6c9cSStella Laurenzo     // elemental type of the attribute. py::int_ is implicitly constructible
953436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
954436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
955436c6c9cSStella Laurenzo     // querying them on each element access.
956436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
957436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
958436c6c9cSStella Laurenzo     if (isUnsigned) {
959436c6c9cSStella Laurenzo       if (width == 1) {
960436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
961436c6c9cSStella Laurenzo       }
962308d8b8cSRahul Kayaith       if (width == 8) {
963308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt8Value(*this, pos);
964308d8b8cSRahul Kayaith       }
965308d8b8cSRahul Kayaith       if (width == 16) {
966308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt16Value(*this, pos);
967308d8b8cSRahul Kayaith       }
968436c6c9cSStella Laurenzo       if (width == 32) {
969436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
970436c6c9cSStella Laurenzo       }
971436c6c9cSStella Laurenzo       if (width == 64) {
972436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
973436c6c9cSStella Laurenzo       }
974436c6c9cSStella Laurenzo     } else {
975436c6c9cSStella Laurenzo       if (width == 1) {
976436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
977436c6c9cSStella Laurenzo       }
978308d8b8cSRahul Kayaith       if (width == 8) {
979308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt8Value(*this, pos);
980308d8b8cSRahul Kayaith       }
981308d8b8cSRahul Kayaith       if (width == 16) {
982308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt16Value(*this, pos);
983308d8b8cSRahul Kayaith       }
984436c6c9cSStella Laurenzo       if (width == 32) {
985436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt32Value(*this, pos);
986436c6c9cSStella Laurenzo       }
987436c6c9cSStella Laurenzo       if (width == 64) {
988436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt64Value(*this, pos);
989436c6c9cSStella Laurenzo       }
990436c6c9cSStella Laurenzo     }
9914811270bSmax     throw py::type_error("Unsupported integer type");
992436c6c9cSStella Laurenzo   }
993436c6c9cSStella Laurenzo 
994436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
995436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
996436c6c9cSStella Laurenzo   }
997436c6c9cSStella Laurenzo };
998436c6c9cSStella Laurenzo 
999436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
1000436c6c9cSStella Laurenzo public:
1001436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
1002436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
1003436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
10049566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
10059566ee28Smax       mlirDictionaryAttrGetTypeID;
1006436c6c9cSStella Laurenzo 
1007436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
1008436c6c9cSStella Laurenzo 
10099fb1086bSAdrian Kuegel   bool dunderContains(const std::string &name) {
10109fb1086bSAdrian Kuegel     return !mlirAttributeIsNull(
10119fb1086bSAdrian Kuegel         mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
10129fb1086bSAdrian Kuegel   }
10139fb1086bSAdrian Kuegel 
1014436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
10159fb1086bSAdrian Kuegel     c.def("__contains__", &PyDictAttribute::dunderContains);
1016436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
1017436c6c9cSStella Laurenzo     c.def_static(
1018436c6c9cSStella Laurenzo         "get",
1019436c6c9cSStella Laurenzo         [](py::dict attributes, DefaultingPyMlirContext context) {
1020436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
1021436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
1022436c6c9cSStella Laurenzo           for (auto &it : attributes) {
102302b6fb21SMehdi Amini             auto &mlirAttr = it.second.cast<PyAttribute &>();
1024436c6c9cSStella Laurenzo             auto name = it.first.cast<std::string>();
1025436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
102602b6fb21SMehdi Amini                 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
1027436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
102802b6fb21SMehdi Amini                 mlirAttr));
1029436c6c9cSStella Laurenzo           }
1030436c6c9cSStella Laurenzo           MlirAttribute attr =
1031436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
1032436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
1033436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
1034436c6c9cSStella Laurenzo         },
1035ed9e52f3SAlex Zinenko         py::arg("value") = py::dict(), py::arg("context") = py::none(),
1036436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
1037436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
1038436c6c9cSStella Laurenzo       MlirAttribute attr =
1039436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
1040974c1596SRahul Kayaith       if (mlirAttributeIsNull(attr))
10414811270bSmax         throw py::key_error("attempt to access a non-existent attribute");
1042974c1596SRahul Kayaith       return attr;
1043436c6c9cSStella Laurenzo     });
1044436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
1045436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
10464811270bSmax         throw py::index_error("attempt to access out of bounds attribute");
1047436c6c9cSStella Laurenzo       }
1048436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
1049436c6c9cSStella Laurenzo       return PyNamedAttribute(
1050436c6c9cSStella Laurenzo           namedAttr.attribute,
1051436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
1052436c6c9cSStella Laurenzo     });
1053436c6c9cSStella Laurenzo   }
1054436c6c9cSStella Laurenzo };
1055436c6c9cSStella Laurenzo 
1056436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
1057436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
1058436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
1059436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1060436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
1061436c6c9cSStella Laurenzo public:
1062436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1063436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
1064436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1065436c6c9cSStella Laurenzo 
1066436c6c9cSStella Laurenzo   py::float_ dunderGetItem(intptr_t pos) {
1067436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
10684811270bSmax       throw py::index_error("attempt to access out of bounds element");
1069436c6c9cSStella Laurenzo     }
1070436c6c9cSStella Laurenzo 
1071436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
1072436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
1073436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
1074436c6c9cSStella Laurenzo     // elemental type of the attribute. py::float_ is implicitly constructible
1075436c6c9cSStella Laurenzo     // from float and double.
1076436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
1077436c6c9cSStella Laurenzo     // querying them on each element access.
1078436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
1079436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetFloatValue(*this, pos);
1080436c6c9cSStella Laurenzo     }
1081436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
1082436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
1083436c6c9cSStella Laurenzo     }
10844811270bSmax     throw py::type_error("Unsupported floating-point type");
1085436c6c9cSStella Laurenzo   }
1086436c6c9cSStella Laurenzo 
1087436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1088436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1089436c6c9cSStella Laurenzo   }
1090436c6c9cSStella Laurenzo };
1091436c6c9cSStella Laurenzo 
1092436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1093436c6c9cSStella Laurenzo public:
1094436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1095436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
1096436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
10979566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
10989566ee28Smax       mlirTypeAttrGetTypeID;
1099436c6c9cSStella Laurenzo 
1100436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1101436c6c9cSStella Laurenzo     c.def_static(
1102436c6c9cSStella Laurenzo         "get",
1103436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
1104436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
1105436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
1106436c6c9cSStella Laurenzo         },
1107436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
1108436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
1109436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyTypeAttribute &self) {
1110bfb1ba75Smax       return mlirTypeAttrGetValue(self.get());
1111436c6c9cSStella Laurenzo     });
1112436c6c9cSStella Laurenzo   }
1113436c6c9cSStella Laurenzo };
1114436c6c9cSStella Laurenzo 
1115436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
1116436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1117436c6c9cSStella Laurenzo public:
1118436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1119436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
1120436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
11219566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
11229566ee28Smax       mlirUnitAttrGetTypeID;
1123436c6c9cSStella Laurenzo 
1124436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1125436c6c9cSStella Laurenzo     c.def_static(
1126436c6c9cSStella Laurenzo         "get",
1127436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
1128436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
1129436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
1130436c6c9cSStella Laurenzo         },
1131436c6c9cSStella Laurenzo         py::arg("context") = py::none(), "Create a Unit attribute.");
1132436c6c9cSStella Laurenzo   }
1133436c6c9cSStella Laurenzo };
1134436c6c9cSStella Laurenzo 
1135ac2e2d65SDenys Shabalin /// Strided layout attribute subclass.
1136ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute
1137ac2e2d65SDenys Shabalin     : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1138ac2e2d65SDenys Shabalin public:
1139ac2e2d65SDenys Shabalin   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1140ac2e2d65SDenys Shabalin   static constexpr const char *pyClassName = "StridedLayoutAttr";
1141ac2e2d65SDenys Shabalin   using PyConcreteAttribute::PyConcreteAttribute;
11429566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
11439566ee28Smax       mlirStridedLayoutAttrGetTypeID;
1144ac2e2d65SDenys Shabalin 
1145ac2e2d65SDenys Shabalin   static void bindDerived(ClassTy &c) {
1146ac2e2d65SDenys Shabalin     c.def_static(
1147ac2e2d65SDenys Shabalin         "get",
1148ac2e2d65SDenys Shabalin         [](int64_t offset, const std::vector<int64_t> strides,
1149ac2e2d65SDenys Shabalin            DefaultingPyMlirContext ctx) {
1150ac2e2d65SDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1151ac2e2d65SDenys Shabalin               ctx->get(), offset, strides.size(), strides.data());
1152ac2e2d65SDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1153ac2e2d65SDenys Shabalin         },
1154ac2e2d65SDenys Shabalin         py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
1155ac2e2d65SDenys Shabalin         "Gets a strided layout attribute.");
1156e3fd612eSDenys Shabalin     c.def_static(
1157e3fd612eSDenys Shabalin         "get_fully_dynamic",
1158e3fd612eSDenys Shabalin         [](int64_t rank, DefaultingPyMlirContext ctx) {
1159e3fd612eSDenys Shabalin           auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
1160e3fd612eSDenys Shabalin           std::vector<int64_t> strides(rank);
1161e3fd612eSDenys Shabalin           std::fill(strides.begin(), strides.end(), dynamic);
1162e3fd612eSDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1163e3fd612eSDenys Shabalin               ctx->get(), dynamic, strides.size(), strides.data());
1164e3fd612eSDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1165e3fd612eSDenys Shabalin         },
1166e3fd612eSDenys Shabalin         py::arg("rank"), py::arg("context") = py::none(),
1167e3fd612eSDenys Shabalin         "Gets a strided layout attribute with dynamic offset and strides of a "
1168e3fd612eSDenys Shabalin         "given rank.");
1169ac2e2d65SDenys Shabalin     c.def_property_readonly(
1170ac2e2d65SDenys Shabalin         "offset",
1171ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1172ac2e2d65SDenys Shabalin           return mlirStridedLayoutAttrGetOffset(self);
1173ac2e2d65SDenys Shabalin         },
1174ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1175ac2e2d65SDenys Shabalin     c.def_property_readonly(
1176ac2e2d65SDenys Shabalin         "strides",
1177ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1178ac2e2d65SDenys Shabalin           intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1179ac2e2d65SDenys Shabalin           std::vector<int64_t> strides(size);
1180ac2e2d65SDenys Shabalin           for (intptr_t i = 0; i < size; i++) {
1181ac2e2d65SDenys Shabalin             strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1182ac2e2d65SDenys Shabalin           }
1183ac2e2d65SDenys Shabalin           return strides;
1184ac2e2d65SDenys Shabalin         },
1185ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1186ac2e2d65SDenys Shabalin   }
1187ac2e2d65SDenys Shabalin };
1188ac2e2d65SDenys Shabalin 
11899566ee28Smax py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
11909566ee28Smax   if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
11919566ee28Smax     return py::cast(PyDenseBoolArrayAttribute(pyAttribute));
11929566ee28Smax   if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
11939566ee28Smax     return py::cast(PyDenseI8ArrayAttribute(pyAttribute));
11949566ee28Smax   if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
11959566ee28Smax     return py::cast(PyDenseI16ArrayAttribute(pyAttribute));
11969566ee28Smax   if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
11979566ee28Smax     return py::cast(PyDenseI32ArrayAttribute(pyAttribute));
11989566ee28Smax   if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
11999566ee28Smax     return py::cast(PyDenseI64ArrayAttribute(pyAttribute));
12009566ee28Smax   if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
12019566ee28Smax     return py::cast(PyDenseF32ArrayAttribute(pyAttribute));
12029566ee28Smax   if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
12039566ee28Smax     return py::cast(PyDenseF64ArrayAttribute(pyAttribute));
12049566ee28Smax   std::string msg =
12059566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
12069566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
12079566ee28Smax   throw py::cast_error(msg);
12089566ee28Smax }
12099566ee28Smax 
12109566ee28Smax py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
12119566ee28Smax   if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
12129566ee28Smax     return py::cast(PyDenseFPElementsAttribute(pyAttribute));
12139566ee28Smax   if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
12149566ee28Smax     return py::cast(PyDenseIntElementsAttribute(pyAttribute));
12159566ee28Smax   std::string msg =
12169566ee28Smax       std::string(
12179566ee28Smax           "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
12189566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
12199566ee28Smax   throw py::cast_error(msg);
12209566ee28Smax }
12219566ee28Smax 
12229566ee28Smax py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
12239566ee28Smax   if (PyBoolAttribute::isaFunction(pyAttribute))
12249566ee28Smax     return py::cast(PyBoolAttribute(pyAttribute));
12259566ee28Smax   if (PyIntegerAttribute::isaFunction(pyAttribute))
12269566ee28Smax     return py::cast(PyIntegerAttribute(pyAttribute));
12279566ee28Smax   std::string msg =
12289566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
12299566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
12309566ee28Smax   throw py::cast_error(msg);
12319566ee28Smax }
12329566ee28Smax 
12334eee9ef9Smax py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
12344eee9ef9Smax   if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
12354eee9ef9Smax     return py::cast(PyFlatSymbolRefAttribute(pyAttribute));
12364eee9ef9Smax   if (PySymbolRefAttribute::isaFunction(pyAttribute))
12374eee9ef9Smax     return py::cast(PySymbolRefAttribute(pyAttribute));
12384eee9ef9Smax   std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
12394eee9ef9Smax                     std::string(py::repr(py::cast(pyAttribute))) + ")";
12404eee9ef9Smax   throw py::cast_error(msg);
12414eee9ef9Smax }
12424eee9ef9Smax 
1243436c6c9cSStella Laurenzo } // namespace
1244436c6c9cSStella Laurenzo 
1245436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) {
1246436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
1247619fd8c2SJeff Niu 
1248619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::bind(m);
1249619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1250619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::bind(m);
1251619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1252619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::bind(m);
1253619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1254619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::bind(m);
1255619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1256619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::bind(m);
1257619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1258619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::bind(m);
1259619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1260619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::bind(m);
1261619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
12629566ee28Smax   PyGlobals::get().registerTypeCaster(
12639566ee28Smax       mlirDenseArrayAttrGetTypeID(),
12649566ee28Smax       pybind11::cpp_function(denseArrayAttributeCaster));
1265619fd8c2SJeff Niu 
1266436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
1267436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1268436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
1269436c6c9cSStella Laurenzo   PyDenseElementsAttribute::bind(m);
1270436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
1271436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
12729566ee28Smax   PyGlobals::get().registerTypeCaster(
12739566ee28Smax       mlirDenseIntOrFPElementsAttrGetTypeID(),
12749566ee28Smax       pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
12759566ee28Smax 
1276436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
12774eee9ef9Smax   PySymbolRefAttribute::bind(m);
12784eee9ef9Smax   PyGlobals::get().registerTypeCaster(
12794eee9ef9Smax       mlirSymbolRefAttrGetTypeID(),
12804eee9ef9Smax       pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster));
12814eee9ef9Smax 
1282436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
12835c3861b2SYun Long   PyOpaqueAttribute::bind(m);
1284436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
1285436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
1286436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
1287436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
12889566ee28Smax   PyGlobals::get().registerTypeCaster(
12899566ee28Smax       mlirIntegerAttrGetTypeID(),
12909566ee28Smax       pybind11::cpp_function(integerOrBoolAttributeCaster));
1291436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
1292ac2e2d65SDenys Shabalin 
1293ac2e2d65SDenys Shabalin   PyStridedLayoutAttribute::bind(m);
1294436c6c9cSStella Laurenzo }
1295