1 //===-- CUFAddConstructor.cpp ---------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/BoxValue.h" 10 #include "flang/Optimizer/Builder/CUFCommon.h" 11 #include "flang/Optimizer/Builder/FIRBuilder.h" 12 #include "flang/Optimizer/Builder/Runtime/RTBuilder.h" 13 #include "flang/Optimizer/Builder/Todo.h" 14 #include "flang/Optimizer/CodeGen/Target.h" 15 #include "flang/Optimizer/CodeGen/TypeConverter.h" 16 #include "flang/Optimizer/Dialect/CUF/CUFOps.h" 17 #include "flang/Optimizer/Dialect/FIRAttr.h" 18 #include "flang/Optimizer/Dialect/FIRDialect.h" 19 #include "flang/Optimizer/Dialect/FIROps.h" 20 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 21 #include "flang/Optimizer/Dialect/FIRType.h" 22 #include "flang/Optimizer/Support/DataLayout.h" 23 #include "flang/Runtime/CUDA/registration.h" 24 #include "flang/Runtime/entry-names.h" 25 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 26 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 27 #include "mlir/IR/Value.h" 28 #include "mlir/Pass/Pass.h" 29 #include "llvm/ADT/SmallVector.h" 30 31 namespace fir { 32 #define GEN_PASS_DEF_CUFADDCONSTRUCTOR 33 #include "flang/Optimizer/Transforms/Passes.h.inc" 34 } // namespace fir 35 36 using namespace Fortran::runtime::cuda; 37 38 namespace { 39 40 static constexpr llvm::StringRef cudaFortranCtorName{ 41 "__cudaFortranConstructor"}; 42 43 struct CUFAddConstructor 44 : public fir::impl::CUFAddConstructorBase<CUFAddConstructor> { 45 46 void runOnOperation() override { 47 mlir::ModuleOp mod = getOperation(); 48 mlir::SymbolTable symTab(mod); 49 mlir::OpBuilder opBuilder{mod.getBodyRegion()}; 50 fir::FirOpBuilder builder(opBuilder, mod); 51 fir::KindMapping kindMap{fir::getKindMapping(mod)}; 52 builder.setInsertionPointToEnd(mod.getBody()); 53 mlir::Location loc = mod.getLoc(); 54 auto *ctx = mod.getContext(); 55 auto voidTy = mlir::LLVM::LLVMVoidType::get(ctx); 56 auto idxTy = builder.getIndexType(); 57 auto funcTy = 58 mlir::LLVM::LLVMFunctionType::get(voidTy, {}, /*isVarArg=*/false); 59 std::optional<mlir::DataLayout> dl = 60 fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/false); 61 if (!dl) { 62 mlir::emitError(mod.getLoc(), 63 "data layout attribute is required to perform " + 64 getName() + "pass"); 65 } 66 67 // Symbol reference to CUFRegisterAllocator. 68 builder.setInsertionPointToEnd(mod.getBody()); 69 auto registerFuncOp = builder.create<mlir::LLVM::LLVMFuncOp>( 70 loc, RTNAME_STRING(CUFRegisterAllocator), funcTy); 71 registerFuncOp.setVisibility(mlir::SymbolTable::Visibility::Private); 72 auto cufRegisterAllocatorRef = mlir::SymbolRefAttr::get( 73 mod.getContext(), RTNAME_STRING(CUFRegisterAllocator)); 74 builder.setInsertionPointToEnd(mod.getBody()); 75 76 // Create the constructor function that call CUFRegisterAllocator. 77 auto func = builder.create<mlir::LLVM::LLVMFuncOp>(loc, cudaFortranCtorName, 78 funcTy); 79 func.setLinkage(mlir::LLVM::Linkage::Internal); 80 builder.setInsertionPointToStart(func.addEntryBlock(builder)); 81 builder.create<mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef); 82 83 auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName); 84 if (gpuMod) { 85 auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx); 86 auto registeredMod = builder.create<cuf::RegisterModuleOp>( 87 loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName())); 88 89 fir::LLVMTypeConverter typeConverter(mod, /*applyTBAA=*/false, 90 /*forceUnifiedTBAATree=*/false, *dl); 91 // Register kernels 92 for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) { 93 if (func.isKernel()) { 94 auto kernelName = mlir::SymbolRefAttr::get( 95 builder.getStringAttr(cudaDeviceModuleName), 96 {mlir::SymbolRefAttr::get(builder.getContext(), func.getName())}); 97 builder.create<cuf::RegisterKernelOp>(loc, kernelName, registeredMod); 98 } 99 } 100 101 // Register variables 102 for (fir::GlobalOp globalOp : mod.getOps<fir::GlobalOp>()) { 103 auto attr = globalOp.getDataAttrAttr(); 104 if (!attr) 105 continue; 106 107 mlir::func::FuncOp func; 108 switch (attr.getValue()) { 109 case cuf::DataAttribute::Device: 110 case cuf::DataAttribute::Constant: { 111 func = fir::runtime::getRuntimeFunc<mkRTKey(CUFRegisterVariable)>( 112 loc, builder); 113 auto fTy = func.getFunctionType(); 114 115 // Global variable name 116 std::string gblNameStr = globalOp.getSymbol().getValue().str(); 117 gblNameStr += '\0'; 118 mlir::Value gblName = fir::getBase( 119 fir::factory::createStringLiteral(builder, loc, gblNameStr)); 120 121 // Global variable size 122 std::optional<uint64_t> size; 123 if (auto boxTy = 124 mlir::dyn_cast<fir::BaseBoxType>(globalOp.getType())) { 125 mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy); 126 size = dl->getTypeSizeInBits(structTy) / 8; 127 } 128 if (!size) { 129 size = fir::getTypeSizeAndAlignmentOrCrash(loc, globalOp.getType(), 130 *dl, kindMap) 131 .first; 132 } 133 auto sizeVal = builder.createIntegerConstant(loc, idxTy, *size); 134 135 // Global variable address 136 mlir::Value addr = builder.create<fir::AddrOfOp>( 137 loc, globalOp.resultType(), globalOp.getSymbol()); 138 139 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( 140 builder, loc, fTy, registeredMod, addr, gblName, sizeVal)}; 141 builder.create<fir::CallOp>(loc, func, args); 142 } break; 143 case cuf::DataAttribute::Managed: 144 TODO(loc, "registration of managed variables"); 145 default: 146 break; 147 } 148 } 149 } 150 builder.create<mlir::LLVM::ReturnOp>(loc, mlir::ValueRange{}); 151 152 // Create the llvm.global_ctor with the function. 153 // TODO: We might want to have a utility that retrieve it if already 154 // created and adds new functions. 155 builder.setInsertionPointToEnd(mod.getBody()); 156 llvm::SmallVector<mlir::Attribute> funcs; 157 funcs.push_back( 158 mlir::FlatSymbolRefAttr::get(mod.getContext(), func.getSymName())); 159 llvm::SmallVector<int> priorities; 160 priorities.push_back(0); 161 builder.create<mlir::LLVM::GlobalCtorsOp>( 162 mod.getLoc(), builder.getArrayAttr(funcs), 163 builder.getI32ArrayAttr(priorities)); 164 } 165 }; 166 167 } // end anonymous namespace 168