xref: /llvm-project/mlir/lib/Bindings/Python/IRInterfaces.cpp (revision ee308c99ed0877edc286870089219179a2c64a9e)
1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 #include <optional>
11 
12 #include "IRModule.h"
13 #include "mlir-c/BuiltinAttributes.h"
14 #include "mlir-c/Interfaces.h"
15 #include "llvm/ADT/STLExtras.h"
16 
17 namespace py = pybind11;
18 
19 namespace mlir {
20 namespace python {
21 
22 constexpr static const char *constructorDoc =
23     R"(Creates an interface from a given operation/opview object or from a
24 subclass of OpView. Raises ValueError if the operation does not implement the
25 interface.)";
26 
27 constexpr static const char *operationDoc =
28     R"(Returns an Operation for which the interface was constructed.)";
29 
30 constexpr static const char *opviewDoc =
31     R"(Returns an OpView subclass _instance_ for which the interface was
32 constructed)";
33 
34 constexpr static const char *inferReturnTypesDoc =
35     R"(Given the arguments required to build an operation, attempts to infer
36 its return types. Raises ValueError on failure.)";
37 
38 /// CRTP base class for Python classes representing MLIR Op interfaces.
39 /// Interface hierarchies are flat so no base class is expected here. The
40 /// derived class is expected to define the following static fields:
41 ///  - `const char *pyClassName` - the name of the Python class to create;
42 ///  - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
43 ///    of the interface.
44 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
45 /// interface-specific methods.
46 ///
47 /// An interface class may be constructed from either an Operation/OpView object
48 /// or from a subclass of OpView. In the latter case, only the static interface
49 /// methods are available, similarly to calling ConcereteOp::staticMethod on the
50 /// C++ side. Implementations of concrete interfaces can use the `isStatic`
51 /// method to check whether the interface object was constructed from a class or
52 /// an operation/opview instance. The `getOpName` always succeeds and returns a
53 /// canonical name of the operation suitable for lookups.
54 template <typename ConcreteIface>
55 class PyConcreteOpInterface {
56 protected:
57   using ClassTy = py::class_<ConcreteIface>;
58   using GetTypeIDFunctionTy = MlirTypeID (*)();
59 
60 public:
61   /// Constructs an interface instance from an object that is either an
62   /// operation or a subclass of OpView. In the latter case, only the static
63   /// methods of the interface are accessible to the caller.
64   PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
65       : obj(std::move(object)) {
66     try {
67       operation = &py::cast<PyOperation &>(obj);
68     } catch (py::cast_error &) {
69       // Do nothing.
70     }
71 
72     try {
73       operation = &py::cast<PyOpView &>(obj).getOperation();
74     } catch (py::cast_error &) {
75       // Do nothing.
76     }
77 
78     if (operation != nullptr) {
79       if (!mlirOperationImplementsInterface(*operation,
80                                             ConcreteIface::getInterfaceID())) {
81         std::string msg = "the operation does not implement ";
82         throw py::value_error(msg + ConcreteIface::pyClassName);
83       }
84 
85       MlirIdentifier identifier = mlirOperationGetName(*operation);
86       MlirStringRef stringRef = mlirIdentifierStr(identifier);
87       opName = std::string(stringRef.data, stringRef.length);
88     } else {
89       try {
90         opName = obj.attr("OPERATION_NAME").template cast<std::string>();
91       } catch (py::cast_error &) {
92         throw py::type_error(
93             "Op interface does not refer to an operation or OpView class");
94       }
95 
96       if (!mlirOperationImplementsInterfaceStatic(
97               mlirStringRefCreate(opName.data(), opName.length()),
98               context.resolve().get(), ConcreteIface::getInterfaceID())) {
99         std::string msg = "the operation does not implement ";
100         throw py::value_error(msg + ConcreteIface::pyClassName);
101       }
102     }
103   }
104 
105   /// Creates the Python bindings for this class in the given module.
106   static void bind(py::module &m) {
107     py::class_<ConcreteIface> cls(m, "InferTypeOpInterface",
108                                   py::module_local());
109     cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
110             py::arg("context") = py::none(), constructorDoc)
111         .def_property_readonly("operation",
112                                &PyConcreteOpInterface::getOperationObject,
113                                operationDoc)
114         .def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
115                                opviewDoc);
116     ConcreteIface::bindDerived(cls);
117   }
118 
119   /// Hook for derived classes to add class-specific bindings.
120   static void bindDerived(ClassTy &cls) {}
121 
122   /// Returns `true` if this object was constructed from a subclass of OpView
123   /// rather than from an operation instance.
124   bool isStatic() { return operation == nullptr; }
125 
126   /// Returns the operation instance from which this object was constructed.
127   /// Throws a type error if this object was constructed from a subclass of
128   /// OpView.
129   py::object getOperationObject() {
130     if (operation == nullptr) {
131       throw py::type_error("Cannot get an operation from a static interface");
132     }
133 
134     return operation->getRef().releaseObject();
135   }
136 
137   /// Returns the opview of the operation instance from which this object was
138   /// constructed. Throws a type error if this object was constructed form a
139   /// subclass of OpView.
140   py::object getOpView() {
141     if (operation == nullptr) {
142       throw py::type_error("Cannot get an opview from a static interface");
143     }
144 
145     return operation->createOpView();
146   }
147 
148   /// Returns the canonical name of the operation this interface is constructed
149   /// from.
150   const std::string &getOpName() { return opName; }
151 
152 private:
153   PyOperation *operation = nullptr;
154   std::string opName;
155   py::object obj;
156 };
157 
158 /// Python wrapper for InterTypeOpInterface. This interface has only static
159 /// methods.
160 class PyInferTypeOpInterface
161     : public PyConcreteOpInterface<PyInferTypeOpInterface> {
162 public:
163   using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
164 
165   constexpr static const char *pyClassName = "InferTypeOpInterface";
166   constexpr static GetTypeIDFunctionTy getInterfaceID =
167       &mlirInferTypeOpInterfaceTypeID;
168 
169   /// C-style user-data structure for type appending callback.
170   struct AppendResultsCallbackData {
171     std::vector<PyType> &inferredTypes;
172     PyMlirContext &pyMlirContext;
173   };
174 
175   /// Appends the types provided as the two first arguments to the user-data
176   /// structure (expects AppendResultsCallbackData).
177   static void appendResultsCallback(intptr_t nTypes, MlirType *types,
178                                     void *userData) {
179     auto *data = static_cast<AppendResultsCallbackData *>(userData);
180     data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
181     for (intptr_t i = 0; i < nTypes; ++i) {
182       data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
183     }
184   }
185 
186   /// Given the arguments required to build an operation, attempts to infer its
187   /// return types. Throws value_error on failure.
188   std::vector<PyType>
189   inferReturnTypes(std::optional<py::list> operandList,
190                    std::optional<PyAttribute> attributes,
191                    std::optional<std::vector<PyRegion>> regions,
192                    DefaultingPyMlirContext context,
193                    DefaultingPyLocation location) {
194     llvm::SmallVector<MlirValue> mlirOperands;
195     llvm::SmallVector<MlirRegion> mlirRegions;
196 
197     if (operandList && !operandList->empty()) {
198       // Note: as the list may contain other lists this may not be final size.
199       mlirOperands.reserve(operandList->size());
200       for (const auto& it : llvm::enumerate(*operandList)) {
201         PyValue* val;
202         try {
203           val = py::cast<PyValue *>(it.value());
204           if (!val)
205             throw py::cast_error();
206           mlirOperands.push_back(val->get());
207           continue;
208         } catch (py::cast_error &err) {
209         }
210 
211         try {
212           auto vals = py::cast<py::sequence>(it.value());
213           for (py::object v : vals) {
214             try {
215               val = py::cast<PyValue *>(v);
216               if (!val)
217                 throw py::cast_error();
218               mlirOperands.push_back(val->get());
219             } catch (py::cast_error &err) {
220               throw py::value_error(
221                   (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
222                    " must be a Value or Sequence of Values (" + err.what() +
223                    ")")
224                       .str());
225             }
226           }
227           continue;
228         } catch (py::cast_error &err) {
229           throw py::value_error(
230               (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
231                " must be a Value or Sequence of Values (" + err.what() + ")")
232                   .str());
233         }
234 
235         throw py::cast_error();
236       }
237     }
238 
239     if (regions) {
240       mlirRegions.reserve(regions->size());
241       for (PyRegion &region : *regions) {
242         mlirRegions.push_back(region);
243       }
244     }
245 
246     std::vector<PyType> inferredTypes;
247     PyMlirContext &pyContext = context.resolve();
248     AppendResultsCallbackData data{inferredTypes, pyContext};
249     MlirStringRef opNameRef =
250         mlirStringRefCreate(getOpName().data(), getOpName().length());
251     MlirAttribute attributeDict =
252         attributes ? attributes->get() : mlirAttributeGetNull();
253 
254     MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
255         opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
256         mlirOperands.data(), attributeDict, mlirRegions.size(),
257         mlirRegions.data(), &appendResultsCallback, &data);
258 
259     if (mlirLogicalResultIsFailure(result)) {
260       throw py::value_error("Failed to infer result types");
261     }
262 
263     return inferredTypes;
264   }
265 
266   static void bindDerived(ClassTy &cls) {
267     cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
268             py::arg("operands") = py::none(),
269             py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
270             py::arg("context") = py::none(), py::arg("loc") = py::none(),
271             inferReturnTypesDoc);
272   }
273 };
274 
275 void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); }
276 
277 } // namespace python
278 } // namespace mlir
279