1 //===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++ 2 //-*-===// 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 10 #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 11 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 12 13 #include "mlir-c/Support.h" 14 #include "mlir/Bindings/Python/Nanobind.h" 15 #include "llvm/ADT/STLExtras.h" 16 #include "llvm/ADT/Twine.h" 17 #include "llvm/Support/DataTypes.h" 18 19 template <> 20 struct std::iterator_traits<nanobind::detail::fast_iterator> { 21 using value_type = nanobind::handle; 22 using reference = const value_type; 23 using pointer = void; 24 using difference_type = std::ptrdiff_t; 25 using iterator_category = std::forward_iterator_tag; 26 }; 27 28 namespace mlir { 29 namespace python { 30 31 /// CRTP template for special wrapper types that are allowed to be passed in as 32 /// 'None' function arguments and can be resolved by some global mechanic if 33 /// so. Such types will raise an error if this global resolution fails, and 34 /// it is actually illegal for them to ever be unresolved. From a user 35 /// perspective, they behave like a smart ptr to the underlying type (i.e. 36 /// 'get' method and operator-> overloaded). 37 /// 38 /// Derived types must provide a method, which is called when an environmental 39 /// resolution is required. It must raise an exception if resolution fails: 40 /// static ReferrentTy &resolve() 41 /// 42 /// They must also provide a parameter description that will be used in 43 /// error messages about mismatched types: 44 /// static constexpr const char kTypeDescription[] = "<Description>"; 45 46 template <typename DerivedTy, typename T> 47 class Defaulting { 48 public: 49 using ReferrentTy = T; 50 /// Type casters require the type to be default constructible, but using 51 /// such an instance is illegal. 52 Defaulting() = default; 53 Defaulting(ReferrentTy &referrent) : referrent(&referrent) {} 54 55 ReferrentTy *get() const { return referrent; } 56 ReferrentTy *operator->() { return referrent; } 57 58 private: 59 ReferrentTy *referrent = nullptr; 60 }; 61 62 } // namespace python 63 } // namespace mlir 64 65 namespace nanobind { 66 namespace detail { 67 68 template <typename DefaultingTy> 69 struct MlirDefaultingCaster { 70 NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription)) 71 72 bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { 73 if (src.is_none()) { 74 // Note that we do want an exception to propagate from here as it will be 75 // the most informative. 76 value = DefaultingTy{DefaultingTy::resolve()}; 77 return true; 78 } 79 80 // Unlike many casters that chain, these casters are expected to always 81 // succeed, so instead of doing an isinstance check followed by a cast, 82 // just cast in one step and handle the exception. Returning false (vs 83 // letting the exception propagate) causes higher level signature parsing 84 // code to produce nice error messages (other than "Cannot cast..."). 85 try { 86 value = DefaultingTy{ 87 nanobind::cast<typename DefaultingTy::ReferrentTy &>(src)}; 88 return true; 89 } catch (std::exception &) { 90 return false; 91 } 92 } 93 94 static handle from_cpp(DefaultingTy src, rv_policy policy, 95 cleanup_list *cleanup) noexcept { 96 return nanobind::cast(src, policy); 97 } 98 }; 99 } // namespace detail 100 } // namespace nanobind 101 102 //------------------------------------------------------------------------------ 103 // Conversion utilities. 104 //------------------------------------------------------------------------------ 105 106 namespace mlir { 107 108 /// Accumulates into a python string from a method that accepts an 109 /// MlirStringCallback. 110 struct PyPrintAccumulator { 111 nanobind::list parts; 112 113 void *getUserData() { return this; } 114 115 MlirStringCallback getCallback() { 116 return [](MlirStringRef part, void *userData) { 117 PyPrintAccumulator *printAccum = 118 static_cast<PyPrintAccumulator *>(userData); 119 nanobind::str pyPart(part.data, 120 part.length); // Decodes as UTF-8 by default. 121 printAccum->parts.append(std::move(pyPart)); 122 }; 123 } 124 125 nanobind::str join() { 126 nanobind::str delim("", 0); 127 return nanobind::cast<nanobind::str>(delim.attr("join")(parts)); 128 } 129 }; 130 131 /// Accumulates int a python file-like object, either writing text (default) 132 /// or binary. 133 class PyFileAccumulator { 134 public: 135 PyFileAccumulator(const nanobind::object &fileObject, bool binary) 136 : pyWriteFunction(fileObject.attr("write")), binary(binary) {} 137 138 void *getUserData() { return this; } 139 140 MlirStringCallback getCallback() { 141 return [](MlirStringRef part, void *userData) { 142 nanobind::gil_scoped_acquire acquire; 143 PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData); 144 if (accum->binary) { 145 // Note: Still has to copy and not avoidable with this API. 146 nanobind::bytes pyBytes(part.data, part.length); 147 accum->pyWriteFunction(pyBytes); 148 } else { 149 nanobind::str pyStr(part.data, 150 part.length); // Decodes as UTF-8 by default. 151 accum->pyWriteFunction(pyStr); 152 } 153 }; 154 } 155 156 private: 157 nanobind::object pyWriteFunction; 158 bool binary; 159 }; 160 161 /// Accumulates into a python string from a method that is expected to make 162 /// one (no more, no less) call to the callback (asserts internally on 163 /// violation). 164 struct PySinglePartStringAccumulator { 165 void *getUserData() { return this; } 166 167 MlirStringCallback getCallback() { 168 return [](MlirStringRef part, void *userData) { 169 PySinglePartStringAccumulator *accum = 170 static_cast<PySinglePartStringAccumulator *>(userData); 171 assert(!accum->invoked && 172 "PySinglePartStringAccumulator called back multiple times"); 173 accum->invoked = true; 174 accum->value = nanobind::str(part.data, part.length); 175 }; 176 } 177 178 nanobind::str takeValue() { 179 assert(invoked && "PySinglePartStringAccumulator not called back"); 180 return std::move(value); 181 } 182 183 private: 184 nanobind::str value; 185 bool invoked = false; 186 }; 187 188 /// A CRTP base class for pseudo-containers willing to support Python-type 189 /// slicing access on top of indexed access. Calling ::bind on this class 190 /// will define `__len__` as well as `__getitem__` with integer and slice 191 /// arguments. 192 /// 193 /// This is intended for pseudo-containers that can refer to arbitrary slices of 194 /// underlying storage indexed by a single integer. Indexing those with an 195 /// integer produces an instance of ElementTy. Indexing those with a slice 196 /// produces a new instance of Derived, which can be sliced further. 197 /// 198 /// A derived class must provide the following: 199 /// - a `static const char *pyClassName ` field containing the name of the 200 /// Python class to bind; 201 /// - an instance method `intptr_t getRawNumElements()` that returns the 202 /// number 203 /// of elements in the backing container (NOT that of the slice); 204 /// - an instance method `ElementTy getRawElement(intptr_t)` that returns a 205 /// single element at the given linear index (NOT slice index); 206 /// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that 207 /// constructs a new instance of the derived pseudo-container with the 208 /// given slice parameters (to be forwarded to the Sliceable constructor). 209 /// 210 /// The getRawNumElements() and getRawElement(intptr_t) callbacks must not 211 /// throw. 212 /// 213 /// A derived class may additionally define: 214 /// - a `static void bindDerived(ClassTy &)` method to bind additional methods 215 /// the python class. 216 template <typename Derived, typename ElementTy> 217 class Sliceable { 218 protected: 219 using ClassTy = nanobind::class_<Derived>; 220 221 /// Transforms `index` into a legal value to access the underlying sequence. 222 /// Returns <0 on failure. 223 intptr_t wrapIndex(intptr_t index) { 224 if (index < 0) 225 index = length + index; 226 if (index < 0 || index >= length) 227 return -1; 228 return index; 229 } 230 231 /// Computes the linear index given the current slice properties. 232 intptr_t linearizeIndex(intptr_t index) { 233 intptr_t linearIndex = index * step + startIndex; 234 assert(linearIndex >= 0 && 235 linearIndex < static_cast<Derived *>(this)->getRawNumElements() && 236 "linear index out of bounds, the slice is ill-formed"); 237 return linearIndex; 238 } 239 240 /// Trait to check if T provides a `maybeDownCast` method. 241 /// Note, you need the & to detect inherited members. 242 template <typename T, typename... Args> 243 using has_maybe_downcast = decltype(&T::maybeDownCast); 244 245 /// Returns the element at the given slice index. Supports negative indices 246 /// by taking elements in inverse order. Returns a nullptr object if out 247 /// of bounds. 248 nanobind::object getItem(intptr_t index) { 249 // Negative indices mean we count from the end. 250 index = wrapIndex(index); 251 if (index < 0) { 252 PyErr_SetString(PyExc_IndexError, "index out of range"); 253 return {}; 254 } 255 256 if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value) 257 return static_cast<Derived *>(this) 258 ->getRawElement(linearizeIndex(index)) 259 .maybeDownCast(); 260 else 261 return nanobind::cast( 262 static_cast<Derived *>(this)->getRawElement(linearizeIndex(index))); 263 } 264 265 /// Returns a new instance of the pseudo-container restricted to the given 266 /// slice. Returns a nullptr object on failure. 267 nanobind::object getItemSlice(PyObject *slice) { 268 ssize_t start, stop, extraStep, sliceLength; 269 if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep, 270 &sliceLength) != 0) { 271 PyErr_SetString(PyExc_IndexError, "index out of range"); 272 return {}; 273 } 274 return nanobind::cast(static_cast<Derived *>(this)->slice( 275 startIndex + start * step, sliceLength, step * extraStep)); 276 } 277 278 public: 279 explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) 280 : startIndex(startIndex), length(length), step(step) { 281 assert(length >= 0 && "expected non-negative slice length"); 282 } 283 284 /// Returns the `index`-th element in the slice, supports negative indices. 285 /// Throws if the index is out of bounds. 286 ElementTy getElement(intptr_t index) { 287 // Negative indices mean we count from the end. 288 index = wrapIndex(index); 289 if (index < 0) { 290 throw nanobind::index_error("index out of range"); 291 } 292 293 return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)); 294 } 295 296 /// Returns the size of slice. 297 intptr_t size() { return length; } 298 299 /// Returns a new vector (mapped to Python list) containing elements from two 300 /// slices. The new vector is necessary because slices may not be contiguous 301 /// or even come from the same original sequence. 302 std::vector<ElementTy> dunderAdd(Derived &other) { 303 std::vector<ElementTy> elements; 304 elements.reserve(length + other.length); 305 for (intptr_t i = 0; i < length; ++i) { 306 elements.push_back(static_cast<Derived *>(this)->getElement(i)); 307 } 308 for (intptr_t i = 0; i < other.length; ++i) { 309 elements.push_back(static_cast<Derived *>(&other)->getElement(i)); 310 } 311 return elements; 312 } 313 314 /// Binds the indexing and length methods in the Python class. 315 static void bind(nanobind::module_ &m) { 316 auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName) 317 .def("__add__", &Sliceable::dunderAdd); 318 Derived::bindDerived(clazz); 319 320 // Manually implement the sequence protocol via the C API. We do this 321 // because it is approx 4x faster than via nanobind, largely because that 322 // formulation requires a C++ exception to be thrown to detect end of 323 // sequence. 324 // Since we are in a C-context, any C++ exception that happens here 325 // will terminate the program. There is nothing in this implementation 326 // that should throw in a non-terminal way, so we forgo further 327 // exception marshalling. 328 // See: https://github.com/pybind/nanobind/issues/2842 329 auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr()); 330 assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE && 331 "must be heap type"); 332 heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t { 333 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf)); 334 return self->length; 335 }; 336 // sq_item is called as part of the sequence protocol for iteration, 337 // list construction, etc. 338 heap_type->as_sequence.sq_item = 339 +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * { 340 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf)); 341 return self->getItem(index).release().ptr(); 342 }; 343 // mp_subscript is used for both slices and integer lookups. 344 heap_type->as_mapping.mp_subscript = 345 +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * { 346 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf)); 347 Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError); 348 if (!PyErr_Occurred()) { 349 // Integer indexing. 350 return self->getItem(index).release().ptr(); 351 } 352 PyErr_Clear(); 353 354 // Assume slice-based indexing. 355 if (PySlice_Check(rawSubscript)) { 356 return self->getItemSlice(rawSubscript).release().ptr(); 357 } 358 359 PyErr_SetString(PyExc_ValueError, "expected integer or slice"); 360 return nullptr; 361 }; 362 } 363 364 /// Hook for derived classes willing to bind more methods. 365 static void bindDerived(ClassTy &) {} 366 367 private: 368 intptr_t startIndex; 369 intptr_t length; 370 intptr_t step; 371 }; 372 373 } // namespace mlir 374 375 namespace llvm { 376 377 template <> 378 struct DenseMapInfo<MlirTypeID> { 379 static inline MlirTypeID getEmptyKey() { 380 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); 381 return mlirTypeIDCreate(pointer); 382 } 383 static inline MlirTypeID getTombstoneKey() { 384 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); 385 return mlirTypeIDCreate(pointer); 386 } 387 static inline unsigned getHashValue(const MlirTypeID &val) { 388 return mlirTypeIDHashValue(val); 389 } 390 static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) { 391 return mlirTypeIDEqual(lhs, rhs); 392 } 393 }; 394 } // namespace llvm 395 396 #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 397