xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision e9db306dcd53f33b982d772793ffe7326d40c018)
1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
91fc096afSMehdi Amini #include <utility>
101fc096afSMehdi Amini 
11436c6c9cSStella Laurenzo #include "IRModule.h"
12436c6c9cSStella Laurenzo 
13436c6c9cSStella Laurenzo #include "PybindUtils.h"
14436c6c9cSStella Laurenzo 
15436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
16436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
17436c6c9cSStella Laurenzo 
18436c6c9cSStella Laurenzo namespace py = pybind11;
19436c6c9cSStella Laurenzo using namespace mlir;
20436c6c9cSStella Laurenzo using namespace mlir::python;
21436c6c9cSStella Laurenzo 
225d6d30edSStella Laurenzo using llvm::Optional;
23436c6c9cSStella Laurenzo using llvm::SmallVector;
24436c6c9cSStella Laurenzo using llvm::Twine;
25436c6c9cSStella Laurenzo 
265d6d30edSStella Laurenzo //------------------------------------------------------------------------------
275d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
285d6d30edSStella Laurenzo //------------------------------------------------------------------------------
295d6d30edSStella Laurenzo 
305d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] =
315d6d30edSStella Laurenzo     R"(Gets a DenseElementsAttr from a Python buffer or array.
325d6d30edSStella Laurenzo 
335d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based
345d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and
355d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these
365d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer.
375d6d30edSStella Laurenzo 
385d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided
395d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal
405d6d30edSStella Laurenzo representation:
415d6d30edSStella Laurenzo 
425d6d30edSStella Laurenzo   * Integer types (except for i1): the buffer must be byte aligned to the
435d6d30edSStella Laurenzo     next byte boundary.
445d6d30edSStella Laurenzo   * Floating point types: Must be bit-castable to the given floating point
455d6d30edSStella Laurenzo     size.
465d6d30edSStella Laurenzo   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
475d6d30edSStella Laurenzo     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
485d6d30edSStella Laurenzo     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
495d6d30edSStella Laurenzo 
505d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0
515d6d30edSStella Laurenzo or 255), then a splat will be created.
525d6d30edSStella Laurenzo 
535d6d30edSStella Laurenzo Args:
545d6d30edSStella Laurenzo   array: The array or buffer to convert.
555d6d30edSStella Laurenzo   signless: If inferring an appropriate MLIR type, use signless types for
565d6d30edSStella Laurenzo     integers (defaults True).
575d6d30edSStella Laurenzo   type: Skips inference of the MLIR element type and uses this instead. The
585d6d30edSStella Laurenzo     storage size must be consistent with the actual contents of the buffer.
595d6d30edSStella Laurenzo   shape: Overrides the shape of the buffer when constructing the MLIR
605d6d30edSStella Laurenzo     shaped type. This is needed when the physical and logical shape differ (as
615d6d30edSStella Laurenzo     for i1).
625d6d30edSStella Laurenzo   context: Explicit context, if not from context manager.
635d6d30edSStella Laurenzo 
645d6d30edSStella Laurenzo Returns:
655d6d30edSStella Laurenzo   DenseElementsAttr on success.
665d6d30edSStella Laurenzo 
675d6d30edSStella Laurenzo Raises:
685d6d30edSStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
695d6d30edSStella Laurenzo     type or if the buffer does not meet expectations.
705d6d30edSStella Laurenzo )";
715d6d30edSStella Laurenzo 
72436c6c9cSStella Laurenzo namespace {
73436c6c9cSStella Laurenzo 
74436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
75436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
76436c6c9cSStella Laurenzo }
77436c6c9cSStella Laurenzo 
78436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
79436c6c9cSStella Laurenzo public:
80436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
81436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
82436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
83436c6c9cSStella Laurenzo 
84436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
85436c6c9cSStella Laurenzo     c.def_static(
86436c6c9cSStella Laurenzo         "get",
87436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
88436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
89436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
90436c6c9cSStella Laurenzo         },
91436c6c9cSStella Laurenzo         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
92436c6c9cSStella Laurenzo   }
93436c6c9cSStella Laurenzo };
94436c6c9cSStella Laurenzo 
95ed9e52f3SAlex Zinenko template <typename T>
96ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) {
97ed9e52f3SAlex Zinenko   try {
98ed9e52f3SAlex Zinenko     return object.cast<T>();
99ed9e52f3SAlex Zinenko   } catch (py::cast_error &err) {
100ed9e52f3SAlex Zinenko     std::string msg =
101ed9e52f3SAlex Zinenko         std::string(
102ed9e52f3SAlex Zinenko             "Invalid attribute when attempting to create an ArrayAttribute (") +
103ed9e52f3SAlex Zinenko         err.what() + ")";
104ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
105ed9e52f3SAlex Zinenko   } catch (py::reference_cast_error &err) {
106ed9e52f3SAlex Zinenko     std::string msg = std::string("Invalid attribute (None?) when attempting "
107ed9e52f3SAlex Zinenko                                   "to create an ArrayAttribute (") +
108ed9e52f3SAlex Zinenko                       err.what() + ")";
109ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
110ed9e52f3SAlex Zinenko   }
111ed9e52f3SAlex Zinenko }
112ed9e52f3SAlex Zinenko 
113436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
114436c6c9cSStella Laurenzo public:
115436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
116436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
117436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
118436c6c9cSStella Laurenzo 
119436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
120436c6c9cSStella Laurenzo   public:
1211fc096afSMehdi Amini     PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
122436c6c9cSStella Laurenzo 
123436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
124436c6c9cSStella Laurenzo 
125436c6c9cSStella Laurenzo     PyAttribute dunderNext() {
126436c6c9cSStella Laurenzo       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
127436c6c9cSStella Laurenzo         throw py::stop_iteration();
128436c6c9cSStella Laurenzo       }
129436c6c9cSStella Laurenzo       return PyAttribute(attr.getContext(),
130436c6c9cSStella Laurenzo                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
131436c6c9cSStella Laurenzo     }
132436c6c9cSStella Laurenzo 
133436c6c9cSStella Laurenzo     static void bind(py::module &m) {
134f05ff4f7SStella Laurenzo       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
135f05ff4f7SStella Laurenzo                                            py::module_local())
136436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
137436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
138436c6c9cSStella Laurenzo     }
139436c6c9cSStella Laurenzo 
140436c6c9cSStella Laurenzo   private:
141436c6c9cSStella Laurenzo     PyAttribute attr;
142436c6c9cSStella Laurenzo     int nextIndex = 0;
143436c6c9cSStella Laurenzo   };
144436c6c9cSStella Laurenzo 
145ed9e52f3SAlex Zinenko   PyAttribute getItem(intptr_t i) {
146ed9e52f3SAlex Zinenko     return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
147ed9e52f3SAlex Zinenko   }
148ed9e52f3SAlex Zinenko 
149436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
150436c6c9cSStella Laurenzo     c.def_static(
151436c6c9cSStella Laurenzo         "get",
152436c6c9cSStella Laurenzo         [](py::list attributes, DefaultingPyMlirContext context) {
153436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
154436c6c9cSStella Laurenzo           mlirAttributes.reserve(py::len(attributes));
155436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
156ed9e52f3SAlex Zinenko             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
157436c6c9cSStella Laurenzo           }
158436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
159436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
160436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
161436c6c9cSStella Laurenzo         },
162436c6c9cSStella Laurenzo         py::arg("attributes"), py::arg("context") = py::none(),
163436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
164436c6c9cSStella Laurenzo     c.def("__getitem__",
165436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
166436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
167436c6c9cSStella Laurenzo               throw py::index_error("ArrayAttribute index out of range");
168ed9e52f3SAlex Zinenko             return arr.getItem(i);
169436c6c9cSStella Laurenzo           })
170436c6c9cSStella Laurenzo         .def("__len__",
171436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
172436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
173436c6c9cSStella Laurenzo              })
174436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
175436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
176436c6c9cSStella Laurenzo         });
177ed9e52f3SAlex Zinenko     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
178ed9e52f3SAlex Zinenko       std::vector<MlirAttribute> attributes;
179ed9e52f3SAlex Zinenko       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
180ed9e52f3SAlex Zinenko       attributes.reserve(numOldElements + py::len(extras));
181ed9e52f3SAlex Zinenko       for (intptr_t i = 0; i < numOldElements; ++i)
182ed9e52f3SAlex Zinenko         attributes.push_back(arr.getItem(i));
183ed9e52f3SAlex Zinenko       for (py::handle attr : extras)
184ed9e52f3SAlex Zinenko         attributes.push_back(pyTryCast<PyAttribute>(attr));
185ed9e52f3SAlex Zinenko       MlirAttribute arrayAttr = mlirArrayAttrGet(
186ed9e52f3SAlex Zinenko           arr.getContext()->get(), attributes.size(), attributes.data());
187ed9e52f3SAlex Zinenko       return PyArrayAttribute(arr.getContext(), arrayAttr);
188ed9e52f3SAlex Zinenko     });
189436c6c9cSStella Laurenzo   }
190436c6c9cSStella Laurenzo };
191436c6c9cSStella Laurenzo 
192436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
193436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
194436c6c9cSStella Laurenzo public:
195436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
196436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
197436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
198436c6c9cSStella Laurenzo 
199436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
200436c6c9cSStella Laurenzo     c.def_static(
201436c6c9cSStella Laurenzo         "get",
202436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
203436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
204436c6c9cSStella Laurenzo           // TODO: Rework error reporting once diagnostic engine is exposed
205436c6c9cSStella Laurenzo           // in C API.
206436c6c9cSStella Laurenzo           if (mlirAttributeIsNull(attr)) {
207436c6c9cSStella Laurenzo             throw SetPyError(PyExc_ValueError,
208436c6c9cSStella Laurenzo                              Twine("invalid '") +
209436c6c9cSStella Laurenzo                                  py::repr(py::cast(type)).cast<std::string>() +
210436c6c9cSStella Laurenzo                                  "' and expected floating point type.");
211436c6c9cSStella Laurenzo           }
212436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
213436c6c9cSStella Laurenzo         },
214436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
215436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
216436c6c9cSStella Laurenzo     c.def_static(
217436c6c9cSStella Laurenzo         "get_f32",
218436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
219436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
220436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
221436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
222436c6c9cSStella Laurenzo         },
223436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
224436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
225436c6c9cSStella Laurenzo     c.def_static(
226436c6c9cSStella Laurenzo         "get_f64",
227436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
228436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
229436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
230436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
231436c6c9cSStella Laurenzo         },
232436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
233436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
234436c6c9cSStella Laurenzo     c.def_property_readonly(
235436c6c9cSStella Laurenzo         "value",
236436c6c9cSStella Laurenzo         [](PyFloatAttribute &self) {
237436c6c9cSStella Laurenzo           return mlirFloatAttrGetValueDouble(self);
238436c6c9cSStella Laurenzo         },
239436c6c9cSStella Laurenzo         "Returns the value of the float point attribute");
240436c6c9cSStella Laurenzo   }
241436c6c9cSStella Laurenzo };
242436c6c9cSStella Laurenzo 
243436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
244436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
245436c6c9cSStella Laurenzo public:
246436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
247436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
248436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
249436c6c9cSStella Laurenzo 
250436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
251436c6c9cSStella Laurenzo     c.def_static(
252436c6c9cSStella Laurenzo         "get",
253436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
254436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
255436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
256436c6c9cSStella Laurenzo         },
257436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"),
258436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
259436c6c9cSStella Laurenzo     c.def_property_readonly(
260436c6c9cSStella Laurenzo         "value",
261*e9db306dSrkayaith         [](PyIntegerAttribute &self) -> py::int_ {
262*e9db306dSrkayaith           MlirType type = mlirAttributeGetType(self);
263*e9db306dSrkayaith           if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
264436c6c9cSStella Laurenzo             return mlirIntegerAttrGetValueInt(self);
265*e9db306dSrkayaith           if (mlirIntegerTypeIsSigned(type))
266*e9db306dSrkayaith             return mlirIntegerAttrGetValueSInt(self);
267*e9db306dSrkayaith           return mlirIntegerAttrGetValueUInt(self);
268436c6c9cSStella Laurenzo         },
269436c6c9cSStella Laurenzo         "Returns the value of the integer attribute");
270436c6c9cSStella Laurenzo   }
271436c6c9cSStella Laurenzo };
272436c6c9cSStella Laurenzo 
273436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
274436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
275436c6c9cSStella Laurenzo public:
276436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
277436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
278436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
279436c6c9cSStella Laurenzo 
280436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
281436c6c9cSStella Laurenzo     c.def_static(
282436c6c9cSStella Laurenzo         "get",
283436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
284436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
285436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
286436c6c9cSStella Laurenzo         },
287436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
288436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
289436c6c9cSStella Laurenzo     c.def_property_readonly(
290436c6c9cSStella Laurenzo         "value",
291436c6c9cSStella Laurenzo         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
292436c6c9cSStella Laurenzo         "Returns the value of the bool attribute");
293436c6c9cSStella Laurenzo   }
294436c6c9cSStella Laurenzo };
295436c6c9cSStella Laurenzo 
296436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
297436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
298436c6c9cSStella Laurenzo public:
299436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
300436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
301436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
302436c6c9cSStella Laurenzo 
303436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
304436c6c9cSStella Laurenzo     c.def_static(
305436c6c9cSStella Laurenzo         "get",
306436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
307436c6c9cSStella Laurenzo           MlirAttribute attr =
308436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
309436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
310436c6c9cSStella Laurenzo         },
311436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
312436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
313436c6c9cSStella Laurenzo     c.def_property_readonly(
314436c6c9cSStella Laurenzo         "value",
315436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
316436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
317436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
318436c6c9cSStella Laurenzo         },
319436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
320436c6c9cSStella Laurenzo   }
321436c6c9cSStella Laurenzo };
322436c6c9cSStella Laurenzo 
323436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
324436c6c9cSStella Laurenzo public:
325436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
326436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
327436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
328436c6c9cSStella Laurenzo 
329436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
330436c6c9cSStella Laurenzo     c.def_static(
331436c6c9cSStella Laurenzo         "get",
332436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
333436c6c9cSStella Laurenzo           MlirAttribute attr =
334436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
335436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
336436c6c9cSStella Laurenzo         },
337436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
338436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
339436c6c9cSStella Laurenzo     c.def_static(
340436c6c9cSStella Laurenzo         "get_typed",
341436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
342436c6c9cSStella Laurenzo           MlirAttribute attr =
343436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
344436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
345436c6c9cSStella Laurenzo         },
346a6e7d024SStella Laurenzo         py::arg("type"), py::arg("value"),
347436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
348436c6c9cSStella Laurenzo     c.def_property_readonly(
349436c6c9cSStella Laurenzo         "value",
350436c6c9cSStella Laurenzo         [](PyStringAttribute &self) {
351436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirStringAttrGetValue(self);
352436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
353436c6c9cSStella Laurenzo         },
354436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
355436c6c9cSStella Laurenzo   }
356436c6c9cSStella Laurenzo };
357436c6c9cSStella Laurenzo 
358436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
359436c6c9cSStella Laurenzo class PyDenseElementsAttribute
360436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
361436c6c9cSStella Laurenzo public:
362436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
363436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
364436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
365436c6c9cSStella Laurenzo 
366436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
3675d6d30edSStella Laurenzo   getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
3685d6d30edSStella Laurenzo                 Optional<std::vector<int64_t>> explicitShape,
369436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
370436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
371436c6c9cSStella Laurenzo     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
372436c6c9cSStella Laurenzo     Py_buffer *view = new Py_buffer();
373436c6c9cSStella Laurenzo     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
374436c6c9cSStella Laurenzo       delete view;
375436c6c9cSStella Laurenzo       throw py::error_already_set();
376436c6c9cSStella Laurenzo     }
377436c6c9cSStella Laurenzo     py::buffer_info arrayInfo(view);
3785d6d30edSStella Laurenzo     SmallVector<int64_t> shape;
3795d6d30edSStella Laurenzo     if (explicitShape) {
3805d6d30edSStella Laurenzo       shape.append(explicitShape->begin(), explicitShape->end());
3815d6d30edSStella Laurenzo     } else {
3825d6d30edSStella Laurenzo       shape.append(arrayInfo.shape.begin(),
3835d6d30edSStella Laurenzo                    arrayInfo.shape.begin() + arrayInfo.ndim);
3845d6d30edSStella Laurenzo     }
385436c6c9cSStella Laurenzo 
3865d6d30edSStella Laurenzo     MlirAttribute encodingAttr = mlirAttributeGetNull();
387436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
3885d6d30edSStella Laurenzo 
3895d6d30edSStella Laurenzo     // Detect format codes that are suitable for bulk loading. This includes
3905d6d30edSStella Laurenzo     // all byte aligned integer and floating point types up to 8 bytes.
3915d6d30edSStella Laurenzo     // Notably, this excludes, bool (which needs to be bit-packed) and
3925d6d30edSStella Laurenzo     // other exotics which do not have a direct representation in the buffer
3935d6d30edSStella Laurenzo     // protocol (i.e. complex, etc).
3945d6d30edSStella Laurenzo     Optional<MlirType> bulkLoadElementType;
3955d6d30edSStella Laurenzo     if (explicitType) {
3965d6d30edSStella Laurenzo       bulkLoadElementType = *explicitType;
3975d6d30edSStella Laurenzo     } else if (arrayInfo.format == "f") {
398436c6c9cSStella Laurenzo       // f32
399436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
4005d6d30edSStella Laurenzo       bulkLoadElementType = mlirF32TypeGet(context);
401436c6c9cSStella Laurenzo     } else if (arrayInfo.format == "d") {
402436c6c9cSStella Laurenzo       // f64
403436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
4045d6d30edSStella Laurenzo       bulkLoadElementType = mlirF64TypeGet(context);
4055d6d30edSStella Laurenzo     } else if (arrayInfo.format == "e") {
4065d6d30edSStella Laurenzo       // f16
4075d6d30edSStella Laurenzo       assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
4085d6d30edSStella Laurenzo       bulkLoadElementType = mlirF16TypeGet(context);
409436c6c9cSStella Laurenzo     } else if (isSignedIntegerFormat(arrayInfo.format)) {
410436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
411436c6c9cSStella Laurenzo         // i32
4125d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
413436c6c9cSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 32);
414436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
415436c6c9cSStella Laurenzo         // i64
4165d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
417436c6c9cSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 64);
4185d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 1) {
4195d6d30edSStella Laurenzo         // i8
4205d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
4215d6d30edSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 8);
4225d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 2) {
4235d6d30edSStella Laurenzo         // i16
4245d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
4255d6d30edSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 16);
426436c6c9cSStella Laurenzo       }
427436c6c9cSStella Laurenzo     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
428436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
429436c6c9cSStella Laurenzo         // unsigned i32
4305d6d30edSStella Laurenzo         bulkLoadElementType = signless
431436c6c9cSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 32)
432436c6c9cSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 32);
433436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
434436c6c9cSStella Laurenzo         // unsigned i64
4355d6d30edSStella Laurenzo         bulkLoadElementType = signless
436436c6c9cSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 64)
437436c6c9cSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 64);
4385d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 1) {
4395d6d30edSStella Laurenzo         // i8
4405d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
4415d6d30edSStella Laurenzo                                        : mlirIntegerTypeUnsignedGet(context, 8);
4425d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 2) {
4435d6d30edSStella Laurenzo         // i16
4445d6d30edSStella Laurenzo         bulkLoadElementType = signless
4455d6d30edSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 16)
4465d6d30edSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 16);
447436c6c9cSStella Laurenzo       }
448436c6c9cSStella Laurenzo     }
4495d6d30edSStella Laurenzo     if (bulkLoadElementType) {
4505d6d30edSStella Laurenzo       auto shapedType = mlirRankedTensorTypeGet(
4515d6d30edSStella Laurenzo           shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
4525d6d30edSStella Laurenzo       size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
4535d6d30edSStella Laurenzo       MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
4545d6d30edSStella Laurenzo           shapedType, rawBufferSize, arrayInfo.ptr);
4555d6d30edSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
4565d6d30edSStella Laurenzo         throw std::invalid_argument(
4575d6d30edSStella Laurenzo             "DenseElementsAttr could not be constructed from the given buffer. "
4585d6d30edSStella Laurenzo             "This may mean that the Python buffer layout does not match that "
4595d6d30edSStella Laurenzo             "MLIR expected layout and is a bug.");
4605d6d30edSStella Laurenzo       }
4615d6d30edSStella Laurenzo       return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
4625d6d30edSStella Laurenzo     }
463436c6c9cSStella Laurenzo 
4645d6d30edSStella Laurenzo     throw std::invalid_argument(
4655d6d30edSStella Laurenzo         std::string("unimplemented array format conversion from format: ") +
4665d6d30edSStella Laurenzo         arrayInfo.format);
467436c6c9cSStella Laurenzo   }
468436c6c9cSStella Laurenzo 
4691fc096afSMehdi Amini   static PyDenseElementsAttribute getSplat(const PyType &shapedType,
470436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
471436c6c9cSStella Laurenzo     auto contextWrapper =
472436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
473436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
474436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
475436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
476436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
477436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
478436c6c9cSStella Laurenzo     }
479436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
480436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
481436c6c9cSStella Laurenzo       std::string message =
482436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
483436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
484436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
485436c6c9cSStella Laurenzo     }
486436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
487436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
488436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
489436c6c9cSStella Laurenzo       std::string message =
490436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
491436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
492436c6c9cSStella Laurenzo       message.append(", element=");
493436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
494436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
495436c6c9cSStella Laurenzo     }
496436c6c9cSStella Laurenzo 
497436c6c9cSStella Laurenzo     MlirAttribute elements =
498436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
499436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
500436c6c9cSStella Laurenzo   }
501436c6c9cSStella Laurenzo 
502436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
503436c6c9cSStella Laurenzo 
504436c6c9cSStella Laurenzo   py::buffer_info accessBuffer() {
5055d6d30edSStella Laurenzo     if (mlirDenseElementsAttrIsSplat(*this)) {
506c5f445d1SStella Laurenzo       // TODO: Currently crashes the program.
5075d6d30edSStella Laurenzo       // Reported as https://github.com/pybind/pybind11/issues/3336
508c5f445d1SStella Laurenzo       throw std::invalid_argument(
509c5f445d1SStella Laurenzo           "unsupported data type for conversion to Python buffer");
5105d6d30edSStella Laurenzo     }
5115d6d30edSStella Laurenzo 
512436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
513436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
5145d6d30edSStella Laurenzo     std::string format;
515436c6c9cSStella Laurenzo 
516436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
517436c6c9cSStella Laurenzo       // f32
5185d6d30edSStella Laurenzo       return bufferInfo<float>(shapedType);
51902b6fb21SMehdi Amini     }
52002b6fb21SMehdi Amini     if (mlirTypeIsAF64(elementType)) {
521436c6c9cSStella Laurenzo       // f64
5225d6d30edSStella Laurenzo       return bufferInfo<double>(shapedType);
523bb56c2b3SMehdi Amini     }
524bb56c2b3SMehdi Amini     if (mlirTypeIsAF16(elementType)) {
5255d6d30edSStella Laurenzo       // f16
5265d6d30edSStella Laurenzo       return bufferInfo<uint16_t>(shapedType, "e");
527bb56c2b3SMehdi Amini     }
528bb56c2b3SMehdi Amini     if (mlirTypeIsAInteger(elementType) &&
529436c6c9cSStella Laurenzo         mlirIntegerTypeGetWidth(elementType) == 32) {
530436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
531436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
532436c6c9cSStella Laurenzo         // i32
5335d6d30edSStella Laurenzo         return bufferInfo<int32_t>(shapedType);
534e5639b3fSMehdi Amini       }
535e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
536436c6c9cSStella Laurenzo         // unsigned i32
5375d6d30edSStella Laurenzo         return bufferInfo<uint32_t>(shapedType);
538436c6c9cSStella Laurenzo       }
539436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
540436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
541436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
542436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
543436c6c9cSStella Laurenzo         // i64
5445d6d30edSStella Laurenzo         return bufferInfo<int64_t>(shapedType);
545e5639b3fSMehdi Amini       }
546e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
547436c6c9cSStella Laurenzo         // unsigned i64
5485d6d30edSStella Laurenzo         return bufferInfo<uint64_t>(shapedType);
5495d6d30edSStella Laurenzo       }
5505d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
5515d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 8) {
5525d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
5535d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
5545d6d30edSStella Laurenzo         // i8
5555d6d30edSStella Laurenzo         return bufferInfo<int8_t>(shapedType);
556e5639b3fSMehdi Amini       }
557e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
5585d6d30edSStella Laurenzo         // unsigned i8
5595d6d30edSStella Laurenzo         return bufferInfo<uint8_t>(shapedType);
5605d6d30edSStella Laurenzo       }
5615d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
5625d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 16) {
5635d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
5645d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
5655d6d30edSStella Laurenzo         // i16
5665d6d30edSStella Laurenzo         return bufferInfo<int16_t>(shapedType);
567e5639b3fSMehdi Amini       }
568e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
5695d6d30edSStella Laurenzo         // unsigned i16
5705d6d30edSStella Laurenzo         return bufferInfo<uint16_t>(shapedType);
571436c6c9cSStella Laurenzo       }
572436c6c9cSStella Laurenzo     }
573436c6c9cSStella Laurenzo 
574c5f445d1SStella Laurenzo     // TODO: Currently crashes the program.
5755d6d30edSStella Laurenzo     // Reported as https://github.com/pybind/pybind11/issues/3336
576c5f445d1SStella Laurenzo     throw std::invalid_argument(
577c5f445d1SStella Laurenzo         "unsupported data type for conversion to Python buffer");
578436c6c9cSStella Laurenzo   }
579436c6c9cSStella Laurenzo 
580436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
581436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
582436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
583436c6c9cSStella Laurenzo                     py::arg("array"), py::arg("signless") = true,
5845d6d30edSStella Laurenzo                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
585436c6c9cSStella Laurenzo                     py::arg("context") = py::none(),
5865d6d30edSStella Laurenzo                     kDenseElementsAttrGetDocstring)
587436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
588436c6c9cSStella Laurenzo                     py::arg("shaped_type"), py::arg("element_attr"),
589436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
590436c6c9cSStella Laurenzo         .def_property_readonly("is_splat",
591436c6c9cSStella Laurenzo                                [](PyDenseElementsAttribute &self) -> bool {
592436c6c9cSStella Laurenzo                                  return mlirDenseElementsAttrIsSplat(self);
593436c6c9cSStella Laurenzo                                })
594436c6c9cSStella Laurenzo         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
595436c6c9cSStella Laurenzo   }
596436c6c9cSStella Laurenzo 
597436c6c9cSStella Laurenzo private:
598436c6c9cSStella Laurenzo   static bool isUnsignedIntegerFormat(const std::string &format) {
599436c6c9cSStella Laurenzo     if (format.empty())
600436c6c9cSStella Laurenzo       return false;
601436c6c9cSStella Laurenzo     char code = format[0];
602436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
603436c6c9cSStella Laurenzo            code == 'Q';
604436c6c9cSStella Laurenzo   }
605436c6c9cSStella Laurenzo 
606436c6c9cSStella Laurenzo   static bool isSignedIntegerFormat(const std::string &format) {
607436c6c9cSStella Laurenzo     if (format.empty())
608436c6c9cSStella Laurenzo       return false;
609436c6c9cSStella Laurenzo     char code = format[0];
610436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
611436c6c9cSStella Laurenzo            code == 'q';
612436c6c9cSStella Laurenzo   }
613436c6c9cSStella Laurenzo 
614436c6c9cSStella Laurenzo   template <typename Type>
615436c6c9cSStella Laurenzo   py::buffer_info bufferInfo(MlirType shapedType,
6165d6d30edSStella Laurenzo                              const char *explicitFormat = nullptr) {
617436c6c9cSStella Laurenzo     intptr_t rank = mlirShapedTypeGetRank(shapedType);
618436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
619436c6c9cSStella Laurenzo     // Buffer is configured for read-only access below.
620436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
621436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
622436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
623436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
624436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
625436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
626436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
627436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
628436c6c9cSStella Laurenzo     intptr_t strideFactor = 1;
629436c6c9cSStella Laurenzo     for (intptr_t i = 1; i < rank; ++i) {
630436c6c9cSStella Laurenzo       strideFactor = 1;
631436c6c9cSStella Laurenzo       for (intptr_t j = i; j < rank; ++j) {
632436c6c9cSStella Laurenzo         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
633436c6c9cSStella Laurenzo       }
634436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type) * strideFactor);
635436c6c9cSStella Laurenzo     }
636436c6c9cSStella Laurenzo     strides.push_back(sizeof(Type));
6375d6d30edSStella Laurenzo     std::string format;
6385d6d30edSStella Laurenzo     if (explicitFormat) {
6395d6d30edSStella Laurenzo       format = explicitFormat;
6405d6d30edSStella Laurenzo     } else {
6415d6d30edSStella Laurenzo       format = py::format_descriptor<Type>::format();
6425d6d30edSStella Laurenzo     }
6435d6d30edSStella Laurenzo     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
6445d6d30edSStella Laurenzo                            /*readonly=*/true);
645436c6c9cSStella Laurenzo   }
646436c6c9cSStella Laurenzo }; // namespace
647436c6c9cSStella Laurenzo 
648436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer
649436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access.
650436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
651436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
652436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
653436c6c9cSStella Laurenzo public:
654436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
655436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
656436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
657436c6c9cSStella Laurenzo 
658436c6c9cSStella Laurenzo   /// Returns the element at the given linear position. Asserts if the index is
659436c6c9cSStella Laurenzo   /// out of range.
660436c6c9cSStella Laurenzo   py::int_ dunderGetItem(intptr_t pos) {
661436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
662436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
663436c6c9cSStella Laurenzo                        "attempt to access out of bounds element");
664436c6c9cSStella Laurenzo     }
665436c6c9cSStella Laurenzo 
666436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
667436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
668436c6c9cSStella Laurenzo     assert(mlirTypeIsAInteger(type) &&
669436c6c9cSStella Laurenzo            "expected integer element type in dense int elements attribute");
670436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
671436c6c9cSStella Laurenzo     // elemental type of the attribute. py::int_ is implicitly constructible
672436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
673436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
674436c6c9cSStella Laurenzo     // querying them on each element access.
675436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
676436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
677436c6c9cSStella Laurenzo     if (isUnsigned) {
678436c6c9cSStella Laurenzo       if (width == 1) {
679436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
680436c6c9cSStella Laurenzo       }
681308d8b8cSRahul Kayaith       if (width == 8) {
682308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt8Value(*this, pos);
683308d8b8cSRahul Kayaith       }
684308d8b8cSRahul Kayaith       if (width == 16) {
685308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt16Value(*this, pos);
686308d8b8cSRahul Kayaith       }
687436c6c9cSStella Laurenzo       if (width == 32) {
688436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
689436c6c9cSStella Laurenzo       }
690436c6c9cSStella Laurenzo       if (width == 64) {
691436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
692436c6c9cSStella Laurenzo       }
693436c6c9cSStella Laurenzo     } else {
694436c6c9cSStella Laurenzo       if (width == 1) {
695436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
696436c6c9cSStella Laurenzo       }
697308d8b8cSRahul Kayaith       if (width == 8) {
698308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt8Value(*this, pos);
699308d8b8cSRahul Kayaith       }
700308d8b8cSRahul Kayaith       if (width == 16) {
701308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt16Value(*this, pos);
702308d8b8cSRahul Kayaith       }
703436c6c9cSStella Laurenzo       if (width == 32) {
704436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt32Value(*this, pos);
705436c6c9cSStella Laurenzo       }
706436c6c9cSStella Laurenzo       if (width == 64) {
707436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt64Value(*this, pos);
708436c6c9cSStella Laurenzo       }
709436c6c9cSStella Laurenzo     }
710436c6c9cSStella Laurenzo     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
711436c6c9cSStella Laurenzo   }
712436c6c9cSStella Laurenzo 
713436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
714436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
715436c6c9cSStella Laurenzo   }
716436c6c9cSStella Laurenzo };
717436c6c9cSStella Laurenzo 
718436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
719436c6c9cSStella Laurenzo public:
720436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
721436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
722436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
723436c6c9cSStella Laurenzo 
724436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
725436c6c9cSStella Laurenzo 
7269fb1086bSAdrian Kuegel   bool dunderContains(const std::string &name) {
7279fb1086bSAdrian Kuegel     return !mlirAttributeIsNull(
7289fb1086bSAdrian Kuegel         mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
7299fb1086bSAdrian Kuegel   }
7309fb1086bSAdrian Kuegel 
731436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
7329fb1086bSAdrian Kuegel     c.def("__contains__", &PyDictAttribute::dunderContains);
733436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
734436c6c9cSStella Laurenzo     c.def_static(
735436c6c9cSStella Laurenzo         "get",
736436c6c9cSStella Laurenzo         [](py::dict attributes, DefaultingPyMlirContext context) {
737436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
738436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
739436c6c9cSStella Laurenzo           for (auto &it : attributes) {
74002b6fb21SMehdi Amini             auto &mlirAttr = it.second.cast<PyAttribute &>();
741436c6c9cSStella Laurenzo             auto name = it.first.cast<std::string>();
742436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
74302b6fb21SMehdi Amini                 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
744436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
74502b6fb21SMehdi Amini                 mlirAttr));
746436c6c9cSStella Laurenzo           }
747436c6c9cSStella Laurenzo           MlirAttribute attr =
748436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
749436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
750436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
751436c6c9cSStella Laurenzo         },
752ed9e52f3SAlex Zinenko         py::arg("value") = py::dict(), py::arg("context") = py::none(),
753436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
754436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
755436c6c9cSStella Laurenzo       MlirAttribute attr =
756436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
757436c6c9cSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
758436c6c9cSStella Laurenzo         throw SetPyError(PyExc_KeyError,
759436c6c9cSStella Laurenzo                          "attempt to access a non-existent attribute");
760436c6c9cSStella Laurenzo       }
761436c6c9cSStella Laurenzo       return PyAttribute(self.getContext(), attr);
762436c6c9cSStella Laurenzo     });
763436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
764436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
765436c6c9cSStella Laurenzo         throw SetPyError(PyExc_IndexError,
766436c6c9cSStella Laurenzo                          "attempt to access out of bounds attribute");
767436c6c9cSStella Laurenzo       }
768436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
769436c6c9cSStella Laurenzo       return PyNamedAttribute(
770436c6c9cSStella Laurenzo           namedAttr.attribute,
771436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
772436c6c9cSStella Laurenzo     });
773436c6c9cSStella Laurenzo   }
774436c6c9cSStella Laurenzo };
775436c6c9cSStella Laurenzo 
776436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
777436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
778436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
779436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
780436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
781436c6c9cSStella Laurenzo public:
782436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
783436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
784436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
785436c6c9cSStella Laurenzo 
786436c6c9cSStella Laurenzo   py::float_ dunderGetItem(intptr_t pos) {
787436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
788436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
789436c6c9cSStella Laurenzo                        "attempt to access out of bounds element");
790436c6c9cSStella Laurenzo     }
791436c6c9cSStella Laurenzo 
792436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
793436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
794436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
795436c6c9cSStella Laurenzo     // elemental type of the attribute. py::float_ is implicitly constructible
796436c6c9cSStella Laurenzo     // from float and double.
797436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
798436c6c9cSStella Laurenzo     // querying them on each element access.
799436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
800436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetFloatValue(*this, pos);
801436c6c9cSStella Laurenzo     }
802436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
803436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
804436c6c9cSStella Laurenzo     }
805436c6c9cSStella Laurenzo     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
806436c6c9cSStella Laurenzo   }
807436c6c9cSStella Laurenzo 
808436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
809436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
810436c6c9cSStella Laurenzo   }
811436c6c9cSStella Laurenzo };
812436c6c9cSStella Laurenzo 
813436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
814436c6c9cSStella Laurenzo public:
815436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
816436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
817436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
818436c6c9cSStella Laurenzo 
819436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
820436c6c9cSStella Laurenzo     c.def_static(
821436c6c9cSStella Laurenzo         "get",
822436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
823436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
824436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
825436c6c9cSStella Laurenzo         },
826436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
827436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
828436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyTypeAttribute &self) {
829436c6c9cSStella Laurenzo       return PyType(self.getContext()->getRef(),
830436c6c9cSStella Laurenzo                     mlirTypeAttrGetValue(self.get()));
831436c6c9cSStella Laurenzo     });
832436c6c9cSStella Laurenzo   }
833436c6c9cSStella Laurenzo };
834436c6c9cSStella Laurenzo 
835436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
836436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
837436c6c9cSStella Laurenzo public:
838436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
839436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
840436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
841436c6c9cSStella Laurenzo 
842436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
843436c6c9cSStella Laurenzo     c.def_static(
844436c6c9cSStella Laurenzo         "get",
845436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
846436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
847436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
848436c6c9cSStella Laurenzo         },
849436c6c9cSStella Laurenzo         py::arg("context") = py::none(), "Create a Unit attribute.");
850436c6c9cSStella Laurenzo   }
851436c6c9cSStella Laurenzo };
852436c6c9cSStella Laurenzo 
853436c6c9cSStella Laurenzo } // namespace
854436c6c9cSStella Laurenzo 
855436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) {
856436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
857436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
858436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
859436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
860436c6c9cSStella Laurenzo   PyDenseElementsAttribute::bind(m);
861436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
862436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
863436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
864436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
865436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
866436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
867436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
868436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
869436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
870436c6c9cSStella Laurenzo }
871