xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision 71a254543d44a943dfe8790abc60795b87173f0b)
1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9a1fe1f5fSKazu Hirata #include <optional>
10*71a25454SPeter Hawkins #include <string_view>
114811270bSmax #include <utility>
121fc096afSMehdi Amini 
13436c6c9cSStella Laurenzo #include "IRModule.h"
14436c6c9cSStella Laurenzo 
15436c6c9cSStella Laurenzo #include "PybindUtils.h"
16436c6c9cSStella Laurenzo 
17*71a25454SPeter Hawkins #include "llvm/ADT/ScopeExit.h"
18*71a25454SPeter Hawkins 
19436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
20436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
21bfb1ba75Smax #include "mlir/Bindings/Python/PybindAdaptors.h"
22436c6c9cSStella Laurenzo 
23436c6c9cSStella Laurenzo namespace py = pybind11;
24436c6c9cSStella Laurenzo using namespace mlir;
25436c6c9cSStella Laurenzo using namespace mlir::python;
26436c6c9cSStella Laurenzo 
27436c6c9cSStella Laurenzo using llvm::SmallVector;
28436c6c9cSStella Laurenzo 
295d6d30edSStella Laurenzo //------------------------------------------------------------------------------
305d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
315d6d30edSStella Laurenzo //------------------------------------------------------------------------------
325d6d30edSStella Laurenzo 
335d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] =
345d6d30edSStella Laurenzo     R"(Gets a DenseElementsAttr from a Python buffer or array.
355d6d30edSStella Laurenzo 
365d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based
375d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and
385d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these
395d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer.
405d6d30edSStella Laurenzo 
415d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided
425d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal
435d6d30edSStella Laurenzo representation:
445d6d30edSStella Laurenzo 
455d6d30edSStella Laurenzo   * Integer types (except for i1): the buffer must be byte aligned to the
465d6d30edSStella Laurenzo     next byte boundary.
475d6d30edSStella Laurenzo   * Floating point types: Must be bit-castable to the given floating point
485d6d30edSStella Laurenzo     size.
495d6d30edSStella Laurenzo   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
505d6d30edSStella Laurenzo     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
515d6d30edSStella Laurenzo     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
525d6d30edSStella Laurenzo 
535d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0
545d6d30edSStella Laurenzo or 255), then a splat will be created.
555d6d30edSStella Laurenzo 
565d6d30edSStella Laurenzo Args:
575d6d30edSStella Laurenzo   array: The array or buffer to convert.
585d6d30edSStella Laurenzo   signless: If inferring an appropriate MLIR type, use signless types for
595d6d30edSStella Laurenzo     integers (defaults True).
605d6d30edSStella Laurenzo   type: Skips inference of the MLIR element type and uses this instead. The
615d6d30edSStella Laurenzo     storage size must be consistent with the actual contents of the buffer.
625d6d30edSStella Laurenzo   shape: Overrides the shape of the buffer when constructing the MLIR
635d6d30edSStella Laurenzo     shaped type. This is needed when the physical and logical shape differ (as
645d6d30edSStella Laurenzo     for i1).
655d6d30edSStella Laurenzo   context: Explicit context, if not from context manager.
665d6d30edSStella Laurenzo 
675d6d30edSStella Laurenzo Returns:
685d6d30edSStella Laurenzo   DenseElementsAttr on success.
695d6d30edSStella Laurenzo 
705d6d30edSStella Laurenzo Raises:
715d6d30edSStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
725d6d30edSStella Laurenzo     type or if the buffer does not meet expectations.
735d6d30edSStella Laurenzo )";
745d6d30edSStella Laurenzo 
75436c6c9cSStella Laurenzo namespace {
76436c6c9cSStella Laurenzo 
77436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
78436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
79436c6c9cSStella Laurenzo }
80436c6c9cSStella Laurenzo 
81436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
82436c6c9cSStella Laurenzo public:
83436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
84436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
85436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
869566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
879566ee28Smax       mlirAffineMapAttrGetTypeID;
88436c6c9cSStella Laurenzo 
89436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
90436c6c9cSStella Laurenzo     c.def_static(
91436c6c9cSStella Laurenzo         "get",
92436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
93436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
94436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
95436c6c9cSStella Laurenzo         },
96436c6c9cSStella Laurenzo         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
97436c6c9cSStella Laurenzo   }
98436c6c9cSStella Laurenzo };
99436c6c9cSStella Laurenzo 
100ed9e52f3SAlex Zinenko template <typename T>
101ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) {
102ed9e52f3SAlex Zinenko   try {
103ed9e52f3SAlex Zinenko     return object.cast<T>();
104ed9e52f3SAlex Zinenko   } catch (py::cast_error &err) {
105ed9e52f3SAlex Zinenko     std::string msg =
106ed9e52f3SAlex Zinenko         std::string(
107ed9e52f3SAlex Zinenko             "Invalid attribute when attempting to create an ArrayAttribute (") +
108ed9e52f3SAlex Zinenko         err.what() + ")";
109ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
110ed9e52f3SAlex Zinenko   } catch (py::reference_cast_error &err) {
111ed9e52f3SAlex Zinenko     std::string msg = std::string("Invalid attribute (None?) when attempting "
112ed9e52f3SAlex Zinenko                                   "to create an ArrayAttribute (") +
113ed9e52f3SAlex Zinenko                       err.what() + ")";
114ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
115ed9e52f3SAlex Zinenko   }
116ed9e52f3SAlex Zinenko }
117ed9e52f3SAlex Zinenko 
118619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived
119619fd8c2SJeff Niu /// implementation class.
120619fd8c2SJeff Niu template <typename EltTy, typename DerivedT>
121133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
122619fd8c2SJeff Niu public:
123133624acSJeff Niu   using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
124619fd8c2SJeff Niu 
125619fd8c2SJeff Niu   /// Iterator over the integer elements of a dense array.
126619fd8c2SJeff Niu   class PyDenseArrayIterator {
127619fd8c2SJeff Niu   public:
1284a1b1196SMehdi Amini     PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
129619fd8c2SJeff Niu 
130619fd8c2SJeff Niu     /// Return a copy of the iterator.
131619fd8c2SJeff Niu     PyDenseArrayIterator dunderIter() { return *this; }
132619fd8c2SJeff Niu 
133619fd8c2SJeff Niu     /// Return the next element.
134619fd8c2SJeff Niu     EltTy dunderNext() {
135619fd8c2SJeff Niu       // Throw if the index has reached the end.
136619fd8c2SJeff Niu       if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
137619fd8c2SJeff Niu         throw py::stop_iteration();
138619fd8c2SJeff Niu       return DerivedT::getElement(attr.get(), nextIndex++);
139619fd8c2SJeff Niu     }
140619fd8c2SJeff Niu 
141619fd8c2SJeff Niu     /// Bind the iterator class.
142619fd8c2SJeff Niu     static void bind(py::module &m) {
143619fd8c2SJeff Niu       py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName,
144619fd8c2SJeff Niu                                        py::module_local())
145619fd8c2SJeff Niu           .def("__iter__", &PyDenseArrayIterator::dunderIter)
146619fd8c2SJeff Niu           .def("__next__", &PyDenseArrayIterator::dunderNext);
147619fd8c2SJeff Niu     }
148619fd8c2SJeff Niu 
149619fd8c2SJeff Niu   private:
150619fd8c2SJeff Niu     /// The referenced dense array attribute.
151619fd8c2SJeff Niu     PyAttribute attr;
152619fd8c2SJeff Niu     /// The next index to read.
153619fd8c2SJeff Niu     int nextIndex = 0;
154619fd8c2SJeff Niu   };
155619fd8c2SJeff Niu 
156619fd8c2SJeff Niu   /// Get the element at the given index.
157619fd8c2SJeff Niu   EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
158619fd8c2SJeff Niu 
159619fd8c2SJeff Niu   /// Bind the attribute class.
160133624acSJeff Niu   static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
161619fd8c2SJeff Niu     // Bind the constructor.
162619fd8c2SJeff Niu     c.def_static(
163619fd8c2SJeff Niu         "get",
164619fd8c2SJeff Niu         [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
165619fd8c2SJeff Niu           MlirAttribute attr =
166619fd8c2SJeff Niu               DerivedT::getAttribute(ctx->get(), values.size(), values.data());
167133624acSJeff Niu           return DerivedT(ctx->getRef(), attr);
168619fd8c2SJeff Niu         },
169619fd8c2SJeff Niu         py::arg("values"), py::arg("context") = py::none(),
170619fd8c2SJeff Niu         "Gets a uniqued dense array attribute");
171619fd8c2SJeff Niu     // Bind the array methods.
172133624acSJeff Niu     c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
173619fd8c2SJeff Niu       if (i >= mlirDenseArrayGetNumElements(arr))
174619fd8c2SJeff Niu         throw py::index_error("DenseArray index out of range");
175619fd8c2SJeff Niu       return arr.getItem(i);
176619fd8c2SJeff Niu     });
177133624acSJeff Niu     c.def("__len__", [](const DerivedT &arr) {
178619fd8c2SJeff Niu       return mlirDenseArrayGetNumElements(arr);
179619fd8c2SJeff Niu     });
180133624acSJeff Niu     c.def("__iter__",
181133624acSJeff Niu           [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
1824a1b1196SMehdi Amini     c.def("__add__", [](DerivedT &arr, const py::list &extras) {
183619fd8c2SJeff Niu       std::vector<EltTy> values;
184619fd8c2SJeff Niu       intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
185619fd8c2SJeff Niu       values.reserve(numOldElements + py::len(extras));
186619fd8c2SJeff Niu       for (intptr_t i = 0; i < numOldElements; ++i)
187619fd8c2SJeff Niu         values.push_back(arr.getItem(i));
188619fd8c2SJeff Niu       for (py::handle attr : extras)
189619fd8c2SJeff Niu         values.push_back(pyTryCast<EltTy>(attr));
190619fd8c2SJeff Niu       MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(),
191619fd8c2SJeff Niu                                                   values.size(), values.data());
192133624acSJeff Niu       return DerivedT(arr.getContext(), attr);
193619fd8c2SJeff Niu     });
194619fd8c2SJeff Niu   }
195619fd8c2SJeff Niu };
196619fd8c2SJeff Niu 
197619fd8c2SJeff Niu /// Instantiate the python dense array classes.
198619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute
199619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int, PyDenseBoolArrayAttribute> {
200619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
201619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseBoolArrayGet;
202619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseBoolArrayGetElement;
203619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseBoolArrayAttr";
204619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
205619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
206619fd8c2SJeff Niu };
207619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute
208619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
209619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
210619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI8ArrayGet;
211619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI8ArrayGetElement;
212619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI8ArrayAttr";
213619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
214619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
215619fd8c2SJeff Niu };
216619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute
217619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
218619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
219619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI16ArrayGet;
220619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI16ArrayGetElement;
221619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI16ArrayAttr";
222619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
223619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
224619fd8c2SJeff Niu };
225619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute
226619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
227619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
228619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI32ArrayGet;
229619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI32ArrayGetElement;
230619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI32ArrayAttr";
231619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
232619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
233619fd8c2SJeff Niu };
234619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute
235619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
236619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
237619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI64ArrayGet;
238619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI64ArrayGetElement;
239619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI64ArrayAttr";
240619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
241619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
242619fd8c2SJeff Niu };
243619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute
244619fd8c2SJeff Niu     : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
245619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
246619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF32ArrayGet;
247619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF32ArrayGetElement;
248619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF32ArrayAttr";
249619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
250619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
251619fd8c2SJeff Niu };
252619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute
253619fd8c2SJeff Niu     : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
254619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
255619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF64ArrayGet;
256619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF64ArrayGetElement;
257619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF64ArrayAttr";
258619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
259619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
260619fd8c2SJeff Niu };
261619fd8c2SJeff Niu 
262436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
263436c6c9cSStella Laurenzo public:
264436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
265436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
266436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
2679566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
2689566ee28Smax       mlirArrayAttrGetTypeID;
269436c6c9cSStella Laurenzo 
270436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
271436c6c9cSStella Laurenzo   public:
2721fc096afSMehdi Amini     PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
273436c6c9cSStella Laurenzo 
274436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
275436c6c9cSStella Laurenzo 
276974c1596SRahul Kayaith     MlirAttribute dunderNext() {
277bca88952SJeff Niu       // TODO: Throw is an inefficient way to stop iteration.
278bca88952SJeff Niu       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
279436c6c9cSStella Laurenzo         throw py::stop_iteration();
280974c1596SRahul Kayaith       return mlirArrayAttrGetElement(attr.get(), nextIndex++);
281436c6c9cSStella Laurenzo     }
282436c6c9cSStella Laurenzo 
283436c6c9cSStella Laurenzo     static void bind(py::module &m) {
284f05ff4f7SStella Laurenzo       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
285f05ff4f7SStella Laurenzo                                            py::module_local())
286436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
287436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
288436c6c9cSStella Laurenzo     }
289436c6c9cSStella Laurenzo 
290436c6c9cSStella Laurenzo   private:
291436c6c9cSStella Laurenzo     PyAttribute attr;
292436c6c9cSStella Laurenzo     int nextIndex = 0;
293436c6c9cSStella Laurenzo   };
294436c6c9cSStella Laurenzo 
295974c1596SRahul Kayaith   MlirAttribute getItem(intptr_t i) {
296974c1596SRahul Kayaith     return mlirArrayAttrGetElement(*this, i);
297ed9e52f3SAlex Zinenko   }
298ed9e52f3SAlex Zinenko 
299436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
300436c6c9cSStella Laurenzo     c.def_static(
301436c6c9cSStella Laurenzo         "get",
302436c6c9cSStella Laurenzo         [](py::list attributes, DefaultingPyMlirContext context) {
303436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
304436c6c9cSStella Laurenzo           mlirAttributes.reserve(py::len(attributes));
305436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
306ed9e52f3SAlex Zinenko             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
307436c6c9cSStella Laurenzo           }
308436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
309436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
310436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
311436c6c9cSStella Laurenzo         },
312436c6c9cSStella Laurenzo         py::arg("attributes"), py::arg("context") = py::none(),
313436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
314436c6c9cSStella Laurenzo     c.def("__getitem__",
315436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
316436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
317436c6c9cSStella Laurenzo               throw py::index_error("ArrayAttribute index out of range");
318ed9e52f3SAlex Zinenko             return arr.getItem(i);
319436c6c9cSStella Laurenzo           })
320436c6c9cSStella Laurenzo         .def("__len__",
321436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
322436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
323436c6c9cSStella Laurenzo              })
324436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
325436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
326436c6c9cSStella Laurenzo         });
327ed9e52f3SAlex Zinenko     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
328ed9e52f3SAlex Zinenko       std::vector<MlirAttribute> attributes;
329ed9e52f3SAlex Zinenko       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
330ed9e52f3SAlex Zinenko       attributes.reserve(numOldElements + py::len(extras));
331ed9e52f3SAlex Zinenko       for (intptr_t i = 0; i < numOldElements; ++i)
332ed9e52f3SAlex Zinenko         attributes.push_back(arr.getItem(i));
333ed9e52f3SAlex Zinenko       for (py::handle attr : extras)
334ed9e52f3SAlex Zinenko         attributes.push_back(pyTryCast<PyAttribute>(attr));
335ed9e52f3SAlex Zinenko       MlirAttribute arrayAttr = mlirArrayAttrGet(
336ed9e52f3SAlex Zinenko           arr.getContext()->get(), attributes.size(), attributes.data());
337ed9e52f3SAlex Zinenko       return PyArrayAttribute(arr.getContext(), arrayAttr);
338ed9e52f3SAlex Zinenko     });
339436c6c9cSStella Laurenzo   }
340436c6c9cSStella Laurenzo };
341436c6c9cSStella Laurenzo 
342436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
343436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
344436c6c9cSStella Laurenzo public:
345436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
346436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
347436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
3489566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
3499566ee28Smax       mlirFloatAttrGetTypeID;
350436c6c9cSStella Laurenzo 
351436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
352436c6c9cSStella Laurenzo     c.def_static(
353436c6c9cSStella Laurenzo         "get",
354436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
3553ea4c501SRahul Kayaith           PyMlirContext::ErrorCapture errors(loc->getContext());
356436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
3573ea4c501SRahul Kayaith           if (mlirAttributeIsNull(attr))
3583ea4c501SRahul Kayaith             throw MLIRError("Invalid attribute", errors.take());
359436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
360436c6c9cSStella Laurenzo         },
361436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
362436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
363436c6c9cSStella Laurenzo     c.def_static(
364436c6c9cSStella Laurenzo         "get_f32",
365436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
366436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
367436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
368436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
369436c6c9cSStella Laurenzo         },
370436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
371436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
372436c6c9cSStella Laurenzo     c.def_static(
373436c6c9cSStella Laurenzo         "get_f64",
374436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
375436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
376436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
377436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
378436c6c9cSStella Laurenzo         },
379436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
380436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
381436c6c9cSStella Laurenzo     c.def_property_readonly(
382436c6c9cSStella Laurenzo         "value",
383436c6c9cSStella Laurenzo         [](PyFloatAttribute &self) {
384436c6c9cSStella Laurenzo           return mlirFloatAttrGetValueDouble(self);
385436c6c9cSStella Laurenzo         },
386436c6c9cSStella Laurenzo         "Returns the value of the float point attribute");
387436c6c9cSStella Laurenzo   }
388436c6c9cSStella Laurenzo };
389436c6c9cSStella Laurenzo 
390436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
391436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
392436c6c9cSStella Laurenzo public:
393436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
394436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
395436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
396436c6c9cSStella Laurenzo 
397436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
398436c6c9cSStella Laurenzo     c.def_static(
399436c6c9cSStella Laurenzo         "get",
400436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
401436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
402436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
403436c6c9cSStella Laurenzo         },
404436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"),
405436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
406436c6c9cSStella Laurenzo     c.def_property_readonly(
407436c6c9cSStella Laurenzo         "value",
408e9db306dSrkayaith         [](PyIntegerAttribute &self) -> py::int_ {
409e9db306dSrkayaith           MlirType type = mlirAttributeGetType(self);
410e9db306dSrkayaith           if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
411436c6c9cSStella Laurenzo             return mlirIntegerAttrGetValueInt(self);
412e9db306dSrkayaith           if (mlirIntegerTypeIsSigned(type))
413e9db306dSrkayaith             return mlirIntegerAttrGetValueSInt(self);
414e9db306dSrkayaith           return mlirIntegerAttrGetValueUInt(self);
415436c6c9cSStella Laurenzo         },
416436c6c9cSStella Laurenzo         "Returns the value of the integer attribute");
4179566ee28Smax     c.def_property_readonly_static("static_typeid",
4189566ee28Smax                                    [](py::object & /*class*/) -> MlirTypeID {
4199566ee28Smax                                      return mlirIntegerAttrGetTypeID();
4209566ee28Smax                                    });
421436c6c9cSStella Laurenzo   }
422436c6c9cSStella Laurenzo };
423436c6c9cSStella Laurenzo 
424436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
425436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
426436c6c9cSStella Laurenzo public:
427436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
428436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
429436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
430436c6c9cSStella Laurenzo 
431436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
432436c6c9cSStella Laurenzo     c.def_static(
433436c6c9cSStella Laurenzo         "get",
434436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
435436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
436436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
437436c6c9cSStella Laurenzo         },
438436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
439436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
440436c6c9cSStella Laurenzo     c.def_property_readonly(
441436c6c9cSStella Laurenzo         "value",
442436c6c9cSStella Laurenzo         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
443436c6c9cSStella Laurenzo         "Returns the value of the bool attribute");
444436c6c9cSStella Laurenzo   }
445436c6c9cSStella Laurenzo };
446436c6c9cSStella Laurenzo 
4474eee9ef9Smax class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
4484eee9ef9Smax public:
4494eee9ef9Smax   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
4504eee9ef9Smax   static constexpr const char *pyClassName = "SymbolRefAttr";
4514eee9ef9Smax   using PyConcreteAttribute::PyConcreteAttribute;
4524eee9ef9Smax 
4534eee9ef9Smax   static MlirAttribute fromList(const std::vector<std::string> &symbols,
4544eee9ef9Smax                                 PyMlirContext &context) {
4554eee9ef9Smax     if (symbols.empty())
4564eee9ef9Smax       throw std::runtime_error("SymbolRefAttr must be composed of at least "
4574eee9ef9Smax                                "one symbol.");
4584eee9ef9Smax     MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
4594eee9ef9Smax     SmallVector<MlirAttribute, 3> referenceAttrs;
4604eee9ef9Smax     for (size_t i = 1; i < symbols.size(); ++i) {
4614eee9ef9Smax       referenceAttrs.push_back(
4624eee9ef9Smax           mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
4634eee9ef9Smax     }
4644eee9ef9Smax     return mlirSymbolRefAttrGet(context.get(), rootSymbol,
4654eee9ef9Smax                                 referenceAttrs.size(), referenceAttrs.data());
4664eee9ef9Smax   }
4674eee9ef9Smax 
4684eee9ef9Smax   static void bindDerived(ClassTy &c) {
4694eee9ef9Smax     c.def_static(
4704eee9ef9Smax         "get",
4714eee9ef9Smax         [](const std::vector<std::string> &symbols,
4724eee9ef9Smax            DefaultingPyMlirContext context) {
4734eee9ef9Smax           return PySymbolRefAttribute::fromList(symbols, context.resolve());
4744eee9ef9Smax         },
4754eee9ef9Smax         py::arg("symbols"), py::arg("context") = py::none(),
4764eee9ef9Smax         "Gets a uniqued SymbolRef attribute from a list of symbol names");
4774eee9ef9Smax     c.def_property_readonly(
4784eee9ef9Smax         "value",
4794eee9ef9Smax         [](PySymbolRefAttribute &self) {
4804eee9ef9Smax           std::vector<std::string> symbols = {
4814eee9ef9Smax               unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
4824eee9ef9Smax           for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
4834eee9ef9Smax                ++i)
4844eee9ef9Smax             symbols.push_back(
4854eee9ef9Smax                 unwrap(mlirSymbolRefAttrGetRootReference(
4864eee9ef9Smax                            mlirSymbolRefAttrGetNestedReference(self, i)))
4874eee9ef9Smax                     .str());
4884eee9ef9Smax           return symbols;
4894eee9ef9Smax         },
4904eee9ef9Smax         "Returns the value of the SymbolRef attribute as a list[str]");
4914eee9ef9Smax   }
4924eee9ef9Smax };
4934eee9ef9Smax 
494436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
495436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
496436c6c9cSStella Laurenzo public:
497436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
498436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
499436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
500436c6c9cSStella Laurenzo 
501436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
502436c6c9cSStella Laurenzo     c.def_static(
503436c6c9cSStella Laurenzo         "get",
504436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
505436c6c9cSStella Laurenzo           MlirAttribute attr =
506436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
507436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
508436c6c9cSStella Laurenzo         },
509436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
510436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
511436c6c9cSStella Laurenzo     c.def_property_readonly(
512436c6c9cSStella Laurenzo         "value",
513436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
514436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
515436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
516436c6c9cSStella Laurenzo         },
517436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
518436c6c9cSStella Laurenzo   }
519436c6c9cSStella Laurenzo };
520436c6c9cSStella Laurenzo 
5215c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
5225c3861b2SYun Long public:
5235c3861b2SYun Long   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
5245c3861b2SYun Long   static constexpr const char *pyClassName = "OpaqueAttr";
5255c3861b2SYun Long   using PyConcreteAttribute::PyConcreteAttribute;
5269566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
5279566ee28Smax       mlirOpaqueAttrGetTypeID;
5285c3861b2SYun Long 
5295c3861b2SYun Long   static void bindDerived(ClassTy &c) {
5305c3861b2SYun Long     c.def_static(
5315c3861b2SYun Long         "get",
5325c3861b2SYun Long         [](std::string dialectNamespace, py::buffer buffer, PyType &type,
5335c3861b2SYun Long            DefaultingPyMlirContext context) {
5345c3861b2SYun Long           const py::buffer_info bufferInfo = buffer.request();
5355c3861b2SYun Long           intptr_t bufferSize = bufferInfo.size;
5365c3861b2SYun Long           MlirAttribute attr = mlirOpaqueAttrGet(
5375c3861b2SYun Long               context->get(), toMlirStringRef(dialectNamespace), bufferSize,
5385c3861b2SYun Long               static_cast<char *>(bufferInfo.ptr), type);
5395c3861b2SYun Long           return PyOpaqueAttribute(context->getRef(), attr);
5405c3861b2SYun Long         },
5415c3861b2SYun Long         py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"),
5425c3861b2SYun Long         py::arg("context") = py::none(), "Gets an Opaque attribute.");
5435c3861b2SYun Long     c.def_property_readonly(
5445c3861b2SYun Long         "dialect_namespace",
5455c3861b2SYun Long         [](PyOpaqueAttribute &self) {
5465c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
5475c3861b2SYun Long           return py::str(stringRef.data, stringRef.length);
5485c3861b2SYun Long         },
5495c3861b2SYun Long         "Returns the dialect namespace for the Opaque attribute as a string");
5505c3861b2SYun Long     c.def_property_readonly(
5515c3861b2SYun Long         "data",
5525c3861b2SYun Long         [](PyOpaqueAttribute &self) {
5535c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
55462bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
5555c3861b2SYun Long         },
55662bf6c2eSChris Jones         "Returns the data for the Opaqued attributes as `bytes`");
5575c3861b2SYun Long   }
5585c3861b2SYun Long };
5595c3861b2SYun Long 
560436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
561436c6c9cSStella Laurenzo public:
562436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
563436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
564436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
5659566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
5669566ee28Smax       mlirStringAttrGetTypeID;
567436c6c9cSStella Laurenzo 
568436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
569436c6c9cSStella Laurenzo     c.def_static(
570436c6c9cSStella Laurenzo         "get",
571436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
572436c6c9cSStella Laurenzo           MlirAttribute attr =
573436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
574436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
575436c6c9cSStella Laurenzo         },
576436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
577436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
578436c6c9cSStella Laurenzo     c.def_static(
579436c6c9cSStella Laurenzo         "get_typed",
580436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
581436c6c9cSStella Laurenzo           MlirAttribute attr =
582436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
583436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
584436c6c9cSStella Laurenzo         },
585a6e7d024SStella Laurenzo         py::arg("type"), py::arg("value"),
586436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
587436c6c9cSStella Laurenzo     c.def_property_readonly(
588436c6c9cSStella Laurenzo         "value",
589436c6c9cSStella Laurenzo         [](PyStringAttribute &self) {
590436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirStringAttrGetValue(self);
591436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
592436c6c9cSStella Laurenzo         },
593436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
59462bf6c2eSChris Jones     c.def_property_readonly(
59562bf6c2eSChris Jones         "value_bytes",
59662bf6c2eSChris Jones         [](PyStringAttribute &self) {
59762bf6c2eSChris Jones           MlirStringRef stringRef = mlirStringAttrGetValue(self);
59862bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
59962bf6c2eSChris Jones         },
60062bf6c2eSChris Jones         "Returns the value of the string attribute as `bytes`");
601436c6c9cSStella Laurenzo   }
602436c6c9cSStella Laurenzo };
603436c6c9cSStella Laurenzo 
604436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
605436c6c9cSStella Laurenzo class PyDenseElementsAttribute
606436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
607436c6c9cSStella Laurenzo public:
608436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
609436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
610436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
611436c6c9cSStella Laurenzo 
612436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
6130a81ace0SKazu Hirata   getFromBuffer(py::buffer array, bool signless,
6140a81ace0SKazu Hirata                 std::optional<PyType> explicitType,
6150a81ace0SKazu Hirata                 std::optional<std::vector<int64_t>> explicitShape,
616436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
617436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
618*71a25454SPeter Hawkins     int flags = PyBUF_ND;
619*71a25454SPeter Hawkins     if (!explicitType) {
620*71a25454SPeter Hawkins       flags |= PyBUF_FORMAT;
621*71a25454SPeter Hawkins     }
622*71a25454SPeter Hawkins     Py_buffer view;
623*71a25454SPeter Hawkins     if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
624436c6c9cSStella Laurenzo       throw py::error_already_set();
625436c6c9cSStella Laurenzo     }
626*71a25454SPeter Hawkins     auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
6275d6d30edSStella Laurenzo     SmallVector<int64_t> shape;
6285d6d30edSStella Laurenzo     if (explicitShape) {
6295d6d30edSStella Laurenzo       shape.append(explicitShape->begin(), explicitShape->end());
6305d6d30edSStella Laurenzo     } else {
631*71a25454SPeter Hawkins       shape.append(view.shape, view.shape + view.ndim);
6325d6d30edSStella Laurenzo     }
633436c6c9cSStella Laurenzo 
6345d6d30edSStella Laurenzo     MlirAttribute encodingAttr = mlirAttributeGetNull();
635436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
6365d6d30edSStella Laurenzo 
6375d6d30edSStella Laurenzo     // Detect format codes that are suitable for bulk loading. This includes
6385d6d30edSStella Laurenzo     // all byte aligned integer and floating point types up to 8 bytes.
6395d6d30edSStella Laurenzo     // Notably, this excludes, bool (which needs to be bit-packed) and
6405d6d30edSStella Laurenzo     // other exotics which do not have a direct representation in the buffer
6415d6d30edSStella Laurenzo     // protocol (i.e. complex, etc).
6420a81ace0SKazu Hirata     std::optional<MlirType> bulkLoadElementType;
6435d6d30edSStella Laurenzo     if (explicitType) {
6445d6d30edSStella Laurenzo       bulkLoadElementType = *explicitType;
645*71a25454SPeter Hawkins     } else {
646*71a25454SPeter Hawkins       std::string_view format(view.format);
647*71a25454SPeter Hawkins       if (format == "f") {
648436c6c9cSStella Laurenzo         // f32
649*71a25454SPeter Hawkins         assert(view.itemsize == 4 && "mismatched array itemsize");
6505d6d30edSStella Laurenzo         bulkLoadElementType = mlirF32TypeGet(context);
651*71a25454SPeter Hawkins       } else if (format == "d") {
652436c6c9cSStella Laurenzo         // f64
653*71a25454SPeter Hawkins         assert(view.itemsize == 8 && "mismatched array itemsize");
6545d6d30edSStella Laurenzo         bulkLoadElementType = mlirF64TypeGet(context);
655*71a25454SPeter Hawkins       } else if (format == "e") {
6565d6d30edSStella Laurenzo         // f16
657*71a25454SPeter Hawkins         assert(view.itemsize == 2 && "mismatched array itemsize");
6585d6d30edSStella Laurenzo         bulkLoadElementType = mlirF16TypeGet(context);
659*71a25454SPeter Hawkins       } else if (isSignedIntegerFormat(format)) {
660*71a25454SPeter Hawkins         if (view.itemsize == 4) {
661436c6c9cSStella Laurenzo           // i32
662*71a25454SPeter Hawkins           bulkLoadElementType = signless
663*71a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 32)
664436c6c9cSStella Laurenzo                                     : mlirIntegerTypeSignedGet(context, 32);
665*71a25454SPeter Hawkins         } else if (view.itemsize == 8) {
666436c6c9cSStella Laurenzo           // i64
667*71a25454SPeter Hawkins           bulkLoadElementType = signless
668*71a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 64)
669436c6c9cSStella Laurenzo                                     : mlirIntegerTypeSignedGet(context, 64);
670*71a25454SPeter Hawkins         } else if (view.itemsize == 1) {
6715d6d30edSStella Laurenzo           // i8
6725d6d30edSStella Laurenzo           bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
6735d6d30edSStella Laurenzo                                          : mlirIntegerTypeSignedGet(context, 8);
674*71a25454SPeter Hawkins         } else if (view.itemsize == 2) {
6755d6d30edSStella Laurenzo           // i16
676*71a25454SPeter Hawkins           bulkLoadElementType = signless
677*71a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 16)
6785d6d30edSStella Laurenzo                                     : mlirIntegerTypeSignedGet(context, 16);
679436c6c9cSStella Laurenzo         }
680*71a25454SPeter Hawkins       } else if (isUnsignedIntegerFormat(format)) {
681*71a25454SPeter Hawkins         if (view.itemsize == 4) {
682436c6c9cSStella Laurenzo           // unsigned i32
6835d6d30edSStella Laurenzo           bulkLoadElementType = signless
684436c6c9cSStella Laurenzo                                     ? mlirIntegerTypeGet(context, 32)
685436c6c9cSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 32);
686*71a25454SPeter Hawkins         } else if (view.itemsize == 8) {
687436c6c9cSStella Laurenzo           // unsigned i64
6885d6d30edSStella Laurenzo           bulkLoadElementType = signless
689436c6c9cSStella Laurenzo                                     ? mlirIntegerTypeGet(context, 64)
690436c6c9cSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 64);
691*71a25454SPeter Hawkins         } else if (view.itemsize == 1) {
6925d6d30edSStella Laurenzo           // i8
693*71a25454SPeter Hawkins           bulkLoadElementType = signless
694*71a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 8)
6955d6d30edSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 8);
696*71a25454SPeter Hawkins         } else if (view.itemsize == 2) {
6975d6d30edSStella Laurenzo           // i16
6985d6d30edSStella Laurenzo           bulkLoadElementType = signless
6995d6d30edSStella Laurenzo                                     ? mlirIntegerTypeGet(context, 16)
7005d6d30edSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 16);
701436c6c9cSStella Laurenzo         }
702436c6c9cSStella Laurenzo       }
703*71a25454SPeter Hawkins       if (!bulkLoadElementType) {
704*71a25454SPeter Hawkins         throw std::invalid_argument(
705*71a25454SPeter Hawkins             std::string("unimplemented array format conversion from format: ") +
706*71a25454SPeter Hawkins             std::string(format));
707*71a25454SPeter Hawkins       }
708*71a25454SPeter Hawkins     }
709*71a25454SPeter Hawkins 
71099dee31eSAdam Paszke     MlirType shapedType;
71199dee31eSAdam Paszke     if (mlirTypeIsAShaped(*bulkLoadElementType)) {
71299dee31eSAdam Paszke       if (explicitShape) {
71399dee31eSAdam Paszke         throw std::invalid_argument("Shape can only be specified explicitly "
71499dee31eSAdam Paszke                                     "when the type is not a shaped type.");
71599dee31eSAdam Paszke       }
71699dee31eSAdam Paszke       shapedType = *bulkLoadElementType;
71799dee31eSAdam Paszke     } else {
718*71a25454SPeter Hawkins       shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
719*71a25454SPeter Hawkins                                            *bulkLoadElementType, encodingAttr);
72099dee31eSAdam Paszke     }
721*71a25454SPeter Hawkins     size_t rawBufferSize = view.len;
722*71a25454SPeter Hawkins     MlirAttribute attr =
723*71a25454SPeter Hawkins         mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
7245d6d30edSStella Laurenzo     if (mlirAttributeIsNull(attr)) {
7255d6d30edSStella Laurenzo       throw std::invalid_argument(
7265d6d30edSStella Laurenzo           "DenseElementsAttr could not be constructed from the given buffer. "
7275d6d30edSStella Laurenzo           "This may mean that the Python buffer layout does not match that "
7285d6d30edSStella Laurenzo           "MLIR expected layout and is a bug.");
7295d6d30edSStella Laurenzo     }
7305d6d30edSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
7315d6d30edSStella Laurenzo   }
732436c6c9cSStella Laurenzo 
7331fc096afSMehdi Amini   static PyDenseElementsAttribute getSplat(const PyType &shapedType,
734436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
735436c6c9cSStella Laurenzo     auto contextWrapper =
736436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
737436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
738436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
739436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
740436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
7414811270bSmax       throw py::value_error(message);
742436c6c9cSStella Laurenzo     }
743436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
744436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
745436c6c9cSStella Laurenzo       std::string message =
746436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
747436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
7484811270bSmax       throw py::value_error(message);
749436c6c9cSStella Laurenzo     }
750436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
751436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
752436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
753436c6c9cSStella Laurenzo       std::string message =
754436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
755436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
756436c6c9cSStella Laurenzo       message.append(", element=");
757436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
7584811270bSmax       throw py::value_error(message);
759436c6c9cSStella Laurenzo     }
760436c6c9cSStella Laurenzo 
761436c6c9cSStella Laurenzo     MlirAttribute elements =
762436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
763436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
764436c6c9cSStella Laurenzo   }
765436c6c9cSStella Laurenzo 
766436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
767436c6c9cSStella Laurenzo 
768436c6c9cSStella Laurenzo   py::buffer_info accessBuffer() {
769436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
770436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
7715d6d30edSStella Laurenzo     std::string format;
772436c6c9cSStella Laurenzo 
773436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
774436c6c9cSStella Laurenzo       // f32
7755d6d30edSStella Laurenzo       return bufferInfo<float>(shapedType);
77602b6fb21SMehdi Amini     }
77702b6fb21SMehdi Amini     if (mlirTypeIsAF64(elementType)) {
778436c6c9cSStella Laurenzo       // f64
7795d6d30edSStella Laurenzo       return bufferInfo<double>(shapedType);
780bb56c2b3SMehdi Amini     }
781bb56c2b3SMehdi Amini     if (mlirTypeIsAF16(elementType)) {
7825d6d30edSStella Laurenzo       // f16
7835d6d30edSStella Laurenzo       return bufferInfo<uint16_t>(shapedType, "e");
784bb56c2b3SMehdi Amini     }
785ef1b735dSmax     if (mlirTypeIsAIndex(elementType)) {
786ef1b735dSmax       // Same as IndexType::kInternalStorageBitWidth
787ef1b735dSmax       return bufferInfo<int64_t>(shapedType);
788ef1b735dSmax     }
789bb56c2b3SMehdi Amini     if (mlirTypeIsAInteger(elementType) &&
790436c6c9cSStella Laurenzo         mlirIntegerTypeGetWidth(elementType) == 32) {
791436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
792436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
793436c6c9cSStella Laurenzo         // i32
7945d6d30edSStella Laurenzo         return bufferInfo<int32_t>(shapedType);
795e5639b3fSMehdi Amini       }
796e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
797436c6c9cSStella Laurenzo         // unsigned i32
7985d6d30edSStella Laurenzo         return bufferInfo<uint32_t>(shapedType);
799436c6c9cSStella Laurenzo       }
800436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
801436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
802436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
803436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
804436c6c9cSStella Laurenzo         // i64
8055d6d30edSStella Laurenzo         return bufferInfo<int64_t>(shapedType);
806e5639b3fSMehdi Amini       }
807e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
808436c6c9cSStella Laurenzo         // unsigned i64
8095d6d30edSStella Laurenzo         return bufferInfo<uint64_t>(shapedType);
8105d6d30edSStella Laurenzo       }
8115d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
8125d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 8) {
8135d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
8145d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
8155d6d30edSStella Laurenzo         // i8
8165d6d30edSStella Laurenzo         return bufferInfo<int8_t>(shapedType);
817e5639b3fSMehdi Amini       }
818e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
8195d6d30edSStella Laurenzo         // unsigned i8
8205d6d30edSStella Laurenzo         return bufferInfo<uint8_t>(shapedType);
8215d6d30edSStella Laurenzo       }
8225d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
8235d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 16) {
8245d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
8255d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
8265d6d30edSStella Laurenzo         // i16
8275d6d30edSStella Laurenzo         return bufferInfo<int16_t>(shapedType);
828e5639b3fSMehdi Amini       }
829e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
8305d6d30edSStella Laurenzo         // unsigned i16
8315d6d30edSStella Laurenzo         return bufferInfo<uint16_t>(shapedType);
832436c6c9cSStella Laurenzo       }
833436c6c9cSStella Laurenzo     }
834436c6c9cSStella Laurenzo 
835c5f445d1SStella Laurenzo     // TODO: Currently crashes the program.
8365d6d30edSStella Laurenzo     // Reported as https://github.com/pybind/pybind11/issues/3336
837c5f445d1SStella Laurenzo     throw std::invalid_argument(
838c5f445d1SStella Laurenzo         "unsupported data type for conversion to Python buffer");
839436c6c9cSStella Laurenzo   }
840436c6c9cSStella Laurenzo 
841436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
842436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
843436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
844436c6c9cSStella Laurenzo                     py::arg("array"), py::arg("signless") = true,
8455d6d30edSStella Laurenzo                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
846436c6c9cSStella Laurenzo                     py::arg("context") = py::none(),
8475d6d30edSStella Laurenzo                     kDenseElementsAttrGetDocstring)
848436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
849436c6c9cSStella Laurenzo                     py::arg("shaped_type"), py::arg("element_attr"),
850436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
851436c6c9cSStella Laurenzo         .def_property_readonly("is_splat",
852436c6c9cSStella Laurenzo                                [](PyDenseElementsAttribute &self) -> bool {
853436c6c9cSStella Laurenzo                                  return mlirDenseElementsAttrIsSplat(self);
854436c6c9cSStella Laurenzo                                })
85591259963SAdam Paszke         .def("get_splat_value",
856974c1596SRahul Kayaith              [](PyDenseElementsAttribute &self) {
857974c1596SRahul Kayaith                if (!mlirDenseElementsAttrIsSplat(self))
8584811270bSmax                  throw py::value_error(
85991259963SAdam Paszke                      "get_splat_value called on a non-splat attribute");
860974c1596SRahul Kayaith                return mlirDenseElementsAttrGetSplatValue(self);
86191259963SAdam Paszke              })
862436c6c9cSStella Laurenzo         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
863436c6c9cSStella Laurenzo   }
864436c6c9cSStella Laurenzo 
865436c6c9cSStella Laurenzo private:
866*71a25454SPeter Hawkins   static bool isUnsignedIntegerFormat(std::string_view format) {
867436c6c9cSStella Laurenzo     if (format.empty())
868436c6c9cSStella Laurenzo       return false;
869436c6c9cSStella Laurenzo     char code = format[0];
870436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
871436c6c9cSStella Laurenzo            code == 'Q';
872436c6c9cSStella Laurenzo   }
873436c6c9cSStella Laurenzo 
874*71a25454SPeter Hawkins   static bool isSignedIntegerFormat(std::string_view format) {
875436c6c9cSStella Laurenzo     if (format.empty())
876436c6c9cSStella Laurenzo       return false;
877436c6c9cSStella Laurenzo     char code = format[0];
878436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
879436c6c9cSStella Laurenzo            code == 'q';
880436c6c9cSStella Laurenzo   }
881436c6c9cSStella Laurenzo 
882436c6c9cSStella Laurenzo   template <typename Type>
883436c6c9cSStella Laurenzo   py::buffer_info bufferInfo(MlirType shapedType,
8845d6d30edSStella Laurenzo                              const char *explicitFormat = nullptr) {
885436c6c9cSStella Laurenzo     intptr_t rank = mlirShapedTypeGetRank(shapedType);
886436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
887436c6c9cSStella Laurenzo     // Buffer is configured for read-only access below.
888436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
889436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
890436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
891436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
892436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
893436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
894436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
895436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
896f0e847d0SRahul Kayaith     if (mlirDenseElementsAttrIsSplat(*this)) {
897f0e847d0SRahul Kayaith       // Splats are special, only the single value is stored.
898f0e847d0SRahul Kayaith       strides.assign(rank, 0);
899f0e847d0SRahul Kayaith     } else {
900436c6c9cSStella Laurenzo       for (intptr_t i = 1; i < rank; ++i) {
901f0e847d0SRahul Kayaith         intptr_t strideFactor = 1;
902f0e847d0SRahul Kayaith         for (intptr_t j = i; j < rank; ++j)
903436c6c9cSStella Laurenzo           strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
904436c6c9cSStella Laurenzo         strides.push_back(sizeof(Type) * strideFactor);
905436c6c9cSStella Laurenzo       }
906436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type));
907f0e847d0SRahul Kayaith     }
9085d6d30edSStella Laurenzo     std::string format;
9095d6d30edSStella Laurenzo     if (explicitFormat) {
9105d6d30edSStella Laurenzo       format = explicitFormat;
9115d6d30edSStella Laurenzo     } else {
9125d6d30edSStella Laurenzo       format = py::format_descriptor<Type>::format();
9135d6d30edSStella Laurenzo     }
9145d6d30edSStella Laurenzo     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
9155d6d30edSStella Laurenzo                            /*readonly=*/true);
916436c6c9cSStella Laurenzo   }
917436c6c9cSStella Laurenzo }; // namespace
918436c6c9cSStella Laurenzo 
919436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer
920436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access.
921436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
922436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
923436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
924436c6c9cSStella Laurenzo public:
925436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
926436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
927436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
928436c6c9cSStella Laurenzo 
929436c6c9cSStella Laurenzo   /// Returns the element at the given linear position. Asserts if the index is
930436c6c9cSStella Laurenzo   /// out of range.
931436c6c9cSStella Laurenzo   py::int_ dunderGetItem(intptr_t pos) {
932436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
9334811270bSmax       throw py::index_error("attempt to access out of bounds element");
934436c6c9cSStella Laurenzo     }
935436c6c9cSStella Laurenzo 
936436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
937436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
938436c6c9cSStella Laurenzo     assert(mlirTypeIsAInteger(type) &&
939436c6c9cSStella Laurenzo            "expected integer element type in dense int elements attribute");
940436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
941436c6c9cSStella Laurenzo     // elemental type of the attribute. py::int_ is implicitly constructible
942436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
943436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
944436c6c9cSStella Laurenzo     // querying them on each element access.
945436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
946436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
947436c6c9cSStella Laurenzo     if (isUnsigned) {
948436c6c9cSStella Laurenzo       if (width == 1) {
949436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
950436c6c9cSStella Laurenzo       }
951308d8b8cSRahul Kayaith       if (width == 8) {
952308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt8Value(*this, pos);
953308d8b8cSRahul Kayaith       }
954308d8b8cSRahul Kayaith       if (width == 16) {
955308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt16Value(*this, pos);
956308d8b8cSRahul Kayaith       }
957436c6c9cSStella Laurenzo       if (width == 32) {
958436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
959436c6c9cSStella Laurenzo       }
960436c6c9cSStella Laurenzo       if (width == 64) {
961436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
962436c6c9cSStella Laurenzo       }
963436c6c9cSStella Laurenzo     } else {
964436c6c9cSStella Laurenzo       if (width == 1) {
965436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
966436c6c9cSStella Laurenzo       }
967308d8b8cSRahul Kayaith       if (width == 8) {
968308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt8Value(*this, pos);
969308d8b8cSRahul Kayaith       }
970308d8b8cSRahul Kayaith       if (width == 16) {
971308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt16Value(*this, pos);
972308d8b8cSRahul Kayaith       }
973436c6c9cSStella Laurenzo       if (width == 32) {
974436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt32Value(*this, pos);
975436c6c9cSStella Laurenzo       }
976436c6c9cSStella Laurenzo       if (width == 64) {
977436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt64Value(*this, pos);
978436c6c9cSStella Laurenzo       }
979436c6c9cSStella Laurenzo     }
9804811270bSmax     throw py::type_error("Unsupported integer type");
981436c6c9cSStella Laurenzo   }
982436c6c9cSStella Laurenzo 
983436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
984436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
985436c6c9cSStella Laurenzo   }
986436c6c9cSStella Laurenzo };
987436c6c9cSStella Laurenzo 
988436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
989436c6c9cSStella Laurenzo public:
990436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
991436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
992436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
9939566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
9949566ee28Smax       mlirDictionaryAttrGetTypeID;
995436c6c9cSStella Laurenzo 
996436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
997436c6c9cSStella Laurenzo 
9989fb1086bSAdrian Kuegel   bool dunderContains(const std::string &name) {
9999fb1086bSAdrian Kuegel     return !mlirAttributeIsNull(
10009fb1086bSAdrian Kuegel         mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
10019fb1086bSAdrian Kuegel   }
10029fb1086bSAdrian Kuegel 
1003436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
10049fb1086bSAdrian Kuegel     c.def("__contains__", &PyDictAttribute::dunderContains);
1005436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
1006436c6c9cSStella Laurenzo     c.def_static(
1007436c6c9cSStella Laurenzo         "get",
1008436c6c9cSStella Laurenzo         [](py::dict attributes, DefaultingPyMlirContext context) {
1009436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
1010436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
1011436c6c9cSStella Laurenzo           for (auto &it : attributes) {
101202b6fb21SMehdi Amini             auto &mlirAttr = it.second.cast<PyAttribute &>();
1013436c6c9cSStella Laurenzo             auto name = it.first.cast<std::string>();
1014436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
101502b6fb21SMehdi Amini                 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
1016436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
101702b6fb21SMehdi Amini                 mlirAttr));
1018436c6c9cSStella Laurenzo           }
1019436c6c9cSStella Laurenzo           MlirAttribute attr =
1020436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
1021436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
1022436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
1023436c6c9cSStella Laurenzo         },
1024ed9e52f3SAlex Zinenko         py::arg("value") = py::dict(), py::arg("context") = py::none(),
1025436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
1026436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
1027436c6c9cSStella Laurenzo       MlirAttribute attr =
1028436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
1029974c1596SRahul Kayaith       if (mlirAttributeIsNull(attr))
10304811270bSmax         throw py::key_error("attempt to access a non-existent attribute");
1031974c1596SRahul Kayaith       return attr;
1032436c6c9cSStella Laurenzo     });
1033436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
1034436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
10354811270bSmax         throw py::index_error("attempt to access out of bounds attribute");
1036436c6c9cSStella Laurenzo       }
1037436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
1038436c6c9cSStella Laurenzo       return PyNamedAttribute(
1039436c6c9cSStella Laurenzo           namedAttr.attribute,
1040436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
1041436c6c9cSStella Laurenzo     });
1042436c6c9cSStella Laurenzo   }
1043436c6c9cSStella Laurenzo };
1044436c6c9cSStella Laurenzo 
1045436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
1046436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
1047436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
1048436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1049436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
1050436c6c9cSStella Laurenzo public:
1051436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1052436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
1053436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1054436c6c9cSStella Laurenzo 
1055436c6c9cSStella Laurenzo   py::float_ dunderGetItem(intptr_t pos) {
1056436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
10574811270bSmax       throw py::index_error("attempt to access out of bounds element");
1058436c6c9cSStella Laurenzo     }
1059436c6c9cSStella Laurenzo 
1060436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
1061436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
1062436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
1063436c6c9cSStella Laurenzo     // elemental type of the attribute. py::float_ is implicitly constructible
1064436c6c9cSStella Laurenzo     // from float and double.
1065436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
1066436c6c9cSStella Laurenzo     // querying them on each element access.
1067436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
1068436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetFloatValue(*this, pos);
1069436c6c9cSStella Laurenzo     }
1070436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
1071436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
1072436c6c9cSStella Laurenzo     }
10734811270bSmax     throw py::type_error("Unsupported floating-point type");
1074436c6c9cSStella Laurenzo   }
1075436c6c9cSStella Laurenzo 
1076436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1077436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1078436c6c9cSStella Laurenzo   }
1079436c6c9cSStella Laurenzo };
1080436c6c9cSStella Laurenzo 
1081436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1082436c6c9cSStella Laurenzo public:
1083436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1084436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
1085436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
10869566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
10879566ee28Smax       mlirTypeAttrGetTypeID;
1088436c6c9cSStella Laurenzo 
1089436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1090436c6c9cSStella Laurenzo     c.def_static(
1091436c6c9cSStella Laurenzo         "get",
1092436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
1093436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
1094436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
1095436c6c9cSStella Laurenzo         },
1096436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
1097436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
1098436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyTypeAttribute &self) {
1099bfb1ba75Smax       return mlirTypeAttrGetValue(self.get());
1100436c6c9cSStella Laurenzo     });
1101436c6c9cSStella Laurenzo   }
1102436c6c9cSStella Laurenzo };
1103436c6c9cSStella Laurenzo 
1104436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
1105436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1106436c6c9cSStella Laurenzo public:
1107436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1108436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
1109436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
11109566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
11119566ee28Smax       mlirUnitAttrGetTypeID;
1112436c6c9cSStella Laurenzo 
1113436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1114436c6c9cSStella Laurenzo     c.def_static(
1115436c6c9cSStella Laurenzo         "get",
1116436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
1117436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
1118436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
1119436c6c9cSStella Laurenzo         },
1120436c6c9cSStella Laurenzo         py::arg("context") = py::none(), "Create a Unit attribute.");
1121436c6c9cSStella Laurenzo   }
1122436c6c9cSStella Laurenzo };
1123436c6c9cSStella Laurenzo 
1124ac2e2d65SDenys Shabalin /// Strided layout attribute subclass.
1125ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute
1126ac2e2d65SDenys Shabalin     : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1127ac2e2d65SDenys Shabalin public:
1128ac2e2d65SDenys Shabalin   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1129ac2e2d65SDenys Shabalin   static constexpr const char *pyClassName = "StridedLayoutAttr";
1130ac2e2d65SDenys Shabalin   using PyConcreteAttribute::PyConcreteAttribute;
11319566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
11329566ee28Smax       mlirStridedLayoutAttrGetTypeID;
1133ac2e2d65SDenys Shabalin 
1134ac2e2d65SDenys Shabalin   static void bindDerived(ClassTy &c) {
1135ac2e2d65SDenys Shabalin     c.def_static(
1136ac2e2d65SDenys Shabalin         "get",
1137ac2e2d65SDenys Shabalin         [](int64_t offset, const std::vector<int64_t> strides,
1138ac2e2d65SDenys Shabalin            DefaultingPyMlirContext ctx) {
1139ac2e2d65SDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1140ac2e2d65SDenys Shabalin               ctx->get(), offset, strides.size(), strides.data());
1141ac2e2d65SDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1142ac2e2d65SDenys Shabalin         },
1143ac2e2d65SDenys Shabalin         py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
1144ac2e2d65SDenys Shabalin         "Gets a strided layout attribute.");
1145e3fd612eSDenys Shabalin     c.def_static(
1146e3fd612eSDenys Shabalin         "get_fully_dynamic",
1147e3fd612eSDenys Shabalin         [](int64_t rank, DefaultingPyMlirContext ctx) {
1148e3fd612eSDenys Shabalin           auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
1149e3fd612eSDenys Shabalin           std::vector<int64_t> strides(rank);
1150e3fd612eSDenys Shabalin           std::fill(strides.begin(), strides.end(), dynamic);
1151e3fd612eSDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1152e3fd612eSDenys Shabalin               ctx->get(), dynamic, strides.size(), strides.data());
1153e3fd612eSDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1154e3fd612eSDenys Shabalin         },
1155e3fd612eSDenys Shabalin         py::arg("rank"), py::arg("context") = py::none(),
1156e3fd612eSDenys Shabalin         "Gets a strided layout attribute with dynamic offset and strides of a "
1157e3fd612eSDenys Shabalin         "given rank.");
1158ac2e2d65SDenys Shabalin     c.def_property_readonly(
1159ac2e2d65SDenys Shabalin         "offset",
1160ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1161ac2e2d65SDenys Shabalin           return mlirStridedLayoutAttrGetOffset(self);
1162ac2e2d65SDenys Shabalin         },
1163ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1164ac2e2d65SDenys Shabalin     c.def_property_readonly(
1165ac2e2d65SDenys Shabalin         "strides",
1166ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1167ac2e2d65SDenys Shabalin           intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1168ac2e2d65SDenys Shabalin           std::vector<int64_t> strides(size);
1169ac2e2d65SDenys Shabalin           for (intptr_t i = 0; i < size; i++) {
1170ac2e2d65SDenys Shabalin             strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1171ac2e2d65SDenys Shabalin           }
1172ac2e2d65SDenys Shabalin           return strides;
1173ac2e2d65SDenys Shabalin         },
1174ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1175ac2e2d65SDenys Shabalin   }
1176ac2e2d65SDenys Shabalin };
1177ac2e2d65SDenys Shabalin 
11789566ee28Smax py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
11799566ee28Smax   if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
11809566ee28Smax     return py::cast(PyDenseBoolArrayAttribute(pyAttribute));
11819566ee28Smax   if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
11829566ee28Smax     return py::cast(PyDenseI8ArrayAttribute(pyAttribute));
11839566ee28Smax   if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
11849566ee28Smax     return py::cast(PyDenseI16ArrayAttribute(pyAttribute));
11859566ee28Smax   if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
11869566ee28Smax     return py::cast(PyDenseI32ArrayAttribute(pyAttribute));
11879566ee28Smax   if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
11889566ee28Smax     return py::cast(PyDenseI64ArrayAttribute(pyAttribute));
11899566ee28Smax   if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
11909566ee28Smax     return py::cast(PyDenseF32ArrayAttribute(pyAttribute));
11919566ee28Smax   if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
11929566ee28Smax     return py::cast(PyDenseF64ArrayAttribute(pyAttribute));
11939566ee28Smax   std::string msg =
11949566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
11959566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
11969566ee28Smax   throw py::cast_error(msg);
11979566ee28Smax }
11989566ee28Smax 
11999566ee28Smax py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
12009566ee28Smax   if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
12019566ee28Smax     return py::cast(PyDenseFPElementsAttribute(pyAttribute));
12029566ee28Smax   if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
12039566ee28Smax     return py::cast(PyDenseIntElementsAttribute(pyAttribute));
12049566ee28Smax   std::string msg =
12059566ee28Smax       std::string(
12069566ee28Smax           "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
12079566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
12089566ee28Smax   throw py::cast_error(msg);
12099566ee28Smax }
12109566ee28Smax 
12119566ee28Smax py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
12129566ee28Smax   if (PyBoolAttribute::isaFunction(pyAttribute))
12139566ee28Smax     return py::cast(PyBoolAttribute(pyAttribute));
12149566ee28Smax   if (PyIntegerAttribute::isaFunction(pyAttribute))
12159566ee28Smax     return py::cast(PyIntegerAttribute(pyAttribute));
12169566ee28Smax   std::string msg =
12179566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
12189566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
12199566ee28Smax   throw py::cast_error(msg);
12209566ee28Smax }
12219566ee28Smax 
12224eee9ef9Smax py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
12234eee9ef9Smax   if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
12244eee9ef9Smax     return py::cast(PyFlatSymbolRefAttribute(pyAttribute));
12254eee9ef9Smax   if (PySymbolRefAttribute::isaFunction(pyAttribute))
12264eee9ef9Smax     return py::cast(PySymbolRefAttribute(pyAttribute));
12274eee9ef9Smax   std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
12284eee9ef9Smax                     std::string(py::repr(py::cast(pyAttribute))) + ")";
12294eee9ef9Smax   throw py::cast_error(msg);
12304eee9ef9Smax }
12314eee9ef9Smax 
1232436c6c9cSStella Laurenzo } // namespace
1233436c6c9cSStella Laurenzo 
1234436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) {
1235436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
1236619fd8c2SJeff Niu 
1237619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::bind(m);
1238619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1239619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::bind(m);
1240619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1241619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::bind(m);
1242619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1243619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::bind(m);
1244619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1245619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::bind(m);
1246619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1247619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::bind(m);
1248619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1249619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::bind(m);
1250619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
12519566ee28Smax   PyGlobals::get().registerTypeCaster(
12529566ee28Smax       mlirDenseArrayAttrGetTypeID(),
12539566ee28Smax       pybind11::cpp_function(denseArrayAttributeCaster));
1254619fd8c2SJeff Niu 
1255436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
1256436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1257436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
1258436c6c9cSStella Laurenzo   PyDenseElementsAttribute::bind(m);
1259436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
1260436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
12619566ee28Smax   PyGlobals::get().registerTypeCaster(
12629566ee28Smax       mlirDenseIntOrFPElementsAttrGetTypeID(),
12639566ee28Smax       pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
12649566ee28Smax 
1265436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
12664eee9ef9Smax   PySymbolRefAttribute::bind(m);
12674eee9ef9Smax   PyGlobals::get().registerTypeCaster(
12684eee9ef9Smax       mlirSymbolRefAttrGetTypeID(),
12694eee9ef9Smax       pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster));
12704eee9ef9Smax 
1271436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
12725c3861b2SYun Long   PyOpaqueAttribute::bind(m);
1273436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
1274436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
1275436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
1276436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
12779566ee28Smax   PyGlobals::get().registerTypeCaster(
12789566ee28Smax       mlirIntegerAttrGetTypeID(),
12799566ee28Smax       pybind11::cpp_function(integerOrBoolAttributeCaster));
1280436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
1281ac2e2d65SDenys Shabalin 
1282ac2e2d65SDenys Shabalin   PyStridedLayoutAttribute::bind(m);
1283436c6c9cSStella Laurenzo }
1284