1 //===- MapsForPrivatizedSymbols.cpp ---------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 //===----------------------------------------------------------------------===// 10 /// \file 11 /// An OpenMP dialect related pass for FIR/HLFIR which creates MapInfoOp 12 /// instances for certain privatized symbols. 13 /// For example, if an allocatable variable is used in a private clause attached 14 /// to a omp.target op, then the allocatable variable's descriptor will be 15 /// needed on the device (e.g. GPU). This descriptor needs to be separately 16 /// mapped onto the device. This pass creates the necessary omp.map.info ops for 17 /// this. 18 //===----------------------------------------------------------------------===// 19 // TODO: 20 // 1. Before adding omp.map.info, check if we already have an omp.map.info for 21 // the variable in question. 22 // 2. Generalize this for more than just omp.target ops. 23 //===----------------------------------------------------------------------===// 24 25 #include "flang/Optimizer/Builder/FIRBuilder.h" 26 #include "flang/Optimizer/Dialect/FIRType.h" 27 #include "flang/Optimizer/Dialect/Support/KindMapping.h" 28 #include "flang/Optimizer/HLFIR/HLFIROps.h" 29 #include "flang/Optimizer/OpenMP/Passes.h" 30 31 #include "mlir/Dialect/Func/IR/FuncOps.h" 32 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 33 #include "mlir/IR/BuiltinAttributes.h" 34 #include "mlir/IR/SymbolTable.h" 35 #include "mlir/Pass/Pass.h" 36 #include "llvm/Frontend/OpenMP/OMPConstants.h" 37 #include "llvm/Support/Debug.h" 38 #include <type_traits> 39 40 #define DEBUG_TYPE "omp-maps-for-privatized-symbols" 41 42 namespace flangomp { 43 #define GEN_PASS_DEF_MAPSFORPRIVATIZEDSYMBOLSPASS 44 #include "flang/Optimizer/OpenMP/Passes.h.inc" 45 } // namespace flangomp 46 using namespace mlir; 47 namespace { 48 class MapsForPrivatizedSymbolsPass 49 : public flangomp::impl::MapsForPrivatizedSymbolsPassBase< 50 MapsForPrivatizedSymbolsPass> { 51 52 omp::MapInfoOp createMapInfo(Location loc, Value var, 53 fir::FirOpBuilder &builder) { 54 uint64_t mapTypeTo = static_cast< 55 std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( 56 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); 57 Operation *definingOp = var.getDefiningOp(); 58 auto declOp = llvm::dyn_cast_or_null<hlfir::DeclareOp>(definingOp); 59 assert(declOp && 60 "Expected defining Op of privatized var to be hlfir.declare"); 61 62 // We want the first result of the hlfir.declare op because our goal 63 // is to map the descriptor (fir.box or fir.boxchar) and the first 64 // result for hlfir.declare is the descriptor if a the symbol being 65 // decalred needs a descriptor. 66 Value varPtr = declOp.getBase(); 67 68 // If we do not have a reference to descritor, but the descriptor itself 69 // then we need to store that on the stack so that we can map the 70 // address of the descriptor. 71 if (mlir::isa<fir::BaseBoxType>(varPtr.getType()) || 72 mlir::isa<fir::BoxCharType>(varPtr.getType())) { 73 OpBuilder::InsertPoint savedInsPoint = builder.saveInsertionPoint(); 74 mlir::Block *allocaBlock = builder.getAllocaBlock(); 75 assert(allocaBlock && "No allocablock found for a funcOp"); 76 builder.setInsertionPointToStart(allocaBlock); 77 auto alloca = builder.create<fir::AllocaOp>(loc, varPtr.getType()); 78 builder.restoreInsertionPoint(savedInsPoint); 79 builder.create<fir::StoreOp>(loc, varPtr, alloca); 80 varPtr = alloca; 81 } 82 return builder.create<omp::MapInfoOp>( 83 loc, varPtr.getType(), varPtr, 84 TypeAttr::get(llvm::cast<omp::PointerLikeType>(varPtr.getType()) 85 .getElementType()), 86 /*varPtrPtr=*/Value{}, 87 /*members=*/SmallVector<Value>{}, 88 /*member_index=*/mlir::ArrayAttr{}, 89 /*bounds=*/ValueRange{}, 90 builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false), 91 mapTypeTo), 92 builder.getAttr<omp::VariableCaptureKindAttr>( 93 omp::VariableCaptureKind::ByRef), 94 StringAttr(), builder.getBoolAttr(false)); 95 } 96 void addMapInfoOp(omp::TargetOp targetOp, omp::MapInfoOp mapInfoOp) { 97 auto argIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp); 98 unsigned insertIndex = 99 argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs(); 100 targetOp.getMapVarsMutable().append(ValueRange{mapInfoOp}); 101 targetOp.getRegion().insertArgument(insertIndex, mapInfoOp.getType(), 102 mapInfoOp.getLoc()); 103 } 104 void addMapInfoOps(omp::TargetOp targetOp, 105 llvm::SmallVectorImpl<omp::MapInfoOp> &mapInfoOps) { 106 for (auto mapInfoOp : mapInfoOps) 107 addMapInfoOp(targetOp, mapInfoOp); 108 } 109 void runOnOperation() override { 110 ModuleOp module = getOperation()->getParentOfType<ModuleOp>(); 111 fir::KindMapping kindMap = fir::getKindMapping(module); 112 fir::FirOpBuilder builder{module, std::move(kindMap)}; 113 llvm::DenseMap<Operation *, llvm::SmallVector<omp::MapInfoOp, 4>> 114 mapInfoOpsForTarget; 115 116 getOperation()->walk([&](omp::TargetOp targetOp) { 117 if (targetOp.getPrivateVars().empty()) 118 return; 119 OperandRange privVars = targetOp.getPrivateVars(); 120 llvm::SmallVector<int64_t> privVarMapIdx; 121 122 std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms(); 123 SmallVector<omp::MapInfoOp, 4> mapInfoOps; 124 for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) { 125 126 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym); 127 omp::PrivateClauseOp privatizer = 128 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>( 129 targetOp, privatizerName); 130 if (!privatizer.needsMap()) { 131 privVarMapIdx.push_back(-1); 132 continue; 133 } 134 135 privVarMapIdx.push_back(targetOp.getMapVars().size() + 136 mapInfoOps.size()); 137 138 builder.setInsertionPoint(targetOp); 139 Location loc = targetOp.getLoc(); 140 omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder); 141 mapInfoOps.push_back(mapInfoOp); 142 143 LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n"); 144 LLVM_DEBUG(mapInfoOp.dump()); 145 } 146 if (!mapInfoOps.empty()) { 147 mapInfoOpsForTarget.insert({targetOp.getOperation(), mapInfoOps}); 148 targetOp.setPrivateMapsAttr( 149 mlir::DenseI64ArrayAttr::get(targetOp.getContext(), privVarMapIdx)); 150 } 151 }); 152 if (!mapInfoOpsForTarget.empty()) { 153 for (auto &[targetOp, mapInfoOps] : mapInfoOpsForTarget) { 154 addMapInfoOps(static_cast<omp::TargetOp>(targetOp), mapInfoOps); 155 } 156 } 157 } 158 }; 159 } // namespace 160