1 //===-- Lower/OpenMP/ClauseProcessor.h --------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 #ifndef FORTRAN_LOWER_CLAUSEPROCESSOR_H 13 #define FORTRAN_LOWER_CLAUSEPROCESSOR_H 14 15 #include "Clauses.h" 16 #include "ReductionProcessor.h" 17 #include "Utils.h" 18 #include "flang/Lower/AbstractConverter.h" 19 #include "flang/Lower/Bridge.h" 20 #include "flang/Lower/DirectivesCommon.h" 21 #include "flang/Optimizer/Builder/Todo.h" 22 #include "flang/Parser/dump-parse-tree.h" 23 #include "flang/Parser/parse-tree.h" 24 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 25 26 namespace fir { 27 class FirOpBuilder; 28 } // namespace fir 29 30 namespace Fortran { 31 namespace lower { 32 namespace omp { 33 34 /// Class that handles the processing of OpenMP clauses. 35 /// 36 /// Its `process<ClauseName>()` methods perform MLIR code generation for their 37 /// corresponding clause if it is present in the clause list. Otherwise, they 38 /// will return `false` to signal that the clause was not found. 39 /// 40 /// The intended use of this class is to move clause processing outside of 41 /// construct processing, since the same clauses can appear attached to 42 /// different constructs and constructs can be combined, so that code 43 /// duplication is minimized. 44 /// 45 /// Each construct-lowering function only calls the `process<ClauseName>()` 46 /// methods that relate to clauses that can impact the lowering of that 47 /// construct. 48 class ClauseProcessor { 49 public: 50 ClauseProcessor(lower::AbstractConverter &converter, 51 semantics::SemanticsContext &semaCtx, 52 const List<Clause> &clauses) 53 : converter(converter), semaCtx(semaCtx), clauses(clauses) {} 54 55 // 'Unique' clauses: They can appear at most once in the clause list. 56 bool processBare(mlir::omp::BareClauseOps &result) const; 57 bool processBind(mlir::omp::BindClauseOps &result) const; 58 bool 59 processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval, 60 mlir::omp::LoopRelatedClauseOps &result, 61 llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const; 62 bool processDevice(lower::StatementContext &stmtCtx, 63 mlir::omp::DeviceClauseOps &result) const; 64 bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const; 65 bool processDistSchedule(lower::StatementContext &stmtCtx, 66 mlir::omp::DistScheduleClauseOps &result) const; 67 bool processFilter(lower::StatementContext &stmtCtx, 68 mlir::omp::FilterClauseOps &result) const; 69 bool processFinal(lower::StatementContext &stmtCtx, 70 mlir::omp::FinalClauseOps &result) const; 71 bool processHasDeviceAddr( 72 mlir::omp::HasDeviceAddrClauseOps &result, 73 llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const; 74 bool processHint(mlir::omp::HintClauseOps &result) const; 75 bool processMergeable(mlir::omp::MergeableClauseOps &result) const; 76 bool processNowait(mlir::omp::NowaitClauseOps &result) const; 77 bool processNumTeams(lower::StatementContext &stmtCtx, 78 mlir::omp::NumTeamsClauseOps &result) const; 79 bool processNumThreads(lower::StatementContext &stmtCtx, 80 mlir::omp::NumThreadsClauseOps &result) const; 81 bool processOrder(mlir::omp::OrderClauseOps &result) const; 82 bool processOrdered(mlir::omp::OrderedClauseOps &result) const; 83 bool processPriority(lower::StatementContext &stmtCtx, 84 mlir::omp::PriorityClauseOps &result) const; 85 bool processProcBind(mlir::omp::ProcBindClauseOps &result) const; 86 bool processSafelen(mlir::omp::SafelenClauseOps &result) const; 87 bool processSchedule(lower::StatementContext &stmtCtx, 88 mlir::omp::ScheduleClauseOps &result) const; 89 bool processSimdlen(mlir::omp::SimdlenClauseOps &result) const; 90 bool processThreadLimit(lower::StatementContext &stmtCtx, 91 mlir::omp::ThreadLimitClauseOps &result) const; 92 bool processUntied(mlir::omp::UntiedClauseOps &result) const; 93 94 bool processDetach(mlir::omp::DetachClauseOps &result) const; 95 // 'Repeatable' clauses: They can appear multiple times in the clause list. 96 bool processAligned(mlir::omp::AlignedClauseOps &result) const; 97 bool processAllocate(mlir::omp::AllocateClauseOps &result) const; 98 bool processCopyin() const; 99 bool processCopyprivate(mlir::Location currentLocation, 100 mlir::omp::CopyprivateClauseOps &result) const; 101 bool processDepend(mlir::omp::DependClauseOps &result) const; 102 bool 103 processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; 104 bool processIf(omp::clause::If::DirectiveNameModifier directiveName, 105 mlir::omp::IfClauseOps &result) const; 106 bool processIsDevicePtr( 107 mlir::omp::IsDevicePtrClauseOps &result, 108 llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const; 109 bool 110 processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; 111 112 // This method is used to process a map clause. 113 // The optional parameter mapSyms is used to store the original Fortran symbol 114 // for the map operands. It may be used later on to create the block_arguments 115 // for some of the directives that require it. 116 bool processMap(mlir::Location currentLocation, 117 lower::StatementContext &stmtCtx, 118 mlir::omp::MapClauseOps &result, 119 llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms = 120 nullptr) const; 121 bool processMotionClauses(lower::StatementContext &stmtCtx, 122 mlir::omp::MapClauseOps &result); 123 bool processNontemporal(mlir::omp::NontemporalClauseOps &result) const; 124 bool processReduction( 125 mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result, 126 llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const; 127 bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; 128 bool processUseDeviceAddr( 129 lower::StatementContext &stmtCtx, 130 mlir::omp::UseDeviceAddrClauseOps &result, 131 llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const; 132 bool processUseDevicePtr( 133 lower::StatementContext &stmtCtx, 134 mlir::omp::UseDevicePtrClauseOps &result, 135 llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const; 136 137 // Call this method for these clauses that should be supported but are not 138 // implemented yet. It triggers a compilation error if any of the given 139 // clauses is found. 140 template <typename... Ts> 141 void processTODO(mlir::Location currentLocation, 142 llvm::omp::Directive directive) const; 143 144 private: 145 using ClauseIterator = List<Clause>::const_iterator; 146 147 /// Utility to find a clause within a range in the clause list. 148 template <typename T> 149 static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end); 150 151 /// Return the first instance of the given clause found in the clause list or 152 /// `nullptr` if not present. If more than one instance is expected, use 153 /// `findRepeatableClause` instead. 154 template <typename T> 155 const T *findUniqueClause(const parser::CharBlock **source = nullptr) const; 156 157 /// Call `callbackFn` for each occurrence of the given clause. Return `true` 158 /// if at least one instance was found. 159 template <typename T> 160 bool findRepeatableClause( 161 std::function<void(const T &, const parser::CharBlock &source)> 162 callbackFn) const; 163 164 /// Set the `result` to a new `mlir::UnitAttr` if the clause is present. 165 template <typename T> 166 bool markClauseOccurrence(mlir::UnitAttr &result) const; 167 168 void processMapObjects( 169 lower::StatementContext &stmtCtx, mlir::Location clauseLocation, 170 const omp::ObjectList &objects, 171 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits, 172 std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices, 173 llvm::SmallVectorImpl<mlir::Value> &mapVars, 174 llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const; 175 176 lower::AbstractConverter &converter; 177 semantics::SemanticsContext &semaCtx; 178 List<Clause> clauses; 179 }; 180 181 template <typename... Ts> 182 void ClauseProcessor::processTODO(mlir::Location currentLocation, 183 llvm::omp::Directive directive) const { 184 auto checkUnhandledClause = [&](llvm::omp::Clause id, const auto *x) { 185 if (!x) 186 return; 187 TODO(currentLocation, 188 "Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() + 189 " in " + llvm::omp::getOpenMPDirectiveName(directive).upper() + 190 " construct"); 191 }; 192 193 for (ClauseIterator it = clauses.begin(); it != clauses.end(); ++it) 194 (checkUnhandledClause(it->id, std::get_if<Ts>(&it->u)), ...); 195 } 196 197 template <typename T> 198 ClauseProcessor::ClauseIterator 199 ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) { 200 for (ClauseIterator it = begin; it != end; ++it) { 201 if (std::get_if<T>(&it->u)) 202 return it; 203 } 204 205 return end; 206 } 207 208 template <typename T> 209 const T * 210 ClauseProcessor::findUniqueClause(const parser::CharBlock **source) const { 211 ClauseIterator it = findClause<T>(clauses.begin(), clauses.end()); 212 if (it != clauses.end()) { 213 if (source) 214 *source = &it->source; 215 return &std::get<T>(it->u); 216 } 217 return nullptr; 218 } 219 220 template <typename T> 221 bool ClauseProcessor::findRepeatableClause( 222 std::function<void(const T &, const parser::CharBlock &source)> callbackFn) 223 const { 224 bool found = false; 225 ClauseIterator nextIt, endIt = clauses.end(); 226 for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) { 227 nextIt = findClause<T>(it, endIt); 228 229 if (nextIt != endIt) { 230 callbackFn(std::get<T>(nextIt->u), nextIt->source); 231 found = true; 232 ++nextIt; 233 } 234 } 235 return found; 236 } 237 238 template <typename T> 239 bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const { 240 if (findUniqueClause<T>()) { 241 result = converter.getFirOpBuilder().getUnitAttr(); 242 return true; 243 } 244 return false; 245 } 246 247 } // namespace omp 248 } // namespace lower 249 } // namespace Fortran 250 251 #endif // FORTRAN_LOWER_CLAUSEPROCESSOR_H 252