xref: /llvm-project/mlir/lib/Bindings/Python/IRAttributes.cpp (revision 0a68171b3c67503f7143856580f1b22a93ef566e)
1436c6c9cSStella Laurenzo //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9a1fe1f5fSKazu Hirata #include <optional>
1071a25454SPeter Hawkins #include <string_view>
114811270bSmax #include <utility>
121fc096afSMehdi Amini 
13436c6c9cSStella Laurenzo #include "IRModule.h"
14436c6c9cSStella Laurenzo 
15436c6c9cSStella Laurenzo #include "PybindUtils.h"
16436c6c9cSStella Laurenzo 
1771a25454SPeter Hawkins #include "llvm/ADT/ScopeExit.h"
18c912f0e7Spranavm-nvidia #include "llvm/Support/raw_ostream.h"
1971a25454SPeter Hawkins 
20436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
21436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
22bfb1ba75Smax #include "mlir/Bindings/Python/PybindAdaptors.h"
23436c6c9cSStella Laurenzo 
24436c6c9cSStella Laurenzo namespace py = pybind11;
25436c6c9cSStella Laurenzo using namespace mlir;
26436c6c9cSStella Laurenzo using namespace mlir::python;
27436c6c9cSStella Laurenzo 
28436c6c9cSStella Laurenzo using llvm::SmallVector;
29436c6c9cSStella Laurenzo 
305d6d30edSStella Laurenzo //------------------------------------------------------------------------------
315d6d30edSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
325d6d30edSStella Laurenzo //------------------------------------------------------------------------------
335d6d30edSStella Laurenzo 
345d6d30edSStella Laurenzo static const char kDenseElementsAttrGetDocstring[] =
355d6d30edSStella Laurenzo     R"(Gets a DenseElementsAttr from a Python buffer or array.
365d6d30edSStella Laurenzo 
375d6d30edSStella Laurenzo When `type` is not provided, then some limited type inferencing is done based
385d6d30edSStella Laurenzo on the buffer format. Support presently exists for 8/16/32/64 signed and
395d6d30edSStella Laurenzo unsigned integers and float16/float32/float64. DenseElementsAttrs of these
405d6d30edSStella Laurenzo types can also be converted back to a corresponding buffer.
415d6d30edSStella Laurenzo 
425d6d30edSStella Laurenzo For conversions outside of these types, a `type=` must be explicitly provided
435d6d30edSStella Laurenzo and the buffer contents must be bit-castable to the MLIR internal
445d6d30edSStella Laurenzo representation:
455d6d30edSStella Laurenzo 
465d6d30edSStella Laurenzo   * Integer types (except for i1): the buffer must be byte aligned to the
475d6d30edSStella Laurenzo     next byte boundary.
485d6d30edSStella Laurenzo   * Floating point types: Must be bit-castable to the given floating point
495d6d30edSStella Laurenzo     size.
505d6d30edSStella Laurenzo   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
515d6d30edSStella Laurenzo     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
525d6d30edSStella Laurenzo     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
535d6d30edSStella Laurenzo 
545d6d30edSStella Laurenzo If a single element buffer is passed (or for i1, a single byte with value 0
555d6d30edSStella Laurenzo or 255), then a splat will be created.
565d6d30edSStella Laurenzo 
575d6d30edSStella Laurenzo Args:
585d6d30edSStella Laurenzo   array: The array or buffer to convert.
595d6d30edSStella Laurenzo   signless: If inferring an appropriate MLIR type, use signless types for
605d6d30edSStella Laurenzo     integers (defaults True).
615d6d30edSStella Laurenzo   type: Skips inference of the MLIR element type and uses this instead. The
625d6d30edSStella Laurenzo     storage size must be consistent with the actual contents of the buffer.
635d6d30edSStella Laurenzo   shape: Overrides the shape of the buffer when constructing the MLIR
645d6d30edSStella Laurenzo     shaped type. This is needed when the physical and logical shape differ (as
655d6d30edSStella Laurenzo     for i1).
665d6d30edSStella Laurenzo   context: Explicit context, if not from context manager.
675d6d30edSStella Laurenzo 
685d6d30edSStella Laurenzo Returns:
695d6d30edSStella Laurenzo   DenseElementsAttr on success.
705d6d30edSStella Laurenzo 
715d6d30edSStella Laurenzo Raises:
725d6d30edSStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
735d6d30edSStella Laurenzo     type or if the buffer does not meet expectations.
745d6d30edSStella Laurenzo )";
755d6d30edSStella Laurenzo 
76c912f0e7Spranavm-nvidia static const char kDenseElementsAttrGetFromListDocstring[] =
77c912f0e7Spranavm-nvidia     R"(Gets a DenseElementsAttr from a Python list of attributes.
78c912f0e7Spranavm-nvidia 
79c912f0e7Spranavm-nvidia Note that it can be expensive to construct attributes individually.
80c912f0e7Spranavm-nvidia For a large number of elements, consider using a Python buffer or array instead.
81c912f0e7Spranavm-nvidia 
82c912f0e7Spranavm-nvidia Args:
83c912f0e7Spranavm-nvidia   attrs: A list of attributes.
84c912f0e7Spranavm-nvidia   type: The desired shape and type of the resulting DenseElementsAttr.
85c912f0e7Spranavm-nvidia     If not provided, the element type is determined based on the type
86c912f0e7Spranavm-nvidia     of the 0th attribute and the shape is `[len(attrs)]`.
87c912f0e7Spranavm-nvidia   context: Explicit context, if not from context manager.
88c912f0e7Spranavm-nvidia 
89c912f0e7Spranavm-nvidia Returns:
90c912f0e7Spranavm-nvidia   DenseElementsAttr on success.
91c912f0e7Spranavm-nvidia 
92c912f0e7Spranavm-nvidia Raises:
93c912f0e7Spranavm-nvidia   ValueError: If the type of the attributes does not match the type
94c912f0e7Spranavm-nvidia     specified by `shaped_type`.
95c912f0e7Spranavm-nvidia )";
96c912f0e7Spranavm-nvidia 
97f66cd9e9SStella Laurenzo static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
98f66cd9e9SStella Laurenzo     R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
99f66cd9e9SStella Laurenzo 
100f66cd9e9SStella Laurenzo This function does minimal validation or massaging of the data, and it is
101f66cd9e9SStella Laurenzo up to the caller to ensure that the buffer meets the characteristics
102f66cd9e9SStella Laurenzo implied by the shape.
103f66cd9e9SStella Laurenzo 
104f66cd9e9SStella Laurenzo The backing buffer and any user objects will be retained for the lifetime
105f66cd9e9SStella Laurenzo of the resource blob. This is typically bounded to the context but the
106f66cd9e9SStella Laurenzo resource can have a shorter lifespan depending on how it is used in
107f66cd9e9SStella Laurenzo subsequent processing.
108f66cd9e9SStella Laurenzo 
109f66cd9e9SStella Laurenzo Args:
110f66cd9e9SStella Laurenzo   buffer: The array or buffer to convert.
111f66cd9e9SStella Laurenzo   name: Name to provide to the resource (may be changed upon collision).
112f66cd9e9SStella Laurenzo   type: The explicit ShapedType to construct the attribute with.
113f66cd9e9SStella Laurenzo   context: Explicit context, if not from context manager.
114f66cd9e9SStella Laurenzo 
115f66cd9e9SStella Laurenzo Returns:
116f66cd9e9SStella Laurenzo   DenseResourceElementsAttr on success.
117f66cd9e9SStella Laurenzo 
118f66cd9e9SStella Laurenzo Raises:
119f66cd9e9SStella Laurenzo   ValueError: If the type of the buffer or array cannot be matched to an MLIR
120f66cd9e9SStella Laurenzo     type or if the buffer does not meet expectations.
121f66cd9e9SStella Laurenzo )";
122f66cd9e9SStella Laurenzo 
123436c6c9cSStella Laurenzo namespace {
124436c6c9cSStella Laurenzo 
125436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
126436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
127436c6c9cSStella Laurenzo }
128436c6c9cSStella Laurenzo 
129436c6c9cSStella Laurenzo class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
130436c6c9cSStella Laurenzo public:
131436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
132436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "AffineMapAttr";
133436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1349566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1359566ee28Smax       mlirAffineMapAttrGetTypeID;
136436c6c9cSStella Laurenzo 
137436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
138436c6c9cSStella Laurenzo     c.def_static(
139436c6c9cSStella Laurenzo         "get",
140436c6c9cSStella Laurenzo         [](PyAffineMap &affineMap) {
141436c6c9cSStella Laurenzo           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
142436c6c9cSStella Laurenzo           return PyAffineMapAttribute(affineMap.getContext(), attr);
143436c6c9cSStella Laurenzo         },
144436c6c9cSStella Laurenzo         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
145c36b4248SBimo     c.def_property_readonly("value", mlirAffineMapAttrGetValue,
146c36b4248SBimo                             "Returns the value of the AffineMap attribute");
147436c6c9cSStella Laurenzo   }
148436c6c9cSStella Laurenzo };
149436c6c9cSStella Laurenzo 
150334873feSAmy Wang class PyIntegerSetAttribute
151334873feSAmy Wang     : public PyConcreteAttribute<PyIntegerSetAttribute> {
152334873feSAmy Wang public:
153334873feSAmy Wang   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
154334873feSAmy Wang   static constexpr const char *pyClassName = "IntegerSetAttr";
155334873feSAmy Wang   using PyConcreteAttribute::PyConcreteAttribute;
156334873feSAmy Wang   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
157334873feSAmy Wang       mlirIntegerSetAttrGetTypeID;
158334873feSAmy Wang 
159334873feSAmy Wang   static void bindDerived(ClassTy &c) {
160334873feSAmy Wang     c.def_static(
161334873feSAmy Wang         "get",
162334873feSAmy Wang         [](PyIntegerSet &integerSet) {
163334873feSAmy Wang           MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
164334873feSAmy Wang           return PyIntegerSetAttribute(integerSet.getContext(), attr);
165334873feSAmy Wang         },
166334873feSAmy Wang         py::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
167334873feSAmy Wang   }
168334873feSAmy Wang };
169334873feSAmy Wang 
170ed9e52f3SAlex Zinenko template <typename T>
171ed9e52f3SAlex Zinenko static T pyTryCast(py::handle object) {
172ed9e52f3SAlex Zinenko   try {
173ed9e52f3SAlex Zinenko     return object.cast<T>();
174ed9e52f3SAlex Zinenko   } catch (py::cast_error &err) {
175ed9e52f3SAlex Zinenko     std::string msg =
176ed9e52f3SAlex Zinenko         std::string(
177ed9e52f3SAlex Zinenko             "Invalid attribute when attempting to create an ArrayAttribute (") +
178ed9e52f3SAlex Zinenko         err.what() + ")";
179ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
180ed9e52f3SAlex Zinenko   } catch (py::reference_cast_error &err) {
181ed9e52f3SAlex Zinenko     std::string msg = std::string("Invalid attribute (None?) when attempting "
182ed9e52f3SAlex Zinenko                                   "to create an ArrayAttribute (") +
183ed9e52f3SAlex Zinenko                       err.what() + ")";
184ed9e52f3SAlex Zinenko     throw py::cast_error(msg);
185ed9e52f3SAlex Zinenko   }
186ed9e52f3SAlex Zinenko }
187ed9e52f3SAlex Zinenko 
188619fd8c2SJeff Niu /// A python-wrapped dense array attribute with an element type and a derived
189619fd8c2SJeff Niu /// implementation class.
190619fd8c2SJeff Niu template <typename EltTy, typename DerivedT>
191133624acSJeff Niu class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
192619fd8c2SJeff Niu public:
193133624acSJeff Niu   using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
194619fd8c2SJeff Niu 
195619fd8c2SJeff Niu   /// Iterator over the integer elements of a dense array.
196619fd8c2SJeff Niu   class PyDenseArrayIterator {
197619fd8c2SJeff Niu   public:
1984a1b1196SMehdi Amini     PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
199619fd8c2SJeff Niu 
200619fd8c2SJeff Niu     /// Return a copy of the iterator.
201619fd8c2SJeff Niu     PyDenseArrayIterator dunderIter() { return *this; }
202619fd8c2SJeff Niu 
203619fd8c2SJeff Niu     /// Return the next element.
204619fd8c2SJeff Niu     EltTy dunderNext() {
205619fd8c2SJeff Niu       // Throw if the index has reached the end.
206619fd8c2SJeff Niu       if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
207619fd8c2SJeff Niu         throw py::stop_iteration();
208619fd8c2SJeff Niu       return DerivedT::getElement(attr.get(), nextIndex++);
209619fd8c2SJeff Niu     }
210619fd8c2SJeff Niu 
211619fd8c2SJeff Niu     /// Bind the iterator class.
212619fd8c2SJeff Niu     static void bind(py::module &m) {
213619fd8c2SJeff Niu       py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName,
214619fd8c2SJeff Niu                                        py::module_local())
215619fd8c2SJeff Niu           .def("__iter__", &PyDenseArrayIterator::dunderIter)
216619fd8c2SJeff Niu           .def("__next__", &PyDenseArrayIterator::dunderNext);
217619fd8c2SJeff Niu     }
218619fd8c2SJeff Niu 
219619fd8c2SJeff Niu   private:
220619fd8c2SJeff Niu     /// The referenced dense array attribute.
221619fd8c2SJeff Niu     PyAttribute attr;
222619fd8c2SJeff Niu     /// The next index to read.
223619fd8c2SJeff Niu     int nextIndex = 0;
224619fd8c2SJeff Niu   };
225619fd8c2SJeff Niu 
226619fd8c2SJeff Niu   /// Get the element at the given index.
227619fd8c2SJeff Niu   EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
228619fd8c2SJeff Niu 
229619fd8c2SJeff Niu   /// Bind the attribute class.
230133624acSJeff Niu   static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
231619fd8c2SJeff Niu     // Bind the constructor.
232619fd8c2SJeff Niu     c.def_static(
233619fd8c2SJeff Niu         "get",
234619fd8c2SJeff Niu         [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
2358dcb6722SIngo Müller           return getAttribute(values, ctx->getRef());
236619fd8c2SJeff Niu         },
237619fd8c2SJeff Niu         py::arg("values"), py::arg("context") = py::none(),
238619fd8c2SJeff Niu         "Gets a uniqued dense array attribute");
239619fd8c2SJeff Niu     // Bind the array methods.
240133624acSJeff Niu     c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
241619fd8c2SJeff Niu       if (i >= mlirDenseArrayGetNumElements(arr))
242619fd8c2SJeff Niu         throw py::index_error("DenseArray index out of range");
243619fd8c2SJeff Niu       return arr.getItem(i);
244619fd8c2SJeff Niu     });
245133624acSJeff Niu     c.def("__len__", [](const DerivedT &arr) {
246619fd8c2SJeff Niu       return mlirDenseArrayGetNumElements(arr);
247619fd8c2SJeff Niu     });
248133624acSJeff Niu     c.def("__iter__",
249133624acSJeff Niu           [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
2504a1b1196SMehdi Amini     c.def("__add__", [](DerivedT &arr, const py::list &extras) {
251619fd8c2SJeff Niu       std::vector<EltTy> values;
252619fd8c2SJeff Niu       intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
253619fd8c2SJeff Niu       values.reserve(numOldElements + py::len(extras));
254619fd8c2SJeff Niu       for (intptr_t i = 0; i < numOldElements; ++i)
255619fd8c2SJeff Niu         values.push_back(arr.getItem(i));
256619fd8c2SJeff Niu       for (py::handle attr : extras)
257619fd8c2SJeff Niu         values.push_back(pyTryCast<EltTy>(attr));
2588dcb6722SIngo Müller       return getAttribute(values, arr.getContext());
259619fd8c2SJeff Niu     });
260619fd8c2SJeff Niu   }
2618dcb6722SIngo Müller 
2628dcb6722SIngo Müller private:
2638dcb6722SIngo Müller   static DerivedT getAttribute(const std::vector<EltTy> &values,
2648dcb6722SIngo Müller                                PyMlirContextRef ctx) {
2658dcb6722SIngo Müller     if constexpr (std::is_same_v<EltTy, bool>) {
2668dcb6722SIngo Müller       std::vector<int> intValues(values.begin(), values.end());
2678dcb6722SIngo Müller       MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
2688dcb6722SIngo Müller                                                   intValues.data());
2698dcb6722SIngo Müller       return DerivedT(ctx, attr);
2708dcb6722SIngo Müller     } else {
2718dcb6722SIngo Müller       MlirAttribute attr =
2728dcb6722SIngo Müller           DerivedT::getAttribute(ctx->get(), values.size(), values.data());
2738dcb6722SIngo Müller       return DerivedT(ctx, attr);
2748dcb6722SIngo Müller     }
2758dcb6722SIngo Müller   }
276619fd8c2SJeff Niu };
277619fd8c2SJeff Niu 
278619fd8c2SJeff Niu /// Instantiate the python dense array classes.
279619fd8c2SJeff Niu struct PyDenseBoolArrayAttribute
2808dcb6722SIngo Müller     : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
281619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
282619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseBoolArrayGet;
283619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseBoolArrayGetElement;
284619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseBoolArrayAttr";
285619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
286619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
287619fd8c2SJeff Niu };
288619fd8c2SJeff Niu struct PyDenseI8ArrayAttribute
289619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
290619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
291619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI8ArrayGet;
292619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI8ArrayGetElement;
293619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI8ArrayAttr";
294619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
295619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
296619fd8c2SJeff Niu };
297619fd8c2SJeff Niu struct PyDenseI16ArrayAttribute
298619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
299619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
300619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI16ArrayGet;
301619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI16ArrayGetElement;
302619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI16ArrayAttr";
303619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
304619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
305619fd8c2SJeff Niu };
306619fd8c2SJeff Niu struct PyDenseI32ArrayAttribute
307619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
308619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
309619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI32ArrayGet;
310619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI32ArrayGetElement;
311619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI32ArrayAttr";
312619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
313619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
314619fd8c2SJeff Niu };
315619fd8c2SJeff Niu struct PyDenseI64ArrayAttribute
316619fd8c2SJeff Niu     : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
317619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
318619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseI64ArrayGet;
319619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseI64ArrayGetElement;
320619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseI64ArrayAttr";
321619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
322619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
323619fd8c2SJeff Niu };
324619fd8c2SJeff Niu struct PyDenseF32ArrayAttribute
325619fd8c2SJeff Niu     : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
326619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
327619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF32ArrayGet;
328619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF32ArrayGetElement;
329619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF32ArrayAttr";
330619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
331619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
332619fd8c2SJeff Niu };
333619fd8c2SJeff Niu struct PyDenseF64ArrayAttribute
334619fd8c2SJeff Niu     : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
335619fd8c2SJeff Niu   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
336619fd8c2SJeff Niu   static constexpr auto getAttribute = mlirDenseF64ArrayGet;
337619fd8c2SJeff Niu   static constexpr auto getElement = mlirDenseF64ArrayGetElement;
338619fd8c2SJeff Niu   static constexpr const char *pyClassName = "DenseF64ArrayAttr";
339619fd8c2SJeff Niu   static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
340619fd8c2SJeff Niu   using PyDenseArrayAttribute::PyDenseArrayAttribute;
341619fd8c2SJeff Niu };
342619fd8c2SJeff Niu 
343436c6c9cSStella Laurenzo class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
344436c6c9cSStella Laurenzo public:
345436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
346436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "ArrayAttr";
347436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
3489566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
3499566ee28Smax       mlirArrayAttrGetTypeID;
350436c6c9cSStella Laurenzo 
351436c6c9cSStella Laurenzo   class PyArrayAttributeIterator {
352436c6c9cSStella Laurenzo   public:
3531fc096afSMehdi Amini     PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
354436c6c9cSStella Laurenzo 
355436c6c9cSStella Laurenzo     PyArrayAttributeIterator &dunderIter() { return *this; }
356436c6c9cSStella Laurenzo 
357974c1596SRahul Kayaith     MlirAttribute dunderNext() {
358bca88952SJeff Niu       // TODO: Throw is an inefficient way to stop iteration.
359bca88952SJeff Niu       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
360436c6c9cSStella Laurenzo         throw py::stop_iteration();
361974c1596SRahul Kayaith       return mlirArrayAttrGetElement(attr.get(), nextIndex++);
362436c6c9cSStella Laurenzo     }
363436c6c9cSStella Laurenzo 
364436c6c9cSStella Laurenzo     static void bind(py::module &m) {
365f05ff4f7SStella Laurenzo       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
366f05ff4f7SStella Laurenzo                                            py::module_local())
367436c6c9cSStella Laurenzo           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
368436c6c9cSStella Laurenzo           .def("__next__", &PyArrayAttributeIterator::dunderNext);
369436c6c9cSStella Laurenzo     }
370436c6c9cSStella Laurenzo 
371436c6c9cSStella Laurenzo   private:
372436c6c9cSStella Laurenzo     PyAttribute attr;
373436c6c9cSStella Laurenzo     int nextIndex = 0;
374436c6c9cSStella Laurenzo   };
375436c6c9cSStella Laurenzo 
376974c1596SRahul Kayaith   MlirAttribute getItem(intptr_t i) {
377974c1596SRahul Kayaith     return mlirArrayAttrGetElement(*this, i);
378ed9e52f3SAlex Zinenko   }
379ed9e52f3SAlex Zinenko 
380436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
381436c6c9cSStella Laurenzo     c.def_static(
382436c6c9cSStella Laurenzo         "get",
383436c6c9cSStella Laurenzo         [](py::list attributes, DefaultingPyMlirContext context) {
384436c6c9cSStella Laurenzo           SmallVector<MlirAttribute> mlirAttributes;
385436c6c9cSStella Laurenzo           mlirAttributes.reserve(py::len(attributes));
386436c6c9cSStella Laurenzo           for (auto attribute : attributes) {
387ed9e52f3SAlex Zinenko             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
388436c6c9cSStella Laurenzo           }
389436c6c9cSStella Laurenzo           MlirAttribute attr = mlirArrayAttrGet(
390436c6c9cSStella Laurenzo               context->get(), mlirAttributes.size(), mlirAttributes.data());
391436c6c9cSStella Laurenzo           return PyArrayAttribute(context->getRef(), attr);
392436c6c9cSStella Laurenzo         },
393436c6c9cSStella Laurenzo         py::arg("attributes"), py::arg("context") = py::none(),
394436c6c9cSStella Laurenzo         "Gets a uniqued Array attribute");
395436c6c9cSStella Laurenzo     c.def("__getitem__",
396436c6c9cSStella Laurenzo           [](PyArrayAttribute &arr, intptr_t i) {
397436c6c9cSStella Laurenzo             if (i >= mlirArrayAttrGetNumElements(arr))
398436c6c9cSStella Laurenzo               throw py::index_error("ArrayAttribute index out of range");
399ed9e52f3SAlex Zinenko             return arr.getItem(i);
400436c6c9cSStella Laurenzo           })
401436c6c9cSStella Laurenzo         .def("__len__",
402436c6c9cSStella Laurenzo              [](const PyArrayAttribute &arr) {
403436c6c9cSStella Laurenzo                return mlirArrayAttrGetNumElements(arr);
404436c6c9cSStella Laurenzo              })
405436c6c9cSStella Laurenzo         .def("__iter__", [](const PyArrayAttribute &arr) {
406436c6c9cSStella Laurenzo           return PyArrayAttributeIterator(arr);
407436c6c9cSStella Laurenzo         });
408ed9e52f3SAlex Zinenko     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
409ed9e52f3SAlex Zinenko       std::vector<MlirAttribute> attributes;
410ed9e52f3SAlex Zinenko       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
411ed9e52f3SAlex Zinenko       attributes.reserve(numOldElements + py::len(extras));
412ed9e52f3SAlex Zinenko       for (intptr_t i = 0; i < numOldElements; ++i)
413ed9e52f3SAlex Zinenko         attributes.push_back(arr.getItem(i));
414ed9e52f3SAlex Zinenko       for (py::handle attr : extras)
415ed9e52f3SAlex Zinenko         attributes.push_back(pyTryCast<PyAttribute>(attr));
416ed9e52f3SAlex Zinenko       MlirAttribute arrayAttr = mlirArrayAttrGet(
417ed9e52f3SAlex Zinenko           arr.getContext()->get(), attributes.size(), attributes.data());
418ed9e52f3SAlex Zinenko       return PyArrayAttribute(arr.getContext(), arrayAttr);
419ed9e52f3SAlex Zinenko     });
420436c6c9cSStella Laurenzo   }
421436c6c9cSStella Laurenzo };
422436c6c9cSStella Laurenzo 
423436c6c9cSStella Laurenzo /// Float Point Attribute subclass - FloatAttr.
424436c6c9cSStella Laurenzo class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
425436c6c9cSStella Laurenzo public:
426436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
427436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FloatAttr";
428436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
4299566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
4309566ee28Smax       mlirFloatAttrGetTypeID;
431436c6c9cSStella Laurenzo 
432436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
433436c6c9cSStella Laurenzo     c.def_static(
434436c6c9cSStella Laurenzo         "get",
435436c6c9cSStella Laurenzo         [](PyType &type, double value, DefaultingPyLocation loc) {
4363ea4c501SRahul Kayaith           PyMlirContext::ErrorCapture errors(loc->getContext());
437436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
4383ea4c501SRahul Kayaith           if (mlirAttributeIsNull(attr))
4393ea4c501SRahul Kayaith             throw MLIRError("Invalid attribute", errors.take());
440436c6c9cSStella Laurenzo           return PyFloatAttribute(type.getContext(), attr);
441436c6c9cSStella Laurenzo         },
442436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
443436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a type");
444436c6c9cSStella Laurenzo     c.def_static(
445436c6c9cSStella Laurenzo         "get_f32",
446436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
447436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
448436c6c9cSStella Laurenzo               context->get(), mlirF32TypeGet(context->get()), value);
449436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
450436c6c9cSStella Laurenzo         },
451436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
452436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f32 type");
453436c6c9cSStella Laurenzo     c.def_static(
454436c6c9cSStella Laurenzo         "get_f64",
455436c6c9cSStella Laurenzo         [](double value, DefaultingPyMlirContext context) {
456436c6c9cSStella Laurenzo           MlirAttribute attr = mlirFloatAttrDoubleGet(
457436c6c9cSStella Laurenzo               context->get(), mlirF64TypeGet(context->get()), value);
458436c6c9cSStella Laurenzo           return PyFloatAttribute(context->getRef(), attr);
459436c6c9cSStella Laurenzo         },
460436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
461436c6c9cSStella Laurenzo         "Gets an uniqued float point attribute associated to a f64 type");
4622a5d4974SIngo Müller     c.def_property_readonly("value", mlirFloatAttrGetValueDouble,
4632a5d4974SIngo Müller                             "Returns the value of the float attribute");
4642a5d4974SIngo Müller     c.def("__float__", mlirFloatAttrGetValueDouble,
4652a5d4974SIngo Müller           "Converts the value of the float attribute to a Python float");
466436c6c9cSStella Laurenzo   }
467436c6c9cSStella Laurenzo };
468436c6c9cSStella Laurenzo 
469436c6c9cSStella Laurenzo /// Integer Attribute subclass - IntegerAttr.
470436c6c9cSStella Laurenzo class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
471436c6c9cSStella Laurenzo public:
472436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
473436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "IntegerAttr";
474436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
475436c6c9cSStella Laurenzo 
476436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
477436c6c9cSStella Laurenzo     c.def_static(
478436c6c9cSStella Laurenzo         "get",
479436c6c9cSStella Laurenzo         [](PyType &type, int64_t value) {
480436c6c9cSStella Laurenzo           MlirAttribute attr = mlirIntegerAttrGet(type, value);
481436c6c9cSStella Laurenzo           return PyIntegerAttribute(type.getContext(), attr);
482436c6c9cSStella Laurenzo         },
483436c6c9cSStella Laurenzo         py::arg("type"), py::arg("value"),
484436c6c9cSStella Laurenzo         "Gets an uniqued integer attribute associated to a type");
4852a5d4974SIngo Müller     c.def_property_readonly("value", toPyInt,
4862a5d4974SIngo Müller                             "Returns the value of the integer attribute");
4872a5d4974SIngo Müller     c.def("__int__", toPyInt,
4882a5d4974SIngo Müller           "Converts the value of the integer attribute to a Python int");
4892a5d4974SIngo Müller     c.def_property_readonly_static("static_typeid",
4902a5d4974SIngo Müller                                    [](py::object & /*class*/) -> MlirTypeID {
4912a5d4974SIngo Müller                                      return mlirIntegerAttrGetTypeID();
4922a5d4974SIngo Müller                                    });
4932a5d4974SIngo Müller   }
4942a5d4974SIngo Müller 
4952a5d4974SIngo Müller private:
4962a5d4974SIngo Müller   static py::int_ toPyInt(PyIntegerAttribute &self) {
497e9db306dSrkayaith     MlirType type = mlirAttributeGetType(self);
498e9db306dSrkayaith     if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
499436c6c9cSStella Laurenzo       return mlirIntegerAttrGetValueInt(self);
500e9db306dSrkayaith     if (mlirIntegerTypeIsSigned(type))
501e9db306dSrkayaith       return mlirIntegerAttrGetValueSInt(self);
502e9db306dSrkayaith     return mlirIntegerAttrGetValueUInt(self);
503436c6c9cSStella Laurenzo   }
504436c6c9cSStella Laurenzo };
505436c6c9cSStella Laurenzo 
506436c6c9cSStella Laurenzo /// Bool Attribute subclass - BoolAttr.
507436c6c9cSStella Laurenzo class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
508436c6c9cSStella Laurenzo public:
509436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
510436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BoolAttr";
511436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
512436c6c9cSStella Laurenzo 
513436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
514436c6c9cSStella Laurenzo     c.def_static(
515436c6c9cSStella Laurenzo         "get",
516436c6c9cSStella Laurenzo         [](bool value, DefaultingPyMlirContext context) {
517436c6c9cSStella Laurenzo           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
518436c6c9cSStella Laurenzo           return PyBoolAttribute(context->getRef(), attr);
519436c6c9cSStella Laurenzo         },
520436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
521436c6c9cSStella Laurenzo         "Gets an uniqued bool attribute");
5222a5d4974SIngo Müller     c.def_property_readonly("value", mlirBoolAttrGetValue,
523436c6c9cSStella Laurenzo                             "Returns the value of the bool attribute");
5242a5d4974SIngo Müller     c.def("__bool__", mlirBoolAttrGetValue,
5252a5d4974SIngo Müller           "Converts the value of the bool attribute to a Python bool");
526436c6c9cSStella Laurenzo   }
527436c6c9cSStella Laurenzo };
528436c6c9cSStella Laurenzo 
5294eee9ef9Smax class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
5304eee9ef9Smax public:
5314eee9ef9Smax   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
5324eee9ef9Smax   static constexpr const char *pyClassName = "SymbolRefAttr";
5334eee9ef9Smax   using PyConcreteAttribute::PyConcreteAttribute;
5344eee9ef9Smax 
5354eee9ef9Smax   static MlirAttribute fromList(const std::vector<std::string> &symbols,
5364eee9ef9Smax                                 PyMlirContext &context) {
5374eee9ef9Smax     if (symbols.empty())
5384eee9ef9Smax       throw std::runtime_error("SymbolRefAttr must be composed of at least "
5394eee9ef9Smax                                "one symbol.");
5404eee9ef9Smax     MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
5414eee9ef9Smax     SmallVector<MlirAttribute, 3> referenceAttrs;
5424eee9ef9Smax     for (size_t i = 1; i < symbols.size(); ++i) {
5434eee9ef9Smax       referenceAttrs.push_back(
5444eee9ef9Smax           mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
5454eee9ef9Smax     }
5464eee9ef9Smax     return mlirSymbolRefAttrGet(context.get(), rootSymbol,
5474eee9ef9Smax                                 referenceAttrs.size(), referenceAttrs.data());
5484eee9ef9Smax   }
5494eee9ef9Smax 
5504eee9ef9Smax   static void bindDerived(ClassTy &c) {
5514eee9ef9Smax     c.def_static(
5524eee9ef9Smax         "get",
5534eee9ef9Smax         [](const std::vector<std::string> &symbols,
5544eee9ef9Smax            DefaultingPyMlirContext context) {
5554eee9ef9Smax           return PySymbolRefAttribute::fromList(symbols, context.resolve());
5564eee9ef9Smax         },
5574eee9ef9Smax         py::arg("symbols"), py::arg("context") = py::none(),
5584eee9ef9Smax         "Gets a uniqued SymbolRef attribute from a list of symbol names");
5594eee9ef9Smax     c.def_property_readonly(
5604eee9ef9Smax         "value",
5614eee9ef9Smax         [](PySymbolRefAttribute &self) {
5624eee9ef9Smax           std::vector<std::string> symbols = {
5634eee9ef9Smax               unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
5644eee9ef9Smax           for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
5654eee9ef9Smax                ++i)
5664eee9ef9Smax             symbols.push_back(
5674eee9ef9Smax                 unwrap(mlirSymbolRefAttrGetRootReference(
5684eee9ef9Smax                            mlirSymbolRefAttrGetNestedReference(self, i)))
5694eee9ef9Smax                     .str());
5704eee9ef9Smax           return symbols;
5714eee9ef9Smax         },
5724eee9ef9Smax         "Returns the value of the SymbolRef attribute as a list[str]");
5734eee9ef9Smax   }
5744eee9ef9Smax };
5754eee9ef9Smax 
576436c6c9cSStella Laurenzo class PyFlatSymbolRefAttribute
577436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
578436c6c9cSStella Laurenzo public:
579436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
580436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
581436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
582436c6c9cSStella Laurenzo 
583436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
584436c6c9cSStella Laurenzo     c.def_static(
585436c6c9cSStella Laurenzo         "get",
586436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
587436c6c9cSStella Laurenzo           MlirAttribute attr =
588436c6c9cSStella Laurenzo               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
589436c6c9cSStella Laurenzo           return PyFlatSymbolRefAttribute(context->getRef(), attr);
590436c6c9cSStella Laurenzo         },
591436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
592436c6c9cSStella Laurenzo         "Gets a uniqued FlatSymbolRef attribute");
593436c6c9cSStella Laurenzo     c.def_property_readonly(
594436c6c9cSStella Laurenzo         "value",
595436c6c9cSStella Laurenzo         [](PyFlatSymbolRefAttribute &self) {
596436c6c9cSStella Laurenzo           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
597436c6c9cSStella Laurenzo           return py::str(stringRef.data, stringRef.length);
598436c6c9cSStella Laurenzo         },
599436c6c9cSStella Laurenzo         "Returns the value of the FlatSymbolRef attribute as a string");
600436c6c9cSStella Laurenzo   }
601436c6c9cSStella Laurenzo };
602436c6c9cSStella Laurenzo 
6035c3861b2SYun Long class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
6045c3861b2SYun Long public:
6055c3861b2SYun Long   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
6065c3861b2SYun Long   static constexpr const char *pyClassName = "OpaqueAttr";
6075c3861b2SYun Long   using PyConcreteAttribute::PyConcreteAttribute;
6089566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
6099566ee28Smax       mlirOpaqueAttrGetTypeID;
6105c3861b2SYun Long 
6115c3861b2SYun Long   static void bindDerived(ClassTy &c) {
6125c3861b2SYun Long     c.def_static(
6135c3861b2SYun Long         "get",
6145c3861b2SYun Long         [](std::string dialectNamespace, py::buffer buffer, PyType &type,
6155c3861b2SYun Long            DefaultingPyMlirContext context) {
6165c3861b2SYun Long           const py::buffer_info bufferInfo = buffer.request();
6175c3861b2SYun Long           intptr_t bufferSize = bufferInfo.size;
6185c3861b2SYun Long           MlirAttribute attr = mlirOpaqueAttrGet(
6195c3861b2SYun Long               context->get(), toMlirStringRef(dialectNamespace), bufferSize,
6205c3861b2SYun Long               static_cast<char *>(bufferInfo.ptr), type);
6215c3861b2SYun Long           return PyOpaqueAttribute(context->getRef(), attr);
6225c3861b2SYun Long         },
6235c3861b2SYun Long         py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"),
6245c3861b2SYun Long         py::arg("context") = py::none(), "Gets an Opaque attribute.");
6255c3861b2SYun Long     c.def_property_readonly(
6265c3861b2SYun Long         "dialect_namespace",
6275c3861b2SYun Long         [](PyOpaqueAttribute &self) {
6285c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
6295c3861b2SYun Long           return py::str(stringRef.data, stringRef.length);
6305c3861b2SYun Long         },
6315c3861b2SYun Long         "Returns the dialect namespace for the Opaque attribute as a string");
6325c3861b2SYun Long     c.def_property_readonly(
6335c3861b2SYun Long         "data",
6345c3861b2SYun Long         [](PyOpaqueAttribute &self) {
6355c3861b2SYun Long           MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
63662bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
6375c3861b2SYun Long         },
63862bf6c2eSChris Jones         "Returns the data for the Opaqued attributes as `bytes`");
6395c3861b2SYun Long   }
6405c3861b2SYun Long };
6415c3861b2SYun Long 
642436c6c9cSStella Laurenzo class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
643436c6c9cSStella Laurenzo public:
644436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
645436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "StringAttr";
646436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
6479566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
6489566ee28Smax       mlirStringAttrGetTypeID;
649436c6c9cSStella Laurenzo 
650436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
651436c6c9cSStella Laurenzo     c.def_static(
652436c6c9cSStella Laurenzo         "get",
653436c6c9cSStella Laurenzo         [](std::string value, DefaultingPyMlirContext context) {
654436c6c9cSStella Laurenzo           MlirAttribute attr =
655436c6c9cSStella Laurenzo               mlirStringAttrGet(context->get(), toMlirStringRef(value));
656436c6c9cSStella Laurenzo           return PyStringAttribute(context->getRef(), attr);
657436c6c9cSStella Laurenzo         },
658436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
659436c6c9cSStella Laurenzo         "Gets a uniqued string attribute");
660436c6c9cSStella Laurenzo     c.def_static(
661436c6c9cSStella Laurenzo         "get_typed",
662436c6c9cSStella Laurenzo         [](PyType &type, std::string value) {
663436c6c9cSStella Laurenzo           MlirAttribute attr =
664436c6c9cSStella Laurenzo               mlirStringAttrTypedGet(type, toMlirStringRef(value));
665436c6c9cSStella Laurenzo           return PyStringAttribute(type.getContext(), attr);
666436c6c9cSStella Laurenzo         },
667a6e7d024SStella Laurenzo         py::arg("type"), py::arg("value"),
668436c6c9cSStella Laurenzo         "Gets a uniqued string attribute associated to a type");
6699f533548SIngo Müller     c.def_property_readonly(
6709f533548SIngo Müller         "value",
6719f533548SIngo Müller         [](PyStringAttribute &self) {
6729f533548SIngo Müller           MlirStringRef stringRef = mlirStringAttrGetValue(self);
6739f533548SIngo Müller           return py::str(stringRef.data, stringRef.length);
6749f533548SIngo Müller         },
675436c6c9cSStella Laurenzo         "Returns the value of the string attribute");
67662bf6c2eSChris Jones     c.def_property_readonly(
67762bf6c2eSChris Jones         "value_bytes",
67862bf6c2eSChris Jones         [](PyStringAttribute &self) {
67962bf6c2eSChris Jones           MlirStringRef stringRef = mlirStringAttrGetValue(self);
68062bf6c2eSChris Jones           return py::bytes(stringRef.data, stringRef.length);
68162bf6c2eSChris Jones         },
68262bf6c2eSChris Jones         "Returns the value of the string attribute as `bytes`");
683436c6c9cSStella Laurenzo   }
684436c6c9cSStella Laurenzo };
685436c6c9cSStella Laurenzo 
686436c6c9cSStella Laurenzo // TODO: Support construction of string elements.
687436c6c9cSStella Laurenzo class PyDenseElementsAttribute
688436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseElementsAttribute> {
689436c6c9cSStella Laurenzo public:
690436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
691436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseElementsAttr";
692436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
693436c6c9cSStella Laurenzo 
694436c6c9cSStella Laurenzo   static PyDenseElementsAttribute
695c912f0e7Spranavm-nvidia   getFromList(py::list attributes, std::optional<PyType> explicitType,
696c912f0e7Spranavm-nvidia               DefaultingPyMlirContext contextWrapper) {
697c912f0e7Spranavm-nvidia 
698c912f0e7Spranavm-nvidia     const size_t numAttributes = py::len(attributes);
699c912f0e7Spranavm-nvidia     if (numAttributes == 0)
700c912f0e7Spranavm-nvidia       throw py::value_error("Attributes list must be non-empty.");
701c912f0e7Spranavm-nvidia 
702c912f0e7Spranavm-nvidia     MlirType shapedType;
703c912f0e7Spranavm-nvidia     if (explicitType) {
704c912f0e7Spranavm-nvidia       if ((!mlirTypeIsAShaped(*explicitType) ||
705c912f0e7Spranavm-nvidia            !mlirShapedTypeHasStaticShape(*explicitType))) {
706c912f0e7Spranavm-nvidia 
707c912f0e7Spranavm-nvidia         std::string message;
708c912f0e7Spranavm-nvidia         llvm::raw_string_ostream os(message);
709c912f0e7Spranavm-nvidia         os << "Expected a static ShapedType for the shaped_type parameter: "
710c912f0e7Spranavm-nvidia            << py::repr(py::cast(*explicitType));
711095b41c6SJOE1994         throw py::value_error(message);
712c912f0e7Spranavm-nvidia       }
713c912f0e7Spranavm-nvidia       shapedType = *explicitType;
714c912f0e7Spranavm-nvidia     } else {
715c912f0e7Spranavm-nvidia       SmallVector<int64_t> shape{static_cast<int64_t>(numAttributes)};
716c912f0e7Spranavm-nvidia       shapedType = mlirRankedTensorTypeGet(
717c912f0e7Spranavm-nvidia           shape.size(), shape.data(),
718c912f0e7Spranavm-nvidia           mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
719c912f0e7Spranavm-nvidia           mlirAttributeGetNull());
720c912f0e7Spranavm-nvidia     }
721c912f0e7Spranavm-nvidia 
722c912f0e7Spranavm-nvidia     SmallVector<MlirAttribute> mlirAttributes;
723c912f0e7Spranavm-nvidia     mlirAttributes.reserve(numAttributes);
724c912f0e7Spranavm-nvidia     for (const py::handle &attribute : attributes) {
725c912f0e7Spranavm-nvidia       MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
726c912f0e7Spranavm-nvidia       MlirType attrType = mlirAttributeGetType(mlirAttribute);
727c912f0e7Spranavm-nvidia       mlirAttributes.push_back(mlirAttribute);
728c912f0e7Spranavm-nvidia 
729c912f0e7Spranavm-nvidia       if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
730c912f0e7Spranavm-nvidia         std::string message;
731c912f0e7Spranavm-nvidia         llvm::raw_string_ostream os(message);
732c912f0e7Spranavm-nvidia         os << "All attributes must be of the same type and match "
733c912f0e7Spranavm-nvidia            << "the type parameter: expected=" << py::repr(py::cast(shapedType))
734c912f0e7Spranavm-nvidia            << ", but got=" << py::repr(py::cast(attrType));
735095b41c6SJOE1994         throw py::value_error(message);
736c912f0e7Spranavm-nvidia       }
737c912f0e7Spranavm-nvidia     }
738c912f0e7Spranavm-nvidia 
739c912f0e7Spranavm-nvidia     MlirAttribute elements = mlirDenseElementsAttrGet(
740c912f0e7Spranavm-nvidia         shapedType, mlirAttributes.size(), mlirAttributes.data());
741c912f0e7Spranavm-nvidia 
742c912f0e7Spranavm-nvidia     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
743c912f0e7Spranavm-nvidia   }
744c912f0e7Spranavm-nvidia 
745c912f0e7Spranavm-nvidia   static PyDenseElementsAttribute
7460a81ace0SKazu Hirata   getFromBuffer(py::buffer array, bool signless,
7470a81ace0SKazu Hirata                 std::optional<PyType> explicitType,
7480a81ace0SKazu Hirata                 std::optional<std::vector<int64_t>> explicitShape,
749436c6c9cSStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
750436c6c9cSStella Laurenzo     // Request a contiguous view. In exotic cases, this will cause a copy.
75171a25454SPeter Hawkins     int flags = PyBUF_ND;
75271a25454SPeter Hawkins     if (!explicitType) {
75371a25454SPeter Hawkins       flags |= PyBUF_FORMAT;
75471a25454SPeter Hawkins     }
75571a25454SPeter Hawkins     Py_buffer view;
75671a25454SPeter Hawkins     if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
757436c6c9cSStella Laurenzo       throw py::error_already_set();
758436c6c9cSStella Laurenzo     }
75971a25454SPeter Hawkins     auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
760*0a68171bSDmitri Gribenko     SmallVector<int64_t> shape;
761*0a68171bSDmitri Gribenko     if (explicitShape) {
762*0a68171bSDmitri Gribenko       shape.append(explicitShape->begin(), explicitShape->end());
763*0a68171bSDmitri Gribenko     } else {
764*0a68171bSDmitri Gribenko       shape.append(view.shape, view.shape + view.ndim);
765*0a68171bSDmitri Gribenko     }
766436c6c9cSStella Laurenzo 
767*0a68171bSDmitri Gribenko     MlirAttribute encodingAttr = mlirAttributeGetNull();
768436c6c9cSStella Laurenzo     MlirContext context = contextWrapper->get();
769*0a68171bSDmitri Gribenko 
770*0a68171bSDmitri Gribenko     // Detect format codes that are suitable for bulk loading. This includes
771*0a68171bSDmitri Gribenko     // all byte aligned integer and floating point types up to 8 bytes.
772*0a68171bSDmitri Gribenko     // Notably, this excludes, bool (which needs to be bit-packed) and
773*0a68171bSDmitri Gribenko     // other exotics which do not have a direct representation in the buffer
774*0a68171bSDmitri Gribenko     // protocol (i.e. complex, etc).
775*0a68171bSDmitri Gribenko     std::optional<MlirType> bulkLoadElementType;
776*0a68171bSDmitri Gribenko     if (explicitType) {
777*0a68171bSDmitri Gribenko       bulkLoadElementType = *explicitType;
778*0a68171bSDmitri Gribenko     } else {
779*0a68171bSDmitri Gribenko       std::string_view format(view.format);
780*0a68171bSDmitri Gribenko       if (format == "f") {
781*0a68171bSDmitri Gribenko         // f32
782*0a68171bSDmitri Gribenko         assert(view.itemsize == 4 && "mismatched array itemsize");
783*0a68171bSDmitri Gribenko         bulkLoadElementType = mlirF32TypeGet(context);
784*0a68171bSDmitri Gribenko       } else if (format == "d") {
785*0a68171bSDmitri Gribenko         // f64
786*0a68171bSDmitri Gribenko         assert(view.itemsize == 8 && "mismatched array itemsize");
787*0a68171bSDmitri Gribenko         bulkLoadElementType = mlirF64TypeGet(context);
788*0a68171bSDmitri Gribenko       } else if (format == "e") {
789*0a68171bSDmitri Gribenko         // f16
790*0a68171bSDmitri Gribenko         assert(view.itemsize == 2 && "mismatched array itemsize");
791*0a68171bSDmitri Gribenko         bulkLoadElementType = mlirF16TypeGet(context);
792*0a68171bSDmitri Gribenko       } else if (isSignedIntegerFormat(format)) {
793*0a68171bSDmitri Gribenko         if (view.itemsize == 4) {
794*0a68171bSDmitri Gribenko           // i32
795*0a68171bSDmitri Gribenko           bulkLoadElementType = signless
796*0a68171bSDmitri Gribenko                                     ? mlirIntegerTypeGet(context, 32)
797*0a68171bSDmitri Gribenko                                     : mlirIntegerTypeSignedGet(context, 32);
798*0a68171bSDmitri Gribenko         } else if (view.itemsize == 8) {
799*0a68171bSDmitri Gribenko           // i64
800*0a68171bSDmitri Gribenko           bulkLoadElementType = signless
801*0a68171bSDmitri Gribenko                                     ? mlirIntegerTypeGet(context, 64)
802*0a68171bSDmitri Gribenko                                     : mlirIntegerTypeSignedGet(context, 64);
803*0a68171bSDmitri Gribenko         } else if (view.itemsize == 1) {
804*0a68171bSDmitri Gribenko           // i8
805*0a68171bSDmitri Gribenko           bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
806*0a68171bSDmitri Gribenko                                          : mlirIntegerTypeSignedGet(context, 8);
807*0a68171bSDmitri Gribenko         } else if (view.itemsize == 2) {
808*0a68171bSDmitri Gribenko           // i16
809*0a68171bSDmitri Gribenko           bulkLoadElementType = signless
810*0a68171bSDmitri Gribenko                                     ? mlirIntegerTypeGet(context, 16)
811*0a68171bSDmitri Gribenko                                     : mlirIntegerTypeSignedGet(context, 16);
812*0a68171bSDmitri Gribenko         }
813*0a68171bSDmitri Gribenko       } else if (isUnsignedIntegerFormat(format)) {
814*0a68171bSDmitri Gribenko         if (view.itemsize == 4) {
815*0a68171bSDmitri Gribenko           // unsigned i32
816*0a68171bSDmitri Gribenko           bulkLoadElementType = signless
817*0a68171bSDmitri Gribenko                                     ? mlirIntegerTypeGet(context, 32)
818*0a68171bSDmitri Gribenko                                     : mlirIntegerTypeUnsignedGet(context, 32);
819*0a68171bSDmitri Gribenko         } else if (view.itemsize == 8) {
820*0a68171bSDmitri Gribenko           // unsigned i64
821*0a68171bSDmitri Gribenko           bulkLoadElementType = signless
822*0a68171bSDmitri Gribenko                                     ? mlirIntegerTypeGet(context, 64)
823*0a68171bSDmitri Gribenko                                     : mlirIntegerTypeUnsignedGet(context, 64);
824*0a68171bSDmitri Gribenko         } else if (view.itemsize == 1) {
825*0a68171bSDmitri Gribenko           // i8
826*0a68171bSDmitri Gribenko           bulkLoadElementType = signless
827*0a68171bSDmitri Gribenko                                     ? mlirIntegerTypeGet(context, 8)
828*0a68171bSDmitri Gribenko                                     : mlirIntegerTypeUnsignedGet(context, 8);
829*0a68171bSDmitri Gribenko         } else if (view.itemsize == 2) {
830*0a68171bSDmitri Gribenko           // i16
831*0a68171bSDmitri Gribenko           bulkLoadElementType = signless
832*0a68171bSDmitri Gribenko                                     ? mlirIntegerTypeGet(context, 16)
833*0a68171bSDmitri Gribenko                                     : mlirIntegerTypeUnsignedGet(context, 16);
834*0a68171bSDmitri Gribenko         }
835*0a68171bSDmitri Gribenko       }
836*0a68171bSDmitri Gribenko       if (!bulkLoadElementType) {
837*0a68171bSDmitri Gribenko         throw std::invalid_argument(
838*0a68171bSDmitri Gribenko             std::string("unimplemented array format conversion from format: ") +
839*0a68171bSDmitri Gribenko             std::string(format));
840*0a68171bSDmitri Gribenko       }
841*0a68171bSDmitri Gribenko     }
842*0a68171bSDmitri Gribenko 
843*0a68171bSDmitri Gribenko     MlirType shapedType;
844*0a68171bSDmitri Gribenko     if (mlirTypeIsAShaped(*bulkLoadElementType)) {
845*0a68171bSDmitri Gribenko       if (explicitShape) {
846*0a68171bSDmitri Gribenko         throw std::invalid_argument("Shape can only be specified explicitly "
847*0a68171bSDmitri Gribenko                                     "when the type is not a shaped type.");
848*0a68171bSDmitri Gribenko       }
849*0a68171bSDmitri Gribenko       shapedType = *bulkLoadElementType;
850*0a68171bSDmitri Gribenko     } else {
851*0a68171bSDmitri Gribenko       shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
852*0a68171bSDmitri Gribenko                                            *bulkLoadElementType, encodingAttr);
853*0a68171bSDmitri Gribenko     }
854*0a68171bSDmitri Gribenko     size_t rawBufferSize = view.len;
855*0a68171bSDmitri Gribenko     MlirAttribute attr =
856*0a68171bSDmitri Gribenko         mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
8575d6d30edSStella Laurenzo     if (mlirAttributeIsNull(attr)) {
8585d6d30edSStella Laurenzo       throw std::invalid_argument(
8595d6d30edSStella Laurenzo           "DenseElementsAttr could not be constructed from the given buffer. "
8605d6d30edSStella Laurenzo           "This may mean that the Python buffer layout does not match that "
8615d6d30edSStella Laurenzo           "MLIR expected layout and is a bug.");
8625d6d30edSStella Laurenzo     }
8635d6d30edSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
8645d6d30edSStella Laurenzo   }
865436c6c9cSStella Laurenzo 
8661fc096afSMehdi Amini   static PyDenseElementsAttribute getSplat(const PyType &shapedType,
867436c6c9cSStella Laurenzo                                            PyAttribute &elementAttr) {
868436c6c9cSStella Laurenzo     auto contextWrapper =
869436c6c9cSStella Laurenzo         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
870436c6c9cSStella Laurenzo     if (!mlirAttributeIsAInteger(elementAttr) &&
871436c6c9cSStella Laurenzo         !mlirAttributeIsAFloat(elementAttr)) {
872436c6c9cSStella Laurenzo       std::string message = "Illegal element type for DenseElementsAttr: ";
873436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
8744811270bSmax       throw py::value_error(message);
875436c6c9cSStella Laurenzo     }
876436c6c9cSStella Laurenzo     if (!mlirTypeIsAShaped(shapedType) ||
877436c6c9cSStella Laurenzo         !mlirShapedTypeHasStaticShape(shapedType)) {
878436c6c9cSStella Laurenzo       std::string message =
879436c6c9cSStella Laurenzo           "Expected a static ShapedType for the shaped_type parameter: ";
880436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
8814811270bSmax       throw py::value_error(message);
882436c6c9cSStella Laurenzo     }
883436c6c9cSStella Laurenzo     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
884436c6c9cSStella Laurenzo     MlirType attrType = mlirAttributeGetType(elementAttr);
885436c6c9cSStella Laurenzo     if (!mlirTypeEqual(shapedElementType, attrType)) {
886436c6c9cSStella Laurenzo       std::string message =
887436c6c9cSStella Laurenzo           "Shaped element type and attribute type must be equal: shaped=";
888436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(shapedType)));
889436c6c9cSStella Laurenzo       message.append(", element=");
890436c6c9cSStella Laurenzo       message.append(py::repr(py::cast(elementAttr)));
8914811270bSmax       throw py::value_error(message);
892436c6c9cSStella Laurenzo     }
893436c6c9cSStella Laurenzo 
894436c6c9cSStella Laurenzo     MlirAttribute elements =
895436c6c9cSStella Laurenzo         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
896436c6c9cSStella Laurenzo     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
897436c6c9cSStella Laurenzo   }
898436c6c9cSStella Laurenzo 
899436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
900436c6c9cSStella Laurenzo 
901436c6c9cSStella Laurenzo   py::buffer_info accessBuffer() {
902436c6c9cSStella Laurenzo     MlirType shapedType = mlirAttributeGetType(*this);
903436c6c9cSStella Laurenzo     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
9045d6d30edSStella Laurenzo     std::string format;
905436c6c9cSStella Laurenzo 
906436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(elementType)) {
907436c6c9cSStella Laurenzo       // f32
9085d6d30edSStella Laurenzo       return bufferInfo<float>(shapedType);
90902b6fb21SMehdi Amini     }
91002b6fb21SMehdi Amini     if (mlirTypeIsAF64(elementType)) {
911436c6c9cSStella Laurenzo       // f64
9125d6d30edSStella Laurenzo       return bufferInfo<double>(shapedType);
913bb56c2b3SMehdi Amini     }
914bb56c2b3SMehdi Amini     if (mlirTypeIsAF16(elementType)) {
9155d6d30edSStella Laurenzo       // f16
9165d6d30edSStella Laurenzo       return bufferInfo<uint16_t>(shapedType, "e");
917bb56c2b3SMehdi Amini     }
918ef1b735dSmax     if (mlirTypeIsAIndex(elementType)) {
919ef1b735dSmax       // Same as IndexType::kInternalStorageBitWidth
920ef1b735dSmax       return bufferInfo<int64_t>(shapedType);
921ef1b735dSmax     }
922bb56c2b3SMehdi Amini     if (mlirTypeIsAInteger(elementType) &&
923436c6c9cSStella Laurenzo         mlirIntegerTypeGetWidth(elementType) == 32) {
924436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
925436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
926436c6c9cSStella Laurenzo         // i32
9275d6d30edSStella Laurenzo         return bufferInfo<int32_t>(shapedType);
928e5639b3fSMehdi Amini       }
929e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
930436c6c9cSStella Laurenzo         // unsigned i32
9315d6d30edSStella Laurenzo         return bufferInfo<uint32_t>(shapedType);
932436c6c9cSStella Laurenzo       }
933436c6c9cSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
934436c6c9cSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 64) {
935436c6c9cSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
936436c6c9cSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
937436c6c9cSStella Laurenzo         // i64
9385d6d30edSStella Laurenzo         return bufferInfo<int64_t>(shapedType);
939e5639b3fSMehdi Amini       }
940e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
941436c6c9cSStella Laurenzo         // unsigned i64
9425d6d30edSStella Laurenzo         return bufferInfo<uint64_t>(shapedType);
9435d6d30edSStella Laurenzo       }
9445d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
9455d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 8) {
9465d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
9475d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
9485d6d30edSStella Laurenzo         // i8
9495d6d30edSStella Laurenzo         return bufferInfo<int8_t>(shapedType);
950e5639b3fSMehdi Amini       }
951e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
9525d6d30edSStella Laurenzo         // unsigned i8
9535d6d30edSStella Laurenzo         return bufferInfo<uint8_t>(shapedType);
9545d6d30edSStella Laurenzo       }
9555d6d30edSStella Laurenzo     } else if (mlirTypeIsAInteger(elementType) &&
9565d6d30edSStella Laurenzo                mlirIntegerTypeGetWidth(elementType) == 16) {
9575d6d30edSStella Laurenzo       if (mlirIntegerTypeIsSignless(elementType) ||
9585d6d30edSStella Laurenzo           mlirIntegerTypeIsSigned(elementType)) {
9595d6d30edSStella Laurenzo         // i16
9605d6d30edSStella Laurenzo         return bufferInfo<int16_t>(shapedType);
961e5639b3fSMehdi Amini       }
962e5639b3fSMehdi Amini       if (mlirIntegerTypeIsUnsigned(elementType)) {
9635d6d30edSStella Laurenzo         // unsigned i16
9645d6d30edSStella Laurenzo         return bufferInfo<uint16_t>(shapedType);
965436c6c9cSStella Laurenzo       }
966436c6c9cSStella Laurenzo     }
967436c6c9cSStella Laurenzo 
968c5f445d1SStella Laurenzo     // TODO: Currently crashes the program.
9695d6d30edSStella Laurenzo     // Reported as https://github.com/pybind/pybind11/issues/3336
970c5f445d1SStella Laurenzo     throw std::invalid_argument(
971c5f445d1SStella Laurenzo         "unsupported data type for conversion to Python buffer");
972436c6c9cSStella Laurenzo   }
973436c6c9cSStella Laurenzo 
974436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
975436c6c9cSStella Laurenzo     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
976436c6c9cSStella Laurenzo         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
977436c6c9cSStella Laurenzo                     py::arg("array"), py::arg("signless") = true,
9785d6d30edSStella Laurenzo                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
979436c6c9cSStella Laurenzo                     py::arg("context") = py::none(),
9805d6d30edSStella Laurenzo                     kDenseElementsAttrGetDocstring)
981c912f0e7Spranavm-nvidia         .def_static("get", PyDenseElementsAttribute::getFromList,
982c912f0e7Spranavm-nvidia                     py::arg("attrs"), py::arg("type") = py::none(),
983c912f0e7Spranavm-nvidia                     py::arg("context") = py::none(),
984c912f0e7Spranavm-nvidia                     kDenseElementsAttrGetFromListDocstring)
985436c6c9cSStella Laurenzo         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
986436c6c9cSStella Laurenzo                     py::arg("shaped_type"), py::arg("element_attr"),
987436c6c9cSStella Laurenzo                     "Gets a DenseElementsAttr where all values are the same")
988436c6c9cSStella Laurenzo         .def_property_readonly("is_splat",
989436c6c9cSStella Laurenzo                                [](PyDenseElementsAttribute &self) -> bool {
990436c6c9cSStella Laurenzo                                  return mlirDenseElementsAttrIsSplat(self);
991436c6c9cSStella Laurenzo                                })
99291259963SAdam Paszke         .def("get_splat_value",
993974c1596SRahul Kayaith              [](PyDenseElementsAttribute &self) {
994974c1596SRahul Kayaith                if (!mlirDenseElementsAttrIsSplat(self))
9954811270bSmax                  throw py::value_error(
99691259963SAdam Paszke                      "get_splat_value called on a non-splat attribute");
997974c1596SRahul Kayaith                return mlirDenseElementsAttrGetSplatValue(self);
99891259963SAdam Paszke              })
999436c6c9cSStella Laurenzo         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
1000436c6c9cSStella Laurenzo   }
1001436c6c9cSStella Laurenzo 
1002436c6c9cSStella Laurenzo private:
100371a25454SPeter Hawkins   static bool isUnsignedIntegerFormat(std::string_view format) {
1004436c6c9cSStella Laurenzo     if (format.empty())
1005436c6c9cSStella Laurenzo       return false;
1006436c6c9cSStella Laurenzo     char code = format[0];
1007436c6c9cSStella Laurenzo     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
1008436c6c9cSStella Laurenzo            code == 'Q';
1009436c6c9cSStella Laurenzo   }
1010436c6c9cSStella Laurenzo 
101171a25454SPeter Hawkins   static bool isSignedIntegerFormat(std::string_view format) {
1012436c6c9cSStella Laurenzo     if (format.empty())
1013436c6c9cSStella Laurenzo       return false;
1014436c6c9cSStella Laurenzo     char code = format[0];
1015436c6c9cSStella Laurenzo     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
1016436c6c9cSStella Laurenzo            code == 'q';
1017436c6c9cSStella Laurenzo   }
1018436c6c9cSStella Laurenzo 
1019436c6c9cSStella Laurenzo   template <typename Type>
1020436c6c9cSStella Laurenzo   py::buffer_info bufferInfo(MlirType shapedType,
10215d6d30edSStella Laurenzo                              const char *explicitFormat = nullptr) {
1022*0a68171bSDmitri Gribenko     intptr_t rank = mlirShapedTypeGetRank(shapedType);
1023436c6c9cSStella Laurenzo     // Prepare the data for the buffer_info.
1024*0a68171bSDmitri Gribenko     // Buffer is configured for read-only access below.
1025436c6c9cSStella Laurenzo     Type *data = static_cast<Type *>(
1026436c6c9cSStella Laurenzo         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1027436c6c9cSStella Laurenzo     // Prepare the shape for the buffer_info.
1028436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> shape;
1029436c6c9cSStella Laurenzo     for (intptr_t i = 0; i < rank; ++i)
1030436c6c9cSStella Laurenzo       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
1031436c6c9cSStella Laurenzo     // Prepare the strides for the buffer_info.
1032436c6c9cSStella Laurenzo     SmallVector<intptr_t, 4> strides;
1033f0e847d0SRahul Kayaith     if (mlirDenseElementsAttrIsSplat(*this)) {
1034f0e847d0SRahul Kayaith       // Splats are special, only the single value is stored.
1035f0e847d0SRahul Kayaith       strides.assign(rank, 0);
1036f0e847d0SRahul Kayaith     } else {
1037436c6c9cSStella Laurenzo       for (intptr_t i = 1; i < rank; ++i) {
1038f0e847d0SRahul Kayaith         intptr_t strideFactor = 1;
1039f0e847d0SRahul Kayaith         for (intptr_t j = i; j < rank; ++j)
1040436c6c9cSStella Laurenzo           strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
1041436c6c9cSStella Laurenzo         strides.push_back(sizeof(Type) * strideFactor);
1042436c6c9cSStella Laurenzo       }
1043436c6c9cSStella Laurenzo       strides.push_back(sizeof(Type));
1044f0e847d0SRahul Kayaith     }
10455d6d30edSStella Laurenzo     std::string format;
10465d6d30edSStella Laurenzo     if (explicitFormat) {
10475d6d30edSStella Laurenzo       format = explicitFormat;
10485d6d30edSStella Laurenzo     } else {
10495d6d30edSStella Laurenzo       format = py::format_descriptor<Type>::format();
10505d6d30edSStella Laurenzo     }
10515d6d30edSStella Laurenzo     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
10525d6d30edSStella Laurenzo                            /*readonly=*/true);
1053436c6c9cSStella Laurenzo   }
1054436c6c9cSStella Laurenzo }; // namespace
1055436c6c9cSStella Laurenzo 
1056436c6c9cSStella Laurenzo /// Refinement of the PyDenseElementsAttribute for attributes containing integer
1057436c6c9cSStella Laurenzo /// (and boolean) values. Supports element access.
1058436c6c9cSStella Laurenzo class PyDenseIntElementsAttribute
1059436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
1060436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
1061436c6c9cSStella Laurenzo public:
1062436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
1063436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseIntElementsAttr";
1064436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1065436c6c9cSStella Laurenzo 
1066436c6c9cSStella Laurenzo   /// Returns the element at the given linear position. Asserts if the index is
1067436c6c9cSStella Laurenzo   /// out of range.
1068436c6c9cSStella Laurenzo   py::int_ dunderGetItem(intptr_t pos) {
1069436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
10704811270bSmax       throw py::index_error("attempt to access out of bounds element");
1071436c6c9cSStella Laurenzo     }
1072436c6c9cSStella Laurenzo 
1073436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
1074436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
1075436c6c9cSStella Laurenzo     assert(mlirTypeIsAInteger(type) &&
1076436c6c9cSStella Laurenzo            "expected integer element type in dense int elements attribute");
1077436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
1078436c6c9cSStella Laurenzo     // elemental type of the attribute. py::int_ is implicitly constructible
1079436c6c9cSStella Laurenzo     // from any C++ integral type and handles bitwidth correctly.
1080436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
1081436c6c9cSStella Laurenzo     // querying them on each element access.
1082436c6c9cSStella Laurenzo     unsigned width = mlirIntegerTypeGetWidth(type);
1083436c6c9cSStella Laurenzo     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
1084436c6c9cSStella Laurenzo     if (isUnsigned) {
1085436c6c9cSStella Laurenzo       if (width == 1) {
1086436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
1087436c6c9cSStella Laurenzo       }
1088308d8b8cSRahul Kayaith       if (width == 8) {
1089308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt8Value(*this, pos);
1090308d8b8cSRahul Kayaith       }
1091308d8b8cSRahul Kayaith       if (width == 16) {
1092308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetUInt16Value(*this, pos);
1093308d8b8cSRahul Kayaith       }
1094436c6c9cSStella Laurenzo       if (width == 32) {
1095436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
1096436c6c9cSStella Laurenzo       }
1097436c6c9cSStella Laurenzo       if (width == 64) {
1098436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
1099436c6c9cSStella Laurenzo       }
1100436c6c9cSStella Laurenzo     } else {
1101436c6c9cSStella Laurenzo       if (width == 1) {
1102436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetBoolValue(*this, pos);
1103436c6c9cSStella Laurenzo       }
1104308d8b8cSRahul Kayaith       if (width == 8) {
1105308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt8Value(*this, pos);
1106308d8b8cSRahul Kayaith       }
1107308d8b8cSRahul Kayaith       if (width == 16) {
1108308d8b8cSRahul Kayaith         return mlirDenseElementsAttrGetInt16Value(*this, pos);
1109308d8b8cSRahul Kayaith       }
1110436c6c9cSStella Laurenzo       if (width == 32) {
1111436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt32Value(*this, pos);
1112436c6c9cSStella Laurenzo       }
1113436c6c9cSStella Laurenzo       if (width == 64) {
1114436c6c9cSStella Laurenzo         return mlirDenseElementsAttrGetInt64Value(*this, pos);
1115436c6c9cSStella Laurenzo       }
1116436c6c9cSStella Laurenzo     }
11174811270bSmax     throw py::type_error("Unsupported integer type");
1118436c6c9cSStella Laurenzo   }
1119436c6c9cSStella Laurenzo 
1120436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1121436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
1122436c6c9cSStella Laurenzo   }
1123436c6c9cSStella Laurenzo };
1124436c6c9cSStella Laurenzo 
1125f66cd9e9SStella Laurenzo class PyDenseResourceElementsAttribute
1126f66cd9e9SStella Laurenzo     : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
1127f66cd9e9SStella Laurenzo public:
1128f66cd9e9SStella Laurenzo   static constexpr IsAFunctionTy isaFunction =
1129f66cd9e9SStella Laurenzo       mlirAttributeIsADenseResourceElements;
1130f66cd9e9SStella Laurenzo   static constexpr const char *pyClassName = "DenseResourceElementsAttr";
1131f66cd9e9SStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1132f66cd9e9SStella Laurenzo 
1133f66cd9e9SStella Laurenzo   static PyDenseResourceElementsAttribute
1134962bf002SMehdi Amini   getFromBuffer(py::buffer buffer, const std::string &name, const PyType &type,
1135f66cd9e9SStella Laurenzo                 std::optional<size_t> alignment, bool isMutable,
1136f66cd9e9SStella Laurenzo                 DefaultingPyMlirContext contextWrapper) {
1137f66cd9e9SStella Laurenzo     if (!mlirTypeIsAShaped(type)) {
1138f66cd9e9SStella Laurenzo       throw std::invalid_argument(
1139f66cd9e9SStella Laurenzo           "Constructing a DenseResourceElementsAttr requires a ShapedType.");
1140f66cd9e9SStella Laurenzo     }
1141f66cd9e9SStella Laurenzo 
1142f66cd9e9SStella Laurenzo     // Do not request any conversions as we must ensure to use caller
1143f66cd9e9SStella Laurenzo     // managed memory.
1144f66cd9e9SStella Laurenzo     int flags = PyBUF_STRIDES;
1145f66cd9e9SStella Laurenzo     std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
1146f66cd9e9SStella Laurenzo     if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
1147f66cd9e9SStella Laurenzo       throw py::error_already_set();
1148f66cd9e9SStella Laurenzo     }
1149f66cd9e9SStella Laurenzo 
1150f66cd9e9SStella Laurenzo     // This scope releaser will only release if we haven't yet transferred
1151f66cd9e9SStella Laurenzo     // ownership.
1152f66cd9e9SStella Laurenzo     auto freeBuffer = llvm::make_scope_exit([&]() {
1153f66cd9e9SStella Laurenzo       if (view)
1154f66cd9e9SStella Laurenzo         PyBuffer_Release(view.get());
1155f66cd9e9SStella Laurenzo     });
1156f66cd9e9SStella Laurenzo 
1157f66cd9e9SStella Laurenzo     if (!PyBuffer_IsContiguous(view.get(), 'A')) {
1158f66cd9e9SStella Laurenzo       throw std::invalid_argument("Contiguous buffer is required.");
1159f66cd9e9SStella Laurenzo     }
1160f66cd9e9SStella Laurenzo 
1161f66cd9e9SStella Laurenzo     // Infer alignment to be the stride of one element if not explicit.
1162f66cd9e9SStella Laurenzo     size_t inferredAlignment;
1163f66cd9e9SStella Laurenzo     if (alignment)
1164f66cd9e9SStella Laurenzo       inferredAlignment = *alignment;
1165f66cd9e9SStella Laurenzo     else
1166f66cd9e9SStella Laurenzo       inferredAlignment = view->strides[view->ndim - 1];
1167f66cd9e9SStella Laurenzo 
1168f66cd9e9SStella Laurenzo     // The userData is a Py_buffer* that the deleter owns.
1169f66cd9e9SStella Laurenzo     auto deleter = [](void *userData, const void *data, size_t size,
1170f66cd9e9SStella Laurenzo                       size_t align) {
1171f66cd9e9SStella Laurenzo       Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
1172f66cd9e9SStella Laurenzo       PyBuffer_Release(ownedView);
1173f66cd9e9SStella Laurenzo       delete ownedView;
1174f66cd9e9SStella Laurenzo     };
1175f66cd9e9SStella Laurenzo 
1176f66cd9e9SStella Laurenzo     size_t rawBufferSize = view->len;
1177f66cd9e9SStella Laurenzo     MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
1178f66cd9e9SStella Laurenzo         type, toMlirStringRef(name), view->buf, rawBufferSize,
1179f66cd9e9SStella Laurenzo         inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
1180f66cd9e9SStella Laurenzo     if (mlirAttributeIsNull(attr)) {
1181f66cd9e9SStella Laurenzo       throw std::invalid_argument(
1182f66cd9e9SStella Laurenzo           "DenseResourceElementsAttr could not be constructed from the given "
1183f66cd9e9SStella Laurenzo           "buffer. "
1184f66cd9e9SStella Laurenzo           "This may mean that the Python buffer layout does not match that "
1185f66cd9e9SStella Laurenzo           "MLIR expected layout and is a bug.");
1186f66cd9e9SStella Laurenzo     }
1187f66cd9e9SStella Laurenzo     view.release();
1188f66cd9e9SStella Laurenzo     return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
1189f66cd9e9SStella Laurenzo   }
1190f66cd9e9SStella Laurenzo 
1191f66cd9e9SStella Laurenzo   static void bindDerived(ClassTy &c) {
1192f66cd9e9SStella Laurenzo     c.def_static("get_from_buffer",
1193f66cd9e9SStella Laurenzo                  PyDenseResourceElementsAttribute::getFromBuffer,
1194f66cd9e9SStella Laurenzo                  py::arg("array"), py::arg("name"), py::arg("type"),
1195f66cd9e9SStella Laurenzo                  py::arg("alignment") = py::none(),
1196f66cd9e9SStella Laurenzo                  py::arg("is_mutable") = false, py::arg("context") = py::none(),
1197f66cd9e9SStella Laurenzo                  kDenseResourceElementsAttrGetFromBufferDocstring);
1198f66cd9e9SStella Laurenzo   }
1199f66cd9e9SStella Laurenzo };
1200f66cd9e9SStella Laurenzo 
1201436c6c9cSStella Laurenzo class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
1202436c6c9cSStella Laurenzo public:
1203436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
1204436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DictAttr";
1205436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
12069566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
12079566ee28Smax       mlirDictionaryAttrGetTypeID;
1208436c6c9cSStella Laurenzo 
1209436c6c9cSStella Laurenzo   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
1210436c6c9cSStella Laurenzo 
12119fb1086bSAdrian Kuegel   bool dunderContains(const std::string &name) {
12129fb1086bSAdrian Kuegel     return !mlirAttributeIsNull(
12139fb1086bSAdrian Kuegel         mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
12149fb1086bSAdrian Kuegel   }
12159fb1086bSAdrian Kuegel 
1216436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
12179fb1086bSAdrian Kuegel     c.def("__contains__", &PyDictAttribute::dunderContains);
1218436c6c9cSStella Laurenzo     c.def("__len__", &PyDictAttribute::dunderLen);
1219436c6c9cSStella Laurenzo     c.def_static(
1220436c6c9cSStella Laurenzo         "get",
1221436c6c9cSStella Laurenzo         [](py::dict attributes, DefaultingPyMlirContext context) {
1222436c6c9cSStella Laurenzo           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
1223436c6c9cSStella Laurenzo           mlirNamedAttributes.reserve(attributes.size());
1224436c6c9cSStella Laurenzo           for (auto &it : attributes) {
122502b6fb21SMehdi Amini             auto &mlirAttr = it.second.cast<PyAttribute &>();
1226436c6c9cSStella Laurenzo             auto name = it.first.cast<std::string>();
1227436c6c9cSStella Laurenzo             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
122802b6fb21SMehdi Amini                 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
1229436c6c9cSStella Laurenzo                                   toMlirStringRef(name)),
123002b6fb21SMehdi Amini                 mlirAttr));
1231436c6c9cSStella Laurenzo           }
1232436c6c9cSStella Laurenzo           MlirAttribute attr =
1233436c6c9cSStella Laurenzo               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
1234436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
1235436c6c9cSStella Laurenzo           return PyDictAttribute(context->getRef(), attr);
1236436c6c9cSStella Laurenzo         },
1237ed9e52f3SAlex Zinenko         py::arg("value") = py::dict(), py::arg("context") = py::none(),
1238436c6c9cSStella Laurenzo         "Gets an uniqued dict attribute");
1239436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
1240436c6c9cSStella Laurenzo       MlirAttribute attr =
1241436c6c9cSStella Laurenzo           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
1242974c1596SRahul Kayaith       if (mlirAttributeIsNull(attr))
12434811270bSmax         throw py::key_error("attempt to access a non-existent attribute");
1244974c1596SRahul Kayaith       return attr;
1245436c6c9cSStella Laurenzo     });
1246436c6c9cSStella Laurenzo     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
1247436c6c9cSStella Laurenzo       if (index < 0 || index >= self.dunderLen()) {
12484811270bSmax         throw py::index_error("attempt to access out of bounds attribute");
1249436c6c9cSStella Laurenzo       }
1250436c6c9cSStella Laurenzo       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
1251436c6c9cSStella Laurenzo       return PyNamedAttribute(
1252436c6c9cSStella Laurenzo           namedAttr.attribute,
1253436c6c9cSStella Laurenzo           std::string(mlirIdentifierStr(namedAttr.name).data));
1254436c6c9cSStella Laurenzo     });
1255436c6c9cSStella Laurenzo   }
1256436c6c9cSStella Laurenzo };
1257436c6c9cSStella Laurenzo 
1258436c6c9cSStella Laurenzo /// Refinement of PyDenseElementsAttribute for attributes containing
1259436c6c9cSStella Laurenzo /// floating-point values. Supports element access.
1260436c6c9cSStella Laurenzo class PyDenseFPElementsAttribute
1261436c6c9cSStella Laurenzo     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1262436c6c9cSStella Laurenzo                                  PyDenseElementsAttribute> {
1263436c6c9cSStella Laurenzo public:
1264436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1265436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "DenseFPElementsAttr";
1266436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
1267436c6c9cSStella Laurenzo 
1268436c6c9cSStella Laurenzo   py::float_ dunderGetItem(intptr_t pos) {
1269436c6c9cSStella Laurenzo     if (pos < 0 || pos >= dunderLen()) {
12704811270bSmax       throw py::index_error("attempt to access out of bounds element");
1271436c6c9cSStella Laurenzo     }
1272436c6c9cSStella Laurenzo 
1273436c6c9cSStella Laurenzo     MlirType type = mlirAttributeGetType(*this);
1274436c6c9cSStella Laurenzo     type = mlirShapedTypeGetElementType(type);
1275436c6c9cSStella Laurenzo     // Dispatch element extraction to an appropriate C function based on the
1276436c6c9cSStella Laurenzo     // elemental type of the attribute. py::float_ is implicitly constructible
1277436c6c9cSStella Laurenzo     // from float and double.
1278436c6c9cSStella Laurenzo     // TODO: consider caching the type properties in the constructor to avoid
1279436c6c9cSStella Laurenzo     // querying them on each element access.
1280436c6c9cSStella Laurenzo     if (mlirTypeIsAF32(type)) {
1281436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetFloatValue(*this, pos);
1282436c6c9cSStella Laurenzo     }
1283436c6c9cSStella Laurenzo     if (mlirTypeIsAF64(type)) {
1284436c6c9cSStella Laurenzo       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
1285436c6c9cSStella Laurenzo     }
12864811270bSmax     throw py::type_error("Unsupported floating-point type");
1287436c6c9cSStella Laurenzo   }
1288436c6c9cSStella Laurenzo 
1289436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1290436c6c9cSStella Laurenzo     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1291436c6c9cSStella Laurenzo   }
1292436c6c9cSStella Laurenzo };
1293436c6c9cSStella Laurenzo 
1294436c6c9cSStella Laurenzo class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1295436c6c9cSStella Laurenzo public:
1296436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1297436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "TypeAttr";
1298436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
12999566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
13009566ee28Smax       mlirTypeAttrGetTypeID;
1301436c6c9cSStella Laurenzo 
1302436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1303436c6c9cSStella Laurenzo     c.def_static(
1304436c6c9cSStella Laurenzo         "get",
1305436c6c9cSStella Laurenzo         [](PyType value, DefaultingPyMlirContext context) {
1306436c6c9cSStella Laurenzo           MlirAttribute attr = mlirTypeAttrGet(value.get());
1307436c6c9cSStella Laurenzo           return PyTypeAttribute(context->getRef(), attr);
1308436c6c9cSStella Laurenzo         },
1309436c6c9cSStella Laurenzo         py::arg("value"), py::arg("context") = py::none(),
1310436c6c9cSStella Laurenzo         "Gets a uniqued Type attribute");
1311436c6c9cSStella Laurenzo     c.def_property_readonly("value", [](PyTypeAttribute &self) {
1312bfb1ba75Smax       return mlirTypeAttrGetValue(self.get());
1313436c6c9cSStella Laurenzo     });
1314436c6c9cSStella Laurenzo   }
1315436c6c9cSStella Laurenzo };
1316436c6c9cSStella Laurenzo 
1317436c6c9cSStella Laurenzo /// Unit Attribute subclass. Unit attributes don't have values.
1318436c6c9cSStella Laurenzo class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1319436c6c9cSStella Laurenzo public:
1320436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1321436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "UnitAttr";
1322436c6c9cSStella Laurenzo   using PyConcreteAttribute::PyConcreteAttribute;
13239566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
13249566ee28Smax       mlirUnitAttrGetTypeID;
1325436c6c9cSStella Laurenzo 
1326436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1327436c6c9cSStella Laurenzo     c.def_static(
1328436c6c9cSStella Laurenzo         "get",
1329436c6c9cSStella Laurenzo         [](DefaultingPyMlirContext context) {
1330436c6c9cSStella Laurenzo           return PyUnitAttribute(context->getRef(),
1331436c6c9cSStella Laurenzo                                  mlirUnitAttrGet(context->get()));
1332436c6c9cSStella Laurenzo         },
1333436c6c9cSStella Laurenzo         py::arg("context") = py::none(), "Create a Unit attribute.");
1334436c6c9cSStella Laurenzo   }
1335436c6c9cSStella Laurenzo };
1336436c6c9cSStella Laurenzo 
1337ac2e2d65SDenys Shabalin /// Strided layout attribute subclass.
1338ac2e2d65SDenys Shabalin class PyStridedLayoutAttribute
1339ac2e2d65SDenys Shabalin     : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1340ac2e2d65SDenys Shabalin public:
1341ac2e2d65SDenys Shabalin   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1342ac2e2d65SDenys Shabalin   static constexpr const char *pyClassName = "StridedLayoutAttr";
1343ac2e2d65SDenys Shabalin   using PyConcreteAttribute::PyConcreteAttribute;
13449566ee28Smax   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
13459566ee28Smax       mlirStridedLayoutAttrGetTypeID;
1346ac2e2d65SDenys Shabalin 
1347ac2e2d65SDenys Shabalin   static void bindDerived(ClassTy &c) {
1348ac2e2d65SDenys Shabalin     c.def_static(
1349ac2e2d65SDenys Shabalin         "get",
1350ac2e2d65SDenys Shabalin         [](int64_t offset, const std::vector<int64_t> strides,
1351ac2e2d65SDenys Shabalin            DefaultingPyMlirContext ctx) {
1352ac2e2d65SDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1353ac2e2d65SDenys Shabalin               ctx->get(), offset, strides.size(), strides.data());
1354ac2e2d65SDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1355ac2e2d65SDenys Shabalin         },
1356ac2e2d65SDenys Shabalin         py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
1357ac2e2d65SDenys Shabalin         "Gets a strided layout attribute.");
1358e3fd612eSDenys Shabalin     c.def_static(
1359e3fd612eSDenys Shabalin         "get_fully_dynamic",
1360e3fd612eSDenys Shabalin         [](int64_t rank, DefaultingPyMlirContext ctx) {
1361e3fd612eSDenys Shabalin           auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
1362e3fd612eSDenys Shabalin           std::vector<int64_t> strides(rank);
1363e3fd612eSDenys Shabalin           std::fill(strides.begin(), strides.end(), dynamic);
1364e3fd612eSDenys Shabalin           MlirAttribute attr = mlirStridedLayoutAttrGet(
1365e3fd612eSDenys Shabalin               ctx->get(), dynamic, strides.size(), strides.data());
1366e3fd612eSDenys Shabalin           return PyStridedLayoutAttribute(ctx->getRef(), attr);
1367e3fd612eSDenys Shabalin         },
1368e3fd612eSDenys Shabalin         py::arg("rank"), py::arg("context") = py::none(),
1369e3fd612eSDenys Shabalin         "Gets a strided layout attribute with dynamic offset and strides of a "
1370e3fd612eSDenys Shabalin         "given rank.");
1371ac2e2d65SDenys Shabalin     c.def_property_readonly(
1372ac2e2d65SDenys Shabalin         "offset",
1373ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1374ac2e2d65SDenys Shabalin           return mlirStridedLayoutAttrGetOffset(self);
1375ac2e2d65SDenys Shabalin         },
1376ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1377ac2e2d65SDenys Shabalin     c.def_property_readonly(
1378ac2e2d65SDenys Shabalin         "strides",
1379ac2e2d65SDenys Shabalin         [](PyStridedLayoutAttribute &self) {
1380ac2e2d65SDenys Shabalin           intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1381ac2e2d65SDenys Shabalin           std::vector<int64_t> strides(size);
1382ac2e2d65SDenys Shabalin           for (intptr_t i = 0; i < size; i++) {
1383ac2e2d65SDenys Shabalin             strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1384ac2e2d65SDenys Shabalin           }
1385ac2e2d65SDenys Shabalin           return strides;
1386ac2e2d65SDenys Shabalin         },
1387ac2e2d65SDenys Shabalin         "Returns the value of the float point attribute");
1388ac2e2d65SDenys Shabalin   }
1389ac2e2d65SDenys Shabalin };
1390ac2e2d65SDenys Shabalin 
13919566ee28Smax py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
13929566ee28Smax   if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
13939566ee28Smax     return py::cast(PyDenseBoolArrayAttribute(pyAttribute));
13949566ee28Smax   if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
13959566ee28Smax     return py::cast(PyDenseI8ArrayAttribute(pyAttribute));
13969566ee28Smax   if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
13979566ee28Smax     return py::cast(PyDenseI16ArrayAttribute(pyAttribute));
13989566ee28Smax   if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
13999566ee28Smax     return py::cast(PyDenseI32ArrayAttribute(pyAttribute));
14009566ee28Smax   if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
14019566ee28Smax     return py::cast(PyDenseI64ArrayAttribute(pyAttribute));
14029566ee28Smax   if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
14039566ee28Smax     return py::cast(PyDenseF32ArrayAttribute(pyAttribute));
14049566ee28Smax   if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
14059566ee28Smax     return py::cast(PyDenseF64ArrayAttribute(pyAttribute));
14069566ee28Smax   std::string msg =
14079566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
14089566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
14099566ee28Smax   throw py::cast_error(msg);
14109566ee28Smax }
14119566ee28Smax 
14129566ee28Smax py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
14139566ee28Smax   if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
14149566ee28Smax     return py::cast(PyDenseFPElementsAttribute(pyAttribute));
14159566ee28Smax   if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
14169566ee28Smax     return py::cast(PyDenseIntElementsAttribute(pyAttribute));
14179566ee28Smax   std::string msg =
14189566ee28Smax       std::string(
14199566ee28Smax           "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
14209566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
14219566ee28Smax   throw py::cast_error(msg);
14229566ee28Smax }
14239566ee28Smax 
14249566ee28Smax py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
14259566ee28Smax   if (PyBoolAttribute::isaFunction(pyAttribute))
14269566ee28Smax     return py::cast(PyBoolAttribute(pyAttribute));
14279566ee28Smax   if (PyIntegerAttribute::isaFunction(pyAttribute))
14289566ee28Smax     return py::cast(PyIntegerAttribute(pyAttribute));
14299566ee28Smax   std::string msg =
14309566ee28Smax       std::string("Can't cast unknown element type DenseArrayAttr (") +
14319566ee28Smax       std::string(py::repr(py::cast(pyAttribute))) + ")";
14329566ee28Smax   throw py::cast_error(msg);
14339566ee28Smax }
14349566ee28Smax 
14354eee9ef9Smax py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
14364eee9ef9Smax   if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
14374eee9ef9Smax     return py::cast(PyFlatSymbolRefAttribute(pyAttribute));
14384eee9ef9Smax   if (PySymbolRefAttribute::isaFunction(pyAttribute))
14394eee9ef9Smax     return py::cast(PySymbolRefAttribute(pyAttribute));
14404eee9ef9Smax   std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
14414eee9ef9Smax                     std::string(py::repr(py::cast(pyAttribute))) + ")";
14424eee9ef9Smax   throw py::cast_error(msg);
14434eee9ef9Smax }
14444eee9ef9Smax 
1445436c6c9cSStella Laurenzo } // namespace
1446436c6c9cSStella Laurenzo 
1447436c6c9cSStella Laurenzo void mlir::python::populateIRAttributes(py::module &m) {
1448436c6c9cSStella Laurenzo   PyAffineMapAttribute::bind(m);
1449619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::bind(m);
1450619fd8c2SJeff Niu   PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1451619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::bind(m);
1452619fd8c2SJeff Niu   PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1453619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::bind(m);
1454619fd8c2SJeff Niu   PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1455619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::bind(m);
1456619fd8c2SJeff Niu   PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1457619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::bind(m);
1458619fd8c2SJeff Niu   PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1459619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::bind(m);
1460619fd8c2SJeff Niu   PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1461619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::bind(m);
1462619fd8c2SJeff Niu   PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
14639566ee28Smax   PyGlobals::get().registerTypeCaster(
14649566ee28Smax       mlirDenseArrayAttrGetTypeID(),
14659566ee28Smax       pybind11::cpp_function(denseArrayAttributeCaster));
1466619fd8c2SJeff Niu 
1467436c6c9cSStella Laurenzo   PyArrayAttribute::bind(m);
1468436c6c9cSStella Laurenzo   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1469436c6c9cSStella Laurenzo   PyBoolAttribute::bind(m);
1470436c6c9cSStella Laurenzo   PyDenseElementsAttribute::bind(m);
1471436c6c9cSStella Laurenzo   PyDenseFPElementsAttribute::bind(m);
1472436c6c9cSStella Laurenzo   PyDenseIntElementsAttribute::bind(m);
14739566ee28Smax   PyGlobals::get().registerTypeCaster(
14749566ee28Smax       mlirDenseIntOrFPElementsAttrGetTypeID(),
14759566ee28Smax       pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
1476f66cd9e9SStella Laurenzo   PyDenseResourceElementsAttribute::bind(m);
14779566ee28Smax 
1478436c6c9cSStella Laurenzo   PyDictAttribute::bind(m);
14794eee9ef9Smax   PySymbolRefAttribute::bind(m);
14804eee9ef9Smax   PyGlobals::get().registerTypeCaster(
14814eee9ef9Smax       mlirSymbolRefAttrGetTypeID(),
14824eee9ef9Smax       pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster));
14834eee9ef9Smax 
1484436c6c9cSStella Laurenzo   PyFlatSymbolRefAttribute::bind(m);
14855c3861b2SYun Long   PyOpaqueAttribute::bind(m);
1486436c6c9cSStella Laurenzo   PyFloatAttribute::bind(m);
1487436c6c9cSStella Laurenzo   PyIntegerAttribute::bind(m);
1488334873feSAmy Wang   PyIntegerSetAttribute::bind(m);
1489436c6c9cSStella Laurenzo   PyStringAttribute::bind(m);
1490436c6c9cSStella Laurenzo   PyTypeAttribute::bind(m);
14919566ee28Smax   PyGlobals::get().registerTypeCaster(
14929566ee28Smax       mlirIntegerAttrGetTypeID(),
14939566ee28Smax       pybind11::cpp_function(integerOrBoolAttributeCaster));
1494436c6c9cSStella Laurenzo   PyUnitAttribute::bind(m);
1495ac2e2d65SDenys Shabalin 
1496ac2e2d65SDenys Shabalin   PyStridedLayoutAttribute::bind(m);
1497436c6c9cSStella Laurenzo }
1498