xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision c36b4248286c4546df0c0e93137a340facc75e17)
1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9a1fe1f5fSKazu Hirata #include <optional>
1071a25454SPeter Hawkins #include <string_view>
114811270bSmax #include <utility>
121fc096afSMehdi Amini 
13436c6c9cSStella Laurenzo #include "IRModule.h"
14436c6c9cSStella Laurenzo 
15436c6c9cSStella Laurenzo #include "PybindUtils.h"
16436c6c9cSStella Laurenzo 
1771a25454SPeter Hawkins #include "llvm/ADT/ScopeExit.h"
18c912f0e7Spranavm-nvidia #include "llvm/Support/raw_ostream.h"
1971a25454SPeter Hawkins 
20436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
21436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
22bfb1ba75Smax #include "mlir/Bindings/Python/PybindAdaptors.h"
23436c6c9cSStella Laurenzo 
24436c6c9cSStella Laurenzo namespace py = pybind11;
25436c6c9cSStella Laurenzo using namespace mlir;
26436c6c9cSStella Laurenzo using namespace mlir::python;
27436c6c9cSStella Laurenzo 
28436c6c9cSStella Laurenzo using llvm::SmallVector;
29436c6c9cSStella Laurenzo 
305d6d30edSStella Laurenzo //------------------------------------------------------------------------------
315d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
325d6d30edSStella Laurenzo //------------------------------------------------------------------------------
335d6d30edSStella Laurenzo 
345d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] =
355d6d30edSStella Laurenzo     R"(Gets a DenseElementsAttr from a Python buffer or array.
365d6d30edSStella Laurenzo 
375d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based
385d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and
395d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these
405d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer.
415d6d30edSStella Laurenzo 
425d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided
435d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal
445d6d30edSStella Laurenzo representation:
455d6d30edSStella Laurenzo 
465d6d30edSStella Laurenzo   * Integer types (except for i1): the buffer must be byte aligned to the
475d6d30edSStella Laurenzo     next byte boundary.
485d6d30edSStella Laurenzo   * Floating point types: Must be bit-castable to the given floating point
495d6d30edSStella Laurenzo     size.
505d6d30edSStella Laurenzo   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
515d6d30edSStella Laurenzo     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
525d6d30edSStella Laurenzo     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
535d6d30edSStella Laurenzo 
545d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0
555d6d30edSStella Laurenzo or 255), then a splat will be created.
565d6d30edSStella Laurenzo 
575d6d30edSStella Laurenzo Args:
585d6d30edSStella Laurenzo   array: The array or buffer to convert.
595d6d30edSStella Laurenzo   signless: If inferring an appropriate MLIR type, use signless types for
605d6d30edSStella Laurenzo     integers (defaults True).
615d6d30edSStella Laurenzo   type: Skips inference of the MLIR element type and uses this instead. The
625d6d30edSStella Laurenzo     storage size must be consistent with the actual contents of the buffer.
635d6d30edSStella Laurenzo   shape: Overrides the shape of the buffer when constructing the MLIR
645d6d30edSStella Laurenzo     shaped type. This is needed when the physical and logical shape differ (as
655d6d30edSStella Laurenzo     for i1).
665d6d30edSStella Laurenzo   context: Explicit context, if not from context manager.
675d6d30edSStella Laurenzo 
685d6d30edSStella Laurenzo Returns:
695d6d30edSStella Laurenzo   DenseElementsAttr on success.
705d6d30edSStella Laurenzo 
715d6d30edSStella Laurenzo Raises:
725d6d30edSStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
735d6d30edSStella Laurenzo     type or if the buffer does not meet expectations.
745d6d30edSStella Laurenzo )";
755d6d30edSStella Laurenzo 
76c912f0e7Spranavm-nvidia static const char kDenseElementsAttrGetFromListDocstring[] =
77c912f0e7Spranavm-nvidia     R"(Gets a DenseElementsAttr from a Python list of attributes.
78c912f0e7Spranavm-nvidia 
79c912f0e7Spranavm-nvidia Note that it can be expensive to construct attributes individually.
80c912f0e7Spranavm-nvidia For a large number of elements, consider using a Python buffer or array instead.
81c912f0e7Spranavm-nvidia 
82c912f0e7Spranavm-nvidia Args:
83c912f0e7Spranavm-nvidia   attrs: A list of attributes.
84c912f0e7Spranavm-nvidia   type: The desired shape and type of the resulting DenseElementsAttr.
85c912f0e7Spranavm-nvidia     If not provided, the element type is determined based on the type
86c912f0e7Spranavm-nvidia     of the 0th attribute and the shape is `[len(attrs)]`.
87c912f0e7Spranavm-nvidia   context: Explicit context, if not from context manager.
88c912f0e7Spranavm-nvidia 
89c912f0e7Spranavm-nvidia Returns:
90c912f0e7Spranavm-nvidia   DenseElementsAttr on success.
91c912f0e7Spranavm-nvidia 
92c912f0e7Spranavm-nvidia Raises:
93c912f0e7Spranavm-nvidia   ValueError: If the type of the attributes does not match the type
94c912f0e7Spranavm-nvidia     specified by `shaped_type`.
95c912f0e7Spranavm-nvidia )";
96c912f0e7Spranavm-nvidia 
97f66cd9e9SStella Laurenzo static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
98f66cd9e9SStella Laurenzo     R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
99f66cd9e9SStella Laurenzo 
100f66cd9e9SStella Laurenzo This function does minimal validation or massaging of the data, and it is
101f66cd9e9SStella Laurenzo up to the caller to ensure that the buffer meets the characteristics
102f66cd9e9SStella Laurenzo implied by the shape.
103f66cd9e9SStella Laurenzo 
104f66cd9e9SStella Laurenzo The backing buffer and any user objects will be retained for the lifetime
105f66cd9e9SStella Laurenzo of the resource blob. This is typically bounded to the context but the
106f66cd9e9SStella Laurenzo resource can have a shorter lifespan depending on how it is used in
107f66cd9e9SStella Laurenzo subsequent processing.
108f66cd9e9SStella Laurenzo 
109f66cd9e9SStella Laurenzo Args:
110f66cd9e9SStella Laurenzo   buffer: The array or buffer to convert.
111f66cd9e9SStella Laurenzo   name: Name to provide to the resource (may be changed upon collision).
112f66cd9e9SStella Laurenzo   type: The explicit ShapedType to construct the attribute with.
113f66cd9e9SStella Laurenzo   context: Explicit context, if not from context manager.
114f66cd9e9SStella Laurenzo 
115f66cd9e9SStella Laurenzo Returns:
116f66cd9e9SStella Laurenzo   DenseResourceElementsAttr on success.
117f66cd9e9SStella Laurenzo 
118f66cd9e9SStella Laurenzo Raises:
119f66cd9e9SStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
120f66cd9e9SStella Laurenzo     type or if the buffer does not meet expectations.
121f66cd9e9SStella Laurenzo )";
122f66cd9e9SStella Laurenzo 
123436c6c9cSStella Laurenzo namespace {
124436c6c9cSStella Laurenzo 
125436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
126436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
127436c6c9cSStella Laurenzo }
128436c6c9cSStella Laurenzo 
129436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
130436c6c9cSStella Laurenzo public:
131436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
132436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
133436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1349566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1359566ee28Smax       mlirAffineMapAttrGetTypeID;
136436c6c9cSStella Laurenzo 
137436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
138436c6c9cSStella Laurenzo     c.def_static(
139436c6c9cSStella Laurenzo         "get",
140436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
141436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
142436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
143436c6c9cSStella Laurenzo         },
144436c6c9cSStella Laurenzo         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
145*c36b4248SBimo     c.def_property_readonly("value", mlirAffineMapAttrGetValue,
146*c36b4248SBimo                             "Returns the value of the AffineMap attribute");
147436c6c9cSStella Laurenzo   }
148436c6c9cSStella Laurenzo };
149436c6c9cSStella Laurenzo 
150ed9e52f3SAlex Zinenko template <typename T>
151ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) {
152ed9e52f3SAlex Zinenko   try {
153ed9e52f3SAlex Zinenko     return object.cast<T>();
154ed9e52f3SAlex Zinenko   } catch (py::cast_error &err) {
155ed9e52f3SAlex Zinenko     std::string msg =
156ed9e52f3SAlex Zinenko         std::string(
157ed9e52f3SAlex Zinenko             "Invalid attribute when attempting to create an ArrayAttribute (") +
158ed9e52f3SAlex Zinenko         err.what() + ")";
159ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
160ed9e52f3SAlex Zinenko   } catch (py::reference_cast_error &err) {
161ed9e52f3SAlex Zinenko     std::string msg = std::string("Invalid attribute (None?) when attempting "
162ed9e52f3SAlex Zinenko                                   "to create an ArrayAttribute (") +
163ed9e52f3SAlex Zinenko                       err.what() + ")";
164ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
165ed9e52f3SAlex Zinenko   }
166ed9e52f3SAlex Zinenko }
167ed9e52f3SAlex Zinenko 
168619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived
169619fd8c2SJeff Niu /// implementation class.
170619fd8c2SJeff Niu template <typename EltTy, typename DerivedT>
171133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
172619fd8c2SJeff Niu public:
173133624acSJeff Niu   using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
174619fd8c2SJeff Niu 
175619fd8c2SJeff Niu   /// Iterator over the integer elements of a dense array.
176619fd8c2SJeff Niu   class PyDenseArrayIterator {
177619fd8c2SJeff Niu   public:
1784a1b1196SMehdi Amini     PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
179619fd8c2SJeff Niu 
180619fd8c2SJeff Niu     /// Return a copy of the iterator.
181619fd8c2SJeff Niu     PyDenseArrayIterator dunderIter() { return *this; }
182619fd8c2SJeff Niu 
183619fd8c2SJeff Niu     /// Return the next element.
184619fd8c2SJeff Niu     EltTy dunderNext() {
185619fd8c2SJeff Niu       // Throw if the index has reached the end.
186619fd8c2SJeff Niu       if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
187619fd8c2SJeff Niu         throw py::stop_iteration();
188619fd8c2SJeff Niu       return DerivedT::getElement(attr.get(), nextIndex++);
189619fd8c2SJeff Niu     }
190619fd8c2SJeff Niu 
191619fd8c2SJeff Niu     /// Bind the iterator class.
192619fd8c2SJeff Niu     static void bind(py::module &m) {
193619fd8c2SJeff Niu       py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName,
194619fd8c2SJeff Niu                                        py::module_local())
195619fd8c2SJeff Niu           .def("__iter__", &PyDenseArrayIterator::dunderIter)
196619fd8c2SJeff Niu           .def("__next__", &PyDenseArrayIterator::dunderNext);
197619fd8c2SJeff Niu     }
198619fd8c2SJeff Niu 
199619fd8c2SJeff Niu   private:
200619fd8c2SJeff Niu     /// The referenced dense array attribute.
201619fd8c2SJeff Niu     PyAttribute attr;
202619fd8c2SJeff Niu     /// The next index to read.
203619fd8c2SJeff Niu     int nextIndex = 0;
204619fd8c2SJeff Niu   };
205619fd8c2SJeff Niu 
206619fd8c2SJeff Niu   /// Get the element at the given index.
207619fd8c2SJeff Niu   EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
208619fd8c2SJeff Niu 
209619fd8c2SJeff Niu   /// Bind the attribute class.
210133624acSJeff Niu   static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
211619fd8c2SJeff Niu     // Bind the constructor.
212619fd8c2SJeff Niu     c.def_static(
213619fd8c2SJeff Niu         "get",
214619fd8c2SJeff Niu         [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
2158dcb6722SIngo Müller           return getAttribute(values, ctx->getRef());
216619fd8c2SJeff Niu         },
217619fd8c2SJeff Niu         py::arg("values"), py::arg("context") = py::none(),
218619fd8c2SJeff Niu         "Gets a uniqued dense array attribute");
219619fd8c2SJeff Niu     // Bind the array methods.
220133624acSJeff Niu     c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
221619fd8c2SJeff Niu       if (i >= mlirDenseArrayGetNumElements(arr))
222619fd8c2SJeff Niu         throw py::index_error("DenseArray index out of range");
223619fd8c2SJeff Niu       return arr.getItem(i);
224619fd8c2SJeff Niu     });
225133624acSJeff Niu     c.def("__len__", [](const DerivedT &arr) {
226619fd8c2SJeff Niu       return mlirDenseArrayGetNumElements(arr);
227619fd8c2SJeff Niu     });
228133624acSJeff Niu     c.def("__iter__",
229133624acSJeff Niu           [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
2304a1b1196SMehdi Amini     c.def("__add__", [](DerivedT &arr, const py::list &extras) {
231619fd8c2SJeff Niu       std::vector<EltTy> values;
232619fd8c2SJeff Niu       intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
233619fd8c2SJeff Niu       values.reserve(numOldElements + py::len(extras));
234619fd8c2SJeff Niu       for (intptr_t i = 0; i < numOldElements; ++i)
235619fd8c2SJeff Niu         values.push_back(arr.getItem(i));
236619fd8c2SJeff Niu       for (py::handle attr : extras)
237619fd8c2SJeff Niu         values.push_back(pyTryCast<EltTy>(attr));
2388dcb6722SIngo Müller       return getAttribute(values, arr.getContext());
239619fd8c2SJeff Niu     });
240619fd8c2SJeff Niu   }
2418dcb6722SIngo Müller 
2428dcb6722SIngo Müller private:
2438dcb6722SIngo Müller   static DerivedT getAttribute(const std::vector<EltTy> &values,
2448dcb6722SIngo Müller                                PyMlirContextRef ctx) {
2458dcb6722SIngo Müller     if constexpr (std::is_same_v<EltTy, bool>) {
2468dcb6722SIngo Müller       std::vector<int> intValues(values.begin(), values.end());
2478dcb6722SIngo Müller       MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
2488dcb6722SIngo Müller                                                   intValues.data());
2498dcb6722SIngo Müller       return DerivedT(ctx, attr);
2508dcb6722SIngo Müller     } else {
2518dcb6722SIngo Müller       MlirAttribute attr =
2528dcb6722SIngo Müller           DerivedT::getAttribute(ctx->get(), values.size(), values.data());
2538dcb6722SIngo Müller       return DerivedT(ctx, attr);
2548dcb6722SIngo Müller     }
2558dcb6722SIngo Müller   }
256619fd8c2SJeff Niu };
257619fd8c2SJeff Niu 
258619fd8c2SJeff Niu /// Instantiate the python dense array classes.
259619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute
2608dcb6722SIngo Müller     : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
261619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
262619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseBoolArrayGet;
263619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseBoolArrayGetElement;
264619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseBoolArrayAttr";
265619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
266619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
267619fd8c2SJeff Niu };
268619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute
269619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
270619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
271619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI8ArrayGet;
272619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI8ArrayGetElement;
273619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI8ArrayAttr";
274619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
275619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
276619fd8c2SJeff Niu };
277619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute
278619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
279619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
280619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI16ArrayGet;
281619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI16ArrayGetElement;
282619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI16ArrayAttr";
283619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
284619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
285619fd8c2SJeff Niu };
286619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute
287619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
288619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
289619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI32ArrayGet;
290619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI32ArrayGetElement;
291619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI32ArrayAttr";
292619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
293619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
294619fd8c2SJeff Niu };
295619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute
296619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
297619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
298619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI64ArrayGet;
299619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI64ArrayGetElement;
300619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI64ArrayAttr";
301619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
302619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
303619fd8c2SJeff Niu };
304619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute
305619fd8c2SJeff Niu     : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
306619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
307619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF32ArrayGet;
308619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF32ArrayGetElement;
309619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF32ArrayAttr";
310619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
311619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
312619fd8c2SJeff Niu };
313619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute
314619fd8c2SJeff Niu     : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
315619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
316619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF64ArrayGet;
317619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF64ArrayGetElement;
318619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF64ArrayAttr";
319619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
320619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
321619fd8c2SJeff Niu };
322619fd8c2SJeff Niu 
323436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
324436c6c9cSStella Laurenzo public:
325436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
326436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
327436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
3289566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
3299566ee28Smax       mlirArrayAttrGetTypeID;
330436c6c9cSStella Laurenzo 
331436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
332436c6c9cSStella Laurenzo   public:
3331fc096afSMehdi Amini     PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
334436c6c9cSStella Laurenzo 
335436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
336436c6c9cSStella Laurenzo 
337974c1596SRahul Kayaith     MlirAttribute dunderNext() {
338bca88952SJeff Niu       // TODO: Throw is an inefficient way to stop iteration.
339bca88952SJeff Niu       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
340436c6c9cSStella Laurenzo         throw py::stop_iteration();
341974c1596SRahul Kayaith       return mlirArrayAttrGetElement(attr.get(), nextIndex++);
342436c6c9cSStella Laurenzo     }
343436c6c9cSStella Laurenzo 
344436c6c9cSStella Laurenzo     static void bind(py::module &m) {
345f05ff4f7SStella Laurenzo       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
346f05ff4f7SStella Laurenzo                                            py::module_local())
347436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
348436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
349436c6c9cSStella Laurenzo     }
350436c6c9cSStella Laurenzo 
351436c6c9cSStella Laurenzo   private:
352436c6c9cSStella Laurenzo     PyAttribute attr;
353436c6c9cSStella Laurenzo     int nextIndex = 0;
354436c6c9cSStella Laurenzo   };
355436c6c9cSStella Laurenzo 
356974c1596SRahul Kayaith   MlirAttribute getItem(intptr_t i) {
357974c1596SRahul Kayaith     return mlirArrayAttrGetElement(*this, i);
358ed9e52f3SAlex Zinenko   }
359ed9e52f3SAlex Zinenko 
360436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
361436c6c9cSStella Laurenzo     c.def_static(
362436c6c9cSStella Laurenzo         "get",
363436c6c9cSStella Laurenzo         [](py::list attributes, DefaultingPyMlirContext context) {
364436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
365436c6c9cSStella Laurenzo           mlirAttributes.reserve(py::len(attributes));
366436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
367ed9e52f3SAlex Zinenko             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
368436c6c9cSStella Laurenzo           }
369436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
370436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
371436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
372436c6c9cSStella Laurenzo         },
373436c6c9cSStella Laurenzo         py::arg("attributes"), py::arg("context") = py::none(),
374436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
375436c6c9cSStella Laurenzo     c.def("__getitem__",
376436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
377436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
378436c6c9cSStella Laurenzo               throw py::index_error("ArrayAttribute index out of range");
379ed9e52f3SAlex Zinenko             return arr.getItem(i);
380436c6c9cSStella Laurenzo           })
381436c6c9cSStella Laurenzo         .def("__len__",
382436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
383436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
384436c6c9cSStella Laurenzo              })
385436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
386436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
387436c6c9cSStella Laurenzo         });
388ed9e52f3SAlex Zinenko     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
389ed9e52f3SAlex Zinenko       std::vector<MlirAttribute> attributes;
390ed9e52f3SAlex Zinenko       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
391ed9e52f3SAlex Zinenko       attributes.reserve(numOldElements + py::len(extras));
392ed9e52f3SAlex Zinenko       for (intptr_t i = 0; i < numOldElements; ++i)
393ed9e52f3SAlex Zinenko         attributes.push_back(arr.getItem(i));
394ed9e52f3SAlex Zinenko       for (py::handle attr : extras)
395ed9e52f3SAlex Zinenko         attributes.push_back(pyTryCast<PyAttribute>(attr));
396ed9e52f3SAlex Zinenko       MlirAttribute arrayAttr = mlirArrayAttrGet(
397ed9e52f3SAlex Zinenko           arr.getContext()->get(), attributes.size(), attributes.data());
398ed9e52f3SAlex Zinenko       return PyArrayAttribute(arr.getContext(), arrayAttr);
399ed9e52f3SAlex Zinenko     });
400436c6c9cSStella Laurenzo   }
401436c6c9cSStella Laurenzo };
402436c6c9cSStella Laurenzo 
403436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
404436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
405436c6c9cSStella Laurenzo public:
406436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
407436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
408436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
4099566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
4109566ee28Smax       mlirFloatAttrGetTypeID;
411436c6c9cSStella Laurenzo 
412436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
413436c6c9cSStella Laurenzo     c.def_static(
414436c6c9cSStella Laurenzo         "get",
415436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
4163ea4c501SRahul Kayaith           PyMlirContext::ErrorCapture errors(loc->getContext());
417436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
4183ea4c501SRahul Kayaith           if (mlirAttributeIsNull(attr))
4193ea4c501SRahul Kayaith             throw MLIRError("Invalid attribute", errors.take());
420436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
421436c6c9cSStella Laurenzo         },
422436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
423436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
424436c6c9cSStella Laurenzo     c.def_static(
425436c6c9cSStella Laurenzo         "get_f32",
426436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
427436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
428436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
429436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
430436c6c9cSStella Laurenzo         },
431436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
432436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
433436c6c9cSStella Laurenzo     c.def_static(
434436c6c9cSStella Laurenzo         "get_f64",
435436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
436436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
437436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
438436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
439436c6c9cSStella Laurenzo         },
440436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
441436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
4422a5d4974SIngo Müller     c.def_property_readonly("value", mlirFloatAttrGetValueDouble,
4432a5d4974SIngo Müller                             "Returns the value of the float attribute");
4442a5d4974SIngo Müller     c.def("__float__", mlirFloatAttrGetValueDouble,
4452a5d4974SIngo Müller           "Converts the value of the float attribute to a Python float");
446436c6c9cSStella Laurenzo   }
447436c6c9cSStella Laurenzo };
448436c6c9cSStella Laurenzo 
449436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
450436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
451436c6c9cSStella Laurenzo public:
452436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
453436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
454436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
455436c6c9cSStella Laurenzo 
456436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
457436c6c9cSStella Laurenzo     c.def_static(
458436c6c9cSStella Laurenzo         "get",
459436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
460436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
461436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
462436c6c9cSStella Laurenzo         },
463436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"),
464436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
4652a5d4974SIngo Müller     c.def_property_readonly("value", toPyInt,
4662a5d4974SIngo Müller                             "Returns the value of the integer attribute");
4672a5d4974SIngo Müller     c.def("__int__", toPyInt,
4682a5d4974SIngo Müller           "Converts the value of the integer attribute to a Python int");
4692a5d4974SIngo Müller     c.def_property_readonly_static("static_typeid",
4702a5d4974SIngo Müller                                    [](py::object & /*class*/) -> MlirTypeID {
4712a5d4974SIngo Müller                                      return mlirIntegerAttrGetTypeID();
4722a5d4974SIngo Müller                                    });
4732a5d4974SIngo Müller   }
4742a5d4974SIngo Müller 
4752a5d4974SIngo Müller private:
4762a5d4974SIngo Müller   static py::int_ toPyInt(PyIntegerAttribute &self) {
477e9db306dSrkayaith     MlirType type = mlirAttributeGetType(self);
478e9db306dSrkayaith     if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
479436c6c9cSStella Laurenzo       return mlirIntegerAttrGetValueInt(self);
480e9db306dSrkayaith     if (mlirIntegerTypeIsSigned(type))
481e9db306dSrkayaith       return mlirIntegerAttrGetValueSInt(self);
482e9db306dSrkayaith     return mlirIntegerAttrGetValueUInt(self);
483436c6c9cSStella Laurenzo   }
484436c6c9cSStella Laurenzo };
485436c6c9cSStella Laurenzo 
486436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
487436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
488436c6c9cSStella Laurenzo public:
489436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
490436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
491436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
492436c6c9cSStella Laurenzo 
493436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
494436c6c9cSStella Laurenzo     c.def_static(
495436c6c9cSStella Laurenzo         "get",
496436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
497436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
498436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
499436c6c9cSStella Laurenzo         },
500436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
501436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
5022a5d4974SIngo Müller     c.def_property_readonly("value", mlirBoolAttrGetValue,
503436c6c9cSStella Laurenzo                             "Returns the value of the bool attribute");
5042a5d4974SIngo Müller     c.def("__bool__", mlirBoolAttrGetValue,
5052a5d4974SIngo Müller           "Converts the value of the bool attribute to a Python bool");
506436c6c9cSStella Laurenzo   }
507436c6c9cSStella Laurenzo };
508436c6c9cSStella Laurenzo 
5094eee9ef9Smax class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
5104eee9ef9Smax public:
5114eee9ef9Smax   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
5124eee9ef9Smax   static constexpr const char *pyClassName = "SymbolRefAttr";
5134eee9ef9Smax   using PyConcreteAttribute::PyConcreteAttribute;
5144eee9ef9Smax 
5154eee9ef9Smax   static MlirAttribute fromList(const std::vector<std::string> &symbols,
5164eee9ef9Smax                                 PyMlirContext &context) {
5174eee9ef9Smax     if (symbols.empty())
5184eee9ef9Smax       throw std::runtime_error("SymbolRefAttr must be composed of at least "
5194eee9ef9Smax                                "one symbol.");
5204eee9ef9Smax     MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
5214eee9ef9Smax     SmallVector<MlirAttribute, 3> referenceAttrs;
5224eee9ef9Smax     for (size_t i = 1; i < symbols.size(); ++i) {
5234eee9ef9Smax       referenceAttrs.push_back(
5244eee9ef9Smax           mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
5254eee9ef9Smax     }
5264eee9ef9Smax     return mlirSymbolRefAttrGet(context.get(), rootSymbol,
5274eee9ef9Smax                                 referenceAttrs.size(), referenceAttrs.data());
5284eee9ef9Smax   }
5294eee9ef9Smax 
5304eee9ef9Smax   static void bindDerived(ClassTy &c) {
5314eee9ef9Smax     c.def_static(
5324eee9ef9Smax         "get",
5334eee9ef9Smax         [](const std::vector<std::string> &symbols,
5344eee9ef9Smax            DefaultingPyMlirContext context) {
5354eee9ef9Smax           return PySymbolRefAttribute::fromList(symbols, context.resolve());
5364eee9ef9Smax         },
5374eee9ef9Smax         py::arg("symbols"), py::arg("context") = py::none(),
5384eee9ef9Smax         "Gets a uniqued SymbolRef attribute from a list of symbol names");
5394eee9ef9Smax     c.def_property_readonly(
5404eee9ef9Smax         "value",
5414eee9ef9Smax         [](PySymbolRefAttribute &self) {
5424eee9ef9Smax           std::vector<std::string> symbols = {
5434eee9ef9Smax               unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
5444eee9ef9Smax           for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
5454eee9ef9Smax                ++i)
5464eee9ef9Smax             symbols.push_back(
5474eee9ef9Smax                 unwrap(mlirSymbolRefAttrGetRootReference(
5484eee9ef9Smax                            mlirSymbolRefAttrGetNestedReference(self, i)))
5494eee9ef9Smax                     .str());
5504eee9ef9Smax           return symbols;
5514eee9ef9Smax         },
5524eee9ef9Smax         "Returns the value of the SymbolRef attribute as a list[str]");
5534eee9ef9Smax   }
5544eee9ef9Smax };
5554eee9ef9Smax 
556436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
557436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
558436c6c9cSStella Laurenzo public:
559436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
560436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
561436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
562436c6c9cSStella Laurenzo 
563436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
564436c6c9cSStella Laurenzo     c.def_static(
565436c6c9cSStella Laurenzo         "get",
566436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
567436c6c9cSStella Laurenzo           MlirAttribute attr =
568436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
569436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
570436c6c9cSStella Laurenzo         },
571436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
572436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
573436c6c9cSStella Laurenzo     c.def_property_readonly(
574436c6c9cSStella Laurenzo         "value",
575436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
576436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
577436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
578436c6c9cSStella Laurenzo         },
579436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
580436c6c9cSStella Laurenzo   }
581436c6c9cSStella Laurenzo };
582436c6c9cSStella Laurenzo 
5835c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
5845c3861b2SYun Long public:
5855c3861b2SYun Long   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
5865c3861b2SYun Long   static constexpr const char *pyClassName = "OpaqueAttr";
5875c3861b2SYun Long   using PyConcreteAttribute::PyConcreteAttribute;
5889566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
5899566ee28Smax       mlirOpaqueAttrGetTypeID;
5905c3861b2SYun Long 
5915c3861b2SYun Long   static void bindDerived(ClassTy &c) {
5925c3861b2SYun Long     c.def_static(
5935c3861b2SYun Long         "get",
5945c3861b2SYun Long         [](std::string dialectNamespace, py::buffer buffer, PyType &type,
5955c3861b2SYun Long            DefaultingPyMlirContext context) {
5965c3861b2SYun Long           const py::buffer_info bufferInfo = buffer.request();
5975c3861b2SYun Long           intptr_t bufferSize = bufferInfo.size;
5985c3861b2SYun Long           MlirAttribute attr = mlirOpaqueAttrGet(
5995c3861b2SYun Long               context->get(), toMlirStringRef(dialectNamespace), bufferSize,
6005c3861b2SYun Long               static_cast<char *>(bufferInfo.ptr), type);
6015c3861b2SYun Long           return PyOpaqueAttribute(context->getRef(), attr);
6025c3861b2SYun Long         },
6035c3861b2SYun Long         py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"),
6045c3861b2SYun Long         py::arg("context") = py::none(), "Gets an Opaque attribute.");
6055c3861b2SYun Long     c.def_property_readonly(
6065c3861b2SYun Long         "dialect_namespace",
6075c3861b2SYun Long         [](PyOpaqueAttribute &self) {
6085c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
6095c3861b2SYun Long           return py::str(stringRef.data, stringRef.length);
6105c3861b2SYun Long         },
6115c3861b2SYun Long         "Returns the dialect namespace for the Opaque attribute as a string");
6125c3861b2SYun Long     c.def_property_readonly(
6135c3861b2SYun Long         "data",
6145c3861b2SYun Long         [](PyOpaqueAttribute &self) {
6155c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
61662bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
6175c3861b2SYun Long         },
61862bf6c2eSChris Jones         "Returns the data for the Opaqued attributes as `bytes`");
6195c3861b2SYun Long   }
6205c3861b2SYun Long };
6215c3861b2SYun Long 
622436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
623436c6c9cSStella Laurenzo public:
624436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
625436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
626436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
6279566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
6289566ee28Smax       mlirStringAttrGetTypeID;
629436c6c9cSStella Laurenzo 
630436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
631436c6c9cSStella Laurenzo     c.def_static(
632436c6c9cSStella Laurenzo         "get",
633436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
634436c6c9cSStella Laurenzo           MlirAttribute attr =
635436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
636436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
637436c6c9cSStella Laurenzo         },
638436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
639436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
640436c6c9cSStella Laurenzo     c.def_static(
641436c6c9cSStella Laurenzo         "get_typed",
642436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
643436c6c9cSStella Laurenzo           MlirAttribute attr =
644436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
645436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
646436c6c9cSStella Laurenzo         },
647a6e7d024SStella Laurenzo         py::arg("type"), py::arg("value"),
648436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
6499f533548SIngo Müller     c.def_property_readonly(
6509f533548SIngo Müller         "value",
6519f533548SIngo Müller         [](PyStringAttribute &self) {
6529f533548SIngo Müller           MlirStringRef stringRef = mlirStringAttrGetValue(self);
6539f533548SIngo Müller           return py::str(stringRef.data, stringRef.length);
6549f533548SIngo Müller         },
655436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
65662bf6c2eSChris Jones     c.def_property_readonly(
65762bf6c2eSChris Jones         "value_bytes",
65862bf6c2eSChris Jones         [](PyStringAttribute &self) {
65962bf6c2eSChris Jones           MlirStringRef stringRef = mlirStringAttrGetValue(self);
66062bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
66162bf6c2eSChris Jones         },
66262bf6c2eSChris Jones         "Returns the value of the string attribute as `bytes`");
663436c6c9cSStella Laurenzo   }
664436c6c9cSStella Laurenzo };
665436c6c9cSStella Laurenzo 
666436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
667436c6c9cSStella Laurenzo class PyDenseElementsAttribute
668436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
669436c6c9cSStella Laurenzo public:
670436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
671436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
672436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
673436c6c9cSStella Laurenzo 
674436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
675c912f0e7Spranavm-nvidia   getFromList(py::list attributes, std::optional<PyType> explicitType,
676c912f0e7Spranavm-nvidia               DefaultingPyMlirContext contextWrapper) {
677c912f0e7Spranavm-nvidia 
678c912f0e7Spranavm-nvidia     const size_t numAttributes = py::len(attributes);
679c912f0e7Spranavm-nvidia     if (numAttributes == 0)
680c912f0e7Spranavm-nvidia       throw py::value_error("Attributes list must be non-empty.");
681c912f0e7Spranavm-nvidia 
682c912f0e7Spranavm-nvidia     MlirType shapedType;
683c912f0e7Spranavm-nvidia     if (explicitType) {
684c912f0e7Spranavm-nvidia       if ((!mlirTypeIsAShaped(*explicitType) ||
685c912f0e7Spranavm-nvidia            !mlirShapedTypeHasStaticShape(*explicitType))) {
686c912f0e7Spranavm-nvidia 
687c912f0e7Spranavm-nvidia         std::string message;
688c912f0e7Spranavm-nvidia         llvm::raw_string_ostream os(message);
689c912f0e7Spranavm-nvidia         os << "Expected a static ShapedType for the shaped_type parameter: "
690c912f0e7Spranavm-nvidia            << py::repr(py::cast(*explicitType));
691c912f0e7Spranavm-nvidia         throw py::value_error(os.str());
692c912f0e7Spranavm-nvidia       }
693c912f0e7Spranavm-nvidia       shapedType = *explicitType;
694c912f0e7Spranavm-nvidia     } else {
695c912f0e7Spranavm-nvidia       SmallVector<int64_t> shape{static_cast<int64_t>(numAttributes)};
696c912f0e7Spranavm-nvidia       shapedType = mlirRankedTensorTypeGet(
697c912f0e7Spranavm-nvidia           shape.size(), shape.data(),
698c912f0e7Spranavm-nvidia           mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
699c912f0e7Spranavm-nvidia           mlirAttributeGetNull());
700c912f0e7Spranavm-nvidia     }
701c912f0e7Spranavm-nvidia 
702c912f0e7Spranavm-nvidia     SmallVector<MlirAttribute> mlirAttributes;
703c912f0e7Spranavm-nvidia     mlirAttributes.reserve(numAttributes);
704c912f0e7Spranavm-nvidia     for (const py::handle &attribute : attributes) {
705c912f0e7Spranavm-nvidia       MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
706c912f0e7Spranavm-nvidia       MlirType attrType = mlirAttributeGetType(mlirAttribute);
707c912f0e7Spranavm-nvidia       mlirAttributes.push_back(mlirAttribute);
708c912f0e7Spranavm-nvidia 
709c912f0e7Spranavm-nvidia       if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
710c912f0e7Spranavm-nvidia         std::string message;
711c912f0e7Spranavm-nvidia         llvm::raw_string_ostream os(message);
712c912f0e7Spranavm-nvidia         os << "All attributes must be of the same type and match "
713c912f0e7Spranavm-nvidia            << "the type parameter: expected=" << py::repr(py::cast(shapedType))
714c912f0e7Spranavm-nvidia            << ", but got=" << py::repr(py::cast(attrType));
715c912f0e7Spranavm-nvidia         throw py::value_error(os.str());
716c912f0e7Spranavm-nvidia       }
717c912f0e7Spranavm-nvidia     }
718c912f0e7Spranavm-nvidia 
719c912f0e7Spranavm-nvidia     MlirAttribute elements = mlirDenseElementsAttrGet(
720c912f0e7Spranavm-nvidia         shapedType, mlirAttributes.size(), mlirAttributes.data());
721c912f0e7Spranavm-nvidia 
722c912f0e7Spranavm-nvidia     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
723c912f0e7Spranavm-nvidia   }
724c912f0e7Spranavm-nvidia 
725c912f0e7Spranavm-nvidia   static PyDenseElementsAttribute
7260a81ace0SKazu Hirata   getFromBuffer(py::buffer array, bool signless,
7270a81ace0SKazu Hirata                 std::optional<PyType> explicitType,
7280a81ace0SKazu Hirata                 std::optional<std::vector<int64_t>> explicitShape,
729436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
730436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
73171a25454SPeter Hawkins     int flags = PyBUF_ND;
73271a25454SPeter Hawkins     if (!explicitType) {
73371a25454SPeter Hawkins       flags |= PyBUF_FORMAT;
73471a25454SPeter Hawkins     }
73571a25454SPeter Hawkins     Py_buffer view;
73671a25454SPeter Hawkins     if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
737436c6c9cSStella Laurenzo       throw py::error_already_set();
738436c6c9cSStella Laurenzo     }
73971a25454SPeter Hawkins     auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
7405d6d30edSStella Laurenzo     SmallVector<int64_t> shape;
7415d6d30edSStella Laurenzo     if (explicitShape) {
7425d6d30edSStella Laurenzo       shape.append(explicitShape->begin(), explicitShape->end());
7435d6d30edSStella Laurenzo     } else {
74471a25454SPeter Hawkins       shape.append(view.shape, view.shape + view.ndim);
7455d6d30edSStella Laurenzo     }
746436c6c9cSStella Laurenzo 
7475d6d30edSStella Laurenzo     MlirAttribute encodingAttr = mlirAttributeGetNull();
748436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
7495d6d30edSStella Laurenzo 
7505d6d30edSStella Laurenzo     // Detect format codes that are suitable for bulk loading. This includes
7515d6d30edSStella Laurenzo     // all byte aligned integer and floating point types up to 8 bytes.
7525d6d30edSStella Laurenzo     // Notably, this excludes, bool (which needs to be bit-packed) and
7535d6d30edSStella Laurenzo     // other exotics which do not have a direct representation in the buffer
7545d6d30edSStella Laurenzo     // protocol (i.e. complex, etc).
7550a81ace0SKazu Hirata     std::optional<MlirType> bulkLoadElementType;
7565d6d30edSStella Laurenzo     if (explicitType) {
7575d6d30edSStella Laurenzo       bulkLoadElementType = *explicitType;
75871a25454SPeter Hawkins     } else {
75971a25454SPeter Hawkins       std::string_view format(view.format);
76071a25454SPeter Hawkins       if (format == "f") {
761436c6c9cSStella Laurenzo         // f32
76271a25454SPeter Hawkins         assert(view.itemsize == 4 && "mismatched array itemsize");
7635d6d30edSStella Laurenzo         bulkLoadElementType = mlirF32TypeGet(context);
76471a25454SPeter Hawkins       } else if (format == "d") {
765436c6c9cSStella Laurenzo         // f64
76671a25454SPeter Hawkins         assert(view.itemsize == 8 && "mismatched array itemsize");
7675d6d30edSStella Laurenzo         bulkLoadElementType = mlirF64TypeGet(context);
76871a25454SPeter Hawkins       } else if (format == "e") {
7695d6d30edSStella Laurenzo         // f16
77071a25454SPeter Hawkins         assert(view.itemsize == 2 && "mismatched array itemsize");
7715d6d30edSStella Laurenzo         bulkLoadElementType = mlirF16TypeGet(context);
77271a25454SPeter Hawkins       } else if (isSignedIntegerFormat(format)) {
77371a25454SPeter Hawkins         if (view.itemsize == 4) {
774436c6c9cSStella Laurenzo           // i32
77571a25454SPeter Hawkins           bulkLoadElementType = signless
77671a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 32)
777436c6c9cSStella Laurenzo                                     : mlirIntegerTypeSignedGet(context, 32);
77871a25454SPeter Hawkins         } else if (view.itemsize == 8) {
779436c6c9cSStella Laurenzo           // i64
78071a25454SPeter Hawkins           bulkLoadElementType = signless
78171a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 64)
782436c6c9cSStella Laurenzo                                     : mlirIntegerTypeSignedGet(context, 64);
78371a25454SPeter Hawkins         } else if (view.itemsize == 1) {
7845d6d30edSStella Laurenzo           // i8
7855d6d30edSStella Laurenzo           bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
7865d6d30edSStella Laurenzo                                          : mlirIntegerTypeSignedGet(context, 8);
78771a25454SPeter Hawkins         } else if (view.itemsize == 2) {
7885d6d30edSStella Laurenzo           // i16
78971a25454SPeter Hawkins           bulkLoadElementType = signless
79071a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 16)
7915d6d30edSStella Laurenzo                                     : mlirIntegerTypeSignedGet(context, 16);
792436c6c9cSStella Laurenzo         }
79371a25454SPeter Hawkins       } else if (isUnsignedIntegerFormat(format)) {
79471a25454SPeter Hawkins         if (view.itemsize == 4) {
795436c6c9cSStella Laurenzo           // unsigned i32
7965d6d30edSStella Laurenzo           bulkLoadElementType = signless
797436c6c9cSStella Laurenzo                                     ? mlirIntegerTypeGet(context, 32)
798436c6c9cSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 32);
79971a25454SPeter Hawkins         } else if (view.itemsize == 8) {
800436c6c9cSStella Laurenzo           // unsigned i64
8015d6d30edSStella Laurenzo           bulkLoadElementType = signless
802436c6c9cSStella Laurenzo                                     ? mlirIntegerTypeGet(context, 64)
803436c6c9cSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 64);
80471a25454SPeter Hawkins         } else if (view.itemsize == 1) {
8055d6d30edSStella Laurenzo           // i8
80671a25454SPeter Hawkins           bulkLoadElementType = signless
80771a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 8)
8085d6d30edSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 8);
80971a25454SPeter Hawkins         } else if (view.itemsize == 2) {
8105d6d30edSStella Laurenzo           // i16
8115d6d30edSStella Laurenzo           bulkLoadElementType = signless
8125d6d30edSStella Laurenzo                                     ? mlirIntegerTypeGet(context, 16)
8135d6d30edSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 16);
814436c6c9cSStella Laurenzo         }
815436c6c9cSStella Laurenzo       }
81671a25454SPeter Hawkins       if (!bulkLoadElementType) {
81771a25454SPeter Hawkins         throw std::invalid_argument(
81871a25454SPeter Hawkins             std::string("unimplemented array format conversion from format: ") +
81971a25454SPeter Hawkins             std::string(format));
82071a25454SPeter Hawkins       }
82171a25454SPeter Hawkins     }
82271a25454SPeter Hawkins 
82399dee31eSAdam Paszke     MlirType shapedType;
82499dee31eSAdam Paszke     if (mlirTypeIsAShaped(*bulkLoadElementType)) {
82599dee31eSAdam Paszke       if (explicitShape) {
82699dee31eSAdam Paszke         throw std::invalid_argument("Shape can only be specified explicitly "
82799dee31eSAdam Paszke                                     "when the type is not a shaped type.");
82899dee31eSAdam Paszke       }
82999dee31eSAdam Paszke       shapedType = *bulkLoadElementType;
83099dee31eSAdam Paszke     } else {
83171a25454SPeter Hawkins       shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
83271a25454SPeter Hawkins                                            *bulkLoadElementType, encodingAttr);
83399dee31eSAdam Paszke     }
83471a25454SPeter Hawkins     size_t rawBufferSize = view.len;
83571a25454SPeter Hawkins     MlirAttribute attr =
83671a25454SPeter Hawkins         mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
8375d6d30edSStella Laurenzo     if (mlirAttributeIsNull(attr)) {
8385d6d30edSStella Laurenzo       throw std::invalid_argument(
8395d6d30edSStella Laurenzo           "DenseElementsAttr could not be constructed from the given buffer. "
8405d6d30edSStella Laurenzo           "This may mean that the Python buffer layout does not match that "
8415d6d30edSStella Laurenzo           "MLIR expected layout and is a bug.");
8425d6d30edSStella Laurenzo     }
8435d6d30edSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
8445d6d30edSStella Laurenzo   }
845436c6c9cSStella Laurenzo 
8461fc096afSMehdi Amini   static PyDenseElementsAttribute getSplat(const PyType &shapedType,
847436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
848436c6c9cSStella Laurenzo     auto contextWrapper =
849436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
850436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
851436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
852436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
853436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
8544811270bSmax       throw py::value_error(message);
855436c6c9cSStella Laurenzo     }
856436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
857436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
858436c6c9cSStella Laurenzo       std::string message =
859436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
860436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
8614811270bSmax       throw py::value_error(message);
862436c6c9cSStella Laurenzo     }
863436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
864436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
865436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
866436c6c9cSStella Laurenzo       std::string message =
867436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
868436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
869436c6c9cSStella Laurenzo       message.append(", element=");
870436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
8714811270bSmax       throw py::value_error(message);
872436c6c9cSStella Laurenzo     }
873436c6c9cSStella Laurenzo 
874436c6c9cSStella Laurenzo     MlirAttribute elements =
875436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
876436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
877436c6c9cSStella Laurenzo   }
878436c6c9cSStella Laurenzo 
879436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
880436c6c9cSStella Laurenzo 
881436c6c9cSStella Laurenzo   py::buffer_info accessBuffer() {
882436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
883436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
8845d6d30edSStella Laurenzo     std::string format;
885436c6c9cSStella Laurenzo 
886436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
887436c6c9cSStella Laurenzo       // f32
8885d6d30edSStella Laurenzo       return bufferInfo<float>(shapedType);
88902b6fb21SMehdi Amini     }
89002b6fb21SMehdi Amini     if (mlirTypeIsAF64(elementType)) {
891436c6c9cSStella Laurenzo       // f64
8925d6d30edSStella Laurenzo       return bufferInfo<double>(shapedType);
893bb56c2b3SMehdi Amini     }
894bb56c2b3SMehdi Amini     if (mlirTypeIsAF16(elementType)) {
8955d6d30edSStella Laurenzo       // f16
8965d6d30edSStella Laurenzo       return bufferInfo<uint16_t>(shapedType, "e");
897bb56c2b3SMehdi Amini     }
898ef1b735dSmax     if (mlirTypeIsAIndex(elementType)) {
899ef1b735dSmax       // Same as IndexType::kInternalStorageBitWidth
900ef1b735dSmax       return bufferInfo<int64_t>(shapedType);
901ef1b735dSmax     }
902bb56c2b3SMehdi Amini     if (mlirTypeIsAInteger(elementType) &&
903436c6c9cSStella Laurenzo         mlirIntegerTypeGetWidth(elementType) == 32) {
904436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
905436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
906436c6c9cSStella Laurenzo         // i32
9075d6d30edSStella Laurenzo         return bufferInfo<int32_t>(shapedType);
908e5639b3fSMehdi Amini       }
909e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
910436c6c9cSStella Laurenzo         // unsigned i32
9115d6d30edSStella Laurenzo         return bufferInfo<uint32_t>(shapedType);
912436c6c9cSStella Laurenzo       }
913436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
914436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
915436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
916436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
917436c6c9cSStella Laurenzo         // i64
9185d6d30edSStella Laurenzo         return bufferInfo<int64_t>(shapedType);
919e5639b3fSMehdi Amini       }
920e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
921436c6c9cSStella Laurenzo         // unsigned i64
9225d6d30edSStella Laurenzo         return bufferInfo<uint64_t>(shapedType);
9235d6d30edSStella Laurenzo       }
9245d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
9255d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 8) {
9265d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
9275d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
9285d6d30edSStella Laurenzo         // i8
9295d6d30edSStella Laurenzo         return bufferInfo<int8_t>(shapedType);
930e5639b3fSMehdi Amini       }
931e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
9325d6d30edSStella Laurenzo         // unsigned i8
9335d6d30edSStella Laurenzo         return bufferInfo<uint8_t>(shapedType);
9345d6d30edSStella Laurenzo       }
9355d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
9365d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 16) {
9375d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
9385d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
9395d6d30edSStella Laurenzo         // i16
9405d6d30edSStella Laurenzo         return bufferInfo<int16_t>(shapedType);
941e5639b3fSMehdi Amini       }
942e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
9435d6d30edSStella Laurenzo         // unsigned i16
9445d6d30edSStella Laurenzo         return bufferInfo<uint16_t>(shapedType);
945436c6c9cSStella Laurenzo       }
946436c6c9cSStella Laurenzo     }
947436c6c9cSStella Laurenzo 
948c5f445d1SStella Laurenzo     // TODO: Currently crashes the program.
9495d6d30edSStella Laurenzo     // Reported as https://github.com/pybind/pybind11/issues/3336
950c5f445d1SStella Laurenzo     throw std::invalid_argument(
951c5f445d1SStella Laurenzo         "unsupported data type for conversion to Python buffer");
952436c6c9cSStella Laurenzo   }
953436c6c9cSStella Laurenzo 
954436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
955436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
956436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
957436c6c9cSStella Laurenzo                     py::arg("array"), py::arg("signless") = true,
9585d6d30edSStella Laurenzo                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
959436c6c9cSStella Laurenzo                     py::arg("context") = py::none(),
9605d6d30edSStella Laurenzo                     kDenseElementsAttrGetDocstring)
961c912f0e7Spranavm-nvidia         .def_static("get", PyDenseElementsAttribute::getFromList,
962c912f0e7Spranavm-nvidia                     py::arg("attrs"), py::arg("type") = py::none(),
963c912f0e7Spranavm-nvidia                     py::arg("context") = py::none(),
964c912f0e7Spranavm-nvidia                     kDenseElementsAttrGetFromListDocstring)
965436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
966436c6c9cSStella Laurenzo                     py::arg("shaped_type"), py::arg("element_attr"),
967436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
968436c6c9cSStella Laurenzo         .def_property_readonly("is_splat",
969436c6c9cSStella Laurenzo                                [](PyDenseElementsAttribute &self) -> bool {
970436c6c9cSStella Laurenzo                                  return mlirDenseElementsAttrIsSplat(self);
971436c6c9cSStella Laurenzo                                })
97291259963SAdam Paszke         .def("get_splat_value",
973974c1596SRahul Kayaith              [](PyDenseElementsAttribute &self) {
974974c1596SRahul Kayaith                if (!mlirDenseElementsAttrIsSplat(self))
9754811270bSmax                  throw py::value_error(
97691259963SAdam Paszke                      "get_splat_value called on a non-splat attribute");
977974c1596SRahul Kayaith                return mlirDenseElementsAttrGetSplatValue(self);
97891259963SAdam Paszke              })
979436c6c9cSStella Laurenzo         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
980436c6c9cSStella Laurenzo   }
981436c6c9cSStella Laurenzo 
982436c6c9cSStella Laurenzo private:
98371a25454SPeter Hawkins   static bool isUnsignedIntegerFormat(std::string_view format) {
984436c6c9cSStella Laurenzo     if (format.empty())
985436c6c9cSStella Laurenzo       return false;
986436c6c9cSStella Laurenzo     char code = format[0];
987436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
988436c6c9cSStella Laurenzo            code == 'Q';
989436c6c9cSStella Laurenzo   }
990436c6c9cSStella Laurenzo 
99171a25454SPeter Hawkins   static bool isSignedIntegerFormat(std::string_view format) {
992436c6c9cSStella Laurenzo     if (format.empty())
993436c6c9cSStella Laurenzo       return false;
994436c6c9cSStella Laurenzo     char code = format[0];
995436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
996436c6c9cSStella Laurenzo            code == 'q';
997436c6c9cSStella Laurenzo   }
998436c6c9cSStella Laurenzo 
999436c6c9cSStella Laurenzo   template <typename Type>
1000436c6c9cSStella Laurenzo   py::buffer_info bufferInfo(MlirType shapedType,
10015d6d30edSStella Laurenzo                              const char *explicitFormat = nullptr) {
1002436c6c9cSStella Laurenzo     intptr_t rank = mlirShapedTypeGetRank(shapedType);
1003436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
1004436c6c9cSStella Laurenzo     // Buffer is configured for read-only access below.
1005436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
1006436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1007436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
1008436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
1009436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
1010436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
1011436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
1012436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
1013f0e847d0SRahul Kayaith     if (mlirDenseElementsAttrIsSplat(*this)) {
1014f0e847d0SRahul Kayaith       // Splats are special, only the single value is stored.
1015f0e847d0SRahul Kayaith       strides.assign(rank, 0);
1016f0e847d0SRahul Kayaith     } else {
1017436c6c9cSStella Laurenzo       for (intptr_t i = 1; i < rank; ++i) {
1018f0e847d0SRahul Kayaith         intptr_t strideFactor = 1;
1019f0e847d0SRahul Kayaith         for (intptr_t j = i; j < rank; ++j)
1020436c6c9cSStella Laurenzo           strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
1021436c6c9cSStella Laurenzo         strides.push_back(sizeof(Type) * strideFactor);
1022436c6c9cSStella Laurenzo       }
1023436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type));
1024f0e847d0SRahul Kayaith     }
10255d6d30edSStella Laurenzo     std::string format;
10265d6d30edSStella Laurenzo     if (explicitFormat) {
10275d6d30edSStella Laurenzo       format = explicitFormat;
10285d6d30edSStella Laurenzo     } else {
10295d6d30edSStella Laurenzo       format = py::format_descriptor<Type>::format();
10305d6d30edSStella Laurenzo     }
10315d6d30edSStella Laurenzo     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
10325d6d30edSStella Laurenzo                            /*readonly=*/true);
1033436c6c9cSStella Laurenzo   }
1034436c6c9cSStella Laurenzo }; // namespace
1035436c6c9cSStella Laurenzo 
1036436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer
1037436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access.
1038436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
1039436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
1040436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
1041436c6c9cSStella Laurenzo public:
1042436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
1043436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
1044436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1045436c6c9cSStella Laurenzo 
1046436c6c9cSStella Laurenzo   /// Returns the element at the given linear position. Asserts if the index is
1047436c6c9cSStella Laurenzo   /// out of range.
1048436c6c9cSStella Laurenzo   py::int_ dunderGetItem(intptr_t pos) {
1049436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
10504811270bSmax       throw py::index_error("attempt to access out of bounds element");
1051436c6c9cSStella Laurenzo     }
1052436c6c9cSStella Laurenzo 
1053436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
1054436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
1055436c6c9cSStella Laurenzo     assert(mlirTypeIsAInteger(type) &&
1056436c6c9cSStella Laurenzo            "expected integer element type in dense int elements attribute");
1057436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
1058436c6c9cSStella Laurenzo     // elemental type of the attribute. py::int_ is implicitly constructible
1059436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
1060436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
1061436c6c9cSStella Laurenzo     // querying them on each element access.
1062436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
1063436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
1064436c6c9cSStella Laurenzo     if (isUnsigned) {
1065436c6c9cSStella Laurenzo       if (width == 1) {
1066436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
1067436c6c9cSStella Laurenzo       }
1068308d8b8cSRahul Kayaith       if (width == 8) {
1069308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt8Value(*this, pos);
1070308d8b8cSRahul Kayaith       }
1071308d8b8cSRahul Kayaith       if (width == 16) {
1072308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt16Value(*this, pos);
1073308d8b8cSRahul Kayaith       }
1074436c6c9cSStella Laurenzo       if (width == 32) {
1075436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
1076436c6c9cSStella Laurenzo       }
1077436c6c9cSStella Laurenzo       if (width == 64) {
1078436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
1079436c6c9cSStella Laurenzo       }
1080436c6c9cSStella Laurenzo     } else {
1081436c6c9cSStella Laurenzo       if (width == 1) {
1082436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
1083436c6c9cSStella Laurenzo       }
1084308d8b8cSRahul Kayaith       if (width == 8) {
1085308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt8Value(*this, pos);
1086308d8b8cSRahul Kayaith       }
1087308d8b8cSRahul Kayaith       if (width == 16) {
1088308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt16Value(*this, pos);
1089308d8b8cSRahul Kayaith       }
1090436c6c9cSStella Laurenzo       if (width == 32) {
1091436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt32Value(*this, pos);
1092436c6c9cSStella Laurenzo       }
1093436c6c9cSStella Laurenzo       if (width == 64) {
1094436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt64Value(*this, pos);
1095436c6c9cSStella Laurenzo       }
1096436c6c9cSStella Laurenzo     }
10974811270bSmax     throw py::type_error("Unsupported integer type");
1098436c6c9cSStella Laurenzo   }
1099436c6c9cSStella Laurenzo 
1100436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1101436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
1102436c6c9cSStella Laurenzo   }
1103436c6c9cSStella Laurenzo };
1104436c6c9cSStella Laurenzo 
1105f66cd9e9SStella Laurenzo class PyDenseResourceElementsAttribute
1106f66cd9e9SStella Laurenzo     : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
1107f66cd9e9SStella Laurenzo public:
1108f66cd9e9SStella Laurenzo   static constexpr IsAFunctionTy isaFunction =
1109f66cd9e9SStella Laurenzo       mlirAttributeIsADenseResourceElements;
1110f66cd9e9SStella Laurenzo   static constexpr const char *pyClassName = "DenseResourceElementsAttr";
1111f66cd9e9SStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1112f66cd9e9SStella Laurenzo 
1113f66cd9e9SStella Laurenzo   static PyDenseResourceElementsAttribute
1114962bf002SMehdi Amini   getFromBuffer(py::buffer buffer, const std::string &name, const PyType &type,
1115f66cd9e9SStella Laurenzo                 std::optional<size_t> alignment, bool isMutable,
1116f66cd9e9SStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
1117f66cd9e9SStella Laurenzo     if (!mlirTypeIsAShaped(type)) {
1118f66cd9e9SStella Laurenzo       throw std::invalid_argument(
1119f66cd9e9SStella Laurenzo           "Constructing a DenseResourceElementsAttr requires a ShapedType.");
1120f66cd9e9SStella Laurenzo     }
1121f66cd9e9SStella Laurenzo 
1122f66cd9e9SStella Laurenzo     // Do not request any conversions as we must ensure to use caller
1123f66cd9e9SStella Laurenzo     // managed memory.
1124f66cd9e9SStella Laurenzo     int flags = PyBUF_STRIDES;
1125f66cd9e9SStella Laurenzo     std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
1126f66cd9e9SStella Laurenzo     if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
1127f66cd9e9SStella Laurenzo       throw py::error_already_set();
1128f66cd9e9SStella Laurenzo     }
1129f66cd9e9SStella Laurenzo 
1130f66cd9e9SStella Laurenzo     // This scope releaser will only release if we haven't yet transferred
1131f66cd9e9SStella Laurenzo     // ownership.
1132f66cd9e9SStella Laurenzo     auto freeBuffer = llvm::make_scope_exit([&]() {
1133f66cd9e9SStella Laurenzo       if (view)
1134f66cd9e9SStella Laurenzo         PyBuffer_Release(view.get());
1135f66cd9e9SStella Laurenzo     });
1136f66cd9e9SStella Laurenzo 
1137f66cd9e9SStella Laurenzo     if (!PyBuffer_IsContiguous(view.get(), 'A')) {
1138f66cd9e9SStella Laurenzo       throw std::invalid_argument("Contiguous buffer is required.");
1139f66cd9e9SStella Laurenzo     }
1140f66cd9e9SStella Laurenzo 
1141f66cd9e9SStella Laurenzo     // Infer alignment to be the stride of one element if not explicit.
1142f66cd9e9SStella Laurenzo     size_t inferredAlignment;
1143f66cd9e9SStella Laurenzo     if (alignment)
1144f66cd9e9SStella Laurenzo       inferredAlignment = *alignment;
1145f66cd9e9SStella Laurenzo     else
1146f66cd9e9SStella Laurenzo       inferredAlignment = view->strides[view->ndim - 1];
1147f66cd9e9SStella Laurenzo 
1148f66cd9e9SStella Laurenzo     // The userData is a Py_buffer* that the deleter owns.
1149f66cd9e9SStella Laurenzo     auto deleter = [](void *userData, const void *data, size_t size,
1150f66cd9e9SStella Laurenzo                       size_t align) {
1151f66cd9e9SStella Laurenzo       Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
1152f66cd9e9SStella Laurenzo       PyBuffer_Release(ownedView);
1153f66cd9e9SStella Laurenzo       delete ownedView;
1154f66cd9e9SStella Laurenzo     };
1155f66cd9e9SStella Laurenzo 
1156f66cd9e9SStella Laurenzo     size_t rawBufferSize = view->len;
1157f66cd9e9SStella Laurenzo     MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
1158f66cd9e9SStella Laurenzo         type, toMlirStringRef(name), view->buf, rawBufferSize,
1159f66cd9e9SStella Laurenzo         inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
1160f66cd9e9SStella Laurenzo     if (mlirAttributeIsNull(attr)) {
1161f66cd9e9SStella Laurenzo       throw std::invalid_argument(
1162f66cd9e9SStella Laurenzo           "DenseResourceElementsAttr could not be constructed from the given "
1163f66cd9e9SStella Laurenzo           "buffer. "
1164f66cd9e9SStella Laurenzo           "This may mean that the Python buffer layout does not match that "
1165f66cd9e9SStella Laurenzo           "MLIR expected layout and is a bug.");
1166f66cd9e9SStella Laurenzo     }
1167f66cd9e9SStella Laurenzo     view.release();
1168f66cd9e9SStella Laurenzo     return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
1169f66cd9e9SStella Laurenzo   }
1170f66cd9e9SStella Laurenzo 
1171f66cd9e9SStella Laurenzo   static void bindDerived(ClassTy &c) {
1172f66cd9e9SStella Laurenzo     c.def_static("get_from_buffer",
1173f66cd9e9SStella Laurenzo                  PyDenseResourceElementsAttribute::getFromBuffer,
1174f66cd9e9SStella Laurenzo                  py::arg("array"), py::arg("name"), py::arg("type"),
1175f66cd9e9SStella Laurenzo                  py::arg("alignment") = py::none(),
1176f66cd9e9SStella Laurenzo                  py::arg("is_mutable") = false, py::arg("context") = py::none(),
1177f66cd9e9SStella Laurenzo                  kDenseResourceElementsAttrGetFromBufferDocstring);
1178f66cd9e9SStella Laurenzo   }
1179f66cd9e9SStella Laurenzo };
1180f66cd9e9SStella Laurenzo 
1181436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
1182436c6c9cSStella Laurenzo public:
1183436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
1184436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
1185436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
11869566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
11879566ee28Smax       mlirDictionaryAttrGetTypeID;
1188436c6c9cSStella Laurenzo 
1189436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
1190436c6c9cSStella Laurenzo 
11919fb1086bSAdrian Kuegel   bool dunderContains(const std::string &name) {
11929fb1086bSAdrian Kuegel     return !mlirAttributeIsNull(
11939fb1086bSAdrian Kuegel         mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
11949fb1086bSAdrian Kuegel   }
11959fb1086bSAdrian Kuegel 
1196436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
11979fb1086bSAdrian Kuegel     c.def("__contains__", &PyDictAttribute::dunderContains);
1198436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
1199436c6c9cSStella Laurenzo     c.def_static(
1200436c6c9cSStella Laurenzo         "get",
1201436c6c9cSStella Laurenzo         [](py::dict attributes, DefaultingPyMlirContext context) {
1202436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
1203436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
1204436c6c9cSStella Laurenzo           for (auto &it : attributes) {
120502b6fb21SMehdi Amini             auto &mlirAttr = it.second.cast<PyAttribute &>();
1206436c6c9cSStella Laurenzo             auto name = it.first.cast<std::string>();
1207436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
120802b6fb21SMehdi Amini                 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
1209436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
121002b6fb21SMehdi Amini                 mlirAttr));
1211436c6c9cSStella Laurenzo           }
1212436c6c9cSStella Laurenzo           MlirAttribute attr =
1213436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
1214436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
1215436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
1216436c6c9cSStella Laurenzo         },
1217ed9e52f3SAlex Zinenko         py::arg("value") = py::dict(), py::arg("context") = py::none(),
1218436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
1219436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
1220436c6c9cSStella Laurenzo       MlirAttribute attr =
1221436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
1222974c1596SRahul Kayaith       if (mlirAttributeIsNull(attr))
12234811270bSmax         throw py::key_error("attempt to access a non-existent attribute");
1224974c1596SRahul Kayaith       return attr;
1225436c6c9cSStella Laurenzo     });
1226436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
1227436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
12284811270bSmax         throw py::index_error("attempt to access out of bounds attribute");
1229436c6c9cSStella Laurenzo       }
1230436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
1231436c6c9cSStella Laurenzo       return PyNamedAttribute(
1232436c6c9cSStella Laurenzo           namedAttr.attribute,
1233436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
1234436c6c9cSStella Laurenzo     });
1235436c6c9cSStella Laurenzo   }
1236436c6c9cSStella Laurenzo };
1237436c6c9cSStella Laurenzo 
1238436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
1239436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
1240436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
1241436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1242436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
1243436c6c9cSStella Laurenzo public:
1244436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1245436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
1246436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1247436c6c9cSStella Laurenzo 
1248436c6c9cSStella Laurenzo   py::float_ dunderGetItem(intptr_t pos) {
1249436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
12504811270bSmax       throw py::index_error("attempt to access out of bounds element");
1251436c6c9cSStella Laurenzo     }
1252436c6c9cSStella Laurenzo 
1253436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
1254436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
1255436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
1256436c6c9cSStella Laurenzo     // elemental type of the attribute. py::float_ is implicitly constructible
1257436c6c9cSStella Laurenzo     // from float and double.
1258436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
1259436c6c9cSStella Laurenzo     // querying them on each element access.
1260436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
1261436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetFloatValue(*this, pos);
1262436c6c9cSStella Laurenzo     }
1263436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
1264436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
1265436c6c9cSStella Laurenzo     }
12664811270bSmax     throw py::type_error("Unsupported floating-point type");
1267436c6c9cSStella Laurenzo   }
1268436c6c9cSStella Laurenzo 
1269436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1270436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1271436c6c9cSStella Laurenzo   }
1272436c6c9cSStella Laurenzo };
1273436c6c9cSStella Laurenzo 
1274436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1275436c6c9cSStella Laurenzo public:
1276436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1277436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
1278436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
12799566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
12809566ee28Smax       mlirTypeAttrGetTypeID;
1281436c6c9cSStella Laurenzo 
1282436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1283436c6c9cSStella Laurenzo     c.def_static(
1284436c6c9cSStella Laurenzo         "get",
1285436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
1286436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
1287436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
1288436c6c9cSStella Laurenzo         },
1289436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
1290436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
1291436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyTypeAttribute &self) {
1292bfb1ba75Smax       return mlirTypeAttrGetValue(self.get());
1293436c6c9cSStella Laurenzo     });
1294436c6c9cSStella Laurenzo   }
1295436c6c9cSStella Laurenzo };
1296436c6c9cSStella Laurenzo 
1297436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
1298436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1299436c6c9cSStella Laurenzo public:
1300436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1301436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
1302436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
13039566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
13049566ee28Smax       mlirUnitAttrGetTypeID;
1305436c6c9cSStella Laurenzo 
1306436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1307436c6c9cSStella Laurenzo     c.def_static(
1308436c6c9cSStella Laurenzo         "get",
1309436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
1310436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
1311436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
1312436c6c9cSStella Laurenzo         },
1313436c6c9cSStella Laurenzo         py::arg("context") = py::none(), "Create a Unit attribute.");
1314436c6c9cSStella Laurenzo   }
1315436c6c9cSStella Laurenzo };
1316436c6c9cSStella Laurenzo 
1317ac2e2d65SDenys Shabalin /// Strided layout attribute subclass.
1318ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute
1319ac2e2d65SDenys Shabalin     : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1320ac2e2d65SDenys Shabalin public:
1321ac2e2d65SDenys Shabalin   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1322ac2e2d65SDenys Shabalin   static constexpr const char *pyClassName = "StridedLayoutAttr";
1323ac2e2d65SDenys Shabalin   using PyConcreteAttribute::PyConcreteAttribute;
13249566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
13259566ee28Smax       mlirStridedLayoutAttrGetTypeID;
1326ac2e2d65SDenys Shabalin 
1327ac2e2d65SDenys Shabalin   static void bindDerived(ClassTy &c) {
1328ac2e2d65SDenys Shabalin     c.def_static(
1329ac2e2d65SDenys Shabalin         "get",
1330ac2e2d65SDenys Shabalin         [](int64_t offset, const std::vector<int64_t> strides,
1331ac2e2d65SDenys Shabalin            DefaultingPyMlirContext ctx) {
1332ac2e2d65SDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1333ac2e2d65SDenys Shabalin               ctx->get(), offset, strides.size(), strides.data());
1334ac2e2d65SDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1335ac2e2d65SDenys Shabalin         },
1336ac2e2d65SDenys Shabalin         py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
1337ac2e2d65SDenys Shabalin         "Gets a strided layout attribute.");
1338e3fd612eSDenys Shabalin     c.def_static(
1339e3fd612eSDenys Shabalin         "get_fully_dynamic",
1340e3fd612eSDenys Shabalin         [](int64_t rank, DefaultingPyMlirContext ctx) {
1341e3fd612eSDenys Shabalin           auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
1342e3fd612eSDenys Shabalin           std::vector<int64_t> strides(rank);
1343e3fd612eSDenys Shabalin           std::fill(strides.begin(), strides.end(), dynamic);
1344e3fd612eSDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1345e3fd612eSDenys Shabalin               ctx->get(), dynamic, strides.size(), strides.data());
1346e3fd612eSDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1347e3fd612eSDenys Shabalin         },
1348e3fd612eSDenys Shabalin         py::arg("rank"), py::arg("context") = py::none(),
1349e3fd612eSDenys Shabalin         "Gets a strided layout attribute with dynamic offset and strides of a "
1350e3fd612eSDenys Shabalin         "given rank.");
1351ac2e2d65SDenys Shabalin     c.def_property_readonly(
1352ac2e2d65SDenys Shabalin         "offset",
1353ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1354ac2e2d65SDenys Shabalin           return mlirStridedLayoutAttrGetOffset(self);
1355ac2e2d65SDenys Shabalin         },
1356ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1357ac2e2d65SDenys Shabalin     c.def_property_readonly(
1358ac2e2d65SDenys Shabalin         "strides",
1359ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1360ac2e2d65SDenys Shabalin           intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1361ac2e2d65SDenys Shabalin           std::vector<int64_t> strides(size);
1362ac2e2d65SDenys Shabalin           for (intptr_t i = 0; i < size; i++) {
1363ac2e2d65SDenys Shabalin             strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1364ac2e2d65SDenys Shabalin           }
1365ac2e2d65SDenys Shabalin           return strides;
1366ac2e2d65SDenys Shabalin         },
1367ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1368ac2e2d65SDenys Shabalin   }
1369ac2e2d65SDenys Shabalin };
1370ac2e2d65SDenys Shabalin 
13719566ee28Smax py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
13729566ee28Smax   if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
13739566ee28Smax     return py::cast(PyDenseBoolArrayAttribute(pyAttribute));
13749566ee28Smax   if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
13759566ee28Smax     return py::cast(PyDenseI8ArrayAttribute(pyAttribute));
13769566ee28Smax   if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
13779566ee28Smax     return py::cast(PyDenseI16ArrayAttribute(pyAttribute));
13789566ee28Smax   if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
13799566ee28Smax     return py::cast(PyDenseI32ArrayAttribute(pyAttribute));
13809566ee28Smax   if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
13819566ee28Smax     return py::cast(PyDenseI64ArrayAttribute(pyAttribute));
13829566ee28Smax   if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
13839566ee28Smax     return py::cast(PyDenseF32ArrayAttribute(pyAttribute));
13849566ee28Smax   if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
13859566ee28Smax     return py::cast(PyDenseF64ArrayAttribute(pyAttribute));
13869566ee28Smax   std::string msg =
13879566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
13889566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
13899566ee28Smax   throw py::cast_error(msg);
13909566ee28Smax }
13919566ee28Smax 
13929566ee28Smax py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
13939566ee28Smax   if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
13949566ee28Smax     return py::cast(PyDenseFPElementsAttribute(pyAttribute));
13959566ee28Smax   if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
13969566ee28Smax     return py::cast(PyDenseIntElementsAttribute(pyAttribute));
13979566ee28Smax   std::string msg =
13989566ee28Smax       std::string(
13999566ee28Smax           "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
14009566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
14019566ee28Smax   throw py::cast_error(msg);
14029566ee28Smax }
14039566ee28Smax 
14049566ee28Smax py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
14059566ee28Smax   if (PyBoolAttribute::isaFunction(pyAttribute))
14069566ee28Smax     return py::cast(PyBoolAttribute(pyAttribute));
14079566ee28Smax   if (PyIntegerAttribute::isaFunction(pyAttribute))
14089566ee28Smax     return py::cast(PyIntegerAttribute(pyAttribute));
14099566ee28Smax   std::string msg =
14109566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
14119566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
14129566ee28Smax   throw py::cast_error(msg);
14139566ee28Smax }
14149566ee28Smax 
14154eee9ef9Smax py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
14164eee9ef9Smax   if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
14174eee9ef9Smax     return py::cast(PyFlatSymbolRefAttribute(pyAttribute));
14184eee9ef9Smax   if (PySymbolRefAttribute::isaFunction(pyAttribute))
14194eee9ef9Smax     return py::cast(PySymbolRefAttribute(pyAttribute));
14204eee9ef9Smax   std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
14214eee9ef9Smax                     std::string(py::repr(py::cast(pyAttribute))) + ")";
14224eee9ef9Smax   throw py::cast_error(msg);
14234eee9ef9Smax }
14244eee9ef9Smax 
1425436c6c9cSStella Laurenzo } // namespace
1426436c6c9cSStella Laurenzo 
1427436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) {
1428436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
1429619fd8c2SJeff Niu 
1430619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::bind(m);
1431619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1432619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::bind(m);
1433619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1434619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::bind(m);
1435619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1436619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::bind(m);
1437619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1438619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::bind(m);
1439619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1440619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::bind(m);
1441619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1442619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::bind(m);
1443619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
14449566ee28Smax   PyGlobals::get().registerTypeCaster(
14459566ee28Smax       mlirDenseArrayAttrGetTypeID(),
14469566ee28Smax       pybind11::cpp_function(denseArrayAttributeCaster));
1447619fd8c2SJeff Niu 
1448436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
1449436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1450436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
1451436c6c9cSStella Laurenzo   PyDenseElementsAttribute::bind(m);
1452436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
1453436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
14549566ee28Smax   PyGlobals::get().registerTypeCaster(
14559566ee28Smax       mlirDenseIntOrFPElementsAttrGetTypeID(),
14569566ee28Smax       pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
1457f66cd9e9SStella Laurenzo   PyDenseResourceElementsAttribute::bind(m);
14589566ee28Smax 
1459436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
14604eee9ef9Smax   PySymbolRefAttribute::bind(m);
14614eee9ef9Smax   PyGlobals::get().registerTypeCaster(
14624eee9ef9Smax       mlirSymbolRefAttrGetTypeID(),
14634eee9ef9Smax       pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster));
14644eee9ef9Smax 
1465436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
14665c3861b2SYun Long   PyOpaqueAttribute::bind(m);
1467436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
1468436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
1469436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
1470436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
14719566ee28Smax   PyGlobals::get().registerTypeCaster(
14729566ee28Smax       mlirIntegerAttrGetTypeID(),
14739566ee28Smax       pybind11::cpp_function(integerOrBoolAttributeCaster));
1474436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
1475ac2e2d65SDenys Shabalin 
1476ac2e2d65SDenys Shabalin   PyStridedLayoutAttribute::bind(m);
1477436c6c9cSStella Laurenzo }
1478