1b56d1ec6SPeter Hawkins //===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++ 2b56d1ec6SPeter Hawkins //-*-===// 3b56d1ec6SPeter Hawkins // 4b56d1ec6SPeter Hawkins // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5b56d1ec6SPeter Hawkins // See https://llvm.org/LICENSE.txt for license information. 6b56d1ec6SPeter Hawkins // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7b56d1ec6SPeter Hawkins // 8b56d1ec6SPeter Hawkins //===----------------------------------------------------------------------===// 9b56d1ec6SPeter Hawkins 10b56d1ec6SPeter Hawkins #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 11b56d1ec6SPeter Hawkins #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 12b56d1ec6SPeter Hawkins 13b56d1ec6SPeter Hawkins #include "mlir-c/Support.h" 14*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h" 15b56d1ec6SPeter Hawkins #include "llvm/ADT/STLExtras.h" 16b56d1ec6SPeter Hawkins #include "llvm/ADT/Twine.h" 17b56d1ec6SPeter Hawkins #include "llvm/Support/DataTypes.h" 18b56d1ec6SPeter Hawkins 19b56d1ec6SPeter Hawkins template <> 20b56d1ec6SPeter Hawkins struct std::iterator_traits<nanobind::detail::fast_iterator> { 21b56d1ec6SPeter Hawkins using value_type = nanobind::handle; 22b56d1ec6SPeter Hawkins using reference = const value_type; 23b56d1ec6SPeter Hawkins using pointer = void; 24b56d1ec6SPeter Hawkins using difference_type = std::ptrdiff_t; 25b56d1ec6SPeter Hawkins using iterator_category = std::forward_iterator_tag; 26b56d1ec6SPeter Hawkins }; 27b56d1ec6SPeter Hawkins 28b56d1ec6SPeter Hawkins namespace mlir { 29b56d1ec6SPeter Hawkins namespace python { 30b56d1ec6SPeter Hawkins 31b56d1ec6SPeter Hawkins /// CRTP template for special wrapper types that are allowed to be passed in as 32b56d1ec6SPeter Hawkins /// 'None' function arguments and can be resolved by some global mechanic if 33b56d1ec6SPeter Hawkins /// so. Such types will raise an error if this global resolution fails, and 34b56d1ec6SPeter Hawkins /// it is actually illegal for them to ever be unresolved. From a user 35b56d1ec6SPeter Hawkins /// perspective, they behave like a smart ptr to the underlying type (i.e. 36b56d1ec6SPeter Hawkins /// 'get' method and operator-> overloaded). 37b56d1ec6SPeter Hawkins /// 38b56d1ec6SPeter Hawkins /// Derived types must provide a method, which is called when an environmental 39b56d1ec6SPeter Hawkins /// resolution is required. It must raise an exception if resolution fails: 40b56d1ec6SPeter Hawkins /// static ReferrentTy &resolve() 41b56d1ec6SPeter Hawkins /// 42b56d1ec6SPeter Hawkins /// They must also provide a parameter description that will be used in 43b56d1ec6SPeter Hawkins /// error messages about mismatched types: 44b56d1ec6SPeter Hawkins /// static constexpr const char kTypeDescription[] = "<Description>"; 45b56d1ec6SPeter Hawkins 46b56d1ec6SPeter Hawkins template <typename DerivedTy, typename T> 47b56d1ec6SPeter Hawkins class Defaulting { 48b56d1ec6SPeter Hawkins public: 49b56d1ec6SPeter Hawkins using ReferrentTy = T; 50b56d1ec6SPeter Hawkins /// Type casters require the type to be default constructible, but using 51b56d1ec6SPeter Hawkins /// such an instance is illegal. 52b56d1ec6SPeter Hawkins Defaulting() = default; 53b56d1ec6SPeter Hawkins Defaulting(ReferrentTy &referrent) : referrent(&referrent) {} 54b56d1ec6SPeter Hawkins 55b56d1ec6SPeter Hawkins ReferrentTy *get() const { return referrent; } 56b56d1ec6SPeter Hawkins ReferrentTy *operator->() { return referrent; } 57b56d1ec6SPeter Hawkins 58b56d1ec6SPeter Hawkins private: 59b56d1ec6SPeter Hawkins ReferrentTy *referrent = nullptr; 60b56d1ec6SPeter Hawkins }; 61b56d1ec6SPeter Hawkins 62b56d1ec6SPeter Hawkins } // namespace python 63b56d1ec6SPeter Hawkins } // namespace mlir 64b56d1ec6SPeter Hawkins 65b56d1ec6SPeter Hawkins namespace nanobind { 66b56d1ec6SPeter Hawkins namespace detail { 67b56d1ec6SPeter Hawkins 68b56d1ec6SPeter Hawkins template <typename DefaultingTy> 69b56d1ec6SPeter Hawkins struct MlirDefaultingCaster { 70*5cd42747SPeter Hawkins NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription)) 71b56d1ec6SPeter Hawkins 72b56d1ec6SPeter Hawkins bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { 73b56d1ec6SPeter Hawkins if (src.is_none()) { 74b56d1ec6SPeter Hawkins // Note that we do want an exception to propagate from here as it will be 75b56d1ec6SPeter Hawkins // the most informative. 76b56d1ec6SPeter Hawkins value = DefaultingTy{DefaultingTy::resolve()}; 77b56d1ec6SPeter Hawkins return true; 78b56d1ec6SPeter Hawkins } 79b56d1ec6SPeter Hawkins 80b56d1ec6SPeter Hawkins // Unlike many casters that chain, these casters are expected to always 81b56d1ec6SPeter Hawkins // succeed, so instead of doing an isinstance check followed by a cast, 82b56d1ec6SPeter Hawkins // just cast in one step and handle the exception. Returning false (vs 83b56d1ec6SPeter Hawkins // letting the exception propagate) causes higher level signature parsing 84b56d1ec6SPeter Hawkins // code to produce nice error messages (other than "Cannot cast..."). 85b56d1ec6SPeter Hawkins try { 86b56d1ec6SPeter Hawkins value = DefaultingTy{ 87b56d1ec6SPeter Hawkins nanobind::cast<typename DefaultingTy::ReferrentTy &>(src)}; 88b56d1ec6SPeter Hawkins return true; 89b56d1ec6SPeter Hawkins } catch (std::exception &) { 90b56d1ec6SPeter Hawkins return false; 91b56d1ec6SPeter Hawkins } 92b56d1ec6SPeter Hawkins } 93b56d1ec6SPeter Hawkins 94b56d1ec6SPeter Hawkins static handle from_cpp(DefaultingTy src, rv_policy policy, 95b56d1ec6SPeter Hawkins cleanup_list *cleanup) noexcept { 96b56d1ec6SPeter Hawkins return nanobind::cast(src, policy); 97b56d1ec6SPeter Hawkins } 98b56d1ec6SPeter Hawkins }; 99b56d1ec6SPeter Hawkins } // namespace detail 100b56d1ec6SPeter Hawkins } // namespace nanobind 101b56d1ec6SPeter Hawkins 102b56d1ec6SPeter Hawkins //------------------------------------------------------------------------------ 103b56d1ec6SPeter Hawkins // Conversion utilities. 104b56d1ec6SPeter Hawkins //------------------------------------------------------------------------------ 105b56d1ec6SPeter Hawkins 106b56d1ec6SPeter Hawkins namespace mlir { 107b56d1ec6SPeter Hawkins 108b56d1ec6SPeter Hawkins /// Accumulates into a python string from a method that accepts an 109b56d1ec6SPeter Hawkins /// MlirStringCallback. 110b56d1ec6SPeter Hawkins struct PyPrintAccumulator { 111b56d1ec6SPeter Hawkins nanobind::list parts; 112b56d1ec6SPeter Hawkins 113b56d1ec6SPeter Hawkins void *getUserData() { return this; } 114b56d1ec6SPeter Hawkins 115b56d1ec6SPeter Hawkins MlirStringCallback getCallback() { 116b56d1ec6SPeter Hawkins return [](MlirStringRef part, void *userData) { 117b56d1ec6SPeter Hawkins PyPrintAccumulator *printAccum = 118b56d1ec6SPeter Hawkins static_cast<PyPrintAccumulator *>(userData); 119b56d1ec6SPeter Hawkins nanobind::str pyPart(part.data, 120b56d1ec6SPeter Hawkins part.length); // Decodes as UTF-8 by default. 121b56d1ec6SPeter Hawkins printAccum->parts.append(std::move(pyPart)); 122b56d1ec6SPeter Hawkins }; 123b56d1ec6SPeter Hawkins } 124b56d1ec6SPeter Hawkins 125b56d1ec6SPeter Hawkins nanobind::str join() { 126b56d1ec6SPeter Hawkins nanobind::str delim("", 0); 127b56d1ec6SPeter Hawkins return nanobind::cast<nanobind::str>(delim.attr("join")(parts)); 128b56d1ec6SPeter Hawkins } 129b56d1ec6SPeter Hawkins }; 130b56d1ec6SPeter Hawkins 131b56d1ec6SPeter Hawkins /// Accumulates int a python file-like object, either writing text (default) 132b56d1ec6SPeter Hawkins /// or binary. 133b56d1ec6SPeter Hawkins class PyFileAccumulator { 134b56d1ec6SPeter Hawkins public: 135b56d1ec6SPeter Hawkins PyFileAccumulator(const nanobind::object &fileObject, bool binary) 136b56d1ec6SPeter Hawkins : pyWriteFunction(fileObject.attr("write")), binary(binary) {} 137b56d1ec6SPeter Hawkins 138b56d1ec6SPeter Hawkins void *getUserData() { return this; } 139b56d1ec6SPeter Hawkins 140b56d1ec6SPeter Hawkins MlirStringCallback getCallback() { 141b56d1ec6SPeter Hawkins return [](MlirStringRef part, void *userData) { 142b56d1ec6SPeter Hawkins nanobind::gil_scoped_acquire acquire; 143b56d1ec6SPeter Hawkins PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData); 144b56d1ec6SPeter Hawkins if (accum->binary) { 145b56d1ec6SPeter Hawkins // Note: Still has to copy and not avoidable with this API. 146b56d1ec6SPeter Hawkins nanobind::bytes pyBytes(part.data, part.length); 147b56d1ec6SPeter Hawkins accum->pyWriteFunction(pyBytes); 148b56d1ec6SPeter Hawkins } else { 149b56d1ec6SPeter Hawkins nanobind::str pyStr(part.data, 150b56d1ec6SPeter Hawkins part.length); // Decodes as UTF-8 by default. 151b56d1ec6SPeter Hawkins accum->pyWriteFunction(pyStr); 152b56d1ec6SPeter Hawkins } 153b56d1ec6SPeter Hawkins }; 154b56d1ec6SPeter Hawkins } 155b56d1ec6SPeter Hawkins 156b56d1ec6SPeter Hawkins private: 157b56d1ec6SPeter Hawkins nanobind::object pyWriteFunction; 158b56d1ec6SPeter Hawkins bool binary; 159b56d1ec6SPeter Hawkins }; 160b56d1ec6SPeter Hawkins 161b56d1ec6SPeter Hawkins /// Accumulates into a python string from a method that is expected to make 162b56d1ec6SPeter Hawkins /// one (no more, no less) call to the callback (asserts internally on 163b56d1ec6SPeter Hawkins /// violation). 164b56d1ec6SPeter Hawkins struct PySinglePartStringAccumulator { 165b56d1ec6SPeter Hawkins void *getUserData() { return this; } 166b56d1ec6SPeter Hawkins 167b56d1ec6SPeter Hawkins MlirStringCallback getCallback() { 168b56d1ec6SPeter Hawkins return [](MlirStringRef part, void *userData) { 169b56d1ec6SPeter Hawkins PySinglePartStringAccumulator *accum = 170b56d1ec6SPeter Hawkins static_cast<PySinglePartStringAccumulator *>(userData); 171b56d1ec6SPeter Hawkins assert(!accum->invoked && 172b56d1ec6SPeter Hawkins "PySinglePartStringAccumulator called back multiple times"); 173b56d1ec6SPeter Hawkins accum->invoked = true; 174b56d1ec6SPeter Hawkins accum->value = nanobind::str(part.data, part.length); 175b56d1ec6SPeter Hawkins }; 176b56d1ec6SPeter Hawkins } 177b56d1ec6SPeter Hawkins 178b56d1ec6SPeter Hawkins nanobind::str takeValue() { 179b56d1ec6SPeter Hawkins assert(invoked && "PySinglePartStringAccumulator not called back"); 180b56d1ec6SPeter Hawkins return std::move(value); 181b56d1ec6SPeter Hawkins } 182b56d1ec6SPeter Hawkins 183b56d1ec6SPeter Hawkins private: 184b56d1ec6SPeter Hawkins nanobind::str value; 185b56d1ec6SPeter Hawkins bool invoked = false; 186b56d1ec6SPeter Hawkins }; 187b56d1ec6SPeter Hawkins 188b56d1ec6SPeter Hawkins /// A CRTP base class for pseudo-containers willing to support Python-type 189b56d1ec6SPeter Hawkins /// slicing access on top of indexed access. Calling ::bind on this class 190b56d1ec6SPeter Hawkins /// will define `__len__` as well as `__getitem__` with integer and slice 191b56d1ec6SPeter Hawkins /// arguments. 192b56d1ec6SPeter Hawkins /// 193b56d1ec6SPeter Hawkins /// This is intended for pseudo-containers that can refer to arbitrary slices of 194b56d1ec6SPeter Hawkins /// underlying storage indexed by a single integer. Indexing those with an 195b56d1ec6SPeter Hawkins /// integer produces an instance of ElementTy. Indexing those with a slice 196b56d1ec6SPeter Hawkins /// produces a new instance of Derived, which can be sliced further. 197b56d1ec6SPeter Hawkins /// 198b56d1ec6SPeter Hawkins /// A derived class must provide the following: 199b56d1ec6SPeter Hawkins /// - a `static const char *pyClassName ` field containing the name of the 200b56d1ec6SPeter Hawkins /// Python class to bind; 201b56d1ec6SPeter Hawkins /// - an instance method `intptr_t getRawNumElements()` that returns the 202b56d1ec6SPeter Hawkins /// number 203b56d1ec6SPeter Hawkins /// of elements in the backing container (NOT that of the slice); 204b56d1ec6SPeter Hawkins /// - an instance method `ElementTy getRawElement(intptr_t)` that returns a 205b56d1ec6SPeter Hawkins /// single element at the given linear index (NOT slice index); 206b56d1ec6SPeter Hawkins /// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that 207b56d1ec6SPeter Hawkins /// constructs a new instance of the derived pseudo-container with the 208b56d1ec6SPeter Hawkins /// given slice parameters (to be forwarded to the Sliceable constructor). 209b56d1ec6SPeter Hawkins /// 210b56d1ec6SPeter Hawkins /// The getRawNumElements() and getRawElement(intptr_t) callbacks must not 211b56d1ec6SPeter Hawkins /// throw. 212b56d1ec6SPeter Hawkins /// 213b56d1ec6SPeter Hawkins /// A derived class may additionally define: 214b56d1ec6SPeter Hawkins /// - a `static void bindDerived(ClassTy &)` method to bind additional methods 215b56d1ec6SPeter Hawkins /// the python class. 216b56d1ec6SPeter Hawkins template <typename Derived, typename ElementTy> 217b56d1ec6SPeter Hawkins class Sliceable { 218b56d1ec6SPeter Hawkins protected: 219b56d1ec6SPeter Hawkins using ClassTy = nanobind::class_<Derived>; 220b56d1ec6SPeter Hawkins 221b56d1ec6SPeter Hawkins /// Transforms `index` into a legal value to access the underlying sequence. 222b56d1ec6SPeter Hawkins /// Returns <0 on failure. 223b56d1ec6SPeter Hawkins intptr_t wrapIndex(intptr_t index) { 224b56d1ec6SPeter Hawkins if (index < 0) 225b56d1ec6SPeter Hawkins index = length + index; 226b56d1ec6SPeter Hawkins if (index < 0 || index >= length) 227b56d1ec6SPeter Hawkins return -1; 228b56d1ec6SPeter Hawkins return index; 229b56d1ec6SPeter Hawkins } 230b56d1ec6SPeter Hawkins 231b56d1ec6SPeter Hawkins /// Computes the linear index given the current slice properties. 232b56d1ec6SPeter Hawkins intptr_t linearizeIndex(intptr_t index) { 233b56d1ec6SPeter Hawkins intptr_t linearIndex = index * step + startIndex; 234b56d1ec6SPeter Hawkins assert(linearIndex >= 0 && 235b56d1ec6SPeter Hawkins linearIndex < static_cast<Derived *>(this)->getRawNumElements() && 236b56d1ec6SPeter Hawkins "linear index out of bounds, the slice is ill-formed"); 237b56d1ec6SPeter Hawkins return linearIndex; 238b56d1ec6SPeter Hawkins } 239b56d1ec6SPeter Hawkins 240b56d1ec6SPeter Hawkins /// Trait to check if T provides a `maybeDownCast` method. 241b56d1ec6SPeter Hawkins /// Note, you need the & to detect inherited members. 242b56d1ec6SPeter Hawkins template <typename T, typename... Args> 243b56d1ec6SPeter Hawkins using has_maybe_downcast = decltype(&T::maybeDownCast); 244b56d1ec6SPeter Hawkins 245b56d1ec6SPeter Hawkins /// Returns the element at the given slice index. Supports negative indices 246b56d1ec6SPeter Hawkins /// by taking elements in inverse order. Returns a nullptr object if out 247b56d1ec6SPeter Hawkins /// of bounds. 248b56d1ec6SPeter Hawkins nanobind::object getItem(intptr_t index) { 249b56d1ec6SPeter Hawkins // Negative indices mean we count from the end. 250b56d1ec6SPeter Hawkins index = wrapIndex(index); 251b56d1ec6SPeter Hawkins if (index < 0) { 252b56d1ec6SPeter Hawkins PyErr_SetString(PyExc_IndexError, "index out of range"); 253b56d1ec6SPeter Hawkins return {}; 254b56d1ec6SPeter Hawkins } 255b56d1ec6SPeter Hawkins 256b56d1ec6SPeter Hawkins if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value) 257b56d1ec6SPeter Hawkins return static_cast<Derived *>(this) 258b56d1ec6SPeter Hawkins ->getRawElement(linearizeIndex(index)) 259b56d1ec6SPeter Hawkins .maybeDownCast(); 260b56d1ec6SPeter Hawkins else 261b56d1ec6SPeter Hawkins return nanobind::cast( 262b56d1ec6SPeter Hawkins static_cast<Derived *>(this)->getRawElement(linearizeIndex(index))); 263b56d1ec6SPeter Hawkins } 264b56d1ec6SPeter Hawkins 265b56d1ec6SPeter Hawkins /// Returns a new instance of the pseudo-container restricted to the given 266b56d1ec6SPeter Hawkins /// slice. Returns a nullptr object on failure. 267b56d1ec6SPeter Hawkins nanobind::object getItemSlice(PyObject *slice) { 268b56d1ec6SPeter Hawkins ssize_t start, stop, extraStep, sliceLength; 269b56d1ec6SPeter Hawkins if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep, 270b56d1ec6SPeter Hawkins &sliceLength) != 0) { 271b56d1ec6SPeter Hawkins PyErr_SetString(PyExc_IndexError, "index out of range"); 272b56d1ec6SPeter Hawkins return {}; 273b56d1ec6SPeter Hawkins } 274b56d1ec6SPeter Hawkins return nanobind::cast(static_cast<Derived *>(this)->slice( 275b56d1ec6SPeter Hawkins startIndex + start * step, sliceLength, step * extraStep)); 276b56d1ec6SPeter Hawkins } 277b56d1ec6SPeter Hawkins 278b56d1ec6SPeter Hawkins public: 279b56d1ec6SPeter Hawkins explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) 280b56d1ec6SPeter Hawkins : startIndex(startIndex), length(length), step(step) { 281b56d1ec6SPeter Hawkins assert(length >= 0 && "expected non-negative slice length"); 282b56d1ec6SPeter Hawkins } 283b56d1ec6SPeter Hawkins 284b56d1ec6SPeter Hawkins /// Returns the `index`-th element in the slice, supports negative indices. 285b56d1ec6SPeter Hawkins /// Throws if the index is out of bounds. 286b56d1ec6SPeter Hawkins ElementTy getElement(intptr_t index) { 287b56d1ec6SPeter Hawkins // Negative indices mean we count from the end. 288b56d1ec6SPeter Hawkins index = wrapIndex(index); 289b56d1ec6SPeter Hawkins if (index < 0) { 290b56d1ec6SPeter Hawkins throw nanobind::index_error("index out of range"); 291b56d1ec6SPeter Hawkins } 292b56d1ec6SPeter Hawkins 293b56d1ec6SPeter Hawkins return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)); 294b56d1ec6SPeter Hawkins } 295b56d1ec6SPeter Hawkins 296b56d1ec6SPeter Hawkins /// Returns the size of slice. 297b56d1ec6SPeter Hawkins intptr_t size() { return length; } 298b56d1ec6SPeter Hawkins 299b56d1ec6SPeter Hawkins /// Returns a new vector (mapped to Python list) containing elements from two 300b56d1ec6SPeter Hawkins /// slices. The new vector is necessary because slices may not be contiguous 301b56d1ec6SPeter Hawkins /// or even come from the same original sequence. 302b56d1ec6SPeter Hawkins std::vector<ElementTy> dunderAdd(Derived &other) { 303b56d1ec6SPeter Hawkins std::vector<ElementTy> elements; 304b56d1ec6SPeter Hawkins elements.reserve(length + other.length); 305b56d1ec6SPeter Hawkins for (intptr_t i = 0; i < length; ++i) { 306b56d1ec6SPeter Hawkins elements.push_back(static_cast<Derived *>(this)->getElement(i)); 307b56d1ec6SPeter Hawkins } 308b56d1ec6SPeter Hawkins for (intptr_t i = 0; i < other.length; ++i) { 309b56d1ec6SPeter Hawkins elements.push_back(static_cast<Derived *>(&other)->getElement(i)); 310b56d1ec6SPeter Hawkins } 311b56d1ec6SPeter Hawkins return elements; 312b56d1ec6SPeter Hawkins } 313b56d1ec6SPeter Hawkins 314b56d1ec6SPeter Hawkins /// Binds the indexing and length methods in the Python class. 315b56d1ec6SPeter Hawkins static void bind(nanobind::module_ &m) { 316b56d1ec6SPeter Hawkins auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName) 317b56d1ec6SPeter Hawkins .def("__add__", &Sliceable::dunderAdd); 318b56d1ec6SPeter Hawkins Derived::bindDerived(clazz); 319b56d1ec6SPeter Hawkins 320b56d1ec6SPeter Hawkins // Manually implement the sequence protocol via the C API. We do this 321b56d1ec6SPeter Hawkins // because it is approx 4x faster than via nanobind, largely because that 322b56d1ec6SPeter Hawkins // formulation requires a C++ exception to be thrown to detect end of 323b56d1ec6SPeter Hawkins // sequence. 324b56d1ec6SPeter Hawkins // Since we are in a C-context, any C++ exception that happens here 325b56d1ec6SPeter Hawkins // will terminate the program. There is nothing in this implementation 326b56d1ec6SPeter Hawkins // that should throw in a non-terminal way, so we forgo further 327b56d1ec6SPeter Hawkins // exception marshalling. 328b56d1ec6SPeter Hawkins // See: https://github.com/pybind/nanobind/issues/2842 329b56d1ec6SPeter Hawkins auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr()); 330b56d1ec6SPeter Hawkins assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE && 331b56d1ec6SPeter Hawkins "must be heap type"); 332b56d1ec6SPeter Hawkins heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t { 333b56d1ec6SPeter Hawkins auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf)); 334b56d1ec6SPeter Hawkins return self->length; 335b56d1ec6SPeter Hawkins }; 336b56d1ec6SPeter Hawkins // sq_item is called as part of the sequence protocol for iteration, 337b56d1ec6SPeter Hawkins // list construction, etc. 338b56d1ec6SPeter Hawkins heap_type->as_sequence.sq_item = 339b56d1ec6SPeter Hawkins +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * { 340b56d1ec6SPeter Hawkins auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf)); 341b56d1ec6SPeter Hawkins return self->getItem(index).release().ptr(); 342b56d1ec6SPeter Hawkins }; 343b56d1ec6SPeter Hawkins // mp_subscript is used for both slices and integer lookups. 344b56d1ec6SPeter Hawkins heap_type->as_mapping.mp_subscript = 345b56d1ec6SPeter Hawkins +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * { 346b56d1ec6SPeter Hawkins auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf)); 347b56d1ec6SPeter Hawkins Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError); 348b56d1ec6SPeter Hawkins if (!PyErr_Occurred()) { 349b56d1ec6SPeter Hawkins // Integer indexing. 350b56d1ec6SPeter Hawkins return self->getItem(index).release().ptr(); 351b56d1ec6SPeter Hawkins } 352b56d1ec6SPeter Hawkins PyErr_Clear(); 353b56d1ec6SPeter Hawkins 354b56d1ec6SPeter Hawkins // Assume slice-based indexing. 355b56d1ec6SPeter Hawkins if (PySlice_Check(rawSubscript)) { 356b56d1ec6SPeter Hawkins return self->getItemSlice(rawSubscript).release().ptr(); 357b56d1ec6SPeter Hawkins } 358b56d1ec6SPeter Hawkins 359b56d1ec6SPeter Hawkins PyErr_SetString(PyExc_ValueError, "expected integer or slice"); 360b56d1ec6SPeter Hawkins return nullptr; 361b56d1ec6SPeter Hawkins }; 362b56d1ec6SPeter Hawkins } 363b56d1ec6SPeter Hawkins 364b56d1ec6SPeter Hawkins /// Hook for derived classes willing to bind more methods. 365b56d1ec6SPeter Hawkins static void bindDerived(ClassTy &) {} 366b56d1ec6SPeter Hawkins 367b56d1ec6SPeter Hawkins private: 368b56d1ec6SPeter Hawkins intptr_t startIndex; 369b56d1ec6SPeter Hawkins intptr_t length; 370b56d1ec6SPeter Hawkins intptr_t step; 371b56d1ec6SPeter Hawkins }; 372b56d1ec6SPeter Hawkins 373b56d1ec6SPeter Hawkins } // namespace mlir 374b56d1ec6SPeter Hawkins 375b56d1ec6SPeter Hawkins namespace llvm { 376b56d1ec6SPeter Hawkins 377b56d1ec6SPeter Hawkins template <> 378b56d1ec6SPeter Hawkins struct DenseMapInfo<MlirTypeID> { 379b56d1ec6SPeter Hawkins static inline MlirTypeID getEmptyKey() { 380b56d1ec6SPeter Hawkins auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); 381b56d1ec6SPeter Hawkins return mlirTypeIDCreate(pointer); 382b56d1ec6SPeter Hawkins } 383b56d1ec6SPeter Hawkins static inline MlirTypeID getTombstoneKey() { 384b56d1ec6SPeter Hawkins auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); 385b56d1ec6SPeter Hawkins return mlirTypeIDCreate(pointer); 386b56d1ec6SPeter Hawkins } 387b56d1ec6SPeter Hawkins static inline unsigned getHashValue(const MlirTypeID &val) { 388b56d1ec6SPeter Hawkins return mlirTypeIDHashValue(val); 389b56d1ec6SPeter Hawkins } 390b56d1ec6SPeter Hawkins static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) { 391b56d1ec6SPeter Hawkins return mlirTypeIDEqual(lhs, rhs); 392b56d1ec6SPeter Hawkins } 393b56d1ec6SPeter Hawkins }; 394b56d1ec6SPeter Hawkins } // namespace llvm 395b56d1ec6SPeter Hawkins 396b56d1ec6SPeter Hawkins #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 397