xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision f66cd9e9556a53142a26a5c21a72e21f1579217c)
1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9a1fe1f5fSKazu Hirata #include <optional>
1071a25454SPeter Hawkins #include <string_view>
114811270bSmax #include <utility>
121fc096afSMehdi Amini 
13436c6c9cSStella Laurenzo #include "IRModule.h"
14436c6c9cSStella Laurenzo 
15436c6c9cSStella Laurenzo #include "PybindUtils.h"
16436c6c9cSStella Laurenzo 
1771a25454SPeter Hawkins #include "llvm/ADT/ScopeExit.h"
1871a25454SPeter Hawkins 
19436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
20436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
21bfb1ba75Smax #include "mlir/Bindings/Python/PybindAdaptors.h"
22436c6c9cSStella Laurenzo 
23436c6c9cSStella Laurenzo namespace py = pybind11;
24436c6c9cSStella Laurenzo using namespace mlir;
25436c6c9cSStella Laurenzo using namespace mlir::python;
26436c6c9cSStella Laurenzo 
27436c6c9cSStella Laurenzo using llvm::SmallVector;
28436c6c9cSStella Laurenzo 
295d6d30edSStella Laurenzo //------------------------------------------------------------------------------
305d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
315d6d30edSStella Laurenzo //------------------------------------------------------------------------------
325d6d30edSStella Laurenzo 
335d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] =
345d6d30edSStella Laurenzo     R"(Gets a DenseElementsAttr from a Python buffer or array.
355d6d30edSStella Laurenzo 
365d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based
375d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and
385d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these
395d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer.
405d6d30edSStella Laurenzo 
415d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided
425d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal
435d6d30edSStella Laurenzo representation:
445d6d30edSStella Laurenzo 
455d6d30edSStella Laurenzo   * Integer types (except for i1): the buffer must be byte aligned to the
465d6d30edSStella Laurenzo     next byte boundary.
475d6d30edSStella Laurenzo   * Floating point types: Must be bit-castable to the given floating point
485d6d30edSStella Laurenzo     size.
495d6d30edSStella Laurenzo   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
505d6d30edSStella Laurenzo     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
515d6d30edSStella Laurenzo     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
525d6d30edSStella Laurenzo 
535d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0
545d6d30edSStella Laurenzo or 255), then a splat will be created.
555d6d30edSStella Laurenzo 
565d6d30edSStella Laurenzo Args:
575d6d30edSStella Laurenzo   array: The array or buffer to convert.
585d6d30edSStella Laurenzo   signless: If inferring an appropriate MLIR type, use signless types for
595d6d30edSStella Laurenzo     integers (defaults True).
605d6d30edSStella Laurenzo   type: Skips inference of the MLIR element type and uses this instead. The
615d6d30edSStella Laurenzo     storage size must be consistent with the actual contents of the buffer.
625d6d30edSStella Laurenzo   shape: Overrides the shape of the buffer when constructing the MLIR
635d6d30edSStella Laurenzo     shaped type. This is needed when the physical and logical shape differ (as
645d6d30edSStella Laurenzo     for i1).
655d6d30edSStella Laurenzo   context: Explicit context, if not from context manager.
665d6d30edSStella Laurenzo 
675d6d30edSStella Laurenzo Returns:
685d6d30edSStella Laurenzo   DenseElementsAttr on success.
695d6d30edSStella Laurenzo 
705d6d30edSStella Laurenzo Raises:
715d6d30edSStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
725d6d30edSStella Laurenzo     type or if the buffer does not meet expectations.
735d6d30edSStella Laurenzo )";
745d6d30edSStella Laurenzo 
75*f66cd9e9SStella Laurenzo static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
76*f66cd9e9SStella Laurenzo     R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
77*f66cd9e9SStella Laurenzo 
78*f66cd9e9SStella Laurenzo This function does minimal validation or massaging of the data, and it is
79*f66cd9e9SStella Laurenzo up to the caller to ensure that the buffer meets the characteristics
80*f66cd9e9SStella Laurenzo implied by the shape.
81*f66cd9e9SStella Laurenzo 
82*f66cd9e9SStella Laurenzo The backing buffer and any user objects will be retained for the lifetime
83*f66cd9e9SStella Laurenzo of the resource blob. This is typically bounded to the context but the
84*f66cd9e9SStella Laurenzo resource can have a shorter lifespan depending on how it is used in
85*f66cd9e9SStella Laurenzo subsequent processing.
86*f66cd9e9SStella Laurenzo 
87*f66cd9e9SStella Laurenzo Args:
88*f66cd9e9SStella Laurenzo   buffer: The array or buffer to convert.
89*f66cd9e9SStella Laurenzo   name: Name to provide to the resource (may be changed upon collision).
90*f66cd9e9SStella Laurenzo   type: The explicit ShapedType to construct the attribute with.
91*f66cd9e9SStella Laurenzo   context: Explicit context, if not from context manager.
92*f66cd9e9SStella Laurenzo 
93*f66cd9e9SStella Laurenzo Returns:
94*f66cd9e9SStella Laurenzo   DenseResourceElementsAttr on success.
95*f66cd9e9SStella Laurenzo 
96*f66cd9e9SStella Laurenzo Raises:
97*f66cd9e9SStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
98*f66cd9e9SStella Laurenzo     type or if the buffer does not meet expectations.
99*f66cd9e9SStella Laurenzo )";
100*f66cd9e9SStella Laurenzo 
101436c6c9cSStella Laurenzo namespace {
102436c6c9cSStella Laurenzo 
103436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
104436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
105436c6c9cSStella Laurenzo }
106436c6c9cSStella Laurenzo 
107436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
108436c6c9cSStella Laurenzo public:
109436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
110436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
111436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1129566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1139566ee28Smax       mlirAffineMapAttrGetTypeID;
114436c6c9cSStella Laurenzo 
115436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
116436c6c9cSStella Laurenzo     c.def_static(
117436c6c9cSStella Laurenzo         "get",
118436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
119436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
120436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
121436c6c9cSStella Laurenzo         },
122436c6c9cSStella Laurenzo         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
123436c6c9cSStella Laurenzo   }
124436c6c9cSStella Laurenzo };
125436c6c9cSStella Laurenzo 
126ed9e52f3SAlex Zinenko template <typename T>
127ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) {
128ed9e52f3SAlex Zinenko   try {
129ed9e52f3SAlex Zinenko     return object.cast<T>();
130ed9e52f3SAlex Zinenko   } catch (py::cast_error &err) {
131ed9e52f3SAlex Zinenko     std::string msg =
132ed9e52f3SAlex Zinenko         std::string(
133ed9e52f3SAlex Zinenko             "Invalid attribute when attempting to create an ArrayAttribute (") +
134ed9e52f3SAlex Zinenko         err.what() + ")";
135ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
136ed9e52f3SAlex Zinenko   } catch (py::reference_cast_error &err) {
137ed9e52f3SAlex Zinenko     std::string msg = std::string("Invalid attribute (None?) when attempting "
138ed9e52f3SAlex Zinenko                                   "to create an ArrayAttribute (") +
139ed9e52f3SAlex Zinenko                       err.what() + ")";
140ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
141ed9e52f3SAlex Zinenko   }
142ed9e52f3SAlex Zinenko }
143ed9e52f3SAlex Zinenko 
144619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived
145619fd8c2SJeff Niu /// implementation class.
146619fd8c2SJeff Niu template <typename EltTy, typename DerivedT>
147133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
148619fd8c2SJeff Niu public:
149133624acSJeff Niu   using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
150619fd8c2SJeff Niu 
151619fd8c2SJeff Niu   /// Iterator over the integer elements of a dense array.
152619fd8c2SJeff Niu   class PyDenseArrayIterator {
153619fd8c2SJeff Niu   public:
1544a1b1196SMehdi Amini     PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
155619fd8c2SJeff Niu 
156619fd8c2SJeff Niu     /// Return a copy of the iterator.
157619fd8c2SJeff Niu     PyDenseArrayIterator dunderIter() { return *this; }
158619fd8c2SJeff Niu 
159619fd8c2SJeff Niu     /// Return the next element.
160619fd8c2SJeff Niu     EltTy dunderNext() {
161619fd8c2SJeff Niu       // Throw if the index has reached the end.
162619fd8c2SJeff Niu       if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
163619fd8c2SJeff Niu         throw py::stop_iteration();
164619fd8c2SJeff Niu       return DerivedT::getElement(attr.get(), nextIndex++);
165619fd8c2SJeff Niu     }
166619fd8c2SJeff Niu 
167619fd8c2SJeff Niu     /// Bind the iterator class.
168619fd8c2SJeff Niu     static void bind(py::module &m) {
169619fd8c2SJeff Niu       py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName,
170619fd8c2SJeff Niu                                        py::module_local())
171619fd8c2SJeff Niu           .def("__iter__", &PyDenseArrayIterator::dunderIter)
172619fd8c2SJeff Niu           .def("__next__", &PyDenseArrayIterator::dunderNext);
173619fd8c2SJeff Niu     }
174619fd8c2SJeff Niu 
175619fd8c2SJeff Niu   private:
176619fd8c2SJeff Niu     /// The referenced dense array attribute.
177619fd8c2SJeff Niu     PyAttribute attr;
178619fd8c2SJeff Niu     /// The next index to read.
179619fd8c2SJeff Niu     int nextIndex = 0;
180619fd8c2SJeff Niu   };
181619fd8c2SJeff Niu 
182619fd8c2SJeff Niu   /// Get the element at the given index.
183619fd8c2SJeff Niu   EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
184619fd8c2SJeff Niu 
185619fd8c2SJeff Niu   /// Bind the attribute class.
186133624acSJeff Niu   static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
187619fd8c2SJeff Niu     // Bind the constructor.
188619fd8c2SJeff Niu     c.def_static(
189619fd8c2SJeff Niu         "get",
190619fd8c2SJeff Niu         [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
1918dcb6722SIngo Müller           return getAttribute(values, ctx->getRef());
192619fd8c2SJeff Niu         },
193619fd8c2SJeff Niu         py::arg("values"), py::arg("context") = py::none(),
194619fd8c2SJeff Niu         "Gets a uniqued dense array attribute");
195619fd8c2SJeff Niu     // Bind the array methods.
196133624acSJeff Niu     c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
197619fd8c2SJeff Niu       if (i >= mlirDenseArrayGetNumElements(arr))
198619fd8c2SJeff Niu         throw py::index_error("DenseArray index out of range");
199619fd8c2SJeff Niu       return arr.getItem(i);
200619fd8c2SJeff Niu     });
201133624acSJeff Niu     c.def("__len__", [](const DerivedT &arr) {
202619fd8c2SJeff Niu       return mlirDenseArrayGetNumElements(arr);
203619fd8c2SJeff Niu     });
204133624acSJeff Niu     c.def("__iter__",
205133624acSJeff Niu           [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
2064a1b1196SMehdi Amini     c.def("__add__", [](DerivedT &arr, const py::list &extras) {
207619fd8c2SJeff Niu       std::vector<EltTy> values;
208619fd8c2SJeff Niu       intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
209619fd8c2SJeff Niu       values.reserve(numOldElements + py::len(extras));
210619fd8c2SJeff Niu       for (intptr_t i = 0; i < numOldElements; ++i)
211619fd8c2SJeff Niu         values.push_back(arr.getItem(i));
212619fd8c2SJeff Niu       for (py::handle attr : extras)
213619fd8c2SJeff Niu         values.push_back(pyTryCast<EltTy>(attr));
2148dcb6722SIngo Müller       return getAttribute(values, arr.getContext());
215619fd8c2SJeff Niu     });
216619fd8c2SJeff Niu   }
2178dcb6722SIngo Müller 
2188dcb6722SIngo Müller private:
2198dcb6722SIngo Müller   static DerivedT getAttribute(const std::vector<EltTy> &values,
2208dcb6722SIngo Müller                                PyMlirContextRef ctx) {
2218dcb6722SIngo Müller     if constexpr (std::is_same_v<EltTy, bool>) {
2228dcb6722SIngo Müller       std::vector<int> intValues(values.begin(), values.end());
2238dcb6722SIngo Müller       MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
2248dcb6722SIngo Müller                                                   intValues.data());
2258dcb6722SIngo Müller       return DerivedT(ctx, attr);
2268dcb6722SIngo Müller     } else {
2278dcb6722SIngo Müller       MlirAttribute attr =
2288dcb6722SIngo Müller           DerivedT::getAttribute(ctx->get(), values.size(), values.data());
2298dcb6722SIngo Müller       return DerivedT(ctx, attr);
2308dcb6722SIngo Müller     }
2318dcb6722SIngo Müller   }
232619fd8c2SJeff Niu };
233619fd8c2SJeff Niu 
234619fd8c2SJeff Niu /// Instantiate the python dense array classes.
235619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute
2368dcb6722SIngo Müller     : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
237619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
238619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseBoolArrayGet;
239619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseBoolArrayGetElement;
240619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseBoolArrayAttr";
241619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
242619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
243619fd8c2SJeff Niu };
244619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute
245619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
246619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
247619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI8ArrayGet;
248619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI8ArrayGetElement;
249619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI8ArrayAttr";
250619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
251619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
252619fd8c2SJeff Niu };
253619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute
254619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
255619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
256619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI16ArrayGet;
257619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI16ArrayGetElement;
258619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI16ArrayAttr";
259619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
260619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
261619fd8c2SJeff Niu };
262619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute
263619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
264619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
265619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI32ArrayGet;
266619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI32ArrayGetElement;
267619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI32ArrayAttr";
268619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
269619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
270619fd8c2SJeff Niu };
271619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute
272619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
273619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
274619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI64ArrayGet;
275619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI64ArrayGetElement;
276619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI64ArrayAttr";
277619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
278619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
279619fd8c2SJeff Niu };
280619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute
281619fd8c2SJeff Niu     : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
282619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
283619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF32ArrayGet;
284619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF32ArrayGetElement;
285619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF32ArrayAttr";
286619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
287619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
288619fd8c2SJeff Niu };
289619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute
290619fd8c2SJeff Niu     : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
291619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
292619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF64ArrayGet;
293619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF64ArrayGetElement;
294619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF64ArrayAttr";
295619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
296619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
297619fd8c2SJeff Niu };
298619fd8c2SJeff Niu 
299436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
300436c6c9cSStella Laurenzo public:
301436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
302436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
303436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
3049566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
3059566ee28Smax       mlirArrayAttrGetTypeID;
306436c6c9cSStella Laurenzo 
307436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
308436c6c9cSStella Laurenzo   public:
3091fc096afSMehdi Amini     PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
310436c6c9cSStella Laurenzo 
311436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
312436c6c9cSStella Laurenzo 
313974c1596SRahul Kayaith     MlirAttribute dunderNext() {
314bca88952SJeff Niu       // TODO: Throw is an inefficient way to stop iteration.
315bca88952SJeff Niu       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
316436c6c9cSStella Laurenzo         throw py::stop_iteration();
317974c1596SRahul Kayaith       return mlirArrayAttrGetElement(attr.get(), nextIndex++);
318436c6c9cSStella Laurenzo     }
319436c6c9cSStella Laurenzo 
320436c6c9cSStella Laurenzo     static void bind(py::module &m) {
321f05ff4f7SStella Laurenzo       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
322f05ff4f7SStella Laurenzo                                            py::module_local())
323436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
324436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
325436c6c9cSStella Laurenzo     }
326436c6c9cSStella Laurenzo 
327436c6c9cSStella Laurenzo   private:
328436c6c9cSStella Laurenzo     PyAttribute attr;
329436c6c9cSStella Laurenzo     int nextIndex = 0;
330436c6c9cSStella Laurenzo   };
331436c6c9cSStella Laurenzo 
332974c1596SRahul Kayaith   MlirAttribute getItem(intptr_t i) {
333974c1596SRahul Kayaith     return mlirArrayAttrGetElement(*this, i);
334ed9e52f3SAlex Zinenko   }
335ed9e52f3SAlex Zinenko 
336436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
337436c6c9cSStella Laurenzo     c.def_static(
338436c6c9cSStella Laurenzo         "get",
339436c6c9cSStella Laurenzo         [](py::list attributes, DefaultingPyMlirContext context) {
340436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
341436c6c9cSStella Laurenzo           mlirAttributes.reserve(py::len(attributes));
342436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
343ed9e52f3SAlex Zinenko             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
344436c6c9cSStella Laurenzo           }
345436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
346436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
347436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
348436c6c9cSStella Laurenzo         },
349436c6c9cSStella Laurenzo         py::arg("attributes"), py::arg("context") = py::none(),
350436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
351436c6c9cSStella Laurenzo     c.def("__getitem__",
352436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
353436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
354436c6c9cSStella Laurenzo               throw py::index_error("ArrayAttribute index out of range");
355ed9e52f3SAlex Zinenko             return arr.getItem(i);
356436c6c9cSStella Laurenzo           })
357436c6c9cSStella Laurenzo         .def("__len__",
358436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
359436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
360436c6c9cSStella Laurenzo              })
361436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
362436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
363436c6c9cSStella Laurenzo         });
364ed9e52f3SAlex Zinenko     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
365ed9e52f3SAlex Zinenko       std::vector<MlirAttribute> attributes;
366ed9e52f3SAlex Zinenko       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
367ed9e52f3SAlex Zinenko       attributes.reserve(numOldElements + py::len(extras));
368ed9e52f3SAlex Zinenko       for (intptr_t i = 0; i < numOldElements; ++i)
369ed9e52f3SAlex Zinenko         attributes.push_back(arr.getItem(i));
370ed9e52f3SAlex Zinenko       for (py::handle attr : extras)
371ed9e52f3SAlex Zinenko         attributes.push_back(pyTryCast<PyAttribute>(attr));
372ed9e52f3SAlex Zinenko       MlirAttribute arrayAttr = mlirArrayAttrGet(
373ed9e52f3SAlex Zinenko           arr.getContext()->get(), attributes.size(), attributes.data());
374ed9e52f3SAlex Zinenko       return PyArrayAttribute(arr.getContext(), arrayAttr);
375ed9e52f3SAlex Zinenko     });
376436c6c9cSStella Laurenzo   }
377436c6c9cSStella Laurenzo };
378436c6c9cSStella Laurenzo 
379436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
380436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
381436c6c9cSStella Laurenzo public:
382436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
383436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
384436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
3859566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
3869566ee28Smax       mlirFloatAttrGetTypeID;
387436c6c9cSStella Laurenzo 
388436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
389436c6c9cSStella Laurenzo     c.def_static(
390436c6c9cSStella Laurenzo         "get",
391436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
3923ea4c501SRahul Kayaith           PyMlirContext::ErrorCapture errors(loc->getContext());
393436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
3943ea4c501SRahul Kayaith           if (mlirAttributeIsNull(attr))
3953ea4c501SRahul Kayaith             throw MLIRError("Invalid attribute", errors.take());
396436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
397436c6c9cSStella Laurenzo         },
398436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
399436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
400436c6c9cSStella Laurenzo     c.def_static(
401436c6c9cSStella Laurenzo         "get_f32",
402436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
403436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
404436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
405436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
406436c6c9cSStella Laurenzo         },
407436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
408436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
409436c6c9cSStella Laurenzo     c.def_static(
410436c6c9cSStella Laurenzo         "get_f64",
411436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
412436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
413436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
414436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
415436c6c9cSStella Laurenzo         },
416436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
417436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
4182a5d4974SIngo Müller     c.def_property_readonly("value", mlirFloatAttrGetValueDouble,
4192a5d4974SIngo Müller                             "Returns the value of the float attribute");
4202a5d4974SIngo Müller     c.def("__float__", mlirFloatAttrGetValueDouble,
4212a5d4974SIngo Müller           "Converts the value of the float attribute to a Python float");
422436c6c9cSStella Laurenzo   }
423436c6c9cSStella Laurenzo };
424436c6c9cSStella Laurenzo 
425436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
426436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
427436c6c9cSStella Laurenzo public:
428436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
429436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
430436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
431436c6c9cSStella Laurenzo 
432436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
433436c6c9cSStella Laurenzo     c.def_static(
434436c6c9cSStella Laurenzo         "get",
435436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
436436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
437436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
438436c6c9cSStella Laurenzo         },
439436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"),
440436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
4412a5d4974SIngo Müller     c.def_property_readonly("value", toPyInt,
4422a5d4974SIngo Müller                             "Returns the value of the integer attribute");
4432a5d4974SIngo Müller     c.def("__int__", toPyInt,
4442a5d4974SIngo Müller           "Converts the value of the integer attribute to a Python int");
4452a5d4974SIngo Müller     c.def_property_readonly_static("static_typeid",
4462a5d4974SIngo Müller                                    [](py::object & /*class*/) -> MlirTypeID {
4472a5d4974SIngo Müller                                      return mlirIntegerAttrGetTypeID();
4482a5d4974SIngo Müller                                    });
4492a5d4974SIngo Müller   }
4502a5d4974SIngo Müller 
4512a5d4974SIngo Müller private:
4522a5d4974SIngo Müller   static py::int_ toPyInt(PyIntegerAttribute &self) {
453e9db306dSrkayaith     MlirType type = mlirAttributeGetType(self);
454e9db306dSrkayaith     if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
455436c6c9cSStella Laurenzo       return mlirIntegerAttrGetValueInt(self);
456e9db306dSrkayaith     if (mlirIntegerTypeIsSigned(type))
457e9db306dSrkayaith       return mlirIntegerAttrGetValueSInt(self);
458e9db306dSrkayaith     return mlirIntegerAttrGetValueUInt(self);
459436c6c9cSStella Laurenzo   }
460436c6c9cSStella Laurenzo };
461436c6c9cSStella Laurenzo 
462436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
463436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
464436c6c9cSStella Laurenzo public:
465436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
466436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
467436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
468436c6c9cSStella Laurenzo 
469436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
470436c6c9cSStella Laurenzo     c.def_static(
471436c6c9cSStella Laurenzo         "get",
472436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
473436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
474436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
475436c6c9cSStella Laurenzo         },
476436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
477436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
4782a5d4974SIngo Müller     c.def_property_readonly("value", mlirBoolAttrGetValue,
479436c6c9cSStella Laurenzo                             "Returns the value of the bool attribute");
4802a5d4974SIngo Müller     c.def("__bool__", mlirBoolAttrGetValue,
4812a5d4974SIngo Müller           "Converts the value of the bool attribute to a Python bool");
482436c6c9cSStella Laurenzo   }
483436c6c9cSStella Laurenzo };
484436c6c9cSStella Laurenzo 
4854eee9ef9Smax class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
4864eee9ef9Smax public:
4874eee9ef9Smax   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
4884eee9ef9Smax   static constexpr const char *pyClassName = "SymbolRefAttr";
4894eee9ef9Smax   using PyConcreteAttribute::PyConcreteAttribute;
4904eee9ef9Smax 
4914eee9ef9Smax   static MlirAttribute fromList(const std::vector<std::string> &symbols,
4924eee9ef9Smax                                 PyMlirContext &context) {
4934eee9ef9Smax     if (symbols.empty())
4944eee9ef9Smax       throw std::runtime_error("SymbolRefAttr must be composed of at least "
4954eee9ef9Smax                                "one symbol.");
4964eee9ef9Smax     MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
4974eee9ef9Smax     SmallVector<MlirAttribute, 3> referenceAttrs;
4984eee9ef9Smax     for (size_t i = 1; i < symbols.size(); ++i) {
4994eee9ef9Smax       referenceAttrs.push_back(
5004eee9ef9Smax           mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
5014eee9ef9Smax     }
5024eee9ef9Smax     return mlirSymbolRefAttrGet(context.get(), rootSymbol,
5034eee9ef9Smax                                 referenceAttrs.size(), referenceAttrs.data());
5044eee9ef9Smax   }
5054eee9ef9Smax 
5064eee9ef9Smax   static void bindDerived(ClassTy &c) {
5074eee9ef9Smax     c.def_static(
5084eee9ef9Smax         "get",
5094eee9ef9Smax         [](const std::vector<std::string> &symbols,
5104eee9ef9Smax            DefaultingPyMlirContext context) {
5114eee9ef9Smax           return PySymbolRefAttribute::fromList(symbols, context.resolve());
5124eee9ef9Smax         },
5134eee9ef9Smax         py::arg("symbols"), py::arg("context") = py::none(),
5144eee9ef9Smax         "Gets a uniqued SymbolRef attribute from a list of symbol names");
5154eee9ef9Smax     c.def_property_readonly(
5164eee9ef9Smax         "value",
5174eee9ef9Smax         [](PySymbolRefAttribute &self) {
5184eee9ef9Smax           std::vector<std::string> symbols = {
5194eee9ef9Smax               unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
5204eee9ef9Smax           for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
5214eee9ef9Smax                ++i)
5224eee9ef9Smax             symbols.push_back(
5234eee9ef9Smax                 unwrap(mlirSymbolRefAttrGetRootReference(
5244eee9ef9Smax                            mlirSymbolRefAttrGetNestedReference(self, i)))
5254eee9ef9Smax                     .str());
5264eee9ef9Smax           return symbols;
5274eee9ef9Smax         },
5284eee9ef9Smax         "Returns the value of the SymbolRef attribute as a list[str]");
5294eee9ef9Smax   }
5304eee9ef9Smax };
5314eee9ef9Smax 
532436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
533436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
534436c6c9cSStella Laurenzo public:
535436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
536436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
537436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
538436c6c9cSStella Laurenzo 
539436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
540436c6c9cSStella Laurenzo     c.def_static(
541436c6c9cSStella Laurenzo         "get",
542436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
543436c6c9cSStella Laurenzo           MlirAttribute attr =
544436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
545436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
546436c6c9cSStella Laurenzo         },
547436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
548436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
549436c6c9cSStella Laurenzo     c.def_property_readonly(
550436c6c9cSStella Laurenzo         "value",
551436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
552436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
553436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
554436c6c9cSStella Laurenzo         },
555436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
556436c6c9cSStella Laurenzo   }
557436c6c9cSStella Laurenzo };
558436c6c9cSStella Laurenzo 
5595c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
5605c3861b2SYun Long public:
5615c3861b2SYun Long   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
5625c3861b2SYun Long   static constexpr const char *pyClassName = "OpaqueAttr";
5635c3861b2SYun Long   using PyConcreteAttribute::PyConcreteAttribute;
5649566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
5659566ee28Smax       mlirOpaqueAttrGetTypeID;
5665c3861b2SYun Long 
5675c3861b2SYun Long   static void bindDerived(ClassTy &c) {
5685c3861b2SYun Long     c.def_static(
5695c3861b2SYun Long         "get",
5705c3861b2SYun Long         [](std::string dialectNamespace, py::buffer buffer, PyType &type,
5715c3861b2SYun Long            DefaultingPyMlirContext context) {
5725c3861b2SYun Long           const py::buffer_info bufferInfo = buffer.request();
5735c3861b2SYun Long           intptr_t bufferSize = bufferInfo.size;
5745c3861b2SYun Long           MlirAttribute attr = mlirOpaqueAttrGet(
5755c3861b2SYun Long               context->get(), toMlirStringRef(dialectNamespace), bufferSize,
5765c3861b2SYun Long               static_cast<char *>(bufferInfo.ptr), type);
5775c3861b2SYun Long           return PyOpaqueAttribute(context->getRef(), attr);
5785c3861b2SYun Long         },
5795c3861b2SYun Long         py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"),
5805c3861b2SYun Long         py::arg("context") = py::none(), "Gets an Opaque attribute.");
5815c3861b2SYun Long     c.def_property_readonly(
5825c3861b2SYun Long         "dialect_namespace",
5835c3861b2SYun Long         [](PyOpaqueAttribute &self) {
5845c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
5855c3861b2SYun Long           return py::str(stringRef.data, stringRef.length);
5865c3861b2SYun Long         },
5875c3861b2SYun Long         "Returns the dialect namespace for the Opaque attribute as a string");
5885c3861b2SYun Long     c.def_property_readonly(
5895c3861b2SYun Long         "data",
5905c3861b2SYun Long         [](PyOpaqueAttribute &self) {
5915c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
59262bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
5935c3861b2SYun Long         },
59462bf6c2eSChris Jones         "Returns the data for the Opaqued attributes as `bytes`");
5955c3861b2SYun Long   }
5965c3861b2SYun Long };
5975c3861b2SYun Long 
598436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
599436c6c9cSStella Laurenzo public:
600436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
601436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
602436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
6039566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
6049566ee28Smax       mlirStringAttrGetTypeID;
605436c6c9cSStella Laurenzo 
606436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
607436c6c9cSStella Laurenzo     c.def_static(
608436c6c9cSStella Laurenzo         "get",
609436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
610436c6c9cSStella Laurenzo           MlirAttribute attr =
611436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
612436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
613436c6c9cSStella Laurenzo         },
614436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
615436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
616436c6c9cSStella Laurenzo     c.def_static(
617436c6c9cSStella Laurenzo         "get_typed",
618436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
619436c6c9cSStella Laurenzo           MlirAttribute attr =
620436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
621436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
622436c6c9cSStella Laurenzo         },
623a6e7d024SStella Laurenzo         py::arg("type"), py::arg("value"),
624436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
6259f533548SIngo Müller     c.def_property_readonly(
6269f533548SIngo Müller         "value",
6279f533548SIngo Müller         [](PyStringAttribute &self) {
6289f533548SIngo Müller           MlirStringRef stringRef = mlirStringAttrGetValue(self);
6299f533548SIngo Müller           return py::str(stringRef.data, stringRef.length);
6309f533548SIngo Müller         },
631436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
63262bf6c2eSChris Jones     c.def_property_readonly(
63362bf6c2eSChris Jones         "value_bytes",
63462bf6c2eSChris Jones         [](PyStringAttribute &self) {
63562bf6c2eSChris Jones           MlirStringRef stringRef = mlirStringAttrGetValue(self);
63662bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
63762bf6c2eSChris Jones         },
63862bf6c2eSChris Jones         "Returns the value of the string attribute as `bytes`");
639436c6c9cSStella Laurenzo   }
640436c6c9cSStella Laurenzo };
641436c6c9cSStella Laurenzo 
642436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
643436c6c9cSStella Laurenzo class PyDenseElementsAttribute
644436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
645436c6c9cSStella Laurenzo public:
646436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
647436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
648436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
649436c6c9cSStella Laurenzo 
650436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
6510a81ace0SKazu Hirata   getFromBuffer(py::buffer array, bool signless,
6520a81ace0SKazu Hirata                 std::optional<PyType> explicitType,
6530a81ace0SKazu Hirata                 std::optional<std::vector<int64_t>> explicitShape,
654436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
655436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
65671a25454SPeter Hawkins     int flags = PyBUF_ND;
65771a25454SPeter Hawkins     if (!explicitType) {
65871a25454SPeter Hawkins       flags |= PyBUF_FORMAT;
65971a25454SPeter Hawkins     }
66071a25454SPeter Hawkins     Py_buffer view;
66171a25454SPeter Hawkins     if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
662436c6c9cSStella Laurenzo       throw py::error_already_set();
663436c6c9cSStella Laurenzo     }
66471a25454SPeter Hawkins     auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
6655d6d30edSStella Laurenzo     SmallVector<int64_t> shape;
6665d6d30edSStella Laurenzo     if (explicitShape) {
6675d6d30edSStella Laurenzo       shape.append(explicitShape->begin(), explicitShape->end());
6685d6d30edSStella Laurenzo     } else {
66971a25454SPeter Hawkins       shape.append(view.shape, view.shape + view.ndim);
6705d6d30edSStella Laurenzo     }
671436c6c9cSStella Laurenzo 
6725d6d30edSStella Laurenzo     MlirAttribute encodingAttr = mlirAttributeGetNull();
673436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
6745d6d30edSStella Laurenzo 
6755d6d30edSStella Laurenzo     // Detect format codes that are suitable for bulk loading. This includes
6765d6d30edSStella Laurenzo     // all byte aligned integer and floating point types up to 8 bytes.
6775d6d30edSStella Laurenzo     // Notably, this excludes, bool (which needs to be bit-packed) and
6785d6d30edSStella Laurenzo     // other exotics which do not have a direct representation in the buffer
6795d6d30edSStella Laurenzo     // protocol (i.e. complex, etc).
6800a81ace0SKazu Hirata     std::optional<MlirType> bulkLoadElementType;
6815d6d30edSStella Laurenzo     if (explicitType) {
6825d6d30edSStella Laurenzo       bulkLoadElementType = *explicitType;
68371a25454SPeter Hawkins     } else {
68471a25454SPeter Hawkins       std::string_view format(view.format);
68571a25454SPeter Hawkins       if (format == "f") {
686436c6c9cSStella Laurenzo         // f32
68771a25454SPeter Hawkins         assert(view.itemsize == 4 && "mismatched array itemsize");
6885d6d30edSStella Laurenzo         bulkLoadElementType = mlirF32TypeGet(context);
68971a25454SPeter Hawkins       } else if (format == "d") {
690436c6c9cSStella Laurenzo         // f64
69171a25454SPeter Hawkins         assert(view.itemsize == 8 && "mismatched array itemsize");
6925d6d30edSStella Laurenzo         bulkLoadElementType = mlirF64TypeGet(context);
69371a25454SPeter Hawkins       } else if (format == "e") {
6945d6d30edSStella Laurenzo         // f16
69571a25454SPeter Hawkins         assert(view.itemsize == 2 && "mismatched array itemsize");
6965d6d30edSStella Laurenzo         bulkLoadElementType = mlirF16TypeGet(context);
69771a25454SPeter Hawkins       } else if (isSignedIntegerFormat(format)) {
69871a25454SPeter Hawkins         if (view.itemsize == 4) {
699436c6c9cSStella Laurenzo           // i32
70071a25454SPeter Hawkins           bulkLoadElementType = signless
70171a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 32)
702436c6c9cSStella Laurenzo                                     : mlirIntegerTypeSignedGet(context, 32);
70371a25454SPeter Hawkins         } else if (view.itemsize == 8) {
704436c6c9cSStella Laurenzo           // i64
70571a25454SPeter Hawkins           bulkLoadElementType = signless
70671a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 64)
707436c6c9cSStella Laurenzo                                     : mlirIntegerTypeSignedGet(context, 64);
70871a25454SPeter Hawkins         } else if (view.itemsize == 1) {
7095d6d30edSStella Laurenzo           // i8
7105d6d30edSStella Laurenzo           bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
7115d6d30edSStella Laurenzo                                          : mlirIntegerTypeSignedGet(context, 8);
71271a25454SPeter Hawkins         } else if (view.itemsize == 2) {
7135d6d30edSStella Laurenzo           // i16
71471a25454SPeter Hawkins           bulkLoadElementType = signless
71571a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 16)
7165d6d30edSStella Laurenzo                                     : mlirIntegerTypeSignedGet(context, 16);
717436c6c9cSStella Laurenzo         }
71871a25454SPeter Hawkins       } else if (isUnsignedIntegerFormat(format)) {
71971a25454SPeter Hawkins         if (view.itemsize == 4) {
720436c6c9cSStella Laurenzo           // unsigned i32
7215d6d30edSStella Laurenzo           bulkLoadElementType = signless
722436c6c9cSStella Laurenzo                                     ? mlirIntegerTypeGet(context, 32)
723436c6c9cSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 32);
72471a25454SPeter Hawkins         } else if (view.itemsize == 8) {
725436c6c9cSStella Laurenzo           // unsigned i64
7265d6d30edSStella Laurenzo           bulkLoadElementType = signless
727436c6c9cSStella Laurenzo                                     ? mlirIntegerTypeGet(context, 64)
728436c6c9cSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 64);
72971a25454SPeter Hawkins         } else if (view.itemsize == 1) {
7305d6d30edSStella Laurenzo           // i8
73171a25454SPeter Hawkins           bulkLoadElementType = signless
73271a25454SPeter Hawkins                                     ? mlirIntegerTypeGet(context, 8)
7335d6d30edSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 8);
73471a25454SPeter Hawkins         } else if (view.itemsize == 2) {
7355d6d30edSStella Laurenzo           // i16
7365d6d30edSStella Laurenzo           bulkLoadElementType = signless
7375d6d30edSStella Laurenzo                                     ? mlirIntegerTypeGet(context, 16)
7385d6d30edSStella Laurenzo                                     : mlirIntegerTypeUnsignedGet(context, 16);
739436c6c9cSStella Laurenzo         }
740436c6c9cSStella Laurenzo       }
74171a25454SPeter Hawkins       if (!bulkLoadElementType) {
74271a25454SPeter Hawkins         throw std::invalid_argument(
74371a25454SPeter Hawkins             std::string("unimplemented array format conversion from format: ") +
74471a25454SPeter Hawkins             std::string(format));
74571a25454SPeter Hawkins       }
74671a25454SPeter Hawkins     }
74771a25454SPeter Hawkins 
74899dee31eSAdam Paszke     MlirType shapedType;
74999dee31eSAdam Paszke     if (mlirTypeIsAShaped(*bulkLoadElementType)) {
75099dee31eSAdam Paszke       if (explicitShape) {
75199dee31eSAdam Paszke         throw std::invalid_argument("Shape can only be specified explicitly "
75299dee31eSAdam Paszke                                     "when the type is not a shaped type.");
75399dee31eSAdam Paszke       }
75499dee31eSAdam Paszke       shapedType = *bulkLoadElementType;
75599dee31eSAdam Paszke     } else {
75671a25454SPeter Hawkins       shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
75771a25454SPeter Hawkins                                            *bulkLoadElementType, encodingAttr);
75899dee31eSAdam Paszke     }
75971a25454SPeter Hawkins     size_t rawBufferSize = view.len;
76071a25454SPeter Hawkins     MlirAttribute attr =
76171a25454SPeter Hawkins         mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
7625d6d30edSStella Laurenzo     if (mlirAttributeIsNull(attr)) {
7635d6d30edSStella Laurenzo       throw std::invalid_argument(
7645d6d30edSStella Laurenzo           "DenseElementsAttr could not be constructed from the given buffer. "
7655d6d30edSStella Laurenzo           "This may mean that the Python buffer layout does not match that "
7665d6d30edSStella Laurenzo           "MLIR expected layout and is a bug.");
7675d6d30edSStella Laurenzo     }
7685d6d30edSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
7695d6d30edSStella Laurenzo   }
770436c6c9cSStella Laurenzo 
7711fc096afSMehdi Amini   static PyDenseElementsAttribute getSplat(const PyType &shapedType,
772436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
773436c6c9cSStella Laurenzo     auto contextWrapper =
774436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
775436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
776436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
777436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
778436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
7794811270bSmax       throw py::value_error(message);
780436c6c9cSStella Laurenzo     }
781436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
782436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
783436c6c9cSStella Laurenzo       std::string message =
784436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
785436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
7864811270bSmax       throw py::value_error(message);
787436c6c9cSStella Laurenzo     }
788436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
789436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
790436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
791436c6c9cSStella Laurenzo       std::string message =
792436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
793436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
794436c6c9cSStella Laurenzo       message.append(", element=");
795436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
7964811270bSmax       throw py::value_error(message);
797436c6c9cSStella Laurenzo     }
798436c6c9cSStella Laurenzo 
799436c6c9cSStella Laurenzo     MlirAttribute elements =
800436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
801436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
802436c6c9cSStella Laurenzo   }
803436c6c9cSStella Laurenzo 
804436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
805436c6c9cSStella Laurenzo 
806436c6c9cSStella Laurenzo   py::buffer_info accessBuffer() {
807436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
808436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
8095d6d30edSStella Laurenzo     std::string format;
810436c6c9cSStella Laurenzo 
811436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
812436c6c9cSStella Laurenzo       // f32
8135d6d30edSStella Laurenzo       return bufferInfo<float>(shapedType);
81402b6fb21SMehdi Amini     }
81502b6fb21SMehdi Amini     if (mlirTypeIsAF64(elementType)) {
816436c6c9cSStella Laurenzo       // f64
8175d6d30edSStella Laurenzo       return bufferInfo<double>(shapedType);
818bb56c2b3SMehdi Amini     }
819bb56c2b3SMehdi Amini     if (mlirTypeIsAF16(elementType)) {
8205d6d30edSStella Laurenzo       // f16
8215d6d30edSStella Laurenzo       return bufferInfo<uint16_t>(shapedType, "e");
822bb56c2b3SMehdi Amini     }
823ef1b735dSmax     if (mlirTypeIsAIndex(elementType)) {
824ef1b735dSmax       // Same as IndexType::kInternalStorageBitWidth
825ef1b735dSmax       return bufferInfo<int64_t>(shapedType);
826ef1b735dSmax     }
827bb56c2b3SMehdi Amini     if (mlirTypeIsAInteger(elementType) &&
828436c6c9cSStella Laurenzo         mlirIntegerTypeGetWidth(elementType) == 32) {
829436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
830436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
831436c6c9cSStella Laurenzo         // i32
8325d6d30edSStella Laurenzo         return bufferInfo<int32_t>(shapedType);
833e5639b3fSMehdi Amini       }
834e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
835436c6c9cSStella Laurenzo         // unsigned i32
8365d6d30edSStella Laurenzo         return bufferInfo<uint32_t>(shapedType);
837436c6c9cSStella Laurenzo       }
838436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
839436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
840436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
841436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
842436c6c9cSStella Laurenzo         // i64
8435d6d30edSStella Laurenzo         return bufferInfo<int64_t>(shapedType);
844e5639b3fSMehdi Amini       }
845e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
846436c6c9cSStella Laurenzo         // unsigned i64
8475d6d30edSStella Laurenzo         return bufferInfo<uint64_t>(shapedType);
8485d6d30edSStella Laurenzo       }
8495d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
8505d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 8) {
8515d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
8525d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
8535d6d30edSStella Laurenzo         // i8
8545d6d30edSStella Laurenzo         return bufferInfo<int8_t>(shapedType);
855e5639b3fSMehdi Amini       }
856e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
8575d6d30edSStella Laurenzo         // unsigned i8
8585d6d30edSStella Laurenzo         return bufferInfo<uint8_t>(shapedType);
8595d6d30edSStella Laurenzo       }
8605d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
8615d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 16) {
8625d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
8635d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
8645d6d30edSStella Laurenzo         // i16
8655d6d30edSStella Laurenzo         return bufferInfo<int16_t>(shapedType);
866e5639b3fSMehdi Amini       }
867e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
8685d6d30edSStella Laurenzo         // unsigned i16
8695d6d30edSStella Laurenzo         return bufferInfo<uint16_t>(shapedType);
870436c6c9cSStella Laurenzo       }
871436c6c9cSStella Laurenzo     }
872436c6c9cSStella Laurenzo 
873c5f445d1SStella Laurenzo     // TODO: Currently crashes the program.
8745d6d30edSStella Laurenzo     // Reported as https://github.com/pybind/pybind11/issues/3336
875c5f445d1SStella Laurenzo     throw std::invalid_argument(
876c5f445d1SStella Laurenzo         "unsupported data type for conversion to Python buffer");
877436c6c9cSStella Laurenzo   }
878436c6c9cSStella Laurenzo 
879436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
880436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
881436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
882436c6c9cSStella Laurenzo                     py::arg("array"), py::arg("signless") = true,
8835d6d30edSStella Laurenzo                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
884436c6c9cSStella Laurenzo                     py::arg("context") = py::none(),
8855d6d30edSStella Laurenzo                     kDenseElementsAttrGetDocstring)
886436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
887436c6c9cSStella Laurenzo                     py::arg("shaped_type"), py::arg("element_attr"),
888436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
889436c6c9cSStella Laurenzo         .def_property_readonly("is_splat",
890436c6c9cSStella Laurenzo                                [](PyDenseElementsAttribute &self) -> bool {
891436c6c9cSStella Laurenzo                                  return mlirDenseElementsAttrIsSplat(self);
892436c6c9cSStella Laurenzo                                })
89391259963SAdam Paszke         .def("get_splat_value",
894974c1596SRahul Kayaith              [](PyDenseElementsAttribute &self) {
895974c1596SRahul Kayaith                if (!mlirDenseElementsAttrIsSplat(self))
8964811270bSmax                  throw py::value_error(
89791259963SAdam Paszke                      "get_splat_value called on a non-splat attribute");
898974c1596SRahul Kayaith                return mlirDenseElementsAttrGetSplatValue(self);
89991259963SAdam Paszke              })
900436c6c9cSStella Laurenzo         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
901436c6c9cSStella Laurenzo   }
902436c6c9cSStella Laurenzo 
903436c6c9cSStella Laurenzo private:
90471a25454SPeter Hawkins   static bool isUnsignedIntegerFormat(std::string_view format) {
905436c6c9cSStella Laurenzo     if (format.empty())
906436c6c9cSStella Laurenzo       return false;
907436c6c9cSStella Laurenzo     char code = format[0];
908436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
909436c6c9cSStella Laurenzo            code == 'Q';
910436c6c9cSStella Laurenzo   }
911436c6c9cSStella Laurenzo 
91271a25454SPeter Hawkins   static bool isSignedIntegerFormat(std::string_view format) {
913436c6c9cSStella Laurenzo     if (format.empty())
914436c6c9cSStella Laurenzo       return false;
915436c6c9cSStella Laurenzo     char code = format[0];
916436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
917436c6c9cSStella Laurenzo            code == 'q';
918436c6c9cSStella Laurenzo   }
919436c6c9cSStella Laurenzo 
920436c6c9cSStella Laurenzo   template <typename Type>
921436c6c9cSStella Laurenzo   py::buffer_info bufferInfo(MlirType shapedType,
9225d6d30edSStella Laurenzo                              const char *explicitFormat = nullptr) {
923436c6c9cSStella Laurenzo     intptr_t rank = mlirShapedTypeGetRank(shapedType);
924436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
925436c6c9cSStella Laurenzo     // Buffer is configured for read-only access below.
926436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
927436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
928436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
929436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
930436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
931436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
932436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
933436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
934f0e847d0SRahul Kayaith     if (mlirDenseElementsAttrIsSplat(*this)) {
935f0e847d0SRahul Kayaith       // Splats are special, only the single value is stored.
936f0e847d0SRahul Kayaith       strides.assign(rank, 0);
937f0e847d0SRahul Kayaith     } else {
938436c6c9cSStella Laurenzo       for (intptr_t i = 1; i < rank; ++i) {
939f0e847d0SRahul Kayaith         intptr_t strideFactor = 1;
940f0e847d0SRahul Kayaith         for (intptr_t j = i; j < rank; ++j)
941436c6c9cSStella Laurenzo           strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
942436c6c9cSStella Laurenzo         strides.push_back(sizeof(Type) * strideFactor);
943436c6c9cSStella Laurenzo       }
944436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type));
945f0e847d0SRahul Kayaith     }
9465d6d30edSStella Laurenzo     std::string format;
9475d6d30edSStella Laurenzo     if (explicitFormat) {
9485d6d30edSStella Laurenzo       format = explicitFormat;
9495d6d30edSStella Laurenzo     } else {
9505d6d30edSStella Laurenzo       format = py::format_descriptor<Type>::format();
9515d6d30edSStella Laurenzo     }
9525d6d30edSStella Laurenzo     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
9535d6d30edSStella Laurenzo                            /*readonly=*/true);
954436c6c9cSStella Laurenzo   }
955436c6c9cSStella Laurenzo }; // namespace
956436c6c9cSStella Laurenzo 
957436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer
958436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access.
959436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
960436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
961436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
962436c6c9cSStella Laurenzo public:
963436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
964436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
965436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
966436c6c9cSStella Laurenzo 
967436c6c9cSStella Laurenzo   /// Returns the element at the given linear position. Asserts if the index is
968436c6c9cSStella Laurenzo   /// out of range.
969436c6c9cSStella Laurenzo   py::int_ dunderGetItem(intptr_t pos) {
970436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
9714811270bSmax       throw py::index_error("attempt to access out of bounds element");
972436c6c9cSStella Laurenzo     }
973436c6c9cSStella Laurenzo 
974436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
975436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
976436c6c9cSStella Laurenzo     assert(mlirTypeIsAInteger(type) &&
977436c6c9cSStella Laurenzo            "expected integer element type in dense int elements attribute");
978436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
979436c6c9cSStella Laurenzo     // elemental type of the attribute. py::int_ is implicitly constructible
980436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
981436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
982436c6c9cSStella Laurenzo     // querying them on each element access.
983436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
984436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
985436c6c9cSStella Laurenzo     if (isUnsigned) {
986436c6c9cSStella Laurenzo       if (width == 1) {
987436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
988436c6c9cSStella Laurenzo       }
989308d8b8cSRahul Kayaith       if (width == 8) {
990308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt8Value(*this, pos);
991308d8b8cSRahul Kayaith       }
992308d8b8cSRahul Kayaith       if (width == 16) {
993308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt16Value(*this, pos);
994308d8b8cSRahul Kayaith       }
995436c6c9cSStella Laurenzo       if (width == 32) {
996436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
997436c6c9cSStella Laurenzo       }
998436c6c9cSStella Laurenzo       if (width == 64) {
999436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
1000436c6c9cSStella Laurenzo       }
1001436c6c9cSStella Laurenzo     } else {
1002436c6c9cSStella Laurenzo       if (width == 1) {
1003436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
1004436c6c9cSStella Laurenzo       }
1005308d8b8cSRahul Kayaith       if (width == 8) {
1006308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt8Value(*this, pos);
1007308d8b8cSRahul Kayaith       }
1008308d8b8cSRahul Kayaith       if (width == 16) {
1009308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt16Value(*this, pos);
1010308d8b8cSRahul Kayaith       }
1011436c6c9cSStella Laurenzo       if (width == 32) {
1012436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt32Value(*this, pos);
1013436c6c9cSStella Laurenzo       }
1014436c6c9cSStella Laurenzo       if (width == 64) {
1015436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt64Value(*this, pos);
1016436c6c9cSStella Laurenzo       }
1017436c6c9cSStella Laurenzo     }
10184811270bSmax     throw py::type_error("Unsupported integer type");
1019436c6c9cSStella Laurenzo   }
1020436c6c9cSStella Laurenzo 
1021436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1022436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
1023436c6c9cSStella Laurenzo   }
1024436c6c9cSStella Laurenzo };
1025436c6c9cSStella Laurenzo 
1026*f66cd9e9SStella Laurenzo class PyDenseResourceElementsAttribute
1027*f66cd9e9SStella Laurenzo     : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
1028*f66cd9e9SStella Laurenzo public:
1029*f66cd9e9SStella Laurenzo   static constexpr IsAFunctionTy isaFunction =
1030*f66cd9e9SStella Laurenzo       mlirAttributeIsADenseResourceElements;
1031*f66cd9e9SStella Laurenzo   static constexpr const char *pyClassName = "DenseResourceElementsAttr";
1032*f66cd9e9SStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1033*f66cd9e9SStella Laurenzo 
1034*f66cd9e9SStella Laurenzo   static PyDenseResourceElementsAttribute
1035*f66cd9e9SStella Laurenzo   getFromBuffer(py::buffer buffer, std::string name, PyType type,
1036*f66cd9e9SStella Laurenzo                 std::optional<size_t> alignment, bool isMutable,
1037*f66cd9e9SStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
1038*f66cd9e9SStella Laurenzo     if (!mlirTypeIsAShaped(type)) {
1039*f66cd9e9SStella Laurenzo       throw std::invalid_argument(
1040*f66cd9e9SStella Laurenzo           "Constructing a DenseResourceElementsAttr requires a ShapedType.");
1041*f66cd9e9SStella Laurenzo     }
1042*f66cd9e9SStella Laurenzo 
1043*f66cd9e9SStella Laurenzo     // Do not request any conversions as we must ensure to use caller
1044*f66cd9e9SStella Laurenzo     // managed memory.
1045*f66cd9e9SStella Laurenzo     int flags = PyBUF_STRIDES;
1046*f66cd9e9SStella Laurenzo     std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
1047*f66cd9e9SStella Laurenzo     if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
1048*f66cd9e9SStella Laurenzo       throw py::error_already_set();
1049*f66cd9e9SStella Laurenzo     }
1050*f66cd9e9SStella Laurenzo 
1051*f66cd9e9SStella Laurenzo     // This scope releaser will only release if we haven't yet transferred
1052*f66cd9e9SStella Laurenzo     // ownership.
1053*f66cd9e9SStella Laurenzo     auto freeBuffer = llvm::make_scope_exit([&]() {
1054*f66cd9e9SStella Laurenzo       if (view)
1055*f66cd9e9SStella Laurenzo         PyBuffer_Release(view.get());
1056*f66cd9e9SStella Laurenzo     });
1057*f66cd9e9SStella Laurenzo 
1058*f66cd9e9SStella Laurenzo     if (!PyBuffer_IsContiguous(view.get(), 'A')) {
1059*f66cd9e9SStella Laurenzo       throw std::invalid_argument("Contiguous buffer is required.");
1060*f66cd9e9SStella Laurenzo     }
1061*f66cd9e9SStella Laurenzo 
1062*f66cd9e9SStella Laurenzo     // Infer alignment to be the stride of one element if not explicit.
1063*f66cd9e9SStella Laurenzo     size_t inferredAlignment;
1064*f66cd9e9SStella Laurenzo     if (alignment)
1065*f66cd9e9SStella Laurenzo       inferredAlignment = *alignment;
1066*f66cd9e9SStella Laurenzo     else
1067*f66cd9e9SStella Laurenzo       inferredAlignment = view->strides[view->ndim - 1];
1068*f66cd9e9SStella Laurenzo 
1069*f66cd9e9SStella Laurenzo     // The userData is a Py_buffer* that the deleter owns.
1070*f66cd9e9SStella Laurenzo     auto deleter = [](void *userData, const void *data, size_t size,
1071*f66cd9e9SStella Laurenzo                       size_t align) {
1072*f66cd9e9SStella Laurenzo       Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
1073*f66cd9e9SStella Laurenzo       PyBuffer_Release(ownedView);
1074*f66cd9e9SStella Laurenzo       delete ownedView;
1075*f66cd9e9SStella Laurenzo     };
1076*f66cd9e9SStella Laurenzo 
1077*f66cd9e9SStella Laurenzo     size_t rawBufferSize = view->len;
1078*f66cd9e9SStella Laurenzo     MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
1079*f66cd9e9SStella Laurenzo         type, toMlirStringRef(name), view->buf, rawBufferSize,
1080*f66cd9e9SStella Laurenzo         inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
1081*f66cd9e9SStella Laurenzo     if (mlirAttributeIsNull(attr)) {
1082*f66cd9e9SStella Laurenzo       throw std::invalid_argument(
1083*f66cd9e9SStella Laurenzo           "DenseResourceElementsAttr could not be constructed from the given "
1084*f66cd9e9SStella Laurenzo           "buffer. "
1085*f66cd9e9SStella Laurenzo           "This may mean that the Python buffer layout does not match that "
1086*f66cd9e9SStella Laurenzo           "MLIR expected layout and is a bug.");
1087*f66cd9e9SStella Laurenzo     }
1088*f66cd9e9SStella Laurenzo     view.release();
1089*f66cd9e9SStella Laurenzo     return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
1090*f66cd9e9SStella Laurenzo   }
1091*f66cd9e9SStella Laurenzo 
1092*f66cd9e9SStella Laurenzo   static void bindDerived(ClassTy &c) {
1093*f66cd9e9SStella Laurenzo     c.def_static("get_from_buffer",
1094*f66cd9e9SStella Laurenzo                  PyDenseResourceElementsAttribute::getFromBuffer,
1095*f66cd9e9SStella Laurenzo                  py::arg("array"), py::arg("name"), py::arg("type"),
1096*f66cd9e9SStella Laurenzo                  py::arg("alignment") = py::none(),
1097*f66cd9e9SStella Laurenzo                  py::arg("is_mutable") = false, py::arg("context") = py::none(),
1098*f66cd9e9SStella Laurenzo                  kDenseResourceElementsAttrGetFromBufferDocstring);
1099*f66cd9e9SStella Laurenzo   }
1100*f66cd9e9SStella Laurenzo };
1101*f66cd9e9SStella Laurenzo 
1102436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
1103436c6c9cSStella Laurenzo public:
1104436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
1105436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
1106436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
11079566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
11089566ee28Smax       mlirDictionaryAttrGetTypeID;
1109436c6c9cSStella Laurenzo 
1110436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
1111436c6c9cSStella Laurenzo 
11129fb1086bSAdrian Kuegel   bool dunderContains(const std::string &name) {
11139fb1086bSAdrian Kuegel     return !mlirAttributeIsNull(
11149fb1086bSAdrian Kuegel         mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
11159fb1086bSAdrian Kuegel   }
11169fb1086bSAdrian Kuegel 
1117436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
11189fb1086bSAdrian Kuegel     c.def("__contains__", &PyDictAttribute::dunderContains);
1119436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
1120436c6c9cSStella Laurenzo     c.def_static(
1121436c6c9cSStella Laurenzo         "get",
1122436c6c9cSStella Laurenzo         [](py::dict attributes, DefaultingPyMlirContext context) {
1123436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
1124436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
1125436c6c9cSStella Laurenzo           for (auto &it : attributes) {
112602b6fb21SMehdi Amini             auto &mlirAttr = it.second.cast<PyAttribute &>();
1127436c6c9cSStella Laurenzo             auto name = it.first.cast<std::string>();
1128436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
112902b6fb21SMehdi Amini                 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
1130436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
113102b6fb21SMehdi Amini                 mlirAttr));
1132436c6c9cSStella Laurenzo           }
1133436c6c9cSStella Laurenzo           MlirAttribute attr =
1134436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
1135436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
1136436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
1137436c6c9cSStella Laurenzo         },
1138ed9e52f3SAlex Zinenko         py::arg("value") = py::dict(), py::arg("context") = py::none(),
1139436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
1140436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
1141436c6c9cSStella Laurenzo       MlirAttribute attr =
1142436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
1143974c1596SRahul Kayaith       if (mlirAttributeIsNull(attr))
11444811270bSmax         throw py::key_error("attempt to access a non-existent attribute");
1145974c1596SRahul Kayaith       return attr;
1146436c6c9cSStella Laurenzo     });
1147436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
1148436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
11494811270bSmax         throw py::index_error("attempt to access out of bounds attribute");
1150436c6c9cSStella Laurenzo       }
1151436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
1152436c6c9cSStella Laurenzo       return PyNamedAttribute(
1153436c6c9cSStella Laurenzo           namedAttr.attribute,
1154436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
1155436c6c9cSStella Laurenzo     });
1156436c6c9cSStella Laurenzo   }
1157436c6c9cSStella Laurenzo };
1158436c6c9cSStella Laurenzo 
1159436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
1160436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
1161436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
1162436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1163436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
1164436c6c9cSStella Laurenzo public:
1165436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1166436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
1167436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1168436c6c9cSStella Laurenzo 
1169436c6c9cSStella Laurenzo   py::float_ dunderGetItem(intptr_t pos) {
1170436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
11714811270bSmax       throw py::index_error("attempt to access out of bounds element");
1172436c6c9cSStella Laurenzo     }
1173436c6c9cSStella Laurenzo 
1174436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
1175436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
1176436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
1177436c6c9cSStella Laurenzo     // elemental type of the attribute. py::float_ is implicitly constructible
1178436c6c9cSStella Laurenzo     // from float and double.
1179436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
1180436c6c9cSStella Laurenzo     // querying them on each element access.
1181436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
1182436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetFloatValue(*this, pos);
1183436c6c9cSStella Laurenzo     }
1184436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
1185436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
1186436c6c9cSStella Laurenzo     }
11874811270bSmax     throw py::type_error("Unsupported floating-point type");
1188436c6c9cSStella Laurenzo   }
1189436c6c9cSStella Laurenzo 
1190436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1191436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1192436c6c9cSStella Laurenzo   }
1193436c6c9cSStella Laurenzo };
1194436c6c9cSStella Laurenzo 
1195436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1196436c6c9cSStella Laurenzo public:
1197436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1198436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
1199436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
12009566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
12019566ee28Smax       mlirTypeAttrGetTypeID;
1202436c6c9cSStella Laurenzo 
1203436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1204436c6c9cSStella Laurenzo     c.def_static(
1205436c6c9cSStella Laurenzo         "get",
1206436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
1207436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
1208436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
1209436c6c9cSStella Laurenzo         },
1210436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
1211436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
1212436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyTypeAttribute &self) {
1213bfb1ba75Smax       return mlirTypeAttrGetValue(self.get());
1214436c6c9cSStella Laurenzo     });
1215436c6c9cSStella Laurenzo   }
1216436c6c9cSStella Laurenzo };
1217436c6c9cSStella Laurenzo 
1218436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
1219436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1220436c6c9cSStella Laurenzo public:
1221436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1222436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
1223436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
12249566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
12259566ee28Smax       mlirUnitAttrGetTypeID;
1226436c6c9cSStella Laurenzo 
1227436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1228436c6c9cSStella Laurenzo     c.def_static(
1229436c6c9cSStella Laurenzo         "get",
1230436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
1231436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
1232436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
1233436c6c9cSStella Laurenzo         },
1234436c6c9cSStella Laurenzo         py::arg("context") = py::none(), "Create a Unit attribute.");
1235436c6c9cSStella Laurenzo   }
1236436c6c9cSStella Laurenzo };
1237436c6c9cSStella Laurenzo 
1238ac2e2d65SDenys Shabalin /// Strided layout attribute subclass.
1239ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute
1240ac2e2d65SDenys Shabalin     : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1241ac2e2d65SDenys Shabalin public:
1242ac2e2d65SDenys Shabalin   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1243ac2e2d65SDenys Shabalin   static constexpr const char *pyClassName = "StridedLayoutAttr";
1244ac2e2d65SDenys Shabalin   using PyConcreteAttribute::PyConcreteAttribute;
12459566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
12469566ee28Smax       mlirStridedLayoutAttrGetTypeID;
1247ac2e2d65SDenys Shabalin 
1248ac2e2d65SDenys Shabalin   static void bindDerived(ClassTy &c) {
1249ac2e2d65SDenys Shabalin     c.def_static(
1250ac2e2d65SDenys Shabalin         "get",
1251ac2e2d65SDenys Shabalin         [](int64_t offset, const std::vector<int64_t> strides,
1252ac2e2d65SDenys Shabalin            DefaultingPyMlirContext ctx) {
1253ac2e2d65SDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1254ac2e2d65SDenys Shabalin               ctx->get(), offset, strides.size(), strides.data());
1255ac2e2d65SDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1256ac2e2d65SDenys Shabalin         },
1257ac2e2d65SDenys Shabalin         py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
1258ac2e2d65SDenys Shabalin         "Gets a strided layout attribute.");
1259e3fd612eSDenys Shabalin     c.def_static(
1260e3fd612eSDenys Shabalin         "get_fully_dynamic",
1261e3fd612eSDenys Shabalin         [](int64_t rank, DefaultingPyMlirContext ctx) {
1262e3fd612eSDenys Shabalin           auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
1263e3fd612eSDenys Shabalin           std::vector<int64_t> strides(rank);
1264e3fd612eSDenys Shabalin           std::fill(strides.begin(), strides.end(), dynamic);
1265e3fd612eSDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1266e3fd612eSDenys Shabalin               ctx->get(), dynamic, strides.size(), strides.data());
1267e3fd612eSDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1268e3fd612eSDenys Shabalin         },
1269e3fd612eSDenys Shabalin         py::arg("rank"), py::arg("context") = py::none(),
1270e3fd612eSDenys Shabalin         "Gets a strided layout attribute with dynamic offset and strides of a "
1271e3fd612eSDenys Shabalin         "given rank.");
1272ac2e2d65SDenys Shabalin     c.def_property_readonly(
1273ac2e2d65SDenys Shabalin         "offset",
1274ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1275ac2e2d65SDenys Shabalin           return mlirStridedLayoutAttrGetOffset(self);
1276ac2e2d65SDenys Shabalin         },
1277ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1278ac2e2d65SDenys Shabalin     c.def_property_readonly(
1279ac2e2d65SDenys Shabalin         "strides",
1280ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1281ac2e2d65SDenys Shabalin           intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1282ac2e2d65SDenys Shabalin           std::vector<int64_t> strides(size);
1283ac2e2d65SDenys Shabalin           for (intptr_t i = 0; i < size; i++) {
1284ac2e2d65SDenys Shabalin             strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1285ac2e2d65SDenys Shabalin           }
1286ac2e2d65SDenys Shabalin           return strides;
1287ac2e2d65SDenys Shabalin         },
1288ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1289ac2e2d65SDenys Shabalin   }
1290ac2e2d65SDenys Shabalin };
1291ac2e2d65SDenys Shabalin 
12929566ee28Smax py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
12939566ee28Smax   if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
12949566ee28Smax     return py::cast(PyDenseBoolArrayAttribute(pyAttribute));
12959566ee28Smax   if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
12969566ee28Smax     return py::cast(PyDenseI8ArrayAttribute(pyAttribute));
12979566ee28Smax   if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
12989566ee28Smax     return py::cast(PyDenseI16ArrayAttribute(pyAttribute));
12999566ee28Smax   if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
13009566ee28Smax     return py::cast(PyDenseI32ArrayAttribute(pyAttribute));
13019566ee28Smax   if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
13029566ee28Smax     return py::cast(PyDenseI64ArrayAttribute(pyAttribute));
13039566ee28Smax   if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
13049566ee28Smax     return py::cast(PyDenseF32ArrayAttribute(pyAttribute));
13059566ee28Smax   if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
13069566ee28Smax     return py::cast(PyDenseF64ArrayAttribute(pyAttribute));
13079566ee28Smax   std::string msg =
13089566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
13099566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
13109566ee28Smax   throw py::cast_error(msg);
13119566ee28Smax }
13129566ee28Smax 
13139566ee28Smax py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
13149566ee28Smax   if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
13159566ee28Smax     return py::cast(PyDenseFPElementsAttribute(pyAttribute));
13169566ee28Smax   if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
13179566ee28Smax     return py::cast(PyDenseIntElementsAttribute(pyAttribute));
13189566ee28Smax   std::string msg =
13199566ee28Smax       std::string(
13209566ee28Smax           "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
13219566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
13229566ee28Smax   throw py::cast_error(msg);
13239566ee28Smax }
13249566ee28Smax 
13259566ee28Smax py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
13269566ee28Smax   if (PyBoolAttribute::isaFunction(pyAttribute))
13279566ee28Smax     return py::cast(PyBoolAttribute(pyAttribute));
13289566ee28Smax   if (PyIntegerAttribute::isaFunction(pyAttribute))
13299566ee28Smax     return py::cast(PyIntegerAttribute(pyAttribute));
13309566ee28Smax   std::string msg =
13319566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
13329566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
13339566ee28Smax   throw py::cast_error(msg);
13349566ee28Smax }
13359566ee28Smax 
13364eee9ef9Smax py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
13374eee9ef9Smax   if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
13384eee9ef9Smax     return py::cast(PyFlatSymbolRefAttribute(pyAttribute));
13394eee9ef9Smax   if (PySymbolRefAttribute::isaFunction(pyAttribute))
13404eee9ef9Smax     return py::cast(PySymbolRefAttribute(pyAttribute));
13414eee9ef9Smax   std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
13424eee9ef9Smax                     std::string(py::repr(py::cast(pyAttribute))) + ")";
13434eee9ef9Smax   throw py::cast_error(msg);
13444eee9ef9Smax }
13454eee9ef9Smax 
1346436c6c9cSStella Laurenzo } // namespace
1347436c6c9cSStella Laurenzo 
1348436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) {
1349436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
1350619fd8c2SJeff Niu 
1351619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::bind(m);
1352619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1353619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::bind(m);
1354619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1355619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::bind(m);
1356619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1357619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::bind(m);
1358619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1359619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::bind(m);
1360619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1361619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::bind(m);
1362619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1363619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::bind(m);
1364619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
13659566ee28Smax   PyGlobals::get().registerTypeCaster(
13669566ee28Smax       mlirDenseArrayAttrGetTypeID(),
13679566ee28Smax       pybind11::cpp_function(denseArrayAttributeCaster));
1368619fd8c2SJeff Niu 
1369436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
1370436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1371436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
1372436c6c9cSStella Laurenzo   PyDenseElementsAttribute::bind(m);
1373436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
1374436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
13759566ee28Smax   PyGlobals::get().registerTypeCaster(
13769566ee28Smax       mlirDenseIntOrFPElementsAttrGetTypeID(),
13779566ee28Smax       pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
1378*f66cd9e9SStella Laurenzo   PyDenseResourceElementsAttribute::bind(m);
13799566ee28Smax 
1380436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
13814eee9ef9Smax   PySymbolRefAttribute::bind(m);
13824eee9ef9Smax   PyGlobals::get().registerTypeCaster(
13834eee9ef9Smax       mlirSymbolRefAttrGetTypeID(),
13844eee9ef9Smax       pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster));
13854eee9ef9Smax 
1386436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
13875c3861b2SYun Long   PyOpaqueAttribute::bind(m);
1388436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
1389436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
1390436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
1391436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
13929566ee28Smax   PyGlobals::get().registerTypeCaster(
13939566ee28Smax       mlirIntegerAttrGetTypeID(),
13949566ee28Smax       pybind11::cpp_function(integerOrBoolAttributeCaster));
1395436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
1396ac2e2d65SDenys Shabalin 
1397ac2e2d65SDenys Shabalin   PyStridedLayoutAttribute::bind(m);
1398436c6c9cSStella Laurenzo }
1399