xref: /llvm-project/mlir/lib/Bindings/Python/DialectGPU.cpp (revision f8ff9094711b74d3f695f7571f6390f8a481fc52)
1 //===- DialectGPU.cpp - Pybind module for the GPU passes ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/GPU.h"
10 #include "mlir-c/IR.h"
11 #include "mlir-c/Support.h"
12 #include "mlir/Bindings/Python/PybindAdaptors.h"
13 
14 #include <pybind11/detail/common.h>
15 #include <pybind11/pybind11.h>
16 
17 namespace py = pybind11;
18 using namespace mlir;
19 using namespace mlir::python;
20 using namespace mlir::python::adaptors;
21 
22 // -----------------------------------------------------------------------------
23 // Module initialization.
24 // -----------------------------------------------------------------------------
25 
26 PYBIND11_MODULE(_mlirDialectsGPU, m) {
27   m.doc() = "MLIR GPU Dialect";
28   //===-------------------------------------------------------------------===//
29   // AsyncTokenType
30   //===-------------------------------------------------------------------===//
31 
32   auto mlirGPUAsyncTokenType =
33       mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType);
34 
35   mlirGPUAsyncTokenType.def_classmethod(
36       "get",
37       [](py::object cls, MlirContext ctx) {
38         return cls(mlirGPUAsyncTokenTypeGet(ctx));
39       },
40       "Gets an instance of AsyncTokenType in the same context", py::arg("cls"),
41       py::arg("ctx") = py::none());
42 
43   //===-------------------------------------------------------------------===//
44   // ObjectAttr
45   //===-------------------------------------------------------------------===//
46 
47   mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
48       .def_classmethod(
49           "get",
50           [](py::object cls, MlirAttribute target, uint32_t format,
51              py::bytes object, std::optional<MlirAttribute> mlirObjectProps) {
52             py::buffer_info info(py::buffer(object).request());
53             MlirStringRef objectStrRef =
54                 mlirStringRefCreate(static_cast<char *>(info.ptr), info.size);
55             return cls(mlirGPUObjectAttrGet(
56                 mlirAttributeGetContext(target), target, format, objectStrRef,
57                 mlirObjectProps.has_value() ? *mlirObjectProps
58                                             : MlirAttribute{nullptr}));
59           },
60           "cls"_a, "target"_a, "format"_a, "object"_a,
61           "properties"_a = py::none(), "Gets a gpu.object from parameters.")
62       .def_property_readonly(
63           "target",
64           [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); })
65       .def_property_readonly(
66           "format",
67           [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); })
68       .def_property_readonly(
69           "object",
70           [](MlirAttribute self) {
71             MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
72             return py::bytes(stringRef.data, stringRef.length);
73           })
74       .def_property_readonly("properties", [](MlirAttribute self) {
75         if (mlirGPUObjectAttrHasProperties(self))
76           return py::cast(mlirGPUObjectAttrGetProperties(self));
77         return py::none().cast<py::object>();
78       });
79 }
80