xref: /llvm-project/mlir/lib/Bindings/Python/IRInterfaces.cpp (revision 1fc096af1e495d121679340b527701a5c0a9ef8b)
1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "IRModule.h"
12 #include "mlir-c/BuiltinAttributes.h"
13 #include "mlir-c/Interfaces.h"
14 
15 namespace py = pybind11;
16 
17 namespace mlir {
18 namespace python {
19 
20 constexpr static const char *constructorDoc =
21     R"(Creates an interface from a given operation/opview object or from a
22 subclass of OpView. Raises ValueError if the operation does not implement the
23 interface.)";
24 
25 constexpr static const char *operationDoc =
26     R"(Returns an Operation for which the interface was constructed.)";
27 
28 constexpr static const char *opviewDoc =
29     R"(Returns an OpView subclass _instance_ for which the interface was
30 constructed)";
31 
32 constexpr static const char *inferReturnTypesDoc =
33     R"(Given the arguments required to build an operation, attempts to infer
34 its return types. Raises ValueError on failure.)";
35 
36 /// CRTP base class for Python classes representing MLIR Op interfaces.
37 /// Interface hierarchies are flat so no base class is expected here. The
38 /// derived class is expected to define the following static fields:
39 ///  - `const char *pyClassName` - the name of the Python class to create;
40 ///  - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
41 ///    of the interface.
42 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
43 /// interface-specific methods.
44 ///
45 /// An interface class may be constructed from either an Operation/OpView object
46 /// or from a subclass of OpView. In the latter case, only the static interface
47 /// methods are available, similarly to calling ConcereteOp::staticMethod on the
48 /// C++ side. Implementations of concrete interfaces can use the `isStatic`
49 /// method to check whether the interface object was constructed from a class or
50 /// an operation/opview instance. The `getOpName` always succeeds and returns a
51 /// canonical name of the operation suitable for lookups.
52 template <typename ConcreteIface>
53 class PyConcreteOpInterface {
54 protected:
55   using ClassTy = py::class_<ConcreteIface>;
56   using GetTypeIDFunctionTy = MlirTypeID (*)();
57 
58 public:
59   /// Constructs an interface instance from an object that is either an
60   /// operation or a subclass of OpView. In the latter case, only the static
61   /// methods of the interface are accessible to the caller.
62   PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
63       : obj(std::move(object)) {
64     try {
65       operation = &py::cast<PyOperation &>(obj);
66     } catch (py::cast_error &err) {
67       // Do nothing.
68     }
69 
70     try {
71       operation = &py::cast<PyOpView &>(obj).getOperation();
72     } catch (py::cast_error &err) {
73       // Do nothing.
74     }
75 
76     if (operation != nullptr) {
77       if (!mlirOperationImplementsInterface(*operation,
78                                             ConcreteIface::getInterfaceID())) {
79         std::string msg = "the operation does not implement ";
80         throw py::value_error(msg + ConcreteIface::pyClassName);
81       }
82 
83       MlirIdentifier identifier = mlirOperationGetName(*operation);
84       MlirStringRef stringRef = mlirIdentifierStr(identifier);
85       opName = std::string(stringRef.data, stringRef.length);
86     } else {
87       try {
88         opName = obj.attr("OPERATION_NAME").template cast<std::string>();
89       } catch (py::cast_error &err) {
90         throw py::type_error(
91             "Op interface does not refer to an operation or OpView class");
92       }
93 
94       if (!mlirOperationImplementsInterfaceStatic(
95               mlirStringRefCreate(opName.data(), opName.length()),
96               context.resolve().get(), ConcreteIface::getInterfaceID())) {
97         std::string msg = "the operation does not implement ";
98         throw py::value_error(msg + ConcreteIface::pyClassName);
99       }
100     }
101   }
102 
103   /// Creates the Python bindings for this class in the given module.
104   static void bind(py::module &m) {
105     py::class_<ConcreteIface> cls(m, "InferTypeOpInterface",
106                                   py::module_local());
107     cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
108             py::arg("context") = py::none(), constructorDoc)
109         .def_property_readonly("operation",
110                                &PyConcreteOpInterface::getOperationObject,
111                                operationDoc)
112         .def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
113                                opviewDoc);
114     ConcreteIface::bindDerived(cls);
115   }
116 
117   /// Hook for derived classes to add class-specific bindings.
118   static void bindDerived(ClassTy &cls) {}
119 
120   /// Returns `true` if this object was constructed from a subclass of OpView
121   /// rather than from an operation instance.
122   bool isStatic() { return operation == nullptr; }
123 
124   /// Returns the operation instance from which this object was constructed.
125   /// Throws a type error if this object was constructed from a subclass of
126   /// OpView.
127   py::object getOperationObject() {
128     if (operation == nullptr) {
129       throw py::type_error("Cannot get an operation from a static interface");
130     }
131 
132     return operation->getRef().releaseObject();
133   }
134 
135   /// Returns the opview of the operation instance from which this object was
136   /// constructed. Throws a type error if this object was constructed form a
137   /// subclass of OpView.
138   py::object getOpView() {
139     if (operation == nullptr) {
140       throw py::type_error("Cannot get an opview from a static interface");
141     }
142 
143     return operation->createOpView();
144   }
145 
146   /// Returns the canonical name of the operation this interface is constructed
147   /// from.
148   const std::string &getOpName() { return opName; }
149 
150 private:
151   PyOperation *operation = nullptr;
152   std::string opName;
153   py::object obj;
154 };
155 
156 /// Python wrapper for InterTypeOpInterface. This interface has only static
157 /// methods.
158 class PyInferTypeOpInterface
159     : public PyConcreteOpInterface<PyInferTypeOpInterface> {
160 public:
161   using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
162 
163   constexpr static const char *pyClassName = "InferTypeOpInterface";
164   constexpr static GetTypeIDFunctionTy getInterfaceID =
165       &mlirInferTypeOpInterfaceTypeID;
166 
167   /// C-style user-data structure for type appending callback.
168   struct AppendResultsCallbackData {
169     std::vector<PyType> &inferredTypes;
170     PyMlirContext &pyMlirContext;
171   };
172 
173   /// Appends the types provided as the two first arguments to the user-data
174   /// structure (expects AppendResultsCallbackData).
175   static void appendResultsCallback(intptr_t nTypes, MlirType *types,
176                                     void *userData) {
177     auto *data = static_cast<AppendResultsCallbackData *>(userData);
178     data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
179     for (intptr_t i = 0; i < nTypes; ++i) {
180       data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
181     }
182   }
183 
184   /// Given the arguments required to build an operation, attempts to infer its
185   /// return types. Throws value_error on faliure.
186   std::vector<PyType>
187   inferReturnTypes(llvm::Optional<std::vector<PyValue>> operands,
188                    llvm::Optional<PyAttribute> attributes,
189                    llvm::Optional<std::vector<PyRegion>> regions,
190                    DefaultingPyMlirContext context,
191                    DefaultingPyLocation location) {
192     llvm::SmallVector<MlirValue> mlirOperands;
193     llvm::SmallVector<MlirRegion> mlirRegions;
194 
195     if (operands) {
196       mlirOperands.reserve(operands->size());
197       for (PyValue &value : *operands) {
198         mlirOperands.push_back(value);
199       }
200     }
201 
202     if (regions) {
203       mlirRegions.reserve(regions->size());
204       for (PyRegion &region : *regions) {
205         mlirRegions.push_back(region);
206       }
207     }
208 
209     std::vector<PyType> inferredTypes;
210     PyMlirContext &pyContext = context.resolve();
211     AppendResultsCallbackData data{inferredTypes, pyContext};
212     MlirStringRef opNameRef =
213         mlirStringRefCreate(getOpName().data(), getOpName().length());
214     MlirAttribute attributeDict =
215         attributes ? attributes->get() : mlirAttributeGetNull();
216 
217     MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
218         opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
219         mlirOperands.data(), attributeDict, mlirRegions.size(),
220         mlirRegions.data(), &appendResultsCallback, &data);
221 
222     if (mlirLogicalResultIsFailure(result)) {
223       throw py::value_error("Failed to infer result types");
224     }
225 
226     return inferredTypes;
227   }
228 
229   static void bindDerived(ClassTy &cls) {
230     cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
231             py::arg("operands") = py::none(),
232             py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
233             py::arg("context") = py::none(), py::arg("loc") = py::none(),
234             inferReturnTypesDoc);
235   }
236 };
237 
238 void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); }
239 
240 } // namespace python
241 } // namespace mlir
242