xref: /llvm-project/mlir/lib/Bindings/Python/DialectGPU.cpp (revision 016e1eb9c86923bf6a9669697f6be8309d12b78c)
1 //===- DialectGPU.cpp - Pybind module for the GPU passes ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/GPU.h"
10 #include "mlir-c/IR.h"
11 #include "mlir-c/Support.h"
12 #include "mlir/Bindings/Python/PybindAdaptors.h"
13 
14 #include <pybind11/detail/common.h>
15 #include <pybind11/pybind11.h>
16 
17 namespace py = pybind11;
18 using namespace mlir;
19 using namespace mlir::python;
20 using namespace mlir::python::adaptors;
21 
22 // -----------------------------------------------------------------------------
23 // Module initialization.
24 // -----------------------------------------------------------------------------
25 
26 PYBIND11_MODULE(_mlirDialectsGPU, m) {
27   m.doc() = "MLIR GPU Dialect";
28   //===-------------------------------------------------------------------===//
29   // AsyncTokenType
30   //===-------------------------------------------------------------------===//
31 
32   auto mlirGPUAsyncTokenType =
33       mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType);
34 
35   mlirGPUAsyncTokenType.def_classmethod(
36       "get",
37       [](py::object cls, MlirContext ctx) {
38         return cls(mlirGPUAsyncTokenTypeGet(ctx));
39       },
40       "Gets an instance of AsyncTokenType in the same context", py::arg("cls"),
41       py::arg("ctx") = py::none());
42 
43   //===-------------------------------------------------------------------===//
44   // ObjectAttr
45   //===-------------------------------------------------------------------===//
46 
47   mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
48       .def_classmethod(
49           "get",
50           [](py::object cls, MlirAttribute target, uint32_t format,
51              py::bytes object, std::optional<MlirAttribute> mlirObjectProps,
52              std::optional<MlirAttribute> mlirKernelsAttr) {
53             py::buffer_info info(py::buffer(object).request());
54             MlirStringRef objectStrRef =
55                 mlirStringRefCreate(static_cast<char *>(info.ptr), info.size);
56             return cls(mlirGPUObjectAttrGetWithKernels(
57                 mlirAttributeGetContext(target), target, format, objectStrRef,
58                 mlirObjectProps.has_value() ? *mlirObjectProps
59                                             : MlirAttribute{nullptr},
60                 mlirKernelsAttr.has_value() ? *mlirKernelsAttr
61                                             : MlirAttribute{nullptr}));
62           },
63           "cls"_a, "target"_a, "format"_a, "object"_a,
64           "properties"_a = py::none(), "kernels"_a = py::none(),
65           "Gets a gpu.object from parameters.")
66       .def_property_readonly(
67           "target",
68           [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); })
69       .def_property_readonly(
70           "format",
71           [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); })
72       .def_property_readonly(
73           "object",
74           [](MlirAttribute self) {
75             MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
76             return py::bytes(stringRef.data, stringRef.length);
77           })
78       .def_property_readonly("properties",
79                              [](MlirAttribute self) {
80                                if (mlirGPUObjectAttrHasProperties(self))
81                                  return py::cast(
82                                      mlirGPUObjectAttrGetProperties(self));
83                                return py::none().cast<py::object>();
84                              })
85       .def_property_readonly("kernels", [](MlirAttribute self) {
86         if (mlirGPUObjectAttrHasKernels(self))
87           return py::cast(mlirGPUObjectAttrGetKernels(self));
88         return py::none().cast<py::object>();
89       });
90 }
91