xref: /llvm-project/mlir/lib/Bindings/Python/IRInterfaces.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
114c92070SAlex Zinenko //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
214c92070SAlex Zinenko //
314c92070SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
414c92070SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
514c92070SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
614c92070SAlex Zinenko //
714c92070SAlex Zinenko //===----------------------------------------------------------------------===//
814c92070SAlex Zinenko 
9dc81dfa0SMehdi Amini #include <cstdint>
10a1fe1f5fSKazu Hirata #include <optional>
11dc81dfa0SMehdi Amini #include <string>
12f22008edSArash Taheri-Dezfouli #include <utility>
13dc81dfa0SMehdi Amini #include <vector>
141fc096afSMehdi Amini 
1514c92070SAlex Zinenko #include "IRModule.h"
1614c92070SAlex Zinenko #include "mlir-c/BuiltinAttributes.h"
17dc81dfa0SMehdi Amini #include "mlir-c/IR.h"
1814c92070SAlex Zinenko #include "mlir-c/Interfaces.h"
19dc81dfa0SMehdi Amini #include "mlir-c/Support.h"
20*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h"
21ee308c99SJacques Pienaar #include "llvm/ADT/STLExtras.h"
22dc81dfa0SMehdi Amini #include "llvm/ADT/SmallVector.h"
2314c92070SAlex Zinenko 
24b56d1ec6SPeter Hawkins namespace nb = nanobind;
2514c92070SAlex Zinenko 
2614c92070SAlex Zinenko namespace mlir {
2714c92070SAlex Zinenko namespace python {
2814c92070SAlex Zinenko 
2914c92070SAlex Zinenko constexpr static const char *constructorDoc =
3014c92070SAlex Zinenko     R"(Creates an interface from a given operation/opview object or from a
3114c92070SAlex Zinenko subclass of OpView. Raises ValueError if the operation does not implement the
3214c92070SAlex Zinenko interface.)";
3314c92070SAlex Zinenko 
3414c92070SAlex Zinenko constexpr static const char *operationDoc =
3514c92070SAlex Zinenko     R"(Returns an Operation for which the interface was constructed.)";
3614c92070SAlex Zinenko 
3714c92070SAlex Zinenko constexpr static const char *opviewDoc =
3814c92070SAlex Zinenko     R"(Returns an OpView subclass _instance_ for which the interface was
3914c92070SAlex Zinenko constructed)";
4014c92070SAlex Zinenko 
4114c92070SAlex Zinenko constexpr static const char *inferReturnTypesDoc =
4214c92070SAlex Zinenko     R"(Given the arguments required to build an operation, attempts to infer
4314c92070SAlex Zinenko its return types. Raises ValueError on failure.)";
4414c92070SAlex Zinenko 
45f22008edSArash Taheri-Dezfouli constexpr static const char *inferReturnTypeComponentsDoc =
46f22008edSArash Taheri-Dezfouli     R"(Given the arguments required to build an operation, attempts to infer
47f22008edSArash Taheri-Dezfouli its return shaped type components. Raises ValueError on failure.)";
48f22008edSArash Taheri-Dezfouli 
49f22008edSArash Taheri-Dezfouli namespace {
50f22008edSArash Taheri-Dezfouli 
51f22008edSArash Taheri-Dezfouli /// Takes in an optional ist of operands and converts them into a SmallVector
52f22008edSArash Taheri-Dezfouli /// of MlirVlaues. Returns an empty SmallVector if the list is empty.
53b56d1ec6SPeter Hawkins llvm::SmallVector<MlirValue> wrapOperands(std::optional<nb::list> operandList) {
54f22008edSArash Taheri-Dezfouli   llvm::SmallVector<MlirValue> mlirOperands;
55f22008edSArash Taheri-Dezfouli 
56b56d1ec6SPeter Hawkins   if (!operandList || operandList->size() == 0) {
57f22008edSArash Taheri-Dezfouli     return mlirOperands;
58f22008edSArash Taheri-Dezfouli   }
59f22008edSArash Taheri-Dezfouli 
60f22008edSArash Taheri-Dezfouli   // Note: as the list may contain other lists this may not be final size.
61f22008edSArash Taheri-Dezfouli   mlirOperands.reserve(operandList->size());
62f22008edSArash Taheri-Dezfouli   for (const auto &&it : llvm::enumerate(*operandList)) {
63e0ca7e99Smax     if (it.value().is_none())
64e0ca7e99Smax       continue;
65e0ca7e99Smax 
66f22008edSArash Taheri-Dezfouli     PyValue *val;
67f22008edSArash Taheri-Dezfouli     try {
68b56d1ec6SPeter Hawkins       val = nb::cast<PyValue *>(it.value());
69f22008edSArash Taheri-Dezfouli       if (!val)
70b56d1ec6SPeter Hawkins         throw nb::cast_error();
71f22008edSArash Taheri-Dezfouli       mlirOperands.push_back(val->get());
72f22008edSArash Taheri-Dezfouli       continue;
73b56d1ec6SPeter Hawkins     } catch (nb::cast_error &err) {
74f22008edSArash Taheri-Dezfouli       // Intentionally unhandled to try sequence below first.
75f22008edSArash Taheri-Dezfouli       (void)err;
76f22008edSArash Taheri-Dezfouli     }
77f22008edSArash Taheri-Dezfouli 
78f22008edSArash Taheri-Dezfouli     try {
79b56d1ec6SPeter Hawkins       auto vals = nb::cast<nb::sequence>(it.value());
80b56d1ec6SPeter Hawkins       for (nb::handle v : vals) {
81f22008edSArash Taheri-Dezfouli         try {
82b56d1ec6SPeter Hawkins           val = nb::cast<PyValue *>(v);
83f22008edSArash Taheri-Dezfouli           if (!val)
84b56d1ec6SPeter Hawkins             throw nb::cast_error();
85f22008edSArash Taheri-Dezfouli           mlirOperands.push_back(val->get());
86b56d1ec6SPeter Hawkins         } catch (nb::cast_error &err) {
87b56d1ec6SPeter Hawkins           throw nb::value_error(
88f22008edSArash Taheri-Dezfouli               (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
89f22008edSArash Taheri-Dezfouli                " must be a Value or Sequence of Values (" + err.what() + ")")
90b56d1ec6SPeter Hawkins                   .str()
91b56d1ec6SPeter Hawkins                   .c_str());
92f22008edSArash Taheri-Dezfouli         }
93f22008edSArash Taheri-Dezfouli       }
94f22008edSArash Taheri-Dezfouli       continue;
95b56d1ec6SPeter Hawkins     } catch (nb::cast_error &err) {
96b56d1ec6SPeter Hawkins       throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
97f22008edSArash Taheri-Dezfouli                              " must be a Value or Sequence of Values (" +
98f22008edSArash Taheri-Dezfouli                              err.what() + ")")
99b56d1ec6SPeter Hawkins                                 .str()
100b56d1ec6SPeter Hawkins                                 .c_str());
101f22008edSArash Taheri-Dezfouli     }
102f22008edSArash Taheri-Dezfouli 
103b56d1ec6SPeter Hawkins     throw nb::cast_error();
104f22008edSArash Taheri-Dezfouli   }
105f22008edSArash Taheri-Dezfouli 
106f22008edSArash Taheri-Dezfouli   return mlirOperands;
107f22008edSArash Taheri-Dezfouli }
108f22008edSArash Taheri-Dezfouli 
109f22008edSArash Taheri-Dezfouli /// Takes in an optional vector of PyRegions and returns a SmallVector of
110f22008edSArash Taheri-Dezfouli /// MlirRegion. Returns an empty SmallVector if the list is empty.
111f22008edSArash Taheri-Dezfouli llvm::SmallVector<MlirRegion>
112f22008edSArash Taheri-Dezfouli wrapRegions(std::optional<std::vector<PyRegion>> regions) {
113f22008edSArash Taheri-Dezfouli   llvm::SmallVector<MlirRegion> mlirRegions;
114f22008edSArash Taheri-Dezfouli 
115f22008edSArash Taheri-Dezfouli   if (regions) {
116f22008edSArash Taheri-Dezfouli     mlirRegions.reserve(regions->size());
117f22008edSArash Taheri-Dezfouli     for (PyRegion &region : *regions) {
118f22008edSArash Taheri-Dezfouli       mlirRegions.push_back(region);
119f22008edSArash Taheri-Dezfouli     }
120f22008edSArash Taheri-Dezfouli   }
121f22008edSArash Taheri-Dezfouli 
122f22008edSArash Taheri-Dezfouli   return mlirRegions;
123f22008edSArash Taheri-Dezfouli }
124f22008edSArash Taheri-Dezfouli 
125f22008edSArash Taheri-Dezfouli } // namespace
126f22008edSArash Taheri-Dezfouli 
12714c92070SAlex Zinenko /// CRTP base class for Python classes representing MLIR Op interfaces.
12814c92070SAlex Zinenko /// Interface hierarchies are flat so no base class is expected here. The
12914c92070SAlex Zinenko /// derived class is expected to define the following static fields:
13014c92070SAlex Zinenko ///  - `const char *pyClassName` - the name of the Python class to create;
13114c92070SAlex Zinenko ///  - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
13214c92070SAlex Zinenko ///    of the interface.
13314c92070SAlex Zinenko /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
13414c92070SAlex Zinenko /// interface-specific methods.
13514c92070SAlex Zinenko ///
13614c92070SAlex Zinenko /// An interface class may be constructed from either an Operation/OpView object
13714c92070SAlex Zinenko /// or from a subclass of OpView. In the latter case, only the static interface
13814c92070SAlex Zinenko /// methods are available, similarly to calling ConcereteOp::staticMethod on the
13914c92070SAlex Zinenko /// C++ side. Implementations of concrete interfaces can use the `isStatic`
14014c92070SAlex Zinenko /// method to check whether the interface object was constructed from a class or
14114c92070SAlex Zinenko /// an operation/opview instance. The `getOpName` always succeeds and returns a
14214c92070SAlex Zinenko /// canonical name of the operation suitable for lookups.
14314c92070SAlex Zinenko template <typename ConcreteIface>
14414c92070SAlex Zinenko class PyConcreteOpInterface {
14514c92070SAlex Zinenko protected:
146b56d1ec6SPeter Hawkins   using ClassTy = nb::class_<ConcreteIface>;
14714c92070SAlex Zinenko   using GetTypeIDFunctionTy = MlirTypeID (*)();
14814c92070SAlex Zinenko 
14914c92070SAlex Zinenko public:
15014c92070SAlex Zinenko   /// Constructs an interface instance from an object that is either an
15114c92070SAlex Zinenko   /// operation or a subclass of OpView. In the latter case, only the static
15214c92070SAlex Zinenko   /// methods of the interface are accessible to the caller.
153b56d1ec6SPeter Hawkins   PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context)
1541fc096afSMehdi Amini       : obj(std::move(object)) {
15514c92070SAlex Zinenko     try {
156b56d1ec6SPeter Hawkins       operation = &nb::cast<PyOperation &>(obj);
157b56d1ec6SPeter Hawkins     } catch (nb::cast_error &) {
15814c92070SAlex Zinenko       // Do nothing.
15914c92070SAlex Zinenko     }
16014c92070SAlex Zinenko 
16114c92070SAlex Zinenko     try {
162b56d1ec6SPeter Hawkins       operation = &nb::cast<PyOpView &>(obj).getOperation();
163b56d1ec6SPeter Hawkins     } catch (nb::cast_error &) {
16414c92070SAlex Zinenko       // Do nothing.
16514c92070SAlex Zinenko     }
16614c92070SAlex Zinenko 
16714c92070SAlex Zinenko     if (operation != nullptr) {
16814c92070SAlex Zinenko       if (!mlirOperationImplementsInterface(*operation,
16914c92070SAlex Zinenko                                             ConcreteIface::getInterfaceID())) {
17014c92070SAlex Zinenko         std::string msg = "the operation does not implement ";
171b56d1ec6SPeter Hawkins         throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
17214c92070SAlex Zinenko       }
17314c92070SAlex Zinenko 
17414c92070SAlex Zinenko       MlirIdentifier identifier = mlirOperationGetName(*operation);
17514c92070SAlex Zinenko       MlirStringRef stringRef = mlirIdentifierStr(identifier);
17614c92070SAlex Zinenko       opName = std::string(stringRef.data, stringRef.length);
17714c92070SAlex Zinenko     } else {
17814c92070SAlex Zinenko       try {
179b56d1ec6SPeter Hawkins         opName = nb::cast<std::string>(obj.attr("OPERATION_NAME"));
180b56d1ec6SPeter Hawkins       } catch (nb::cast_error &) {
181b56d1ec6SPeter Hawkins         throw nb::type_error(
18214c92070SAlex Zinenko             "Op interface does not refer to an operation or OpView class");
18314c92070SAlex Zinenko       }
18414c92070SAlex Zinenko 
18514c92070SAlex Zinenko       if (!mlirOperationImplementsInterfaceStatic(
18614c92070SAlex Zinenko               mlirStringRefCreate(opName.data(), opName.length()),
18714c92070SAlex Zinenko               context.resolve().get(), ConcreteIface::getInterfaceID())) {
18814c92070SAlex Zinenko         std::string msg = "the operation does not implement ";
189b56d1ec6SPeter Hawkins         throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
19014c92070SAlex Zinenko       }
19114c92070SAlex Zinenko     }
19214c92070SAlex Zinenko   }
19314c92070SAlex Zinenko 
19414c92070SAlex Zinenko   /// Creates the Python bindings for this class in the given module.
195b56d1ec6SPeter Hawkins   static void bind(nb::module_ &m) {
196b56d1ec6SPeter Hawkins     nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
197b56d1ec6SPeter Hawkins     cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
198b56d1ec6SPeter Hawkins             nb::arg("context").none() = nb::none(), constructorDoc)
199b56d1ec6SPeter Hawkins         .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
20014c92070SAlex Zinenko                      operationDoc)
201b56d1ec6SPeter Hawkins         .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
20214c92070SAlex Zinenko     ConcreteIface::bindDerived(cls);
20314c92070SAlex Zinenko   }
20414c92070SAlex Zinenko 
20514c92070SAlex Zinenko   /// Hook for derived classes to add class-specific bindings.
20614c92070SAlex Zinenko   static void bindDerived(ClassTy &cls) {}
20714c92070SAlex Zinenko 
20814c92070SAlex Zinenko   /// Returns `true` if this object was constructed from a subclass of OpView
20914c92070SAlex Zinenko   /// rather than from an operation instance.
21014c92070SAlex Zinenko   bool isStatic() { return operation == nullptr; }
21114c92070SAlex Zinenko 
21214c92070SAlex Zinenko   /// Returns the operation instance from which this object was constructed.
21314c92070SAlex Zinenko   /// Throws a type error if this object was constructed from a subclass of
21414c92070SAlex Zinenko   /// OpView.
215b56d1ec6SPeter Hawkins   nb::object getOperationObject() {
21614c92070SAlex Zinenko     if (operation == nullptr) {
217b56d1ec6SPeter Hawkins       throw nb::type_error("Cannot get an operation from a static interface");
21814c92070SAlex Zinenko     }
21914c92070SAlex Zinenko 
22014c92070SAlex Zinenko     return operation->getRef().releaseObject();
22114c92070SAlex Zinenko   }
22214c92070SAlex Zinenko 
22314c92070SAlex Zinenko   /// Returns the opview of the operation instance from which this object was
22414c92070SAlex Zinenko   /// constructed. Throws a type error if this object was constructed form a
22514c92070SAlex Zinenko   /// subclass of OpView.
226b56d1ec6SPeter Hawkins   nb::object getOpView() {
22714c92070SAlex Zinenko     if (operation == nullptr) {
228b56d1ec6SPeter Hawkins       throw nb::type_error("Cannot get an opview from a static interface");
22914c92070SAlex Zinenko     }
23014c92070SAlex Zinenko 
23114c92070SAlex Zinenko     return operation->createOpView();
23214c92070SAlex Zinenko   }
23314c92070SAlex Zinenko 
23414c92070SAlex Zinenko   /// Returns the canonical name of the operation this interface is constructed
23514c92070SAlex Zinenko   /// from.
23614c92070SAlex Zinenko   const std::string &getOpName() { return opName; }
23714c92070SAlex Zinenko 
23814c92070SAlex Zinenko private:
23914c92070SAlex Zinenko   PyOperation *operation = nullptr;
24014c92070SAlex Zinenko   std::string opName;
241b56d1ec6SPeter Hawkins   nb::object obj;
24214c92070SAlex Zinenko };
24314c92070SAlex Zinenko 
244f22008edSArash Taheri-Dezfouli /// Python wrapper for InferTypeOpInterface. This interface has only static
24514c92070SAlex Zinenko /// methods.
24614c92070SAlex Zinenko class PyInferTypeOpInterface
24714c92070SAlex Zinenko     : public PyConcreteOpInterface<PyInferTypeOpInterface> {
24814c92070SAlex Zinenko public:
24914c92070SAlex Zinenko   using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
25014c92070SAlex Zinenko 
25114c92070SAlex Zinenko   constexpr static const char *pyClassName = "InferTypeOpInterface";
25214c92070SAlex Zinenko   constexpr static GetTypeIDFunctionTy getInterfaceID =
25314c92070SAlex Zinenko       &mlirInferTypeOpInterfaceTypeID;
25414c92070SAlex Zinenko 
25514c92070SAlex Zinenko   /// C-style user-data structure for type appending callback.
25614c92070SAlex Zinenko   struct AppendResultsCallbackData {
25714c92070SAlex Zinenko     std::vector<PyType> &inferredTypes;
25814c92070SAlex Zinenko     PyMlirContext &pyMlirContext;
25914c92070SAlex Zinenko   };
26014c92070SAlex Zinenko 
26114c92070SAlex Zinenko   /// Appends the types provided as the two first arguments to the user-data
26214c92070SAlex Zinenko   /// structure (expects AppendResultsCallbackData).
26314c92070SAlex Zinenko   static void appendResultsCallback(intptr_t nTypes, MlirType *types,
26414c92070SAlex Zinenko                                     void *userData) {
26514c92070SAlex Zinenko     auto *data = static_cast<AppendResultsCallbackData *>(userData);
26614c92070SAlex Zinenko     data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
26714c92070SAlex Zinenko     for (intptr_t i = 0; i < nTypes; ++i) {
268e5639b3fSMehdi Amini       data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
26914c92070SAlex Zinenko     }
27014c92070SAlex Zinenko   }
27114c92070SAlex Zinenko 
27214c92070SAlex Zinenko   /// Given the arguments required to build an operation, attempts to infer its
273ee308c99SJacques Pienaar   /// return types. Throws value_error on failure.
27414c92070SAlex Zinenko   std::vector<PyType>
275b56d1ec6SPeter Hawkins   inferReturnTypes(std::optional<nb::list> operandList,
2765e118f93SMehdi Amini                    std::optional<PyAttribute> attributes, void *properties,
2770a81ace0SKazu Hirata                    std::optional<std::vector<PyRegion>> regions,
27814c92070SAlex Zinenko                    DefaultingPyMlirContext context,
27914c92070SAlex Zinenko                    DefaultingPyLocation location) {
28089b0f1eeSMehdi Amini     llvm::SmallVector<MlirValue> mlirOperands =
28189b0f1eeSMehdi Amini         wrapOperands(std::move(operandList));
28289b0f1eeSMehdi Amini     llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
28314c92070SAlex Zinenko 
28414c92070SAlex Zinenko     std::vector<PyType> inferredTypes;
28514c92070SAlex Zinenko     PyMlirContext &pyContext = context.resolve();
28614c92070SAlex Zinenko     AppendResultsCallbackData data{inferredTypes, pyContext};
28714c92070SAlex Zinenko     MlirStringRef opNameRef =
28814c92070SAlex Zinenko         mlirStringRefCreate(getOpName().data(), getOpName().length());
28914c92070SAlex Zinenko     MlirAttribute attributeDict =
29014c92070SAlex Zinenko         attributes ? attributes->get() : mlirAttributeGetNull();
29114c92070SAlex Zinenko 
29214c92070SAlex Zinenko     MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
29314c92070SAlex Zinenko         opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
2945e118f93SMehdi Amini         mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
29514c92070SAlex Zinenko         mlirRegions.data(), &appendResultsCallback, &data);
29614c92070SAlex Zinenko 
29714c92070SAlex Zinenko     if (mlirLogicalResultIsFailure(result)) {
298b56d1ec6SPeter Hawkins       throw nb::value_error("Failed to infer result types");
29914c92070SAlex Zinenko     }
30014c92070SAlex Zinenko 
30114c92070SAlex Zinenko     return inferredTypes;
30214c92070SAlex Zinenko   }
30314c92070SAlex Zinenko 
30414c92070SAlex Zinenko   static void bindDerived(ClassTy &cls) {
30514c92070SAlex Zinenko     cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
306b56d1ec6SPeter Hawkins             nb::arg("operands").none() = nb::none(),
307b56d1ec6SPeter Hawkins             nb::arg("attributes").none() = nb::none(),
308b56d1ec6SPeter Hawkins             nb::arg("properties").none() = nb::none(),
309b56d1ec6SPeter Hawkins             nb::arg("regions").none() = nb::none(),
310b56d1ec6SPeter Hawkins             nb::arg("context").none() = nb::none(),
311b56d1ec6SPeter Hawkins             nb::arg("loc").none() = nb::none(), inferReturnTypesDoc);
31214c92070SAlex Zinenko   }
31314c92070SAlex Zinenko };
31414c92070SAlex Zinenko 
315f22008edSArash Taheri-Dezfouli /// Wrapper around an shaped type components.
316f22008edSArash Taheri-Dezfouli class PyShapedTypeComponents {
317f22008edSArash Taheri-Dezfouli public:
318f22008edSArash Taheri-Dezfouli   PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
319b56d1ec6SPeter Hawkins   PyShapedTypeComponents(nb::list shape, MlirType elementType)
32089b0f1eeSMehdi Amini       : shape(std::move(shape)), elementType(elementType), ranked(true) {}
321b56d1ec6SPeter Hawkins   PyShapedTypeComponents(nb::list shape, MlirType elementType,
322f22008edSArash Taheri-Dezfouli                          MlirAttribute attribute)
32389b0f1eeSMehdi Amini       : shape(std::move(shape)), elementType(elementType), attribute(attribute),
324f22008edSArash Taheri-Dezfouli         ranked(true) {}
325f22008edSArash Taheri-Dezfouli   PyShapedTypeComponents(PyShapedTypeComponents &) = delete;
326ea2e83afSAdrian Kuegel   PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept
327f22008edSArash Taheri-Dezfouli       : shape(other.shape), elementType(other.elementType),
328f22008edSArash Taheri-Dezfouli         attribute(other.attribute), ranked(other.ranked) {}
329f22008edSArash Taheri-Dezfouli 
330b56d1ec6SPeter Hawkins   static void bind(nb::module_ &m) {
331b56d1ec6SPeter Hawkins     nb::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents")
332b56d1ec6SPeter Hawkins         .def_prop_ro(
333f22008edSArash Taheri-Dezfouli             "element_type",
334bfb1ba75Smax             [](PyShapedTypeComponents &self) { return self.elementType; },
335f22008edSArash Taheri-Dezfouli             "Returns the element type of the shaped type components.")
336f22008edSArash Taheri-Dezfouli         .def_static(
337f22008edSArash Taheri-Dezfouli             "get",
338f22008edSArash Taheri-Dezfouli             [](PyType &elementType) {
339f22008edSArash Taheri-Dezfouli               return PyShapedTypeComponents(elementType);
340f22008edSArash Taheri-Dezfouli             },
341b56d1ec6SPeter Hawkins             nb::arg("element_type"),
342f22008edSArash Taheri-Dezfouli             "Create an shaped type components object with only the element "
343f22008edSArash Taheri-Dezfouli             "type.")
344f22008edSArash Taheri-Dezfouli         .def_static(
345f22008edSArash Taheri-Dezfouli             "get",
346b56d1ec6SPeter Hawkins             [](nb::list shape, PyType &elementType) {
34789b0f1eeSMehdi Amini               return PyShapedTypeComponents(std::move(shape), elementType);
348f22008edSArash Taheri-Dezfouli             },
349b56d1ec6SPeter Hawkins             nb::arg("shape"), nb::arg("element_type"),
350f22008edSArash Taheri-Dezfouli             "Create a ranked shaped type components object.")
351f22008edSArash Taheri-Dezfouli         .def_static(
352f22008edSArash Taheri-Dezfouli             "get",
353b56d1ec6SPeter Hawkins             [](nb::list shape, PyType &elementType, PyAttribute &attribute) {
35489b0f1eeSMehdi Amini               return PyShapedTypeComponents(std::move(shape), elementType,
35589b0f1eeSMehdi Amini                                             attribute);
356f22008edSArash Taheri-Dezfouli             },
357b56d1ec6SPeter Hawkins             nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"),
358f22008edSArash Taheri-Dezfouli             "Create a ranked shaped type components object with attribute.")
359b56d1ec6SPeter Hawkins         .def_prop_ro(
360f22008edSArash Taheri-Dezfouli             "has_rank",
361f22008edSArash Taheri-Dezfouli             [](PyShapedTypeComponents &self) -> bool { return self.ranked; },
362f22008edSArash Taheri-Dezfouli             "Returns whether the given shaped type component is ranked.")
363b56d1ec6SPeter Hawkins         .def_prop_ro(
364f22008edSArash Taheri-Dezfouli             "rank",
365b56d1ec6SPeter Hawkins             [](PyShapedTypeComponents &self) -> nb::object {
366f22008edSArash Taheri-Dezfouli               if (!self.ranked) {
367b56d1ec6SPeter Hawkins                 return nb::none();
368f22008edSArash Taheri-Dezfouli               }
369b56d1ec6SPeter Hawkins               return nb::int_(self.shape.size());
370f22008edSArash Taheri-Dezfouli             },
371f22008edSArash Taheri-Dezfouli             "Returns the rank of the given ranked shaped type components. If "
372f22008edSArash Taheri-Dezfouli             "the shaped type components does not have a rank, None is "
373f22008edSArash Taheri-Dezfouli             "returned.")
374b56d1ec6SPeter Hawkins         .def_prop_ro(
375f22008edSArash Taheri-Dezfouli             "shape",
376b56d1ec6SPeter Hawkins             [](PyShapedTypeComponents &self) -> nb::object {
377f22008edSArash Taheri-Dezfouli               if (!self.ranked) {
378b56d1ec6SPeter Hawkins                 return nb::none();
379f22008edSArash Taheri-Dezfouli               }
380b56d1ec6SPeter Hawkins               return nb::list(self.shape);
381f22008edSArash Taheri-Dezfouli             },
382f22008edSArash Taheri-Dezfouli             "Returns the shape of the ranked shaped type components as a list "
383f22008edSArash Taheri-Dezfouli             "of integers. Returns none if the shaped type component does not "
384f22008edSArash Taheri-Dezfouli             "have a rank.");
385f22008edSArash Taheri-Dezfouli   }
386f22008edSArash Taheri-Dezfouli 
387b56d1ec6SPeter Hawkins   nb::object getCapsule();
388b56d1ec6SPeter Hawkins   static PyShapedTypeComponents createFromCapsule(nb::object capsule);
389f22008edSArash Taheri-Dezfouli 
390f22008edSArash Taheri-Dezfouli private:
391b56d1ec6SPeter Hawkins   nb::list shape;
392f22008edSArash Taheri-Dezfouli   MlirType elementType;
393f22008edSArash Taheri-Dezfouli   MlirAttribute attribute;
394f22008edSArash Taheri-Dezfouli   bool ranked{false};
395f22008edSArash Taheri-Dezfouli };
396f22008edSArash Taheri-Dezfouli 
397f22008edSArash Taheri-Dezfouli /// Python wrapper for InferShapedTypeOpInterface. This interface has only
398f22008edSArash Taheri-Dezfouli /// static methods.
399f22008edSArash Taheri-Dezfouli class PyInferShapedTypeOpInterface
400f22008edSArash Taheri-Dezfouli     : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> {
401f22008edSArash Taheri-Dezfouli public:
402f22008edSArash Taheri-Dezfouli   using PyConcreteOpInterface<
403f22008edSArash Taheri-Dezfouli       PyInferShapedTypeOpInterface>::PyConcreteOpInterface;
404f22008edSArash Taheri-Dezfouli 
405f22008edSArash Taheri-Dezfouli   constexpr static const char *pyClassName = "InferShapedTypeOpInterface";
406f22008edSArash Taheri-Dezfouli   constexpr static GetTypeIDFunctionTy getInterfaceID =
407f22008edSArash Taheri-Dezfouli       &mlirInferShapedTypeOpInterfaceTypeID;
408f22008edSArash Taheri-Dezfouli 
409f22008edSArash Taheri-Dezfouli   /// C-style user-data structure for type appending callback.
410f22008edSArash Taheri-Dezfouli   struct AppendResultsCallbackData {
411f22008edSArash Taheri-Dezfouli     std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents;
412f22008edSArash Taheri-Dezfouli   };
413f22008edSArash Taheri-Dezfouli 
414f22008edSArash Taheri-Dezfouli   /// Appends the shaped type components provided as unpacked shape, element
415f22008edSArash Taheri-Dezfouli   /// type, attribute to the user-data.
416f22008edSArash Taheri-Dezfouli   static void appendResultsCallback(bool hasRank, intptr_t rank,
417f22008edSArash Taheri-Dezfouli                                     const int64_t *shape, MlirType elementType,
418f22008edSArash Taheri-Dezfouli                                     MlirAttribute attribute, void *userData) {
419f22008edSArash Taheri-Dezfouli     auto *data = static_cast<AppendResultsCallbackData *>(userData);
420f22008edSArash Taheri-Dezfouli     if (!hasRank) {
421f22008edSArash Taheri-Dezfouli       data->inferredShapedTypeComponents.emplace_back(elementType);
422f22008edSArash Taheri-Dezfouli     } else {
423b56d1ec6SPeter Hawkins       nb::list shapeList;
424f22008edSArash Taheri-Dezfouli       for (intptr_t i = 0; i < rank; ++i) {
425f22008edSArash Taheri-Dezfouli         shapeList.append(shape[i]);
426f22008edSArash Taheri-Dezfouli       }
427f22008edSArash Taheri-Dezfouli       data->inferredShapedTypeComponents.emplace_back(shapeList, elementType,
428f22008edSArash Taheri-Dezfouli                                                       attribute);
429f22008edSArash Taheri-Dezfouli     }
430f22008edSArash Taheri-Dezfouli   }
431f22008edSArash Taheri-Dezfouli 
432f22008edSArash Taheri-Dezfouli   /// Given the arguments required to build an operation, attempts to infer the
433f22008edSArash Taheri-Dezfouli   /// shaped type components. Throws value_error on failure.
434f22008edSArash Taheri-Dezfouli   std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
435b56d1ec6SPeter Hawkins       std::optional<nb::list> operandList,
436f22008edSArash Taheri-Dezfouli       std::optional<PyAttribute> attributes, void *properties,
437f22008edSArash Taheri-Dezfouli       std::optional<std::vector<PyRegion>> regions,
438f22008edSArash Taheri-Dezfouli       DefaultingPyMlirContext context, DefaultingPyLocation location) {
43989b0f1eeSMehdi Amini     llvm::SmallVector<MlirValue> mlirOperands =
44089b0f1eeSMehdi Amini         wrapOperands(std::move(operandList));
44189b0f1eeSMehdi Amini     llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
442f22008edSArash Taheri-Dezfouli 
443f22008edSArash Taheri-Dezfouli     std::vector<PyShapedTypeComponents> inferredShapedTypeComponents;
444f22008edSArash Taheri-Dezfouli     PyMlirContext &pyContext = context.resolve();
445f22008edSArash Taheri-Dezfouli     AppendResultsCallbackData data{inferredShapedTypeComponents};
446f22008edSArash Taheri-Dezfouli     MlirStringRef opNameRef =
447f22008edSArash Taheri-Dezfouli         mlirStringRefCreate(getOpName().data(), getOpName().length());
448f22008edSArash Taheri-Dezfouli     MlirAttribute attributeDict =
449f22008edSArash Taheri-Dezfouli         attributes ? attributes->get() : mlirAttributeGetNull();
450f22008edSArash Taheri-Dezfouli 
451f22008edSArash Taheri-Dezfouli     MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes(
452f22008edSArash Taheri-Dezfouli         opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
453f22008edSArash Taheri-Dezfouli         mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
454f22008edSArash Taheri-Dezfouli         mlirRegions.data(), &appendResultsCallback, &data);
455f22008edSArash Taheri-Dezfouli 
456f22008edSArash Taheri-Dezfouli     if (mlirLogicalResultIsFailure(result)) {
457b56d1ec6SPeter Hawkins       throw nb::value_error("Failed to infer result shape type components");
458f22008edSArash Taheri-Dezfouli     }
459f22008edSArash Taheri-Dezfouli 
460f22008edSArash Taheri-Dezfouli     return inferredShapedTypeComponents;
461f22008edSArash Taheri-Dezfouli   }
462f22008edSArash Taheri-Dezfouli 
463f22008edSArash Taheri-Dezfouli   static void bindDerived(ClassTy &cls) {
464f22008edSArash Taheri-Dezfouli     cls.def("inferReturnTypeComponents",
465f22008edSArash Taheri-Dezfouli             &PyInferShapedTypeOpInterface::inferReturnTypeComponents,
466b56d1ec6SPeter Hawkins             nb::arg("operands").none() = nb::none(),
467b56d1ec6SPeter Hawkins             nb::arg("attributes").none() = nb::none(),
468b56d1ec6SPeter Hawkins             nb::arg("regions").none() = nb::none(),
469b56d1ec6SPeter Hawkins             nb::arg("properties").none() = nb::none(),
470b56d1ec6SPeter Hawkins             nb::arg("context").none() = nb::none(),
471b56d1ec6SPeter Hawkins             nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc);
472f22008edSArash Taheri-Dezfouli   }
473f22008edSArash Taheri-Dezfouli };
474f22008edSArash Taheri-Dezfouli 
475b56d1ec6SPeter Hawkins void populateIRInterfaces(nb::module_ &m) {
476f22008edSArash Taheri-Dezfouli   PyInferTypeOpInterface::bind(m);
477f22008edSArash Taheri-Dezfouli   PyShapedTypeComponents::bind(m);
478f22008edSArash Taheri-Dezfouli   PyInferShapedTypeOpInterface::bind(m);
479f22008edSArash Taheri-Dezfouli }
48014c92070SAlex Zinenko 
48114c92070SAlex Zinenko } // namespace python
48214c92070SAlex Zinenko } // namespace mlir
483