1 //===-- Utils..cpp ----------------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Utils.h" 14 15 #include "Clauses.h" 16 #include <flang/Lower/AbstractConverter.h> 17 #include <flang/Lower/ConvertType.h> 18 #include <flang/Lower/PFTBuilder.h> 19 #include <flang/Optimizer/Builder/FIRBuilder.h> 20 #include <flang/Parser/parse-tree.h> 21 #include <flang/Parser/tools.h> 22 #include <flang/Semantics/tools.h> 23 #include <llvm/Support/CommandLine.h> 24 25 #include <algorithm> 26 #include <numeric> 27 28 llvm::cl::opt<bool> treatIndexAsSection( 29 "openmp-treat-index-as-section", 30 llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."), 31 llvm::cl::init(true)); 32 33 llvm::cl::opt<bool> enableDelayedPrivatization( 34 "openmp-enable-delayed-privatization", 35 llvm::cl::desc( 36 "Emit `[first]private` variables as clauses on the MLIR ops."), 37 llvm::cl::init(false)); 38 39 llvm::cl::opt<bool> enableDelayedPrivatizationStaging( 40 "openmp-enable-delayed-privatization-staging", 41 llvm::cl::desc("For partially supported constructs, emit `[first]private` " 42 "variables as clauses on the MLIR ops."), 43 llvm::cl::init(false)); 44 45 namespace Fortran { 46 namespace lower { 47 namespace omp { 48 49 int64_t getCollapseValue(const List<Clause> &clauses) { 50 auto iter = llvm::find_if(clauses, [](const Clause &clause) { 51 return clause.id == llvm::omp::Clause::OMPC_collapse; 52 }); 53 if (iter != clauses.end()) { 54 const auto &collapse = std::get<clause::Collapse>(iter->u); 55 return evaluate::ToInt64(collapse.v).value(); 56 } 57 return 1; 58 } 59 60 void genObjectList(const ObjectList &objects, 61 lower::AbstractConverter &converter, 62 llvm::SmallVectorImpl<mlir::Value> &operands) { 63 for (const Object &object : objects) { 64 const semantics::Symbol *sym = object.sym(); 65 assert(sym && "Expected Symbol"); 66 if (mlir::Value variable = converter.getSymbolAddress(*sym)) { 67 operands.push_back(variable); 68 } else if (const auto *details = 69 sym->detailsIf<semantics::HostAssocDetails>()) { 70 operands.push_back(converter.getSymbolAddress(details->symbol())); 71 converter.copySymbolBinding(details->symbol(), *sym); 72 } 73 } 74 } 75 76 mlir::Type getLoopVarType(lower::AbstractConverter &converter, 77 std::size_t loopVarTypeSize) { 78 // OpenMP runtime requires 32-bit or 64-bit loop variables. 79 loopVarTypeSize = loopVarTypeSize * 8; 80 if (loopVarTypeSize < 32) { 81 loopVarTypeSize = 32; 82 } else if (loopVarTypeSize > 64) { 83 loopVarTypeSize = 64; 84 mlir::emitWarning(converter.getCurrentLocation(), 85 "OpenMP loop iteration variable cannot have more than 64 " 86 "bits size and will be narrowed into 64 bits."); 87 } 88 assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) && 89 "OpenMP loop iteration variable size must be transformed into 32-bit " 90 "or 64-bit"); 91 return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize); 92 } 93 94 semantics::Symbol * 95 getIterationVariableSymbol(const lower::pft::Evaluation &eval) { 96 return eval.visit(common::visitors{ 97 [&](const parser::DoConstruct &doLoop) { 98 if (const auto &maybeCtrl = doLoop.GetLoopControl()) { 99 using LoopControl = parser::LoopControl; 100 if (auto *bounds = std::get_if<LoopControl::Bounds>(&maybeCtrl->u)) { 101 static_assert(std::is_same_v<decltype(bounds->name), 102 parser::Scalar<parser::Name>>); 103 return bounds->name.thing.symbol; 104 } 105 } 106 return static_cast<semantics::Symbol *>(nullptr); 107 }, 108 [](auto &&) { return static_cast<semantics::Symbol *>(nullptr); }, 109 }); 110 } 111 112 void gatherFuncAndVarSyms( 113 const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause, 114 llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) { 115 for (const Object &object : objects) 116 symbolAndClause.emplace_back(clause, *object.sym()); 117 } 118 119 mlir::omp::MapInfoOp 120 createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, 121 mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name, 122 llvm::ArrayRef<mlir::Value> bounds, 123 llvm::ArrayRef<mlir::Value> members, 124 mlir::DenseIntElementsAttr membersIndex, uint64_t mapType, 125 mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, 126 bool partialMap) { 127 if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) { 128 baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr); 129 retTy = baseAddr.getType(); 130 } 131 132 mlir::TypeAttr varType = mlir::TypeAttr::get( 133 llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType()); 134 135 mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>( 136 loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds, 137 builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), 138 builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType), 139 builder.getStringAttr(name), builder.getBoolAttr(partialMap)); 140 141 return op; 142 } 143 144 static int 145 getComponentPlacementInParent(const semantics::Symbol *componentSym) { 146 const auto *derived = componentSym->owner() 147 .derivedTypeSpec() 148 ->typeSymbol() 149 .detailsIf<semantics::DerivedTypeDetails>(); 150 assert(derived && 151 "expected derived type details when processing component symbol"); 152 for (auto [placement, name] : llvm::enumerate(derived->componentNames())) 153 if (name == componentSym->name()) 154 return placement; 155 return -1; 156 } 157 158 static std::optional<Object> 159 getComponentObject(std::optional<Object> object, 160 semantics::SemanticsContext &semaCtx) { 161 if (!object) 162 return std::nullopt; 163 164 auto ref = evaluate::ExtractDataRef(*object.value().ref()); 165 if (!ref) 166 return std::nullopt; 167 168 if (std::holds_alternative<evaluate::Component>(ref->u)) 169 return object; 170 171 auto baseObj = getBaseObject(object.value(), semaCtx); 172 if (!baseObj) 173 return std::nullopt; 174 175 return getComponentObject(baseObj.value(), semaCtx); 176 } 177 178 static void 179 generateMemberPlacementIndices(const Object &object, 180 llvm::SmallVectorImpl<int> &indices, 181 semantics::SemanticsContext &semaCtx) { 182 auto compObj = getComponentObject(object, semaCtx); 183 while (compObj) { 184 indices.push_back(getComponentPlacementInParent(compObj->sym())); 185 compObj = 186 getComponentObject(getBaseObject(compObj.value(), semaCtx), semaCtx); 187 } 188 189 indices = llvm::SmallVector<int>{llvm::reverse(indices)}; 190 } 191 192 void addChildIndexAndMapToParent( 193 const omp::Object &object, 194 std::map<const semantics::Symbol *, 195 llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices, 196 mlir::omp::MapInfoOp &mapOp, semantics::SemanticsContext &semaCtx) { 197 std::optional<evaluate::DataRef> dataRef = ExtractDataRef(object.ref()); 198 assert(dataRef.has_value() && 199 "DataRef could not be extracted during mapping of derived type " 200 "cannot proceed"); 201 const semantics::Symbol *parentSym = &dataRef->GetFirstSymbol(); 202 assert(parentSym && "Could not find parent symbol during lower of " 203 "a component member in OpenMP map clause"); 204 llvm::SmallVector<int> indices; 205 generateMemberPlacementIndices(object, indices, semaCtx); 206 parentMemberIndices[parentSym].push_back({indices, mapOp}); 207 } 208 209 static void calculateShapeAndFillIndices( 210 llvm::SmallVectorImpl<int64_t> &shape, 211 llvm::SmallVectorImpl<OmpMapMemberIndicesData> &memberPlacementData) { 212 shape.push_back(memberPlacementData.size()); 213 size_t largestIndicesSize = 214 std::max_element(memberPlacementData.begin(), memberPlacementData.end(), 215 [](auto a, auto b) { 216 return a.memberPlacementIndices.size() < 217 b.memberPlacementIndices.size(); 218 }) 219 ->memberPlacementIndices.size(); 220 shape.push_back(largestIndicesSize); 221 222 // DenseElementsAttr expects a rectangular shape for the data, so all 223 // index lists have to be of the same length, this emplaces -1 as filler. 224 for (auto &v : memberPlacementData) { 225 if (v.memberPlacementIndices.size() < largestIndicesSize) { 226 auto *prevEnd = v.memberPlacementIndices.end(); 227 v.memberPlacementIndices.resize(largestIndicesSize); 228 std::fill(prevEnd, v.memberPlacementIndices.end(), -1); 229 } 230 } 231 } 232 233 static mlir::DenseIntElementsAttr createDenseElementsAttrFromIndices( 234 llvm::SmallVectorImpl<OmpMapMemberIndicesData> &memberPlacementData, 235 fir::FirOpBuilder &builder) { 236 llvm::SmallVector<int64_t> shape; 237 calculateShapeAndFillIndices(shape, memberPlacementData); 238 239 llvm::SmallVector<int> indicesFlattened = 240 std::accumulate(memberPlacementData.begin(), memberPlacementData.end(), 241 llvm::SmallVector<int>(), 242 [](llvm::SmallVector<int> &x, OmpMapMemberIndicesData y) { 243 x.insert(x.end(), y.memberPlacementIndices.begin(), 244 y.memberPlacementIndices.end()); 245 return x; 246 }); 247 248 return mlir::DenseIntElementsAttr::get( 249 mlir::VectorType::get(shape, 250 mlir::IntegerType::get(builder.getContext(), 32)), 251 indicesFlattened); 252 } 253 254 void insertChildMapInfoIntoParent( 255 lower::AbstractConverter &converter, 256 std::map<const semantics::Symbol *, 257 llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices, 258 llvm::SmallVectorImpl<mlir::Value> &mapOperands, 259 llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms, 260 llvm::SmallVectorImpl<mlir::Type> *mapSymTypes, 261 llvm::SmallVectorImpl<mlir::Location> *mapSymLocs) { 262 for (auto indices : parentMemberIndices) { 263 bool parentExists = false; 264 size_t parentIdx; 265 for (parentIdx = 0; parentIdx < mapSyms.size(); ++parentIdx) { 266 if (mapSyms[parentIdx] == indices.first) { 267 parentExists = true; 268 break; 269 } 270 } 271 272 if (parentExists) { 273 auto mapOp = llvm::cast<mlir::omp::MapInfoOp>( 274 mapOperands[parentIdx].getDefiningOp()); 275 276 // NOTE: To maintain appropriate SSA ordering, we move the parent map 277 // which will now have references to its children after the last 278 // of its members to be generated. This is necessary when a user 279 // has defined a series of parent and children maps where the parent 280 // precedes the children. An alternative, may be to do 281 // delayed generation of map info operations from the clauses and 282 // organize them first before generation. 283 mapOp->moveAfter(indices.second.back().memberMap); 284 285 for (auto memberIndicesData : indices.second) 286 mapOp.getMembersMutable().append( 287 memberIndicesData.memberMap.getResult()); 288 289 mapOp.setMembersIndexAttr(createDenseElementsAttrFromIndices( 290 indices.second, converter.getFirOpBuilder())); 291 } else { 292 // NOTE: We take the map type of the first child, this may not 293 // be the correct thing to do, however, we shall see. For the moment 294 // it allows this to work with enter and exit without causing MLIR 295 // verification issues. The more appropriate thing may be to take 296 // the "main" map type clause from the directive being used. 297 uint64_t mapType = indices.second[0].memberMap.getMapType().value_or(0); 298 299 // create parent to emplace and bind members 300 mlir::Value origSymbol = converter.getSymbolAddress(*indices.first); 301 302 llvm::SmallVector<mlir::Value> members; 303 for (OmpMapMemberIndicesData memberIndicesData : indices.second) 304 members.push_back((mlir::Value)memberIndicesData.memberMap); 305 306 mlir::Value mapOp = createMapInfoOp( 307 converter.getFirOpBuilder(), origSymbol.getLoc(), origSymbol, 308 /*varPtrPtr=*/mlir::Value(), indices.first->name().ToString(), 309 /*bounds=*/{}, members, 310 createDenseElementsAttrFromIndices(indices.second, 311 converter.getFirOpBuilder()), 312 mapType, mlir::omp::VariableCaptureKind::ByRef, origSymbol.getType(), 313 /*partialMap=*/true); 314 315 mapOperands.push_back(mapOp); 316 mapSyms.push_back(indices.first); 317 318 if (mapSymTypes) 319 mapSymTypes->push_back(mapOp.getType()); 320 if (mapSymLocs) 321 mapSymLocs->push_back(mapOp.getLoc()); 322 } 323 } 324 } 325 326 semantics::Symbol *getOmpObjectSymbol(const parser::OmpObject &ompObject) { 327 semantics::Symbol *sym = nullptr; 328 Fortran::common::visit( 329 common::visitors{ 330 [&](const parser::Designator &designator) { 331 if (auto *arrayEle = 332 parser::Unwrap<parser::ArrayElement>(designator)) { 333 // Use getLastName to retrieve the arrays symbol, this will 334 // provide the farthest right symbol (the last) in a designator, 335 // i.e. providing something like the following: 336 // "dtype1%dtype2%array[2:10]", will result in "array" 337 sym = GetLastName(arrayEle->base).symbol; 338 } else if (auto *structComp = 339 parser::Unwrap<parser::StructureComponent>( 340 designator)) { 341 sym = structComp->component.symbol; 342 } else if (const parser::Name *name = 343 semantics::getDesignatorNameIfDataRef(designator)) { 344 sym = name->symbol; 345 } 346 }, 347 [&](const parser::Name &name) { sym = name.symbol; }}, 348 ompObject.u); 349 return sym; 350 } 351 352 } // namespace omp 353 } // namespace lower 354 } // namespace Fortran 355