xref: /llvm-project/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp (revision 544a3cb65b6b9b1455f9294d1764f47a7b8673b7)
1 //===-- CUFDeviceGlobal.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Common/Fortran.h"
10 #include "flang/Optimizer/Builder/CUFCommon.h"
11 #include "flang/Optimizer/Dialect/CUF/CUFOps.h"
12 #include "flang/Optimizer/Dialect/FIRDialect.h"
13 #include "flang/Optimizer/Dialect/FIROps.h"
14 #include "flang/Optimizer/HLFIR/HLFIROps.h"
15 #include "flang/Optimizer/Support/InternalNames.h"
16 #include "flang/Runtime/CUDA/common.h"
17 #include "flang/Runtime/allocatable.h"
18 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
19 #include "mlir/IR/SymbolTable.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/DenseSet.h"
23 
24 namespace fir {
25 #define GEN_PASS_DEF_CUFDEVICEGLOBAL
26 #include "flang/Optimizer/Transforms/Passes.h.inc"
27 } // namespace fir
28 
29 namespace {
30 
31 static void processAddrOfOp(fir::AddrOfOp addrOfOp,
32                             mlir::SymbolTable &symbolTable,
33                             llvm::DenseSet<fir::GlobalOp> &candidates,
34                             bool recurseInGlobal) {
35   if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
36           addrOfOp.getSymbol().getRootReference().getValue())) {
37     // TO DO: limit candidates to non-scalars. Scalars appear to have been
38     // folded in already.
39     if (recurseInGlobal)
40       globalOp.walk([&](fir::AddrOfOp op) {
41         processAddrOfOp(op, symbolTable, candidates, recurseInGlobal);
42       });
43     candidates.insert(globalOp);
44   }
45 }
46 
47 static void processEmboxOp(fir::EmboxOp emboxOp, mlir::SymbolTable &symbolTable,
48                            llvm::DenseSet<fir::GlobalOp> &candidates) {
49   if (auto recTy = mlir::dyn_cast<fir::RecordType>(
50           fir::unwrapRefType(emboxOp.getMemref().getType()))) {
51     if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
52             fir::NameUniquer::getTypeDescriptorName(recTy.getName()))) {
53       if (!candidates.contains(globalOp)) {
54         globalOp.walk([&](fir::AddrOfOp op) {
55           processAddrOfOp(op, symbolTable, candidates,
56                           /*recurseInGlobal=*/true);
57         });
58         candidates.insert(globalOp);
59       }
60     }
61   }
62 }
63 
64 static void
65 prepareImplicitDeviceGlobals(mlir::func::FuncOp funcOp,
66                              mlir::SymbolTable &symbolTable,
67                              llvm::DenseSet<fir::GlobalOp> &candidates) {
68   auto cudaProcAttr{
69       funcOp->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName())};
70   if (cudaProcAttr && cudaProcAttr.getValue() != cuf::ProcAttribute::Host) {
71     funcOp.walk([&](fir::AddrOfOp op) {
72       processAddrOfOp(op, symbolTable, candidates, /*recurseInGlobal=*/false);
73     });
74     funcOp.walk(
75         [&](fir::EmboxOp op) { processEmboxOp(op, symbolTable, candidates); });
76   }
77 }
78 
79 class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
80 public:
81   void runOnOperation() override {
82     mlir::Operation *op = getOperation();
83     mlir::ModuleOp mod = mlir::dyn_cast<mlir::ModuleOp>(op);
84     if (!mod)
85       return signalPassFailure();
86 
87     llvm::DenseSet<fir::GlobalOp> candidates;
88     mlir::SymbolTable symTable(mod);
89     mod.walk([&](mlir::func::FuncOp funcOp) {
90       prepareImplicitDeviceGlobals(funcOp, symTable, candidates);
91       return mlir::WalkResult::advance();
92     });
93     mod.walk([&](cuf::KernelOp kernelOp) {
94       kernelOp.walk([&](fir::AddrOfOp addrOfOp) {
95         processAddrOfOp(addrOfOp, symTable, candidates,
96                         /*recurseInGlobal=*/false);
97       });
98     });
99 
100     // Copying the device global variable into the gpu module
101     mlir::SymbolTable parentSymTable(mod);
102     auto gpuMod = cuf::getOrCreateGPUModule(mod, parentSymTable);
103     if (!gpuMod)
104       return signalPassFailure();
105     mlir::SymbolTable gpuSymTable(gpuMod);
106     for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
107       if (cuf::isRegisteredDeviceGlobal(globalOp))
108         candidates.insert(globalOp);
109     }
110     for (auto globalOp : candidates) {
111       auto globalName{globalOp.getSymbol().getValue()};
112       if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) {
113         break;
114       }
115       gpuSymTable.insert(globalOp->clone());
116     }
117   }
118 };
119 } // namespace
120