xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision c912f0e773386cc309155b78e2441ee5f1052c13)
1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9a1fe1f5fSKazu Hirata #include <optional>
1071a25454SPeter Hawkins #include <string_view>
114811270bSmax #include <utility>
121fc096afSMehdi Amini 
13436c6c9cSStella Laurenzo #include "IRModule.h"
14436c6c9cSStella Laurenzo 
15436c6c9cSStella Laurenzo #include "PybindUtils.h"
16436c6c9cSStella Laurenzo 
1771a25454SPeter Hawkins #include "llvm/ADT/ScopeExit.h"
18*c912f0e7Spranavm-nvidia #include "llvm/Support/raw_ostream.h"
1971a25454SPeter Hawkins 
20436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
21436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
22bfb1ba75Smax #include "mlir/Bindings/Python/PybindAdaptors.h"
23436c6c9cSStella Laurenzo 
24436c6c9cSStella Laurenzo namespace py = pybind11;
25436c6c9cSStella Laurenzo using namespace mlir;
26436c6c9cSStella Laurenzo using namespace mlir::python;
27436c6c9cSStella Laurenzo 
28436c6c9cSStella Laurenzo using llvm::SmallVector;
29436c6c9cSStella Laurenzo 
305d6d30edSStella Laurenzo //------------------------------------------------------------------------------
315d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
325d6d30edSStella Laurenzo //------------------------------------------------------------------------------
335d6d30edSStella Laurenzo 
345d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] =
355d6d30edSStella Laurenzo     R"(Gets a DenseElementsAttr from a Python buffer or array.
365d6d30edSStella Laurenzo 
375d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based
385d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and
395d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these
405d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer.
415d6d30edSStella Laurenzo 
425d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided
435d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal
445d6d30edSStella Laurenzo representation:
455d6d30edSStella Laurenzo 
465d6d30edSStella Laurenzo   * Integer types (except for i1): the buffer must be byte aligned to the
475d6d30edSStella Laurenzo     next byte boundary.
485d6d30edSStella Laurenzo   * Floating point types: Must be bit-castable to the given floating point
495d6d30edSStella Laurenzo     size.
505d6d30edSStella Laurenzo   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
515d6d30edSStella Laurenzo     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
525d6d30edSStella Laurenzo     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
535d6d30edSStella Laurenzo 
545d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0
555d6d30edSStella Laurenzo or 255), then a splat will be created.
565d6d30edSStella Laurenzo 
575d6d30edSStella Laurenzo Args:
585d6d30edSStella Laurenzo   array: The array or buffer to convert.
595d6d30edSStella Laurenzo   signless: If inferring an appropriate MLIR type, use signless types for
605d6d30edSStella Laurenzo     integers (defaults True).
615d6d30edSStella Laurenzo   type: Skips inference of the MLIR element type and uses this instead. The
625d6d30edSStella Laurenzo     storage size must be consistent with the actual contents of the buffer.
635d6d30edSStella Laurenzo   shape: Overrides the shape of the buffer when constructing the MLIR
645d6d30edSStella Laurenzo     shaped type. This is needed when the physical and logical shape differ (as
655d6d30edSStella Laurenzo     for i1).
665d6d30edSStella Laurenzo   context: Explicit context, if not from context manager.
675d6d30edSStella Laurenzo 
685d6d30edSStella Laurenzo Returns:
695d6d30edSStella Laurenzo   DenseElementsAttr on success.
705d6d30edSStella Laurenzo 
715d6d30edSStella Laurenzo Raises:
725d6d30edSStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
735d6d30edSStella Laurenzo     type or if the buffer does not meet expectations.
745d6d30edSStella Laurenzo )";
755d6d30edSStella Laurenzo 
76*c912f0e7Spranavm-nvidia static const char kDenseElementsAttrGetFromListDocstring[] =
77*c912f0e7Spranavm-nvidia     R"(Gets a DenseElementsAttr from a Python list of attributes.
78*c912f0e7Spranavm-nvidia 
79*c912f0e7Spranavm-nvidia Note that it can be expensive to construct attributes individually.
80*c912f0e7Spranavm-nvidia For a large number of elements, consider using a Python buffer or array instead.
81*c912f0e7Spranavm-nvidia 
82*c912f0e7Spranavm-nvidia Args:
83*c912f0e7Spranavm-nvidia   attrs: A list of attributes.
84*c912f0e7Spranavm-nvidia   type: The desired shape and type of the resulting DenseElementsAttr.
85*c912f0e7Spranavm-nvidia     If not provided, the element type is determined based on the type
86*c912f0e7Spranavm-nvidia     of the 0th attribute and the shape is `[len(attrs)]`.
87*c912f0e7Spranavm-nvidia   context: Explicit context, if not from context manager.
88*c912f0e7Spranavm-nvidia 
89*c912f0e7Spranavm-nvidia Returns:
90*c912f0e7Spranavm-nvidia   DenseElementsAttr on success.
91*c912f0e7Spranavm-nvidia 
92*c912f0e7Spranavm-nvidia Raises:
93*c912f0e7Spranavm-nvidia   ValueError: If the type of the attributes does not match the type
94*c912f0e7Spranavm-nvidia     specified by `shaped_type`.
95*c912f0e7Spranavm-nvidia )";
96*c912f0e7Spranavm-nvidia 
97f66cd9e9SStella Laurenzo static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
98f66cd9e9SStella Laurenzo     R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
99f66cd9e9SStella Laurenzo 
100f66cd9e9SStella Laurenzo This function does minimal validation or massaging of the data, and it is
101f66cd9e9SStella Laurenzo up to the caller to ensure that the buffer meets the characteristics
102f66cd9e9SStella Laurenzo implied by the shape.
103f66cd9e9SStella Laurenzo 
104f66cd9e9SStella Laurenzo The backing buffer and any user objects will be retained for the lifetime
105f66cd9e9SStella Laurenzo of the resource blob. This is typically bounded to the context but the
106f66cd9e9SStella Laurenzo resource can have a shorter lifespan depending on how it is used in
107f66cd9e9SStella Laurenzo subsequent processing.
108f66cd9e9SStella Laurenzo 
109f66cd9e9SStella Laurenzo Args:
110f66cd9e9SStella Laurenzo   buffer: The array or buffer to convert.
111f66cd9e9SStella Laurenzo   name: Name to provide to the resource (may be changed upon collision).
112f66cd9e9SStella Laurenzo   type: The explicit ShapedType to construct the attribute with.
113f66cd9e9SStella Laurenzo   context: Explicit context, if not from context manager.
114f66cd9e9SStella Laurenzo 
115f66cd9e9SStella Laurenzo Returns:
116f66cd9e9SStella Laurenzo   DenseResourceElementsAttr on success.
117f66cd9e9SStella Laurenzo 
118f66cd9e9SStella Laurenzo Raises:
119f66cd9e9SStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
120f66cd9e9SStella Laurenzo     type or if the buffer does not meet expectations.
121f66cd9e9SStella Laurenzo )";
122f66cd9e9SStella Laurenzo 
123436c6c9cSStella Laurenzo namespace {
124436c6c9cSStella Laurenzo 
125436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
126436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
127436c6c9cSStella Laurenzo }
128436c6c9cSStella Laurenzo 
129436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
130436c6c9cSStella Laurenzo public:
131436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
132436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
133436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1349566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1359566ee28Smax       mlirAffineMapAttrGetTypeID;
136436c6c9cSStella Laurenzo 
137436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
138436c6c9cSStella Laurenzo     c.def_static(
139436c6c9cSStella Laurenzo         "get",
140436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
141436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
142436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
143436c6c9cSStella Laurenzo         },
144436c6c9cSStella Laurenzo         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
145436c6c9cSStella Laurenzo   }
146436c6c9cSStella Laurenzo };
147436c6c9cSStella Laurenzo 
148ed9e52f3SAlex Zinenko template <typename T>
149ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) {
150ed9e52f3SAlex Zinenko   try {
151ed9e52f3SAlex Zinenko     return object.cast<T>();
152ed9e52f3SAlex Zinenko   } catch (py::cast_error &err) {
153ed9e52f3SAlex Zinenko     std::string msg =
154ed9e52f3SAlex Zinenko         std::string(
155ed9e52f3SAlex Zinenko             "Invalid attribute when attempting to create an ArrayAttribute (") +
156ed9e52f3SAlex Zinenko         err.what() + ")";
157ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
158ed9e52f3SAlex Zinenko   } catch (py::reference_cast_error &err) {
159ed9e52f3SAlex Zinenko     std::string msg = std::string("Invalid attribute (None?) when attempting "
160ed9e52f3SAlex Zinenko                                   "to create an ArrayAttribute (") +
161ed9e52f3SAlex Zinenko                       err.what() + ")";
162ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
163ed9e52f3SAlex Zinenko   }
164ed9e52f3SAlex Zinenko }
165ed9e52f3SAlex Zinenko 
166619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived
167619fd8c2SJeff Niu /// implementation class.
168619fd8c2SJeff Niu template <typename EltTy, typename DerivedT>
169133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
170619fd8c2SJeff Niu public:
171133624acSJeff Niu   using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
172619fd8c2SJeff Niu 
173619fd8c2SJeff Niu   /// Iterator over the integer elements of a dense array.
174619fd8c2SJeff Niu   class PyDenseArrayIterator {
175619fd8c2SJeff Niu   public:
1764a1b1196SMehdi Amini     PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
177619fd8c2SJeff Niu 
178619fd8c2SJeff Niu     /// Return a copy of the iterator.
179619fd8c2SJeff Niu     PyDenseArrayIterator dunderIter() { return *this; }
180619fd8c2SJeff Niu 
181619fd8c2SJeff Niu     /// Return the next element.
182619fd8c2SJeff Niu     EltTy dunderNext() {
183619fd8c2SJeff Niu       // Throw if the index has reached the end.
184619fd8c2SJeff Niu       if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
185619fd8c2SJeff Niu         throw py::stop_iteration();
186619fd8c2SJeff Niu       return DerivedT::getElement(attr.get(), nextIndex++);
187619fd8c2SJeff Niu     }
188619fd8c2SJeff Niu 
189619fd8c2SJeff Niu     /// Bind the iterator class.
190619fd8c2SJeff Niu     static void bind(py::module &m) {
191619fd8c2SJeff Niu       py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName,
192619fd8c2SJeff Niu                                        py::module_local())
193619fd8c2SJeff Niu           .def("__iter__", &PyDenseArrayIterator::dunderIter)
194619fd8c2SJeff Niu           .def("__next__", &PyDenseArrayIterator::dunderNext);
195619fd8c2SJeff Niu     }
196619fd8c2SJeff Niu 
197619fd8c2SJeff Niu   private:
198619fd8c2SJeff Niu     /// The referenced dense array attribute.
199619fd8c2SJeff Niu     PyAttribute attr;
200619fd8c2SJeff Niu     /// The next index to read.
201619fd8c2SJeff Niu     int nextIndex = 0;
202619fd8c2SJeff Niu   };
203619fd8c2SJeff Niu 
204619fd8c2SJeff Niu   /// Get the element at the given index.
205619fd8c2SJeff Niu   EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
206619fd8c2SJeff Niu 
207619fd8c2SJeff Niu   /// Bind the attribute class.
208133624acSJeff Niu   static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
209619fd8c2SJeff Niu     // Bind the constructor.
210619fd8c2SJeff Niu     c.def_static(
211619fd8c2SJeff Niu         "get",
212619fd8c2SJeff Niu         [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
2138dcb6722SIngo Müller           return getAttribute(values, ctx->getRef());
214619fd8c2SJeff Niu         },
215619fd8c2SJeff Niu         py::arg("values"), py::arg("context") = py::none(),
216619fd8c2SJeff Niu         "Gets a uniqued dense array attribute");
217619fd8c2SJeff Niu     // Bind the array methods.
218133624acSJeff Niu     c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
219619fd8c2SJeff Niu       if (i >= mlirDenseArrayGetNumElements(arr))
220619fd8c2SJeff Niu         throw py::index_error("DenseArray index out of range");
221619fd8c2SJeff Niu       return arr.getItem(i);
222619fd8c2SJeff Niu     });
223133624acSJeff Niu     c.def("__len__", [](const DerivedT &arr) {
224619fd8c2SJeff Niu       return mlirDenseArrayGetNumElements(arr);
225619fd8c2SJeff Niu     });
226133624acSJeff Niu     c.def("__iter__",
227133624acSJeff Niu           [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
2284a1b1196SMehdi Amini     c.def("__add__", [](DerivedT &arr, const py::list &extras) {
229619fd8c2SJeff Niu       std::vector<EltTy> values;
230619fd8c2SJeff Niu       intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
231619fd8c2SJeff Niu       values.reserve(numOldElements + py::len(extras));
232619fd8c2SJeff Niu       for (intptr_t i = 0; i < numOldElements; ++i)
233619fd8c2SJeff Niu         values.push_back(arr.getItem(i));
234619fd8c2SJeff Niu       for (py::handle attr : extras)
235619fd8c2SJeff Niu         values.push_back(pyTryCast<EltTy>(attr));
2368dcb6722SIngo Müller       return getAttribute(values, arr.getContext());
237619fd8c2SJeff Niu     });
238619fd8c2SJeff Niu   }
2398dcb6722SIngo Müller 
2408dcb6722SIngo Müller private:
2418dcb6722SIngo Müller   static DerivedT getAttribute(const std::vector<EltTy> &values,
2428dcb6722SIngo Müller                                PyMlirContextRef ctx) {
2438dcb6722SIngo Müller     if constexpr (std::is_same_v<EltTy, bool>) {
2448dcb6722SIngo Müller       std::vector<int> intValues(values.begin(), values.end());
2458dcb6722SIngo Müller       MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
2468dcb6722SIngo Müller                                                   intValues.data());
2478dcb6722SIngo Müller       return DerivedT(ctx, attr);
2488dcb6722SIngo Müller     } else {
2498dcb6722SIngo Müller       MlirAttribute attr =
2508dcb6722SIngo Müller           DerivedT::getAttribute(ctx->get(), values.size(), values.data());
2518dcb6722SIngo Müller       return DerivedT(ctx, attr);
2528dcb6722SIngo Müller     }
2538dcb6722SIngo Müller   }
254619fd8c2SJeff Niu };
255619fd8c2SJeff Niu 
256619fd8c2SJeff Niu /// Instantiate the python dense array classes.
257619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute
2588dcb6722SIngo Müller     : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
259619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
260619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseBoolArrayGet;
261619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseBoolArrayGetElement;
262619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseBoolArrayAttr";
263619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
264619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
265619fd8c2SJeff Niu };
266619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute
267619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
268619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
269619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI8ArrayGet;
270619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI8ArrayGetElement;
271619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI8ArrayAttr";
272619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
273619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
274619fd8c2SJeff Niu };
275619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute
276619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
277619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
278619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI16ArrayGet;
279619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI16ArrayGetElement;
280619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI16ArrayAttr";
281619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
282619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
283619fd8c2SJeff Niu };
284619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute
285619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
286619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
287619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI32ArrayGet;
288619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI32ArrayGetElement;
289619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI32ArrayAttr";
290619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
291619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
292619fd8c2SJeff Niu };
293619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute
294619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
295619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
296619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI64ArrayGet;
297619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI64ArrayGetElement;
298619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI64ArrayAttr";
299619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
300619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
301619fd8c2SJeff Niu };
302619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute
303619fd8c2SJeff Niu     : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
304619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
305619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF32ArrayGet;
306619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF32ArrayGetElement;
307619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF32ArrayAttr";
308619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
309619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
310619fd8c2SJeff Niu };
311619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute
312619fd8c2SJeff Niu     : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
313619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
314619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF64ArrayGet;
315619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF64ArrayGetElement;
316619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF64ArrayAttr";
317619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
318619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
319619fd8c2SJeff Niu };
320619fd8c2SJeff Niu 
321436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
322436c6c9cSStella Laurenzo public:
323436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
324436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
325436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
3269566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
3279566ee28Smax       mlirArrayAttrGetTypeID;
328436c6c9cSStella Laurenzo 
329436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
330436c6c9cSStella Laurenzo   public:
3311fc096afSMehdi Amini     PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
332436c6c9cSStella Laurenzo 
333436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
334436c6c9cSStella Laurenzo 
335974c1596SRahul Kayaith     MlirAttribute dunderNext() {
336bca88952SJeff Niu       // TODO: Throw is an inefficient way to stop iteration.
337bca88952SJeff Niu       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
338436c6c9cSStella Laurenzo         throw py::stop_iteration();
339974c1596SRahul Kayaith       return mlirArrayAttrGetElement(attr.get(), nextIndex++);
340436c6c9cSStella Laurenzo     }
341436c6c9cSStella Laurenzo 
342436c6c9cSStella Laurenzo     static void bind(py::module &m) {
343f05ff4f7SStella Laurenzo       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
344f05ff4f7SStella Laurenzo                                            py::module_local())
345436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
346436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
347436c6c9cSStella Laurenzo     }
348436c6c9cSStella Laurenzo 
349436c6c9cSStella Laurenzo   private:
350436c6c9cSStella Laurenzo     PyAttribute attr;
351436c6c9cSStella Laurenzo     int nextIndex = 0;
352436c6c9cSStella Laurenzo   };
353436c6c9cSStella Laurenzo 
354974c1596SRahul Kayaith   MlirAttribute getItem(intptr_t i) {
355974c1596SRahul Kayaith     return mlirArrayAttrGetElement(*this, i);
356ed9e52f3SAlex Zinenko   }
357ed9e52f3SAlex Zinenko 
358436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
359436c6c9cSStella Laurenzo     c.def_static(
360436c6c9cSStella Laurenzo         "get",
361436c6c9cSStella Laurenzo         [](py::list attributes, DefaultingPyMlirContext context) {
362436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
363436c6c9cSStella Laurenzo           mlirAttributes.reserve(py::len(attributes));
364436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
365ed9e52f3SAlex Zinenko             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
366436c6c9cSStella Laurenzo           }
367436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
368436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
369436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
370436c6c9cSStella Laurenzo         },
371436c6c9cSStella Laurenzo         py::arg("attributes"), py::arg("context") = py::none(),
372436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
373436c6c9cSStella Laurenzo     c.def("__getitem__",
374436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
375436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
376436c6c9cSStella Laurenzo               throw py::index_error("ArrayAttribute index out of range");
377ed9e52f3SAlex Zinenko             return arr.getItem(i);
378436c6c9cSStella Laurenzo           })
379436c6c9cSStella Laurenzo         .def("__len__",
380436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
381436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
382436c6c9cSStella Laurenzo              })
383436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
384436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
385436c6c9cSStella Laurenzo         });
386ed9e52f3SAlex Zinenko     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
387ed9e52f3SAlex Zinenko       std::vector<MlirAttribute> attributes;
388ed9e52f3SAlex Zinenko       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
389ed9e52f3SAlex Zinenko       attributes.reserve(numOldElements + py::len(extras));
390ed9e52f3SAlex Zinenko       for (intptr_t i = 0; i < numOldElements; ++i)
391ed9e52f3SAlex Zinenko         attributes.push_back(arr.getItem(i));
392ed9e52f3SAlex Zinenko       for (py::handle attr : extras)
393ed9e52f3SAlex Zinenko         attributes.push_back(pyTryCast<PyAttribute>(attr));
394ed9e52f3SAlex Zinenko       MlirAttribute arrayAttr = mlirArrayAttrGet(
395ed9e52f3SAlex Zinenko           arr.getContext()->get(), attributes.size(), attributes.data());
396ed9e52f3SAlex Zinenko       return PyArrayAttribute(arr.getContext(), arrayAttr);
397ed9e52f3SAlex Zinenko     });
398436c6c9cSStella Laurenzo   }
399436c6c9cSStella Laurenzo };
400436c6c9cSStella Laurenzo 
401436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
402436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
403436c6c9cSStella Laurenzo public:
404436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
405436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
406436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
4079566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
4089566ee28Smax       mlirFloatAttrGetTypeID;
409436c6c9cSStella Laurenzo 
410436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
411436c6c9cSStella Laurenzo     c.def_static(
412436c6c9cSStella Laurenzo         "get",
413436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
4143ea4c501SRahul Kayaith           PyMlirContext::ErrorCapture errors(loc->getContext());
415436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
4163ea4c501SRahul Kayaith           if (mlirAttributeIsNull(attr))
4173ea4c501SRahul Kayaith             throw MLIRError("Invalid attribute", errors.take());
418436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
419436c6c9cSStella Laurenzo         },
420436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
421436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
422436c6c9cSStella Laurenzo     c.def_static(
423436c6c9cSStella Laurenzo         "get_f32",
424436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
425436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
426436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
427436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
428436c6c9cSStella Laurenzo         },
429436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
430436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
431436c6c9cSStella Laurenzo     c.def_static(
432436c6c9cSStella Laurenzo         "get_f64",
433436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
434436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
435436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
436436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
437436c6c9cSStella Laurenzo         },
438436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
439436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
4402a5d4974SIngo Müller     c.def_property_readonly("value", mlirFloatAttrGetValueDouble,
4412a5d4974SIngo Müller                             "Returns the value of the float attribute");
4422a5d4974SIngo Müller     c.def("__float__", mlirFloatAttrGetValueDouble,
4432a5d4974SIngo Müller           "Converts the value of the float attribute to a Python float");
444436c6c9cSStella Laurenzo   }
445436c6c9cSStella Laurenzo };
446436c6c9cSStella Laurenzo 
447436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
448436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
449436c6c9cSStella Laurenzo public:
450436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
451436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
452436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
453436c6c9cSStella Laurenzo 
454436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
455436c6c9cSStella Laurenzo     c.def_static(
456436c6c9cSStella Laurenzo         "get",
457436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
458436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
459436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
460436c6c9cSStella Laurenzo         },
461436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"),
462436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
4632a5d4974SIngo Müller     c.def_property_readonly("value", toPyInt,
4642a5d4974SIngo Müller                             "Returns the value of the integer attribute");
4652a5d4974SIngo Müller     c.def("__int__", toPyInt,
4662a5d4974SIngo Müller           "Converts the value of the integer attribute to a Python int");
4672a5d4974SIngo Müller     c.def_property_readonly_static("static_typeid",
4682a5d4974SIngo Müller                                    [](py::object & /*class*/) -> MlirTypeID {
4692a5d4974SIngo Müller                                      return mlirIntegerAttrGetTypeID();
4702a5d4974SIngo Müller                                    });
4712a5d4974SIngo Müller   }
4722a5d4974SIngo Müller 
4732a5d4974SIngo Müller private:
4742a5d4974SIngo Müller   static py::int_ toPyInt(PyIntegerAttribute &self) {
475e9db306dSrkayaith     MlirType type = mlirAttributeGetType(self);
476e9db306dSrkayaith     if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
477436c6c9cSStella Laurenzo       return mlirIntegerAttrGetValueInt(self);
478e9db306dSrkayaith     if (mlirIntegerTypeIsSigned(type))
479e9db306dSrkayaith       return mlirIntegerAttrGetValueSInt(self);
480e9db306dSrkayaith     return mlirIntegerAttrGetValueUInt(self);
481436c6c9cSStella Laurenzo   }
482436c6c9cSStella Laurenzo };
483436c6c9cSStella Laurenzo 
484436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
485436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
486436c6c9cSStella Laurenzo public:
487436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
488436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
489436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
490436c6c9cSStella Laurenzo 
491436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
492436c6c9cSStella Laurenzo     c.def_static(
493436c6c9cSStella Laurenzo         "get",
494436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
495436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
496436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
497436c6c9cSStella Laurenzo         },
498436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
499436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
5002a5d4974SIngo Müller     c.def_property_readonly("value", mlirBoolAttrGetValue,
501436c6c9cSStella Laurenzo                             "Returns the value of the bool attribute");
5022a5d4974SIngo Müller     c.def("__bool__", mlirBoolAttrGetValue,
5032a5d4974SIngo Müller           "Converts the value of the bool attribute to a Python bool");
504436c6c9cSStella Laurenzo   }
505436c6c9cSStella Laurenzo };
506436c6c9cSStella Laurenzo 
5074eee9ef9Smax class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
5084eee9ef9Smax public:
5094eee9ef9Smax   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
5104eee9ef9Smax   static constexpr const char *pyClassName = "SymbolRefAttr";
5114eee9ef9Smax   using PyConcreteAttribute::PyConcreteAttribute;
5124eee9ef9Smax 
5134eee9ef9Smax   static MlirAttribute fromList(const std::vector<std::string> &symbols,
5144eee9ef9Smax                                 PyMlirContext &context) {
5154eee9ef9Smax     if (symbols.empty())
5164eee9ef9Smax       throw std::runtime_error("SymbolRefAttr must be composed of at least "
5174eee9ef9Smax                                "one symbol.");
5184eee9ef9Smax     MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
5194eee9ef9Smax     SmallVector<MlirAttribute, 3> referenceAttrs;
5204eee9ef9Smax     for (size_t i = 1; i < symbols.size(); ++i) {
5214eee9ef9Smax       referenceAttrs.push_back(
5224eee9ef9Smax           mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
5234eee9ef9Smax     }
5244eee9ef9Smax     return mlirSymbolRefAttrGet(context.get(), rootSymbol,
5254eee9ef9Smax                                 referenceAttrs.size(), referenceAttrs.data());
5264eee9ef9Smax   }
5274eee9ef9Smax 
5284eee9ef9Smax   static void bindDerived(ClassTy &c) {
5294eee9ef9Smax     c.def_static(
5304eee9ef9Smax         "get",
5314eee9ef9Smax         [](const std::vector<std::string> &symbols,
5324eee9ef9Smax            DefaultingPyMlirContext context) {
5334eee9ef9Smax           return PySymbolRefAttribute::fromList(symbols, context.resolve());
5344eee9ef9Smax         },
5354eee9ef9Smax         py::arg("symbols"), py::arg("context") = py::none(),
5364eee9ef9Smax         "Gets a uniqued SymbolRef attribute from a list of symbol names");
5374eee9ef9Smax     c.def_property_readonly(
5384eee9ef9Smax         "value",
5394eee9ef9Smax         [](PySymbolRefAttribute &self) {
5404eee9ef9Smax           std::vector<std::string> symbols = {
5414eee9ef9Smax               unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
5424eee9ef9Smax           for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
5434eee9ef9Smax                ++i)
5444eee9ef9Smax             symbols.push_back(
5454eee9ef9Smax                 unwrap(mlirSymbolRefAttrGetRootReference(
5464eee9ef9Smax                            mlirSymbolRefAttrGetNestedReference(self, i)))
5474eee9ef9Smax                     .str());
5484eee9ef9Smax           return symbols;
5494eee9ef9Smax         },
5504eee9ef9Smax         "Returns the value of the SymbolRef attribute as a list[str]");
5514eee9ef9Smax   }
5524eee9ef9Smax };
5534eee9ef9Smax 
554436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
555436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
556436c6c9cSStella Laurenzo public:
557436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
558436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
559436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
560436c6c9cSStella Laurenzo 
561436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
562436c6c9cSStella Laurenzo     c.def_static(
563436c6c9cSStella Laurenzo         "get",
564436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
565436c6c9cSStella Laurenzo           MlirAttribute attr =
566436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
567436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
568436c6c9cSStella Laurenzo         },
569436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
570436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
571436c6c9cSStella Laurenzo     c.def_property_readonly(
572436c6c9cSStella Laurenzo         "value",
573436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
574436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
575436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
576436c6c9cSStella Laurenzo         },
577436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
578436c6c9cSStella Laurenzo   }
579436c6c9cSStella Laurenzo };
580436c6c9cSStella Laurenzo 
5815c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
5825c3861b2SYun Long public:
5835c3861b2SYun Long   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
5845c3861b2SYun Long   static constexpr const char *pyClassName = "OpaqueAttr";
5855c3861b2SYun Long   using PyConcreteAttribute::PyConcreteAttribute;
5869566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
5879566ee28Smax       mlirOpaqueAttrGetTypeID;
5885c3861b2SYun Long 
5895c3861b2SYun Long   static void bindDerived(ClassTy &c) {
5905c3861b2SYun Long     c.def_static(
5915c3861b2SYun Long         "get",
5925c3861b2SYun Long         [](std::string dialectNamespace, py::buffer buffer, PyType &type,
5935c3861b2SYun Long            DefaultingPyMlirContext context) {
5945c3861b2SYun Long           const py::buffer_info bufferInfo = buffer.request();
5955c3861b2SYun Long           intptr_t bufferSize = bufferInfo.size;
5965c3861b2SYun Long           MlirAttribute attr = mlirOpaqueAttrGet(
5975c3861b2SYun Long               context->get(), toMlirStringRef(dialectNamespace), bufferSize,
5985c3861b2SYun Long               static_cast<char *>(bufferInfo.ptr), type);
5995c3861b2SYun Long           return PyOpaqueAttribute(context->getRef(), attr);
6005c3861b2SYun Long         },
6015c3861b2SYun Long         py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"),
6025c3861b2SYun Long         py::arg("context") = py::none(), "Gets an Opaque attribute.");
6035c3861b2SYun Long     c.def_property_readonly(
6045c3861b2SYun Long         "dialect_namespace",
6055c3861b2SYun Long         [](PyOpaqueAttribute &self) {
6065c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
6075c3861b2SYun Long           return py::str(stringRef.data, stringRef.length);
6085c3861b2SYun Long         },
6095c3861b2SYun Long         "Returns the dialect namespace for the Opaque attribute as a string");
6105c3861b2SYun Long     c.def_property_readonly(
6115c3861b2SYun Long         "data",
6125c3861b2SYun Long         [](PyOpaqueAttribute &self) {
6135c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
61462bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
6155c3861b2SYun Long         },
61662bf6c2eSChris Jones         "Returns the data for the Opaqued attributes as `bytes`");
6175c3861b2SYun Long   }
6185c3861b2SYun Long };
6195c3861b2SYun Long 
620436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
621436c6c9cSStella Laurenzo public:
622436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
623436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
624436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
6259566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
6269566ee28Smax       mlirStringAttrGetTypeID;
627436c6c9cSStella Laurenzo 
628436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
629436c6c9cSStella Laurenzo     c.def_static(
630436c6c9cSStella Laurenzo         "get",
631436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
632436c6c9cSStella Laurenzo           MlirAttribute attr =
633436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
634436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
635436c6c9cSStella Laurenzo         },
636436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
637436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
638436c6c9cSStella Laurenzo     c.def_static(
639436c6c9cSStella Laurenzo         "get_typed",
640436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
641436c6c9cSStella Laurenzo           MlirAttribute attr =
642436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
643436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
644436c6c9cSStella Laurenzo         },
645a6e7d024SStella Laurenzo         py::arg("type"), py::arg("value"),
646436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
6479f533548SIngo Müller     c.def_property_readonly(
6489f533548SIngo Müller         "value",
6499f533548SIngo Müller         [](PyStringAttribute &self) {
6509f533548SIngo Müller           MlirStringRef stringRef = mlirStringAttrGetValue(self);
6519f533548SIngo Müller           return py::str(stringRef.data, stringRef.length);
6529f533548SIngo Müller         },
653436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
65462bf6c2eSChris Jones     c.def_property_readonly(
65562bf6c2eSChris Jones         "value_bytes",
65662bf6c2eSChris Jones         [](PyStringAttribute &self) {
65762bf6c2eSChris Jones           MlirStringRef stringRef = mlirStringAttrGetValue(self);
65862bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
65962bf6c2eSChris Jones         },
66062bf6c2eSChris Jones         "Returns the value of the string attribute as `bytes`");
661436c6c9cSStella Laurenzo   }
662436c6c9cSStella Laurenzo };
663436c6c9cSStella Laurenzo 
664436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
665436c6c9cSStella Laurenzo class PyDenseElementsAttribute
666436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
667436c6c9cSStella Laurenzo public:
668436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
669436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
670436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
671436c6c9cSStella Laurenzo 
672436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
673*c912f0e7Spranavm-nvidia   getFromList(py::list attributes, std::optional<PyType> explicitType,
674*c912f0e7Spranavm-nvidia               DefaultingPyMlirContext contextWrapper) {
675*c912f0e7Spranavm-nvidia 
676*c912f0e7Spranavm-nvidia     const size_t numAttributes = py::len(attributes);
677*c912f0e7Spranavm-nvidia     if (numAttributes == 0)
678*c912f0e7Spranavm-nvidia       throw py::value_error("Attributes list must be non-empty.");
679*c912f0e7Spranavm-nvidia 
680*c912f0e7Spranavm-nvidia     MlirType shapedType;
681*c912f0e7Spranavm-nvidia     if (explicitType) {
682*c912f0e7Spranavm-nvidia       if ((!mlirTypeIsAShaped(*explicitType) ||
683*c912f0e7Spranavm-nvidia            !mlirShapedTypeHasStaticShape(*explicitType))) {
684*c912f0e7Spranavm-nvidia 
685*c912f0e7Spranavm-nvidia         std::string message;
686*c912f0e7Spranavm-nvidia         llvm::raw_string_ostream os(message);
687*c912f0e7Spranavm-nvidia         os << "Expected a static ShapedType for the shaped_type parameter: "
688*c912f0e7Spranavm-nvidia            << py::repr(py::cast(*explicitType));
689*c912f0e7Spranavm-nvidia         throw py::value_error(os.str());
690*c912f0e7Spranavm-nvidia       }
691*c912f0e7Spranavm-nvidia       shapedType = *explicitType;
692*c912f0e7Spranavm-nvidia     } else {
693*c912f0e7Spranavm-nvidia       SmallVector<int64_t> shape{static_cast<int64_t>(numAttributes)};
694*c912f0e7Spranavm-nvidia       shapedType = mlirRankedTensorTypeGet(
695*c912f0e7Spranavm-nvidia           shape.size(), shape.data(),
696*c912f0e7Spranavm-nvidia           mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
697*c912f0e7Spranavm-nvidia           mlirAttributeGetNull());
698*c912f0e7Spranavm-nvidia     }
699*c912f0e7Spranavm-nvidia 
700*c912f0e7Spranavm-nvidia     SmallVector<MlirAttribute> mlirAttributes;
701*c912f0e7Spranavm-nvidia     mlirAttributes.reserve(numAttributes);
702*c912f0e7Spranavm-nvidia     for (const py::handle &attribute : attributes) {
703*c912f0e7Spranavm-nvidia       MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
704*c912f0e7Spranavm-nvidia       MlirType attrType = mlirAttributeGetType(mlirAttribute);
705*c912f0e7Spranavm-nvidia       mlirAttributes.push_back(mlirAttribute);
706*c912f0e7Spranavm-nvidia 
707*c912f0e7Spranavm-nvidia       if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
708*c912f0e7Spranavm-nvidia         std::string message;
709*c912f0e7Spranavm-nvidia         llvm::raw_string_ostream os(message);
710*c912f0e7Spranavm-nvidia         os << "All attributes must be of the same type and match "
711*c912f0e7Spranavm-nvidia            << "the type parameter: expected=" << py::repr(py::cast(shapedType))
712*c912f0e7Spranavm-nvidia            << ", but got=" << py::repr(py::cast(attrType));
713*c912f0e7Spranavm-nvidia         throw py::value_error(os.str());
714*c912f0e7Spranavm-nvidia       }
715*c912f0e7Spranavm-nvidia     }
716*c912f0e7Spranavm-nvidia 
717*c912f0e7Spranavm-nvidia     MlirAttribute elements = mlirDenseElementsAttrGet(
718*c912f0e7Spranavm-nvidia         shapedType, mlirAttributes.size(), mlirAttributes.data());
719*c912f0e7Spranavm-nvidia 
720*c912f0e7Spranavm-nvidia     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
721*c912f0e7Spranavm-nvidia   }
722*c912f0e7Spranavm-nvidia 
723*c912f0e7Spranavm-nvidia   static PyDenseElementsAttribute
7240a81ace0SKazu Hirata   getFromBuffer(py::buffer array, bool signless,
7250a81ace0SKazu Hirata                 std::optional<PyType> explicitType,
7260a81ace0SKazu Hirata                 std::optional<std::vector<int64_t>> explicitShape,
727436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
728436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
72971a25454SPeter Hawkins     int flags = PyBUF_ND;
73071a25454SPeter Hawkins     if (!explicitType) {
73171a25454SPeter Hawkins       flags |= PyBUF_FORMAT;
73271a25454SPeter Hawkins     }
73371a25454SPeter Hawkins     Py_buffer view;
73471a25454SPeter Hawkins     if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
735436c6c9cSStella Laurenzo       throw py::error_already_set();
736436c6c9cSStella Laurenzo     }
73771a25454SPeter Hawkins     auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
7385d6d30edSStella Laurenzo     SmallVector<int64_t> shape;
7395d6d30edSStella Laurenzo     if (explicitShape) {
7405d6d30edSStella Laurenzo       shape.append(explicitShape->begin(), explicitShape->end());
7415d6d30edSStella Laurenzo     } else {
74271a25454SPeter Hawkins       shape.append(view.shape, view.shape + view.ndim);
7435d6d30edSStella Laurenzo     }
744436c6c9cSStella Laurenzo 
7455d6d30edSStella Laurenzo     MlirAttribute encodingAttr = mlirAttributeGetNull();
746436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
7475d6d30edSStella Laurenzo 
7485d6d30edSStella Laurenzo     // Detect format codes that are suitable for bulk loading. This includes
7495d6d30edSStella Laurenzo     // all byte aligned integer and floating point types up to 8 bytes.
7505d6d30edSStella Laurenzo     // Notably, this excludes, bool (which needs to be bit-packed) and
7515d6d30edSStella Laurenzo     // other exotics which do not have a direct representation in the buffer
7525d6d30edSStella Laurenzo     // protocol (i.e. complex, etc).
7530a81ace0SKazu Hirata     std::optional<MlirType> bulkLoadElementType;
7545d6d30edSStella Laurenzo     if (explicitType) {
7555d6d30edSStella Laurenzo       bulkLoadElementType = *explicitType;
75671a25454SPeter Hawkins     } else {
75771a25454SPeter Hawkins       std::string_view format(view.format);
75871a25454SPeter Hawkins       if (format == "f") {
759436c6c9cSStella Laurenzo         // f32
76071a25454SPeter Hawkins         assert(view.itemsize == 4 && "mismatched array itemsize");
7615d6d30edSStella Laurenzo         bulkLoadElementType = mlirF32TypeGet(context);
76271a25454SPeter Hawkins       } else if (format == "d") {
763436c6c9cSStella Laurenzo         // f64
76471a25454SPeter Hawkins         assert(view.itemsize == 8 && "mismatched array itemsize");
7655d6d30edSStella Laurenzo         bulkLoadElementType = mlirF64TypeGet(context);
76671a25454SPeter Hawkins       } else if (format == "e") {
7675d6d30edSStella Laurenzo         // f16
76871a25454SPeter Hawkins         assert(view.itemsize == 2 && "mismatched array itemsize");
7695d6d30edSStella Laurenzo         bulkLoadElementType = mlirF16TypeGet(context);
77071a25454SPeter Hawkins       } else if (isSignedIntegerFormat(format)) {
77171a25454SPeter Hawkins         if (view.itemsize == 4) {
772436c6c9cSStella Laurenzo           // i32
77371a25454SPeter Hawkins           bulkLoadElementType = signless
77471a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 32)
775436c6c9cSStella Laurenzo                                     : mlirIntegerTypeSignedGet(context, 32);
77671a25454SPeter Hawkins         } else if (view.itemsize == 8) {
777436c6c9cSStella Laurenzo           // i64
77871a25454SPeter Hawkins           bulkLoadElementType = signless
77971a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 64)
780436c6c9cSStella Laurenzo                                     : mlirIntegerTypeSignedGet(context, 64);
78171a25454SPeter Hawkins         } else if (view.itemsize == 1) {
7825d6d30edSStella Laurenzo           // i8
7835d6d30edSStella Laurenzo           bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
7845d6d30edSStella Laurenzo                                          : mlirIntegerTypeSignedGet(context, 8);
78571a25454SPeter Hawkins         } else if (view.itemsize == 2) {
7865d6d30edSStella Laurenzo           // i16
78771a25454SPeter Hawkins           bulkLoadElementType = signless
78871a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 16)
7895d6d30edSStella Laurenzo                                     : mlirIntegerTypeSignedGet(context, 16);
790436c6c9cSStella Laurenzo         }
79171a25454SPeter Hawkins       } else if (isUnsignedIntegerFormat(format)) {
79271a25454SPeter Hawkins         if (view.itemsize == 4) {
793436c6c9cSStella Laurenzo           // unsigned i32
7945d6d30edSStella Laurenzo           bulkLoadElementType = signless
795436c6c9cSStella Laurenzo                                     ? mlirIntegerTypeGet(context, 32)
796436c6c9cSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 32);
79771a25454SPeter Hawkins         } else if (view.itemsize == 8) {
798436c6c9cSStella Laurenzo           // unsigned i64
7995d6d30edSStella Laurenzo           bulkLoadElementType = signless
800436c6c9cSStella Laurenzo                                     ? mlirIntegerTypeGet(context, 64)
801436c6c9cSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 64);
80271a25454SPeter Hawkins         } else if (view.itemsize == 1) {
8035d6d30edSStella Laurenzo           // i8
80471a25454SPeter Hawkins           bulkLoadElementType = signless
80571a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 8)
8065d6d30edSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 8);
80771a25454SPeter Hawkins         } else if (view.itemsize == 2) {
8085d6d30edSStella Laurenzo           // i16
8095d6d30edSStella Laurenzo           bulkLoadElementType = signless
8105d6d30edSStella Laurenzo                                     ? mlirIntegerTypeGet(context, 16)
8115d6d30edSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 16);
812436c6c9cSStella Laurenzo         }
813436c6c9cSStella Laurenzo       }
81471a25454SPeter Hawkins       if (!bulkLoadElementType) {
81571a25454SPeter Hawkins         throw std::invalid_argument(
81671a25454SPeter Hawkins             std::string("unimplemented array format conversion from format: ") +
81771a25454SPeter Hawkins             std::string(format));
81871a25454SPeter Hawkins       }
81971a25454SPeter Hawkins     }
82071a25454SPeter Hawkins 
82199dee31eSAdam Paszke     MlirType shapedType;
82299dee31eSAdam Paszke     if (mlirTypeIsAShaped(*bulkLoadElementType)) {
82399dee31eSAdam Paszke       if (explicitShape) {
82499dee31eSAdam Paszke         throw std::invalid_argument("Shape can only be specified explicitly "
82599dee31eSAdam Paszke                                     "when the type is not a shaped type.");
82699dee31eSAdam Paszke       }
82799dee31eSAdam Paszke       shapedType = *bulkLoadElementType;
82899dee31eSAdam Paszke     } else {
82971a25454SPeter Hawkins       shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
83071a25454SPeter Hawkins                                            *bulkLoadElementType, encodingAttr);
83199dee31eSAdam Paszke     }
83271a25454SPeter Hawkins     size_t rawBufferSize = view.len;
83371a25454SPeter Hawkins     MlirAttribute attr =
83471a25454SPeter Hawkins         mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
8355d6d30edSStella Laurenzo     if (mlirAttributeIsNull(attr)) {
8365d6d30edSStella Laurenzo       throw std::invalid_argument(
8375d6d30edSStella Laurenzo           "DenseElementsAttr could not be constructed from the given buffer. "
8385d6d30edSStella Laurenzo           "This may mean that the Python buffer layout does not match that "
8395d6d30edSStella Laurenzo           "MLIR expected layout and is a bug.");
8405d6d30edSStella Laurenzo     }
8415d6d30edSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
8425d6d30edSStella Laurenzo   }
843436c6c9cSStella Laurenzo 
8441fc096afSMehdi Amini   static PyDenseElementsAttribute getSplat(const PyType &shapedType,
845436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
846436c6c9cSStella Laurenzo     auto contextWrapper =
847436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
848436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
849436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
850436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
851436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
8524811270bSmax       throw py::value_error(message);
853436c6c9cSStella Laurenzo     }
854436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
855436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
856436c6c9cSStella Laurenzo       std::string message =
857436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
858436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
8594811270bSmax       throw py::value_error(message);
860436c6c9cSStella Laurenzo     }
861436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
862436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
863436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
864436c6c9cSStella Laurenzo       std::string message =
865436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
866436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
867436c6c9cSStella Laurenzo       message.append(", element=");
868436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
8694811270bSmax       throw py::value_error(message);
870436c6c9cSStella Laurenzo     }
871436c6c9cSStella Laurenzo 
872436c6c9cSStella Laurenzo     MlirAttribute elements =
873436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
874436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
875436c6c9cSStella Laurenzo   }
876436c6c9cSStella Laurenzo 
877436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
878436c6c9cSStella Laurenzo 
879436c6c9cSStella Laurenzo   py::buffer_info accessBuffer() {
880436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
881436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
8825d6d30edSStella Laurenzo     std::string format;
883436c6c9cSStella Laurenzo 
884436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
885436c6c9cSStella Laurenzo       // f32
8865d6d30edSStella Laurenzo       return bufferInfo<float>(shapedType);
88702b6fb21SMehdi Amini     }
88802b6fb21SMehdi Amini     if (mlirTypeIsAF64(elementType)) {
889436c6c9cSStella Laurenzo       // f64
8905d6d30edSStella Laurenzo       return bufferInfo<double>(shapedType);
891bb56c2b3SMehdi Amini     }
892bb56c2b3SMehdi Amini     if (mlirTypeIsAF16(elementType)) {
8935d6d30edSStella Laurenzo       // f16
8945d6d30edSStella Laurenzo       return bufferInfo<uint16_t>(shapedType, "e");
895bb56c2b3SMehdi Amini     }
896ef1b735dSmax     if (mlirTypeIsAIndex(elementType)) {
897ef1b735dSmax       // Same as IndexType::kInternalStorageBitWidth
898ef1b735dSmax       return bufferInfo<int64_t>(shapedType);
899ef1b735dSmax     }
900bb56c2b3SMehdi Amini     if (mlirTypeIsAInteger(elementType) &&
901436c6c9cSStella Laurenzo         mlirIntegerTypeGetWidth(elementType) == 32) {
902436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
903436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
904436c6c9cSStella Laurenzo         // i32
9055d6d30edSStella Laurenzo         return bufferInfo<int32_t>(shapedType);
906e5639b3fSMehdi Amini       }
907e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
908436c6c9cSStella Laurenzo         // unsigned i32
9095d6d30edSStella Laurenzo         return bufferInfo<uint32_t>(shapedType);
910436c6c9cSStella Laurenzo       }
911436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
912436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
913436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
914436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
915436c6c9cSStella Laurenzo         // i64
9165d6d30edSStella Laurenzo         return bufferInfo<int64_t>(shapedType);
917e5639b3fSMehdi Amini       }
918e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
919436c6c9cSStella Laurenzo         // unsigned i64
9205d6d30edSStella Laurenzo         return bufferInfo<uint64_t>(shapedType);
9215d6d30edSStella Laurenzo       }
9225d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
9235d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 8) {
9245d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
9255d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
9265d6d30edSStella Laurenzo         // i8
9275d6d30edSStella Laurenzo         return bufferInfo<int8_t>(shapedType);
928e5639b3fSMehdi Amini       }
929e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
9305d6d30edSStella Laurenzo         // unsigned i8
9315d6d30edSStella Laurenzo         return bufferInfo<uint8_t>(shapedType);
9325d6d30edSStella Laurenzo       }
9335d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
9345d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 16) {
9355d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
9365d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
9375d6d30edSStella Laurenzo         // i16
9385d6d30edSStella Laurenzo         return bufferInfo<int16_t>(shapedType);
939e5639b3fSMehdi Amini       }
940e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
9415d6d30edSStella Laurenzo         // unsigned i16
9425d6d30edSStella Laurenzo         return bufferInfo<uint16_t>(shapedType);
943436c6c9cSStella Laurenzo       }
944436c6c9cSStella Laurenzo     }
945436c6c9cSStella Laurenzo 
946c5f445d1SStella Laurenzo     // TODO: Currently crashes the program.
9475d6d30edSStella Laurenzo     // Reported as https://github.com/pybind/pybind11/issues/3336
948c5f445d1SStella Laurenzo     throw std::invalid_argument(
949c5f445d1SStella Laurenzo         "unsupported data type for conversion to Python buffer");
950436c6c9cSStella Laurenzo   }
951436c6c9cSStella Laurenzo 
952436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
953436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
954436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
955436c6c9cSStella Laurenzo                     py::arg("array"), py::arg("signless") = true,
9565d6d30edSStella Laurenzo                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
957436c6c9cSStella Laurenzo                     py::arg("context") = py::none(),
9585d6d30edSStella Laurenzo                     kDenseElementsAttrGetDocstring)
959*c912f0e7Spranavm-nvidia         .def_static("get", PyDenseElementsAttribute::getFromList,
960*c912f0e7Spranavm-nvidia                     py::arg("attrs"), py::arg("type") = py::none(),
961*c912f0e7Spranavm-nvidia                     py::arg("context") = py::none(),
962*c912f0e7Spranavm-nvidia                     kDenseElementsAttrGetFromListDocstring)
963436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
964436c6c9cSStella Laurenzo                     py::arg("shaped_type"), py::arg("element_attr"),
965436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
966436c6c9cSStella Laurenzo         .def_property_readonly("is_splat",
967436c6c9cSStella Laurenzo                                [](PyDenseElementsAttribute &self) -> bool {
968436c6c9cSStella Laurenzo                                  return mlirDenseElementsAttrIsSplat(self);
969436c6c9cSStella Laurenzo                                })
97091259963SAdam Paszke         .def("get_splat_value",
971974c1596SRahul Kayaith              [](PyDenseElementsAttribute &self) {
972974c1596SRahul Kayaith                if (!mlirDenseElementsAttrIsSplat(self))
9734811270bSmax                  throw py::value_error(
97491259963SAdam Paszke                      "get_splat_value called on a non-splat attribute");
975974c1596SRahul Kayaith                return mlirDenseElementsAttrGetSplatValue(self);
97691259963SAdam Paszke              })
977436c6c9cSStella Laurenzo         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
978436c6c9cSStella Laurenzo   }
979436c6c9cSStella Laurenzo 
980436c6c9cSStella Laurenzo private:
98171a25454SPeter Hawkins   static bool isUnsignedIntegerFormat(std::string_view format) {
982436c6c9cSStella Laurenzo     if (format.empty())
983436c6c9cSStella Laurenzo       return false;
984436c6c9cSStella Laurenzo     char code = format[0];
985436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
986436c6c9cSStella Laurenzo            code == 'Q';
987436c6c9cSStella Laurenzo   }
988436c6c9cSStella Laurenzo 
98971a25454SPeter Hawkins   static bool isSignedIntegerFormat(std::string_view format) {
990436c6c9cSStella Laurenzo     if (format.empty())
991436c6c9cSStella Laurenzo       return false;
992436c6c9cSStella Laurenzo     char code = format[0];
993436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
994436c6c9cSStella Laurenzo            code == 'q';
995436c6c9cSStella Laurenzo   }
996436c6c9cSStella Laurenzo 
997436c6c9cSStella Laurenzo   template <typename Type>
998436c6c9cSStella Laurenzo   py::buffer_info bufferInfo(MlirType shapedType,
9995d6d30edSStella Laurenzo                              const char *explicitFormat = nullptr) {
1000436c6c9cSStella Laurenzo     intptr_t rank = mlirShapedTypeGetRank(shapedType);
1001436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
1002436c6c9cSStella Laurenzo     // Buffer is configured for read-only access below.
1003436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
1004436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1005436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
1006436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
1007436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
1008436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
1009436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
1010436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
1011f0e847d0SRahul Kayaith     if (mlirDenseElementsAttrIsSplat(*this)) {
1012f0e847d0SRahul Kayaith       // Splats are special, only the single value is stored.
1013f0e847d0SRahul Kayaith       strides.assign(rank, 0);
1014f0e847d0SRahul Kayaith     } else {
1015436c6c9cSStella Laurenzo       for (intptr_t i = 1; i < rank; ++i) {
1016f0e847d0SRahul Kayaith         intptr_t strideFactor = 1;
1017f0e847d0SRahul Kayaith         for (intptr_t j = i; j < rank; ++j)
1018436c6c9cSStella Laurenzo           strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
1019436c6c9cSStella Laurenzo         strides.push_back(sizeof(Type) * strideFactor);
1020436c6c9cSStella Laurenzo       }
1021436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type));
1022f0e847d0SRahul Kayaith     }
10235d6d30edSStella Laurenzo     std::string format;
10245d6d30edSStella Laurenzo     if (explicitFormat) {
10255d6d30edSStella Laurenzo       format = explicitFormat;
10265d6d30edSStella Laurenzo     } else {
10275d6d30edSStella Laurenzo       format = py::format_descriptor<Type>::format();
10285d6d30edSStella Laurenzo     }
10295d6d30edSStella Laurenzo     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
10305d6d30edSStella Laurenzo                            /*readonly=*/true);
1031436c6c9cSStella Laurenzo   }
1032436c6c9cSStella Laurenzo }; // namespace
1033436c6c9cSStella Laurenzo 
1034436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer
1035436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access.
1036436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
1037436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
1038436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
1039436c6c9cSStella Laurenzo public:
1040436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
1041436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
1042436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1043436c6c9cSStella Laurenzo 
1044436c6c9cSStella Laurenzo   /// Returns the element at the given linear position. Asserts if the index is
1045436c6c9cSStella Laurenzo   /// out of range.
1046436c6c9cSStella Laurenzo   py::int_ dunderGetItem(intptr_t pos) {
1047436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
10484811270bSmax       throw py::index_error("attempt to access out of bounds element");
1049436c6c9cSStella Laurenzo     }
1050436c6c9cSStella Laurenzo 
1051436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
1052436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
1053436c6c9cSStella Laurenzo     assert(mlirTypeIsAInteger(type) &&
1054436c6c9cSStella Laurenzo            "expected integer element type in dense int elements attribute");
1055436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
1056436c6c9cSStella Laurenzo     // elemental type of the attribute. py::int_ is implicitly constructible
1057436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
1058436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
1059436c6c9cSStella Laurenzo     // querying them on each element access.
1060436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
1061436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
1062436c6c9cSStella Laurenzo     if (isUnsigned) {
1063436c6c9cSStella Laurenzo       if (width == 1) {
1064436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
1065436c6c9cSStella Laurenzo       }
1066308d8b8cSRahul Kayaith       if (width == 8) {
1067308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt8Value(*this, pos);
1068308d8b8cSRahul Kayaith       }
1069308d8b8cSRahul Kayaith       if (width == 16) {
1070308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt16Value(*this, pos);
1071308d8b8cSRahul Kayaith       }
1072436c6c9cSStella Laurenzo       if (width == 32) {
1073436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
1074436c6c9cSStella Laurenzo       }
1075436c6c9cSStella Laurenzo       if (width == 64) {
1076436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
1077436c6c9cSStella Laurenzo       }
1078436c6c9cSStella Laurenzo     } else {
1079436c6c9cSStella Laurenzo       if (width == 1) {
1080436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
1081436c6c9cSStella Laurenzo       }
1082308d8b8cSRahul Kayaith       if (width == 8) {
1083308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt8Value(*this, pos);
1084308d8b8cSRahul Kayaith       }
1085308d8b8cSRahul Kayaith       if (width == 16) {
1086308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt16Value(*this, pos);
1087308d8b8cSRahul Kayaith       }
1088436c6c9cSStella Laurenzo       if (width == 32) {
1089436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt32Value(*this, pos);
1090436c6c9cSStella Laurenzo       }
1091436c6c9cSStella Laurenzo       if (width == 64) {
1092436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt64Value(*this, pos);
1093436c6c9cSStella Laurenzo       }
1094436c6c9cSStella Laurenzo     }
10954811270bSmax     throw py::type_error("Unsupported integer type");
1096436c6c9cSStella Laurenzo   }
1097436c6c9cSStella Laurenzo 
1098436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1099436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
1100436c6c9cSStella Laurenzo   }
1101436c6c9cSStella Laurenzo };
1102436c6c9cSStella Laurenzo 
1103f66cd9e9SStella Laurenzo class PyDenseResourceElementsAttribute
1104f66cd9e9SStella Laurenzo     : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
1105f66cd9e9SStella Laurenzo public:
1106f66cd9e9SStella Laurenzo   static constexpr IsAFunctionTy isaFunction =
1107f66cd9e9SStella Laurenzo       mlirAttributeIsADenseResourceElements;
1108f66cd9e9SStella Laurenzo   static constexpr const char *pyClassName = "DenseResourceElementsAttr";
1109f66cd9e9SStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1110f66cd9e9SStella Laurenzo 
1111f66cd9e9SStella Laurenzo   static PyDenseResourceElementsAttribute
1112962bf002SMehdi Amini   getFromBuffer(py::buffer buffer, const std::string &name, const PyType &type,
1113f66cd9e9SStella Laurenzo                 std::optional<size_t> alignment, bool isMutable,
1114f66cd9e9SStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
1115f66cd9e9SStella Laurenzo     if (!mlirTypeIsAShaped(type)) {
1116f66cd9e9SStella Laurenzo       throw std::invalid_argument(
1117f66cd9e9SStella Laurenzo           "Constructing a DenseResourceElementsAttr requires a ShapedType.");
1118f66cd9e9SStella Laurenzo     }
1119f66cd9e9SStella Laurenzo 
1120f66cd9e9SStella Laurenzo     // Do not request any conversions as we must ensure to use caller
1121f66cd9e9SStella Laurenzo     // managed memory.
1122f66cd9e9SStella Laurenzo     int flags = PyBUF_STRIDES;
1123f66cd9e9SStella Laurenzo     std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
1124f66cd9e9SStella Laurenzo     if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
1125f66cd9e9SStella Laurenzo       throw py::error_already_set();
1126f66cd9e9SStella Laurenzo     }
1127f66cd9e9SStella Laurenzo 
1128f66cd9e9SStella Laurenzo     // This scope releaser will only release if we haven't yet transferred
1129f66cd9e9SStella Laurenzo     // ownership.
1130f66cd9e9SStella Laurenzo     auto freeBuffer = llvm::make_scope_exit([&]() {
1131f66cd9e9SStella Laurenzo       if (view)
1132f66cd9e9SStella Laurenzo         PyBuffer_Release(view.get());
1133f66cd9e9SStella Laurenzo     });
1134f66cd9e9SStella Laurenzo 
1135f66cd9e9SStella Laurenzo     if (!PyBuffer_IsContiguous(view.get(), 'A')) {
1136f66cd9e9SStella Laurenzo       throw std::invalid_argument("Contiguous buffer is required.");
1137f66cd9e9SStella Laurenzo     }
1138f66cd9e9SStella Laurenzo 
1139f66cd9e9SStella Laurenzo     // Infer alignment to be the stride of one element if not explicit.
1140f66cd9e9SStella Laurenzo     size_t inferredAlignment;
1141f66cd9e9SStella Laurenzo     if (alignment)
1142f66cd9e9SStella Laurenzo       inferredAlignment = *alignment;
1143f66cd9e9SStella Laurenzo     else
1144f66cd9e9SStella Laurenzo       inferredAlignment = view->strides[view->ndim - 1];
1145f66cd9e9SStella Laurenzo 
1146f66cd9e9SStella Laurenzo     // The userData is a Py_buffer* that the deleter owns.
1147f66cd9e9SStella Laurenzo     auto deleter = [](void *userData, const void *data, size_t size,
1148f66cd9e9SStella Laurenzo                       size_t align) {
1149f66cd9e9SStella Laurenzo       Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
1150f66cd9e9SStella Laurenzo       PyBuffer_Release(ownedView);
1151f66cd9e9SStella Laurenzo       delete ownedView;
1152f66cd9e9SStella Laurenzo     };
1153f66cd9e9SStella Laurenzo 
1154f66cd9e9SStella Laurenzo     size_t rawBufferSize = view->len;
1155f66cd9e9SStella Laurenzo     MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
1156f66cd9e9SStella Laurenzo         type, toMlirStringRef(name), view->buf, rawBufferSize,
1157f66cd9e9SStella Laurenzo         inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
1158f66cd9e9SStella Laurenzo     if (mlirAttributeIsNull(attr)) {
1159f66cd9e9SStella Laurenzo       throw std::invalid_argument(
1160f66cd9e9SStella Laurenzo           "DenseResourceElementsAttr could not be constructed from the given "
1161f66cd9e9SStella Laurenzo           "buffer. "
1162f66cd9e9SStella Laurenzo           "This may mean that the Python buffer layout does not match that "
1163f66cd9e9SStella Laurenzo           "MLIR expected layout and is a bug.");
1164f66cd9e9SStella Laurenzo     }
1165f66cd9e9SStella Laurenzo     view.release();
1166f66cd9e9SStella Laurenzo     return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
1167f66cd9e9SStella Laurenzo   }
1168f66cd9e9SStella Laurenzo 
1169f66cd9e9SStella Laurenzo   static void bindDerived(ClassTy &c) {
1170f66cd9e9SStella Laurenzo     c.def_static("get_from_buffer",
1171f66cd9e9SStella Laurenzo                  PyDenseResourceElementsAttribute::getFromBuffer,
1172f66cd9e9SStella Laurenzo                  py::arg("array"), py::arg("name"), py::arg("type"),
1173f66cd9e9SStella Laurenzo                  py::arg("alignment") = py::none(),
1174f66cd9e9SStella Laurenzo                  py::arg("is_mutable") = false, py::arg("context") = py::none(),
1175f66cd9e9SStella Laurenzo                  kDenseResourceElementsAttrGetFromBufferDocstring);
1176f66cd9e9SStella Laurenzo   }
1177f66cd9e9SStella Laurenzo };
1178f66cd9e9SStella Laurenzo 
1179436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
1180436c6c9cSStella Laurenzo public:
1181436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
1182436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
1183436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
11849566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
11859566ee28Smax       mlirDictionaryAttrGetTypeID;
1186436c6c9cSStella Laurenzo 
1187436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
1188436c6c9cSStella Laurenzo 
11899fb1086bSAdrian Kuegel   bool dunderContains(const std::string &name) {
11909fb1086bSAdrian Kuegel     return !mlirAttributeIsNull(
11919fb1086bSAdrian Kuegel         mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
11929fb1086bSAdrian Kuegel   }
11939fb1086bSAdrian Kuegel 
1194436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
11959fb1086bSAdrian Kuegel     c.def("__contains__", &PyDictAttribute::dunderContains);
1196436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
1197436c6c9cSStella Laurenzo     c.def_static(
1198436c6c9cSStella Laurenzo         "get",
1199436c6c9cSStella Laurenzo         [](py::dict attributes, DefaultingPyMlirContext context) {
1200436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
1201436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
1202436c6c9cSStella Laurenzo           for (auto &it : attributes) {
120302b6fb21SMehdi Amini             auto &mlirAttr = it.second.cast<PyAttribute &>();
1204436c6c9cSStella Laurenzo             auto name = it.first.cast<std::string>();
1205436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
120602b6fb21SMehdi Amini                 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
1207436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
120802b6fb21SMehdi Amini                 mlirAttr));
1209436c6c9cSStella Laurenzo           }
1210436c6c9cSStella Laurenzo           MlirAttribute attr =
1211436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
1212436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
1213436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
1214436c6c9cSStella Laurenzo         },
1215ed9e52f3SAlex Zinenko         py::arg("value") = py::dict(), py::arg("context") = py::none(),
1216436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
1217436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
1218436c6c9cSStella Laurenzo       MlirAttribute attr =
1219436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
1220974c1596SRahul Kayaith       if (mlirAttributeIsNull(attr))
12214811270bSmax         throw py::key_error("attempt to access a non-existent attribute");
1222974c1596SRahul Kayaith       return attr;
1223436c6c9cSStella Laurenzo     });
1224436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
1225436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
12264811270bSmax         throw py::index_error("attempt to access out of bounds attribute");
1227436c6c9cSStella Laurenzo       }
1228436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
1229436c6c9cSStella Laurenzo       return PyNamedAttribute(
1230436c6c9cSStella Laurenzo           namedAttr.attribute,
1231436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
1232436c6c9cSStella Laurenzo     });
1233436c6c9cSStella Laurenzo   }
1234436c6c9cSStella Laurenzo };
1235436c6c9cSStella Laurenzo 
1236436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
1237436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
1238436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
1239436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1240436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
1241436c6c9cSStella Laurenzo public:
1242436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1243436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
1244436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1245436c6c9cSStella Laurenzo 
1246436c6c9cSStella Laurenzo   py::float_ dunderGetItem(intptr_t pos) {
1247436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
12484811270bSmax       throw py::index_error("attempt to access out of bounds element");
1249436c6c9cSStella Laurenzo     }
1250436c6c9cSStella Laurenzo 
1251436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
1252436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
1253436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
1254436c6c9cSStella Laurenzo     // elemental type of the attribute. py::float_ is implicitly constructible
1255436c6c9cSStella Laurenzo     // from float and double.
1256436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
1257436c6c9cSStella Laurenzo     // querying them on each element access.
1258436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
1259436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetFloatValue(*this, pos);
1260436c6c9cSStella Laurenzo     }
1261436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
1262436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
1263436c6c9cSStella Laurenzo     }
12644811270bSmax     throw py::type_error("Unsupported floating-point type");
1265436c6c9cSStella Laurenzo   }
1266436c6c9cSStella Laurenzo 
1267436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1268436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1269436c6c9cSStella Laurenzo   }
1270436c6c9cSStella Laurenzo };
1271436c6c9cSStella Laurenzo 
1272436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1273436c6c9cSStella Laurenzo public:
1274436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1275436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
1276436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
12779566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
12789566ee28Smax       mlirTypeAttrGetTypeID;
1279436c6c9cSStella Laurenzo 
1280436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1281436c6c9cSStella Laurenzo     c.def_static(
1282436c6c9cSStella Laurenzo         "get",
1283436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
1284436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
1285436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
1286436c6c9cSStella Laurenzo         },
1287436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
1288436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
1289436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyTypeAttribute &self) {
1290bfb1ba75Smax       return mlirTypeAttrGetValue(self.get());
1291436c6c9cSStella Laurenzo     });
1292436c6c9cSStella Laurenzo   }
1293436c6c9cSStella Laurenzo };
1294436c6c9cSStella Laurenzo 
1295436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
1296436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1297436c6c9cSStella Laurenzo public:
1298436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1299436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
1300436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
13019566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
13029566ee28Smax       mlirUnitAttrGetTypeID;
1303436c6c9cSStella Laurenzo 
1304436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1305436c6c9cSStella Laurenzo     c.def_static(
1306436c6c9cSStella Laurenzo         "get",
1307436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
1308436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
1309436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
1310436c6c9cSStella Laurenzo         },
1311436c6c9cSStella Laurenzo         py::arg("context") = py::none(), "Create a Unit attribute.");
1312436c6c9cSStella Laurenzo   }
1313436c6c9cSStella Laurenzo };
1314436c6c9cSStella Laurenzo 
1315ac2e2d65SDenys Shabalin /// Strided layout attribute subclass.
1316ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute
1317ac2e2d65SDenys Shabalin     : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1318ac2e2d65SDenys Shabalin public:
1319ac2e2d65SDenys Shabalin   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1320ac2e2d65SDenys Shabalin   static constexpr const char *pyClassName = "StridedLayoutAttr";
1321ac2e2d65SDenys Shabalin   using PyConcreteAttribute::PyConcreteAttribute;
13229566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
13239566ee28Smax       mlirStridedLayoutAttrGetTypeID;
1324ac2e2d65SDenys Shabalin 
1325ac2e2d65SDenys Shabalin   static void bindDerived(ClassTy &c) {
1326ac2e2d65SDenys Shabalin     c.def_static(
1327ac2e2d65SDenys Shabalin         "get",
1328ac2e2d65SDenys Shabalin         [](int64_t offset, const std::vector<int64_t> strides,
1329ac2e2d65SDenys Shabalin            DefaultingPyMlirContext ctx) {
1330ac2e2d65SDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1331ac2e2d65SDenys Shabalin               ctx->get(), offset, strides.size(), strides.data());
1332ac2e2d65SDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1333ac2e2d65SDenys Shabalin         },
1334ac2e2d65SDenys Shabalin         py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
1335ac2e2d65SDenys Shabalin         "Gets a strided layout attribute.");
1336e3fd612eSDenys Shabalin     c.def_static(
1337e3fd612eSDenys Shabalin         "get_fully_dynamic",
1338e3fd612eSDenys Shabalin         [](int64_t rank, DefaultingPyMlirContext ctx) {
1339e3fd612eSDenys Shabalin           auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
1340e3fd612eSDenys Shabalin           std::vector<int64_t> strides(rank);
1341e3fd612eSDenys Shabalin           std::fill(strides.begin(), strides.end(), dynamic);
1342e3fd612eSDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1343e3fd612eSDenys Shabalin               ctx->get(), dynamic, strides.size(), strides.data());
1344e3fd612eSDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1345e3fd612eSDenys Shabalin         },
1346e3fd612eSDenys Shabalin         py::arg("rank"), py::arg("context") = py::none(),
1347e3fd612eSDenys Shabalin         "Gets a strided layout attribute with dynamic offset and strides of a "
1348e3fd612eSDenys Shabalin         "given rank.");
1349ac2e2d65SDenys Shabalin     c.def_property_readonly(
1350ac2e2d65SDenys Shabalin         "offset",
1351ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1352ac2e2d65SDenys Shabalin           return mlirStridedLayoutAttrGetOffset(self);
1353ac2e2d65SDenys Shabalin         },
1354ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1355ac2e2d65SDenys Shabalin     c.def_property_readonly(
1356ac2e2d65SDenys Shabalin         "strides",
1357ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1358ac2e2d65SDenys Shabalin           intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1359ac2e2d65SDenys Shabalin           std::vector<int64_t> strides(size);
1360ac2e2d65SDenys Shabalin           for (intptr_t i = 0; i < size; i++) {
1361ac2e2d65SDenys Shabalin             strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1362ac2e2d65SDenys Shabalin           }
1363ac2e2d65SDenys Shabalin           return strides;
1364ac2e2d65SDenys Shabalin         },
1365ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1366ac2e2d65SDenys Shabalin   }
1367ac2e2d65SDenys Shabalin };
1368ac2e2d65SDenys Shabalin 
13699566ee28Smax py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
13709566ee28Smax   if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
13719566ee28Smax     return py::cast(PyDenseBoolArrayAttribute(pyAttribute));
13729566ee28Smax   if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
13739566ee28Smax     return py::cast(PyDenseI8ArrayAttribute(pyAttribute));
13749566ee28Smax   if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
13759566ee28Smax     return py::cast(PyDenseI16ArrayAttribute(pyAttribute));
13769566ee28Smax   if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
13779566ee28Smax     return py::cast(PyDenseI32ArrayAttribute(pyAttribute));
13789566ee28Smax   if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
13799566ee28Smax     return py::cast(PyDenseI64ArrayAttribute(pyAttribute));
13809566ee28Smax   if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
13819566ee28Smax     return py::cast(PyDenseF32ArrayAttribute(pyAttribute));
13829566ee28Smax   if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
13839566ee28Smax     return py::cast(PyDenseF64ArrayAttribute(pyAttribute));
13849566ee28Smax   std::string msg =
13859566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
13869566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
13879566ee28Smax   throw py::cast_error(msg);
13889566ee28Smax }
13899566ee28Smax 
13909566ee28Smax py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
13919566ee28Smax   if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
13929566ee28Smax     return py::cast(PyDenseFPElementsAttribute(pyAttribute));
13939566ee28Smax   if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
13949566ee28Smax     return py::cast(PyDenseIntElementsAttribute(pyAttribute));
13959566ee28Smax   std::string msg =
13969566ee28Smax       std::string(
13979566ee28Smax           "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
13989566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
13999566ee28Smax   throw py::cast_error(msg);
14009566ee28Smax }
14019566ee28Smax 
14029566ee28Smax py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
14039566ee28Smax   if (PyBoolAttribute::isaFunction(pyAttribute))
14049566ee28Smax     return py::cast(PyBoolAttribute(pyAttribute));
14059566ee28Smax   if (PyIntegerAttribute::isaFunction(pyAttribute))
14069566ee28Smax     return py::cast(PyIntegerAttribute(pyAttribute));
14079566ee28Smax   std::string msg =
14089566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
14099566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
14109566ee28Smax   throw py::cast_error(msg);
14119566ee28Smax }
14129566ee28Smax 
14134eee9ef9Smax py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
14144eee9ef9Smax   if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
14154eee9ef9Smax     return py::cast(PyFlatSymbolRefAttribute(pyAttribute));
14164eee9ef9Smax   if (PySymbolRefAttribute::isaFunction(pyAttribute))
14174eee9ef9Smax     return py::cast(PySymbolRefAttribute(pyAttribute));
14184eee9ef9Smax   std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
14194eee9ef9Smax                     std::string(py::repr(py::cast(pyAttribute))) + ")";
14204eee9ef9Smax   throw py::cast_error(msg);
14214eee9ef9Smax }
14224eee9ef9Smax 
1423436c6c9cSStella Laurenzo } // namespace
1424436c6c9cSStella Laurenzo 
1425436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) {
1426436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
1427619fd8c2SJeff Niu 
1428619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::bind(m);
1429619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1430619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::bind(m);
1431619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1432619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::bind(m);
1433619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1434619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::bind(m);
1435619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1436619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::bind(m);
1437619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1438619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::bind(m);
1439619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1440619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::bind(m);
1441619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
14429566ee28Smax   PyGlobals::get().registerTypeCaster(
14439566ee28Smax       mlirDenseArrayAttrGetTypeID(),
14449566ee28Smax       pybind11::cpp_function(denseArrayAttributeCaster));
1445619fd8c2SJeff Niu 
1446436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
1447436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1448436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
1449436c6c9cSStella Laurenzo   PyDenseElementsAttribute::bind(m);
1450436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
1451436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
14529566ee28Smax   PyGlobals::get().registerTypeCaster(
14539566ee28Smax       mlirDenseIntOrFPElementsAttrGetTypeID(),
14549566ee28Smax       pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
1455f66cd9e9SStella Laurenzo   PyDenseResourceElementsAttribute::bind(m);
14569566ee28Smax 
1457436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
14584eee9ef9Smax   PySymbolRefAttribute::bind(m);
14594eee9ef9Smax   PyGlobals::get().registerTypeCaster(
14604eee9ef9Smax       mlirSymbolRefAttrGetTypeID(),
14614eee9ef9Smax       pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster));
14624eee9ef9Smax 
1463436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
14645c3861b2SYun Long   PyOpaqueAttribute::bind(m);
1465436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
1466436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
1467436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
1468436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
14699566ee28Smax   PyGlobals::get().registerTypeCaster(
14709566ee28Smax       mlirIntegerAttrGetTypeID(),
14719566ee28Smax       pybind11::cpp_function(integerOrBoolAttributeCaster));
1472436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
1473ac2e2d65SDenys Shabalin 
1474ac2e2d65SDenys Shabalin   PyStridedLayoutAttribute::bind(m);
1475436c6c9cSStella Laurenzo }
1476