xref: /llvm-project/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp (revision c6e7b4a61ab8718d9ac9d1d32f7d2d0cd0b19a7f)
1 //===- CUFToLLVMIRTranslation.cpp - Translate CUF dialect to LLVM IR ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR CUF dialect and LLVM IR.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h"
14 #include "flang/Optimizer/Dialect/CUF/CUFOps.h"
15 #include "flang/Runtime/entry-names.h"
16 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
17 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/Module.h"
21 #include "llvm/Support/FormatVariadic.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 
27 LogicalResult registerModule(cuf::RegisterModuleOp op,
28                              llvm::IRBuilderBase &builder,
29                              LLVM::ModuleTranslation &moduleTranslation) {
30   std::string binaryIdentifier =
31       op.getName().getLeafReference().str() + "_bin_cst";
32   llvm::Module *module = moduleTranslation.getLLVMModule();
33   llvm::Value *binary = module->getGlobalVariable(binaryIdentifier, true);
34   if (!binary)
35     return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;
36 
37   llvm::Type *ptrTy = builder.getPtrTy(0);
38   llvm::FunctionCallee fct = module->getOrInsertFunction(
39       RTNAME_STRING(CUFRegisterModule),
40       llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy}), false));
41   auto *handle = builder.CreateCall(fct, {binary});
42   moduleTranslation.mapValue(op->getResults().front()) = handle;
43   return mlir::success();
44 }
45 
46 llvm::Value *getOrCreateFunctionName(llvm::Module *module,
47                                      llvm::IRBuilderBase &builder,
48                                      llvm::StringRef moduleName,
49                                      llvm::StringRef kernelName) {
50   std::string globalName =
51       std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, kernelName));
52 
53   if (llvm::GlobalVariable *gv = module->getGlobalVariable(globalName))
54     return gv;
55 
56   return builder.CreateGlobalString(kernelName, globalName);
57 }
58 
59 LogicalResult registerKernel(cuf::RegisterKernelOp op,
60                              llvm::IRBuilderBase &builder,
61                              LLVM::ModuleTranslation &moduleTranslation) {
62   llvm::Module *module = moduleTranslation.getLLVMModule();
63   llvm::Type *ptrTy = builder.getPtrTy(0);
64   llvm::FunctionCallee fct = module->getOrInsertFunction(
65       RTNAME_STRING(CUFRegisterFunction),
66       llvm::FunctionType::get(
67           ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy, ptrTy}), false));
68   llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr());
69   if (!modulePtr)
70     return op.emitError() << "Couldn't find the module ptr";
71   llvm::Function *fctSym =
72       moduleTranslation.lookupFunction(op.getKernelName().str());
73   if (!fctSym)
74     return op.emitError() << "Couldn't find kernel name symbol: "
75                           << op.getKernelName().str();
76   builder.CreateCall(fct, {modulePtr, fctSym,
77                            getOrCreateFunctionName(
78                                module, builder, op.getKernelModuleName().str(),
79                                op.getKernelName().str())});
80   return mlir::success();
81 }
82 
83 class CUFDialectLLVMIRTranslationInterface
84     : public LLVMTranslationDialectInterface {
85 public:
86   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
87 
88   LogicalResult
89   convertOperation(Operation *operation, llvm::IRBuilderBase &builder,
90                    LLVM::ModuleTranslation &moduleTranslation) const override {
91     return llvm::TypeSwitch<Operation *, LogicalResult>(operation)
92         .Case([&](cuf::RegisterModuleOp op) {
93           return registerModule(op, builder, moduleTranslation);
94         })
95         .Case([&](cuf::RegisterKernelOp op) {
96           return registerKernel(op, builder, moduleTranslation);
97         })
98         .Default([&](Operation *op) {
99           return op->emitError("unsupported GPU operation: ") << op->getName();
100         });
101   }
102 };
103 
104 } // namespace
105 
106 void cuf::registerCUFDialectTranslation(DialectRegistry &registry) {
107   registry.insert<cuf::CUFDialect>();
108   registry.addExtension(+[](MLIRContext *ctx, cuf::CUFDialect *dialect) {
109     dialect->addInterfaces<CUFDialectLLVMIRTranslationInterface>();
110   });
111 }
112