1 //===- DialectTransform.cpp - 'transform' dialect submodule ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Dialect/Transform.h" 10 #include "mlir-c/IR.h" 11 #include "mlir-c/Support.h" 12 #include "mlir/Bindings/Python/PybindAdaptors.h" 13 #include <pybind11/cast.h> 14 #include <pybind11/detail/common.h> 15 #include <pybind11/pybind11.h> 16 #include <pybind11/pytypes.h> 17 #include <string> 18 19 namespace py = pybind11; 20 using namespace mlir; 21 using namespace mlir::python; 22 using namespace mlir::python::adaptors; 23 24 void populateDialectTransformSubmodule(const pybind11::module &m) { 25 //===-------------------------------------------------------------------===// 26 // AnyOpType 27 //===-------------------------------------------------------------------===// 28 29 auto anyOpType = 30 mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType, 31 mlirTransformAnyOpTypeGetTypeID); 32 anyOpType.def_classmethod( 33 "get", 34 [](py::object cls, MlirContext ctx) { 35 return cls(mlirTransformAnyOpTypeGet(ctx)); 36 }, 37 "Get an instance of AnyOpType in the given context.", py::arg("cls"), 38 py::arg("context") = py::none()); 39 40 //===-------------------------------------------------------------------===// 41 // AnyParamType 42 //===-------------------------------------------------------------------===// 43 44 auto anyParamType = 45 mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType, 46 mlirTransformAnyParamTypeGetTypeID); 47 anyParamType.def_classmethod( 48 "get", 49 [](py::object cls, MlirContext ctx) { 50 return cls(mlirTransformAnyParamTypeGet(ctx)); 51 }, 52 "Get an instance of AnyParamType in the given context.", py::arg("cls"), 53 py::arg("context") = py::none()); 54 55 //===-------------------------------------------------------------------===// 56 // AnyValueType 57 //===-------------------------------------------------------------------===// 58 59 auto anyValueType = 60 mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType, 61 mlirTransformAnyValueTypeGetTypeID); 62 anyValueType.def_classmethod( 63 "get", 64 [](py::object cls, MlirContext ctx) { 65 return cls(mlirTransformAnyValueTypeGet(ctx)); 66 }, 67 "Get an instance of AnyValueType in the given context.", py::arg("cls"), 68 py::arg("context") = py::none()); 69 70 //===-------------------------------------------------------------------===// 71 // OperationType 72 //===-------------------------------------------------------------------===// 73 74 auto operationType = 75 mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType, 76 mlirTransformOperationTypeGetTypeID); 77 operationType.def_classmethod( 78 "get", 79 [](py::object cls, const std::string &operationName, MlirContext ctx) { 80 MlirStringRef cOperationName = 81 mlirStringRefCreate(operationName.data(), operationName.size()); 82 return cls(mlirTransformOperationTypeGet(ctx, cOperationName)); 83 }, 84 "Get an instance of OperationType for the given kind in the given " 85 "context", 86 py::arg("cls"), py::arg("operation_name"), 87 py::arg("context") = py::none()); 88 operationType.def_property_readonly( 89 "operation_name", 90 [](MlirType type) { 91 MlirStringRef operationName = 92 mlirTransformOperationTypeGetOperationName(type); 93 return py::str(operationName.data, operationName.length); 94 }, 95 "Get the name of the payload operation accepted by the handle."); 96 97 //===-------------------------------------------------------------------===// 98 // ParamType 99 //===-------------------------------------------------------------------===// 100 101 auto paramType = 102 mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType, 103 mlirTransformParamTypeGetTypeID); 104 paramType.def_classmethod( 105 "get", 106 [](py::object cls, MlirType type, MlirContext ctx) { 107 return cls(mlirTransformParamTypeGet(ctx, type)); 108 }, 109 "Get an instance of ParamType for the given type in the given context.", 110 py::arg("cls"), py::arg("type"), py::arg("context") = py::none()); 111 paramType.def_property_readonly( 112 "type", 113 [](MlirType type) { 114 MlirType paramType = mlirTransformParamTypeGetType(type); 115 return paramType; 116 }, 117 "Get the type this ParamType is associated with."); 118 } 119 120 PYBIND11_MODULE(_mlirDialectsTransform, m) { 121 m.doc() = "MLIR Transform dialect."; 122 populateDialectTransformSubmodule(m); 123 } 124