xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision c5f445d143485f898353df6d422eea1dea22c7a8)
1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9436c6c9cSStella Laurenzo #include "IRModule.h"
10436c6c9cSStella Laurenzo 
11436c6c9cSStella Laurenzo #include "PybindUtils.h"
12436c6c9cSStella Laurenzo 
13436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
14436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
15436c6c9cSStella Laurenzo 
16436c6c9cSStella Laurenzo namespace py = pybind11;
17436c6c9cSStella Laurenzo using namespace mlir;
18436c6c9cSStella Laurenzo using namespace mlir::python;
19436c6c9cSStella Laurenzo 
205d6d30edSStella Laurenzo using llvm::None;
215d6d30edSStella Laurenzo using llvm::Optional;
22436c6c9cSStella Laurenzo using llvm::SmallVector;
23436c6c9cSStella Laurenzo using llvm::Twine;
24436c6c9cSStella Laurenzo 
255d6d30edSStella Laurenzo //------------------------------------------------------------------------------
265d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
275d6d30edSStella Laurenzo //------------------------------------------------------------------------------
285d6d30edSStella Laurenzo 
295d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] =
305d6d30edSStella Laurenzo     R"(Gets a DenseElementsAttr from a Python buffer or array.
315d6d30edSStella Laurenzo 
325d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based
335d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and
345d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these
355d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer.
365d6d30edSStella Laurenzo 
375d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided
385d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal
395d6d30edSStella Laurenzo representation:
405d6d30edSStella Laurenzo 
415d6d30edSStella Laurenzo   * Integer types (except for i1): the buffer must be byte aligned to the
425d6d30edSStella Laurenzo     next byte boundary.
435d6d30edSStella Laurenzo   * Floating point types: Must be bit-castable to the given floating point
445d6d30edSStella Laurenzo     size.
455d6d30edSStella Laurenzo   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
465d6d30edSStella Laurenzo     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
475d6d30edSStella Laurenzo     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
485d6d30edSStella Laurenzo 
495d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0
505d6d30edSStella Laurenzo or 255), then a splat will be created.
515d6d30edSStella Laurenzo 
525d6d30edSStella Laurenzo Args:
535d6d30edSStella Laurenzo   array: The array or buffer to convert.
545d6d30edSStella Laurenzo   signless: If inferring an appropriate MLIR type, use signless types for
555d6d30edSStella Laurenzo     integers (defaults True).
565d6d30edSStella Laurenzo   type: Skips inference of the MLIR element type and uses this instead. The
575d6d30edSStella Laurenzo     storage size must be consistent with the actual contents of the buffer.
585d6d30edSStella Laurenzo   shape: Overrides the shape of the buffer when constructing the MLIR
595d6d30edSStella Laurenzo     shaped type. This is needed when the physical and logical shape differ (as
605d6d30edSStella Laurenzo     for i1).
615d6d30edSStella Laurenzo   context: Explicit context, if not from context manager.
625d6d30edSStella Laurenzo 
635d6d30edSStella Laurenzo Returns:
645d6d30edSStella Laurenzo   DenseElementsAttr on success.
655d6d30edSStella Laurenzo 
665d6d30edSStella Laurenzo Raises:
675d6d30edSStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
685d6d30edSStella Laurenzo     type or if the buffer does not meet expectations.
695d6d30edSStella Laurenzo )";
705d6d30edSStella Laurenzo 
71436c6c9cSStella Laurenzo namespace {
72436c6c9cSStella Laurenzo 
73436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
74436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
75436c6c9cSStella Laurenzo }
76436c6c9cSStella Laurenzo 
77436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
78436c6c9cSStella Laurenzo public:
79436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
80436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
81436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
82436c6c9cSStella Laurenzo 
83436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
84436c6c9cSStella Laurenzo     c.def_static(
85436c6c9cSStella Laurenzo         "get",
86436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
87436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
88436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
89436c6c9cSStella Laurenzo         },
90436c6c9cSStella Laurenzo         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
91436c6c9cSStella Laurenzo   }
92436c6c9cSStella Laurenzo };
93436c6c9cSStella Laurenzo 
94ed9e52f3SAlex Zinenko template <typename T>
95ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) {
96ed9e52f3SAlex Zinenko   try {
97ed9e52f3SAlex Zinenko     return object.cast<T>();
98ed9e52f3SAlex Zinenko   } catch (py::cast_error &err) {
99ed9e52f3SAlex Zinenko     std::string msg =
100ed9e52f3SAlex Zinenko         std::string(
101ed9e52f3SAlex Zinenko             "Invalid attribute when attempting to create an ArrayAttribute (") +
102ed9e52f3SAlex Zinenko         err.what() + ")";
103ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
104ed9e52f3SAlex Zinenko   } catch (py::reference_cast_error &err) {
105ed9e52f3SAlex Zinenko     std::string msg = std::string("Invalid attribute (None?) when attempting "
106ed9e52f3SAlex Zinenko                                   "to create an ArrayAttribute (") +
107ed9e52f3SAlex Zinenko                       err.what() + ")";
108ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
109ed9e52f3SAlex Zinenko   }
110ed9e52f3SAlex Zinenko }
111ed9e52f3SAlex Zinenko 
112436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
113436c6c9cSStella Laurenzo public:
114436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
115436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
116436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
117436c6c9cSStella Laurenzo 
118436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
119436c6c9cSStella Laurenzo   public:
120436c6c9cSStella Laurenzo     PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
121436c6c9cSStella Laurenzo 
122436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
123436c6c9cSStella Laurenzo 
124436c6c9cSStella Laurenzo     PyAttribute dunderNext() {
125436c6c9cSStella Laurenzo       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
126436c6c9cSStella Laurenzo         throw py::stop_iteration();
127436c6c9cSStella Laurenzo       }
128436c6c9cSStella Laurenzo       return PyAttribute(attr.getContext(),
129436c6c9cSStella Laurenzo                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
130436c6c9cSStella Laurenzo     }
131436c6c9cSStella Laurenzo 
132436c6c9cSStella Laurenzo     static void bind(py::module &m) {
133f05ff4f7SStella Laurenzo       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
134f05ff4f7SStella Laurenzo                                            py::module_local())
135436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
136436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
137436c6c9cSStella Laurenzo     }
138436c6c9cSStella Laurenzo 
139436c6c9cSStella Laurenzo   private:
140436c6c9cSStella Laurenzo     PyAttribute attr;
141436c6c9cSStella Laurenzo     int nextIndex = 0;
142436c6c9cSStella Laurenzo   };
143436c6c9cSStella Laurenzo 
144ed9e52f3SAlex Zinenko   PyAttribute getItem(intptr_t i) {
145ed9e52f3SAlex Zinenko     return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
146ed9e52f3SAlex Zinenko   }
147ed9e52f3SAlex Zinenko 
148436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
149436c6c9cSStella Laurenzo     c.def_static(
150436c6c9cSStella Laurenzo         "get",
151436c6c9cSStella Laurenzo         [](py::list attributes, DefaultingPyMlirContext context) {
152436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
153436c6c9cSStella Laurenzo           mlirAttributes.reserve(py::len(attributes));
154436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
155ed9e52f3SAlex Zinenko             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
156436c6c9cSStella Laurenzo           }
157436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
158436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
159436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
160436c6c9cSStella Laurenzo         },
161436c6c9cSStella Laurenzo         py::arg("attributes"), py::arg("context") = py::none(),
162436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
163436c6c9cSStella Laurenzo     c.def("__getitem__",
164436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
165436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
166436c6c9cSStella Laurenzo               throw py::index_error("ArrayAttribute index out of range");
167ed9e52f3SAlex Zinenko             return arr.getItem(i);
168436c6c9cSStella Laurenzo           })
169436c6c9cSStella Laurenzo         .def("__len__",
170436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
171436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
172436c6c9cSStella Laurenzo              })
173436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
174436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
175436c6c9cSStella Laurenzo         });
176ed9e52f3SAlex Zinenko     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
177ed9e52f3SAlex Zinenko       std::vector<MlirAttribute> attributes;
178ed9e52f3SAlex Zinenko       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
179ed9e52f3SAlex Zinenko       attributes.reserve(numOldElements + py::len(extras));
180ed9e52f3SAlex Zinenko       for (intptr_t i = 0; i < numOldElements; ++i)
181ed9e52f3SAlex Zinenko         attributes.push_back(arr.getItem(i));
182ed9e52f3SAlex Zinenko       for (py::handle attr : extras)
183ed9e52f3SAlex Zinenko         attributes.push_back(pyTryCast<PyAttribute>(attr));
184ed9e52f3SAlex Zinenko       MlirAttribute arrayAttr = mlirArrayAttrGet(
185ed9e52f3SAlex Zinenko           arr.getContext()->get(), attributes.size(), attributes.data());
186ed9e52f3SAlex Zinenko       return PyArrayAttribute(arr.getContext(), arrayAttr);
187ed9e52f3SAlex Zinenko     });
188436c6c9cSStella Laurenzo   }
189436c6c9cSStella Laurenzo };
190436c6c9cSStella Laurenzo 
191436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
192436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
193436c6c9cSStella Laurenzo public:
194436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
195436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
196436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
197436c6c9cSStella Laurenzo 
198436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
199436c6c9cSStella Laurenzo     c.def_static(
200436c6c9cSStella Laurenzo         "get",
201436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
202436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
203436c6c9cSStella Laurenzo           // TODO: Rework error reporting once diagnostic engine is exposed
204436c6c9cSStella Laurenzo           // in C API.
205436c6c9cSStella Laurenzo           if (mlirAttributeIsNull(attr)) {
206436c6c9cSStella Laurenzo             throw SetPyError(PyExc_ValueError,
207436c6c9cSStella Laurenzo                              Twine("invalid '") +
208436c6c9cSStella Laurenzo                                  py::repr(py::cast(type)).cast<std::string>() +
209436c6c9cSStella Laurenzo                                  "' and expected floating point type.");
210436c6c9cSStella Laurenzo           }
211436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
212436c6c9cSStella Laurenzo         },
213436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
214436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
215436c6c9cSStella Laurenzo     c.def_static(
216436c6c9cSStella Laurenzo         "get_f32",
217436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
218436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
219436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
220436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
221436c6c9cSStella Laurenzo         },
222436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
223436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
224436c6c9cSStella Laurenzo     c.def_static(
225436c6c9cSStella Laurenzo         "get_f64",
226436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
227436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
228436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
229436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
230436c6c9cSStella Laurenzo         },
231436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
232436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
233436c6c9cSStella Laurenzo     c.def_property_readonly(
234436c6c9cSStella Laurenzo         "value",
235436c6c9cSStella Laurenzo         [](PyFloatAttribute &self) {
236436c6c9cSStella Laurenzo           return mlirFloatAttrGetValueDouble(self);
237436c6c9cSStella Laurenzo         },
238436c6c9cSStella Laurenzo         "Returns the value of the float point attribute");
239436c6c9cSStella Laurenzo   }
240436c6c9cSStella Laurenzo };
241436c6c9cSStella Laurenzo 
242436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
243436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
244436c6c9cSStella Laurenzo public:
245436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
246436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
247436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
248436c6c9cSStella Laurenzo 
249436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
250436c6c9cSStella Laurenzo     c.def_static(
251436c6c9cSStella Laurenzo         "get",
252436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
253436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
254436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
255436c6c9cSStella Laurenzo         },
256436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"),
257436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
258436c6c9cSStella Laurenzo     c.def_property_readonly(
259436c6c9cSStella Laurenzo         "value",
260436c6c9cSStella Laurenzo         [](PyIntegerAttribute &self) {
261436c6c9cSStella Laurenzo           return mlirIntegerAttrGetValueInt(self);
262436c6c9cSStella Laurenzo         },
263436c6c9cSStella Laurenzo         "Returns the value of the integer attribute");
264436c6c9cSStella Laurenzo   }
265436c6c9cSStella Laurenzo };
266436c6c9cSStella Laurenzo 
267436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
268436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
269436c6c9cSStella Laurenzo public:
270436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
271436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
272436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
273436c6c9cSStella Laurenzo 
274436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
275436c6c9cSStella Laurenzo     c.def_static(
276436c6c9cSStella Laurenzo         "get",
277436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
278436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
279436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
280436c6c9cSStella Laurenzo         },
281436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
282436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
283436c6c9cSStella Laurenzo     c.def_property_readonly(
284436c6c9cSStella Laurenzo         "value",
285436c6c9cSStella Laurenzo         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
286436c6c9cSStella Laurenzo         "Returns the value of the bool attribute");
287436c6c9cSStella Laurenzo   }
288436c6c9cSStella Laurenzo };
289436c6c9cSStella Laurenzo 
290436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
291436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
292436c6c9cSStella Laurenzo public:
293436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
294436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
295436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
296436c6c9cSStella Laurenzo 
297436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
298436c6c9cSStella Laurenzo     c.def_static(
299436c6c9cSStella Laurenzo         "get",
300436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
301436c6c9cSStella Laurenzo           MlirAttribute attr =
302436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
303436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
304436c6c9cSStella Laurenzo         },
305436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
306436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
307436c6c9cSStella Laurenzo     c.def_property_readonly(
308436c6c9cSStella Laurenzo         "value",
309436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
310436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
311436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
312436c6c9cSStella Laurenzo         },
313436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
314436c6c9cSStella Laurenzo   }
315436c6c9cSStella Laurenzo };
316436c6c9cSStella Laurenzo 
317436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
318436c6c9cSStella Laurenzo public:
319436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
320436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
321436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
322436c6c9cSStella Laurenzo 
323436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
324436c6c9cSStella Laurenzo     c.def_static(
325436c6c9cSStella Laurenzo         "get",
326436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
327436c6c9cSStella Laurenzo           MlirAttribute attr =
328436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
329436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
330436c6c9cSStella Laurenzo         },
331436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
332436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
333436c6c9cSStella Laurenzo     c.def_static(
334436c6c9cSStella Laurenzo         "get_typed",
335436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
336436c6c9cSStella Laurenzo           MlirAttribute attr =
337436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
338436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
339436c6c9cSStella Laurenzo         },
340436c6c9cSStella Laurenzo 
341436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
342436c6c9cSStella Laurenzo     c.def_property_readonly(
343436c6c9cSStella Laurenzo         "value",
344436c6c9cSStella Laurenzo         [](PyStringAttribute &self) {
345436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirStringAttrGetValue(self);
346436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
347436c6c9cSStella Laurenzo         },
348436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
349436c6c9cSStella Laurenzo   }
350436c6c9cSStella Laurenzo };
351436c6c9cSStella Laurenzo 
352436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
353436c6c9cSStella Laurenzo class PyDenseElementsAttribute
354436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
355436c6c9cSStella Laurenzo public:
356436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
357436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
358436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
359436c6c9cSStella Laurenzo 
360436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
3615d6d30edSStella Laurenzo   getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
3625d6d30edSStella Laurenzo                 Optional<std::vector<int64_t>> explicitShape,
363436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
364436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
365436c6c9cSStella Laurenzo     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
366436c6c9cSStella Laurenzo     Py_buffer *view = new Py_buffer();
367436c6c9cSStella Laurenzo     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
368436c6c9cSStella Laurenzo       delete view;
369436c6c9cSStella Laurenzo       throw py::error_already_set();
370436c6c9cSStella Laurenzo     }
371436c6c9cSStella Laurenzo     py::buffer_info arrayInfo(view);
3725d6d30edSStella Laurenzo     SmallVector<int64_t> shape;
3735d6d30edSStella Laurenzo     if (explicitShape) {
3745d6d30edSStella Laurenzo       shape.append(explicitShape->begin(), explicitShape->end());
3755d6d30edSStella Laurenzo     } else {
3765d6d30edSStella Laurenzo       shape.append(arrayInfo.shape.begin(),
3775d6d30edSStella Laurenzo                    arrayInfo.shape.begin() + arrayInfo.ndim);
3785d6d30edSStella Laurenzo     }
379436c6c9cSStella Laurenzo 
3805d6d30edSStella Laurenzo     MlirAttribute encodingAttr = mlirAttributeGetNull();
381436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
3825d6d30edSStella Laurenzo 
3835d6d30edSStella Laurenzo     // Detect format codes that are suitable for bulk loading. This includes
3845d6d30edSStella Laurenzo     // all byte aligned integer and floating point types up to 8 bytes.
3855d6d30edSStella Laurenzo     // Notably, this excludes, bool (which needs to be bit-packed) and
3865d6d30edSStella Laurenzo     // other exotics which do not have a direct representation in the buffer
3875d6d30edSStella Laurenzo     // protocol (i.e. complex, etc).
3885d6d30edSStella Laurenzo     Optional<MlirType> bulkLoadElementType;
3895d6d30edSStella Laurenzo     if (explicitType) {
3905d6d30edSStella Laurenzo       bulkLoadElementType = *explicitType;
3915d6d30edSStella Laurenzo     } else if (arrayInfo.format == "f") {
392436c6c9cSStella Laurenzo       // f32
393436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
3945d6d30edSStella Laurenzo       bulkLoadElementType = mlirF32TypeGet(context);
395436c6c9cSStella Laurenzo     } else if (arrayInfo.format == "d") {
396436c6c9cSStella Laurenzo       // f64
397436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
3985d6d30edSStella Laurenzo       bulkLoadElementType = mlirF64TypeGet(context);
3995d6d30edSStella Laurenzo     } else if (arrayInfo.format == "e") {
4005d6d30edSStella Laurenzo       // f16
4015d6d30edSStella Laurenzo       assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
4025d6d30edSStella Laurenzo       bulkLoadElementType = mlirF16TypeGet(context);
403436c6c9cSStella Laurenzo     } else if (isSignedIntegerFormat(arrayInfo.format)) {
404436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
405436c6c9cSStella Laurenzo         // i32
4065d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
407436c6c9cSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 32);
408436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
409436c6c9cSStella Laurenzo         // i64
4105d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
411436c6c9cSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 64);
4125d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 1) {
4135d6d30edSStella Laurenzo         // i8
4145d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
4155d6d30edSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 8);
4165d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 2) {
4175d6d30edSStella Laurenzo         // i16
4185d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
4195d6d30edSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 16);
420436c6c9cSStella Laurenzo       }
421436c6c9cSStella Laurenzo     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
422436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
423436c6c9cSStella Laurenzo         // unsigned i32
4245d6d30edSStella Laurenzo         bulkLoadElementType = signless
425436c6c9cSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 32)
426436c6c9cSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 32);
427436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
428436c6c9cSStella Laurenzo         // unsigned i64
4295d6d30edSStella Laurenzo         bulkLoadElementType = signless
430436c6c9cSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 64)
431436c6c9cSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 64);
4325d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 1) {
4335d6d30edSStella Laurenzo         // i8
4345d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
4355d6d30edSStella Laurenzo                                        : mlirIntegerTypeUnsignedGet(context, 8);
4365d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 2) {
4375d6d30edSStella Laurenzo         // i16
4385d6d30edSStella Laurenzo         bulkLoadElementType = signless
4395d6d30edSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 16)
4405d6d30edSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 16);
441436c6c9cSStella Laurenzo       }
442436c6c9cSStella Laurenzo     }
4435d6d30edSStella Laurenzo     if (bulkLoadElementType) {
4445d6d30edSStella Laurenzo       auto shapedType = mlirRankedTensorTypeGet(
4455d6d30edSStella Laurenzo           shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
4465d6d30edSStella Laurenzo       size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
4475d6d30edSStella Laurenzo       MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
4485d6d30edSStella Laurenzo           shapedType, rawBufferSize, arrayInfo.ptr);
4495d6d30edSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
4505d6d30edSStella Laurenzo         throw std::invalid_argument(
4515d6d30edSStella Laurenzo             "DenseElementsAttr could not be constructed from the given buffer. "
4525d6d30edSStella Laurenzo             "This may mean that the Python buffer layout does not match that "
4535d6d30edSStella Laurenzo             "MLIR expected layout and is a bug.");
4545d6d30edSStella Laurenzo       }
4555d6d30edSStella Laurenzo       return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
4565d6d30edSStella Laurenzo     }
457436c6c9cSStella Laurenzo 
4585d6d30edSStella Laurenzo     throw std::invalid_argument(
4595d6d30edSStella Laurenzo         std::string("unimplemented array format conversion from format: ") +
4605d6d30edSStella Laurenzo         arrayInfo.format);
461436c6c9cSStella Laurenzo   }
462436c6c9cSStella Laurenzo 
463436c6c9cSStella Laurenzo   static PyDenseElementsAttribute getSplat(PyType shapedType,
464436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
465436c6c9cSStella Laurenzo     auto contextWrapper =
466436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
467436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
468436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
469436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
470436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
471436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
472436c6c9cSStella Laurenzo     }
473436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
474436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
475436c6c9cSStella Laurenzo       std::string message =
476436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
477436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
478436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
479436c6c9cSStella Laurenzo     }
480436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
481436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
482436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
483436c6c9cSStella Laurenzo       std::string message =
484436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
485436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
486436c6c9cSStella Laurenzo       message.append(", element=");
487436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
488436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
489436c6c9cSStella Laurenzo     }
490436c6c9cSStella Laurenzo 
491436c6c9cSStella Laurenzo     MlirAttribute elements =
492436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
493436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
494436c6c9cSStella Laurenzo   }
495436c6c9cSStella Laurenzo 
496436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
497436c6c9cSStella Laurenzo 
498436c6c9cSStella Laurenzo   py::buffer_info accessBuffer() {
4995d6d30edSStella Laurenzo     if (mlirDenseElementsAttrIsSplat(*this)) {
500*c5f445d1SStella Laurenzo       // TODO: Currently crashes the program.
5015d6d30edSStella Laurenzo       // Reported as https://github.com/pybind/pybind11/issues/3336
502*c5f445d1SStella Laurenzo       throw std::invalid_argument(
503*c5f445d1SStella Laurenzo           "unsupported data type for conversion to Python buffer");
5045d6d30edSStella Laurenzo     }
5055d6d30edSStella Laurenzo 
506436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
507436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
5085d6d30edSStella Laurenzo     std::string format;
509436c6c9cSStella Laurenzo 
510436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
511436c6c9cSStella Laurenzo       // f32
5125d6d30edSStella Laurenzo       return bufferInfo<float>(shapedType);
513436c6c9cSStella Laurenzo     } else if (mlirTypeIsAF64(elementType)) {
514436c6c9cSStella Laurenzo       // f64
5155d6d30edSStella Laurenzo       return bufferInfo<double>(shapedType);
5165d6d30edSStella Laurenzo     } else if (mlirTypeIsAF16(elementType)) {
5175d6d30edSStella Laurenzo       // f16
5185d6d30edSStella Laurenzo       return bufferInfo<uint16_t>(shapedType, "e");
519436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
520436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 32) {
521436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
522436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
523436c6c9cSStella Laurenzo         // i32
5245d6d30edSStella Laurenzo         return bufferInfo<int32_t>(shapedType);
525436c6c9cSStella Laurenzo       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
526436c6c9cSStella Laurenzo         // unsigned i32
5275d6d30edSStella Laurenzo         return bufferInfo<uint32_t>(shapedType);
528436c6c9cSStella Laurenzo       }
529436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
530436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
531436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
532436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
533436c6c9cSStella Laurenzo         // i64
5345d6d30edSStella Laurenzo         return bufferInfo<int64_t>(shapedType);
535436c6c9cSStella Laurenzo       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
536436c6c9cSStella Laurenzo         // unsigned i64
5375d6d30edSStella Laurenzo         return bufferInfo<uint64_t>(shapedType);
5385d6d30edSStella Laurenzo       }
5395d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
5405d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 8) {
5415d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
5425d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
5435d6d30edSStella Laurenzo         // i8
5445d6d30edSStella Laurenzo         return bufferInfo<int8_t>(shapedType);
5455d6d30edSStella Laurenzo       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
5465d6d30edSStella Laurenzo         // unsigned i8
5475d6d30edSStella Laurenzo         return bufferInfo<uint8_t>(shapedType);
5485d6d30edSStella Laurenzo       }
5495d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
5505d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 16) {
5515d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
5525d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
5535d6d30edSStella Laurenzo         // i16
5545d6d30edSStella Laurenzo         return bufferInfo<int16_t>(shapedType);
5555d6d30edSStella Laurenzo       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
5565d6d30edSStella Laurenzo         // unsigned i16
5575d6d30edSStella Laurenzo         return bufferInfo<uint16_t>(shapedType);
558436c6c9cSStella Laurenzo       }
559436c6c9cSStella Laurenzo     }
560436c6c9cSStella Laurenzo 
561*c5f445d1SStella Laurenzo     // TODO: Currently crashes the program.
5625d6d30edSStella Laurenzo     // Reported as https://github.com/pybind/pybind11/issues/3336
563*c5f445d1SStella Laurenzo     throw std::invalid_argument(
564*c5f445d1SStella Laurenzo         "unsupported data type for conversion to Python buffer");
565436c6c9cSStella Laurenzo   }
566436c6c9cSStella Laurenzo 
567436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
568436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
569436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
570436c6c9cSStella Laurenzo                     py::arg("array"), py::arg("signless") = true,
5715d6d30edSStella Laurenzo                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
572436c6c9cSStella Laurenzo                     py::arg("context") = py::none(),
5735d6d30edSStella Laurenzo                     kDenseElementsAttrGetDocstring)
574436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
575436c6c9cSStella Laurenzo                     py::arg("shaped_type"), py::arg("element_attr"),
576436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
577436c6c9cSStella Laurenzo         .def_property_readonly("is_splat",
578436c6c9cSStella Laurenzo                                [](PyDenseElementsAttribute &self) -> bool {
579436c6c9cSStella Laurenzo                                  return mlirDenseElementsAttrIsSplat(self);
580436c6c9cSStella Laurenzo                                })
581436c6c9cSStella Laurenzo         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
582436c6c9cSStella Laurenzo   }
583436c6c9cSStella Laurenzo 
584436c6c9cSStella Laurenzo private:
585436c6c9cSStella Laurenzo   static bool isUnsignedIntegerFormat(const std::string &format) {
586436c6c9cSStella Laurenzo     if (format.empty())
587436c6c9cSStella Laurenzo       return false;
588436c6c9cSStella Laurenzo     char code = format[0];
589436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
590436c6c9cSStella Laurenzo            code == 'Q';
591436c6c9cSStella Laurenzo   }
592436c6c9cSStella Laurenzo 
593436c6c9cSStella Laurenzo   static bool isSignedIntegerFormat(const std::string &format) {
594436c6c9cSStella Laurenzo     if (format.empty())
595436c6c9cSStella Laurenzo       return false;
596436c6c9cSStella Laurenzo     char code = format[0];
597436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
598436c6c9cSStella Laurenzo            code == 'q';
599436c6c9cSStella Laurenzo   }
600436c6c9cSStella Laurenzo 
601436c6c9cSStella Laurenzo   template <typename Type>
602436c6c9cSStella Laurenzo   py::buffer_info bufferInfo(MlirType shapedType,
6035d6d30edSStella Laurenzo                              const char *explicitFormat = nullptr) {
604436c6c9cSStella Laurenzo     intptr_t rank = mlirShapedTypeGetRank(shapedType);
605436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
606436c6c9cSStella Laurenzo     // Buffer is configured for read-only access below.
607436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
608436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
609436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
610436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
611436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
612436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
613436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
614436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
615436c6c9cSStella Laurenzo     intptr_t strideFactor = 1;
616436c6c9cSStella Laurenzo     for (intptr_t i = 1; i < rank; ++i) {
617436c6c9cSStella Laurenzo       strideFactor = 1;
618436c6c9cSStella Laurenzo       for (intptr_t j = i; j < rank; ++j) {
619436c6c9cSStella Laurenzo         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
620436c6c9cSStella Laurenzo       }
621436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type) * strideFactor);
622436c6c9cSStella Laurenzo     }
623436c6c9cSStella Laurenzo     strides.push_back(sizeof(Type));
6245d6d30edSStella Laurenzo     std::string format;
6255d6d30edSStella Laurenzo     if (explicitFormat) {
6265d6d30edSStella Laurenzo       format = explicitFormat;
6275d6d30edSStella Laurenzo     } else {
6285d6d30edSStella Laurenzo       format = py::format_descriptor<Type>::format();
6295d6d30edSStella Laurenzo     }
6305d6d30edSStella Laurenzo     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
6315d6d30edSStella Laurenzo                            /*readonly=*/true);
632436c6c9cSStella Laurenzo   }
633436c6c9cSStella Laurenzo }; // namespace
634436c6c9cSStella Laurenzo 
635436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer
636436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access.
637436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
638436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
639436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
640436c6c9cSStella Laurenzo public:
641436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
642436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
643436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
644436c6c9cSStella Laurenzo 
645436c6c9cSStella Laurenzo   /// Returns the element at the given linear position. Asserts if the index is
646436c6c9cSStella Laurenzo   /// out of range.
647436c6c9cSStella Laurenzo   py::int_ dunderGetItem(intptr_t pos) {
648436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
649436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
650436c6c9cSStella Laurenzo                        "attempt to access out of bounds element");
651436c6c9cSStella Laurenzo     }
652436c6c9cSStella Laurenzo 
653436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
654436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
655436c6c9cSStella Laurenzo     assert(mlirTypeIsAInteger(type) &&
656436c6c9cSStella Laurenzo            "expected integer element type in dense int elements attribute");
657436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
658436c6c9cSStella Laurenzo     // elemental type of the attribute. py::int_ is implicitly constructible
659436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
660436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
661436c6c9cSStella Laurenzo     // querying them on each element access.
662436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
663436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
664436c6c9cSStella Laurenzo     if (isUnsigned) {
665436c6c9cSStella Laurenzo       if (width == 1) {
666436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
667436c6c9cSStella Laurenzo       }
668436c6c9cSStella Laurenzo       if (width == 32) {
669436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
670436c6c9cSStella Laurenzo       }
671436c6c9cSStella Laurenzo       if (width == 64) {
672436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
673436c6c9cSStella Laurenzo       }
674436c6c9cSStella Laurenzo     } else {
675436c6c9cSStella Laurenzo       if (width == 1) {
676436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
677436c6c9cSStella Laurenzo       }
678436c6c9cSStella Laurenzo       if (width == 32) {
679436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt32Value(*this, pos);
680436c6c9cSStella Laurenzo       }
681436c6c9cSStella Laurenzo       if (width == 64) {
682436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt64Value(*this, pos);
683436c6c9cSStella Laurenzo       }
684436c6c9cSStella Laurenzo     }
685436c6c9cSStella Laurenzo     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
686436c6c9cSStella Laurenzo   }
687436c6c9cSStella Laurenzo 
688436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
689436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
690436c6c9cSStella Laurenzo   }
691436c6c9cSStella Laurenzo };
692436c6c9cSStella Laurenzo 
693436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
694436c6c9cSStella Laurenzo public:
695436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
696436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
697436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
698436c6c9cSStella Laurenzo 
699436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
700436c6c9cSStella Laurenzo 
701436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
702436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
703436c6c9cSStella Laurenzo     c.def_static(
704436c6c9cSStella Laurenzo         "get",
705436c6c9cSStella Laurenzo         [](py::dict attributes, DefaultingPyMlirContext context) {
706436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
707436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
708436c6c9cSStella Laurenzo           for (auto &it : attributes) {
709436c6c9cSStella Laurenzo             auto &mlir_attr = it.second.cast<PyAttribute &>();
710436c6c9cSStella Laurenzo             auto name = it.first.cast<std::string>();
711436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
712436c6c9cSStella Laurenzo                 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
713436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
714436c6c9cSStella Laurenzo                 mlir_attr));
715436c6c9cSStella Laurenzo           }
716436c6c9cSStella Laurenzo           MlirAttribute attr =
717436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
718436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
719436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
720436c6c9cSStella Laurenzo         },
721ed9e52f3SAlex Zinenko         py::arg("value") = py::dict(), py::arg("context") = py::none(),
722436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
723436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
724436c6c9cSStella Laurenzo       MlirAttribute attr =
725436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
726436c6c9cSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
727436c6c9cSStella Laurenzo         throw SetPyError(PyExc_KeyError,
728436c6c9cSStella Laurenzo                          "attempt to access a non-existent attribute");
729436c6c9cSStella Laurenzo       }
730436c6c9cSStella Laurenzo       return PyAttribute(self.getContext(), attr);
731436c6c9cSStella Laurenzo     });
732436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
733436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
734436c6c9cSStella Laurenzo         throw SetPyError(PyExc_IndexError,
735436c6c9cSStella Laurenzo                          "attempt to access out of bounds attribute");
736436c6c9cSStella Laurenzo       }
737436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
738436c6c9cSStella Laurenzo       return PyNamedAttribute(
739436c6c9cSStella Laurenzo           namedAttr.attribute,
740436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
741436c6c9cSStella Laurenzo     });
742436c6c9cSStella Laurenzo   }
743436c6c9cSStella Laurenzo };
744436c6c9cSStella Laurenzo 
745436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
746436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
747436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
748436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
749436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
750436c6c9cSStella Laurenzo public:
751436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
752436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
753436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
754436c6c9cSStella Laurenzo 
755436c6c9cSStella Laurenzo   py::float_ dunderGetItem(intptr_t pos) {
756436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
757436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
758436c6c9cSStella Laurenzo                        "attempt to access out of bounds element");
759436c6c9cSStella Laurenzo     }
760436c6c9cSStella Laurenzo 
761436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
762436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
763436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
764436c6c9cSStella Laurenzo     // elemental type of the attribute. py::float_ is implicitly constructible
765436c6c9cSStella Laurenzo     // from float and double.
766436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
767436c6c9cSStella Laurenzo     // querying them on each element access.
768436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
769436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetFloatValue(*this, pos);
770436c6c9cSStella Laurenzo     }
771436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
772436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
773436c6c9cSStella Laurenzo     }
774436c6c9cSStella Laurenzo     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
775436c6c9cSStella Laurenzo   }
776436c6c9cSStella Laurenzo 
777436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
778436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
779436c6c9cSStella Laurenzo   }
780436c6c9cSStella Laurenzo };
781436c6c9cSStella Laurenzo 
782436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
783436c6c9cSStella Laurenzo public:
784436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
785436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
786436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
787436c6c9cSStella Laurenzo 
788436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
789436c6c9cSStella Laurenzo     c.def_static(
790436c6c9cSStella Laurenzo         "get",
791436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
792436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
793436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
794436c6c9cSStella Laurenzo         },
795436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
796436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
797436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyTypeAttribute &self) {
798436c6c9cSStella Laurenzo       return PyType(self.getContext()->getRef(),
799436c6c9cSStella Laurenzo                     mlirTypeAttrGetValue(self.get()));
800436c6c9cSStella Laurenzo     });
801436c6c9cSStella Laurenzo   }
802436c6c9cSStella Laurenzo };
803436c6c9cSStella Laurenzo 
804436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
805436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
806436c6c9cSStella Laurenzo public:
807436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
808436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
809436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
810436c6c9cSStella Laurenzo 
811436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
812436c6c9cSStella Laurenzo     c.def_static(
813436c6c9cSStella Laurenzo         "get",
814436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
815436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
816436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
817436c6c9cSStella Laurenzo         },
818436c6c9cSStella Laurenzo         py::arg("context") = py::none(), "Create a Unit attribute.");
819436c6c9cSStella Laurenzo   }
820436c6c9cSStella Laurenzo };
821436c6c9cSStella Laurenzo 
822436c6c9cSStella Laurenzo } // namespace
823436c6c9cSStella Laurenzo 
824436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) {
825436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
826436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
827436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
828436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
829436c6c9cSStella Laurenzo   PyDenseElementsAttribute::bind(m);
830436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
831436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
832436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
833436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
834436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
835436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
836436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
837436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
838436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
839436c6c9cSStella Laurenzo }
840