1 //===- DialectTransform.cpp - 'transform' dialect submodule ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Dialect/Transform.h" 10 #include "mlir-c/IR.h" 11 #include "mlir-c/Support.h" 12 #include "mlir/Bindings/Python/PybindAdaptors.h" 13 #include <pybind11/cast.h> 14 #include <pybind11/detail/common.h> 15 #include <pybind11/pybind11.h> 16 #include <pybind11/pytypes.h> 17 #include <string> 18 19 namespace py = pybind11; 20 using namespace mlir; 21 using namespace mlir::python; 22 using namespace mlir::python::adaptors; 23 24 void populateDialectTransformSubmodule(const pybind11::module &m) { 25 //===-------------------------------------------------------------------===// 26 // AnyOpType 27 //===-------------------------------------------------------------------===// 28 29 auto anyOpType = 30 mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType); 31 anyOpType.def_classmethod( 32 "get", 33 [](py::object cls, MlirContext ctx) { 34 return cls(mlirTransformAnyOpTypeGet(ctx)); 35 }, 36 "Get an instance of AnyOpType in the given context.", py::arg("cls"), 37 py::arg("context") = py::none()); 38 39 //===-------------------------------------------------------------------===// 40 // AnyParamType 41 //===-------------------------------------------------------------------===// 42 43 auto anyParamType = 44 mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType); 45 anyParamType.def_classmethod( 46 "get", 47 [](py::object cls, MlirContext ctx) { 48 return cls(mlirTransformAnyParamTypeGet(ctx)); 49 }, 50 "Get an instance of AnyParamType in the given context.", py::arg("cls"), 51 py::arg("context") = py::none()); 52 53 //===-------------------------------------------------------------------===// 54 // AnyValueType 55 //===-------------------------------------------------------------------===// 56 57 auto anyValueType = 58 mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType); 59 anyValueType.def_classmethod( 60 "get", 61 [](py::object cls, MlirContext ctx) { 62 return cls(mlirTransformAnyValueTypeGet(ctx)); 63 }, 64 "Get an instance of AnyValueType in the given context.", py::arg("cls"), 65 py::arg("context") = py::none()); 66 67 //===-------------------------------------------------------------------===// 68 // OperationType 69 //===-------------------------------------------------------------------===// 70 71 auto operationType = 72 mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType, 73 mlirTransformOperationTypeGetTypeID); 74 operationType.def_classmethod( 75 "get", 76 [](py::object cls, const std::string &operationName, MlirContext ctx) { 77 MlirStringRef cOperationName = 78 mlirStringRefCreate(operationName.data(), operationName.size()); 79 return cls(mlirTransformOperationTypeGet(ctx, cOperationName)); 80 }, 81 "Get an instance of OperationType for the given kind in the given " 82 "context", 83 py::arg("cls"), py::arg("operation_name"), 84 py::arg("context") = py::none()); 85 operationType.def_property_readonly( 86 "operation_name", 87 [](MlirType type) { 88 MlirStringRef operationName = 89 mlirTransformOperationTypeGetOperationName(type); 90 return py::str(operationName.data, operationName.length); 91 }, 92 "Get the name of the payload operation accepted by the handle."); 93 94 //===-------------------------------------------------------------------===// 95 // ParamType 96 //===-------------------------------------------------------------------===// 97 98 auto paramType = 99 mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType); 100 paramType.def_classmethod( 101 "get", 102 [](py::object cls, MlirType type, MlirContext ctx) { 103 return cls(mlirTransformParamTypeGet(ctx, type)); 104 }, 105 "Get an instance of ParamType for the given type in the given context.", 106 py::arg("cls"), py::arg("type"), py::arg("context") = py::none()); 107 paramType.def_property_readonly( 108 "type", 109 [](MlirType type) { 110 MlirType paramType = mlirTransformParamTypeGetType(type); 111 return paramType; 112 }, 113 "Get the type this ParamType is associated with."); 114 } 115 116 PYBIND11_MODULE(_mlirDialectsTransform, m) { 117 m.doc() = "MLIR Transform dialect."; 118 populateDialectTransformSubmodule(m); 119 } 120