1 //===-- CUFDeviceGlobal.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Common/Fortran.h" 10 #include "flang/Optimizer/Builder/CUFCommon.h" 11 #include "flang/Optimizer/Dialect/CUF/CUFOps.h" 12 #include "flang/Optimizer/Dialect/FIRDialect.h" 13 #include "flang/Optimizer/Dialect/FIROps.h" 14 #include "flang/Optimizer/HLFIR/HLFIROps.h" 15 #include "flang/Optimizer/Support/InternalNames.h" 16 #include "flang/Runtime/CUDA/common.h" 17 #include "flang/Runtime/allocatable.h" 18 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 19 #include "mlir/IR/SymbolTable.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "llvm/ADT/DenseSet.h" 23 24 namespace fir { 25 #define GEN_PASS_DEF_CUFDEVICEGLOBAL 26 #include "flang/Optimizer/Transforms/Passes.h.inc" 27 } // namespace fir 28 29 namespace { 30 31 static void processAddrOfOp(fir::AddrOfOp addrOfOp, 32 mlir::SymbolTable &symbolTable, 33 llvm::DenseSet<fir::GlobalOp> &candidates, 34 bool recurseInGlobal) { 35 if (auto globalOp = symbolTable.lookup<fir::GlobalOp>( 36 addrOfOp.getSymbol().getRootReference().getValue())) { 37 // TO DO: limit candidates to non-scalars. Scalars appear to have been 38 // folded in already. 39 if (recurseInGlobal) 40 globalOp.walk([&](fir::AddrOfOp op) { 41 processAddrOfOp(op, symbolTable, candidates, recurseInGlobal); 42 }); 43 candidates.insert(globalOp); 44 } 45 } 46 47 static void processEmboxOp(fir::EmboxOp emboxOp, mlir::SymbolTable &symbolTable, 48 llvm::DenseSet<fir::GlobalOp> &candidates) { 49 if (auto recTy = mlir::dyn_cast<fir::RecordType>( 50 fir::unwrapRefType(emboxOp.getMemref().getType()))) { 51 if (auto globalOp = symbolTable.lookup<fir::GlobalOp>( 52 fir::NameUniquer::getTypeDescriptorName(recTy.getName()))) { 53 if (!candidates.contains(globalOp)) { 54 globalOp.walk([&](fir::AddrOfOp op) { 55 processAddrOfOp(op, symbolTable, candidates, 56 /*recurseInGlobal=*/true); 57 }); 58 candidates.insert(globalOp); 59 } 60 } 61 } 62 } 63 64 static void 65 prepareImplicitDeviceGlobals(mlir::func::FuncOp funcOp, 66 mlir::SymbolTable &symbolTable, 67 llvm::DenseSet<fir::GlobalOp> &candidates) { 68 auto cudaProcAttr{ 69 funcOp->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName())}; 70 if (cudaProcAttr && cudaProcAttr.getValue() != cuf::ProcAttribute::Host) { 71 funcOp.walk([&](fir::AddrOfOp op) { 72 processAddrOfOp(op, symbolTable, candidates, /*recurseInGlobal=*/false); 73 }); 74 funcOp.walk( 75 [&](fir::EmboxOp op) { processEmboxOp(op, symbolTable, candidates); }); 76 } 77 } 78 79 class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> { 80 public: 81 void runOnOperation() override { 82 mlir::Operation *op = getOperation(); 83 mlir::ModuleOp mod = mlir::dyn_cast<mlir::ModuleOp>(op); 84 if (!mod) 85 return signalPassFailure(); 86 87 llvm::DenseSet<fir::GlobalOp> candidates; 88 mlir::SymbolTable symTable(mod); 89 mod.walk([&](mlir::func::FuncOp funcOp) { 90 prepareImplicitDeviceGlobals(funcOp, symTable, candidates); 91 return mlir::WalkResult::advance(); 92 }); 93 mod.walk([&](cuf::KernelOp kernelOp) { 94 kernelOp.walk([&](fir::AddrOfOp addrOfOp) { 95 processAddrOfOp(addrOfOp, symTable, candidates, 96 /*recurseInGlobal=*/false); 97 }); 98 }); 99 100 // Copying the device global variable into the gpu module 101 mlir::SymbolTable parentSymTable(mod); 102 auto gpuMod = cuf::getOrCreateGPUModule(mod, parentSymTable); 103 if (!gpuMod) 104 return signalPassFailure(); 105 mlir::SymbolTable gpuSymTable(gpuMod); 106 for (auto globalOp : mod.getOps<fir::GlobalOp>()) { 107 if (cuf::isRegisteredDeviceGlobal(globalOp)) 108 candidates.insert(globalOp); 109 } 110 for (auto globalOp : candidates) { 111 auto globalName{globalOp.getSymbol().getValue()}; 112 if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) { 113 break; 114 } 115 gpuSymTable.insert(globalOp->clone()); 116 } 117 } 118 }; 119 } // namespace 120