1 //===-- Utils..cpp ----------------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Utils.h" 14 15 #include "Clauses.h" 16 #include <flang/Lower/AbstractConverter.h> 17 #include <flang/Lower/ConvertType.h> 18 #include <flang/Lower/PFTBuilder.h> 19 #include <flang/Optimizer/Builder/FIRBuilder.h> 20 #include <flang/Optimizer/Builder/Todo.h> 21 #include <flang/Parser/parse-tree.h> 22 #include <flang/Parser/tools.h> 23 #include <flang/Semantics/tools.h> 24 #include <llvm/Support/CommandLine.h> 25 26 #include <algorithm> 27 #include <numeric> 28 29 llvm::cl::opt<bool> treatIndexAsSection( 30 "openmp-treat-index-as-section", 31 llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."), 32 llvm::cl::init(true)); 33 34 llvm::cl::opt<bool> enableDelayedPrivatization( 35 "openmp-enable-delayed-privatization", 36 llvm::cl::desc( 37 "Emit `[first]private` variables as clauses on the MLIR ops."), 38 llvm::cl::init(true)); 39 40 llvm::cl::opt<bool> enableDelayedPrivatizationStaging( 41 "openmp-enable-delayed-privatization-staging", 42 llvm::cl::desc("For partially supported constructs, emit `[first]private` " 43 "variables as clauses on the MLIR ops."), 44 llvm::cl::init(false)); 45 46 namespace Fortran { 47 namespace lower { 48 namespace omp { 49 50 int64_t getCollapseValue(const List<Clause> &clauses) { 51 auto iter = llvm::find_if(clauses, [](const Clause &clause) { 52 return clause.id == llvm::omp::Clause::OMPC_collapse; 53 }); 54 if (iter != clauses.end()) { 55 const auto &collapse = std::get<clause::Collapse>(iter->u); 56 return evaluate::ToInt64(collapse.v).value(); 57 } 58 return 1; 59 } 60 61 void genObjectList(const ObjectList &objects, 62 lower::AbstractConverter &converter, 63 llvm::SmallVectorImpl<mlir::Value> &operands) { 64 for (const Object &object : objects) { 65 const semantics::Symbol *sym = object.sym(); 66 assert(sym && "Expected Symbol"); 67 if (mlir::Value variable = converter.getSymbolAddress(*sym)) { 68 operands.push_back(variable); 69 } else if (const auto *details = 70 sym->detailsIf<semantics::HostAssocDetails>()) { 71 operands.push_back(converter.getSymbolAddress(details->symbol())); 72 converter.copySymbolBinding(details->symbol(), *sym); 73 } 74 } 75 } 76 77 mlir::Type getLoopVarType(lower::AbstractConverter &converter, 78 std::size_t loopVarTypeSize) { 79 // OpenMP runtime requires 32-bit or 64-bit loop variables. 80 loopVarTypeSize = loopVarTypeSize * 8; 81 if (loopVarTypeSize < 32) { 82 loopVarTypeSize = 32; 83 } else if (loopVarTypeSize > 64) { 84 loopVarTypeSize = 64; 85 mlir::emitWarning(converter.getCurrentLocation(), 86 "OpenMP loop iteration variable cannot have more than 64 " 87 "bits size and will be narrowed into 64 bits."); 88 } 89 assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) && 90 "OpenMP loop iteration variable size must be transformed into 32-bit " 91 "or 64-bit"); 92 return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize); 93 } 94 95 semantics::Symbol * 96 getIterationVariableSymbol(const lower::pft::Evaluation &eval) { 97 return eval.visit(common::visitors{ 98 [&](const parser::DoConstruct &doLoop) { 99 if (const auto &maybeCtrl = doLoop.GetLoopControl()) { 100 using LoopControl = parser::LoopControl; 101 if (auto *bounds = std::get_if<LoopControl::Bounds>(&maybeCtrl->u)) { 102 static_assert(std::is_same_v<decltype(bounds->name), 103 parser::Scalar<parser::Name>>); 104 return bounds->name.thing.symbol; 105 } 106 } 107 return static_cast<semantics::Symbol *>(nullptr); 108 }, 109 [](auto &&) { return static_cast<semantics::Symbol *>(nullptr); }, 110 }); 111 } 112 113 void gatherFuncAndVarSyms( 114 const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause, 115 llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) { 116 for (const Object &object : objects) 117 symbolAndClause.emplace_back(clause, *object.sym()); 118 } 119 120 mlir::omp::MapInfoOp 121 createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, 122 mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name, 123 llvm::ArrayRef<mlir::Value> bounds, 124 llvm::ArrayRef<mlir::Value> members, 125 mlir::DenseIntElementsAttr membersIndex, uint64_t mapType, 126 mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, 127 bool partialMap) { 128 if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) { 129 baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr); 130 retTy = baseAddr.getType(); 131 } 132 133 mlir::TypeAttr varType = mlir::TypeAttr::get( 134 llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType()); 135 136 // For types with unknown extents such as <2x?xi32> we discard the incomplete 137 // type info and only retain the base type. The correct dimensions are later 138 // recovered through the bounds info. 139 if (auto seqType = llvm::dyn_cast<fir::SequenceType>(varType.getValue())) 140 if (seqType.hasDynamicExtents()) 141 varType = mlir::TypeAttr::get(seqType.getEleTy()); 142 143 mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>( 144 loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds, 145 builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), 146 builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType), 147 builder.getStringAttr(name), builder.getBoolAttr(partialMap)); 148 149 return op; 150 } 151 152 static int 153 getComponentPlacementInParent(const semantics::Symbol *componentSym) { 154 const auto *derived = componentSym->owner() 155 .derivedTypeSpec() 156 ->typeSymbol() 157 .detailsIf<semantics::DerivedTypeDetails>(); 158 assert(derived && 159 "expected derived type details when processing component symbol"); 160 for (auto [placement, name] : llvm::enumerate(derived->componentNames())) 161 if (name == componentSym->name()) 162 return placement; 163 return -1; 164 } 165 166 static std::optional<Object> 167 getComponentObject(std::optional<Object> object, 168 semantics::SemanticsContext &semaCtx) { 169 if (!object) 170 return std::nullopt; 171 172 auto ref = evaluate::ExtractDataRef(*object.value().ref()); 173 if (!ref) 174 return std::nullopt; 175 176 if (std::holds_alternative<evaluate::Component>(ref->u)) 177 return object; 178 179 auto baseObj = getBaseObject(object.value(), semaCtx); 180 if (!baseObj) 181 return std::nullopt; 182 183 return getComponentObject(baseObj.value(), semaCtx); 184 } 185 186 static void 187 generateMemberPlacementIndices(const Object &object, 188 llvm::SmallVectorImpl<int> &indices, 189 semantics::SemanticsContext &semaCtx) { 190 auto compObj = getComponentObject(object, semaCtx); 191 while (compObj) { 192 indices.push_back(getComponentPlacementInParent(compObj->sym())); 193 compObj = 194 getComponentObject(getBaseObject(compObj.value(), semaCtx), semaCtx); 195 } 196 197 indices = llvm::SmallVector<int>{llvm::reverse(indices)}; 198 } 199 200 void addChildIndexAndMapToParent( 201 const omp::Object &object, 202 std::map<const semantics::Symbol *, 203 llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices, 204 mlir::omp::MapInfoOp &mapOp, semantics::SemanticsContext &semaCtx) { 205 std::optional<evaluate::DataRef> dataRef = ExtractDataRef(object.ref()); 206 assert(dataRef.has_value() && 207 "DataRef could not be extracted during mapping of derived type " 208 "cannot proceed"); 209 const semantics::Symbol *parentSym = &dataRef->GetFirstSymbol(); 210 assert(parentSym && "Could not find parent symbol during lower of " 211 "a component member in OpenMP map clause"); 212 llvm::SmallVector<int> indices; 213 generateMemberPlacementIndices(object, indices, semaCtx); 214 parentMemberIndices[parentSym].push_back({indices, mapOp}); 215 } 216 217 static void calculateShapeAndFillIndices( 218 llvm::SmallVectorImpl<int64_t> &shape, 219 llvm::SmallVectorImpl<OmpMapMemberIndicesData> &memberPlacementData) { 220 shape.push_back(memberPlacementData.size()); 221 size_t largestIndicesSize = 222 std::max_element(memberPlacementData.begin(), memberPlacementData.end(), 223 [](auto a, auto b) { 224 return a.memberPlacementIndices.size() < 225 b.memberPlacementIndices.size(); 226 }) 227 ->memberPlacementIndices.size(); 228 shape.push_back(largestIndicesSize); 229 230 // DenseElementsAttr expects a rectangular shape for the data, so all 231 // index lists have to be of the same length, this emplaces -1 as filler. 232 for (auto &v : memberPlacementData) { 233 if (v.memberPlacementIndices.size() < largestIndicesSize) { 234 auto *prevEnd = v.memberPlacementIndices.end(); 235 v.memberPlacementIndices.resize(largestIndicesSize); 236 std::fill(prevEnd, v.memberPlacementIndices.end(), -1); 237 } 238 } 239 } 240 241 static mlir::DenseIntElementsAttr createDenseElementsAttrFromIndices( 242 llvm::SmallVectorImpl<OmpMapMemberIndicesData> &memberPlacementData, 243 fir::FirOpBuilder &builder) { 244 llvm::SmallVector<int64_t> shape; 245 calculateShapeAndFillIndices(shape, memberPlacementData); 246 247 llvm::SmallVector<int> indicesFlattened = 248 std::accumulate(memberPlacementData.begin(), memberPlacementData.end(), 249 llvm::SmallVector<int>(), 250 [](llvm::SmallVector<int> &x, OmpMapMemberIndicesData y) { 251 x.insert(x.end(), y.memberPlacementIndices.begin(), 252 y.memberPlacementIndices.end()); 253 return x; 254 }); 255 256 return mlir::DenseIntElementsAttr::get( 257 mlir::VectorType::get(shape, 258 mlir::IntegerType::get(builder.getContext(), 32)), 259 indicesFlattened); 260 } 261 262 void insertChildMapInfoIntoParent( 263 lower::AbstractConverter &converter, 264 std::map<const semantics::Symbol *, 265 llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices, 266 llvm::SmallVectorImpl<mlir::Value> &mapOperands, 267 llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms, 268 llvm::SmallVectorImpl<mlir::Type> *mapSymTypes, 269 llvm::SmallVectorImpl<mlir::Location> *mapSymLocs) { 270 for (auto indices : parentMemberIndices) { 271 bool parentExists = false; 272 size_t parentIdx; 273 for (parentIdx = 0; parentIdx < mapSyms.size(); ++parentIdx) { 274 if (mapSyms[parentIdx] == indices.first) { 275 parentExists = true; 276 break; 277 } 278 } 279 280 if (parentExists) { 281 auto mapOp = llvm::cast<mlir::omp::MapInfoOp>( 282 mapOperands[parentIdx].getDefiningOp()); 283 284 // NOTE: To maintain appropriate SSA ordering, we move the parent map 285 // which will now have references to its children after the last 286 // of its members to be generated. This is necessary when a user 287 // has defined a series of parent and children maps where the parent 288 // precedes the children. An alternative, may be to do 289 // delayed generation of map info operations from the clauses and 290 // organize them first before generation. 291 mapOp->moveAfter(indices.second.back().memberMap); 292 293 for (auto memberIndicesData : indices.second) 294 mapOp.getMembersMutable().append( 295 memberIndicesData.memberMap.getResult()); 296 297 mapOp.setMembersIndexAttr(createDenseElementsAttrFromIndices( 298 indices.second, converter.getFirOpBuilder())); 299 } else { 300 // NOTE: We take the map type of the first child, this may not 301 // be the correct thing to do, however, we shall see. For the moment 302 // it allows this to work with enter and exit without causing MLIR 303 // verification issues. The more appropriate thing may be to take 304 // the "main" map type clause from the directive being used. 305 uint64_t mapType = indices.second[0].memberMap.getMapType().value_or(0); 306 307 // create parent to emplace and bind members 308 mlir::Value origSymbol = converter.getSymbolAddress(*indices.first); 309 310 llvm::SmallVector<mlir::Value> members; 311 for (OmpMapMemberIndicesData memberIndicesData : indices.second) 312 members.push_back((mlir::Value)memberIndicesData.memberMap); 313 314 mlir::Value mapOp = createMapInfoOp( 315 converter.getFirOpBuilder(), origSymbol.getLoc(), origSymbol, 316 /*varPtrPtr=*/mlir::Value(), indices.first->name().ToString(), 317 /*bounds=*/{}, members, 318 createDenseElementsAttrFromIndices(indices.second, 319 converter.getFirOpBuilder()), 320 mapType, mlir::omp::VariableCaptureKind::ByRef, origSymbol.getType(), 321 /*partialMap=*/true); 322 323 mapOperands.push_back(mapOp); 324 mapSyms.push_back(indices.first); 325 326 if (mapSymTypes) 327 mapSymTypes->push_back(mapOp.getType()); 328 if (mapSymLocs) 329 mapSymLocs->push_back(mapOp.getLoc()); 330 } 331 } 332 } 333 334 semantics::Symbol *getOmpObjectSymbol(const parser::OmpObject &ompObject) { 335 semantics::Symbol *sym = nullptr; 336 Fortran::common::visit( 337 common::visitors{ 338 [&](const parser::Designator &designator) { 339 if (auto *arrayEle = 340 parser::Unwrap<parser::ArrayElement>(designator)) { 341 // Use getLastName to retrieve the arrays symbol, this will 342 // provide the farthest right symbol (the last) in a designator, 343 // i.e. providing something like the following: 344 // "dtype1%dtype2%array[2:10]", will result in "array" 345 sym = GetLastName(arrayEle->base).symbol; 346 } else if (auto *structComp = 347 parser::Unwrap<parser::StructureComponent>( 348 designator)) { 349 sym = structComp->component.symbol; 350 } else if (const parser::Name *name = 351 semantics::getDesignatorNameIfDataRef(designator)) { 352 sym = name->symbol; 353 } 354 }, 355 [&](const parser::Name &name) { sym = name.symbol; }}, 356 ompObject.u); 357 return sym; 358 } 359 360 void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp, 361 mlir::Location loc) { 362 using Lastprivate = omp::clause::Lastprivate; 363 auto &maybeMod = 364 std::get<std::optional<Lastprivate::LastprivateModifier>>(lastp.t); 365 if (maybeMod) { 366 assert(*maybeMod == Lastprivate::LastprivateModifier::Conditional && 367 "Unexpected lastprivate modifier"); 368 TODO(loc, "lastprivate clause with CONDITIONAL modifier"); 369 } 370 } 371 372 } // namespace omp 373 } // namespace lower 374 } // namespace Fortran 375