xref: /llvm-project/mlir/lib/Bindings/Python/IRInterfaces.cpp (revision f22008ed89eac028cd70f91de3adf41a481f6d22)
1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <optional>
10 #include <utility>
11 
12 #include "IRModule.h"
13 #include "mlir-c/BuiltinAttributes.h"
14 #include "mlir-c/Interfaces.h"
15 #include "llvm/ADT/STLExtras.h"
16 
17 namespace py = pybind11;
18 
19 namespace mlir {
20 namespace python {
21 
22 constexpr static const char *constructorDoc =
23     R"(Creates an interface from a given operation/opview object or from a
24 subclass of OpView. Raises ValueError if the operation does not implement the
25 interface.)";
26 
27 constexpr static const char *operationDoc =
28     R"(Returns an Operation for which the interface was constructed.)";
29 
30 constexpr static const char *opviewDoc =
31     R"(Returns an OpView subclass _instance_ for which the interface was
32 constructed)";
33 
34 constexpr static const char *inferReturnTypesDoc =
35     R"(Given the arguments required to build an operation, attempts to infer
36 its return types. Raises ValueError on failure.)";
37 
38 constexpr static const char *inferReturnTypeComponentsDoc =
39     R"(Given the arguments required to build an operation, attempts to infer
40 its return shaped type components. Raises ValueError on failure.)";
41 
42 namespace {
43 
44 /// Takes in an optional ist of operands and converts them into a SmallVector
45 /// of MlirVlaues. Returns an empty SmallVector if the list is empty.
46 llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) {
47   llvm::SmallVector<MlirValue> mlirOperands;
48 
49   if (!operandList || operandList->empty()) {
50     return mlirOperands;
51   }
52 
53   // Note: as the list may contain other lists this may not be final size.
54   mlirOperands.reserve(operandList->size());
55   for (const auto &&it : llvm::enumerate(*operandList)) {
56     PyValue *val;
57     try {
58       val = py::cast<PyValue *>(it.value());
59       if (!val)
60         throw py::cast_error();
61       mlirOperands.push_back(val->get());
62       continue;
63     } catch (py::cast_error &err) {
64       // Intentionally unhandled to try sequence below first.
65       (void)err;
66     }
67 
68     try {
69       auto vals = py::cast<py::sequence>(it.value());
70       for (py::object v : vals) {
71         try {
72           val = py::cast<PyValue *>(v);
73           if (!val)
74             throw py::cast_error();
75           mlirOperands.push_back(val->get());
76         } catch (py::cast_error &err) {
77           throw py::value_error(
78               (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
79                " must be a Value or Sequence of Values (" + err.what() + ")")
80                   .str());
81         }
82       }
83       continue;
84     } catch (py::cast_error &err) {
85       throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
86                              " must be a Value or Sequence of Values (" +
87                              err.what() + ")")
88                                 .str());
89     }
90 
91     throw py::cast_error();
92   }
93 
94   return mlirOperands;
95 }
96 
97 /// Takes in an optional vector of PyRegions and returns a SmallVector of
98 /// MlirRegion. Returns an empty SmallVector if the list is empty.
99 llvm::SmallVector<MlirRegion>
100 wrapRegions(std::optional<std::vector<PyRegion>> regions) {
101   llvm::SmallVector<MlirRegion> mlirRegions;
102 
103   if (regions) {
104     mlirRegions.reserve(regions->size());
105     for (PyRegion &region : *regions) {
106       mlirRegions.push_back(region);
107     }
108   }
109 
110   return mlirRegions;
111 }
112 
113 } // namespace
114 
115 /// CRTP base class for Python classes representing MLIR Op interfaces.
116 /// Interface hierarchies are flat so no base class is expected here. The
117 /// derived class is expected to define the following static fields:
118 ///  - `const char *pyClassName` - the name of the Python class to create;
119 ///  - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
120 ///    of the interface.
121 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
122 /// interface-specific methods.
123 ///
124 /// An interface class may be constructed from either an Operation/OpView object
125 /// or from a subclass of OpView. In the latter case, only the static interface
126 /// methods are available, similarly to calling ConcereteOp::staticMethod on the
127 /// C++ side. Implementations of concrete interfaces can use the `isStatic`
128 /// method to check whether the interface object was constructed from a class or
129 /// an operation/opview instance. The `getOpName` always succeeds and returns a
130 /// canonical name of the operation suitable for lookups.
131 template <typename ConcreteIface>
132 class PyConcreteOpInterface {
133 protected:
134   using ClassTy = py::class_<ConcreteIface>;
135   using GetTypeIDFunctionTy = MlirTypeID (*)();
136 
137 public:
138   /// Constructs an interface instance from an object that is either an
139   /// operation or a subclass of OpView. In the latter case, only the static
140   /// methods of the interface are accessible to the caller.
141   PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
142       : obj(std::move(object)) {
143     try {
144       operation = &py::cast<PyOperation &>(obj);
145     } catch (py::cast_error &) {
146       // Do nothing.
147     }
148 
149     try {
150       operation = &py::cast<PyOpView &>(obj).getOperation();
151     } catch (py::cast_error &) {
152       // Do nothing.
153     }
154 
155     if (operation != nullptr) {
156       if (!mlirOperationImplementsInterface(*operation,
157                                             ConcreteIface::getInterfaceID())) {
158         std::string msg = "the operation does not implement ";
159         throw py::value_error(msg + ConcreteIface::pyClassName);
160       }
161 
162       MlirIdentifier identifier = mlirOperationGetName(*operation);
163       MlirStringRef stringRef = mlirIdentifierStr(identifier);
164       opName = std::string(stringRef.data, stringRef.length);
165     } else {
166       try {
167         opName = obj.attr("OPERATION_NAME").template cast<std::string>();
168       } catch (py::cast_error &) {
169         throw py::type_error(
170             "Op interface does not refer to an operation or OpView class");
171       }
172 
173       if (!mlirOperationImplementsInterfaceStatic(
174               mlirStringRefCreate(opName.data(), opName.length()),
175               context.resolve().get(), ConcreteIface::getInterfaceID())) {
176         std::string msg = "the operation does not implement ";
177         throw py::value_error(msg + ConcreteIface::pyClassName);
178       }
179     }
180   }
181 
182   /// Creates the Python bindings for this class in the given module.
183   static void bind(py::module &m) {
184     py::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName,
185                                   py::module_local());
186     cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
187             py::arg("context") = py::none(), constructorDoc)
188         .def_property_readonly("operation",
189                                &PyConcreteOpInterface::getOperationObject,
190                                operationDoc)
191         .def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
192                                opviewDoc);
193     ConcreteIface::bindDerived(cls);
194   }
195 
196   /// Hook for derived classes to add class-specific bindings.
197   static void bindDerived(ClassTy &cls) {}
198 
199   /// Returns `true` if this object was constructed from a subclass of OpView
200   /// rather than from an operation instance.
201   bool isStatic() { return operation == nullptr; }
202 
203   /// Returns the operation instance from which this object was constructed.
204   /// Throws a type error if this object was constructed from a subclass of
205   /// OpView.
206   py::object getOperationObject() {
207     if (operation == nullptr) {
208       throw py::type_error("Cannot get an operation from a static interface");
209     }
210 
211     return operation->getRef().releaseObject();
212   }
213 
214   /// Returns the opview of the operation instance from which this object was
215   /// constructed. Throws a type error if this object was constructed form a
216   /// subclass of OpView.
217   py::object getOpView() {
218     if (operation == nullptr) {
219       throw py::type_error("Cannot get an opview from a static interface");
220     }
221 
222     return operation->createOpView();
223   }
224 
225   /// Returns the canonical name of the operation this interface is constructed
226   /// from.
227   const std::string &getOpName() { return opName; }
228 
229 private:
230   PyOperation *operation = nullptr;
231   std::string opName;
232   py::object obj;
233 };
234 
235 /// Python wrapper for InferTypeOpInterface. This interface has only static
236 /// methods.
237 class PyInferTypeOpInterface
238     : public PyConcreteOpInterface<PyInferTypeOpInterface> {
239 public:
240   using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
241 
242   constexpr static const char *pyClassName = "InferTypeOpInterface";
243   constexpr static GetTypeIDFunctionTy getInterfaceID =
244       &mlirInferTypeOpInterfaceTypeID;
245 
246   /// C-style user-data structure for type appending callback.
247   struct AppendResultsCallbackData {
248     std::vector<PyType> &inferredTypes;
249     PyMlirContext &pyMlirContext;
250   };
251 
252   /// Appends the types provided as the two first arguments to the user-data
253   /// structure (expects AppendResultsCallbackData).
254   static void appendResultsCallback(intptr_t nTypes, MlirType *types,
255                                     void *userData) {
256     auto *data = static_cast<AppendResultsCallbackData *>(userData);
257     data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
258     for (intptr_t i = 0; i < nTypes; ++i) {
259       data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
260     }
261   }
262 
263   /// Given the arguments required to build an operation, attempts to infer its
264   /// return types. Throws value_error on failure.
265   std::vector<PyType>
266   inferReturnTypes(std::optional<py::list> operandList,
267                    std::optional<PyAttribute> attributes, void *properties,
268                    std::optional<std::vector<PyRegion>> regions,
269                    DefaultingPyMlirContext context,
270                    DefaultingPyLocation location) {
271     llvm::SmallVector<MlirValue> mlirOperands = wrapOperands(operandList);
272     llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(regions);
273 
274     std::vector<PyType> inferredTypes;
275     PyMlirContext &pyContext = context.resolve();
276     AppendResultsCallbackData data{inferredTypes, pyContext};
277     MlirStringRef opNameRef =
278         mlirStringRefCreate(getOpName().data(), getOpName().length());
279     MlirAttribute attributeDict =
280         attributes ? attributes->get() : mlirAttributeGetNull();
281 
282     MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
283         opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
284         mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
285         mlirRegions.data(), &appendResultsCallback, &data);
286 
287     if (mlirLogicalResultIsFailure(result)) {
288       throw py::value_error("Failed to infer result types");
289     }
290 
291     return inferredTypes;
292   }
293 
294   static void bindDerived(ClassTy &cls) {
295     cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
296             py::arg("operands") = py::none(),
297             py::arg("attributes") = py::none(),
298             py::arg("properties") = py::none(), py::arg("regions") = py::none(),
299             py::arg("context") = py::none(), py::arg("loc") = py::none(),
300             inferReturnTypesDoc);
301   }
302 };
303 
304 /// Wrapper around an shaped type components.
305 class PyShapedTypeComponents {
306 public:
307   PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
308   PyShapedTypeComponents(py::list shape, MlirType elementType)
309       : shape(shape), elementType(elementType), ranked(true) {}
310   PyShapedTypeComponents(py::list shape, MlirType elementType,
311                          MlirAttribute attribute)
312       : shape(shape), elementType(elementType), attribute(attribute),
313         ranked(true) {}
314   PyShapedTypeComponents(PyShapedTypeComponents &) = delete;
315   PyShapedTypeComponents(PyShapedTypeComponents &&other)
316       : shape(other.shape), elementType(other.elementType),
317         attribute(other.attribute), ranked(other.ranked) {}
318 
319   static void bind(py::module &m) {
320     py::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents",
321                                        py::module_local())
322         .def_property_readonly(
323             "element_type",
324             [](PyShapedTypeComponents &self) {
325               return PyType(PyMlirContext::forContext(
326                                 mlirTypeGetContext(self.elementType)),
327                             self.elementType);
328             },
329             "Returns the element type of the shaped type components.")
330         .def_static(
331             "get",
332             [](PyType &elementType) {
333               return PyShapedTypeComponents(elementType);
334             },
335             py::arg("element_type"),
336             "Create an shaped type components object with only the element "
337             "type.")
338         .def_static(
339             "get",
340             [](py::list shape, PyType &elementType) {
341               return PyShapedTypeComponents(shape, elementType);
342             },
343             py::arg("shape"), py::arg("element_type"),
344             "Create a ranked shaped type components object.")
345         .def_static(
346             "get",
347             [](py::list shape, PyType &elementType, PyAttribute &attribute) {
348               return PyShapedTypeComponents(shape, elementType, attribute);
349             },
350             py::arg("shape"), py::arg("element_type"), py::arg("attribute"),
351             "Create a ranked shaped type components object with attribute.")
352         .def_property_readonly(
353             "has_rank",
354             [](PyShapedTypeComponents &self) -> bool { return self.ranked; },
355             "Returns whether the given shaped type component is ranked.")
356         .def_property_readonly(
357             "rank",
358             [](PyShapedTypeComponents &self) -> py::object {
359               if (!self.ranked) {
360                 return py::none();
361               }
362               return py::int_(self.shape.size());
363             },
364             "Returns the rank of the given ranked shaped type components. If "
365             "the shaped type components does not have a rank, None is "
366             "returned.")
367         .def_property_readonly(
368             "shape",
369             [](PyShapedTypeComponents &self) -> py::object {
370               if (!self.ranked) {
371                 return py::none();
372               }
373               return py::list(self.shape);
374             },
375             "Returns the shape of the ranked shaped type components as a list "
376             "of integers. Returns none if the shaped type component does not "
377             "have a rank.");
378   }
379 
380   pybind11::object getCapsule();
381   static PyShapedTypeComponents createFromCapsule(pybind11::object capsule);
382 
383 private:
384   py::list shape;
385   MlirType elementType;
386   MlirAttribute attribute;
387   bool ranked{false};
388 };
389 
390 /// Python wrapper for InferShapedTypeOpInterface. This interface has only
391 /// static methods.
392 class PyInferShapedTypeOpInterface
393     : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> {
394 public:
395   using PyConcreteOpInterface<
396       PyInferShapedTypeOpInterface>::PyConcreteOpInterface;
397 
398   constexpr static const char *pyClassName = "InferShapedTypeOpInterface";
399   constexpr static GetTypeIDFunctionTy getInterfaceID =
400       &mlirInferShapedTypeOpInterfaceTypeID;
401 
402   /// C-style user-data structure for type appending callback.
403   struct AppendResultsCallbackData {
404     std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents;
405   };
406 
407   /// Appends the shaped type components provided as unpacked shape, element
408   /// type, attribute to the user-data.
409   static void appendResultsCallback(bool hasRank, intptr_t rank,
410                                     const int64_t *shape, MlirType elementType,
411                                     MlirAttribute attribute, void *userData) {
412     auto *data = static_cast<AppendResultsCallbackData *>(userData);
413     if (!hasRank) {
414       data->inferredShapedTypeComponents.emplace_back(elementType);
415     } else {
416       py::list shapeList;
417       for (intptr_t i = 0; i < rank; ++i) {
418         shapeList.append(shape[i]);
419       }
420       data->inferredShapedTypeComponents.emplace_back(shapeList, elementType,
421                                                       attribute);
422     }
423   }
424 
425   /// Given the arguments required to build an operation, attempts to infer the
426   /// shaped type components. Throws value_error on failure.
427   std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
428       std::optional<py::list> operandList,
429       std::optional<PyAttribute> attributes, void *properties,
430       std::optional<std::vector<PyRegion>> regions,
431       DefaultingPyMlirContext context, DefaultingPyLocation location) {
432     llvm::SmallVector<MlirValue> mlirOperands = wrapOperands(operandList);
433     llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(regions);
434 
435     std::vector<PyShapedTypeComponents> inferredShapedTypeComponents;
436     PyMlirContext &pyContext = context.resolve();
437     AppendResultsCallbackData data{inferredShapedTypeComponents};
438     MlirStringRef opNameRef =
439         mlirStringRefCreate(getOpName().data(), getOpName().length());
440     MlirAttribute attributeDict =
441         attributes ? attributes->get() : mlirAttributeGetNull();
442 
443     MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes(
444         opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
445         mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
446         mlirRegions.data(), &appendResultsCallback, &data);
447 
448     if (mlirLogicalResultIsFailure(result)) {
449       throw py::value_error("Failed to infer result shape type components");
450     }
451 
452     return inferredShapedTypeComponents;
453   }
454 
455   static void bindDerived(ClassTy &cls) {
456     cls.def("inferReturnTypeComponents",
457             &PyInferShapedTypeOpInterface::inferReturnTypeComponents,
458             py::arg("operands") = py::none(),
459             py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
460             py::arg("properties") = py::none(), py::arg("context") = py::none(),
461             py::arg("loc") = py::none(), inferReturnTypeComponentsDoc);
462   }
463 };
464 
465 void populateIRInterfaces(py::module &m) {
466   PyInferTypeOpInterface::bind(m);
467   PyShapedTypeComponents::bind(m);
468   PyInferShapedTypeOpInterface::bind(m);
469 }
470 
471 } // namespace python
472 } // namespace mlir
473