xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision 4eee9ef9768b1335800878b8f0b7aa3e549e41dc)
1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9a1fe1f5fSKazu Hirata #include <optional>
104811270bSmax #include <utility>
111fc096afSMehdi Amini 
12436c6c9cSStella Laurenzo #include "IRModule.h"
13436c6c9cSStella Laurenzo 
14436c6c9cSStella Laurenzo #include "PybindUtils.h"
15436c6c9cSStella Laurenzo 
16436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
17436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
18bfb1ba75Smax #include "mlir/Bindings/Python/PybindAdaptors.h"
19436c6c9cSStella Laurenzo 
20436c6c9cSStella Laurenzo namespace py = pybind11;
21436c6c9cSStella Laurenzo using namespace mlir;
22436c6c9cSStella Laurenzo using namespace mlir::python;
23436c6c9cSStella Laurenzo 
24436c6c9cSStella Laurenzo using llvm::SmallVector;
25436c6c9cSStella Laurenzo 
265d6d30edSStella Laurenzo //------------------------------------------------------------------------------
275d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
285d6d30edSStella Laurenzo //------------------------------------------------------------------------------
295d6d30edSStella Laurenzo 
305d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] =
315d6d30edSStella Laurenzo     R"(Gets a DenseElementsAttr from a Python buffer or array.
325d6d30edSStella Laurenzo 
335d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based
345d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and
355d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these
365d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer.
375d6d30edSStella Laurenzo 
385d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided
395d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal
405d6d30edSStella Laurenzo representation:
415d6d30edSStella Laurenzo 
425d6d30edSStella Laurenzo   * Integer types (except for i1): the buffer must be byte aligned to the
435d6d30edSStella Laurenzo     next byte boundary.
445d6d30edSStella Laurenzo   * Floating point types: Must be bit-castable to the given floating point
455d6d30edSStella Laurenzo     size.
465d6d30edSStella Laurenzo   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
475d6d30edSStella Laurenzo     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
485d6d30edSStella Laurenzo     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
495d6d30edSStella Laurenzo 
505d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0
515d6d30edSStella Laurenzo or 255), then a splat will be created.
525d6d30edSStella Laurenzo 
535d6d30edSStella Laurenzo Args:
545d6d30edSStella Laurenzo   array: The array or buffer to convert.
555d6d30edSStella Laurenzo   signless: If inferring an appropriate MLIR type, use signless types for
565d6d30edSStella Laurenzo     integers (defaults True).
575d6d30edSStella Laurenzo   type: Skips inference of the MLIR element type and uses this instead. The
585d6d30edSStella Laurenzo     storage size must be consistent with the actual contents of the buffer.
595d6d30edSStella Laurenzo   shape: Overrides the shape of the buffer when constructing the MLIR
605d6d30edSStella Laurenzo     shaped type. This is needed when the physical and logical shape differ (as
615d6d30edSStella Laurenzo     for i1).
625d6d30edSStella Laurenzo   context: Explicit context, if not from context manager.
635d6d30edSStella Laurenzo 
645d6d30edSStella Laurenzo Returns:
655d6d30edSStella Laurenzo   DenseElementsAttr on success.
665d6d30edSStella Laurenzo 
675d6d30edSStella Laurenzo Raises:
685d6d30edSStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
695d6d30edSStella Laurenzo     type or if the buffer does not meet expectations.
705d6d30edSStella Laurenzo )";
715d6d30edSStella Laurenzo 
72436c6c9cSStella Laurenzo namespace {
73436c6c9cSStella Laurenzo 
74436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
75436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
76436c6c9cSStella Laurenzo }
77436c6c9cSStella Laurenzo 
78436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
79436c6c9cSStella Laurenzo public:
80436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
81436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
82436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
839566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
849566ee28Smax       mlirAffineMapAttrGetTypeID;
85436c6c9cSStella Laurenzo 
86436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
87436c6c9cSStella Laurenzo     c.def_static(
88436c6c9cSStella Laurenzo         "get",
89436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
90436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
91436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
92436c6c9cSStella Laurenzo         },
93436c6c9cSStella Laurenzo         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
94436c6c9cSStella Laurenzo   }
95436c6c9cSStella Laurenzo };
96436c6c9cSStella Laurenzo 
97ed9e52f3SAlex Zinenko template <typename T>
98ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) {
99ed9e52f3SAlex Zinenko   try {
100ed9e52f3SAlex Zinenko     return object.cast<T>();
101ed9e52f3SAlex Zinenko   } catch (py::cast_error &err) {
102ed9e52f3SAlex Zinenko     std::string msg =
103ed9e52f3SAlex Zinenko         std::string(
104ed9e52f3SAlex Zinenko             "Invalid attribute when attempting to create an ArrayAttribute (") +
105ed9e52f3SAlex Zinenko         err.what() + ")";
106ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
107ed9e52f3SAlex Zinenko   } catch (py::reference_cast_error &err) {
108ed9e52f3SAlex Zinenko     std::string msg = std::string("Invalid attribute (None?) when attempting "
109ed9e52f3SAlex Zinenko                                   "to create an ArrayAttribute (") +
110ed9e52f3SAlex Zinenko                       err.what() + ")";
111ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
112ed9e52f3SAlex Zinenko   }
113ed9e52f3SAlex Zinenko }
114ed9e52f3SAlex Zinenko 
115619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived
116619fd8c2SJeff Niu /// implementation class.
117619fd8c2SJeff Niu template <typename EltTy, typename DerivedT>
118133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
119619fd8c2SJeff Niu public:
120133624acSJeff Niu   using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
121619fd8c2SJeff Niu 
122619fd8c2SJeff Niu   /// Iterator over the integer elements of a dense array.
123619fd8c2SJeff Niu   class PyDenseArrayIterator {
124619fd8c2SJeff Niu   public:
1254a1b1196SMehdi Amini     PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
126619fd8c2SJeff Niu 
127619fd8c2SJeff Niu     /// Return a copy of the iterator.
128619fd8c2SJeff Niu     PyDenseArrayIterator dunderIter() { return *this; }
129619fd8c2SJeff Niu 
130619fd8c2SJeff Niu     /// Return the next element.
131619fd8c2SJeff Niu     EltTy dunderNext() {
132619fd8c2SJeff Niu       // Throw if the index has reached the end.
133619fd8c2SJeff Niu       if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
134619fd8c2SJeff Niu         throw py::stop_iteration();
135619fd8c2SJeff Niu       return DerivedT::getElement(attr.get(), nextIndex++);
136619fd8c2SJeff Niu     }
137619fd8c2SJeff Niu 
138619fd8c2SJeff Niu     /// Bind the iterator class.
139619fd8c2SJeff Niu     static void bind(py::module &m) {
140619fd8c2SJeff Niu       py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName,
141619fd8c2SJeff Niu                                        py::module_local())
142619fd8c2SJeff Niu           .def("__iter__", &PyDenseArrayIterator::dunderIter)
143619fd8c2SJeff Niu           .def("__next__", &PyDenseArrayIterator::dunderNext);
144619fd8c2SJeff Niu     }
145619fd8c2SJeff Niu 
146619fd8c2SJeff Niu   private:
147619fd8c2SJeff Niu     /// The referenced dense array attribute.
148619fd8c2SJeff Niu     PyAttribute attr;
149619fd8c2SJeff Niu     /// The next index to read.
150619fd8c2SJeff Niu     int nextIndex = 0;
151619fd8c2SJeff Niu   };
152619fd8c2SJeff Niu 
153619fd8c2SJeff Niu   /// Get the element at the given index.
154619fd8c2SJeff Niu   EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
155619fd8c2SJeff Niu 
156619fd8c2SJeff Niu   /// Bind the attribute class.
157133624acSJeff Niu   static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
158619fd8c2SJeff Niu     // Bind the constructor.
159619fd8c2SJeff Niu     c.def_static(
160619fd8c2SJeff Niu         "get",
161619fd8c2SJeff Niu         [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
162619fd8c2SJeff Niu           MlirAttribute attr =
163619fd8c2SJeff Niu               DerivedT::getAttribute(ctx->get(), values.size(), values.data());
164133624acSJeff Niu           return DerivedT(ctx->getRef(), attr);
165619fd8c2SJeff Niu         },
166619fd8c2SJeff Niu         py::arg("values"), py::arg("context") = py::none(),
167619fd8c2SJeff Niu         "Gets a uniqued dense array attribute");
168619fd8c2SJeff Niu     // Bind the array methods.
169133624acSJeff Niu     c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
170619fd8c2SJeff Niu       if (i >= mlirDenseArrayGetNumElements(arr))
171619fd8c2SJeff Niu         throw py::index_error("DenseArray index out of range");
172619fd8c2SJeff Niu       return arr.getItem(i);
173619fd8c2SJeff Niu     });
174133624acSJeff Niu     c.def("__len__", [](const DerivedT &arr) {
175619fd8c2SJeff Niu       return mlirDenseArrayGetNumElements(arr);
176619fd8c2SJeff Niu     });
177133624acSJeff Niu     c.def("__iter__",
178133624acSJeff Niu           [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
1794a1b1196SMehdi Amini     c.def("__add__", [](DerivedT &arr, const py::list &extras) {
180619fd8c2SJeff Niu       std::vector<EltTy> values;
181619fd8c2SJeff Niu       intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
182619fd8c2SJeff Niu       values.reserve(numOldElements + py::len(extras));
183619fd8c2SJeff Niu       for (intptr_t i = 0; i < numOldElements; ++i)
184619fd8c2SJeff Niu         values.push_back(arr.getItem(i));
185619fd8c2SJeff Niu       for (py::handle attr : extras)
186619fd8c2SJeff Niu         values.push_back(pyTryCast<EltTy>(attr));
187619fd8c2SJeff Niu       MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(),
188619fd8c2SJeff Niu                                                   values.size(), values.data());
189133624acSJeff Niu       return DerivedT(arr.getContext(), attr);
190619fd8c2SJeff Niu     });
191619fd8c2SJeff Niu   }
192619fd8c2SJeff Niu };
193619fd8c2SJeff Niu 
194619fd8c2SJeff Niu /// Instantiate the python dense array classes.
195619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute
196619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int, PyDenseBoolArrayAttribute> {
197619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
198619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseBoolArrayGet;
199619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseBoolArrayGetElement;
200619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseBoolArrayAttr";
201619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
202619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
203619fd8c2SJeff Niu };
204619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute
205619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
206619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
207619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI8ArrayGet;
208619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI8ArrayGetElement;
209619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI8ArrayAttr";
210619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
211619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
212619fd8c2SJeff Niu };
213619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute
214619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
215619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
216619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI16ArrayGet;
217619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI16ArrayGetElement;
218619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI16ArrayAttr";
219619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
220619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
221619fd8c2SJeff Niu };
222619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute
223619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
224619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
225619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI32ArrayGet;
226619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI32ArrayGetElement;
227619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI32ArrayAttr";
228619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
229619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
230619fd8c2SJeff Niu };
231619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute
232619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
233619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
234619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI64ArrayGet;
235619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI64ArrayGetElement;
236619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI64ArrayAttr";
237619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
238619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
239619fd8c2SJeff Niu };
240619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute
241619fd8c2SJeff Niu     : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
242619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
243619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF32ArrayGet;
244619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF32ArrayGetElement;
245619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF32ArrayAttr";
246619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
247619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
248619fd8c2SJeff Niu };
249619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute
250619fd8c2SJeff Niu     : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
251619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
252619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF64ArrayGet;
253619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF64ArrayGetElement;
254619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF64ArrayAttr";
255619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
256619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
257619fd8c2SJeff Niu };
258619fd8c2SJeff Niu 
259436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
260436c6c9cSStella Laurenzo public:
261436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
262436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
263436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
2649566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
2659566ee28Smax       mlirArrayAttrGetTypeID;
266436c6c9cSStella Laurenzo 
267436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
268436c6c9cSStella Laurenzo   public:
2691fc096afSMehdi Amini     PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
270436c6c9cSStella Laurenzo 
271436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
272436c6c9cSStella Laurenzo 
273436c6c9cSStella Laurenzo     PyAttribute dunderNext() {
274bca88952SJeff Niu       // TODO: Throw is an inefficient way to stop iteration.
275bca88952SJeff Niu       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
276436c6c9cSStella Laurenzo         throw py::stop_iteration();
277436c6c9cSStella Laurenzo       return PyAttribute(attr.getContext(),
278436c6c9cSStella Laurenzo                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
279436c6c9cSStella Laurenzo     }
280436c6c9cSStella Laurenzo 
281436c6c9cSStella Laurenzo     static void bind(py::module &m) {
282f05ff4f7SStella Laurenzo       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
283f05ff4f7SStella Laurenzo                                            py::module_local())
284436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
285436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
286436c6c9cSStella Laurenzo     }
287436c6c9cSStella Laurenzo 
288436c6c9cSStella Laurenzo   private:
289436c6c9cSStella Laurenzo     PyAttribute attr;
290436c6c9cSStella Laurenzo     int nextIndex = 0;
291436c6c9cSStella Laurenzo   };
292436c6c9cSStella Laurenzo 
293ed9e52f3SAlex Zinenko   PyAttribute getItem(intptr_t i) {
294ed9e52f3SAlex Zinenko     return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
295ed9e52f3SAlex Zinenko   }
296ed9e52f3SAlex Zinenko 
297436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
298436c6c9cSStella Laurenzo     c.def_static(
299436c6c9cSStella Laurenzo         "get",
300436c6c9cSStella Laurenzo         [](py::list attributes, DefaultingPyMlirContext context) {
301436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
302436c6c9cSStella Laurenzo           mlirAttributes.reserve(py::len(attributes));
303436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
304ed9e52f3SAlex Zinenko             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
305436c6c9cSStella Laurenzo           }
306436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
307436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
308436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
309436c6c9cSStella Laurenzo         },
310436c6c9cSStella Laurenzo         py::arg("attributes"), py::arg("context") = py::none(),
311436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
312436c6c9cSStella Laurenzo     c.def("__getitem__",
313436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
314436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
315436c6c9cSStella Laurenzo               throw py::index_error("ArrayAttribute index out of range");
316ed9e52f3SAlex Zinenko             return arr.getItem(i);
317436c6c9cSStella Laurenzo           })
318436c6c9cSStella Laurenzo         .def("__len__",
319436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
320436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
321436c6c9cSStella Laurenzo              })
322436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
323436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
324436c6c9cSStella Laurenzo         });
325ed9e52f3SAlex Zinenko     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
326ed9e52f3SAlex Zinenko       std::vector<MlirAttribute> attributes;
327ed9e52f3SAlex Zinenko       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
328ed9e52f3SAlex Zinenko       attributes.reserve(numOldElements + py::len(extras));
329ed9e52f3SAlex Zinenko       for (intptr_t i = 0; i < numOldElements; ++i)
330ed9e52f3SAlex Zinenko         attributes.push_back(arr.getItem(i));
331ed9e52f3SAlex Zinenko       for (py::handle attr : extras)
332ed9e52f3SAlex Zinenko         attributes.push_back(pyTryCast<PyAttribute>(attr));
333ed9e52f3SAlex Zinenko       MlirAttribute arrayAttr = mlirArrayAttrGet(
334ed9e52f3SAlex Zinenko           arr.getContext()->get(), attributes.size(), attributes.data());
335ed9e52f3SAlex Zinenko       return PyArrayAttribute(arr.getContext(), arrayAttr);
336ed9e52f3SAlex Zinenko     });
337436c6c9cSStella Laurenzo   }
338436c6c9cSStella Laurenzo };
339436c6c9cSStella Laurenzo 
340436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
341436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
342436c6c9cSStella Laurenzo public:
343436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
344436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
345436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
3469566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
3479566ee28Smax       mlirFloatAttrGetTypeID;
348436c6c9cSStella Laurenzo 
349436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
350436c6c9cSStella Laurenzo     c.def_static(
351436c6c9cSStella Laurenzo         "get",
352436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
3533ea4c501SRahul Kayaith           PyMlirContext::ErrorCapture errors(loc->getContext());
354436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
3553ea4c501SRahul Kayaith           if (mlirAttributeIsNull(attr))
3563ea4c501SRahul Kayaith             throw MLIRError("Invalid attribute", errors.take());
357436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
358436c6c9cSStella Laurenzo         },
359436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
360436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
361436c6c9cSStella Laurenzo     c.def_static(
362436c6c9cSStella Laurenzo         "get_f32",
363436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
364436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
365436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
366436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
367436c6c9cSStella Laurenzo         },
368436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
369436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
370436c6c9cSStella Laurenzo     c.def_static(
371436c6c9cSStella Laurenzo         "get_f64",
372436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
373436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
374436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
375436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
376436c6c9cSStella Laurenzo         },
377436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
378436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
379436c6c9cSStella Laurenzo     c.def_property_readonly(
380436c6c9cSStella Laurenzo         "value",
381436c6c9cSStella Laurenzo         [](PyFloatAttribute &self) {
382436c6c9cSStella Laurenzo           return mlirFloatAttrGetValueDouble(self);
383436c6c9cSStella Laurenzo         },
384436c6c9cSStella Laurenzo         "Returns the value of the float point attribute");
385436c6c9cSStella Laurenzo   }
386436c6c9cSStella Laurenzo };
387436c6c9cSStella Laurenzo 
388436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
389436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
390436c6c9cSStella Laurenzo public:
391436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
392436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
393436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
394436c6c9cSStella Laurenzo 
395436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
396436c6c9cSStella Laurenzo     c.def_static(
397436c6c9cSStella Laurenzo         "get",
398436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
399436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
400436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
401436c6c9cSStella Laurenzo         },
402436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"),
403436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
404436c6c9cSStella Laurenzo     c.def_property_readonly(
405436c6c9cSStella Laurenzo         "value",
406e9db306dSrkayaith         [](PyIntegerAttribute &self) -> py::int_ {
407e9db306dSrkayaith           MlirType type = mlirAttributeGetType(self);
408e9db306dSrkayaith           if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
409436c6c9cSStella Laurenzo             return mlirIntegerAttrGetValueInt(self);
410e9db306dSrkayaith           if (mlirIntegerTypeIsSigned(type))
411e9db306dSrkayaith             return mlirIntegerAttrGetValueSInt(self);
412e9db306dSrkayaith           return mlirIntegerAttrGetValueUInt(self);
413436c6c9cSStella Laurenzo         },
414436c6c9cSStella Laurenzo         "Returns the value of the integer attribute");
4159566ee28Smax     c.def_property_readonly_static("static_typeid",
4169566ee28Smax                                    [](py::object & /*class*/) -> MlirTypeID {
4179566ee28Smax                                      return mlirIntegerAttrGetTypeID();
4189566ee28Smax                                    });
419436c6c9cSStella Laurenzo   }
420436c6c9cSStella Laurenzo };
421436c6c9cSStella Laurenzo 
422436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
423436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
424436c6c9cSStella Laurenzo public:
425436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
426436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
427436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
428436c6c9cSStella Laurenzo 
429436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
430436c6c9cSStella Laurenzo     c.def_static(
431436c6c9cSStella Laurenzo         "get",
432436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
433436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
434436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
435436c6c9cSStella Laurenzo         },
436436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
437436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
438436c6c9cSStella Laurenzo     c.def_property_readonly(
439436c6c9cSStella Laurenzo         "value",
440436c6c9cSStella Laurenzo         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
441436c6c9cSStella Laurenzo         "Returns the value of the bool attribute");
442436c6c9cSStella Laurenzo   }
443436c6c9cSStella Laurenzo };
444436c6c9cSStella Laurenzo 
445*4eee9ef9Smax class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
446*4eee9ef9Smax public:
447*4eee9ef9Smax   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
448*4eee9ef9Smax   static constexpr const char *pyClassName = "SymbolRefAttr";
449*4eee9ef9Smax   using PyConcreteAttribute::PyConcreteAttribute;
450*4eee9ef9Smax 
451*4eee9ef9Smax   static MlirAttribute fromList(const std::vector<std::string> &symbols,
452*4eee9ef9Smax                                 PyMlirContext &context) {
453*4eee9ef9Smax     if (symbols.empty())
454*4eee9ef9Smax       throw std::runtime_error("SymbolRefAttr must be composed of at least "
455*4eee9ef9Smax                                "one symbol.");
456*4eee9ef9Smax     MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
457*4eee9ef9Smax     SmallVector<MlirAttribute, 3> referenceAttrs;
458*4eee9ef9Smax     for (size_t i = 1; i < symbols.size(); ++i) {
459*4eee9ef9Smax       referenceAttrs.push_back(
460*4eee9ef9Smax           mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
461*4eee9ef9Smax     }
462*4eee9ef9Smax     return mlirSymbolRefAttrGet(context.get(), rootSymbol,
463*4eee9ef9Smax                                 referenceAttrs.size(), referenceAttrs.data());
464*4eee9ef9Smax   }
465*4eee9ef9Smax 
466*4eee9ef9Smax   static void bindDerived(ClassTy &c) {
467*4eee9ef9Smax     c.def_static(
468*4eee9ef9Smax         "get",
469*4eee9ef9Smax         [](const std::vector<std::string> &symbols,
470*4eee9ef9Smax            DefaultingPyMlirContext context) {
471*4eee9ef9Smax           return PySymbolRefAttribute::fromList(symbols, context.resolve());
472*4eee9ef9Smax         },
473*4eee9ef9Smax         py::arg("symbols"), py::arg("context") = py::none(),
474*4eee9ef9Smax         "Gets a uniqued SymbolRef attribute from a list of symbol names");
475*4eee9ef9Smax     c.def_property_readonly(
476*4eee9ef9Smax         "value",
477*4eee9ef9Smax         [](PySymbolRefAttribute &self) {
478*4eee9ef9Smax           std::vector<std::string> symbols = {
479*4eee9ef9Smax               unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
480*4eee9ef9Smax           for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
481*4eee9ef9Smax                ++i)
482*4eee9ef9Smax             symbols.push_back(
483*4eee9ef9Smax                 unwrap(mlirSymbolRefAttrGetRootReference(
484*4eee9ef9Smax                            mlirSymbolRefAttrGetNestedReference(self, i)))
485*4eee9ef9Smax                     .str());
486*4eee9ef9Smax           return symbols;
487*4eee9ef9Smax         },
488*4eee9ef9Smax         "Returns the value of the SymbolRef attribute as a list[str]");
489*4eee9ef9Smax   }
490*4eee9ef9Smax };
491*4eee9ef9Smax 
492436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
493436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
494436c6c9cSStella Laurenzo public:
495436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
496436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
497436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
498436c6c9cSStella Laurenzo 
499436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
500436c6c9cSStella Laurenzo     c.def_static(
501436c6c9cSStella Laurenzo         "get",
502436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
503436c6c9cSStella Laurenzo           MlirAttribute attr =
504436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
505436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
506436c6c9cSStella Laurenzo         },
507436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
508436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
509436c6c9cSStella Laurenzo     c.def_property_readonly(
510436c6c9cSStella Laurenzo         "value",
511436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
512436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
513436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
514436c6c9cSStella Laurenzo         },
515436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
516436c6c9cSStella Laurenzo   }
517436c6c9cSStella Laurenzo };
518436c6c9cSStella Laurenzo 
5195c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
5205c3861b2SYun Long public:
5215c3861b2SYun Long   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
5225c3861b2SYun Long   static constexpr const char *pyClassName = "OpaqueAttr";
5235c3861b2SYun Long   using PyConcreteAttribute::PyConcreteAttribute;
5249566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
5259566ee28Smax       mlirOpaqueAttrGetTypeID;
5265c3861b2SYun Long 
5275c3861b2SYun Long   static void bindDerived(ClassTy &c) {
5285c3861b2SYun Long     c.def_static(
5295c3861b2SYun Long         "get",
5305c3861b2SYun Long         [](std::string dialectNamespace, py::buffer buffer, PyType &type,
5315c3861b2SYun Long            DefaultingPyMlirContext context) {
5325c3861b2SYun Long           const py::buffer_info bufferInfo = buffer.request();
5335c3861b2SYun Long           intptr_t bufferSize = bufferInfo.size;
5345c3861b2SYun Long           MlirAttribute attr = mlirOpaqueAttrGet(
5355c3861b2SYun Long               context->get(), toMlirStringRef(dialectNamespace), bufferSize,
5365c3861b2SYun Long               static_cast<char *>(bufferInfo.ptr), type);
5375c3861b2SYun Long           return PyOpaqueAttribute(context->getRef(), attr);
5385c3861b2SYun Long         },
5395c3861b2SYun Long         py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"),
5405c3861b2SYun Long         py::arg("context") = py::none(), "Gets an Opaque attribute.");
5415c3861b2SYun Long     c.def_property_readonly(
5425c3861b2SYun Long         "dialect_namespace",
5435c3861b2SYun Long         [](PyOpaqueAttribute &self) {
5445c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
5455c3861b2SYun Long           return py::str(stringRef.data, stringRef.length);
5465c3861b2SYun Long         },
5475c3861b2SYun Long         "Returns the dialect namespace for the Opaque attribute as a string");
5485c3861b2SYun Long     c.def_property_readonly(
5495c3861b2SYun Long         "data",
5505c3861b2SYun Long         [](PyOpaqueAttribute &self) {
5515c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
55262bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
5535c3861b2SYun Long         },
55462bf6c2eSChris Jones         "Returns the data for the Opaqued attributes as `bytes`");
5555c3861b2SYun Long   }
5565c3861b2SYun Long };
5575c3861b2SYun Long 
558436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
559436c6c9cSStella Laurenzo public:
560436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
561436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
562436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
5639566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
5649566ee28Smax       mlirStringAttrGetTypeID;
565436c6c9cSStella Laurenzo 
566436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
567436c6c9cSStella Laurenzo     c.def_static(
568436c6c9cSStella Laurenzo         "get",
569436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
570436c6c9cSStella Laurenzo           MlirAttribute attr =
571436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
572436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
573436c6c9cSStella Laurenzo         },
574436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
575436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
576436c6c9cSStella Laurenzo     c.def_static(
577436c6c9cSStella Laurenzo         "get_typed",
578436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
579436c6c9cSStella Laurenzo           MlirAttribute attr =
580436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
581436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
582436c6c9cSStella Laurenzo         },
583a6e7d024SStella Laurenzo         py::arg("type"), py::arg("value"),
584436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
585436c6c9cSStella Laurenzo     c.def_property_readonly(
586436c6c9cSStella Laurenzo         "value",
587436c6c9cSStella Laurenzo         [](PyStringAttribute &self) {
588436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirStringAttrGetValue(self);
589436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
590436c6c9cSStella Laurenzo         },
591436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
59262bf6c2eSChris Jones     c.def_property_readonly(
59362bf6c2eSChris Jones         "value_bytes",
59462bf6c2eSChris Jones         [](PyStringAttribute &self) {
59562bf6c2eSChris Jones           MlirStringRef stringRef = mlirStringAttrGetValue(self);
59662bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
59762bf6c2eSChris Jones         },
59862bf6c2eSChris Jones         "Returns the value of the string attribute as `bytes`");
599436c6c9cSStella Laurenzo   }
600436c6c9cSStella Laurenzo };
601436c6c9cSStella Laurenzo 
602436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
603436c6c9cSStella Laurenzo class PyDenseElementsAttribute
604436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
605436c6c9cSStella Laurenzo public:
606436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
607436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
608436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
609436c6c9cSStella Laurenzo 
610436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
6110a81ace0SKazu Hirata   getFromBuffer(py::buffer array, bool signless,
6120a81ace0SKazu Hirata                 std::optional<PyType> explicitType,
6130a81ace0SKazu Hirata                 std::optional<std::vector<int64_t>> explicitShape,
614436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
615436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
616436c6c9cSStella Laurenzo     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
617436c6c9cSStella Laurenzo     Py_buffer *view = new Py_buffer();
618436c6c9cSStella Laurenzo     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
619436c6c9cSStella Laurenzo       delete view;
620436c6c9cSStella Laurenzo       throw py::error_already_set();
621436c6c9cSStella Laurenzo     }
622436c6c9cSStella Laurenzo     py::buffer_info arrayInfo(view);
6235d6d30edSStella Laurenzo     SmallVector<int64_t> shape;
6245d6d30edSStella Laurenzo     if (explicitShape) {
6255d6d30edSStella Laurenzo       shape.append(explicitShape->begin(), explicitShape->end());
6265d6d30edSStella Laurenzo     } else {
6275d6d30edSStella Laurenzo       shape.append(arrayInfo.shape.begin(),
6285d6d30edSStella Laurenzo                    arrayInfo.shape.begin() + arrayInfo.ndim);
6295d6d30edSStella Laurenzo     }
630436c6c9cSStella Laurenzo 
6315d6d30edSStella Laurenzo     MlirAttribute encodingAttr = mlirAttributeGetNull();
632436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
6335d6d30edSStella Laurenzo 
6345d6d30edSStella Laurenzo     // Detect format codes that are suitable for bulk loading. This includes
6355d6d30edSStella Laurenzo     // all byte aligned integer and floating point types up to 8 bytes.
6365d6d30edSStella Laurenzo     // Notably, this excludes, bool (which needs to be bit-packed) and
6375d6d30edSStella Laurenzo     // other exotics which do not have a direct representation in the buffer
6385d6d30edSStella Laurenzo     // protocol (i.e. complex, etc).
6390a81ace0SKazu Hirata     std::optional<MlirType> bulkLoadElementType;
6405d6d30edSStella Laurenzo     if (explicitType) {
6415d6d30edSStella Laurenzo       bulkLoadElementType = *explicitType;
6425d6d30edSStella Laurenzo     } else if (arrayInfo.format == "f") {
643436c6c9cSStella Laurenzo       // f32
644436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
6455d6d30edSStella Laurenzo       bulkLoadElementType = mlirF32TypeGet(context);
646436c6c9cSStella Laurenzo     } else if (arrayInfo.format == "d") {
647436c6c9cSStella Laurenzo       // f64
648436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
6495d6d30edSStella Laurenzo       bulkLoadElementType = mlirF64TypeGet(context);
6505d6d30edSStella Laurenzo     } else if (arrayInfo.format == "e") {
6515d6d30edSStella Laurenzo       // f16
6525d6d30edSStella Laurenzo       assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
6535d6d30edSStella Laurenzo       bulkLoadElementType = mlirF16TypeGet(context);
654436c6c9cSStella Laurenzo     } else if (isSignedIntegerFormat(arrayInfo.format)) {
655436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
656436c6c9cSStella Laurenzo         // i32
6575d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
658436c6c9cSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 32);
659436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
660436c6c9cSStella Laurenzo         // i64
6615d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
662436c6c9cSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 64);
6635d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 1) {
6645d6d30edSStella Laurenzo         // i8
6655d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
6665d6d30edSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 8);
6675d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 2) {
6685d6d30edSStella Laurenzo         // i16
6695d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
6705d6d30edSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 16);
671436c6c9cSStella Laurenzo       }
672436c6c9cSStella Laurenzo     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
673436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
674436c6c9cSStella Laurenzo         // unsigned i32
6755d6d30edSStella Laurenzo         bulkLoadElementType = signless
676436c6c9cSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 32)
677436c6c9cSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 32);
678436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
679436c6c9cSStella Laurenzo         // unsigned i64
6805d6d30edSStella Laurenzo         bulkLoadElementType = signless
681436c6c9cSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 64)
682436c6c9cSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 64);
6835d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 1) {
6845d6d30edSStella Laurenzo         // i8
6855d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
6865d6d30edSStella Laurenzo                                        : mlirIntegerTypeUnsignedGet(context, 8);
6875d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 2) {
6885d6d30edSStella Laurenzo         // i16
6895d6d30edSStella Laurenzo         bulkLoadElementType = signless
6905d6d30edSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 16)
6915d6d30edSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 16);
692436c6c9cSStella Laurenzo       }
693436c6c9cSStella Laurenzo     }
6945d6d30edSStella Laurenzo     if (bulkLoadElementType) {
69599dee31eSAdam Paszke       MlirType shapedType;
69699dee31eSAdam Paszke       if (mlirTypeIsAShaped(*bulkLoadElementType)) {
69799dee31eSAdam Paszke         if (explicitShape) {
69899dee31eSAdam Paszke           throw std::invalid_argument("Shape can only be specified explicitly "
69999dee31eSAdam Paszke                                       "when the type is not a shaped type.");
70099dee31eSAdam Paszke         }
70199dee31eSAdam Paszke         shapedType = *bulkLoadElementType;
70299dee31eSAdam Paszke       } else {
70399dee31eSAdam Paszke         shapedType = mlirRankedTensorTypeGet(
7045d6d30edSStella Laurenzo             shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
70599dee31eSAdam Paszke       }
7065d6d30edSStella Laurenzo       size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
7075d6d30edSStella Laurenzo       MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
7085d6d30edSStella Laurenzo           shapedType, rawBufferSize, arrayInfo.ptr);
7095d6d30edSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
7105d6d30edSStella Laurenzo         throw std::invalid_argument(
7115d6d30edSStella Laurenzo             "DenseElementsAttr could not be constructed from the given buffer. "
7125d6d30edSStella Laurenzo             "This may mean that the Python buffer layout does not match that "
7135d6d30edSStella Laurenzo             "MLIR expected layout and is a bug.");
7145d6d30edSStella Laurenzo       }
7155d6d30edSStella Laurenzo       return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
7165d6d30edSStella Laurenzo     }
717436c6c9cSStella Laurenzo 
7185d6d30edSStella Laurenzo     throw std::invalid_argument(
7195d6d30edSStella Laurenzo         std::string("unimplemented array format conversion from format: ") +
7205d6d30edSStella Laurenzo         arrayInfo.format);
721436c6c9cSStella Laurenzo   }
722436c6c9cSStella Laurenzo 
7231fc096afSMehdi Amini   static PyDenseElementsAttribute getSplat(const PyType &shapedType,
724436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
725436c6c9cSStella Laurenzo     auto contextWrapper =
726436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
727436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
728436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
729436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
730436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
7314811270bSmax       throw py::value_error(message);
732436c6c9cSStella Laurenzo     }
733436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
734436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
735436c6c9cSStella Laurenzo       std::string message =
736436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
737436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
7384811270bSmax       throw py::value_error(message);
739436c6c9cSStella Laurenzo     }
740436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
741436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
742436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
743436c6c9cSStella Laurenzo       std::string message =
744436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
745436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
746436c6c9cSStella Laurenzo       message.append(", element=");
747436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
7484811270bSmax       throw py::value_error(message);
749436c6c9cSStella Laurenzo     }
750436c6c9cSStella Laurenzo 
751436c6c9cSStella Laurenzo     MlirAttribute elements =
752436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
753436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
754436c6c9cSStella Laurenzo   }
755436c6c9cSStella Laurenzo 
756436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
757436c6c9cSStella Laurenzo 
758436c6c9cSStella Laurenzo   py::buffer_info accessBuffer() {
759436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
760436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
7615d6d30edSStella Laurenzo     std::string format;
762436c6c9cSStella Laurenzo 
763436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
764436c6c9cSStella Laurenzo       // f32
7655d6d30edSStella Laurenzo       return bufferInfo<float>(shapedType);
76602b6fb21SMehdi Amini     }
76702b6fb21SMehdi Amini     if (mlirTypeIsAF64(elementType)) {
768436c6c9cSStella Laurenzo       // f64
7695d6d30edSStella Laurenzo       return bufferInfo<double>(shapedType);
770bb56c2b3SMehdi Amini     }
771bb56c2b3SMehdi Amini     if (mlirTypeIsAF16(elementType)) {
7725d6d30edSStella Laurenzo       // f16
7735d6d30edSStella Laurenzo       return bufferInfo<uint16_t>(shapedType, "e");
774bb56c2b3SMehdi Amini     }
775ef1b735dSmax     if (mlirTypeIsAIndex(elementType)) {
776ef1b735dSmax       // Same as IndexType::kInternalStorageBitWidth
777ef1b735dSmax       return bufferInfo<int64_t>(shapedType);
778ef1b735dSmax     }
779bb56c2b3SMehdi Amini     if (mlirTypeIsAInteger(elementType) &&
780436c6c9cSStella Laurenzo         mlirIntegerTypeGetWidth(elementType) == 32) {
781436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
782436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
783436c6c9cSStella Laurenzo         // i32
7845d6d30edSStella Laurenzo         return bufferInfo<int32_t>(shapedType);
785e5639b3fSMehdi Amini       }
786e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
787436c6c9cSStella Laurenzo         // unsigned i32
7885d6d30edSStella Laurenzo         return bufferInfo<uint32_t>(shapedType);
789436c6c9cSStella Laurenzo       }
790436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
791436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
792436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
793436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
794436c6c9cSStella Laurenzo         // i64
7955d6d30edSStella Laurenzo         return bufferInfo<int64_t>(shapedType);
796e5639b3fSMehdi Amini       }
797e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
798436c6c9cSStella Laurenzo         // unsigned i64
7995d6d30edSStella Laurenzo         return bufferInfo<uint64_t>(shapedType);
8005d6d30edSStella Laurenzo       }
8015d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
8025d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 8) {
8035d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
8045d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
8055d6d30edSStella Laurenzo         // i8
8065d6d30edSStella Laurenzo         return bufferInfo<int8_t>(shapedType);
807e5639b3fSMehdi Amini       }
808e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
8095d6d30edSStella Laurenzo         // unsigned i8
8105d6d30edSStella Laurenzo         return bufferInfo<uint8_t>(shapedType);
8115d6d30edSStella Laurenzo       }
8125d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
8135d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 16) {
8145d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
8155d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
8165d6d30edSStella Laurenzo         // i16
8175d6d30edSStella Laurenzo         return bufferInfo<int16_t>(shapedType);
818e5639b3fSMehdi Amini       }
819e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
8205d6d30edSStella Laurenzo         // unsigned i16
8215d6d30edSStella Laurenzo         return bufferInfo<uint16_t>(shapedType);
822436c6c9cSStella Laurenzo       }
823436c6c9cSStella Laurenzo     }
824436c6c9cSStella Laurenzo 
825c5f445d1SStella Laurenzo     // TODO: Currently crashes the program.
8265d6d30edSStella Laurenzo     // Reported as https://github.com/pybind/pybind11/issues/3336
827c5f445d1SStella Laurenzo     throw std::invalid_argument(
828c5f445d1SStella Laurenzo         "unsupported data type for conversion to Python buffer");
829436c6c9cSStella Laurenzo   }
830436c6c9cSStella Laurenzo 
831436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
832436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
833436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
834436c6c9cSStella Laurenzo                     py::arg("array"), py::arg("signless") = true,
8355d6d30edSStella Laurenzo                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
836436c6c9cSStella Laurenzo                     py::arg("context") = py::none(),
8375d6d30edSStella Laurenzo                     kDenseElementsAttrGetDocstring)
838436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
839436c6c9cSStella Laurenzo                     py::arg("shaped_type"), py::arg("element_attr"),
840436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
841436c6c9cSStella Laurenzo         .def_property_readonly("is_splat",
842436c6c9cSStella Laurenzo                                [](PyDenseElementsAttribute &self) -> bool {
843436c6c9cSStella Laurenzo                                  return mlirDenseElementsAttrIsSplat(self);
844436c6c9cSStella Laurenzo                                })
84591259963SAdam Paszke         .def("get_splat_value",
84691259963SAdam Paszke              [](PyDenseElementsAttribute &self) -> PyAttribute {
84791259963SAdam Paszke                if (!mlirDenseElementsAttrIsSplat(self)) {
8484811270bSmax                  throw py::value_error(
84991259963SAdam Paszke                      "get_splat_value called on a non-splat attribute");
85091259963SAdam Paszke                }
85191259963SAdam Paszke                return PyAttribute(self.getContext(),
85291259963SAdam Paszke                                   mlirDenseElementsAttrGetSplatValue(self));
85391259963SAdam Paszke              })
854436c6c9cSStella Laurenzo         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
855436c6c9cSStella Laurenzo   }
856436c6c9cSStella Laurenzo 
857436c6c9cSStella Laurenzo private:
858436c6c9cSStella Laurenzo   static bool isUnsignedIntegerFormat(const std::string &format) {
859436c6c9cSStella Laurenzo     if (format.empty())
860436c6c9cSStella Laurenzo       return false;
861436c6c9cSStella Laurenzo     char code = format[0];
862436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
863436c6c9cSStella Laurenzo            code == 'Q';
864436c6c9cSStella Laurenzo   }
865436c6c9cSStella Laurenzo 
866436c6c9cSStella Laurenzo   static bool isSignedIntegerFormat(const std::string &format) {
867436c6c9cSStella Laurenzo     if (format.empty())
868436c6c9cSStella Laurenzo       return false;
869436c6c9cSStella Laurenzo     char code = format[0];
870436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
871436c6c9cSStella Laurenzo            code == 'q';
872436c6c9cSStella Laurenzo   }
873436c6c9cSStella Laurenzo 
874436c6c9cSStella Laurenzo   template <typename Type>
875436c6c9cSStella Laurenzo   py::buffer_info bufferInfo(MlirType shapedType,
8765d6d30edSStella Laurenzo                              const char *explicitFormat = nullptr) {
877436c6c9cSStella Laurenzo     intptr_t rank = mlirShapedTypeGetRank(shapedType);
878436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
879436c6c9cSStella Laurenzo     // Buffer is configured for read-only access below.
880436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
881436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
882436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
883436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
884436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
885436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
886436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
887436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
888f0e847d0SRahul Kayaith     if (mlirDenseElementsAttrIsSplat(*this)) {
889f0e847d0SRahul Kayaith       // Splats are special, only the single value is stored.
890f0e847d0SRahul Kayaith       strides.assign(rank, 0);
891f0e847d0SRahul Kayaith     } else {
892436c6c9cSStella Laurenzo       for (intptr_t i = 1; i < rank; ++i) {
893f0e847d0SRahul Kayaith         intptr_t strideFactor = 1;
894f0e847d0SRahul Kayaith         for (intptr_t j = i; j < rank; ++j)
895436c6c9cSStella Laurenzo           strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
896436c6c9cSStella Laurenzo         strides.push_back(sizeof(Type) * strideFactor);
897436c6c9cSStella Laurenzo       }
898436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type));
899f0e847d0SRahul Kayaith     }
9005d6d30edSStella Laurenzo     std::string format;
9015d6d30edSStella Laurenzo     if (explicitFormat) {
9025d6d30edSStella Laurenzo       format = explicitFormat;
9035d6d30edSStella Laurenzo     } else {
9045d6d30edSStella Laurenzo       format = py::format_descriptor<Type>::format();
9055d6d30edSStella Laurenzo     }
9065d6d30edSStella Laurenzo     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
9075d6d30edSStella Laurenzo                            /*readonly=*/true);
908436c6c9cSStella Laurenzo   }
909436c6c9cSStella Laurenzo }; // namespace
910436c6c9cSStella Laurenzo 
911436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer
912436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access.
913436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
914436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
915436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
916436c6c9cSStella Laurenzo public:
917436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
918436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
919436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
920436c6c9cSStella Laurenzo 
921436c6c9cSStella Laurenzo   /// Returns the element at the given linear position. Asserts if the index is
922436c6c9cSStella Laurenzo   /// out of range.
923436c6c9cSStella Laurenzo   py::int_ dunderGetItem(intptr_t pos) {
924436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
9254811270bSmax       throw py::index_error("attempt to access out of bounds element");
926436c6c9cSStella Laurenzo     }
927436c6c9cSStella Laurenzo 
928436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
929436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
930436c6c9cSStella Laurenzo     assert(mlirTypeIsAInteger(type) &&
931436c6c9cSStella Laurenzo            "expected integer element type in dense int elements attribute");
932436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
933436c6c9cSStella Laurenzo     // elemental type of the attribute. py::int_ is implicitly constructible
934436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
935436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
936436c6c9cSStella Laurenzo     // querying them on each element access.
937436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
938436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
939436c6c9cSStella Laurenzo     if (isUnsigned) {
940436c6c9cSStella Laurenzo       if (width == 1) {
941436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
942436c6c9cSStella Laurenzo       }
943308d8b8cSRahul Kayaith       if (width == 8) {
944308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt8Value(*this, pos);
945308d8b8cSRahul Kayaith       }
946308d8b8cSRahul Kayaith       if (width == 16) {
947308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt16Value(*this, pos);
948308d8b8cSRahul Kayaith       }
949436c6c9cSStella Laurenzo       if (width == 32) {
950436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
951436c6c9cSStella Laurenzo       }
952436c6c9cSStella Laurenzo       if (width == 64) {
953436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
954436c6c9cSStella Laurenzo       }
955436c6c9cSStella Laurenzo     } else {
956436c6c9cSStella Laurenzo       if (width == 1) {
957436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
958436c6c9cSStella Laurenzo       }
959308d8b8cSRahul Kayaith       if (width == 8) {
960308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt8Value(*this, pos);
961308d8b8cSRahul Kayaith       }
962308d8b8cSRahul Kayaith       if (width == 16) {
963308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt16Value(*this, pos);
964308d8b8cSRahul Kayaith       }
965436c6c9cSStella Laurenzo       if (width == 32) {
966436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt32Value(*this, pos);
967436c6c9cSStella Laurenzo       }
968436c6c9cSStella Laurenzo       if (width == 64) {
969436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt64Value(*this, pos);
970436c6c9cSStella Laurenzo       }
971436c6c9cSStella Laurenzo     }
9724811270bSmax     throw py::type_error("Unsupported integer type");
973436c6c9cSStella Laurenzo   }
974436c6c9cSStella Laurenzo 
975436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
976436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
977436c6c9cSStella Laurenzo   }
978436c6c9cSStella Laurenzo };
979436c6c9cSStella Laurenzo 
980436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
981436c6c9cSStella Laurenzo public:
982436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
983436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
984436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
9859566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
9869566ee28Smax       mlirDictionaryAttrGetTypeID;
987436c6c9cSStella Laurenzo 
988436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
989436c6c9cSStella Laurenzo 
9909fb1086bSAdrian Kuegel   bool dunderContains(const std::string &name) {
9919fb1086bSAdrian Kuegel     return !mlirAttributeIsNull(
9929fb1086bSAdrian Kuegel         mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
9939fb1086bSAdrian Kuegel   }
9949fb1086bSAdrian Kuegel 
995436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
9969fb1086bSAdrian Kuegel     c.def("__contains__", &PyDictAttribute::dunderContains);
997436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
998436c6c9cSStella Laurenzo     c.def_static(
999436c6c9cSStella Laurenzo         "get",
1000436c6c9cSStella Laurenzo         [](py::dict attributes, DefaultingPyMlirContext context) {
1001436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
1002436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
1003436c6c9cSStella Laurenzo           for (auto &it : attributes) {
100402b6fb21SMehdi Amini             auto &mlirAttr = it.second.cast<PyAttribute &>();
1005436c6c9cSStella Laurenzo             auto name = it.first.cast<std::string>();
1006436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
100702b6fb21SMehdi Amini                 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
1008436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
100902b6fb21SMehdi Amini                 mlirAttr));
1010436c6c9cSStella Laurenzo           }
1011436c6c9cSStella Laurenzo           MlirAttribute attr =
1012436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
1013436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
1014436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
1015436c6c9cSStella Laurenzo         },
1016ed9e52f3SAlex Zinenko         py::arg("value") = py::dict(), py::arg("context") = py::none(),
1017436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
1018436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
1019436c6c9cSStella Laurenzo       MlirAttribute attr =
1020436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
1021436c6c9cSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
10224811270bSmax         throw py::key_error("attempt to access a non-existent attribute");
1023436c6c9cSStella Laurenzo       }
1024436c6c9cSStella Laurenzo       return PyAttribute(self.getContext(), attr);
1025436c6c9cSStella Laurenzo     });
1026436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
1027436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
10284811270bSmax         throw py::index_error("attempt to access out of bounds attribute");
1029436c6c9cSStella Laurenzo       }
1030436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
1031436c6c9cSStella Laurenzo       return PyNamedAttribute(
1032436c6c9cSStella Laurenzo           namedAttr.attribute,
1033436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
1034436c6c9cSStella Laurenzo     });
1035436c6c9cSStella Laurenzo   }
1036436c6c9cSStella Laurenzo };
1037436c6c9cSStella Laurenzo 
1038436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
1039436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
1040436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
1041436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1042436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
1043436c6c9cSStella Laurenzo public:
1044436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1045436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
1046436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1047436c6c9cSStella Laurenzo 
1048436c6c9cSStella Laurenzo   py::float_ dunderGetItem(intptr_t pos) {
1049436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
10504811270bSmax       throw py::index_error("attempt to access out of bounds element");
1051436c6c9cSStella Laurenzo     }
1052436c6c9cSStella Laurenzo 
1053436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
1054436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
1055436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
1056436c6c9cSStella Laurenzo     // elemental type of the attribute. py::float_ is implicitly constructible
1057436c6c9cSStella Laurenzo     // from float and double.
1058436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
1059436c6c9cSStella Laurenzo     // querying them on each element access.
1060436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
1061436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetFloatValue(*this, pos);
1062436c6c9cSStella Laurenzo     }
1063436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
1064436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
1065436c6c9cSStella Laurenzo     }
10664811270bSmax     throw py::type_error("Unsupported floating-point type");
1067436c6c9cSStella Laurenzo   }
1068436c6c9cSStella Laurenzo 
1069436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1070436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1071436c6c9cSStella Laurenzo   }
1072436c6c9cSStella Laurenzo };
1073436c6c9cSStella Laurenzo 
1074436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1075436c6c9cSStella Laurenzo public:
1076436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1077436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
1078436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
10799566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
10809566ee28Smax       mlirTypeAttrGetTypeID;
1081436c6c9cSStella Laurenzo 
1082436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1083436c6c9cSStella Laurenzo     c.def_static(
1084436c6c9cSStella Laurenzo         "get",
1085436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
1086436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
1087436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
1088436c6c9cSStella Laurenzo         },
1089436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
1090436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
1091436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyTypeAttribute &self) {
1092bfb1ba75Smax       return mlirTypeAttrGetValue(self.get());
1093436c6c9cSStella Laurenzo     });
1094436c6c9cSStella Laurenzo   }
1095436c6c9cSStella Laurenzo };
1096436c6c9cSStella Laurenzo 
1097436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
1098436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1099436c6c9cSStella Laurenzo public:
1100436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1101436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
1102436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
11039566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
11049566ee28Smax       mlirUnitAttrGetTypeID;
1105436c6c9cSStella Laurenzo 
1106436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1107436c6c9cSStella Laurenzo     c.def_static(
1108436c6c9cSStella Laurenzo         "get",
1109436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
1110436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
1111436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
1112436c6c9cSStella Laurenzo         },
1113436c6c9cSStella Laurenzo         py::arg("context") = py::none(), "Create a Unit attribute.");
1114436c6c9cSStella Laurenzo   }
1115436c6c9cSStella Laurenzo };
1116436c6c9cSStella Laurenzo 
1117ac2e2d65SDenys Shabalin /// Strided layout attribute subclass.
1118ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute
1119ac2e2d65SDenys Shabalin     : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1120ac2e2d65SDenys Shabalin public:
1121ac2e2d65SDenys Shabalin   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1122ac2e2d65SDenys Shabalin   static constexpr const char *pyClassName = "StridedLayoutAttr";
1123ac2e2d65SDenys Shabalin   using PyConcreteAttribute::PyConcreteAttribute;
11249566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
11259566ee28Smax       mlirStridedLayoutAttrGetTypeID;
1126ac2e2d65SDenys Shabalin 
1127ac2e2d65SDenys Shabalin   static void bindDerived(ClassTy &c) {
1128ac2e2d65SDenys Shabalin     c.def_static(
1129ac2e2d65SDenys Shabalin         "get",
1130ac2e2d65SDenys Shabalin         [](int64_t offset, const std::vector<int64_t> strides,
1131ac2e2d65SDenys Shabalin            DefaultingPyMlirContext ctx) {
1132ac2e2d65SDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1133ac2e2d65SDenys Shabalin               ctx->get(), offset, strides.size(), strides.data());
1134ac2e2d65SDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1135ac2e2d65SDenys Shabalin         },
1136ac2e2d65SDenys Shabalin         py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
1137ac2e2d65SDenys Shabalin         "Gets a strided layout attribute.");
1138e3fd612eSDenys Shabalin     c.def_static(
1139e3fd612eSDenys Shabalin         "get_fully_dynamic",
1140e3fd612eSDenys Shabalin         [](int64_t rank, DefaultingPyMlirContext ctx) {
1141e3fd612eSDenys Shabalin           auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
1142e3fd612eSDenys Shabalin           std::vector<int64_t> strides(rank);
1143e3fd612eSDenys Shabalin           std::fill(strides.begin(), strides.end(), dynamic);
1144e3fd612eSDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1145e3fd612eSDenys Shabalin               ctx->get(), dynamic, strides.size(), strides.data());
1146e3fd612eSDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1147e3fd612eSDenys Shabalin         },
1148e3fd612eSDenys Shabalin         py::arg("rank"), py::arg("context") = py::none(),
1149e3fd612eSDenys Shabalin         "Gets a strided layout attribute with dynamic offset and strides of a "
1150e3fd612eSDenys Shabalin         "given rank.");
1151ac2e2d65SDenys Shabalin     c.def_property_readonly(
1152ac2e2d65SDenys Shabalin         "offset",
1153ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1154ac2e2d65SDenys Shabalin           return mlirStridedLayoutAttrGetOffset(self);
1155ac2e2d65SDenys Shabalin         },
1156ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1157ac2e2d65SDenys Shabalin     c.def_property_readonly(
1158ac2e2d65SDenys Shabalin         "strides",
1159ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1160ac2e2d65SDenys Shabalin           intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1161ac2e2d65SDenys Shabalin           std::vector<int64_t> strides(size);
1162ac2e2d65SDenys Shabalin           for (intptr_t i = 0; i < size; i++) {
1163ac2e2d65SDenys Shabalin             strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1164ac2e2d65SDenys Shabalin           }
1165ac2e2d65SDenys Shabalin           return strides;
1166ac2e2d65SDenys Shabalin         },
1167ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1168ac2e2d65SDenys Shabalin   }
1169ac2e2d65SDenys Shabalin };
1170ac2e2d65SDenys Shabalin 
11719566ee28Smax py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
11729566ee28Smax   if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
11739566ee28Smax     return py::cast(PyDenseBoolArrayAttribute(pyAttribute));
11749566ee28Smax   if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
11759566ee28Smax     return py::cast(PyDenseI8ArrayAttribute(pyAttribute));
11769566ee28Smax   if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
11779566ee28Smax     return py::cast(PyDenseI16ArrayAttribute(pyAttribute));
11789566ee28Smax   if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
11799566ee28Smax     return py::cast(PyDenseI32ArrayAttribute(pyAttribute));
11809566ee28Smax   if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
11819566ee28Smax     return py::cast(PyDenseI64ArrayAttribute(pyAttribute));
11829566ee28Smax   if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
11839566ee28Smax     return py::cast(PyDenseF32ArrayAttribute(pyAttribute));
11849566ee28Smax   if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
11859566ee28Smax     return py::cast(PyDenseF64ArrayAttribute(pyAttribute));
11869566ee28Smax   std::string msg =
11879566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
11889566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
11899566ee28Smax   throw py::cast_error(msg);
11909566ee28Smax }
11919566ee28Smax 
11929566ee28Smax py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
11939566ee28Smax   if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
11949566ee28Smax     return py::cast(PyDenseFPElementsAttribute(pyAttribute));
11959566ee28Smax   if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
11969566ee28Smax     return py::cast(PyDenseIntElementsAttribute(pyAttribute));
11979566ee28Smax   std::string msg =
11989566ee28Smax       std::string(
11999566ee28Smax           "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
12009566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
12019566ee28Smax   throw py::cast_error(msg);
12029566ee28Smax }
12039566ee28Smax 
12049566ee28Smax py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
12059566ee28Smax   if (PyBoolAttribute::isaFunction(pyAttribute))
12069566ee28Smax     return py::cast(PyBoolAttribute(pyAttribute));
12079566ee28Smax   if (PyIntegerAttribute::isaFunction(pyAttribute))
12089566ee28Smax     return py::cast(PyIntegerAttribute(pyAttribute));
12099566ee28Smax   std::string msg =
12109566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
12119566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
12129566ee28Smax   throw py::cast_error(msg);
12139566ee28Smax }
12149566ee28Smax 
1215*4eee9ef9Smax py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
1216*4eee9ef9Smax   if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
1217*4eee9ef9Smax     return py::cast(PyFlatSymbolRefAttribute(pyAttribute));
1218*4eee9ef9Smax   if (PySymbolRefAttribute::isaFunction(pyAttribute))
1219*4eee9ef9Smax     return py::cast(PySymbolRefAttribute(pyAttribute));
1220*4eee9ef9Smax   std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
1221*4eee9ef9Smax                     std::string(py::repr(py::cast(pyAttribute))) + ")";
1222*4eee9ef9Smax   throw py::cast_error(msg);
1223*4eee9ef9Smax }
1224*4eee9ef9Smax 
1225436c6c9cSStella Laurenzo } // namespace
1226436c6c9cSStella Laurenzo 
1227436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) {
1228436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
1229619fd8c2SJeff Niu 
1230619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::bind(m);
1231619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1232619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::bind(m);
1233619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1234619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::bind(m);
1235619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1236619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::bind(m);
1237619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1238619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::bind(m);
1239619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1240619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::bind(m);
1241619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1242619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::bind(m);
1243619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
12449566ee28Smax   PyGlobals::get().registerTypeCaster(
12459566ee28Smax       mlirDenseArrayAttrGetTypeID(),
12469566ee28Smax       pybind11::cpp_function(denseArrayAttributeCaster));
1247619fd8c2SJeff Niu 
1248436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
1249436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1250436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
1251436c6c9cSStella Laurenzo   PyDenseElementsAttribute::bind(m);
1252436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
1253436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
12549566ee28Smax   PyGlobals::get().registerTypeCaster(
12559566ee28Smax       mlirDenseIntOrFPElementsAttrGetTypeID(),
12569566ee28Smax       pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
12579566ee28Smax 
1258436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
1259*4eee9ef9Smax   PySymbolRefAttribute::bind(m);
1260*4eee9ef9Smax   PyGlobals::get().registerTypeCaster(
1261*4eee9ef9Smax       mlirSymbolRefAttrGetTypeID(),
1262*4eee9ef9Smax       pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster));
1263*4eee9ef9Smax 
1264436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
12655c3861b2SYun Long   PyOpaqueAttribute::bind(m);
1266436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
1267436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
1268436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
1269436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
12709566ee28Smax   PyGlobals::get().registerTypeCaster(
12719566ee28Smax       mlirIntegerAttrGetTypeID(),
12729566ee28Smax       pybind11::cpp_function(integerOrBoolAttributeCaster));
1273436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
1274ac2e2d65SDenys Shabalin 
1275ac2e2d65SDenys Shabalin   PyStridedLayoutAttribute::bind(m);
1276436c6c9cSStella Laurenzo }
1277