xref: /llvm-project/mlir/lib/Bindings/Python/IRInterfaces.cpp (revision 89b0f1ee340b31f09d148ca542a41a9870fb0fad)
1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cstdint>
10 #include <optional>
11 #include <pybind11/cast.h>
12 #include <pybind11/detail/common.h>
13 #include <pybind11/pybind11.h>
14 #include <pybind11/pytypes.h>
15 #include <string>
16 #include <utility>
17 #include <vector>
18 
19 #include "IRModule.h"
20 #include "mlir-c/BuiltinAttributes.h"
21 #include "mlir-c/IR.h"
22 #include "mlir-c/Interfaces.h"
23 #include "mlir-c/Support.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 
27 namespace py = pybind11;
28 
29 namespace mlir {
30 namespace python {
31 
32 constexpr static const char *constructorDoc =
33     R"(Creates an interface from a given operation/opview object or from a
34 subclass of OpView. Raises ValueError if the operation does not implement the
35 interface.)";
36 
37 constexpr static const char *operationDoc =
38     R"(Returns an Operation for which the interface was constructed.)";
39 
40 constexpr static const char *opviewDoc =
41     R"(Returns an OpView subclass _instance_ for which the interface was
42 constructed)";
43 
44 constexpr static const char *inferReturnTypesDoc =
45     R"(Given the arguments required to build an operation, attempts to infer
46 its return types. Raises ValueError on failure.)";
47 
48 constexpr static const char *inferReturnTypeComponentsDoc =
49     R"(Given the arguments required to build an operation, attempts to infer
50 its return shaped type components. Raises ValueError on failure.)";
51 
52 namespace {
53 
54 /// Takes in an optional ist of operands and converts them into a SmallVector
55 /// of MlirVlaues. Returns an empty SmallVector if the list is empty.
56 llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) {
57   llvm::SmallVector<MlirValue> mlirOperands;
58 
59   if (!operandList || operandList->empty()) {
60     return mlirOperands;
61   }
62 
63   // Note: as the list may contain other lists this may not be final size.
64   mlirOperands.reserve(operandList->size());
65   for (const auto &&it : llvm::enumerate(*operandList)) {
66     if (it.value().is_none())
67       continue;
68 
69     PyValue *val;
70     try {
71       val = py::cast<PyValue *>(it.value());
72       if (!val)
73         throw py::cast_error();
74       mlirOperands.push_back(val->get());
75       continue;
76     } catch (py::cast_error &err) {
77       // Intentionally unhandled to try sequence below first.
78       (void)err;
79     }
80 
81     try {
82       auto vals = py::cast<py::sequence>(it.value());
83       for (py::object v : vals) {
84         try {
85           val = py::cast<PyValue *>(v);
86           if (!val)
87             throw py::cast_error();
88           mlirOperands.push_back(val->get());
89         } catch (py::cast_error &err) {
90           throw py::value_error(
91               (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
92                " must be a Value or Sequence of Values (" + err.what() + ")")
93                   .str());
94         }
95       }
96       continue;
97     } catch (py::cast_error &err) {
98       throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
99                              " must be a Value or Sequence of Values (" +
100                              err.what() + ")")
101                                 .str());
102     }
103 
104     throw py::cast_error();
105   }
106 
107   return mlirOperands;
108 }
109 
110 /// Takes in an optional vector of PyRegions and returns a SmallVector of
111 /// MlirRegion. Returns an empty SmallVector if the list is empty.
112 llvm::SmallVector<MlirRegion>
113 wrapRegions(std::optional<std::vector<PyRegion>> regions) {
114   llvm::SmallVector<MlirRegion> mlirRegions;
115 
116   if (regions) {
117     mlirRegions.reserve(regions->size());
118     for (PyRegion &region : *regions) {
119       mlirRegions.push_back(region);
120     }
121   }
122 
123   return mlirRegions;
124 }
125 
126 } // namespace
127 
128 /// CRTP base class for Python classes representing MLIR Op interfaces.
129 /// Interface hierarchies are flat so no base class is expected here. The
130 /// derived class is expected to define the following static fields:
131 ///  - `const char *pyClassName` - the name of the Python class to create;
132 ///  - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
133 ///    of the interface.
134 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
135 /// interface-specific methods.
136 ///
137 /// An interface class may be constructed from either an Operation/OpView object
138 /// or from a subclass of OpView. In the latter case, only the static interface
139 /// methods are available, similarly to calling ConcereteOp::staticMethod on the
140 /// C++ side. Implementations of concrete interfaces can use the `isStatic`
141 /// method to check whether the interface object was constructed from a class or
142 /// an operation/opview instance. The `getOpName` always succeeds and returns a
143 /// canonical name of the operation suitable for lookups.
144 template <typename ConcreteIface>
145 class PyConcreteOpInterface {
146 protected:
147   using ClassTy = py::class_<ConcreteIface>;
148   using GetTypeIDFunctionTy = MlirTypeID (*)();
149 
150 public:
151   /// Constructs an interface instance from an object that is either an
152   /// operation or a subclass of OpView. In the latter case, only the static
153   /// methods of the interface are accessible to the caller.
154   PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
155       : obj(std::move(object)) {
156     try {
157       operation = &py::cast<PyOperation &>(obj);
158     } catch (py::cast_error &) {
159       // Do nothing.
160     }
161 
162     try {
163       operation = &py::cast<PyOpView &>(obj).getOperation();
164     } catch (py::cast_error &) {
165       // Do nothing.
166     }
167 
168     if (operation != nullptr) {
169       if (!mlirOperationImplementsInterface(*operation,
170                                             ConcreteIface::getInterfaceID())) {
171         std::string msg = "the operation does not implement ";
172         throw py::value_error(msg + ConcreteIface::pyClassName);
173       }
174 
175       MlirIdentifier identifier = mlirOperationGetName(*operation);
176       MlirStringRef stringRef = mlirIdentifierStr(identifier);
177       opName = std::string(stringRef.data, stringRef.length);
178     } else {
179       try {
180         opName = obj.attr("OPERATION_NAME").template cast<std::string>();
181       } catch (py::cast_error &) {
182         throw py::type_error(
183             "Op interface does not refer to an operation or OpView class");
184       }
185 
186       if (!mlirOperationImplementsInterfaceStatic(
187               mlirStringRefCreate(opName.data(), opName.length()),
188               context.resolve().get(), ConcreteIface::getInterfaceID())) {
189         std::string msg = "the operation does not implement ";
190         throw py::value_error(msg + ConcreteIface::pyClassName);
191       }
192     }
193   }
194 
195   /// Creates the Python bindings for this class in the given module.
196   static void bind(py::module &m) {
197     py::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName,
198                                   py::module_local());
199     cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
200             py::arg("context") = py::none(), constructorDoc)
201         .def_property_readonly("operation",
202                                &PyConcreteOpInterface::getOperationObject,
203                                operationDoc)
204         .def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
205                                opviewDoc);
206     ConcreteIface::bindDerived(cls);
207   }
208 
209   /// Hook for derived classes to add class-specific bindings.
210   static void bindDerived(ClassTy &cls) {}
211 
212   /// Returns `true` if this object was constructed from a subclass of OpView
213   /// rather than from an operation instance.
214   bool isStatic() { return operation == nullptr; }
215 
216   /// Returns the operation instance from which this object was constructed.
217   /// Throws a type error if this object was constructed from a subclass of
218   /// OpView.
219   py::object getOperationObject() {
220     if (operation == nullptr) {
221       throw py::type_error("Cannot get an operation from a static interface");
222     }
223 
224     return operation->getRef().releaseObject();
225   }
226 
227   /// Returns the opview of the operation instance from which this object was
228   /// constructed. Throws a type error if this object was constructed form a
229   /// subclass of OpView.
230   py::object getOpView() {
231     if (operation == nullptr) {
232       throw py::type_error("Cannot get an opview from a static interface");
233     }
234 
235     return operation->createOpView();
236   }
237 
238   /// Returns the canonical name of the operation this interface is constructed
239   /// from.
240   const std::string &getOpName() { return opName; }
241 
242 private:
243   PyOperation *operation = nullptr;
244   std::string opName;
245   py::object obj;
246 };
247 
248 /// Python wrapper for InferTypeOpInterface. This interface has only static
249 /// methods.
250 class PyInferTypeOpInterface
251     : public PyConcreteOpInterface<PyInferTypeOpInterface> {
252 public:
253   using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
254 
255   constexpr static const char *pyClassName = "InferTypeOpInterface";
256   constexpr static GetTypeIDFunctionTy getInterfaceID =
257       &mlirInferTypeOpInterfaceTypeID;
258 
259   /// C-style user-data structure for type appending callback.
260   struct AppendResultsCallbackData {
261     std::vector<PyType> &inferredTypes;
262     PyMlirContext &pyMlirContext;
263   };
264 
265   /// Appends the types provided as the two first arguments to the user-data
266   /// structure (expects AppendResultsCallbackData).
267   static void appendResultsCallback(intptr_t nTypes, MlirType *types,
268                                     void *userData) {
269     auto *data = static_cast<AppendResultsCallbackData *>(userData);
270     data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
271     for (intptr_t i = 0; i < nTypes; ++i) {
272       data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
273     }
274   }
275 
276   /// Given the arguments required to build an operation, attempts to infer its
277   /// return types. Throws value_error on failure.
278   std::vector<PyType>
279   inferReturnTypes(std::optional<py::list> operandList,
280                    std::optional<PyAttribute> attributes, void *properties,
281                    std::optional<std::vector<PyRegion>> regions,
282                    DefaultingPyMlirContext context,
283                    DefaultingPyLocation location) {
284     llvm::SmallVector<MlirValue> mlirOperands =
285         wrapOperands(std::move(operandList));
286     llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
287 
288     std::vector<PyType> inferredTypes;
289     PyMlirContext &pyContext = context.resolve();
290     AppendResultsCallbackData data{inferredTypes, pyContext};
291     MlirStringRef opNameRef =
292         mlirStringRefCreate(getOpName().data(), getOpName().length());
293     MlirAttribute attributeDict =
294         attributes ? attributes->get() : mlirAttributeGetNull();
295 
296     MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
297         opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
298         mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
299         mlirRegions.data(), &appendResultsCallback, &data);
300 
301     if (mlirLogicalResultIsFailure(result)) {
302       throw py::value_error("Failed to infer result types");
303     }
304 
305     return inferredTypes;
306   }
307 
308   static void bindDerived(ClassTy &cls) {
309     cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
310             py::arg("operands") = py::none(),
311             py::arg("attributes") = py::none(),
312             py::arg("properties") = py::none(), py::arg("regions") = py::none(),
313             py::arg("context") = py::none(), py::arg("loc") = py::none(),
314             inferReturnTypesDoc);
315   }
316 };
317 
318 /// Wrapper around an shaped type components.
319 class PyShapedTypeComponents {
320 public:
321   PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
322   PyShapedTypeComponents(py::list shape, MlirType elementType)
323       : shape(std::move(shape)), elementType(elementType), ranked(true) {}
324   PyShapedTypeComponents(py::list shape, MlirType elementType,
325                          MlirAttribute attribute)
326       : shape(std::move(shape)), elementType(elementType), attribute(attribute),
327         ranked(true) {}
328   PyShapedTypeComponents(PyShapedTypeComponents &) = delete;
329   PyShapedTypeComponents(PyShapedTypeComponents &&other)
330       : shape(other.shape), elementType(other.elementType),
331         attribute(other.attribute), ranked(other.ranked) {}
332 
333   static void bind(py::module &m) {
334     py::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents",
335                                        py::module_local())
336         .def_property_readonly(
337             "element_type",
338             [](PyShapedTypeComponents &self) { return self.elementType; },
339             "Returns the element type of the shaped type components.")
340         .def_static(
341             "get",
342             [](PyType &elementType) {
343               return PyShapedTypeComponents(elementType);
344             },
345             py::arg("element_type"),
346             "Create an shaped type components object with only the element "
347             "type.")
348         .def_static(
349             "get",
350             [](py::list shape, PyType &elementType) {
351               return PyShapedTypeComponents(std::move(shape), elementType);
352             },
353             py::arg("shape"), py::arg("element_type"),
354             "Create a ranked shaped type components object.")
355         .def_static(
356             "get",
357             [](py::list shape, PyType &elementType, PyAttribute &attribute) {
358               return PyShapedTypeComponents(std::move(shape), elementType,
359                                             attribute);
360             },
361             py::arg("shape"), py::arg("element_type"), py::arg("attribute"),
362             "Create a ranked shaped type components object with attribute.")
363         .def_property_readonly(
364             "has_rank",
365             [](PyShapedTypeComponents &self) -> bool { return self.ranked; },
366             "Returns whether the given shaped type component is ranked.")
367         .def_property_readonly(
368             "rank",
369             [](PyShapedTypeComponents &self) -> py::object {
370               if (!self.ranked) {
371                 return py::none();
372               }
373               return py::int_(self.shape.size());
374             },
375             "Returns the rank of the given ranked shaped type components. If "
376             "the shaped type components does not have a rank, None is "
377             "returned.")
378         .def_property_readonly(
379             "shape",
380             [](PyShapedTypeComponents &self) -> py::object {
381               if (!self.ranked) {
382                 return py::none();
383               }
384               return py::list(self.shape);
385             },
386             "Returns the shape of the ranked shaped type components as a list "
387             "of integers. Returns none if the shaped type component does not "
388             "have a rank.");
389   }
390 
391   pybind11::object getCapsule();
392   static PyShapedTypeComponents createFromCapsule(pybind11::object capsule);
393 
394 private:
395   py::list shape;
396   MlirType elementType;
397   MlirAttribute attribute;
398   bool ranked{false};
399 };
400 
401 /// Python wrapper for InferShapedTypeOpInterface. This interface has only
402 /// static methods.
403 class PyInferShapedTypeOpInterface
404     : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> {
405 public:
406   using PyConcreteOpInterface<
407       PyInferShapedTypeOpInterface>::PyConcreteOpInterface;
408 
409   constexpr static const char *pyClassName = "InferShapedTypeOpInterface";
410   constexpr static GetTypeIDFunctionTy getInterfaceID =
411       &mlirInferShapedTypeOpInterfaceTypeID;
412 
413   /// C-style user-data structure for type appending callback.
414   struct AppendResultsCallbackData {
415     std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents;
416   };
417 
418   /// Appends the shaped type components provided as unpacked shape, element
419   /// type, attribute to the user-data.
420   static void appendResultsCallback(bool hasRank, intptr_t rank,
421                                     const int64_t *shape, MlirType elementType,
422                                     MlirAttribute attribute, void *userData) {
423     auto *data = static_cast<AppendResultsCallbackData *>(userData);
424     if (!hasRank) {
425       data->inferredShapedTypeComponents.emplace_back(elementType);
426     } else {
427       py::list shapeList;
428       for (intptr_t i = 0; i < rank; ++i) {
429         shapeList.append(shape[i]);
430       }
431       data->inferredShapedTypeComponents.emplace_back(shapeList, elementType,
432                                                       attribute);
433     }
434   }
435 
436   /// Given the arguments required to build an operation, attempts to infer the
437   /// shaped type components. Throws value_error on failure.
438   std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
439       std::optional<py::list> operandList,
440       std::optional<PyAttribute> attributes, void *properties,
441       std::optional<std::vector<PyRegion>> regions,
442       DefaultingPyMlirContext context, DefaultingPyLocation location) {
443     llvm::SmallVector<MlirValue> mlirOperands =
444         wrapOperands(std::move(operandList));
445     llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
446 
447     std::vector<PyShapedTypeComponents> inferredShapedTypeComponents;
448     PyMlirContext &pyContext = context.resolve();
449     AppendResultsCallbackData data{inferredShapedTypeComponents};
450     MlirStringRef opNameRef =
451         mlirStringRefCreate(getOpName().data(), getOpName().length());
452     MlirAttribute attributeDict =
453         attributes ? attributes->get() : mlirAttributeGetNull();
454 
455     MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes(
456         opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
457         mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
458         mlirRegions.data(), &appendResultsCallback, &data);
459 
460     if (mlirLogicalResultIsFailure(result)) {
461       throw py::value_error("Failed to infer result shape type components");
462     }
463 
464     return inferredShapedTypeComponents;
465   }
466 
467   static void bindDerived(ClassTy &cls) {
468     cls.def("inferReturnTypeComponents",
469             &PyInferShapedTypeOpInterface::inferReturnTypeComponents,
470             py::arg("operands") = py::none(),
471             py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
472             py::arg("properties") = py::none(), py::arg("context") = py::none(),
473             py::arg("loc") = py::none(), inferReturnTypeComponentsDoc);
474   }
475 };
476 
477 void populateIRInterfaces(py::module &m) {
478   PyInferTypeOpInterface::bind(m);
479   PyShapedTypeComponents::bind(m);
480   PyInferShapedTypeOpInterface::bind(m);
481 }
482 
483 } // namespace python
484 } // namespace mlir
485