xref: /llvm-project/mlir/lib/Bindings/Python/IRInterfaces.cpp (revision 14c9207063bb00823a5126131e50c93f6e288bd3)
1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 #include "mlir-c/BuiltinAttributes.h"
11 #include "mlir-c/Interfaces.h"
12 
13 namespace py = pybind11;
14 
15 namespace mlir {
16 namespace python {
17 
18 constexpr static const char *constructorDoc =
19     R"(Creates an interface from a given operation/opview object or from a
20 subclass of OpView. Raises ValueError if the operation does not implement the
21 interface.)";
22 
23 constexpr static const char *operationDoc =
24     R"(Returns an Operation for which the interface was constructed.)";
25 
26 constexpr static const char *opviewDoc =
27     R"(Returns an OpView subclass _instance_ for which the interface was
28 constructed)";
29 
30 constexpr static const char *inferReturnTypesDoc =
31     R"(Given the arguments required to build an operation, attempts to infer
32 its return types. Raises ValueError on failure.)";
33 
34 /// CRTP base class for Python classes representing MLIR Op interfaces.
35 /// Interface hierarchies are flat so no base class is expected here. The
36 /// derived class is expected to define the following static fields:
37 ///  - `const char *pyClassName` - the name of the Python class to create;
38 ///  - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
39 ///    of the interface.
40 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
41 /// interface-specific methods.
42 ///
43 /// An interface class may be constructed from either an Operation/OpView object
44 /// or from a subclass of OpView. In the latter case, only the static interface
45 /// methods are available, similarly to calling ConcereteOp::staticMethod on the
46 /// C++ side. Implementations of concrete interfaces can use the `isStatic`
47 /// method to check whether the interface object was constructed from a class or
48 /// an operation/opview instance. The `getOpName` always succeeds and returns a
49 /// canonical name of the operation suitable for lookups.
50 template <typename ConcreteIface>
51 class PyConcreteOpInterface {
52 protected:
53   using ClassTy = py::class_<ConcreteIface>;
54   using GetTypeIDFunctionTy = MlirTypeID (*)();
55 
56 public:
57   /// Constructs an interface instance from an object that is either an
58   /// operation or a subclass of OpView. In the latter case, only the static
59   /// methods of the interface are accessible to the caller.
60   PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
61       : obj(object) {
62     try {
63       operation = &py::cast<PyOperation &>(obj);
64     } catch (py::cast_error &err) {
65       // Do nothing.
66     }
67 
68     try {
69       operation = &py::cast<PyOpView &>(obj).getOperation();
70     } catch (py::cast_error &err) {
71       // Do nothing.
72     }
73 
74     if (operation != nullptr) {
75       if (!mlirOperationImplementsInterface(*operation,
76                                             ConcreteIface::getInterfaceID())) {
77         std::string msg = "the operation does not implement ";
78         throw py::value_error(msg + ConcreteIface::pyClassName);
79       }
80 
81       MlirIdentifier identifier = mlirOperationGetName(*operation);
82       MlirStringRef stringRef = mlirIdentifierStr(identifier);
83       opName = std::string(stringRef.data, stringRef.length);
84     } else {
85       try {
86         opName = obj.attr("OPERATION_NAME").template cast<std::string>();
87       } catch (py::cast_error &err) {
88         throw py::type_error(
89             "Op interface does not refer to an operation or OpView class");
90       }
91 
92       if (!mlirOperationImplementsInterfaceStatic(
93               mlirStringRefCreate(opName.data(), opName.length()),
94               context.resolve().get(), ConcreteIface::getInterfaceID())) {
95         std::string msg = "the operation does not implement ";
96         throw py::value_error(msg + ConcreteIface::pyClassName);
97       }
98     }
99   }
100 
101   /// Creates the Python bindings for this class in the given module.
102   static void bind(py::module &m) {
103     py::class_<ConcreteIface> cls(m, "InferTypeOpInterface",
104                                   py::module_local());
105     cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
106             py::arg("context") = py::none(), constructorDoc)
107         .def_property_readonly("operation",
108                                &PyConcreteOpInterface::getOperationObject,
109                                operationDoc)
110         .def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
111                                opviewDoc);
112     ConcreteIface::bindDerived(cls);
113   }
114 
115   /// Hook for derived classes to add class-specific bindings.
116   static void bindDerived(ClassTy &cls) {}
117 
118   /// Returns `true` if this object was constructed from a subclass of OpView
119   /// rather than from an operation instance.
120   bool isStatic() { return operation == nullptr; }
121 
122   /// Returns the operation instance from which this object was constructed.
123   /// Throws a type error if this object was constructed from a subclass of
124   /// OpView.
125   py::object getOperationObject() {
126     if (operation == nullptr) {
127       throw py::type_error("Cannot get an operation from a static interface");
128     }
129 
130     return operation->getRef().releaseObject();
131   }
132 
133   /// Returns the opview of the operation instance from which this object was
134   /// constructed. Throws a type error if this object was constructed form a
135   /// subclass of OpView.
136   py::object getOpView() {
137     if (operation == nullptr) {
138       throw py::type_error("Cannot get an opview from a static interface");
139     }
140 
141     return operation->createOpView();
142   }
143 
144   /// Returns the canonical name of the operation this interface is constructed
145   /// from.
146   const std::string &getOpName() { return opName; }
147 
148 private:
149   PyOperation *operation = nullptr;
150   std::string opName;
151   py::object obj;
152 };
153 
154 /// Python wrapper for InterTypeOpInterface. This interface has only static
155 /// methods.
156 class PyInferTypeOpInterface
157     : public PyConcreteOpInterface<PyInferTypeOpInterface> {
158 public:
159   using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
160 
161   constexpr static const char *pyClassName = "InferTypeOpInterface";
162   constexpr static GetTypeIDFunctionTy getInterfaceID =
163       &mlirInferTypeOpInterfaceTypeID;
164 
165   /// C-style user-data structure for type appending callback.
166   struct AppendResultsCallbackData {
167     std::vector<PyType> &inferredTypes;
168     PyMlirContext &pyMlirContext;
169   };
170 
171   /// Appends the types provided as the two first arguments to the user-data
172   /// structure (expects AppendResultsCallbackData).
173   static void appendResultsCallback(intptr_t nTypes, MlirType *types,
174                                     void *userData) {
175     auto *data = static_cast<AppendResultsCallbackData *>(userData);
176     data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
177     for (intptr_t i = 0; i < nTypes; ++i) {
178       data->inferredTypes.push_back(
179           PyType(data->pyMlirContext.getRef(), types[i]));
180     }
181   }
182 
183   /// Given the arguments required to build an operation, attempts to infer its
184   /// return types. Throws value_error on faliure.
185   std::vector<PyType>
186   inferReturnTypes(llvm::Optional<std::vector<PyValue>> operands,
187                    llvm::Optional<PyAttribute> attributes,
188                    llvm::Optional<std::vector<PyRegion>> regions,
189                    DefaultingPyMlirContext context,
190                    DefaultingPyLocation location) {
191     llvm::SmallVector<MlirValue> mlirOperands;
192     llvm::SmallVector<MlirRegion> mlirRegions;
193 
194     if (operands) {
195       mlirOperands.reserve(operands->size());
196       for (PyValue &value : *operands) {
197         mlirOperands.push_back(value);
198       }
199     }
200 
201     if (regions) {
202       mlirRegions.reserve(regions->size());
203       for (PyRegion &region : *regions) {
204         mlirRegions.push_back(region);
205       }
206     }
207 
208     std::vector<PyType> inferredTypes;
209     PyMlirContext &pyContext = context.resolve();
210     AppendResultsCallbackData data{inferredTypes, pyContext};
211     MlirStringRef opNameRef =
212         mlirStringRefCreate(getOpName().data(), getOpName().length());
213     MlirAttribute attributeDict =
214         attributes ? attributes->get() : mlirAttributeGetNull();
215 
216     MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
217         opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
218         mlirOperands.data(), attributeDict, mlirRegions.size(),
219         mlirRegions.data(), &appendResultsCallback, &data);
220 
221     if (mlirLogicalResultIsFailure(result)) {
222       throw py::value_error("Failed to infer result types");
223     }
224 
225     return inferredTypes;
226   }
227 
228   static void bindDerived(ClassTy &cls) {
229     cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
230             py::arg("operands") = py::none(),
231             py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
232             py::arg("context") = py::none(), py::arg("loc") = py::none(),
233             inferReturnTypesDoc);
234   }
235 };
236 
237 void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); }
238 
239 } // namespace python
240 } // namespace mlir
241