xref: /llvm-project/mlir/lib/Target/LLVM/ROCDL/Utils.cpp (revision 016e1eb9c86923bf6a9669697f6be8309d12b78c)
1 //===- Utils.cpp - MLIR ROCDL target utils ----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This files defines ROCDL target related utility classes and functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Target/LLVM/ROCDL/Utils.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
16 
17 #include "llvm/ADT/StringMap.h"
18 #include "llvm/Frontend/Offloading/Utility.h"
19 
20 using namespace mlir;
21 using namespace mlir::ROCDL;
22 
23 std::optional<DenseMap<StringAttr, NamedAttrList>>
24 mlir::ROCDL::getAMDHSAKernelsELFMetadata(Builder &builder,
25                                          ArrayRef<char> elfData) {
26   uint16_t elfABIVersion;
27   llvm::StringMap<llvm::offloading::amdgpu::AMDGPUKernelMetaData> kernels;
28   llvm::MemoryBufferRef buffer(StringRef(elfData.data(), elfData.size()),
29                                "buffer");
30   // Get the metadata.
31   llvm::Error error = llvm::offloading::amdgpu::getAMDGPUMetaDataFromImage(
32       buffer, kernels, elfABIVersion);
33   // Return `nullopt` if the metadata couldn't be retrieved.
34   if (error) {
35     llvm::consumeError(std::move(error));
36     return std::nullopt;
37   }
38   // Helper lambda for converting values.
39   auto getI32Array = [&builder](const uint32_t *array) {
40     return builder.getDenseI32ArrayAttr({static_cast<int32_t>(array[0]),
41                                          static_cast<int32_t>(array[1]),
42                                          static_cast<int32_t>(array[2])});
43   };
44   DenseMap<StringAttr, NamedAttrList> kernelMD;
45   for (const auto &[name, kernel] : kernels) {
46     NamedAttrList attrs;
47     // Add kernel metadata.
48     attrs.append("agpr_count", builder.getI64IntegerAttr(kernel.AGPRCount));
49     attrs.append("sgpr_count", builder.getI64IntegerAttr(kernel.SGPRCount));
50     attrs.append("vgpr_count", builder.getI64IntegerAttr(kernel.VGPRCount));
51     attrs.append("sgpr_spill_count",
52                  builder.getI64IntegerAttr(kernel.SGPRSpillCount));
53     attrs.append("vgpr_spill_count",
54                  builder.getI64IntegerAttr(kernel.VGPRSpillCount));
55     attrs.append("wavefront_size",
56                  builder.getI64IntegerAttr(kernel.WavefrontSize));
57     attrs.append("max_flat_workgroup_size",
58                  builder.getI64IntegerAttr(kernel.MaxFlatWorkgroupSize));
59     attrs.append("group_segment_fixed_size",
60                  builder.getI64IntegerAttr(kernel.GroupSegmentList));
61     attrs.append("private_segment_fixed_size",
62                  builder.getI64IntegerAttr(kernel.PrivateSegmentSize));
63     attrs.append("reqd_workgroup_size",
64                  getI32Array(kernel.RequestedWorkgroupSize));
65     attrs.append("workgroup_size_hint", getI32Array(kernel.WorkgroupSizeHint));
66     kernelMD[builder.getStringAttr(name)] = std::move(attrs);
67   }
68   return std::move(kernelMD);
69 }
70 
71 gpu::KernelTableAttr mlir::ROCDL::getKernelMetadata(Operation *gpuModule,
72                                                     ArrayRef<char> elfData) {
73   auto module = cast<gpu::GPUModuleOp>(gpuModule);
74   Builder builder(module.getContext());
75   SmallVector<gpu::KernelMetadataAttr> kernels;
76   std::optional<DenseMap<StringAttr, NamedAttrList>> mdMapOrNull =
77       getAMDHSAKernelsELFMetadata(builder, elfData);
78   for (auto funcOp : module.getBody()->getOps<LLVM::LLVMFuncOp>()) {
79     if (!funcOp->getDiscardableAttr("rocdl.kernel"))
80       continue;
81     kernels.push_back(gpu::KernelMetadataAttr::get(
82         funcOp, mdMapOrNull ? builder.getDictionaryAttr(
83                                   mdMapOrNull->lookup(funcOp.getNameAttr()))
84                             : nullptr));
85   }
86   return gpu::KernelTableAttr::get(gpuModule->getContext(), kernels);
87 }
88