1 //===-- Lower/OpenMP/ClauseProcessor.h --------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 #ifndef FORTRAN_LOWER_CLAUSEPROCESSOR_H 13 #define FORTRAN_LOWER_CLAUSEPROCESSOR_H 14 15 #include "Clauses.h" 16 #include "DirectivesCommon.h" 17 #include "ReductionProcessor.h" 18 #include "Utils.h" 19 #include "flang/Lower/AbstractConverter.h" 20 #include "flang/Lower/Bridge.h" 21 #include "flang/Optimizer/Builder/Todo.h" 22 #include "flang/Parser/dump-parse-tree.h" 23 #include "flang/Parser/parse-tree.h" 24 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 25 26 namespace fir { 27 class FirOpBuilder; 28 } // namespace fir 29 30 namespace Fortran { 31 namespace lower { 32 namespace omp { 33 34 /// Class that handles the processing of OpenMP clauses. 35 /// 36 /// Its `process<ClauseName>()` methods perform MLIR code generation for their 37 /// corresponding clause if it is present in the clause list. Otherwise, they 38 /// will return `false` to signal that the clause was not found. 39 /// 40 /// The intended use of this class is to move clause processing outside of 41 /// construct processing, since the same clauses can appear attached to 42 /// different constructs and constructs can be combined, so that code 43 /// duplication is minimized. 44 /// 45 /// Each construct-lowering function only calls the `process<ClauseName>()` 46 /// methods that relate to clauses that can impact the lowering of that 47 /// construct. 48 class ClauseProcessor { 49 public: 50 ClauseProcessor(lower::AbstractConverter &converter, 51 semantics::SemanticsContext &semaCtx, 52 const List<Clause> &clauses) 53 : converter(converter), semaCtx(semaCtx), clauses(clauses) {} 54 55 // 'Unique' clauses: They can appear at most once in the clause list. 56 bool 57 processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval, 58 mlir::omp::CollapseClauseOps &result, 59 llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const; 60 bool processDefault() const; 61 bool processDevice(lower::StatementContext &stmtCtx, 62 mlir::omp::DeviceClauseOps &result) const; 63 bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const; 64 bool processFinal(lower::StatementContext &stmtCtx, 65 mlir::omp::FinalClauseOps &result) const; 66 bool processHasDeviceAddr( 67 mlir::omp::HasDeviceAddrClauseOps &result, 68 llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes, 69 llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs, 70 llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const; 71 bool processHint(mlir::omp::HintClauseOps &result) const; 72 bool processMergeable(mlir::omp::MergeableClauseOps &result) const; 73 bool processNowait(mlir::omp::NowaitClauseOps &result) const; 74 bool processNumTeams(lower::StatementContext &stmtCtx, 75 mlir::omp::NumTeamsClauseOps &result) const; 76 bool processNumThreads(lower::StatementContext &stmtCtx, 77 mlir::omp::NumThreadsClauseOps &result) const; 78 bool processOrdered(mlir::omp::OrderedClauseOps &result) const; 79 bool processPriority(lower::StatementContext &stmtCtx, 80 mlir::omp::PriorityClauseOps &result) const; 81 bool processProcBind(mlir::omp::ProcBindClauseOps &result) const; 82 bool processSafelen(mlir::omp::SafelenClauseOps &result) const; 83 bool processSchedule(lower::StatementContext &stmtCtx, 84 mlir::omp::ScheduleClauseOps &result) const; 85 bool processSimdlen(mlir::omp::SimdlenClauseOps &result) const; 86 bool processThreadLimit(lower::StatementContext &stmtCtx, 87 mlir::omp::ThreadLimitClauseOps &result) const; 88 bool processUntied(mlir::omp::UntiedClauseOps &result) const; 89 90 // 'Repeatable' clauses: They can appear multiple times in the clause list. 91 bool processAllocate(mlir::omp::AllocateClauseOps &result) const; 92 bool processCopyin() const; 93 bool processCopyprivate(mlir::Location currentLocation, 94 mlir::omp::CopyprivateClauseOps &result) const; 95 bool processDepend(mlir::omp::DependClauseOps &result) const; 96 bool 97 processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; 98 bool processIf(omp::clause::If::DirectiveNameModifier directiveName, 99 mlir::omp::IfClauseOps &result) const; 100 bool processIsDevicePtr( 101 mlir::omp::IsDevicePtrClauseOps &result, 102 llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes, 103 llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs, 104 llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const; 105 bool 106 processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; 107 108 // This method is used to process a map clause. 109 // The optional parameters - mapSymTypes, mapSymLocs & mapSyms are used to 110 // store the original type, location and Fortran symbol for the map operands. 111 // They may be used later on to create the block_arguments for some of the 112 // target directives that require it. 113 bool processMap( 114 mlir::Location currentLocation, lower::StatementContext &stmtCtx, 115 mlir::omp::MapClauseOps &result, 116 llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms = nullptr, 117 llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr, 118 llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const; 119 bool processReduction( 120 mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result, 121 llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr, 122 llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSyms = 123 nullptr) const; 124 bool processSectionsReduction(mlir::Location currentLocation, 125 mlir::omp::ReductionClauseOps &result) const; 126 bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; 127 bool processUseDeviceAddr( 128 mlir::omp::UseDeviceClauseOps &result, 129 llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, 130 llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, 131 llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const; 132 bool processUseDevicePtr( 133 mlir::omp::UseDeviceClauseOps &result, 134 llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, 135 llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, 136 llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const; 137 138 template <typename T> 139 bool processMotionClauses(lower::StatementContext &stmtCtx, 140 mlir::omp::MapClauseOps &result); 141 142 // Call this method for these clauses that should be supported but are not 143 // implemented yet. It triggers a compilation error if any of the given 144 // clauses is found. 145 template <typename... Ts> 146 void processTODO(mlir::Location currentLocation, 147 llvm::omp::Directive directive) const; 148 149 private: 150 using ClauseIterator = List<Clause>::const_iterator; 151 152 /// Utility to find a clause within a range in the clause list. 153 template <typename T> 154 static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end); 155 156 /// Return the first instance of the given clause found in the clause list or 157 /// `nullptr` if not present. If more than one instance is expected, use 158 /// `findRepeatableClause` instead. 159 template <typename T> 160 const T *findUniqueClause(const parser::CharBlock **source = nullptr) const; 161 162 /// Call `callbackFn` for each occurrence of the given clause. Return `true` 163 /// if at least one instance was found. 164 template <typename T> 165 bool findRepeatableClause( 166 std::function<void(const T &, const parser::CharBlock &source)> 167 callbackFn) const; 168 169 /// Set the `result` to a new `mlir::UnitAttr` if the clause is present. 170 template <typename T> 171 bool markClauseOccurrence(mlir::UnitAttr &result) const; 172 173 lower::AbstractConverter &converter; 174 semantics::SemanticsContext &semaCtx; 175 List<Clause> clauses; 176 }; 177 178 template <typename T> 179 bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx, 180 mlir::omp::MapClauseOps &result) { 181 std::map<const semantics::Symbol *, 182 llvm::SmallVector<OmpMapMemberIndicesData>> 183 parentMemberIndices; 184 llvm::SmallVector<const semantics::Symbol *> mapSymbols; 185 186 bool clauseFound = findRepeatableClause<T>( 187 [&](const T &clause, const parser::CharBlock &source) { 188 mlir::Location clauseLocation = converter.genLocation(source); 189 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 190 191 static_assert(std::is_same_v<T, omp::clause::To> || 192 std::is_same_v<T, omp::clause::From>); 193 194 // TODO Support motion modifiers: present, mapper, iterator. 195 constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = 196 std::is_same_v<T, omp::clause::To> 197 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO 198 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; 199 200 auto &objects = std::get<ObjectList>(clause.t); 201 for (const omp::Object &object : objects) { 202 llvm::SmallVector<mlir::Value> bounds; 203 std::stringstream asFortran; 204 205 lower::AddrAndBoundsInfo info = 206 lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp, 207 mlir::omp::MapBoundsType>( 208 converter, firOpBuilder, semaCtx, stmtCtx, *object.id(), 209 object.ref(), clauseLocation, asFortran, bounds, 210 treatIndexAsSection); 211 212 auto origSymbol = converter.getSymbolAddress(*object.id()); 213 mlir::Value symAddr = info.addr; 214 if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType())) 215 symAddr = origSymbol; 216 217 // Explicit map captures are captured ByRef by default, 218 // optimisation passes may alter this to ByCopy or other capture 219 // types to optimise 220 mlir::omp::MapInfoOp mapOp = createMapInfoOp( 221 firOpBuilder, clauseLocation, symAddr, 222 /*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds, 223 /*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{}, 224 static_cast< 225 std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( 226 mapTypeBits), 227 mlir::omp::VariableCaptureKind::ByRef, symAddr.getType()); 228 229 if (object.id()->owner().IsDerivedType()) { 230 addChildIndexAndMapToParent(object, parentMemberIndices, mapOp, 231 semaCtx); 232 } else { 233 result.mapVars.push_back(mapOp); 234 mapSymbols.push_back(object.id()); 235 } 236 } 237 }); 238 239 insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars, 240 mapSymbols, 241 /*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr); 242 return clauseFound; 243 } 244 245 template <typename... Ts> 246 void ClauseProcessor::processTODO(mlir::Location currentLocation, 247 llvm::omp::Directive directive) const { 248 auto checkUnhandledClause = [&](llvm::omp::Clause id, const auto *x) { 249 if (!x) 250 return; 251 TODO(currentLocation, 252 "Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() + 253 " in " + llvm::omp::getOpenMPDirectiveName(directive).upper() + 254 " construct"); 255 }; 256 257 for (ClauseIterator it = clauses.begin(); it != clauses.end(); ++it) 258 (checkUnhandledClause(it->id, std::get_if<Ts>(&it->u)), ...); 259 } 260 261 template <typename T> 262 ClauseProcessor::ClauseIterator 263 ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) { 264 for (ClauseIterator it = begin; it != end; ++it) { 265 if (std::get_if<T>(&it->u)) 266 return it; 267 } 268 269 return end; 270 } 271 272 template <typename T> 273 const T * 274 ClauseProcessor::findUniqueClause(const parser::CharBlock **source) const { 275 ClauseIterator it = findClause<T>(clauses.begin(), clauses.end()); 276 if (it != clauses.end()) { 277 if (source) 278 *source = &it->source; 279 return &std::get<T>(it->u); 280 } 281 return nullptr; 282 } 283 284 template <typename T> 285 bool ClauseProcessor::findRepeatableClause( 286 std::function<void(const T &, const parser::CharBlock &source)> callbackFn) 287 const { 288 bool found = false; 289 ClauseIterator nextIt, endIt = clauses.end(); 290 for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) { 291 nextIt = findClause<T>(it, endIt); 292 293 if (nextIt != endIt) { 294 callbackFn(std::get<T>(nextIt->u), nextIt->source); 295 found = true; 296 ++nextIt; 297 } 298 } 299 return found; 300 } 301 302 template <typename T> 303 bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const { 304 if (findUniqueClause<T>()) { 305 result = converter.getFirOpBuilder().getUnitAttr(); 306 return true; 307 } 308 return false; 309 } 310 311 } // namespace omp 312 } // namespace lower 313 } // namespace Fortran 314 315 #endif // FORTRAN_LOWER_CLAUSEPROCESSOR_H 316