xref: /llvm-project/mlir/lib/Bindings/Python/NanobindUtils.h (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1b56d1ec6SPeter Hawkins //===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++
2b56d1ec6SPeter Hawkins //-*-===//
3b56d1ec6SPeter Hawkins //
4b56d1ec6SPeter Hawkins // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5b56d1ec6SPeter Hawkins // See https://llvm.org/LICENSE.txt for license information.
6b56d1ec6SPeter Hawkins // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7b56d1ec6SPeter Hawkins //
8b56d1ec6SPeter Hawkins //===----------------------------------------------------------------------===//
9b56d1ec6SPeter Hawkins 
10b56d1ec6SPeter Hawkins #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
11b56d1ec6SPeter Hawkins #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
12b56d1ec6SPeter Hawkins 
13b56d1ec6SPeter Hawkins #include "mlir-c/Support.h"
14*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h"
15b56d1ec6SPeter Hawkins #include "llvm/ADT/STLExtras.h"
16b56d1ec6SPeter Hawkins #include "llvm/ADT/Twine.h"
17b56d1ec6SPeter Hawkins #include "llvm/Support/DataTypes.h"
18b56d1ec6SPeter Hawkins 
19b56d1ec6SPeter Hawkins template <>
20b56d1ec6SPeter Hawkins struct std::iterator_traits<nanobind::detail::fast_iterator> {
21b56d1ec6SPeter Hawkins   using value_type = nanobind::handle;
22b56d1ec6SPeter Hawkins   using reference = const value_type;
23b56d1ec6SPeter Hawkins   using pointer = void;
24b56d1ec6SPeter Hawkins   using difference_type = std::ptrdiff_t;
25b56d1ec6SPeter Hawkins   using iterator_category = std::forward_iterator_tag;
26b56d1ec6SPeter Hawkins };
27b56d1ec6SPeter Hawkins 
28b56d1ec6SPeter Hawkins namespace mlir {
29b56d1ec6SPeter Hawkins namespace python {
30b56d1ec6SPeter Hawkins 
31b56d1ec6SPeter Hawkins /// CRTP template for special wrapper types that are allowed to be passed in as
32b56d1ec6SPeter Hawkins /// 'None' function arguments and can be resolved by some global mechanic if
33b56d1ec6SPeter Hawkins /// so. Such types will raise an error if this global resolution fails, and
34b56d1ec6SPeter Hawkins /// it is actually illegal for them to ever be unresolved. From a user
35b56d1ec6SPeter Hawkins /// perspective, they behave like a smart ptr to the underlying type (i.e.
36b56d1ec6SPeter Hawkins /// 'get' method and operator-> overloaded).
37b56d1ec6SPeter Hawkins ///
38b56d1ec6SPeter Hawkins /// Derived types must provide a method, which is called when an environmental
39b56d1ec6SPeter Hawkins /// resolution is required. It must raise an exception if resolution fails:
40b56d1ec6SPeter Hawkins ///   static ReferrentTy &resolve()
41b56d1ec6SPeter Hawkins ///
42b56d1ec6SPeter Hawkins /// They must also provide a parameter description that will be used in
43b56d1ec6SPeter Hawkins /// error messages about mismatched types:
44b56d1ec6SPeter Hawkins ///   static constexpr const char kTypeDescription[] = "<Description>";
45b56d1ec6SPeter Hawkins 
46b56d1ec6SPeter Hawkins template <typename DerivedTy, typename T>
47b56d1ec6SPeter Hawkins class Defaulting {
48b56d1ec6SPeter Hawkins public:
49b56d1ec6SPeter Hawkins   using ReferrentTy = T;
50b56d1ec6SPeter Hawkins   /// Type casters require the type to be default constructible, but using
51b56d1ec6SPeter Hawkins   /// such an instance is illegal.
52b56d1ec6SPeter Hawkins   Defaulting() = default;
53b56d1ec6SPeter Hawkins   Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
54b56d1ec6SPeter Hawkins 
55b56d1ec6SPeter Hawkins   ReferrentTy *get() const { return referrent; }
56b56d1ec6SPeter Hawkins   ReferrentTy *operator->() { return referrent; }
57b56d1ec6SPeter Hawkins 
58b56d1ec6SPeter Hawkins private:
59b56d1ec6SPeter Hawkins   ReferrentTy *referrent = nullptr;
60b56d1ec6SPeter Hawkins };
61b56d1ec6SPeter Hawkins 
62b56d1ec6SPeter Hawkins } // namespace python
63b56d1ec6SPeter Hawkins } // namespace mlir
64b56d1ec6SPeter Hawkins 
65b56d1ec6SPeter Hawkins namespace nanobind {
66b56d1ec6SPeter Hawkins namespace detail {
67b56d1ec6SPeter Hawkins 
68b56d1ec6SPeter Hawkins template <typename DefaultingTy>
69b56d1ec6SPeter Hawkins struct MlirDefaultingCaster {
70*5cd42747SPeter Hawkins   NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription))
71b56d1ec6SPeter Hawkins 
72b56d1ec6SPeter Hawkins   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
73b56d1ec6SPeter Hawkins     if (src.is_none()) {
74b56d1ec6SPeter Hawkins       // Note that we do want an exception to propagate from here as it will be
75b56d1ec6SPeter Hawkins       // the most informative.
76b56d1ec6SPeter Hawkins       value = DefaultingTy{DefaultingTy::resolve()};
77b56d1ec6SPeter Hawkins       return true;
78b56d1ec6SPeter Hawkins     }
79b56d1ec6SPeter Hawkins 
80b56d1ec6SPeter Hawkins     // Unlike many casters that chain, these casters are expected to always
81b56d1ec6SPeter Hawkins     // succeed, so instead of doing an isinstance check followed by a cast,
82b56d1ec6SPeter Hawkins     // just cast in one step and handle the exception. Returning false (vs
83b56d1ec6SPeter Hawkins     // letting the exception propagate) causes higher level signature parsing
84b56d1ec6SPeter Hawkins     // code to produce nice error messages (other than "Cannot cast...").
85b56d1ec6SPeter Hawkins     try {
86b56d1ec6SPeter Hawkins       value = DefaultingTy{
87b56d1ec6SPeter Hawkins           nanobind::cast<typename DefaultingTy::ReferrentTy &>(src)};
88b56d1ec6SPeter Hawkins       return true;
89b56d1ec6SPeter Hawkins     } catch (std::exception &) {
90b56d1ec6SPeter Hawkins       return false;
91b56d1ec6SPeter Hawkins     }
92b56d1ec6SPeter Hawkins   }
93b56d1ec6SPeter Hawkins 
94b56d1ec6SPeter Hawkins   static handle from_cpp(DefaultingTy src, rv_policy policy,
95b56d1ec6SPeter Hawkins                          cleanup_list *cleanup) noexcept {
96b56d1ec6SPeter Hawkins     return nanobind::cast(src, policy);
97b56d1ec6SPeter Hawkins   }
98b56d1ec6SPeter Hawkins };
99b56d1ec6SPeter Hawkins } // namespace detail
100b56d1ec6SPeter Hawkins } // namespace nanobind
101b56d1ec6SPeter Hawkins 
102b56d1ec6SPeter Hawkins //------------------------------------------------------------------------------
103b56d1ec6SPeter Hawkins // Conversion utilities.
104b56d1ec6SPeter Hawkins //------------------------------------------------------------------------------
105b56d1ec6SPeter Hawkins 
106b56d1ec6SPeter Hawkins namespace mlir {
107b56d1ec6SPeter Hawkins 
108b56d1ec6SPeter Hawkins /// Accumulates into a python string from a method that accepts an
109b56d1ec6SPeter Hawkins /// MlirStringCallback.
110b56d1ec6SPeter Hawkins struct PyPrintAccumulator {
111b56d1ec6SPeter Hawkins   nanobind::list parts;
112b56d1ec6SPeter Hawkins 
113b56d1ec6SPeter Hawkins   void *getUserData() { return this; }
114b56d1ec6SPeter Hawkins 
115b56d1ec6SPeter Hawkins   MlirStringCallback getCallback() {
116b56d1ec6SPeter Hawkins     return [](MlirStringRef part, void *userData) {
117b56d1ec6SPeter Hawkins       PyPrintAccumulator *printAccum =
118b56d1ec6SPeter Hawkins           static_cast<PyPrintAccumulator *>(userData);
119b56d1ec6SPeter Hawkins       nanobind::str pyPart(part.data,
120b56d1ec6SPeter Hawkins                            part.length); // Decodes as UTF-8 by default.
121b56d1ec6SPeter Hawkins       printAccum->parts.append(std::move(pyPart));
122b56d1ec6SPeter Hawkins     };
123b56d1ec6SPeter Hawkins   }
124b56d1ec6SPeter Hawkins 
125b56d1ec6SPeter Hawkins   nanobind::str join() {
126b56d1ec6SPeter Hawkins     nanobind::str delim("", 0);
127b56d1ec6SPeter Hawkins     return nanobind::cast<nanobind::str>(delim.attr("join")(parts));
128b56d1ec6SPeter Hawkins   }
129b56d1ec6SPeter Hawkins };
130b56d1ec6SPeter Hawkins 
131b56d1ec6SPeter Hawkins /// Accumulates int a python file-like object, either writing text (default)
132b56d1ec6SPeter Hawkins /// or binary.
133b56d1ec6SPeter Hawkins class PyFileAccumulator {
134b56d1ec6SPeter Hawkins public:
135b56d1ec6SPeter Hawkins   PyFileAccumulator(const nanobind::object &fileObject, bool binary)
136b56d1ec6SPeter Hawkins       : pyWriteFunction(fileObject.attr("write")), binary(binary) {}
137b56d1ec6SPeter Hawkins 
138b56d1ec6SPeter Hawkins   void *getUserData() { return this; }
139b56d1ec6SPeter Hawkins 
140b56d1ec6SPeter Hawkins   MlirStringCallback getCallback() {
141b56d1ec6SPeter Hawkins     return [](MlirStringRef part, void *userData) {
142b56d1ec6SPeter Hawkins       nanobind::gil_scoped_acquire acquire;
143b56d1ec6SPeter Hawkins       PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
144b56d1ec6SPeter Hawkins       if (accum->binary) {
145b56d1ec6SPeter Hawkins         // Note: Still has to copy and not avoidable with this API.
146b56d1ec6SPeter Hawkins         nanobind::bytes pyBytes(part.data, part.length);
147b56d1ec6SPeter Hawkins         accum->pyWriteFunction(pyBytes);
148b56d1ec6SPeter Hawkins       } else {
149b56d1ec6SPeter Hawkins         nanobind::str pyStr(part.data,
150b56d1ec6SPeter Hawkins                             part.length); // Decodes as UTF-8 by default.
151b56d1ec6SPeter Hawkins         accum->pyWriteFunction(pyStr);
152b56d1ec6SPeter Hawkins       }
153b56d1ec6SPeter Hawkins     };
154b56d1ec6SPeter Hawkins   }
155b56d1ec6SPeter Hawkins 
156b56d1ec6SPeter Hawkins private:
157b56d1ec6SPeter Hawkins   nanobind::object pyWriteFunction;
158b56d1ec6SPeter Hawkins   bool binary;
159b56d1ec6SPeter Hawkins };
160b56d1ec6SPeter Hawkins 
161b56d1ec6SPeter Hawkins /// Accumulates into a python string from a method that is expected to make
162b56d1ec6SPeter Hawkins /// one (no more, no less) call to the callback (asserts internally on
163b56d1ec6SPeter Hawkins /// violation).
164b56d1ec6SPeter Hawkins struct PySinglePartStringAccumulator {
165b56d1ec6SPeter Hawkins   void *getUserData() { return this; }
166b56d1ec6SPeter Hawkins 
167b56d1ec6SPeter Hawkins   MlirStringCallback getCallback() {
168b56d1ec6SPeter Hawkins     return [](MlirStringRef part, void *userData) {
169b56d1ec6SPeter Hawkins       PySinglePartStringAccumulator *accum =
170b56d1ec6SPeter Hawkins           static_cast<PySinglePartStringAccumulator *>(userData);
171b56d1ec6SPeter Hawkins       assert(!accum->invoked &&
172b56d1ec6SPeter Hawkins              "PySinglePartStringAccumulator called back multiple times");
173b56d1ec6SPeter Hawkins       accum->invoked = true;
174b56d1ec6SPeter Hawkins       accum->value = nanobind::str(part.data, part.length);
175b56d1ec6SPeter Hawkins     };
176b56d1ec6SPeter Hawkins   }
177b56d1ec6SPeter Hawkins 
178b56d1ec6SPeter Hawkins   nanobind::str takeValue() {
179b56d1ec6SPeter Hawkins     assert(invoked && "PySinglePartStringAccumulator not called back");
180b56d1ec6SPeter Hawkins     return std::move(value);
181b56d1ec6SPeter Hawkins   }
182b56d1ec6SPeter Hawkins 
183b56d1ec6SPeter Hawkins private:
184b56d1ec6SPeter Hawkins   nanobind::str value;
185b56d1ec6SPeter Hawkins   bool invoked = false;
186b56d1ec6SPeter Hawkins };
187b56d1ec6SPeter Hawkins 
188b56d1ec6SPeter Hawkins /// A CRTP base class for pseudo-containers willing to support Python-type
189b56d1ec6SPeter Hawkins /// slicing access on top of indexed access. Calling ::bind on this class
190b56d1ec6SPeter Hawkins /// will define `__len__` as well as `__getitem__` with integer and slice
191b56d1ec6SPeter Hawkins /// arguments.
192b56d1ec6SPeter Hawkins ///
193b56d1ec6SPeter Hawkins /// This is intended for pseudo-containers that can refer to arbitrary slices of
194b56d1ec6SPeter Hawkins /// underlying storage indexed by a single integer. Indexing those with an
195b56d1ec6SPeter Hawkins /// integer produces an instance of ElementTy. Indexing those with a slice
196b56d1ec6SPeter Hawkins /// produces a new instance of Derived, which can be sliced further.
197b56d1ec6SPeter Hawkins ///
198b56d1ec6SPeter Hawkins /// A derived class must provide the following:
199b56d1ec6SPeter Hawkins ///   - a `static const char *pyClassName ` field containing the name of the
200b56d1ec6SPeter Hawkins ///     Python class to bind;
201b56d1ec6SPeter Hawkins ///   - an instance method `intptr_t getRawNumElements()` that returns the
202b56d1ec6SPeter Hawkins ///   number
203b56d1ec6SPeter Hawkins ///     of elements in the backing container (NOT that of the slice);
204b56d1ec6SPeter Hawkins ///   - an instance method `ElementTy getRawElement(intptr_t)` that returns a
205b56d1ec6SPeter Hawkins ///     single element at the given linear index (NOT slice index);
206b56d1ec6SPeter Hawkins ///   - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
207b56d1ec6SPeter Hawkins ///     constructs a new instance of the derived pseudo-container with the
208b56d1ec6SPeter Hawkins ///     given slice parameters (to be forwarded to the Sliceable constructor).
209b56d1ec6SPeter Hawkins ///
210b56d1ec6SPeter Hawkins /// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
211b56d1ec6SPeter Hawkins /// throw.
212b56d1ec6SPeter Hawkins ///
213b56d1ec6SPeter Hawkins /// A derived class may additionally define:
214b56d1ec6SPeter Hawkins ///   - a `static void bindDerived(ClassTy &)` method to bind additional methods
215b56d1ec6SPeter Hawkins ///     the python class.
216b56d1ec6SPeter Hawkins template <typename Derived, typename ElementTy>
217b56d1ec6SPeter Hawkins class Sliceable {
218b56d1ec6SPeter Hawkins protected:
219b56d1ec6SPeter Hawkins   using ClassTy = nanobind::class_<Derived>;
220b56d1ec6SPeter Hawkins 
221b56d1ec6SPeter Hawkins   /// Transforms `index` into a legal value to access the underlying sequence.
222b56d1ec6SPeter Hawkins   /// Returns <0 on failure.
223b56d1ec6SPeter Hawkins   intptr_t wrapIndex(intptr_t index) {
224b56d1ec6SPeter Hawkins     if (index < 0)
225b56d1ec6SPeter Hawkins       index = length + index;
226b56d1ec6SPeter Hawkins     if (index < 0 || index >= length)
227b56d1ec6SPeter Hawkins       return -1;
228b56d1ec6SPeter Hawkins     return index;
229b56d1ec6SPeter Hawkins   }
230b56d1ec6SPeter Hawkins 
231b56d1ec6SPeter Hawkins   /// Computes the linear index given the current slice properties.
232b56d1ec6SPeter Hawkins   intptr_t linearizeIndex(intptr_t index) {
233b56d1ec6SPeter Hawkins     intptr_t linearIndex = index * step + startIndex;
234b56d1ec6SPeter Hawkins     assert(linearIndex >= 0 &&
235b56d1ec6SPeter Hawkins            linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
236b56d1ec6SPeter Hawkins            "linear index out of bounds, the slice is ill-formed");
237b56d1ec6SPeter Hawkins     return linearIndex;
238b56d1ec6SPeter Hawkins   }
239b56d1ec6SPeter Hawkins 
240b56d1ec6SPeter Hawkins   /// Trait to check if T provides a `maybeDownCast` method.
241b56d1ec6SPeter Hawkins   /// Note, you need the & to detect inherited members.
242b56d1ec6SPeter Hawkins   template <typename T, typename... Args>
243b56d1ec6SPeter Hawkins   using has_maybe_downcast = decltype(&T::maybeDownCast);
244b56d1ec6SPeter Hawkins 
245b56d1ec6SPeter Hawkins   /// Returns the element at the given slice index. Supports negative indices
246b56d1ec6SPeter Hawkins   /// by taking elements in inverse order. Returns a nullptr object if out
247b56d1ec6SPeter Hawkins   /// of bounds.
248b56d1ec6SPeter Hawkins   nanobind::object getItem(intptr_t index) {
249b56d1ec6SPeter Hawkins     // Negative indices mean we count from the end.
250b56d1ec6SPeter Hawkins     index = wrapIndex(index);
251b56d1ec6SPeter Hawkins     if (index < 0) {
252b56d1ec6SPeter Hawkins       PyErr_SetString(PyExc_IndexError, "index out of range");
253b56d1ec6SPeter Hawkins       return {};
254b56d1ec6SPeter Hawkins     }
255b56d1ec6SPeter Hawkins 
256b56d1ec6SPeter Hawkins     if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value)
257b56d1ec6SPeter Hawkins       return static_cast<Derived *>(this)
258b56d1ec6SPeter Hawkins           ->getRawElement(linearizeIndex(index))
259b56d1ec6SPeter Hawkins           .maybeDownCast();
260b56d1ec6SPeter Hawkins     else
261b56d1ec6SPeter Hawkins       return nanobind::cast(
262b56d1ec6SPeter Hawkins           static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
263b56d1ec6SPeter Hawkins   }
264b56d1ec6SPeter Hawkins 
265b56d1ec6SPeter Hawkins   /// Returns a new instance of the pseudo-container restricted to the given
266b56d1ec6SPeter Hawkins   /// slice. Returns a nullptr object on failure.
267b56d1ec6SPeter Hawkins   nanobind::object getItemSlice(PyObject *slice) {
268b56d1ec6SPeter Hawkins     ssize_t start, stop, extraStep, sliceLength;
269b56d1ec6SPeter Hawkins     if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
270b56d1ec6SPeter Hawkins                              &sliceLength) != 0) {
271b56d1ec6SPeter Hawkins       PyErr_SetString(PyExc_IndexError, "index out of range");
272b56d1ec6SPeter Hawkins       return {};
273b56d1ec6SPeter Hawkins     }
274b56d1ec6SPeter Hawkins     return nanobind::cast(static_cast<Derived *>(this)->slice(
275b56d1ec6SPeter Hawkins         startIndex + start * step, sliceLength, step * extraStep));
276b56d1ec6SPeter Hawkins   }
277b56d1ec6SPeter Hawkins 
278b56d1ec6SPeter Hawkins public:
279b56d1ec6SPeter Hawkins   explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
280b56d1ec6SPeter Hawkins       : startIndex(startIndex), length(length), step(step) {
281b56d1ec6SPeter Hawkins     assert(length >= 0 && "expected non-negative slice length");
282b56d1ec6SPeter Hawkins   }
283b56d1ec6SPeter Hawkins 
284b56d1ec6SPeter Hawkins   /// Returns the `index`-th element in the slice, supports negative indices.
285b56d1ec6SPeter Hawkins   /// Throws if the index is out of bounds.
286b56d1ec6SPeter Hawkins   ElementTy getElement(intptr_t index) {
287b56d1ec6SPeter Hawkins     // Negative indices mean we count from the end.
288b56d1ec6SPeter Hawkins     index = wrapIndex(index);
289b56d1ec6SPeter Hawkins     if (index < 0) {
290b56d1ec6SPeter Hawkins       throw nanobind::index_error("index out of range");
291b56d1ec6SPeter Hawkins     }
292b56d1ec6SPeter Hawkins 
293b56d1ec6SPeter Hawkins     return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
294b56d1ec6SPeter Hawkins   }
295b56d1ec6SPeter Hawkins 
296b56d1ec6SPeter Hawkins   /// Returns the size of slice.
297b56d1ec6SPeter Hawkins   intptr_t size() { return length; }
298b56d1ec6SPeter Hawkins 
299b56d1ec6SPeter Hawkins   /// Returns a new vector (mapped to Python list) containing elements from two
300b56d1ec6SPeter Hawkins   /// slices. The new vector is necessary because slices may not be contiguous
301b56d1ec6SPeter Hawkins   /// or even come from the same original sequence.
302b56d1ec6SPeter Hawkins   std::vector<ElementTy> dunderAdd(Derived &other) {
303b56d1ec6SPeter Hawkins     std::vector<ElementTy> elements;
304b56d1ec6SPeter Hawkins     elements.reserve(length + other.length);
305b56d1ec6SPeter Hawkins     for (intptr_t i = 0; i < length; ++i) {
306b56d1ec6SPeter Hawkins       elements.push_back(static_cast<Derived *>(this)->getElement(i));
307b56d1ec6SPeter Hawkins     }
308b56d1ec6SPeter Hawkins     for (intptr_t i = 0; i < other.length; ++i) {
309b56d1ec6SPeter Hawkins       elements.push_back(static_cast<Derived *>(&other)->getElement(i));
310b56d1ec6SPeter Hawkins     }
311b56d1ec6SPeter Hawkins     return elements;
312b56d1ec6SPeter Hawkins   }
313b56d1ec6SPeter Hawkins 
314b56d1ec6SPeter Hawkins   /// Binds the indexing and length methods in the Python class.
315b56d1ec6SPeter Hawkins   static void bind(nanobind::module_ &m) {
316b56d1ec6SPeter Hawkins     auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName)
317b56d1ec6SPeter Hawkins                      .def("__add__", &Sliceable::dunderAdd);
318b56d1ec6SPeter Hawkins     Derived::bindDerived(clazz);
319b56d1ec6SPeter Hawkins 
320b56d1ec6SPeter Hawkins     // Manually implement the sequence protocol via the C API. We do this
321b56d1ec6SPeter Hawkins     // because it is approx 4x faster than via nanobind, largely because that
322b56d1ec6SPeter Hawkins     // formulation requires a C++ exception to be thrown to detect end of
323b56d1ec6SPeter Hawkins     // sequence.
324b56d1ec6SPeter Hawkins     // Since we are in a C-context, any C++ exception that happens here
325b56d1ec6SPeter Hawkins     // will terminate the program. There is nothing in this implementation
326b56d1ec6SPeter Hawkins     // that should throw in a non-terminal way, so we forgo further
327b56d1ec6SPeter Hawkins     // exception marshalling.
328b56d1ec6SPeter Hawkins     // See: https://github.com/pybind/nanobind/issues/2842
329b56d1ec6SPeter Hawkins     auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
330b56d1ec6SPeter Hawkins     assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
331b56d1ec6SPeter Hawkins            "must be heap type");
332b56d1ec6SPeter Hawkins     heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
333b56d1ec6SPeter Hawkins       auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
334b56d1ec6SPeter Hawkins       return self->length;
335b56d1ec6SPeter Hawkins     };
336b56d1ec6SPeter Hawkins     // sq_item is called as part of the sequence protocol for iteration,
337b56d1ec6SPeter Hawkins     // list construction, etc.
338b56d1ec6SPeter Hawkins     heap_type->as_sequence.sq_item =
339b56d1ec6SPeter Hawkins         +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
340b56d1ec6SPeter Hawkins       auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
341b56d1ec6SPeter Hawkins       return self->getItem(index).release().ptr();
342b56d1ec6SPeter Hawkins     };
343b56d1ec6SPeter Hawkins     // mp_subscript is used for both slices and integer lookups.
344b56d1ec6SPeter Hawkins     heap_type->as_mapping.mp_subscript =
345b56d1ec6SPeter Hawkins         +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
346b56d1ec6SPeter Hawkins       auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
347b56d1ec6SPeter Hawkins       Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
348b56d1ec6SPeter Hawkins       if (!PyErr_Occurred()) {
349b56d1ec6SPeter Hawkins         // Integer indexing.
350b56d1ec6SPeter Hawkins         return self->getItem(index).release().ptr();
351b56d1ec6SPeter Hawkins       }
352b56d1ec6SPeter Hawkins       PyErr_Clear();
353b56d1ec6SPeter Hawkins 
354b56d1ec6SPeter Hawkins       // Assume slice-based indexing.
355b56d1ec6SPeter Hawkins       if (PySlice_Check(rawSubscript)) {
356b56d1ec6SPeter Hawkins         return self->getItemSlice(rawSubscript).release().ptr();
357b56d1ec6SPeter Hawkins       }
358b56d1ec6SPeter Hawkins 
359b56d1ec6SPeter Hawkins       PyErr_SetString(PyExc_ValueError, "expected integer or slice");
360b56d1ec6SPeter Hawkins       return nullptr;
361b56d1ec6SPeter Hawkins     };
362b56d1ec6SPeter Hawkins   }
363b56d1ec6SPeter Hawkins 
364b56d1ec6SPeter Hawkins   /// Hook for derived classes willing to bind more methods.
365b56d1ec6SPeter Hawkins   static void bindDerived(ClassTy &) {}
366b56d1ec6SPeter Hawkins 
367b56d1ec6SPeter Hawkins private:
368b56d1ec6SPeter Hawkins   intptr_t startIndex;
369b56d1ec6SPeter Hawkins   intptr_t length;
370b56d1ec6SPeter Hawkins   intptr_t step;
371b56d1ec6SPeter Hawkins };
372b56d1ec6SPeter Hawkins 
373b56d1ec6SPeter Hawkins } // namespace mlir
374b56d1ec6SPeter Hawkins 
375b56d1ec6SPeter Hawkins namespace llvm {
376b56d1ec6SPeter Hawkins 
377b56d1ec6SPeter Hawkins template <>
378b56d1ec6SPeter Hawkins struct DenseMapInfo<MlirTypeID> {
379b56d1ec6SPeter Hawkins   static inline MlirTypeID getEmptyKey() {
380b56d1ec6SPeter Hawkins     auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
381b56d1ec6SPeter Hawkins     return mlirTypeIDCreate(pointer);
382b56d1ec6SPeter Hawkins   }
383b56d1ec6SPeter Hawkins   static inline MlirTypeID getTombstoneKey() {
384b56d1ec6SPeter Hawkins     auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
385b56d1ec6SPeter Hawkins     return mlirTypeIDCreate(pointer);
386b56d1ec6SPeter Hawkins   }
387b56d1ec6SPeter Hawkins   static inline unsigned getHashValue(const MlirTypeID &val) {
388b56d1ec6SPeter Hawkins     return mlirTypeIDHashValue(val);
389b56d1ec6SPeter Hawkins   }
390b56d1ec6SPeter Hawkins   static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) {
391b56d1ec6SPeter Hawkins     return mlirTypeIDEqual(lhs, rhs);
392b56d1ec6SPeter Hawkins   }
393b56d1ec6SPeter Hawkins };
394b56d1ec6SPeter Hawkins } // namespace llvm
395b56d1ec6SPeter Hawkins 
396b56d1ec6SPeter Hawkins #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
397