xref: /llvm-project/mlir/lib/Bindings/Python/DialectTransform.cpp (revision bfb1ba752655bf09b35c486f6cc9817dbedfb1bb)
1 //===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/Transform.h"
10 #include "mlir-c/IR.h"
11 #include "mlir-c/Support.h"
12 #include "mlir/Bindings/Python/PybindAdaptors.h"
13 
14 namespace py = pybind11;
15 using namespace mlir;
16 using namespace mlir::python;
17 using namespace mlir::python::adaptors;
18 
19 void populateDialectTransformSubmodule(const pybind11::module &m) {
20   //===-------------------------------------------------------------------===//
21   // AnyOpType
22   //===-------------------------------------------------------------------===//
23 
24   auto anyOpType =
25       mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType);
26   anyOpType.def_classmethod(
27       "get",
28       [](py::object cls, MlirContext ctx) {
29         return cls(mlirTransformAnyOpTypeGet(ctx));
30       },
31       "Get an instance of AnyOpType in the given context.", py::arg("cls"),
32       py::arg("context") = py::none());
33 
34   //===-------------------------------------------------------------------===//
35   // OperationType
36   //===-------------------------------------------------------------------===//
37 
38   auto operationType =
39       mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType,
40                          mlirTransformOperationTypeGetTypeID);
41   operationType.def_classmethod(
42       "get",
43       [](py::object cls, const std::string &operationName, MlirContext ctx) {
44         MlirStringRef cOperationName =
45             mlirStringRefCreate(operationName.data(), operationName.size());
46         return cls(mlirTransformOperationTypeGet(ctx, cOperationName));
47       },
48       "Get an instance of OperationType for the given kind in the given "
49       "context",
50       py::arg("cls"), py::arg("operation_name"),
51       py::arg("context") = py::none());
52   operationType.def_property_readonly(
53       "operation_name",
54       [](MlirType type) {
55         MlirStringRef operationName =
56             mlirTransformOperationTypeGetOperationName(type);
57         return py::str(operationName.data, operationName.length);
58       },
59       "Get the name of the payload operation accepted by the handle.");
60 }
61 
62 PYBIND11_MODULE(_mlirDialectsTransform, m) {
63   m.doc() = "MLIR Transform dialect.";
64   populateDialectTransformSubmodule(m);
65 }
66