1 //===- DialectTransform.cpp - 'transform' dialect submodule ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <string> 10 11 #include "mlir-c/Dialect/Transform.h" 12 #include "mlir-c/IR.h" 13 #include "mlir-c/Support.h" 14 #include "mlir/Bindings/Python/NanobindAdaptors.h" 15 #include "mlir/Bindings/Python/Nanobind.h" 16 17 namespace nb = nanobind; 18 using namespace mlir; 19 using namespace mlir::python; 20 using namespace mlir::python::nanobind_adaptors; 21 22 void populateDialectTransformSubmodule(const nb::module_ &m) { 23 //===-------------------------------------------------------------------===// 24 // AnyOpType 25 //===-------------------------------------------------------------------===// 26 27 auto anyOpType = 28 mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType, 29 mlirTransformAnyOpTypeGetTypeID); 30 anyOpType.def_classmethod( 31 "get", 32 [](nb::object cls, MlirContext ctx) { 33 return cls(mlirTransformAnyOpTypeGet(ctx)); 34 }, 35 "Get an instance of AnyOpType in the given context.", nb::arg("cls"), 36 nb::arg("context").none() = nb::none()); 37 38 //===-------------------------------------------------------------------===// 39 // AnyParamType 40 //===-------------------------------------------------------------------===// 41 42 auto anyParamType = 43 mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType, 44 mlirTransformAnyParamTypeGetTypeID); 45 anyParamType.def_classmethod( 46 "get", 47 [](nb::object cls, MlirContext ctx) { 48 return cls(mlirTransformAnyParamTypeGet(ctx)); 49 }, 50 "Get an instance of AnyParamType in the given context.", nb::arg("cls"), 51 nb::arg("context").none() = nb::none()); 52 53 //===-------------------------------------------------------------------===// 54 // AnyValueType 55 //===-------------------------------------------------------------------===// 56 57 auto anyValueType = 58 mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType, 59 mlirTransformAnyValueTypeGetTypeID); 60 anyValueType.def_classmethod( 61 "get", 62 [](nb::object cls, MlirContext ctx) { 63 return cls(mlirTransformAnyValueTypeGet(ctx)); 64 }, 65 "Get an instance of AnyValueType in the given context.", nb::arg("cls"), 66 nb::arg("context").none() = nb::none()); 67 68 //===-------------------------------------------------------------------===// 69 // OperationType 70 //===-------------------------------------------------------------------===// 71 72 auto operationType = 73 mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType, 74 mlirTransformOperationTypeGetTypeID); 75 operationType.def_classmethod( 76 "get", 77 [](nb::object cls, const std::string &operationName, MlirContext ctx) { 78 MlirStringRef cOperationName = 79 mlirStringRefCreate(operationName.data(), operationName.size()); 80 return cls(mlirTransformOperationTypeGet(ctx, cOperationName)); 81 }, 82 "Get an instance of OperationType for the given kind in the given " 83 "context", 84 nb::arg("cls"), nb::arg("operation_name"), 85 nb::arg("context").none() = nb::none()); 86 operationType.def_property_readonly( 87 "operation_name", 88 [](MlirType type) { 89 MlirStringRef operationName = 90 mlirTransformOperationTypeGetOperationName(type); 91 return nb::str(operationName.data, operationName.length); 92 }, 93 "Get the name of the payload operation accepted by the handle."); 94 95 //===-------------------------------------------------------------------===// 96 // ParamType 97 //===-------------------------------------------------------------------===// 98 99 auto paramType = 100 mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType, 101 mlirTransformParamTypeGetTypeID); 102 paramType.def_classmethod( 103 "get", 104 [](nb::object cls, MlirType type, MlirContext ctx) { 105 return cls(mlirTransformParamTypeGet(ctx, type)); 106 }, 107 "Get an instance of ParamType for the given type in the given context.", 108 nb::arg("cls"), nb::arg("type"), nb::arg("context").none() = nb::none()); 109 paramType.def_property_readonly( 110 "type", 111 [](MlirType type) { 112 MlirType paramType = mlirTransformParamTypeGetType(type); 113 return paramType; 114 }, 115 "Get the type this ParamType is associated with."); 116 } 117 118 NB_MODULE(_mlirDialectsTransform, m) { 119 m.doc() = "MLIR Transform dialect."; 120 populateDialectTransformSubmodule(m); 121 } 122