xref: /llvm-project/mlir/lib/Bindings/Python/DialectTransform.cpp (revision 681eacc1b670fd7137d8677fef6fc76c6e37dca9)
1 //===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/Transform.h"
10 #include "mlir-c/IR.h"
11 #include "mlir-c/Support.h"
12 #include "mlir/Bindings/Python/PybindAdaptors.h"
13 #include <pybind11/cast.h>
14 #include <pybind11/detail/common.h>
15 #include <pybind11/pybind11.h>
16 #include <pybind11/pytypes.h>
17 #include <string>
18 
19 namespace py = pybind11;
20 using namespace mlir;
21 using namespace mlir::python;
22 using namespace mlir::python::adaptors;
23 
24 void populateDialectTransformSubmodule(const pybind11::module &m) {
25   //===-------------------------------------------------------------------===//
26   // AnyOpType
27   //===-------------------------------------------------------------------===//
28 
29   auto anyOpType =
30       mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType,
31                          mlirTransformAnyOpTypeGetTypeID);
32   anyOpType.def_classmethod(
33       "get",
34       [](py::object cls, MlirContext ctx) {
35         return cls(mlirTransformAnyOpTypeGet(ctx));
36       },
37       "Get an instance of AnyOpType in the given context.", py::arg("cls"),
38       py::arg("context") = py::none());
39 
40   //===-------------------------------------------------------------------===//
41   // AnyParamType
42   //===-------------------------------------------------------------------===//
43 
44   auto anyParamType =
45       mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType,
46                          mlirTransformAnyParamTypeGetTypeID);
47   anyParamType.def_classmethod(
48       "get",
49       [](py::object cls, MlirContext ctx) {
50         return cls(mlirTransformAnyParamTypeGet(ctx));
51       },
52       "Get an instance of AnyParamType in the given context.", py::arg("cls"),
53       py::arg("context") = py::none());
54 
55   //===-------------------------------------------------------------------===//
56   // AnyValueType
57   //===-------------------------------------------------------------------===//
58 
59   auto anyValueType =
60       mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType,
61                          mlirTransformAnyValueTypeGetTypeID);
62   anyValueType.def_classmethod(
63       "get",
64       [](py::object cls, MlirContext ctx) {
65         return cls(mlirTransformAnyValueTypeGet(ctx));
66       },
67       "Get an instance of AnyValueType in the given context.", py::arg("cls"),
68       py::arg("context") = py::none());
69 
70   //===-------------------------------------------------------------------===//
71   // OperationType
72   //===-------------------------------------------------------------------===//
73 
74   auto operationType =
75       mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType,
76                          mlirTransformOperationTypeGetTypeID);
77   operationType.def_classmethod(
78       "get",
79       [](py::object cls, const std::string &operationName, MlirContext ctx) {
80         MlirStringRef cOperationName =
81             mlirStringRefCreate(operationName.data(), operationName.size());
82         return cls(mlirTransformOperationTypeGet(ctx, cOperationName));
83       },
84       "Get an instance of OperationType for the given kind in the given "
85       "context",
86       py::arg("cls"), py::arg("operation_name"),
87       py::arg("context") = py::none());
88   operationType.def_property_readonly(
89       "operation_name",
90       [](MlirType type) {
91         MlirStringRef operationName =
92             mlirTransformOperationTypeGetOperationName(type);
93         return py::str(operationName.data, operationName.length);
94       },
95       "Get the name of the payload operation accepted by the handle.");
96 
97   //===-------------------------------------------------------------------===//
98   // ParamType
99   //===-------------------------------------------------------------------===//
100 
101   auto paramType =
102       mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType,
103                          mlirTransformParamTypeGetTypeID);
104   paramType.def_classmethod(
105       "get",
106       [](py::object cls, MlirType type, MlirContext ctx) {
107         return cls(mlirTransformParamTypeGet(ctx, type));
108       },
109       "Get an instance of ParamType for the given type in the given context.",
110       py::arg("cls"), py::arg("type"), py::arg("context") = py::none());
111   paramType.def_property_readonly(
112       "type",
113       [](MlirType type) {
114         MlirType paramType = mlirTransformParamTypeGetType(type);
115         return paramType;
116       },
117       "Get the type this ParamType is associated with.");
118 }
119 
120 PYBIND11_MODULE(_mlirDialectsTransform, m) {
121   m.doc() = "MLIR Transform dialect.";
122   populateDialectTransformSubmodule(m);
123 }
124