xref: /llvm-project/mlir/lib/Bindings/Python/DialectTransform.cpp (revision 285a229f205ae67dca48c8eac8206a115320c677)
1 //===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/Transform.h"
10 #include "mlir-c/IR.h"
11 #include "mlir-c/Support.h"
12 #include "mlir/Bindings/Python/PybindAdaptors.h"
13 #include <pybind11/cast.h>
14 #include <pybind11/detail/common.h>
15 #include <pybind11/pybind11.h>
16 #include <pybind11/pytypes.h>
17 #include <string>
18 
19 namespace py = pybind11;
20 using namespace mlir;
21 using namespace mlir::python;
22 using namespace mlir::python::adaptors;
23 
24 void populateDialectTransformSubmodule(const pybind11::module &m) {
25   //===-------------------------------------------------------------------===//
26   // AnyOpType
27   //===-------------------------------------------------------------------===//
28 
29   auto anyOpType =
30       mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType);
31   anyOpType.def_classmethod(
32       "get",
33       [](py::object cls, MlirContext ctx) {
34         return cls(mlirTransformAnyOpTypeGet(ctx));
35       },
36       "Get an instance of AnyOpType in the given context.", py::arg("cls"),
37       py::arg("context") = py::none());
38 
39   //===-------------------------------------------------------------------===//
40   // AnyParamType
41   //===-------------------------------------------------------------------===//
42 
43   auto anyParamType =
44       mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType);
45   anyParamType.def_classmethod(
46       "get",
47       [](py::object cls, MlirContext ctx) {
48         return cls(mlirTransformAnyParamTypeGet(ctx));
49       },
50       "Get an instance of AnyParamType in the given context.", py::arg("cls"),
51       py::arg("context") = py::none());
52 
53   //===-------------------------------------------------------------------===//
54   // AnyValueType
55   //===-------------------------------------------------------------------===//
56 
57   auto anyValueType =
58       mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType);
59   anyValueType.def_classmethod(
60       "get",
61       [](py::object cls, MlirContext ctx) {
62         return cls(mlirTransformAnyValueTypeGet(ctx));
63       },
64       "Get an instance of AnyValueType in the given context.", py::arg("cls"),
65       py::arg("context") = py::none());
66 
67   //===-------------------------------------------------------------------===//
68   // OperationType
69   //===-------------------------------------------------------------------===//
70 
71   auto operationType =
72       mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType,
73                          mlirTransformOperationTypeGetTypeID);
74   operationType.def_classmethod(
75       "get",
76       [](py::object cls, const std::string &operationName, MlirContext ctx) {
77         MlirStringRef cOperationName =
78             mlirStringRefCreate(operationName.data(), operationName.size());
79         return cls(mlirTransformOperationTypeGet(ctx, cOperationName));
80       },
81       "Get an instance of OperationType for the given kind in the given "
82       "context",
83       py::arg("cls"), py::arg("operation_name"),
84       py::arg("context") = py::none());
85   operationType.def_property_readonly(
86       "operation_name",
87       [](MlirType type) {
88         MlirStringRef operationName =
89             mlirTransformOperationTypeGetOperationName(type);
90         return py::str(operationName.data, operationName.length);
91       },
92       "Get the name of the payload operation accepted by the handle.");
93 
94   //===-------------------------------------------------------------------===//
95   // ParamType
96   //===-------------------------------------------------------------------===//
97 
98   auto paramType =
99       mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType);
100   paramType.def_classmethod(
101       "get",
102       [](py::object cls, MlirType type, MlirContext ctx) {
103         return cls(mlirTransformParamTypeGet(ctx, type));
104       },
105       "Get an instance of ParamType for the given type in the given context.",
106       py::arg("cls"), py::arg("type"), py::arg("context") = py::none());
107   paramType.def_property_readonly(
108       "type",
109       [](MlirType type) {
110         MlirType paramType = mlirTransformParamTypeGetType(type);
111         return paramType;
112       },
113       "Get the type this ParamType is associated with.");
114 }
115 
116 PYBIND11_MODULE(_mlirDialectsTransform, m) {
117   m.doc() = "MLIR Transform dialect.";
118   populateDialectTransformSubmodule(m);
119 }
120