xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp (revision 72e8b9aeaa3f584f223bc59924812df69a09a48b)
143752a2aSFabian Mora //===- ModuleToBinary.cpp - Transforms GPU modules to GPU binaries ----------=//
243752a2aSFabian Mora //
343752a2aSFabian Mora // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
443752a2aSFabian Mora // See https://llvm.org/LICENSE.txt for license information.
543752a2aSFabian Mora // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
643752a2aSFabian Mora //
743752a2aSFabian Mora //===----------------------------------------------------------------------===//
843752a2aSFabian Mora //
943752a2aSFabian Mora // This file implements the `GpuModuleToBinaryPass` pass, transforming GPU
1043752a2aSFabian Mora // modules into GPU binaries.
1143752a2aSFabian Mora //
1243752a2aSFabian Mora //===----------------------------------------------------------------------===//
1343752a2aSFabian Mora 
1443752a2aSFabian Mora #include "mlir/Dialect/GPU/Transforms/Passes.h"
1543752a2aSFabian Mora 
1643752a2aSFabian Mora #include "mlir/Dialect/Func/IR/FuncOps.h"
1743752a2aSFabian Mora #include "mlir/Dialect/GPU/IR/GPUDialect.h"
187c4e8c6aSNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
197c4e8c6aSNicolas Vasilache #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
207c4e8c6aSNicolas Vasilache #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
212dace045SSang Ik Lee #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
2243752a2aSFabian Mora #include "mlir/IR/BuiltinOps.h"
2343752a2aSFabian Mora #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2443752a2aSFabian Mora 
2543752a2aSFabian Mora #include "llvm/ADT/STLExtras.h"
2643752a2aSFabian Mora #include "llvm/ADT/StringSwitch.h"
2743752a2aSFabian Mora 
2843752a2aSFabian Mora using namespace mlir;
2943752a2aSFabian Mora using namespace mlir::gpu;
3043752a2aSFabian Mora 
3143752a2aSFabian Mora namespace mlir {
3243752a2aSFabian Mora #define GEN_PASS_DEF_GPUMODULETOBINARYPASS
3343752a2aSFabian Mora #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
3443752a2aSFabian Mora } // namespace mlir
3543752a2aSFabian Mora 
3643752a2aSFabian Mora namespace {
3743752a2aSFabian Mora class GpuModuleToBinaryPass
3843752a2aSFabian Mora     : public impl::GpuModuleToBinaryPassBase<GpuModuleToBinaryPass> {
3943752a2aSFabian Mora public:
4043752a2aSFabian Mora   using Base::Base;
4143752a2aSFabian Mora   void runOnOperation() final;
4243752a2aSFabian Mora };
4343752a2aSFabian Mora } // namespace
4443752a2aSFabian Mora 
4543752a2aSFabian Mora void GpuModuleToBinaryPass::runOnOperation() {
4643752a2aSFabian Mora   RewritePatternSet patterns(&getContext());
475093413aSFabian Mora   auto targetFormat =
485093413aSFabian Mora       llvm::StringSwitch<std::optional<CompilationTarget>>(compilationTarget)
495093413aSFabian Mora           .Cases("offloading", "llvm", CompilationTarget::Offload)
505093413aSFabian Mora           .Cases("assembly", "isa", CompilationTarget::Assembly)
515093413aSFabian Mora           .Cases("binary", "bin", CompilationTarget::Binary)
525093413aSFabian Mora           .Cases("fatbinary", "fatbin", CompilationTarget::Fatbin)
535093413aSFabian Mora           .Default(std::nullopt);
545093413aSFabian Mora   if (!targetFormat)
5543752a2aSFabian Mora     getOperation()->emitError() << "Invalid format specified.";
56444abb39SFabian Mora 
57444abb39SFabian Mora   // Lazy symbol table builder callback.
58444abb39SFabian Mora   std::optional<SymbolTable> parentTable;
59444abb39SFabian Mora   auto lazyTableBuilder = [&]() -> SymbolTable * {
60444abb39SFabian Mora     // Build the table if it has not been built.
61444abb39SFabian Mora     if (!parentTable) {
62444abb39SFabian Mora       Operation *table = SymbolTable::getNearestSymbolTable(getOperation());
63444abb39SFabian Mora       // It's up to the target attribute to determine if failing to find a
64444abb39SFabian Mora       // symbol table is an error.
65444abb39SFabian Mora       if (!table)
66444abb39SFabian Mora         return nullptr;
67444abb39SFabian Mora       parentTable = SymbolTable(table);
68444abb39SFabian Mora     }
69444abb39SFabian Mora     return &parentTable.value();
70444abb39SFabian Mora   };
71*72e8b9aeSMehdi Amini   SmallVector<Attribute> librariesToLink;
72*72e8b9aeSMehdi Amini   for (const std::string &path : linkFiles)
73*72e8b9aeSMehdi Amini     librariesToLink.push_back(StringAttr::get(&getContext(), path));
74*72e8b9aeSMehdi Amini   TargetOptions targetOptions(toolkitPath, librariesToLink, cmdOptions,
75*72e8b9aeSMehdi Amini                               elfSection, *targetFormat, lazyTableBuilder);
7643752a2aSFabian Mora   if (failed(transformGpuModulesToBinaries(
77dc6ce608SFabian Mora           getOperation(), OffloadingLLVMTranslationAttrInterface(nullptr),
7843752a2aSFabian Mora           targetOptions)))
7943752a2aSFabian Mora     return signalPassFailure();
8043752a2aSFabian Mora }
8143752a2aSFabian Mora 
8243752a2aSFabian Mora namespace {
8343752a2aSFabian Mora LogicalResult moduleSerializer(GPUModuleOp op,
8443752a2aSFabian Mora                                OffloadingLLVMTranslationAttrInterface handler,
8543752a2aSFabian Mora                                const TargetOptions &targetOptions) {
8643752a2aSFabian Mora   OpBuilder builder(op->getContext());
8743752a2aSFabian Mora   SmallVector<Attribute> objects;
88419c45a3SFabian Mora   // Fail if there are no target attributes
89419c45a3SFabian Mora   if (!op.getTargetsAttr())
90419c45a3SFabian Mora     return op.emitError("the module has no target attributes");
9143752a2aSFabian Mora   // Serialize all targets.
9243752a2aSFabian Mora   for (auto targetAttr : op.getTargetsAttr()) {
9343752a2aSFabian Mora     assert(targetAttr && "Target attribute cannot be null.");
9443752a2aSFabian Mora     auto target = dyn_cast<gpu::TargetAttrInterface>(targetAttr);
9543752a2aSFabian Mora     assert(target &&
9643752a2aSFabian Mora            "Target attribute doesn't implements `TargetAttrInterface`.");
975093413aSFabian Mora     std::optional<SmallVector<char, 0>> serializedModule =
9843752a2aSFabian Mora         target.serializeToObject(op, targetOptions);
995093413aSFabian Mora     if (!serializedModule) {
10043752a2aSFabian Mora       op.emitError("An error happened while serializing the module.");
10143752a2aSFabian Mora       return failure();
10243752a2aSFabian Mora     }
10343752a2aSFabian Mora 
104fd36a7b9SFabian Mora     Attribute object =
105fd36a7b9SFabian Mora         target.createObject(op, *serializedModule, targetOptions);
1065093413aSFabian Mora     if (!object) {
1075093413aSFabian Mora       op.emitError("An error happened while creating the object.");
1085093413aSFabian Mora       return failure();
1095093413aSFabian Mora     }
1105093413aSFabian Mora     objects.push_back(object);
11143752a2aSFabian Mora   }
1125b4f2b90SFabian Mora   if (auto moduleHandler =
1135b4f2b90SFabian Mora           dyn_cast_or_null<OffloadingLLVMTranslationAttrInterface>(
1145b4f2b90SFabian Mora               op.getOffloadingHandlerAttr());
1155b4f2b90SFabian Mora       !handler && moduleHandler)
1165b4f2b90SFabian Mora     handler = moduleHandler;
11743752a2aSFabian Mora   builder.setInsertionPointAfter(op);
11843752a2aSFabian Mora   builder.create<gpu::BinaryOp>(op.getLoc(), op.getName(), handler,
11943752a2aSFabian Mora                                 builder.getArrayAttr(objects));
12043752a2aSFabian Mora   op->erase();
12143752a2aSFabian Mora   return success();
12243752a2aSFabian Mora }
12343752a2aSFabian Mora } // namespace
12443752a2aSFabian Mora 
12543752a2aSFabian Mora LogicalResult mlir::gpu::transformGpuModulesToBinaries(
12643752a2aSFabian Mora     Operation *op, OffloadingLLVMTranslationAttrInterface handler,
12743752a2aSFabian Mora     const gpu::TargetOptions &targetOptions) {
12843752a2aSFabian Mora   for (Region &region : op->getRegions())
12943752a2aSFabian Mora     for (Block &block : region.getBlocks())
13043752a2aSFabian Mora       for (auto module :
13143752a2aSFabian Mora            llvm::make_early_inc_range(block.getOps<GPUModuleOp>()))
13243752a2aSFabian Mora         if (failed(moduleSerializer(module, handler, targetOptions)))
13343752a2aSFabian Mora           return failure();
13443752a2aSFabian Mora   return success();
13543752a2aSFabian Mora }
136