114c92070SAlex Zinenko //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// 214c92070SAlex Zinenko // 314c92070SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 414c92070SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 514c92070SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 614c92070SAlex Zinenko // 714c92070SAlex Zinenko //===----------------------------------------------------------------------===// 814c92070SAlex Zinenko 9dc81dfa0SMehdi Amini #include <cstdint> 10a1fe1f5fSKazu Hirata #include <optional> 11dc81dfa0SMehdi Amini #include <string> 12f22008edSArash Taheri-Dezfouli #include <utility> 13dc81dfa0SMehdi Amini #include <vector> 141fc096afSMehdi Amini 1514c92070SAlex Zinenko #include "IRModule.h" 1614c92070SAlex Zinenko #include "mlir-c/BuiltinAttributes.h" 17dc81dfa0SMehdi Amini #include "mlir-c/IR.h" 1814c92070SAlex Zinenko #include "mlir-c/Interfaces.h" 19dc81dfa0SMehdi Amini #include "mlir-c/Support.h" 20*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h" 21ee308c99SJacques Pienaar #include "llvm/ADT/STLExtras.h" 22dc81dfa0SMehdi Amini #include "llvm/ADT/SmallVector.h" 2314c92070SAlex Zinenko 24b56d1ec6SPeter Hawkins namespace nb = nanobind; 2514c92070SAlex Zinenko 2614c92070SAlex Zinenko namespace mlir { 2714c92070SAlex Zinenko namespace python { 2814c92070SAlex Zinenko 2914c92070SAlex Zinenko constexpr static const char *constructorDoc = 3014c92070SAlex Zinenko R"(Creates an interface from a given operation/opview object or from a 3114c92070SAlex Zinenko subclass of OpView. Raises ValueError if the operation does not implement the 3214c92070SAlex Zinenko interface.)"; 3314c92070SAlex Zinenko 3414c92070SAlex Zinenko constexpr static const char *operationDoc = 3514c92070SAlex Zinenko R"(Returns an Operation for which the interface was constructed.)"; 3614c92070SAlex Zinenko 3714c92070SAlex Zinenko constexpr static const char *opviewDoc = 3814c92070SAlex Zinenko R"(Returns an OpView subclass _instance_ for which the interface was 3914c92070SAlex Zinenko constructed)"; 4014c92070SAlex Zinenko 4114c92070SAlex Zinenko constexpr static const char *inferReturnTypesDoc = 4214c92070SAlex Zinenko R"(Given the arguments required to build an operation, attempts to infer 4314c92070SAlex Zinenko its return types. Raises ValueError on failure.)"; 4414c92070SAlex Zinenko 45f22008edSArash Taheri-Dezfouli constexpr static const char *inferReturnTypeComponentsDoc = 46f22008edSArash Taheri-Dezfouli R"(Given the arguments required to build an operation, attempts to infer 47f22008edSArash Taheri-Dezfouli its return shaped type components. Raises ValueError on failure.)"; 48f22008edSArash Taheri-Dezfouli 49f22008edSArash Taheri-Dezfouli namespace { 50f22008edSArash Taheri-Dezfouli 51f22008edSArash Taheri-Dezfouli /// Takes in an optional ist of operands and converts them into a SmallVector 52f22008edSArash Taheri-Dezfouli /// of MlirVlaues. Returns an empty SmallVector if the list is empty. 53b56d1ec6SPeter Hawkins llvm::SmallVector<MlirValue> wrapOperands(std::optional<nb::list> operandList) { 54f22008edSArash Taheri-Dezfouli llvm::SmallVector<MlirValue> mlirOperands; 55f22008edSArash Taheri-Dezfouli 56b56d1ec6SPeter Hawkins if (!operandList || operandList->size() == 0) { 57f22008edSArash Taheri-Dezfouli return mlirOperands; 58f22008edSArash Taheri-Dezfouli } 59f22008edSArash Taheri-Dezfouli 60f22008edSArash Taheri-Dezfouli // Note: as the list may contain other lists this may not be final size. 61f22008edSArash Taheri-Dezfouli mlirOperands.reserve(operandList->size()); 62f22008edSArash Taheri-Dezfouli for (const auto &&it : llvm::enumerate(*operandList)) { 63e0ca7e99Smax if (it.value().is_none()) 64e0ca7e99Smax continue; 65e0ca7e99Smax 66f22008edSArash Taheri-Dezfouli PyValue *val; 67f22008edSArash Taheri-Dezfouli try { 68b56d1ec6SPeter Hawkins val = nb::cast<PyValue *>(it.value()); 69f22008edSArash Taheri-Dezfouli if (!val) 70b56d1ec6SPeter Hawkins throw nb::cast_error(); 71f22008edSArash Taheri-Dezfouli mlirOperands.push_back(val->get()); 72f22008edSArash Taheri-Dezfouli continue; 73b56d1ec6SPeter Hawkins } catch (nb::cast_error &err) { 74f22008edSArash Taheri-Dezfouli // Intentionally unhandled to try sequence below first. 75f22008edSArash Taheri-Dezfouli (void)err; 76f22008edSArash Taheri-Dezfouli } 77f22008edSArash Taheri-Dezfouli 78f22008edSArash Taheri-Dezfouli try { 79b56d1ec6SPeter Hawkins auto vals = nb::cast<nb::sequence>(it.value()); 80b56d1ec6SPeter Hawkins for (nb::handle v : vals) { 81f22008edSArash Taheri-Dezfouli try { 82b56d1ec6SPeter Hawkins val = nb::cast<PyValue *>(v); 83f22008edSArash Taheri-Dezfouli if (!val) 84b56d1ec6SPeter Hawkins throw nb::cast_error(); 85f22008edSArash Taheri-Dezfouli mlirOperands.push_back(val->get()); 86b56d1ec6SPeter Hawkins } catch (nb::cast_error &err) { 87b56d1ec6SPeter Hawkins throw nb::value_error( 88f22008edSArash Taheri-Dezfouli (llvm::Twine("Operand ") + llvm::Twine(it.index()) + 89f22008edSArash Taheri-Dezfouli " must be a Value or Sequence of Values (" + err.what() + ")") 90b56d1ec6SPeter Hawkins .str() 91b56d1ec6SPeter Hawkins .c_str()); 92f22008edSArash Taheri-Dezfouli } 93f22008edSArash Taheri-Dezfouli } 94f22008edSArash Taheri-Dezfouli continue; 95b56d1ec6SPeter Hawkins } catch (nb::cast_error &err) { 96b56d1ec6SPeter Hawkins throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + 97f22008edSArash Taheri-Dezfouli " must be a Value or Sequence of Values (" + 98f22008edSArash Taheri-Dezfouli err.what() + ")") 99b56d1ec6SPeter Hawkins .str() 100b56d1ec6SPeter Hawkins .c_str()); 101f22008edSArash Taheri-Dezfouli } 102f22008edSArash Taheri-Dezfouli 103b56d1ec6SPeter Hawkins throw nb::cast_error(); 104f22008edSArash Taheri-Dezfouli } 105f22008edSArash Taheri-Dezfouli 106f22008edSArash Taheri-Dezfouli return mlirOperands; 107f22008edSArash Taheri-Dezfouli } 108f22008edSArash Taheri-Dezfouli 109f22008edSArash Taheri-Dezfouli /// Takes in an optional vector of PyRegions and returns a SmallVector of 110f22008edSArash Taheri-Dezfouli /// MlirRegion. Returns an empty SmallVector if the list is empty. 111f22008edSArash Taheri-Dezfouli llvm::SmallVector<MlirRegion> 112f22008edSArash Taheri-Dezfouli wrapRegions(std::optional<std::vector<PyRegion>> regions) { 113f22008edSArash Taheri-Dezfouli llvm::SmallVector<MlirRegion> mlirRegions; 114f22008edSArash Taheri-Dezfouli 115f22008edSArash Taheri-Dezfouli if (regions) { 116f22008edSArash Taheri-Dezfouli mlirRegions.reserve(regions->size()); 117f22008edSArash Taheri-Dezfouli for (PyRegion ®ion : *regions) { 118f22008edSArash Taheri-Dezfouli mlirRegions.push_back(region); 119f22008edSArash Taheri-Dezfouli } 120f22008edSArash Taheri-Dezfouli } 121f22008edSArash Taheri-Dezfouli 122f22008edSArash Taheri-Dezfouli return mlirRegions; 123f22008edSArash Taheri-Dezfouli } 124f22008edSArash Taheri-Dezfouli 125f22008edSArash Taheri-Dezfouli } // namespace 126f22008edSArash Taheri-Dezfouli 12714c92070SAlex Zinenko /// CRTP base class for Python classes representing MLIR Op interfaces. 12814c92070SAlex Zinenko /// Interface hierarchies are flat so no base class is expected here. The 12914c92070SAlex Zinenko /// derived class is expected to define the following static fields: 13014c92070SAlex Zinenko /// - `const char *pyClassName` - the name of the Python class to create; 13114c92070SAlex Zinenko /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID 13214c92070SAlex Zinenko /// of the interface. 13314c92070SAlex Zinenko /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind 13414c92070SAlex Zinenko /// interface-specific methods. 13514c92070SAlex Zinenko /// 13614c92070SAlex Zinenko /// An interface class may be constructed from either an Operation/OpView object 13714c92070SAlex Zinenko /// or from a subclass of OpView. In the latter case, only the static interface 13814c92070SAlex Zinenko /// methods are available, similarly to calling ConcereteOp::staticMethod on the 13914c92070SAlex Zinenko /// C++ side. Implementations of concrete interfaces can use the `isStatic` 14014c92070SAlex Zinenko /// method to check whether the interface object was constructed from a class or 14114c92070SAlex Zinenko /// an operation/opview instance. The `getOpName` always succeeds and returns a 14214c92070SAlex Zinenko /// canonical name of the operation suitable for lookups. 14314c92070SAlex Zinenko template <typename ConcreteIface> 14414c92070SAlex Zinenko class PyConcreteOpInterface { 14514c92070SAlex Zinenko protected: 146b56d1ec6SPeter Hawkins using ClassTy = nb::class_<ConcreteIface>; 14714c92070SAlex Zinenko using GetTypeIDFunctionTy = MlirTypeID (*)(); 14814c92070SAlex Zinenko 14914c92070SAlex Zinenko public: 15014c92070SAlex Zinenko /// Constructs an interface instance from an object that is either an 15114c92070SAlex Zinenko /// operation or a subclass of OpView. In the latter case, only the static 15214c92070SAlex Zinenko /// methods of the interface are accessible to the caller. 153b56d1ec6SPeter Hawkins PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context) 1541fc096afSMehdi Amini : obj(std::move(object)) { 15514c92070SAlex Zinenko try { 156b56d1ec6SPeter Hawkins operation = &nb::cast<PyOperation &>(obj); 157b56d1ec6SPeter Hawkins } catch (nb::cast_error &) { 15814c92070SAlex Zinenko // Do nothing. 15914c92070SAlex Zinenko } 16014c92070SAlex Zinenko 16114c92070SAlex Zinenko try { 162b56d1ec6SPeter Hawkins operation = &nb::cast<PyOpView &>(obj).getOperation(); 163b56d1ec6SPeter Hawkins } catch (nb::cast_error &) { 16414c92070SAlex Zinenko // Do nothing. 16514c92070SAlex Zinenko } 16614c92070SAlex Zinenko 16714c92070SAlex Zinenko if (operation != nullptr) { 16814c92070SAlex Zinenko if (!mlirOperationImplementsInterface(*operation, 16914c92070SAlex Zinenko ConcreteIface::getInterfaceID())) { 17014c92070SAlex Zinenko std::string msg = "the operation does not implement "; 171b56d1ec6SPeter Hawkins throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); 17214c92070SAlex Zinenko } 17314c92070SAlex Zinenko 17414c92070SAlex Zinenko MlirIdentifier identifier = mlirOperationGetName(*operation); 17514c92070SAlex Zinenko MlirStringRef stringRef = mlirIdentifierStr(identifier); 17614c92070SAlex Zinenko opName = std::string(stringRef.data, stringRef.length); 17714c92070SAlex Zinenko } else { 17814c92070SAlex Zinenko try { 179b56d1ec6SPeter Hawkins opName = nb::cast<std::string>(obj.attr("OPERATION_NAME")); 180b56d1ec6SPeter Hawkins } catch (nb::cast_error &) { 181b56d1ec6SPeter Hawkins throw nb::type_error( 18214c92070SAlex Zinenko "Op interface does not refer to an operation or OpView class"); 18314c92070SAlex Zinenko } 18414c92070SAlex Zinenko 18514c92070SAlex Zinenko if (!mlirOperationImplementsInterfaceStatic( 18614c92070SAlex Zinenko mlirStringRefCreate(opName.data(), opName.length()), 18714c92070SAlex Zinenko context.resolve().get(), ConcreteIface::getInterfaceID())) { 18814c92070SAlex Zinenko std::string msg = "the operation does not implement "; 189b56d1ec6SPeter Hawkins throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); 19014c92070SAlex Zinenko } 19114c92070SAlex Zinenko } 19214c92070SAlex Zinenko } 19314c92070SAlex Zinenko 19414c92070SAlex Zinenko /// Creates the Python bindings for this class in the given module. 195b56d1ec6SPeter Hawkins static void bind(nb::module_ &m) { 196b56d1ec6SPeter Hawkins nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName); 197b56d1ec6SPeter Hawkins cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"), 198b56d1ec6SPeter Hawkins nb::arg("context").none() = nb::none(), constructorDoc) 199b56d1ec6SPeter Hawkins .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject, 20014c92070SAlex Zinenko operationDoc) 201b56d1ec6SPeter Hawkins .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc); 20214c92070SAlex Zinenko ConcreteIface::bindDerived(cls); 20314c92070SAlex Zinenko } 20414c92070SAlex Zinenko 20514c92070SAlex Zinenko /// Hook for derived classes to add class-specific bindings. 20614c92070SAlex Zinenko static void bindDerived(ClassTy &cls) {} 20714c92070SAlex Zinenko 20814c92070SAlex Zinenko /// Returns `true` if this object was constructed from a subclass of OpView 20914c92070SAlex Zinenko /// rather than from an operation instance. 21014c92070SAlex Zinenko bool isStatic() { return operation == nullptr; } 21114c92070SAlex Zinenko 21214c92070SAlex Zinenko /// Returns the operation instance from which this object was constructed. 21314c92070SAlex Zinenko /// Throws a type error if this object was constructed from a subclass of 21414c92070SAlex Zinenko /// OpView. 215b56d1ec6SPeter Hawkins nb::object getOperationObject() { 21614c92070SAlex Zinenko if (operation == nullptr) { 217b56d1ec6SPeter Hawkins throw nb::type_error("Cannot get an operation from a static interface"); 21814c92070SAlex Zinenko } 21914c92070SAlex Zinenko 22014c92070SAlex Zinenko return operation->getRef().releaseObject(); 22114c92070SAlex Zinenko } 22214c92070SAlex Zinenko 22314c92070SAlex Zinenko /// Returns the opview of the operation instance from which this object was 22414c92070SAlex Zinenko /// constructed. Throws a type error if this object was constructed form a 22514c92070SAlex Zinenko /// subclass of OpView. 226b56d1ec6SPeter Hawkins nb::object getOpView() { 22714c92070SAlex Zinenko if (operation == nullptr) { 228b56d1ec6SPeter Hawkins throw nb::type_error("Cannot get an opview from a static interface"); 22914c92070SAlex Zinenko } 23014c92070SAlex Zinenko 23114c92070SAlex Zinenko return operation->createOpView(); 23214c92070SAlex Zinenko } 23314c92070SAlex Zinenko 23414c92070SAlex Zinenko /// Returns the canonical name of the operation this interface is constructed 23514c92070SAlex Zinenko /// from. 23614c92070SAlex Zinenko const std::string &getOpName() { return opName; } 23714c92070SAlex Zinenko 23814c92070SAlex Zinenko private: 23914c92070SAlex Zinenko PyOperation *operation = nullptr; 24014c92070SAlex Zinenko std::string opName; 241b56d1ec6SPeter Hawkins nb::object obj; 24214c92070SAlex Zinenko }; 24314c92070SAlex Zinenko 244f22008edSArash Taheri-Dezfouli /// Python wrapper for InferTypeOpInterface. This interface has only static 24514c92070SAlex Zinenko /// methods. 24614c92070SAlex Zinenko class PyInferTypeOpInterface 24714c92070SAlex Zinenko : public PyConcreteOpInterface<PyInferTypeOpInterface> { 24814c92070SAlex Zinenko public: 24914c92070SAlex Zinenko using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface; 25014c92070SAlex Zinenko 25114c92070SAlex Zinenko constexpr static const char *pyClassName = "InferTypeOpInterface"; 25214c92070SAlex Zinenko constexpr static GetTypeIDFunctionTy getInterfaceID = 25314c92070SAlex Zinenko &mlirInferTypeOpInterfaceTypeID; 25414c92070SAlex Zinenko 25514c92070SAlex Zinenko /// C-style user-data structure for type appending callback. 25614c92070SAlex Zinenko struct AppendResultsCallbackData { 25714c92070SAlex Zinenko std::vector<PyType> &inferredTypes; 25814c92070SAlex Zinenko PyMlirContext &pyMlirContext; 25914c92070SAlex Zinenko }; 26014c92070SAlex Zinenko 26114c92070SAlex Zinenko /// Appends the types provided as the two first arguments to the user-data 26214c92070SAlex Zinenko /// structure (expects AppendResultsCallbackData). 26314c92070SAlex Zinenko static void appendResultsCallback(intptr_t nTypes, MlirType *types, 26414c92070SAlex Zinenko void *userData) { 26514c92070SAlex Zinenko auto *data = static_cast<AppendResultsCallbackData *>(userData); 26614c92070SAlex Zinenko data->inferredTypes.reserve(data->inferredTypes.size() + nTypes); 26714c92070SAlex Zinenko for (intptr_t i = 0; i < nTypes; ++i) { 268e5639b3fSMehdi Amini data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]); 26914c92070SAlex Zinenko } 27014c92070SAlex Zinenko } 27114c92070SAlex Zinenko 27214c92070SAlex Zinenko /// Given the arguments required to build an operation, attempts to infer its 273ee308c99SJacques Pienaar /// return types. Throws value_error on failure. 27414c92070SAlex Zinenko std::vector<PyType> 275b56d1ec6SPeter Hawkins inferReturnTypes(std::optional<nb::list> operandList, 2765e118f93SMehdi Amini std::optional<PyAttribute> attributes, void *properties, 2770a81ace0SKazu Hirata std::optional<std::vector<PyRegion>> regions, 27814c92070SAlex Zinenko DefaultingPyMlirContext context, 27914c92070SAlex Zinenko DefaultingPyLocation location) { 28089b0f1eeSMehdi Amini llvm::SmallVector<MlirValue> mlirOperands = 28189b0f1eeSMehdi Amini wrapOperands(std::move(operandList)); 28289b0f1eeSMehdi Amini llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions)); 28314c92070SAlex Zinenko 28414c92070SAlex Zinenko std::vector<PyType> inferredTypes; 28514c92070SAlex Zinenko PyMlirContext &pyContext = context.resolve(); 28614c92070SAlex Zinenko AppendResultsCallbackData data{inferredTypes, pyContext}; 28714c92070SAlex Zinenko MlirStringRef opNameRef = 28814c92070SAlex Zinenko mlirStringRefCreate(getOpName().data(), getOpName().length()); 28914c92070SAlex Zinenko MlirAttribute attributeDict = 29014c92070SAlex Zinenko attributes ? attributes->get() : mlirAttributeGetNull(); 29114c92070SAlex Zinenko 29214c92070SAlex Zinenko MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( 29314c92070SAlex Zinenko opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), 2945e118f93SMehdi Amini mlirOperands.data(), attributeDict, properties, mlirRegions.size(), 29514c92070SAlex Zinenko mlirRegions.data(), &appendResultsCallback, &data); 29614c92070SAlex Zinenko 29714c92070SAlex Zinenko if (mlirLogicalResultIsFailure(result)) { 298b56d1ec6SPeter Hawkins throw nb::value_error("Failed to infer result types"); 29914c92070SAlex Zinenko } 30014c92070SAlex Zinenko 30114c92070SAlex Zinenko return inferredTypes; 30214c92070SAlex Zinenko } 30314c92070SAlex Zinenko 30414c92070SAlex Zinenko static void bindDerived(ClassTy &cls) { 30514c92070SAlex Zinenko cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, 306b56d1ec6SPeter Hawkins nb::arg("operands").none() = nb::none(), 307b56d1ec6SPeter Hawkins nb::arg("attributes").none() = nb::none(), 308b56d1ec6SPeter Hawkins nb::arg("properties").none() = nb::none(), 309b56d1ec6SPeter Hawkins nb::arg("regions").none() = nb::none(), 310b56d1ec6SPeter Hawkins nb::arg("context").none() = nb::none(), 311b56d1ec6SPeter Hawkins nb::arg("loc").none() = nb::none(), inferReturnTypesDoc); 31214c92070SAlex Zinenko } 31314c92070SAlex Zinenko }; 31414c92070SAlex Zinenko 315f22008edSArash Taheri-Dezfouli /// Wrapper around an shaped type components. 316f22008edSArash Taheri-Dezfouli class PyShapedTypeComponents { 317f22008edSArash Taheri-Dezfouli public: 318f22008edSArash Taheri-Dezfouli PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} 319b56d1ec6SPeter Hawkins PyShapedTypeComponents(nb::list shape, MlirType elementType) 32089b0f1eeSMehdi Amini : shape(std::move(shape)), elementType(elementType), ranked(true) {} 321b56d1ec6SPeter Hawkins PyShapedTypeComponents(nb::list shape, MlirType elementType, 322f22008edSArash Taheri-Dezfouli MlirAttribute attribute) 32389b0f1eeSMehdi Amini : shape(std::move(shape)), elementType(elementType), attribute(attribute), 324f22008edSArash Taheri-Dezfouli ranked(true) {} 325f22008edSArash Taheri-Dezfouli PyShapedTypeComponents(PyShapedTypeComponents &) = delete; 326ea2e83afSAdrian Kuegel PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept 327f22008edSArash Taheri-Dezfouli : shape(other.shape), elementType(other.elementType), 328f22008edSArash Taheri-Dezfouli attribute(other.attribute), ranked(other.ranked) {} 329f22008edSArash Taheri-Dezfouli 330b56d1ec6SPeter Hawkins static void bind(nb::module_ &m) { 331b56d1ec6SPeter Hawkins nb::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents") 332b56d1ec6SPeter Hawkins .def_prop_ro( 333f22008edSArash Taheri-Dezfouli "element_type", 334bfb1ba75Smax [](PyShapedTypeComponents &self) { return self.elementType; }, 335f22008edSArash Taheri-Dezfouli "Returns the element type of the shaped type components.") 336f22008edSArash Taheri-Dezfouli .def_static( 337f22008edSArash Taheri-Dezfouli "get", 338f22008edSArash Taheri-Dezfouli [](PyType &elementType) { 339f22008edSArash Taheri-Dezfouli return PyShapedTypeComponents(elementType); 340f22008edSArash Taheri-Dezfouli }, 341b56d1ec6SPeter Hawkins nb::arg("element_type"), 342f22008edSArash Taheri-Dezfouli "Create an shaped type components object with only the element " 343f22008edSArash Taheri-Dezfouli "type.") 344f22008edSArash Taheri-Dezfouli .def_static( 345f22008edSArash Taheri-Dezfouli "get", 346b56d1ec6SPeter Hawkins [](nb::list shape, PyType &elementType) { 34789b0f1eeSMehdi Amini return PyShapedTypeComponents(std::move(shape), elementType); 348f22008edSArash Taheri-Dezfouli }, 349b56d1ec6SPeter Hawkins nb::arg("shape"), nb::arg("element_type"), 350f22008edSArash Taheri-Dezfouli "Create a ranked shaped type components object.") 351f22008edSArash Taheri-Dezfouli .def_static( 352f22008edSArash Taheri-Dezfouli "get", 353b56d1ec6SPeter Hawkins [](nb::list shape, PyType &elementType, PyAttribute &attribute) { 35489b0f1eeSMehdi Amini return PyShapedTypeComponents(std::move(shape), elementType, 35589b0f1eeSMehdi Amini attribute); 356f22008edSArash Taheri-Dezfouli }, 357b56d1ec6SPeter Hawkins nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"), 358f22008edSArash Taheri-Dezfouli "Create a ranked shaped type components object with attribute.") 359b56d1ec6SPeter Hawkins .def_prop_ro( 360f22008edSArash Taheri-Dezfouli "has_rank", 361f22008edSArash Taheri-Dezfouli [](PyShapedTypeComponents &self) -> bool { return self.ranked; }, 362f22008edSArash Taheri-Dezfouli "Returns whether the given shaped type component is ranked.") 363b56d1ec6SPeter Hawkins .def_prop_ro( 364f22008edSArash Taheri-Dezfouli "rank", 365b56d1ec6SPeter Hawkins [](PyShapedTypeComponents &self) -> nb::object { 366f22008edSArash Taheri-Dezfouli if (!self.ranked) { 367b56d1ec6SPeter Hawkins return nb::none(); 368f22008edSArash Taheri-Dezfouli } 369b56d1ec6SPeter Hawkins return nb::int_(self.shape.size()); 370f22008edSArash Taheri-Dezfouli }, 371f22008edSArash Taheri-Dezfouli "Returns the rank of the given ranked shaped type components. If " 372f22008edSArash Taheri-Dezfouli "the shaped type components does not have a rank, None is " 373f22008edSArash Taheri-Dezfouli "returned.") 374b56d1ec6SPeter Hawkins .def_prop_ro( 375f22008edSArash Taheri-Dezfouli "shape", 376b56d1ec6SPeter Hawkins [](PyShapedTypeComponents &self) -> nb::object { 377f22008edSArash Taheri-Dezfouli if (!self.ranked) { 378b56d1ec6SPeter Hawkins return nb::none(); 379f22008edSArash Taheri-Dezfouli } 380b56d1ec6SPeter Hawkins return nb::list(self.shape); 381f22008edSArash Taheri-Dezfouli }, 382f22008edSArash Taheri-Dezfouli "Returns the shape of the ranked shaped type components as a list " 383f22008edSArash Taheri-Dezfouli "of integers. Returns none if the shaped type component does not " 384f22008edSArash Taheri-Dezfouli "have a rank."); 385f22008edSArash Taheri-Dezfouli } 386f22008edSArash Taheri-Dezfouli 387b56d1ec6SPeter Hawkins nb::object getCapsule(); 388b56d1ec6SPeter Hawkins static PyShapedTypeComponents createFromCapsule(nb::object capsule); 389f22008edSArash Taheri-Dezfouli 390f22008edSArash Taheri-Dezfouli private: 391b56d1ec6SPeter Hawkins nb::list shape; 392f22008edSArash Taheri-Dezfouli MlirType elementType; 393f22008edSArash Taheri-Dezfouli MlirAttribute attribute; 394f22008edSArash Taheri-Dezfouli bool ranked{false}; 395f22008edSArash Taheri-Dezfouli }; 396f22008edSArash Taheri-Dezfouli 397f22008edSArash Taheri-Dezfouli /// Python wrapper for InferShapedTypeOpInterface. This interface has only 398f22008edSArash Taheri-Dezfouli /// static methods. 399f22008edSArash Taheri-Dezfouli class PyInferShapedTypeOpInterface 400f22008edSArash Taheri-Dezfouli : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> { 401f22008edSArash Taheri-Dezfouli public: 402f22008edSArash Taheri-Dezfouli using PyConcreteOpInterface< 403f22008edSArash Taheri-Dezfouli PyInferShapedTypeOpInterface>::PyConcreteOpInterface; 404f22008edSArash Taheri-Dezfouli 405f22008edSArash Taheri-Dezfouli constexpr static const char *pyClassName = "InferShapedTypeOpInterface"; 406f22008edSArash Taheri-Dezfouli constexpr static GetTypeIDFunctionTy getInterfaceID = 407f22008edSArash Taheri-Dezfouli &mlirInferShapedTypeOpInterfaceTypeID; 408f22008edSArash Taheri-Dezfouli 409f22008edSArash Taheri-Dezfouli /// C-style user-data structure for type appending callback. 410f22008edSArash Taheri-Dezfouli struct AppendResultsCallbackData { 411f22008edSArash Taheri-Dezfouli std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents; 412f22008edSArash Taheri-Dezfouli }; 413f22008edSArash Taheri-Dezfouli 414f22008edSArash Taheri-Dezfouli /// Appends the shaped type components provided as unpacked shape, element 415f22008edSArash Taheri-Dezfouli /// type, attribute to the user-data. 416f22008edSArash Taheri-Dezfouli static void appendResultsCallback(bool hasRank, intptr_t rank, 417f22008edSArash Taheri-Dezfouli const int64_t *shape, MlirType elementType, 418f22008edSArash Taheri-Dezfouli MlirAttribute attribute, void *userData) { 419f22008edSArash Taheri-Dezfouli auto *data = static_cast<AppendResultsCallbackData *>(userData); 420f22008edSArash Taheri-Dezfouli if (!hasRank) { 421f22008edSArash Taheri-Dezfouli data->inferredShapedTypeComponents.emplace_back(elementType); 422f22008edSArash Taheri-Dezfouli } else { 423b56d1ec6SPeter Hawkins nb::list shapeList; 424f22008edSArash Taheri-Dezfouli for (intptr_t i = 0; i < rank; ++i) { 425f22008edSArash Taheri-Dezfouli shapeList.append(shape[i]); 426f22008edSArash Taheri-Dezfouli } 427f22008edSArash Taheri-Dezfouli data->inferredShapedTypeComponents.emplace_back(shapeList, elementType, 428f22008edSArash Taheri-Dezfouli attribute); 429f22008edSArash Taheri-Dezfouli } 430f22008edSArash Taheri-Dezfouli } 431f22008edSArash Taheri-Dezfouli 432f22008edSArash Taheri-Dezfouli /// Given the arguments required to build an operation, attempts to infer the 433f22008edSArash Taheri-Dezfouli /// shaped type components. Throws value_error on failure. 434f22008edSArash Taheri-Dezfouli std::vector<PyShapedTypeComponents> inferReturnTypeComponents( 435b56d1ec6SPeter Hawkins std::optional<nb::list> operandList, 436f22008edSArash Taheri-Dezfouli std::optional<PyAttribute> attributes, void *properties, 437f22008edSArash Taheri-Dezfouli std::optional<std::vector<PyRegion>> regions, 438f22008edSArash Taheri-Dezfouli DefaultingPyMlirContext context, DefaultingPyLocation location) { 43989b0f1eeSMehdi Amini llvm::SmallVector<MlirValue> mlirOperands = 44089b0f1eeSMehdi Amini wrapOperands(std::move(operandList)); 44189b0f1eeSMehdi Amini llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions)); 442f22008edSArash Taheri-Dezfouli 443f22008edSArash Taheri-Dezfouli std::vector<PyShapedTypeComponents> inferredShapedTypeComponents; 444f22008edSArash Taheri-Dezfouli PyMlirContext &pyContext = context.resolve(); 445f22008edSArash Taheri-Dezfouli AppendResultsCallbackData data{inferredShapedTypeComponents}; 446f22008edSArash Taheri-Dezfouli MlirStringRef opNameRef = 447f22008edSArash Taheri-Dezfouli mlirStringRefCreate(getOpName().data(), getOpName().length()); 448f22008edSArash Taheri-Dezfouli MlirAttribute attributeDict = 449f22008edSArash Taheri-Dezfouli attributes ? attributes->get() : mlirAttributeGetNull(); 450f22008edSArash Taheri-Dezfouli 451f22008edSArash Taheri-Dezfouli MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes( 452f22008edSArash Taheri-Dezfouli opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), 453f22008edSArash Taheri-Dezfouli mlirOperands.data(), attributeDict, properties, mlirRegions.size(), 454f22008edSArash Taheri-Dezfouli mlirRegions.data(), &appendResultsCallback, &data); 455f22008edSArash Taheri-Dezfouli 456f22008edSArash Taheri-Dezfouli if (mlirLogicalResultIsFailure(result)) { 457b56d1ec6SPeter Hawkins throw nb::value_error("Failed to infer result shape type components"); 458f22008edSArash Taheri-Dezfouli } 459f22008edSArash Taheri-Dezfouli 460f22008edSArash Taheri-Dezfouli return inferredShapedTypeComponents; 461f22008edSArash Taheri-Dezfouli } 462f22008edSArash Taheri-Dezfouli 463f22008edSArash Taheri-Dezfouli static void bindDerived(ClassTy &cls) { 464f22008edSArash Taheri-Dezfouli cls.def("inferReturnTypeComponents", 465f22008edSArash Taheri-Dezfouli &PyInferShapedTypeOpInterface::inferReturnTypeComponents, 466b56d1ec6SPeter Hawkins nb::arg("operands").none() = nb::none(), 467b56d1ec6SPeter Hawkins nb::arg("attributes").none() = nb::none(), 468b56d1ec6SPeter Hawkins nb::arg("regions").none() = nb::none(), 469b56d1ec6SPeter Hawkins nb::arg("properties").none() = nb::none(), 470b56d1ec6SPeter Hawkins nb::arg("context").none() = nb::none(), 471b56d1ec6SPeter Hawkins nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc); 472f22008edSArash Taheri-Dezfouli } 473f22008edSArash Taheri-Dezfouli }; 474f22008edSArash Taheri-Dezfouli 475b56d1ec6SPeter Hawkins void populateIRInterfaces(nb::module_ &m) { 476f22008edSArash Taheri-Dezfouli PyInferTypeOpInterface::bind(m); 477f22008edSArash Taheri-Dezfouli PyShapedTypeComponents::bind(m); 478f22008edSArash Taheri-Dezfouli PyInferShapedTypeOpInterface::bind(m); 479f22008edSArash Taheri-Dezfouli } 48014c92070SAlex Zinenko 48114c92070SAlex Zinenko } // namespace python 48214c92070SAlex Zinenko } // namespace mlir 483