xref: /llvm-project/mlir/lib/Bindings/Python/DialectGPU.cpp (revision 6e6da74c8b936e457ca5e56a828823ae6a9f9066)
1 //===- DialectGPU.cpp - Pybind module for the GPU passes ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/GPU.h"
10 #include "mlir-c/IR.h"
11 #include "mlir-c/Support.h"
12 #include "mlir/Bindings/Python/PybindAdaptors.h"
13 
14 #include <pybind11/detail/common.h>
15 #include <pybind11/pybind11.h>
16 
17 namespace py = pybind11;
18 using namespace mlir;
19 using namespace mlir::python;
20 using namespace mlir::python::adaptors;
21 
22 // -----------------------------------------------------------------------------
23 // Module initialization.
24 // -----------------------------------------------------------------------------
25 
26 PYBIND11_MODULE(_mlirDialectsGPU, m) {
27   m.doc() = "MLIR GPU Dialect";
28 
29   //===-------------------------------------------------------------------===//
30   // ObjectAttr
31   //===-------------------------------------------------------------------===//
32 
33   mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
34       .def_classmethod(
35           "get",
36           [](py::object cls, MlirAttribute target, uint32_t format,
37              py::bytes object, std::optional<MlirAttribute> mlirObjectProps) {
38             py::buffer_info info(py::buffer(object).request());
39             MlirStringRef objectStrRef =
40                 mlirStringRefCreate(static_cast<char *>(info.ptr), info.size);
41             return cls(mlirGPUObjectAttrGet(
42                 mlirAttributeGetContext(target), target, format, objectStrRef,
43                 mlirObjectProps.has_value() ? *mlirObjectProps
44                                             : MlirAttribute{nullptr}));
45           },
46           "cls"_a, "target"_a, "format"_a, "object"_a,
47           "properties"_a = py::none(), "Gets a gpu.object from parameters.")
48       .def_property_readonly(
49           "target",
50           [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); })
51       .def_property_readonly(
52           "format",
53           [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); })
54       .def_property_readonly(
55           "object",
56           [](MlirAttribute self) {
57             MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
58             return py::bytes(stringRef.data, stringRef.length);
59           })
60       .def_property_readonly("properties", [](MlirAttribute self) {
61         if (mlirGPUObjectAttrHasProperties(self))
62           return py::cast(mlirGPUObjectAttrGetProperties(self));
63         return py::none().cast<py::object>();
64       });
65 }
66