1 //===-- Utils..cpp ----------------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Utils.h" 14 15 #include "Clauses.h" 16 #include <flang/Lower/AbstractConverter.h> 17 #include <flang/Lower/ConvertType.h> 18 #include <flang/Lower/PFTBuilder.h> 19 #include <flang/Optimizer/Builder/FIRBuilder.h> 20 #include <flang/Parser/parse-tree.h> 21 #include <flang/Parser/tools.h> 22 #include <flang/Semantics/tools.h> 23 #include <llvm/Support/CommandLine.h> 24 25 #include <algorithm> 26 #include <numeric> 27 28 llvm::cl::opt<bool> treatIndexAsSection( 29 "openmp-treat-index-as-section", 30 llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."), 31 llvm::cl::init(true)); 32 33 llvm::cl::opt<bool> enableDelayedPrivatization( 34 "openmp-enable-delayed-privatization", 35 llvm::cl::desc( 36 "Emit `[first]private` variables as clauses on the MLIR ops."), 37 llvm::cl::init(false)); 38 39 namespace Fortran { 40 namespace lower { 41 namespace omp { 42 43 int64_t getCollapseValue(const List<Clause> &clauses) { 44 auto iter = llvm::find_if(clauses, [](const Clause &clause) { 45 return clause.id == llvm::omp::Clause::OMPC_collapse; 46 }); 47 if (iter != clauses.end()) { 48 const auto &collapse = std::get<clause::Collapse>(iter->u); 49 return evaluate::ToInt64(collapse.v).value(); 50 } 51 return 1; 52 } 53 54 uint32_t getOpenMPVersion(mlir::ModuleOp mod) { 55 if (mlir::Attribute verAttr = mod->getAttr("omp.version")) 56 return llvm::cast<mlir::omp::VersionAttr>(verAttr).getVersion(); 57 llvm_unreachable("Expecting OpenMP version attribute in module"); 58 } 59 60 void genObjectList(const ObjectList &objects, 61 Fortran::lower::AbstractConverter &converter, 62 llvm::SmallVectorImpl<mlir::Value> &operands) { 63 for (const Object &object : objects) { 64 const Fortran::semantics::Symbol *sym = object.id(); 65 assert(sym && "Expected Symbol"); 66 if (mlir::Value variable = converter.getSymbolAddress(*sym)) { 67 operands.push_back(variable); 68 } else if (const auto *details = 69 sym->detailsIf<Fortran::semantics::HostAssocDetails>()) { 70 operands.push_back(converter.getSymbolAddress(details->symbol())); 71 converter.copySymbolBinding(details->symbol(), *sym); 72 } 73 } 74 } 75 76 mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter, 77 std::size_t loopVarTypeSize) { 78 // OpenMP runtime requires 32-bit or 64-bit loop variables. 79 loopVarTypeSize = loopVarTypeSize * 8; 80 if (loopVarTypeSize < 32) { 81 loopVarTypeSize = 32; 82 } else if (loopVarTypeSize > 64) { 83 loopVarTypeSize = 64; 84 mlir::emitWarning(converter.getCurrentLocation(), 85 "OpenMP loop iteration variable cannot have more than 64 " 86 "bits size and will be narrowed into 64 bits."); 87 } 88 assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) && 89 "OpenMP loop iteration variable size must be transformed into 32-bit " 90 "or 64-bit"); 91 return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize); 92 } 93 94 Fortran::semantics::Symbol * 95 getIterationVariableSymbol(const Fortran::lower::pft::Evaluation &eval) { 96 return eval.visit(Fortran::common::visitors{ 97 [&](const Fortran::parser::DoConstruct &doLoop) { 98 if (const auto &maybeCtrl = doLoop.GetLoopControl()) { 99 using LoopControl = Fortran::parser::LoopControl; 100 if (auto *bounds = std::get_if<LoopControl::Bounds>(&maybeCtrl->u)) { 101 static_assert( 102 std::is_same_v<decltype(bounds->name), 103 Fortran::parser::Scalar<Fortran::parser::Name>>); 104 return bounds->name.thing.symbol; 105 } 106 } 107 return static_cast<Fortran::semantics::Symbol *>(nullptr); 108 }, 109 [](auto &&) { 110 return static_cast<Fortran::semantics::Symbol *>(nullptr); 111 }, 112 }); 113 } 114 115 void gatherFuncAndVarSyms( 116 const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause, 117 llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) { 118 for (const Object &object : objects) 119 symbolAndClause.emplace_back(clause, *object.id()); 120 } 121 122 mlir::omp::MapInfoOp 123 createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, 124 mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name, 125 llvm::ArrayRef<mlir::Value> bounds, 126 llvm::ArrayRef<mlir::Value> members, 127 mlir::DenseIntElementsAttr membersIndex, uint64_t mapType, 128 mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, 129 bool partialMap) { 130 if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) { 131 baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr); 132 retTy = baseAddr.getType(); 133 } 134 135 mlir::TypeAttr varType = mlir::TypeAttr::get( 136 llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType()); 137 138 mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>( 139 loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds, 140 builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), 141 builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType), 142 builder.getStringAttr(name), builder.getBoolAttr(partialMap)); 143 144 return op; 145 } 146 147 static int 148 getComponentPlacementInParent(const Fortran::semantics::Symbol *componentSym) { 149 const auto *derived = 150 componentSym->owner() 151 .derivedTypeSpec() 152 ->typeSymbol() 153 .detailsIf<Fortran::semantics::DerivedTypeDetails>(); 154 assert(derived && 155 "expected derived type details when processing component symbol"); 156 for (auto [placement, name] : llvm::enumerate(derived->componentNames())) 157 if (name == componentSym->name()) 158 return placement; 159 return -1; 160 } 161 162 static std::optional<Object> 163 getComponentObject(std::optional<Object> object, 164 Fortran::semantics::SemanticsContext &semaCtx) { 165 if (!object) 166 return std::nullopt; 167 168 auto ref = evaluate::ExtractDataRef(*object.value().ref()); 169 if (!ref) 170 return std::nullopt; 171 172 if (std::holds_alternative<evaluate::Component>(ref->u)) 173 return object; 174 175 auto baseObj = getBaseObject(object.value(), semaCtx); 176 if (!baseObj) 177 return std::nullopt; 178 179 return getComponentObject(baseObj.value(), semaCtx); 180 } 181 182 static void 183 generateMemberPlacementIndices(const Object &object, 184 llvm::SmallVectorImpl<int> &indices, 185 Fortran::semantics::SemanticsContext &semaCtx) { 186 auto compObj = getComponentObject(object, semaCtx); 187 while (compObj) { 188 indices.push_back(getComponentPlacementInParent(compObj->id())); 189 compObj = 190 getComponentObject(getBaseObject(compObj.value(), semaCtx), semaCtx); 191 } 192 193 indices = llvm::SmallVector<int>{llvm::reverse(indices)}; 194 } 195 196 void addChildIndexAndMapToParent( 197 const omp::Object &object, 198 std::map<const Fortran::semantics::Symbol *, 199 llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices, 200 mlir::omp::MapInfoOp &mapOp, 201 Fortran::semantics::SemanticsContext &semaCtx) { 202 std::optional<Fortran::evaluate::DataRef> dataRef = 203 ExtractDataRef(object.designator); 204 assert(dataRef.has_value() && 205 "DataRef could not be extracted during mapping of derived type " 206 "cannot proceed"); 207 const Fortran::semantics::Symbol *parentSym = &dataRef->GetFirstSymbol(); 208 assert(parentSym && "Could not find parent symbol during lower of " 209 "a component member in OpenMP map clause"); 210 llvm::SmallVector<int> indices; 211 generateMemberPlacementIndices(object, indices, semaCtx); 212 parentMemberIndices[parentSym].push_back({indices, mapOp}); 213 } 214 215 static void calculateShapeAndFillIndices( 216 llvm::SmallVectorImpl<int64_t> &shape, 217 llvm::SmallVectorImpl<OmpMapMemberIndicesData> &memberPlacementData) { 218 shape.push_back(memberPlacementData.size()); 219 size_t largestIndicesSize = 220 std::max_element(memberPlacementData.begin(), memberPlacementData.end(), 221 [](auto a, auto b) { 222 return a.memberPlacementIndices.size() < 223 b.memberPlacementIndices.size(); 224 }) 225 ->memberPlacementIndices.size(); 226 shape.push_back(largestIndicesSize); 227 228 // DenseElementsAttr expects a rectangular shape for the data, so all 229 // index lists have to be of the same length, this emplaces -1 as filler. 230 for (auto &v : memberPlacementData) { 231 if (v.memberPlacementIndices.size() < largestIndicesSize) { 232 auto *prevEnd = v.memberPlacementIndices.end(); 233 v.memberPlacementIndices.resize(largestIndicesSize); 234 std::fill(prevEnd, v.memberPlacementIndices.end(), -1); 235 } 236 } 237 } 238 239 static mlir::DenseIntElementsAttr createDenseElementsAttrFromIndices( 240 llvm::SmallVectorImpl<OmpMapMemberIndicesData> &memberPlacementData, 241 fir::FirOpBuilder &builder) { 242 llvm::SmallVector<int64_t> shape; 243 calculateShapeAndFillIndices(shape, memberPlacementData); 244 245 llvm::SmallVector<int> indicesFlattened = std::accumulate( 246 memberPlacementData.begin(), memberPlacementData.end(), 247 llvm::SmallVector<int>(), 248 [](llvm::SmallVector<int> &x, OmpMapMemberIndicesData y) { 249 x.insert(x.end(), y.memberPlacementIndices.begin(), 250 y.memberPlacementIndices.end()); 251 return x; 252 }); 253 254 return mlir::DenseIntElementsAttr::get( 255 mlir::VectorType::get(shape, 256 mlir::IntegerType::get(builder.getContext(), 32)), 257 indicesFlattened); 258 } 259 260 void insertChildMapInfoIntoParent( 261 Fortran::lower::AbstractConverter &converter, 262 std::map<const Fortran::semantics::Symbol *, 263 llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices, 264 llvm::SmallVectorImpl<mlir::Value> &mapOperands, 265 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &mapSyms, 266 llvm::SmallVectorImpl<mlir::Type> *mapSymTypes, 267 llvm::SmallVectorImpl<mlir::Location> *mapSymLocs) { 268 for (auto indices : parentMemberIndices) { 269 bool parentExists = false; 270 size_t parentIdx; 271 for (parentIdx = 0; parentIdx < mapSyms.size(); ++parentIdx) { 272 if (mapSyms[parentIdx] == indices.first) { 273 parentExists = true; 274 break; 275 } 276 } 277 278 if (parentExists) { 279 auto mapOp = llvm::cast<mlir::omp::MapInfoOp>( 280 mapOperands[parentIdx].getDefiningOp()); 281 282 // NOTE: To maintain appropriate SSA ordering, we move the parent map 283 // which will now have references to its children after the last 284 // of its members to be generated. This is necessary when a user 285 // has defined a series of parent and children maps where the parent 286 // precedes the children. An alternative, may be to do 287 // delayed generation of map info operations from the clauses and 288 // organize them first before generation. 289 mapOp->moveAfter(indices.second.back().memberMap); 290 291 for (auto memberIndicesData : indices.second) 292 mapOp.getMembersMutable().append( 293 memberIndicesData.memberMap.getResult()); 294 295 mapOp.setMembersIndexAttr(createDenseElementsAttrFromIndices( 296 indices.second, converter.getFirOpBuilder())); 297 } else { 298 // NOTE: We take the map type of the first child, this may not 299 // be the correct thing to do, however, we shall see. For the moment 300 // it allows this to work with enter and exit without causing MLIR 301 // verification issues. The more appropriate thing may be to take 302 // the "main" map type clause from the directive being used. 303 uint64_t mapType = indices.second[0].memberMap.getMapType().value_or(0); 304 305 // create parent to emplace and bind members 306 mlir::Value origSymbol = converter.getSymbolAddress(*indices.first); 307 308 llvm::SmallVector<mlir::Value> members; 309 for (OmpMapMemberIndicesData memberIndicesData : indices.second) 310 members.push_back((mlir::Value)memberIndicesData.memberMap); 311 312 mlir::Value mapOp = createMapInfoOp( 313 converter.getFirOpBuilder(), origSymbol.getLoc(), origSymbol, 314 /*varPtrPtr=*/mlir::Value(), indices.first->name().ToString(), 315 /*bounds=*/{}, members, 316 createDenseElementsAttrFromIndices(indices.second, 317 converter.getFirOpBuilder()), 318 mapType, mlir::omp::VariableCaptureKind::ByRef, origSymbol.getType(), 319 /*partialMap=*/true); 320 321 mapOperands.push_back(mapOp); 322 mapSyms.push_back(indices.first); 323 324 if (mapSymTypes) 325 mapSymTypes->push_back(mapOp.getType()); 326 if (mapSymLocs) 327 mapSymLocs->push_back(mapOp.getLoc()); 328 } 329 } 330 } 331 332 Fortran::semantics::Symbol * 333 getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) { 334 Fortran::semantics::Symbol *sym = nullptr; 335 std::visit( 336 Fortran::common::visitors{ 337 [&](const Fortran::parser::Designator &designator) { 338 if (auto *arrayEle = 339 Fortran::parser::Unwrap<Fortran::parser::ArrayElement>( 340 designator)) { 341 // Use getLastName to retrieve the arrays symbol, this will 342 // provide the farthest right symbol (the last) in a designator, 343 // i.e. providing something like the following: 344 // "dtype1%dtype2%array[2:10]", will result in "array" 345 sym = GetLastName(arrayEle->base).symbol; 346 } else if (auto *structComp = Fortran::parser::Unwrap< 347 Fortran::parser::StructureComponent>(designator)) { 348 sym = structComp->component.symbol; 349 } else if (const Fortran::parser::Name *name = 350 Fortran::semantics::getDesignatorNameIfDataRef( 351 designator)) { 352 sym = name->symbol; 353 } 354 }, 355 [&](const Fortran::parser::Name &name) { sym = name.symbol; }}, 356 ompObject.u); 357 return sym; 358 } 359 360 } // namespace omp 361 } // namespace lower 362 } // namespace Fortran 363