xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision 7714b405a0de47e461c77fa8dbd2c21f0d34bbf2)
1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9436c6c9cSStella Laurenzo #include "IRModule.h"
10436c6c9cSStella Laurenzo 
11436c6c9cSStella Laurenzo #include "PybindUtils.h"
12436c6c9cSStella Laurenzo 
13436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
14436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
15436c6c9cSStella Laurenzo 
16436c6c9cSStella Laurenzo namespace py = pybind11;
17436c6c9cSStella Laurenzo using namespace mlir;
18436c6c9cSStella Laurenzo using namespace mlir::python;
19436c6c9cSStella Laurenzo 
20436c6c9cSStella Laurenzo using llvm::SmallVector;
21436c6c9cSStella Laurenzo using llvm::StringRef;
22436c6c9cSStella Laurenzo using llvm::Twine;
23436c6c9cSStella Laurenzo 
24436c6c9cSStella Laurenzo namespace {
25436c6c9cSStella Laurenzo 
26436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
27436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
28436c6c9cSStella Laurenzo }
29436c6c9cSStella Laurenzo 
30436c6c9cSStella Laurenzo /// CRTP base classes for Python attributes that subclass Attribute and should
31436c6c9cSStella Laurenzo /// be castable from it (i.e. via something like StringAttr(attr)).
32436c6c9cSStella Laurenzo /// By default, attribute class hierarchies are one level deep (i.e. a
33436c6c9cSStella Laurenzo /// concrete attribute class extends PyAttribute); however, intermediate
34436c6c9cSStella Laurenzo /// python-visible base classes can be modeled by specifying a BaseTy.
35436c6c9cSStella Laurenzo template <typename DerivedTy, typename BaseTy = PyAttribute>
36436c6c9cSStella Laurenzo class PyConcreteAttribute : public BaseTy {
37436c6c9cSStella Laurenzo public:
38436c6c9cSStella Laurenzo   // Derived classes must define statics for:
39436c6c9cSStella Laurenzo   //   IsAFunctionTy isaFunction
40436c6c9cSStella Laurenzo   //   const char *pyClassName
41436c6c9cSStella Laurenzo   using ClassTy = py::class_<DerivedTy, BaseTy>;
42436c6c9cSStella Laurenzo   using IsAFunctionTy = bool (*)(MlirAttribute);
43436c6c9cSStella Laurenzo 
44436c6c9cSStella Laurenzo   PyConcreteAttribute() = default;
45436c6c9cSStella Laurenzo   PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
46436c6c9cSStella Laurenzo       : BaseTy(std::move(contextRef), attr) {}
47436c6c9cSStella Laurenzo   PyConcreteAttribute(PyAttribute &orig)
48436c6c9cSStella Laurenzo       : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
49436c6c9cSStella Laurenzo 
50436c6c9cSStella Laurenzo   static MlirAttribute castFrom(PyAttribute &orig) {
51436c6c9cSStella Laurenzo     if (!DerivedTy::isaFunction(orig)) {
52436c6c9cSStella Laurenzo       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
53436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
54436c6c9cSStella Laurenzo                                              DerivedTy::pyClassName +
55436c6c9cSStella Laurenzo                                              " (from " + origRepr + ")");
56436c6c9cSStella Laurenzo     }
57436c6c9cSStella Laurenzo     return orig;
58436c6c9cSStella Laurenzo   }
59436c6c9cSStella Laurenzo 
60436c6c9cSStella Laurenzo   static void bind(py::module &m) {
61436c6c9cSStella Laurenzo     auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol());
62436c6c9cSStella Laurenzo     cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
63436c6c9cSStella Laurenzo     DerivedTy::bindDerived(cls);
64436c6c9cSStella Laurenzo   }
65436c6c9cSStella Laurenzo 
66436c6c9cSStella Laurenzo   /// Implemented by derived classes to add methods to the Python subclass.
67436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &m) {}
68436c6c9cSStella Laurenzo };
69436c6c9cSStella Laurenzo 
70436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
71436c6c9cSStella Laurenzo public:
72436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
73436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
74436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
75436c6c9cSStella Laurenzo 
76436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
77436c6c9cSStella Laurenzo     c.def_static(
78436c6c9cSStella Laurenzo         "get",
79436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
80436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
81436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
82436c6c9cSStella Laurenzo         },
83436c6c9cSStella Laurenzo         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
84436c6c9cSStella Laurenzo   }
85436c6c9cSStella Laurenzo };
86436c6c9cSStella Laurenzo 
87436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
88436c6c9cSStella Laurenzo public:
89436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
90436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
91436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
92436c6c9cSStella Laurenzo 
93436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
94436c6c9cSStella Laurenzo   public:
95436c6c9cSStella Laurenzo     PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
96436c6c9cSStella Laurenzo 
97436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
98436c6c9cSStella Laurenzo 
99436c6c9cSStella Laurenzo     PyAttribute dunderNext() {
100436c6c9cSStella Laurenzo       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
101436c6c9cSStella Laurenzo         throw py::stop_iteration();
102436c6c9cSStella Laurenzo       }
103436c6c9cSStella Laurenzo       return PyAttribute(attr.getContext(),
104436c6c9cSStella Laurenzo                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
105436c6c9cSStella Laurenzo     }
106436c6c9cSStella Laurenzo 
107436c6c9cSStella Laurenzo     static void bind(py::module &m) {
108436c6c9cSStella Laurenzo       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
109436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
110436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
111436c6c9cSStella Laurenzo     }
112436c6c9cSStella Laurenzo 
113436c6c9cSStella Laurenzo   private:
114436c6c9cSStella Laurenzo     PyAttribute attr;
115436c6c9cSStella Laurenzo     int nextIndex = 0;
116436c6c9cSStella Laurenzo   };
117436c6c9cSStella Laurenzo 
118436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
119436c6c9cSStella Laurenzo     c.def_static(
120436c6c9cSStella Laurenzo         "get",
121436c6c9cSStella Laurenzo         [](py::list attributes, DefaultingPyMlirContext context) {
122436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
123436c6c9cSStella Laurenzo           mlirAttributes.reserve(py::len(attributes));
124436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
125436c6c9cSStella Laurenzo             try {
126436c6c9cSStella Laurenzo               mlirAttributes.push_back(attribute.cast<PyAttribute>());
127436c6c9cSStella Laurenzo             } catch (py::cast_error &err) {
128436c6c9cSStella Laurenzo               std::string msg = std::string("Invalid attribute when attempting "
129436c6c9cSStella Laurenzo                                             "to create an ArrayAttribute (") +
130436c6c9cSStella Laurenzo                                 err.what() + ")";
131436c6c9cSStella Laurenzo               throw py::cast_error(msg);
132436c6c9cSStella Laurenzo             } catch (py::reference_cast_error &err) {
133436c6c9cSStella Laurenzo               // This exception seems thrown when the value is "None".
134436c6c9cSStella Laurenzo               std::string msg =
135436c6c9cSStella Laurenzo                   std::string("Invalid attribute (None?) when attempting to "
136436c6c9cSStella Laurenzo                               "create an ArrayAttribute (") +
137436c6c9cSStella Laurenzo                   err.what() + ")";
138436c6c9cSStella Laurenzo               throw py::cast_error(msg);
139436c6c9cSStella Laurenzo             }
140436c6c9cSStella Laurenzo           }
141436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
142436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
143436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
144436c6c9cSStella Laurenzo         },
145436c6c9cSStella Laurenzo         py::arg("attributes"), py::arg("context") = py::none(),
146436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
147436c6c9cSStella Laurenzo     c.def("__getitem__",
148436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
149436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
150436c6c9cSStella Laurenzo               throw py::index_error("ArrayAttribute index out of range");
151436c6c9cSStella Laurenzo             return PyAttribute(arr.getContext(),
152436c6c9cSStella Laurenzo                                mlirArrayAttrGetElement(arr, i));
153436c6c9cSStella Laurenzo           })
154436c6c9cSStella Laurenzo         .def("__len__",
155436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
156436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
157436c6c9cSStella Laurenzo              })
158436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
159436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
160436c6c9cSStella Laurenzo         });
161436c6c9cSStella Laurenzo   }
162436c6c9cSStella Laurenzo };
163436c6c9cSStella Laurenzo 
164436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
165436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
166436c6c9cSStella Laurenzo public:
167436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
168436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
169436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
170436c6c9cSStella Laurenzo 
171436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
172436c6c9cSStella Laurenzo     c.def_static(
173436c6c9cSStella Laurenzo         "get",
174436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
175436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
176436c6c9cSStella Laurenzo           // TODO: Rework error reporting once diagnostic engine is exposed
177436c6c9cSStella Laurenzo           // in C API.
178436c6c9cSStella Laurenzo           if (mlirAttributeIsNull(attr)) {
179436c6c9cSStella Laurenzo             throw SetPyError(PyExc_ValueError,
180436c6c9cSStella Laurenzo                              Twine("invalid '") +
181436c6c9cSStella Laurenzo                                  py::repr(py::cast(type)).cast<std::string>() +
182436c6c9cSStella Laurenzo                                  "' and expected floating point type.");
183436c6c9cSStella Laurenzo           }
184436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
185436c6c9cSStella Laurenzo         },
186436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
187436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
188436c6c9cSStella Laurenzo     c.def_static(
189436c6c9cSStella Laurenzo         "get_f32",
190436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
191436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
192436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
193436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
194436c6c9cSStella Laurenzo         },
195436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
196436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
197436c6c9cSStella Laurenzo     c.def_static(
198436c6c9cSStella Laurenzo         "get_f64",
199436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
200436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
201436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
202436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
203436c6c9cSStella Laurenzo         },
204436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
205436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
206436c6c9cSStella Laurenzo     c.def_property_readonly(
207436c6c9cSStella Laurenzo         "value",
208436c6c9cSStella Laurenzo         [](PyFloatAttribute &self) {
209436c6c9cSStella Laurenzo           return mlirFloatAttrGetValueDouble(self);
210436c6c9cSStella Laurenzo         },
211436c6c9cSStella Laurenzo         "Returns the value of the float point attribute");
212436c6c9cSStella Laurenzo   }
213436c6c9cSStella Laurenzo };
214436c6c9cSStella Laurenzo 
215436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
216436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
217436c6c9cSStella Laurenzo public:
218436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
219436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
220436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
221436c6c9cSStella Laurenzo 
222436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
223436c6c9cSStella Laurenzo     c.def_static(
224436c6c9cSStella Laurenzo         "get",
225436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
226436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
227436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
228436c6c9cSStella Laurenzo         },
229436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"),
230436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
231436c6c9cSStella Laurenzo     c.def_property_readonly(
232436c6c9cSStella Laurenzo         "value",
233436c6c9cSStella Laurenzo         [](PyIntegerAttribute &self) {
234436c6c9cSStella Laurenzo           return mlirIntegerAttrGetValueInt(self);
235436c6c9cSStella Laurenzo         },
236436c6c9cSStella Laurenzo         "Returns the value of the integer attribute");
237436c6c9cSStella Laurenzo   }
238436c6c9cSStella Laurenzo };
239436c6c9cSStella Laurenzo 
240436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
241436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
242436c6c9cSStella Laurenzo public:
243436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
244436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
245436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
246436c6c9cSStella Laurenzo 
247436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
248436c6c9cSStella Laurenzo     c.def_static(
249436c6c9cSStella Laurenzo         "get",
250436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
251436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
252436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
253436c6c9cSStella Laurenzo         },
254436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
255436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
256436c6c9cSStella Laurenzo     c.def_property_readonly(
257436c6c9cSStella Laurenzo         "value",
258436c6c9cSStella Laurenzo         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
259436c6c9cSStella Laurenzo         "Returns the value of the bool attribute");
260436c6c9cSStella Laurenzo   }
261436c6c9cSStella Laurenzo };
262436c6c9cSStella Laurenzo 
263436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
264436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
265436c6c9cSStella Laurenzo public:
266436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
267436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
268436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
269436c6c9cSStella Laurenzo 
270436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
271436c6c9cSStella Laurenzo     c.def_static(
272436c6c9cSStella Laurenzo         "get",
273436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
274436c6c9cSStella Laurenzo           MlirAttribute attr =
275436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
276436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
277436c6c9cSStella Laurenzo         },
278436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
279436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
280436c6c9cSStella Laurenzo     c.def_property_readonly(
281436c6c9cSStella Laurenzo         "value",
282436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
283436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
284436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
285436c6c9cSStella Laurenzo         },
286436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
287436c6c9cSStella Laurenzo   }
288436c6c9cSStella Laurenzo };
289436c6c9cSStella Laurenzo 
290436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
291436c6c9cSStella Laurenzo public:
292436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
293436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
294436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
295436c6c9cSStella Laurenzo 
296436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
297436c6c9cSStella Laurenzo     c.def_static(
298436c6c9cSStella Laurenzo         "get",
299436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
300436c6c9cSStella Laurenzo           MlirAttribute attr =
301436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
302436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
303436c6c9cSStella Laurenzo         },
304436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
305436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
306436c6c9cSStella Laurenzo     c.def_static(
307436c6c9cSStella Laurenzo         "get_typed",
308436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
309436c6c9cSStella Laurenzo           MlirAttribute attr =
310436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
311436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
312436c6c9cSStella Laurenzo         },
313436c6c9cSStella Laurenzo 
314436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
315436c6c9cSStella Laurenzo     c.def_property_readonly(
316436c6c9cSStella Laurenzo         "value",
317436c6c9cSStella Laurenzo         [](PyStringAttribute &self) {
318436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirStringAttrGetValue(self);
319436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
320436c6c9cSStella Laurenzo         },
321436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
322436c6c9cSStella Laurenzo   }
323436c6c9cSStella Laurenzo };
324436c6c9cSStella Laurenzo 
325436c6c9cSStella Laurenzo // TODO: Support construction of bool elements.
326436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
327436c6c9cSStella Laurenzo class PyDenseElementsAttribute
328436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
329436c6c9cSStella Laurenzo public:
330436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
331436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
332436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
333436c6c9cSStella Laurenzo 
334436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
335436c6c9cSStella Laurenzo   getFromBuffer(py::buffer array, bool signless,
336436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
337436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
338436c6c9cSStella Laurenzo     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
339436c6c9cSStella Laurenzo     Py_buffer *view = new Py_buffer();
340436c6c9cSStella Laurenzo     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
341436c6c9cSStella Laurenzo       delete view;
342436c6c9cSStella Laurenzo       throw py::error_already_set();
343436c6c9cSStella Laurenzo     }
344436c6c9cSStella Laurenzo     py::buffer_info arrayInfo(view);
345436c6c9cSStella Laurenzo 
346436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
347436c6c9cSStella Laurenzo     // Switch on the types that can be bulk loaded between the Python and
348436c6c9cSStella Laurenzo     // MLIR-C APIs.
349436c6c9cSStella Laurenzo     // See: https://docs.python.org/3/library/struct.html#format-characters
350436c6c9cSStella Laurenzo     if (arrayInfo.format == "f") {
351436c6c9cSStella Laurenzo       // f32
352436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
353436c6c9cSStella Laurenzo       return PyDenseElementsAttribute(
354436c6c9cSStella Laurenzo           contextWrapper->getRef(),
355436c6c9cSStella Laurenzo           bulkLoad(context, mlirDenseElementsAttrFloatGet,
356436c6c9cSStella Laurenzo                    mlirF32TypeGet(context), arrayInfo));
357436c6c9cSStella Laurenzo     } else if (arrayInfo.format == "d") {
358436c6c9cSStella Laurenzo       // f64
359436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
360436c6c9cSStella Laurenzo       return PyDenseElementsAttribute(
361436c6c9cSStella Laurenzo           contextWrapper->getRef(),
362436c6c9cSStella Laurenzo           bulkLoad(context, mlirDenseElementsAttrDoubleGet,
363436c6c9cSStella Laurenzo                    mlirF64TypeGet(context), arrayInfo));
364436c6c9cSStella Laurenzo     } else if (isSignedIntegerFormat(arrayInfo.format)) {
365436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
366436c6c9cSStella Laurenzo         // i32
367436c6c9cSStella Laurenzo         MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
368436c6c9cSStella Laurenzo                                         : mlirIntegerTypeSignedGet(context, 32);
369436c6c9cSStella Laurenzo         return PyDenseElementsAttribute(contextWrapper->getRef(),
370436c6c9cSStella Laurenzo                                         bulkLoad(context,
371436c6c9cSStella Laurenzo                                                  mlirDenseElementsAttrInt32Get,
372436c6c9cSStella Laurenzo                                                  elementType, arrayInfo));
373436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
374436c6c9cSStella Laurenzo         // i64
375436c6c9cSStella Laurenzo         MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
376436c6c9cSStella Laurenzo                                         : mlirIntegerTypeSignedGet(context, 64);
377436c6c9cSStella Laurenzo         return PyDenseElementsAttribute(contextWrapper->getRef(),
378436c6c9cSStella Laurenzo                                         bulkLoad(context,
379436c6c9cSStella Laurenzo                                                  mlirDenseElementsAttrInt64Get,
380436c6c9cSStella Laurenzo                                                  elementType, arrayInfo));
381436c6c9cSStella Laurenzo       }
382436c6c9cSStella Laurenzo     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
383436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
384436c6c9cSStella Laurenzo         // unsigned i32
385436c6c9cSStella Laurenzo         MlirType elementType = signless
386436c6c9cSStella Laurenzo                                    ? mlirIntegerTypeGet(context, 32)
387436c6c9cSStella Laurenzo                                    : mlirIntegerTypeUnsignedGet(context, 32);
388436c6c9cSStella Laurenzo         return PyDenseElementsAttribute(contextWrapper->getRef(),
389436c6c9cSStella Laurenzo                                         bulkLoad(context,
390436c6c9cSStella Laurenzo                                                  mlirDenseElementsAttrUInt32Get,
391436c6c9cSStella Laurenzo                                                  elementType, arrayInfo));
392436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
393436c6c9cSStella Laurenzo         // unsigned i64
394436c6c9cSStella Laurenzo         MlirType elementType = signless
395436c6c9cSStella Laurenzo                                    ? mlirIntegerTypeGet(context, 64)
396436c6c9cSStella Laurenzo                                    : mlirIntegerTypeUnsignedGet(context, 64);
397436c6c9cSStella Laurenzo         return PyDenseElementsAttribute(contextWrapper->getRef(),
398436c6c9cSStella Laurenzo                                         bulkLoad(context,
399436c6c9cSStella Laurenzo                                                  mlirDenseElementsAttrUInt64Get,
400436c6c9cSStella Laurenzo                                                  elementType, arrayInfo));
401436c6c9cSStella Laurenzo       }
402436c6c9cSStella Laurenzo     }
403436c6c9cSStella Laurenzo 
404436c6c9cSStella Laurenzo     // TODO: Fall back to string-based get.
405436c6c9cSStella Laurenzo     std::string message = "unimplemented array format conversion from format: ";
406436c6c9cSStella Laurenzo     message.append(arrayInfo.format);
407436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, message);
408436c6c9cSStella Laurenzo   }
409436c6c9cSStella Laurenzo 
410436c6c9cSStella Laurenzo   static PyDenseElementsAttribute getSplat(PyType shapedType,
411436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
412436c6c9cSStella Laurenzo     auto contextWrapper =
413436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
414436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
415436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
416436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
417436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
418436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
419436c6c9cSStella Laurenzo     }
420436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
421436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
422436c6c9cSStella Laurenzo       std::string message =
423436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
424436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
425436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
426436c6c9cSStella Laurenzo     }
427436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
428436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
429436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
430436c6c9cSStella Laurenzo       std::string message =
431436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
432436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
433436c6c9cSStella Laurenzo       message.append(", element=");
434436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
435436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
436436c6c9cSStella Laurenzo     }
437436c6c9cSStella Laurenzo 
438436c6c9cSStella Laurenzo     MlirAttribute elements =
439436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
440436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
441436c6c9cSStella Laurenzo   }
442436c6c9cSStella Laurenzo 
443436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
444436c6c9cSStella Laurenzo 
445436c6c9cSStella Laurenzo   py::buffer_info accessBuffer() {
446436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
447436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
448436c6c9cSStella Laurenzo 
449436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
450436c6c9cSStella Laurenzo       // f32
451436c6c9cSStella Laurenzo       return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
452436c6c9cSStella Laurenzo     } else if (mlirTypeIsAF64(elementType)) {
453436c6c9cSStella Laurenzo       // f64
454436c6c9cSStella Laurenzo       return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
455436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
456436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 32) {
457436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
458436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
459436c6c9cSStella Laurenzo         // i32
460436c6c9cSStella Laurenzo         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
461436c6c9cSStella Laurenzo       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
462436c6c9cSStella Laurenzo         // unsigned i32
463436c6c9cSStella Laurenzo         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
464436c6c9cSStella Laurenzo       }
465436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
466436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
467436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
468436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
469436c6c9cSStella Laurenzo         // i64
470436c6c9cSStella Laurenzo         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
471436c6c9cSStella Laurenzo       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
472436c6c9cSStella Laurenzo         // unsigned i64
473436c6c9cSStella Laurenzo         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
474436c6c9cSStella Laurenzo       }
475436c6c9cSStella Laurenzo     }
476436c6c9cSStella Laurenzo 
477436c6c9cSStella Laurenzo     std::string message = "unimplemented array format.";
478436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, message);
479436c6c9cSStella Laurenzo   }
480436c6c9cSStella Laurenzo 
481436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
482436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
483436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
484436c6c9cSStella Laurenzo                     py::arg("array"), py::arg("signless") = true,
485436c6c9cSStella Laurenzo                     py::arg("context") = py::none(),
486436c6c9cSStella Laurenzo                     "Gets from a buffer or ndarray")
487436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
488436c6c9cSStella Laurenzo                     py::arg("shaped_type"), py::arg("element_attr"),
489436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
490436c6c9cSStella Laurenzo         .def_property_readonly("is_splat",
491436c6c9cSStella Laurenzo                                [](PyDenseElementsAttribute &self) -> bool {
492436c6c9cSStella Laurenzo                                  return mlirDenseElementsAttrIsSplat(self);
493436c6c9cSStella Laurenzo                                })
494436c6c9cSStella Laurenzo         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
495436c6c9cSStella Laurenzo   }
496436c6c9cSStella Laurenzo 
497436c6c9cSStella Laurenzo private:
498436c6c9cSStella Laurenzo   template <typename ElementTy>
499436c6c9cSStella Laurenzo   static MlirAttribute
500436c6c9cSStella Laurenzo   bulkLoad(MlirContext context,
501436c6c9cSStella Laurenzo            MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
502436c6c9cSStella Laurenzo            MlirType mlirElementType, py::buffer_info &arrayInfo) {
503436c6c9cSStella Laurenzo     SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
504436c6c9cSStella Laurenzo                                   arrayInfo.shape.begin() + arrayInfo.ndim);
505*7714b405SAart Bik     MlirAttribute encodingAttr = mlirAttributeGetNull();
506*7714b405SAart Bik     auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
507*7714b405SAart Bik                                               mlirElementType, encodingAttr);
508436c6c9cSStella Laurenzo     intptr_t numElements = arrayInfo.size;
509436c6c9cSStella Laurenzo     const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
510436c6c9cSStella Laurenzo     return ctor(shapedType, numElements, contents);
511436c6c9cSStella Laurenzo   }
512436c6c9cSStella Laurenzo 
513436c6c9cSStella Laurenzo   static bool isUnsignedIntegerFormat(const std::string &format) {
514436c6c9cSStella Laurenzo     if (format.empty())
515436c6c9cSStella Laurenzo       return false;
516436c6c9cSStella Laurenzo     char code = format[0];
517436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
518436c6c9cSStella Laurenzo            code == 'Q';
519436c6c9cSStella Laurenzo   }
520436c6c9cSStella Laurenzo 
521436c6c9cSStella Laurenzo   static bool isSignedIntegerFormat(const std::string &format) {
522436c6c9cSStella Laurenzo     if (format.empty())
523436c6c9cSStella Laurenzo       return false;
524436c6c9cSStella Laurenzo     char code = format[0];
525436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
526436c6c9cSStella Laurenzo            code == 'q';
527436c6c9cSStella Laurenzo   }
528436c6c9cSStella Laurenzo 
529436c6c9cSStella Laurenzo   template <typename Type>
530436c6c9cSStella Laurenzo   py::buffer_info bufferInfo(MlirType shapedType,
531436c6c9cSStella Laurenzo                              Type (*value)(MlirAttribute, intptr_t)) {
532436c6c9cSStella Laurenzo     intptr_t rank = mlirShapedTypeGetRank(shapedType);
533436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
534436c6c9cSStella Laurenzo     // Buffer is configured for read-only access below.
535436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
536436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
537436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
538436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
539436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
540436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
541436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
542436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
543436c6c9cSStella Laurenzo     intptr_t strideFactor = 1;
544436c6c9cSStella Laurenzo     for (intptr_t i = 1; i < rank; ++i) {
545436c6c9cSStella Laurenzo       strideFactor = 1;
546436c6c9cSStella Laurenzo       for (intptr_t j = i; j < rank; ++j) {
547436c6c9cSStella Laurenzo         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
548436c6c9cSStella Laurenzo       }
549436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type) * strideFactor);
550436c6c9cSStella Laurenzo     }
551436c6c9cSStella Laurenzo     strides.push_back(sizeof(Type));
552436c6c9cSStella Laurenzo     return py::buffer_info(data, sizeof(Type),
553436c6c9cSStella Laurenzo                            py::format_descriptor<Type>::format(), rank, shape,
554436c6c9cSStella Laurenzo                            strides, /*readonly=*/true);
555436c6c9cSStella Laurenzo   }
556436c6c9cSStella Laurenzo }; // namespace
557436c6c9cSStella Laurenzo 
558436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer
559436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access.
560436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
561436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
562436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
563436c6c9cSStella Laurenzo public:
564436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
565436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
566436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
567436c6c9cSStella Laurenzo 
568436c6c9cSStella Laurenzo   /// Returns the element at the given linear position. Asserts if the index is
569436c6c9cSStella Laurenzo   /// out of range.
570436c6c9cSStella Laurenzo   py::int_ dunderGetItem(intptr_t pos) {
571436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
572436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
573436c6c9cSStella Laurenzo                        "attempt to access out of bounds element");
574436c6c9cSStella Laurenzo     }
575436c6c9cSStella Laurenzo 
576436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
577436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
578436c6c9cSStella Laurenzo     assert(mlirTypeIsAInteger(type) &&
579436c6c9cSStella Laurenzo            "expected integer element type in dense int elements attribute");
580436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
581436c6c9cSStella Laurenzo     // elemental type of the attribute. py::int_ is implicitly constructible
582436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
583436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
584436c6c9cSStella Laurenzo     // querying them on each element access.
585436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
586436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
587436c6c9cSStella Laurenzo     if (isUnsigned) {
588436c6c9cSStella Laurenzo       if (width == 1) {
589436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
590436c6c9cSStella Laurenzo       }
591436c6c9cSStella Laurenzo       if (width == 32) {
592436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
593436c6c9cSStella Laurenzo       }
594436c6c9cSStella Laurenzo       if (width == 64) {
595436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
596436c6c9cSStella Laurenzo       }
597436c6c9cSStella Laurenzo     } else {
598436c6c9cSStella Laurenzo       if (width == 1) {
599436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
600436c6c9cSStella Laurenzo       }
601436c6c9cSStella Laurenzo       if (width == 32) {
602436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt32Value(*this, pos);
603436c6c9cSStella Laurenzo       }
604436c6c9cSStella Laurenzo       if (width == 64) {
605436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt64Value(*this, pos);
606436c6c9cSStella Laurenzo       }
607436c6c9cSStella Laurenzo     }
608436c6c9cSStella Laurenzo     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
609436c6c9cSStella Laurenzo   }
610436c6c9cSStella Laurenzo 
611436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
612436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
613436c6c9cSStella Laurenzo   }
614436c6c9cSStella Laurenzo };
615436c6c9cSStella Laurenzo 
616436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
617436c6c9cSStella Laurenzo public:
618436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
619436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
620436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
621436c6c9cSStella Laurenzo 
622436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
623436c6c9cSStella Laurenzo 
624436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
625436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
626436c6c9cSStella Laurenzo     c.def_static(
627436c6c9cSStella Laurenzo         "get",
628436c6c9cSStella Laurenzo         [](py::dict attributes, DefaultingPyMlirContext context) {
629436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
630436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
631436c6c9cSStella Laurenzo           for (auto &it : attributes) {
632436c6c9cSStella Laurenzo             auto &mlir_attr = it.second.cast<PyAttribute &>();
633436c6c9cSStella Laurenzo             auto name = it.first.cast<std::string>();
634436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
635436c6c9cSStella Laurenzo                 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
636436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
637436c6c9cSStella Laurenzo                 mlir_attr));
638436c6c9cSStella Laurenzo           }
639436c6c9cSStella Laurenzo           MlirAttribute attr =
640436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
641436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
642436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
643436c6c9cSStella Laurenzo         },
644436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
645436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
646436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
647436c6c9cSStella Laurenzo       MlirAttribute attr =
648436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
649436c6c9cSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
650436c6c9cSStella Laurenzo         throw SetPyError(PyExc_KeyError,
651436c6c9cSStella Laurenzo                          "attempt to access a non-existent attribute");
652436c6c9cSStella Laurenzo       }
653436c6c9cSStella Laurenzo       return PyAttribute(self.getContext(), attr);
654436c6c9cSStella Laurenzo     });
655436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
656436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
657436c6c9cSStella Laurenzo         throw SetPyError(PyExc_IndexError,
658436c6c9cSStella Laurenzo                          "attempt to access out of bounds attribute");
659436c6c9cSStella Laurenzo       }
660436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
661436c6c9cSStella Laurenzo       return PyNamedAttribute(
662436c6c9cSStella Laurenzo           namedAttr.attribute,
663436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
664436c6c9cSStella Laurenzo     });
665436c6c9cSStella Laurenzo   }
666436c6c9cSStella Laurenzo };
667436c6c9cSStella Laurenzo 
668436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
669436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
670436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
671436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
672436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
673436c6c9cSStella Laurenzo public:
674436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
675436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
676436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
677436c6c9cSStella Laurenzo 
678436c6c9cSStella Laurenzo   py::float_ dunderGetItem(intptr_t pos) {
679436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
680436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
681436c6c9cSStella Laurenzo                        "attempt to access out of bounds element");
682436c6c9cSStella Laurenzo     }
683436c6c9cSStella Laurenzo 
684436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
685436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
686436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
687436c6c9cSStella Laurenzo     // elemental type of the attribute. py::float_ is implicitly constructible
688436c6c9cSStella Laurenzo     // from float and double.
689436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
690436c6c9cSStella Laurenzo     // querying them on each element access.
691436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
692436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetFloatValue(*this, pos);
693436c6c9cSStella Laurenzo     }
694436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
695436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
696436c6c9cSStella Laurenzo     }
697436c6c9cSStella Laurenzo     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
698436c6c9cSStella Laurenzo   }
699436c6c9cSStella Laurenzo 
700436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
701436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
702436c6c9cSStella Laurenzo   }
703436c6c9cSStella Laurenzo };
704436c6c9cSStella Laurenzo 
705436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
706436c6c9cSStella Laurenzo public:
707436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
708436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
709436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
710436c6c9cSStella Laurenzo 
711436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
712436c6c9cSStella Laurenzo     c.def_static(
713436c6c9cSStella Laurenzo         "get",
714436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
715436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
716436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
717436c6c9cSStella Laurenzo         },
718436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
719436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
720436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyTypeAttribute &self) {
721436c6c9cSStella Laurenzo       return PyType(self.getContext()->getRef(),
722436c6c9cSStella Laurenzo                     mlirTypeAttrGetValue(self.get()));
723436c6c9cSStella Laurenzo     });
724436c6c9cSStella Laurenzo   }
725436c6c9cSStella Laurenzo };
726436c6c9cSStella Laurenzo 
727436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
728436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
729436c6c9cSStella Laurenzo public:
730436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
731436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
732436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
733436c6c9cSStella Laurenzo 
734436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
735436c6c9cSStella Laurenzo     c.def_static(
736436c6c9cSStella Laurenzo         "get",
737436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
738436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
739436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
740436c6c9cSStella Laurenzo         },
741436c6c9cSStella Laurenzo         py::arg("context") = py::none(), "Create a Unit attribute.");
742436c6c9cSStella Laurenzo   }
743436c6c9cSStella Laurenzo };
744436c6c9cSStella Laurenzo 
745436c6c9cSStella Laurenzo } // namespace
746436c6c9cSStella Laurenzo 
747436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) {
748436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
749436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
750436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
751436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
752436c6c9cSStella Laurenzo   PyDenseElementsAttribute::bind(m);
753436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
754436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
755436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
756436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
757436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
758436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
759436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
760436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
761436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
762436c6c9cSStella Laurenzo }
763