xref: /llvm-project/mlir/lib/Bindings/Python/IRInterfaces.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cstdint>
10 #include <optional>
11 #include <string>
12 #include <utility>
13 #include <vector>
14 
15 #include "IRModule.h"
16 #include "mlir-c/BuiltinAttributes.h"
17 #include "mlir-c/IR.h"
18 #include "mlir-c/Interfaces.h"
19 #include "mlir-c/Support.h"
20 #include "mlir/Bindings/Python/Nanobind.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 
24 namespace nb = nanobind;
25 
26 namespace mlir {
27 namespace python {
28 
29 constexpr static const char *constructorDoc =
30     R"(Creates an interface from a given operation/opview object or from a
31 subclass of OpView. Raises ValueError if the operation does not implement the
32 interface.)";
33 
34 constexpr static const char *operationDoc =
35     R"(Returns an Operation for which the interface was constructed.)";
36 
37 constexpr static const char *opviewDoc =
38     R"(Returns an OpView subclass _instance_ for which the interface was
39 constructed)";
40 
41 constexpr static const char *inferReturnTypesDoc =
42     R"(Given the arguments required to build an operation, attempts to infer
43 its return types. Raises ValueError on failure.)";
44 
45 constexpr static const char *inferReturnTypeComponentsDoc =
46     R"(Given the arguments required to build an operation, attempts to infer
47 its return shaped type components. Raises ValueError on failure.)";
48 
49 namespace {
50 
51 /// Takes in an optional ist of operands and converts them into a SmallVector
52 /// of MlirVlaues. Returns an empty SmallVector if the list is empty.
53 llvm::SmallVector<MlirValue> wrapOperands(std::optional<nb::list> operandList) {
54   llvm::SmallVector<MlirValue> mlirOperands;
55 
56   if (!operandList || operandList->size() == 0) {
57     return mlirOperands;
58   }
59 
60   // Note: as the list may contain other lists this may not be final size.
61   mlirOperands.reserve(operandList->size());
62   for (const auto &&it : llvm::enumerate(*operandList)) {
63     if (it.value().is_none())
64       continue;
65 
66     PyValue *val;
67     try {
68       val = nb::cast<PyValue *>(it.value());
69       if (!val)
70         throw nb::cast_error();
71       mlirOperands.push_back(val->get());
72       continue;
73     } catch (nb::cast_error &err) {
74       // Intentionally unhandled to try sequence below first.
75       (void)err;
76     }
77 
78     try {
79       auto vals = nb::cast<nb::sequence>(it.value());
80       for (nb::handle v : vals) {
81         try {
82           val = nb::cast<PyValue *>(v);
83           if (!val)
84             throw nb::cast_error();
85           mlirOperands.push_back(val->get());
86         } catch (nb::cast_error &err) {
87           throw nb::value_error(
88               (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
89                " must be a Value or Sequence of Values (" + err.what() + ")")
90                   .str()
91                   .c_str());
92         }
93       }
94       continue;
95     } catch (nb::cast_error &err) {
96       throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
97                              " must be a Value or Sequence of Values (" +
98                              err.what() + ")")
99                                 .str()
100                                 .c_str());
101     }
102 
103     throw nb::cast_error();
104   }
105 
106   return mlirOperands;
107 }
108 
109 /// Takes in an optional vector of PyRegions and returns a SmallVector of
110 /// MlirRegion. Returns an empty SmallVector if the list is empty.
111 llvm::SmallVector<MlirRegion>
112 wrapRegions(std::optional<std::vector<PyRegion>> regions) {
113   llvm::SmallVector<MlirRegion> mlirRegions;
114 
115   if (regions) {
116     mlirRegions.reserve(regions->size());
117     for (PyRegion &region : *regions) {
118       mlirRegions.push_back(region);
119     }
120   }
121 
122   return mlirRegions;
123 }
124 
125 } // namespace
126 
127 /// CRTP base class for Python classes representing MLIR Op interfaces.
128 /// Interface hierarchies are flat so no base class is expected here. The
129 /// derived class is expected to define the following static fields:
130 ///  - `const char *pyClassName` - the name of the Python class to create;
131 ///  - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
132 ///    of the interface.
133 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
134 /// interface-specific methods.
135 ///
136 /// An interface class may be constructed from either an Operation/OpView object
137 /// or from a subclass of OpView. In the latter case, only the static interface
138 /// methods are available, similarly to calling ConcereteOp::staticMethod on the
139 /// C++ side. Implementations of concrete interfaces can use the `isStatic`
140 /// method to check whether the interface object was constructed from a class or
141 /// an operation/opview instance. The `getOpName` always succeeds and returns a
142 /// canonical name of the operation suitable for lookups.
143 template <typename ConcreteIface>
144 class PyConcreteOpInterface {
145 protected:
146   using ClassTy = nb::class_<ConcreteIface>;
147   using GetTypeIDFunctionTy = MlirTypeID (*)();
148 
149 public:
150   /// Constructs an interface instance from an object that is either an
151   /// operation or a subclass of OpView. In the latter case, only the static
152   /// methods of the interface are accessible to the caller.
153   PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context)
154       : obj(std::move(object)) {
155     try {
156       operation = &nb::cast<PyOperation &>(obj);
157     } catch (nb::cast_error &) {
158       // Do nothing.
159     }
160 
161     try {
162       operation = &nb::cast<PyOpView &>(obj).getOperation();
163     } catch (nb::cast_error &) {
164       // Do nothing.
165     }
166 
167     if (operation != nullptr) {
168       if (!mlirOperationImplementsInterface(*operation,
169                                             ConcreteIface::getInterfaceID())) {
170         std::string msg = "the operation does not implement ";
171         throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
172       }
173 
174       MlirIdentifier identifier = mlirOperationGetName(*operation);
175       MlirStringRef stringRef = mlirIdentifierStr(identifier);
176       opName = std::string(stringRef.data, stringRef.length);
177     } else {
178       try {
179         opName = nb::cast<std::string>(obj.attr("OPERATION_NAME"));
180       } catch (nb::cast_error &) {
181         throw nb::type_error(
182             "Op interface does not refer to an operation or OpView class");
183       }
184 
185       if (!mlirOperationImplementsInterfaceStatic(
186               mlirStringRefCreate(opName.data(), opName.length()),
187               context.resolve().get(), ConcreteIface::getInterfaceID())) {
188         std::string msg = "the operation does not implement ";
189         throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
190       }
191     }
192   }
193 
194   /// Creates the Python bindings for this class in the given module.
195   static void bind(nb::module_ &m) {
196     nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
197     cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
198             nb::arg("context").none() = nb::none(), constructorDoc)
199         .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
200                      operationDoc)
201         .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
202     ConcreteIface::bindDerived(cls);
203   }
204 
205   /// Hook for derived classes to add class-specific bindings.
206   static void bindDerived(ClassTy &cls) {}
207 
208   /// Returns `true` if this object was constructed from a subclass of OpView
209   /// rather than from an operation instance.
210   bool isStatic() { return operation == nullptr; }
211 
212   /// Returns the operation instance from which this object was constructed.
213   /// Throws a type error if this object was constructed from a subclass of
214   /// OpView.
215   nb::object getOperationObject() {
216     if (operation == nullptr) {
217       throw nb::type_error("Cannot get an operation from a static interface");
218     }
219 
220     return operation->getRef().releaseObject();
221   }
222 
223   /// Returns the opview of the operation instance from which this object was
224   /// constructed. Throws a type error if this object was constructed form a
225   /// subclass of OpView.
226   nb::object getOpView() {
227     if (operation == nullptr) {
228       throw nb::type_error("Cannot get an opview from a static interface");
229     }
230 
231     return operation->createOpView();
232   }
233 
234   /// Returns the canonical name of the operation this interface is constructed
235   /// from.
236   const std::string &getOpName() { return opName; }
237 
238 private:
239   PyOperation *operation = nullptr;
240   std::string opName;
241   nb::object obj;
242 };
243 
244 /// Python wrapper for InferTypeOpInterface. This interface has only static
245 /// methods.
246 class PyInferTypeOpInterface
247     : public PyConcreteOpInterface<PyInferTypeOpInterface> {
248 public:
249   using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
250 
251   constexpr static const char *pyClassName = "InferTypeOpInterface";
252   constexpr static GetTypeIDFunctionTy getInterfaceID =
253       &mlirInferTypeOpInterfaceTypeID;
254 
255   /// C-style user-data structure for type appending callback.
256   struct AppendResultsCallbackData {
257     std::vector<PyType> &inferredTypes;
258     PyMlirContext &pyMlirContext;
259   };
260 
261   /// Appends the types provided as the two first arguments to the user-data
262   /// structure (expects AppendResultsCallbackData).
263   static void appendResultsCallback(intptr_t nTypes, MlirType *types,
264                                     void *userData) {
265     auto *data = static_cast<AppendResultsCallbackData *>(userData);
266     data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
267     for (intptr_t i = 0; i < nTypes; ++i) {
268       data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
269     }
270   }
271 
272   /// Given the arguments required to build an operation, attempts to infer its
273   /// return types. Throws value_error on failure.
274   std::vector<PyType>
275   inferReturnTypes(std::optional<nb::list> operandList,
276                    std::optional<PyAttribute> attributes, void *properties,
277                    std::optional<std::vector<PyRegion>> regions,
278                    DefaultingPyMlirContext context,
279                    DefaultingPyLocation location) {
280     llvm::SmallVector<MlirValue> mlirOperands =
281         wrapOperands(std::move(operandList));
282     llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
283 
284     std::vector<PyType> inferredTypes;
285     PyMlirContext &pyContext = context.resolve();
286     AppendResultsCallbackData data{inferredTypes, pyContext};
287     MlirStringRef opNameRef =
288         mlirStringRefCreate(getOpName().data(), getOpName().length());
289     MlirAttribute attributeDict =
290         attributes ? attributes->get() : mlirAttributeGetNull();
291 
292     MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
293         opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
294         mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
295         mlirRegions.data(), &appendResultsCallback, &data);
296 
297     if (mlirLogicalResultIsFailure(result)) {
298       throw nb::value_error("Failed to infer result types");
299     }
300 
301     return inferredTypes;
302   }
303 
304   static void bindDerived(ClassTy &cls) {
305     cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
306             nb::arg("operands").none() = nb::none(),
307             nb::arg("attributes").none() = nb::none(),
308             nb::arg("properties").none() = nb::none(),
309             nb::arg("regions").none() = nb::none(),
310             nb::arg("context").none() = nb::none(),
311             nb::arg("loc").none() = nb::none(), inferReturnTypesDoc);
312   }
313 };
314 
315 /// Wrapper around an shaped type components.
316 class PyShapedTypeComponents {
317 public:
318   PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
319   PyShapedTypeComponents(nb::list shape, MlirType elementType)
320       : shape(std::move(shape)), elementType(elementType), ranked(true) {}
321   PyShapedTypeComponents(nb::list shape, MlirType elementType,
322                          MlirAttribute attribute)
323       : shape(std::move(shape)), elementType(elementType), attribute(attribute),
324         ranked(true) {}
325   PyShapedTypeComponents(PyShapedTypeComponents &) = delete;
326   PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept
327       : shape(other.shape), elementType(other.elementType),
328         attribute(other.attribute), ranked(other.ranked) {}
329 
330   static void bind(nb::module_ &m) {
331     nb::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents")
332         .def_prop_ro(
333             "element_type",
334             [](PyShapedTypeComponents &self) { return self.elementType; },
335             "Returns the element type of the shaped type components.")
336         .def_static(
337             "get",
338             [](PyType &elementType) {
339               return PyShapedTypeComponents(elementType);
340             },
341             nb::arg("element_type"),
342             "Create an shaped type components object with only the element "
343             "type.")
344         .def_static(
345             "get",
346             [](nb::list shape, PyType &elementType) {
347               return PyShapedTypeComponents(std::move(shape), elementType);
348             },
349             nb::arg("shape"), nb::arg("element_type"),
350             "Create a ranked shaped type components object.")
351         .def_static(
352             "get",
353             [](nb::list shape, PyType &elementType, PyAttribute &attribute) {
354               return PyShapedTypeComponents(std::move(shape), elementType,
355                                             attribute);
356             },
357             nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"),
358             "Create a ranked shaped type components object with attribute.")
359         .def_prop_ro(
360             "has_rank",
361             [](PyShapedTypeComponents &self) -> bool { return self.ranked; },
362             "Returns whether the given shaped type component is ranked.")
363         .def_prop_ro(
364             "rank",
365             [](PyShapedTypeComponents &self) -> nb::object {
366               if (!self.ranked) {
367                 return nb::none();
368               }
369               return nb::int_(self.shape.size());
370             },
371             "Returns the rank of the given ranked shaped type components. If "
372             "the shaped type components does not have a rank, None is "
373             "returned.")
374         .def_prop_ro(
375             "shape",
376             [](PyShapedTypeComponents &self) -> nb::object {
377               if (!self.ranked) {
378                 return nb::none();
379               }
380               return nb::list(self.shape);
381             },
382             "Returns the shape of the ranked shaped type components as a list "
383             "of integers. Returns none if the shaped type component does not "
384             "have a rank.");
385   }
386 
387   nb::object getCapsule();
388   static PyShapedTypeComponents createFromCapsule(nb::object capsule);
389 
390 private:
391   nb::list shape;
392   MlirType elementType;
393   MlirAttribute attribute;
394   bool ranked{false};
395 };
396 
397 /// Python wrapper for InferShapedTypeOpInterface. This interface has only
398 /// static methods.
399 class PyInferShapedTypeOpInterface
400     : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> {
401 public:
402   using PyConcreteOpInterface<
403       PyInferShapedTypeOpInterface>::PyConcreteOpInterface;
404 
405   constexpr static const char *pyClassName = "InferShapedTypeOpInterface";
406   constexpr static GetTypeIDFunctionTy getInterfaceID =
407       &mlirInferShapedTypeOpInterfaceTypeID;
408 
409   /// C-style user-data structure for type appending callback.
410   struct AppendResultsCallbackData {
411     std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents;
412   };
413 
414   /// Appends the shaped type components provided as unpacked shape, element
415   /// type, attribute to the user-data.
416   static void appendResultsCallback(bool hasRank, intptr_t rank,
417                                     const int64_t *shape, MlirType elementType,
418                                     MlirAttribute attribute, void *userData) {
419     auto *data = static_cast<AppendResultsCallbackData *>(userData);
420     if (!hasRank) {
421       data->inferredShapedTypeComponents.emplace_back(elementType);
422     } else {
423       nb::list shapeList;
424       for (intptr_t i = 0; i < rank; ++i) {
425         shapeList.append(shape[i]);
426       }
427       data->inferredShapedTypeComponents.emplace_back(shapeList, elementType,
428                                                       attribute);
429     }
430   }
431 
432   /// Given the arguments required to build an operation, attempts to infer the
433   /// shaped type components. Throws value_error on failure.
434   std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
435       std::optional<nb::list> operandList,
436       std::optional<PyAttribute> attributes, void *properties,
437       std::optional<std::vector<PyRegion>> regions,
438       DefaultingPyMlirContext context, DefaultingPyLocation location) {
439     llvm::SmallVector<MlirValue> mlirOperands =
440         wrapOperands(std::move(operandList));
441     llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
442 
443     std::vector<PyShapedTypeComponents> inferredShapedTypeComponents;
444     PyMlirContext &pyContext = context.resolve();
445     AppendResultsCallbackData data{inferredShapedTypeComponents};
446     MlirStringRef opNameRef =
447         mlirStringRefCreate(getOpName().data(), getOpName().length());
448     MlirAttribute attributeDict =
449         attributes ? attributes->get() : mlirAttributeGetNull();
450 
451     MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes(
452         opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
453         mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
454         mlirRegions.data(), &appendResultsCallback, &data);
455 
456     if (mlirLogicalResultIsFailure(result)) {
457       throw nb::value_error("Failed to infer result shape type components");
458     }
459 
460     return inferredShapedTypeComponents;
461   }
462 
463   static void bindDerived(ClassTy &cls) {
464     cls.def("inferReturnTypeComponents",
465             &PyInferShapedTypeOpInterface::inferReturnTypeComponents,
466             nb::arg("operands").none() = nb::none(),
467             nb::arg("attributes").none() = nb::none(),
468             nb::arg("regions").none() = nb::none(),
469             nb::arg("properties").none() = nb::none(),
470             nb::arg("context").none() = nb::none(),
471             nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc);
472   }
473 };
474 
475 void populateIRInterfaces(nb::module_ &m) {
476   PyInferTypeOpInterface::bind(m);
477   PyShapedTypeComponents::bind(m);
478   PyInferShapedTypeOpInterface::bind(m);
479 }
480 
481 } // namespace python
482 } // namespace mlir
483