xref: /llvm-project/mlir/lib/CAPI/Dialect/GPU.cpp (revision 016e1eb9c86923bf6a9669697f6be8309d12b78c)
1 //===- GPU.cpp - C Interface for GPU dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/GPU.h"
10 #include "mlir/CAPI/Registration.h"
11 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
12 #include "llvm/Support/Casting.h"
13 
14 using namespace mlir;
15 
16 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(GPU, gpu, gpu::GPUDialect)
17 
18 //===-------------------------------------------------------------------===//
19 // AsyncTokenType
20 //===-------------------------------------------------------------------===//
21 
22 bool mlirTypeIsAGPUAsyncTokenType(MlirType type) {
23   return isa<gpu::AsyncTokenType>(unwrap(type));
24 }
25 
26 MlirType mlirGPUAsyncTokenTypeGet(MlirContext ctx) {
27   return wrap(gpu::AsyncTokenType::get(unwrap(ctx)));
28 }
29 
30 //===---------------------------------------------------------------------===//
31 // ObjectAttr
32 //===---------------------------------------------------------------------===//
33 
34 bool mlirAttributeIsAGPUObjectAttr(MlirAttribute attr) {
35   return llvm::isa<gpu::ObjectAttr>(unwrap(attr));
36 }
37 
38 MlirAttribute mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target,
39                                    uint32_t format, MlirStringRef objectStrRef,
40                                    MlirAttribute mlirObjectProps) {
41   MLIRContext *ctx = unwrap(mlirCtx);
42   llvm::StringRef object = unwrap(objectStrRef);
43   DictionaryAttr objectProps;
44   if (mlirObjectProps.ptr != nullptr)
45     objectProps = llvm::cast<DictionaryAttr>(unwrap(mlirObjectProps));
46   return wrap(gpu::ObjectAttr::get(
47       ctx, unwrap(target), static_cast<gpu::CompilationTarget>(format),
48       StringAttr::get(ctx, object), objectProps, nullptr));
49 }
50 
51 MlirAttribute mlirGPUObjectAttrGetWithKernels(MlirContext mlirCtx,
52                                               MlirAttribute target,
53                                               uint32_t format,
54                                               MlirStringRef objectStrRef,
55                                               MlirAttribute mlirObjectProps,
56                                               MlirAttribute mlirKernelsAttr) {
57   MLIRContext *ctx = unwrap(mlirCtx);
58   llvm::StringRef object = unwrap(objectStrRef);
59   DictionaryAttr objectProps;
60   if (mlirObjectProps.ptr != nullptr)
61     objectProps = llvm::cast<DictionaryAttr>(unwrap(mlirObjectProps));
62   gpu::KernelTableAttr kernels;
63   if (mlirKernelsAttr.ptr != nullptr)
64     kernels = llvm::cast<gpu::KernelTableAttr>(unwrap(mlirKernelsAttr));
65   return wrap(gpu::ObjectAttr::get(
66       ctx, unwrap(target), static_cast<gpu::CompilationTarget>(format),
67       StringAttr::get(ctx, object), objectProps, kernels));
68 }
69 
70 MlirAttribute mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr) {
71   gpu::ObjectAttr objectAttr =
72       llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
73   return wrap(objectAttr.getTarget());
74 }
75 
76 uint32_t mlirGPUObjectAttrGetFormat(MlirAttribute mlirObjectAttr) {
77   gpu::ObjectAttr objectAttr =
78       llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
79   return static_cast<uint32_t>(objectAttr.getFormat());
80 }
81 
82 MlirStringRef mlirGPUObjectAttrGetObject(MlirAttribute mlirObjectAttr) {
83   gpu::ObjectAttr objectAttr =
84       llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
85   llvm::StringRef object = objectAttr.getObject();
86   return mlirStringRefCreate(object.data(), object.size());
87 }
88 
89 bool mlirGPUObjectAttrHasProperties(MlirAttribute mlirObjectAttr) {
90   gpu::ObjectAttr objectAttr =
91       llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
92   return objectAttr.getProperties() != nullptr;
93 }
94 
95 MlirAttribute mlirGPUObjectAttrGetProperties(MlirAttribute mlirObjectAttr) {
96   gpu::ObjectAttr objectAttr =
97       llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
98   return wrap(objectAttr.getProperties());
99 }
100 
101 bool mlirGPUObjectAttrHasKernels(MlirAttribute mlirObjectAttr) {
102   gpu::ObjectAttr objectAttr =
103       llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
104   return objectAttr.getKernels() != nullptr;
105 }
106 
107 MlirAttribute mlirGPUObjectAttrGetKernels(MlirAttribute mlirObjectAttr) {
108   gpu::ObjectAttr objectAttr =
109       llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
110   return wrap(objectAttr.getKernels());
111 }
112