1 //===- DialectGPU.cpp - Pybind module for the GPU passes ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 9 #include "mlir-c/Dialect/GPU.h" 10 #include "mlir-c/IR.h" 11 #include "mlir-c/Support.h" 12 #include "mlir/Bindings/Python/PybindAdaptors.h" 13 14 #include <pybind11/detail/common.h> 15 #include <pybind11/pybind11.h> 16 17 namespace py = pybind11; 18 using namespace mlir; 19 using namespace mlir::python; 20 using namespace mlir::python::adaptors; 21 22 // ----------------------------------------------------------------------------- 23 // Module initialization. 24 // ----------------------------------------------------------------------------- 25 26 PYBIND11_MODULE(_mlirDialectsGPU, m) { 27 m.doc() = "MLIR GPU Dialect"; 28 29 //===-------------------------------------------------------------------===// 30 // ObjectAttr 31 //===-------------------------------------------------------------------===// 32 33 mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr) 34 .def_classmethod( 35 "get", 36 [](py::object cls, MlirAttribute target, uint32_t format, 37 py::bytes object, std::optional<MlirAttribute> mlirObjectProps) { 38 py::buffer_info info(py::buffer(object).request()); 39 MlirStringRef objectStrRef = 40 mlirStringRefCreate(static_cast<char *>(info.ptr), info.size); 41 return cls(mlirGPUObjectAttrGet( 42 mlirAttributeGetContext(target), target, format, objectStrRef, 43 mlirObjectProps.has_value() ? *mlirObjectProps 44 : MlirAttribute{nullptr})); 45 }, 46 "cls"_a, "target"_a, "format"_a, "object"_a, 47 "properties"_a = py::none(), "Gets a gpu.object from parameters.") 48 .def_property_readonly( 49 "target", 50 [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); }) 51 .def_property_readonly( 52 "format", 53 [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); }) 54 .def_property_readonly( 55 "object", 56 [](MlirAttribute self) { 57 MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self); 58 return py::bytes(stringRef.data, stringRef.length); 59 }) 60 .def_property_readonly("properties", [](MlirAttribute self) { 61 if (mlirGPUObjectAttrHasProperties(self)) 62 return py::cast(mlirGPUObjectAttrGetProperties(self)); 63 return py::none().cast<py::object>(); 64 }); 65 } 66