xref: /llvm-project/mlir/lib/Bindings/Python/NanobindUtils.h (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1 //===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
11 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
12 
13 #include "mlir-c/Support.h"
14 #include "mlir/Bindings/Python/Nanobind.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/DataTypes.h"
18 
19 template <>
20 struct std::iterator_traits<nanobind::detail::fast_iterator> {
21   using value_type = nanobind::handle;
22   using reference = const value_type;
23   using pointer = void;
24   using difference_type = std::ptrdiff_t;
25   using iterator_category = std::forward_iterator_tag;
26 };
27 
28 namespace mlir {
29 namespace python {
30 
31 /// CRTP template for special wrapper types that are allowed to be passed in as
32 /// 'None' function arguments and can be resolved by some global mechanic if
33 /// so. Such types will raise an error if this global resolution fails, and
34 /// it is actually illegal for them to ever be unresolved. From a user
35 /// perspective, they behave like a smart ptr to the underlying type (i.e.
36 /// 'get' method and operator-> overloaded).
37 ///
38 /// Derived types must provide a method, which is called when an environmental
39 /// resolution is required. It must raise an exception if resolution fails:
40 ///   static ReferrentTy &resolve()
41 ///
42 /// They must also provide a parameter description that will be used in
43 /// error messages about mismatched types:
44 ///   static constexpr const char kTypeDescription[] = "<Description>";
45 
46 template <typename DerivedTy, typename T>
47 class Defaulting {
48 public:
49   using ReferrentTy = T;
50   /// Type casters require the type to be default constructible, but using
51   /// such an instance is illegal.
52   Defaulting() = default;
53   Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
54 
55   ReferrentTy *get() const { return referrent; }
56   ReferrentTy *operator->() { return referrent; }
57 
58 private:
59   ReferrentTy *referrent = nullptr;
60 };
61 
62 } // namespace python
63 } // namespace mlir
64 
65 namespace nanobind {
66 namespace detail {
67 
68 template <typename DefaultingTy>
69 struct MlirDefaultingCaster {
70   NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription))
71 
72   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
73     if (src.is_none()) {
74       // Note that we do want an exception to propagate from here as it will be
75       // the most informative.
76       value = DefaultingTy{DefaultingTy::resolve()};
77       return true;
78     }
79 
80     // Unlike many casters that chain, these casters are expected to always
81     // succeed, so instead of doing an isinstance check followed by a cast,
82     // just cast in one step and handle the exception. Returning false (vs
83     // letting the exception propagate) causes higher level signature parsing
84     // code to produce nice error messages (other than "Cannot cast...").
85     try {
86       value = DefaultingTy{
87           nanobind::cast<typename DefaultingTy::ReferrentTy &>(src)};
88       return true;
89     } catch (std::exception &) {
90       return false;
91     }
92   }
93 
94   static handle from_cpp(DefaultingTy src, rv_policy policy,
95                          cleanup_list *cleanup) noexcept {
96     return nanobind::cast(src, policy);
97   }
98 };
99 } // namespace detail
100 } // namespace nanobind
101 
102 //------------------------------------------------------------------------------
103 // Conversion utilities.
104 //------------------------------------------------------------------------------
105 
106 namespace mlir {
107 
108 /// Accumulates into a python string from a method that accepts an
109 /// MlirStringCallback.
110 struct PyPrintAccumulator {
111   nanobind::list parts;
112 
113   void *getUserData() { return this; }
114 
115   MlirStringCallback getCallback() {
116     return [](MlirStringRef part, void *userData) {
117       PyPrintAccumulator *printAccum =
118           static_cast<PyPrintAccumulator *>(userData);
119       nanobind::str pyPart(part.data,
120                            part.length); // Decodes as UTF-8 by default.
121       printAccum->parts.append(std::move(pyPart));
122     };
123   }
124 
125   nanobind::str join() {
126     nanobind::str delim("", 0);
127     return nanobind::cast<nanobind::str>(delim.attr("join")(parts));
128   }
129 };
130 
131 /// Accumulates int a python file-like object, either writing text (default)
132 /// or binary.
133 class PyFileAccumulator {
134 public:
135   PyFileAccumulator(const nanobind::object &fileObject, bool binary)
136       : pyWriteFunction(fileObject.attr("write")), binary(binary) {}
137 
138   void *getUserData() { return this; }
139 
140   MlirStringCallback getCallback() {
141     return [](MlirStringRef part, void *userData) {
142       nanobind::gil_scoped_acquire acquire;
143       PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
144       if (accum->binary) {
145         // Note: Still has to copy and not avoidable with this API.
146         nanobind::bytes pyBytes(part.data, part.length);
147         accum->pyWriteFunction(pyBytes);
148       } else {
149         nanobind::str pyStr(part.data,
150                             part.length); // Decodes as UTF-8 by default.
151         accum->pyWriteFunction(pyStr);
152       }
153     };
154   }
155 
156 private:
157   nanobind::object pyWriteFunction;
158   bool binary;
159 };
160 
161 /// Accumulates into a python string from a method that is expected to make
162 /// one (no more, no less) call to the callback (asserts internally on
163 /// violation).
164 struct PySinglePartStringAccumulator {
165   void *getUserData() { return this; }
166 
167   MlirStringCallback getCallback() {
168     return [](MlirStringRef part, void *userData) {
169       PySinglePartStringAccumulator *accum =
170           static_cast<PySinglePartStringAccumulator *>(userData);
171       assert(!accum->invoked &&
172              "PySinglePartStringAccumulator called back multiple times");
173       accum->invoked = true;
174       accum->value = nanobind::str(part.data, part.length);
175     };
176   }
177 
178   nanobind::str takeValue() {
179     assert(invoked && "PySinglePartStringAccumulator not called back");
180     return std::move(value);
181   }
182 
183 private:
184   nanobind::str value;
185   bool invoked = false;
186 };
187 
188 /// A CRTP base class for pseudo-containers willing to support Python-type
189 /// slicing access on top of indexed access. Calling ::bind on this class
190 /// will define `__len__` as well as `__getitem__` with integer and slice
191 /// arguments.
192 ///
193 /// This is intended for pseudo-containers that can refer to arbitrary slices of
194 /// underlying storage indexed by a single integer. Indexing those with an
195 /// integer produces an instance of ElementTy. Indexing those with a slice
196 /// produces a new instance of Derived, which can be sliced further.
197 ///
198 /// A derived class must provide the following:
199 ///   - a `static const char *pyClassName ` field containing the name of the
200 ///     Python class to bind;
201 ///   - an instance method `intptr_t getRawNumElements()` that returns the
202 ///   number
203 ///     of elements in the backing container (NOT that of the slice);
204 ///   - an instance method `ElementTy getRawElement(intptr_t)` that returns a
205 ///     single element at the given linear index (NOT slice index);
206 ///   - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
207 ///     constructs a new instance of the derived pseudo-container with the
208 ///     given slice parameters (to be forwarded to the Sliceable constructor).
209 ///
210 /// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
211 /// throw.
212 ///
213 /// A derived class may additionally define:
214 ///   - a `static void bindDerived(ClassTy &)` method to bind additional methods
215 ///     the python class.
216 template <typename Derived, typename ElementTy>
217 class Sliceable {
218 protected:
219   using ClassTy = nanobind::class_<Derived>;
220 
221   /// Transforms `index` into a legal value to access the underlying sequence.
222   /// Returns <0 on failure.
223   intptr_t wrapIndex(intptr_t index) {
224     if (index < 0)
225       index = length + index;
226     if (index < 0 || index >= length)
227       return -1;
228     return index;
229   }
230 
231   /// Computes the linear index given the current slice properties.
232   intptr_t linearizeIndex(intptr_t index) {
233     intptr_t linearIndex = index * step + startIndex;
234     assert(linearIndex >= 0 &&
235            linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
236            "linear index out of bounds, the slice is ill-formed");
237     return linearIndex;
238   }
239 
240   /// Trait to check if T provides a `maybeDownCast` method.
241   /// Note, you need the & to detect inherited members.
242   template <typename T, typename... Args>
243   using has_maybe_downcast = decltype(&T::maybeDownCast);
244 
245   /// Returns the element at the given slice index. Supports negative indices
246   /// by taking elements in inverse order. Returns a nullptr object if out
247   /// of bounds.
248   nanobind::object getItem(intptr_t index) {
249     // Negative indices mean we count from the end.
250     index = wrapIndex(index);
251     if (index < 0) {
252       PyErr_SetString(PyExc_IndexError, "index out of range");
253       return {};
254     }
255 
256     if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value)
257       return static_cast<Derived *>(this)
258           ->getRawElement(linearizeIndex(index))
259           .maybeDownCast();
260     else
261       return nanobind::cast(
262           static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
263   }
264 
265   /// Returns a new instance of the pseudo-container restricted to the given
266   /// slice. Returns a nullptr object on failure.
267   nanobind::object getItemSlice(PyObject *slice) {
268     ssize_t start, stop, extraStep, sliceLength;
269     if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
270                              &sliceLength) != 0) {
271       PyErr_SetString(PyExc_IndexError, "index out of range");
272       return {};
273     }
274     return nanobind::cast(static_cast<Derived *>(this)->slice(
275         startIndex + start * step, sliceLength, step * extraStep));
276   }
277 
278 public:
279   explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
280       : startIndex(startIndex), length(length), step(step) {
281     assert(length >= 0 && "expected non-negative slice length");
282   }
283 
284   /// Returns the `index`-th element in the slice, supports negative indices.
285   /// Throws if the index is out of bounds.
286   ElementTy getElement(intptr_t index) {
287     // Negative indices mean we count from the end.
288     index = wrapIndex(index);
289     if (index < 0) {
290       throw nanobind::index_error("index out of range");
291     }
292 
293     return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
294   }
295 
296   /// Returns the size of slice.
297   intptr_t size() { return length; }
298 
299   /// Returns a new vector (mapped to Python list) containing elements from two
300   /// slices. The new vector is necessary because slices may not be contiguous
301   /// or even come from the same original sequence.
302   std::vector<ElementTy> dunderAdd(Derived &other) {
303     std::vector<ElementTy> elements;
304     elements.reserve(length + other.length);
305     for (intptr_t i = 0; i < length; ++i) {
306       elements.push_back(static_cast<Derived *>(this)->getElement(i));
307     }
308     for (intptr_t i = 0; i < other.length; ++i) {
309       elements.push_back(static_cast<Derived *>(&other)->getElement(i));
310     }
311     return elements;
312   }
313 
314   /// Binds the indexing and length methods in the Python class.
315   static void bind(nanobind::module_ &m) {
316     auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName)
317                      .def("__add__", &Sliceable::dunderAdd);
318     Derived::bindDerived(clazz);
319 
320     // Manually implement the sequence protocol via the C API. We do this
321     // because it is approx 4x faster than via nanobind, largely because that
322     // formulation requires a C++ exception to be thrown to detect end of
323     // sequence.
324     // Since we are in a C-context, any C++ exception that happens here
325     // will terminate the program. There is nothing in this implementation
326     // that should throw in a non-terminal way, so we forgo further
327     // exception marshalling.
328     // See: https://github.com/pybind/nanobind/issues/2842
329     auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
330     assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
331            "must be heap type");
332     heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
333       auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
334       return self->length;
335     };
336     // sq_item is called as part of the sequence protocol for iteration,
337     // list construction, etc.
338     heap_type->as_sequence.sq_item =
339         +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
340       auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
341       return self->getItem(index).release().ptr();
342     };
343     // mp_subscript is used for both slices and integer lookups.
344     heap_type->as_mapping.mp_subscript =
345         +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
346       auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
347       Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
348       if (!PyErr_Occurred()) {
349         // Integer indexing.
350         return self->getItem(index).release().ptr();
351       }
352       PyErr_Clear();
353 
354       // Assume slice-based indexing.
355       if (PySlice_Check(rawSubscript)) {
356         return self->getItemSlice(rawSubscript).release().ptr();
357       }
358 
359       PyErr_SetString(PyExc_ValueError, "expected integer or slice");
360       return nullptr;
361     };
362   }
363 
364   /// Hook for derived classes willing to bind more methods.
365   static void bindDerived(ClassTy &) {}
366 
367 private:
368   intptr_t startIndex;
369   intptr_t length;
370   intptr_t step;
371 };
372 
373 } // namespace mlir
374 
375 namespace llvm {
376 
377 template <>
378 struct DenseMapInfo<MlirTypeID> {
379   static inline MlirTypeID getEmptyKey() {
380     auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
381     return mlirTypeIDCreate(pointer);
382   }
383   static inline MlirTypeID getTombstoneKey() {
384     auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
385     return mlirTypeIDCreate(pointer);
386   }
387   static inline unsigned getHashValue(const MlirTypeID &val) {
388     return mlirTypeIDHashValue(val);
389   }
390   static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) {
391     return mlirTypeIDEqual(lhs, rhs);
392   }
393 };
394 } // namespace llvm
395 
396 #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
397