xref: /llvm-project/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp (revision f9734b9df15bc1eea84ef00973c2e5560e70c27d)
1 //===- MapsForPrivatizedSymbols.cpp ---------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 //===----------------------------------------------------------------------===//
10 /// \file
11 /// An OpenMP dialect related pass for FIR/HLFIR which creates MapInfoOp
12 /// instances for certain privatized symbols.
13 /// For example, if an allocatable variable is used in a private clause attached
14 /// to a omp.target op, then the allocatable variable's descriptor will be
15 /// needed on the device (e.g. GPU). This descriptor needs to be separately
16 /// mapped onto the device. This pass creates the necessary omp.map.info ops for
17 /// this.
18 //===----------------------------------------------------------------------===//
19 // TODO:
20 // 1. Before adding omp.map.info, check if we already have an omp.map.info for
21 // the variable in question.
22 // 2. Generalize this for more than just omp.target ops.
23 //===----------------------------------------------------------------------===//
24 
25 #include "flang/Optimizer/Builder/FIRBuilder.h"
26 #include "flang/Optimizer/Dialect/FIRType.h"
27 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
28 #include "flang/Optimizer/HLFIR/HLFIROps.h"
29 #include "flang/Optimizer/OpenMP/Passes.h"
30 
31 #include "mlir/Dialect/Func/IR/FuncOps.h"
32 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
33 #include "mlir/IR/BuiltinAttributes.h"
34 #include "mlir/IR/SymbolTable.h"
35 #include "mlir/Pass/Pass.h"
36 #include "llvm/Frontend/OpenMP/OMPConstants.h"
37 #include "llvm/Support/Debug.h"
38 #include <type_traits>
39 
40 #define DEBUG_TYPE "omp-maps-for-privatized-symbols"
41 
42 namespace flangomp {
43 #define GEN_PASS_DEF_MAPSFORPRIVATIZEDSYMBOLSPASS
44 #include "flang/Optimizer/OpenMP/Passes.h.inc"
45 } // namespace flangomp
46 using namespace mlir;
47 namespace {
48 class MapsForPrivatizedSymbolsPass
49     : public flangomp::impl::MapsForPrivatizedSymbolsPassBase<
50           MapsForPrivatizedSymbolsPass> {
51 
52   omp::MapInfoOp createMapInfo(Location loc, Value var,
53                                fir::FirOpBuilder &builder) {
54     uint64_t mapTypeTo = static_cast<
55         std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
56         llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
57     Operation *definingOp = var.getDefiningOp();
58     auto declOp = llvm::dyn_cast_or_null<hlfir::DeclareOp>(definingOp);
59     assert(declOp &&
60            "Expected defining Op of privatized var to be hlfir.declare");
61 
62     // We want the first result of the hlfir.declare op because our goal
63     // is to map the descriptor (fir.box or fir.boxchar) and the first
64     // result for hlfir.declare is the descriptor if a the symbol being
65     // decalred needs a descriptor.
66     Value varPtr = declOp.getBase();
67 
68     // If we do not have a reference to descritor, but the descriptor itself
69     // then we need to store that on the stack so that we can map the
70     // address of the descriptor.
71     if (mlir::isa<fir::BaseBoxType>(varPtr.getType()) ||
72         mlir::isa<fir::BoxCharType>(varPtr.getType())) {
73       OpBuilder::InsertPoint savedInsPoint = builder.saveInsertionPoint();
74       mlir::Block *allocaBlock = builder.getAllocaBlock();
75       assert(allocaBlock && "No allocablock  found for a funcOp");
76       builder.setInsertionPointToStart(allocaBlock);
77       auto alloca = builder.create<fir::AllocaOp>(loc, varPtr.getType());
78       builder.restoreInsertionPoint(savedInsPoint);
79       builder.create<fir::StoreOp>(loc, varPtr, alloca);
80       varPtr = alloca;
81     }
82     return builder.create<omp::MapInfoOp>(
83         loc, varPtr.getType(), varPtr,
84         TypeAttr::get(llvm::cast<omp::PointerLikeType>(varPtr.getType())
85                           .getElementType()),
86         /*varPtrPtr=*/Value{},
87         /*members=*/SmallVector<Value>{},
88         /*member_index=*/mlir::ArrayAttr{},
89         /*bounds=*/ValueRange{},
90         builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false),
91                                mapTypeTo),
92         builder.getAttr<omp::VariableCaptureKindAttr>(
93             omp::VariableCaptureKind::ByRef),
94         StringAttr(), builder.getBoolAttr(false));
95   }
96   void addMapInfoOp(omp::TargetOp targetOp, omp::MapInfoOp mapInfoOp) {
97     auto argIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
98     unsigned insertIndex =
99         argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs();
100     targetOp.getMapVarsMutable().append(ValueRange{mapInfoOp});
101     targetOp.getRegion().insertArgument(insertIndex, mapInfoOp.getType(),
102                                         mapInfoOp.getLoc());
103   }
104   void addMapInfoOps(omp::TargetOp targetOp,
105                      llvm::SmallVectorImpl<omp::MapInfoOp> &mapInfoOps) {
106     for (auto mapInfoOp : mapInfoOps)
107       addMapInfoOp(targetOp, mapInfoOp);
108   }
109   void runOnOperation() override {
110     ModuleOp module = getOperation()->getParentOfType<ModuleOp>();
111     fir::KindMapping kindMap = fir::getKindMapping(module);
112     fir::FirOpBuilder builder{module, std::move(kindMap)};
113     llvm::DenseMap<Operation *, llvm::SmallVector<omp::MapInfoOp, 4>>
114         mapInfoOpsForTarget;
115 
116     getOperation()->walk([&](omp::TargetOp targetOp) {
117       if (targetOp.getPrivateVars().empty())
118         return;
119       OperandRange privVars = targetOp.getPrivateVars();
120       llvm::SmallVector<int64_t> privVarMapIdx;
121 
122       std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms();
123       SmallVector<omp::MapInfoOp, 4> mapInfoOps;
124       for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) {
125 
126         SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
127         omp::PrivateClauseOp privatizer =
128             SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
129                 targetOp, privatizerName);
130         if (!privatizer.needsMap()) {
131           privVarMapIdx.push_back(-1);
132           continue;
133         }
134 
135         privVarMapIdx.push_back(targetOp.getMapVars().size() +
136                                 mapInfoOps.size());
137 
138         builder.setInsertionPoint(targetOp);
139         Location loc = targetOp.getLoc();
140         omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
141         mapInfoOps.push_back(mapInfoOp);
142 
143         LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n");
144         LLVM_DEBUG(mapInfoOp.dump());
145       }
146       if (!mapInfoOps.empty()) {
147         mapInfoOpsForTarget.insert({targetOp.getOperation(), mapInfoOps});
148         targetOp.setPrivateMapsAttr(
149             mlir::DenseI64ArrayAttr::get(targetOp.getContext(), privVarMapIdx));
150       }
151     });
152     if (!mapInfoOpsForTarget.empty()) {
153       for (auto &[targetOp, mapInfoOps] : mapInfoOpsForTarget) {
154         addMapInfoOps(static_cast<omp::TargetOp>(targetOp), mapInfoOps);
155       }
156     }
157   }
158 };
159 } // namespace
160