//===-- ClauseProcessor.cpp -------------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ // //===----------------------------------------------------------------------===// #include "ClauseProcessor.h" #include "Clauses.h" #include "flang/Lower/PFTBuilder.h" #include "flang/Parser/tools.h" #include "flang/Semantics/tools.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" namespace Fortran { namespace lower { namespace omp { /// Check for unsupported map operand types. static void checkMapType(mlir::Location location, mlir::Type type) { if (auto refType = mlir::dyn_cast(type)) type = refType.getElementType(); if (auto boxType = mlir::dyn_cast_or_null(type)) if (!mlir::isa(boxType.getElementType())) TODO(location, "OMPD_target_data MapOperand BoxType"); } static mlir::omp::ScheduleModifier translateScheduleModifier(const omp::clause::Schedule::OrderingModifier &m) { switch (m) { case omp::clause::Schedule::OrderingModifier::Monotonic: return mlir::omp::ScheduleModifier::monotonic; case omp::clause::Schedule::OrderingModifier::Nonmonotonic: return mlir::omp::ScheduleModifier::nonmonotonic; } return mlir::omp::ScheduleModifier::none; } static mlir::omp::ScheduleModifier getScheduleModifier(const omp::clause::Schedule &clause) { using Schedule = omp::clause::Schedule; const auto &modifier = std::get>(clause.t); if (modifier) return translateScheduleModifier(*modifier); return mlir::omp::ScheduleModifier::none; } static mlir::omp::ScheduleModifier getSimdModifier(const omp::clause::Schedule &clause) { using Schedule = omp::clause::Schedule; const auto &modifier = std::get>(clause.t); if (modifier && *modifier == Schedule::ChunkModifier::Simd) return mlir::omp::ScheduleModifier::simd; return mlir::omp::ScheduleModifier::none; } static void genAllocateClause(lower::AbstractConverter &converter, const omp::clause::Allocate &clause, llvm::SmallVectorImpl &allocatorOperands, llvm::SmallVectorImpl &allocateOperands) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::Location currentLocation = converter.getCurrentLocation(); lower::StatementContext stmtCtx; auto &objects = std::get(clause.t); using Allocate = omp::clause::Allocate; // ALIGN in this context is unimplemented if (std::get>(clause.t)) TODO(currentLocation, "OmpAllocateClause ALIGN modifier"); // Check if allocate clause has allocator specified. If so, add it // to list of allocators, otherwise, add default allocator to // list of allocators. using ComplexModifier = Allocate::AllocatorComplexModifier; if (auto &mod = std::get>(clause.t)) { mlir::Value operand = fir::getBase(converter.genExprValue(mod->v, stmtCtx)); allocatorOperands.append(objects.size(), operand); } else { mlir::Value operand = firOpBuilder.createIntegerConstant( currentLocation, firOpBuilder.getI32Type(), 1); allocatorOperands.append(objects.size(), operand); } genObjectList(objects, converter, allocateOperands); } static mlir::omp::ClauseBindKindAttr genBindKindAttr(fir::FirOpBuilder &firOpBuilder, const omp::clause::Bind &clause) { mlir::omp::ClauseBindKind bindKind; switch (clause.v) { case omp::clause::Bind::Binding::Teams: bindKind = mlir::omp::ClauseBindKind::Teams; break; case omp::clause::Bind::Binding::Parallel: bindKind = mlir::omp::ClauseBindKind::Parallel; break; case omp::clause::Bind::Binding::Thread: bindKind = mlir::omp::ClauseBindKind::Thread; break; } return mlir::omp::ClauseBindKindAttr::get(firOpBuilder.getContext(), bindKind); } static mlir::omp::ClauseProcBindKindAttr genProcBindKindAttr(fir::FirOpBuilder &firOpBuilder, const omp::clause::ProcBind &clause) { mlir::omp::ClauseProcBindKind procBindKind; switch (clause.v) { case omp::clause::ProcBind::AffinityPolicy::Master: procBindKind = mlir::omp::ClauseProcBindKind::Master; break; case omp::clause::ProcBind::AffinityPolicy::Close: procBindKind = mlir::omp::ClauseProcBindKind::Close; break; case omp::clause::ProcBind::AffinityPolicy::Spread: procBindKind = mlir::omp::ClauseProcBindKind::Spread; break; case omp::clause::ProcBind::AffinityPolicy::Primary: procBindKind = mlir::omp::ClauseProcBindKind::Primary; break; } return mlir::omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), procBindKind); } static mlir::omp::ClauseTaskDependAttr genDependKindAttr(lower::AbstractConverter &converter, const omp::clause::DependenceType kind) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::Location currentLocation = converter.getCurrentLocation(); mlir::omp::ClauseTaskDepend pbKind; switch (kind) { case omp::clause::DependenceType::In: pbKind = mlir::omp::ClauseTaskDepend::taskdependin; break; case omp::clause::DependenceType::Out: pbKind = mlir::omp::ClauseTaskDepend::taskdependout; break; case omp::clause::DependenceType::Inout: pbKind = mlir::omp::ClauseTaskDepend::taskdependinout; break; case omp::clause::DependenceType::Mutexinoutset: pbKind = mlir::omp::ClauseTaskDepend::taskdependmutexinoutset; break; case omp::clause::DependenceType::Inoutset: pbKind = mlir::omp::ClauseTaskDepend::taskdependinoutset; break; case omp::clause::DependenceType::Depobj: TODO(currentLocation, "DEPOBJ dependence-type"); break; case omp::clause::DependenceType::Sink: case omp::clause::DependenceType::Source: llvm_unreachable("unhandled parser task dependence type"); break; } return mlir::omp::ClauseTaskDependAttr::get(firOpBuilder.getContext(), pbKind); } static mlir::Value getIfClauseOperand(lower::AbstractConverter &converter, const omp::clause::If &clause, omp::clause::If::DirectiveNameModifier directiveName, mlir::Location clauseLocation) { // Only consider the clause if it's intended for the given directive. auto &directive = std::get>(clause.t); if (directive && directive.value() != directiveName) return nullptr; lower::StatementContext stmtCtx; fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::Value ifVal = fir::getBase( converter.genExprValue(std::get(clause.t), stmtCtx)); return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(), ifVal); } static void addUseDeviceClause( lower::AbstractConverter &converter, const omp::ObjectList &objects, llvm::SmallVectorImpl &operands, llvm::SmallVectorImpl &useDeviceSyms) { genObjectList(objects, converter, operands); for (mlir::Value &operand : operands) checkMapType(operand.getLoc(), operand.getType()); for (const omp::Object &object : objects) useDeviceSyms.push_back(object.sym()); } static void convertLoopBounds(lower::AbstractConverter &converter, mlir::Location loc, mlir::omp::LoopRelatedClauseOps &result, std::size_t loopVarTypeSize) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); // The types of lower bound, upper bound, and step are converted into the // type of the loop variable if necessary. mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); for (unsigned it = 0; it < (unsigned)result.loopLowerBounds.size(); it++) { result.loopLowerBounds[it] = firOpBuilder.createConvert( loc, loopVarType, result.loopLowerBounds[it]); result.loopUpperBounds[it] = firOpBuilder.createConvert( loc, loopVarType, result.loopUpperBounds[it]); result.loopSteps[it] = firOpBuilder.createConvert(loc, loopVarType, result.loopSteps[it]); } } //===----------------------------------------------------------------------===// // ClauseProcessor unique clauses //===----------------------------------------------------------------------===// bool ClauseProcessor::processBare(mlir::omp::BareClauseOps &result) const { return markClauseOccurrence(result.bare); } bool ClauseProcessor::processBind(mlir::omp::BindClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); result.bindKind = genBindKindAttr(firOpBuilder, *clause); return true; } return false; } bool ClauseProcessor::processCollapse( mlir::Location currentLocation, lower::pft::Evaluation &eval, mlir::omp::LoopRelatedClauseOps &result, llvm::SmallVectorImpl &iv) const { bool found = false; fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); // Collect the loops to collapse. lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation(); if (doConstructEval->getIf()->IsDoConcurrent()) { TODO(currentLocation, "Do Concurrent in Worksharing loop construct"); } std::int64_t collapseValue = 1l; if (auto *clause = findUniqueClause()) { collapseValue = evaluate::ToInt64(clause->v).value(); found = true; } std::size_t loopVarTypeSize = 0; do { lower::pft::Evaluation *doLoop = &doConstructEval->getFirstNestedEvaluation(); auto *doStmt = doLoop->getIf(); assert(doStmt && "Expected do loop to be in the nested evaluation"); const auto &loopControl = std::get>(doStmt->t); const parser::LoopControl::Bounds *bounds = std::get_if(&loopControl->u); assert(bounds && "Expected bounds for worksharing do loop"); lower::StatementContext stmtCtx; result.loopLowerBounds.push_back(fir::getBase( converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx))); result.loopUpperBounds.push_back(fir::getBase( converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx))); if (bounds->step) { result.loopSteps.push_back(fir::getBase( converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx))); } else { // If `step` is not present, assume it as `1`. result.loopSteps.push_back(firOpBuilder.createIntegerConstant( currentLocation, firOpBuilder.getIntegerType(32), 1)); } iv.push_back(bounds->name.thing.symbol); loopVarTypeSize = std::max(loopVarTypeSize, bounds->name.thing.symbol->GetUltimate().size()); collapseValue--; doConstructEval = &*std::next(doConstructEval->getNestedEvaluations().begin()); } while (collapseValue > 0); convertLoopBounds(converter, currentLocation, result, loopVarTypeSize); return found; } bool ClauseProcessor::processDevice(lower::StatementContext &stmtCtx, mlir::omp::DeviceClauseOps &result) const { const parser::CharBlock *source = nullptr; if (auto *clause = findUniqueClause(&source)) { mlir::Location clauseLocation = converter.genLocation(*source); if (auto deviceModifier = std::get>( clause->t)) { if (deviceModifier == omp::clause::Device::DeviceModifier::Ancestor) { TODO(clauseLocation, "OMPD_target Device Modifier Ancestor"); } } const auto &deviceExpr = std::get(clause->t); result.device = fir::getBase(converter.genExprValue(deviceExpr, stmtCtx)); return true; } return false; } bool ClauseProcessor::processDeviceType( mlir::omp::DeviceTypeClauseOps &result) const { if (auto *clause = findUniqueClause()) { // Case: declare target ... device_type(any | host | nohost) switch (clause->v) { case omp::clause::DeviceType::DeviceTypeDescription::Nohost: result.deviceType = mlir::omp::DeclareTargetDeviceType::nohost; break; case omp::clause::DeviceType::DeviceTypeDescription::Host: result.deviceType = mlir::omp::DeclareTargetDeviceType::host; break; case omp::clause::DeviceType::DeviceTypeDescription::Any: result.deviceType = mlir::omp::DeclareTargetDeviceType::any; break; } return true; } return false; } bool ClauseProcessor::processDistSchedule( lower::StatementContext &stmtCtx, mlir::omp::DistScheduleClauseOps &result) const { if (auto *clause = findUniqueClause()) { result.distScheduleStatic = converter.getFirOpBuilder().getUnitAttr(); const auto &chunkSize = std::get>(clause->t); if (chunkSize) result.distScheduleChunkSize = fir::getBase(converter.genExprValue(*chunkSize, stmtCtx)); return true; } return false; } bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx, mlir::omp::FilterClauseOps &result) const { if (auto *clause = findUniqueClause()) { result.filteredThreadId = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); return true; } return false; } bool ClauseProcessor::processFinal(lower::StatementContext &stmtCtx, mlir::omp::FinalClauseOps &result) const { const parser::CharBlock *source = nullptr; if (auto *clause = findUniqueClause(&source)) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::Location clauseLocation = converter.genLocation(*source); mlir::Value finalVal = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); result.final = firOpBuilder.createConvert( clauseLocation, firOpBuilder.getI1Type(), finalVal); return true; } return false; } bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); int64_t hintValue = *evaluate::ToInt64(clause->v); result.hint = firOpBuilder.getI64IntegerAttr(hintValue); return true; } return false; } bool ClauseProcessor::processMergeable( mlir::omp::MergeableClauseOps &result) const { return markClauseOccurrence(result.mergeable); } bool ClauseProcessor::processNowait(mlir::omp::NowaitClauseOps &result) const { return markClauseOccurrence(result.nowait); } bool ClauseProcessor::processNumTeams( lower::StatementContext &stmtCtx, mlir::omp::NumTeamsClauseOps &result) const { // TODO Get lower and upper bounds for num_teams when parser is updated to // accept both. if (auto *clause = findUniqueClause()) { // The num_teams directive accepts a list of team lower/upper bounds. // This is an extension to support grid specification for ompx_bare. // Here, only expect a single element in the list. assert(clause->v.size() == 1); // auto lowerBound = std::get>(clause->v[0]->t); auto &upperBound = std::get(clause->v[0].t); result.numTeamsUpper = fir::getBase(converter.genExprValue(upperBound, stmtCtx)); return true; } return false; } bool ClauseProcessor::processNumThreads( lower::StatementContext &stmtCtx, mlir::omp::NumThreadsClauseOps &result) const { if (auto *clause = findUniqueClause()) { // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`. result.numThreads = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); return true; } return false; } bool ClauseProcessor::processOrder(mlir::omp::OrderClauseOps &result) const { using Order = omp::clause::Order; if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); result.order = mlir::omp::ClauseOrderKindAttr::get( firOpBuilder.getContext(), mlir::omp::ClauseOrderKind::Concurrent); const auto &modifier = std::get>(clause->t); if (modifier && *modifier == Order::OrderModifier::Unconstrained) { result.orderMod = mlir::omp::OrderModifierAttr::get( firOpBuilder.getContext(), mlir::omp::OrderModifier::unconstrained); } else { // "If order-modifier is not unconstrained, the behavior is as if the // reproducible modifier is present." result.orderMod = mlir::omp::OrderModifierAttr::get( firOpBuilder.getContext(), mlir::omp::OrderModifier::reproducible); } return true; } return false; } bool ClauseProcessor::processOrdered( mlir::omp::OrderedClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); int64_t orderedClauseValue = 0l; if (clause->v.has_value()) orderedClauseValue = *evaluate::ToInt64(*clause->v); result.ordered = firOpBuilder.getI64IntegerAttr(orderedClauseValue); return true; } return false; } bool ClauseProcessor::processPriority( lower::StatementContext &stmtCtx, mlir::omp::PriorityClauseOps &result) const { if (auto *clause = findUniqueClause()) { result.priority = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); return true; } return false; } bool ClauseProcessor::processDetach(mlir::omp::DetachClauseOps &result) const { if (auto *clause = findUniqueClause()) { semantics::Symbol *sym = clause->v.sym(); mlir::Value symVal = converter.getSymbolAddress(*sym); result.eventHandle = symVal; return true; } return false; } bool ClauseProcessor::processProcBind( mlir::omp::ProcBindClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); result.procBindKind = genProcBindKindAttr(firOpBuilder, *clause); return true; } return false; } bool ClauseProcessor::processSafelen( mlir::omp::SafelenClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); const std::optional safelenVal = evaluate::ToInt64(clause->v); result.safelen = firOpBuilder.getI64IntegerAttr(*safelenVal); return true; } return false; } bool ClauseProcessor::processSchedule( lower::StatementContext &stmtCtx, mlir::omp::ScheduleClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::MLIRContext *context = firOpBuilder.getContext(); const auto &scheduleType = std::get(clause->t); mlir::omp::ClauseScheduleKind scheduleKind; switch (scheduleType) { case omp::clause::Schedule::Kind::Static: scheduleKind = mlir::omp::ClauseScheduleKind::Static; break; case omp::clause::Schedule::Kind::Dynamic: scheduleKind = mlir::omp::ClauseScheduleKind::Dynamic; break; case omp::clause::Schedule::Kind::Guided: scheduleKind = mlir::omp::ClauseScheduleKind::Guided; break; case omp::clause::Schedule::Kind::Auto: scheduleKind = mlir::omp::ClauseScheduleKind::Auto; break; case omp::clause::Schedule::Kind::Runtime: scheduleKind = mlir::omp::ClauseScheduleKind::Runtime; break; } result.scheduleKind = mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind); mlir::omp::ScheduleModifier scheduleMod = getScheduleModifier(*clause); if (scheduleMod != mlir::omp::ScheduleModifier::none) result.scheduleMod = mlir::omp::ScheduleModifierAttr::get(context, scheduleMod); if (getSimdModifier(*clause) != mlir::omp::ScheduleModifier::none) result.scheduleSimd = firOpBuilder.getUnitAttr(); if (const auto &chunkExpr = std::get(clause->t)) result.scheduleChunk = fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx)); return true; } return false; } bool ClauseProcessor::processSimdlen( mlir::omp::SimdlenClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); const std::optional simdlenVal = evaluate::ToInt64(clause->v); result.simdlen = firOpBuilder.getI64IntegerAttr(*simdlenVal); return true; } return false; } bool ClauseProcessor::processThreadLimit( lower::StatementContext &stmtCtx, mlir::omp::ThreadLimitClauseOps &result) const { if (auto *clause = findUniqueClause()) { result.threadLimit = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); return true; } return false; } bool ClauseProcessor::processUntied(mlir::omp::UntiedClauseOps &result) const { return markClauseOccurrence(result.untied); } //===----------------------------------------------------------------------===// // ClauseProcessor repeatable clauses //===----------------------------------------------------------------------===// static llvm::StringMap getTargetFeatures(mlir::ModuleOp module) { llvm::StringMap featuresMap; llvm::SmallVector targetFeaturesVec; if (mlir::LLVM::TargetFeaturesAttr features = fir::getTargetFeatures(module)) { llvm::ArrayRef featureAttrs = features.getFeatures(); for (auto &featureAttr : featureAttrs) { llvm::StringRef featureKeyString = featureAttr.strref(); featuresMap[featureKeyString.substr(1)] = (featureKeyString[0] == '+'); } } return featuresMap; } static void addAlignedClause(lower::AbstractConverter &converter, const omp::clause::Aligned &clause, llvm::SmallVectorImpl &alignedVars, llvm::SmallVectorImpl &alignments) { using Aligned = omp::clause::Aligned; lower::StatementContext stmtCtx; mlir::IntegerAttr alignmentValueAttr; int64_t alignment = 0; fir::FirOpBuilder &builder = converter.getFirOpBuilder(); if (auto &alignmentValueParserExpr = std::get>(clause.t)) { mlir::Value operand = fir::getBase( converter.genExprValue(*alignmentValueParserExpr, stmtCtx)); alignment = *fir::getIntIfConstant(operand); } else { llvm::StringMap featuresMap = getTargetFeatures(builder.getModule()); llvm::Triple triple = fir::getTargetTriple(builder.getModule()); alignment = llvm::OpenMPIRBuilder::getOpenMPDefaultSimdAlign(triple, featuresMap); } // The default alignment for some targets is equal to 0. // Do not generate alignment assumption if alignment is less than or equal to // 0. if (alignment > 0) { // alignment value must be power of 2 assert((alignment & (alignment - 1)) == 0 && "alignment is not power of 2"); auto &objects = std::get(clause.t); if (!objects.empty()) genObjectList(objects, converter, alignedVars); alignmentValueAttr = builder.getI64IntegerAttr(alignment); // All the list items in a aligned clause will have same alignment for (std::size_t i = 0; i < objects.size(); i++) alignments.push_back(alignmentValueAttr); } } bool ClauseProcessor::processAligned( mlir::omp::AlignedClauseOps &result) const { return findRepeatableClause( [&](const omp::clause::Aligned &clause, const parser::CharBlock &) { addAlignedClause(converter, clause, result.alignedVars, result.alignments); }); } bool ClauseProcessor::processAllocate( mlir::omp::AllocateClauseOps &result) const { return findRepeatableClause( [&](const omp::clause::Allocate &clause, const parser::CharBlock &) { genAllocateClause(converter, clause, result.allocatorVars, result.allocateVars); }); } bool ClauseProcessor::processCopyin() const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint(); firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); auto checkAndCopyHostAssociateVar = [&](semantics::Symbol *sym, mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr) { assert(sym->has() && "No host-association found"); if (converter.isPresentShallowLookup(*sym)) converter.copyHostAssociateVar(*sym, copyAssignIP); }; bool hasCopyin = findRepeatableClause( [&](const omp::clause::Copyin &clause, const parser::CharBlock &) { for (const omp::Object &object : clause.v) { semantics::Symbol *sym = object.sym(); assert(sym && "Expecting symbol"); if (const auto *commonDetails = sym->detailsIf()) { for (const auto &mem : commonDetails->objects()) checkAndCopyHostAssociateVar(&*mem, &insPt); break; } assert(sym->has() && "No host-association found"); checkAndCopyHostAssociateVar(sym); } }); // [OMP 5.0, 2.19.6.1] The copy is done after the team is formed and prior to // the execution of the associated structured block. Emit implicit barrier to // synchronize threads and avoid data races on propagation master's thread // values of threadprivate variables to local instances of that variables of // all other implicit threads. // All copies are inserted at either "insPt" (i.e. immediately before it), // or at some earlier point (as determined by "copyHostAssociateVar"). // Unless the insertion point is given to "copyHostAssociateVar" explicitly, // it will not restore the builder's insertion point. Since the copies may be // inserted in any order (not following the execution order), make sure the // barrier is inserted following all of them. firOpBuilder.restoreInsertionPoint(insPt); if (hasCopyin) firOpBuilder.create(converter.getCurrentLocation()); return hasCopyin; } /// Class that extracts information from the specified type. class TypeInfo { public: TypeInfo(mlir::Type ty) { typeScan(ty); } // Returns the length of character types. std::optional getCharLength() const { return charLen; } // Returns the shape of array types. llvm::ArrayRef getShape() const { return shape; } // Is the type inside a box? bool isBox() const { return inBox; } private: void typeScan(mlir::Type type); std::optional charLen; llvm::SmallVector shape; bool inBox = false; }; void TypeInfo::typeScan(mlir::Type ty) { if (auto sty = mlir::dyn_cast(ty)) { assert(shape.empty() && !sty.getShape().empty()); shape = llvm::SmallVector(sty.getShape()); typeScan(sty.getEleTy()); } else if (auto bty = mlir::dyn_cast(ty)) { inBox = true; typeScan(bty.getEleTy()); } else if (auto cty = mlir::dyn_cast(ty)) { charLen = cty.getLen(); } else if (auto hty = mlir::dyn_cast(ty)) { typeScan(hty.getEleTy()); } else if (auto pty = mlir::dyn_cast(ty)) { typeScan(pty.getEleTy()); } else { // The scan ends when reaching any built-in or record type. assert(ty.isIntOrIndexOrFloat() || mlir::isa(ty) || mlir::isa(ty) || mlir::isa(ty)); } } // Create a function that performs a copy between two variables, compatible // with their types and attributes. static mlir::func::FuncOp createCopyFunc(mlir::Location loc, lower::AbstractConverter &converter, mlir::Type varType, fir::FortranVariableFlagsEnum varAttrs) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); mlir::ModuleOp module = builder.getModule(); mlir::Type eleTy = mlir::cast(varType).getEleTy(); TypeInfo typeInfo(eleTy); std::string copyFuncName = fir::getTypeAsString(eleTy, builder.getKindMap(), "_copy"); if (auto decl = module.lookupSymbol(copyFuncName)) return decl; // create function mlir::OpBuilder::InsertionGuard guard(builder); mlir::OpBuilder modBuilder(module.getBodyRegion()); llvm::SmallVector argsTy = {varType, varType}; auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {}); mlir::func::FuncOp funcOp = modBuilder.create(loc, copyFuncName, funcType); funcOp.setVisibility(mlir::SymbolTable::Visibility::Private); fir::factory::setInternalLinkage(funcOp); builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy, {loc, loc}); builder.setInsertionPointToStart(&funcOp.getRegion().back()); // generate body fir::FortranVariableFlagsAttr attrs; if (varAttrs != fir::FortranVariableFlagsEnum::None) attrs = fir::FortranVariableFlagsAttr::get(builder.getContext(), varAttrs); llvm::SmallVector typeparams; if (typeInfo.getCharLength().has_value()) { mlir::Value charLen = builder.createIntegerConstant( loc, builder.getCharacterLengthType(), *typeInfo.getCharLength()); typeparams.push_back(charLen); } mlir::Value shape; if (!typeInfo.isBox() && !typeInfo.getShape().empty()) { llvm::SmallVector extents; for (auto extent : typeInfo.getShape()) extents.push_back( builder.createIntegerConstant(loc, builder.getIndexType(), extent)); shape = builder.create(loc, extents); } auto declDst = builder.create( loc, funcOp.getArgument(0), copyFuncName + "_dst", shape, typeparams, /*dummy_scope=*/nullptr, attrs); auto declSrc = builder.create( loc, funcOp.getArgument(1), copyFuncName + "_src", shape, typeparams, /*dummy_scope=*/nullptr, attrs); converter.copyVar(loc, declDst.getBase(), declSrc.getBase(), varAttrs); builder.create(loc); return funcOp; } bool ClauseProcessor::processCopyprivate( mlir::Location currentLocation, mlir::omp::CopyprivateClauseOps &result) const { auto addCopyPrivateVar = [&](semantics::Symbol *sym) { mlir::Value symVal = converter.getSymbolAddress(*sym); auto declOp = symVal.getDefiningOp(); if (!declOp) fir::emitFatalError(currentLocation, "COPYPRIVATE is supported only in HLFIR mode"); symVal = declOp.getBase(); mlir::Type symType = symVal.getType(); fir::FortranVariableFlagsEnum attrs = declOp.getFortranAttrs().has_value() ? *declOp.getFortranAttrs() : fir::FortranVariableFlagsEnum::None; mlir::Value cpVar = symVal; // CopyPrivate variables must be passed by reference. However, in the case // of assumed shapes/vla the type is not a !fir.ref, but a !fir.box. // In these cases to retrieve the appropriate !fir.ref> to // access the data we need we must perform an alloca and then store to it // and retrieve the data from the new alloca. if (mlir::isa(symType)) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); auto alloca = builder.create(currentLocation, symType); builder.create(currentLocation, symVal, alloca); cpVar = alloca; } result.copyprivateVars.push_back(cpVar); mlir::func::FuncOp funcOp = createCopyFunc(currentLocation, converter, cpVar.getType(), attrs); result.copyprivateSyms.push_back(mlir::SymbolRefAttr::get(funcOp)); }; bool hasCopyPrivate = findRepeatableClause( [&](const clause::Copyprivate &clause, const parser::CharBlock &) { for (const Object &object : clause.v) { semantics::Symbol *sym = object.sym(); if (const auto *commonDetails = sym->detailsIf()) { for (const auto &mem : commonDetails->objects()) addCopyPrivateVar(&*mem); break; } addCopyPrivateVar(sym); } }); return hasCopyPrivate; } bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const { auto process = [&](const omp::clause::Depend &clause, const parser::CharBlock &) { using Depend = omp::clause::Depend; if (!std::holds_alternative(clause.u)) { TODO(converter.getCurrentLocation(), "DEPEND clause with SINK or SOURCE is not supported yet"); } auto &taskDep = std::get(clause.u); auto depType = std::get(taskDep.t); auto &objects = std::get(taskDep.t); if (std::get>(taskDep.t)) { TODO(converter.getCurrentLocation(), "Support for iterator modifiers is not implemented yet"); } mlir::omp::ClauseTaskDependAttr dependTypeOperand = genDependKindAttr(converter, depType); result.dependKinds.append(objects.size(), dependTypeOperand); for (const omp::Object &object : objects) { assert(object.ref() && "Expecting designator"); if (evaluate::ExtractSubstring(*object.ref())) { TODO(converter.getCurrentLocation(), "substring not supported for task depend"); } else if (evaluate::IsArrayElement(*object.ref())) { TODO(converter.getCurrentLocation(), "array sections not supported for task depend"); } semantics::Symbol *sym = object.sym(); const mlir::Value variable = converter.getSymbolAddress(*sym); result.dependVars.push_back(variable); } }; return findRepeatableClause(process); } bool ClauseProcessor::processHasDeviceAddr( mlir::omp::HasDeviceAddrClauseOps &result, llvm::SmallVectorImpl &isDeviceSyms) const { return findRepeatableClause( [&](const omp::clause::HasDeviceAddr &devAddrClause, const parser::CharBlock &) { addUseDeviceClause(converter, devAddrClause.v, result.hasDeviceAddrVars, isDeviceSyms); }); } bool ClauseProcessor::processIf( omp::clause::If::DirectiveNameModifier directiveName, mlir::omp::IfClauseOps &result) const { bool found = false; findRepeatableClause([&](const omp::clause::If &clause, const parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); mlir::Value operand = getIfClauseOperand(converter, clause, directiveName, clauseLocation); // Assume that, at most, a single 'if' clause will be applicable to the // given directive. if (operand) { result.ifExpr = operand; found = true; } }); return found; } bool ClauseProcessor::processIsDevicePtr( mlir::omp::IsDevicePtrClauseOps &result, llvm::SmallVectorImpl &isDeviceSyms) const { return findRepeatableClause( [&](const omp::clause::IsDevicePtr &devPtrClause, const parser::CharBlock &) { addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars, isDeviceSyms); }); } bool ClauseProcessor::processLink( llvm::SmallVectorImpl &result) const { return findRepeatableClause( [&](const omp::clause::Link &clause, const parser::CharBlock &) { // Case: declare target link(var1, var2)... gatherFuncAndVarSyms( clause.v, mlir::omp::DeclareTargetCaptureClause::link, result); }); } void ClauseProcessor::processMapObjects( lower::StatementContext &stmtCtx, mlir::Location clauseLocation, const omp::ObjectList &objects, llvm::omp::OpenMPOffloadMappingFlags mapTypeBits, std::map &parentMemberIndices, llvm::SmallVectorImpl &mapVars, llvm::SmallVectorImpl &mapSyms) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); for (const omp::Object &object : objects) { llvm::SmallVector bounds; std::stringstream asFortran; std::optional parentObj; fir::factory::AddrAndBoundsInfo info = lower::gatherDataOperandAddrAndBounds( converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(), object.ref(), clauseLocation, asFortran, bounds, treatIndexAsSection); mlir::Value baseOp = info.rawInput; if (object.sym()->owner().IsDerivedType()) { omp::ObjectList objectList = gatherObjectsOf(object, semaCtx); assert(!objectList.empty() && "could not find parent objects of derived type member"); parentObj = objectList[0]; parentMemberIndices.emplace(parentObj.value(), OmpMapParentAndMemberData{}); if (isMemberOrParentAllocatableOrPointer(object, semaCtx)) { llvm::SmallVector indices; generateMemberPlacementIndices(object, indices, semaCtx); baseOp = createParentSymAndGenIntermediateMaps( clauseLocation, converter, semaCtx, stmtCtx, objectList, indices, parentMemberIndices[parentObj.value()], asFortran.str(), mapTypeBits); } } // Explicit map captures are captured ByRef by default, // optimisation passes may alter this to ByCopy or other capture // types to optimise auto location = mlir::NameLoc::get( mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()), baseOp.getLoc()); mlir::omp::MapInfoOp mapOp = createMapInfoOp( firOpBuilder, location, baseOp, /*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds, /*members=*/{}, /*membersIndex=*/mlir::ArrayAttr{}, static_cast< std::underlying_type_t>( mapTypeBits), mlir::omp::VariableCaptureKind::ByRef, baseOp.getType()); if (parentObj.has_value()) { parentMemberIndices[parentObj.value()].addChildIndexAndMapToParent( object, mapOp, semaCtx); } else { mapVars.push_back(mapOp); mapSyms.push_back(object.sym()); } } } bool ClauseProcessor::processMap( mlir::Location currentLocation, lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result, llvm::SmallVectorImpl *mapSyms) const { // We always require tracking of symbols, even if the caller does not, // so we create an optionally used local set of symbols when the mapSyms // argument is not present. llvm::SmallVector localMapSyms; llvm::SmallVectorImpl *ptrMapSyms = mapSyms ? mapSyms : &localMapSyms; std::map parentMemberIndices; auto process = [&](const omp::clause::Map &clause, const parser::CharBlock &source) { using Map = omp::clause::Map; mlir::Location clauseLocation = converter.genLocation(source); const auto &[mapType, typeMods, mappers, iterator, objects] = clause.t; llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; // If the map type is specified, then process it else Tofrom is the // default. Map::MapType type = mapType.value_or(Map::MapType::Tofrom); switch (type) { case Map::MapType::To: mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; break; case Map::MapType::From: mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; break; case Map::MapType::Tofrom: mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; break; case Map::MapType::Alloc: case Map::MapType::Release: // alloc and release is the default map_type for the Target Data // Ops, i.e. if no bits for map_type is supplied then alloc/release // is implicitly assumed based on the target directive. Default // value for Target Data and Enter Data is alloc and for Exit Data // it is release. break; case Map::MapType::Delete: mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; } if (typeMods) { if (llvm::is_contained(*typeMods, Map::MapTypeModifier::Always)) mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; // Diagnose unimplemented map-type-modifiers. if (llvm::any_of(*typeMods, [](Map::MapTypeModifier m) { return m != Map::MapTypeModifier::Always; })) { TODO(currentLocation, "Map type modifiers (other than 'ALWAYS')" " are not implemented yet"); } } if (iterator) { TODO(currentLocation, "Support for iterator modifiers is not implemented yet"); } if (mappers) { TODO(currentLocation, "Support for mapper modifiers is not implemented yet"); } processMapObjects(stmtCtx, clauseLocation, std::get(clause.t), mapTypeBits, parentMemberIndices, result.mapVars, *ptrMapSyms); }; bool clauseFound = findRepeatableClause(process); insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices, result.mapVars, *ptrMapSyms); return clauseFound; } bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result) { std::map parentMemberIndices; llvm::SmallVector mapSymbols; auto callbackFn = [&](const auto &clause, const parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); const auto &[expectation, mapper, iterator, objects] = clause.t; // TODO Support motion modifiers: present, mapper, iterator. if (expectation) { TODO(clauseLocation, "PRESENT modifier is not supported yet"); } else if (mapper) { TODO(clauseLocation, "Mapper modifier is not supported yet"); } else if (iterator) { TODO(clauseLocation, "Iterator modifier is not supported yet"); } constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = std::is_same_v, omp::clause::To> ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; processMapObjects(stmtCtx, clauseLocation, objects, mapTypeBits, parentMemberIndices, result.mapVars, mapSymbols); }; bool clauseFound = findRepeatableClause(callbackFn); clauseFound = findRepeatableClause(callbackFn) || clauseFound; insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices, result.mapVars, mapSymbols); return clauseFound; } bool ClauseProcessor::processNontemporal( mlir::omp::NontemporalClauseOps &result) const { return findRepeatableClause( [&](const omp::clause::Nontemporal &clause, const parser::CharBlock &) { for (const Object &object : clause.v) { semantics::Symbol *sym = object.sym(); mlir::Value symVal = converter.getSymbolAddress(*sym); result.nontemporalVars.push_back(symVal); } }); } bool ClauseProcessor::processReduction( mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result, llvm::SmallVectorImpl &outReductionSyms) const { return findRepeatableClause( [&](const omp::clause::Reduction &clause, const parser::CharBlock &) { llvm::SmallVector reductionVars; llvm::SmallVector reduceVarByRef; llvm::SmallVector reductionDeclSymbols; llvm::SmallVector reductionSyms; ReductionProcessor rp; rp.addDeclareReduction(currentLocation, converter, clause, reductionVars, reduceVarByRef, reductionDeclSymbols, reductionSyms); // Copy local lists into the output. llvm::copy(reductionVars, std::back_inserter(result.reductionVars)); llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref)); llvm::copy(reductionDeclSymbols, std::back_inserter(result.reductionSyms)); llvm::copy(reductionSyms, std::back_inserter(outReductionSyms)); }); } bool ClauseProcessor::processTo( llvm::SmallVectorImpl &result) const { return findRepeatableClause( [&](const omp::clause::To &clause, const parser::CharBlock &) { // Case: declare target to(func, var1, var2)... gatherFuncAndVarSyms(std::get(clause.t), mlir::omp::DeclareTargetCaptureClause::to, result); }); } bool ClauseProcessor::processEnter( llvm::SmallVectorImpl &result) const { return findRepeatableClause( [&](const omp::clause::Enter &clause, const parser::CharBlock &) { // Case: declare target enter(func, var1, var2)... gatherFuncAndVarSyms( clause.v, mlir::omp::DeclareTargetCaptureClause::enter, result); }); } bool ClauseProcessor::processUseDeviceAddr( lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result, llvm::SmallVectorImpl &useDeviceSyms) const { std::map parentMemberIndices; bool clauseFound = findRepeatableClause( [&](const omp::clause::UseDeviceAddr &clause, const parser::CharBlock &source) { mlir::Location location = converter.genLocation(source); llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; processMapObjects(stmtCtx, location, clause.v, mapTypeBits, parentMemberIndices, result.useDeviceAddrVars, useDeviceSyms); }); insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices, result.useDeviceAddrVars, useDeviceSyms); return clauseFound; } bool ClauseProcessor::processUseDevicePtr( lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result, llvm::SmallVectorImpl &useDeviceSyms) const { std::map parentMemberIndices; bool clauseFound = findRepeatableClause( [&](const omp::clause::UseDevicePtr &clause, const parser::CharBlock &source) { mlir::Location location = converter.genLocation(source); llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; processMapObjects(stmtCtx, location, clause.v, mapTypeBits, parentMemberIndices, result.useDevicePtrVars, useDeviceSyms); }); insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices, result.useDevicePtrVars, useDeviceSyms); return clauseFound; } } // namespace omp } // namespace lower } // namespace Fortran