xref: /llvm-project/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp (revision 4b17a8b10ebb69d3bd30ee7714b5ca24f7e944dc)
1 //===-- CUFAddConstructor.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/BoxValue.h"
10 #include "flang/Optimizer/Builder/CUFCommon.h"
11 #include "flang/Optimizer/Builder/FIRBuilder.h"
12 #include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
13 #include "flang/Optimizer/Builder/Todo.h"
14 #include "flang/Optimizer/CodeGen/Target.h"
15 #include "flang/Optimizer/CodeGen/TypeConverter.h"
16 #include "flang/Optimizer/Dialect/CUF/CUFOps.h"
17 #include "flang/Optimizer/Dialect/FIRAttr.h"
18 #include "flang/Optimizer/Dialect/FIRDialect.h"
19 #include "flang/Optimizer/Dialect/FIROps.h"
20 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
21 #include "flang/Optimizer/Dialect/FIRType.h"
22 #include "flang/Optimizer/Support/DataLayout.h"
23 #include "flang/Runtime/CUDA/registration.h"
24 #include "flang/Runtime/entry-names.h"
25 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
26 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
27 #include "mlir/IR/Value.h"
28 #include "mlir/Pass/Pass.h"
29 #include "llvm/ADT/SmallVector.h"
30 
31 namespace fir {
32 #define GEN_PASS_DEF_CUFADDCONSTRUCTOR
33 #include "flang/Optimizer/Transforms/Passes.h.inc"
34 } // namespace fir
35 
36 using namespace Fortran::runtime::cuda;
37 
38 namespace {
39 
40 static constexpr llvm::StringRef cudaFortranCtorName{
41     "__cudaFortranConstructor"};
42 
43 struct CUFAddConstructor
44     : public fir::impl::CUFAddConstructorBase<CUFAddConstructor> {
45 
46   void runOnOperation() override {
47     mlir::ModuleOp mod = getOperation();
48     mlir::SymbolTable symTab(mod);
49     mlir::OpBuilder opBuilder{mod.getBodyRegion()};
50     fir::FirOpBuilder builder(opBuilder, mod);
51     fir::KindMapping kindMap{fir::getKindMapping(mod)};
52     builder.setInsertionPointToEnd(mod.getBody());
53     mlir::Location loc = mod.getLoc();
54     auto *ctx = mod.getContext();
55     auto voidTy = mlir::LLVM::LLVMVoidType::get(ctx);
56     auto idxTy = builder.getIndexType();
57     auto funcTy =
58         mlir::LLVM::LLVMFunctionType::get(voidTy, {}, /*isVarArg=*/false);
59     std::optional<mlir::DataLayout> dl =
60         fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/false);
61     if (!dl) {
62       mlir::emitError(mod.getLoc(),
63                       "data layout attribute is required to perform " +
64                           getName() + "pass");
65     }
66 
67     // Symbol reference to CUFRegisterAllocator.
68     builder.setInsertionPointToEnd(mod.getBody());
69     auto registerFuncOp = builder.create<mlir::LLVM::LLVMFuncOp>(
70         loc, RTNAME_STRING(CUFRegisterAllocator), funcTy);
71     registerFuncOp.setVisibility(mlir::SymbolTable::Visibility::Private);
72     auto cufRegisterAllocatorRef = mlir::SymbolRefAttr::get(
73         mod.getContext(), RTNAME_STRING(CUFRegisterAllocator));
74     builder.setInsertionPointToEnd(mod.getBody());
75 
76     // Create the constructor function that call CUFRegisterAllocator.
77     auto func = builder.create<mlir::LLVM::LLVMFuncOp>(loc, cudaFortranCtorName,
78                                                        funcTy);
79     func.setLinkage(mlir::LLVM::Linkage::Internal);
80     builder.setInsertionPointToStart(func.addEntryBlock(builder));
81     builder.create<mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef);
82 
83     auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
84     if (gpuMod) {
85       auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx);
86       auto registeredMod = builder.create<cuf::RegisterModuleOp>(
87           loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName()));
88 
89       fir::LLVMTypeConverter typeConverter(mod, /*applyTBAA=*/false,
90                                            /*forceUnifiedTBAATree=*/false, *dl);
91       // Register kernels
92       for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
93         if (func.isKernel()) {
94           auto kernelName = mlir::SymbolRefAttr::get(
95               builder.getStringAttr(cudaDeviceModuleName),
96               {mlir::SymbolRefAttr::get(builder.getContext(), func.getName())});
97           builder.create<cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
98         }
99       }
100 
101       // Register variables
102       for (fir::GlobalOp globalOp : mod.getOps<fir::GlobalOp>()) {
103         auto attr = globalOp.getDataAttrAttr();
104         if (!attr)
105           continue;
106 
107         mlir::func::FuncOp func;
108         switch (attr.getValue()) {
109         case cuf::DataAttribute::Device:
110         case cuf::DataAttribute::Constant: {
111           func = fir::runtime::getRuntimeFunc<mkRTKey(CUFRegisterVariable)>(
112               loc, builder);
113           auto fTy = func.getFunctionType();
114 
115           // Global variable name
116           std::string gblNameStr = globalOp.getSymbol().getValue().str();
117           gblNameStr += '\0';
118           mlir::Value gblName = fir::getBase(
119               fir::factory::createStringLiteral(builder, loc, gblNameStr));
120 
121           // Global variable size
122           std::optional<uint64_t> size;
123           if (auto boxTy =
124                   mlir::dyn_cast<fir::BaseBoxType>(globalOp.getType())) {
125             mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
126             size = dl->getTypeSizeInBits(structTy) / 8;
127           }
128           if (!size) {
129             size = fir::getTypeSizeAndAlignmentOrCrash(loc, globalOp.getType(),
130                                                        *dl, kindMap)
131                        .first;
132           }
133           auto sizeVal = builder.createIntegerConstant(loc, idxTy, *size);
134 
135           // Global variable address
136           mlir::Value addr = builder.create<fir::AddrOfOp>(
137               loc, globalOp.resultType(), globalOp.getSymbol());
138 
139           llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
140               builder, loc, fTy, registeredMod, addr, gblName, sizeVal)};
141           builder.create<fir::CallOp>(loc, func, args);
142         } break;
143         case cuf::DataAttribute::Managed:
144           TODO(loc, "registration of managed variables");
145         default:
146           break;
147         }
148       }
149     }
150     builder.create<mlir::LLVM::ReturnOp>(loc, mlir::ValueRange{});
151 
152     // Create the llvm.global_ctor with the function.
153     // TODO: We might want to have a utility that retrieve it if already
154     // created and adds new functions.
155     builder.setInsertionPointToEnd(mod.getBody());
156     llvm::SmallVector<mlir::Attribute> funcs;
157     funcs.push_back(
158         mlir::FlatSymbolRefAttr::get(mod.getContext(), func.getSymName()));
159     llvm::SmallVector<int> priorities;
160     priorities.push_back(0);
161     builder.create<mlir::LLVM::GlobalCtorsOp>(
162         mod.getLoc(), builder.getArrayAttr(funcs),
163         builder.getI32ArrayAttr(priorities));
164   }
165 };
166 
167 } // end anonymous namespace
168