xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision 4811270bac0e57ab8f5baf27eb280012817bdfe5)
1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9a1fe1f5fSKazu Hirata #include <optional>
10*4811270bSmax #include <utility>
111fc096afSMehdi Amini 
12436c6c9cSStella Laurenzo #include "IRModule.h"
13436c6c9cSStella Laurenzo 
14436c6c9cSStella Laurenzo #include "PybindUtils.h"
15436c6c9cSStella Laurenzo 
16436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
17436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
18436c6c9cSStella Laurenzo 
19436c6c9cSStella Laurenzo namespace py = pybind11;
20436c6c9cSStella Laurenzo using namespace mlir;
21436c6c9cSStella Laurenzo using namespace mlir::python;
22436c6c9cSStella Laurenzo 
23436c6c9cSStella Laurenzo using llvm::SmallVector;
24436c6c9cSStella Laurenzo 
255d6d30edSStella Laurenzo //------------------------------------------------------------------------------
265d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
275d6d30edSStella Laurenzo //------------------------------------------------------------------------------
285d6d30edSStella Laurenzo 
295d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] =
305d6d30edSStella Laurenzo     R"(Gets a DenseElementsAttr from a Python buffer or array.
315d6d30edSStella Laurenzo 
325d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based
335d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and
345d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these
355d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer.
365d6d30edSStella Laurenzo 
375d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided
385d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal
395d6d30edSStella Laurenzo representation:
405d6d30edSStella Laurenzo 
415d6d30edSStella Laurenzo   * Integer types (except for i1): the buffer must be byte aligned to the
425d6d30edSStella Laurenzo     next byte boundary.
435d6d30edSStella Laurenzo   * Floating point types: Must be bit-castable to the given floating point
445d6d30edSStella Laurenzo     size.
455d6d30edSStella Laurenzo   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
465d6d30edSStella Laurenzo     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
475d6d30edSStella Laurenzo     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
485d6d30edSStella Laurenzo 
495d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0
505d6d30edSStella Laurenzo or 255), then a splat will be created.
515d6d30edSStella Laurenzo 
525d6d30edSStella Laurenzo Args:
535d6d30edSStella Laurenzo   array: The array or buffer to convert.
545d6d30edSStella Laurenzo   signless: If inferring an appropriate MLIR type, use signless types for
555d6d30edSStella Laurenzo     integers (defaults True).
565d6d30edSStella Laurenzo   type: Skips inference of the MLIR element type and uses this instead. The
575d6d30edSStella Laurenzo     storage size must be consistent with the actual contents of the buffer.
585d6d30edSStella Laurenzo   shape: Overrides the shape of the buffer when constructing the MLIR
595d6d30edSStella Laurenzo     shaped type. This is needed when the physical and logical shape differ (as
605d6d30edSStella Laurenzo     for i1).
615d6d30edSStella Laurenzo   context: Explicit context, if not from context manager.
625d6d30edSStella Laurenzo 
635d6d30edSStella Laurenzo Returns:
645d6d30edSStella Laurenzo   DenseElementsAttr on success.
655d6d30edSStella Laurenzo 
665d6d30edSStella Laurenzo Raises:
675d6d30edSStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
685d6d30edSStella Laurenzo     type or if the buffer does not meet expectations.
695d6d30edSStella Laurenzo )";
705d6d30edSStella Laurenzo 
71436c6c9cSStella Laurenzo namespace {
72436c6c9cSStella Laurenzo 
73436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
74436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
75436c6c9cSStella Laurenzo }
76436c6c9cSStella Laurenzo 
77436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
78436c6c9cSStella Laurenzo public:
79436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
80436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
81436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
82436c6c9cSStella Laurenzo 
83436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
84436c6c9cSStella Laurenzo     c.def_static(
85436c6c9cSStella Laurenzo         "get",
86436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
87436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
88436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
89436c6c9cSStella Laurenzo         },
90436c6c9cSStella Laurenzo         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
91436c6c9cSStella Laurenzo   }
92436c6c9cSStella Laurenzo };
93436c6c9cSStella Laurenzo 
94ed9e52f3SAlex Zinenko template <typename T>
95ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) {
96ed9e52f3SAlex Zinenko   try {
97ed9e52f3SAlex Zinenko     return object.cast<T>();
98ed9e52f3SAlex Zinenko   } catch (py::cast_error &err) {
99ed9e52f3SAlex Zinenko     std::string msg =
100ed9e52f3SAlex Zinenko         std::string(
101ed9e52f3SAlex Zinenko             "Invalid attribute when attempting to create an ArrayAttribute (") +
102ed9e52f3SAlex Zinenko         err.what() + ")";
103ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
104ed9e52f3SAlex Zinenko   } catch (py::reference_cast_error &err) {
105ed9e52f3SAlex Zinenko     std::string msg = std::string("Invalid attribute (None?) when attempting "
106ed9e52f3SAlex Zinenko                                   "to create an ArrayAttribute (") +
107ed9e52f3SAlex Zinenko                       err.what() + ")";
108ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
109ed9e52f3SAlex Zinenko   }
110ed9e52f3SAlex Zinenko }
111ed9e52f3SAlex Zinenko 
112619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived
113619fd8c2SJeff Niu /// implementation class.
114619fd8c2SJeff Niu template <typename EltTy, typename DerivedT>
115133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
116619fd8c2SJeff Niu public:
117133624acSJeff Niu   using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
118619fd8c2SJeff Niu 
119619fd8c2SJeff Niu   /// Iterator over the integer elements of a dense array.
120619fd8c2SJeff Niu   class PyDenseArrayIterator {
121619fd8c2SJeff Niu   public:
1224a1b1196SMehdi Amini     PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
123619fd8c2SJeff Niu 
124619fd8c2SJeff Niu     /// Return a copy of the iterator.
125619fd8c2SJeff Niu     PyDenseArrayIterator dunderIter() { return *this; }
126619fd8c2SJeff Niu 
127619fd8c2SJeff Niu     /// Return the next element.
128619fd8c2SJeff Niu     EltTy dunderNext() {
129619fd8c2SJeff Niu       // Throw if the index has reached the end.
130619fd8c2SJeff Niu       if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
131619fd8c2SJeff Niu         throw py::stop_iteration();
132619fd8c2SJeff Niu       return DerivedT::getElement(attr.get(), nextIndex++);
133619fd8c2SJeff Niu     }
134619fd8c2SJeff Niu 
135619fd8c2SJeff Niu     /// Bind the iterator class.
136619fd8c2SJeff Niu     static void bind(py::module &m) {
137619fd8c2SJeff Niu       py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName,
138619fd8c2SJeff Niu                                        py::module_local())
139619fd8c2SJeff Niu           .def("__iter__", &PyDenseArrayIterator::dunderIter)
140619fd8c2SJeff Niu           .def("__next__", &PyDenseArrayIterator::dunderNext);
141619fd8c2SJeff Niu     }
142619fd8c2SJeff Niu 
143619fd8c2SJeff Niu   private:
144619fd8c2SJeff Niu     /// The referenced dense array attribute.
145619fd8c2SJeff Niu     PyAttribute attr;
146619fd8c2SJeff Niu     /// The next index to read.
147619fd8c2SJeff Niu     int nextIndex = 0;
148619fd8c2SJeff Niu   };
149619fd8c2SJeff Niu 
150619fd8c2SJeff Niu   /// Get the element at the given index.
151619fd8c2SJeff Niu   EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
152619fd8c2SJeff Niu 
153619fd8c2SJeff Niu   /// Bind the attribute class.
154133624acSJeff Niu   static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
155619fd8c2SJeff Niu     // Bind the constructor.
156619fd8c2SJeff Niu     c.def_static(
157619fd8c2SJeff Niu         "get",
158619fd8c2SJeff Niu         [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
159619fd8c2SJeff Niu           MlirAttribute attr =
160619fd8c2SJeff Niu               DerivedT::getAttribute(ctx->get(), values.size(), values.data());
161133624acSJeff Niu           return DerivedT(ctx->getRef(), attr);
162619fd8c2SJeff Niu         },
163619fd8c2SJeff Niu         py::arg("values"), py::arg("context") = py::none(),
164619fd8c2SJeff Niu         "Gets a uniqued dense array attribute");
165619fd8c2SJeff Niu     // Bind the array methods.
166133624acSJeff Niu     c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
167619fd8c2SJeff Niu       if (i >= mlirDenseArrayGetNumElements(arr))
168619fd8c2SJeff Niu         throw py::index_error("DenseArray index out of range");
169619fd8c2SJeff Niu       return arr.getItem(i);
170619fd8c2SJeff Niu     });
171133624acSJeff Niu     c.def("__len__", [](const DerivedT &arr) {
172619fd8c2SJeff Niu       return mlirDenseArrayGetNumElements(arr);
173619fd8c2SJeff Niu     });
174133624acSJeff Niu     c.def("__iter__",
175133624acSJeff Niu           [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
1764a1b1196SMehdi Amini     c.def("__add__", [](DerivedT &arr, const py::list &extras) {
177619fd8c2SJeff Niu       std::vector<EltTy> values;
178619fd8c2SJeff Niu       intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
179619fd8c2SJeff Niu       values.reserve(numOldElements + py::len(extras));
180619fd8c2SJeff Niu       for (intptr_t i = 0; i < numOldElements; ++i)
181619fd8c2SJeff Niu         values.push_back(arr.getItem(i));
182619fd8c2SJeff Niu       for (py::handle attr : extras)
183619fd8c2SJeff Niu         values.push_back(pyTryCast<EltTy>(attr));
184619fd8c2SJeff Niu       MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(),
185619fd8c2SJeff Niu                                                   values.size(), values.data());
186133624acSJeff Niu       return DerivedT(arr.getContext(), attr);
187619fd8c2SJeff Niu     });
188619fd8c2SJeff Niu   }
189619fd8c2SJeff Niu };
190619fd8c2SJeff Niu 
191619fd8c2SJeff Niu /// Instantiate the python dense array classes.
192619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute
193619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int, PyDenseBoolArrayAttribute> {
194619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
195619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseBoolArrayGet;
196619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseBoolArrayGetElement;
197619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseBoolArrayAttr";
198619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
199619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
200619fd8c2SJeff Niu };
201619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute
202619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
203619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
204619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI8ArrayGet;
205619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI8ArrayGetElement;
206619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI8ArrayAttr";
207619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
208619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
209619fd8c2SJeff Niu };
210619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute
211619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
212619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
213619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI16ArrayGet;
214619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI16ArrayGetElement;
215619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI16ArrayAttr";
216619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
217619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
218619fd8c2SJeff Niu };
219619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute
220619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
221619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
222619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI32ArrayGet;
223619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI32ArrayGetElement;
224619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI32ArrayAttr";
225619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
226619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
227619fd8c2SJeff Niu };
228619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute
229619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
230619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
231619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI64ArrayGet;
232619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI64ArrayGetElement;
233619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI64ArrayAttr";
234619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
235619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
236619fd8c2SJeff Niu };
237619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute
238619fd8c2SJeff Niu     : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
239619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
240619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF32ArrayGet;
241619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF32ArrayGetElement;
242619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF32ArrayAttr";
243619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
244619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
245619fd8c2SJeff Niu };
246619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute
247619fd8c2SJeff Niu     : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
248619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
249619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF64ArrayGet;
250619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF64ArrayGetElement;
251619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF64ArrayAttr";
252619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
253619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
254619fd8c2SJeff Niu };
255619fd8c2SJeff Niu 
256436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
257436c6c9cSStella Laurenzo public:
258436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
259436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
260436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
261436c6c9cSStella Laurenzo 
262436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
263436c6c9cSStella Laurenzo   public:
2641fc096afSMehdi Amini     PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
265436c6c9cSStella Laurenzo 
266436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
267436c6c9cSStella Laurenzo 
268436c6c9cSStella Laurenzo     PyAttribute dunderNext() {
269bca88952SJeff Niu       // TODO: Throw is an inefficient way to stop iteration.
270bca88952SJeff Niu       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
271436c6c9cSStella Laurenzo         throw py::stop_iteration();
272436c6c9cSStella Laurenzo       return PyAttribute(attr.getContext(),
273436c6c9cSStella Laurenzo                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
274436c6c9cSStella Laurenzo     }
275436c6c9cSStella Laurenzo 
276436c6c9cSStella Laurenzo     static void bind(py::module &m) {
277f05ff4f7SStella Laurenzo       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
278f05ff4f7SStella Laurenzo                                            py::module_local())
279436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
280436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
281436c6c9cSStella Laurenzo     }
282436c6c9cSStella Laurenzo 
283436c6c9cSStella Laurenzo   private:
284436c6c9cSStella Laurenzo     PyAttribute attr;
285436c6c9cSStella Laurenzo     int nextIndex = 0;
286436c6c9cSStella Laurenzo   };
287436c6c9cSStella Laurenzo 
288ed9e52f3SAlex Zinenko   PyAttribute getItem(intptr_t i) {
289ed9e52f3SAlex Zinenko     return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
290ed9e52f3SAlex Zinenko   }
291ed9e52f3SAlex Zinenko 
292436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
293436c6c9cSStella Laurenzo     c.def_static(
294436c6c9cSStella Laurenzo         "get",
295436c6c9cSStella Laurenzo         [](py::list attributes, DefaultingPyMlirContext context) {
296436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
297436c6c9cSStella Laurenzo           mlirAttributes.reserve(py::len(attributes));
298436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
299ed9e52f3SAlex Zinenko             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
300436c6c9cSStella Laurenzo           }
301436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
302436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
303436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
304436c6c9cSStella Laurenzo         },
305436c6c9cSStella Laurenzo         py::arg("attributes"), py::arg("context") = py::none(),
306436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
307436c6c9cSStella Laurenzo     c.def("__getitem__",
308436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
309436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
310436c6c9cSStella Laurenzo               throw py::index_error("ArrayAttribute index out of range");
311ed9e52f3SAlex Zinenko             return arr.getItem(i);
312436c6c9cSStella Laurenzo           })
313436c6c9cSStella Laurenzo         .def("__len__",
314436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
315436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
316436c6c9cSStella Laurenzo              })
317436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
318436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
319436c6c9cSStella Laurenzo         });
320ed9e52f3SAlex Zinenko     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
321ed9e52f3SAlex Zinenko       std::vector<MlirAttribute> attributes;
322ed9e52f3SAlex Zinenko       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
323ed9e52f3SAlex Zinenko       attributes.reserve(numOldElements + py::len(extras));
324ed9e52f3SAlex Zinenko       for (intptr_t i = 0; i < numOldElements; ++i)
325ed9e52f3SAlex Zinenko         attributes.push_back(arr.getItem(i));
326ed9e52f3SAlex Zinenko       for (py::handle attr : extras)
327ed9e52f3SAlex Zinenko         attributes.push_back(pyTryCast<PyAttribute>(attr));
328ed9e52f3SAlex Zinenko       MlirAttribute arrayAttr = mlirArrayAttrGet(
329ed9e52f3SAlex Zinenko           arr.getContext()->get(), attributes.size(), attributes.data());
330ed9e52f3SAlex Zinenko       return PyArrayAttribute(arr.getContext(), arrayAttr);
331ed9e52f3SAlex Zinenko     });
332436c6c9cSStella Laurenzo   }
333436c6c9cSStella Laurenzo };
334436c6c9cSStella Laurenzo 
335436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
336436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
337436c6c9cSStella Laurenzo public:
338436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
339436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
340436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
341436c6c9cSStella Laurenzo 
342436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
343436c6c9cSStella Laurenzo     c.def_static(
344436c6c9cSStella Laurenzo         "get",
345436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
3463ea4c501SRahul Kayaith           PyMlirContext::ErrorCapture errors(loc->getContext());
347436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
3483ea4c501SRahul Kayaith           if (mlirAttributeIsNull(attr))
3493ea4c501SRahul Kayaith             throw MLIRError("Invalid attribute", errors.take());
350436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
351436c6c9cSStella Laurenzo         },
352436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
353436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
354436c6c9cSStella Laurenzo     c.def_static(
355436c6c9cSStella Laurenzo         "get_f32",
356436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
357436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
358436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
359436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
360436c6c9cSStella Laurenzo         },
361436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
362436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
363436c6c9cSStella Laurenzo     c.def_static(
364436c6c9cSStella Laurenzo         "get_f64",
365436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
366436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
367436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
368436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
369436c6c9cSStella Laurenzo         },
370436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
371436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
372436c6c9cSStella Laurenzo     c.def_property_readonly(
373436c6c9cSStella Laurenzo         "value",
374436c6c9cSStella Laurenzo         [](PyFloatAttribute &self) {
375436c6c9cSStella Laurenzo           return mlirFloatAttrGetValueDouble(self);
376436c6c9cSStella Laurenzo         },
377436c6c9cSStella Laurenzo         "Returns the value of the float point attribute");
378436c6c9cSStella Laurenzo   }
379436c6c9cSStella Laurenzo };
380436c6c9cSStella Laurenzo 
381436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
382436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
383436c6c9cSStella Laurenzo public:
384436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
385436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
386436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
387436c6c9cSStella Laurenzo 
388436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
389436c6c9cSStella Laurenzo     c.def_static(
390436c6c9cSStella Laurenzo         "get",
391436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
392436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
393436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
394436c6c9cSStella Laurenzo         },
395436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"),
396436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
397436c6c9cSStella Laurenzo     c.def_property_readonly(
398436c6c9cSStella Laurenzo         "value",
399e9db306dSrkayaith         [](PyIntegerAttribute &self) -> py::int_ {
400e9db306dSrkayaith           MlirType type = mlirAttributeGetType(self);
401e9db306dSrkayaith           if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
402436c6c9cSStella Laurenzo             return mlirIntegerAttrGetValueInt(self);
403e9db306dSrkayaith           if (mlirIntegerTypeIsSigned(type))
404e9db306dSrkayaith             return mlirIntegerAttrGetValueSInt(self);
405e9db306dSrkayaith           return mlirIntegerAttrGetValueUInt(self);
406436c6c9cSStella Laurenzo         },
407436c6c9cSStella Laurenzo         "Returns the value of the integer attribute");
408436c6c9cSStella Laurenzo   }
409436c6c9cSStella Laurenzo };
410436c6c9cSStella Laurenzo 
411436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
412436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
413436c6c9cSStella Laurenzo public:
414436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
415436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
416436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
417436c6c9cSStella Laurenzo 
418436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
419436c6c9cSStella Laurenzo     c.def_static(
420436c6c9cSStella Laurenzo         "get",
421436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
422436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
423436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
424436c6c9cSStella Laurenzo         },
425436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
426436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
427436c6c9cSStella Laurenzo     c.def_property_readonly(
428436c6c9cSStella Laurenzo         "value",
429436c6c9cSStella Laurenzo         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
430436c6c9cSStella Laurenzo         "Returns the value of the bool attribute");
431436c6c9cSStella Laurenzo   }
432436c6c9cSStella Laurenzo };
433436c6c9cSStella Laurenzo 
434436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
435436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
436436c6c9cSStella Laurenzo public:
437436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
438436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
439436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
440436c6c9cSStella Laurenzo 
441436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
442436c6c9cSStella Laurenzo     c.def_static(
443436c6c9cSStella Laurenzo         "get",
444436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
445436c6c9cSStella Laurenzo           MlirAttribute attr =
446436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
447436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
448436c6c9cSStella Laurenzo         },
449436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
450436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
451436c6c9cSStella Laurenzo     c.def_property_readonly(
452436c6c9cSStella Laurenzo         "value",
453436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
454436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
455436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
456436c6c9cSStella Laurenzo         },
457436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
458436c6c9cSStella Laurenzo   }
459436c6c9cSStella Laurenzo };
460436c6c9cSStella Laurenzo 
4615c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
4625c3861b2SYun Long public:
4635c3861b2SYun Long   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
4645c3861b2SYun Long   static constexpr const char *pyClassName = "OpaqueAttr";
4655c3861b2SYun Long   using PyConcreteAttribute::PyConcreteAttribute;
4665c3861b2SYun Long 
4675c3861b2SYun Long   static void bindDerived(ClassTy &c) {
4685c3861b2SYun Long     c.def_static(
4695c3861b2SYun Long         "get",
4705c3861b2SYun Long         [](std::string dialectNamespace, py::buffer buffer, PyType &type,
4715c3861b2SYun Long            DefaultingPyMlirContext context) {
4725c3861b2SYun Long           const py::buffer_info bufferInfo = buffer.request();
4735c3861b2SYun Long           intptr_t bufferSize = bufferInfo.size;
4745c3861b2SYun Long           MlirAttribute attr = mlirOpaqueAttrGet(
4755c3861b2SYun Long               context->get(), toMlirStringRef(dialectNamespace), bufferSize,
4765c3861b2SYun Long               static_cast<char *>(bufferInfo.ptr), type);
4775c3861b2SYun Long           return PyOpaqueAttribute(context->getRef(), attr);
4785c3861b2SYun Long         },
4795c3861b2SYun Long         py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"),
4805c3861b2SYun Long         py::arg("context") = py::none(), "Gets an Opaque attribute.");
4815c3861b2SYun Long     c.def_property_readonly(
4825c3861b2SYun Long         "dialect_namespace",
4835c3861b2SYun Long         [](PyOpaqueAttribute &self) {
4845c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
4855c3861b2SYun Long           return py::str(stringRef.data, stringRef.length);
4865c3861b2SYun Long         },
4875c3861b2SYun Long         "Returns the dialect namespace for the Opaque attribute as a string");
4885c3861b2SYun Long     c.def_property_readonly(
4895c3861b2SYun Long         "data",
4905c3861b2SYun Long         [](PyOpaqueAttribute &self) {
4915c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
49262bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
4935c3861b2SYun Long         },
49462bf6c2eSChris Jones         "Returns the data for the Opaqued attributes as `bytes`");
4955c3861b2SYun Long   }
4965c3861b2SYun Long };
4975c3861b2SYun Long 
498436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
499436c6c9cSStella Laurenzo public:
500436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
501436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
502436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
503436c6c9cSStella Laurenzo 
504436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
505436c6c9cSStella Laurenzo     c.def_static(
506436c6c9cSStella Laurenzo         "get",
507436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
508436c6c9cSStella Laurenzo           MlirAttribute attr =
509436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
510436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
511436c6c9cSStella Laurenzo         },
512436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
513436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
514436c6c9cSStella Laurenzo     c.def_static(
515436c6c9cSStella Laurenzo         "get_typed",
516436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
517436c6c9cSStella Laurenzo           MlirAttribute attr =
518436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
519436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
520436c6c9cSStella Laurenzo         },
521a6e7d024SStella Laurenzo         py::arg("type"), py::arg("value"),
522436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
523436c6c9cSStella Laurenzo     c.def_property_readonly(
524436c6c9cSStella Laurenzo         "value",
525436c6c9cSStella Laurenzo         [](PyStringAttribute &self) {
526436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirStringAttrGetValue(self);
527436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
528436c6c9cSStella Laurenzo         },
529436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
53062bf6c2eSChris Jones     c.def_property_readonly(
53162bf6c2eSChris Jones         "value_bytes",
53262bf6c2eSChris Jones         [](PyStringAttribute &self) {
53362bf6c2eSChris Jones           MlirStringRef stringRef = mlirStringAttrGetValue(self);
53462bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
53562bf6c2eSChris Jones         },
53662bf6c2eSChris Jones         "Returns the value of the string attribute as `bytes`");
537436c6c9cSStella Laurenzo   }
538436c6c9cSStella Laurenzo };
539436c6c9cSStella Laurenzo 
540436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
541436c6c9cSStella Laurenzo class PyDenseElementsAttribute
542436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
543436c6c9cSStella Laurenzo public:
544436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
545436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
546436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
547436c6c9cSStella Laurenzo 
548436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
5490a81ace0SKazu Hirata   getFromBuffer(py::buffer array, bool signless,
5500a81ace0SKazu Hirata                 std::optional<PyType> explicitType,
5510a81ace0SKazu Hirata                 std::optional<std::vector<int64_t>> explicitShape,
552436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
553436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
554436c6c9cSStella Laurenzo     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
555436c6c9cSStella Laurenzo     Py_buffer *view = new Py_buffer();
556436c6c9cSStella Laurenzo     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
557436c6c9cSStella Laurenzo       delete view;
558436c6c9cSStella Laurenzo       throw py::error_already_set();
559436c6c9cSStella Laurenzo     }
560436c6c9cSStella Laurenzo     py::buffer_info arrayInfo(view);
5615d6d30edSStella Laurenzo     SmallVector<int64_t> shape;
5625d6d30edSStella Laurenzo     if (explicitShape) {
5635d6d30edSStella Laurenzo       shape.append(explicitShape->begin(), explicitShape->end());
5645d6d30edSStella Laurenzo     } else {
5655d6d30edSStella Laurenzo       shape.append(arrayInfo.shape.begin(),
5665d6d30edSStella Laurenzo                    arrayInfo.shape.begin() + arrayInfo.ndim);
5675d6d30edSStella Laurenzo     }
568436c6c9cSStella Laurenzo 
5695d6d30edSStella Laurenzo     MlirAttribute encodingAttr = mlirAttributeGetNull();
570436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
5715d6d30edSStella Laurenzo 
5725d6d30edSStella Laurenzo     // Detect format codes that are suitable for bulk loading. This includes
5735d6d30edSStella Laurenzo     // all byte aligned integer and floating point types up to 8 bytes.
5745d6d30edSStella Laurenzo     // Notably, this excludes, bool (which needs to be bit-packed) and
5755d6d30edSStella Laurenzo     // other exotics which do not have a direct representation in the buffer
5765d6d30edSStella Laurenzo     // protocol (i.e. complex, etc).
5770a81ace0SKazu Hirata     std::optional<MlirType> bulkLoadElementType;
5785d6d30edSStella Laurenzo     if (explicitType) {
5795d6d30edSStella Laurenzo       bulkLoadElementType = *explicitType;
5805d6d30edSStella Laurenzo     } else if (arrayInfo.format == "f") {
581436c6c9cSStella Laurenzo       // f32
582436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
5835d6d30edSStella Laurenzo       bulkLoadElementType = mlirF32TypeGet(context);
584436c6c9cSStella Laurenzo     } else if (arrayInfo.format == "d") {
585436c6c9cSStella Laurenzo       // f64
586436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
5875d6d30edSStella Laurenzo       bulkLoadElementType = mlirF64TypeGet(context);
5885d6d30edSStella Laurenzo     } else if (arrayInfo.format == "e") {
5895d6d30edSStella Laurenzo       // f16
5905d6d30edSStella Laurenzo       assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
5915d6d30edSStella Laurenzo       bulkLoadElementType = mlirF16TypeGet(context);
592436c6c9cSStella Laurenzo     } else if (isSignedIntegerFormat(arrayInfo.format)) {
593436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
594436c6c9cSStella Laurenzo         // i32
5955d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
596436c6c9cSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 32);
597436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
598436c6c9cSStella Laurenzo         // i64
5995d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
600436c6c9cSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 64);
6015d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 1) {
6025d6d30edSStella Laurenzo         // i8
6035d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
6045d6d30edSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 8);
6055d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 2) {
6065d6d30edSStella Laurenzo         // i16
6075d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
6085d6d30edSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 16);
609436c6c9cSStella Laurenzo       }
610436c6c9cSStella Laurenzo     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
611436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
612436c6c9cSStella Laurenzo         // unsigned i32
6135d6d30edSStella Laurenzo         bulkLoadElementType = signless
614436c6c9cSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 32)
615436c6c9cSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 32);
616436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
617436c6c9cSStella Laurenzo         // unsigned i64
6185d6d30edSStella Laurenzo         bulkLoadElementType = signless
619436c6c9cSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 64)
620436c6c9cSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 64);
6215d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 1) {
6225d6d30edSStella Laurenzo         // i8
6235d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
6245d6d30edSStella Laurenzo                                        : mlirIntegerTypeUnsignedGet(context, 8);
6255d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 2) {
6265d6d30edSStella Laurenzo         // i16
6275d6d30edSStella Laurenzo         bulkLoadElementType = signless
6285d6d30edSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 16)
6295d6d30edSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 16);
630436c6c9cSStella Laurenzo       }
631436c6c9cSStella Laurenzo     }
6325d6d30edSStella Laurenzo     if (bulkLoadElementType) {
63399dee31eSAdam Paszke       MlirType shapedType;
63499dee31eSAdam Paszke       if (mlirTypeIsAShaped(*bulkLoadElementType)) {
63599dee31eSAdam Paszke         if (explicitShape) {
63699dee31eSAdam Paszke           throw std::invalid_argument("Shape can only be specified explicitly "
63799dee31eSAdam Paszke                                       "when the type is not a shaped type.");
63899dee31eSAdam Paszke         }
63999dee31eSAdam Paszke         shapedType = *bulkLoadElementType;
64099dee31eSAdam Paszke       } else {
64199dee31eSAdam Paszke         shapedType = mlirRankedTensorTypeGet(
6425d6d30edSStella Laurenzo             shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
64399dee31eSAdam Paszke       }
6445d6d30edSStella Laurenzo       size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
6455d6d30edSStella Laurenzo       MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
6465d6d30edSStella Laurenzo           shapedType, rawBufferSize, arrayInfo.ptr);
6475d6d30edSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
6485d6d30edSStella Laurenzo         throw std::invalid_argument(
6495d6d30edSStella Laurenzo             "DenseElementsAttr could not be constructed from the given buffer. "
6505d6d30edSStella Laurenzo             "This may mean that the Python buffer layout does not match that "
6515d6d30edSStella Laurenzo             "MLIR expected layout and is a bug.");
6525d6d30edSStella Laurenzo       }
6535d6d30edSStella Laurenzo       return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
6545d6d30edSStella Laurenzo     }
655436c6c9cSStella Laurenzo 
6565d6d30edSStella Laurenzo     throw std::invalid_argument(
6575d6d30edSStella Laurenzo         std::string("unimplemented array format conversion from format: ") +
6585d6d30edSStella Laurenzo         arrayInfo.format);
659436c6c9cSStella Laurenzo   }
660436c6c9cSStella Laurenzo 
6611fc096afSMehdi Amini   static PyDenseElementsAttribute getSplat(const PyType &shapedType,
662436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
663436c6c9cSStella Laurenzo     auto contextWrapper =
664436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
665436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
666436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
667436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
668436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
669*4811270bSmax       throw py::value_error(message);
670436c6c9cSStella Laurenzo     }
671436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
672436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
673436c6c9cSStella Laurenzo       std::string message =
674436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
675436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
676*4811270bSmax       throw py::value_error(message);
677436c6c9cSStella Laurenzo     }
678436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
679436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
680436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
681436c6c9cSStella Laurenzo       std::string message =
682436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
683436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
684436c6c9cSStella Laurenzo       message.append(", element=");
685436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
686*4811270bSmax       throw py::value_error(message);
687436c6c9cSStella Laurenzo     }
688436c6c9cSStella Laurenzo 
689436c6c9cSStella Laurenzo     MlirAttribute elements =
690436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
691436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
692436c6c9cSStella Laurenzo   }
693436c6c9cSStella Laurenzo 
694436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
695436c6c9cSStella Laurenzo 
696436c6c9cSStella Laurenzo   py::buffer_info accessBuffer() {
697436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
698436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
6995d6d30edSStella Laurenzo     std::string format;
700436c6c9cSStella Laurenzo 
701436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
702436c6c9cSStella Laurenzo       // f32
7035d6d30edSStella Laurenzo       return bufferInfo<float>(shapedType);
70402b6fb21SMehdi Amini     }
70502b6fb21SMehdi Amini     if (mlirTypeIsAF64(elementType)) {
706436c6c9cSStella Laurenzo       // f64
7075d6d30edSStella Laurenzo       return bufferInfo<double>(shapedType);
708bb56c2b3SMehdi Amini     }
709bb56c2b3SMehdi Amini     if (mlirTypeIsAF16(elementType)) {
7105d6d30edSStella Laurenzo       // f16
7115d6d30edSStella Laurenzo       return bufferInfo<uint16_t>(shapedType, "e");
712bb56c2b3SMehdi Amini     }
713ef1b735dSmax     if (mlirTypeIsAIndex(elementType)) {
714ef1b735dSmax       // Same as IndexType::kInternalStorageBitWidth
715ef1b735dSmax       return bufferInfo<int64_t>(shapedType);
716ef1b735dSmax     }
717bb56c2b3SMehdi Amini     if (mlirTypeIsAInteger(elementType) &&
718436c6c9cSStella Laurenzo         mlirIntegerTypeGetWidth(elementType) == 32) {
719436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
720436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
721436c6c9cSStella Laurenzo         // i32
7225d6d30edSStella Laurenzo         return bufferInfo<int32_t>(shapedType);
723e5639b3fSMehdi Amini       }
724e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
725436c6c9cSStella Laurenzo         // unsigned i32
7265d6d30edSStella Laurenzo         return bufferInfo<uint32_t>(shapedType);
727436c6c9cSStella Laurenzo       }
728436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
729436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
730436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
731436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
732436c6c9cSStella Laurenzo         // i64
7335d6d30edSStella Laurenzo         return bufferInfo<int64_t>(shapedType);
734e5639b3fSMehdi Amini       }
735e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
736436c6c9cSStella Laurenzo         // unsigned i64
7375d6d30edSStella Laurenzo         return bufferInfo<uint64_t>(shapedType);
7385d6d30edSStella Laurenzo       }
7395d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
7405d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 8) {
7415d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
7425d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
7435d6d30edSStella Laurenzo         // i8
7445d6d30edSStella Laurenzo         return bufferInfo<int8_t>(shapedType);
745e5639b3fSMehdi Amini       }
746e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
7475d6d30edSStella Laurenzo         // unsigned i8
7485d6d30edSStella Laurenzo         return bufferInfo<uint8_t>(shapedType);
7495d6d30edSStella Laurenzo       }
7505d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
7515d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 16) {
7525d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
7535d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
7545d6d30edSStella Laurenzo         // i16
7555d6d30edSStella Laurenzo         return bufferInfo<int16_t>(shapedType);
756e5639b3fSMehdi Amini       }
757e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
7585d6d30edSStella Laurenzo         // unsigned i16
7595d6d30edSStella Laurenzo         return bufferInfo<uint16_t>(shapedType);
760436c6c9cSStella Laurenzo       }
761436c6c9cSStella Laurenzo     }
762436c6c9cSStella Laurenzo 
763c5f445d1SStella Laurenzo     // TODO: Currently crashes the program.
7645d6d30edSStella Laurenzo     // Reported as https://github.com/pybind/pybind11/issues/3336
765c5f445d1SStella Laurenzo     throw std::invalid_argument(
766c5f445d1SStella Laurenzo         "unsupported data type for conversion to Python buffer");
767436c6c9cSStella Laurenzo   }
768436c6c9cSStella Laurenzo 
769436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
770436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
771436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
772436c6c9cSStella Laurenzo                     py::arg("array"), py::arg("signless") = true,
7735d6d30edSStella Laurenzo                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
774436c6c9cSStella Laurenzo                     py::arg("context") = py::none(),
7755d6d30edSStella Laurenzo                     kDenseElementsAttrGetDocstring)
776436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
777436c6c9cSStella Laurenzo                     py::arg("shaped_type"), py::arg("element_attr"),
778436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
779436c6c9cSStella Laurenzo         .def_property_readonly("is_splat",
780436c6c9cSStella Laurenzo                                [](PyDenseElementsAttribute &self) -> bool {
781436c6c9cSStella Laurenzo                                  return mlirDenseElementsAttrIsSplat(self);
782436c6c9cSStella Laurenzo                                })
78391259963SAdam Paszke         .def("get_splat_value",
78491259963SAdam Paszke              [](PyDenseElementsAttribute &self) -> PyAttribute {
78591259963SAdam Paszke                if (!mlirDenseElementsAttrIsSplat(self)) {
786*4811270bSmax                  throw py::value_error(
78791259963SAdam Paszke                      "get_splat_value called on a non-splat attribute");
78891259963SAdam Paszke                }
78991259963SAdam Paszke                return PyAttribute(self.getContext(),
79091259963SAdam Paszke                                   mlirDenseElementsAttrGetSplatValue(self));
79191259963SAdam Paszke              })
792436c6c9cSStella Laurenzo         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
793436c6c9cSStella Laurenzo   }
794436c6c9cSStella Laurenzo 
795436c6c9cSStella Laurenzo private:
796436c6c9cSStella Laurenzo   static bool isUnsignedIntegerFormat(const std::string &format) {
797436c6c9cSStella Laurenzo     if (format.empty())
798436c6c9cSStella Laurenzo       return false;
799436c6c9cSStella Laurenzo     char code = format[0];
800436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
801436c6c9cSStella Laurenzo            code == 'Q';
802436c6c9cSStella Laurenzo   }
803436c6c9cSStella Laurenzo 
804436c6c9cSStella Laurenzo   static bool isSignedIntegerFormat(const std::string &format) {
805436c6c9cSStella Laurenzo     if (format.empty())
806436c6c9cSStella Laurenzo       return false;
807436c6c9cSStella Laurenzo     char code = format[0];
808436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
809436c6c9cSStella Laurenzo            code == 'q';
810436c6c9cSStella Laurenzo   }
811436c6c9cSStella Laurenzo 
812436c6c9cSStella Laurenzo   template <typename Type>
813436c6c9cSStella Laurenzo   py::buffer_info bufferInfo(MlirType shapedType,
8145d6d30edSStella Laurenzo                              const char *explicitFormat = nullptr) {
815436c6c9cSStella Laurenzo     intptr_t rank = mlirShapedTypeGetRank(shapedType);
816436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
817436c6c9cSStella Laurenzo     // Buffer is configured for read-only access below.
818436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
819436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
820436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
821436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
822436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
823436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
824436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
825436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
826f0e847d0SRahul Kayaith     if (mlirDenseElementsAttrIsSplat(*this)) {
827f0e847d0SRahul Kayaith       // Splats are special, only the single value is stored.
828f0e847d0SRahul Kayaith       strides.assign(rank, 0);
829f0e847d0SRahul Kayaith     } else {
830436c6c9cSStella Laurenzo       for (intptr_t i = 1; i < rank; ++i) {
831f0e847d0SRahul Kayaith         intptr_t strideFactor = 1;
832f0e847d0SRahul Kayaith         for (intptr_t j = i; j < rank; ++j)
833436c6c9cSStella Laurenzo           strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
834436c6c9cSStella Laurenzo         strides.push_back(sizeof(Type) * strideFactor);
835436c6c9cSStella Laurenzo       }
836436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type));
837f0e847d0SRahul Kayaith     }
8385d6d30edSStella Laurenzo     std::string format;
8395d6d30edSStella Laurenzo     if (explicitFormat) {
8405d6d30edSStella Laurenzo       format = explicitFormat;
8415d6d30edSStella Laurenzo     } else {
8425d6d30edSStella Laurenzo       format = py::format_descriptor<Type>::format();
8435d6d30edSStella Laurenzo     }
8445d6d30edSStella Laurenzo     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
8455d6d30edSStella Laurenzo                            /*readonly=*/true);
846436c6c9cSStella Laurenzo   }
847436c6c9cSStella Laurenzo }; // namespace
848436c6c9cSStella Laurenzo 
849436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer
850436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access.
851436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
852436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
853436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
854436c6c9cSStella Laurenzo public:
855436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
856436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
857436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
858436c6c9cSStella Laurenzo 
859436c6c9cSStella Laurenzo   /// Returns the element at the given linear position. Asserts if the index is
860436c6c9cSStella Laurenzo   /// out of range.
861436c6c9cSStella Laurenzo   py::int_ dunderGetItem(intptr_t pos) {
862436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
863*4811270bSmax       throw py::index_error("attempt to access out of bounds element");
864436c6c9cSStella Laurenzo     }
865436c6c9cSStella Laurenzo 
866436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
867436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
868436c6c9cSStella Laurenzo     assert(mlirTypeIsAInteger(type) &&
869436c6c9cSStella Laurenzo            "expected integer element type in dense int elements attribute");
870436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
871436c6c9cSStella Laurenzo     // elemental type of the attribute. py::int_ is implicitly constructible
872436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
873436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
874436c6c9cSStella Laurenzo     // querying them on each element access.
875436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
876436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
877436c6c9cSStella Laurenzo     if (isUnsigned) {
878436c6c9cSStella Laurenzo       if (width == 1) {
879436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
880436c6c9cSStella Laurenzo       }
881308d8b8cSRahul Kayaith       if (width == 8) {
882308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt8Value(*this, pos);
883308d8b8cSRahul Kayaith       }
884308d8b8cSRahul Kayaith       if (width == 16) {
885308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt16Value(*this, pos);
886308d8b8cSRahul Kayaith       }
887436c6c9cSStella Laurenzo       if (width == 32) {
888436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
889436c6c9cSStella Laurenzo       }
890436c6c9cSStella Laurenzo       if (width == 64) {
891436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
892436c6c9cSStella Laurenzo       }
893436c6c9cSStella Laurenzo     } else {
894436c6c9cSStella Laurenzo       if (width == 1) {
895436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
896436c6c9cSStella Laurenzo       }
897308d8b8cSRahul Kayaith       if (width == 8) {
898308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt8Value(*this, pos);
899308d8b8cSRahul Kayaith       }
900308d8b8cSRahul Kayaith       if (width == 16) {
901308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt16Value(*this, pos);
902308d8b8cSRahul Kayaith       }
903436c6c9cSStella Laurenzo       if (width == 32) {
904436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt32Value(*this, pos);
905436c6c9cSStella Laurenzo       }
906436c6c9cSStella Laurenzo       if (width == 64) {
907436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt64Value(*this, pos);
908436c6c9cSStella Laurenzo       }
909436c6c9cSStella Laurenzo     }
910*4811270bSmax     throw py::type_error("Unsupported integer type");
911436c6c9cSStella Laurenzo   }
912436c6c9cSStella Laurenzo 
913436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
914436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
915436c6c9cSStella Laurenzo   }
916436c6c9cSStella Laurenzo };
917436c6c9cSStella Laurenzo 
918436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
919436c6c9cSStella Laurenzo public:
920436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
921436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
922436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
923436c6c9cSStella Laurenzo 
924436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
925436c6c9cSStella Laurenzo 
9269fb1086bSAdrian Kuegel   bool dunderContains(const std::string &name) {
9279fb1086bSAdrian Kuegel     return !mlirAttributeIsNull(
9289fb1086bSAdrian Kuegel         mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
9299fb1086bSAdrian Kuegel   }
9309fb1086bSAdrian Kuegel 
931436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
9329fb1086bSAdrian Kuegel     c.def("__contains__", &PyDictAttribute::dunderContains);
933436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
934436c6c9cSStella Laurenzo     c.def_static(
935436c6c9cSStella Laurenzo         "get",
936436c6c9cSStella Laurenzo         [](py::dict attributes, DefaultingPyMlirContext context) {
937436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
938436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
939436c6c9cSStella Laurenzo           for (auto &it : attributes) {
94002b6fb21SMehdi Amini             auto &mlirAttr = it.second.cast<PyAttribute &>();
941436c6c9cSStella Laurenzo             auto name = it.first.cast<std::string>();
942436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
94302b6fb21SMehdi Amini                 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
944436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
94502b6fb21SMehdi Amini                 mlirAttr));
946436c6c9cSStella Laurenzo           }
947436c6c9cSStella Laurenzo           MlirAttribute attr =
948436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
949436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
950436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
951436c6c9cSStella Laurenzo         },
952ed9e52f3SAlex Zinenko         py::arg("value") = py::dict(), py::arg("context") = py::none(),
953436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
954436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
955436c6c9cSStella Laurenzo       MlirAttribute attr =
956436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
957436c6c9cSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
958*4811270bSmax         throw py::key_error("attempt to access a non-existent attribute");
959436c6c9cSStella Laurenzo       }
960436c6c9cSStella Laurenzo       return PyAttribute(self.getContext(), attr);
961436c6c9cSStella Laurenzo     });
962436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
963436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
964*4811270bSmax         throw py::index_error("attempt to access out of bounds attribute");
965436c6c9cSStella Laurenzo       }
966436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
967436c6c9cSStella Laurenzo       return PyNamedAttribute(
968436c6c9cSStella Laurenzo           namedAttr.attribute,
969436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
970436c6c9cSStella Laurenzo     });
971436c6c9cSStella Laurenzo   }
972436c6c9cSStella Laurenzo };
973436c6c9cSStella Laurenzo 
974436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
975436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
976436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
977436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
978436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
979436c6c9cSStella Laurenzo public:
980436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
981436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
982436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
983436c6c9cSStella Laurenzo 
984436c6c9cSStella Laurenzo   py::float_ dunderGetItem(intptr_t pos) {
985436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
986*4811270bSmax       throw py::index_error("attempt to access out of bounds element");
987436c6c9cSStella Laurenzo     }
988436c6c9cSStella Laurenzo 
989436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
990436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
991436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
992436c6c9cSStella Laurenzo     // elemental type of the attribute. py::float_ is implicitly constructible
993436c6c9cSStella Laurenzo     // from float and double.
994436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
995436c6c9cSStella Laurenzo     // querying them on each element access.
996436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
997436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetFloatValue(*this, pos);
998436c6c9cSStella Laurenzo     }
999436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
1000436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
1001436c6c9cSStella Laurenzo     }
1002*4811270bSmax     throw py::type_error("Unsupported floating-point type");
1003436c6c9cSStella Laurenzo   }
1004436c6c9cSStella Laurenzo 
1005436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1006436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1007436c6c9cSStella Laurenzo   }
1008436c6c9cSStella Laurenzo };
1009436c6c9cSStella Laurenzo 
1010436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1011436c6c9cSStella Laurenzo public:
1012436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1013436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
1014436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1015436c6c9cSStella Laurenzo 
1016436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1017436c6c9cSStella Laurenzo     c.def_static(
1018436c6c9cSStella Laurenzo         "get",
1019436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
1020436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
1021436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
1022436c6c9cSStella Laurenzo         },
1023436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
1024436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
1025436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyTypeAttribute &self) {
1026436c6c9cSStella Laurenzo       return PyType(self.getContext()->getRef(),
1027436c6c9cSStella Laurenzo                     mlirTypeAttrGetValue(self.get()));
1028436c6c9cSStella Laurenzo     });
1029436c6c9cSStella Laurenzo   }
1030436c6c9cSStella Laurenzo };
1031436c6c9cSStella Laurenzo 
1032436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
1033436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1034436c6c9cSStella Laurenzo public:
1035436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1036436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
1037436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1038436c6c9cSStella Laurenzo 
1039436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1040436c6c9cSStella Laurenzo     c.def_static(
1041436c6c9cSStella Laurenzo         "get",
1042436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
1043436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
1044436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
1045436c6c9cSStella Laurenzo         },
1046436c6c9cSStella Laurenzo         py::arg("context") = py::none(), "Create a Unit attribute.");
1047436c6c9cSStella Laurenzo   }
1048436c6c9cSStella Laurenzo };
1049436c6c9cSStella Laurenzo 
1050ac2e2d65SDenys Shabalin /// Strided layout attribute subclass.
1051ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute
1052ac2e2d65SDenys Shabalin     : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1053ac2e2d65SDenys Shabalin public:
1054ac2e2d65SDenys Shabalin   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1055ac2e2d65SDenys Shabalin   static constexpr const char *pyClassName = "StridedLayoutAttr";
1056ac2e2d65SDenys Shabalin   using PyConcreteAttribute::PyConcreteAttribute;
1057ac2e2d65SDenys Shabalin 
1058ac2e2d65SDenys Shabalin   static void bindDerived(ClassTy &c) {
1059ac2e2d65SDenys Shabalin     c.def_static(
1060ac2e2d65SDenys Shabalin         "get",
1061ac2e2d65SDenys Shabalin         [](int64_t offset, const std::vector<int64_t> strides,
1062ac2e2d65SDenys Shabalin            DefaultingPyMlirContext ctx) {
1063ac2e2d65SDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1064ac2e2d65SDenys Shabalin               ctx->get(), offset, strides.size(), strides.data());
1065ac2e2d65SDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1066ac2e2d65SDenys Shabalin         },
1067ac2e2d65SDenys Shabalin         py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
1068ac2e2d65SDenys Shabalin         "Gets a strided layout attribute.");
1069e3fd612eSDenys Shabalin     c.def_static(
1070e3fd612eSDenys Shabalin         "get_fully_dynamic",
1071e3fd612eSDenys Shabalin         [](int64_t rank, DefaultingPyMlirContext ctx) {
1072e3fd612eSDenys Shabalin           auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
1073e3fd612eSDenys Shabalin           std::vector<int64_t> strides(rank);
1074e3fd612eSDenys Shabalin           std::fill(strides.begin(), strides.end(), dynamic);
1075e3fd612eSDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1076e3fd612eSDenys Shabalin               ctx->get(), dynamic, strides.size(), strides.data());
1077e3fd612eSDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1078e3fd612eSDenys Shabalin         },
1079e3fd612eSDenys Shabalin         py::arg("rank"), py::arg("context") = py::none(),
1080e3fd612eSDenys Shabalin         "Gets a strided layout attribute with dynamic offset and strides of a "
1081e3fd612eSDenys Shabalin         "given rank.");
1082ac2e2d65SDenys Shabalin     c.def_property_readonly(
1083ac2e2d65SDenys Shabalin         "offset",
1084ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1085ac2e2d65SDenys Shabalin           return mlirStridedLayoutAttrGetOffset(self);
1086ac2e2d65SDenys Shabalin         },
1087ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1088ac2e2d65SDenys Shabalin     c.def_property_readonly(
1089ac2e2d65SDenys Shabalin         "strides",
1090ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1091ac2e2d65SDenys Shabalin           intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1092ac2e2d65SDenys Shabalin           std::vector<int64_t> strides(size);
1093ac2e2d65SDenys Shabalin           for (intptr_t i = 0; i < size; i++) {
1094ac2e2d65SDenys Shabalin             strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1095ac2e2d65SDenys Shabalin           }
1096ac2e2d65SDenys Shabalin           return strides;
1097ac2e2d65SDenys Shabalin         },
1098ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1099ac2e2d65SDenys Shabalin   }
1100ac2e2d65SDenys Shabalin };
1101ac2e2d65SDenys Shabalin 
1102436c6c9cSStella Laurenzo } // namespace
1103436c6c9cSStella Laurenzo 
1104436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) {
1105436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
1106619fd8c2SJeff Niu 
1107619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::bind(m);
1108619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1109619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::bind(m);
1110619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1111619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::bind(m);
1112619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1113619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::bind(m);
1114619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1115619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::bind(m);
1116619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1117619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::bind(m);
1118619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1119619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::bind(m);
1120619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
1121619fd8c2SJeff Niu 
1122436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
1123436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1124436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
1125436c6c9cSStella Laurenzo   PyDenseElementsAttribute::bind(m);
1126436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
1127436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
1128436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
1129436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
11305c3861b2SYun Long   PyOpaqueAttribute::bind(m);
1131436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
1132436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
1133436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
1134436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
1135436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
1136ac2e2d65SDenys Shabalin 
1137ac2e2d65SDenys Shabalin   PyStridedLayoutAttribute::bind(m);
1138436c6c9cSStella Laurenzo }
1139