xref: /llvm-project/mlir/lib/Bindings/Python/IRInterfaces.cpp (revision 41bd35b58bb482fd466aa4b13aa44a810ad6470f)
1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <nanobind/nanobind.h>
10 #include <nanobind/stl/optional.h>
11 #include <nanobind/stl/vector.h>
12 
13 #include <cstdint>
14 #include <optional>
15 #include <string>
16 #include <utility>
17 #include <vector>
18 
19 #include "IRModule.h"
20 #include "mlir-c/BuiltinAttributes.h"
21 #include "mlir-c/IR.h"
22 #include "mlir-c/Interfaces.h"
23 #include "mlir-c/Support.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 
27 namespace nb = nanobind;
28 
29 namespace mlir {
30 namespace python {
31 
32 constexpr static const char *constructorDoc =
33     R"(Creates an interface from a given operation/opview object or from a
34 subclass of OpView. Raises ValueError if the operation does not implement the
35 interface.)";
36 
37 constexpr static const char *operationDoc =
38     R"(Returns an Operation for which the interface was constructed.)";
39 
40 constexpr static const char *opviewDoc =
41     R"(Returns an OpView subclass _instance_ for which the interface was
42 constructed)";
43 
44 constexpr static const char *inferReturnTypesDoc =
45     R"(Given the arguments required to build an operation, attempts to infer
46 its return types. Raises ValueError on failure.)";
47 
48 constexpr static const char *inferReturnTypeComponentsDoc =
49     R"(Given the arguments required to build an operation, attempts to infer
50 its return shaped type components. Raises ValueError on failure.)";
51 
52 namespace {
53 
54 /// Takes in an optional ist of operands and converts them into a SmallVector
55 /// of MlirVlaues. Returns an empty SmallVector if the list is empty.
56 llvm::SmallVector<MlirValue> wrapOperands(std::optional<nb::list> operandList) {
57   llvm::SmallVector<MlirValue> mlirOperands;
58 
59   if (!operandList || operandList->size() == 0) {
60     return mlirOperands;
61   }
62 
63   // Note: as the list may contain other lists this may not be final size.
64   mlirOperands.reserve(operandList->size());
65   for (const auto &&it : llvm::enumerate(*operandList)) {
66     if (it.value().is_none())
67       continue;
68 
69     PyValue *val;
70     try {
71       val = nb::cast<PyValue *>(it.value());
72       if (!val)
73         throw nb::cast_error();
74       mlirOperands.push_back(val->get());
75       continue;
76     } catch (nb::cast_error &err) {
77       // Intentionally unhandled to try sequence below first.
78       (void)err;
79     }
80 
81     try {
82       auto vals = nb::cast<nb::sequence>(it.value());
83       for (nb::handle v : vals) {
84         try {
85           val = nb::cast<PyValue *>(v);
86           if (!val)
87             throw nb::cast_error();
88           mlirOperands.push_back(val->get());
89         } catch (nb::cast_error &err) {
90           throw nb::value_error(
91               (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
92                " must be a Value or Sequence of Values (" + err.what() + ")")
93                   .str()
94                   .c_str());
95         }
96       }
97       continue;
98     } catch (nb::cast_error &err) {
99       throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
100                              " must be a Value or Sequence of Values (" +
101                              err.what() + ")")
102                                 .str()
103                                 .c_str());
104     }
105 
106     throw nb::cast_error();
107   }
108 
109   return mlirOperands;
110 }
111 
112 /// Takes in an optional vector of PyRegions and returns a SmallVector of
113 /// MlirRegion. Returns an empty SmallVector if the list is empty.
114 llvm::SmallVector<MlirRegion>
115 wrapRegions(std::optional<std::vector<PyRegion>> regions) {
116   llvm::SmallVector<MlirRegion> mlirRegions;
117 
118   if (regions) {
119     mlirRegions.reserve(regions->size());
120     for (PyRegion &region : *regions) {
121       mlirRegions.push_back(region);
122     }
123   }
124 
125   return mlirRegions;
126 }
127 
128 } // namespace
129 
130 /// CRTP base class for Python classes representing MLIR Op interfaces.
131 /// Interface hierarchies are flat so no base class is expected here. The
132 /// derived class is expected to define the following static fields:
133 ///  - `const char *pyClassName` - the name of the Python class to create;
134 ///  - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
135 ///    of the interface.
136 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
137 /// interface-specific methods.
138 ///
139 /// An interface class may be constructed from either an Operation/OpView object
140 /// or from a subclass of OpView. In the latter case, only the static interface
141 /// methods are available, similarly to calling ConcereteOp::staticMethod on the
142 /// C++ side. Implementations of concrete interfaces can use the `isStatic`
143 /// method to check whether the interface object was constructed from a class or
144 /// an operation/opview instance. The `getOpName` always succeeds and returns a
145 /// canonical name of the operation suitable for lookups.
146 template <typename ConcreteIface>
147 class PyConcreteOpInterface {
148 protected:
149   using ClassTy = nb::class_<ConcreteIface>;
150   using GetTypeIDFunctionTy = MlirTypeID (*)();
151 
152 public:
153   /// Constructs an interface instance from an object that is either an
154   /// operation or a subclass of OpView. In the latter case, only the static
155   /// methods of the interface are accessible to the caller.
156   PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context)
157       : obj(std::move(object)) {
158     try {
159       operation = &nb::cast<PyOperation &>(obj);
160     } catch (nb::cast_error &) {
161       // Do nothing.
162     }
163 
164     try {
165       operation = &nb::cast<PyOpView &>(obj).getOperation();
166     } catch (nb::cast_error &) {
167       // Do nothing.
168     }
169 
170     if (operation != nullptr) {
171       if (!mlirOperationImplementsInterface(*operation,
172                                             ConcreteIface::getInterfaceID())) {
173         std::string msg = "the operation does not implement ";
174         throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
175       }
176 
177       MlirIdentifier identifier = mlirOperationGetName(*operation);
178       MlirStringRef stringRef = mlirIdentifierStr(identifier);
179       opName = std::string(stringRef.data, stringRef.length);
180     } else {
181       try {
182         opName = nb::cast<std::string>(obj.attr("OPERATION_NAME"));
183       } catch (nb::cast_error &) {
184         throw nb::type_error(
185             "Op interface does not refer to an operation or OpView class");
186       }
187 
188       if (!mlirOperationImplementsInterfaceStatic(
189               mlirStringRefCreate(opName.data(), opName.length()),
190               context.resolve().get(), ConcreteIface::getInterfaceID())) {
191         std::string msg = "the operation does not implement ";
192         throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
193       }
194     }
195   }
196 
197   /// Creates the Python bindings for this class in the given module.
198   static void bind(nb::module_ &m) {
199     nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
200     cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
201             nb::arg("context").none() = nb::none(), constructorDoc)
202         .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
203                      operationDoc)
204         .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
205     ConcreteIface::bindDerived(cls);
206   }
207 
208   /// Hook for derived classes to add class-specific bindings.
209   static void bindDerived(ClassTy &cls) {}
210 
211   /// Returns `true` if this object was constructed from a subclass of OpView
212   /// rather than from an operation instance.
213   bool isStatic() { return operation == nullptr; }
214 
215   /// Returns the operation instance from which this object was constructed.
216   /// Throws a type error if this object was constructed from a subclass of
217   /// OpView.
218   nb::object getOperationObject() {
219     if (operation == nullptr) {
220       throw nb::type_error("Cannot get an operation from a static interface");
221     }
222 
223     return operation->getRef().releaseObject();
224   }
225 
226   /// Returns the opview of the operation instance from which this object was
227   /// constructed. Throws a type error if this object was constructed form a
228   /// subclass of OpView.
229   nb::object getOpView() {
230     if (operation == nullptr) {
231       throw nb::type_error("Cannot get an opview from a static interface");
232     }
233 
234     return operation->createOpView();
235   }
236 
237   /// Returns the canonical name of the operation this interface is constructed
238   /// from.
239   const std::string &getOpName() { return opName; }
240 
241 private:
242   PyOperation *operation = nullptr;
243   std::string opName;
244   nb::object obj;
245 };
246 
247 /// Python wrapper for InferTypeOpInterface. This interface has only static
248 /// methods.
249 class PyInferTypeOpInterface
250     : public PyConcreteOpInterface<PyInferTypeOpInterface> {
251 public:
252   using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
253 
254   constexpr static const char *pyClassName = "InferTypeOpInterface";
255   constexpr static GetTypeIDFunctionTy getInterfaceID =
256       &mlirInferTypeOpInterfaceTypeID;
257 
258   /// C-style user-data structure for type appending callback.
259   struct AppendResultsCallbackData {
260     std::vector<PyType> &inferredTypes;
261     PyMlirContext &pyMlirContext;
262   };
263 
264   /// Appends the types provided as the two first arguments to the user-data
265   /// structure (expects AppendResultsCallbackData).
266   static void appendResultsCallback(intptr_t nTypes, MlirType *types,
267                                     void *userData) {
268     auto *data = static_cast<AppendResultsCallbackData *>(userData);
269     data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
270     for (intptr_t i = 0; i < nTypes; ++i) {
271       data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
272     }
273   }
274 
275   /// Given the arguments required to build an operation, attempts to infer its
276   /// return types. Throws value_error on failure.
277   std::vector<PyType>
278   inferReturnTypes(std::optional<nb::list> operandList,
279                    std::optional<PyAttribute> attributes, void *properties,
280                    std::optional<std::vector<PyRegion>> regions,
281                    DefaultingPyMlirContext context,
282                    DefaultingPyLocation location) {
283     llvm::SmallVector<MlirValue> mlirOperands =
284         wrapOperands(std::move(operandList));
285     llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
286 
287     std::vector<PyType> inferredTypes;
288     PyMlirContext &pyContext = context.resolve();
289     AppendResultsCallbackData data{inferredTypes, pyContext};
290     MlirStringRef opNameRef =
291         mlirStringRefCreate(getOpName().data(), getOpName().length());
292     MlirAttribute attributeDict =
293         attributes ? attributes->get() : mlirAttributeGetNull();
294 
295     MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
296         opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
297         mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
298         mlirRegions.data(), &appendResultsCallback, &data);
299 
300     if (mlirLogicalResultIsFailure(result)) {
301       throw nb::value_error("Failed to infer result types");
302     }
303 
304     return inferredTypes;
305   }
306 
307   static void bindDerived(ClassTy &cls) {
308     cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
309             nb::arg("operands").none() = nb::none(),
310             nb::arg("attributes").none() = nb::none(),
311             nb::arg("properties").none() = nb::none(),
312             nb::arg("regions").none() = nb::none(),
313             nb::arg("context").none() = nb::none(),
314             nb::arg("loc").none() = nb::none(), inferReturnTypesDoc);
315   }
316 };
317 
318 /// Wrapper around an shaped type components.
319 class PyShapedTypeComponents {
320 public:
321   PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
322   PyShapedTypeComponents(nb::list shape, MlirType elementType)
323       : shape(std::move(shape)), elementType(elementType), ranked(true) {}
324   PyShapedTypeComponents(nb::list shape, MlirType elementType,
325                          MlirAttribute attribute)
326       : shape(std::move(shape)), elementType(elementType), attribute(attribute),
327         ranked(true) {}
328   PyShapedTypeComponents(PyShapedTypeComponents &) = delete;
329   PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept
330       : shape(other.shape), elementType(other.elementType),
331         attribute(other.attribute), ranked(other.ranked) {}
332 
333   static void bind(nb::module_ &m) {
334     nb::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents")
335         .def_prop_ro(
336             "element_type",
337             [](PyShapedTypeComponents &self) { return self.elementType; },
338             "Returns the element type of the shaped type components.")
339         .def_static(
340             "get",
341             [](PyType &elementType) {
342               return PyShapedTypeComponents(elementType);
343             },
344             nb::arg("element_type"),
345             "Create an shaped type components object with only the element "
346             "type.")
347         .def_static(
348             "get",
349             [](nb::list shape, PyType &elementType) {
350               return PyShapedTypeComponents(std::move(shape), elementType);
351             },
352             nb::arg("shape"), nb::arg("element_type"),
353             "Create a ranked shaped type components object.")
354         .def_static(
355             "get",
356             [](nb::list shape, PyType &elementType, PyAttribute &attribute) {
357               return PyShapedTypeComponents(std::move(shape), elementType,
358                                             attribute);
359             },
360             nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"),
361             "Create a ranked shaped type components object with attribute.")
362         .def_prop_ro(
363             "has_rank",
364             [](PyShapedTypeComponents &self) -> bool { return self.ranked; },
365             "Returns whether the given shaped type component is ranked.")
366         .def_prop_ro(
367             "rank",
368             [](PyShapedTypeComponents &self) -> nb::object {
369               if (!self.ranked) {
370                 return nb::none();
371               }
372               return nb::int_(self.shape.size());
373             },
374             "Returns the rank of the given ranked shaped type components. If "
375             "the shaped type components does not have a rank, None is "
376             "returned.")
377         .def_prop_ro(
378             "shape",
379             [](PyShapedTypeComponents &self) -> nb::object {
380               if (!self.ranked) {
381                 return nb::none();
382               }
383               return nb::list(self.shape);
384             },
385             "Returns the shape of the ranked shaped type components as a list "
386             "of integers. Returns none if the shaped type component does not "
387             "have a rank.");
388   }
389 
390   nb::object getCapsule();
391   static PyShapedTypeComponents createFromCapsule(nb::object capsule);
392 
393 private:
394   nb::list shape;
395   MlirType elementType;
396   MlirAttribute attribute;
397   bool ranked{false};
398 };
399 
400 /// Python wrapper for InferShapedTypeOpInterface. This interface has only
401 /// static methods.
402 class PyInferShapedTypeOpInterface
403     : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> {
404 public:
405   using PyConcreteOpInterface<
406       PyInferShapedTypeOpInterface>::PyConcreteOpInterface;
407 
408   constexpr static const char *pyClassName = "InferShapedTypeOpInterface";
409   constexpr static GetTypeIDFunctionTy getInterfaceID =
410       &mlirInferShapedTypeOpInterfaceTypeID;
411 
412   /// C-style user-data structure for type appending callback.
413   struct AppendResultsCallbackData {
414     std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents;
415   };
416 
417   /// Appends the shaped type components provided as unpacked shape, element
418   /// type, attribute to the user-data.
419   static void appendResultsCallback(bool hasRank, intptr_t rank,
420                                     const int64_t *shape, MlirType elementType,
421                                     MlirAttribute attribute, void *userData) {
422     auto *data = static_cast<AppendResultsCallbackData *>(userData);
423     if (!hasRank) {
424       data->inferredShapedTypeComponents.emplace_back(elementType);
425     } else {
426       nb::list shapeList;
427       for (intptr_t i = 0; i < rank; ++i) {
428         shapeList.append(shape[i]);
429       }
430       data->inferredShapedTypeComponents.emplace_back(shapeList, elementType,
431                                                       attribute);
432     }
433   }
434 
435   /// Given the arguments required to build an operation, attempts to infer the
436   /// shaped type components. Throws value_error on failure.
437   std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
438       std::optional<nb::list> operandList,
439       std::optional<PyAttribute> attributes, void *properties,
440       std::optional<std::vector<PyRegion>> regions,
441       DefaultingPyMlirContext context, DefaultingPyLocation location) {
442     llvm::SmallVector<MlirValue> mlirOperands =
443         wrapOperands(std::move(operandList));
444     llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
445 
446     std::vector<PyShapedTypeComponents> inferredShapedTypeComponents;
447     PyMlirContext &pyContext = context.resolve();
448     AppendResultsCallbackData data{inferredShapedTypeComponents};
449     MlirStringRef opNameRef =
450         mlirStringRefCreate(getOpName().data(), getOpName().length());
451     MlirAttribute attributeDict =
452         attributes ? attributes->get() : mlirAttributeGetNull();
453 
454     MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes(
455         opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
456         mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
457         mlirRegions.data(), &appendResultsCallback, &data);
458 
459     if (mlirLogicalResultIsFailure(result)) {
460       throw nb::value_error("Failed to infer result shape type components");
461     }
462 
463     return inferredShapedTypeComponents;
464   }
465 
466   static void bindDerived(ClassTy &cls) {
467     cls.def("inferReturnTypeComponents",
468             &PyInferShapedTypeOpInterface::inferReturnTypeComponents,
469             nb::arg("operands").none() = nb::none(),
470             nb::arg("attributes").none() = nb::none(),
471             nb::arg("regions").none() = nb::none(),
472             nb::arg("properties").none() = nb::none(),
473             nb::arg("context").none() = nb::none(),
474             nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc);
475   }
476 };
477 
478 void populateIRInterfaces(nb::module_ &m) {
479   PyInferTypeOpInterface::bind(m);
480   PyShapedTypeComponents::bind(m);
481   PyInferShapedTypeOpInterface::bind(m);
482 }
483 
484 } // namespace python
485 } // namespace mlir
486