xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision e5639b3fa45f68c53e8e676749b0f1aacd16b41d)
1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9436c6c9cSStella Laurenzo #include "IRModule.h"
10436c6c9cSStella Laurenzo 
11436c6c9cSStella Laurenzo #include "PybindUtils.h"
12436c6c9cSStella Laurenzo 
13436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
14436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
15436c6c9cSStella Laurenzo 
16436c6c9cSStella Laurenzo namespace py = pybind11;
17436c6c9cSStella Laurenzo using namespace mlir;
18436c6c9cSStella Laurenzo using namespace mlir::python;
19436c6c9cSStella Laurenzo 
205d6d30edSStella Laurenzo using llvm::Optional;
21436c6c9cSStella Laurenzo using llvm::SmallVector;
22436c6c9cSStella Laurenzo using llvm::Twine;
23436c6c9cSStella Laurenzo 
245d6d30edSStella Laurenzo //------------------------------------------------------------------------------
255d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
265d6d30edSStella Laurenzo //------------------------------------------------------------------------------
275d6d30edSStella Laurenzo 
285d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] =
295d6d30edSStella Laurenzo     R"(Gets a DenseElementsAttr from a Python buffer or array.
305d6d30edSStella Laurenzo 
315d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based
325d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and
335d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these
345d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer.
355d6d30edSStella Laurenzo 
365d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided
375d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal
385d6d30edSStella Laurenzo representation:
395d6d30edSStella Laurenzo 
405d6d30edSStella Laurenzo   * Integer types (except for i1): the buffer must be byte aligned to the
415d6d30edSStella Laurenzo     next byte boundary.
425d6d30edSStella Laurenzo   * Floating point types: Must be bit-castable to the given floating point
435d6d30edSStella Laurenzo     size.
445d6d30edSStella Laurenzo   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
455d6d30edSStella Laurenzo     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
465d6d30edSStella Laurenzo     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
475d6d30edSStella Laurenzo 
485d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0
495d6d30edSStella Laurenzo or 255), then a splat will be created.
505d6d30edSStella Laurenzo 
515d6d30edSStella Laurenzo Args:
525d6d30edSStella Laurenzo   array: The array or buffer to convert.
535d6d30edSStella Laurenzo   signless: If inferring an appropriate MLIR type, use signless types for
545d6d30edSStella Laurenzo     integers (defaults True).
555d6d30edSStella Laurenzo   type: Skips inference of the MLIR element type and uses this instead. The
565d6d30edSStella Laurenzo     storage size must be consistent with the actual contents of the buffer.
575d6d30edSStella Laurenzo   shape: Overrides the shape of the buffer when constructing the MLIR
585d6d30edSStella Laurenzo     shaped type. This is needed when the physical and logical shape differ (as
595d6d30edSStella Laurenzo     for i1).
605d6d30edSStella Laurenzo   context: Explicit context, if not from context manager.
615d6d30edSStella Laurenzo 
625d6d30edSStella Laurenzo Returns:
635d6d30edSStella Laurenzo   DenseElementsAttr on success.
645d6d30edSStella Laurenzo 
655d6d30edSStella Laurenzo Raises:
665d6d30edSStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
675d6d30edSStella Laurenzo     type or if the buffer does not meet expectations.
685d6d30edSStella Laurenzo )";
695d6d30edSStella Laurenzo 
70436c6c9cSStella Laurenzo namespace {
71436c6c9cSStella Laurenzo 
72436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
73436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
74436c6c9cSStella Laurenzo }
75436c6c9cSStella Laurenzo 
76436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
77436c6c9cSStella Laurenzo public:
78436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
79436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
80436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
81436c6c9cSStella Laurenzo 
82436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
83436c6c9cSStella Laurenzo     c.def_static(
84436c6c9cSStella Laurenzo         "get",
85436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
86436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
87436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
88436c6c9cSStella Laurenzo         },
89436c6c9cSStella Laurenzo         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
90436c6c9cSStella Laurenzo   }
91436c6c9cSStella Laurenzo };
92436c6c9cSStella Laurenzo 
93ed9e52f3SAlex Zinenko template <typename T>
94ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) {
95ed9e52f3SAlex Zinenko   try {
96ed9e52f3SAlex Zinenko     return object.cast<T>();
97ed9e52f3SAlex Zinenko   } catch (py::cast_error &err) {
98ed9e52f3SAlex Zinenko     std::string msg =
99ed9e52f3SAlex Zinenko         std::string(
100ed9e52f3SAlex Zinenko             "Invalid attribute when attempting to create an ArrayAttribute (") +
101ed9e52f3SAlex Zinenko         err.what() + ")";
102ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
103ed9e52f3SAlex Zinenko   } catch (py::reference_cast_error &err) {
104ed9e52f3SAlex Zinenko     std::string msg = std::string("Invalid attribute (None?) when attempting "
105ed9e52f3SAlex Zinenko                                   "to create an ArrayAttribute (") +
106ed9e52f3SAlex Zinenko                       err.what() + ")";
107ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
108ed9e52f3SAlex Zinenko   }
109ed9e52f3SAlex Zinenko }
110ed9e52f3SAlex Zinenko 
111436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
112436c6c9cSStella Laurenzo public:
113436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
114436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
115436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
116436c6c9cSStella Laurenzo 
117436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
118436c6c9cSStella Laurenzo   public:
119436c6c9cSStella Laurenzo     PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
120436c6c9cSStella Laurenzo 
121436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
122436c6c9cSStella Laurenzo 
123436c6c9cSStella Laurenzo     PyAttribute dunderNext() {
124436c6c9cSStella Laurenzo       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
125436c6c9cSStella Laurenzo         throw py::stop_iteration();
126436c6c9cSStella Laurenzo       }
127436c6c9cSStella Laurenzo       return PyAttribute(attr.getContext(),
128436c6c9cSStella Laurenzo                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
129436c6c9cSStella Laurenzo     }
130436c6c9cSStella Laurenzo 
131436c6c9cSStella Laurenzo     static void bind(py::module &m) {
132f05ff4f7SStella Laurenzo       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
133f05ff4f7SStella Laurenzo                                            py::module_local())
134436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
135436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
136436c6c9cSStella Laurenzo     }
137436c6c9cSStella Laurenzo 
138436c6c9cSStella Laurenzo   private:
139436c6c9cSStella Laurenzo     PyAttribute attr;
140436c6c9cSStella Laurenzo     int nextIndex = 0;
141436c6c9cSStella Laurenzo   };
142436c6c9cSStella Laurenzo 
143ed9e52f3SAlex Zinenko   PyAttribute getItem(intptr_t i) {
144ed9e52f3SAlex Zinenko     return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
145ed9e52f3SAlex Zinenko   }
146ed9e52f3SAlex Zinenko 
147436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
148436c6c9cSStella Laurenzo     c.def_static(
149436c6c9cSStella Laurenzo         "get",
150436c6c9cSStella Laurenzo         [](py::list attributes, DefaultingPyMlirContext context) {
151436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
152436c6c9cSStella Laurenzo           mlirAttributes.reserve(py::len(attributes));
153436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
154ed9e52f3SAlex Zinenko             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
155436c6c9cSStella Laurenzo           }
156436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
157436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
158436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
159436c6c9cSStella Laurenzo         },
160436c6c9cSStella Laurenzo         py::arg("attributes"), py::arg("context") = py::none(),
161436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
162436c6c9cSStella Laurenzo     c.def("__getitem__",
163436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
164436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
165436c6c9cSStella Laurenzo               throw py::index_error("ArrayAttribute index out of range");
166ed9e52f3SAlex Zinenko             return arr.getItem(i);
167436c6c9cSStella Laurenzo           })
168436c6c9cSStella Laurenzo         .def("__len__",
169436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
170436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
171436c6c9cSStella Laurenzo              })
172436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
173436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
174436c6c9cSStella Laurenzo         });
175ed9e52f3SAlex Zinenko     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
176ed9e52f3SAlex Zinenko       std::vector<MlirAttribute> attributes;
177ed9e52f3SAlex Zinenko       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
178ed9e52f3SAlex Zinenko       attributes.reserve(numOldElements + py::len(extras));
179ed9e52f3SAlex Zinenko       for (intptr_t i = 0; i < numOldElements; ++i)
180ed9e52f3SAlex Zinenko         attributes.push_back(arr.getItem(i));
181ed9e52f3SAlex Zinenko       for (py::handle attr : extras)
182ed9e52f3SAlex Zinenko         attributes.push_back(pyTryCast<PyAttribute>(attr));
183ed9e52f3SAlex Zinenko       MlirAttribute arrayAttr = mlirArrayAttrGet(
184ed9e52f3SAlex Zinenko           arr.getContext()->get(), attributes.size(), attributes.data());
185ed9e52f3SAlex Zinenko       return PyArrayAttribute(arr.getContext(), arrayAttr);
186ed9e52f3SAlex Zinenko     });
187436c6c9cSStella Laurenzo   }
188436c6c9cSStella Laurenzo };
189436c6c9cSStella Laurenzo 
190436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
191436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
192436c6c9cSStella Laurenzo public:
193436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
194436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
195436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
196436c6c9cSStella Laurenzo 
197436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
198436c6c9cSStella Laurenzo     c.def_static(
199436c6c9cSStella Laurenzo         "get",
200436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
201436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
202436c6c9cSStella Laurenzo           // TODO: Rework error reporting once diagnostic engine is exposed
203436c6c9cSStella Laurenzo           // in C API.
204436c6c9cSStella Laurenzo           if (mlirAttributeIsNull(attr)) {
205436c6c9cSStella Laurenzo             throw SetPyError(PyExc_ValueError,
206436c6c9cSStella Laurenzo                              Twine("invalid '") +
207436c6c9cSStella Laurenzo                                  py::repr(py::cast(type)).cast<std::string>() +
208436c6c9cSStella Laurenzo                                  "' and expected floating point type.");
209436c6c9cSStella Laurenzo           }
210436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
211436c6c9cSStella Laurenzo         },
212436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
213436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
214436c6c9cSStella Laurenzo     c.def_static(
215436c6c9cSStella Laurenzo         "get_f32",
216436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
217436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
218436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
219436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
220436c6c9cSStella Laurenzo         },
221436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
222436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
223436c6c9cSStella Laurenzo     c.def_static(
224436c6c9cSStella Laurenzo         "get_f64",
225436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
226436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
227436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
228436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
229436c6c9cSStella Laurenzo         },
230436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
231436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
232436c6c9cSStella Laurenzo     c.def_property_readonly(
233436c6c9cSStella Laurenzo         "value",
234436c6c9cSStella Laurenzo         [](PyFloatAttribute &self) {
235436c6c9cSStella Laurenzo           return mlirFloatAttrGetValueDouble(self);
236436c6c9cSStella Laurenzo         },
237436c6c9cSStella Laurenzo         "Returns the value of the float point attribute");
238436c6c9cSStella Laurenzo   }
239436c6c9cSStella Laurenzo };
240436c6c9cSStella Laurenzo 
241436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
242436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
243436c6c9cSStella Laurenzo public:
244436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
245436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
246436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
247436c6c9cSStella Laurenzo 
248436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
249436c6c9cSStella Laurenzo     c.def_static(
250436c6c9cSStella Laurenzo         "get",
251436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
252436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
253436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
254436c6c9cSStella Laurenzo         },
255436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"),
256436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
257436c6c9cSStella Laurenzo     c.def_property_readonly(
258436c6c9cSStella Laurenzo         "value",
259436c6c9cSStella Laurenzo         [](PyIntegerAttribute &self) {
260436c6c9cSStella Laurenzo           return mlirIntegerAttrGetValueInt(self);
261436c6c9cSStella Laurenzo         },
262436c6c9cSStella Laurenzo         "Returns the value of the integer attribute");
263436c6c9cSStella Laurenzo   }
264436c6c9cSStella Laurenzo };
265436c6c9cSStella Laurenzo 
266436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
267436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
268436c6c9cSStella Laurenzo public:
269436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
270436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
271436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
272436c6c9cSStella Laurenzo 
273436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
274436c6c9cSStella Laurenzo     c.def_static(
275436c6c9cSStella Laurenzo         "get",
276436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
277436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
278436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
279436c6c9cSStella Laurenzo         },
280436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
281436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
282436c6c9cSStella Laurenzo     c.def_property_readonly(
283436c6c9cSStella Laurenzo         "value",
284436c6c9cSStella Laurenzo         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
285436c6c9cSStella Laurenzo         "Returns the value of the bool attribute");
286436c6c9cSStella Laurenzo   }
287436c6c9cSStella Laurenzo };
288436c6c9cSStella Laurenzo 
289436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
290436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
291436c6c9cSStella Laurenzo public:
292436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
293436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
294436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
295436c6c9cSStella Laurenzo 
296436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
297436c6c9cSStella Laurenzo     c.def_static(
298436c6c9cSStella Laurenzo         "get",
299436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
300436c6c9cSStella Laurenzo           MlirAttribute attr =
301436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
302436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
303436c6c9cSStella Laurenzo         },
304436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
305436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
306436c6c9cSStella Laurenzo     c.def_property_readonly(
307436c6c9cSStella Laurenzo         "value",
308436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
309436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
310436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
311436c6c9cSStella Laurenzo         },
312436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
313436c6c9cSStella Laurenzo   }
314436c6c9cSStella Laurenzo };
315436c6c9cSStella Laurenzo 
316436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
317436c6c9cSStella Laurenzo public:
318436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
319436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
320436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
321436c6c9cSStella Laurenzo 
322436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
323436c6c9cSStella Laurenzo     c.def_static(
324436c6c9cSStella Laurenzo         "get",
325436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
326436c6c9cSStella Laurenzo           MlirAttribute attr =
327436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
328436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
329436c6c9cSStella Laurenzo         },
330436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
331436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
332436c6c9cSStella Laurenzo     c.def_static(
333436c6c9cSStella Laurenzo         "get_typed",
334436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
335436c6c9cSStella Laurenzo           MlirAttribute attr =
336436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
337436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
338436c6c9cSStella Laurenzo         },
339a6e7d024SStella Laurenzo         py::arg("type"), py::arg("value"),
340436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
341436c6c9cSStella Laurenzo     c.def_property_readonly(
342436c6c9cSStella Laurenzo         "value",
343436c6c9cSStella Laurenzo         [](PyStringAttribute &self) {
344436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirStringAttrGetValue(self);
345436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
346436c6c9cSStella Laurenzo         },
347436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
348436c6c9cSStella Laurenzo   }
349436c6c9cSStella Laurenzo };
350436c6c9cSStella Laurenzo 
351436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
352436c6c9cSStella Laurenzo class PyDenseElementsAttribute
353436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
354436c6c9cSStella Laurenzo public:
355436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
356436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
357436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
358436c6c9cSStella Laurenzo 
359436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
3605d6d30edSStella Laurenzo   getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
3615d6d30edSStella Laurenzo                 Optional<std::vector<int64_t>> explicitShape,
362436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
363436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
364436c6c9cSStella Laurenzo     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
365436c6c9cSStella Laurenzo     Py_buffer *view = new Py_buffer();
366436c6c9cSStella Laurenzo     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
367436c6c9cSStella Laurenzo       delete view;
368436c6c9cSStella Laurenzo       throw py::error_already_set();
369436c6c9cSStella Laurenzo     }
370436c6c9cSStella Laurenzo     py::buffer_info arrayInfo(view);
3715d6d30edSStella Laurenzo     SmallVector<int64_t> shape;
3725d6d30edSStella Laurenzo     if (explicitShape) {
3735d6d30edSStella Laurenzo       shape.append(explicitShape->begin(), explicitShape->end());
3745d6d30edSStella Laurenzo     } else {
3755d6d30edSStella Laurenzo       shape.append(arrayInfo.shape.begin(),
3765d6d30edSStella Laurenzo                    arrayInfo.shape.begin() + arrayInfo.ndim);
3775d6d30edSStella Laurenzo     }
378436c6c9cSStella Laurenzo 
3795d6d30edSStella Laurenzo     MlirAttribute encodingAttr = mlirAttributeGetNull();
380436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
3815d6d30edSStella Laurenzo 
3825d6d30edSStella Laurenzo     // Detect format codes that are suitable for bulk loading. This includes
3835d6d30edSStella Laurenzo     // all byte aligned integer and floating point types up to 8 bytes.
3845d6d30edSStella Laurenzo     // Notably, this excludes, bool (which needs to be bit-packed) and
3855d6d30edSStella Laurenzo     // other exotics which do not have a direct representation in the buffer
3865d6d30edSStella Laurenzo     // protocol (i.e. complex, etc).
3875d6d30edSStella Laurenzo     Optional<MlirType> bulkLoadElementType;
3885d6d30edSStella Laurenzo     if (explicitType) {
3895d6d30edSStella Laurenzo       bulkLoadElementType = *explicitType;
3905d6d30edSStella Laurenzo     } else if (arrayInfo.format == "f") {
391436c6c9cSStella Laurenzo       // f32
392436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
3935d6d30edSStella Laurenzo       bulkLoadElementType = mlirF32TypeGet(context);
394436c6c9cSStella Laurenzo     } else if (arrayInfo.format == "d") {
395436c6c9cSStella Laurenzo       // f64
396436c6c9cSStella Laurenzo       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
3975d6d30edSStella Laurenzo       bulkLoadElementType = mlirF64TypeGet(context);
3985d6d30edSStella Laurenzo     } else if (arrayInfo.format == "e") {
3995d6d30edSStella Laurenzo       // f16
4005d6d30edSStella Laurenzo       assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
4015d6d30edSStella Laurenzo       bulkLoadElementType = mlirF16TypeGet(context);
402436c6c9cSStella Laurenzo     } else if (isSignedIntegerFormat(arrayInfo.format)) {
403436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
404436c6c9cSStella Laurenzo         // i32
4055d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
406436c6c9cSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 32);
407436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
408436c6c9cSStella Laurenzo         // i64
4095d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
410436c6c9cSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 64);
4115d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 1) {
4125d6d30edSStella Laurenzo         // i8
4135d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
4145d6d30edSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 8);
4155d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 2) {
4165d6d30edSStella Laurenzo         // i16
4175d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
4185d6d30edSStella Laurenzo                                        : mlirIntegerTypeSignedGet(context, 16);
419436c6c9cSStella Laurenzo       }
420436c6c9cSStella Laurenzo     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
421436c6c9cSStella Laurenzo       if (arrayInfo.itemsize == 4) {
422436c6c9cSStella Laurenzo         // unsigned i32
4235d6d30edSStella Laurenzo         bulkLoadElementType = signless
424436c6c9cSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 32)
425436c6c9cSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 32);
426436c6c9cSStella Laurenzo       } else if (arrayInfo.itemsize == 8) {
427436c6c9cSStella Laurenzo         // unsigned i64
4285d6d30edSStella Laurenzo         bulkLoadElementType = signless
429436c6c9cSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 64)
430436c6c9cSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 64);
4315d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 1) {
4325d6d30edSStella Laurenzo         // i8
4335d6d30edSStella Laurenzo         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
4345d6d30edSStella Laurenzo                                        : mlirIntegerTypeUnsignedGet(context, 8);
4355d6d30edSStella Laurenzo       } else if (arrayInfo.itemsize == 2) {
4365d6d30edSStella Laurenzo         // i16
4375d6d30edSStella Laurenzo         bulkLoadElementType = signless
4385d6d30edSStella Laurenzo                                   ? mlirIntegerTypeGet(context, 16)
4395d6d30edSStella Laurenzo                                   : mlirIntegerTypeUnsignedGet(context, 16);
440436c6c9cSStella Laurenzo       }
441436c6c9cSStella Laurenzo     }
4425d6d30edSStella Laurenzo     if (bulkLoadElementType) {
4435d6d30edSStella Laurenzo       auto shapedType = mlirRankedTensorTypeGet(
4445d6d30edSStella Laurenzo           shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
4455d6d30edSStella Laurenzo       size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
4465d6d30edSStella Laurenzo       MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
4475d6d30edSStella Laurenzo           shapedType, rawBufferSize, arrayInfo.ptr);
4485d6d30edSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
4495d6d30edSStella Laurenzo         throw std::invalid_argument(
4505d6d30edSStella Laurenzo             "DenseElementsAttr could not be constructed from the given buffer. "
4515d6d30edSStella Laurenzo             "This may mean that the Python buffer layout does not match that "
4525d6d30edSStella Laurenzo             "MLIR expected layout and is a bug.");
4535d6d30edSStella Laurenzo       }
4545d6d30edSStella Laurenzo       return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
4555d6d30edSStella Laurenzo     }
456436c6c9cSStella Laurenzo 
4575d6d30edSStella Laurenzo     throw std::invalid_argument(
4585d6d30edSStella Laurenzo         std::string("unimplemented array format conversion from format: ") +
4595d6d30edSStella Laurenzo         arrayInfo.format);
460436c6c9cSStella Laurenzo   }
461436c6c9cSStella Laurenzo 
462436c6c9cSStella Laurenzo   static PyDenseElementsAttribute getSplat(PyType shapedType,
463436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
464436c6c9cSStella Laurenzo     auto contextWrapper =
465436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
466436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
467436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
468436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
469436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
470436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
471436c6c9cSStella Laurenzo     }
472436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
473436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
474436c6c9cSStella Laurenzo       std::string message =
475436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
476436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
477436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
478436c6c9cSStella Laurenzo     }
479436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
480436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
481436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
482436c6c9cSStella Laurenzo       std::string message =
483436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
484436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
485436c6c9cSStella Laurenzo       message.append(", element=");
486436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
487436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, message);
488436c6c9cSStella Laurenzo     }
489436c6c9cSStella Laurenzo 
490436c6c9cSStella Laurenzo     MlirAttribute elements =
491436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
492436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
493436c6c9cSStella Laurenzo   }
494436c6c9cSStella Laurenzo 
495436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
496436c6c9cSStella Laurenzo 
497436c6c9cSStella Laurenzo   py::buffer_info accessBuffer() {
4985d6d30edSStella Laurenzo     if (mlirDenseElementsAttrIsSplat(*this)) {
499c5f445d1SStella Laurenzo       // TODO: Currently crashes the program.
5005d6d30edSStella Laurenzo       // Reported as https://github.com/pybind/pybind11/issues/3336
501c5f445d1SStella Laurenzo       throw std::invalid_argument(
502c5f445d1SStella Laurenzo           "unsupported data type for conversion to Python buffer");
5035d6d30edSStella Laurenzo     }
5045d6d30edSStella Laurenzo 
505436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
506436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
5075d6d30edSStella Laurenzo     std::string format;
508436c6c9cSStella Laurenzo 
509436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
510436c6c9cSStella Laurenzo       // f32
5115d6d30edSStella Laurenzo       return bufferInfo<float>(shapedType);
51202b6fb21SMehdi Amini     }
51302b6fb21SMehdi Amini     if (mlirTypeIsAF64(elementType)) {
514436c6c9cSStella Laurenzo       // f64
5155d6d30edSStella Laurenzo       return bufferInfo<double>(shapedType);
516bb56c2b3SMehdi Amini     }
517bb56c2b3SMehdi Amini     if (mlirTypeIsAF16(elementType)) {
5185d6d30edSStella Laurenzo       // f16
5195d6d30edSStella Laurenzo       return bufferInfo<uint16_t>(shapedType, "e");
520bb56c2b3SMehdi Amini     }
521bb56c2b3SMehdi Amini     if (mlirTypeIsAInteger(elementType) &&
522436c6c9cSStella Laurenzo         mlirIntegerTypeGetWidth(elementType) == 32) {
523436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
524436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
525436c6c9cSStella Laurenzo         // i32
5265d6d30edSStella Laurenzo         return bufferInfo<int32_t>(shapedType);
527*e5639b3fSMehdi Amini       }
528*e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
529436c6c9cSStella Laurenzo         // unsigned i32
5305d6d30edSStella Laurenzo         return bufferInfo<uint32_t>(shapedType);
531436c6c9cSStella Laurenzo       }
532436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
533436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
534436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
535436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
536436c6c9cSStella Laurenzo         // i64
5375d6d30edSStella Laurenzo         return bufferInfo<int64_t>(shapedType);
538*e5639b3fSMehdi Amini       }
539*e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
540436c6c9cSStella Laurenzo         // unsigned i64
5415d6d30edSStella Laurenzo         return bufferInfo<uint64_t>(shapedType);
5425d6d30edSStella Laurenzo       }
5435d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
5445d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 8) {
5455d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
5465d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
5475d6d30edSStella Laurenzo         // i8
5485d6d30edSStella Laurenzo         return bufferInfo<int8_t>(shapedType);
549*e5639b3fSMehdi Amini       }
550*e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
5515d6d30edSStella Laurenzo         // unsigned i8
5525d6d30edSStella Laurenzo         return bufferInfo<uint8_t>(shapedType);
5535d6d30edSStella Laurenzo       }
5545d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
5555d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 16) {
5565d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
5575d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
5585d6d30edSStella Laurenzo         // i16
5595d6d30edSStella Laurenzo         return bufferInfo<int16_t>(shapedType);
560*e5639b3fSMehdi Amini       }
561*e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
5625d6d30edSStella Laurenzo         // unsigned i16
5635d6d30edSStella Laurenzo         return bufferInfo<uint16_t>(shapedType);
564436c6c9cSStella Laurenzo       }
565436c6c9cSStella Laurenzo     }
566436c6c9cSStella Laurenzo 
567c5f445d1SStella Laurenzo     // TODO: Currently crashes the program.
5685d6d30edSStella Laurenzo     // Reported as https://github.com/pybind/pybind11/issues/3336
569c5f445d1SStella Laurenzo     throw std::invalid_argument(
570c5f445d1SStella Laurenzo         "unsupported data type for conversion to Python buffer");
571436c6c9cSStella Laurenzo   }
572436c6c9cSStella Laurenzo 
573436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
574436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
575436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
576436c6c9cSStella Laurenzo                     py::arg("array"), py::arg("signless") = true,
5775d6d30edSStella Laurenzo                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
578436c6c9cSStella Laurenzo                     py::arg("context") = py::none(),
5795d6d30edSStella Laurenzo                     kDenseElementsAttrGetDocstring)
580436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
581436c6c9cSStella Laurenzo                     py::arg("shaped_type"), py::arg("element_attr"),
582436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
583436c6c9cSStella Laurenzo         .def_property_readonly("is_splat",
584436c6c9cSStella Laurenzo                                [](PyDenseElementsAttribute &self) -> bool {
585436c6c9cSStella Laurenzo                                  return mlirDenseElementsAttrIsSplat(self);
586436c6c9cSStella Laurenzo                                })
587436c6c9cSStella Laurenzo         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
588436c6c9cSStella Laurenzo   }
589436c6c9cSStella Laurenzo 
590436c6c9cSStella Laurenzo private:
591436c6c9cSStella Laurenzo   static bool isUnsignedIntegerFormat(const std::string &format) {
592436c6c9cSStella Laurenzo     if (format.empty())
593436c6c9cSStella Laurenzo       return false;
594436c6c9cSStella Laurenzo     char code = format[0];
595436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
596436c6c9cSStella Laurenzo            code == 'Q';
597436c6c9cSStella Laurenzo   }
598436c6c9cSStella Laurenzo 
599436c6c9cSStella Laurenzo   static bool isSignedIntegerFormat(const std::string &format) {
600436c6c9cSStella Laurenzo     if (format.empty())
601436c6c9cSStella Laurenzo       return false;
602436c6c9cSStella Laurenzo     char code = format[0];
603436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
604436c6c9cSStella Laurenzo            code == 'q';
605436c6c9cSStella Laurenzo   }
606436c6c9cSStella Laurenzo 
607436c6c9cSStella Laurenzo   template <typename Type>
608436c6c9cSStella Laurenzo   py::buffer_info bufferInfo(MlirType shapedType,
6095d6d30edSStella Laurenzo                              const char *explicitFormat = nullptr) {
610436c6c9cSStella Laurenzo     intptr_t rank = mlirShapedTypeGetRank(shapedType);
611436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
612436c6c9cSStella Laurenzo     // Buffer is configured for read-only access below.
613436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
614436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
615436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
616436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
617436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
618436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
619436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
620436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
621436c6c9cSStella Laurenzo     intptr_t strideFactor = 1;
622436c6c9cSStella Laurenzo     for (intptr_t i = 1; i < rank; ++i) {
623436c6c9cSStella Laurenzo       strideFactor = 1;
624436c6c9cSStella Laurenzo       for (intptr_t j = i; j < rank; ++j) {
625436c6c9cSStella Laurenzo         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
626436c6c9cSStella Laurenzo       }
627436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type) * strideFactor);
628436c6c9cSStella Laurenzo     }
629436c6c9cSStella Laurenzo     strides.push_back(sizeof(Type));
6305d6d30edSStella Laurenzo     std::string format;
6315d6d30edSStella Laurenzo     if (explicitFormat) {
6325d6d30edSStella Laurenzo       format = explicitFormat;
6335d6d30edSStella Laurenzo     } else {
6345d6d30edSStella Laurenzo       format = py::format_descriptor<Type>::format();
6355d6d30edSStella Laurenzo     }
6365d6d30edSStella Laurenzo     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
6375d6d30edSStella Laurenzo                            /*readonly=*/true);
638436c6c9cSStella Laurenzo   }
639436c6c9cSStella Laurenzo }; // namespace
640436c6c9cSStella Laurenzo 
641436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer
642436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access.
643436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
644436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
645436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
646436c6c9cSStella Laurenzo public:
647436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
648436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
649436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
650436c6c9cSStella Laurenzo 
651436c6c9cSStella Laurenzo   /// Returns the element at the given linear position. Asserts if the index is
652436c6c9cSStella Laurenzo   /// out of range.
653436c6c9cSStella Laurenzo   py::int_ dunderGetItem(intptr_t pos) {
654436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
655436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
656436c6c9cSStella Laurenzo                        "attempt to access out of bounds element");
657436c6c9cSStella Laurenzo     }
658436c6c9cSStella Laurenzo 
659436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
660436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
661436c6c9cSStella Laurenzo     assert(mlirTypeIsAInteger(type) &&
662436c6c9cSStella Laurenzo            "expected integer element type in dense int elements attribute");
663436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
664436c6c9cSStella Laurenzo     // elemental type of the attribute. py::int_ is implicitly constructible
665436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
666436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
667436c6c9cSStella Laurenzo     // querying them on each element access.
668436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
669436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
670436c6c9cSStella Laurenzo     if (isUnsigned) {
671436c6c9cSStella Laurenzo       if (width == 1) {
672436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
673436c6c9cSStella Laurenzo       }
674436c6c9cSStella Laurenzo       if (width == 32) {
675436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
676436c6c9cSStella Laurenzo       }
677436c6c9cSStella Laurenzo       if (width == 64) {
678436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
679436c6c9cSStella Laurenzo       }
680436c6c9cSStella Laurenzo     } else {
681436c6c9cSStella Laurenzo       if (width == 1) {
682436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
683436c6c9cSStella Laurenzo       }
684436c6c9cSStella Laurenzo       if (width == 32) {
685436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt32Value(*this, pos);
686436c6c9cSStella Laurenzo       }
687436c6c9cSStella Laurenzo       if (width == 64) {
688436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt64Value(*this, pos);
689436c6c9cSStella Laurenzo       }
690436c6c9cSStella Laurenzo     }
691436c6c9cSStella Laurenzo     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
692436c6c9cSStella Laurenzo   }
693436c6c9cSStella Laurenzo 
694436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
695436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
696436c6c9cSStella Laurenzo   }
697436c6c9cSStella Laurenzo };
698436c6c9cSStella Laurenzo 
699436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
700436c6c9cSStella Laurenzo public:
701436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
702436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
703436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
704436c6c9cSStella Laurenzo 
705436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
706436c6c9cSStella Laurenzo 
7079fb1086bSAdrian Kuegel   bool dunderContains(const std::string &name) {
7089fb1086bSAdrian Kuegel     return !mlirAttributeIsNull(
7099fb1086bSAdrian Kuegel         mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
7109fb1086bSAdrian Kuegel   }
7119fb1086bSAdrian Kuegel 
712436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
7139fb1086bSAdrian Kuegel     c.def("__contains__", &PyDictAttribute::dunderContains);
714436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
715436c6c9cSStella Laurenzo     c.def_static(
716436c6c9cSStella Laurenzo         "get",
717436c6c9cSStella Laurenzo         [](py::dict attributes, DefaultingPyMlirContext context) {
718436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
719436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
720436c6c9cSStella Laurenzo           for (auto &it : attributes) {
72102b6fb21SMehdi Amini             auto &mlirAttr = it.second.cast<PyAttribute &>();
722436c6c9cSStella Laurenzo             auto name = it.first.cast<std::string>();
723436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
72402b6fb21SMehdi Amini                 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
725436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
72602b6fb21SMehdi Amini                 mlirAttr));
727436c6c9cSStella Laurenzo           }
728436c6c9cSStella Laurenzo           MlirAttribute attr =
729436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
730436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
731436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
732436c6c9cSStella Laurenzo         },
733ed9e52f3SAlex Zinenko         py::arg("value") = py::dict(), py::arg("context") = py::none(),
734436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
735436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
736436c6c9cSStella Laurenzo       MlirAttribute attr =
737436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
738436c6c9cSStella Laurenzo       if (mlirAttributeIsNull(attr)) {
739436c6c9cSStella Laurenzo         throw SetPyError(PyExc_KeyError,
740436c6c9cSStella Laurenzo                          "attempt to access a non-existent attribute");
741436c6c9cSStella Laurenzo       }
742436c6c9cSStella Laurenzo       return PyAttribute(self.getContext(), attr);
743436c6c9cSStella Laurenzo     });
744436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
745436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
746436c6c9cSStella Laurenzo         throw SetPyError(PyExc_IndexError,
747436c6c9cSStella Laurenzo                          "attempt to access out of bounds attribute");
748436c6c9cSStella Laurenzo       }
749436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
750436c6c9cSStella Laurenzo       return PyNamedAttribute(
751436c6c9cSStella Laurenzo           namedAttr.attribute,
752436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
753436c6c9cSStella Laurenzo     });
754436c6c9cSStella Laurenzo   }
755436c6c9cSStella Laurenzo };
756436c6c9cSStella Laurenzo 
757436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
758436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
759436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
760436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
761436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
762436c6c9cSStella Laurenzo public:
763436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
764436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
765436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
766436c6c9cSStella Laurenzo 
767436c6c9cSStella Laurenzo   py::float_ dunderGetItem(intptr_t pos) {
768436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
769436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
770436c6c9cSStella Laurenzo                        "attempt to access out of bounds element");
771436c6c9cSStella Laurenzo     }
772436c6c9cSStella Laurenzo 
773436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
774436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
775436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
776436c6c9cSStella Laurenzo     // elemental type of the attribute. py::float_ is implicitly constructible
777436c6c9cSStella Laurenzo     // from float and double.
778436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
779436c6c9cSStella Laurenzo     // querying them on each element access.
780436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
781436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetFloatValue(*this, pos);
782436c6c9cSStella Laurenzo     }
783436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
784436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
785436c6c9cSStella Laurenzo     }
786436c6c9cSStella Laurenzo     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
787436c6c9cSStella Laurenzo   }
788436c6c9cSStella Laurenzo 
789436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
790436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
791436c6c9cSStella Laurenzo   }
792436c6c9cSStella Laurenzo };
793436c6c9cSStella Laurenzo 
794436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
795436c6c9cSStella Laurenzo public:
796436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
797436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
798436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
799436c6c9cSStella Laurenzo 
800436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
801436c6c9cSStella Laurenzo     c.def_static(
802436c6c9cSStella Laurenzo         "get",
803436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
804436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
805436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
806436c6c9cSStella Laurenzo         },
807436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
808436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
809436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyTypeAttribute &self) {
810436c6c9cSStella Laurenzo       return PyType(self.getContext()->getRef(),
811436c6c9cSStella Laurenzo                     mlirTypeAttrGetValue(self.get()));
812436c6c9cSStella Laurenzo     });
813436c6c9cSStella Laurenzo   }
814436c6c9cSStella Laurenzo };
815436c6c9cSStella Laurenzo 
816436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
817436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
818436c6c9cSStella Laurenzo public:
819436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
820436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
821436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
822436c6c9cSStella Laurenzo 
823436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
824436c6c9cSStella Laurenzo     c.def_static(
825436c6c9cSStella Laurenzo         "get",
826436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
827436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
828436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
829436c6c9cSStella Laurenzo         },
830436c6c9cSStella Laurenzo         py::arg("context") = py::none(), "Create a Unit attribute.");
831436c6c9cSStella Laurenzo   }
832436c6c9cSStella Laurenzo };
833436c6c9cSStella Laurenzo 
834436c6c9cSStella Laurenzo } // namespace
835436c6c9cSStella Laurenzo 
836436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) {
837436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
838436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
839436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
840436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
841436c6c9cSStella Laurenzo   PyDenseElementsAttribute::bind(m);
842436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
843436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
844436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
845436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
846436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
847436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
848436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
849436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
850436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
851436c6c9cSStella Laurenzo }
852