xref: /llvm-project/mlir/lib/Bindings/Python/DialectGPU.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1 //===- DialectGPU.cpp - Pybind module for the GPU passes ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/GPU.h"
10 #include "mlir-c/IR.h"
11 #include "mlir-c/Support.h"
12 #include "mlir/Bindings/Python/NanobindAdaptors.h"
13 #include "mlir/Bindings/Python/Nanobind.h"
14 
15 namespace nb = nanobind;
16 using namespace nanobind::literals;
17 
18 using namespace mlir;
19 using namespace mlir::python;
20 using namespace mlir::python::nanobind_adaptors;
21 
22 // -----------------------------------------------------------------------------
23 // Module initialization.
24 // -----------------------------------------------------------------------------
25 
26 NB_MODULE(_mlirDialectsGPU, m) {
27   m.doc() = "MLIR GPU Dialect";
28   //===-------------------------------------------------------------------===//
29   // AsyncTokenType
30   //===-------------------------------------------------------------------===//
31 
32   auto mlirGPUAsyncTokenType =
33       mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType);
34 
35   mlirGPUAsyncTokenType.def_classmethod(
36       "get",
37       [](nb::object cls, MlirContext ctx) {
38         return cls(mlirGPUAsyncTokenTypeGet(ctx));
39       },
40       "Gets an instance of AsyncTokenType in the same context", nb::arg("cls"),
41       nb::arg("ctx").none() = nb::none());
42 
43   //===-------------------------------------------------------------------===//
44   // ObjectAttr
45   //===-------------------------------------------------------------------===//
46 
47   mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
48       .def_classmethod(
49           "get",
50           [](nb::object cls, MlirAttribute target, uint32_t format,
51              nb::bytes object, std::optional<MlirAttribute> mlirObjectProps,
52              std::optional<MlirAttribute> mlirKernelsAttr) {
53             MlirStringRef objectStrRef = mlirStringRefCreate(
54                 static_cast<char *>(const_cast<void *>(object.data())),
55                 object.size());
56             return cls(mlirGPUObjectAttrGetWithKernels(
57                 mlirAttributeGetContext(target), target, format, objectStrRef,
58                 mlirObjectProps.has_value() ? *mlirObjectProps
59                                             : MlirAttribute{nullptr},
60                 mlirKernelsAttr.has_value() ? *mlirKernelsAttr
61                                             : MlirAttribute{nullptr}));
62           },
63           "cls"_a, "target"_a, "format"_a, "object"_a,
64           "properties"_a.none() = nb::none(), "kernels"_a.none() = nb::none(),
65           "Gets a gpu.object from parameters.")
66       .def_property_readonly(
67           "target",
68           [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); })
69       .def_property_readonly(
70           "format",
71           [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); })
72       .def_property_readonly(
73           "object",
74           [](MlirAttribute self) {
75             MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
76             return nb::bytes(stringRef.data, stringRef.length);
77           })
78       .def_property_readonly("properties",
79                              [](MlirAttribute self) -> nb::object {
80                                if (mlirGPUObjectAttrHasProperties(self))
81                                  return nb::cast(
82                                      mlirGPUObjectAttrGetProperties(self));
83                                return nb::none();
84                              })
85       .def_property_readonly("kernels", [](MlirAttribute self) -> nb::object {
86         if (mlirGPUObjectAttrHasKernels(self))
87           return nb::cast(mlirGPUObjectAttrGetKernels(self));
88         return nb::none();
89       });
90 }
91