xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision 404d0e9966a46c29e6539e20d9295adcbc8bf9bf)
1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9a1fe1f5fSKazu Hirata #include <optional>
1071a25454SPeter Hawkins #include <string_view>
114811270bSmax #include <utility>
121fc096afSMehdi Amini 
13436c6c9cSStella Laurenzo #include "IRModule.h"
14436c6c9cSStella Laurenzo 
15436c6c9cSStella Laurenzo #include "PybindUtils.h"
161824e45cSKasper Nielsen #include <pybind11/numpy.h>
17436c6c9cSStella Laurenzo 
1871a25454SPeter Hawkins #include "llvm/ADT/ScopeExit.h"
19c912f0e7Spranavm-nvidia #include "llvm/Support/raw_ostream.h"
2071a25454SPeter Hawkins 
21436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
22436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
23bfb1ba75Smax #include "mlir/Bindings/Python/PybindAdaptors.h"
24436c6c9cSStella Laurenzo 
25436c6c9cSStella Laurenzo namespace py = pybind11;
26436c6c9cSStella Laurenzo using namespace mlir;
27436c6c9cSStella Laurenzo using namespace mlir::python;
28436c6c9cSStella Laurenzo 
29436c6c9cSStella Laurenzo using llvm::SmallVector;
30436c6c9cSStella Laurenzo 
315d6d30edSStella Laurenzo //------------------------------------------------------------------------------
325d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
335d6d30edSStella Laurenzo //------------------------------------------------------------------------------
345d6d30edSStella Laurenzo 
355d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] =
365d6d30edSStella Laurenzo     R"(Gets a DenseElementsAttr from a Python buffer or array.
375d6d30edSStella Laurenzo 
385d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based
395d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and
405d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these
415d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer.
425d6d30edSStella Laurenzo 
435d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided
445d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal
455d6d30edSStella Laurenzo representation:
465d6d30edSStella Laurenzo 
475d6d30edSStella Laurenzo   * Integer types (except for i1): the buffer must be byte aligned to the
485d6d30edSStella Laurenzo     next byte boundary.
495d6d30edSStella Laurenzo   * Floating point types: Must be bit-castable to the given floating point
505d6d30edSStella Laurenzo     size.
515d6d30edSStella Laurenzo   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
525d6d30edSStella Laurenzo     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
535d6d30edSStella Laurenzo     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
545d6d30edSStella Laurenzo 
555d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0
565d6d30edSStella Laurenzo or 255), then a splat will be created.
575d6d30edSStella Laurenzo 
585d6d30edSStella Laurenzo Args:
595d6d30edSStella Laurenzo   array: The array or buffer to convert.
605d6d30edSStella Laurenzo   signless: If inferring an appropriate MLIR type, use signless types for
615d6d30edSStella Laurenzo     integers (defaults True).
625d6d30edSStella Laurenzo   type: Skips inference of the MLIR element type and uses this instead. The
635d6d30edSStella Laurenzo     storage size must be consistent with the actual contents of the buffer.
645d6d30edSStella Laurenzo   shape: Overrides the shape of the buffer when constructing the MLIR
655d6d30edSStella Laurenzo     shaped type. This is needed when the physical and logical shape differ (as
665d6d30edSStella Laurenzo     for i1).
675d6d30edSStella Laurenzo   context: Explicit context, if not from context manager.
685d6d30edSStella Laurenzo 
695d6d30edSStella Laurenzo Returns:
705d6d30edSStella Laurenzo   DenseElementsAttr on success.
715d6d30edSStella Laurenzo 
725d6d30edSStella Laurenzo Raises:
735d6d30edSStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
745d6d30edSStella Laurenzo     type or if the buffer does not meet expectations.
755d6d30edSStella Laurenzo )";
765d6d30edSStella Laurenzo 
77c912f0e7Spranavm-nvidia static const char kDenseElementsAttrGetFromListDocstring[] =
78c912f0e7Spranavm-nvidia     R"(Gets a DenseElementsAttr from a Python list of attributes.
79c912f0e7Spranavm-nvidia 
80c912f0e7Spranavm-nvidia Note that it can be expensive to construct attributes individually.
81c912f0e7Spranavm-nvidia For a large number of elements, consider using a Python buffer or array instead.
82c912f0e7Spranavm-nvidia 
83c912f0e7Spranavm-nvidia Args:
84c912f0e7Spranavm-nvidia   attrs: A list of attributes.
85c912f0e7Spranavm-nvidia   type: The desired shape and type of the resulting DenseElementsAttr.
86c912f0e7Spranavm-nvidia     If not provided, the element type is determined based on the type
87c912f0e7Spranavm-nvidia     of the 0th attribute and the shape is `[len(attrs)]`.
88c912f0e7Spranavm-nvidia   context: Explicit context, if not from context manager.
89c912f0e7Spranavm-nvidia 
90c912f0e7Spranavm-nvidia Returns:
91c912f0e7Spranavm-nvidia   DenseElementsAttr on success.
92c912f0e7Spranavm-nvidia 
93c912f0e7Spranavm-nvidia Raises:
94c912f0e7Spranavm-nvidia   ValueError: If the type of the attributes does not match the type
95c912f0e7Spranavm-nvidia     specified by `shaped_type`.
96c912f0e7Spranavm-nvidia )";
97c912f0e7Spranavm-nvidia 
98f66cd9e9SStella Laurenzo static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
99f66cd9e9SStella Laurenzo     R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
100f66cd9e9SStella Laurenzo 
101f66cd9e9SStella Laurenzo This function does minimal validation or massaging of the data, and it is
102f66cd9e9SStella Laurenzo up to the caller to ensure that the buffer meets the characteristics
103f66cd9e9SStella Laurenzo implied by the shape.
104f66cd9e9SStella Laurenzo 
105f66cd9e9SStella Laurenzo The backing buffer and any user objects will be retained for the lifetime
106f66cd9e9SStella Laurenzo of the resource blob. This is typically bounded to the context but the
107f66cd9e9SStella Laurenzo resource can have a shorter lifespan depending on how it is used in
108f66cd9e9SStella Laurenzo subsequent processing.
109f66cd9e9SStella Laurenzo 
110f66cd9e9SStella Laurenzo Args:
111f66cd9e9SStella Laurenzo   buffer: The array or buffer to convert.
112f66cd9e9SStella Laurenzo   name: Name to provide to the resource (may be changed upon collision).
113f66cd9e9SStella Laurenzo   type: The explicit ShapedType to construct the attribute with.
114f66cd9e9SStella Laurenzo   context: Explicit context, if not from context manager.
115f66cd9e9SStella Laurenzo 
116f66cd9e9SStella Laurenzo Returns:
117f66cd9e9SStella Laurenzo   DenseResourceElementsAttr on success.
118f66cd9e9SStella Laurenzo 
119f66cd9e9SStella Laurenzo Raises:
120f66cd9e9SStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
121f66cd9e9SStella Laurenzo     type or if the buffer does not meet expectations.
122f66cd9e9SStella Laurenzo )";
123f66cd9e9SStella Laurenzo 
124436c6c9cSStella Laurenzo namespace {
125436c6c9cSStella Laurenzo 
126436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
127436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
128436c6c9cSStella Laurenzo }
129436c6c9cSStella Laurenzo 
130436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
131436c6c9cSStella Laurenzo public:
132436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
133436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
134436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1359566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1369566ee28Smax       mlirAffineMapAttrGetTypeID;
137436c6c9cSStella Laurenzo 
138436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
139436c6c9cSStella Laurenzo     c.def_static(
140436c6c9cSStella Laurenzo         "get",
141436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
142436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
143436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
144436c6c9cSStella Laurenzo         },
145436c6c9cSStella Laurenzo         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
146c36b4248SBimo     c.def_property_readonly("value", mlirAffineMapAttrGetValue,
147c36b4248SBimo                             "Returns the value of the AffineMap attribute");
148436c6c9cSStella Laurenzo   }
149436c6c9cSStella Laurenzo };
150436c6c9cSStella Laurenzo 
151334873feSAmy Wang class PyIntegerSetAttribute
152334873feSAmy Wang     : public PyConcreteAttribute<PyIntegerSetAttribute> {
153334873feSAmy Wang public:
154334873feSAmy Wang   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
155334873feSAmy Wang   static constexpr const char *pyClassName = "IntegerSetAttr";
156334873feSAmy Wang   using PyConcreteAttribute::PyConcreteAttribute;
157334873feSAmy Wang   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
158334873feSAmy Wang       mlirIntegerSetAttrGetTypeID;
159334873feSAmy Wang 
160334873feSAmy Wang   static void bindDerived(ClassTy &c) {
161334873feSAmy Wang     c.def_static(
162334873feSAmy Wang         "get",
163334873feSAmy Wang         [](PyIntegerSet &integerSet) {
164334873feSAmy Wang           MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
165334873feSAmy Wang           return PyIntegerSetAttribute(integerSet.getContext(), attr);
166334873feSAmy Wang         },
167334873feSAmy Wang         py::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
168334873feSAmy Wang   }
169334873feSAmy Wang };
170334873feSAmy Wang 
171ed9e52f3SAlex Zinenko template <typename T>
172ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) {
173ed9e52f3SAlex Zinenko   try {
174ed9e52f3SAlex Zinenko     return object.cast<T>();
175ed9e52f3SAlex Zinenko   } catch (py::cast_error &err) {
176ed9e52f3SAlex Zinenko     std::string msg =
177ed9e52f3SAlex Zinenko         std::string(
178ed9e52f3SAlex Zinenko             "Invalid attribute when attempting to create an ArrayAttribute (") +
179ed9e52f3SAlex Zinenko         err.what() + ")";
180ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
181ed9e52f3SAlex Zinenko   } catch (py::reference_cast_error &err) {
182ed9e52f3SAlex Zinenko     std::string msg = std::string("Invalid attribute (None?) when attempting "
183ed9e52f3SAlex Zinenko                                   "to create an ArrayAttribute (") +
184ed9e52f3SAlex Zinenko                       err.what() + ")";
185ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
186ed9e52f3SAlex Zinenko   }
187ed9e52f3SAlex Zinenko }
188ed9e52f3SAlex Zinenko 
189619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived
190619fd8c2SJeff Niu /// implementation class.
191619fd8c2SJeff Niu template <typename EltTy, typename DerivedT>
192133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
193619fd8c2SJeff Niu public:
194133624acSJeff Niu   using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
195619fd8c2SJeff Niu 
196619fd8c2SJeff Niu   /// Iterator over the integer elements of a dense array.
197619fd8c2SJeff Niu   class PyDenseArrayIterator {
198619fd8c2SJeff Niu   public:
1994a1b1196SMehdi Amini     PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
200619fd8c2SJeff Niu 
201619fd8c2SJeff Niu     /// Return a copy of the iterator.
202619fd8c2SJeff Niu     PyDenseArrayIterator dunderIter() { return *this; }
203619fd8c2SJeff Niu 
204619fd8c2SJeff Niu     /// Return the next element.
205619fd8c2SJeff Niu     EltTy dunderNext() {
206619fd8c2SJeff Niu       // Throw if the index has reached the end.
207619fd8c2SJeff Niu       if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
208619fd8c2SJeff Niu         throw py::stop_iteration();
209619fd8c2SJeff Niu       return DerivedT::getElement(attr.get(), nextIndex++);
210619fd8c2SJeff Niu     }
211619fd8c2SJeff Niu 
212619fd8c2SJeff Niu     /// Bind the iterator class.
213619fd8c2SJeff Niu     static void bind(py::module &m) {
214619fd8c2SJeff Niu       py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName,
215619fd8c2SJeff Niu                                        py::module_local())
216619fd8c2SJeff Niu           .def("__iter__", &PyDenseArrayIterator::dunderIter)
217619fd8c2SJeff Niu           .def("__next__", &PyDenseArrayIterator::dunderNext);
218619fd8c2SJeff Niu     }
219619fd8c2SJeff Niu 
220619fd8c2SJeff Niu   private:
221619fd8c2SJeff Niu     /// The referenced dense array attribute.
222619fd8c2SJeff Niu     PyAttribute attr;
223619fd8c2SJeff Niu     /// The next index to read.
224619fd8c2SJeff Niu     int nextIndex = 0;
225619fd8c2SJeff Niu   };
226619fd8c2SJeff Niu 
227619fd8c2SJeff Niu   /// Get the element at the given index.
228619fd8c2SJeff Niu   EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
229619fd8c2SJeff Niu 
230619fd8c2SJeff Niu   /// Bind the attribute class.
231133624acSJeff Niu   static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
232619fd8c2SJeff Niu     // Bind the constructor.
233619fd8c2SJeff Niu     c.def_static(
234619fd8c2SJeff Niu         "get",
235619fd8c2SJeff Niu         [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
2368dcb6722SIngo Müller           return getAttribute(values, ctx->getRef());
237619fd8c2SJeff Niu         },
238619fd8c2SJeff Niu         py::arg("values"), py::arg("context") = py::none(),
239619fd8c2SJeff Niu         "Gets a uniqued dense array attribute");
240619fd8c2SJeff Niu     // Bind the array methods.
241133624acSJeff Niu     c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
242619fd8c2SJeff Niu       if (i >= mlirDenseArrayGetNumElements(arr))
243619fd8c2SJeff Niu         throw py::index_error("DenseArray index out of range");
244619fd8c2SJeff Niu       return arr.getItem(i);
245619fd8c2SJeff Niu     });
246133624acSJeff Niu     c.def("__len__", [](const DerivedT &arr) {
247619fd8c2SJeff Niu       return mlirDenseArrayGetNumElements(arr);
248619fd8c2SJeff Niu     });
249133624acSJeff Niu     c.def("__iter__",
250133624acSJeff Niu           [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
2514a1b1196SMehdi Amini     c.def("__add__", [](DerivedT &arr, const py::list &extras) {
252619fd8c2SJeff Niu       std::vector<EltTy> values;
253619fd8c2SJeff Niu       intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
254619fd8c2SJeff Niu       values.reserve(numOldElements + py::len(extras));
255619fd8c2SJeff Niu       for (intptr_t i = 0; i < numOldElements; ++i)
256619fd8c2SJeff Niu         values.push_back(arr.getItem(i));
257619fd8c2SJeff Niu       for (py::handle attr : extras)
258619fd8c2SJeff Niu         values.push_back(pyTryCast<EltTy>(attr));
2598dcb6722SIngo Müller       return getAttribute(values, arr.getContext());
260619fd8c2SJeff Niu     });
261619fd8c2SJeff Niu   }
2628dcb6722SIngo Müller 
2638dcb6722SIngo Müller private:
2648dcb6722SIngo Müller   static DerivedT getAttribute(const std::vector<EltTy> &values,
2658dcb6722SIngo Müller                                PyMlirContextRef ctx) {
2668dcb6722SIngo Müller     if constexpr (std::is_same_v<EltTy, bool>) {
2678dcb6722SIngo Müller       std::vector<int> intValues(values.begin(), values.end());
2688dcb6722SIngo Müller       MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
2698dcb6722SIngo Müller                                                   intValues.data());
2708dcb6722SIngo Müller       return DerivedT(ctx, attr);
2718dcb6722SIngo Müller     } else {
2728dcb6722SIngo Müller       MlirAttribute attr =
2738dcb6722SIngo Müller           DerivedT::getAttribute(ctx->get(), values.size(), values.data());
2748dcb6722SIngo Müller       return DerivedT(ctx, attr);
2758dcb6722SIngo Müller     }
2768dcb6722SIngo Müller   }
277619fd8c2SJeff Niu };
278619fd8c2SJeff Niu 
279619fd8c2SJeff Niu /// Instantiate the python dense array classes.
280619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute
2818dcb6722SIngo Müller     : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
282619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
283619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseBoolArrayGet;
284619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseBoolArrayGetElement;
285619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseBoolArrayAttr";
286619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
287619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
288619fd8c2SJeff Niu };
289619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute
290619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
291619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
292619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI8ArrayGet;
293619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI8ArrayGetElement;
294619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI8ArrayAttr";
295619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
296619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
297619fd8c2SJeff Niu };
298619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute
299619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
300619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
301619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI16ArrayGet;
302619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI16ArrayGetElement;
303619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI16ArrayAttr";
304619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
305619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
306619fd8c2SJeff Niu };
307619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute
308619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
309619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
310619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI32ArrayGet;
311619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI32ArrayGetElement;
312619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI32ArrayAttr";
313619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
314619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
315619fd8c2SJeff Niu };
316619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute
317619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
318619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
319619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI64ArrayGet;
320619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI64ArrayGetElement;
321619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI64ArrayAttr";
322619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
323619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
324619fd8c2SJeff Niu };
325619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute
326619fd8c2SJeff Niu     : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
327619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
328619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF32ArrayGet;
329619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF32ArrayGetElement;
330619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF32ArrayAttr";
331619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
332619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
333619fd8c2SJeff Niu };
334619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute
335619fd8c2SJeff Niu     : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
336619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
337619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF64ArrayGet;
338619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF64ArrayGetElement;
339619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF64ArrayAttr";
340619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
341619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
342619fd8c2SJeff Niu };
343619fd8c2SJeff Niu 
344436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
345436c6c9cSStella Laurenzo public:
346436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
347436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
348436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
3499566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
3509566ee28Smax       mlirArrayAttrGetTypeID;
351436c6c9cSStella Laurenzo 
352436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
353436c6c9cSStella Laurenzo   public:
3541fc096afSMehdi Amini     PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
355436c6c9cSStella Laurenzo 
356436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
357436c6c9cSStella Laurenzo 
358974c1596SRahul Kayaith     MlirAttribute dunderNext() {
359bca88952SJeff Niu       // TODO: Throw is an inefficient way to stop iteration.
360bca88952SJeff Niu       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
361436c6c9cSStella Laurenzo         throw py::stop_iteration();
362974c1596SRahul Kayaith       return mlirArrayAttrGetElement(attr.get(), nextIndex++);
363436c6c9cSStella Laurenzo     }
364436c6c9cSStella Laurenzo 
365436c6c9cSStella Laurenzo     static void bind(py::module &m) {
366f05ff4f7SStella Laurenzo       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
367f05ff4f7SStella Laurenzo                                            py::module_local())
368436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
369436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
370436c6c9cSStella Laurenzo     }
371436c6c9cSStella Laurenzo 
372436c6c9cSStella Laurenzo   private:
373436c6c9cSStella Laurenzo     PyAttribute attr;
374436c6c9cSStella Laurenzo     int nextIndex = 0;
375436c6c9cSStella Laurenzo   };
376436c6c9cSStella Laurenzo 
377974c1596SRahul Kayaith   MlirAttribute getItem(intptr_t i) {
378974c1596SRahul Kayaith     return mlirArrayAttrGetElement(*this, i);
379ed9e52f3SAlex Zinenko   }
380ed9e52f3SAlex Zinenko 
381436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
382436c6c9cSStella Laurenzo     c.def_static(
383436c6c9cSStella Laurenzo         "get",
384436c6c9cSStella Laurenzo         [](py::list attributes, DefaultingPyMlirContext context) {
385436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
386436c6c9cSStella Laurenzo           mlirAttributes.reserve(py::len(attributes));
387436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
388ed9e52f3SAlex Zinenko             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
389436c6c9cSStella Laurenzo           }
390436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
391436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
392436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
393436c6c9cSStella Laurenzo         },
394436c6c9cSStella Laurenzo         py::arg("attributes"), py::arg("context") = py::none(),
395436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
396436c6c9cSStella Laurenzo     c.def("__getitem__",
397436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
398436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
399436c6c9cSStella Laurenzo               throw py::index_error("ArrayAttribute index out of range");
400ed9e52f3SAlex Zinenko             return arr.getItem(i);
401436c6c9cSStella Laurenzo           })
402436c6c9cSStella Laurenzo         .def("__len__",
403436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
404436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
405436c6c9cSStella Laurenzo              })
406436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
407436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
408436c6c9cSStella Laurenzo         });
409ed9e52f3SAlex Zinenko     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
410ed9e52f3SAlex Zinenko       std::vector<MlirAttribute> attributes;
411ed9e52f3SAlex Zinenko       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
412ed9e52f3SAlex Zinenko       attributes.reserve(numOldElements + py::len(extras));
413ed9e52f3SAlex Zinenko       for (intptr_t i = 0; i < numOldElements; ++i)
414ed9e52f3SAlex Zinenko         attributes.push_back(arr.getItem(i));
415ed9e52f3SAlex Zinenko       for (py::handle attr : extras)
416ed9e52f3SAlex Zinenko         attributes.push_back(pyTryCast<PyAttribute>(attr));
417ed9e52f3SAlex Zinenko       MlirAttribute arrayAttr = mlirArrayAttrGet(
418ed9e52f3SAlex Zinenko           arr.getContext()->get(), attributes.size(), attributes.data());
419ed9e52f3SAlex Zinenko       return PyArrayAttribute(arr.getContext(), arrayAttr);
420ed9e52f3SAlex Zinenko     });
421436c6c9cSStella Laurenzo   }
422436c6c9cSStella Laurenzo };
423436c6c9cSStella Laurenzo 
424436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
425436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
426436c6c9cSStella Laurenzo public:
427436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
428436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
429436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
4309566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
4319566ee28Smax       mlirFloatAttrGetTypeID;
432436c6c9cSStella Laurenzo 
433436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
434436c6c9cSStella Laurenzo     c.def_static(
435436c6c9cSStella Laurenzo         "get",
436436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
4373ea4c501SRahul Kayaith           PyMlirContext::ErrorCapture errors(loc->getContext());
438436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
4393ea4c501SRahul Kayaith           if (mlirAttributeIsNull(attr))
4403ea4c501SRahul Kayaith             throw MLIRError("Invalid attribute", errors.take());
441436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
442436c6c9cSStella Laurenzo         },
443436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
444436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
445436c6c9cSStella Laurenzo     c.def_static(
446436c6c9cSStella Laurenzo         "get_f32",
447436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
448436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
449436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
450436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
451436c6c9cSStella Laurenzo         },
452436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
453436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
454436c6c9cSStella Laurenzo     c.def_static(
455436c6c9cSStella Laurenzo         "get_f64",
456436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
457436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
458436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
459436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
460436c6c9cSStella Laurenzo         },
461436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
462436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
4632a5d4974SIngo Müller     c.def_property_readonly("value", mlirFloatAttrGetValueDouble,
4642a5d4974SIngo Müller                             "Returns the value of the float attribute");
4652a5d4974SIngo Müller     c.def("__float__", mlirFloatAttrGetValueDouble,
4662a5d4974SIngo Müller           "Converts the value of the float attribute to a Python float");
467436c6c9cSStella Laurenzo   }
468436c6c9cSStella Laurenzo };
469436c6c9cSStella Laurenzo 
470436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
471436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
472436c6c9cSStella Laurenzo public:
473436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
474436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
475436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
476436c6c9cSStella Laurenzo 
477436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
478436c6c9cSStella Laurenzo     c.def_static(
479436c6c9cSStella Laurenzo         "get",
480436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
481436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
482436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
483436c6c9cSStella Laurenzo         },
484436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"),
485436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
4862a5d4974SIngo Müller     c.def_property_readonly("value", toPyInt,
4872a5d4974SIngo Müller                             "Returns the value of the integer attribute");
4882a5d4974SIngo Müller     c.def("__int__", toPyInt,
4892a5d4974SIngo Müller           "Converts the value of the integer attribute to a Python int");
4902a5d4974SIngo Müller     c.def_property_readonly_static("static_typeid",
4912a5d4974SIngo Müller                                    [](py::object & /*class*/) -> MlirTypeID {
4922a5d4974SIngo Müller                                      return mlirIntegerAttrGetTypeID();
4932a5d4974SIngo Müller                                    });
4942a5d4974SIngo Müller   }
4952a5d4974SIngo Müller 
4962a5d4974SIngo Müller private:
4972a5d4974SIngo Müller   static py::int_ toPyInt(PyIntegerAttribute &self) {
498e9db306dSrkayaith     MlirType type = mlirAttributeGetType(self);
499e9db306dSrkayaith     if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
500436c6c9cSStella Laurenzo       return mlirIntegerAttrGetValueInt(self);
501e9db306dSrkayaith     if (mlirIntegerTypeIsSigned(type))
502e9db306dSrkayaith       return mlirIntegerAttrGetValueSInt(self);
503e9db306dSrkayaith     return mlirIntegerAttrGetValueUInt(self);
504436c6c9cSStella Laurenzo   }
505436c6c9cSStella Laurenzo };
506436c6c9cSStella Laurenzo 
507436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
508436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
509436c6c9cSStella Laurenzo public:
510436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
511436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
512436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
513436c6c9cSStella Laurenzo 
514436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
515436c6c9cSStella Laurenzo     c.def_static(
516436c6c9cSStella Laurenzo         "get",
517436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
518436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
519436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
520436c6c9cSStella Laurenzo         },
521436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
522436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
5232a5d4974SIngo Müller     c.def_property_readonly("value", mlirBoolAttrGetValue,
524436c6c9cSStella Laurenzo                             "Returns the value of the bool attribute");
5252a5d4974SIngo Müller     c.def("__bool__", mlirBoolAttrGetValue,
5262a5d4974SIngo Müller           "Converts the value of the bool attribute to a Python bool");
527436c6c9cSStella Laurenzo   }
528436c6c9cSStella Laurenzo };
529436c6c9cSStella Laurenzo 
5304eee9ef9Smax class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
5314eee9ef9Smax public:
5324eee9ef9Smax   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
5334eee9ef9Smax   static constexpr const char *pyClassName = "SymbolRefAttr";
5344eee9ef9Smax   using PyConcreteAttribute::PyConcreteAttribute;
5354eee9ef9Smax 
5364eee9ef9Smax   static MlirAttribute fromList(const std::vector<std::string> &symbols,
5374eee9ef9Smax                                 PyMlirContext &context) {
5384eee9ef9Smax     if (symbols.empty())
5394eee9ef9Smax       throw std::runtime_error("SymbolRefAttr must be composed of at least "
5404eee9ef9Smax                                "one symbol.");
5414eee9ef9Smax     MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
5424eee9ef9Smax     SmallVector<MlirAttribute, 3> referenceAttrs;
5434eee9ef9Smax     for (size_t i = 1; i < symbols.size(); ++i) {
5444eee9ef9Smax       referenceAttrs.push_back(
5454eee9ef9Smax           mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
5464eee9ef9Smax     }
5474eee9ef9Smax     return mlirSymbolRefAttrGet(context.get(), rootSymbol,
5484eee9ef9Smax                                 referenceAttrs.size(), referenceAttrs.data());
5494eee9ef9Smax   }
5504eee9ef9Smax 
5514eee9ef9Smax   static void bindDerived(ClassTy &c) {
5524eee9ef9Smax     c.def_static(
5534eee9ef9Smax         "get",
5544eee9ef9Smax         [](const std::vector<std::string> &symbols,
5554eee9ef9Smax            DefaultingPyMlirContext context) {
5564eee9ef9Smax           return PySymbolRefAttribute::fromList(symbols, context.resolve());
5574eee9ef9Smax         },
5584eee9ef9Smax         py::arg("symbols"), py::arg("context") = py::none(),
5594eee9ef9Smax         "Gets a uniqued SymbolRef attribute from a list of symbol names");
5604eee9ef9Smax     c.def_property_readonly(
5614eee9ef9Smax         "value",
5624eee9ef9Smax         [](PySymbolRefAttribute &self) {
5634eee9ef9Smax           std::vector<std::string> symbols = {
5644eee9ef9Smax               unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
5654eee9ef9Smax           for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
5664eee9ef9Smax                ++i)
5674eee9ef9Smax             symbols.push_back(
5684eee9ef9Smax                 unwrap(mlirSymbolRefAttrGetRootReference(
5694eee9ef9Smax                            mlirSymbolRefAttrGetNestedReference(self, i)))
5704eee9ef9Smax                     .str());
5714eee9ef9Smax           return symbols;
5724eee9ef9Smax         },
5734eee9ef9Smax         "Returns the value of the SymbolRef attribute as a list[str]");
5744eee9ef9Smax   }
5754eee9ef9Smax };
5764eee9ef9Smax 
577436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
578436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
579436c6c9cSStella Laurenzo public:
580436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
581436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
582436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
583436c6c9cSStella Laurenzo 
584436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
585436c6c9cSStella Laurenzo     c.def_static(
586436c6c9cSStella Laurenzo         "get",
587436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
588436c6c9cSStella Laurenzo           MlirAttribute attr =
589436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
590436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
591436c6c9cSStella Laurenzo         },
592436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
593436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
594436c6c9cSStella Laurenzo     c.def_property_readonly(
595436c6c9cSStella Laurenzo         "value",
596436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
597436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
598436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
599436c6c9cSStella Laurenzo         },
600436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
601436c6c9cSStella Laurenzo   }
602436c6c9cSStella Laurenzo };
603436c6c9cSStella Laurenzo 
6045c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
6055c3861b2SYun Long public:
6065c3861b2SYun Long   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
6075c3861b2SYun Long   static constexpr const char *pyClassName = "OpaqueAttr";
6085c3861b2SYun Long   using PyConcreteAttribute::PyConcreteAttribute;
6099566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
6109566ee28Smax       mlirOpaqueAttrGetTypeID;
6115c3861b2SYun Long 
6125c3861b2SYun Long   static void bindDerived(ClassTy &c) {
6135c3861b2SYun Long     c.def_static(
6145c3861b2SYun Long         "get",
6155c3861b2SYun Long         [](std::string dialectNamespace, py::buffer buffer, PyType &type,
6165c3861b2SYun Long            DefaultingPyMlirContext context) {
6175c3861b2SYun Long           const py::buffer_info bufferInfo = buffer.request();
6185c3861b2SYun Long           intptr_t bufferSize = bufferInfo.size;
6195c3861b2SYun Long           MlirAttribute attr = mlirOpaqueAttrGet(
6205c3861b2SYun Long               context->get(), toMlirStringRef(dialectNamespace), bufferSize,
6215c3861b2SYun Long               static_cast<char *>(bufferInfo.ptr), type);
6225c3861b2SYun Long           return PyOpaqueAttribute(context->getRef(), attr);
6235c3861b2SYun Long         },
6245c3861b2SYun Long         py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"),
6255c3861b2SYun Long         py::arg("context") = py::none(), "Gets an Opaque attribute.");
6265c3861b2SYun Long     c.def_property_readonly(
6275c3861b2SYun Long         "dialect_namespace",
6285c3861b2SYun Long         [](PyOpaqueAttribute &self) {
6295c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
6305c3861b2SYun Long           return py::str(stringRef.data, stringRef.length);
6315c3861b2SYun Long         },
6325c3861b2SYun Long         "Returns the dialect namespace for the Opaque attribute as a string");
6335c3861b2SYun Long     c.def_property_readonly(
6345c3861b2SYun Long         "data",
6355c3861b2SYun Long         [](PyOpaqueAttribute &self) {
6365c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
63762bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
6385c3861b2SYun Long         },
63962bf6c2eSChris Jones         "Returns the data for the Opaqued attributes as `bytes`");
6405c3861b2SYun Long   }
6415c3861b2SYun Long };
6425c3861b2SYun Long 
643436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
644436c6c9cSStella Laurenzo public:
645436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
646436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
647436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
6489566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
6499566ee28Smax       mlirStringAttrGetTypeID;
650436c6c9cSStella Laurenzo 
651436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
652436c6c9cSStella Laurenzo     c.def_static(
653436c6c9cSStella Laurenzo         "get",
654436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
655436c6c9cSStella Laurenzo           MlirAttribute attr =
656436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
657436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
658436c6c9cSStella Laurenzo         },
659436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
660436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
661436c6c9cSStella Laurenzo     c.def_static(
662436c6c9cSStella Laurenzo         "get_typed",
663436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
664436c6c9cSStella Laurenzo           MlirAttribute attr =
665436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
666436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
667436c6c9cSStella Laurenzo         },
668a6e7d024SStella Laurenzo         py::arg("type"), py::arg("value"),
669436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
6709f533548SIngo Müller     c.def_property_readonly(
6719f533548SIngo Müller         "value",
6729f533548SIngo Müller         [](PyStringAttribute &self) {
6739f533548SIngo Müller           MlirStringRef stringRef = mlirStringAttrGetValue(self);
6749f533548SIngo Müller           return py::str(stringRef.data, stringRef.length);
6759f533548SIngo Müller         },
676436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
67762bf6c2eSChris Jones     c.def_property_readonly(
67862bf6c2eSChris Jones         "value_bytes",
67962bf6c2eSChris Jones         [](PyStringAttribute &self) {
68062bf6c2eSChris Jones           MlirStringRef stringRef = mlirStringAttrGetValue(self);
68162bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
68262bf6c2eSChris Jones         },
68362bf6c2eSChris Jones         "Returns the value of the string attribute as `bytes`");
684436c6c9cSStella Laurenzo   }
685436c6c9cSStella Laurenzo };
686436c6c9cSStella Laurenzo 
687436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
688436c6c9cSStella Laurenzo class PyDenseElementsAttribute
689436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
690436c6c9cSStella Laurenzo public:
691436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
692436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
693436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
694436c6c9cSStella Laurenzo 
695436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
696c912f0e7Spranavm-nvidia   getFromList(py::list attributes, std::optional<PyType> explicitType,
697c912f0e7Spranavm-nvidia               DefaultingPyMlirContext contextWrapper) {
698c912f0e7Spranavm-nvidia 
699c912f0e7Spranavm-nvidia     const size_t numAttributes = py::len(attributes);
700c912f0e7Spranavm-nvidia     if (numAttributes == 0)
701c912f0e7Spranavm-nvidia       throw py::value_error("Attributes list must be non-empty.");
702c912f0e7Spranavm-nvidia 
703c912f0e7Spranavm-nvidia     MlirType shapedType;
704c912f0e7Spranavm-nvidia     if (explicitType) {
705c912f0e7Spranavm-nvidia       if ((!mlirTypeIsAShaped(*explicitType) ||
706c912f0e7Spranavm-nvidia            !mlirShapedTypeHasStaticShape(*explicitType))) {
707c912f0e7Spranavm-nvidia 
708c912f0e7Spranavm-nvidia         std::string message;
709c912f0e7Spranavm-nvidia         llvm::raw_string_ostream os(message);
710c912f0e7Spranavm-nvidia         os << "Expected a static ShapedType for the shaped_type parameter: "
711c912f0e7Spranavm-nvidia            << py::repr(py::cast(*explicitType));
712095b41c6SJOE1994         throw py::value_error(message);
713c912f0e7Spranavm-nvidia       }
714c912f0e7Spranavm-nvidia       shapedType = *explicitType;
715c912f0e7Spranavm-nvidia     } else {
716c912f0e7Spranavm-nvidia       SmallVector<int64_t> shape{static_cast<int64_t>(numAttributes)};
717c912f0e7Spranavm-nvidia       shapedType = mlirRankedTensorTypeGet(
718c912f0e7Spranavm-nvidia           shape.size(), shape.data(),
719c912f0e7Spranavm-nvidia           mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
720c912f0e7Spranavm-nvidia           mlirAttributeGetNull());
721c912f0e7Spranavm-nvidia     }
722c912f0e7Spranavm-nvidia 
723c912f0e7Spranavm-nvidia     SmallVector<MlirAttribute> mlirAttributes;
724c912f0e7Spranavm-nvidia     mlirAttributes.reserve(numAttributes);
725c912f0e7Spranavm-nvidia     for (const py::handle &attribute : attributes) {
726c912f0e7Spranavm-nvidia       MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
727c912f0e7Spranavm-nvidia       MlirType attrType = mlirAttributeGetType(mlirAttribute);
728c912f0e7Spranavm-nvidia       mlirAttributes.push_back(mlirAttribute);
729c912f0e7Spranavm-nvidia 
730c912f0e7Spranavm-nvidia       if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
731c912f0e7Spranavm-nvidia         std::string message;
732c912f0e7Spranavm-nvidia         llvm::raw_string_ostream os(message);
733c912f0e7Spranavm-nvidia         os << "All attributes must be of the same type and match "
734c912f0e7Spranavm-nvidia            << "the type parameter: expected=" << py::repr(py::cast(shapedType))
735c912f0e7Spranavm-nvidia            << ", but got=" << py::repr(py::cast(attrType));
736095b41c6SJOE1994         throw py::value_error(message);
737c912f0e7Spranavm-nvidia       }
738c912f0e7Spranavm-nvidia     }
739c912f0e7Spranavm-nvidia 
740c912f0e7Spranavm-nvidia     MlirAttribute elements = mlirDenseElementsAttrGet(
741c912f0e7Spranavm-nvidia         shapedType, mlirAttributes.size(), mlirAttributes.data());
742c912f0e7Spranavm-nvidia 
743c912f0e7Spranavm-nvidia     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
744c912f0e7Spranavm-nvidia   }
745c912f0e7Spranavm-nvidia 
746c912f0e7Spranavm-nvidia   static PyDenseElementsAttribute
7470a81ace0SKazu Hirata   getFromBuffer(py::buffer array, bool signless,
7480a81ace0SKazu Hirata                 std::optional<PyType> explicitType,
7490a81ace0SKazu Hirata                 std::optional<std::vector<int64_t>> explicitShape,
750436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
751436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
75271a25454SPeter Hawkins     int flags = PyBUF_ND;
75371a25454SPeter Hawkins     if (!explicitType) {
75471a25454SPeter Hawkins       flags |= PyBUF_FORMAT;
75571a25454SPeter Hawkins     }
75671a25454SPeter Hawkins     Py_buffer view;
75771a25454SPeter Hawkins     if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
758436c6c9cSStella Laurenzo       throw py::error_already_set();
759436c6c9cSStella Laurenzo     }
76071a25454SPeter Hawkins     auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
761436c6c9cSStella Laurenzo 
762436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
7631824e45cSKasper Nielsen     MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
7641824e45cSKasper Nielsen                                                 explicitShape, context);
7655d6d30edSStella Laurenzo     if (mlirAttributeIsNull(attr)) {
7665d6d30edSStella Laurenzo       throw std::invalid_argument(
7675d6d30edSStella Laurenzo           "DenseElementsAttr could not be constructed from the given buffer. "
7685d6d30edSStella Laurenzo           "This may mean that the Python buffer layout does not match that "
7695d6d30edSStella Laurenzo           "MLIR expected layout and is a bug.");
7705d6d30edSStella Laurenzo     }
7715d6d30edSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
7725d6d30edSStella Laurenzo   }
773436c6c9cSStella Laurenzo 
7741fc096afSMehdi Amini   static PyDenseElementsAttribute getSplat(const PyType &shapedType,
775436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
776436c6c9cSStella Laurenzo     auto contextWrapper =
777436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
778436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
779436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
780436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
781436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
7824811270bSmax       throw py::value_error(message);
783436c6c9cSStella Laurenzo     }
784436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
785436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
786436c6c9cSStella Laurenzo       std::string message =
787436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
788436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
7894811270bSmax       throw py::value_error(message);
790436c6c9cSStella Laurenzo     }
791436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
792436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
793436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
794436c6c9cSStella Laurenzo       std::string message =
795436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
796436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
797436c6c9cSStella Laurenzo       message.append(", element=");
798436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
7994811270bSmax       throw py::value_error(message);
800436c6c9cSStella Laurenzo     }
801436c6c9cSStella Laurenzo 
802436c6c9cSStella Laurenzo     MlirAttribute elements =
803436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
804436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
805436c6c9cSStella Laurenzo   }
806436c6c9cSStella Laurenzo 
807436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
808436c6c9cSStella Laurenzo 
809436c6c9cSStella Laurenzo   py::buffer_info accessBuffer() {
810436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
811436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
8125d6d30edSStella Laurenzo     std::string format;
813436c6c9cSStella Laurenzo 
814436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
815436c6c9cSStella Laurenzo       // f32
8165d6d30edSStella Laurenzo       return bufferInfo<float>(shapedType);
81702b6fb21SMehdi Amini     }
81802b6fb21SMehdi Amini     if (mlirTypeIsAF64(elementType)) {
819436c6c9cSStella Laurenzo       // f64
8205d6d30edSStella Laurenzo       return bufferInfo<double>(shapedType);
821bb56c2b3SMehdi Amini     }
822bb56c2b3SMehdi Amini     if (mlirTypeIsAF16(elementType)) {
8235d6d30edSStella Laurenzo       // f16
8245d6d30edSStella Laurenzo       return bufferInfo<uint16_t>(shapedType, "e");
825bb56c2b3SMehdi Amini     }
826ef1b735dSmax     if (mlirTypeIsAIndex(elementType)) {
827ef1b735dSmax       // Same as IndexType::kInternalStorageBitWidth
828ef1b735dSmax       return bufferInfo<int64_t>(shapedType);
829ef1b735dSmax     }
830bb56c2b3SMehdi Amini     if (mlirTypeIsAInteger(elementType) &&
831436c6c9cSStella Laurenzo         mlirIntegerTypeGetWidth(elementType) == 32) {
832436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
833436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
834436c6c9cSStella Laurenzo         // i32
8355d6d30edSStella Laurenzo         return bufferInfo<int32_t>(shapedType);
836e5639b3fSMehdi Amini       }
837e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
838436c6c9cSStella Laurenzo         // unsigned i32
8395d6d30edSStella Laurenzo         return bufferInfo<uint32_t>(shapedType);
840436c6c9cSStella Laurenzo       }
841436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
842436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
843436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
844436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
845436c6c9cSStella Laurenzo         // i64
8465d6d30edSStella Laurenzo         return bufferInfo<int64_t>(shapedType);
847e5639b3fSMehdi Amini       }
848e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
849436c6c9cSStella Laurenzo         // unsigned i64
8505d6d30edSStella Laurenzo         return bufferInfo<uint64_t>(shapedType);
8515d6d30edSStella Laurenzo       }
8525d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
8535d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 8) {
8545d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
8555d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
8565d6d30edSStella Laurenzo         // i8
8575d6d30edSStella Laurenzo         return bufferInfo<int8_t>(shapedType);
858e5639b3fSMehdi Amini       }
859e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
8605d6d30edSStella Laurenzo         // unsigned i8
8615d6d30edSStella Laurenzo         return bufferInfo<uint8_t>(shapedType);
8625d6d30edSStella Laurenzo       }
8635d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
8645d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 16) {
8655d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
8665d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
8675d6d30edSStella Laurenzo         // i16
8685d6d30edSStella Laurenzo         return bufferInfo<int16_t>(shapedType);
869e5639b3fSMehdi Amini       }
870e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
8715d6d30edSStella Laurenzo         // unsigned i16
8725d6d30edSStella Laurenzo         return bufferInfo<uint16_t>(shapedType);
873436c6c9cSStella Laurenzo       }
8741824e45cSKasper Nielsen     } else if (mlirTypeIsAInteger(elementType) &&
8751824e45cSKasper Nielsen                mlirIntegerTypeGetWidth(elementType) == 1) {
8761824e45cSKasper Nielsen       // i1 / bool
8771824e45cSKasper Nielsen       // We can not send the buffer directly back to Python, because the i1
8781824e45cSKasper Nielsen       // values are bitpacked within MLIR. We call numpy's unpackbits function
8791824e45cSKasper Nielsen       // to convert the bytes.
8801824e45cSKasper Nielsen       return getBooleanBufferFromBitpackedAttribute();
881436c6c9cSStella Laurenzo     }
882436c6c9cSStella Laurenzo 
883c5f445d1SStella Laurenzo     // TODO: Currently crashes the program.
8845d6d30edSStella Laurenzo     // Reported as https://github.com/pybind/pybind11/issues/3336
885c5f445d1SStella Laurenzo     throw std::invalid_argument(
886c5f445d1SStella Laurenzo         "unsupported data type for conversion to Python buffer");
887436c6c9cSStella Laurenzo   }
888436c6c9cSStella Laurenzo 
889436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
890436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
891436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
892436c6c9cSStella Laurenzo                     py::arg("array"), py::arg("signless") = true,
8935d6d30edSStella Laurenzo                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
894436c6c9cSStella Laurenzo                     py::arg("context") = py::none(),
8955d6d30edSStella Laurenzo                     kDenseElementsAttrGetDocstring)
896c912f0e7Spranavm-nvidia         .def_static("get", PyDenseElementsAttribute::getFromList,
897c912f0e7Spranavm-nvidia                     py::arg("attrs"), py::arg("type") = py::none(),
898c912f0e7Spranavm-nvidia                     py::arg("context") = py::none(),
899c912f0e7Spranavm-nvidia                     kDenseElementsAttrGetFromListDocstring)
900436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
901436c6c9cSStella Laurenzo                     py::arg("shaped_type"), py::arg("element_attr"),
902436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
903436c6c9cSStella Laurenzo         .def_property_readonly("is_splat",
904436c6c9cSStella Laurenzo                                [](PyDenseElementsAttribute &self) -> bool {
905436c6c9cSStella Laurenzo                                  return mlirDenseElementsAttrIsSplat(self);
906436c6c9cSStella Laurenzo                                })
90791259963SAdam Paszke         .def("get_splat_value",
908974c1596SRahul Kayaith              [](PyDenseElementsAttribute &self) {
909974c1596SRahul Kayaith                if (!mlirDenseElementsAttrIsSplat(self))
9104811270bSmax                  throw py::value_error(
91191259963SAdam Paszke                      "get_splat_value called on a non-splat attribute");
912974c1596SRahul Kayaith                return mlirDenseElementsAttrGetSplatValue(self);
91391259963SAdam Paszke              })
914436c6c9cSStella Laurenzo         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
915436c6c9cSStella Laurenzo   }
916436c6c9cSStella Laurenzo 
917436c6c9cSStella Laurenzo private:
91871a25454SPeter Hawkins   static bool isUnsignedIntegerFormat(std::string_view format) {
919436c6c9cSStella Laurenzo     if (format.empty())
920436c6c9cSStella Laurenzo       return false;
921436c6c9cSStella Laurenzo     char code = format[0];
922436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
923436c6c9cSStella Laurenzo            code == 'Q';
924436c6c9cSStella Laurenzo   }
925436c6c9cSStella Laurenzo 
92671a25454SPeter Hawkins   static bool isSignedIntegerFormat(std::string_view format) {
927436c6c9cSStella Laurenzo     if (format.empty())
928436c6c9cSStella Laurenzo       return false;
929436c6c9cSStella Laurenzo     char code = format[0];
930436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
931436c6c9cSStella Laurenzo            code == 'q';
932436c6c9cSStella Laurenzo   }
933436c6c9cSStella Laurenzo 
9341824e45cSKasper Nielsen   static MlirType
9351824e45cSKasper Nielsen   getShapedType(std::optional<MlirType> bulkLoadElementType,
9361824e45cSKasper Nielsen                 std::optional<std::vector<int64_t>> explicitShape,
9371824e45cSKasper Nielsen                 Py_buffer &view) {
9381824e45cSKasper Nielsen     SmallVector<int64_t> shape;
9391824e45cSKasper Nielsen     if (explicitShape) {
9401824e45cSKasper Nielsen       shape.append(explicitShape->begin(), explicitShape->end());
9411824e45cSKasper Nielsen     } else {
9421824e45cSKasper Nielsen       shape.append(view.shape, view.shape + view.ndim);
9431824e45cSKasper Nielsen     }
9441824e45cSKasper Nielsen 
9451824e45cSKasper Nielsen     if (mlirTypeIsAShaped(*bulkLoadElementType)) {
9461824e45cSKasper Nielsen       if (explicitShape) {
9471824e45cSKasper Nielsen         throw std::invalid_argument("Shape can only be specified explicitly "
9481824e45cSKasper Nielsen                                     "when the type is not a shaped type.");
9491824e45cSKasper Nielsen       }
9501824e45cSKasper Nielsen       return *bulkLoadElementType;
9511824e45cSKasper Nielsen     } else {
9521824e45cSKasper Nielsen       MlirAttribute encodingAttr = mlirAttributeGetNull();
9531824e45cSKasper Nielsen       return mlirRankedTensorTypeGet(shape.size(), shape.data(),
9541824e45cSKasper Nielsen                                      *bulkLoadElementType, encodingAttr);
9551824e45cSKasper Nielsen     }
9561824e45cSKasper Nielsen   }
9571824e45cSKasper Nielsen 
9581824e45cSKasper Nielsen   static MlirAttribute getAttributeFromBuffer(
9591824e45cSKasper Nielsen       Py_buffer &view, bool signless, std::optional<PyType> explicitType,
9601824e45cSKasper Nielsen       std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
9611824e45cSKasper Nielsen     // Detect format codes that are suitable for bulk loading. This includes
9621824e45cSKasper Nielsen     // all byte aligned integer and floating point types up to 8 bytes.
9631824e45cSKasper Nielsen     // Notably, this excludes exotics types which do not have a direct
9641824e45cSKasper Nielsen     // representation in the buffer protocol (i.e. complex, etc).
9651824e45cSKasper Nielsen     std::optional<MlirType> bulkLoadElementType;
9661824e45cSKasper Nielsen     if (explicitType) {
9671824e45cSKasper Nielsen       bulkLoadElementType = *explicitType;
9681824e45cSKasper Nielsen     } else {
9691824e45cSKasper Nielsen       std::string_view format(view.format);
9701824e45cSKasper Nielsen       if (format == "f") {
9711824e45cSKasper Nielsen         // f32
9721824e45cSKasper Nielsen         assert(view.itemsize == 4 && "mismatched array itemsize");
9731824e45cSKasper Nielsen         bulkLoadElementType = mlirF32TypeGet(context);
9741824e45cSKasper Nielsen       } else if (format == "d") {
9751824e45cSKasper Nielsen         // f64
9761824e45cSKasper Nielsen         assert(view.itemsize == 8 && "mismatched array itemsize");
9771824e45cSKasper Nielsen         bulkLoadElementType = mlirF64TypeGet(context);
9781824e45cSKasper Nielsen       } else if (format == "e") {
9791824e45cSKasper Nielsen         // f16
9801824e45cSKasper Nielsen         assert(view.itemsize == 2 && "mismatched array itemsize");
9811824e45cSKasper Nielsen         bulkLoadElementType = mlirF16TypeGet(context);
9821824e45cSKasper Nielsen       } else if (format == "?") {
9831824e45cSKasper Nielsen         // i1
9841824e45cSKasper Nielsen         // The i1 type needs to be bit-packed, so we will handle it seperately
9851824e45cSKasper Nielsen         return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
9861824e45cSKasper Nielsen                                                       context);
9871824e45cSKasper Nielsen       } else if (isSignedIntegerFormat(format)) {
9881824e45cSKasper Nielsen         if (view.itemsize == 4) {
9891824e45cSKasper Nielsen           // i32
9901824e45cSKasper Nielsen           bulkLoadElementType = signless
9911824e45cSKasper Nielsen                                     ? mlirIntegerTypeGet(context, 32)
9921824e45cSKasper Nielsen                                     : mlirIntegerTypeSignedGet(context, 32);
9931824e45cSKasper Nielsen         } else if (view.itemsize == 8) {
9941824e45cSKasper Nielsen           // i64
9951824e45cSKasper Nielsen           bulkLoadElementType = signless
9961824e45cSKasper Nielsen                                     ? mlirIntegerTypeGet(context, 64)
9971824e45cSKasper Nielsen                                     : mlirIntegerTypeSignedGet(context, 64);
9981824e45cSKasper Nielsen         } else if (view.itemsize == 1) {
9991824e45cSKasper Nielsen           // i8
10001824e45cSKasper Nielsen           bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
10011824e45cSKasper Nielsen                                          : mlirIntegerTypeSignedGet(context, 8);
10021824e45cSKasper Nielsen         } else if (view.itemsize == 2) {
10031824e45cSKasper Nielsen           // i16
10041824e45cSKasper Nielsen           bulkLoadElementType = signless
10051824e45cSKasper Nielsen                                     ? mlirIntegerTypeGet(context, 16)
10061824e45cSKasper Nielsen                                     : mlirIntegerTypeSignedGet(context, 16);
10071824e45cSKasper Nielsen         }
10081824e45cSKasper Nielsen       } else if (isUnsignedIntegerFormat(format)) {
10091824e45cSKasper Nielsen         if (view.itemsize == 4) {
10101824e45cSKasper Nielsen           // unsigned i32
10111824e45cSKasper Nielsen           bulkLoadElementType = signless
10121824e45cSKasper Nielsen                                     ? mlirIntegerTypeGet(context, 32)
10131824e45cSKasper Nielsen                                     : mlirIntegerTypeUnsignedGet(context, 32);
10141824e45cSKasper Nielsen         } else if (view.itemsize == 8) {
10151824e45cSKasper Nielsen           // unsigned i64
10161824e45cSKasper Nielsen           bulkLoadElementType = signless
10171824e45cSKasper Nielsen                                     ? mlirIntegerTypeGet(context, 64)
10181824e45cSKasper Nielsen                                     : mlirIntegerTypeUnsignedGet(context, 64);
10191824e45cSKasper Nielsen         } else if (view.itemsize == 1) {
10201824e45cSKasper Nielsen           // i8
10211824e45cSKasper Nielsen           bulkLoadElementType = signless
10221824e45cSKasper Nielsen                                     ? mlirIntegerTypeGet(context, 8)
10231824e45cSKasper Nielsen                                     : mlirIntegerTypeUnsignedGet(context, 8);
10241824e45cSKasper Nielsen         } else if (view.itemsize == 2) {
10251824e45cSKasper Nielsen           // i16
10261824e45cSKasper Nielsen           bulkLoadElementType = signless
10271824e45cSKasper Nielsen                                     ? mlirIntegerTypeGet(context, 16)
10281824e45cSKasper Nielsen                                     : mlirIntegerTypeUnsignedGet(context, 16);
10291824e45cSKasper Nielsen         }
10301824e45cSKasper Nielsen       }
10311824e45cSKasper Nielsen       if (!bulkLoadElementType) {
10321824e45cSKasper Nielsen         throw std::invalid_argument(
10331824e45cSKasper Nielsen             std::string("unimplemented array format conversion from format: ") +
10341824e45cSKasper Nielsen             std::string(format));
10351824e45cSKasper Nielsen       }
10361824e45cSKasper Nielsen     }
10371824e45cSKasper Nielsen 
10381824e45cSKasper Nielsen     MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
10391824e45cSKasper Nielsen     return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
10401824e45cSKasper Nielsen   }
10411824e45cSKasper Nielsen 
10421824e45cSKasper Nielsen   // There is a complication for boolean numpy arrays, as numpy represents them
10431824e45cSKasper Nielsen   // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans
10441824e45cSKasper Nielsen   // per byte.
10451824e45cSKasper Nielsen   static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
10461824e45cSKasper Nielsen       Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
10471824e45cSKasper Nielsen       MlirContext &context) {
10481824e45cSKasper Nielsen     if (llvm::endianness::native != llvm::endianness::little) {
10491824e45cSKasper Nielsen       // Given we have no good way of testing the behavior on big-endian systems
10501824e45cSKasper Nielsen       // we will throw
10511824e45cSKasper Nielsen       throw py::type_error("Constructing a bit-packed MLIR attribute is "
10521824e45cSKasper Nielsen                            "unsupported on big-endian systems");
10531824e45cSKasper Nielsen     }
10541824e45cSKasper Nielsen 
10551824e45cSKasper Nielsen     py::array_t<uint8_t> unpackedArray(view.len,
10561824e45cSKasper Nielsen                                        static_cast<uint8_t *>(view.buf));
10571824e45cSKasper Nielsen 
10581824e45cSKasper Nielsen     py::module numpy = py::module::import("numpy");
10591824e45cSKasper Nielsen     py::object packbitsFunc = numpy.attr("packbits");
10601824e45cSKasper Nielsen     py::object packedBooleans =
10611824e45cSKasper Nielsen         packbitsFunc(unpackedArray, "bitorder"_a = "little");
10621824e45cSKasper Nielsen     py::buffer_info pythonBuffer = packedBooleans.cast<py::buffer>().request();
10631824e45cSKasper Nielsen 
10641824e45cSKasper Nielsen     MlirType bitpackedType =
10651824e45cSKasper Nielsen         getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
10661824e45cSKasper Nielsen     assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
10671824e45cSKasper Nielsen     // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
10681824e45cSKasper Nielsen     // packedBooleans, hence the MlirAttribute will remain valid even when
10691824e45cSKasper Nielsen     // packedBooleans get reclaimed by the end of the function.
10701824e45cSKasper Nielsen     return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
10711824e45cSKasper Nielsen                                              pythonBuffer.ptr);
10721824e45cSKasper Nielsen   }
10731824e45cSKasper Nielsen 
10741824e45cSKasper Nielsen   // This does the opposite transformation of
10751824e45cSKasper Nielsen   // `getBitpackedAttributeFromBooleanBuffer`
10761824e45cSKasper Nielsen   py::buffer_info getBooleanBufferFromBitpackedAttribute() {
10771824e45cSKasper Nielsen     if (llvm::endianness::native != llvm::endianness::little) {
10781824e45cSKasper Nielsen       // Given we have no good way of testing the behavior on big-endian systems
10791824e45cSKasper Nielsen       // we will throw
10801824e45cSKasper Nielsen       throw py::type_error("Constructing a numpy array from a MLIR attribute "
10811824e45cSKasper Nielsen                            "is unsupported on big-endian systems");
10821824e45cSKasper Nielsen     }
10831824e45cSKasper Nielsen 
10841824e45cSKasper Nielsen     int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
10851824e45cSKasper Nielsen     int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
10861824e45cSKasper Nielsen     uint8_t *bitpackedData = static_cast<uint8_t *>(
10871824e45cSKasper Nielsen         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
10881824e45cSKasper Nielsen     py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData);
10891824e45cSKasper Nielsen 
10901824e45cSKasper Nielsen     py::module numpy = py::module::import("numpy");
10911824e45cSKasper Nielsen     py::object unpackbitsFunc = numpy.attr("unpackbits");
10921824e45cSKasper Nielsen     py::object equalFunc = numpy.attr("equal");
10931824e45cSKasper Nielsen     py::object reshapeFunc = numpy.attr("reshape");
10941824e45cSKasper Nielsen     py::array unpackedBooleans =
10951824e45cSKasper Nielsen         unpackbitsFunc(packedArray, "bitorder"_a = "little");
10961824e45cSKasper Nielsen 
10971824e45cSKasper Nielsen     // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
10981824e45cSKasper Nielsen     // We need to:
10991824e45cSKasper Nielsen     //   1. Slice away the padded bits
11001824e45cSKasper Nielsen     //   2. Make the boolean array have the correct shape
11011824e45cSKasper Nielsen     //   3. Convert the array to a boolean array
11021824e45cSKasper Nielsen     unpackedBooleans = unpackedBooleans[py::slice(0, numBooleans, 1)];
11031824e45cSKasper Nielsen     unpackedBooleans = equalFunc(unpackedBooleans, 1);
11041824e45cSKasper Nielsen 
11051824e45cSKasper Nielsen     MlirType shapedType = mlirAttributeGetType(*this);
11061824e45cSKasper Nielsen     intptr_t rank = mlirShapedTypeGetRank(shapedType);
1107*404d0e99SAdrian Kuegel     std::vector<intptr_t> shape(rank);
11081824e45cSKasper Nielsen     for (intptr_t i = 0; i < rank; ++i) {
1109*404d0e99SAdrian Kuegel       shape[i] = mlirShapedTypeGetDimSize(shapedType, i);
11101824e45cSKasper Nielsen     }
11111824e45cSKasper Nielsen     unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
11121824e45cSKasper Nielsen 
11131824e45cSKasper Nielsen     // Make sure the returned py::buffer_view claims ownership of the data in
11141824e45cSKasper Nielsen     // `pythonBuffer` so it remains valid when Python reads it
11151824e45cSKasper Nielsen     py::buffer pythonBuffer = unpackedBooleans.cast<py::buffer>();
11161824e45cSKasper Nielsen     return pythonBuffer.request();
11171824e45cSKasper Nielsen   }
11181824e45cSKasper Nielsen 
1119436c6c9cSStella Laurenzo   template <typename Type>
1120436c6c9cSStella Laurenzo   py::buffer_info bufferInfo(MlirType shapedType,
11215d6d30edSStella Laurenzo                              const char *explicitFormat = nullptr) {
11220a68171bSDmitri Gribenko     intptr_t rank = mlirShapedTypeGetRank(shapedType);
1123436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
11240a68171bSDmitri Gribenko     // Buffer is configured for read-only access below.
1125436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
1126436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1127436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
1128436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
1129436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
1130436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
1131436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
1132436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
1133f0e847d0SRahul Kayaith     if (mlirDenseElementsAttrIsSplat(*this)) {
1134f0e847d0SRahul Kayaith       // Splats are special, only the single value is stored.
1135f0e847d0SRahul Kayaith       strides.assign(rank, 0);
1136f0e847d0SRahul Kayaith     } else {
1137436c6c9cSStella Laurenzo       for (intptr_t i = 1; i < rank; ++i) {
1138f0e847d0SRahul Kayaith         intptr_t strideFactor = 1;
1139f0e847d0SRahul Kayaith         for (intptr_t j = i; j < rank; ++j)
1140436c6c9cSStella Laurenzo           strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
1141436c6c9cSStella Laurenzo         strides.push_back(sizeof(Type) * strideFactor);
1142436c6c9cSStella Laurenzo       }
1143436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type));
1144f0e847d0SRahul Kayaith     }
11455d6d30edSStella Laurenzo     std::string format;
11465d6d30edSStella Laurenzo     if (explicitFormat) {
11475d6d30edSStella Laurenzo       format = explicitFormat;
11485d6d30edSStella Laurenzo     } else {
11495d6d30edSStella Laurenzo       format = py::format_descriptor<Type>::format();
11505d6d30edSStella Laurenzo     }
11515d6d30edSStella Laurenzo     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
11525d6d30edSStella Laurenzo                            /*readonly=*/true);
1153436c6c9cSStella Laurenzo   }
1154436c6c9cSStella Laurenzo }; // namespace
1155436c6c9cSStella Laurenzo 
1156436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer
1157436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access.
1158436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
1159436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
1160436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
1161436c6c9cSStella Laurenzo public:
1162436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
1163436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
1164436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1165436c6c9cSStella Laurenzo 
1166436c6c9cSStella Laurenzo   /// Returns the element at the given linear position. Asserts if the index is
1167436c6c9cSStella Laurenzo   /// out of range.
1168436c6c9cSStella Laurenzo   py::int_ dunderGetItem(intptr_t pos) {
1169436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
11704811270bSmax       throw py::index_error("attempt to access out of bounds element");
1171436c6c9cSStella Laurenzo     }
1172436c6c9cSStella Laurenzo 
1173436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
1174436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
1175436c6c9cSStella Laurenzo     assert(mlirTypeIsAInteger(type) &&
1176436c6c9cSStella Laurenzo            "expected integer element type in dense int elements attribute");
1177436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
1178436c6c9cSStella Laurenzo     // elemental type of the attribute. py::int_ is implicitly constructible
1179436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
1180436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
1181436c6c9cSStella Laurenzo     // querying them on each element access.
1182436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
1183436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
1184436c6c9cSStella Laurenzo     if (isUnsigned) {
1185436c6c9cSStella Laurenzo       if (width == 1) {
1186436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
1187436c6c9cSStella Laurenzo       }
1188308d8b8cSRahul Kayaith       if (width == 8) {
1189308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt8Value(*this, pos);
1190308d8b8cSRahul Kayaith       }
1191308d8b8cSRahul Kayaith       if (width == 16) {
1192308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt16Value(*this, pos);
1193308d8b8cSRahul Kayaith       }
1194436c6c9cSStella Laurenzo       if (width == 32) {
1195436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
1196436c6c9cSStella Laurenzo       }
1197436c6c9cSStella Laurenzo       if (width == 64) {
1198436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
1199436c6c9cSStella Laurenzo       }
1200436c6c9cSStella Laurenzo     } else {
1201436c6c9cSStella Laurenzo       if (width == 1) {
1202436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
1203436c6c9cSStella Laurenzo       }
1204308d8b8cSRahul Kayaith       if (width == 8) {
1205308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt8Value(*this, pos);
1206308d8b8cSRahul Kayaith       }
1207308d8b8cSRahul Kayaith       if (width == 16) {
1208308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt16Value(*this, pos);
1209308d8b8cSRahul Kayaith       }
1210436c6c9cSStella Laurenzo       if (width == 32) {
1211436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt32Value(*this, pos);
1212436c6c9cSStella Laurenzo       }
1213436c6c9cSStella Laurenzo       if (width == 64) {
1214436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt64Value(*this, pos);
1215436c6c9cSStella Laurenzo       }
1216436c6c9cSStella Laurenzo     }
12174811270bSmax     throw py::type_error("Unsupported integer type");
1218436c6c9cSStella Laurenzo   }
1219436c6c9cSStella Laurenzo 
1220436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1221436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
1222436c6c9cSStella Laurenzo   }
1223436c6c9cSStella Laurenzo };
1224436c6c9cSStella Laurenzo 
1225f66cd9e9SStella Laurenzo class PyDenseResourceElementsAttribute
1226f66cd9e9SStella Laurenzo     : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
1227f66cd9e9SStella Laurenzo public:
1228f66cd9e9SStella Laurenzo   static constexpr IsAFunctionTy isaFunction =
1229f66cd9e9SStella Laurenzo       mlirAttributeIsADenseResourceElements;
1230f66cd9e9SStella Laurenzo   static constexpr const char *pyClassName = "DenseResourceElementsAttr";
1231f66cd9e9SStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1232f66cd9e9SStella Laurenzo 
1233f66cd9e9SStella Laurenzo   static PyDenseResourceElementsAttribute
1234962bf002SMehdi Amini   getFromBuffer(py::buffer buffer, const std::string &name, const PyType &type,
1235f66cd9e9SStella Laurenzo                 std::optional<size_t> alignment, bool isMutable,
1236f66cd9e9SStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
1237f66cd9e9SStella Laurenzo     if (!mlirTypeIsAShaped(type)) {
1238f66cd9e9SStella Laurenzo       throw std::invalid_argument(
1239f66cd9e9SStella Laurenzo           "Constructing a DenseResourceElementsAttr requires a ShapedType.");
1240f66cd9e9SStella Laurenzo     }
1241f66cd9e9SStella Laurenzo 
1242f66cd9e9SStella Laurenzo     // Do not request any conversions as we must ensure to use caller
1243f66cd9e9SStella Laurenzo     // managed memory.
1244f66cd9e9SStella Laurenzo     int flags = PyBUF_STRIDES;
1245f66cd9e9SStella Laurenzo     std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
1246f66cd9e9SStella Laurenzo     if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
1247f66cd9e9SStella Laurenzo       throw py::error_already_set();
1248f66cd9e9SStella Laurenzo     }
1249f66cd9e9SStella Laurenzo 
1250f66cd9e9SStella Laurenzo     // This scope releaser will only release if we haven't yet transferred
1251f66cd9e9SStella Laurenzo     // ownership.
1252f66cd9e9SStella Laurenzo     auto freeBuffer = llvm::make_scope_exit([&]() {
1253f66cd9e9SStella Laurenzo       if (view)
1254f66cd9e9SStella Laurenzo         PyBuffer_Release(view.get());
1255f66cd9e9SStella Laurenzo     });
1256f66cd9e9SStella Laurenzo 
1257f66cd9e9SStella Laurenzo     if (!PyBuffer_IsContiguous(view.get(), 'A')) {
1258f66cd9e9SStella Laurenzo       throw std::invalid_argument("Contiguous buffer is required.");
1259f66cd9e9SStella Laurenzo     }
1260f66cd9e9SStella Laurenzo 
1261f66cd9e9SStella Laurenzo     // Infer alignment to be the stride of one element if not explicit.
1262f66cd9e9SStella Laurenzo     size_t inferredAlignment;
1263f66cd9e9SStella Laurenzo     if (alignment)
1264f66cd9e9SStella Laurenzo       inferredAlignment = *alignment;
1265f66cd9e9SStella Laurenzo     else
1266f66cd9e9SStella Laurenzo       inferredAlignment = view->strides[view->ndim - 1];
1267f66cd9e9SStella Laurenzo 
1268f66cd9e9SStella Laurenzo     // The userData is a Py_buffer* that the deleter owns.
1269f66cd9e9SStella Laurenzo     auto deleter = [](void *userData, const void *data, size_t size,
1270f66cd9e9SStella Laurenzo                       size_t align) {
1271f66cd9e9SStella Laurenzo       Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
1272f66cd9e9SStella Laurenzo       PyBuffer_Release(ownedView);
1273f66cd9e9SStella Laurenzo       delete ownedView;
1274f66cd9e9SStella Laurenzo     };
1275f66cd9e9SStella Laurenzo 
1276f66cd9e9SStella Laurenzo     size_t rawBufferSize = view->len;
1277f66cd9e9SStella Laurenzo     MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
1278f66cd9e9SStella Laurenzo         type, toMlirStringRef(name), view->buf, rawBufferSize,
1279f66cd9e9SStella Laurenzo         inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
1280f66cd9e9SStella Laurenzo     if (mlirAttributeIsNull(attr)) {
1281f66cd9e9SStella Laurenzo       throw std::invalid_argument(
1282f66cd9e9SStella Laurenzo           "DenseResourceElementsAttr could not be constructed from the given "
1283f66cd9e9SStella Laurenzo           "buffer. "
1284f66cd9e9SStella Laurenzo           "This may mean that the Python buffer layout does not match that "
1285f66cd9e9SStella Laurenzo           "MLIR expected layout and is a bug.");
1286f66cd9e9SStella Laurenzo     }
1287f66cd9e9SStella Laurenzo     view.release();
1288f66cd9e9SStella Laurenzo     return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
1289f66cd9e9SStella Laurenzo   }
1290f66cd9e9SStella Laurenzo 
1291f66cd9e9SStella Laurenzo   static void bindDerived(ClassTy &c) {
1292f66cd9e9SStella Laurenzo     c.def_static("get_from_buffer",
1293f66cd9e9SStella Laurenzo                  PyDenseResourceElementsAttribute::getFromBuffer,
1294f66cd9e9SStella Laurenzo                  py::arg("array"), py::arg("name"), py::arg("type"),
1295f66cd9e9SStella Laurenzo                  py::arg("alignment") = py::none(),
1296f66cd9e9SStella Laurenzo                  py::arg("is_mutable") = false, py::arg("context") = py::none(),
1297f66cd9e9SStella Laurenzo                  kDenseResourceElementsAttrGetFromBufferDocstring);
1298f66cd9e9SStella Laurenzo   }
1299f66cd9e9SStella Laurenzo };
1300f66cd9e9SStella Laurenzo 
1301436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
1302436c6c9cSStella Laurenzo public:
1303436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
1304436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
1305436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
13069566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
13079566ee28Smax       mlirDictionaryAttrGetTypeID;
1308436c6c9cSStella Laurenzo 
1309436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
1310436c6c9cSStella Laurenzo 
13119fb1086bSAdrian Kuegel   bool dunderContains(const std::string &name) {
13129fb1086bSAdrian Kuegel     return !mlirAttributeIsNull(
13139fb1086bSAdrian Kuegel         mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
13149fb1086bSAdrian Kuegel   }
13159fb1086bSAdrian Kuegel 
1316436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
13179fb1086bSAdrian Kuegel     c.def("__contains__", &PyDictAttribute::dunderContains);
1318436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
1319436c6c9cSStella Laurenzo     c.def_static(
1320436c6c9cSStella Laurenzo         "get",
1321436c6c9cSStella Laurenzo         [](py::dict attributes, DefaultingPyMlirContext context) {
1322436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
1323436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
1324436c6c9cSStella Laurenzo           for (auto &it : attributes) {
132502b6fb21SMehdi Amini             auto &mlirAttr = it.second.cast<PyAttribute &>();
1326436c6c9cSStella Laurenzo             auto name = it.first.cast<std::string>();
1327436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
132802b6fb21SMehdi Amini                 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
1329436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
133002b6fb21SMehdi Amini                 mlirAttr));
1331436c6c9cSStella Laurenzo           }
1332436c6c9cSStella Laurenzo           MlirAttribute attr =
1333436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
1334436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
1335436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
1336436c6c9cSStella Laurenzo         },
1337ed9e52f3SAlex Zinenko         py::arg("value") = py::dict(), py::arg("context") = py::none(),
1338436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
1339436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
1340436c6c9cSStella Laurenzo       MlirAttribute attr =
1341436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
1342974c1596SRahul Kayaith       if (mlirAttributeIsNull(attr))
13434811270bSmax         throw py::key_error("attempt to access a non-existent attribute");
1344974c1596SRahul Kayaith       return attr;
1345436c6c9cSStella Laurenzo     });
1346436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
1347436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
13484811270bSmax         throw py::index_error("attempt to access out of bounds attribute");
1349436c6c9cSStella Laurenzo       }
1350436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
1351436c6c9cSStella Laurenzo       return PyNamedAttribute(
1352436c6c9cSStella Laurenzo           namedAttr.attribute,
1353436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
1354436c6c9cSStella Laurenzo     });
1355436c6c9cSStella Laurenzo   }
1356436c6c9cSStella Laurenzo };
1357436c6c9cSStella Laurenzo 
1358436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
1359436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
1360436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
1361436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1362436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
1363436c6c9cSStella Laurenzo public:
1364436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1365436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
1366436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1367436c6c9cSStella Laurenzo 
1368436c6c9cSStella Laurenzo   py::float_ dunderGetItem(intptr_t pos) {
1369436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
13704811270bSmax       throw py::index_error("attempt to access out of bounds element");
1371436c6c9cSStella Laurenzo     }
1372436c6c9cSStella Laurenzo 
1373436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
1374436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
1375436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
1376436c6c9cSStella Laurenzo     // elemental type of the attribute. py::float_ is implicitly constructible
1377436c6c9cSStella Laurenzo     // from float and double.
1378436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
1379436c6c9cSStella Laurenzo     // querying them on each element access.
1380436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
1381436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetFloatValue(*this, pos);
1382436c6c9cSStella Laurenzo     }
1383436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
1384436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
1385436c6c9cSStella Laurenzo     }
13864811270bSmax     throw py::type_error("Unsupported floating-point type");
1387436c6c9cSStella Laurenzo   }
1388436c6c9cSStella Laurenzo 
1389436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1390436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1391436c6c9cSStella Laurenzo   }
1392436c6c9cSStella Laurenzo };
1393436c6c9cSStella Laurenzo 
1394436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1395436c6c9cSStella Laurenzo public:
1396436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1397436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
1398436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
13999566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
14009566ee28Smax       mlirTypeAttrGetTypeID;
1401436c6c9cSStella Laurenzo 
1402436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1403436c6c9cSStella Laurenzo     c.def_static(
1404436c6c9cSStella Laurenzo         "get",
1405436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
1406436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
1407436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
1408436c6c9cSStella Laurenzo         },
1409436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
1410436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
1411436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyTypeAttribute &self) {
1412bfb1ba75Smax       return mlirTypeAttrGetValue(self.get());
1413436c6c9cSStella Laurenzo     });
1414436c6c9cSStella Laurenzo   }
1415436c6c9cSStella Laurenzo };
1416436c6c9cSStella Laurenzo 
1417436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
1418436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1419436c6c9cSStella Laurenzo public:
1420436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1421436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
1422436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
14239566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
14249566ee28Smax       mlirUnitAttrGetTypeID;
1425436c6c9cSStella Laurenzo 
1426436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1427436c6c9cSStella Laurenzo     c.def_static(
1428436c6c9cSStella Laurenzo         "get",
1429436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
1430436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
1431436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
1432436c6c9cSStella Laurenzo         },
1433436c6c9cSStella Laurenzo         py::arg("context") = py::none(), "Create a Unit attribute.");
1434436c6c9cSStella Laurenzo   }
1435436c6c9cSStella Laurenzo };
1436436c6c9cSStella Laurenzo 
1437ac2e2d65SDenys Shabalin /// Strided layout attribute subclass.
1438ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute
1439ac2e2d65SDenys Shabalin     : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1440ac2e2d65SDenys Shabalin public:
1441ac2e2d65SDenys Shabalin   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1442ac2e2d65SDenys Shabalin   static constexpr const char *pyClassName = "StridedLayoutAttr";
1443ac2e2d65SDenys Shabalin   using PyConcreteAttribute::PyConcreteAttribute;
14449566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
14459566ee28Smax       mlirStridedLayoutAttrGetTypeID;
1446ac2e2d65SDenys Shabalin 
1447ac2e2d65SDenys Shabalin   static void bindDerived(ClassTy &c) {
1448ac2e2d65SDenys Shabalin     c.def_static(
1449ac2e2d65SDenys Shabalin         "get",
1450ac2e2d65SDenys Shabalin         [](int64_t offset, const std::vector<int64_t> strides,
1451ac2e2d65SDenys Shabalin            DefaultingPyMlirContext ctx) {
1452ac2e2d65SDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1453ac2e2d65SDenys Shabalin               ctx->get(), offset, strides.size(), strides.data());
1454ac2e2d65SDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1455ac2e2d65SDenys Shabalin         },
1456ac2e2d65SDenys Shabalin         py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
1457ac2e2d65SDenys Shabalin         "Gets a strided layout attribute.");
1458e3fd612eSDenys Shabalin     c.def_static(
1459e3fd612eSDenys Shabalin         "get_fully_dynamic",
1460e3fd612eSDenys Shabalin         [](int64_t rank, DefaultingPyMlirContext ctx) {
1461e3fd612eSDenys Shabalin           auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
1462e3fd612eSDenys Shabalin           std::vector<int64_t> strides(rank);
1463e3fd612eSDenys Shabalin           std::fill(strides.begin(), strides.end(), dynamic);
1464e3fd612eSDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1465e3fd612eSDenys Shabalin               ctx->get(), dynamic, strides.size(), strides.data());
1466e3fd612eSDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1467e3fd612eSDenys Shabalin         },
1468e3fd612eSDenys Shabalin         py::arg("rank"), py::arg("context") = py::none(),
1469e3fd612eSDenys Shabalin         "Gets a strided layout attribute with dynamic offset and strides of a "
1470e3fd612eSDenys Shabalin         "given rank.");
1471ac2e2d65SDenys Shabalin     c.def_property_readonly(
1472ac2e2d65SDenys Shabalin         "offset",
1473ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1474ac2e2d65SDenys Shabalin           return mlirStridedLayoutAttrGetOffset(self);
1475ac2e2d65SDenys Shabalin         },
1476ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1477ac2e2d65SDenys Shabalin     c.def_property_readonly(
1478ac2e2d65SDenys Shabalin         "strides",
1479ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1480ac2e2d65SDenys Shabalin           intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1481ac2e2d65SDenys Shabalin           std::vector<int64_t> strides(size);
1482ac2e2d65SDenys Shabalin           for (intptr_t i = 0; i < size; i++) {
1483ac2e2d65SDenys Shabalin             strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1484ac2e2d65SDenys Shabalin           }
1485ac2e2d65SDenys Shabalin           return strides;
1486ac2e2d65SDenys Shabalin         },
1487ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1488ac2e2d65SDenys Shabalin   }
1489ac2e2d65SDenys Shabalin };
1490ac2e2d65SDenys Shabalin 
14919566ee28Smax py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
14929566ee28Smax   if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
14939566ee28Smax     return py::cast(PyDenseBoolArrayAttribute(pyAttribute));
14949566ee28Smax   if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
14959566ee28Smax     return py::cast(PyDenseI8ArrayAttribute(pyAttribute));
14969566ee28Smax   if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
14979566ee28Smax     return py::cast(PyDenseI16ArrayAttribute(pyAttribute));
14989566ee28Smax   if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
14999566ee28Smax     return py::cast(PyDenseI32ArrayAttribute(pyAttribute));
15009566ee28Smax   if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
15019566ee28Smax     return py::cast(PyDenseI64ArrayAttribute(pyAttribute));
15029566ee28Smax   if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
15039566ee28Smax     return py::cast(PyDenseF32ArrayAttribute(pyAttribute));
15049566ee28Smax   if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
15059566ee28Smax     return py::cast(PyDenseF64ArrayAttribute(pyAttribute));
15069566ee28Smax   std::string msg =
15079566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
15089566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
15099566ee28Smax   throw py::cast_error(msg);
15109566ee28Smax }
15119566ee28Smax 
15129566ee28Smax py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
15139566ee28Smax   if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
15149566ee28Smax     return py::cast(PyDenseFPElementsAttribute(pyAttribute));
15159566ee28Smax   if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
15169566ee28Smax     return py::cast(PyDenseIntElementsAttribute(pyAttribute));
15179566ee28Smax   std::string msg =
15189566ee28Smax       std::string(
15199566ee28Smax           "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
15209566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
15219566ee28Smax   throw py::cast_error(msg);
15229566ee28Smax }
15239566ee28Smax 
15249566ee28Smax py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
15259566ee28Smax   if (PyBoolAttribute::isaFunction(pyAttribute))
15269566ee28Smax     return py::cast(PyBoolAttribute(pyAttribute));
15279566ee28Smax   if (PyIntegerAttribute::isaFunction(pyAttribute))
15289566ee28Smax     return py::cast(PyIntegerAttribute(pyAttribute));
15299566ee28Smax   std::string msg =
15309566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
15319566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
15329566ee28Smax   throw py::cast_error(msg);
15339566ee28Smax }
15349566ee28Smax 
15354eee9ef9Smax py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
15364eee9ef9Smax   if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
15374eee9ef9Smax     return py::cast(PyFlatSymbolRefAttribute(pyAttribute));
15384eee9ef9Smax   if (PySymbolRefAttribute::isaFunction(pyAttribute))
15394eee9ef9Smax     return py::cast(PySymbolRefAttribute(pyAttribute));
15404eee9ef9Smax   std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
15414eee9ef9Smax                     std::string(py::repr(py::cast(pyAttribute))) + ")";
15424eee9ef9Smax   throw py::cast_error(msg);
15434eee9ef9Smax }
15444eee9ef9Smax 
1545436c6c9cSStella Laurenzo } // namespace
1546436c6c9cSStella Laurenzo 
1547436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) {
1548436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
1549619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::bind(m);
1550619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1551619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::bind(m);
1552619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1553619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::bind(m);
1554619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1555619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::bind(m);
1556619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1557619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::bind(m);
1558619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1559619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::bind(m);
1560619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1561619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::bind(m);
1562619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
15639566ee28Smax   PyGlobals::get().registerTypeCaster(
15649566ee28Smax       mlirDenseArrayAttrGetTypeID(),
15659566ee28Smax       pybind11::cpp_function(denseArrayAttributeCaster));
1566619fd8c2SJeff Niu 
1567436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
1568436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1569436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
1570436c6c9cSStella Laurenzo   PyDenseElementsAttribute::bind(m);
1571436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
1572436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
15739566ee28Smax   PyGlobals::get().registerTypeCaster(
15749566ee28Smax       mlirDenseIntOrFPElementsAttrGetTypeID(),
15759566ee28Smax       pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
1576f66cd9e9SStella Laurenzo   PyDenseResourceElementsAttribute::bind(m);
15779566ee28Smax 
1578436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
15794eee9ef9Smax   PySymbolRefAttribute::bind(m);
15804eee9ef9Smax   PyGlobals::get().registerTypeCaster(
15814eee9ef9Smax       mlirSymbolRefAttrGetTypeID(),
15824eee9ef9Smax       pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster));
15834eee9ef9Smax 
1584436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
15855c3861b2SYun Long   PyOpaqueAttribute::bind(m);
1586436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
1587436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
1588334873feSAmy Wang   PyIntegerSetAttribute::bind(m);
1589436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
1590436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
15919566ee28Smax   PyGlobals::get().registerTypeCaster(
15929566ee28Smax       mlirIntegerAttrGetTypeID(),
15939566ee28Smax       pybind11::cpp_function(integerOrBoolAttributeCaster));
1594436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
1595ac2e2d65SDenys Shabalin 
1596ac2e2d65SDenys Shabalin   PyStridedLayoutAttribute::bind(m);
1597436c6c9cSStella Laurenzo }
1598