1 //===-- OpenACC.cpp -- OpenACC directive lowering -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Lower/OpenACC.h" 14 15 #include "flang/Common/idioms.h" 16 #include "flang/Lower/Bridge.h" 17 #include "flang/Lower/ConvertType.h" 18 #include "flang/Lower/DirectivesCommon.h" 19 #include "flang/Lower/Mangler.h" 20 #include "flang/Lower/PFTBuilder.h" 21 #include "flang/Lower/StatementContext.h" 22 #include "flang/Lower/Support/Utils.h" 23 #include "flang/Optimizer/Builder/BoxValue.h" 24 #include "flang/Optimizer/Builder/Complex.h" 25 #include "flang/Optimizer/Builder/FIRBuilder.h" 26 #include "flang/Optimizer/Builder/HLFIRTools.h" 27 #include "flang/Optimizer/Builder/IntrinsicCall.h" 28 #include "flang/Optimizer/Builder/Todo.h" 29 #include "flang/Parser/parse-tree-visitor.h" 30 #include "flang/Parser/parse-tree.h" 31 #include "flang/Semantics/expression.h" 32 #include "flang/Semantics/scope.h" 33 #include "flang/Semantics/tools.h" 34 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 35 #include "llvm/Frontend/OpenACC/ACC.h.inc" 36 #include "llvm/Support/Debug.h" 37 38 #define DEBUG_TYPE "flang-lower-openacc" 39 40 // Special value for * passed in device_type or gang clauses. 41 static constexpr std::int64_t starCst = -1; 42 43 static unsigned routineCounter = 0; 44 static constexpr llvm::StringRef accRoutinePrefix = "acc_routine_"; 45 static constexpr llvm::StringRef accPrivateInitName = "acc.private.init"; 46 static constexpr llvm::StringRef accReductionInitName = "acc.reduction.init"; 47 static constexpr llvm::StringRef accFirDescriptorPostfix = "_desc"; 48 49 static mlir::Location 50 genOperandLocation(Fortran::lower::AbstractConverter &converter, 51 const Fortran::parser::AccObject &accObject) { 52 mlir::Location loc = converter.genUnknownLocation(); 53 Fortran::common::visit( 54 Fortran::common::visitors{ 55 [&](const Fortran::parser::Designator &designator) { 56 loc = converter.genLocation(designator.source); 57 }, 58 [&](const Fortran::parser::Name &name) { 59 loc = converter.genLocation(name.source); 60 }}, 61 accObject.u); 62 return loc; 63 } 64 65 static void addOperands(llvm::SmallVectorImpl<mlir::Value> &operands, 66 llvm::SmallVectorImpl<int32_t> &operandSegments, 67 llvm::ArrayRef<mlir::Value> clauseOperands) { 68 operands.append(clauseOperands.begin(), clauseOperands.end()); 69 operandSegments.push_back(clauseOperands.size()); 70 } 71 72 static void addOperand(llvm::SmallVectorImpl<mlir::Value> &operands, 73 llvm::SmallVectorImpl<int32_t> &operandSegments, 74 const mlir::Value &clauseOperand) { 75 if (clauseOperand) { 76 operands.push_back(clauseOperand); 77 operandSegments.push_back(1); 78 } else { 79 operandSegments.push_back(0); 80 } 81 } 82 83 template <typename Op> 84 static Op 85 createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc, 86 mlir::Value baseAddr, std::stringstream &name, 87 mlir::SmallVector<mlir::Value> bounds, bool structured, 88 bool implicit, mlir::acc::DataClause dataClause, 89 mlir::Type retTy, llvm::ArrayRef<mlir::Value> async, 90 llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes, 91 llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes, 92 bool unwrapBoxAddr = false, mlir::Value isPresent = {}) { 93 mlir::Value varPtrPtr; 94 // The data clause may apply to either the box reference itself or the 95 // pointer to the data it holds. So use `unwrapBoxAddr` to decide. 96 // When we have a box value - assume it refers to the data inside box. 97 if ((fir::isBoxAddress(baseAddr.getType()) && unwrapBoxAddr) || 98 fir::isa_box_type(baseAddr.getType())) { 99 if (isPresent) { 100 mlir::Type ifRetTy = 101 mlir::cast<fir::BaseBoxType>(fir::unwrapRefType(baseAddr.getType())) 102 .getEleTy(); 103 if (!fir::isa_ref_type(ifRetTy)) 104 ifRetTy = fir::ReferenceType::get(ifRetTy); 105 baseAddr = 106 builder 107 .genIfOp(loc, {ifRetTy}, isPresent, 108 /*withElseRegion=*/true) 109 .genThen([&]() { 110 if (fir::isBoxAddress(baseAddr.getType())) 111 baseAddr = builder.create<fir::LoadOp>(loc, baseAddr); 112 mlir::Value boxAddr = 113 builder.create<fir::BoxAddrOp>(loc, baseAddr); 114 builder.create<fir::ResultOp>(loc, mlir::ValueRange{boxAddr}); 115 }) 116 .genElse([&] { 117 mlir::Value absent = 118 builder.create<fir::AbsentOp>(loc, ifRetTy); 119 builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent}); 120 }) 121 .getResults()[0]; 122 } else { 123 if (fir::isBoxAddress(baseAddr.getType())) 124 baseAddr = builder.create<fir::LoadOp>(loc, baseAddr); 125 baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr); 126 } 127 retTy = baseAddr.getType(); 128 } 129 130 llvm::SmallVector<mlir::Value, 8> operands; 131 llvm::SmallVector<int32_t, 8> operandSegments; 132 133 addOperand(operands, operandSegments, baseAddr); 134 addOperand(operands, operandSegments, varPtrPtr); 135 addOperands(operands, operandSegments, bounds); 136 addOperands(operands, operandSegments, async); 137 138 Op op = builder.create<Op>(loc, retTy, operands); 139 op.setNameAttr(builder.getStringAttr(name.str())); 140 op.setStructured(structured); 141 op.setImplicit(implicit); 142 op.setDataClause(dataClause); 143 op.setVarType(mlir::cast<mlir::acc::PointerLikeType>(baseAddr.getType()) 144 .getElementType()); 145 op->setAttr(Op::getOperandSegmentSizeAttr(), 146 builder.getDenseI32ArrayAttr(operandSegments)); 147 if (!asyncDeviceTypes.empty()) 148 op.setAsyncOperandsDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes)); 149 if (!asyncOnlyDeviceTypes.empty()) 150 op.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes)); 151 return op; 152 } 153 154 static void addDeclareAttr(fir::FirOpBuilder &builder, mlir::Operation *op, 155 mlir::acc::DataClause clause) { 156 if (!op) 157 return; 158 op->setAttr(mlir::acc::getDeclareAttrName(), 159 mlir::acc::DeclareAttr::get(builder.getContext(), 160 mlir::acc::DataClauseAttr::get( 161 builder.getContext(), clause))); 162 } 163 164 static mlir::func::FuncOp 165 createDeclareFunc(mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, 166 mlir::Location loc, llvm::StringRef funcName, 167 llvm::SmallVector<mlir::Type> argsTy = {}, 168 llvm::SmallVector<mlir::Location> locs = {}) { 169 auto funcTy = mlir::FunctionType::get(modBuilder.getContext(), argsTy, {}); 170 auto funcOp = modBuilder.create<mlir::func::FuncOp>(loc, funcName, funcTy); 171 funcOp.setVisibility(mlir::SymbolTable::Visibility::Private); 172 builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy, 173 locs); 174 builder.setInsertionPointToEnd(&funcOp.getRegion().back()); 175 builder.create<mlir::func::ReturnOp>(loc); 176 builder.setInsertionPointToStart(&funcOp.getRegion().back()); 177 return funcOp; 178 } 179 180 template <typename Op> 181 static Op 182 createSimpleOp(fir::FirOpBuilder &builder, mlir::Location loc, 183 const llvm::SmallVectorImpl<mlir::Value> &operands, 184 const llvm::SmallVectorImpl<int32_t> &operandSegments) { 185 llvm::ArrayRef<mlir::Type> argTy; 186 Op op = builder.create<Op>(loc, argTy, operands); 187 op->setAttr(Op::getOperandSegmentSizeAttr(), 188 builder.getDenseI32ArrayAttr(operandSegments)); 189 return op; 190 } 191 192 template <typename EntryOp> 193 static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder, 194 fir::FirOpBuilder &builder, 195 mlir::Location loc, mlir::Type descTy, 196 llvm::StringRef funcNamePrefix, 197 std::stringstream &asFortran, 198 mlir::acc::DataClause clause) { 199 auto crtInsPt = builder.saveInsertionPoint(); 200 std::stringstream registerFuncName; 201 registerFuncName << funcNamePrefix.str() 202 << Fortran::lower::declarePostAllocSuffix.str(); 203 204 if (!mlir::isa<fir::ReferenceType>(descTy)) 205 descTy = fir::ReferenceType::get(descTy); 206 auto registerFuncOp = createDeclareFunc( 207 modBuilder, builder, loc, registerFuncName.str(), {descTy}, {loc}); 208 209 llvm::SmallVector<mlir::Value> bounds; 210 std::stringstream asFortranDesc; 211 asFortranDesc << asFortran.str() << accFirDescriptorPostfix.str(); 212 213 // Updating descriptor must occur before the mapping of the data so that 214 // attached data pointer is not overwritten. 215 mlir::acc::UpdateDeviceOp updateDeviceOp = 216 createDataEntryOp<mlir::acc::UpdateDeviceOp>( 217 builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds, 218 /*structured=*/false, /*implicit=*/true, 219 mlir::acc::DataClause::acc_update_device, descTy, 220 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); 221 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1}; 222 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()}; 223 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments); 224 225 mlir::Value desc = 226 builder.create<fir::LoadOp>(loc, registerFuncOp.getArgument(0)); 227 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, desc); 228 addDeclareAttr(builder, boxAddrOp.getOperation(), clause); 229 EntryOp entryOp = createDataEntryOp<EntryOp>( 230 builder, loc, boxAddrOp.getResult(), asFortran, bounds, 231 /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType(), 232 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); 233 builder.create<mlir::acc::DeclareEnterOp>( 234 loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()), 235 mlir::ValueRange(entryOp.getAccPtr())); 236 237 modBuilder.setInsertionPointAfter(registerFuncOp); 238 builder.restoreInsertionPoint(crtInsPt); 239 } 240 241 template <typename ExitOp> 242 static void createDeclareDeallocFuncWithArg( 243 mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, mlir::Location loc, 244 mlir::Type descTy, llvm::StringRef funcNamePrefix, 245 std::stringstream &asFortran, mlir::acc::DataClause clause) { 246 auto crtInsPt = builder.saveInsertionPoint(); 247 // Generate the pre dealloc function. 248 std::stringstream preDeallocFuncName; 249 preDeallocFuncName << funcNamePrefix.str() 250 << Fortran::lower::declarePreDeallocSuffix.str(); 251 if (!mlir::isa<fir::ReferenceType>(descTy)) 252 descTy = fir::ReferenceType::get(descTy); 253 auto preDeallocOp = createDeclareFunc( 254 modBuilder, builder, loc, preDeallocFuncName.str(), {descTy}, {loc}); 255 mlir::Value loadOp = 256 builder.create<fir::LoadOp>(loc, preDeallocOp.getArgument(0)); 257 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp); 258 addDeclareAttr(builder, boxAddrOp.getOperation(), clause); 259 260 llvm::SmallVector<mlir::Value> bounds; 261 mlir::acc::GetDevicePtrOp entryOp = 262 createDataEntryOp<mlir::acc::GetDevicePtrOp>( 263 builder, loc, boxAddrOp.getResult(), asFortran, bounds, 264 /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType(), 265 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); 266 builder.create<mlir::acc::DeclareExitOp>( 267 loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr())); 268 269 if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> || 270 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>) 271 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(), 272 entryOp.getVarPtr(), entryOp.getVarType(), 273 entryOp.getBounds(), entryOp.getAsyncOperands(), 274 entryOp.getAsyncOperandsDeviceTypeAttr(), 275 entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), 276 /*structured=*/false, /*implicit=*/false, 277 builder.getStringAttr(*entryOp.getName())); 278 else 279 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(), 280 entryOp.getBounds(), entryOp.getAsyncOperands(), 281 entryOp.getAsyncOperandsDeviceTypeAttr(), 282 entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), 283 /*structured=*/false, /*implicit=*/false, 284 builder.getStringAttr(*entryOp.getName())); 285 286 // Generate the post dealloc function. 287 modBuilder.setInsertionPointAfter(preDeallocOp); 288 std::stringstream postDeallocFuncName; 289 postDeallocFuncName << funcNamePrefix.str() 290 << Fortran::lower::declarePostDeallocSuffix.str(); 291 auto postDeallocOp = createDeclareFunc( 292 modBuilder, builder, loc, postDeallocFuncName.str(), {descTy}, {loc}); 293 loadOp = builder.create<fir::LoadOp>(loc, postDeallocOp.getArgument(0)); 294 asFortran << accFirDescriptorPostfix.str(); 295 mlir::acc::UpdateDeviceOp updateDeviceOp = 296 createDataEntryOp<mlir::acc::UpdateDeviceOp>( 297 builder, loc, loadOp, asFortran, bounds, 298 /*structured=*/false, /*implicit=*/true, 299 mlir::acc::DataClause::acc_update_device, loadOp.getType(), 300 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); 301 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1}; 302 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()}; 303 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments); 304 modBuilder.setInsertionPointAfter(postDeallocOp); 305 builder.restoreInsertionPoint(crtInsPt); 306 } 307 308 Fortran::semantics::Symbol & 309 getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) { 310 if (const auto *designator = 311 std::get_if<Fortran::parser::Designator>(&accObject.u)) { 312 if (const auto *name = 313 Fortran::semantics::getDesignatorNameIfDataRef(*designator)) 314 return *name->symbol; 315 if (const auto *arrayElement = 316 Fortran::parser::Unwrap<Fortran::parser::ArrayElement>( 317 *designator)) { 318 const Fortran::parser::Name &name = 319 Fortran::parser::GetLastName(arrayElement->base); 320 return *name.symbol; 321 } 322 if (const auto *component = 323 Fortran::parser::Unwrap<Fortran::parser::StructureComponent>( 324 *designator)) { 325 return *component->component.symbol; 326 } 327 } else if (const auto *name = 328 std::get_if<Fortran::parser::Name>(&accObject.u)) { 329 return *name->symbol; 330 } 331 llvm::report_fatal_error("Could not find symbol"); 332 } 333 334 template <typename Op> 335 static void 336 genDataOperandOperations(const Fortran::parser::AccObjectList &objectList, 337 Fortran::lower::AbstractConverter &converter, 338 Fortran::semantics::SemanticsContext &semanticsContext, 339 Fortran::lower::StatementContext &stmtCtx, 340 llvm::SmallVectorImpl<mlir::Value> &dataOperands, 341 mlir::acc::DataClause dataClause, bool structured, 342 bool implicit, llvm::ArrayRef<mlir::Value> async, 343 llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes, 344 llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes, 345 bool setDeclareAttr = false) { 346 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 347 Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; 348 for (const auto &accObject : objectList.v) { 349 llvm::SmallVector<mlir::Value> bounds; 350 std::stringstream asFortran; 351 mlir::Location operandLocation = genOperandLocation(converter, accObject); 352 Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); 353 Fortran::semantics::MaybeExpr designator = Fortran::common::visit( 354 [&](auto &&s) { return ea.Analyze(s); }, accObject.u); 355 fir::factory::AddrAndBoundsInfo info = 356 Fortran::lower::gatherDataOperandAddrAndBounds< 357 mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>( 358 converter, builder, semanticsContext, stmtCtx, symbol, designator, 359 operandLocation, asFortran, bounds, 360 /*treatIndexAsSection=*/true); 361 LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs())); 362 363 // If the input value is optional and is not a descriptor, we use the 364 // rawInput directly. 365 mlir::Value baseAddr = ((fir::unwrapRefType(info.addr.getType()) != 366 fir::unwrapRefType(info.rawInput.getType())) && 367 info.isPresent) 368 ? info.rawInput 369 : info.addr; 370 Op op = createDataEntryOp<Op>( 371 builder, operandLocation, baseAddr, asFortran, bounds, structured, 372 implicit, dataClause, baseAddr.getType(), async, asyncDeviceTypes, 373 asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true, info.isPresent); 374 dataOperands.push_back(op.getAccPtr()); 375 } 376 } 377 378 template <typename EntryOp, typename ExitOp> 379 static void genDeclareDataOperandOperations( 380 const Fortran::parser::AccObjectList &objectList, 381 Fortran::lower::AbstractConverter &converter, 382 Fortran::semantics::SemanticsContext &semanticsContext, 383 Fortran::lower::StatementContext &stmtCtx, 384 llvm::SmallVectorImpl<mlir::Value> &dataOperands, 385 mlir::acc::DataClause dataClause, bool structured, bool implicit) { 386 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 387 Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; 388 for (const auto &accObject : objectList.v) { 389 llvm::SmallVector<mlir::Value> bounds; 390 std::stringstream asFortran; 391 mlir::Location operandLocation = genOperandLocation(converter, accObject); 392 Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); 393 Fortran::semantics::MaybeExpr designator = Fortran::common::visit( 394 [&](auto &&s) { return ea.Analyze(s); }, accObject.u); 395 fir::factory::AddrAndBoundsInfo info = 396 Fortran::lower::gatherDataOperandAddrAndBounds< 397 mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>( 398 converter, builder, semanticsContext, stmtCtx, symbol, designator, 399 operandLocation, asFortran, bounds); 400 LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs())); 401 EntryOp op = createDataEntryOp<EntryOp>( 402 builder, operandLocation, info.addr, asFortran, bounds, structured, 403 implicit, dataClause, info.addr.getType(), 404 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); 405 dataOperands.push_back(op.getAccPtr()); 406 addDeclareAttr(builder, op.getVarPtr().getDefiningOp(), dataClause); 407 if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(info.addr.getType()))) { 408 mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion()); 409 modBuilder.setInsertionPointAfter(builder.getFunction()); 410 std::string prefix = converter.mangleName(symbol); 411 createDeclareAllocFuncWithArg<EntryOp>( 412 modBuilder, builder, operandLocation, info.addr.getType(), prefix, 413 asFortran, dataClause); 414 if constexpr (!std::is_same_v<EntryOp, ExitOp>) 415 createDeclareDeallocFuncWithArg<ExitOp>( 416 modBuilder, builder, operandLocation, info.addr.getType(), prefix, 417 asFortran, dataClause); 418 } 419 } 420 } 421 422 template <typename EntryOp, typename ExitOp, typename Clause> 423 static void genDeclareDataOperandOperationsWithModifier( 424 const Clause *x, Fortran::lower::AbstractConverter &converter, 425 Fortran::semantics::SemanticsContext &semanticsContext, 426 Fortran::lower::StatementContext &stmtCtx, 427 Fortran::parser::AccDataModifier::Modifier mod, 428 llvm::SmallVectorImpl<mlir::Value> &dataClauseOperands, 429 const mlir::acc::DataClause clause, 430 const mlir::acc::DataClause clauseWithModifier) { 431 const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v; 432 const auto &accObjectList = 433 std::get<Fortran::parser::AccObjectList>(listWithModifier.t); 434 const auto &modifier = 435 std::get<std::optional<Fortran::parser::AccDataModifier>>( 436 listWithModifier.t); 437 mlir::acc::DataClause dataClause = 438 (modifier && (*modifier).v == mod) ? clauseWithModifier : clause; 439 genDeclareDataOperandOperations<EntryOp, ExitOp>( 440 accObjectList, converter, semanticsContext, stmtCtx, dataClauseOperands, 441 dataClause, 442 /*structured=*/true, /*implicit=*/false); 443 } 444 445 template <typename EntryOp, typename ExitOp> 446 static void genDataExitOperations(fir::FirOpBuilder &builder, 447 llvm::SmallVector<mlir::Value> operands, 448 bool structured) { 449 for (mlir::Value operand : operands) { 450 auto entryOp = mlir::dyn_cast_or_null<EntryOp>(operand.getDefiningOp()); 451 assert(entryOp && "data entry op expected"); 452 if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> || 453 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>) 454 builder.create<ExitOp>( 455 entryOp.getLoc(), entryOp.getAccPtr(), entryOp.getVarPtr(), 456 entryOp.getVarType(), entryOp.getBounds(), entryOp.getAsyncOperands(), 457 entryOp.getAsyncOperandsDeviceTypeAttr(), entryOp.getAsyncOnlyAttr(), 458 entryOp.getDataClause(), structured, entryOp.getImplicit(), 459 builder.getStringAttr(*entryOp.getName())); 460 else 461 builder.create<ExitOp>( 462 entryOp.getLoc(), entryOp.getAccPtr(), entryOp.getBounds(), 463 entryOp.getAsyncOperands(), entryOp.getAsyncOperandsDeviceTypeAttr(), 464 entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), structured, 465 entryOp.getImplicit(), builder.getStringAttr(*entryOp.getName())); 466 } 467 } 468 469 fir::ShapeOp genShapeOp(mlir::OpBuilder &builder, fir::SequenceType seqTy, 470 mlir::Location loc) { 471 llvm::SmallVector<mlir::Value> extents; 472 mlir::Type idxTy = builder.getIndexType(); 473 for (auto extent : seqTy.getShape()) 474 extents.push_back(builder.create<mlir::arith::ConstantOp>( 475 loc, idxTy, builder.getIntegerAttr(idxTy, extent))); 476 return builder.create<fir::ShapeOp>(loc, extents); 477 } 478 479 template <typename RecipeOp> 480 static void genPrivateLikeInitRegion(mlir::OpBuilder &builder, RecipeOp recipe, 481 mlir::Type ty, mlir::Location loc) { 482 mlir::Value retVal = recipe.getInitRegion().front().getArgument(0); 483 if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(ty)) { 484 if (fir::isa_trivial(refTy.getEleTy())) { 485 auto alloca = builder.create<fir::AllocaOp>(loc, refTy.getEleTy()); 486 auto declareOp = builder.create<hlfir::DeclareOp>( 487 loc, alloca, accPrivateInitName, /*shape=*/nullptr, 488 llvm::ArrayRef<mlir::Value>{}, /*dummy_scope=*/nullptr, 489 fir::FortranVariableFlagsAttr{}); 490 retVal = declareOp.getBase(); 491 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>( 492 refTy.getEleTy())) { 493 if (fir::isa_trivial(seqTy.getEleTy())) { 494 mlir::Value shape; 495 llvm::SmallVector<mlir::Value> extents; 496 if (seqTy.hasDynamicExtents()) { 497 // Extents are passed as block arguments. First argument is the 498 // original value. 499 for (unsigned i = 1; i < recipe.getInitRegion().getArguments().size(); 500 ++i) 501 extents.push_back(recipe.getInitRegion().getArgument(i)); 502 shape = builder.create<fir::ShapeOp>(loc, extents); 503 } else { 504 shape = genShapeOp(builder, seqTy, loc); 505 } 506 auto alloca = builder.create<fir::AllocaOp>( 507 loc, seqTy, /*typeparams=*/mlir::ValueRange{}, extents); 508 auto declareOp = builder.create<hlfir::DeclareOp>( 509 loc, alloca, accPrivateInitName, shape, 510 llvm::ArrayRef<mlir::Value>{}, /*dummy_scope=*/nullptr, 511 fir::FortranVariableFlagsAttr{}); 512 retVal = declareOp.getBase(); 513 } 514 } 515 } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) { 516 mlir::Type innerTy = fir::extractSequenceType(boxTy); 517 if (!innerTy) 518 TODO(loc, "Unsupported boxed type in OpenACC privatization"); 519 fir::FirOpBuilder firBuilder{builder, recipe.getOperation()}; 520 hlfir::Entity source = hlfir::Entity{retVal}; 521 auto [temp, cleanup] = hlfir::createTempFromMold(loc, firBuilder, source); 522 retVal = temp; 523 } 524 builder.create<mlir::acc::YieldOp>(loc, retVal); 525 } 526 527 mlir::acc::PrivateRecipeOp 528 Fortran::lower::createOrGetPrivateRecipe(mlir::OpBuilder &builder, 529 llvm::StringRef recipeName, 530 mlir::Location loc, mlir::Type ty) { 531 mlir::ModuleOp mod = 532 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>(); 533 if (auto recipe = mod.lookupSymbol<mlir::acc::PrivateRecipeOp>(recipeName)) 534 return recipe; 535 536 auto crtPos = builder.saveInsertionPoint(); 537 mlir::OpBuilder modBuilder(mod.getBodyRegion()); 538 auto recipe = 539 modBuilder.create<mlir::acc::PrivateRecipeOp>(loc, recipeName, ty); 540 llvm::SmallVector<mlir::Type> argsTy{ty}; 541 llvm::SmallVector<mlir::Location> argsLoc{loc}; 542 if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(ty)) { 543 if (auto seqTy = 544 mlir::dyn_cast_or_null<fir::SequenceType>(refTy.getEleTy())) { 545 if (seqTy.hasDynamicExtents()) { 546 mlir::Type idxTy = builder.getIndexType(); 547 for (unsigned i = 0; i < seqTy.getDimension(); ++i) { 548 argsTy.push_back(idxTy); 549 argsLoc.push_back(loc); 550 } 551 } 552 } 553 } 554 builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(), 555 argsTy, argsLoc); 556 builder.setInsertionPointToEnd(&recipe.getInitRegion().back()); 557 genPrivateLikeInitRegion<mlir::acc::PrivateRecipeOp>(builder, recipe, ty, 558 loc); 559 builder.restoreInsertionPoint(crtPos); 560 return recipe; 561 } 562 563 /// Check if the DataBoundsOp is a constant bound (lb and ub are constants or 564 /// extent is a constant). 565 bool isConstantBound(mlir::acc::DataBoundsOp &op) { 566 if (op.getLowerbound() && fir::getIntIfConstant(op.getLowerbound()) && 567 op.getUpperbound() && fir::getIntIfConstant(op.getUpperbound())) 568 return true; 569 if (op.getExtent() && fir::getIntIfConstant(op.getExtent())) 570 return true; 571 return false; 572 } 573 574 /// Return true iff all the bounds are expressed with constant values. 575 bool areAllBoundConstant(const llvm::SmallVector<mlir::Value> &bounds) { 576 for (auto bound : bounds) { 577 auto dataBound = 578 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp()); 579 assert(dataBound && "Must be DataBoundOp operation"); 580 if (!isConstantBound(dataBound)) 581 return false; 582 } 583 return true; 584 } 585 586 static llvm::SmallVector<mlir::Value> 587 genConstantBounds(fir::FirOpBuilder &builder, mlir::Location loc, 588 mlir::acc::DataBoundsOp &dataBound) { 589 mlir::Type idxTy = builder.getIndexType(); 590 mlir::Value lb, ub, step; 591 if (dataBound.getLowerbound() && 592 fir::getIntIfConstant(dataBound.getLowerbound()) && 593 dataBound.getUpperbound() && 594 fir::getIntIfConstant(dataBound.getUpperbound())) { 595 lb = builder.createIntegerConstant( 596 loc, idxTy, *fir::getIntIfConstant(dataBound.getLowerbound())); 597 ub = builder.createIntegerConstant( 598 loc, idxTy, *fir::getIntIfConstant(dataBound.getUpperbound())); 599 step = builder.createIntegerConstant(loc, idxTy, 1); 600 } else if (dataBound.getExtent()) { 601 lb = builder.createIntegerConstant(loc, idxTy, 0); 602 ub = builder.createIntegerConstant( 603 loc, idxTy, *fir::getIntIfConstant(dataBound.getExtent()) - 1); 604 step = builder.createIntegerConstant(loc, idxTy, 1); 605 } else { 606 llvm::report_fatal_error("Expect constant lb/ub or extent"); 607 } 608 return {lb, ub, step}; 609 } 610 611 static fir::ShapeOp genShapeFromBoundsOrArgs( 612 mlir::Location loc, fir::FirOpBuilder &builder, fir::SequenceType seqTy, 613 const llvm::SmallVector<mlir::Value> &bounds, mlir::ValueRange arguments) { 614 llvm::SmallVector<mlir::Value> args; 615 if (areAllBoundConstant(bounds)) { 616 for (auto bound : llvm::reverse(bounds)) { 617 auto dataBound = 618 mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp()); 619 args.append(genConstantBounds(builder, loc, dataBound)); 620 } 621 } else { 622 assert(((arguments.size() - 2) / 3 == seqTy.getDimension()) && 623 "Expect 3 block arguments per dimension"); 624 for (auto arg : arguments.drop_front(2)) 625 args.push_back(arg); 626 } 627 628 assert(args.size() % 3 == 0 && "Triplets must be a multiple of 3"); 629 llvm::SmallVector<mlir::Value> extents; 630 mlir::Type idxTy = builder.getIndexType(); 631 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); 632 mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0); 633 for (unsigned i = 0; i < args.size(); i += 3) { 634 mlir::Value s1 = 635 builder.create<mlir::arith::SubIOp>(loc, args[i + 1], args[0]); 636 mlir::Value s2 = builder.create<mlir::arith::AddIOp>(loc, s1, one); 637 mlir::Value s3 = builder.create<mlir::arith::DivSIOp>(loc, s2, args[i + 2]); 638 mlir::Value cmp = builder.create<mlir::arith::CmpIOp>( 639 loc, mlir::arith::CmpIPredicate::sgt, s3, zero); 640 mlir::Value ext = builder.create<mlir::arith::SelectOp>(loc, cmp, s3, zero); 641 extents.push_back(ext); 642 } 643 return builder.create<fir::ShapeOp>(loc, extents); 644 } 645 646 static hlfir::DesignateOp::Subscripts 647 getSubscriptsFromArgs(mlir::ValueRange args) { 648 hlfir::DesignateOp::Subscripts triplets; 649 for (unsigned i = 2; i < args.size(); i += 3) 650 triplets.emplace_back( 651 hlfir::DesignateOp::Triplet{args[i], args[i + 1], args[i + 2]}); 652 return triplets; 653 } 654 655 static hlfir::Entity genDesignateWithTriplets( 656 fir::FirOpBuilder &builder, mlir::Location loc, hlfir::Entity &entity, 657 hlfir::DesignateOp::Subscripts &triplets, mlir::Value shape) { 658 llvm::SmallVector<mlir::Value> lenParams; 659 hlfir::genLengthParameters(loc, builder, entity, lenParams); 660 auto designate = builder.create<hlfir::DesignateOp>( 661 loc, entity.getBase().getType(), entity, /*component=*/"", 662 /*componentShape=*/mlir::Value{}, triplets, 663 /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt, shape, 664 lenParams); 665 return hlfir::Entity{designate.getResult()}; 666 } 667 668 mlir::acc::FirstprivateRecipeOp Fortran::lower::createOrGetFirstprivateRecipe( 669 mlir::OpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc, 670 mlir::Type ty, llvm::SmallVector<mlir::Value> &bounds) { 671 mlir::ModuleOp mod = 672 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>(); 673 if (auto recipe = 674 mod.lookupSymbol<mlir::acc::FirstprivateRecipeOp>(recipeName)) 675 return recipe; 676 677 auto crtPos = builder.saveInsertionPoint(); 678 mlir::OpBuilder modBuilder(mod.getBodyRegion()); 679 auto recipe = 680 modBuilder.create<mlir::acc::FirstprivateRecipeOp>(loc, recipeName, ty); 681 llvm::SmallVector<mlir::Type> initArgsTy{ty}; 682 llvm::SmallVector<mlir::Location> initArgsLoc{loc}; 683 auto refTy = fir::unwrapRefType(ty); 684 if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(refTy)) { 685 if (seqTy.hasDynamicExtents()) { 686 mlir::Type idxTy = builder.getIndexType(); 687 for (unsigned i = 0; i < seqTy.getDimension(); ++i) { 688 initArgsTy.push_back(idxTy); 689 initArgsLoc.push_back(loc); 690 } 691 } 692 } 693 builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(), 694 initArgsTy, initArgsLoc); 695 builder.setInsertionPointToEnd(&recipe.getInitRegion().back()); 696 genPrivateLikeInitRegion<mlir::acc::FirstprivateRecipeOp>(builder, recipe, ty, 697 loc); 698 699 bool allConstantBound = areAllBoundConstant(bounds); 700 llvm::SmallVector<mlir::Type> argsTy{ty, ty}; 701 llvm::SmallVector<mlir::Location> argsLoc{loc, loc}; 702 if (!allConstantBound) { 703 for (mlir::Value bound : llvm::reverse(bounds)) { 704 auto dataBound = 705 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp()); 706 argsTy.push_back(dataBound.getLowerbound().getType()); 707 argsLoc.push_back(dataBound.getLowerbound().getLoc()); 708 argsTy.push_back(dataBound.getUpperbound().getType()); 709 argsLoc.push_back(dataBound.getUpperbound().getLoc()); 710 argsTy.push_back(dataBound.getStartIdx().getType()); 711 argsLoc.push_back(dataBound.getStartIdx().getLoc()); 712 } 713 } 714 builder.createBlock(&recipe.getCopyRegion(), recipe.getCopyRegion().end(), 715 argsTy, argsLoc); 716 717 builder.setInsertionPointToEnd(&recipe.getCopyRegion().back()); 718 ty = fir::unwrapRefType(ty); 719 if (fir::isa_trivial(ty)) { 720 mlir::Value initValue = builder.create<fir::LoadOp>( 721 loc, recipe.getCopyRegion().front().getArgument(0)); 722 builder.create<fir::StoreOp>(loc, initValue, 723 recipe.getCopyRegion().front().getArgument(1)); 724 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(ty)) { 725 fir::FirOpBuilder firBuilder{builder, recipe.getOperation()}; 726 auto shape = genShapeFromBoundsOrArgs( 727 loc, firBuilder, seqTy, bounds, recipe.getCopyRegion().getArguments()); 728 729 auto leftDeclOp = builder.create<hlfir::DeclareOp>( 730 loc, recipe.getCopyRegion().getArgument(0), llvm::StringRef{}, shape, 731 llvm::ArrayRef<mlir::Value>{}, /*dummy_scope=*/nullptr, 732 fir::FortranVariableFlagsAttr{}); 733 auto rightDeclOp = builder.create<hlfir::DeclareOp>( 734 loc, recipe.getCopyRegion().getArgument(1), llvm::StringRef{}, shape, 735 llvm::ArrayRef<mlir::Value>{}, /*dummy_scope=*/nullptr, 736 fir::FortranVariableFlagsAttr{}); 737 738 hlfir::DesignateOp::Subscripts triplets = 739 getSubscriptsFromArgs(recipe.getCopyRegion().getArguments()); 740 auto leftEntity = hlfir::Entity{leftDeclOp.getBase()}; 741 auto left = 742 genDesignateWithTriplets(firBuilder, loc, leftEntity, triplets, shape); 743 auto rightEntity = hlfir::Entity{rightDeclOp.getBase()}; 744 auto right = 745 genDesignateWithTriplets(firBuilder, loc, rightEntity, triplets, shape); 746 747 firBuilder.create<hlfir::AssignOp>(loc, left, right); 748 749 } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) { 750 fir::FirOpBuilder firBuilder{builder, recipe.getOperation()}; 751 llvm::SmallVector<mlir::Value> tripletArgs; 752 mlir::Type innerTy = fir::extractSequenceType(boxTy); 753 fir::SequenceType seqTy = 754 mlir::dyn_cast_or_null<fir::SequenceType>(innerTy); 755 if (!seqTy) 756 TODO(loc, "Unsupported boxed type in OpenACC firstprivate"); 757 758 auto shape = genShapeFromBoundsOrArgs( 759 loc, firBuilder, seqTy, bounds, recipe.getCopyRegion().getArguments()); 760 hlfir::DesignateOp::Subscripts triplets = 761 getSubscriptsFromArgs(recipe.getCopyRegion().getArguments()); 762 auto leftEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(0)}; 763 auto left = 764 genDesignateWithTriplets(firBuilder, loc, leftEntity, triplets, shape); 765 auto rightEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(1)}; 766 auto right = 767 genDesignateWithTriplets(firBuilder, loc, rightEntity, triplets, shape); 768 firBuilder.create<hlfir::AssignOp>(loc, left, right); 769 } 770 771 builder.create<mlir::acc::TerminatorOp>(loc); 772 builder.restoreInsertionPoint(crtPos); 773 return recipe; 774 } 775 776 /// Get a string representation of the bounds. 777 std::string getBoundsString(llvm::SmallVector<mlir::Value> &bounds) { 778 std::stringstream boundStr; 779 if (!bounds.empty()) 780 boundStr << "_section_"; 781 llvm::interleave( 782 bounds, 783 [&](mlir::Value bound) { 784 auto boundsOp = 785 mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp()); 786 if (boundsOp.getLowerbound() && 787 fir::getIntIfConstant(boundsOp.getLowerbound()) && 788 boundsOp.getUpperbound() && 789 fir::getIntIfConstant(boundsOp.getUpperbound())) { 790 boundStr << "lb" << *fir::getIntIfConstant(boundsOp.getLowerbound()) 791 << ".ub" << *fir::getIntIfConstant(boundsOp.getUpperbound()); 792 } else if (boundsOp.getExtent() && 793 fir::getIntIfConstant(boundsOp.getExtent())) { 794 boundStr << "ext" << *fir::getIntIfConstant(boundsOp.getExtent()); 795 } else { 796 boundStr << "?"; 797 } 798 }, 799 [&] { boundStr << "x"; }); 800 return boundStr.str(); 801 } 802 803 /// Rebuild the array type from the acc.bounds operation with constant 804 /// lowerbound/upperbound or extent. 805 mlir::Type getTypeFromBounds(llvm::SmallVector<mlir::Value> &bounds, 806 mlir::Type ty) { 807 auto seqTy = 808 mlir::dyn_cast_or_null<fir::SequenceType>(fir::unwrapRefType(ty)); 809 if (!bounds.empty() && seqTy) { 810 llvm::SmallVector<int64_t> shape; 811 for (auto b : bounds) { 812 auto boundsOp = 813 mlir::dyn_cast<mlir::acc::DataBoundsOp>(b.getDefiningOp()); 814 if (boundsOp.getLowerbound() && 815 fir::getIntIfConstant(boundsOp.getLowerbound()) && 816 boundsOp.getUpperbound() && 817 fir::getIntIfConstant(boundsOp.getUpperbound())) { 818 int64_t ext = *fir::getIntIfConstant(boundsOp.getUpperbound()) - 819 *fir::getIntIfConstant(boundsOp.getLowerbound()) + 1; 820 shape.push_back(ext); 821 } else if (boundsOp.getExtent() && 822 fir::getIntIfConstant(boundsOp.getExtent())) { 823 shape.push_back(*fir::getIntIfConstant(boundsOp.getExtent())); 824 } else { 825 return ty; // TODO: handle dynamic shaped array slice. 826 } 827 } 828 if (shape.empty() || shape.size() != bounds.size()) 829 return ty; 830 auto newSeqTy = fir::SequenceType::get(shape, seqTy.getEleTy()); 831 if (mlir::isa<fir::ReferenceType, fir::PointerType>(ty)) 832 return fir::ReferenceType::get(newSeqTy); 833 return newSeqTy; 834 } 835 return ty; 836 } 837 838 template <typename RecipeOp> 839 static void 840 genPrivatizations(const Fortran::parser::AccObjectList &objectList, 841 Fortran::lower::AbstractConverter &converter, 842 Fortran::semantics::SemanticsContext &semanticsContext, 843 Fortran::lower::StatementContext &stmtCtx, 844 llvm::SmallVectorImpl<mlir::Value> &dataOperands, 845 llvm::SmallVector<mlir::Attribute> &privatizations, 846 llvm::ArrayRef<mlir::Value> async, 847 llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes, 848 llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes) { 849 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 850 Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; 851 for (const auto &accObject : objectList.v) { 852 llvm::SmallVector<mlir::Value> bounds; 853 std::stringstream asFortran; 854 mlir::Location operandLocation = genOperandLocation(converter, accObject); 855 Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); 856 Fortran::semantics::MaybeExpr designator = Fortran::common::visit( 857 [&](auto &&s) { return ea.Analyze(s); }, accObject.u); 858 fir::factory::AddrAndBoundsInfo info = 859 Fortran::lower::gatherDataOperandAddrAndBounds< 860 mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>( 861 converter, builder, semanticsContext, stmtCtx, symbol, designator, 862 operandLocation, asFortran, bounds); 863 LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs())); 864 865 RecipeOp recipe; 866 mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType()); 867 if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) { 868 std::string recipeName = 869 fir::getTypeAsString(retTy, converter.getKindMap(), 870 Fortran::lower::privatizationRecipePrefix); 871 recipe = Fortran::lower::createOrGetPrivateRecipe(builder, recipeName, 872 operandLocation, retTy); 873 auto op = createDataEntryOp<mlir::acc::PrivateOp>( 874 builder, operandLocation, info.addr, asFortran, bounds, true, 875 /*implicit=*/false, mlir::acc::DataClause::acc_private, retTy, async, 876 asyncDeviceTypes, asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true); 877 dataOperands.push_back(op.getAccPtr()); 878 } else { 879 std::string suffix = 880 areAllBoundConstant(bounds) ? getBoundsString(bounds) : ""; 881 std::string recipeName = fir::getTypeAsString( 882 retTy, converter.getKindMap(), "firstprivatization" + suffix); 883 recipe = Fortran::lower::createOrGetFirstprivateRecipe( 884 builder, recipeName, operandLocation, retTy, bounds); 885 auto op = createDataEntryOp<mlir::acc::FirstprivateOp>( 886 builder, operandLocation, info.addr, asFortran, bounds, true, 887 /*implicit=*/false, mlir::acc::DataClause::acc_firstprivate, retTy, 888 async, asyncDeviceTypes, asyncOnlyDeviceTypes, 889 /*unwrapBoxAddr=*/true); 890 dataOperands.push_back(op.getAccPtr()); 891 } 892 privatizations.push_back(mlir::SymbolRefAttr::get( 893 builder.getContext(), recipe.getSymName().str())); 894 } 895 } 896 897 /// Return the corresponding enum value for the mlir::acc::ReductionOperator 898 /// from the parser representation. 899 static mlir::acc::ReductionOperator 900 getReductionOperator(const Fortran::parser::ReductionOperator &op) { 901 switch (op.v) { 902 case Fortran::parser::ReductionOperator::Operator::Plus: 903 return mlir::acc::ReductionOperator::AccAdd; 904 case Fortran::parser::ReductionOperator::Operator::Multiply: 905 return mlir::acc::ReductionOperator::AccMul; 906 case Fortran::parser::ReductionOperator::Operator::Max: 907 return mlir::acc::ReductionOperator::AccMax; 908 case Fortran::parser::ReductionOperator::Operator::Min: 909 return mlir::acc::ReductionOperator::AccMin; 910 case Fortran::parser::ReductionOperator::Operator::Iand: 911 return mlir::acc::ReductionOperator::AccIand; 912 case Fortran::parser::ReductionOperator::Operator::Ior: 913 return mlir::acc::ReductionOperator::AccIor; 914 case Fortran::parser::ReductionOperator::Operator::Ieor: 915 return mlir::acc::ReductionOperator::AccXor; 916 case Fortran::parser::ReductionOperator::Operator::And: 917 return mlir::acc::ReductionOperator::AccLand; 918 case Fortran::parser::ReductionOperator::Operator::Or: 919 return mlir::acc::ReductionOperator::AccLor; 920 case Fortran::parser::ReductionOperator::Operator::Eqv: 921 return mlir::acc::ReductionOperator::AccEqv; 922 case Fortran::parser::ReductionOperator::Operator::Neqv: 923 return mlir::acc::ReductionOperator::AccNeqv; 924 } 925 llvm_unreachable("unexpected reduction operator"); 926 } 927 928 /// Get the initial value for reduction operator. 929 template <typename R> 930 static R getReductionInitValue(mlir::acc::ReductionOperator op, mlir::Type ty) { 931 if (op == mlir::acc::ReductionOperator::AccMin) { 932 // min init value -> largest 933 if constexpr (std::is_same_v<R, llvm::APInt>) { 934 assert(ty.isIntOrIndex() && "expect integer or index type"); 935 return llvm::APInt::getSignedMaxValue(ty.getIntOrFloatBitWidth()); 936 } 937 if constexpr (std::is_same_v<R, llvm::APFloat>) { 938 auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty); 939 assert(floatTy && "expect float type"); 940 return llvm::APFloat::getLargest(floatTy.getFloatSemantics(), 941 /*negative=*/false); 942 } 943 } else if (op == mlir::acc::ReductionOperator::AccMax) { 944 // max init value -> smallest 945 if constexpr (std::is_same_v<R, llvm::APInt>) { 946 assert(ty.isIntOrIndex() && "expect integer or index type"); 947 return llvm::APInt::getSignedMinValue(ty.getIntOrFloatBitWidth()); 948 } 949 if constexpr (std::is_same_v<R, llvm::APFloat>) { 950 auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty); 951 assert(floatTy && "expect float type"); 952 return llvm::APFloat::getSmallest(floatTy.getFloatSemantics(), 953 /*negative=*/true); 954 } 955 } else if (op == mlir::acc::ReductionOperator::AccIand) { 956 if constexpr (std::is_same_v<R, llvm::APInt>) { 957 assert(ty.isIntOrIndex() && "expect integer type"); 958 unsigned bits = ty.getIntOrFloatBitWidth(); 959 return llvm::APInt::getAllOnes(bits); 960 } 961 } else { 962 // +, ior, ieor init value -> 0 963 // * init value -> 1 964 int64_t value = (op == mlir::acc::ReductionOperator::AccMul) ? 1 : 0; 965 if constexpr (std::is_same_v<R, llvm::APInt>) { 966 assert(ty.isIntOrIndex() && "expect integer or index type"); 967 return llvm::APInt(ty.getIntOrFloatBitWidth(), value, true); 968 } 969 970 if constexpr (std::is_same_v<R, llvm::APFloat>) { 971 assert(mlir::isa<mlir::FloatType>(ty) && "expect float type"); 972 auto floatTy = mlir::dyn_cast<mlir::FloatType>(ty); 973 return llvm::APFloat(floatTy.getFloatSemantics(), value); 974 } 975 976 if constexpr (std::is_same_v<R, int64_t>) 977 return value; 978 } 979 llvm_unreachable("OpenACC reduction unsupported type"); 980 } 981 982 /// Return a constant with the initial value for the reduction operator and 983 /// type combination. 984 static mlir::Value getReductionInitValue(fir::FirOpBuilder &builder, 985 mlir::Location loc, mlir::Type ty, 986 mlir::acc::ReductionOperator op) { 987 if (op == mlir::acc::ReductionOperator::AccLand || 988 op == mlir::acc::ReductionOperator::AccLor || 989 op == mlir::acc::ReductionOperator::AccEqv || 990 op == mlir::acc::ReductionOperator::AccNeqv) { 991 assert(mlir::isa<fir::LogicalType>(ty) && "expect fir.logical type"); 992 bool value = true; // .true. for .and. and .eqv. 993 if (op == mlir::acc::ReductionOperator::AccLor || 994 op == mlir::acc::ReductionOperator::AccNeqv) 995 value = false; // .false. for .or. and .neqv. 996 return builder.createBool(loc, value); 997 } 998 if (ty.isIntOrIndex()) 999 return builder.create<mlir::arith::ConstantOp>( 1000 loc, ty, 1001 builder.getIntegerAttr(ty, getReductionInitValue<llvm::APInt>(op, ty))); 1002 if (op == mlir::acc::ReductionOperator::AccMin || 1003 op == mlir::acc::ReductionOperator::AccMax) { 1004 if (mlir::isa<mlir::ComplexType>(ty)) 1005 llvm::report_fatal_error( 1006 "min/max reduction not supported for complex type"); 1007 if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty)) 1008 return builder.create<mlir::arith::ConstantOp>( 1009 loc, ty, 1010 builder.getFloatAttr(ty, 1011 getReductionInitValue<llvm::APFloat>(op, ty))); 1012 } else if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty)) { 1013 return builder.create<mlir::arith::ConstantOp>( 1014 loc, ty, 1015 builder.getFloatAttr(ty, getReductionInitValue<int64_t>(op, ty))); 1016 } else if (auto cmplxTy = mlir::dyn_cast_or_null<mlir::ComplexType>(ty)) { 1017 mlir::Type floatTy = cmplxTy.getElementType(); 1018 mlir::Value realInit = builder.createRealConstant( 1019 loc, floatTy, getReductionInitValue<int64_t>(op, cmplxTy)); 1020 mlir::Value imagInit = builder.createRealConstant(loc, floatTy, 0.0); 1021 return fir::factory::Complex{builder, loc}.createComplex(cmplxTy, realInit, 1022 imagInit); 1023 } 1024 1025 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) 1026 return getReductionInitValue(builder, loc, seqTy.getEleTy(), op); 1027 1028 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) 1029 return getReductionInitValue(builder, loc, boxTy.getEleTy(), op); 1030 1031 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty)) 1032 return getReductionInitValue(builder, loc, heapTy.getEleTy(), op); 1033 1034 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty)) 1035 return getReductionInitValue(builder, loc, ptrTy.getEleTy(), op); 1036 1037 llvm::report_fatal_error("Unsupported OpenACC reduction type"); 1038 } 1039 1040 static mlir::Value genReductionInitRegion(fir::FirOpBuilder &builder, 1041 mlir::Location loc, mlir::Type ty, 1042 mlir::acc::ReductionOperator op) { 1043 ty = fir::unwrapRefType(ty); 1044 mlir::Value initValue = getReductionInitValue(builder, loc, ty, op); 1045 if (fir::isa_trivial(ty)) { 1046 mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty); 1047 auto declareOp = builder.create<hlfir::DeclareOp>( 1048 loc, alloca, accReductionInitName, /*shape=*/nullptr, 1049 llvm::ArrayRef<mlir::Value>{}, /*dummy_scope=*/nullptr, 1050 fir::FortranVariableFlagsAttr{}); 1051 builder.create<fir::StoreOp>(loc, builder.createConvert(loc, ty, initValue), 1052 declareOp.getBase()); 1053 return declareOp.getBase(); 1054 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(ty)) { 1055 if (fir::isa_trivial(seqTy.getEleTy())) { 1056 mlir::Value shape; 1057 auto extents = builder.getBlock()->getArguments().drop_front(1); 1058 if (seqTy.hasDynamicExtents()) 1059 shape = builder.create<fir::ShapeOp>(loc, extents); 1060 else 1061 shape = genShapeOp(builder, seqTy, loc); 1062 mlir::Value alloca = builder.create<fir::AllocaOp>( 1063 loc, seqTy, /*typeparams=*/mlir::ValueRange{}, extents); 1064 auto declareOp = builder.create<hlfir::DeclareOp>( 1065 loc, alloca, accReductionInitName, shape, 1066 llvm::ArrayRef<mlir::Value>{}, /*dummy_scope=*/nullptr, 1067 fir::FortranVariableFlagsAttr{}); 1068 mlir::Type idxTy = builder.getIndexType(); 1069 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy()); 1070 llvm::SmallVector<fir::DoLoopOp> loops; 1071 llvm::SmallVector<mlir::Value> ivs; 1072 1073 if (seqTy.hasDynamicExtents()) { 1074 builder.create<hlfir::AssignOp>(loc, initValue, declareOp.getBase()); 1075 return declareOp.getBase(); 1076 } 1077 for (auto ext : llvm::reverse(seqTy.getShape())) { 1078 auto lb = builder.createIntegerConstant(loc, idxTy, 0); 1079 auto ub = builder.createIntegerConstant(loc, idxTy, ext - 1); 1080 auto step = builder.createIntegerConstant(loc, idxTy, 1); 1081 auto loop = builder.create<fir::DoLoopOp>(loc, lb, ub, step, 1082 /*unordered=*/false); 1083 builder.setInsertionPointToStart(loop.getBody()); 1084 loops.push_back(loop); 1085 ivs.push_back(loop.getInductionVar()); 1086 } 1087 auto coord = builder.create<fir::CoordinateOp>(loc, refTy, 1088 declareOp.getBase(), ivs); 1089 builder.create<fir::StoreOp>(loc, initValue, coord); 1090 builder.setInsertionPointAfter(loops[0]); 1091 return declareOp.getBase(); 1092 } 1093 } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) { 1094 mlir::Type innerTy = fir::extractSequenceType(boxTy); 1095 if (!mlir::isa<fir::SequenceType>(innerTy)) 1096 TODO(loc, "Unsupported boxed type for reduction"); 1097 // Create the private copy from the initial fir.box. 1098 hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)}; 1099 auto [temp, cleanup] = hlfir::createTempFromMold(loc, builder, source); 1100 builder.create<hlfir::AssignOp>(loc, initValue, temp); 1101 return temp; 1102 } 1103 llvm::report_fatal_error("Unsupported OpenACC reduction type"); 1104 } 1105 1106 template <typename Op> 1107 static mlir::Value genLogicalCombiner(fir::FirOpBuilder &builder, 1108 mlir::Location loc, mlir::Value value1, 1109 mlir::Value value2) { 1110 mlir::Type i1 = builder.getI1Type(); 1111 mlir::Value v1 = builder.create<fir::ConvertOp>(loc, i1, value1); 1112 mlir::Value v2 = builder.create<fir::ConvertOp>(loc, i1, value2); 1113 mlir::Value combined = builder.create<Op>(loc, v1, v2); 1114 return builder.create<fir::ConvertOp>(loc, value1.getType(), combined); 1115 } 1116 1117 static mlir::Value genComparisonCombiner(fir::FirOpBuilder &builder, 1118 mlir::Location loc, 1119 mlir::arith::CmpIPredicate pred, 1120 mlir::Value value1, 1121 mlir::Value value2) { 1122 mlir::Type i1 = builder.getI1Type(); 1123 mlir::Value v1 = builder.create<fir::ConvertOp>(loc, i1, value1); 1124 mlir::Value v2 = builder.create<fir::ConvertOp>(loc, i1, value2); 1125 mlir::Value add = builder.create<mlir::arith::CmpIOp>(loc, pred, v1, v2); 1126 return builder.create<fir::ConvertOp>(loc, value1.getType(), add); 1127 } 1128 1129 static mlir::Value genScalarCombiner(fir::FirOpBuilder &builder, 1130 mlir::Location loc, 1131 mlir::acc::ReductionOperator op, 1132 mlir::Type ty, mlir::Value value1, 1133 mlir::Value value2) { 1134 value1 = builder.loadIfRef(loc, value1); 1135 value2 = builder.loadIfRef(loc, value2); 1136 if (op == mlir::acc::ReductionOperator::AccAdd) { 1137 if (ty.isIntOrIndex()) 1138 return builder.create<mlir::arith::AddIOp>(loc, value1, value2); 1139 if (mlir::isa<mlir::FloatType>(ty)) 1140 return builder.create<mlir::arith::AddFOp>(loc, value1, value2); 1141 if (auto cmplxTy = mlir::dyn_cast_or_null<mlir::ComplexType>(ty)) 1142 return builder.create<fir::AddcOp>(loc, value1, value2); 1143 TODO(loc, "reduction add type"); 1144 } 1145 1146 if (op == mlir::acc::ReductionOperator::AccMul) { 1147 if (ty.isIntOrIndex()) 1148 return builder.create<mlir::arith::MulIOp>(loc, value1, value2); 1149 if (mlir::isa<mlir::FloatType>(ty)) 1150 return builder.create<mlir::arith::MulFOp>(loc, value1, value2); 1151 if (mlir::isa<mlir::ComplexType>(ty)) 1152 return builder.create<fir::MulcOp>(loc, value1, value2); 1153 TODO(loc, "reduction mul type"); 1154 } 1155 1156 if (op == mlir::acc::ReductionOperator::AccMin) 1157 return fir::genMin(builder, loc, {value1, value2}); 1158 1159 if (op == mlir::acc::ReductionOperator::AccMax) 1160 return fir::genMax(builder, loc, {value1, value2}); 1161 1162 if (op == mlir::acc::ReductionOperator::AccIand) 1163 return builder.create<mlir::arith::AndIOp>(loc, value1, value2); 1164 1165 if (op == mlir::acc::ReductionOperator::AccIor) 1166 return builder.create<mlir::arith::OrIOp>(loc, value1, value2); 1167 1168 if (op == mlir::acc::ReductionOperator::AccXor) 1169 return builder.create<mlir::arith::XOrIOp>(loc, value1, value2); 1170 1171 if (op == mlir::acc::ReductionOperator::AccLand) 1172 return genLogicalCombiner<mlir::arith::AndIOp>(builder, loc, value1, 1173 value2); 1174 1175 if (op == mlir::acc::ReductionOperator::AccLor) 1176 return genLogicalCombiner<mlir::arith::OrIOp>(builder, loc, value1, value2); 1177 1178 if (op == mlir::acc::ReductionOperator::AccEqv) 1179 return genComparisonCombiner(builder, loc, mlir::arith::CmpIPredicate::eq, 1180 value1, value2); 1181 1182 if (op == mlir::acc::ReductionOperator::AccNeqv) 1183 return genComparisonCombiner(builder, loc, mlir::arith::CmpIPredicate::ne, 1184 value1, value2); 1185 1186 TODO(loc, "reduction operator"); 1187 } 1188 1189 static hlfir::DesignateOp::Subscripts 1190 getTripletsFromArgs(mlir::acc::ReductionRecipeOp recipe) { 1191 hlfir::DesignateOp::Subscripts triplets; 1192 for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size(); 1193 i += 3) 1194 triplets.emplace_back(hlfir::DesignateOp::Triplet{ 1195 recipe.getCombinerRegion().getArgument(i), 1196 recipe.getCombinerRegion().getArgument(i + 1), 1197 recipe.getCombinerRegion().getArgument(i + 2)}); 1198 return triplets; 1199 } 1200 1201 static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc, 1202 mlir::acc::ReductionOperator op, mlir::Type ty, 1203 mlir::Value value1, mlir::Value value2, 1204 mlir::acc::ReductionRecipeOp &recipe, 1205 llvm::SmallVector<mlir::Value> &bounds, 1206 bool allConstantBound) { 1207 ty = fir::unwrapRefType(ty); 1208 1209 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) { 1210 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy()); 1211 llvm::SmallVector<fir::DoLoopOp> loops; 1212 llvm::SmallVector<mlir::Value> ivs; 1213 if (seqTy.hasDynamicExtents()) { 1214 auto shape = 1215 genShapeFromBoundsOrArgs(loc, builder, seqTy, bounds, 1216 recipe.getCombinerRegion().getArguments()); 1217 auto v1DeclareOp = builder.create<hlfir::DeclareOp>( 1218 loc, value1, llvm::StringRef{}, shape, llvm::ArrayRef<mlir::Value>{}, 1219 /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{}); 1220 auto v2DeclareOp = builder.create<hlfir::DeclareOp>( 1221 loc, value2, llvm::StringRef{}, shape, llvm::ArrayRef<mlir::Value>{}, 1222 /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{}); 1223 hlfir::DesignateOp::Subscripts triplets = getTripletsFromArgs(recipe); 1224 1225 llvm::SmallVector<mlir::Value> lenParamsLeft; 1226 auto leftEntity = hlfir::Entity{v1DeclareOp.getBase()}; 1227 hlfir::genLengthParameters(loc, builder, leftEntity, lenParamsLeft); 1228 auto leftDesignate = builder.create<hlfir::DesignateOp>( 1229 loc, v1DeclareOp.getBase().getType(), v1DeclareOp.getBase(), 1230 /*component=*/"", 1231 /*componentShape=*/mlir::Value{}, triplets, 1232 /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt, 1233 shape, lenParamsLeft); 1234 auto left = hlfir::Entity{leftDesignate.getResult()}; 1235 1236 llvm::SmallVector<mlir::Value> lenParamsRight; 1237 auto rightEntity = hlfir::Entity{v2DeclareOp.getBase()}; 1238 hlfir::genLengthParameters(loc, builder, rightEntity, lenParamsLeft); 1239 auto rightDesignate = builder.create<hlfir::DesignateOp>( 1240 loc, v2DeclareOp.getBase().getType(), v2DeclareOp.getBase(), 1241 /*component=*/"", 1242 /*componentShape=*/mlir::Value{}, triplets, 1243 /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt, 1244 shape, lenParamsRight); 1245 auto right = hlfir::Entity{rightDesignate.getResult()}; 1246 1247 llvm::SmallVector<mlir::Value, 1> typeParams; 1248 auto genKernel = [&builder, &loc, op, seqTy, &left, &right]( 1249 mlir::Location l, fir::FirOpBuilder &b, 1250 mlir::ValueRange oneBasedIndices) -> hlfir::Entity { 1251 auto leftElement = hlfir::getElementAt(l, b, left, oneBasedIndices); 1252 auto rightElement = hlfir::getElementAt(l, b, right, oneBasedIndices); 1253 auto leftVal = hlfir::loadTrivialScalar(l, b, leftElement); 1254 auto rightVal = hlfir::loadTrivialScalar(l, b, rightElement); 1255 return hlfir::Entity{genScalarCombiner( 1256 builder, loc, op, seqTy.getEleTy(), leftVal, rightVal)}; 1257 }; 1258 mlir::Value elemental = hlfir::genElementalOp( 1259 loc, builder, seqTy.getEleTy(), shape, typeParams, genKernel, 1260 /*isUnordered=*/true); 1261 builder.create<hlfir::AssignOp>(loc, elemental, v1DeclareOp.getBase()); 1262 return; 1263 } 1264 if (allConstantBound) { 1265 // Use the constant bound directly in the combiner region so they do not 1266 // need to be passed as block argument. 1267 for (auto bound : llvm::reverse(bounds)) { 1268 auto dataBound = 1269 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp()); 1270 llvm::SmallVector<mlir::Value> values = 1271 genConstantBounds(builder, loc, dataBound); 1272 auto loop = 1273 builder.create<fir::DoLoopOp>(loc, values[0], values[1], values[2], 1274 /*unordered=*/false); 1275 builder.setInsertionPointToStart(loop.getBody()); 1276 loops.push_back(loop); 1277 ivs.push_back(loop.getInductionVar()); 1278 } 1279 } else { 1280 // Lowerbound, upperbound and step are passed as block arguments. 1281 [[maybe_unused]] unsigned nbRangeArgs = 1282 recipe.getCombinerRegion().getArguments().size() - 2; 1283 assert((nbRangeArgs / 3 == seqTy.getDimension()) && 1284 "Expect 3 block arguments per dimension"); 1285 for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size(); 1286 i += 3) { 1287 mlir::Value lb = recipe.getCombinerRegion().getArgument(i); 1288 mlir::Value ub = recipe.getCombinerRegion().getArgument(i + 1); 1289 mlir::Value step = recipe.getCombinerRegion().getArgument(i + 2); 1290 auto loop = builder.create<fir::DoLoopOp>(loc, lb, ub, step, 1291 /*unordered=*/false); 1292 builder.setInsertionPointToStart(loop.getBody()); 1293 loops.push_back(loop); 1294 ivs.push_back(loop.getInductionVar()); 1295 } 1296 } 1297 auto addr1 = builder.create<fir::CoordinateOp>(loc, refTy, value1, ivs); 1298 auto addr2 = builder.create<fir::CoordinateOp>(loc, refTy, value2, ivs); 1299 auto load1 = builder.create<fir::LoadOp>(loc, addr1); 1300 auto load2 = builder.create<fir::LoadOp>(loc, addr2); 1301 mlir::Value res = 1302 genScalarCombiner(builder, loc, op, seqTy.getEleTy(), load1, load2); 1303 builder.create<fir::StoreOp>(loc, res, addr1); 1304 builder.setInsertionPointAfter(loops[0]); 1305 } else if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { 1306 mlir::Type innerTy = fir::extractSequenceType(boxTy); 1307 fir::SequenceType seqTy = 1308 mlir::dyn_cast_or_null<fir::SequenceType>(innerTy); 1309 if (!seqTy) 1310 TODO(loc, "Unsupported boxed type in OpenACC reduction"); 1311 1312 auto shape = genShapeFromBoundsOrArgs( 1313 loc, builder, seqTy, bounds, recipe.getCombinerRegion().getArguments()); 1314 hlfir::DesignateOp::Subscripts triplets = 1315 getSubscriptsFromArgs(recipe.getCombinerRegion().getArguments()); 1316 auto leftEntity = hlfir::Entity{value1}; 1317 auto left = 1318 genDesignateWithTriplets(builder, loc, leftEntity, triplets, shape); 1319 auto rightEntity = hlfir::Entity{value2}; 1320 auto right = 1321 genDesignateWithTriplets(builder, loc, rightEntity, triplets, shape); 1322 1323 llvm::SmallVector<mlir::Value, 1> typeParams; 1324 auto genKernel = [&builder, &loc, op, seqTy, &left, &right]( 1325 mlir::Location l, fir::FirOpBuilder &b, 1326 mlir::ValueRange oneBasedIndices) -> hlfir::Entity { 1327 auto leftElement = hlfir::getElementAt(l, b, left, oneBasedIndices); 1328 auto rightElement = hlfir::getElementAt(l, b, right, oneBasedIndices); 1329 auto leftVal = hlfir::loadTrivialScalar(l, b, leftElement); 1330 auto rightVal = hlfir::loadTrivialScalar(l, b, rightElement); 1331 return hlfir::Entity{genScalarCombiner(builder, loc, op, seqTy.getEleTy(), 1332 leftVal, rightVal)}; 1333 }; 1334 mlir::Value elemental = hlfir::genElementalOp( 1335 loc, builder, seqTy.getEleTy(), shape, typeParams, genKernel, 1336 /*isUnordered=*/true); 1337 builder.create<hlfir::AssignOp>(loc, elemental, value1); 1338 } else { 1339 mlir::Value res = genScalarCombiner(builder, loc, op, ty, value1, value2); 1340 builder.create<fir::StoreOp>(loc, res, value1); 1341 } 1342 } 1343 1344 mlir::acc::ReductionRecipeOp Fortran::lower::createOrGetReductionRecipe( 1345 fir::FirOpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc, 1346 mlir::Type ty, mlir::acc::ReductionOperator op, 1347 llvm::SmallVector<mlir::Value> &bounds) { 1348 mlir::ModuleOp mod = 1349 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>(); 1350 if (auto recipe = mod.lookupSymbol<mlir::acc::ReductionRecipeOp>(recipeName)) 1351 return recipe; 1352 1353 auto crtPos = builder.saveInsertionPoint(); 1354 mlir::OpBuilder modBuilder(mod.getBodyRegion()); 1355 auto recipe = 1356 modBuilder.create<mlir::acc::ReductionRecipeOp>(loc, recipeName, ty, op); 1357 llvm::SmallVector<mlir::Type> initArgsTy{ty}; 1358 llvm::SmallVector<mlir::Location> initArgsLoc{loc}; 1359 mlir::Type refTy = fir::unwrapRefType(ty); 1360 if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(refTy)) { 1361 if (seqTy.hasDynamicExtents()) { 1362 mlir::Type idxTy = builder.getIndexType(); 1363 for (unsigned i = 0; i < seqTy.getDimension(); ++i) { 1364 initArgsTy.push_back(idxTy); 1365 initArgsLoc.push_back(loc); 1366 } 1367 } 1368 } 1369 builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(), 1370 initArgsTy, initArgsLoc); 1371 builder.setInsertionPointToEnd(&recipe.getInitRegion().back()); 1372 mlir::Value initValue = genReductionInitRegion(builder, loc, ty, op); 1373 builder.create<mlir::acc::YieldOp>(loc, initValue); 1374 1375 // The two first block arguments are the two values to be combined. 1376 // The next arguments are the iteration ranges (lb, ub, step) to be used 1377 // for the combiner if needed. 1378 llvm::SmallVector<mlir::Type> argsTy{ty, ty}; 1379 llvm::SmallVector<mlir::Location> argsLoc{loc, loc}; 1380 bool allConstantBound = areAllBoundConstant(bounds); 1381 if (!allConstantBound) { 1382 for (mlir::Value bound : llvm::reverse(bounds)) { 1383 auto dataBound = 1384 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp()); 1385 argsTy.push_back(dataBound.getLowerbound().getType()); 1386 argsLoc.push_back(dataBound.getLowerbound().getLoc()); 1387 argsTy.push_back(dataBound.getUpperbound().getType()); 1388 argsLoc.push_back(dataBound.getUpperbound().getLoc()); 1389 argsTy.push_back(dataBound.getStartIdx().getType()); 1390 argsLoc.push_back(dataBound.getStartIdx().getLoc()); 1391 } 1392 } 1393 builder.createBlock(&recipe.getCombinerRegion(), 1394 recipe.getCombinerRegion().end(), argsTy, argsLoc); 1395 builder.setInsertionPointToEnd(&recipe.getCombinerRegion().back()); 1396 mlir::Value v1 = recipe.getCombinerRegion().front().getArgument(0); 1397 mlir::Value v2 = recipe.getCombinerRegion().front().getArgument(1); 1398 genCombiner(builder, loc, op, ty, v1, v2, recipe, bounds, allConstantBound); 1399 builder.create<mlir::acc::YieldOp>(loc, v1); 1400 builder.restoreInsertionPoint(crtPos); 1401 return recipe; 1402 } 1403 1404 static bool isSupportedReductionType(mlir::Type ty) { 1405 ty = fir::unwrapRefType(ty); 1406 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) 1407 return isSupportedReductionType(boxTy.getEleTy()); 1408 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) 1409 return isSupportedReductionType(seqTy.getEleTy()); 1410 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty)) 1411 return isSupportedReductionType(heapTy.getEleTy()); 1412 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty)) 1413 return isSupportedReductionType(ptrTy.getEleTy()); 1414 return fir::isa_trivial(ty); 1415 } 1416 1417 static void 1418 genReductions(const Fortran::parser::AccObjectListWithReduction &objectList, 1419 Fortran::lower::AbstractConverter &converter, 1420 Fortran::semantics::SemanticsContext &semanticsContext, 1421 Fortran::lower::StatementContext &stmtCtx, 1422 llvm::SmallVectorImpl<mlir::Value> &reductionOperands, 1423 llvm::SmallVector<mlir::Attribute> &reductionRecipes, 1424 llvm::ArrayRef<mlir::Value> async, 1425 llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes, 1426 llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes) { 1427 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 1428 const auto &objects = std::get<Fortran::parser::AccObjectList>(objectList.t); 1429 const auto &op = std::get<Fortran::parser::ReductionOperator>(objectList.t); 1430 mlir::acc::ReductionOperator mlirOp = getReductionOperator(op); 1431 Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; 1432 for (const auto &accObject : objects.v) { 1433 llvm::SmallVector<mlir::Value> bounds; 1434 std::stringstream asFortran; 1435 mlir::Location operandLocation = genOperandLocation(converter, accObject); 1436 Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); 1437 Fortran::semantics::MaybeExpr designator = Fortran::common::visit( 1438 [&](auto &&s) { return ea.Analyze(s); }, accObject.u); 1439 fir::factory::AddrAndBoundsInfo info = 1440 Fortran::lower::gatherDataOperandAddrAndBounds< 1441 mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>( 1442 converter, builder, semanticsContext, stmtCtx, symbol, designator, 1443 operandLocation, asFortran, bounds); 1444 LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs())); 1445 1446 mlir::Type reductionTy = fir::unwrapRefType(info.addr.getType()); 1447 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy)) 1448 reductionTy = seqTy.getEleTy(); 1449 1450 if (!isSupportedReductionType(reductionTy)) 1451 TODO(operandLocation, "reduction with unsupported type"); 1452 1453 auto op = createDataEntryOp<mlir::acc::ReductionOp>( 1454 builder, operandLocation, info.addr, asFortran, bounds, 1455 /*structured=*/true, /*implicit=*/false, 1456 mlir::acc::DataClause::acc_reduction, info.addr.getType(), async, 1457 asyncDeviceTypes, asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true); 1458 mlir::Type ty = op.getAccPtr().getType(); 1459 if (!areAllBoundConstant(bounds) || 1460 fir::isAssumedShape(info.addr.getType()) || 1461 fir::isAllocatableOrPointerArray(info.addr.getType())) 1462 ty = info.addr.getType(); 1463 std::string suffix = 1464 areAllBoundConstant(bounds) ? getBoundsString(bounds) : ""; 1465 std::string recipeName = fir::getTypeAsString( 1466 ty, converter.getKindMap(), 1467 ("reduction_" + stringifyReductionOperator(mlirOp)).str() + suffix); 1468 1469 mlir::acc::ReductionRecipeOp recipe = 1470 Fortran::lower::createOrGetReductionRecipe( 1471 builder, recipeName, operandLocation, ty, mlirOp, bounds); 1472 reductionRecipes.push_back(mlir::SymbolRefAttr::get( 1473 builder.getContext(), recipe.getSymName().str())); 1474 reductionOperands.push_back(op.getAccPtr()); 1475 } 1476 } 1477 1478 template <typename Op, typename Terminator> 1479 static Op 1480 createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc, 1481 mlir::Location returnLoc, Fortran::lower::pft::Evaluation &eval, 1482 const llvm::SmallVectorImpl<mlir::Value> &operands, 1483 const llvm::SmallVectorImpl<int32_t> &operandSegments, 1484 bool outerCombined = false, 1485 llvm::SmallVector<mlir::Type> retTy = {}, 1486 mlir::Value yieldValue = {}, mlir::TypeRange argsTy = {}, 1487 llvm::SmallVector<mlir::Location> locs = {}) { 1488 Op op = builder.create<Op>(loc, retTy, operands); 1489 builder.createBlock(&op.getRegion(), op.getRegion().end(), argsTy, locs); 1490 mlir::Block &block = op.getRegion().back(); 1491 builder.setInsertionPointToStart(&block); 1492 1493 op->setAttr(Op::getOperandSegmentSizeAttr(), 1494 builder.getDenseI32ArrayAttr(operandSegments)); 1495 1496 // Place the insertion point to the start of the first block. 1497 builder.setInsertionPointToStart(&block); 1498 1499 // If it is an unstructured region and is not the outer region of a combined 1500 // construct, create empty blocks for all evaluations. 1501 if (eval.lowerAsUnstructured() && !outerCombined) 1502 Fortran::lower::createEmptyRegionBlocks<mlir::acc::TerminatorOp, 1503 mlir::acc::YieldOp>( 1504 builder, eval.getNestedEvaluations()); 1505 1506 if (yieldValue) { 1507 if constexpr (std::is_same_v<Terminator, mlir::acc::YieldOp>) { 1508 Terminator yieldOp = builder.create<Terminator>(returnLoc, yieldValue); 1509 yieldValue.getDefiningOp()->moveBefore(yieldOp); 1510 } else { 1511 builder.create<Terminator>(returnLoc); 1512 } 1513 } else { 1514 builder.create<Terminator>(returnLoc); 1515 } 1516 builder.setInsertionPointToStart(&block); 1517 return op; 1518 } 1519 1520 static void genAsyncClause(Fortran::lower::AbstractConverter &converter, 1521 const Fortran::parser::AccClause::Async *asyncClause, 1522 mlir::Value &async, bool &addAsyncAttr, 1523 Fortran::lower::StatementContext &stmtCtx) { 1524 const auto &asyncClauseValue = asyncClause->v; 1525 if (asyncClauseValue) { // async has a value. 1526 async = fir::getBase(converter.genExprValue( 1527 *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)); 1528 } else { 1529 addAsyncAttr = true; 1530 } 1531 } 1532 1533 static void 1534 genAsyncClause(Fortran::lower::AbstractConverter &converter, 1535 const Fortran::parser::AccClause::Async *asyncClause, 1536 llvm::SmallVector<mlir::Value> &async, 1537 llvm::SmallVector<mlir::Attribute> &asyncDeviceTypes, 1538 llvm::SmallVector<mlir::Attribute> &asyncOnlyDeviceTypes, 1539 llvm::SmallVector<mlir::Attribute> &deviceTypeAttrs, 1540 Fortran::lower::StatementContext &stmtCtx) { 1541 const auto &asyncClauseValue = asyncClause->v; 1542 if (asyncClauseValue) { // async has a value. 1543 mlir::Value asyncValue = fir::getBase(converter.genExprValue( 1544 *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)); 1545 for (auto deviceTypeAttr : deviceTypeAttrs) { 1546 async.push_back(asyncValue); 1547 asyncDeviceTypes.push_back(deviceTypeAttr); 1548 } 1549 } else { 1550 for (auto deviceTypeAttr : deviceTypeAttrs) 1551 asyncOnlyDeviceTypes.push_back(deviceTypeAttr); 1552 } 1553 } 1554 1555 static mlir::acc::DeviceType 1556 getDeviceType(Fortran::common::OpenACCDeviceType device) { 1557 switch (device) { 1558 case Fortran::common::OpenACCDeviceType::Star: 1559 return mlir::acc::DeviceType::Star; 1560 case Fortran::common::OpenACCDeviceType::Default: 1561 return mlir::acc::DeviceType::Default; 1562 case Fortran::common::OpenACCDeviceType::Nvidia: 1563 return mlir::acc::DeviceType::Nvidia; 1564 case Fortran::common::OpenACCDeviceType::Radeon: 1565 return mlir::acc::DeviceType::Radeon; 1566 case Fortran::common::OpenACCDeviceType::Host: 1567 return mlir::acc::DeviceType::Host; 1568 case Fortran::common::OpenACCDeviceType::Multicore: 1569 return mlir::acc::DeviceType::Multicore; 1570 case Fortran::common::OpenACCDeviceType::None: 1571 return mlir::acc::DeviceType::None; 1572 } 1573 return mlir::acc::DeviceType::None; 1574 } 1575 1576 static void gatherDeviceTypeAttrs( 1577 fir::FirOpBuilder &builder, 1578 const Fortran::parser::AccClause::DeviceType *deviceTypeClause, 1579 llvm::SmallVector<mlir::Attribute> &deviceTypes) { 1580 const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList = 1581 deviceTypeClause->v; 1582 for (const auto &deviceTypeExpr : deviceTypeExprList.v) 1583 deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get( 1584 builder.getContext(), getDeviceType(deviceTypeExpr.v))); 1585 } 1586 1587 static void genIfClause(Fortran::lower::AbstractConverter &converter, 1588 mlir::Location clauseLocation, 1589 const Fortran::parser::AccClause::If *ifClause, 1590 mlir::Value &ifCond, 1591 Fortran::lower::StatementContext &stmtCtx) { 1592 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 1593 mlir::Value cond = fir::getBase(converter.genExprValue( 1594 *Fortran::semantics::GetExpr(ifClause->v), stmtCtx, &clauseLocation)); 1595 ifCond = firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(), 1596 cond); 1597 } 1598 1599 static void genWaitClause(Fortran::lower::AbstractConverter &converter, 1600 const Fortran::parser::AccClause::Wait *waitClause, 1601 llvm::SmallVectorImpl<mlir::Value> &operands, 1602 mlir::Value &waitDevnum, bool &addWaitAttr, 1603 Fortran::lower::StatementContext &stmtCtx) { 1604 const auto &waitClauseValue = waitClause->v; 1605 if (waitClauseValue) { // wait has a value. 1606 const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue; 1607 const auto &waitList = 1608 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t); 1609 for (const Fortran::parser::ScalarIntExpr &value : waitList) { 1610 mlir::Value v = fir::getBase( 1611 converter.genExprValue(*Fortran::semantics::GetExpr(value), stmtCtx)); 1612 operands.push_back(v); 1613 } 1614 1615 const auto &waitDevnumValue = 1616 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t); 1617 if (waitDevnumValue) 1618 waitDevnum = fir::getBase(converter.genExprValue( 1619 *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)); 1620 } else { 1621 addWaitAttr = true; 1622 } 1623 } 1624 1625 static void genWaitClauseWithDeviceType( 1626 Fortran::lower::AbstractConverter &converter, 1627 const Fortran::parser::AccClause::Wait *waitClause, 1628 llvm::SmallVector<mlir::Value> &waitOperands, 1629 llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes, 1630 llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes, 1631 llvm::SmallVector<bool> &hasDevnums, 1632 llvm::SmallVector<int32_t> &waitOperandsSegments, 1633 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, 1634 Fortran::lower::StatementContext &stmtCtx) { 1635 const auto &waitClauseValue = waitClause->v; 1636 if (waitClauseValue) { // wait has a value. 1637 llvm::SmallVector<mlir::Value> waitValues; 1638 1639 const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue; 1640 const auto &waitDevnumValue = 1641 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t); 1642 bool hasDevnum = false; 1643 if (waitDevnumValue) { 1644 waitValues.push_back(fir::getBase(converter.genExprValue( 1645 *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx))); 1646 hasDevnum = true; 1647 } 1648 1649 const auto &waitList = 1650 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t); 1651 for (const Fortran::parser::ScalarIntExpr &value : waitList) { 1652 waitValues.push_back(fir::getBase(converter.genExprValue( 1653 *Fortran::semantics::GetExpr(value), stmtCtx))); 1654 } 1655 1656 for (auto deviceTypeAttr : deviceTypeAttrs) { 1657 for (auto value : waitValues) 1658 waitOperands.push_back(value); 1659 waitOperandsDeviceTypes.push_back(deviceTypeAttr); 1660 waitOperandsSegments.push_back(waitValues.size()); 1661 hasDevnums.push_back(hasDevnum); 1662 } 1663 } else { 1664 for (auto deviceTypeAttr : deviceTypeAttrs) 1665 waitOnlyDeviceTypes.push_back(deviceTypeAttr); 1666 } 1667 } 1668 1669 mlir::Type getTypeFromIvTypeSize(fir::FirOpBuilder &builder, 1670 const Fortran::semantics::Symbol &ivSym) { 1671 std::size_t ivTypeSize = ivSym.size(); 1672 if (ivTypeSize == 0) 1673 llvm::report_fatal_error("unexpected induction variable size"); 1674 // ivTypeSize is in bytes and IntegerType needs to be in bits. 1675 return builder.getIntegerType(ivTypeSize * 8); 1676 } 1677 1678 static void privatizeIv(Fortran::lower::AbstractConverter &converter, 1679 const Fortran::semantics::Symbol &sym, 1680 mlir::Location loc, 1681 llvm::SmallVector<mlir::Type> &ivTypes, 1682 llvm::SmallVector<mlir::Location> &ivLocs, 1683 llvm::SmallVector<mlir::Value> &privateOperands, 1684 llvm::SmallVector<mlir::Value> &ivPrivate, 1685 llvm::SmallVector<mlir::Attribute> &privatizations, 1686 bool isDoConcurrent = false) { 1687 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 1688 1689 mlir::Type ivTy = getTypeFromIvTypeSize(builder, sym); 1690 ivTypes.push_back(ivTy); 1691 ivLocs.push_back(loc); 1692 mlir::Value ivValue = converter.getSymbolAddress(sym); 1693 if (!ivValue && isDoConcurrent) { 1694 // DO CONCURRENT induction variables are not mapped yet since they are local 1695 // to the DO CONCURRENT scope. 1696 mlir::OpBuilder::InsertPoint insPt = builder.saveInsertionPoint(); 1697 builder.setInsertionPointToStart(builder.getAllocaBlock()); 1698 ivValue = builder.createTemporaryAlloc(loc, ivTy, toStringRef(sym.name())); 1699 builder.restoreInsertionPoint(insPt); 1700 } 1701 1702 std::string recipeName = 1703 fir::getTypeAsString(ivValue.getType(), converter.getKindMap(), 1704 Fortran::lower::privatizationRecipePrefix); 1705 auto recipe = Fortran::lower::createOrGetPrivateRecipe( 1706 builder, recipeName, loc, ivValue.getType()); 1707 1708 std::stringstream asFortran; 1709 auto op = createDataEntryOp<mlir::acc::PrivateOp>( 1710 builder, loc, ivValue, asFortran, {}, true, /*implicit=*/true, 1711 mlir::acc::DataClause::acc_private, ivValue.getType(), 1712 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); 1713 1714 privateOperands.push_back(op.getAccPtr()); 1715 privatizations.push_back(mlir::SymbolRefAttr::get(builder.getContext(), 1716 recipe.getSymName().str())); 1717 1718 // Map the new private iv to its symbol for the scope of the loop. bindSymbol 1719 // might create a hlfir.declare op, if so, we map its result in order to 1720 // use the sym value in the scope. 1721 converter.bindSymbol(sym, op.getAccPtr()); 1722 auto privateValue = converter.getSymbolAddress(sym); 1723 if (auto declareOp = 1724 mlir::dyn_cast<hlfir::DeclareOp>(privateValue.getDefiningOp())) 1725 privateValue = declareOp.getResults()[0]; 1726 ivPrivate.push_back(privateValue); 1727 } 1728 1729 static mlir::acc::LoopOp createLoopOp( 1730 Fortran::lower::AbstractConverter &converter, 1731 mlir::Location currentLocation, 1732 Fortran::semantics::SemanticsContext &semanticsContext, 1733 Fortran::lower::StatementContext &stmtCtx, 1734 const Fortran::parser::DoConstruct &outerDoConstruct, 1735 Fortran::lower::pft::Evaluation &eval, 1736 const Fortran::parser::AccClauseList &accClauseList, 1737 std::optional<mlir::acc::CombinedConstructsType> combinedConstructs = 1738 std::nullopt, 1739 bool needEarlyReturnHandling = false) { 1740 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 1741 llvm::SmallVector<mlir::Value> tileOperands, privateOperands, ivPrivate, 1742 reductionOperands, cacheOperands, vectorOperands, workerNumOperands, 1743 gangOperands, lowerbounds, upperbounds, steps; 1744 llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes; 1745 llvm::SmallVector<int32_t> tileOperandsSegments, gangOperandsSegments; 1746 llvm::SmallVector<int64_t> collapseValues; 1747 1748 llvm::SmallVector<mlir::Attribute> gangArgTypes; 1749 llvm::SmallVector<mlir::Attribute> seqDeviceTypes, independentDeviceTypes, 1750 autoDeviceTypes, vectorOperandsDeviceTypes, workerNumOperandsDeviceTypes, 1751 vectorDeviceTypes, workerNumDeviceTypes, tileOperandsDeviceTypes, 1752 collapseDeviceTypes, gangDeviceTypes, gangOperandsDeviceTypes; 1753 1754 // device_type attribute is set to `none` until a device_type clause is 1755 // encountered. 1756 llvm::SmallVector<mlir::Attribute> crtDeviceTypes; 1757 crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get( 1758 builder.getContext(), mlir::acc::DeviceType::None)); 1759 1760 llvm::SmallVector<mlir::Type> ivTypes; 1761 llvm::SmallVector<mlir::Location> ivLocs; 1762 llvm::SmallVector<bool> inclusiveBounds; 1763 1764 llvm::SmallVector<mlir::Location> locs; 1765 locs.push_back(currentLocation); // Location of the directive 1766 Fortran::lower::pft::Evaluation *crtEval = &eval.getFirstNestedEvaluation(); 1767 bool isDoConcurrent = outerDoConstruct.IsDoConcurrent(); 1768 if (isDoConcurrent) { 1769 locs.push_back(converter.genLocation( 1770 Fortran::parser::FindSourceLocation(outerDoConstruct))); 1771 const Fortran::parser::LoopControl *loopControl = 1772 &*outerDoConstruct.GetLoopControl(); 1773 const auto &concurrent = 1774 std::get<Fortran::parser::LoopControl::Concurrent>(loopControl->u); 1775 if (!std::get<std::list<Fortran::parser::LocalitySpec>>(concurrent.t) 1776 .empty()) 1777 TODO(currentLocation, "DO CONCURRENT with locality spec"); 1778 1779 const auto &concurrentHeader = 1780 std::get<Fortran::parser::ConcurrentHeader>(concurrent.t); 1781 const auto &controls = 1782 std::get<std::list<Fortran::parser::ConcurrentControl>>( 1783 concurrentHeader.t); 1784 for (const auto &control : controls) { 1785 lowerbounds.push_back(fir::getBase(converter.genExprValue( 1786 *Fortran::semantics::GetExpr(std::get<1>(control.t)), stmtCtx))); 1787 upperbounds.push_back(fir::getBase(converter.genExprValue( 1788 *Fortran::semantics::GetExpr(std::get<2>(control.t)), stmtCtx))); 1789 if (const auto &expr = 1790 std::get<std::optional<Fortran::parser::ScalarIntExpr>>( 1791 control.t)) 1792 steps.push_back(fir::getBase(converter.genExprValue( 1793 *Fortran::semantics::GetExpr(*expr), stmtCtx))); 1794 else // If `step` is not present, assume it is `1`. 1795 steps.push_back(builder.createIntegerConstant( 1796 currentLocation, upperbounds[upperbounds.size() - 1].getType(), 1)); 1797 1798 const auto &name = std::get<Fortran::parser::Name>(control.t); 1799 privatizeIv(converter, *name.symbol, currentLocation, ivTypes, ivLocs, 1800 privateOperands, ivPrivate, privatizations, isDoConcurrent); 1801 1802 inclusiveBounds.push_back(true); 1803 } 1804 } else { 1805 int64_t collapseValue = Fortran::lower::getCollapseValue(accClauseList); 1806 for (unsigned i = 0; i < collapseValue; ++i) { 1807 const Fortran::parser::LoopControl *loopControl; 1808 if (i == 0) { 1809 loopControl = &*outerDoConstruct.GetLoopControl(); 1810 locs.push_back(converter.genLocation( 1811 Fortran::parser::FindSourceLocation(outerDoConstruct))); 1812 } else { 1813 auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>(); 1814 assert(doCons && "expect do construct"); 1815 loopControl = &*doCons->GetLoopControl(); 1816 locs.push_back(converter.genLocation( 1817 Fortran::parser::FindSourceLocation(*doCons))); 1818 } 1819 1820 const Fortran::parser::LoopControl::Bounds *bounds = 1821 std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u); 1822 assert(bounds && "Expected bounds on the loop construct"); 1823 lowerbounds.push_back(fir::getBase(converter.genExprValue( 1824 *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))); 1825 upperbounds.push_back(fir::getBase(converter.genExprValue( 1826 *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))); 1827 if (bounds->step) 1828 steps.push_back(fir::getBase(converter.genExprValue( 1829 *Fortran::semantics::GetExpr(bounds->step), stmtCtx))); 1830 else // If `step` is not present, assume it is `1`. 1831 steps.push_back(builder.createIntegerConstant( 1832 currentLocation, upperbounds[upperbounds.size() - 1].getType(), 1)); 1833 1834 Fortran::semantics::Symbol &ivSym = 1835 bounds->name.thing.symbol->GetUltimate(); 1836 privatizeIv(converter, ivSym, currentLocation, ivTypes, ivLocs, 1837 privateOperands, ivPrivate, privatizations); 1838 1839 inclusiveBounds.push_back(true); 1840 1841 if (i < collapseValue - 1) 1842 crtEval = &*std::next(crtEval->getNestedEvaluations().begin()); 1843 } 1844 } 1845 1846 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 1847 mlir::Location clauseLocation = converter.genLocation(clause.source); 1848 if (const auto *gangClause = 1849 std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) { 1850 if (gangClause->v) { 1851 const Fortran::parser::AccGangArgList &x = *gangClause->v; 1852 mlir::SmallVector<mlir::Value> gangValues; 1853 mlir::SmallVector<mlir::Attribute> gangArgs; 1854 for (const Fortran::parser::AccGangArg &gangArg : x.v) { 1855 if (const auto *num = 1856 std::get_if<Fortran::parser::AccGangArg::Num>(&gangArg.u)) { 1857 gangValues.push_back(fir::getBase(converter.genExprValue( 1858 *Fortran::semantics::GetExpr(num->v), stmtCtx))); 1859 gangArgs.push_back(mlir::acc::GangArgTypeAttr::get( 1860 builder.getContext(), mlir::acc::GangArgType::Num)); 1861 } else if (const auto *staticArg = 1862 std::get_if<Fortran::parser::AccGangArg::Static>( 1863 &gangArg.u)) { 1864 const Fortran::parser::AccSizeExpr &sizeExpr = staticArg->v; 1865 if (sizeExpr.v) { 1866 gangValues.push_back(fir::getBase(converter.genExprValue( 1867 *Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx))); 1868 } else { 1869 // * was passed as value and will be represented as a special 1870 // constant. 1871 gangValues.push_back(builder.createIntegerConstant( 1872 clauseLocation, builder.getIndexType(), starCst)); 1873 } 1874 gangArgs.push_back(mlir::acc::GangArgTypeAttr::get( 1875 builder.getContext(), mlir::acc::GangArgType::Static)); 1876 } else if (const auto *dim = 1877 std::get_if<Fortran::parser::AccGangArg::Dim>( 1878 &gangArg.u)) { 1879 gangValues.push_back(fir::getBase(converter.genExprValue( 1880 *Fortran::semantics::GetExpr(dim->v), stmtCtx))); 1881 gangArgs.push_back(mlir::acc::GangArgTypeAttr::get( 1882 builder.getContext(), mlir::acc::GangArgType::Dim)); 1883 } 1884 } 1885 for (auto crtDeviceTypeAttr : crtDeviceTypes) { 1886 for (const auto &pair : llvm::zip(gangValues, gangArgs)) { 1887 gangOperands.push_back(std::get<0>(pair)); 1888 gangArgTypes.push_back(std::get<1>(pair)); 1889 } 1890 gangOperandsSegments.push_back(gangValues.size()); 1891 gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr); 1892 } 1893 } else { 1894 for (auto crtDeviceTypeAttr : crtDeviceTypes) 1895 gangDeviceTypes.push_back(crtDeviceTypeAttr); 1896 } 1897 } else if (const auto *workerClause = 1898 std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) { 1899 if (workerClause->v) { 1900 mlir::Value workerNumValue = fir::getBase(converter.genExprValue( 1901 *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx)); 1902 for (auto crtDeviceTypeAttr : crtDeviceTypes) { 1903 workerNumOperands.push_back(workerNumValue); 1904 workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr); 1905 } 1906 } else { 1907 for (auto crtDeviceTypeAttr : crtDeviceTypes) 1908 workerNumDeviceTypes.push_back(crtDeviceTypeAttr); 1909 } 1910 } else if (const auto *vectorClause = 1911 std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) { 1912 if (vectorClause->v) { 1913 mlir::Value vectorValue = fir::getBase(converter.genExprValue( 1914 *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx)); 1915 for (auto crtDeviceTypeAttr : crtDeviceTypes) { 1916 vectorOperands.push_back(vectorValue); 1917 vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr); 1918 } 1919 } else { 1920 for (auto crtDeviceTypeAttr : crtDeviceTypes) 1921 vectorDeviceTypes.push_back(crtDeviceTypeAttr); 1922 } 1923 } else if (const auto *tileClause = 1924 std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) { 1925 const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v; 1926 llvm::SmallVector<mlir::Value> tileValues; 1927 for (const auto &accTileExpr : accTileExprList.v) { 1928 const auto &expr = 1929 std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>( 1930 accTileExpr.t); 1931 if (expr) { 1932 tileValues.push_back(fir::getBase(converter.genExprValue( 1933 *Fortran::semantics::GetExpr(*expr), stmtCtx))); 1934 } else { 1935 // * was passed as value and will be represented as a special 1936 // constant. 1937 mlir::Value tileStar = builder.createIntegerConstant( 1938 clauseLocation, builder.getIntegerType(32), starCst); 1939 tileValues.push_back(tileStar); 1940 } 1941 } 1942 for (auto crtDeviceTypeAttr : crtDeviceTypes) { 1943 for (auto value : tileValues) 1944 tileOperands.push_back(value); 1945 tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr); 1946 tileOperandsSegments.push_back(tileValues.size()); 1947 } 1948 } else if (const auto *privateClause = 1949 std::get_if<Fortran::parser::AccClause::Private>( 1950 &clause.u)) { 1951 genPrivatizations<mlir::acc::PrivateRecipeOp>( 1952 privateClause->v, converter, semanticsContext, stmtCtx, 1953 privateOperands, privatizations, /*async=*/{}, 1954 /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); 1955 } else if (const auto *reductionClause = 1956 std::get_if<Fortran::parser::AccClause::Reduction>( 1957 &clause.u)) { 1958 genReductions(reductionClause->v, converter, semanticsContext, stmtCtx, 1959 reductionOperands, reductionRecipes, /*async=*/{}, 1960 /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); 1961 } else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) { 1962 for (auto crtDeviceTypeAttr : crtDeviceTypes) 1963 seqDeviceTypes.push_back(crtDeviceTypeAttr); 1964 } else if (std::get_if<Fortran::parser::AccClause::Independent>( 1965 &clause.u)) { 1966 for (auto crtDeviceTypeAttr : crtDeviceTypes) 1967 independentDeviceTypes.push_back(crtDeviceTypeAttr); 1968 } else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) { 1969 for (auto crtDeviceTypeAttr : crtDeviceTypes) 1970 autoDeviceTypes.push_back(crtDeviceTypeAttr); 1971 } else if (const auto *deviceTypeClause = 1972 std::get_if<Fortran::parser::AccClause::DeviceType>( 1973 &clause.u)) { 1974 crtDeviceTypes.clear(); 1975 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes); 1976 } else if (const auto *collapseClause = 1977 std::get_if<Fortran::parser::AccClause::Collapse>( 1978 &clause.u)) { 1979 const Fortran::parser::AccCollapseArg &arg = collapseClause->v; 1980 const auto &force = std::get<bool>(arg.t); 1981 if (force) 1982 TODO(clauseLocation, "OpenACC collapse force modifier"); 1983 1984 const auto &intExpr = 1985 std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t); 1986 const auto *expr = Fortran::semantics::GetExpr(intExpr); 1987 const std::optional<int64_t> collapseValue = 1988 Fortran::evaluate::ToInt64(*expr); 1989 assert(collapseValue && "expect integer value for the collapse clause"); 1990 1991 for (auto crtDeviceTypeAttr : crtDeviceTypes) { 1992 collapseValues.push_back(*collapseValue); 1993 collapseDeviceTypes.push_back(crtDeviceTypeAttr); 1994 } 1995 } 1996 } 1997 1998 // Prepare the operand segment size attribute and the operands value range. 1999 llvm::SmallVector<mlir::Value> operands; 2000 llvm::SmallVector<int32_t> operandSegments; 2001 addOperands(operands, operandSegments, lowerbounds); 2002 addOperands(operands, operandSegments, upperbounds); 2003 addOperands(operands, operandSegments, steps); 2004 addOperands(operands, operandSegments, gangOperands); 2005 addOperands(operands, operandSegments, workerNumOperands); 2006 addOperands(operands, operandSegments, vectorOperands); 2007 addOperands(operands, operandSegments, tileOperands); 2008 addOperands(operands, operandSegments, cacheOperands); 2009 addOperands(operands, operandSegments, privateOperands); 2010 addOperands(operands, operandSegments, reductionOperands); 2011 2012 llvm::SmallVector<mlir::Type> retTy; 2013 mlir::Value yieldValue; 2014 if (needEarlyReturnHandling) { 2015 mlir::Type i1Ty = builder.getI1Type(); 2016 yieldValue = builder.createIntegerConstant(currentLocation, i1Ty, 0); 2017 retTy.push_back(i1Ty); 2018 } 2019 2020 auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>( 2021 builder, builder.getFusedLoc(locs), currentLocation, eval, operands, 2022 operandSegments, /*outerCombined=*/false, retTy, yieldValue, ivTypes, 2023 ivLocs); 2024 2025 for (auto [arg, value] : llvm::zip( 2026 loopOp.getLoopRegions().front()->front().getArguments(), ivPrivate)) 2027 builder.create<fir::StoreOp>(currentLocation, arg, value); 2028 2029 loopOp.setInclusiveUpperbound(inclusiveBounds); 2030 2031 if (!gangDeviceTypes.empty()) 2032 loopOp.setGangAttr(builder.getArrayAttr(gangDeviceTypes)); 2033 if (!gangArgTypes.empty()) 2034 loopOp.setGangOperandsArgTypeAttr(builder.getArrayAttr(gangArgTypes)); 2035 if (!gangOperandsSegments.empty()) 2036 loopOp.setGangOperandsSegmentsAttr( 2037 builder.getDenseI32ArrayAttr(gangOperandsSegments)); 2038 if (!gangOperandsDeviceTypes.empty()) 2039 loopOp.setGangOperandsDeviceTypeAttr( 2040 builder.getArrayAttr(gangOperandsDeviceTypes)); 2041 2042 if (!workerNumDeviceTypes.empty()) 2043 loopOp.setWorkerAttr(builder.getArrayAttr(workerNumDeviceTypes)); 2044 if (!workerNumOperandsDeviceTypes.empty()) 2045 loopOp.setWorkerNumOperandsDeviceTypeAttr( 2046 builder.getArrayAttr(workerNumOperandsDeviceTypes)); 2047 2048 if (!vectorDeviceTypes.empty()) 2049 loopOp.setVectorAttr(builder.getArrayAttr(vectorDeviceTypes)); 2050 if (!vectorOperandsDeviceTypes.empty()) 2051 loopOp.setVectorOperandsDeviceTypeAttr( 2052 builder.getArrayAttr(vectorOperandsDeviceTypes)); 2053 2054 if (!tileOperandsDeviceTypes.empty()) 2055 loopOp.setTileOperandsDeviceTypeAttr( 2056 builder.getArrayAttr(tileOperandsDeviceTypes)); 2057 if (!tileOperandsSegments.empty()) 2058 loopOp.setTileOperandsSegmentsAttr( 2059 builder.getDenseI32ArrayAttr(tileOperandsSegments)); 2060 2061 if (!seqDeviceTypes.empty()) 2062 loopOp.setSeqAttr(builder.getArrayAttr(seqDeviceTypes)); 2063 if (!independentDeviceTypes.empty()) 2064 loopOp.setIndependentAttr(builder.getArrayAttr(independentDeviceTypes)); 2065 if (!autoDeviceTypes.empty()) 2066 loopOp.setAuto_Attr(builder.getArrayAttr(autoDeviceTypes)); 2067 2068 if (!privatizations.empty()) 2069 loopOp.setPrivatizationsAttr( 2070 mlir::ArrayAttr::get(builder.getContext(), privatizations)); 2071 2072 if (!reductionRecipes.empty()) 2073 loopOp.setReductionRecipesAttr( 2074 mlir::ArrayAttr::get(builder.getContext(), reductionRecipes)); 2075 2076 if (!collapseValues.empty()) 2077 loopOp.setCollapseAttr(builder.getI64ArrayAttr(collapseValues)); 2078 if (!collapseDeviceTypes.empty()) 2079 loopOp.setCollapseDeviceTypeAttr(builder.getArrayAttr(collapseDeviceTypes)); 2080 2081 if (combinedConstructs) 2082 loopOp.setCombinedAttr(mlir::acc::CombinedConstructsTypeAttr::get( 2083 builder.getContext(), *combinedConstructs)); 2084 2085 // TODO: retrieve directives from NonLabelDoStmt pft::Evaluation, and add them 2086 // as attribute to the acc.loop as an extra attribute. It is not quite clear 2087 // how useful these $dir are in acc contexts, but they could still provide 2088 // more information about the loop acc codegen. They can be obtained by 2089 // looking for the first lexicalSuccessor of eval that is a NonLabelDoStmt, 2090 // and using the related `dirs` member. 2091 2092 return loopOp; 2093 } 2094 2095 static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) { 2096 bool hasReturnStmt = false; 2097 for (auto &e : eval.getNestedEvaluations()) { 2098 e.visit(Fortran::common::visitors{ 2099 [&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; }, 2100 [&](const auto &s) {}, 2101 }); 2102 if (e.hasNestedEvaluations()) 2103 hasReturnStmt = hasEarlyReturn(e); 2104 } 2105 return hasReturnStmt; 2106 } 2107 2108 static mlir::Value 2109 genACC(Fortran::lower::AbstractConverter &converter, 2110 Fortran::semantics::SemanticsContext &semanticsContext, 2111 Fortran::lower::pft::Evaluation &eval, 2112 const Fortran::parser::OpenACCLoopConstruct &loopConstruct) { 2113 2114 const auto &beginLoopDirective = 2115 std::get<Fortran::parser::AccBeginLoopDirective>(loopConstruct.t); 2116 const auto &loopDirective = 2117 std::get<Fortran::parser::AccLoopDirective>(beginLoopDirective.t); 2118 2119 bool needEarlyExitHandling = false; 2120 if (eval.lowerAsUnstructured()) 2121 needEarlyExitHandling = hasEarlyReturn(eval); 2122 2123 mlir::Location currentLocation = 2124 converter.genLocation(beginLoopDirective.source); 2125 Fortran::lower::StatementContext stmtCtx; 2126 2127 assert(loopDirective.v == llvm::acc::ACCD_loop && 2128 "Unsupported OpenACC loop construct"); 2129 (void)loopDirective; 2130 2131 const auto &accClauseList = 2132 std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t); 2133 const auto &outerDoConstruct = 2134 std::get<std::optional<Fortran::parser::DoConstruct>>(loopConstruct.t); 2135 auto loopOp = createLoopOp(converter, currentLocation, semanticsContext, 2136 stmtCtx, *outerDoConstruct, eval, accClauseList, 2137 /*combinedConstructs=*/{}, needEarlyExitHandling); 2138 if (needEarlyExitHandling) 2139 return loopOp.getResult(0); 2140 2141 return mlir::Value{}; 2142 } 2143 2144 template <typename Op, typename Clause> 2145 static void genDataOperandOperationsWithModifier( 2146 const Clause *x, Fortran::lower::AbstractConverter &converter, 2147 Fortran::semantics::SemanticsContext &semanticsContext, 2148 Fortran::lower::StatementContext &stmtCtx, 2149 Fortran::parser::AccDataModifier::Modifier mod, 2150 llvm::SmallVectorImpl<mlir::Value> &dataClauseOperands, 2151 const mlir::acc::DataClause clause, 2152 const mlir::acc::DataClause clauseWithModifier, 2153 llvm::ArrayRef<mlir::Value> async, 2154 llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes, 2155 llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes, 2156 bool setDeclareAttr = false) { 2157 const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v; 2158 const auto &accObjectList = 2159 std::get<Fortran::parser::AccObjectList>(listWithModifier.t); 2160 const auto &modifier = 2161 std::get<std::optional<Fortran::parser::AccDataModifier>>( 2162 listWithModifier.t); 2163 mlir::acc::DataClause dataClause = 2164 (modifier && (*modifier).v == mod) ? clauseWithModifier : clause; 2165 genDataOperandOperations<Op>(accObjectList, converter, semanticsContext, 2166 stmtCtx, dataClauseOperands, dataClause, 2167 /*structured=*/true, /*implicit=*/false, async, 2168 asyncDeviceTypes, asyncOnlyDeviceTypes, 2169 setDeclareAttr); 2170 } 2171 2172 template <typename Op> 2173 static Op createComputeOp( 2174 Fortran::lower::AbstractConverter &converter, 2175 mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval, 2176 Fortran::semantics::SemanticsContext &semanticsContext, 2177 Fortran::lower::StatementContext &stmtCtx, 2178 const Fortran::parser::AccClauseList &accClauseList, 2179 std::optional<mlir::acc::CombinedConstructsType> combinedConstructs = 2180 std::nullopt) { 2181 2182 // Parallel operation operands 2183 mlir::Value ifCond; 2184 mlir::Value selfCond; 2185 llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands, 2186 copyEntryOperands, copyinEntryOperands, copyoutEntryOperands, 2187 createEntryOperands, dataClauseOperands, numGangs, numWorkers, 2188 vectorLength, async; 2189 llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes, 2190 vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes, 2191 waitOperandsDeviceTypes, waitOnlyDeviceTypes; 2192 llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments; 2193 llvm::SmallVector<bool> hasWaitDevnums; 2194 2195 llvm::SmallVector<mlir::Value> reductionOperands, privateOperands, 2196 firstprivateOperands; 2197 llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations, 2198 reductionRecipes; 2199 2200 // Self clause has optional values but can be present with 2201 // no value as well. When there is no value, the op has an attribute to 2202 // represent the clause. 2203 bool addSelfAttr = false; 2204 2205 bool hasDefaultNone = false; 2206 bool hasDefaultPresent = false; 2207 2208 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 2209 2210 // device_type attribute is set to `none` until a device_type clause is 2211 // encountered. 2212 llvm::SmallVector<mlir::Attribute> crtDeviceTypes; 2213 auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( 2214 builder.getContext(), mlir::acc::DeviceType::None); 2215 crtDeviceTypes.push_back(crtDeviceTypeAttr); 2216 2217 // Lower clauses values mapped to operands and array attributes. 2218 // Keep track of each group of operands separately as clauses can appear 2219 // more than once. 2220 2221 // Process the clauses that may have a specified device_type first. 2222 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 2223 if (const auto *asyncClause = 2224 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) { 2225 genAsyncClause(converter, asyncClause, async, asyncDeviceTypes, 2226 asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx); 2227 } else if (const auto *waitClause = 2228 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) { 2229 genWaitClauseWithDeviceType(converter, waitClause, waitOperands, 2230 waitOperandsDeviceTypes, waitOnlyDeviceTypes, 2231 hasWaitDevnums, waitOperandsSegments, 2232 crtDeviceTypes, stmtCtx); 2233 } else if (const auto *numGangsClause = 2234 std::get_if<Fortran::parser::AccClause::NumGangs>( 2235 &clause.u)) { 2236 llvm::SmallVector<mlir::Value> numGangValues; 2237 for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v) 2238 numGangValues.push_back(fir::getBase(converter.genExprValue( 2239 *Fortran::semantics::GetExpr(expr), stmtCtx))); 2240 for (auto crtDeviceTypeAttr : crtDeviceTypes) { 2241 for (auto value : numGangValues) 2242 numGangs.push_back(value); 2243 numGangsDeviceTypes.push_back(crtDeviceTypeAttr); 2244 numGangsSegments.push_back(numGangValues.size()); 2245 } 2246 } else if (const auto *numWorkersClause = 2247 std::get_if<Fortran::parser::AccClause::NumWorkers>( 2248 &clause.u)) { 2249 mlir::Value numWorkerValue = fir::getBase(converter.genExprValue( 2250 *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)); 2251 for (auto crtDeviceTypeAttr : crtDeviceTypes) { 2252 numWorkers.push_back(numWorkerValue); 2253 numWorkersDeviceTypes.push_back(crtDeviceTypeAttr); 2254 } 2255 } else if (const auto *vectorLengthClause = 2256 std::get_if<Fortran::parser::AccClause::VectorLength>( 2257 &clause.u)) { 2258 mlir::Value vectorLengthValue = fir::getBase(converter.genExprValue( 2259 *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)); 2260 for (auto crtDeviceTypeAttr : crtDeviceTypes) { 2261 vectorLength.push_back(vectorLengthValue); 2262 vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr); 2263 } 2264 } else if (const auto *deviceTypeClause = 2265 std::get_if<Fortran::parser::AccClause::DeviceType>( 2266 &clause.u)) { 2267 crtDeviceTypes.clear(); 2268 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes); 2269 } 2270 } 2271 2272 // Process the clauses independent of device_type. 2273 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 2274 mlir::Location clauseLocation = converter.genLocation(clause.source); 2275 if (const auto *ifClause = 2276 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) { 2277 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx); 2278 } else if (const auto *selfClause = 2279 std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) { 2280 const std::optional<Fortran::parser::AccSelfClause> &accSelfClause = 2281 selfClause->v; 2282 if (accSelfClause) { 2283 if (const auto *optCondition = 2284 std::get_if<std::optional<Fortran::parser::ScalarLogicalExpr>>( 2285 &(*accSelfClause).u)) { 2286 if (*optCondition) { 2287 mlir::Value cond = fir::getBase(converter.genExprValue( 2288 *Fortran::semantics::GetExpr(*optCondition), stmtCtx)); 2289 selfCond = builder.createConvert(clauseLocation, 2290 builder.getI1Type(), cond); 2291 } 2292 } else if (const auto *accClauseList = 2293 std::get_if<Fortran::parser::AccObjectList>( 2294 &(*accSelfClause).u)) { 2295 // TODO This would be nicer to be done in canonicalization step. 2296 if (accClauseList->v.size() == 1) { 2297 const auto &accObject = accClauseList->v.front(); 2298 if (const auto *designator = 2299 std::get_if<Fortran::parser::Designator>(&accObject.u)) { 2300 if (const auto *name = 2301 Fortran::semantics::getDesignatorNameIfDataRef( 2302 *designator)) { 2303 auto cond = converter.getSymbolAddress(*name->symbol); 2304 selfCond = builder.createConvert(clauseLocation, 2305 builder.getI1Type(), cond); 2306 } 2307 } 2308 } 2309 } 2310 } else { 2311 addSelfAttr = true; 2312 } 2313 } else if (const auto *copyClause = 2314 std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) { 2315 auto crtDataStart = dataClauseOperands.size(); 2316 genDataOperandOperations<mlir::acc::CopyinOp>( 2317 copyClause->v, converter, semanticsContext, stmtCtx, 2318 dataClauseOperands, mlir::acc::DataClause::acc_copy, 2319 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes, 2320 asyncOnlyDeviceTypes); 2321 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart, 2322 dataClauseOperands.end()); 2323 } else if (const auto *copyinClause = 2324 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) { 2325 auto crtDataStart = dataClauseOperands.size(); 2326 genDataOperandOperationsWithModifier<mlir::acc::CopyinOp, 2327 Fortran::parser::AccClause::Copyin>( 2328 copyinClause, converter, semanticsContext, stmtCtx, 2329 Fortran::parser::AccDataModifier::Modifier::ReadOnly, 2330 dataClauseOperands, mlir::acc::DataClause::acc_copyin, 2331 mlir::acc::DataClause::acc_copyin_readonly, async, asyncDeviceTypes, 2332 asyncOnlyDeviceTypes); 2333 copyinEntryOperands.append(dataClauseOperands.begin() + crtDataStart, 2334 dataClauseOperands.end()); 2335 } else if (const auto *copyoutClause = 2336 std::get_if<Fortran::parser::AccClause::Copyout>( 2337 &clause.u)) { 2338 auto crtDataStart = dataClauseOperands.size(); 2339 genDataOperandOperationsWithModifier<mlir::acc::CreateOp, 2340 Fortran::parser::AccClause::Copyout>( 2341 copyoutClause, converter, semanticsContext, stmtCtx, 2342 Fortran::parser::AccDataModifier::Modifier::ReadOnly, 2343 dataClauseOperands, mlir::acc::DataClause::acc_copyout, 2344 mlir::acc::DataClause::acc_copyout_zero, async, asyncDeviceTypes, 2345 asyncOnlyDeviceTypes); 2346 copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart, 2347 dataClauseOperands.end()); 2348 } else if (const auto *createClause = 2349 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) { 2350 auto crtDataStart = dataClauseOperands.size(); 2351 genDataOperandOperationsWithModifier<mlir::acc::CreateOp, 2352 Fortran::parser::AccClause::Create>( 2353 createClause, converter, semanticsContext, stmtCtx, 2354 Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands, 2355 mlir::acc::DataClause::acc_create, 2356 mlir::acc::DataClause::acc_create_zero, async, asyncDeviceTypes, 2357 asyncOnlyDeviceTypes); 2358 createEntryOperands.append(dataClauseOperands.begin() + crtDataStart, 2359 dataClauseOperands.end()); 2360 } else if (const auto *noCreateClause = 2361 std::get_if<Fortran::parser::AccClause::NoCreate>( 2362 &clause.u)) { 2363 genDataOperandOperations<mlir::acc::NoCreateOp>( 2364 noCreateClause->v, converter, semanticsContext, stmtCtx, 2365 dataClauseOperands, mlir::acc::DataClause::acc_no_create, 2366 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes, 2367 asyncOnlyDeviceTypes); 2368 } else if (const auto *presentClause = 2369 std::get_if<Fortran::parser::AccClause::Present>( 2370 &clause.u)) { 2371 genDataOperandOperations<mlir::acc::PresentOp>( 2372 presentClause->v, converter, semanticsContext, stmtCtx, 2373 dataClauseOperands, mlir::acc::DataClause::acc_present, 2374 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes, 2375 asyncOnlyDeviceTypes); 2376 } else if (const auto *devicePtrClause = 2377 std::get_if<Fortran::parser::AccClause::Deviceptr>( 2378 &clause.u)) { 2379 genDataOperandOperations<mlir::acc::DevicePtrOp>( 2380 devicePtrClause->v, converter, semanticsContext, stmtCtx, 2381 dataClauseOperands, mlir::acc::DataClause::acc_deviceptr, 2382 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes, 2383 asyncOnlyDeviceTypes); 2384 } else if (const auto *attachClause = 2385 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) { 2386 auto crtDataStart = dataClauseOperands.size(); 2387 genDataOperandOperations<mlir::acc::AttachOp>( 2388 attachClause->v, converter, semanticsContext, stmtCtx, 2389 dataClauseOperands, mlir::acc::DataClause::acc_attach, 2390 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes, 2391 asyncOnlyDeviceTypes); 2392 attachEntryOperands.append(dataClauseOperands.begin() + crtDataStart, 2393 dataClauseOperands.end()); 2394 } else if (const auto *privateClause = 2395 std::get_if<Fortran::parser::AccClause::Private>( 2396 &clause.u)) { 2397 if (!combinedConstructs) 2398 genPrivatizations<mlir::acc::PrivateRecipeOp>( 2399 privateClause->v, converter, semanticsContext, stmtCtx, 2400 privateOperands, privatizations, async, asyncDeviceTypes, 2401 asyncOnlyDeviceTypes); 2402 } else if (const auto *firstprivateClause = 2403 std::get_if<Fortran::parser::AccClause::Firstprivate>( 2404 &clause.u)) { 2405 genPrivatizations<mlir::acc::FirstprivateRecipeOp>( 2406 firstprivateClause->v, converter, semanticsContext, stmtCtx, 2407 firstprivateOperands, firstPrivatizations, async, asyncDeviceTypes, 2408 asyncOnlyDeviceTypes); 2409 } else if (const auto *reductionClause = 2410 std::get_if<Fortran::parser::AccClause::Reduction>( 2411 &clause.u)) { 2412 // A reduction clause on a combined construct is treated as if it appeared 2413 // on the loop construct. So don't generate a reduction clause when it is 2414 // combined - delay it to the loop. However, a reduction clause on a 2415 // combined construct implies a copy clause so issue an implicit copy 2416 // instead. 2417 if (!combinedConstructs) { 2418 genReductions(reductionClause->v, converter, semanticsContext, stmtCtx, 2419 reductionOperands, reductionRecipes, async, 2420 asyncDeviceTypes, asyncOnlyDeviceTypes); 2421 } else { 2422 auto crtDataStart = dataClauseOperands.size(); 2423 genDataOperandOperations<mlir::acc::CopyinOp>( 2424 std::get<Fortran::parser::AccObjectList>(reductionClause->v.t), 2425 converter, semanticsContext, stmtCtx, dataClauseOperands, 2426 mlir::acc::DataClause::acc_reduction, 2427 /*structured=*/true, /*implicit=*/true, async, asyncDeviceTypes, 2428 asyncOnlyDeviceTypes); 2429 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart, 2430 dataClauseOperands.end()); 2431 } 2432 } else if (const auto *defaultClause = 2433 std::get_if<Fortran::parser::AccClause::Default>( 2434 &clause.u)) { 2435 if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none) 2436 hasDefaultNone = true; 2437 else if ((defaultClause->v).v == 2438 llvm::acc::DefaultValue::ACC_Default_present) 2439 hasDefaultPresent = true; 2440 } 2441 } 2442 2443 // Prepare the operand segment size attribute and the operands value range. 2444 llvm::SmallVector<mlir::Value, 8> operands; 2445 llvm::SmallVector<int32_t, 8> operandSegments; 2446 addOperands(operands, operandSegments, async); 2447 addOperands(operands, operandSegments, waitOperands); 2448 if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) { 2449 addOperands(operands, operandSegments, numGangs); 2450 addOperands(operands, operandSegments, numWorkers); 2451 addOperands(operands, operandSegments, vectorLength); 2452 } 2453 addOperand(operands, operandSegments, ifCond); 2454 addOperand(operands, operandSegments, selfCond); 2455 if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) { 2456 addOperands(operands, operandSegments, reductionOperands); 2457 addOperands(operands, operandSegments, privateOperands); 2458 addOperands(operands, operandSegments, firstprivateOperands); 2459 } 2460 addOperands(operands, operandSegments, dataClauseOperands); 2461 2462 Op computeOp; 2463 if constexpr (std::is_same_v<Op, mlir::acc::KernelsOp>) 2464 computeOp = createRegionOp<Op, mlir::acc::TerminatorOp>( 2465 builder, currentLocation, currentLocation, eval, operands, 2466 operandSegments, /*outerCombined=*/combinedConstructs.has_value()); 2467 else 2468 computeOp = createRegionOp<Op, mlir::acc::YieldOp>( 2469 builder, currentLocation, currentLocation, eval, operands, 2470 operandSegments, /*outerCombined=*/combinedConstructs.has_value()); 2471 2472 if (addSelfAttr) 2473 computeOp.setSelfAttrAttr(builder.getUnitAttr()); 2474 2475 if (hasDefaultNone) 2476 computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::None); 2477 if (hasDefaultPresent) 2478 computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present); 2479 2480 if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) { 2481 if (!numWorkersDeviceTypes.empty()) 2482 computeOp.setNumWorkersDeviceTypeAttr( 2483 mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes)); 2484 if (!vectorLengthDeviceTypes.empty()) 2485 computeOp.setVectorLengthDeviceTypeAttr( 2486 mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes)); 2487 if (!numGangsDeviceTypes.empty()) 2488 computeOp.setNumGangsDeviceTypeAttr( 2489 mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes)); 2490 if (!numGangsSegments.empty()) 2491 computeOp.setNumGangsSegmentsAttr( 2492 builder.getDenseI32ArrayAttr(numGangsSegments)); 2493 } 2494 if (!asyncDeviceTypes.empty()) 2495 computeOp.setAsyncOperandsDeviceTypeAttr( 2496 builder.getArrayAttr(asyncDeviceTypes)); 2497 if (!asyncOnlyDeviceTypes.empty()) 2498 computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes)); 2499 2500 if (!waitOperandsDeviceTypes.empty()) 2501 computeOp.setWaitOperandsDeviceTypeAttr( 2502 builder.getArrayAttr(waitOperandsDeviceTypes)); 2503 if (!waitOperandsSegments.empty()) 2504 computeOp.setWaitOperandsSegmentsAttr( 2505 builder.getDenseI32ArrayAttr(waitOperandsSegments)); 2506 if (!hasWaitDevnums.empty()) 2507 computeOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums)); 2508 if (!waitOnlyDeviceTypes.empty()) 2509 computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes)); 2510 2511 if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) { 2512 if (!privatizations.empty()) 2513 computeOp.setPrivatizationsAttr( 2514 mlir::ArrayAttr::get(builder.getContext(), privatizations)); 2515 if (!reductionRecipes.empty()) 2516 computeOp.setReductionRecipesAttr( 2517 mlir::ArrayAttr::get(builder.getContext(), reductionRecipes)); 2518 if (!firstPrivatizations.empty()) 2519 computeOp.setFirstprivatizationsAttr( 2520 mlir::ArrayAttr::get(builder.getContext(), firstPrivatizations)); 2521 } 2522 2523 if (combinedConstructs) 2524 computeOp.setCombinedAttr(builder.getUnitAttr()); 2525 2526 auto insPt = builder.saveInsertionPoint(); 2527 builder.setInsertionPointAfter(computeOp); 2528 2529 // Create the exit operations after the region. 2530 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>( 2531 builder, copyEntryOperands, /*structured=*/true); 2532 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::DeleteOp>( 2533 builder, copyinEntryOperands, /*structured=*/true); 2534 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>( 2535 builder, copyoutEntryOperands, /*structured=*/true); 2536 genDataExitOperations<mlir::acc::AttachOp, mlir::acc::DetachOp>( 2537 builder, attachEntryOperands, /*structured=*/true); 2538 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>( 2539 builder, createEntryOperands, /*structured=*/true); 2540 2541 builder.restoreInsertionPoint(insPt); 2542 return computeOp; 2543 } 2544 2545 static void genACCDataOp(Fortran::lower::AbstractConverter &converter, 2546 mlir::Location currentLocation, 2547 Fortran::lower::pft::Evaluation &eval, 2548 Fortran::semantics::SemanticsContext &semanticsContext, 2549 Fortran::lower::StatementContext &stmtCtx, 2550 const Fortran::parser::AccClauseList &accClauseList) { 2551 mlir::Value ifCond; 2552 llvm::SmallVector<mlir::Value> attachEntryOperands, createEntryOperands, 2553 copyEntryOperands, copyinEntryOperands, copyoutEntryOperands, 2554 dataClauseOperands, waitOperands, async; 2555 llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes, 2556 waitOperandsDeviceTypes, waitOnlyDeviceTypes; 2557 llvm::SmallVector<int32_t> waitOperandsSegments; 2558 llvm::SmallVector<bool> hasWaitDevnums; 2559 2560 bool hasDefaultNone = false; 2561 bool hasDefaultPresent = false; 2562 2563 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 2564 2565 // device_type attribute is set to `none` until a device_type clause is 2566 // encountered. 2567 llvm::SmallVector<mlir::Attribute> crtDeviceTypes; 2568 crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get( 2569 builder.getContext(), mlir::acc::DeviceType::None)); 2570 2571 // Lower clauses values mapped to operands and array attributes. 2572 // Keep track of each group of operands separately as clauses can appear 2573 // more than once. 2574 2575 // Process the clauses that may have a specified device_type first. 2576 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 2577 if (const auto *asyncClause = 2578 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) { 2579 genAsyncClause(converter, asyncClause, async, asyncDeviceTypes, 2580 asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx); 2581 } else if (const auto *waitClause = 2582 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) { 2583 genWaitClauseWithDeviceType(converter, waitClause, waitOperands, 2584 waitOperandsDeviceTypes, waitOnlyDeviceTypes, 2585 hasWaitDevnums, waitOperandsSegments, 2586 crtDeviceTypes, stmtCtx); 2587 } else if (const auto *deviceTypeClause = 2588 std::get_if<Fortran::parser::AccClause::DeviceType>( 2589 &clause.u)) { 2590 crtDeviceTypes.clear(); 2591 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes); 2592 } 2593 } 2594 2595 // Process the clauses independent of device_type. 2596 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 2597 mlir::Location clauseLocation = converter.genLocation(clause.source); 2598 if (const auto *ifClause = 2599 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) { 2600 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx); 2601 } else if (const auto *copyClause = 2602 std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) { 2603 auto crtDataStart = dataClauseOperands.size(); 2604 genDataOperandOperations<mlir::acc::CopyinOp>( 2605 copyClause->v, converter, semanticsContext, stmtCtx, 2606 dataClauseOperands, mlir::acc::DataClause::acc_copy, 2607 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes, 2608 asyncOnlyDeviceTypes); 2609 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart, 2610 dataClauseOperands.end()); 2611 } else if (const auto *copyinClause = 2612 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) { 2613 auto crtDataStart = dataClauseOperands.size(); 2614 genDataOperandOperationsWithModifier<mlir::acc::CopyinOp, 2615 Fortran::parser::AccClause::Copyin>( 2616 copyinClause, converter, semanticsContext, stmtCtx, 2617 Fortran::parser::AccDataModifier::Modifier::ReadOnly, 2618 dataClauseOperands, mlir::acc::DataClause::acc_copyin, 2619 mlir::acc::DataClause::acc_copyin_readonly, async, asyncDeviceTypes, 2620 asyncOnlyDeviceTypes); 2621 copyinEntryOperands.append(dataClauseOperands.begin() + crtDataStart, 2622 dataClauseOperands.end()); 2623 } else if (const auto *copyoutClause = 2624 std::get_if<Fortran::parser::AccClause::Copyout>( 2625 &clause.u)) { 2626 auto crtDataStart = dataClauseOperands.size(); 2627 genDataOperandOperationsWithModifier<mlir::acc::CreateOp, 2628 Fortran::parser::AccClause::Copyout>( 2629 copyoutClause, converter, semanticsContext, stmtCtx, 2630 Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands, 2631 mlir::acc::DataClause::acc_copyout, 2632 mlir::acc::DataClause::acc_copyout_zero, async, asyncDeviceTypes, 2633 asyncOnlyDeviceTypes); 2634 copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart, 2635 dataClauseOperands.end()); 2636 } else if (const auto *createClause = 2637 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) { 2638 auto crtDataStart = dataClauseOperands.size(); 2639 genDataOperandOperationsWithModifier<mlir::acc::CreateOp, 2640 Fortran::parser::AccClause::Create>( 2641 createClause, converter, semanticsContext, stmtCtx, 2642 Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands, 2643 mlir::acc::DataClause::acc_create, 2644 mlir::acc::DataClause::acc_create_zero, async, asyncDeviceTypes, 2645 asyncOnlyDeviceTypes); 2646 createEntryOperands.append(dataClauseOperands.begin() + crtDataStart, 2647 dataClauseOperands.end()); 2648 } else if (const auto *noCreateClause = 2649 std::get_if<Fortran::parser::AccClause::NoCreate>( 2650 &clause.u)) { 2651 genDataOperandOperations<mlir::acc::NoCreateOp>( 2652 noCreateClause->v, converter, semanticsContext, stmtCtx, 2653 dataClauseOperands, mlir::acc::DataClause::acc_no_create, 2654 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes, 2655 asyncOnlyDeviceTypes); 2656 } else if (const auto *presentClause = 2657 std::get_if<Fortran::parser::AccClause::Present>( 2658 &clause.u)) { 2659 genDataOperandOperations<mlir::acc::PresentOp>( 2660 presentClause->v, converter, semanticsContext, stmtCtx, 2661 dataClauseOperands, mlir::acc::DataClause::acc_present, 2662 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes, 2663 asyncOnlyDeviceTypes); 2664 } else if (const auto *deviceptrClause = 2665 std::get_if<Fortran::parser::AccClause::Deviceptr>( 2666 &clause.u)) { 2667 genDataOperandOperations<mlir::acc::DevicePtrOp>( 2668 deviceptrClause->v, converter, semanticsContext, stmtCtx, 2669 dataClauseOperands, mlir::acc::DataClause::acc_deviceptr, 2670 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes, 2671 asyncOnlyDeviceTypes); 2672 } else if (const auto *attachClause = 2673 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) { 2674 auto crtDataStart = dataClauseOperands.size(); 2675 genDataOperandOperations<mlir::acc::AttachOp>( 2676 attachClause->v, converter, semanticsContext, stmtCtx, 2677 dataClauseOperands, mlir::acc::DataClause::acc_attach, 2678 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes, 2679 asyncOnlyDeviceTypes); 2680 attachEntryOperands.append(dataClauseOperands.begin() + crtDataStart, 2681 dataClauseOperands.end()); 2682 } else if (const auto *defaultClause = 2683 std::get_if<Fortran::parser::AccClause::Default>( 2684 &clause.u)) { 2685 if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none) 2686 hasDefaultNone = true; 2687 else if ((defaultClause->v).v == 2688 llvm::acc::DefaultValue::ACC_Default_present) 2689 hasDefaultPresent = true; 2690 } 2691 } 2692 2693 // Prepare the operand segment size attribute and the operands value range. 2694 llvm::SmallVector<mlir::Value> operands; 2695 llvm::SmallVector<int32_t> operandSegments; 2696 addOperand(operands, operandSegments, ifCond); 2697 addOperands(operands, operandSegments, async); 2698 addOperands(operands, operandSegments, waitOperands); 2699 addOperands(operands, operandSegments, dataClauseOperands); 2700 2701 if (dataClauseOperands.empty() && !hasDefaultNone && !hasDefaultPresent) 2702 return; 2703 2704 auto dataOp = createRegionOp<mlir::acc::DataOp, mlir::acc::TerminatorOp>( 2705 builder, currentLocation, currentLocation, eval, operands, 2706 operandSegments); 2707 2708 if (!asyncDeviceTypes.empty()) 2709 dataOp.setAsyncOperandsDeviceTypeAttr( 2710 builder.getArrayAttr(asyncDeviceTypes)); 2711 if (!asyncOnlyDeviceTypes.empty()) 2712 dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes)); 2713 if (!waitOperandsDeviceTypes.empty()) 2714 dataOp.setWaitOperandsDeviceTypeAttr( 2715 builder.getArrayAttr(waitOperandsDeviceTypes)); 2716 if (!waitOperandsSegments.empty()) 2717 dataOp.setWaitOperandsSegmentsAttr( 2718 builder.getDenseI32ArrayAttr(waitOperandsSegments)); 2719 if (!hasWaitDevnums.empty()) 2720 dataOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums)); 2721 if (!waitOnlyDeviceTypes.empty()) 2722 dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes)); 2723 2724 if (hasDefaultNone) 2725 dataOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::None); 2726 if (hasDefaultPresent) 2727 dataOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present); 2728 2729 auto insPt = builder.saveInsertionPoint(); 2730 builder.setInsertionPointAfter(dataOp); 2731 2732 // Create the exit operations after the region. 2733 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>( 2734 builder, copyEntryOperands, /*structured=*/true); 2735 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::DeleteOp>( 2736 builder, copyinEntryOperands, /*structured=*/true); 2737 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>( 2738 builder, copyoutEntryOperands, /*structured=*/true); 2739 genDataExitOperations<mlir::acc::AttachOp, mlir::acc::DetachOp>( 2740 builder, attachEntryOperands, /*structured=*/true); 2741 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>( 2742 builder, createEntryOperands, /*structured=*/true); 2743 2744 builder.restoreInsertionPoint(insPt); 2745 } 2746 2747 static void 2748 genACCHostDataOp(Fortran::lower::AbstractConverter &converter, 2749 mlir::Location currentLocation, 2750 Fortran::lower::pft::Evaluation &eval, 2751 Fortran::semantics::SemanticsContext &semanticsContext, 2752 Fortran::lower::StatementContext &stmtCtx, 2753 const Fortran::parser::AccClauseList &accClauseList) { 2754 mlir::Value ifCond; 2755 llvm::SmallVector<mlir::Value> dataOperands; 2756 bool addIfPresentAttr = false; 2757 2758 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 2759 2760 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 2761 mlir::Location clauseLocation = converter.genLocation(clause.source); 2762 if (const auto *ifClause = 2763 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) { 2764 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx); 2765 } else if (const auto *useDevice = 2766 std::get_if<Fortran::parser::AccClause::UseDevice>( 2767 &clause.u)) { 2768 genDataOperandOperations<mlir::acc::UseDeviceOp>( 2769 useDevice->v, converter, semanticsContext, stmtCtx, dataOperands, 2770 mlir::acc::DataClause::acc_use_device, 2771 /*structured=*/true, /*implicit=*/false, /*async=*/{}, 2772 /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); 2773 } else if (std::get_if<Fortran::parser::AccClause::IfPresent>(&clause.u)) { 2774 addIfPresentAttr = true; 2775 } 2776 } 2777 2778 if (ifCond) { 2779 if (auto cst = 2780 mlir::dyn_cast<mlir::arith::ConstantOp>(ifCond.getDefiningOp())) 2781 if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(cst.getValue())) { 2782 if (boolAttr.getValue()) { 2783 // get rid of the if condition if it is always true. 2784 ifCond = mlir::Value(); 2785 } else { 2786 // Do not generate the acc.host_data op if the if condition is always 2787 // false. 2788 return; 2789 } 2790 } 2791 } 2792 2793 // Prepare the operand segment size attribute and the operands value range. 2794 llvm::SmallVector<mlir::Value> operands; 2795 llvm::SmallVector<int32_t> operandSegments; 2796 addOperand(operands, operandSegments, ifCond); 2797 addOperands(operands, operandSegments, dataOperands); 2798 2799 auto hostDataOp = 2800 createRegionOp<mlir::acc::HostDataOp, mlir::acc::TerminatorOp>( 2801 builder, currentLocation, currentLocation, eval, operands, 2802 operandSegments); 2803 2804 if (addIfPresentAttr) 2805 hostDataOp.setIfPresentAttr(builder.getUnitAttr()); 2806 } 2807 2808 static void 2809 genACC(Fortran::lower::AbstractConverter &converter, 2810 Fortran::semantics::SemanticsContext &semanticsContext, 2811 Fortran::lower::pft::Evaluation &eval, 2812 const Fortran::parser::OpenACCBlockConstruct &blockConstruct) { 2813 const auto &beginBlockDirective = 2814 std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t); 2815 const auto &blockDirective = 2816 std::get<Fortran::parser::AccBlockDirective>(beginBlockDirective.t); 2817 const auto &accClauseList = 2818 std::get<Fortran::parser::AccClauseList>(beginBlockDirective.t); 2819 2820 mlir::Location currentLocation = converter.genLocation(blockDirective.source); 2821 Fortran::lower::StatementContext stmtCtx; 2822 2823 if (blockDirective.v == llvm::acc::ACCD_parallel) { 2824 createComputeOp<mlir::acc::ParallelOp>(converter, currentLocation, eval, 2825 semanticsContext, stmtCtx, 2826 accClauseList); 2827 } else if (blockDirective.v == llvm::acc::ACCD_data) { 2828 genACCDataOp(converter, currentLocation, eval, semanticsContext, stmtCtx, 2829 accClauseList); 2830 } else if (blockDirective.v == llvm::acc::ACCD_serial) { 2831 createComputeOp<mlir::acc::SerialOp>(converter, currentLocation, eval, 2832 semanticsContext, stmtCtx, 2833 accClauseList); 2834 } else if (blockDirective.v == llvm::acc::ACCD_kernels) { 2835 createComputeOp<mlir::acc::KernelsOp>(converter, currentLocation, eval, 2836 semanticsContext, stmtCtx, 2837 accClauseList); 2838 } else if (blockDirective.v == llvm::acc::ACCD_host_data) { 2839 genACCHostDataOp(converter, currentLocation, eval, semanticsContext, 2840 stmtCtx, accClauseList); 2841 } 2842 } 2843 2844 static void 2845 genACC(Fortran::lower::AbstractConverter &converter, 2846 Fortran::semantics::SemanticsContext &semanticsContext, 2847 Fortran::lower::pft::Evaluation &eval, 2848 const Fortran::parser::OpenACCCombinedConstruct &combinedConstruct) { 2849 const auto &beginCombinedDirective = 2850 std::get<Fortran::parser::AccBeginCombinedDirective>(combinedConstruct.t); 2851 const auto &combinedDirective = 2852 std::get<Fortran::parser::AccCombinedDirective>(beginCombinedDirective.t); 2853 const auto &accClauseList = 2854 std::get<Fortran::parser::AccClauseList>(beginCombinedDirective.t); 2855 const auto &outerDoConstruct = 2856 std::get<std::optional<Fortran::parser::DoConstruct>>( 2857 combinedConstruct.t); 2858 2859 mlir::Location currentLocation = 2860 converter.genLocation(beginCombinedDirective.source); 2861 Fortran::lower::StatementContext stmtCtx; 2862 2863 if (combinedDirective.v == llvm::acc::ACCD_kernels_loop) { 2864 createComputeOp<mlir::acc::KernelsOp>( 2865 converter, currentLocation, eval, semanticsContext, stmtCtx, 2866 accClauseList, mlir::acc::CombinedConstructsType::KernelsLoop); 2867 createLoopOp(converter, currentLocation, semanticsContext, stmtCtx, 2868 *outerDoConstruct, eval, accClauseList, 2869 mlir::acc::CombinedConstructsType::KernelsLoop); 2870 } else if (combinedDirective.v == llvm::acc::ACCD_parallel_loop) { 2871 createComputeOp<mlir::acc::ParallelOp>( 2872 converter, currentLocation, eval, semanticsContext, stmtCtx, 2873 accClauseList, mlir::acc::CombinedConstructsType::ParallelLoop); 2874 createLoopOp(converter, currentLocation, semanticsContext, stmtCtx, 2875 *outerDoConstruct, eval, accClauseList, 2876 mlir::acc::CombinedConstructsType::ParallelLoop); 2877 } else if (combinedDirective.v == llvm::acc::ACCD_serial_loop) { 2878 createComputeOp<mlir::acc::SerialOp>( 2879 converter, currentLocation, eval, semanticsContext, stmtCtx, 2880 accClauseList, mlir::acc::CombinedConstructsType::SerialLoop); 2881 createLoopOp(converter, currentLocation, semanticsContext, stmtCtx, 2882 *outerDoConstruct, eval, accClauseList, 2883 mlir::acc::CombinedConstructsType::SerialLoop); 2884 } else { 2885 llvm::report_fatal_error("Unknown combined construct encountered"); 2886 } 2887 } 2888 2889 static void 2890 genACCEnterDataOp(Fortran::lower::AbstractConverter &converter, 2891 mlir::Location currentLocation, 2892 Fortran::semantics::SemanticsContext &semanticsContext, 2893 Fortran::lower::StatementContext &stmtCtx, 2894 const Fortran::parser::AccClauseList &accClauseList) { 2895 mlir::Value ifCond, async, waitDevnum; 2896 llvm::SmallVector<mlir::Value> waitOperands, dataClauseOperands; 2897 2898 // Async, wait and self clause have optional values but can be present with 2899 // no value as well. When there is no value, the op has an attribute to 2900 // represent the clause. 2901 bool addAsyncAttr = false; 2902 bool addWaitAttr = false; 2903 2904 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 2905 2906 // Lower clauses values mapped to operands. 2907 // Keep track of each group of operands separately as clauses can appear 2908 // more than once. 2909 2910 // Process the async clause first. 2911 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 2912 if (const auto *asyncClause = 2913 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) { 2914 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx); 2915 } 2916 } 2917 2918 // The async clause of 'enter data' applies to all device types, 2919 // so propagate the async clause to copyin/create/attach ops 2920 // as if it is an async clause without preceding device_type clause. 2921 llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes; 2922 llvm::SmallVector<mlir::Value> asyncValues; 2923 auto noneDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( 2924 firOpBuilder.getContext(), mlir::acc::DeviceType::None); 2925 if (addAsyncAttr) { 2926 asyncOnlyDeviceTypes.push_back(noneDeviceTypeAttr); 2927 } else if (async) { 2928 asyncValues.push_back(async); 2929 asyncDeviceTypes.push_back(noneDeviceTypeAttr); 2930 } 2931 2932 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 2933 mlir::Location clauseLocation = converter.genLocation(clause.source); 2934 if (const auto *ifClause = 2935 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) { 2936 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx); 2937 } else if (const auto *waitClause = 2938 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) { 2939 genWaitClause(converter, waitClause, waitOperands, waitDevnum, 2940 addWaitAttr, stmtCtx); 2941 } else if (const auto *copyinClause = 2942 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) { 2943 const Fortran::parser::AccObjectListWithModifier &listWithModifier = 2944 copyinClause->v; 2945 const auto &accObjectList = 2946 std::get<Fortran::parser::AccObjectList>(listWithModifier.t); 2947 genDataOperandOperations<mlir::acc::CopyinOp>( 2948 accObjectList, converter, semanticsContext, stmtCtx, 2949 dataClauseOperands, mlir::acc::DataClause::acc_copyin, false, 2950 /*implicit=*/false, asyncValues, asyncDeviceTypes, 2951 asyncOnlyDeviceTypes); 2952 } else if (const auto *createClause = 2953 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) { 2954 const Fortran::parser::AccObjectListWithModifier &listWithModifier = 2955 createClause->v; 2956 const auto &accObjectList = 2957 std::get<Fortran::parser::AccObjectList>(listWithModifier.t); 2958 const auto &modifier = 2959 std::get<std::optional<Fortran::parser::AccDataModifier>>( 2960 listWithModifier.t); 2961 mlir::acc::DataClause clause = mlir::acc::DataClause::acc_create; 2962 if (modifier && 2963 (*modifier).v == Fortran::parser::AccDataModifier::Modifier::Zero) 2964 clause = mlir::acc::DataClause::acc_create_zero; 2965 genDataOperandOperations<mlir::acc::CreateOp>( 2966 accObjectList, converter, semanticsContext, stmtCtx, 2967 dataClauseOperands, clause, false, /*implicit=*/false, asyncValues, 2968 asyncDeviceTypes, asyncOnlyDeviceTypes); 2969 } else if (const auto *attachClause = 2970 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) { 2971 genDataOperandOperations<mlir::acc::AttachOp>( 2972 attachClause->v, converter, semanticsContext, stmtCtx, 2973 dataClauseOperands, mlir::acc::DataClause::acc_attach, false, 2974 /*implicit=*/false, asyncValues, asyncDeviceTypes, 2975 asyncOnlyDeviceTypes); 2976 } else if (!std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) { 2977 llvm::report_fatal_error( 2978 "Unknown clause in ENTER DATA directive lowering"); 2979 } 2980 } 2981 2982 // Prepare the operand segment size attribute and the operands value range. 2983 llvm::SmallVector<mlir::Value, 16> operands; 2984 llvm::SmallVector<int32_t, 8> operandSegments; 2985 addOperand(operands, operandSegments, ifCond); 2986 addOperand(operands, operandSegments, async); 2987 addOperand(operands, operandSegments, waitDevnum); 2988 addOperands(operands, operandSegments, waitOperands); 2989 addOperands(operands, operandSegments, dataClauseOperands); 2990 2991 mlir::acc::EnterDataOp enterDataOp = createSimpleOp<mlir::acc::EnterDataOp>( 2992 firOpBuilder, currentLocation, operands, operandSegments); 2993 2994 if (addAsyncAttr) 2995 enterDataOp.setAsyncAttr(firOpBuilder.getUnitAttr()); 2996 if (addWaitAttr) 2997 enterDataOp.setWaitAttr(firOpBuilder.getUnitAttr()); 2998 } 2999 3000 static void 3001 genACCExitDataOp(Fortran::lower::AbstractConverter &converter, 3002 mlir::Location currentLocation, 3003 Fortran::semantics::SemanticsContext &semanticsContext, 3004 Fortran::lower::StatementContext &stmtCtx, 3005 const Fortran::parser::AccClauseList &accClauseList) { 3006 mlir::Value ifCond, async, waitDevnum; 3007 llvm::SmallVector<mlir::Value> waitOperands, dataClauseOperands, 3008 copyoutOperands, deleteOperands, detachOperands; 3009 3010 // Async and wait clause have optional values but can be present with 3011 // no value as well. When there is no value, the op has an attribute to 3012 // represent the clause. 3013 bool addAsyncAttr = false; 3014 bool addWaitAttr = false; 3015 bool addFinalizeAttr = false; 3016 3017 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 3018 3019 // Lower clauses values mapped to operands. 3020 // Keep track of each group of operands separately as clauses can appear 3021 // more than once. 3022 3023 // Process the async clause first. 3024 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 3025 if (const auto *asyncClause = 3026 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) { 3027 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx); 3028 } 3029 } 3030 3031 // The async clause of 'exit data' applies to all device types, 3032 // so propagate the async clause to copyin/create/attach ops 3033 // as if it is an async clause without preceding device_type clause. 3034 llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes; 3035 llvm::SmallVector<mlir::Value> asyncValues; 3036 auto noneDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( 3037 builder.getContext(), mlir::acc::DeviceType::None); 3038 if (addAsyncAttr) { 3039 asyncOnlyDeviceTypes.push_back(noneDeviceTypeAttr); 3040 } else if (async) { 3041 asyncValues.push_back(async); 3042 asyncDeviceTypes.push_back(noneDeviceTypeAttr); 3043 } 3044 3045 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 3046 mlir::Location clauseLocation = converter.genLocation(clause.source); 3047 if (const auto *ifClause = 3048 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) { 3049 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx); 3050 } else if (const auto *waitClause = 3051 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) { 3052 genWaitClause(converter, waitClause, waitOperands, waitDevnum, 3053 addWaitAttr, stmtCtx); 3054 } else if (const auto *copyoutClause = 3055 std::get_if<Fortran::parser::AccClause::Copyout>( 3056 &clause.u)) { 3057 const Fortran::parser::AccObjectListWithModifier &listWithModifier = 3058 copyoutClause->v; 3059 const auto &accObjectList = 3060 std::get<Fortran::parser::AccObjectList>(listWithModifier.t); 3061 genDataOperandOperations<mlir::acc::GetDevicePtrOp>( 3062 accObjectList, converter, semanticsContext, stmtCtx, copyoutOperands, 3063 mlir::acc::DataClause::acc_copyout, false, /*implicit=*/false, 3064 asyncValues, asyncDeviceTypes, asyncOnlyDeviceTypes); 3065 } else if (const auto *deleteClause = 3066 std::get_if<Fortran::parser::AccClause::Delete>(&clause.u)) { 3067 genDataOperandOperations<mlir::acc::GetDevicePtrOp>( 3068 deleteClause->v, converter, semanticsContext, stmtCtx, deleteOperands, 3069 mlir::acc::DataClause::acc_delete, false, /*implicit=*/false, 3070 asyncValues, asyncDeviceTypes, asyncOnlyDeviceTypes); 3071 } else if (const auto *detachClause = 3072 std::get_if<Fortran::parser::AccClause::Detach>(&clause.u)) { 3073 genDataOperandOperations<mlir::acc::GetDevicePtrOp>( 3074 detachClause->v, converter, semanticsContext, stmtCtx, detachOperands, 3075 mlir::acc::DataClause::acc_detach, false, /*implicit=*/false, 3076 asyncValues, asyncDeviceTypes, asyncOnlyDeviceTypes); 3077 } else if (std::get_if<Fortran::parser::AccClause::Finalize>(&clause.u)) { 3078 addFinalizeAttr = true; 3079 } 3080 } 3081 3082 dataClauseOperands.append(copyoutOperands); 3083 dataClauseOperands.append(deleteOperands); 3084 dataClauseOperands.append(detachOperands); 3085 3086 // Prepare the operand segment size attribute and the operands value range. 3087 llvm::SmallVector<mlir::Value, 14> operands; 3088 llvm::SmallVector<int32_t, 7> operandSegments; 3089 addOperand(operands, operandSegments, ifCond); 3090 addOperand(operands, operandSegments, async); 3091 addOperand(operands, operandSegments, waitDevnum); 3092 addOperands(operands, operandSegments, waitOperands); 3093 addOperands(operands, operandSegments, dataClauseOperands); 3094 3095 mlir::acc::ExitDataOp exitDataOp = createSimpleOp<mlir::acc::ExitDataOp>( 3096 builder, currentLocation, operands, operandSegments); 3097 3098 if (addAsyncAttr) 3099 exitDataOp.setAsyncAttr(builder.getUnitAttr()); 3100 if (addWaitAttr) 3101 exitDataOp.setWaitAttr(builder.getUnitAttr()); 3102 if (addFinalizeAttr) 3103 exitDataOp.setFinalizeAttr(builder.getUnitAttr()); 3104 3105 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::CopyoutOp>( 3106 builder, copyoutOperands, /*structured=*/false); 3107 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::DeleteOp>( 3108 builder, deleteOperands, /*structured=*/false); 3109 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::DetachOp>( 3110 builder, detachOperands, /*structured=*/false); 3111 } 3112 3113 template <typename Op> 3114 static void 3115 genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter, 3116 mlir::Location currentLocation, 3117 const Fortran::parser::AccClauseList &accClauseList) { 3118 mlir::Value ifCond, deviceNum; 3119 3120 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 3121 Fortran::lower::StatementContext stmtCtx; 3122 llvm::SmallVector<mlir::Attribute> deviceTypes; 3123 3124 // Lower clauses values mapped to operands. 3125 // Keep track of each group of operands separately as clauses can appear 3126 // more than once. 3127 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 3128 mlir::Location clauseLocation = converter.genLocation(clause.source); 3129 if (const auto *ifClause = 3130 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) { 3131 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx); 3132 } else if (const auto *deviceNumClause = 3133 std::get_if<Fortran::parser::AccClause::DeviceNum>( 3134 &clause.u)) { 3135 deviceNum = fir::getBase(converter.genExprValue( 3136 *Fortran::semantics::GetExpr(deviceNumClause->v), stmtCtx)); 3137 } else if (const auto *deviceTypeClause = 3138 std::get_if<Fortran::parser::AccClause::DeviceType>( 3139 &clause.u)) { 3140 gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes); 3141 } 3142 } 3143 3144 // Prepare the operand segment size attribute and the operands value range. 3145 llvm::SmallVector<mlir::Value, 6> operands; 3146 llvm::SmallVector<int32_t, 2> operandSegments; 3147 3148 addOperand(operands, operandSegments, deviceNum); 3149 addOperand(operands, operandSegments, ifCond); 3150 3151 Op op = 3152 createSimpleOp<Op>(builder, currentLocation, operands, operandSegments); 3153 if (!deviceTypes.empty()) 3154 op.setDeviceTypesAttr( 3155 mlir::ArrayAttr::get(builder.getContext(), deviceTypes)); 3156 } 3157 3158 void genACCSetOp(Fortran::lower::AbstractConverter &converter, 3159 mlir::Location currentLocation, 3160 const Fortran::parser::AccClauseList &accClauseList) { 3161 mlir::Value ifCond, deviceNum, defaultAsync; 3162 llvm::SmallVector<mlir::Value> deviceTypeOperands; 3163 3164 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 3165 Fortran::lower::StatementContext stmtCtx; 3166 llvm::SmallVector<mlir::Attribute> deviceTypes; 3167 3168 // Lower clauses values mapped to operands. 3169 // Keep track of each group of operands separately as clauses can appear 3170 // more than once. 3171 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 3172 mlir::Location clauseLocation = converter.genLocation(clause.source); 3173 if (const auto *ifClause = 3174 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) { 3175 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx); 3176 } else if (const auto *defaultAsyncClause = 3177 std::get_if<Fortran::parser::AccClause::DefaultAsync>( 3178 &clause.u)) { 3179 defaultAsync = fir::getBase(converter.genExprValue( 3180 *Fortran::semantics::GetExpr(defaultAsyncClause->v), stmtCtx)); 3181 } else if (const auto *deviceNumClause = 3182 std::get_if<Fortran::parser::AccClause::DeviceNum>( 3183 &clause.u)) { 3184 deviceNum = fir::getBase(converter.genExprValue( 3185 *Fortran::semantics::GetExpr(deviceNumClause->v), stmtCtx)); 3186 } else if (const auto *deviceTypeClause = 3187 std::get_if<Fortran::parser::AccClause::DeviceType>( 3188 &clause.u)) { 3189 gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes); 3190 } 3191 } 3192 3193 // Prepare the operand segment size attribute and the operands value range. 3194 llvm::SmallVector<mlir::Value> operands; 3195 llvm::SmallVector<int32_t, 3> operandSegments; 3196 addOperand(operands, operandSegments, defaultAsync); 3197 addOperand(operands, operandSegments, deviceNum); 3198 addOperand(operands, operandSegments, ifCond); 3199 3200 auto op = createSimpleOp<mlir::acc::SetOp>(builder, currentLocation, operands, 3201 operandSegments); 3202 if (!deviceTypes.empty()) { 3203 assert(deviceTypes.size() == 1 && "expect only one value for acc.set"); 3204 op.setDeviceTypeAttr(mlir::cast<mlir::acc::DeviceTypeAttr>(deviceTypes[0])); 3205 } 3206 } 3207 3208 static inline mlir::ArrayAttr 3209 getArrayAttr(fir::FirOpBuilder &b, 3210 llvm::SmallVector<mlir::Attribute> &attributes) { 3211 return attributes.empty() ? nullptr : b.getArrayAttr(attributes); 3212 } 3213 3214 static inline mlir::ArrayAttr 3215 getBoolArrayAttr(fir::FirOpBuilder &b, llvm::SmallVector<bool> &values) { 3216 return values.empty() ? nullptr : b.getBoolArrayAttr(values); 3217 } 3218 3219 static inline mlir::DenseI32ArrayAttr 3220 getDenseI32ArrayAttr(fir::FirOpBuilder &builder, 3221 llvm::SmallVector<int32_t> &values) { 3222 return values.empty() ? nullptr : builder.getDenseI32ArrayAttr(values); 3223 } 3224 3225 static void 3226 genACCUpdateOp(Fortran::lower::AbstractConverter &converter, 3227 mlir::Location currentLocation, 3228 Fortran::semantics::SemanticsContext &semanticsContext, 3229 Fortran::lower::StatementContext &stmtCtx, 3230 const Fortran::parser::AccClauseList &accClauseList) { 3231 mlir::Value ifCond; 3232 llvm::SmallVector<mlir::Value> dataClauseOperands, updateHostOperands, 3233 waitOperands, deviceTypeOperands, asyncOperands; 3234 llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes, 3235 asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes; 3236 llvm::SmallVector<bool> hasWaitDevnums; 3237 llvm::SmallVector<int32_t> waitOperandsSegments; 3238 3239 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 3240 3241 // device_type attribute is set to `none` until a device_type clause is 3242 // encountered. 3243 llvm::SmallVector<mlir::Attribute> crtDeviceTypes; 3244 crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get( 3245 builder.getContext(), mlir::acc::DeviceType::None)); 3246 3247 bool ifPresent = false; 3248 3249 // Lower clauses values mapped to operands and array attributes. 3250 // Keep track of each group of operands separately as clauses can appear 3251 // more than once. 3252 3253 // Process the clauses that may have a specified device_type first. 3254 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 3255 if (const auto *asyncClause = 3256 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) { 3257 genAsyncClause(converter, asyncClause, asyncOperands, 3258 asyncOperandsDeviceTypes, asyncOnlyDeviceTypes, 3259 crtDeviceTypes, stmtCtx); 3260 } else if (const auto *waitClause = 3261 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) { 3262 genWaitClauseWithDeviceType(converter, waitClause, waitOperands, 3263 waitOperandsDeviceTypes, waitOnlyDeviceTypes, 3264 hasWaitDevnums, waitOperandsSegments, 3265 crtDeviceTypes, stmtCtx); 3266 } else if (const auto *deviceTypeClause = 3267 std::get_if<Fortran::parser::AccClause::DeviceType>( 3268 &clause.u)) { 3269 crtDeviceTypes.clear(); 3270 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes); 3271 } 3272 } 3273 3274 // Process the clauses independent of device_type. 3275 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 3276 mlir::Location clauseLocation = converter.genLocation(clause.source); 3277 if (const auto *ifClause = 3278 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) { 3279 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx); 3280 } else if (const auto *hostClause = 3281 std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) { 3282 genDataOperandOperations<mlir::acc::GetDevicePtrOp>( 3283 hostClause->v, converter, semanticsContext, stmtCtx, 3284 updateHostOperands, mlir::acc::DataClause::acc_update_host, false, 3285 /*implicit=*/false, asyncOperands, asyncOperandsDeviceTypes, 3286 asyncOnlyDeviceTypes); 3287 } else if (const auto *deviceClause = 3288 std::get_if<Fortran::parser::AccClause::Device>(&clause.u)) { 3289 genDataOperandOperations<mlir::acc::UpdateDeviceOp>( 3290 deviceClause->v, converter, semanticsContext, stmtCtx, 3291 dataClauseOperands, mlir::acc::DataClause::acc_update_device, false, 3292 /*implicit=*/false, asyncOperands, asyncOperandsDeviceTypes, 3293 asyncOnlyDeviceTypes); 3294 } else if (std::get_if<Fortran::parser::AccClause::IfPresent>(&clause.u)) { 3295 ifPresent = true; 3296 } else if (const auto *selfClause = 3297 std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) { 3298 const std::optional<Fortran::parser::AccSelfClause> &accSelfClause = 3299 selfClause->v; 3300 const auto *accObjectList = 3301 std::get_if<Fortran::parser::AccObjectList>(&(*accSelfClause).u); 3302 assert(accObjectList && "expect AccObjectList"); 3303 genDataOperandOperations<mlir::acc::GetDevicePtrOp>( 3304 *accObjectList, converter, semanticsContext, stmtCtx, 3305 updateHostOperands, mlir::acc::DataClause::acc_update_self, false, 3306 /*implicit=*/false, asyncOperands, asyncOperandsDeviceTypes, 3307 asyncOnlyDeviceTypes); 3308 } 3309 } 3310 3311 dataClauseOperands.append(updateHostOperands); 3312 3313 builder.create<mlir::acc::UpdateOp>( 3314 currentLocation, ifCond, asyncOperands, 3315 getArrayAttr(builder, asyncOperandsDeviceTypes), 3316 getArrayAttr(builder, asyncOnlyDeviceTypes), waitOperands, 3317 getDenseI32ArrayAttr(builder, waitOperandsSegments), 3318 getArrayAttr(builder, waitOperandsDeviceTypes), 3319 getBoolArrayAttr(builder, hasWaitDevnums), 3320 getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands, 3321 ifPresent); 3322 3323 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::UpdateHostOp>( 3324 builder, updateHostOperands, /*structured=*/false); 3325 } 3326 3327 static void 3328 genACC(Fortran::lower::AbstractConverter &converter, 3329 Fortran::semantics::SemanticsContext &semanticsContext, 3330 const Fortran::parser::OpenACCStandaloneConstruct &standaloneConstruct) { 3331 const auto &standaloneDirective = 3332 std::get<Fortran::parser::AccStandaloneDirective>(standaloneConstruct.t); 3333 const auto &accClauseList = 3334 std::get<Fortran::parser::AccClauseList>(standaloneConstruct.t); 3335 3336 mlir::Location currentLocation = 3337 converter.genLocation(standaloneDirective.source); 3338 Fortran::lower::StatementContext stmtCtx; 3339 3340 if (standaloneDirective.v == llvm::acc::Directive::ACCD_enter_data) { 3341 genACCEnterDataOp(converter, currentLocation, semanticsContext, stmtCtx, 3342 accClauseList); 3343 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_exit_data) { 3344 genACCExitDataOp(converter, currentLocation, semanticsContext, stmtCtx, 3345 accClauseList); 3346 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_init) { 3347 genACCInitShutdownOp<mlir::acc::InitOp>(converter, currentLocation, 3348 accClauseList); 3349 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_shutdown) { 3350 genACCInitShutdownOp<mlir::acc::ShutdownOp>(converter, currentLocation, 3351 accClauseList); 3352 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_set) { 3353 genACCSetOp(converter, currentLocation, accClauseList); 3354 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_update) { 3355 genACCUpdateOp(converter, currentLocation, semanticsContext, stmtCtx, 3356 accClauseList); 3357 } 3358 } 3359 3360 static void genACC(Fortran::lower::AbstractConverter &converter, 3361 const Fortran::parser::OpenACCWaitConstruct &waitConstruct) { 3362 3363 const auto &waitArgument = 3364 std::get<std::optional<Fortran::parser::AccWaitArgument>>( 3365 waitConstruct.t); 3366 const auto &accClauseList = 3367 std::get<Fortran::parser::AccClauseList>(waitConstruct.t); 3368 3369 mlir::Value ifCond, waitDevnum, async; 3370 llvm::SmallVector<mlir::Value> waitOperands; 3371 3372 // Async clause have optional values but can be present with 3373 // no value as well. When there is no value, the op has an attribute to 3374 // represent the clause. 3375 bool addAsyncAttr = false; 3376 3377 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 3378 mlir::Location currentLocation = converter.genLocation(waitConstruct.source); 3379 Fortran::lower::StatementContext stmtCtx; 3380 3381 if (waitArgument) { // wait has a value. 3382 const Fortran::parser::AccWaitArgument &waitArg = *waitArgument; 3383 const auto &waitList = 3384 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t); 3385 for (const Fortran::parser::ScalarIntExpr &value : waitList) { 3386 mlir::Value v = fir::getBase( 3387 converter.genExprValue(*Fortran::semantics::GetExpr(value), stmtCtx)); 3388 waitOperands.push_back(v); 3389 } 3390 3391 const auto &waitDevnumValue = 3392 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t); 3393 if (waitDevnumValue) 3394 waitDevnum = fir::getBase(converter.genExprValue( 3395 *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)); 3396 } 3397 3398 // Lower clauses values mapped to operands. 3399 // Keep track of each group of operands separately as clauses can appear 3400 // more than once. 3401 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 3402 mlir::Location clauseLocation = converter.genLocation(clause.source); 3403 if (const auto *ifClause = 3404 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) { 3405 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx); 3406 } else if (const auto *asyncClause = 3407 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) { 3408 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx); 3409 } 3410 } 3411 3412 // Prepare the operand segment size attribute and the operands value range. 3413 llvm::SmallVector<mlir::Value> operands; 3414 llvm::SmallVector<int32_t> operandSegments; 3415 addOperands(operands, operandSegments, waitOperands); 3416 addOperand(operands, operandSegments, async); 3417 addOperand(operands, operandSegments, waitDevnum); 3418 addOperand(operands, operandSegments, ifCond); 3419 3420 mlir::acc::WaitOp waitOp = createSimpleOp<mlir::acc::WaitOp>( 3421 firOpBuilder, currentLocation, operands, operandSegments); 3422 3423 if (addAsyncAttr) 3424 waitOp.setAsyncAttr(firOpBuilder.getUnitAttr()); 3425 } 3426 3427 template <typename GlobalOp, typename EntryOp, typename DeclareOp, 3428 typename ExitOp> 3429 static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder, 3430 fir::FirOpBuilder &builder, 3431 mlir::Location loc, fir::GlobalOp globalOp, 3432 mlir::acc::DataClause clause, 3433 const std::string &declareGlobalName, 3434 bool implicit, std::stringstream &asFortran) { 3435 GlobalOp declareGlobalOp = 3436 modBuilder.create<GlobalOp>(loc, declareGlobalName); 3437 builder.createBlock(&declareGlobalOp.getRegion(), 3438 declareGlobalOp.getRegion().end(), {}, {}); 3439 builder.setInsertionPointToEnd(&declareGlobalOp.getRegion().back()); 3440 3441 fir::AddrOfOp addrOp = builder.create<fir::AddrOfOp>( 3442 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol()); 3443 addDeclareAttr(builder, addrOp, clause); 3444 3445 llvm::SmallVector<mlir::Value> bounds; 3446 EntryOp entryOp = createDataEntryOp<EntryOp>( 3447 builder, loc, addrOp.getResTy(), asFortran, bounds, 3448 /*structured=*/false, implicit, clause, addrOp.getResTy().getType(), 3449 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); 3450 if constexpr (std::is_same_v<DeclareOp, mlir::acc::DeclareEnterOp>) 3451 builder.create<DeclareOp>( 3452 loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()), 3453 mlir::ValueRange(entryOp.getAccPtr())); 3454 else 3455 builder.create<DeclareOp>(loc, mlir::Value{}, 3456 mlir::ValueRange(entryOp.getAccPtr())); 3457 if constexpr (std::is_same_v<GlobalOp, mlir::acc::GlobalDestructorOp>) { 3458 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(), 3459 entryOp.getBounds(), entryOp.getAsyncOperands(), 3460 entryOp.getAsyncOperandsDeviceTypeAttr(), 3461 entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), 3462 /*structured=*/false, /*implicit=*/false, 3463 builder.getStringAttr(*entryOp.getName())); 3464 } 3465 builder.create<mlir::acc::TerminatorOp>(loc); 3466 modBuilder.setInsertionPointAfter(declareGlobalOp); 3467 } 3468 3469 template <typename EntryOp> 3470 static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder, 3471 fir::FirOpBuilder &builder, 3472 mlir::Location loc, fir::GlobalOp &globalOp, 3473 mlir::acc::DataClause clause) { 3474 std::stringstream registerFuncName; 3475 registerFuncName << globalOp.getSymName().str() 3476 << Fortran::lower::declarePostAllocSuffix.str(); 3477 auto registerFuncOp = 3478 createDeclareFunc(modBuilder, builder, loc, registerFuncName.str()); 3479 3480 fir::AddrOfOp addrOp = builder.create<fir::AddrOfOp>( 3481 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol()); 3482 3483 std::stringstream asFortran; 3484 asFortran << Fortran::lower::mangle::demangleName(globalOp.getSymName()); 3485 std::stringstream asFortranDesc; 3486 asFortranDesc << asFortran.str() << accFirDescriptorPostfix.str(); 3487 llvm::SmallVector<mlir::Value> bounds; 3488 3489 // Updating descriptor must occur before the mapping of the data so that 3490 // attached data pointer is not overwritten. 3491 mlir::acc::UpdateDeviceOp updateDeviceOp = 3492 createDataEntryOp<mlir::acc::UpdateDeviceOp>( 3493 builder, loc, addrOp, asFortranDesc, bounds, 3494 /*structured=*/false, /*implicit=*/true, 3495 mlir::acc::DataClause::acc_update_device, addrOp.getType(), 3496 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); 3497 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1}; 3498 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()}; 3499 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments); 3500 3501 auto loadOp = builder.create<fir::LoadOp>(loc, addrOp.getResult()); 3502 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp); 3503 addDeclareAttr(builder, boxAddrOp.getOperation(), clause); 3504 EntryOp entryOp = createDataEntryOp<EntryOp>( 3505 builder, loc, boxAddrOp.getResult(), asFortran, bounds, 3506 /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType(), 3507 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); 3508 builder.create<mlir::acc::DeclareEnterOp>( 3509 loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()), 3510 mlir::ValueRange(entryOp.getAccPtr())); 3511 3512 modBuilder.setInsertionPointAfter(registerFuncOp); 3513 } 3514 3515 /// Action to be performed on deallocation are split in two distinct functions. 3516 /// - Pre deallocation function includes all the action to be performed before 3517 /// the actual deallocation is done on the host side. 3518 /// - Post deallocation function includes update to the descriptor. 3519 template <typename ExitOp> 3520 static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder, 3521 fir::FirOpBuilder &builder, 3522 mlir::Location loc, 3523 fir::GlobalOp &globalOp, 3524 mlir::acc::DataClause clause) { 3525 3526 // Generate the pre dealloc function. 3527 std::stringstream preDeallocFuncName; 3528 preDeallocFuncName << globalOp.getSymName().str() 3529 << Fortran::lower::declarePreDeallocSuffix.str(); 3530 auto preDeallocOp = 3531 createDeclareFunc(modBuilder, builder, loc, preDeallocFuncName.str()); 3532 fir::AddrOfOp addrOp = builder.create<fir::AddrOfOp>( 3533 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol()); 3534 auto loadOp = builder.create<fir::LoadOp>(loc, addrOp.getResult()); 3535 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp); 3536 addDeclareAttr(builder, boxAddrOp.getOperation(), clause); 3537 3538 std::stringstream asFortran; 3539 asFortran << Fortran::lower::mangle::demangleName(globalOp.getSymName()); 3540 llvm::SmallVector<mlir::Value> bounds; 3541 mlir::acc::GetDevicePtrOp entryOp = 3542 createDataEntryOp<mlir::acc::GetDevicePtrOp>( 3543 builder, loc, boxAddrOp.getResult(), asFortran, bounds, 3544 /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType(), 3545 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); 3546 3547 builder.create<mlir::acc::DeclareExitOp>( 3548 loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr())); 3549 3550 if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> || 3551 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>) 3552 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(), 3553 entryOp.getVarPtr(), entryOp.getBounds(), 3554 entryOp.getAsyncOperands(), 3555 entryOp.getAsyncOperandsDeviceTypeAttr(), 3556 entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), 3557 /*structured=*/false, /*implicit=*/false, 3558 builder.getStringAttr(*entryOp.getName())); 3559 else 3560 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(), 3561 entryOp.getBounds(), entryOp.getAsyncOperands(), 3562 entryOp.getAsyncOperandsDeviceTypeAttr(), 3563 entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), 3564 /*structured=*/false, /*implicit=*/false, 3565 builder.getStringAttr(*entryOp.getName())); 3566 3567 // Generate the post dealloc function. 3568 modBuilder.setInsertionPointAfter(preDeallocOp); 3569 std::stringstream postDeallocFuncName; 3570 postDeallocFuncName << globalOp.getSymName().str() 3571 << Fortran::lower::declarePostDeallocSuffix.str(); 3572 auto postDeallocOp = 3573 createDeclareFunc(modBuilder, builder, loc, postDeallocFuncName.str()); 3574 3575 addrOp = builder.create<fir::AddrOfOp>( 3576 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol()); 3577 asFortran << accFirDescriptorPostfix.str(); 3578 mlir::acc::UpdateDeviceOp updateDeviceOp = 3579 createDataEntryOp<mlir::acc::UpdateDeviceOp>( 3580 builder, loc, addrOp, asFortran, bounds, 3581 /*structured=*/false, /*implicit=*/true, 3582 mlir::acc::DataClause::acc_update_device, addrOp.getType(), 3583 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); 3584 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1}; 3585 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()}; 3586 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments); 3587 modBuilder.setInsertionPointAfter(postDeallocOp); 3588 } 3589 3590 template <typename EntryOp, typename ExitOp> 3591 static void genGlobalCtors(Fortran::lower::AbstractConverter &converter, 3592 mlir::OpBuilder &modBuilder, 3593 const Fortran::parser::AccObjectList &accObjectList, 3594 mlir::acc::DataClause clause) { 3595 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 3596 for (const auto &accObject : accObjectList.v) { 3597 mlir::Location operandLocation = genOperandLocation(converter, accObject); 3598 Fortran::common::visit( 3599 Fortran::common::visitors{ 3600 [&](const Fortran::parser::Designator &designator) { 3601 if (const auto *name = 3602 Fortran::semantics::getDesignatorNameIfDataRef( 3603 designator)) { 3604 std::string globalName = converter.mangleName(*name->symbol); 3605 fir::GlobalOp globalOp = builder.getNamedGlobal(globalName); 3606 std::stringstream declareGlobalCtorName; 3607 declareGlobalCtorName << globalName << "_acc_ctor"; 3608 std::stringstream declareGlobalDtorName; 3609 declareGlobalDtorName << globalName << "_acc_dtor"; 3610 std::stringstream asFortran; 3611 asFortran << name->symbol->name().ToString(); 3612 3613 if (builder.getModule() 3614 .lookupSymbol<mlir::acc::GlobalConstructorOp>( 3615 declareGlobalCtorName.str())) 3616 return; 3617 3618 if (!globalOp) { 3619 if (Fortran::semantics::FindEquivalenceSet(*name->symbol)) { 3620 for (Fortran::semantics::EquivalenceObject eqObj : 3621 *Fortran::semantics::FindEquivalenceSet( 3622 *name->symbol)) { 3623 std::string eqName = converter.mangleName(eqObj.symbol); 3624 globalOp = builder.getNamedGlobal(eqName); 3625 if (globalOp) 3626 break; 3627 } 3628 3629 if (!globalOp) 3630 llvm::report_fatal_error( 3631 "could not retrieve global symbol"); 3632 } else { 3633 llvm::report_fatal_error( 3634 "could not retrieve global symbol"); 3635 } 3636 } 3637 3638 addDeclareAttr(builder, globalOp.getOperation(), clause); 3639 auto crtPos = builder.saveInsertionPoint(); 3640 modBuilder.setInsertionPointAfter(globalOp); 3641 if (mlir::isa<fir::BaseBoxType>( 3642 fir::unwrapRefType(globalOp.getType()))) { 3643 createDeclareGlobalOp<mlir::acc::GlobalConstructorOp, 3644 mlir::acc::CopyinOp, 3645 mlir::acc::DeclareEnterOp, ExitOp>( 3646 modBuilder, builder, operandLocation, globalOp, clause, 3647 declareGlobalCtorName.str(), /*implicit=*/true, 3648 asFortran); 3649 createDeclareAllocFunc<EntryOp>( 3650 modBuilder, builder, operandLocation, globalOp, clause); 3651 if constexpr (!std::is_same_v<EntryOp, ExitOp>) 3652 createDeclareDeallocFunc<ExitOp>( 3653 modBuilder, builder, operandLocation, globalOp, clause); 3654 } else { 3655 createDeclareGlobalOp<mlir::acc::GlobalConstructorOp, EntryOp, 3656 mlir::acc::DeclareEnterOp, ExitOp>( 3657 modBuilder, builder, operandLocation, globalOp, clause, 3658 declareGlobalCtorName.str(), /*implicit=*/false, 3659 asFortran); 3660 } 3661 if constexpr (!std::is_same_v<EntryOp, ExitOp>) { 3662 createDeclareGlobalOp<mlir::acc::GlobalDestructorOp, 3663 mlir::acc::GetDevicePtrOp, 3664 mlir::acc::DeclareExitOp, ExitOp>( 3665 modBuilder, builder, operandLocation, globalOp, clause, 3666 declareGlobalDtorName.str(), /*implicit=*/false, 3667 asFortran); 3668 } 3669 builder.restoreInsertionPoint(crtPos); 3670 } 3671 }, 3672 [&](const Fortran::parser::Name &name) { 3673 TODO(operandLocation, "OpenACC Global Ctor from parser::Name"); 3674 }}, 3675 accObject.u); 3676 } 3677 } 3678 3679 template <typename Clause, typename EntryOp, typename ExitOp> 3680 static void 3681 genGlobalCtorsWithModifier(Fortran::lower::AbstractConverter &converter, 3682 mlir::OpBuilder &modBuilder, const Clause *x, 3683 Fortran::parser::AccDataModifier::Modifier mod, 3684 const mlir::acc::DataClause clause, 3685 const mlir::acc::DataClause clauseWithModifier) { 3686 const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v; 3687 const auto &accObjectList = 3688 std::get<Fortran::parser::AccObjectList>(listWithModifier.t); 3689 const auto &modifier = 3690 std::get<std::optional<Fortran::parser::AccDataModifier>>( 3691 listWithModifier.t); 3692 mlir::acc::DataClause dataClause = 3693 (modifier && (*modifier).v == mod) ? clauseWithModifier : clause; 3694 genGlobalCtors<EntryOp, ExitOp>(converter, modBuilder, accObjectList, 3695 dataClause); 3696 } 3697 3698 static void 3699 genDeclareInFunction(Fortran::lower::AbstractConverter &converter, 3700 Fortran::semantics::SemanticsContext &semanticsContext, 3701 Fortran::lower::StatementContext &openAccCtx, 3702 mlir::Location loc, 3703 const Fortran::parser::AccClauseList &accClauseList) { 3704 llvm::SmallVector<mlir::Value> dataClauseOperands, copyEntryOperands, 3705 copyinEntryOperands, createEntryOperands, copyoutEntryOperands, 3706 deviceResidentEntryOperands; 3707 Fortran::lower::StatementContext stmtCtx; 3708 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 3709 3710 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 3711 if (const auto *copyClause = 3712 std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) { 3713 auto crtDataStart = dataClauseOperands.size(); 3714 genDeclareDataOperandOperations<mlir::acc::CopyinOp, 3715 mlir::acc::CopyoutOp>( 3716 copyClause->v, converter, semanticsContext, stmtCtx, 3717 dataClauseOperands, mlir::acc::DataClause::acc_copy, 3718 /*structured=*/true, /*implicit=*/false); 3719 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart, 3720 dataClauseOperands.end()); 3721 } else if (const auto *createClause = 3722 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) { 3723 const Fortran::parser::AccObjectListWithModifier &listWithModifier = 3724 createClause->v; 3725 const auto &accObjectList = 3726 std::get<Fortran::parser::AccObjectList>(listWithModifier.t); 3727 auto crtDataStart = dataClauseOperands.size(); 3728 genDeclareDataOperandOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>( 3729 accObjectList, converter, semanticsContext, stmtCtx, 3730 dataClauseOperands, mlir::acc::DataClause::acc_create, 3731 /*structured=*/true, /*implicit=*/false); 3732 createEntryOperands.append(dataClauseOperands.begin() + crtDataStart, 3733 dataClauseOperands.end()); 3734 } else if (const auto *presentClause = 3735 std::get_if<Fortran::parser::AccClause::Present>( 3736 &clause.u)) { 3737 genDeclareDataOperandOperations<mlir::acc::PresentOp, 3738 mlir::acc::PresentOp>( 3739 presentClause->v, converter, semanticsContext, stmtCtx, 3740 dataClauseOperands, mlir::acc::DataClause::acc_present, 3741 /*structured=*/true, /*implicit=*/false); 3742 } else if (const auto *copyinClause = 3743 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) { 3744 auto crtDataStart = dataClauseOperands.size(); 3745 genDeclareDataOperandOperationsWithModifier<mlir::acc::CopyinOp, 3746 mlir::acc::DeleteOp>( 3747 copyinClause, converter, semanticsContext, stmtCtx, 3748 Fortran::parser::AccDataModifier::Modifier::ReadOnly, 3749 dataClauseOperands, mlir::acc::DataClause::acc_copyin, 3750 mlir::acc::DataClause::acc_copyin_readonly); 3751 copyinEntryOperands.append(dataClauseOperands.begin() + crtDataStart, 3752 dataClauseOperands.end()); 3753 } else if (const auto *copyoutClause = 3754 std::get_if<Fortran::parser::AccClause::Copyout>( 3755 &clause.u)) { 3756 const Fortran::parser::AccObjectListWithModifier &listWithModifier = 3757 copyoutClause->v; 3758 const auto &accObjectList = 3759 std::get<Fortran::parser::AccObjectList>(listWithModifier.t); 3760 auto crtDataStart = dataClauseOperands.size(); 3761 genDeclareDataOperandOperations<mlir::acc::CreateOp, 3762 mlir::acc::CopyoutOp>( 3763 accObjectList, converter, semanticsContext, stmtCtx, 3764 dataClauseOperands, mlir::acc::DataClause::acc_copyout, 3765 /*structured=*/true, /*implicit=*/false); 3766 copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart, 3767 dataClauseOperands.end()); 3768 } else if (const auto *devicePtrClause = 3769 std::get_if<Fortran::parser::AccClause::Deviceptr>( 3770 &clause.u)) { 3771 genDeclareDataOperandOperations<mlir::acc::DevicePtrOp, 3772 mlir::acc::DevicePtrOp>( 3773 devicePtrClause->v, converter, semanticsContext, stmtCtx, 3774 dataClauseOperands, mlir::acc::DataClause::acc_deviceptr, 3775 /*structured=*/true, /*implicit=*/false); 3776 } else if (const auto *linkClause = 3777 std::get_if<Fortran::parser::AccClause::Link>(&clause.u)) { 3778 genDeclareDataOperandOperations<mlir::acc::DeclareLinkOp, 3779 mlir::acc::DeclareLinkOp>( 3780 linkClause->v, converter, semanticsContext, stmtCtx, 3781 dataClauseOperands, mlir::acc::DataClause::acc_declare_link, 3782 /*structured=*/true, /*implicit=*/false); 3783 } else if (const auto *deviceResidentClause = 3784 std::get_if<Fortran::parser::AccClause::DeviceResident>( 3785 &clause.u)) { 3786 auto crtDataStart = dataClauseOperands.size(); 3787 genDeclareDataOperandOperations<mlir::acc::DeclareDeviceResidentOp, 3788 mlir::acc::DeleteOp>( 3789 deviceResidentClause->v, converter, semanticsContext, stmtCtx, 3790 dataClauseOperands, 3791 mlir::acc::DataClause::acc_declare_device_resident, 3792 /*structured=*/true, /*implicit=*/false); 3793 deviceResidentEntryOperands.append( 3794 dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end()); 3795 } else { 3796 mlir::Location clauseLocation = converter.genLocation(clause.source); 3797 TODO(clauseLocation, "clause on declare directive"); 3798 } 3799 } 3800 3801 mlir::func::FuncOp funcOp = builder.getFunction(); 3802 auto ops = funcOp.getOps<mlir::acc::DeclareEnterOp>(); 3803 mlir::Value declareToken; 3804 if (ops.empty()) { 3805 declareToken = builder.create<mlir::acc::DeclareEnterOp>( 3806 loc, mlir::acc::DeclareTokenType::get(builder.getContext()), 3807 dataClauseOperands); 3808 } else { 3809 auto declareOp = *ops.begin(); 3810 auto newDeclareOp = builder.create<mlir::acc::DeclareEnterOp>( 3811 loc, mlir::acc::DeclareTokenType::get(builder.getContext()), 3812 declareOp.getDataClauseOperands()); 3813 newDeclareOp.getDataClauseOperandsMutable().append(dataClauseOperands); 3814 declareToken = newDeclareOp.getToken(); 3815 declareOp.erase(); 3816 } 3817 3818 openAccCtx.attachCleanup([&builder, loc, createEntryOperands, 3819 copyEntryOperands, copyinEntryOperands, 3820 copyoutEntryOperands, deviceResidentEntryOperands, 3821 declareToken]() { 3822 llvm::SmallVector<mlir::Value> operands; 3823 operands.append(createEntryOperands); 3824 operands.append(deviceResidentEntryOperands); 3825 operands.append(copyEntryOperands); 3826 operands.append(copyinEntryOperands); 3827 operands.append(copyoutEntryOperands); 3828 3829 mlir::func::FuncOp funcOp = builder.getFunction(); 3830 auto ops = funcOp.getOps<mlir::acc::DeclareExitOp>(); 3831 if (ops.empty()) { 3832 builder.create<mlir::acc::DeclareExitOp>(loc, declareToken, operands); 3833 } else { 3834 auto declareOp = *ops.begin(); 3835 declareOp.getDataClauseOperandsMutable().append(operands); 3836 } 3837 3838 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>( 3839 builder, createEntryOperands, /*structured=*/true); 3840 genDataExitOperations<mlir::acc::DeclareDeviceResidentOp, 3841 mlir::acc::DeleteOp>( 3842 builder, deviceResidentEntryOperands, /*structured=*/true); 3843 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>( 3844 builder, copyEntryOperands, /*structured=*/true); 3845 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::DeleteOp>( 3846 builder, copyinEntryOperands, /*structured=*/true); 3847 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>( 3848 builder, copyoutEntryOperands, /*structured=*/true); 3849 }); 3850 } 3851 3852 static void 3853 genDeclareInModule(Fortran::lower::AbstractConverter &converter, 3854 mlir::ModuleOp moduleOp, 3855 const Fortran::parser::AccClauseList &accClauseList) { 3856 mlir::OpBuilder modBuilder(moduleOp.getBodyRegion()); 3857 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 3858 if (const auto *createClause = 3859 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) { 3860 const Fortran::parser::AccObjectListWithModifier &listWithModifier = 3861 createClause->v; 3862 const auto &accObjectList = 3863 std::get<Fortran::parser::AccObjectList>(listWithModifier.t); 3864 genGlobalCtors<mlir::acc::CreateOp, mlir::acc::DeleteOp>( 3865 converter, modBuilder, accObjectList, 3866 mlir::acc::DataClause::acc_create); 3867 } else if (const auto *copyinClause = 3868 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) { 3869 genGlobalCtorsWithModifier<Fortran::parser::AccClause::Copyin, 3870 mlir::acc::CopyinOp, mlir::acc::DeleteOp>( 3871 converter, modBuilder, copyinClause, 3872 Fortran::parser::AccDataModifier::Modifier::ReadOnly, 3873 mlir::acc::DataClause::acc_copyin, 3874 mlir::acc::DataClause::acc_copyin_readonly); 3875 } else if (const auto *deviceResidentClause = 3876 std::get_if<Fortran::parser::AccClause::DeviceResident>( 3877 &clause.u)) { 3878 genGlobalCtors<mlir::acc::DeclareDeviceResidentOp, mlir::acc::DeleteOp>( 3879 converter, modBuilder, deviceResidentClause->v, 3880 mlir::acc::DataClause::acc_declare_device_resident); 3881 } else if (const auto *linkClause = 3882 std::get_if<Fortran::parser::AccClause::Link>(&clause.u)) { 3883 genGlobalCtors<mlir::acc::DeclareLinkOp, mlir::acc::DeclareLinkOp>( 3884 converter, modBuilder, linkClause->v, 3885 mlir::acc::DataClause::acc_declare_link); 3886 } else { 3887 llvm::report_fatal_error("unsupported clause on DECLARE directive"); 3888 } 3889 } 3890 } 3891 3892 static void genACC(Fortran::lower::AbstractConverter &converter, 3893 Fortran::semantics::SemanticsContext &semanticsContext, 3894 Fortran::lower::StatementContext &openAccCtx, 3895 const Fortran::parser::OpenACCStandaloneDeclarativeConstruct 3896 &declareConstruct) { 3897 3898 const auto &declarativeDir = 3899 std::get<Fortran::parser::AccDeclarativeDirective>(declareConstruct.t); 3900 mlir::Location directiveLocation = 3901 converter.genLocation(declarativeDir.source); 3902 const auto &accClauseList = 3903 std::get<Fortran::parser::AccClauseList>(declareConstruct.t); 3904 3905 if (declarativeDir.v == llvm::acc::Directive::ACCD_declare) { 3906 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 3907 auto moduleOp = 3908 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>(); 3909 auto funcOp = 3910 builder.getBlock()->getParent()->getParentOfType<mlir::func::FuncOp>(); 3911 if (funcOp) 3912 genDeclareInFunction(converter, semanticsContext, openAccCtx, 3913 directiveLocation, accClauseList); 3914 else if (moduleOp) 3915 genDeclareInModule(converter, moduleOp, accClauseList); 3916 return; 3917 } 3918 llvm_unreachable("unsupported declarative directive"); 3919 } 3920 3921 static bool hasDeviceType(llvm::SmallVector<mlir::Attribute> &arrayAttr, 3922 mlir::acc::DeviceType deviceType) { 3923 for (auto attr : arrayAttr) { 3924 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr); 3925 if (deviceTypeAttr.getValue() == deviceType) 3926 return true; 3927 } 3928 return false; 3929 } 3930 3931 template <typename RetTy, typename AttrTy> 3932 static std::optional<RetTy> 3933 getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes, 3934 llvm::SmallVector<mlir::Attribute> &deviceTypes, 3935 mlir::acc::DeviceType deviceType) { 3936 assert(attributes.size() == deviceTypes.size() && 3937 "expect same number of attributes"); 3938 for (auto it : llvm::enumerate(deviceTypes)) { 3939 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(it.value()); 3940 if (deviceTypeAttr.getValue() == deviceType) { 3941 if constexpr (std::is_same_v<mlir::StringAttr, AttrTy>) { 3942 auto strAttr = mlir::dyn_cast<AttrTy>(attributes[it.index()]); 3943 return strAttr.getValue(); 3944 } else if constexpr (std::is_same_v<mlir::IntegerAttr, AttrTy>) { 3945 auto intAttr = 3946 mlir::dyn_cast<mlir::IntegerAttr>(attributes[it.index()]); 3947 return intAttr.getInt(); 3948 } 3949 } 3950 } 3951 return std::nullopt; 3952 } 3953 3954 static bool compareDeviceTypeInfo( 3955 mlir::acc::RoutineOp op, 3956 llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr, 3957 llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr, 3958 llvm::SmallVector<mlir::Attribute> &gangArrayAttr, 3959 llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr, 3960 llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr, 3961 llvm::SmallVector<mlir::Attribute> &seqArrayAttr, 3962 llvm::SmallVector<mlir::Attribute> &workerArrayAttr, 3963 llvm::SmallVector<mlir::Attribute> &vectorArrayAttr) { 3964 for (uint32_t dtypeInt = 0; 3965 dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) { 3966 auto dtype = static_cast<mlir::acc::DeviceType>(dtypeInt); 3967 if (op.getBindNameValue(dtype) != 3968 getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>( 3969 bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype)) 3970 return false; 3971 if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype)) 3972 return false; 3973 if (op.getGangDimValue(dtype) != 3974 getAttributeValueByDeviceType<int64_t, mlir::IntegerAttr>( 3975 gangDimArrayAttr, gangDimDeviceTypeArrayAttr, dtype)) 3976 return false; 3977 if (op.hasSeq(dtype) != hasDeviceType(seqArrayAttr, dtype)) 3978 return false; 3979 if (op.hasWorker(dtype) != hasDeviceType(workerArrayAttr, dtype)) 3980 return false; 3981 if (op.hasVector(dtype) != hasDeviceType(vectorArrayAttr, dtype)) 3982 return false; 3983 } 3984 return true; 3985 } 3986 3987 static void attachRoutineInfo(mlir::func::FuncOp func, 3988 mlir::SymbolRefAttr routineAttr) { 3989 llvm::SmallVector<mlir::SymbolRefAttr> routines; 3990 if (func.getOperation()->hasAttr(mlir::acc::getRoutineInfoAttrName())) { 3991 auto routineInfo = 3992 func.getOperation()->getAttrOfType<mlir::acc::RoutineInfoAttr>( 3993 mlir::acc::getRoutineInfoAttrName()); 3994 routines.append(routineInfo.getAccRoutines().begin(), 3995 routineInfo.getAccRoutines().end()); 3996 } 3997 routines.push_back(routineAttr); 3998 func.getOperation()->setAttr( 3999 mlir::acc::getRoutineInfoAttrName(), 4000 mlir::acc::RoutineInfoAttr::get(func.getContext(), routines)); 4001 } 4002 4003 void Fortran::lower::genOpenACCRoutineConstruct( 4004 Fortran::lower::AbstractConverter &converter, 4005 Fortran::semantics::SemanticsContext &semanticsContext, mlir::ModuleOp mod, 4006 const Fortran::parser::OpenACCRoutineConstruct &routineConstruct, 4007 Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) { 4008 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 4009 mlir::Location loc = converter.genLocation(routineConstruct.source); 4010 std::optional<Fortran::parser::Name> name = 4011 std::get<std::optional<Fortran::parser::Name>>(routineConstruct.t); 4012 const auto &clauses = 4013 std::get<Fortran::parser::AccClauseList>(routineConstruct.t); 4014 mlir::func::FuncOp funcOp; 4015 std::string funcName; 4016 if (name) { 4017 funcName = converter.mangleName(*name->symbol); 4018 funcOp = 4019 builder.getNamedFunction(mod, builder.getMLIRSymbolTable(), funcName); 4020 } else { 4021 Fortran::semantics::Scope &scope = 4022 semanticsContext.FindScope(routineConstruct.source); 4023 const Fortran::semantics::Scope &progUnit{GetProgramUnitContaining(scope)}; 4024 const auto *subpDetails{ 4025 progUnit.symbol() 4026 ? progUnit.symbol() 4027 ->detailsIf<Fortran::semantics::SubprogramDetails>() 4028 : nullptr}; 4029 if (subpDetails && subpDetails->isInterface()) { 4030 funcName = converter.mangleName(*progUnit.symbol()); 4031 funcOp = 4032 builder.getNamedFunction(mod, builder.getMLIRSymbolTable(), funcName); 4033 } else { 4034 funcOp = builder.getFunction(); 4035 funcName = funcOp.getName(); 4036 } 4037 } 4038 bool hasNohost = false; 4039 4040 llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes, 4041 workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes, 4042 gangDimDeviceTypes, gangDimValues; 4043 4044 // device_type attribute is set to `none` until a device_type clause is 4045 // encountered. 4046 llvm::SmallVector<mlir::Attribute> crtDeviceTypes; 4047 crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get( 4048 builder.getContext(), mlir::acc::DeviceType::None)); 4049 4050 for (const Fortran::parser::AccClause &clause : clauses.v) { 4051 if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) { 4052 for (auto crtDeviceTypeAttr : crtDeviceTypes) 4053 seqDeviceTypes.push_back(crtDeviceTypeAttr); 4054 } else if (const auto *gangClause = 4055 std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) { 4056 if (gangClause->v) { 4057 const Fortran::parser::AccGangArgList &x = *gangClause->v; 4058 for (const Fortran::parser::AccGangArg &gangArg : x.v) { 4059 if (const auto *dim = 4060 std::get_if<Fortran::parser::AccGangArg::Dim>(&gangArg.u)) { 4061 const std::optional<int64_t> dimValue = Fortran::evaluate::ToInt64( 4062 *Fortran::semantics::GetExpr(dim->v)); 4063 if (!dimValue) 4064 mlir::emitError(loc, 4065 "dim value must be a constant positive integer"); 4066 mlir::Attribute gangDimAttr = 4067 builder.getIntegerAttr(builder.getI64Type(), *dimValue); 4068 for (auto crtDeviceTypeAttr : crtDeviceTypes) { 4069 gangDimValues.push_back(gangDimAttr); 4070 gangDimDeviceTypes.push_back(crtDeviceTypeAttr); 4071 } 4072 } 4073 } 4074 } else { 4075 for (auto crtDeviceTypeAttr : crtDeviceTypes) 4076 gangDeviceTypes.push_back(crtDeviceTypeAttr); 4077 } 4078 } else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) { 4079 for (auto crtDeviceTypeAttr : crtDeviceTypes) 4080 vectorDeviceTypes.push_back(crtDeviceTypeAttr); 4081 } else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) { 4082 for (auto crtDeviceTypeAttr : crtDeviceTypes) 4083 workerDeviceTypes.push_back(crtDeviceTypeAttr); 4084 } else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u)) { 4085 hasNohost = true; 4086 } else if (const auto *bindClause = 4087 std::get_if<Fortran::parser::AccClause::Bind>(&clause.u)) { 4088 if (const auto *name = 4089 std::get_if<Fortran::parser::Name>(&bindClause->v.u)) { 4090 mlir::Attribute bindNameAttr = 4091 builder.getStringAttr(converter.mangleName(*name->symbol)); 4092 for (auto crtDeviceTypeAttr : crtDeviceTypes) { 4093 bindNames.push_back(bindNameAttr); 4094 bindNameDeviceTypes.push_back(crtDeviceTypeAttr); 4095 } 4096 } else if (const auto charExpr = 4097 std::get_if<Fortran::parser::ScalarDefaultCharExpr>( 4098 &bindClause->v.u)) { 4099 const std::optional<std::string> name = 4100 Fortran::semantics::GetConstExpr<std::string>(semanticsContext, 4101 *charExpr); 4102 if (!name) 4103 mlir::emitError(loc, "Could not retrieve the bind name"); 4104 4105 mlir::Attribute bindNameAttr = builder.getStringAttr(*name); 4106 for (auto crtDeviceTypeAttr : crtDeviceTypes) { 4107 bindNames.push_back(bindNameAttr); 4108 bindNameDeviceTypes.push_back(crtDeviceTypeAttr); 4109 } 4110 } 4111 } else if (const auto *deviceTypeClause = 4112 std::get_if<Fortran::parser::AccClause::DeviceType>( 4113 &clause.u)) { 4114 crtDeviceTypes.clear(); 4115 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes); 4116 } 4117 } 4118 4119 mlir::OpBuilder modBuilder(mod.getBodyRegion()); 4120 std::stringstream routineOpName; 4121 routineOpName << accRoutinePrefix.str() << routineCounter++; 4122 4123 for (auto routineOp : mod.getOps<mlir::acc::RoutineOp>()) { 4124 if (routineOp.getFuncName().str().compare(funcName) == 0) { 4125 // If the routine is already specified with the same clauses, just skip 4126 // the operation creation. 4127 if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes, 4128 gangDeviceTypes, gangDimValues, 4129 gangDimDeviceTypes, seqDeviceTypes, 4130 workerDeviceTypes, vectorDeviceTypes) && 4131 routineOp.getNohost() == hasNohost) 4132 return; 4133 mlir::emitError(loc, "Routine already specified with different clauses"); 4134 } 4135 } 4136 4137 modBuilder.create<mlir::acc::RoutineOp>( 4138 loc, routineOpName.str(), funcName, 4139 bindNames.empty() ? nullptr : builder.getArrayAttr(bindNames), 4140 bindNameDeviceTypes.empty() ? nullptr 4141 : builder.getArrayAttr(bindNameDeviceTypes), 4142 workerDeviceTypes.empty() ? nullptr 4143 : builder.getArrayAttr(workerDeviceTypes), 4144 vectorDeviceTypes.empty() ? nullptr 4145 : builder.getArrayAttr(vectorDeviceTypes), 4146 seqDeviceTypes.empty() ? nullptr : builder.getArrayAttr(seqDeviceTypes), 4147 hasNohost, /*implicit=*/false, 4148 gangDeviceTypes.empty() ? nullptr : builder.getArrayAttr(gangDeviceTypes), 4149 gangDimValues.empty() ? nullptr : builder.getArrayAttr(gangDimValues), 4150 gangDimDeviceTypes.empty() ? nullptr 4151 : builder.getArrayAttr(gangDimDeviceTypes)); 4152 4153 if (funcOp) 4154 attachRoutineInfo(funcOp, builder.getSymbolRefAttr(routineOpName.str())); 4155 else 4156 // FuncOp is not lowered yet. Keep the information so the routine info 4157 // can be attached later to the funcOp. 4158 accRoutineInfos.push_back(std::make_pair( 4159 funcName, builder.getSymbolRefAttr(routineOpName.str()))); 4160 } 4161 4162 void Fortran::lower::finalizeOpenACCRoutineAttachment( 4163 mlir::ModuleOp mod, 4164 Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) { 4165 for (auto &mapping : accRoutineInfos) { 4166 mlir::func::FuncOp funcOp = 4167 mod.lookupSymbol<mlir::func::FuncOp>(mapping.first); 4168 if (!funcOp) 4169 mlir::emitWarning(mod.getLoc(), 4170 llvm::Twine("function '") + llvm::Twine(mapping.first) + 4171 llvm::Twine("' in acc routine directive is not " 4172 "found in this translation unit.")); 4173 else 4174 attachRoutineInfo(funcOp, mapping.second); 4175 } 4176 accRoutineInfos.clear(); 4177 } 4178 4179 static void 4180 genACC(Fortran::lower::AbstractConverter &converter, 4181 Fortran::lower::pft::Evaluation &eval, 4182 const Fortran::parser::OpenACCAtomicConstruct &atomicConstruct) { 4183 4184 mlir::Location loc = converter.genLocation(atomicConstruct.source); 4185 Fortran::common::visit( 4186 Fortran::common::visitors{ 4187 [&](const Fortran::parser::AccAtomicRead &atomicRead) { 4188 Fortran::lower::genOmpAccAtomicRead<Fortran::parser::AccAtomicRead, 4189 void>(converter, atomicRead, 4190 loc); 4191 }, 4192 [&](const Fortran::parser::AccAtomicWrite &atomicWrite) { 4193 Fortran::lower::genOmpAccAtomicWrite< 4194 Fortran::parser::AccAtomicWrite, void>(converter, atomicWrite, 4195 loc); 4196 }, 4197 [&](const Fortran::parser::AccAtomicUpdate &atomicUpdate) { 4198 Fortran::lower::genOmpAccAtomicUpdate< 4199 Fortran::parser::AccAtomicUpdate, void>(converter, atomicUpdate, 4200 loc); 4201 }, 4202 [&](const Fortran::parser::AccAtomicCapture &atomicCapture) { 4203 Fortran::lower::genOmpAccAtomicCapture< 4204 Fortran::parser::AccAtomicCapture, void>(converter, 4205 atomicCapture, loc); 4206 }, 4207 }, 4208 atomicConstruct.u); 4209 } 4210 4211 static void 4212 genACC(Fortran::lower::AbstractConverter &converter, 4213 Fortran::semantics::SemanticsContext &semanticsContext, 4214 const Fortran::parser::OpenACCCacheConstruct &cacheConstruct) { 4215 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 4216 auto loopOp = builder.getRegion().getParentOfType<mlir::acc::LoopOp>(); 4217 auto crtPos = builder.saveInsertionPoint(); 4218 if (loopOp) { 4219 builder.setInsertionPoint(loopOp); 4220 Fortran::lower::StatementContext stmtCtx; 4221 llvm::SmallVector<mlir::Value> cacheOperands; 4222 const Fortran::parser::AccObjectListWithModifier &listWithModifier = 4223 std::get<Fortran::parser::AccObjectListWithModifier>(cacheConstruct.t); 4224 const auto &accObjectList = 4225 std::get<Fortran::parser::AccObjectList>(listWithModifier.t); 4226 const auto &modifier = 4227 std::get<std::optional<Fortran::parser::AccDataModifier>>( 4228 listWithModifier.t); 4229 4230 mlir::acc::DataClause dataClause = mlir::acc::DataClause::acc_cache; 4231 if (modifier && 4232 (*modifier).v == Fortran::parser::AccDataModifier::Modifier::ReadOnly) 4233 dataClause = mlir::acc::DataClause::acc_cache_readonly; 4234 genDataOperandOperations<mlir::acc::CacheOp>( 4235 accObjectList, converter, semanticsContext, stmtCtx, cacheOperands, 4236 dataClause, 4237 /*structured=*/true, /*implicit=*/false, 4238 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}, 4239 /*setDeclareAttr*/ false); 4240 loopOp.getCacheOperandsMutable().append(cacheOperands); 4241 } else { 4242 llvm::report_fatal_error( 4243 "could not find loop to attach OpenACC cache information."); 4244 } 4245 builder.restoreInsertionPoint(crtPos); 4246 } 4247 4248 mlir::Value Fortran::lower::genOpenACCConstruct( 4249 Fortran::lower::AbstractConverter &converter, 4250 Fortran::semantics::SemanticsContext &semanticsContext, 4251 Fortran::lower::pft::Evaluation &eval, 4252 const Fortran::parser::OpenACCConstruct &accConstruct) { 4253 4254 mlir::Value exitCond; 4255 Fortran::common::visit( 4256 common::visitors{ 4257 [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) { 4258 genACC(converter, semanticsContext, eval, blockConstruct); 4259 }, 4260 [&](const Fortran::parser::OpenACCCombinedConstruct 4261 &combinedConstruct) { 4262 genACC(converter, semanticsContext, eval, combinedConstruct); 4263 }, 4264 [&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) { 4265 exitCond = genACC(converter, semanticsContext, eval, loopConstruct); 4266 }, 4267 [&](const Fortran::parser::OpenACCStandaloneConstruct 4268 &standaloneConstruct) { 4269 genACC(converter, semanticsContext, standaloneConstruct); 4270 }, 4271 [&](const Fortran::parser::OpenACCCacheConstruct &cacheConstruct) { 4272 genACC(converter, semanticsContext, cacheConstruct); 4273 }, 4274 [&](const Fortran::parser::OpenACCWaitConstruct &waitConstruct) { 4275 genACC(converter, waitConstruct); 4276 }, 4277 [&](const Fortran::parser::OpenACCAtomicConstruct &atomicConstruct) { 4278 genACC(converter, eval, atomicConstruct); 4279 }, 4280 [&](const Fortran::parser::OpenACCEndConstruct &) { 4281 // No op 4282 }, 4283 }, 4284 accConstruct.u); 4285 return exitCond; 4286 } 4287 4288 void Fortran::lower::genOpenACCDeclarativeConstruct( 4289 Fortran::lower::AbstractConverter &converter, 4290 Fortran::semantics::SemanticsContext &semanticsContext, 4291 Fortran::lower::StatementContext &openAccCtx, 4292 const Fortran::parser::OpenACCDeclarativeConstruct &accDeclConstruct, 4293 Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) { 4294 4295 Fortran::common::visit( 4296 common::visitors{ 4297 [&](const Fortran::parser::OpenACCStandaloneDeclarativeConstruct 4298 &standaloneDeclarativeConstruct) { 4299 genACC(converter, semanticsContext, openAccCtx, 4300 standaloneDeclarativeConstruct); 4301 }, 4302 [&](const Fortran::parser::OpenACCRoutineConstruct 4303 &routineConstruct) { 4304 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 4305 mlir::ModuleOp mod = builder.getModule(); 4306 Fortran::lower::genOpenACCRoutineConstruct( 4307 converter, semanticsContext, mod, routineConstruct, 4308 accRoutineInfos); 4309 }, 4310 }, 4311 accDeclConstruct.u); 4312 } 4313 4314 void Fortran::lower::attachDeclarePostAllocAction( 4315 AbstractConverter &converter, fir::FirOpBuilder &builder, 4316 const Fortran::semantics::Symbol &sym) { 4317 std::stringstream fctName; 4318 fctName << converter.mangleName(sym) << declarePostAllocSuffix.str(); 4319 mlir::Operation *op = &builder.getInsertionBlock()->back(); 4320 4321 if (auto resOp = mlir::dyn_cast<fir::ResultOp>(*op)) { 4322 assert(resOp.getOperands().size() == 0 && 4323 "expect only fir.result op with no operand"); 4324 op = op->getPrevNode(); 4325 } 4326 assert(op && "expect operation to attach the post allocation action"); 4327 4328 if (op->hasAttr(mlir::acc::getDeclareActionAttrName())) { 4329 auto attr = op->getAttrOfType<mlir::acc::DeclareActionAttr>( 4330 mlir::acc::getDeclareActionAttrName()); 4331 op->setAttr(mlir::acc::getDeclareActionAttrName(), 4332 mlir::acc::DeclareActionAttr::get( 4333 builder.getContext(), attr.getPreAlloc(), 4334 /*postAlloc=*/builder.getSymbolRefAttr(fctName.str()), 4335 attr.getPreDealloc(), attr.getPostDealloc())); 4336 } else { 4337 op->setAttr(mlir::acc::getDeclareActionAttrName(), 4338 mlir::acc::DeclareActionAttr::get( 4339 builder.getContext(), 4340 /*preAlloc=*/{}, 4341 /*postAlloc=*/builder.getSymbolRefAttr(fctName.str()), 4342 /*preDealloc=*/{}, /*postDealloc=*/{})); 4343 } 4344 } 4345 4346 void Fortran::lower::attachDeclarePreDeallocAction( 4347 AbstractConverter &converter, fir::FirOpBuilder &builder, 4348 mlir::Value beginOpValue, const Fortran::semantics::Symbol &sym) { 4349 if (!sym.test(Fortran::semantics::Symbol::Flag::AccCreate) && 4350 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyIn) && 4351 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyInReadOnly) && 4352 !sym.test(Fortran::semantics::Symbol::Flag::AccCopy) && 4353 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyOut) && 4354 !sym.test(Fortran::semantics::Symbol::Flag::AccDeviceResident)) 4355 return; 4356 4357 std::stringstream fctName; 4358 fctName << converter.mangleName(sym) << declarePreDeallocSuffix.str(); 4359 4360 auto *op = beginOpValue.getDefiningOp(); 4361 if (op->hasAttr(mlir::acc::getDeclareActionAttrName())) { 4362 auto attr = op->getAttrOfType<mlir::acc::DeclareActionAttr>( 4363 mlir::acc::getDeclareActionAttrName()); 4364 op->setAttr(mlir::acc::getDeclareActionAttrName(), 4365 mlir::acc::DeclareActionAttr::get( 4366 builder.getContext(), attr.getPreAlloc(), 4367 attr.getPostAlloc(), 4368 /*preDealloc=*/builder.getSymbolRefAttr(fctName.str()), 4369 attr.getPostDealloc())); 4370 } else { 4371 op->setAttr(mlir::acc::getDeclareActionAttrName(), 4372 mlir::acc::DeclareActionAttr::get( 4373 builder.getContext(), 4374 /*preAlloc=*/{}, /*postAlloc=*/{}, 4375 /*preDealloc=*/builder.getSymbolRefAttr(fctName.str()), 4376 /*postDealloc=*/{})); 4377 } 4378 } 4379 4380 void Fortran::lower::attachDeclarePostDeallocAction( 4381 AbstractConverter &converter, fir::FirOpBuilder &builder, 4382 const Fortran::semantics::Symbol &sym) { 4383 if (!sym.test(Fortran::semantics::Symbol::Flag::AccCreate) && 4384 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyIn) && 4385 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyInReadOnly) && 4386 !sym.test(Fortran::semantics::Symbol::Flag::AccCopy) && 4387 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyOut) && 4388 !sym.test(Fortran::semantics::Symbol::Flag::AccDeviceResident)) 4389 return; 4390 4391 std::stringstream fctName; 4392 fctName << converter.mangleName(sym) << declarePostDeallocSuffix.str(); 4393 mlir::Operation *op = &builder.getInsertionBlock()->back(); 4394 if (auto resOp = mlir::dyn_cast<fir::ResultOp>(*op)) { 4395 assert(resOp.getOperands().size() == 0 && 4396 "expect only fir.result op with no operand"); 4397 op = op->getPrevNode(); 4398 } 4399 assert(op && "expect operation to attach the post deallocation action"); 4400 if (op->hasAttr(mlir::acc::getDeclareActionAttrName())) { 4401 auto attr = op->getAttrOfType<mlir::acc::DeclareActionAttr>( 4402 mlir::acc::getDeclareActionAttrName()); 4403 op->setAttr(mlir::acc::getDeclareActionAttrName(), 4404 mlir::acc::DeclareActionAttr::get( 4405 builder.getContext(), attr.getPreAlloc(), 4406 attr.getPostAlloc(), attr.getPreDealloc(), 4407 /*postDealloc=*/builder.getSymbolRefAttr(fctName.str()))); 4408 } else { 4409 op->setAttr(mlir::acc::getDeclareActionAttrName(), 4410 mlir::acc::DeclareActionAttr::get( 4411 builder.getContext(), 4412 /*preAlloc=*/{}, /*postAlloc=*/{}, /*preDealloc=*/{}, 4413 /*postDealloc=*/builder.getSymbolRefAttr(fctName.str()))); 4414 } 4415 } 4416 4417 void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder, 4418 mlir::Operation *op, 4419 mlir::Location loc) { 4420 if (mlir::isa<mlir::acc::ParallelOp, mlir::acc::LoopOp>(op)) 4421 builder.create<mlir::acc::YieldOp>(loc); 4422 else 4423 builder.create<mlir::acc::TerminatorOp>(loc); 4424 } 4425 4426 bool Fortran::lower::isInOpenACCLoop(fir::FirOpBuilder &builder) { 4427 if (builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>()) 4428 return true; 4429 return false; 4430 } 4431 4432 void Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside( 4433 fir::FirOpBuilder &builder) { 4434 if (auto loopOp = 4435 builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>()) 4436 builder.setInsertionPointAfter(loopOp); 4437 } 4438 4439 void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder, 4440 mlir::Location loc) { 4441 mlir::Value yieldValue = 4442 builder.createIntegerConstant(loc, builder.getI1Type(), 1); 4443 builder.create<mlir::acc::YieldOp>(loc, yieldValue); 4444 } 4445 4446 int64_t Fortran::lower::getCollapseValue( 4447 const Fortran::parser::AccClauseList &clauseList) { 4448 for (const Fortran::parser::AccClause &clause : clauseList.v) { 4449 if (const auto *collapseClause = 4450 std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) { 4451 const parser::AccCollapseArg &arg = collapseClause->v; 4452 const auto &collapseValue{std::get<parser::ScalarIntConstantExpr>(arg.t)}; 4453 return *Fortran::semantics::GetIntValue(collapseValue); 4454 } 4455 } 4456 return 1; 4457 } 4458