1 //===- DialectGPU.cpp - Pybind module for the GPU passes ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 9 #include "mlir-c/Dialect/GPU.h" 10 #include "mlir-c/IR.h" 11 #include "mlir-c/Support.h" 12 #include "mlir/Bindings/Python/NanobindAdaptors.h" 13 #include "mlir/Bindings/Python/Nanobind.h" 14 15 namespace nb = nanobind; 16 using namespace nanobind::literals; 17 18 using namespace mlir; 19 using namespace mlir::python; 20 using namespace mlir::python::nanobind_adaptors; 21 22 // ----------------------------------------------------------------------------- 23 // Module initialization. 24 // ----------------------------------------------------------------------------- 25 26 NB_MODULE(_mlirDialectsGPU, m) { 27 m.doc() = "MLIR GPU Dialect"; 28 //===-------------------------------------------------------------------===// 29 // AsyncTokenType 30 //===-------------------------------------------------------------------===// 31 32 auto mlirGPUAsyncTokenType = 33 mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType); 34 35 mlirGPUAsyncTokenType.def_classmethod( 36 "get", 37 [](nb::object cls, MlirContext ctx) { 38 return cls(mlirGPUAsyncTokenTypeGet(ctx)); 39 }, 40 "Gets an instance of AsyncTokenType in the same context", nb::arg("cls"), 41 nb::arg("ctx").none() = nb::none()); 42 43 //===-------------------------------------------------------------------===// 44 // ObjectAttr 45 //===-------------------------------------------------------------------===// 46 47 mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr) 48 .def_classmethod( 49 "get", 50 [](nb::object cls, MlirAttribute target, uint32_t format, 51 nb::bytes object, std::optional<MlirAttribute> mlirObjectProps, 52 std::optional<MlirAttribute> mlirKernelsAttr) { 53 MlirStringRef objectStrRef = mlirStringRefCreate( 54 static_cast<char *>(const_cast<void *>(object.data())), 55 object.size()); 56 return cls(mlirGPUObjectAttrGetWithKernels( 57 mlirAttributeGetContext(target), target, format, objectStrRef, 58 mlirObjectProps.has_value() ? *mlirObjectProps 59 : MlirAttribute{nullptr}, 60 mlirKernelsAttr.has_value() ? *mlirKernelsAttr 61 : MlirAttribute{nullptr})); 62 }, 63 "cls"_a, "target"_a, "format"_a, "object"_a, 64 "properties"_a.none() = nb::none(), "kernels"_a.none() = nb::none(), 65 "Gets a gpu.object from parameters.") 66 .def_property_readonly( 67 "target", 68 [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); }) 69 .def_property_readonly( 70 "format", 71 [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); }) 72 .def_property_readonly( 73 "object", 74 [](MlirAttribute self) { 75 MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self); 76 return nb::bytes(stringRef.data, stringRef.length); 77 }) 78 .def_property_readonly("properties", 79 [](MlirAttribute self) -> nb::object { 80 if (mlirGPUObjectAttrHasProperties(self)) 81 return nb::cast( 82 mlirGPUObjectAttrGetProperties(self)); 83 return nb::none(); 84 }) 85 .def_property_readonly("kernels", [](MlirAttribute self) -> nb::object { 86 if (mlirGPUObjectAttrHasKernels(self)) 87 return nb::cast(mlirGPUObjectAttrGetKernels(self)); 88 return nb::none(); 89 }); 90 } 91