xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision ac2e2d6598191d6ffc31127b80d8cba10d00b765)
1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
91fc096afSMehdi Amini #include <utility>
101fc096afSMehdi Amini 
11436c6c9cSStella Laurenzo #include "IRModule.h"
12436c6c9cSStella Laurenzo 
13436c6c9cSStella Laurenzo #include "PybindUtils.h"
14436c6c9cSStella Laurenzo 
15436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
16436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
17436c6c9cSStella Laurenzo 
18436c6c9cSStella Laurenzo namespace py = pybind11;
19436c6c9cSStella Laurenzo using namespace mlir;
20436c6c9cSStella Laurenzo using namespace mlir::python;
21436c6c9cSStella Laurenzo 
225d6d30edSStella Laurenzo using llvm::Optional;
23436c6c9cSStella Laurenzo using llvm::SmallVector;
24436c6c9cSStella Laurenzo using llvm::Twine;
25436c6c9cSStella Laurenzo 
265d6d30edSStella Laurenzo //------------------------------------------------------------------------------
275d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
285d6d30edSStella Laurenzo //------------------------------------------------------------------------------
295d6d30edSStella Laurenzo 
305d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] =
315d6d30edSStella Laurenzo     R"(Gets a DenseElementsAttr from a Python buffer or array.
325d6d30edSStella Laurenzo 
335d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based
345d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and
355d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these
365d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer.
375d6d30edSStella Laurenzo 
385d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided
395d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal
405d6d30edSStella Laurenzo representation:
415d6d30edSStella Laurenzo 
425d6d30edSStella Laurenzo   * Integer types (except for i1): the buffer must be byte aligned to the
435d6d30edSStella Laurenzo     next byte boundary.
445d6d30edSStella Laurenzo   * Floating point types: Must be bit-castable to the given floating point
455d6d30edSStella Laurenzo     size.
465d6d30edSStella Laurenzo   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
475d6d30edSStella Laurenzo     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
485d6d30edSStella Laurenzo     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
495d6d30edSStella Laurenzo 
505d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0
515d6d30edSStella Laurenzo or 255), then a splat will be created.
525d6d30edSStella Laurenzo 
535d6d30edSStella Laurenzo Args:
545d6d30edSStella Laurenzo   array: The array or buffer to convert.
555d6d30edSStella Laurenzo   signless: If inferring an appropriate MLIR type, use signless types for
565d6d30edSStella Laurenzo     integers (defaults True).
575d6d30edSStella Laurenzo   type: Skips inference of the MLIR element type and uses this instead. The
585d6d30edSStella Laurenzo     storage size must be consistent with the actual contents of the buffer.
595d6d30edSStella Laurenzo   shape: Overrides the shape of the buffer when constructing the MLIR
605d6d30edSStella Laurenzo     shaped type. This is needed when the physical and logical shape differ (as
615d6d30edSStella Laurenzo     for i1).
625d6d30edSStella Laurenzo   context: Explicit context, if not from context manager.
635d6d30edSStella Laurenzo 
645d6d30edSStella Laurenzo Returns:
655d6d30edSStella Laurenzo   DenseElementsAttr on success.
665d6d30edSStella Laurenzo 
675d6d30edSStella Laurenzo Raises:
685d6d30edSStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
695d6d30edSStella Laurenzo     type or if the buffer does not meet expectations.
705d6d30edSStella Laurenzo )";
715d6d30edSStella Laurenzo 
72436c6c9cSStella Laurenzo namespace {
73436c6c9cSStella Laurenzo 
74436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
75436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
76436c6c9cSStella Laurenzo }
77436c6c9cSStella Laurenzo 
78436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
79436c6c9cSStella Laurenzo public:
80436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
81436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
82436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
83436c6c9cSStella Laurenzo 
84436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
85436c6c9cSStella Laurenzo     c.def_static(
86436c6c9cSStella Laurenzo         "get",
87436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
88436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
89436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
90436c6c9cSStella Laurenzo         },
91436c6c9cSStella Laurenzo         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
92436c6c9cSStella Laurenzo   }
93436c6c9cSStella Laurenzo };
94436c6c9cSStella Laurenzo 
95ed9e52f3SAlex Zinenko template <typename T>
96ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) {
97ed9e52f3SAlex Zinenko   try {
98ed9e52f3SAlex Zinenko     return object.cast<T>();
99ed9e52f3SAlex Zinenko   } catch (py::cast_error &err) {
100ed9e52f3SAlex Zinenko     std::string msg =
101ed9e52f3SAlex Zinenko         std::string(
102ed9e52f3SAlex Zinenko             "Invalid attribute when attempting to create an ArrayAttribute (") +
103ed9e52f3SAlex Zinenko         err.what() + ")";
104ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
105ed9e52f3SAlex Zinenko   } catch (py::reference_cast_error &err) {
106ed9e52f3SAlex Zinenko     std::string msg = std::string("Invalid attribute (None?) when attempting "
107ed9e52f3SAlex Zinenko                                   "to create an ArrayAttribute (") +
108ed9e52f3SAlex Zinenko                       err.what() + ")";
109ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
110ed9e52f3SAlex Zinenko   }
111ed9e52f3SAlex Zinenko }
112ed9e52f3SAlex Zinenko 
113619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived
114619fd8c2SJeff Niu /// implementation class.
115619fd8c2SJeff Niu template <typename EltTy, typename DerivedT>
116133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
117619fd8c2SJeff Niu public:
118133624acSJeff Niu   using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
119619fd8c2SJeff Niu 
120619fd8c2SJeff Niu   /// Iterator over the integer elements of a dense array.
121619fd8c2SJeff Niu   class PyDenseArrayIterator {
122619fd8c2SJeff Niu   public:
1234a1b1196SMehdi Amini     PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
124619fd8c2SJeff Niu 
125619fd8c2SJeff Niu     /// Return a copy of the iterator.
126619fd8c2SJeff Niu     PyDenseArrayIterator dunderIter() { return *this; }
127619fd8c2SJeff Niu 
128619fd8c2SJeff Niu     /// Return the next element.
129619fd8c2SJeff Niu     EltTy dunderNext() {
130619fd8c2SJeff Niu       // Throw if the index has reached the end.
131619fd8c2SJeff Niu       if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
132619fd8c2SJeff Niu         throw py::stop_iteration();
133619fd8c2SJeff Niu       return DerivedT::getElement(attr.get(), nextIndex++);
134619fd8c2SJeff Niu     }
135619fd8c2SJeff Niu 
136619fd8c2SJeff Niu     /// Bind the iterator class.
137619fd8c2SJeff Niu     static void bind(py::module &m) {
138619fd8c2SJeff Niu       py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName,
139619fd8c2SJeff Niu                                        py::module_local())
140619fd8c2SJeff Niu           .def("__iter__", &PyDenseArrayIterator::dunderIter)
141619fd8c2SJeff Niu           .def("__next__", &PyDenseArrayIterator::dunderNext);
142619fd8c2SJeff Niu     }
143619fd8c2SJeff Niu 
144619fd8c2SJeff Niu   private:
145619fd8c2SJeff Niu     /// The referenced dense array attribute.
146619fd8c2SJeff Niu     PyAttribute attr;
147619fd8c2SJeff Niu     /// The next index to read.
148619fd8c2SJeff Niu     int nextIndex = 0;
149619fd8c2SJeff Niu   };
150619fd8c2SJeff Niu 
151619fd8c2SJeff Niu   /// Get the element at the given index.
152619fd8c2SJeff Niu   EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
153619fd8c2SJeff Niu 
154619fd8c2SJeff Niu   /// Bind the attribute class.
155133624acSJeff Niu   static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
156619fd8c2SJeff Niu     // Bind the constructor.
157619fd8c2SJeff Niu     c.def_static(
158619fd8c2SJeff Niu         "get",
159619fd8c2SJeff Niu         [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
160619fd8c2SJeff Niu           MlirAttribute attr =
161619fd8c2SJeff Niu               DerivedT::getAttribute(ctx->get(), values.size(), values.data());
162133624acSJeff Niu           return DerivedT(ctx->getRef(), attr);
163619fd8c2SJeff Niu         },
164619fd8c2SJeff Niu         py::arg("values"), py::arg("context") = py::none(),
165619fd8c2SJeff Niu         "Gets a uniqued dense array attribute");
166619fd8c2SJeff Niu     // Bind the array methods.
167133624acSJeff Niu     c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
168619fd8c2SJeff Niu       if (i >= mlirDenseArrayGetNumElements(arr))
169619fd8c2SJeff Niu         throw py::index_error("DenseArray index out of range");
170619fd8c2SJeff Niu       return arr.getItem(i);
171619fd8c2SJeff Niu     });
172133624acSJeff Niu     c.def("__len__", [](const DerivedT &arr) {
173619fd8c2SJeff Niu       return mlirDenseArrayGetNumElements(arr);
174619fd8c2SJeff Niu     });
175133624acSJeff Niu     c.def("__iter__",
176133624acSJeff Niu           [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
1774a1b1196SMehdi Amini     c.def("__add__", [](DerivedT &arr, const py::list &extras) {
178619fd8c2SJeff Niu       std::vector<EltTy> values;
179619fd8c2SJeff Niu       intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
180619fd8c2SJeff Niu       values.reserve(numOldElements + py::len(extras));
181619fd8c2SJeff Niu       for (intptr_t i = 0; i < numOldElements; ++i)
182619fd8c2SJeff Niu         values.push_back(arr.getItem(i));
183619fd8c2SJeff Niu       for (py::handle attr : extras)
184619fd8c2SJeff Niu         values.push_back(pyTryCast<EltTy>(attr));
185619fd8c2SJeff Niu       MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(),
186619fd8c2SJeff Niu                                                   values.size(), values.data());
187133624acSJeff Niu       return DerivedT(arr.getContext(), attr);
188619fd8c2SJeff Niu     });
189619fd8c2SJeff Niu   }
190619fd8c2SJeff Niu };
191619fd8c2SJeff Niu 
192619fd8c2SJeff Niu /// Instantiate the python dense array classes.
193619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute
194619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int, PyDenseBoolArrayAttribute> {
195619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
196619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseBoolArrayGet;
197619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseBoolArrayGetElement;
198619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseBoolArrayAttr";
199619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
200619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
201619fd8c2SJeff Niu };
202619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute
203619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
204619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
205619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI8ArrayGet;
206619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI8ArrayGetElement;
207619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI8ArrayAttr";
208619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
209619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
210619fd8c2SJeff Niu };
211619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute
212619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
213619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
214619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI16ArrayGet;
215619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI16ArrayGetElement;
216619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI16ArrayAttr";
217619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
218619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
219619fd8c2SJeff Niu };
220619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute
221619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
222619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
223619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI32ArrayGet;
224619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI32ArrayGetElement;
225619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI32ArrayAttr";
226619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
227619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
228619fd8c2SJeff Niu };
229619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute
230619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
231619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
232619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI64ArrayGet;
233619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI64ArrayGetElement;
234619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI64ArrayAttr";
235619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
236619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
237619fd8c2SJeff Niu };
238619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute
239619fd8c2SJeff Niu     : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
240619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
241619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF32ArrayGet;
242619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF32ArrayGetElement;
243619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF32ArrayAttr";
244619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
245619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
246619fd8c2SJeff Niu };
247619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute
248619fd8c2SJeff Niu     : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
249619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
250619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF64ArrayGet;
251619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF64ArrayGetElement;
252619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF64ArrayAttr";
253619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
254619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
255619fd8c2SJeff Niu };
256619fd8c2SJeff Niu 
257436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
258436c6c9cSStella Laurenzo public:
259436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
260436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
261436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
262436c6c9cSStella Laurenzo 
263436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
264436c6c9cSStella Laurenzo   public:
2651fc096afSMehdi Amini     PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
266436c6c9cSStella Laurenzo 
267436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
268436c6c9cSStella Laurenzo 
269436c6c9cSStella Laurenzo     PyAttribute dunderNext() {
270bca88952SJeff Niu       // TODO: Throw is an inefficient way to stop iteration.
271bca88952SJeff Niu       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
272436c6c9cSStella Laurenzo         throw py::stop_iteration();
273436c6c9cSStella Laurenzo       return PyAttribute(attr.getContext(),
274436c6c9cSStella Laurenzo                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
275436c6c9cSStella Laurenzo     }
276436c6c9cSStella Laurenzo 
277436c6c9cSStella Laurenzo     static void bind(py::module &m) {
278f05ff4f7SStella Laurenzo       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
279f05ff4f7SStella Laurenzo                                            py::module_local())
280436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
281436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
282436c6c9cSStella Laurenzo     }
283436c6c9cSStella Laurenzo 
284436c6c9cSStella Laurenzo   private:
285436c6c9cSStella Laurenzo     PyAttribute attr;
286436c6c9cSStella Laurenzo     int nextIndex = 0;
287436c6c9cSStella Laurenzo   };
288436c6c9cSStella Laurenzo 
289ed9e52f3SAlex Zinenko   PyAttribute getItem(intptr_t i) {
290ed9e52f3SAlex Zinenko     return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
291ed9e52f3SAlex Zinenko   }
292ed9e52f3SAlex Zinenko 
293436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
294436c6c9cSStella Laurenzo     c.def_static(
295436c6c9cSStella Laurenzo         "get",
296436c6c9cSStella Laurenzo         [](py::list attributes, DefaultingPyMlirContext context) {
297436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
298436c6c9cSStella Laurenzo           mlirAttributes.reserve(py::len(attributes));
299436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
300ed9e52f3SAlex Zinenko             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
301436c6c9cSStella Laurenzo           }
302436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
303436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
304436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
305436c6c9cSStella Laurenzo         },
306436c6c9cSStella Laurenzo         py::arg("attributes"), py::arg("context") = py::none(),
307436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
308436c6c9cSStella Laurenzo     c.def("__getitem__",
309436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
310436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
311436c6c9cSStella Laurenzo               throw py::index_error("ArrayAttribute index out of range");
312ed9e52f3SAlex Zinenko             return arr.getItem(i);
313436c6c9cSStella Laurenzo           })
314436c6c9cSStella Laurenzo         .def("__len__",
315436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
316436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
317436c6c9cSStella Laurenzo              })
318436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
319436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
320436c6c9cSStella Laurenzo         });
321ed9e52f3SAlex Zinenko     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
322ed9e52f3SAlex Zinenko       std::vector<MlirAttribute> attributes;
323ed9e52f3SAlex Zinenko       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
324ed9e52f3SAlex Zinenko       attributes.reserve(numOldElements + py::len(extras));
325ed9e52f3SAlex Zinenko       for (intptr_t i = 0; i < numOldElements; ++i)
326ed9e52f3SAlex Zinenko         attributes.push_back(arr.getItem(i));
327ed9e52f3SAlex Zinenko       for (py::handle attr : extras)
328ed9e52f3SAlex Zinenko         attributes.push_back(pyTryCast<PyAttribute>(attr));
329ed9e52f3SAlex Zinenko       MlirAttribute arrayAttr = mlirArrayAttrGet(
330ed9e52f3SAlex Zinenko           arr.getContext()->get(), attributes.size(), attributes.data());
331ed9e52f3SAlex Zinenko       return PyArrayAttribute(arr.getContext(), arrayAttr);
332ed9e52f3SAlex Zinenko     });
333436c6c9cSStella Laurenzo   }
334436c6c9cSStella Laurenzo };
335436c6c9cSStella Laurenzo 
336436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
337436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
338436c6c9cSStella Laurenzo public:
339436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
340436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
341436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
342436c6c9cSStella Laurenzo 
343436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
344436c6c9cSStella Laurenzo     c.def_static(
345436c6c9cSStella Laurenzo         "get",
346436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
347436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
348436c6c9cSStella Laurenzo           // TODO: Rework error reporting once diagnostic engine is exposed
349436c6c9cSStella Laurenzo           // in C API.
350436c6c9cSStella Laurenzo           if (mlirAttributeIsNull(attr)) {
351436c6c9cSStella Laurenzo             throw SetPyError(PyExc_ValueError,
352436c6c9cSStella Laurenzo                              Twine("invalid '") +
353436c6c9cSStella Laurenzo                                  py::repr(py::cast(type)).cast<std::string>() +
354436c6c9cSStella Laurenzo                                  "' and expected floating point type.");
355436c6c9cSStella Laurenzo           }
356436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
357436c6c9cSStella Laurenzo         },
358436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
359436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
360436c6c9cSStella Laurenzo     c.def_static(
361436c6c9cSStella Laurenzo         "get_f32",
362436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
363436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
364436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
365436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
366436c6c9cSStella Laurenzo         },
367436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
368436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
369436c6c9cSStella Laurenzo     c.def_static(
370436c6c9cSStella Laurenzo         "get_f64",
371436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
372436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
373436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
374436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
375436c6c9cSStella Laurenzo         },
376436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
377436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
378436c6c9cSStella Laurenzo     c.def_property_readonly(
379436c6c9cSStella Laurenzo         "value",
380436c6c9cSStella Laurenzo         [](PyFloatAttribute &self) {
381436c6c9cSStella Laurenzo           return mlirFloatAttrGetValueDouble(self);
382436c6c9cSStella Laurenzo         },
383436c6c9cSStella Laurenzo         "Returns the value of the float point attribute");
384436c6c9cSStella Laurenzo   }
385436c6c9cSStella Laurenzo };
386436c6c9cSStella Laurenzo 
387436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
388436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
389436c6c9cSStella Laurenzo public:
390436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
391436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
392436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
393436c6c9cSStella Laurenzo 
394436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
395436c6c9cSStella Laurenzo     c.def_static(
396436c6c9cSStella Laurenzo         "get",
397436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
398436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
399436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
400436c6c9cSStella Laurenzo         },
401436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"),
402436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
403436c6c9cSStella Laurenzo     c.def_property_readonly(
404436c6c9cSStella Laurenzo         "value",
405e9db306dSrkayaith         [](PyIntegerAttribute &self) -> py::int_ {
406e9db306dSrkayaith           MlirType type = mlirAttributeGetType(self);
407e9db306dSrkayaith           if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
408436c6c9cSStella Laurenzo             return mlirIntegerAttrGetValueInt(self);
409e9db306dSrkayaith           if (mlirIntegerTypeIsSigned(type))
410e9db306dSrkayaith             return mlirIntegerAttrGetValueSInt(self);
411e9db306dSrkayaith           return mlirIntegerAttrGetValueUInt(self);
412436c6c9cSStella Laurenzo         },
413436c6c9cSStella Laurenzo         "Returns the value of the integer attribute");
414436c6c9cSStella Laurenzo   }
415436c6c9cSStella Laurenzo };
416436c6c9cSStella Laurenzo 
417436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
418436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
419436c6c9cSStella Laurenzo public:
420436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
421436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
422436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
423436c6c9cSStella Laurenzo 
424436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
425436c6c9cSStella Laurenzo     c.def_static(
426436c6c9cSStella Laurenzo         "get",
427436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
428436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
429436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
430436c6c9cSStella Laurenzo         },
431436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
432436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
433436c6c9cSStella Laurenzo     c.def_property_readonly(
434436c6c9cSStella Laurenzo         "value",
435436c6c9cSStella Laurenzo         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
436436c6c9cSStella Laurenzo         "Returns the value of the bool attribute");
437436c6c9cSStella Laurenzo   }
438436c6c9cSStella Laurenzo };
439436c6c9cSStella Laurenzo 
440436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
441436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
442436c6c9cSStella Laurenzo public:
443436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
444436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
445436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
446436c6c9cSStella Laurenzo 
447436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
448436c6c9cSStella Laurenzo     c.def_static(
449436c6c9cSStella Laurenzo         "get",
450436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
451436c6c9cSStella Laurenzo           MlirAttribute attr =
452436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
453436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
454436c6c9cSStella Laurenzo         },
455436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
456436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
457436c6c9cSStella Laurenzo     c.def_property_readonly(
458436c6c9cSStella Laurenzo         "value",
459436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
460436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
461436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
462436c6c9cSStella Laurenzo         },
463436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
464436c6c9cSStella Laurenzo   }
465436c6c9cSStella Laurenzo };
466436c6c9cSStella Laurenzo 
4675c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
4685c3861b2SYun Long public:
4695c3861b2SYun Long   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
4705c3861b2SYun Long   static constexpr const char *pyClassName = "OpaqueAttr";
4715c3861b2SYun Long   using PyConcreteAttribute::PyConcreteAttribute;
4725c3861b2SYun Long 
4735c3861b2SYun Long   static void bindDerived(ClassTy &c) {
4745c3861b2SYun Long     c.def_static(
4755c3861b2SYun Long         "get",
4765c3861b2SYun Long         [](std::string dialectNamespace, py::buffer buffer, PyType &type,
4775c3861b2SYun Long            DefaultingPyMlirContext context) {
4785c3861b2SYun Long           const py::buffer_info bufferInfo = buffer.request();
4795c3861b2SYun Long           intptr_t bufferSize = bufferInfo.size;
4805c3861b2SYun Long           MlirAttribute attr = mlirOpaqueAttrGet(
4815c3861b2SYun Long               context->get(), toMlirStringRef(dialectNamespace), bufferSize,
4825c3861b2SYun Long               static_cast<char *>(bufferInfo.ptr), type);
4835c3861b2SYun Long           return PyOpaqueAttribute(context->getRef(), attr);
4845c3861b2SYun Long         },
4855c3861b2SYun Long         py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"),
4865c3861b2SYun Long         py::arg("context") = py::none(), "Gets an Opaque attribute.");
4875c3861b2SYun Long     c.def_property_readonly(
4885c3861b2SYun Long         "dialect_namespace",
4895c3861b2SYun Long         [](PyOpaqueAttribute &self) {
4905c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
4915c3861b2SYun Long           return py::str(stringRef.data, stringRef.length);
4925c3861b2SYun Long         },
4935c3861b2SYun Long         "Returns the dialect namespace for the Opaque attribute as a string");
4945c3861b2SYun Long     c.def_property_readonly(
4955c3861b2SYun Long         "data",
4965c3861b2SYun Long         [](PyOpaqueAttribute &self) {
4975c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
4985c3861b2SYun Long           return py::str(stringRef.data, stringRef.length);
4995c3861b2SYun Long         },
5005c3861b2SYun Long         "Returns the data for the Opaqued attributes as a string");
5015c3861b2SYun Long   }
5025c3861b2SYun Long };
5035c3861b2SYun Long 
504436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
505436c6c9cSStella Laurenzo public:
506436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
507436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
508436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
509436c6c9cSStella Laurenzo 
510436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
511436c6c9cSStella Laurenzo     c.def_static(
512436c6c9cSStella Laurenzo         "get",
513436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
514436c6c9cSStella Laurenzo           MlirAttribute attr =
515436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
516436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
517436c6c9cSStella Laurenzo         },
518436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
519436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
520436c6c9cSStella Laurenzo     c.def_static(
521436c6c9cSStella Laurenzo         "get_typed",
522436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
523436c6c9cSStella Laurenzo           MlirAttribute attr =
524436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
525436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
526436c6c9cSStella Laurenzo         },
527a6e7d024SStella Laurenzo         py::arg("type"), py::arg("value"),
528436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
529436c6c9cSStella Laurenzo     c.def_property_readonly(
530436c6c9cSStella Laurenzo         "value",
531436c6c9cSStella Laurenzo         [](PyStringAttribute &self) {
532436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirStringAttrGetValue(self);
533436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
534436c6c9cSStella Laurenzo         },
535436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
536436c6c9cSStella Laurenzo   }
537436c6c9cSStella Laurenzo };
538436c6c9cSStella Laurenzo 
539436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
540436c6c9cSStella Laurenzo class PyDenseElementsAttribute
541436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
542436c6c9cSStella Laurenzo public:
543436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
544436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
545436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
546436c6c9cSStella Laurenzo 
547436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
5485d6d30edSStella Laurenzo   getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
5495d6d30edSStella Laurenzo                 Optional<std::vector<int64_t>> explicitShape,
550436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
551436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
552436c6c9cSStella Laurenzo     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
553436c6c9cSStella Laurenzo     Py_buffer *view = new Py_buffer();
554436c6c9cSStella Laurenzo     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
555436c6c9cSStella Laurenzo       delete view;
556436c6c9cSStella Laurenzo       throw py::error_already_set();
557436c6c9cSStella Laurenzo     }
558436c6c9cSStella Laurenzo     py::buffer_info arrayInfo(view);
5595d6d30edSStella Laurenzo     SmallVector<int64_t> shape;
5605d6d30edSStella Laurenzo     if (explicitShape) {
5615d6d30edSStella Laurenzo       shape.append(explicitShape->begin(), explicitShape->end());
5625d6d30edSStella Laurenzo     } else {
5635d6d30edSStella Laurenzo       shape.append(arrayInfo.shape.begin(),
5645d6d30edSStella Laurenzo                    arrayInfo.shape.begin() + arrayInfo.ndim);
5655d6d30edSStella Laurenzo     }
566436c6c9cSStella Laurenzo 
5675d6d30edSStella Laurenzo     MlirAttribute encodingAttr = mlirAttributeGetNull();
568436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
5695d6d30edSStella Laurenzo 
5705d6d30edSStella Laurenzo     // Detect format codes that are suitable for bulk loading. This includes
5715d6d30edSStella Laurenzo     // all byte aligned integer and floating point types up to 8 bytes.
5725d6d30edSStella Laurenzo     // Notably, this excludes, bool (which needs to be bit-packed) and
5735d6d30edSStella Laurenzo     // other exotics which do not have a direct representation in the buffer
5745d6d30edSStella Laurenzo     // protocol (i.e. complex, etc).
5755d6d30edSStella Laurenzo     Optional<MlirType> bulkLoadElementType;
5765d6d30edSStella Laurenzo     if (explicitType) {
5775d6d30edSStella Laurenzo       bulkLoadElementType = *explicitType;
5785d6d30edSStella Laurenzo     } else if (arrayInfo.format == "f") {
579436c6c9cSStella Laurenzo       // f32
580436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
5815d6d30edSStella Laurenzo       bulkLoadElementType = mlirF32TypeGet(context);
582436c6c9cSStella Laurenzo     } else if (arrayInfo.format == "d") {
583436c6c9cSStella Laurenzo       // f64
584436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
5855d6d30edSStella Laurenzo       bulkLoadElementType = mlirF64TypeGet(context);
5865d6d30edSStella Laurenzo     } else if (arrayInfo.format == "e") {
5875d6d30edSStella Laurenzo       // f16
5885d6d30edSStella Laurenzo       assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
5895d6d30edSStella Laurenzo       bulkLoadElementType = mlirF16TypeGet(context);
590436c6c9cSStella Laurenzo     } else if (isSignedIntegerFormat(arrayInfo.format)) {
591436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
592436c6c9cSStella Laurenzo         // i32
5935d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
594436c6c9cSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 32);
595436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
596436c6c9cSStella Laurenzo         // i64
5975d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
598436c6c9cSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 64);
5995d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 1) {
6005d6d30edSStella Laurenzo         // i8
6015d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
6025d6d30edSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 8);
6035d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 2) {
6045d6d30edSStella Laurenzo         // i16
6055d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
6065d6d30edSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 16);
607436c6c9cSStella Laurenzo       }
608436c6c9cSStella Laurenzo     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
609436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
610436c6c9cSStella Laurenzo         // unsigned i32
6115d6d30edSStella Laurenzo         bulkLoadElementType = signless
612436c6c9cSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 32)
613436c6c9cSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 32);
614436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
615436c6c9cSStella Laurenzo         // unsigned i64
6165d6d30edSStella Laurenzo         bulkLoadElementType = signless
617436c6c9cSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 64)
618436c6c9cSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 64);
6195d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 1) {
6205d6d30edSStella Laurenzo         // i8
6215d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
6225d6d30edSStella Laurenzo                                        : mlirIntegerTypeUnsignedGet(context, 8);
6235d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 2) {
6245d6d30edSStella Laurenzo         // i16
6255d6d30edSStella Laurenzo         bulkLoadElementType = signless
6265d6d30edSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 16)
6275d6d30edSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 16);
628436c6c9cSStella Laurenzo       }
629436c6c9cSStella Laurenzo     }
6305d6d30edSStella Laurenzo     if (bulkLoadElementType) {
6315d6d30edSStella Laurenzo       auto shapedType = mlirRankedTensorTypeGet(
6325d6d30edSStella Laurenzo           shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
6335d6d30edSStella Laurenzo       size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
6345d6d30edSStella Laurenzo       MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
6355d6d30edSStella Laurenzo           shapedType, rawBufferSize, arrayInfo.ptr);
6365d6d30edSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
6375d6d30edSStella Laurenzo         throw std::invalid_argument(
6385d6d30edSStella Laurenzo             "DenseElementsAttr could not be constructed from the given buffer. "
6395d6d30edSStella Laurenzo             "This may mean that the Python buffer layout does not match that "
6405d6d30edSStella Laurenzo             "MLIR expected layout and is a bug.");
6415d6d30edSStella Laurenzo       }
6425d6d30edSStella Laurenzo       return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
6435d6d30edSStella Laurenzo     }
644436c6c9cSStella Laurenzo 
6455d6d30edSStella Laurenzo     throw std::invalid_argument(
6465d6d30edSStella Laurenzo         std::string("unimplemented array format conversion from format: ") +
6475d6d30edSStella Laurenzo         arrayInfo.format);
648436c6c9cSStella Laurenzo   }
649436c6c9cSStella Laurenzo 
6501fc096afSMehdi Amini   static PyDenseElementsAttribute getSplat(const PyType &shapedType,
651436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
652436c6c9cSStella Laurenzo     auto contextWrapper =
653436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
654436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
655436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
656436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
657436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
658436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
659436c6c9cSStella Laurenzo     }
660436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
661436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
662436c6c9cSStella Laurenzo       std::string message =
663436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
664436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
665436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
666436c6c9cSStella Laurenzo     }
667436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
668436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
669436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
670436c6c9cSStella Laurenzo       std::string message =
671436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
672436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
673436c6c9cSStella Laurenzo       message.append(", element=");
674436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
675436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
676436c6c9cSStella Laurenzo     }
677436c6c9cSStella Laurenzo 
678436c6c9cSStella Laurenzo     MlirAttribute elements =
679436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
680436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
681436c6c9cSStella Laurenzo   }
682436c6c9cSStella Laurenzo 
683436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
684436c6c9cSStella Laurenzo 
685436c6c9cSStella Laurenzo   py::buffer_info accessBuffer() {
6865d6d30edSStella Laurenzo     if (mlirDenseElementsAttrIsSplat(*this)) {
687c5f445d1SStella Laurenzo       // TODO: Currently crashes the program.
6885d6d30edSStella Laurenzo       // Reported as https://github.com/pybind/pybind11/issues/3336
689c5f445d1SStella Laurenzo       throw std::invalid_argument(
690c5f445d1SStella Laurenzo           "unsupported data type for conversion to Python buffer");
6915d6d30edSStella Laurenzo     }
6925d6d30edSStella Laurenzo 
693436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
694436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
6955d6d30edSStella Laurenzo     std::string format;
696436c6c9cSStella Laurenzo 
697436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
698436c6c9cSStella Laurenzo       // f32
6995d6d30edSStella Laurenzo       return bufferInfo<float>(shapedType);
70002b6fb21SMehdi Amini     }
70102b6fb21SMehdi Amini     if (mlirTypeIsAF64(elementType)) {
702436c6c9cSStella Laurenzo       // f64
7035d6d30edSStella Laurenzo       return bufferInfo<double>(shapedType);
704bb56c2b3SMehdi Amini     }
705bb56c2b3SMehdi Amini     if (mlirTypeIsAF16(elementType)) {
7065d6d30edSStella Laurenzo       // f16
7075d6d30edSStella Laurenzo       return bufferInfo<uint16_t>(shapedType, "e");
708bb56c2b3SMehdi Amini     }
709bb56c2b3SMehdi Amini     if (mlirTypeIsAInteger(elementType) &&
710436c6c9cSStella Laurenzo         mlirIntegerTypeGetWidth(elementType) == 32) {
711436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
712436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
713436c6c9cSStella Laurenzo         // i32
7145d6d30edSStella Laurenzo         return bufferInfo<int32_t>(shapedType);
715e5639b3fSMehdi Amini       }
716e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
717436c6c9cSStella Laurenzo         // unsigned i32
7185d6d30edSStella Laurenzo         return bufferInfo<uint32_t>(shapedType);
719436c6c9cSStella Laurenzo       }
720436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
721436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
722436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
723436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
724436c6c9cSStella Laurenzo         // i64
7255d6d30edSStella Laurenzo         return bufferInfo<int64_t>(shapedType);
726e5639b3fSMehdi Amini       }
727e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
728436c6c9cSStella Laurenzo         // unsigned i64
7295d6d30edSStella Laurenzo         return bufferInfo<uint64_t>(shapedType);
7305d6d30edSStella Laurenzo       }
7315d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
7325d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 8) {
7335d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
7345d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
7355d6d30edSStella Laurenzo         // i8
7365d6d30edSStella Laurenzo         return bufferInfo<int8_t>(shapedType);
737e5639b3fSMehdi Amini       }
738e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
7395d6d30edSStella Laurenzo         // unsigned i8
7405d6d30edSStella Laurenzo         return bufferInfo<uint8_t>(shapedType);
7415d6d30edSStella Laurenzo       }
7425d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
7435d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 16) {
7445d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
7455d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
7465d6d30edSStella Laurenzo         // i16
7475d6d30edSStella Laurenzo         return bufferInfo<int16_t>(shapedType);
748e5639b3fSMehdi Amini       }
749e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
7505d6d30edSStella Laurenzo         // unsigned i16
7515d6d30edSStella Laurenzo         return bufferInfo<uint16_t>(shapedType);
752436c6c9cSStella Laurenzo       }
753436c6c9cSStella Laurenzo     }
754436c6c9cSStella Laurenzo 
755c5f445d1SStella Laurenzo     // TODO: Currently crashes the program.
7565d6d30edSStella Laurenzo     // Reported as https://github.com/pybind/pybind11/issues/3336
757c5f445d1SStella Laurenzo     throw std::invalid_argument(
758c5f445d1SStella Laurenzo         "unsupported data type for conversion to Python buffer");
759436c6c9cSStella Laurenzo   }
760436c6c9cSStella Laurenzo 
761436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
762436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
763436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
764436c6c9cSStella Laurenzo                     py::arg("array"), py::arg("signless") = true,
7655d6d30edSStella Laurenzo                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
766436c6c9cSStella Laurenzo                     py::arg("context") = py::none(),
7675d6d30edSStella Laurenzo                     kDenseElementsAttrGetDocstring)
768436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
769436c6c9cSStella Laurenzo                     py::arg("shaped_type"), py::arg("element_attr"),
770436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
771436c6c9cSStella Laurenzo         .def_property_readonly("is_splat",
772436c6c9cSStella Laurenzo                                [](PyDenseElementsAttribute &self) -> bool {
773436c6c9cSStella Laurenzo                                  return mlirDenseElementsAttrIsSplat(self);
774436c6c9cSStella Laurenzo                                })
775436c6c9cSStella Laurenzo         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
776436c6c9cSStella Laurenzo   }
777436c6c9cSStella Laurenzo 
778436c6c9cSStella Laurenzo private:
779436c6c9cSStella Laurenzo   static bool isUnsignedIntegerFormat(const std::string &format) {
780436c6c9cSStella Laurenzo     if (format.empty())
781436c6c9cSStella Laurenzo       return false;
782436c6c9cSStella Laurenzo     char code = format[0];
783436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
784436c6c9cSStella Laurenzo            code == 'Q';
785436c6c9cSStella Laurenzo   }
786436c6c9cSStella Laurenzo 
787436c6c9cSStella Laurenzo   static bool isSignedIntegerFormat(const std::string &format) {
788436c6c9cSStella Laurenzo     if (format.empty())
789436c6c9cSStella Laurenzo       return false;
790436c6c9cSStella Laurenzo     char code = format[0];
791436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
792436c6c9cSStella Laurenzo            code == 'q';
793436c6c9cSStella Laurenzo   }
794436c6c9cSStella Laurenzo 
795436c6c9cSStella Laurenzo   template <typename Type>
796436c6c9cSStella Laurenzo   py::buffer_info bufferInfo(MlirType shapedType,
7975d6d30edSStella Laurenzo                              const char *explicitFormat = nullptr) {
798436c6c9cSStella Laurenzo     intptr_t rank = mlirShapedTypeGetRank(shapedType);
799436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
800436c6c9cSStella Laurenzo     // Buffer is configured for read-only access below.
801436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
802436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
803436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
804436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
805436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
806436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
807436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
808436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
809436c6c9cSStella Laurenzo     intptr_t strideFactor = 1;
810436c6c9cSStella Laurenzo     for (intptr_t i = 1; i < rank; ++i) {
811436c6c9cSStella Laurenzo       strideFactor = 1;
812436c6c9cSStella Laurenzo       for (intptr_t j = i; j < rank; ++j) {
813436c6c9cSStella Laurenzo         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
814436c6c9cSStella Laurenzo       }
815436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type) * strideFactor);
816436c6c9cSStella Laurenzo     }
817436c6c9cSStella Laurenzo     strides.push_back(sizeof(Type));
8185d6d30edSStella Laurenzo     std::string format;
8195d6d30edSStella Laurenzo     if (explicitFormat) {
8205d6d30edSStella Laurenzo       format = explicitFormat;
8215d6d30edSStella Laurenzo     } else {
8225d6d30edSStella Laurenzo       format = py::format_descriptor<Type>::format();
8235d6d30edSStella Laurenzo     }
8245d6d30edSStella Laurenzo     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
8255d6d30edSStella Laurenzo                            /*readonly=*/true);
826436c6c9cSStella Laurenzo   }
827436c6c9cSStella Laurenzo }; // namespace
828436c6c9cSStella Laurenzo 
829436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer
830436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access.
831436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
832436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
833436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
834436c6c9cSStella Laurenzo public:
835436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
836436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
837436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
838436c6c9cSStella Laurenzo 
839436c6c9cSStella Laurenzo   /// Returns the element at the given linear position. Asserts if the index is
840436c6c9cSStella Laurenzo   /// out of range.
841436c6c9cSStella Laurenzo   py::int_ dunderGetItem(intptr_t pos) {
842436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
843436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
844436c6c9cSStella Laurenzo                        "attempt to access out of bounds element");
845436c6c9cSStella Laurenzo     }
846436c6c9cSStella Laurenzo 
847436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
848436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
849436c6c9cSStella Laurenzo     assert(mlirTypeIsAInteger(type) &&
850436c6c9cSStella Laurenzo            "expected integer element type in dense int elements attribute");
851436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
852436c6c9cSStella Laurenzo     // elemental type of the attribute. py::int_ is implicitly constructible
853436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
854436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
855436c6c9cSStella Laurenzo     // querying them on each element access.
856436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
857436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
858436c6c9cSStella Laurenzo     if (isUnsigned) {
859436c6c9cSStella Laurenzo       if (width == 1) {
860436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
861436c6c9cSStella Laurenzo       }
862308d8b8cSRahul Kayaith       if (width == 8) {
863308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt8Value(*this, pos);
864308d8b8cSRahul Kayaith       }
865308d8b8cSRahul Kayaith       if (width == 16) {
866308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt16Value(*this, pos);
867308d8b8cSRahul Kayaith       }
868436c6c9cSStella Laurenzo       if (width == 32) {
869436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
870436c6c9cSStella Laurenzo       }
871436c6c9cSStella Laurenzo       if (width == 64) {
872436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
873436c6c9cSStella Laurenzo       }
874436c6c9cSStella Laurenzo     } else {
875436c6c9cSStella Laurenzo       if (width == 1) {
876436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
877436c6c9cSStella Laurenzo       }
878308d8b8cSRahul Kayaith       if (width == 8) {
879308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt8Value(*this, pos);
880308d8b8cSRahul Kayaith       }
881308d8b8cSRahul Kayaith       if (width == 16) {
882308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt16Value(*this, pos);
883308d8b8cSRahul Kayaith       }
884436c6c9cSStella Laurenzo       if (width == 32) {
885436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt32Value(*this, pos);
886436c6c9cSStella Laurenzo       }
887436c6c9cSStella Laurenzo       if (width == 64) {
888436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt64Value(*this, pos);
889436c6c9cSStella Laurenzo       }
890436c6c9cSStella Laurenzo     }
891436c6c9cSStella Laurenzo     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
892436c6c9cSStella Laurenzo   }
893436c6c9cSStella Laurenzo 
894436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
895436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
896436c6c9cSStella Laurenzo   }
897436c6c9cSStella Laurenzo };
898436c6c9cSStella Laurenzo 
899436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
900436c6c9cSStella Laurenzo public:
901436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
902436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
903436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
904436c6c9cSStella Laurenzo 
905436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
906436c6c9cSStella Laurenzo 
9079fb1086bSAdrian Kuegel   bool dunderContains(const std::string &name) {
9089fb1086bSAdrian Kuegel     return !mlirAttributeIsNull(
9099fb1086bSAdrian Kuegel         mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
9109fb1086bSAdrian Kuegel   }
9119fb1086bSAdrian Kuegel 
912436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
9139fb1086bSAdrian Kuegel     c.def("__contains__", &PyDictAttribute::dunderContains);
914436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
915436c6c9cSStella Laurenzo     c.def_static(
916436c6c9cSStella Laurenzo         "get",
917436c6c9cSStella Laurenzo         [](py::dict attributes, DefaultingPyMlirContext context) {
918436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
919436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
920436c6c9cSStella Laurenzo           for (auto &it : attributes) {
92102b6fb21SMehdi Amini             auto &mlirAttr = it.second.cast<PyAttribute &>();
922436c6c9cSStella Laurenzo             auto name = it.first.cast<std::string>();
923436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
92402b6fb21SMehdi Amini                 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
925436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
92602b6fb21SMehdi Amini                 mlirAttr));
927436c6c9cSStella Laurenzo           }
928436c6c9cSStella Laurenzo           MlirAttribute attr =
929436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
930436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
931436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
932436c6c9cSStella Laurenzo         },
933ed9e52f3SAlex Zinenko         py::arg("value") = py::dict(), py::arg("context") = py::none(),
934436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
935436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
936436c6c9cSStella Laurenzo       MlirAttribute attr =
937436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
938436c6c9cSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
939436c6c9cSStella Laurenzo         throw SetPyError(PyExc_KeyError,
940436c6c9cSStella Laurenzo                          "attempt to access a non-existent attribute");
941436c6c9cSStella Laurenzo       }
942436c6c9cSStella Laurenzo       return PyAttribute(self.getContext(), attr);
943436c6c9cSStella Laurenzo     });
944436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
945436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
946436c6c9cSStella Laurenzo         throw SetPyError(PyExc_IndexError,
947436c6c9cSStella Laurenzo                          "attempt to access out of bounds attribute");
948436c6c9cSStella Laurenzo       }
949436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
950436c6c9cSStella Laurenzo       return PyNamedAttribute(
951436c6c9cSStella Laurenzo           namedAttr.attribute,
952436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
953436c6c9cSStella Laurenzo     });
954436c6c9cSStella Laurenzo   }
955436c6c9cSStella Laurenzo };
956436c6c9cSStella Laurenzo 
957436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
958436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
959436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
960436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
961436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
962436c6c9cSStella Laurenzo public:
963436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
964436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
965436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
966436c6c9cSStella Laurenzo 
967436c6c9cSStella Laurenzo   py::float_ dunderGetItem(intptr_t pos) {
968436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
969436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
970436c6c9cSStella Laurenzo                        "attempt to access out of bounds element");
971436c6c9cSStella Laurenzo     }
972436c6c9cSStella Laurenzo 
973436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
974436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
975436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
976436c6c9cSStella Laurenzo     // elemental type of the attribute. py::float_ is implicitly constructible
977436c6c9cSStella Laurenzo     // from float and double.
978436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
979436c6c9cSStella Laurenzo     // querying them on each element access.
980436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
981436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetFloatValue(*this, pos);
982436c6c9cSStella Laurenzo     }
983436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
984436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
985436c6c9cSStella Laurenzo     }
986436c6c9cSStella Laurenzo     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
987436c6c9cSStella Laurenzo   }
988436c6c9cSStella Laurenzo 
989436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
990436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
991436c6c9cSStella Laurenzo   }
992436c6c9cSStella Laurenzo };
993436c6c9cSStella Laurenzo 
994436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
995436c6c9cSStella Laurenzo public:
996436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
997436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
998436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
999436c6c9cSStella Laurenzo 
1000436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1001436c6c9cSStella Laurenzo     c.def_static(
1002436c6c9cSStella Laurenzo         "get",
1003436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
1004436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
1005436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
1006436c6c9cSStella Laurenzo         },
1007436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
1008436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
1009436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyTypeAttribute &self) {
1010436c6c9cSStella Laurenzo       return PyType(self.getContext()->getRef(),
1011436c6c9cSStella Laurenzo                     mlirTypeAttrGetValue(self.get()));
1012436c6c9cSStella Laurenzo     });
1013436c6c9cSStella Laurenzo   }
1014436c6c9cSStella Laurenzo };
1015436c6c9cSStella Laurenzo 
1016436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
1017436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1018436c6c9cSStella Laurenzo public:
1019436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1020436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
1021436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1022436c6c9cSStella Laurenzo 
1023436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1024436c6c9cSStella Laurenzo     c.def_static(
1025436c6c9cSStella Laurenzo         "get",
1026436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
1027436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
1028436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
1029436c6c9cSStella Laurenzo         },
1030436c6c9cSStella Laurenzo         py::arg("context") = py::none(), "Create a Unit attribute.");
1031436c6c9cSStella Laurenzo   }
1032436c6c9cSStella Laurenzo };
1033436c6c9cSStella Laurenzo 
1034*ac2e2d65SDenys Shabalin /// Strided layout attribute subclass.
1035*ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute
1036*ac2e2d65SDenys Shabalin     : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1037*ac2e2d65SDenys Shabalin public:
1038*ac2e2d65SDenys Shabalin   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1039*ac2e2d65SDenys Shabalin   static constexpr const char *pyClassName = "StridedLayoutAttr";
1040*ac2e2d65SDenys Shabalin   using PyConcreteAttribute::PyConcreteAttribute;
1041*ac2e2d65SDenys Shabalin 
1042*ac2e2d65SDenys Shabalin   static void bindDerived(ClassTy &c) {
1043*ac2e2d65SDenys Shabalin     c.def_static(
1044*ac2e2d65SDenys Shabalin         "get",
1045*ac2e2d65SDenys Shabalin         [](int64_t offset, const std::vector<int64_t> strides,
1046*ac2e2d65SDenys Shabalin            DefaultingPyMlirContext ctx) {
1047*ac2e2d65SDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1048*ac2e2d65SDenys Shabalin               ctx->get(), offset, strides.size(), strides.data());
1049*ac2e2d65SDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1050*ac2e2d65SDenys Shabalin         },
1051*ac2e2d65SDenys Shabalin         py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
1052*ac2e2d65SDenys Shabalin         "Gets a strided layout attribute.");
1053*ac2e2d65SDenys Shabalin     c.def_property_readonly(
1054*ac2e2d65SDenys Shabalin         "offset",
1055*ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1056*ac2e2d65SDenys Shabalin           return mlirStridedLayoutAttrGetOffset(self);
1057*ac2e2d65SDenys Shabalin         },
1058*ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1059*ac2e2d65SDenys Shabalin     c.def_property_readonly(
1060*ac2e2d65SDenys Shabalin         "strides",
1061*ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1062*ac2e2d65SDenys Shabalin           intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1063*ac2e2d65SDenys Shabalin           std::vector<int64_t> strides(size);
1064*ac2e2d65SDenys Shabalin           for (intptr_t i = 0; i < size; i++) {
1065*ac2e2d65SDenys Shabalin             strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1066*ac2e2d65SDenys Shabalin           }
1067*ac2e2d65SDenys Shabalin           return strides;
1068*ac2e2d65SDenys Shabalin         },
1069*ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1070*ac2e2d65SDenys Shabalin   }
1071*ac2e2d65SDenys Shabalin };
1072*ac2e2d65SDenys Shabalin 
1073436c6c9cSStella Laurenzo } // namespace
1074436c6c9cSStella Laurenzo 
1075436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) {
1076436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
1077619fd8c2SJeff Niu 
1078619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::bind(m);
1079619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1080619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::bind(m);
1081619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1082619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::bind(m);
1083619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1084619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::bind(m);
1085619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1086619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::bind(m);
1087619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1088619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::bind(m);
1089619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1090619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::bind(m);
1091619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
1092619fd8c2SJeff Niu 
1093436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
1094436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1095436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
1096436c6c9cSStella Laurenzo   PyDenseElementsAttribute::bind(m);
1097436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
1098436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
1099436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
1100436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
11015c3861b2SYun Long   PyOpaqueAttribute::bind(m);
1102436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
1103436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
1104436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
1105436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
1106436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
1107*ac2e2d65SDenys Shabalin 
1108*ac2e2d65SDenys Shabalin   PyStridedLayoutAttribute::bind(m);
1109436c6c9cSStella Laurenzo }
1110