xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision 5d6d30edf8b9b2c69215bdbbc651a85e4d0dc4ff)
1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9436c6c9cSStella Laurenzo #include "IRModule.h"
10436c6c9cSStella Laurenzo 
11436c6c9cSStella Laurenzo #include "PybindUtils.h"
12436c6c9cSStella Laurenzo 
13436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
14436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
15436c6c9cSStella Laurenzo 
16436c6c9cSStella Laurenzo namespace py = pybind11;
17436c6c9cSStella Laurenzo using namespace mlir;
18436c6c9cSStella Laurenzo using namespace mlir::python;
19436c6c9cSStella Laurenzo 
20*5d6d30edSStella Laurenzo using llvm::None;
21*5d6d30edSStella Laurenzo using llvm::Optional;
22436c6c9cSStella Laurenzo using llvm::SmallVector;
23436c6c9cSStella Laurenzo using llvm::Twine;
24436c6c9cSStella Laurenzo 
25*5d6d30edSStella Laurenzo //------------------------------------------------------------------------------
26*5d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
27*5d6d30edSStella Laurenzo //------------------------------------------------------------------------------
28*5d6d30edSStella Laurenzo 
29*5d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] =
30*5d6d30edSStella Laurenzo     R"(Gets a DenseElementsAttr from a Python buffer or array.
31*5d6d30edSStella Laurenzo 
32*5d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based
33*5d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and
34*5d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these
35*5d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer.
36*5d6d30edSStella Laurenzo 
37*5d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided
38*5d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal
39*5d6d30edSStella Laurenzo representation:
40*5d6d30edSStella Laurenzo 
41*5d6d30edSStella Laurenzo   * Integer types (except for i1): the buffer must be byte aligned to the
42*5d6d30edSStella Laurenzo     next byte boundary.
43*5d6d30edSStella Laurenzo   * Floating point types: Must be bit-castable to the given floating point
44*5d6d30edSStella Laurenzo     size.
45*5d6d30edSStella Laurenzo   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
46*5d6d30edSStella Laurenzo     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
47*5d6d30edSStella Laurenzo     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
48*5d6d30edSStella Laurenzo 
49*5d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0
50*5d6d30edSStella Laurenzo or 255), then a splat will be created.
51*5d6d30edSStella Laurenzo 
52*5d6d30edSStella Laurenzo Args:
53*5d6d30edSStella Laurenzo   array: The array or buffer to convert.
54*5d6d30edSStella Laurenzo   signless: If inferring an appropriate MLIR type, use signless types for
55*5d6d30edSStella Laurenzo     integers (defaults True).
56*5d6d30edSStella Laurenzo   type: Skips inference of the MLIR element type and uses this instead. The
57*5d6d30edSStella Laurenzo     storage size must be consistent with the actual contents of the buffer.
58*5d6d30edSStella Laurenzo   shape: Overrides the shape of the buffer when constructing the MLIR
59*5d6d30edSStella Laurenzo     shaped type. This is needed when the physical and logical shape differ (as
60*5d6d30edSStella Laurenzo     for i1).
61*5d6d30edSStella Laurenzo   context: Explicit context, if not from context manager.
62*5d6d30edSStella Laurenzo 
63*5d6d30edSStella Laurenzo Returns:
64*5d6d30edSStella Laurenzo   DenseElementsAttr on success.
65*5d6d30edSStella Laurenzo 
66*5d6d30edSStella Laurenzo Raises:
67*5d6d30edSStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
68*5d6d30edSStella Laurenzo     type or if the buffer does not meet expectations.
69*5d6d30edSStella Laurenzo )";
70*5d6d30edSStella Laurenzo 
71436c6c9cSStella Laurenzo namespace {
72436c6c9cSStella Laurenzo 
73436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
74436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
75436c6c9cSStella Laurenzo }
76436c6c9cSStella Laurenzo 
77436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
78436c6c9cSStella Laurenzo public:
79436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
80436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
81436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
82436c6c9cSStella Laurenzo 
83436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
84436c6c9cSStella Laurenzo     c.def_static(
85436c6c9cSStella Laurenzo         "get",
86436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
87436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
88436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
89436c6c9cSStella Laurenzo         },
90436c6c9cSStella Laurenzo         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
91436c6c9cSStella Laurenzo   }
92436c6c9cSStella Laurenzo };
93436c6c9cSStella Laurenzo 
94ed9e52f3SAlex Zinenko template <typename T>
95ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) {
96ed9e52f3SAlex Zinenko   try {
97ed9e52f3SAlex Zinenko     return object.cast<T>();
98ed9e52f3SAlex Zinenko   } catch (py::cast_error &err) {
99ed9e52f3SAlex Zinenko     std::string msg =
100ed9e52f3SAlex Zinenko         std::string(
101ed9e52f3SAlex Zinenko             "Invalid attribute when attempting to create an ArrayAttribute (") +
102ed9e52f3SAlex Zinenko         err.what() + ")";
103ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
104ed9e52f3SAlex Zinenko   } catch (py::reference_cast_error &err) {
105ed9e52f3SAlex Zinenko     std::string msg = std::string("Invalid attribute (None?) when attempting "
106ed9e52f3SAlex Zinenko                                   "to create an ArrayAttribute (") +
107ed9e52f3SAlex Zinenko                       err.what() + ")";
108ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
109ed9e52f3SAlex Zinenko   }
110ed9e52f3SAlex Zinenko }
111ed9e52f3SAlex Zinenko 
112436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
113436c6c9cSStella Laurenzo public:
114436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
115436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
116436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
117436c6c9cSStella Laurenzo 
118436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
119436c6c9cSStella Laurenzo   public:
120436c6c9cSStella Laurenzo     PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
121436c6c9cSStella Laurenzo 
122436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
123436c6c9cSStella Laurenzo 
124436c6c9cSStella Laurenzo     PyAttribute dunderNext() {
125436c6c9cSStella Laurenzo       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
126436c6c9cSStella Laurenzo         throw py::stop_iteration();
127436c6c9cSStella Laurenzo       }
128436c6c9cSStella Laurenzo       return PyAttribute(attr.getContext(),
129436c6c9cSStella Laurenzo                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
130436c6c9cSStella Laurenzo     }
131436c6c9cSStella Laurenzo 
132436c6c9cSStella Laurenzo     static void bind(py::module &m) {
133f05ff4f7SStella Laurenzo       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
134f05ff4f7SStella Laurenzo                                            py::module_local())
135436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
136436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
137436c6c9cSStella Laurenzo     }
138436c6c9cSStella Laurenzo 
139436c6c9cSStella Laurenzo   private:
140436c6c9cSStella Laurenzo     PyAttribute attr;
141436c6c9cSStella Laurenzo     int nextIndex = 0;
142436c6c9cSStella Laurenzo   };
143436c6c9cSStella Laurenzo 
144ed9e52f3SAlex Zinenko   PyAttribute getItem(intptr_t i) {
145ed9e52f3SAlex Zinenko     return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
146ed9e52f3SAlex Zinenko   }
147ed9e52f3SAlex Zinenko 
148436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
149436c6c9cSStella Laurenzo     c.def_static(
150436c6c9cSStella Laurenzo         "get",
151436c6c9cSStella Laurenzo         [](py::list attributes, DefaultingPyMlirContext context) {
152436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
153436c6c9cSStella Laurenzo           mlirAttributes.reserve(py::len(attributes));
154436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
155ed9e52f3SAlex Zinenko             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
156436c6c9cSStella Laurenzo           }
157436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
158436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
159436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
160436c6c9cSStella Laurenzo         },
161436c6c9cSStella Laurenzo         py::arg("attributes"), py::arg("context") = py::none(),
162436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
163436c6c9cSStella Laurenzo     c.def("__getitem__",
164436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
165436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
166436c6c9cSStella Laurenzo               throw py::index_error("ArrayAttribute index out of range");
167ed9e52f3SAlex Zinenko             return arr.getItem(i);
168436c6c9cSStella Laurenzo           })
169436c6c9cSStella Laurenzo         .def("__len__",
170436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
171436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
172436c6c9cSStella Laurenzo              })
173436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
174436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
175436c6c9cSStella Laurenzo         });
176ed9e52f3SAlex Zinenko     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
177ed9e52f3SAlex Zinenko       std::vector<MlirAttribute> attributes;
178ed9e52f3SAlex Zinenko       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
179ed9e52f3SAlex Zinenko       attributes.reserve(numOldElements + py::len(extras));
180ed9e52f3SAlex Zinenko       for (intptr_t i = 0; i < numOldElements; ++i)
181ed9e52f3SAlex Zinenko         attributes.push_back(arr.getItem(i));
182ed9e52f3SAlex Zinenko       for (py::handle attr : extras)
183ed9e52f3SAlex Zinenko         attributes.push_back(pyTryCast<PyAttribute>(attr));
184ed9e52f3SAlex Zinenko       MlirAttribute arrayAttr = mlirArrayAttrGet(
185ed9e52f3SAlex Zinenko           arr.getContext()->get(), attributes.size(), attributes.data());
186ed9e52f3SAlex Zinenko       return PyArrayAttribute(arr.getContext(), arrayAttr);
187ed9e52f3SAlex Zinenko     });
188436c6c9cSStella Laurenzo   }
189436c6c9cSStella Laurenzo };
190436c6c9cSStella Laurenzo 
191436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
192436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
193436c6c9cSStella Laurenzo public:
194436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
195436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
196436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
197436c6c9cSStella Laurenzo 
198436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
199436c6c9cSStella Laurenzo     c.def_static(
200436c6c9cSStella Laurenzo         "get",
201436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
202436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
203436c6c9cSStella Laurenzo           // TODO: Rework error reporting once diagnostic engine is exposed
204436c6c9cSStella Laurenzo           // in C API.
205436c6c9cSStella Laurenzo           if (mlirAttributeIsNull(attr)) {
206436c6c9cSStella Laurenzo             throw SetPyError(PyExc_ValueError,
207436c6c9cSStella Laurenzo                              Twine("invalid '") +
208436c6c9cSStella Laurenzo                                  py::repr(py::cast(type)).cast<std::string>() +
209436c6c9cSStella Laurenzo                                  "' and expected floating point type.");
210436c6c9cSStella Laurenzo           }
211436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
212436c6c9cSStella Laurenzo         },
213436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
214436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
215436c6c9cSStella Laurenzo     c.def_static(
216436c6c9cSStella Laurenzo         "get_f32",
217436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
218436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
219436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
220436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
221436c6c9cSStella Laurenzo         },
222436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
223436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
224436c6c9cSStella Laurenzo     c.def_static(
225436c6c9cSStella Laurenzo         "get_f64",
226436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
227436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
228436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
229436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
230436c6c9cSStella Laurenzo         },
231436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
232436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
233436c6c9cSStella Laurenzo     c.def_property_readonly(
234436c6c9cSStella Laurenzo         "value",
235436c6c9cSStella Laurenzo         [](PyFloatAttribute &self) {
236436c6c9cSStella Laurenzo           return mlirFloatAttrGetValueDouble(self);
237436c6c9cSStella Laurenzo         },
238436c6c9cSStella Laurenzo         "Returns the value of the float point attribute");
239436c6c9cSStella Laurenzo   }
240436c6c9cSStella Laurenzo };
241436c6c9cSStella Laurenzo 
242436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
243436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
244436c6c9cSStella Laurenzo public:
245436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
246436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
247436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
248436c6c9cSStella Laurenzo 
249436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
250436c6c9cSStella Laurenzo     c.def_static(
251436c6c9cSStella Laurenzo         "get",
252436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
253436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
254436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
255436c6c9cSStella Laurenzo         },
256436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"),
257436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
258436c6c9cSStella Laurenzo     c.def_property_readonly(
259436c6c9cSStella Laurenzo         "value",
260436c6c9cSStella Laurenzo         [](PyIntegerAttribute &self) {
261436c6c9cSStella Laurenzo           return mlirIntegerAttrGetValueInt(self);
262436c6c9cSStella Laurenzo         },
263436c6c9cSStella Laurenzo         "Returns the value of the integer attribute");
264436c6c9cSStella Laurenzo   }
265436c6c9cSStella Laurenzo };
266436c6c9cSStella Laurenzo 
267436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
268436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
269436c6c9cSStella Laurenzo public:
270436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
271436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
272436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
273436c6c9cSStella Laurenzo 
274436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
275436c6c9cSStella Laurenzo     c.def_static(
276436c6c9cSStella Laurenzo         "get",
277436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
278436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
279436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
280436c6c9cSStella Laurenzo         },
281436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
282436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
283436c6c9cSStella Laurenzo     c.def_property_readonly(
284436c6c9cSStella Laurenzo         "value",
285436c6c9cSStella Laurenzo         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
286436c6c9cSStella Laurenzo         "Returns the value of the bool attribute");
287436c6c9cSStella Laurenzo   }
288436c6c9cSStella Laurenzo };
289436c6c9cSStella Laurenzo 
290436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
291436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
292436c6c9cSStella Laurenzo public:
293436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
294436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
295436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
296436c6c9cSStella Laurenzo 
297436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
298436c6c9cSStella Laurenzo     c.def_static(
299436c6c9cSStella Laurenzo         "get",
300436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
301436c6c9cSStella Laurenzo           MlirAttribute attr =
302436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
303436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
304436c6c9cSStella Laurenzo         },
305436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
306436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
307436c6c9cSStella Laurenzo     c.def_property_readonly(
308436c6c9cSStella Laurenzo         "value",
309436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
310436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
311436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
312436c6c9cSStella Laurenzo         },
313436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
314436c6c9cSStella Laurenzo   }
315436c6c9cSStella Laurenzo };
316436c6c9cSStella Laurenzo 
317436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
318436c6c9cSStella Laurenzo public:
319436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
320436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
321436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
322436c6c9cSStella Laurenzo 
323436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
324436c6c9cSStella Laurenzo     c.def_static(
325436c6c9cSStella Laurenzo         "get",
326436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
327436c6c9cSStella Laurenzo           MlirAttribute attr =
328436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
329436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
330436c6c9cSStella Laurenzo         },
331436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
332436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
333436c6c9cSStella Laurenzo     c.def_static(
334436c6c9cSStella Laurenzo         "get_typed",
335436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
336436c6c9cSStella Laurenzo           MlirAttribute attr =
337436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
338436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
339436c6c9cSStella Laurenzo         },
340436c6c9cSStella Laurenzo 
341436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
342436c6c9cSStella Laurenzo     c.def_property_readonly(
343436c6c9cSStella Laurenzo         "value",
344436c6c9cSStella Laurenzo         [](PyStringAttribute &self) {
345436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirStringAttrGetValue(self);
346436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
347436c6c9cSStella Laurenzo         },
348436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
349436c6c9cSStella Laurenzo   }
350436c6c9cSStella Laurenzo };
351436c6c9cSStella Laurenzo 
352436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
353436c6c9cSStella Laurenzo class PyDenseElementsAttribute
354436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
355436c6c9cSStella Laurenzo public:
356436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
357436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
358436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
359436c6c9cSStella Laurenzo 
360436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
361*5d6d30edSStella Laurenzo   getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
362*5d6d30edSStella Laurenzo                 Optional<std::vector<int64_t>> explicitShape,
363436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
364436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
365436c6c9cSStella Laurenzo     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
366436c6c9cSStella Laurenzo     Py_buffer *view = new Py_buffer();
367436c6c9cSStella Laurenzo     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
368436c6c9cSStella Laurenzo       delete view;
369436c6c9cSStella Laurenzo       throw py::error_already_set();
370436c6c9cSStella Laurenzo     }
371436c6c9cSStella Laurenzo     py::buffer_info arrayInfo(view);
372*5d6d30edSStella Laurenzo     SmallVector<int64_t> shape;
373*5d6d30edSStella Laurenzo     if (explicitShape) {
374*5d6d30edSStella Laurenzo       shape.append(explicitShape->begin(), explicitShape->end());
375*5d6d30edSStella Laurenzo     } else {
376*5d6d30edSStella Laurenzo       shape.append(arrayInfo.shape.begin(),
377*5d6d30edSStella Laurenzo                    arrayInfo.shape.begin() + arrayInfo.ndim);
378*5d6d30edSStella Laurenzo     }
379436c6c9cSStella Laurenzo 
380*5d6d30edSStella Laurenzo     MlirAttribute encodingAttr = mlirAttributeGetNull();
381436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
382*5d6d30edSStella Laurenzo 
383*5d6d30edSStella Laurenzo     // Detect format codes that are suitable for bulk loading. This includes
384*5d6d30edSStella Laurenzo     // all byte aligned integer and floating point types up to 8 bytes.
385*5d6d30edSStella Laurenzo     // Notably, this excludes, bool (which needs to be bit-packed) and
386*5d6d30edSStella Laurenzo     // other exotics which do not have a direct representation in the buffer
387*5d6d30edSStella Laurenzo     // protocol (i.e. complex, etc).
388*5d6d30edSStella Laurenzo     Optional<MlirType> bulkLoadElementType;
389*5d6d30edSStella Laurenzo     if (explicitType) {
390*5d6d30edSStella Laurenzo       bulkLoadElementType = *explicitType;
391*5d6d30edSStella Laurenzo     } else if (arrayInfo.format == "f") {
392436c6c9cSStella Laurenzo       // f32
393436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
394*5d6d30edSStella Laurenzo       bulkLoadElementType = mlirF32TypeGet(context);
395436c6c9cSStella Laurenzo     } else if (arrayInfo.format == "d") {
396436c6c9cSStella Laurenzo       // f64
397436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
398*5d6d30edSStella Laurenzo       bulkLoadElementType = mlirF64TypeGet(context);
399*5d6d30edSStella Laurenzo     } else if (arrayInfo.format == "e") {
400*5d6d30edSStella Laurenzo       // f16
401*5d6d30edSStella Laurenzo       assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
402*5d6d30edSStella Laurenzo       bulkLoadElementType = mlirF16TypeGet(context);
403436c6c9cSStella Laurenzo     } else if (isSignedIntegerFormat(arrayInfo.format)) {
404436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
405436c6c9cSStella Laurenzo         // i32
406*5d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
407436c6c9cSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 32);
408436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
409436c6c9cSStella Laurenzo         // i64
410*5d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
411436c6c9cSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 64);
412*5d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 1) {
413*5d6d30edSStella Laurenzo         // i8
414*5d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
415*5d6d30edSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 8);
416*5d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 2) {
417*5d6d30edSStella Laurenzo         // i16
418*5d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
419*5d6d30edSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 16);
420436c6c9cSStella Laurenzo       }
421436c6c9cSStella Laurenzo     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
422436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
423436c6c9cSStella Laurenzo         // unsigned i32
424*5d6d30edSStella Laurenzo         bulkLoadElementType = signless
425436c6c9cSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 32)
426436c6c9cSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 32);
427436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
428436c6c9cSStella Laurenzo         // unsigned i64
429*5d6d30edSStella Laurenzo         bulkLoadElementType = signless
430436c6c9cSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 64)
431436c6c9cSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 64);
432*5d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 1) {
433*5d6d30edSStella Laurenzo         // i8
434*5d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
435*5d6d30edSStella Laurenzo                                        : mlirIntegerTypeUnsignedGet(context, 8);
436*5d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 2) {
437*5d6d30edSStella Laurenzo         // i16
438*5d6d30edSStella Laurenzo         bulkLoadElementType = signless
439*5d6d30edSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 16)
440*5d6d30edSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 16);
441436c6c9cSStella Laurenzo       }
442436c6c9cSStella Laurenzo     }
443*5d6d30edSStella Laurenzo     if (bulkLoadElementType) {
444*5d6d30edSStella Laurenzo       auto shapedType = mlirRankedTensorTypeGet(
445*5d6d30edSStella Laurenzo           shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
446*5d6d30edSStella Laurenzo       size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
447*5d6d30edSStella Laurenzo       MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
448*5d6d30edSStella Laurenzo           shapedType, rawBufferSize, arrayInfo.ptr);
449*5d6d30edSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
450*5d6d30edSStella Laurenzo         throw std::invalid_argument(
451*5d6d30edSStella Laurenzo             "DenseElementsAttr could not be constructed from the given buffer. "
452*5d6d30edSStella Laurenzo             "This may mean that the Python buffer layout does not match that "
453*5d6d30edSStella Laurenzo             "MLIR expected layout and is a bug.");
454*5d6d30edSStella Laurenzo       }
455*5d6d30edSStella Laurenzo       return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
456*5d6d30edSStella Laurenzo     }
457436c6c9cSStella Laurenzo 
458*5d6d30edSStella Laurenzo     throw std::invalid_argument(
459*5d6d30edSStella Laurenzo         std::string("unimplemented array format conversion from format: ") +
460*5d6d30edSStella Laurenzo         arrayInfo.format);
461436c6c9cSStella Laurenzo   }
462436c6c9cSStella Laurenzo 
463436c6c9cSStella Laurenzo   static PyDenseElementsAttribute getSplat(PyType shapedType,
464436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
465436c6c9cSStella Laurenzo     auto contextWrapper =
466436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
467436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
468436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
469436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
470436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
471436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
472436c6c9cSStella Laurenzo     }
473436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
474436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
475436c6c9cSStella Laurenzo       std::string message =
476436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
477436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
478436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
479436c6c9cSStella Laurenzo     }
480436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
481436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
482436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
483436c6c9cSStella Laurenzo       std::string message =
484436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
485436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
486436c6c9cSStella Laurenzo       message.append(", element=");
487436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
488436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
489436c6c9cSStella Laurenzo     }
490436c6c9cSStella Laurenzo 
491436c6c9cSStella Laurenzo     MlirAttribute elements =
492436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
493436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
494436c6c9cSStella Laurenzo   }
495436c6c9cSStella Laurenzo 
496436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
497436c6c9cSStella Laurenzo 
498436c6c9cSStella Laurenzo   py::buffer_info accessBuffer() {
499*5d6d30edSStella Laurenzo     if (mlirDenseElementsAttrIsSplat(*this)) {
500*5d6d30edSStella Laurenzo       // TODO: Raise an exception.
501*5d6d30edSStella Laurenzo       // Reported as https://github.com/pybind/pybind11/issues/3336
502*5d6d30edSStella Laurenzo       return py::buffer_info();
503*5d6d30edSStella Laurenzo     }
504*5d6d30edSStella Laurenzo 
505436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
506436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
507*5d6d30edSStella Laurenzo     std::string format;
508436c6c9cSStella Laurenzo 
509436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
510436c6c9cSStella Laurenzo       // f32
511*5d6d30edSStella Laurenzo       return bufferInfo<float>(shapedType);
512436c6c9cSStella Laurenzo     } else if (mlirTypeIsAF64(elementType)) {
513436c6c9cSStella Laurenzo       // f64
514*5d6d30edSStella Laurenzo       return bufferInfo<double>(shapedType);
515*5d6d30edSStella Laurenzo     } else if (mlirTypeIsAF16(elementType)) {
516*5d6d30edSStella Laurenzo       // f16
517*5d6d30edSStella Laurenzo       return bufferInfo<uint16_t>(shapedType, "e");
518436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
519436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 32) {
520436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
521436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
522436c6c9cSStella Laurenzo         // i32
523*5d6d30edSStella Laurenzo         return bufferInfo<int32_t>(shapedType);
524436c6c9cSStella Laurenzo       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
525436c6c9cSStella Laurenzo         // unsigned i32
526*5d6d30edSStella Laurenzo         return bufferInfo<uint32_t>(shapedType);
527436c6c9cSStella Laurenzo       }
528436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
529436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
530436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
531436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
532436c6c9cSStella Laurenzo         // i64
533*5d6d30edSStella Laurenzo         return bufferInfo<int64_t>(shapedType);
534436c6c9cSStella Laurenzo       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
535436c6c9cSStella Laurenzo         // unsigned i64
536*5d6d30edSStella Laurenzo         return bufferInfo<uint64_t>(shapedType);
537*5d6d30edSStella Laurenzo       }
538*5d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
539*5d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 8) {
540*5d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
541*5d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
542*5d6d30edSStella Laurenzo         // i8
543*5d6d30edSStella Laurenzo         return bufferInfo<int8_t>(shapedType);
544*5d6d30edSStella Laurenzo       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
545*5d6d30edSStella Laurenzo         // unsigned i8
546*5d6d30edSStella Laurenzo         return bufferInfo<uint8_t>(shapedType);
547*5d6d30edSStella Laurenzo       }
548*5d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
549*5d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 16) {
550*5d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
551*5d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
552*5d6d30edSStella Laurenzo         // i16
553*5d6d30edSStella Laurenzo         return bufferInfo<int16_t>(shapedType);
554*5d6d30edSStella Laurenzo       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
555*5d6d30edSStella Laurenzo         // unsigned i16
556*5d6d30edSStella Laurenzo         return bufferInfo<uint16_t>(shapedType);
557436c6c9cSStella Laurenzo       }
558436c6c9cSStella Laurenzo     }
559436c6c9cSStella Laurenzo 
560*5d6d30edSStella Laurenzo     // TODO: Currently crashes the program. Just returning an empty buffer
561*5d6d30edSStella Laurenzo     // for now.
562*5d6d30edSStella Laurenzo     // Reported as https://github.com/pybind/pybind11/issues/3336
563*5d6d30edSStella Laurenzo     // throw std::invalid_argument(
564*5d6d30edSStella Laurenzo     //     "unsupported data type for conversion to Python buffer");
565*5d6d30edSStella Laurenzo     return py::buffer_info();
566436c6c9cSStella Laurenzo   }
567436c6c9cSStella Laurenzo 
568436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
569436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
570436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
571436c6c9cSStella Laurenzo                     py::arg("array"), py::arg("signless") = true,
572*5d6d30edSStella Laurenzo                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
573436c6c9cSStella Laurenzo                     py::arg("context") = py::none(),
574*5d6d30edSStella Laurenzo                     kDenseElementsAttrGetDocstring)
575436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
576436c6c9cSStella Laurenzo                     py::arg("shaped_type"), py::arg("element_attr"),
577436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
578436c6c9cSStella Laurenzo         .def_property_readonly("is_splat",
579436c6c9cSStella Laurenzo                                [](PyDenseElementsAttribute &self) -> bool {
580436c6c9cSStella Laurenzo                                  return mlirDenseElementsAttrIsSplat(self);
581436c6c9cSStella Laurenzo                                })
582436c6c9cSStella Laurenzo         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
583436c6c9cSStella Laurenzo   }
584436c6c9cSStella Laurenzo 
585436c6c9cSStella Laurenzo private:
586436c6c9cSStella Laurenzo   static bool isUnsignedIntegerFormat(const std::string &format) {
587436c6c9cSStella Laurenzo     if (format.empty())
588436c6c9cSStella Laurenzo       return false;
589436c6c9cSStella Laurenzo     char code = format[0];
590436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
591436c6c9cSStella Laurenzo            code == 'Q';
592436c6c9cSStella Laurenzo   }
593436c6c9cSStella Laurenzo 
594436c6c9cSStella Laurenzo   static bool isSignedIntegerFormat(const std::string &format) {
595436c6c9cSStella Laurenzo     if (format.empty())
596436c6c9cSStella Laurenzo       return false;
597436c6c9cSStella Laurenzo     char code = format[0];
598436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
599436c6c9cSStella Laurenzo            code == 'q';
600436c6c9cSStella Laurenzo   }
601436c6c9cSStella Laurenzo 
602436c6c9cSStella Laurenzo   template <typename Type>
603436c6c9cSStella Laurenzo   py::buffer_info bufferInfo(MlirType shapedType,
604*5d6d30edSStella Laurenzo                              const char *explicitFormat = nullptr) {
605436c6c9cSStella Laurenzo     intptr_t rank = mlirShapedTypeGetRank(shapedType);
606436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
607436c6c9cSStella Laurenzo     // Buffer is configured for read-only access below.
608436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
609436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
610436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
611436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
612436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
613436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
614436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
615436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
616436c6c9cSStella Laurenzo     intptr_t strideFactor = 1;
617436c6c9cSStella Laurenzo     for (intptr_t i = 1; i < rank; ++i) {
618436c6c9cSStella Laurenzo       strideFactor = 1;
619436c6c9cSStella Laurenzo       for (intptr_t j = i; j < rank; ++j) {
620436c6c9cSStella Laurenzo         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
621436c6c9cSStella Laurenzo       }
622436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type) * strideFactor);
623436c6c9cSStella Laurenzo     }
624436c6c9cSStella Laurenzo     strides.push_back(sizeof(Type));
625*5d6d30edSStella Laurenzo     std::string format;
626*5d6d30edSStella Laurenzo     if (explicitFormat) {
627*5d6d30edSStella Laurenzo       format = explicitFormat;
628*5d6d30edSStella Laurenzo     } else {
629*5d6d30edSStella Laurenzo       format = py::format_descriptor<Type>::format();
630*5d6d30edSStella Laurenzo     }
631*5d6d30edSStella Laurenzo     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
632*5d6d30edSStella Laurenzo                            /*readonly=*/true);
633436c6c9cSStella Laurenzo   }
634436c6c9cSStella Laurenzo }; // namespace
635436c6c9cSStella Laurenzo 
636436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer
637436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access.
638436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
639436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
640436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
641436c6c9cSStella Laurenzo public:
642436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
643436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
644436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
645436c6c9cSStella Laurenzo 
646436c6c9cSStella Laurenzo   /// Returns the element at the given linear position. Asserts if the index is
647436c6c9cSStella Laurenzo   /// out of range.
648436c6c9cSStella Laurenzo   py::int_ dunderGetItem(intptr_t pos) {
649436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
650436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
651436c6c9cSStella Laurenzo                        "attempt to access out of bounds element");
652436c6c9cSStella Laurenzo     }
653436c6c9cSStella Laurenzo 
654436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
655436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
656436c6c9cSStella Laurenzo     assert(mlirTypeIsAInteger(type) &&
657436c6c9cSStella Laurenzo            "expected integer element type in dense int elements attribute");
658436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
659436c6c9cSStella Laurenzo     // elemental type of the attribute. py::int_ is implicitly constructible
660436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
661436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
662436c6c9cSStella Laurenzo     // querying them on each element access.
663436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
664436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
665436c6c9cSStella Laurenzo     if (isUnsigned) {
666436c6c9cSStella Laurenzo       if (width == 1) {
667436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
668436c6c9cSStella Laurenzo       }
669436c6c9cSStella Laurenzo       if (width == 32) {
670436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
671436c6c9cSStella Laurenzo       }
672436c6c9cSStella Laurenzo       if (width == 64) {
673436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
674436c6c9cSStella Laurenzo       }
675436c6c9cSStella Laurenzo     } else {
676436c6c9cSStella Laurenzo       if (width == 1) {
677436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
678436c6c9cSStella Laurenzo       }
679436c6c9cSStella Laurenzo       if (width == 32) {
680436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt32Value(*this, pos);
681436c6c9cSStella Laurenzo       }
682436c6c9cSStella Laurenzo       if (width == 64) {
683436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt64Value(*this, pos);
684436c6c9cSStella Laurenzo       }
685436c6c9cSStella Laurenzo     }
686436c6c9cSStella Laurenzo     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
687436c6c9cSStella Laurenzo   }
688436c6c9cSStella Laurenzo 
689436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
690436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
691436c6c9cSStella Laurenzo   }
692436c6c9cSStella Laurenzo };
693436c6c9cSStella Laurenzo 
694436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
695436c6c9cSStella Laurenzo public:
696436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
697436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
698436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
699436c6c9cSStella Laurenzo 
700436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
701436c6c9cSStella Laurenzo 
702436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
703436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
704436c6c9cSStella Laurenzo     c.def_static(
705436c6c9cSStella Laurenzo         "get",
706436c6c9cSStella Laurenzo         [](py::dict attributes, DefaultingPyMlirContext context) {
707436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
708436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
709436c6c9cSStella Laurenzo           for (auto &it : attributes) {
710436c6c9cSStella Laurenzo             auto &mlir_attr = it.second.cast<PyAttribute &>();
711436c6c9cSStella Laurenzo             auto name = it.first.cast<std::string>();
712436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
713436c6c9cSStella Laurenzo                 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
714436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
715436c6c9cSStella Laurenzo                 mlir_attr));
716436c6c9cSStella Laurenzo           }
717436c6c9cSStella Laurenzo           MlirAttribute attr =
718436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
719436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
720436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
721436c6c9cSStella Laurenzo         },
722ed9e52f3SAlex Zinenko         py::arg("value") = py::dict(), py::arg("context") = py::none(),
723436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
724436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
725436c6c9cSStella Laurenzo       MlirAttribute attr =
726436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
727436c6c9cSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
728436c6c9cSStella Laurenzo         throw SetPyError(PyExc_KeyError,
729436c6c9cSStella Laurenzo                          "attempt to access a non-existent attribute");
730436c6c9cSStella Laurenzo       }
731436c6c9cSStella Laurenzo       return PyAttribute(self.getContext(), attr);
732436c6c9cSStella Laurenzo     });
733436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
734436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
735436c6c9cSStella Laurenzo         throw SetPyError(PyExc_IndexError,
736436c6c9cSStella Laurenzo                          "attempt to access out of bounds attribute");
737436c6c9cSStella Laurenzo       }
738436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
739436c6c9cSStella Laurenzo       return PyNamedAttribute(
740436c6c9cSStella Laurenzo           namedAttr.attribute,
741436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
742436c6c9cSStella Laurenzo     });
743436c6c9cSStella Laurenzo   }
744436c6c9cSStella Laurenzo };
745436c6c9cSStella Laurenzo 
746436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
747436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
748436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
749436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
750436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
751436c6c9cSStella Laurenzo public:
752436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
753436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
754436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
755436c6c9cSStella Laurenzo 
756436c6c9cSStella Laurenzo   py::float_ dunderGetItem(intptr_t pos) {
757436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
758436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
759436c6c9cSStella Laurenzo                        "attempt to access out of bounds element");
760436c6c9cSStella Laurenzo     }
761436c6c9cSStella Laurenzo 
762436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
763436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
764436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
765436c6c9cSStella Laurenzo     // elemental type of the attribute. py::float_ is implicitly constructible
766436c6c9cSStella Laurenzo     // from float and double.
767436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
768436c6c9cSStella Laurenzo     // querying them on each element access.
769436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
770436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetFloatValue(*this, pos);
771436c6c9cSStella Laurenzo     }
772436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
773436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
774436c6c9cSStella Laurenzo     }
775436c6c9cSStella Laurenzo     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
776436c6c9cSStella Laurenzo   }
777436c6c9cSStella Laurenzo 
778436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
779436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
780436c6c9cSStella Laurenzo   }
781436c6c9cSStella Laurenzo };
782436c6c9cSStella Laurenzo 
783436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
784436c6c9cSStella Laurenzo public:
785436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
786436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
787436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
788436c6c9cSStella Laurenzo 
789436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
790436c6c9cSStella Laurenzo     c.def_static(
791436c6c9cSStella Laurenzo         "get",
792436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
793436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
794436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
795436c6c9cSStella Laurenzo         },
796436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
797436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
798436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyTypeAttribute &self) {
799436c6c9cSStella Laurenzo       return PyType(self.getContext()->getRef(),
800436c6c9cSStella Laurenzo                     mlirTypeAttrGetValue(self.get()));
801436c6c9cSStella Laurenzo     });
802436c6c9cSStella Laurenzo   }
803436c6c9cSStella Laurenzo };
804436c6c9cSStella Laurenzo 
805436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
806436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
807436c6c9cSStella Laurenzo public:
808436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
809436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
810436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
811436c6c9cSStella Laurenzo 
812436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
813436c6c9cSStella Laurenzo     c.def_static(
814436c6c9cSStella Laurenzo         "get",
815436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
816436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
817436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
818436c6c9cSStella Laurenzo         },
819436c6c9cSStella Laurenzo         py::arg("context") = py::none(), "Create a Unit attribute.");
820436c6c9cSStella Laurenzo   }
821436c6c9cSStella Laurenzo };
822436c6c9cSStella Laurenzo 
823436c6c9cSStella Laurenzo } // namespace
824436c6c9cSStella Laurenzo 
825436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) {
826436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
827436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
828436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
829436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
830436c6c9cSStella Laurenzo   PyDenseElementsAttribute::bind(m);
831436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
832436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
833436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
834436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
835436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
836436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
837436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
838436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
839436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
840436c6c9cSStella Laurenzo }
841