1 //===- DialectGPU.cpp - Pybind module for the GPU passes ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 9 #include "mlir-c/Dialect/GPU.h" 10 #include "mlir-c/IR.h" 11 #include "mlir-c/Support.h" 12 #include "mlir/Bindings/Python/PybindAdaptors.h" 13 14 #include <pybind11/detail/common.h> 15 #include <pybind11/pybind11.h> 16 17 namespace py = pybind11; 18 using namespace mlir; 19 using namespace mlir::python; 20 using namespace mlir::python::adaptors; 21 22 // ----------------------------------------------------------------------------- 23 // Module initialization. 24 // ----------------------------------------------------------------------------- 25 26 PYBIND11_MODULE(_mlirDialectsGPU, m) { 27 m.doc() = "MLIR GPU Dialect"; 28 //===-------------------------------------------------------------------===// 29 // AsyncTokenType 30 //===-------------------------------------------------------------------===// 31 32 auto mlirGPUAsyncTokenType = 33 mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType); 34 35 mlirGPUAsyncTokenType.def_classmethod( 36 "get", 37 [](py::object cls, MlirContext ctx) { 38 return cls(mlirGPUAsyncTokenTypeGet(ctx)); 39 }, 40 "Gets an instance of AsyncTokenType in the same context", py::arg("cls"), 41 py::arg("ctx") = py::none()); 42 43 //===-------------------------------------------------------------------===// 44 // ObjectAttr 45 //===-------------------------------------------------------------------===// 46 47 mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr) 48 .def_classmethod( 49 "get", 50 [](py::object cls, MlirAttribute target, uint32_t format, 51 py::bytes object, std::optional<MlirAttribute> mlirObjectProps) { 52 py::buffer_info info(py::buffer(object).request()); 53 MlirStringRef objectStrRef = 54 mlirStringRefCreate(static_cast<char *>(info.ptr), info.size); 55 return cls(mlirGPUObjectAttrGet( 56 mlirAttributeGetContext(target), target, format, objectStrRef, 57 mlirObjectProps.has_value() ? *mlirObjectProps 58 : MlirAttribute{nullptr})); 59 }, 60 "cls"_a, "target"_a, "format"_a, "object"_a, 61 "properties"_a = py::none(), "Gets a gpu.object from parameters.") 62 .def_property_readonly( 63 "target", 64 [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); }) 65 .def_property_readonly( 66 "format", 67 [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); }) 68 .def_property_readonly( 69 "object", 70 [](MlirAttribute self) { 71 MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self); 72 return py::bytes(stringRef.data, stringRef.length); 73 }) 74 .def_property_readonly("properties", [](MlirAttribute self) { 75 if (mlirGPUObjectAttrHasProperties(self)) 76 return py::cast(mlirGPUObjectAttrGetProperties(self)); 77 return py::none().cast<py::object>(); 78 }); 79 } 80