xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision bca889524a23a10d7a32024b80d687e2b3a1360c)
1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
91fc096afSMehdi Amini #include <utility>
101fc096afSMehdi Amini 
11436c6c9cSStella Laurenzo #include "IRModule.h"
12436c6c9cSStella Laurenzo 
13436c6c9cSStella Laurenzo #include "PybindUtils.h"
14436c6c9cSStella Laurenzo 
15436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
16436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
17436c6c9cSStella Laurenzo 
18436c6c9cSStella Laurenzo namespace py = pybind11;
19436c6c9cSStella Laurenzo using namespace mlir;
20436c6c9cSStella Laurenzo using namespace mlir::python;
21436c6c9cSStella Laurenzo 
225d6d30edSStella Laurenzo using llvm::Optional;
23436c6c9cSStella Laurenzo using llvm::SmallVector;
24436c6c9cSStella Laurenzo using llvm::Twine;
25436c6c9cSStella Laurenzo 
265d6d30edSStella Laurenzo //------------------------------------------------------------------------------
275d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
285d6d30edSStella Laurenzo //------------------------------------------------------------------------------
295d6d30edSStella Laurenzo 
305d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] =
315d6d30edSStella Laurenzo     R"(Gets a DenseElementsAttr from a Python buffer or array.
325d6d30edSStella Laurenzo 
335d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based
345d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and
355d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these
365d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer.
375d6d30edSStella Laurenzo 
385d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided
395d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal
405d6d30edSStella Laurenzo representation:
415d6d30edSStella Laurenzo 
425d6d30edSStella Laurenzo   * Integer types (except for i1): the buffer must be byte aligned to the
435d6d30edSStella Laurenzo     next byte boundary.
445d6d30edSStella Laurenzo   * Floating point types: Must be bit-castable to the given floating point
455d6d30edSStella Laurenzo     size.
465d6d30edSStella Laurenzo   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
475d6d30edSStella Laurenzo     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
485d6d30edSStella Laurenzo     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
495d6d30edSStella Laurenzo 
505d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0
515d6d30edSStella Laurenzo or 255), then a splat will be created.
525d6d30edSStella Laurenzo 
535d6d30edSStella Laurenzo Args:
545d6d30edSStella Laurenzo   array: The array or buffer to convert.
555d6d30edSStella Laurenzo   signless: If inferring an appropriate MLIR type, use signless types for
565d6d30edSStella Laurenzo     integers (defaults True).
575d6d30edSStella Laurenzo   type: Skips inference of the MLIR element type and uses this instead. The
585d6d30edSStella Laurenzo     storage size must be consistent with the actual contents of the buffer.
595d6d30edSStella Laurenzo   shape: Overrides the shape of the buffer when constructing the MLIR
605d6d30edSStella Laurenzo     shaped type. This is needed when the physical and logical shape differ (as
615d6d30edSStella Laurenzo     for i1).
625d6d30edSStella Laurenzo   context: Explicit context, if not from context manager.
635d6d30edSStella Laurenzo 
645d6d30edSStella Laurenzo Returns:
655d6d30edSStella Laurenzo   DenseElementsAttr on success.
665d6d30edSStella Laurenzo 
675d6d30edSStella Laurenzo Raises:
685d6d30edSStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
695d6d30edSStella Laurenzo     type or if the buffer does not meet expectations.
705d6d30edSStella Laurenzo )";
715d6d30edSStella Laurenzo 
72436c6c9cSStella Laurenzo namespace {
73436c6c9cSStella Laurenzo 
74436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
75436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
76436c6c9cSStella Laurenzo }
77436c6c9cSStella Laurenzo 
78436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
79436c6c9cSStella Laurenzo public:
80436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
81436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
82436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
83436c6c9cSStella Laurenzo 
84436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
85436c6c9cSStella Laurenzo     c.def_static(
86436c6c9cSStella Laurenzo         "get",
87436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
88436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
89436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
90436c6c9cSStella Laurenzo         },
91436c6c9cSStella Laurenzo         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
92436c6c9cSStella Laurenzo   }
93436c6c9cSStella Laurenzo };
94436c6c9cSStella Laurenzo 
95ed9e52f3SAlex Zinenko template <typename T>
96ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) {
97ed9e52f3SAlex Zinenko   try {
98ed9e52f3SAlex Zinenko     return object.cast<T>();
99ed9e52f3SAlex Zinenko   } catch (py::cast_error &err) {
100ed9e52f3SAlex Zinenko     std::string msg =
101ed9e52f3SAlex Zinenko         std::string(
102ed9e52f3SAlex Zinenko             "Invalid attribute when attempting to create an ArrayAttribute (") +
103ed9e52f3SAlex Zinenko         err.what() + ")";
104ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
105ed9e52f3SAlex Zinenko   } catch (py::reference_cast_error &err) {
106ed9e52f3SAlex Zinenko     std::string msg = std::string("Invalid attribute (None?) when attempting "
107ed9e52f3SAlex Zinenko                                   "to create an ArrayAttribute (") +
108ed9e52f3SAlex Zinenko                       err.what() + ")";
109ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
110ed9e52f3SAlex Zinenko   }
111ed9e52f3SAlex Zinenko }
112ed9e52f3SAlex Zinenko 
113619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived
114619fd8c2SJeff Niu /// implementation class.
115619fd8c2SJeff Niu template <typename EltTy, typename DerivedT>
116619fd8c2SJeff Niu class PyDenseArrayAttribute
117619fd8c2SJeff Niu     : public PyConcreteAttribute<PyDenseArrayAttribute<EltTy, DerivedT>> {
118619fd8c2SJeff Niu public:
119619fd8c2SJeff Niu   static constexpr typename PyConcreteAttribute<
120619fd8c2SJeff Niu       PyDenseArrayAttribute<EltTy, DerivedT>>::IsAFunctionTy isaFunction =
121619fd8c2SJeff Niu       DerivedT::isaFunction;
122619fd8c2SJeff Niu   static constexpr const char *pyClassName = DerivedT::pyClassName;
123619fd8c2SJeff Niu   using PyConcreteAttribute<
124619fd8c2SJeff Niu       PyDenseArrayAttribute<EltTy, DerivedT>>::PyConcreteAttribute;
125619fd8c2SJeff Niu 
126619fd8c2SJeff Niu   /// Iterator over the integer elements of a dense array.
127619fd8c2SJeff Niu   class PyDenseArrayIterator {
128619fd8c2SJeff Niu   public:
129619fd8c2SJeff Niu     PyDenseArrayIterator(PyAttribute attr) : attr(attr) {}
130619fd8c2SJeff Niu 
131619fd8c2SJeff Niu     /// Return a copy of the iterator.
132619fd8c2SJeff Niu     PyDenseArrayIterator dunderIter() { return *this; }
133619fd8c2SJeff Niu 
134619fd8c2SJeff Niu     /// Return the next element.
135619fd8c2SJeff Niu     EltTy dunderNext() {
136619fd8c2SJeff Niu       // Throw if the index has reached the end.
137619fd8c2SJeff Niu       if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
138619fd8c2SJeff Niu         throw py::stop_iteration();
139619fd8c2SJeff Niu       return DerivedT::getElement(attr.get(), nextIndex++);
140619fd8c2SJeff Niu     }
141619fd8c2SJeff Niu 
142619fd8c2SJeff Niu     /// Bind the iterator class.
143619fd8c2SJeff Niu     static void bind(py::module &m) {
144619fd8c2SJeff Niu       py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName,
145619fd8c2SJeff Niu                                        py::module_local())
146619fd8c2SJeff Niu           .def("__iter__", &PyDenseArrayIterator::dunderIter)
147619fd8c2SJeff Niu           .def("__next__", &PyDenseArrayIterator::dunderNext);
148619fd8c2SJeff Niu     }
149619fd8c2SJeff Niu 
150619fd8c2SJeff Niu   private:
151619fd8c2SJeff Niu     /// The referenced dense array attribute.
152619fd8c2SJeff Niu     PyAttribute attr;
153619fd8c2SJeff Niu     /// The next index to read.
154619fd8c2SJeff Niu     int nextIndex = 0;
155619fd8c2SJeff Niu   };
156619fd8c2SJeff Niu 
157619fd8c2SJeff Niu   /// Get the element at the given index.
158619fd8c2SJeff Niu   EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
159619fd8c2SJeff Niu 
160619fd8c2SJeff Niu   /// Bind the attribute class.
161619fd8c2SJeff Niu   static void bindDerived(typename PyConcreteAttribute<
162619fd8c2SJeff Niu                           PyDenseArrayAttribute<EltTy, DerivedT>>::ClassTy &c) {
163619fd8c2SJeff Niu     // Bind the constructor.
164619fd8c2SJeff Niu     c.def_static(
165619fd8c2SJeff Niu         "get",
166619fd8c2SJeff Niu         [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
167619fd8c2SJeff Niu           MlirAttribute attr =
168619fd8c2SJeff Niu               DerivedT::getAttribute(ctx->get(), values.size(), values.data());
169619fd8c2SJeff Niu           return PyDenseArrayAttribute<EltTy, DerivedT>(ctx->getRef(), attr);
170619fd8c2SJeff Niu         },
171619fd8c2SJeff Niu         py::arg("values"), py::arg("context") = py::none(),
172619fd8c2SJeff Niu         "Gets a uniqued dense array attribute");
173619fd8c2SJeff Niu     // Bind the array methods.
174619fd8c2SJeff Niu     c.def("__getitem__",
175619fd8c2SJeff Niu           [](PyDenseArrayAttribute<EltTy, DerivedT> &arr, intptr_t i) {
176619fd8c2SJeff Niu             if (i >= mlirDenseArrayGetNumElements(arr))
177619fd8c2SJeff Niu               throw py::index_error("DenseArray index out of range");
178619fd8c2SJeff Niu             return arr.getItem(i);
179619fd8c2SJeff Niu           });
180619fd8c2SJeff Niu     c.def("__len__", [](const PyDenseArrayAttribute<EltTy, DerivedT> &arr) {
181619fd8c2SJeff Niu       return mlirDenseArrayGetNumElements(arr);
182619fd8c2SJeff Niu     });
183619fd8c2SJeff Niu     c.def("__iter__", [](const PyDenseArrayAttribute<EltTy, DerivedT> &arr) {
184619fd8c2SJeff Niu       return PyDenseArrayIterator(arr);
185619fd8c2SJeff Niu     });
186619fd8c2SJeff Niu     c.def("__add__", [](PyDenseArrayAttribute<EltTy, DerivedT> &arr,
187619fd8c2SJeff Niu                         py::list extras) {
188619fd8c2SJeff Niu       std::vector<EltTy> values;
189619fd8c2SJeff Niu       intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
190619fd8c2SJeff Niu       values.reserve(numOldElements + py::len(extras));
191619fd8c2SJeff Niu       for (intptr_t i = 0; i < numOldElements; ++i)
192619fd8c2SJeff Niu         values.push_back(arr.getItem(i));
193619fd8c2SJeff Niu       for (py::handle attr : extras)
194619fd8c2SJeff Niu         values.push_back(pyTryCast<EltTy>(attr));
195619fd8c2SJeff Niu       MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(),
196619fd8c2SJeff Niu                                                   values.size(), values.data());
197619fd8c2SJeff Niu       return PyDenseArrayAttribute<EltTy, DerivedT>(arr.getContext(), attr);
198619fd8c2SJeff Niu     });
199619fd8c2SJeff Niu   }
200619fd8c2SJeff Niu };
201619fd8c2SJeff Niu 
202619fd8c2SJeff Niu /// Instantiate the python dense array classes.
203619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute
204619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int, PyDenseBoolArrayAttribute> {
205619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
206619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseBoolArrayGet;
207619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseBoolArrayGetElement;
208619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseBoolArrayAttr";
209619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
210619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
211619fd8c2SJeff Niu };
212619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute
213619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
214619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
215619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI8ArrayGet;
216619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI8ArrayGetElement;
217619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI8ArrayAttr";
218619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
219619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
220619fd8c2SJeff Niu };
221619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute
222619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
223619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
224619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI16ArrayGet;
225619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI16ArrayGetElement;
226619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI16ArrayAttr";
227619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
228619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
229619fd8c2SJeff Niu };
230619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute
231619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
232619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
233619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI32ArrayGet;
234619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI32ArrayGetElement;
235619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI32ArrayAttr";
236619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
237619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
238619fd8c2SJeff Niu };
239619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute
240619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
241619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
242619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI64ArrayGet;
243619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI64ArrayGetElement;
244619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI64ArrayAttr";
245619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
246619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
247619fd8c2SJeff Niu };
248619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute
249619fd8c2SJeff Niu     : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
250619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
251619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF32ArrayGet;
252619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF32ArrayGetElement;
253619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF32ArrayAttr";
254619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
255619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
256619fd8c2SJeff Niu };
257619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute
258619fd8c2SJeff Niu     : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
259619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
260619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF64ArrayGet;
261619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF64ArrayGetElement;
262619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF64ArrayAttr";
263619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
264619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
265619fd8c2SJeff Niu };
266619fd8c2SJeff Niu 
267436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
268436c6c9cSStella Laurenzo public:
269436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
270436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
271436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
272436c6c9cSStella Laurenzo 
273436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
274436c6c9cSStella Laurenzo   public:
2751fc096afSMehdi Amini     PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
276436c6c9cSStella Laurenzo 
277436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
278436c6c9cSStella Laurenzo 
279436c6c9cSStella Laurenzo     PyAttribute dunderNext() {
280*bca88952SJeff Niu       // TODO: Throw is an inefficient way to stop iteration.
281*bca88952SJeff Niu       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
282436c6c9cSStella Laurenzo         throw py::stop_iteration();
283436c6c9cSStella Laurenzo       return PyAttribute(attr.getContext(),
284436c6c9cSStella Laurenzo                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
285436c6c9cSStella Laurenzo     }
286436c6c9cSStella Laurenzo 
287436c6c9cSStella Laurenzo     static void bind(py::module &m) {
288f05ff4f7SStella Laurenzo       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
289f05ff4f7SStella Laurenzo                                            py::module_local())
290436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
291436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
292436c6c9cSStella Laurenzo     }
293436c6c9cSStella Laurenzo 
294436c6c9cSStella Laurenzo   private:
295436c6c9cSStella Laurenzo     PyAttribute attr;
296436c6c9cSStella Laurenzo     int nextIndex = 0;
297436c6c9cSStella Laurenzo   };
298436c6c9cSStella Laurenzo 
299ed9e52f3SAlex Zinenko   PyAttribute getItem(intptr_t i) {
300ed9e52f3SAlex Zinenko     return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
301ed9e52f3SAlex Zinenko   }
302ed9e52f3SAlex Zinenko 
303436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
304436c6c9cSStella Laurenzo     c.def_static(
305436c6c9cSStella Laurenzo         "get",
306436c6c9cSStella Laurenzo         [](py::list attributes, DefaultingPyMlirContext context) {
307436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
308436c6c9cSStella Laurenzo           mlirAttributes.reserve(py::len(attributes));
309436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
310ed9e52f3SAlex Zinenko             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
311436c6c9cSStella Laurenzo           }
312436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
313436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
314436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
315436c6c9cSStella Laurenzo         },
316436c6c9cSStella Laurenzo         py::arg("attributes"), py::arg("context") = py::none(),
317436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
318436c6c9cSStella Laurenzo     c.def("__getitem__",
319436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
320436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
321436c6c9cSStella Laurenzo               throw py::index_error("ArrayAttribute index out of range");
322ed9e52f3SAlex Zinenko             return arr.getItem(i);
323436c6c9cSStella Laurenzo           })
324436c6c9cSStella Laurenzo         .def("__len__",
325436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
326436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
327436c6c9cSStella Laurenzo              })
328436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
329436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
330436c6c9cSStella Laurenzo         });
331ed9e52f3SAlex Zinenko     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
332ed9e52f3SAlex Zinenko       std::vector<MlirAttribute> attributes;
333ed9e52f3SAlex Zinenko       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
334ed9e52f3SAlex Zinenko       attributes.reserve(numOldElements + py::len(extras));
335ed9e52f3SAlex Zinenko       for (intptr_t i = 0; i < numOldElements; ++i)
336ed9e52f3SAlex Zinenko         attributes.push_back(arr.getItem(i));
337ed9e52f3SAlex Zinenko       for (py::handle attr : extras)
338ed9e52f3SAlex Zinenko         attributes.push_back(pyTryCast<PyAttribute>(attr));
339ed9e52f3SAlex Zinenko       MlirAttribute arrayAttr = mlirArrayAttrGet(
340ed9e52f3SAlex Zinenko           arr.getContext()->get(), attributes.size(), attributes.data());
341ed9e52f3SAlex Zinenko       return PyArrayAttribute(arr.getContext(), arrayAttr);
342ed9e52f3SAlex Zinenko     });
343436c6c9cSStella Laurenzo   }
344436c6c9cSStella Laurenzo };
345436c6c9cSStella Laurenzo 
346436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
347436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
348436c6c9cSStella Laurenzo public:
349436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
350436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
351436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
352436c6c9cSStella Laurenzo 
353436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
354436c6c9cSStella Laurenzo     c.def_static(
355436c6c9cSStella Laurenzo         "get",
356436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
357436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
358436c6c9cSStella Laurenzo           // TODO: Rework error reporting once diagnostic engine is exposed
359436c6c9cSStella Laurenzo           // in C API.
360436c6c9cSStella Laurenzo           if (mlirAttributeIsNull(attr)) {
361436c6c9cSStella Laurenzo             throw SetPyError(PyExc_ValueError,
362436c6c9cSStella Laurenzo                              Twine("invalid '") +
363436c6c9cSStella Laurenzo                                  py::repr(py::cast(type)).cast<std::string>() +
364436c6c9cSStella Laurenzo                                  "' and expected floating point type.");
365436c6c9cSStella Laurenzo           }
366436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
367436c6c9cSStella Laurenzo         },
368436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
369436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
370436c6c9cSStella Laurenzo     c.def_static(
371436c6c9cSStella Laurenzo         "get_f32",
372436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
373436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
374436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
375436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
376436c6c9cSStella Laurenzo         },
377436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
378436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
379436c6c9cSStella Laurenzo     c.def_static(
380436c6c9cSStella Laurenzo         "get_f64",
381436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
382436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
383436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
384436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
385436c6c9cSStella Laurenzo         },
386436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
387436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
388436c6c9cSStella Laurenzo     c.def_property_readonly(
389436c6c9cSStella Laurenzo         "value",
390436c6c9cSStella Laurenzo         [](PyFloatAttribute &self) {
391436c6c9cSStella Laurenzo           return mlirFloatAttrGetValueDouble(self);
392436c6c9cSStella Laurenzo         },
393436c6c9cSStella Laurenzo         "Returns the value of the float point attribute");
394436c6c9cSStella Laurenzo   }
395436c6c9cSStella Laurenzo };
396436c6c9cSStella Laurenzo 
397436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
398436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
399436c6c9cSStella Laurenzo public:
400436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
401436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
402436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
403436c6c9cSStella Laurenzo 
404436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
405436c6c9cSStella Laurenzo     c.def_static(
406436c6c9cSStella Laurenzo         "get",
407436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
408436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
409436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
410436c6c9cSStella Laurenzo         },
411436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"),
412436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
413436c6c9cSStella Laurenzo     c.def_property_readonly(
414436c6c9cSStella Laurenzo         "value",
415e9db306dSrkayaith         [](PyIntegerAttribute &self) -> py::int_ {
416e9db306dSrkayaith           MlirType type = mlirAttributeGetType(self);
417e9db306dSrkayaith           if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
418436c6c9cSStella Laurenzo             return mlirIntegerAttrGetValueInt(self);
419e9db306dSrkayaith           if (mlirIntegerTypeIsSigned(type))
420e9db306dSrkayaith             return mlirIntegerAttrGetValueSInt(self);
421e9db306dSrkayaith           return mlirIntegerAttrGetValueUInt(self);
422436c6c9cSStella Laurenzo         },
423436c6c9cSStella Laurenzo         "Returns the value of the integer attribute");
424436c6c9cSStella Laurenzo   }
425436c6c9cSStella Laurenzo };
426436c6c9cSStella Laurenzo 
427436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
428436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
429436c6c9cSStella Laurenzo public:
430436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
431436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
432436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
433436c6c9cSStella Laurenzo 
434436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
435436c6c9cSStella Laurenzo     c.def_static(
436436c6c9cSStella Laurenzo         "get",
437436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
438436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
439436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
440436c6c9cSStella Laurenzo         },
441436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
442436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
443436c6c9cSStella Laurenzo     c.def_property_readonly(
444436c6c9cSStella Laurenzo         "value",
445436c6c9cSStella Laurenzo         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
446436c6c9cSStella Laurenzo         "Returns the value of the bool attribute");
447436c6c9cSStella Laurenzo   }
448436c6c9cSStella Laurenzo };
449436c6c9cSStella Laurenzo 
450436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
451436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
452436c6c9cSStella Laurenzo public:
453436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
454436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
455436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
456436c6c9cSStella Laurenzo 
457436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
458436c6c9cSStella Laurenzo     c.def_static(
459436c6c9cSStella Laurenzo         "get",
460436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
461436c6c9cSStella Laurenzo           MlirAttribute attr =
462436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
463436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
464436c6c9cSStella Laurenzo         },
465436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
466436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
467436c6c9cSStella Laurenzo     c.def_property_readonly(
468436c6c9cSStella Laurenzo         "value",
469436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
470436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
471436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
472436c6c9cSStella Laurenzo         },
473436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
474436c6c9cSStella Laurenzo   }
475436c6c9cSStella Laurenzo };
476436c6c9cSStella Laurenzo 
4775c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
4785c3861b2SYun Long public:
4795c3861b2SYun Long   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
4805c3861b2SYun Long   static constexpr const char *pyClassName = "OpaqueAttr";
4815c3861b2SYun Long   using PyConcreteAttribute::PyConcreteAttribute;
4825c3861b2SYun Long 
4835c3861b2SYun Long   static void bindDerived(ClassTy &c) {
4845c3861b2SYun Long     c.def_static(
4855c3861b2SYun Long         "get",
4865c3861b2SYun Long         [](std::string dialectNamespace, py::buffer buffer, PyType &type,
4875c3861b2SYun Long            DefaultingPyMlirContext context) {
4885c3861b2SYun Long           const py::buffer_info bufferInfo = buffer.request();
4895c3861b2SYun Long           intptr_t bufferSize = bufferInfo.size;
4905c3861b2SYun Long           MlirAttribute attr = mlirOpaqueAttrGet(
4915c3861b2SYun Long               context->get(), toMlirStringRef(dialectNamespace), bufferSize,
4925c3861b2SYun Long               static_cast<char *>(bufferInfo.ptr), type);
4935c3861b2SYun Long           return PyOpaqueAttribute(context->getRef(), attr);
4945c3861b2SYun Long         },
4955c3861b2SYun Long         py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"),
4965c3861b2SYun Long         py::arg("context") = py::none(), "Gets an Opaque attribute.");
4975c3861b2SYun Long     c.def_property_readonly(
4985c3861b2SYun Long         "dialect_namespace",
4995c3861b2SYun Long         [](PyOpaqueAttribute &self) {
5005c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
5015c3861b2SYun Long           return py::str(stringRef.data, stringRef.length);
5025c3861b2SYun Long         },
5035c3861b2SYun Long         "Returns the dialect namespace for the Opaque attribute as a string");
5045c3861b2SYun Long     c.def_property_readonly(
5055c3861b2SYun Long         "data",
5065c3861b2SYun Long         [](PyOpaqueAttribute &self) {
5075c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
5085c3861b2SYun Long           return py::str(stringRef.data, stringRef.length);
5095c3861b2SYun Long         },
5105c3861b2SYun Long         "Returns the data for the Opaqued attributes as a string");
5115c3861b2SYun Long   }
5125c3861b2SYun Long };
5135c3861b2SYun Long 
514436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
515436c6c9cSStella Laurenzo public:
516436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
517436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
518436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
519436c6c9cSStella Laurenzo 
520436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
521436c6c9cSStella Laurenzo     c.def_static(
522436c6c9cSStella Laurenzo         "get",
523436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
524436c6c9cSStella Laurenzo           MlirAttribute attr =
525436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
526436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
527436c6c9cSStella Laurenzo         },
528436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
529436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
530436c6c9cSStella Laurenzo     c.def_static(
531436c6c9cSStella Laurenzo         "get_typed",
532436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
533436c6c9cSStella Laurenzo           MlirAttribute attr =
534436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
535436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
536436c6c9cSStella Laurenzo         },
537a6e7d024SStella Laurenzo         py::arg("type"), py::arg("value"),
538436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
539436c6c9cSStella Laurenzo     c.def_property_readonly(
540436c6c9cSStella Laurenzo         "value",
541436c6c9cSStella Laurenzo         [](PyStringAttribute &self) {
542436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirStringAttrGetValue(self);
543436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
544436c6c9cSStella Laurenzo         },
545436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
546436c6c9cSStella Laurenzo   }
547436c6c9cSStella Laurenzo };
548436c6c9cSStella Laurenzo 
549436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
550436c6c9cSStella Laurenzo class PyDenseElementsAttribute
551436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
552436c6c9cSStella Laurenzo public:
553436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
554436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
555436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
556436c6c9cSStella Laurenzo 
557436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
5585d6d30edSStella Laurenzo   getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
5595d6d30edSStella Laurenzo                 Optional<std::vector<int64_t>> explicitShape,
560436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
561436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
562436c6c9cSStella Laurenzo     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
563436c6c9cSStella Laurenzo     Py_buffer *view = new Py_buffer();
564436c6c9cSStella Laurenzo     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
565436c6c9cSStella Laurenzo       delete view;
566436c6c9cSStella Laurenzo       throw py::error_already_set();
567436c6c9cSStella Laurenzo     }
568436c6c9cSStella Laurenzo     py::buffer_info arrayInfo(view);
5695d6d30edSStella Laurenzo     SmallVector<int64_t> shape;
5705d6d30edSStella Laurenzo     if (explicitShape) {
5715d6d30edSStella Laurenzo       shape.append(explicitShape->begin(), explicitShape->end());
5725d6d30edSStella Laurenzo     } else {
5735d6d30edSStella Laurenzo       shape.append(arrayInfo.shape.begin(),
5745d6d30edSStella Laurenzo                    arrayInfo.shape.begin() + arrayInfo.ndim);
5755d6d30edSStella Laurenzo     }
576436c6c9cSStella Laurenzo 
5775d6d30edSStella Laurenzo     MlirAttribute encodingAttr = mlirAttributeGetNull();
578436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
5795d6d30edSStella Laurenzo 
5805d6d30edSStella Laurenzo     // Detect format codes that are suitable for bulk loading. This includes
5815d6d30edSStella Laurenzo     // all byte aligned integer and floating point types up to 8 bytes.
5825d6d30edSStella Laurenzo     // Notably, this excludes, bool (which needs to be bit-packed) and
5835d6d30edSStella Laurenzo     // other exotics which do not have a direct representation in the buffer
5845d6d30edSStella Laurenzo     // protocol (i.e. complex, etc).
5855d6d30edSStella Laurenzo     Optional<MlirType> bulkLoadElementType;
5865d6d30edSStella Laurenzo     if (explicitType) {
5875d6d30edSStella Laurenzo       bulkLoadElementType = *explicitType;
5885d6d30edSStella Laurenzo     } else if (arrayInfo.format == "f") {
589436c6c9cSStella Laurenzo       // f32
590436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
5915d6d30edSStella Laurenzo       bulkLoadElementType = mlirF32TypeGet(context);
592436c6c9cSStella Laurenzo     } else if (arrayInfo.format == "d") {
593436c6c9cSStella Laurenzo       // f64
594436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
5955d6d30edSStella Laurenzo       bulkLoadElementType = mlirF64TypeGet(context);
5965d6d30edSStella Laurenzo     } else if (arrayInfo.format == "e") {
5975d6d30edSStella Laurenzo       // f16
5985d6d30edSStella Laurenzo       assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
5995d6d30edSStella Laurenzo       bulkLoadElementType = mlirF16TypeGet(context);
600436c6c9cSStella Laurenzo     } else if (isSignedIntegerFormat(arrayInfo.format)) {
601436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
602436c6c9cSStella Laurenzo         // i32
6035d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
604436c6c9cSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 32);
605436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
606436c6c9cSStella Laurenzo         // i64
6075d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
608436c6c9cSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 64);
6095d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 1) {
6105d6d30edSStella Laurenzo         // i8
6115d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
6125d6d30edSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 8);
6135d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 2) {
6145d6d30edSStella Laurenzo         // i16
6155d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
6165d6d30edSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 16);
617436c6c9cSStella Laurenzo       }
618436c6c9cSStella Laurenzo     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
619436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
620436c6c9cSStella Laurenzo         // unsigned i32
6215d6d30edSStella Laurenzo         bulkLoadElementType = signless
622436c6c9cSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 32)
623436c6c9cSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 32);
624436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
625436c6c9cSStella Laurenzo         // unsigned i64
6265d6d30edSStella Laurenzo         bulkLoadElementType = signless
627436c6c9cSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 64)
628436c6c9cSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 64);
6295d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 1) {
6305d6d30edSStella Laurenzo         // i8
6315d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
6325d6d30edSStella Laurenzo                                        : mlirIntegerTypeUnsignedGet(context, 8);
6335d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 2) {
6345d6d30edSStella Laurenzo         // i16
6355d6d30edSStella Laurenzo         bulkLoadElementType = signless
6365d6d30edSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 16)
6375d6d30edSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 16);
638436c6c9cSStella Laurenzo       }
639436c6c9cSStella Laurenzo     }
6405d6d30edSStella Laurenzo     if (bulkLoadElementType) {
6415d6d30edSStella Laurenzo       auto shapedType = mlirRankedTensorTypeGet(
6425d6d30edSStella Laurenzo           shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
6435d6d30edSStella Laurenzo       size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
6445d6d30edSStella Laurenzo       MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
6455d6d30edSStella Laurenzo           shapedType, rawBufferSize, arrayInfo.ptr);
6465d6d30edSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
6475d6d30edSStella Laurenzo         throw std::invalid_argument(
6485d6d30edSStella Laurenzo             "DenseElementsAttr could not be constructed from the given buffer. "
6495d6d30edSStella Laurenzo             "This may mean that the Python buffer layout does not match that "
6505d6d30edSStella Laurenzo             "MLIR expected layout and is a bug.");
6515d6d30edSStella Laurenzo       }
6525d6d30edSStella Laurenzo       return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
6535d6d30edSStella Laurenzo     }
654436c6c9cSStella Laurenzo 
6555d6d30edSStella Laurenzo     throw std::invalid_argument(
6565d6d30edSStella Laurenzo         std::string("unimplemented array format conversion from format: ") +
6575d6d30edSStella Laurenzo         arrayInfo.format);
658436c6c9cSStella Laurenzo   }
659436c6c9cSStella Laurenzo 
6601fc096afSMehdi Amini   static PyDenseElementsAttribute getSplat(const PyType &shapedType,
661436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
662436c6c9cSStella Laurenzo     auto contextWrapper =
663436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
664436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
665436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
666436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
667436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
668436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
669436c6c9cSStella Laurenzo     }
670436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
671436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
672436c6c9cSStella Laurenzo       std::string message =
673436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
674436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
675436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
676436c6c9cSStella Laurenzo     }
677436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
678436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
679436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
680436c6c9cSStella Laurenzo       std::string message =
681436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
682436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
683436c6c9cSStella Laurenzo       message.append(", element=");
684436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
685436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
686436c6c9cSStella Laurenzo     }
687436c6c9cSStella Laurenzo 
688436c6c9cSStella Laurenzo     MlirAttribute elements =
689436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
690436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
691436c6c9cSStella Laurenzo   }
692436c6c9cSStella Laurenzo 
693436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
694436c6c9cSStella Laurenzo 
695436c6c9cSStella Laurenzo   py::buffer_info accessBuffer() {
6965d6d30edSStella Laurenzo     if (mlirDenseElementsAttrIsSplat(*this)) {
697c5f445d1SStella Laurenzo       // TODO: Currently crashes the program.
6985d6d30edSStella Laurenzo       // Reported as https://github.com/pybind/pybind11/issues/3336
699c5f445d1SStella Laurenzo       throw std::invalid_argument(
700c5f445d1SStella Laurenzo           "unsupported data type for conversion to Python buffer");
7015d6d30edSStella Laurenzo     }
7025d6d30edSStella Laurenzo 
703436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
704436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
7055d6d30edSStella Laurenzo     std::string format;
706436c6c9cSStella Laurenzo 
707436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
708436c6c9cSStella Laurenzo       // f32
7095d6d30edSStella Laurenzo       return bufferInfo<float>(shapedType);
71002b6fb21SMehdi Amini     }
71102b6fb21SMehdi Amini     if (mlirTypeIsAF64(elementType)) {
712436c6c9cSStella Laurenzo       // f64
7135d6d30edSStella Laurenzo       return bufferInfo<double>(shapedType);
714bb56c2b3SMehdi Amini     }
715bb56c2b3SMehdi Amini     if (mlirTypeIsAF16(elementType)) {
7165d6d30edSStella Laurenzo       // f16
7175d6d30edSStella Laurenzo       return bufferInfo<uint16_t>(shapedType, "e");
718bb56c2b3SMehdi Amini     }
719bb56c2b3SMehdi Amini     if (mlirTypeIsAInteger(elementType) &&
720436c6c9cSStella Laurenzo         mlirIntegerTypeGetWidth(elementType) == 32) {
721436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
722436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
723436c6c9cSStella Laurenzo         // i32
7245d6d30edSStella Laurenzo         return bufferInfo<int32_t>(shapedType);
725e5639b3fSMehdi Amini       }
726e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
727436c6c9cSStella Laurenzo         // unsigned i32
7285d6d30edSStella Laurenzo         return bufferInfo<uint32_t>(shapedType);
729436c6c9cSStella Laurenzo       }
730436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
731436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
732436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
733436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
734436c6c9cSStella Laurenzo         // i64
7355d6d30edSStella Laurenzo         return bufferInfo<int64_t>(shapedType);
736e5639b3fSMehdi Amini       }
737e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
738436c6c9cSStella Laurenzo         // unsigned i64
7395d6d30edSStella Laurenzo         return bufferInfo<uint64_t>(shapedType);
7405d6d30edSStella Laurenzo       }
7415d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
7425d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 8) {
7435d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
7445d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
7455d6d30edSStella Laurenzo         // i8
7465d6d30edSStella Laurenzo         return bufferInfo<int8_t>(shapedType);
747e5639b3fSMehdi Amini       }
748e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
7495d6d30edSStella Laurenzo         // unsigned i8
7505d6d30edSStella Laurenzo         return bufferInfo<uint8_t>(shapedType);
7515d6d30edSStella Laurenzo       }
7525d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
7535d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 16) {
7545d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
7555d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
7565d6d30edSStella Laurenzo         // i16
7575d6d30edSStella Laurenzo         return bufferInfo<int16_t>(shapedType);
758e5639b3fSMehdi Amini       }
759e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
7605d6d30edSStella Laurenzo         // unsigned i16
7615d6d30edSStella Laurenzo         return bufferInfo<uint16_t>(shapedType);
762436c6c9cSStella Laurenzo       }
763436c6c9cSStella Laurenzo     }
764436c6c9cSStella Laurenzo 
765c5f445d1SStella Laurenzo     // TODO: Currently crashes the program.
7665d6d30edSStella Laurenzo     // Reported as https://github.com/pybind/pybind11/issues/3336
767c5f445d1SStella Laurenzo     throw std::invalid_argument(
768c5f445d1SStella Laurenzo         "unsupported data type for conversion to Python buffer");
769436c6c9cSStella Laurenzo   }
770436c6c9cSStella Laurenzo 
771436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
772436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
773436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
774436c6c9cSStella Laurenzo                     py::arg("array"), py::arg("signless") = true,
7755d6d30edSStella Laurenzo                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
776436c6c9cSStella Laurenzo                     py::arg("context") = py::none(),
7775d6d30edSStella Laurenzo                     kDenseElementsAttrGetDocstring)
778436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
779436c6c9cSStella Laurenzo                     py::arg("shaped_type"), py::arg("element_attr"),
780436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
781436c6c9cSStella Laurenzo         .def_property_readonly("is_splat",
782436c6c9cSStella Laurenzo                                [](PyDenseElementsAttribute &self) -> bool {
783436c6c9cSStella Laurenzo                                  return mlirDenseElementsAttrIsSplat(self);
784436c6c9cSStella Laurenzo                                })
785436c6c9cSStella Laurenzo         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
786436c6c9cSStella Laurenzo   }
787436c6c9cSStella Laurenzo 
788436c6c9cSStella Laurenzo private:
789436c6c9cSStella Laurenzo   static bool isUnsignedIntegerFormat(const std::string &format) {
790436c6c9cSStella Laurenzo     if (format.empty())
791436c6c9cSStella Laurenzo       return false;
792436c6c9cSStella Laurenzo     char code = format[0];
793436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
794436c6c9cSStella Laurenzo            code == 'Q';
795436c6c9cSStella Laurenzo   }
796436c6c9cSStella Laurenzo 
797436c6c9cSStella Laurenzo   static bool isSignedIntegerFormat(const std::string &format) {
798436c6c9cSStella Laurenzo     if (format.empty())
799436c6c9cSStella Laurenzo       return false;
800436c6c9cSStella Laurenzo     char code = format[0];
801436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
802436c6c9cSStella Laurenzo            code == 'q';
803436c6c9cSStella Laurenzo   }
804436c6c9cSStella Laurenzo 
805436c6c9cSStella Laurenzo   template <typename Type>
806436c6c9cSStella Laurenzo   py::buffer_info bufferInfo(MlirType shapedType,
8075d6d30edSStella Laurenzo                              const char *explicitFormat = nullptr) {
808436c6c9cSStella Laurenzo     intptr_t rank = mlirShapedTypeGetRank(shapedType);
809436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
810436c6c9cSStella Laurenzo     // Buffer is configured for read-only access below.
811436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
812436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
813436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
814436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
815436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
816436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
817436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
818436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
819436c6c9cSStella Laurenzo     intptr_t strideFactor = 1;
820436c6c9cSStella Laurenzo     for (intptr_t i = 1; i < rank; ++i) {
821436c6c9cSStella Laurenzo       strideFactor = 1;
822436c6c9cSStella Laurenzo       for (intptr_t j = i; j < rank; ++j) {
823436c6c9cSStella Laurenzo         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
824436c6c9cSStella Laurenzo       }
825436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type) * strideFactor);
826436c6c9cSStella Laurenzo     }
827436c6c9cSStella Laurenzo     strides.push_back(sizeof(Type));
8285d6d30edSStella Laurenzo     std::string format;
8295d6d30edSStella Laurenzo     if (explicitFormat) {
8305d6d30edSStella Laurenzo       format = explicitFormat;
8315d6d30edSStella Laurenzo     } else {
8325d6d30edSStella Laurenzo       format = py::format_descriptor<Type>::format();
8335d6d30edSStella Laurenzo     }
8345d6d30edSStella Laurenzo     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
8355d6d30edSStella Laurenzo                            /*readonly=*/true);
836436c6c9cSStella Laurenzo   }
837436c6c9cSStella Laurenzo }; // namespace
838436c6c9cSStella Laurenzo 
839436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer
840436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access.
841436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
842436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
843436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
844436c6c9cSStella Laurenzo public:
845436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
846436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
847436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
848436c6c9cSStella Laurenzo 
849436c6c9cSStella Laurenzo   /// Returns the element at the given linear position. Asserts if the index is
850436c6c9cSStella Laurenzo   /// out of range.
851436c6c9cSStella Laurenzo   py::int_ dunderGetItem(intptr_t pos) {
852436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
853436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
854436c6c9cSStella Laurenzo                        "attempt to access out of bounds element");
855436c6c9cSStella Laurenzo     }
856436c6c9cSStella Laurenzo 
857436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
858436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
859436c6c9cSStella Laurenzo     assert(mlirTypeIsAInteger(type) &&
860436c6c9cSStella Laurenzo            "expected integer element type in dense int elements attribute");
861436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
862436c6c9cSStella Laurenzo     // elemental type of the attribute. py::int_ is implicitly constructible
863436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
864436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
865436c6c9cSStella Laurenzo     // querying them on each element access.
866436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
867436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
868436c6c9cSStella Laurenzo     if (isUnsigned) {
869436c6c9cSStella Laurenzo       if (width == 1) {
870436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
871436c6c9cSStella Laurenzo       }
872308d8b8cSRahul Kayaith       if (width == 8) {
873308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt8Value(*this, pos);
874308d8b8cSRahul Kayaith       }
875308d8b8cSRahul Kayaith       if (width == 16) {
876308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt16Value(*this, pos);
877308d8b8cSRahul Kayaith       }
878436c6c9cSStella Laurenzo       if (width == 32) {
879436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
880436c6c9cSStella Laurenzo       }
881436c6c9cSStella Laurenzo       if (width == 64) {
882436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
883436c6c9cSStella Laurenzo       }
884436c6c9cSStella Laurenzo     } else {
885436c6c9cSStella Laurenzo       if (width == 1) {
886436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
887436c6c9cSStella Laurenzo       }
888308d8b8cSRahul Kayaith       if (width == 8) {
889308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt8Value(*this, pos);
890308d8b8cSRahul Kayaith       }
891308d8b8cSRahul Kayaith       if (width == 16) {
892308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt16Value(*this, pos);
893308d8b8cSRahul Kayaith       }
894436c6c9cSStella Laurenzo       if (width == 32) {
895436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt32Value(*this, pos);
896436c6c9cSStella Laurenzo       }
897436c6c9cSStella Laurenzo       if (width == 64) {
898436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt64Value(*this, pos);
899436c6c9cSStella Laurenzo       }
900436c6c9cSStella Laurenzo     }
901436c6c9cSStella Laurenzo     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
902436c6c9cSStella Laurenzo   }
903436c6c9cSStella Laurenzo 
904436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
905436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
906436c6c9cSStella Laurenzo   }
907436c6c9cSStella Laurenzo };
908436c6c9cSStella Laurenzo 
909436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
910436c6c9cSStella Laurenzo public:
911436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
912436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
913436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
914436c6c9cSStella Laurenzo 
915436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
916436c6c9cSStella Laurenzo 
9179fb1086bSAdrian Kuegel   bool dunderContains(const std::string &name) {
9189fb1086bSAdrian Kuegel     return !mlirAttributeIsNull(
9199fb1086bSAdrian Kuegel         mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
9209fb1086bSAdrian Kuegel   }
9219fb1086bSAdrian Kuegel 
922436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
9239fb1086bSAdrian Kuegel     c.def("__contains__", &PyDictAttribute::dunderContains);
924436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
925436c6c9cSStella Laurenzo     c.def_static(
926436c6c9cSStella Laurenzo         "get",
927436c6c9cSStella Laurenzo         [](py::dict attributes, DefaultingPyMlirContext context) {
928436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
929436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
930436c6c9cSStella Laurenzo           for (auto &it : attributes) {
93102b6fb21SMehdi Amini             auto &mlirAttr = it.second.cast<PyAttribute &>();
932436c6c9cSStella Laurenzo             auto name = it.first.cast<std::string>();
933436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
93402b6fb21SMehdi Amini                 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
935436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
93602b6fb21SMehdi Amini                 mlirAttr));
937436c6c9cSStella Laurenzo           }
938436c6c9cSStella Laurenzo           MlirAttribute attr =
939436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
940436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
941436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
942436c6c9cSStella Laurenzo         },
943ed9e52f3SAlex Zinenko         py::arg("value") = py::dict(), py::arg("context") = py::none(),
944436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
945436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
946436c6c9cSStella Laurenzo       MlirAttribute attr =
947436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
948436c6c9cSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
949436c6c9cSStella Laurenzo         throw SetPyError(PyExc_KeyError,
950436c6c9cSStella Laurenzo                          "attempt to access a non-existent attribute");
951436c6c9cSStella Laurenzo       }
952436c6c9cSStella Laurenzo       return PyAttribute(self.getContext(), attr);
953436c6c9cSStella Laurenzo     });
954436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
955436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
956436c6c9cSStella Laurenzo         throw SetPyError(PyExc_IndexError,
957436c6c9cSStella Laurenzo                          "attempt to access out of bounds attribute");
958436c6c9cSStella Laurenzo       }
959436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
960436c6c9cSStella Laurenzo       return PyNamedAttribute(
961436c6c9cSStella Laurenzo           namedAttr.attribute,
962436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
963436c6c9cSStella Laurenzo     });
964436c6c9cSStella Laurenzo   }
965436c6c9cSStella Laurenzo };
966436c6c9cSStella Laurenzo 
967436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
968436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
969436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
970436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
971436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
972436c6c9cSStella Laurenzo public:
973436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
974436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
975436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
976436c6c9cSStella Laurenzo 
977436c6c9cSStella Laurenzo   py::float_ dunderGetItem(intptr_t pos) {
978436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
979436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
980436c6c9cSStella Laurenzo                        "attempt to access out of bounds element");
981436c6c9cSStella Laurenzo     }
982436c6c9cSStella Laurenzo 
983436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
984436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
985436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
986436c6c9cSStella Laurenzo     // elemental type of the attribute. py::float_ is implicitly constructible
987436c6c9cSStella Laurenzo     // from float and double.
988436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
989436c6c9cSStella Laurenzo     // querying them on each element access.
990436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
991436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetFloatValue(*this, pos);
992436c6c9cSStella Laurenzo     }
993436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
994436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
995436c6c9cSStella Laurenzo     }
996436c6c9cSStella Laurenzo     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
997436c6c9cSStella Laurenzo   }
998436c6c9cSStella Laurenzo 
999436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1000436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1001436c6c9cSStella Laurenzo   }
1002436c6c9cSStella Laurenzo };
1003436c6c9cSStella Laurenzo 
1004436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1005436c6c9cSStella Laurenzo public:
1006436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1007436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
1008436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1009436c6c9cSStella Laurenzo 
1010436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1011436c6c9cSStella Laurenzo     c.def_static(
1012436c6c9cSStella Laurenzo         "get",
1013436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
1014436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
1015436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
1016436c6c9cSStella Laurenzo         },
1017436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
1018436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
1019436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyTypeAttribute &self) {
1020436c6c9cSStella Laurenzo       return PyType(self.getContext()->getRef(),
1021436c6c9cSStella Laurenzo                     mlirTypeAttrGetValue(self.get()));
1022436c6c9cSStella Laurenzo     });
1023436c6c9cSStella Laurenzo   }
1024436c6c9cSStella Laurenzo };
1025436c6c9cSStella Laurenzo 
1026436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
1027436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1028436c6c9cSStella Laurenzo public:
1029436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1030436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
1031436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1032436c6c9cSStella Laurenzo 
1033436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1034436c6c9cSStella Laurenzo     c.def_static(
1035436c6c9cSStella Laurenzo         "get",
1036436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
1037436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
1038436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
1039436c6c9cSStella Laurenzo         },
1040436c6c9cSStella Laurenzo         py::arg("context") = py::none(), "Create a Unit attribute.");
1041436c6c9cSStella Laurenzo   }
1042436c6c9cSStella Laurenzo };
1043436c6c9cSStella Laurenzo 
1044436c6c9cSStella Laurenzo } // namespace
1045436c6c9cSStella Laurenzo 
1046436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) {
1047436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
1048619fd8c2SJeff Niu 
1049619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::bind(m);
1050619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1051619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::bind(m);
1052619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1053619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::bind(m);
1054619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1055619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::bind(m);
1056619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1057619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::bind(m);
1058619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1059619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::bind(m);
1060619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1061619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::bind(m);
1062619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
1063619fd8c2SJeff Niu 
1064436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
1065436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1066436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
1067436c6c9cSStella Laurenzo   PyDenseElementsAttribute::bind(m);
1068436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
1069436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
1070436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
1071436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
10725c3861b2SYun Long   PyOpaqueAttribute::bind(m);
1073436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
1074436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
1075436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
1076436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
1077436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
1078436c6c9cSStella Laurenzo }
1079