xref: /llvm-project/mlir/lib/Bindings/Python/DialectTransform.cpp (revision 3e1f6d02f755e34a0a12a8dd439fb65f84d6621f)
1 //===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/Transform.h"
10 #include "mlir-c/IR.h"
11 #include "mlir-c/Support.h"
12 #include "mlir/Bindings/Python/PybindAdaptors.h"
13 
14 namespace py = pybind11;
15 using namespace mlir;
16 using namespace mlir::python;
17 using namespace mlir::python::adaptors;
18 
19 void populateDialectTransformSubmodule(const pybind11::module &m) {
20   //===-------------------------------------------------------------------===//
21   // AnyOpType
22   //===-------------------------------------------------------------------===//
23 
24   auto anyOpType =
25       mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType);
26   anyOpType.def_classmethod(
27       "get",
28       [](py::object cls, MlirContext ctx) {
29         return cls(mlirTransformAnyOpTypeGet(ctx));
30       },
31       "Get an instance of AnyOpType in the given context.", py::arg("cls"),
32       py::arg("context") = py::none());
33 
34   //===-------------------------------------------------------------------===//
35   // OperationType
36   //===-------------------------------------------------------------------===//
37 
38   auto operationType =
39       mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType);
40   operationType.def_classmethod(
41       "get",
42       [](py::object cls, const std::string &operationName, MlirContext ctx) {
43         MlirStringRef cOperationName =
44             mlirStringRefCreate(operationName.data(), operationName.size());
45         return cls(mlirTransformOperationTypeGet(ctx, cOperationName));
46       },
47       "Get an instance of OperationType for the given kind in the given "
48       "context",
49       py::arg("cls"), py::arg("operation_name"),
50       py::arg("context") = py::none());
51   operationType.def_property_readonly(
52       "operation_name",
53       [](MlirType type) {
54         MlirStringRef operationName =
55             mlirTransformOperationTypeGetOperationName(type);
56         return py::str(operationName.data, operationName.length);
57       },
58       "Get the name of the payload operation accepted by the handle.");
59 }
60 
61 PYBIND11_MODULE(_mlirDialectsTransform, m) {
62   m.doc() = "MLIR Transform dialect.";
63   populateDialectTransformSubmodule(m);
64 }
65