xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision 436c6c9c20cc522c92a923440a5fc509c342a7db)
1*436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2*436c6c9cSStella Laurenzo //
3*436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5*436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*436c6c9cSStella Laurenzo //
7*436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8*436c6c9cSStella Laurenzo 
9*436c6c9cSStella Laurenzo #include "IRModule.h"
10*436c6c9cSStella Laurenzo 
11*436c6c9cSStella Laurenzo #include "PybindUtils.h"
12*436c6c9cSStella Laurenzo 
13*436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
14*436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
15*436c6c9cSStella Laurenzo 
16*436c6c9cSStella Laurenzo namespace py = pybind11;
17*436c6c9cSStella Laurenzo using namespace mlir;
18*436c6c9cSStella Laurenzo using namespace mlir::python;
19*436c6c9cSStella Laurenzo 
20*436c6c9cSStella Laurenzo using llvm::SmallVector;
21*436c6c9cSStella Laurenzo using llvm::StringRef;
22*436c6c9cSStella Laurenzo using llvm::Twine;
23*436c6c9cSStella Laurenzo 
24*436c6c9cSStella Laurenzo namespace {
25*436c6c9cSStella Laurenzo 
26*436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
27*436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
28*436c6c9cSStella Laurenzo }
29*436c6c9cSStella Laurenzo 
30*436c6c9cSStella Laurenzo /// CRTP base classes for Python attributes that subclass Attribute and should
31*436c6c9cSStella Laurenzo /// be castable from it (i.e. via something like StringAttr(attr)).
32*436c6c9cSStella Laurenzo /// By default, attribute class hierarchies are one level deep (i.e. a
33*436c6c9cSStella Laurenzo /// concrete attribute class extends PyAttribute); however, intermediate
34*436c6c9cSStella Laurenzo /// python-visible base classes can be modeled by specifying a BaseTy.
35*436c6c9cSStella Laurenzo template <typename DerivedTy, typename BaseTy = PyAttribute>
36*436c6c9cSStella Laurenzo class PyConcreteAttribute : public BaseTy {
37*436c6c9cSStella Laurenzo public:
38*436c6c9cSStella Laurenzo   // Derived classes must define statics for:
39*436c6c9cSStella Laurenzo   //   IsAFunctionTy isaFunction
40*436c6c9cSStella Laurenzo   //   const char *pyClassName
41*436c6c9cSStella Laurenzo   using ClassTy = py::class_<DerivedTy, BaseTy>;
42*436c6c9cSStella Laurenzo   using IsAFunctionTy = bool (*)(MlirAttribute);
43*436c6c9cSStella Laurenzo 
44*436c6c9cSStella Laurenzo   PyConcreteAttribute() = default;
45*436c6c9cSStella Laurenzo   PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
46*436c6c9cSStella Laurenzo       : BaseTy(std::move(contextRef), attr) {}
47*436c6c9cSStella Laurenzo   PyConcreteAttribute(PyAttribute &orig)
48*436c6c9cSStella Laurenzo       : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
49*436c6c9cSStella Laurenzo 
50*436c6c9cSStella Laurenzo   static MlirAttribute castFrom(PyAttribute &orig) {
51*436c6c9cSStella Laurenzo     if (!DerivedTy::isaFunction(orig)) {
52*436c6c9cSStella Laurenzo       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
53*436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
54*436c6c9cSStella Laurenzo                                              DerivedTy::pyClassName +
55*436c6c9cSStella Laurenzo                                              " (from " + origRepr + ")");
56*436c6c9cSStella Laurenzo     }
57*436c6c9cSStella Laurenzo     return orig;
58*436c6c9cSStella Laurenzo   }
59*436c6c9cSStella Laurenzo 
60*436c6c9cSStella Laurenzo   static void bind(py::module &m) {
61*436c6c9cSStella Laurenzo     auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol());
62*436c6c9cSStella Laurenzo     cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
63*436c6c9cSStella Laurenzo     DerivedTy::bindDerived(cls);
64*436c6c9cSStella Laurenzo   }
65*436c6c9cSStella Laurenzo 
66*436c6c9cSStella Laurenzo   /// Implemented by derived classes to add methods to the Python subclass.
67*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &m) {}
68*436c6c9cSStella Laurenzo };
69*436c6c9cSStella Laurenzo 
70*436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
71*436c6c9cSStella Laurenzo public:
72*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
73*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
74*436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
75*436c6c9cSStella Laurenzo 
76*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
77*436c6c9cSStella Laurenzo     c.def_static(
78*436c6c9cSStella Laurenzo         "get",
79*436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
80*436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
81*436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
82*436c6c9cSStella Laurenzo         },
83*436c6c9cSStella Laurenzo         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
84*436c6c9cSStella Laurenzo   }
85*436c6c9cSStella Laurenzo };
86*436c6c9cSStella Laurenzo 
87*436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
88*436c6c9cSStella Laurenzo public:
89*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
90*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
91*436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
92*436c6c9cSStella Laurenzo 
93*436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
94*436c6c9cSStella Laurenzo   public:
95*436c6c9cSStella Laurenzo     PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
96*436c6c9cSStella Laurenzo 
97*436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
98*436c6c9cSStella Laurenzo 
99*436c6c9cSStella Laurenzo     PyAttribute dunderNext() {
100*436c6c9cSStella Laurenzo       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
101*436c6c9cSStella Laurenzo         throw py::stop_iteration();
102*436c6c9cSStella Laurenzo       }
103*436c6c9cSStella Laurenzo       return PyAttribute(attr.getContext(),
104*436c6c9cSStella Laurenzo                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
105*436c6c9cSStella Laurenzo     }
106*436c6c9cSStella Laurenzo 
107*436c6c9cSStella Laurenzo     static void bind(py::module &m) {
108*436c6c9cSStella Laurenzo       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
109*436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
110*436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
111*436c6c9cSStella Laurenzo     }
112*436c6c9cSStella Laurenzo 
113*436c6c9cSStella Laurenzo   private:
114*436c6c9cSStella Laurenzo     PyAttribute attr;
115*436c6c9cSStella Laurenzo     int nextIndex = 0;
116*436c6c9cSStella Laurenzo   };
117*436c6c9cSStella Laurenzo 
118*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
119*436c6c9cSStella Laurenzo     c.def_static(
120*436c6c9cSStella Laurenzo         "get",
121*436c6c9cSStella Laurenzo         [](py::list attributes, DefaultingPyMlirContext context) {
122*436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
123*436c6c9cSStella Laurenzo           mlirAttributes.reserve(py::len(attributes));
124*436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
125*436c6c9cSStella Laurenzo             try {
126*436c6c9cSStella Laurenzo               mlirAttributes.push_back(attribute.cast<PyAttribute>());
127*436c6c9cSStella Laurenzo             } catch (py::cast_error &err) {
128*436c6c9cSStella Laurenzo               std::string msg = std::string("Invalid attribute when attempting "
129*436c6c9cSStella Laurenzo                                             "to create an ArrayAttribute (") +
130*436c6c9cSStella Laurenzo                                 err.what() + ")";
131*436c6c9cSStella Laurenzo               throw py::cast_error(msg);
132*436c6c9cSStella Laurenzo             } catch (py::reference_cast_error &err) {
133*436c6c9cSStella Laurenzo               // This exception seems thrown when the value is "None".
134*436c6c9cSStella Laurenzo               std::string msg =
135*436c6c9cSStella Laurenzo                   std::string("Invalid attribute (None?) when attempting to "
136*436c6c9cSStella Laurenzo                               "create an ArrayAttribute (") +
137*436c6c9cSStella Laurenzo                   err.what() + ")";
138*436c6c9cSStella Laurenzo               throw py::cast_error(msg);
139*436c6c9cSStella Laurenzo             }
140*436c6c9cSStella Laurenzo           }
141*436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
142*436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
143*436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
144*436c6c9cSStella Laurenzo         },
145*436c6c9cSStella Laurenzo         py::arg("attributes"), py::arg("context") = py::none(),
146*436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
147*436c6c9cSStella Laurenzo     c.def("__getitem__",
148*436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
149*436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
150*436c6c9cSStella Laurenzo               throw py::index_error("ArrayAttribute index out of range");
151*436c6c9cSStella Laurenzo             return PyAttribute(arr.getContext(),
152*436c6c9cSStella Laurenzo                                mlirArrayAttrGetElement(arr, i));
153*436c6c9cSStella Laurenzo           })
154*436c6c9cSStella Laurenzo         .def("__len__",
155*436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
156*436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
157*436c6c9cSStella Laurenzo              })
158*436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
159*436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
160*436c6c9cSStella Laurenzo         });
161*436c6c9cSStella Laurenzo   }
162*436c6c9cSStella Laurenzo };
163*436c6c9cSStella Laurenzo 
164*436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
165*436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
166*436c6c9cSStella Laurenzo public:
167*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
168*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
169*436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
170*436c6c9cSStella Laurenzo 
171*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
172*436c6c9cSStella Laurenzo     c.def_static(
173*436c6c9cSStella Laurenzo         "get",
174*436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
175*436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
176*436c6c9cSStella Laurenzo           // TODO: Rework error reporting once diagnostic engine is exposed
177*436c6c9cSStella Laurenzo           // in C API.
178*436c6c9cSStella Laurenzo           if (mlirAttributeIsNull(attr)) {
179*436c6c9cSStella Laurenzo             throw SetPyError(PyExc_ValueError,
180*436c6c9cSStella Laurenzo                              Twine("invalid '") +
181*436c6c9cSStella Laurenzo                                  py::repr(py::cast(type)).cast<std::string>() +
182*436c6c9cSStella Laurenzo                                  "' and expected floating point type.");
183*436c6c9cSStella Laurenzo           }
184*436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
185*436c6c9cSStella Laurenzo         },
186*436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
187*436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
188*436c6c9cSStella Laurenzo     c.def_static(
189*436c6c9cSStella Laurenzo         "get_f32",
190*436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
191*436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
192*436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
193*436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
194*436c6c9cSStella Laurenzo         },
195*436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
196*436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
197*436c6c9cSStella Laurenzo     c.def_static(
198*436c6c9cSStella Laurenzo         "get_f64",
199*436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
200*436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
201*436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
202*436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
203*436c6c9cSStella Laurenzo         },
204*436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
205*436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
206*436c6c9cSStella Laurenzo     c.def_property_readonly(
207*436c6c9cSStella Laurenzo         "value",
208*436c6c9cSStella Laurenzo         [](PyFloatAttribute &self) {
209*436c6c9cSStella Laurenzo           return mlirFloatAttrGetValueDouble(self);
210*436c6c9cSStella Laurenzo         },
211*436c6c9cSStella Laurenzo         "Returns the value of the float point attribute");
212*436c6c9cSStella Laurenzo   }
213*436c6c9cSStella Laurenzo };
214*436c6c9cSStella Laurenzo 
215*436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
216*436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
217*436c6c9cSStella Laurenzo public:
218*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
219*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
220*436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
221*436c6c9cSStella Laurenzo 
222*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
223*436c6c9cSStella Laurenzo     c.def_static(
224*436c6c9cSStella Laurenzo         "get",
225*436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
226*436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
227*436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
228*436c6c9cSStella Laurenzo         },
229*436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"),
230*436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
231*436c6c9cSStella Laurenzo     c.def_property_readonly(
232*436c6c9cSStella Laurenzo         "value",
233*436c6c9cSStella Laurenzo         [](PyIntegerAttribute &self) {
234*436c6c9cSStella Laurenzo           return mlirIntegerAttrGetValueInt(self);
235*436c6c9cSStella Laurenzo         },
236*436c6c9cSStella Laurenzo         "Returns the value of the integer attribute");
237*436c6c9cSStella Laurenzo   }
238*436c6c9cSStella Laurenzo };
239*436c6c9cSStella Laurenzo 
240*436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
241*436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
242*436c6c9cSStella Laurenzo public:
243*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
244*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
245*436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
246*436c6c9cSStella Laurenzo 
247*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
248*436c6c9cSStella Laurenzo     c.def_static(
249*436c6c9cSStella Laurenzo         "get",
250*436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
251*436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
252*436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
253*436c6c9cSStella Laurenzo         },
254*436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
255*436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
256*436c6c9cSStella Laurenzo     c.def_property_readonly(
257*436c6c9cSStella Laurenzo         "value",
258*436c6c9cSStella Laurenzo         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
259*436c6c9cSStella Laurenzo         "Returns the value of the bool attribute");
260*436c6c9cSStella Laurenzo   }
261*436c6c9cSStella Laurenzo };
262*436c6c9cSStella Laurenzo 
263*436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
264*436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
265*436c6c9cSStella Laurenzo public:
266*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
267*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
268*436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
269*436c6c9cSStella Laurenzo 
270*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
271*436c6c9cSStella Laurenzo     c.def_static(
272*436c6c9cSStella Laurenzo         "get",
273*436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
274*436c6c9cSStella Laurenzo           MlirAttribute attr =
275*436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
276*436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
277*436c6c9cSStella Laurenzo         },
278*436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
279*436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
280*436c6c9cSStella Laurenzo     c.def_property_readonly(
281*436c6c9cSStella Laurenzo         "value",
282*436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
283*436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
284*436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
285*436c6c9cSStella Laurenzo         },
286*436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
287*436c6c9cSStella Laurenzo   }
288*436c6c9cSStella Laurenzo };
289*436c6c9cSStella Laurenzo 
290*436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
291*436c6c9cSStella Laurenzo public:
292*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
293*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
294*436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
295*436c6c9cSStella Laurenzo 
296*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
297*436c6c9cSStella Laurenzo     c.def_static(
298*436c6c9cSStella Laurenzo         "get",
299*436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
300*436c6c9cSStella Laurenzo           MlirAttribute attr =
301*436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
302*436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
303*436c6c9cSStella Laurenzo         },
304*436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
305*436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
306*436c6c9cSStella Laurenzo     c.def_static(
307*436c6c9cSStella Laurenzo         "get_typed",
308*436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
309*436c6c9cSStella Laurenzo           MlirAttribute attr =
310*436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
311*436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
312*436c6c9cSStella Laurenzo         },
313*436c6c9cSStella Laurenzo 
314*436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
315*436c6c9cSStella Laurenzo     c.def_property_readonly(
316*436c6c9cSStella Laurenzo         "value",
317*436c6c9cSStella Laurenzo         [](PyStringAttribute &self) {
318*436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirStringAttrGetValue(self);
319*436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
320*436c6c9cSStella Laurenzo         },
321*436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
322*436c6c9cSStella Laurenzo   }
323*436c6c9cSStella Laurenzo };
324*436c6c9cSStella Laurenzo 
325*436c6c9cSStella Laurenzo // TODO: Support construction of bool elements.
326*436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
327*436c6c9cSStella Laurenzo class PyDenseElementsAttribute
328*436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
329*436c6c9cSStella Laurenzo public:
330*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
331*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
332*436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
333*436c6c9cSStella Laurenzo 
334*436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
335*436c6c9cSStella Laurenzo   getFromBuffer(py::buffer array, bool signless,
336*436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
337*436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
338*436c6c9cSStella Laurenzo     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
339*436c6c9cSStella Laurenzo     Py_buffer *view = new Py_buffer();
340*436c6c9cSStella Laurenzo     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
341*436c6c9cSStella Laurenzo       delete view;
342*436c6c9cSStella Laurenzo       throw py::error_already_set();
343*436c6c9cSStella Laurenzo     }
344*436c6c9cSStella Laurenzo     py::buffer_info arrayInfo(view);
345*436c6c9cSStella Laurenzo 
346*436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
347*436c6c9cSStella Laurenzo     // Switch on the types that can be bulk loaded between the Python and
348*436c6c9cSStella Laurenzo     // MLIR-C APIs.
349*436c6c9cSStella Laurenzo     // See: https://docs.python.org/3/library/struct.html#format-characters
350*436c6c9cSStella Laurenzo     if (arrayInfo.format == "f") {
351*436c6c9cSStella Laurenzo       // f32
352*436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
353*436c6c9cSStella Laurenzo       return PyDenseElementsAttribute(
354*436c6c9cSStella Laurenzo           contextWrapper->getRef(),
355*436c6c9cSStella Laurenzo           bulkLoad(context, mlirDenseElementsAttrFloatGet,
356*436c6c9cSStella Laurenzo                    mlirF32TypeGet(context), arrayInfo));
357*436c6c9cSStella Laurenzo     } else if (arrayInfo.format == "d") {
358*436c6c9cSStella Laurenzo       // f64
359*436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
360*436c6c9cSStella Laurenzo       return PyDenseElementsAttribute(
361*436c6c9cSStella Laurenzo           contextWrapper->getRef(),
362*436c6c9cSStella Laurenzo           bulkLoad(context, mlirDenseElementsAttrDoubleGet,
363*436c6c9cSStella Laurenzo                    mlirF64TypeGet(context), arrayInfo));
364*436c6c9cSStella Laurenzo     } else if (isSignedIntegerFormat(arrayInfo.format)) {
365*436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
366*436c6c9cSStella Laurenzo         // i32
367*436c6c9cSStella Laurenzo         MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
368*436c6c9cSStella Laurenzo                                         : mlirIntegerTypeSignedGet(context, 32);
369*436c6c9cSStella Laurenzo         return PyDenseElementsAttribute(contextWrapper->getRef(),
370*436c6c9cSStella Laurenzo                                         bulkLoad(context,
371*436c6c9cSStella Laurenzo                                                  mlirDenseElementsAttrInt32Get,
372*436c6c9cSStella Laurenzo                                                  elementType, arrayInfo));
373*436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
374*436c6c9cSStella Laurenzo         // i64
375*436c6c9cSStella Laurenzo         MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
376*436c6c9cSStella Laurenzo                                         : mlirIntegerTypeSignedGet(context, 64);
377*436c6c9cSStella Laurenzo         return PyDenseElementsAttribute(contextWrapper->getRef(),
378*436c6c9cSStella Laurenzo                                         bulkLoad(context,
379*436c6c9cSStella Laurenzo                                                  mlirDenseElementsAttrInt64Get,
380*436c6c9cSStella Laurenzo                                                  elementType, arrayInfo));
381*436c6c9cSStella Laurenzo       }
382*436c6c9cSStella Laurenzo     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
383*436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
384*436c6c9cSStella Laurenzo         // unsigned i32
385*436c6c9cSStella Laurenzo         MlirType elementType = signless
386*436c6c9cSStella Laurenzo                                    ? mlirIntegerTypeGet(context, 32)
387*436c6c9cSStella Laurenzo                                    : mlirIntegerTypeUnsignedGet(context, 32);
388*436c6c9cSStella Laurenzo         return PyDenseElementsAttribute(contextWrapper->getRef(),
389*436c6c9cSStella Laurenzo                                         bulkLoad(context,
390*436c6c9cSStella Laurenzo                                                  mlirDenseElementsAttrUInt32Get,
391*436c6c9cSStella Laurenzo                                                  elementType, arrayInfo));
392*436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
393*436c6c9cSStella Laurenzo         // unsigned i64
394*436c6c9cSStella Laurenzo         MlirType elementType = signless
395*436c6c9cSStella Laurenzo                                    ? mlirIntegerTypeGet(context, 64)
396*436c6c9cSStella Laurenzo                                    : mlirIntegerTypeUnsignedGet(context, 64);
397*436c6c9cSStella Laurenzo         return PyDenseElementsAttribute(contextWrapper->getRef(),
398*436c6c9cSStella Laurenzo                                         bulkLoad(context,
399*436c6c9cSStella Laurenzo                                                  mlirDenseElementsAttrUInt64Get,
400*436c6c9cSStella Laurenzo                                                  elementType, arrayInfo));
401*436c6c9cSStella Laurenzo       }
402*436c6c9cSStella Laurenzo     }
403*436c6c9cSStella Laurenzo 
404*436c6c9cSStella Laurenzo     // TODO: Fall back to string-based get.
405*436c6c9cSStella Laurenzo     std::string message = "unimplemented array format conversion from format: ";
406*436c6c9cSStella Laurenzo     message.append(arrayInfo.format);
407*436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, message);
408*436c6c9cSStella Laurenzo   }
409*436c6c9cSStella Laurenzo 
410*436c6c9cSStella Laurenzo   static PyDenseElementsAttribute getSplat(PyType shapedType,
411*436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
412*436c6c9cSStella Laurenzo     auto contextWrapper =
413*436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
414*436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
415*436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
416*436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
417*436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
418*436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
419*436c6c9cSStella Laurenzo     }
420*436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
421*436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
422*436c6c9cSStella Laurenzo       std::string message =
423*436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
424*436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
425*436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
426*436c6c9cSStella Laurenzo     }
427*436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
428*436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
429*436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
430*436c6c9cSStella Laurenzo       std::string message =
431*436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
432*436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
433*436c6c9cSStella Laurenzo       message.append(", element=");
434*436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
435*436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
436*436c6c9cSStella Laurenzo     }
437*436c6c9cSStella Laurenzo 
438*436c6c9cSStella Laurenzo     MlirAttribute elements =
439*436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
440*436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
441*436c6c9cSStella Laurenzo   }
442*436c6c9cSStella Laurenzo 
443*436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
444*436c6c9cSStella Laurenzo 
445*436c6c9cSStella Laurenzo   py::buffer_info accessBuffer() {
446*436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
447*436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
448*436c6c9cSStella Laurenzo 
449*436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
450*436c6c9cSStella Laurenzo       // f32
451*436c6c9cSStella Laurenzo       return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
452*436c6c9cSStella Laurenzo     } else if (mlirTypeIsAF64(elementType)) {
453*436c6c9cSStella Laurenzo       // f64
454*436c6c9cSStella Laurenzo       return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
455*436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
456*436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 32) {
457*436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
458*436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
459*436c6c9cSStella Laurenzo         // i32
460*436c6c9cSStella Laurenzo         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
461*436c6c9cSStella Laurenzo       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
462*436c6c9cSStella Laurenzo         // unsigned i32
463*436c6c9cSStella Laurenzo         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
464*436c6c9cSStella Laurenzo       }
465*436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
466*436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
467*436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
468*436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
469*436c6c9cSStella Laurenzo         // i64
470*436c6c9cSStella Laurenzo         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
471*436c6c9cSStella Laurenzo       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
472*436c6c9cSStella Laurenzo         // unsigned i64
473*436c6c9cSStella Laurenzo         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
474*436c6c9cSStella Laurenzo       }
475*436c6c9cSStella Laurenzo     }
476*436c6c9cSStella Laurenzo 
477*436c6c9cSStella Laurenzo     std::string message = "unimplemented array format.";
478*436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, message);
479*436c6c9cSStella Laurenzo   }
480*436c6c9cSStella Laurenzo 
481*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
482*436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
483*436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
484*436c6c9cSStella Laurenzo                     py::arg("array"), py::arg("signless") = true,
485*436c6c9cSStella Laurenzo                     py::arg("context") = py::none(),
486*436c6c9cSStella Laurenzo                     "Gets from a buffer or ndarray")
487*436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
488*436c6c9cSStella Laurenzo                     py::arg("shaped_type"), py::arg("element_attr"),
489*436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
490*436c6c9cSStella Laurenzo         .def_property_readonly("is_splat",
491*436c6c9cSStella Laurenzo                                [](PyDenseElementsAttribute &self) -> bool {
492*436c6c9cSStella Laurenzo                                  return mlirDenseElementsAttrIsSplat(self);
493*436c6c9cSStella Laurenzo                                })
494*436c6c9cSStella Laurenzo         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
495*436c6c9cSStella Laurenzo   }
496*436c6c9cSStella Laurenzo 
497*436c6c9cSStella Laurenzo private:
498*436c6c9cSStella Laurenzo   template <typename ElementTy>
499*436c6c9cSStella Laurenzo   static MlirAttribute
500*436c6c9cSStella Laurenzo   bulkLoad(MlirContext context,
501*436c6c9cSStella Laurenzo            MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
502*436c6c9cSStella Laurenzo            MlirType mlirElementType, py::buffer_info &arrayInfo) {
503*436c6c9cSStella Laurenzo     SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
504*436c6c9cSStella Laurenzo                                   arrayInfo.shape.begin() + arrayInfo.ndim);
505*436c6c9cSStella Laurenzo     auto shapedType =
506*436c6c9cSStella Laurenzo         mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType);
507*436c6c9cSStella Laurenzo     intptr_t numElements = arrayInfo.size;
508*436c6c9cSStella Laurenzo     const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
509*436c6c9cSStella Laurenzo     return ctor(shapedType, numElements, contents);
510*436c6c9cSStella Laurenzo   }
511*436c6c9cSStella Laurenzo 
512*436c6c9cSStella Laurenzo   static bool isUnsignedIntegerFormat(const std::string &format) {
513*436c6c9cSStella Laurenzo     if (format.empty())
514*436c6c9cSStella Laurenzo       return false;
515*436c6c9cSStella Laurenzo     char code = format[0];
516*436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
517*436c6c9cSStella Laurenzo            code == 'Q';
518*436c6c9cSStella Laurenzo   }
519*436c6c9cSStella Laurenzo 
520*436c6c9cSStella Laurenzo   static bool isSignedIntegerFormat(const std::string &format) {
521*436c6c9cSStella Laurenzo     if (format.empty())
522*436c6c9cSStella Laurenzo       return false;
523*436c6c9cSStella Laurenzo     char code = format[0];
524*436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
525*436c6c9cSStella Laurenzo            code == 'q';
526*436c6c9cSStella Laurenzo   }
527*436c6c9cSStella Laurenzo 
528*436c6c9cSStella Laurenzo   template <typename Type>
529*436c6c9cSStella Laurenzo   py::buffer_info bufferInfo(MlirType shapedType,
530*436c6c9cSStella Laurenzo                              Type (*value)(MlirAttribute, intptr_t)) {
531*436c6c9cSStella Laurenzo     intptr_t rank = mlirShapedTypeGetRank(shapedType);
532*436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
533*436c6c9cSStella Laurenzo     // Buffer is configured for read-only access below.
534*436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
535*436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
536*436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
537*436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
538*436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
539*436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
540*436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
541*436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
542*436c6c9cSStella Laurenzo     intptr_t strideFactor = 1;
543*436c6c9cSStella Laurenzo     for (intptr_t i = 1; i < rank; ++i) {
544*436c6c9cSStella Laurenzo       strideFactor = 1;
545*436c6c9cSStella Laurenzo       for (intptr_t j = i; j < rank; ++j) {
546*436c6c9cSStella Laurenzo         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
547*436c6c9cSStella Laurenzo       }
548*436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type) * strideFactor);
549*436c6c9cSStella Laurenzo     }
550*436c6c9cSStella Laurenzo     strides.push_back(sizeof(Type));
551*436c6c9cSStella Laurenzo     return py::buffer_info(data, sizeof(Type),
552*436c6c9cSStella Laurenzo                            py::format_descriptor<Type>::format(), rank, shape,
553*436c6c9cSStella Laurenzo                            strides, /*readonly=*/true);
554*436c6c9cSStella Laurenzo   }
555*436c6c9cSStella Laurenzo }; // namespace
556*436c6c9cSStella Laurenzo 
557*436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer
558*436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access.
559*436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
560*436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
561*436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
562*436c6c9cSStella Laurenzo public:
563*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
564*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
565*436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
566*436c6c9cSStella Laurenzo 
567*436c6c9cSStella Laurenzo   /// Returns the element at the given linear position. Asserts if the index is
568*436c6c9cSStella Laurenzo   /// out of range.
569*436c6c9cSStella Laurenzo   py::int_ dunderGetItem(intptr_t pos) {
570*436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
571*436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
572*436c6c9cSStella Laurenzo                        "attempt to access out of bounds element");
573*436c6c9cSStella Laurenzo     }
574*436c6c9cSStella Laurenzo 
575*436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
576*436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
577*436c6c9cSStella Laurenzo     assert(mlirTypeIsAInteger(type) &&
578*436c6c9cSStella Laurenzo            "expected integer element type in dense int elements attribute");
579*436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
580*436c6c9cSStella Laurenzo     // elemental type of the attribute. py::int_ is implicitly constructible
581*436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
582*436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
583*436c6c9cSStella Laurenzo     // querying them on each element access.
584*436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
585*436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
586*436c6c9cSStella Laurenzo     if (isUnsigned) {
587*436c6c9cSStella Laurenzo       if (width == 1) {
588*436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
589*436c6c9cSStella Laurenzo       }
590*436c6c9cSStella Laurenzo       if (width == 32) {
591*436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
592*436c6c9cSStella Laurenzo       }
593*436c6c9cSStella Laurenzo       if (width == 64) {
594*436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
595*436c6c9cSStella Laurenzo       }
596*436c6c9cSStella Laurenzo     } else {
597*436c6c9cSStella Laurenzo       if (width == 1) {
598*436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
599*436c6c9cSStella Laurenzo       }
600*436c6c9cSStella Laurenzo       if (width == 32) {
601*436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt32Value(*this, pos);
602*436c6c9cSStella Laurenzo       }
603*436c6c9cSStella Laurenzo       if (width == 64) {
604*436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt64Value(*this, pos);
605*436c6c9cSStella Laurenzo       }
606*436c6c9cSStella Laurenzo     }
607*436c6c9cSStella Laurenzo     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
608*436c6c9cSStella Laurenzo   }
609*436c6c9cSStella Laurenzo 
610*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
611*436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
612*436c6c9cSStella Laurenzo   }
613*436c6c9cSStella Laurenzo };
614*436c6c9cSStella Laurenzo 
615*436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
616*436c6c9cSStella Laurenzo public:
617*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
618*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
619*436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
620*436c6c9cSStella Laurenzo 
621*436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
622*436c6c9cSStella Laurenzo 
623*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
624*436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
625*436c6c9cSStella Laurenzo     c.def_static(
626*436c6c9cSStella Laurenzo         "get",
627*436c6c9cSStella Laurenzo         [](py::dict attributes, DefaultingPyMlirContext context) {
628*436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
629*436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
630*436c6c9cSStella Laurenzo           for (auto &it : attributes) {
631*436c6c9cSStella Laurenzo             auto &mlir_attr = it.second.cast<PyAttribute &>();
632*436c6c9cSStella Laurenzo             auto name = it.first.cast<std::string>();
633*436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
634*436c6c9cSStella Laurenzo                 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
635*436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
636*436c6c9cSStella Laurenzo                 mlir_attr));
637*436c6c9cSStella Laurenzo           }
638*436c6c9cSStella Laurenzo           MlirAttribute attr =
639*436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
640*436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
641*436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
642*436c6c9cSStella Laurenzo         },
643*436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
644*436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
645*436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
646*436c6c9cSStella Laurenzo       MlirAttribute attr =
647*436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
648*436c6c9cSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
649*436c6c9cSStella Laurenzo         throw SetPyError(PyExc_KeyError,
650*436c6c9cSStella Laurenzo                          "attempt to access a non-existent attribute");
651*436c6c9cSStella Laurenzo       }
652*436c6c9cSStella Laurenzo       return PyAttribute(self.getContext(), attr);
653*436c6c9cSStella Laurenzo     });
654*436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
655*436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
656*436c6c9cSStella Laurenzo         throw SetPyError(PyExc_IndexError,
657*436c6c9cSStella Laurenzo                          "attempt to access out of bounds attribute");
658*436c6c9cSStella Laurenzo       }
659*436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
660*436c6c9cSStella Laurenzo       return PyNamedAttribute(
661*436c6c9cSStella Laurenzo           namedAttr.attribute,
662*436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
663*436c6c9cSStella Laurenzo     });
664*436c6c9cSStella Laurenzo   }
665*436c6c9cSStella Laurenzo };
666*436c6c9cSStella Laurenzo 
667*436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
668*436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
669*436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
670*436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
671*436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
672*436c6c9cSStella Laurenzo public:
673*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
674*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
675*436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
676*436c6c9cSStella Laurenzo 
677*436c6c9cSStella Laurenzo   py::float_ dunderGetItem(intptr_t pos) {
678*436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
679*436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
680*436c6c9cSStella Laurenzo                        "attempt to access out of bounds element");
681*436c6c9cSStella Laurenzo     }
682*436c6c9cSStella Laurenzo 
683*436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
684*436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
685*436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
686*436c6c9cSStella Laurenzo     // elemental type of the attribute. py::float_ is implicitly constructible
687*436c6c9cSStella Laurenzo     // from float and double.
688*436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
689*436c6c9cSStella Laurenzo     // querying them on each element access.
690*436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
691*436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetFloatValue(*this, pos);
692*436c6c9cSStella Laurenzo     }
693*436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
694*436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
695*436c6c9cSStella Laurenzo     }
696*436c6c9cSStella Laurenzo     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
697*436c6c9cSStella Laurenzo   }
698*436c6c9cSStella Laurenzo 
699*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
700*436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
701*436c6c9cSStella Laurenzo   }
702*436c6c9cSStella Laurenzo };
703*436c6c9cSStella Laurenzo 
704*436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
705*436c6c9cSStella Laurenzo public:
706*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
707*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
708*436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
709*436c6c9cSStella Laurenzo 
710*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
711*436c6c9cSStella Laurenzo     c.def_static(
712*436c6c9cSStella Laurenzo         "get",
713*436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
714*436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
715*436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
716*436c6c9cSStella Laurenzo         },
717*436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
718*436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
719*436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyTypeAttribute &self) {
720*436c6c9cSStella Laurenzo       return PyType(self.getContext()->getRef(),
721*436c6c9cSStella Laurenzo                     mlirTypeAttrGetValue(self.get()));
722*436c6c9cSStella Laurenzo     });
723*436c6c9cSStella Laurenzo   }
724*436c6c9cSStella Laurenzo };
725*436c6c9cSStella Laurenzo 
726*436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
727*436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
728*436c6c9cSStella Laurenzo public:
729*436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
730*436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
731*436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
732*436c6c9cSStella Laurenzo 
733*436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
734*436c6c9cSStella Laurenzo     c.def_static(
735*436c6c9cSStella Laurenzo         "get",
736*436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
737*436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
738*436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
739*436c6c9cSStella Laurenzo         },
740*436c6c9cSStella Laurenzo         py::arg("context") = py::none(), "Create a Unit attribute.");
741*436c6c9cSStella Laurenzo   }
742*436c6c9cSStella Laurenzo };
743*436c6c9cSStella Laurenzo 
744*436c6c9cSStella Laurenzo } // namespace
745*436c6c9cSStella Laurenzo 
746*436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) {
747*436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
748*436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
749*436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
750*436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
751*436c6c9cSStella Laurenzo   PyDenseElementsAttribute::bind(m);
752*436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
753*436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
754*436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
755*436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
756*436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
757*436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
758*436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
759*436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
760*436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
761*436c6c9cSStella Laurenzo }
762