1 //===- CUFToLLVMIRTranslation.cpp - Translate CUF dialect to LLVM IR ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation between the MLIR CUF dialect and LLVM IR. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h" 14 #include "flang/Optimizer/Dialect/CUF/CUFOps.h" 15 #include "flang/Runtime/entry-names.h" 16 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" 17 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 18 #include "llvm/ADT/TypeSwitch.h" 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/Module.h" 21 #include "llvm/Support/FormatVariadic.h" 22 23 using namespace mlir; 24 25 namespace { 26 27 LogicalResult registerModule(cuf::RegisterModuleOp op, 28 llvm::IRBuilderBase &builder, 29 LLVM::ModuleTranslation &moduleTranslation) { 30 std::string binaryIdentifier = 31 op.getName().getLeafReference().str() + "_bin_cst"; 32 llvm::Module *module = moduleTranslation.getLLVMModule(); 33 llvm::Value *binary = module->getGlobalVariable(binaryIdentifier, true); 34 if (!binary) 35 return op.emitError() << "Couldn't find the binary: " << binaryIdentifier; 36 37 llvm::Type *ptrTy = builder.getPtrTy(0); 38 llvm::FunctionCallee fct = module->getOrInsertFunction( 39 RTNAME_STRING(CUFRegisterModule), 40 llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy}), false)); 41 auto *handle = builder.CreateCall(fct, {binary}); 42 moduleTranslation.mapValue(op->getResults().front()) = handle; 43 return mlir::success(); 44 } 45 46 llvm::Value *getOrCreateFunctionName(llvm::Module *module, 47 llvm::IRBuilderBase &builder, 48 llvm::StringRef moduleName, 49 llvm::StringRef kernelName) { 50 std::string globalName = 51 std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, kernelName)); 52 53 if (llvm::GlobalVariable *gv = module->getGlobalVariable(globalName)) 54 return gv; 55 56 return builder.CreateGlobalString(kernelName, globalName); 57 } 58 59 LogicalResult registerKernel(cuf::RegisterKernelOp op, 60 llvm::IRBuilderBase &builder, 61 LLVM::ModuleTranslation &moduleTranslation) { 62 llvm::Module *module = moduleTranslation.getLLVMModule(); 63 llvm::Type *ptrTy = builder.getPtrTy(0); 64 llvm::FunctionCallee fct = module->getOrInsertFunction( 65 RTNAME_STRING(CUFRegisterFunction), 66 llvm::FunctionType::get( 67 ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy, ptrTy}), false)); 68 llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr()); 69 if (!modulePtr) 70 return op.emitError() << "Couldn't find the module ptr"; 71 llvm::Function *fctSym = 72 moduleTranslation.lookupFunction(op.getKernelName().str()); 73 if (!fctSym) 74 return op.emitError() << "Couldn't find kernel name symbol: " 75 << op.getKernelName().str(); 76 builder.CreateCall(fct, {modulePtr, fctSym, 77 getOrCreateFunctionName( 78 module, builder, op.getKernelModuleName().str(), 79 op.getKernelName().str())}); 80 return mlir::success(); 81 } 82 83 class CUFDialectLLVMIRTranslationInterface 84 : public LLVMTranslationDialectInterface { 85 public: 86 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; 87 88 LogicalResult 89 convertOperation(Operation *operation, llvm::IRBuilderBase &builder, 90 LLVM::ModuleTranslation &moduleTranslation) const override { 91 return llvm::TypeSwitch<Operation *, LogicalResult>(operation) 92 .Case([&](cuf::RegisterModuleOp op) { 93 return registerModule(op, builder, moduleTranslation); 94 }) 95 .Case([&](cuf::RegisterKernelOp op) { 96 return registerKernel(op, builder, moduleTranslation); 97 }) 98 .Default([&](Operation *op) { 99 return op->emitError("unsupported GPU operation: ") << op->getName(); 100 }); 101 } 102 }; 103 104 } // namespace 105 106 void cuf::registerCUFDialectTranslation(DialectRegistry ®istry) { 107 registry.insert<cuf::CUFDialect>(); 108 registry.addExtension(+[](MLIRContext *ctx, cuf::CUFDialect *dialect) { 109 dialect->addInterfaces<CUFDialectLLVMIRTranslationInterface>(); 110 }); 111 } 112