xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp (revision 72e8b9aeaa3f584f223bc59924812df69a09a48b)
1 //===- ModuleToBinary.cpp - Transforms GPU modules to GPU binaries ----------=//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the `GpuModuleToBinaryPass` pass, transforming GPU
10 // modules into GPU binaries.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/GPU/Transforms/Passes.h"
15 
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
20 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
21 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringSwitch.h"
27 
28 using namespace mlir;
29 using namespace mlir::gpu;
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_GPUMODULETOBINARYPASS
33 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
34 } // namespace mlir
35 
36 namespace {
37 class GpuModuleToBinaryPass
38     : public impl::GpuModuleToBinaryPassBase<GpuModuleToBinaryPass> {
39 public:
40   using Base::Base;
41   void runOnOperation() final;
42 };
43 } // namespace
44 
45 void GpuModuleToBinaryPass::runOnOperation() {
46   RewritePatternSet patterns(&getContext());
47   auto targetFormat =
48       llvm::StringSwitch<std::optional<CompilationTarget>>(compilationTarget)
49           .Cases("offloading", "llvm", CompilationTarget::Offload)
50           .Cases("assembly", "isa", CompilationTarget::Assembly)
51           .Cases("binary", "bin", CompilationTarget::Binary)
52           .Cases("fatbinary", "fatbin", CompilationTarget::Fatbin)
53           .Default(std::nullopt);
54   if (!targetFormat)
55     getOperation()->emitError() << "Invalid format specified.";
56 
57   // Lazy symbol table builder callback.
58   std::optional<SymbolTable> parentTable;
59   auto lazyTableBuilder = [&]() -> SymbolTable * {
60     // Build the table if it has not been built.
61     if (!parentTable) {
62       Operation *table = SymbolTable::getNearestSymbolTable(getOperation());
63       // It's up to the target attribute to determine if failing to find a
64       // symbol table is an error.
65       if (!table)
66         return nullptr;
67       parentTable = SymbolTable(table);
68     }
69     return &parentTable.value();
70   };
71   SmallVector<Attribute> librariesToLink;
72   for (const std::string &path : linkFiles)
73     librariesToLink.push_back(StringAttr::get(&getContext(), path));
74   TargetOptions targetOptions(toolkitPath, librariesToLink, cmdOptions,
75                               elfSection, *targetFormat, lazyTableBuilder);
76   if (failed(transformGpuModulesToBinaries(
77           getOperation(), OffloadingLLVMTranslationAttrInterface(nullptr),
78           targetOptions)))
79     return signalPassFailure();
80 }
81 
82 namespace {
83 LogicalResult moduleSerializer(GPUModuleOp op,
84                                OffloadingLLVMTranslationAttrInterface handler,
85                                const TargetOptions &targetOptions) {
86   OpBuilder builder(op->getContext());
87   SmallVector<Attribute> objects;
88   // Fail if there are no target attributes
89   if (!op.getTargetsAttr())
90     return op.emitError("the module has no target attributes");
91   // Serialize all targets.
92   for (auto targetAttr : op.getTargetsAttr()) {
93     assert(targetAttr && "Target attribute cannot be null.");
94     auto target = dyn_cast<gpu::TargetAttrInterface>(targetAttr);
95     assert(target &&
96            "Target attribute doesn't implements `TargetAttrInterface`.");
97     std::optional<SmallVector<char, 0>> serializedModule =
98         target.serializeToObject(op, targetOptions);
99     if (!serializedModule) {
100       op.emitError("An error happened while serializing the module.");
101       return failure();
102     }
103 
104     Attribute object =
105         target.createObject(op, *serializedModule, targetOptions);
106     if (!object) {
107       op.emitError("An error happened while creating the object.");
108       return failure();
109     }
110     objects.push_back(object);
111   }
112   if (auto moduleHandler =
113           dyn_cast_or_null<OffloadingLLVMTranslationAttrInterface>(
114               op.getOffloadingHandlerAttr());
115       !handler && moduleHandler)
116     handler = moduleHandler;
117   builder.setInsertionPointAfter(op);
118   builder.create<gpu::BinaryOp>(op.getLoc(), op.getName(), handler,
119                                 builder.getArrayAttr(objects));
120   op->erase();
121   return success();
122 }
123 } // namespace
124 
125 LogicalResult mlir::gpu::transformGpuModulesToBinaries(
126     Operation *op, OffloadingLLVMTranslationAttrInterface handler,
127     const gpu::TargetOptions &targetOptions) {
128   for (Region &region : op->getRegions())
129     for (Block &block : region.getBlocks())
130       for (auto module :
131            llvm::make_early_inc_range(block.getOps<GPUModuleOp>()))
132         if (failed(moduleSerializer(module, handler, targetOptions)))
133           return failure();
134   return success();
135 }
136