xref: /llvm-project/mlir/lib/Bindings/Python/DialectTransform.cpp (revision 8b134d0b3593a511bff3046d89d35a7d09c469bd)
1 //===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/Transform.h"
10 #include "mlir-c/IR.h"
11 #include "mlir-c/Support.h"
12 #include "mlir/Bindings/Python/PybindAdaptors.h"
13 
14 namespace py = pybind11;
15 using namespace mlir;
16 using namespace mlir::python;
17 using namespace mlir::python::adaptors;
18 
19 void populateDialectTransformSubmodule(const pybind11::module &m) {
20   //===-------------------------------------------------------------------===//
21   // AnyOpType
22   //===-------------------------------------------------------------------===//
23 
24   auto anyOpType =
25       mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType);
26   anyOpType.def_classmethod(
27       "get",
28       [](py::object cls, MlirContext ctx) {
29         return cls(mlirTransformAnyOpTypeGet(ctx));
30       },
31       "Get an instance of AnyOpType in the given context.", py::arg("cls"),
32       py::arg("context") = py::none());
33 
34   //===-------------------------------------------------------------------===//
35   // AnyValueType
36   //===-------------------------------------------------------------------===//
37 
38   auto anyValueType =
39       mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType);
40   anyValueType.def_classmethod(
41       "get",
42       [](py::object cls, MlirContext ctx) {
43         return cls(mlirTransformAnyValueTypeGet(ctx));
44       },
45       "Get an instance of AnyValueType in the given context.", py::arg("cls"),
46       py::arg("context") = py::none());
47 
48   //===-------------------------------------------------------------------===//
49   // OperationType
50   //===-------------------------------------------------------------------===//
51 
52   auto operationType =
53       mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType,
54                          mlirTransformOperationTypeGetTypeID);
55   operationType.def_classmethod(
56       "get",
57       [](py::object cls, const std::string &operationName, MlirContext ctx) {
58         MlirStringRef cOperationName =
59             mlirStringRefCreate(operationName.data(), operationName.size());
60         return cls(mlirTransformOperationTypeGet(ctx, cOperationName));
61       },
62       "Get an instance of OperationType for the given kind in the given "
63       "context",
64       py::arg("cls"), py::arg("operation_name"),
65       py::arg("context") = py::none());
66   operationType.def_property_readonly(
67       "operation_name",
68       [](MlirType type) {
69         MlirStringRef operationName =
70             mlirTransformOperationTypeGetOperationName(type);
71         return py::str(operationName.data, operationName.length);
72       },
73       "Get the name of the payload operation accepted by the handle.");
74 }
75 
76 PYBIND11_MODULE(_mlirDialectsTransform, m) {
77   m.doc() = "MLIR Transform dialect.";
78   populateDialectTransformSubmodule(m);
79 }
80