1 //===- GenericLoopConversion.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Common/OpenMP-utils.h" 10 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 13 #include "mlir/IR/IRMapping.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 17 #include <memory> 18 19 namespace flangomp { 20 #define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS 21 #include "flang/Optimizer/OpenMP/Passes.h.inc" 22 } // namespace flangomp 23 24 namespace { 25 26 /// A conversion pattern to handle various combined forms of `omp.loop`. For how 27 /// combined/composite directive are handled see: 28 /// https://discourse.llvm.org/t/rfc-representing-combined-composite-constructs-in-the-openmp-dialect/76986. 29 class GenericLoopConversionPattern 30 : public mlir::OpConversionPattern<mlir::omp::LoopOp> { 31 public: 32 enum class GenericLoopCombinedInfo { Standalone, TeamsLoop, ParallelLoop }; 33 34 using mlir::OpConversionPattern<mlir::omp::LoopOp>::OpConversionPattern; 35 36 explicit GenericLoopConversionPattern(mlir::MLIRContext *ctx) 37 : mlir::OpConversionPattern<mlir::omp::LoopOp>{ctx} { 38 // Enable rewrite recursion to make sure nested `loop` directives are 39 // handled. 40 this->setHasBoundedRewriteRecursion(true); 41 } 42 43 mlir::LogicalResult 44 matchAndRewrite(mlir::omp::LoopOp loopOp, OpAdaptor adaptor, 45 mlir::ConversionPatternRewriter &rewriter) const override { 46 assert(mlir::succeeded(checkLoopConversionSupportStatus(loopOp))); 47 48 GenericLoopCombinedInfo combinedInfo = findGenericLoopCombineInfo(loopOp); 49 50 switch (combinedInfo) { 51 case GenericLoopCombinedInfo::Standalone: 52 rewriteStandaloneLoop(loopOp, rewriter); 53 break; 54 case GenericLoopCombinedInfo::ParallelLoop: 55 llvm_unreachable( 56 "not yet implemented: Combined `parallel loop` directive"); 57 break; 58 case GenericLoopCombinedInfo::TeamsLoop: 59 rewriteToDistributeParallelDo(loopOp, rewriter); 60 break; 61 } 62 63 rewriter.eraseOp(loopOp); 64 return mlir::success(); 65 } 66 67 static mlir::LogicalResult 68 checkLoopConversionSupportStatus(mlir::omp::LoopOp loopOp) { 69 GenericLoopCombinedInfo combinedInfo = findGenericLoopCombineInfo(loopOp); 70 71 switch (combinedInfo) { 72 case GenericLoopCombinedInfo::Standalone: 73 break; 74 case GenericLoopCombinedInfo::ParallelLoop: 75 return loopOp.emitError( 76 "not yet implemented: Combined `parallel loop` directive"); 77 case GenericLoopCombinedInfo::TeamsLoop: 78 break; 79 } 80 81 auto todo = [&loopOp](mlir::StringRef clauseName) { 82 return loopOp.emitError() 83 << "not yet implemented: Unhandled clause " << clauseName << " in " 84 << loopOp->getName() << " operation"; 85 }; 86 87 // For standalone directives, `bind` is already supported. Other combined 88 // forms will be supported in a follow-up PR. 89 if (combinedInfo != GenericLoopCombinedInfo::Standalone && 90 loopOp.getBindKind()) 91 return todo("bind"); 92 93 if (loopOp.getOrder()) 94 return todo("order"); 95 96 if (!loopOp.getReductionVars().empty()) 97 return todo("reduction"); 98 99 // TODO For `teams loop`, check similar constrains to what is checked 100 // by `TeamsLoopChecker` in SemaOpenMP.cpp. 101 return mlir::success(); 102 } 103 104 private: 105 static GenericLoopCombinedInfo 106 findGenericLoopCombineInfo(mlir::omp::LoopOp loopOp) { 107 mlir::Operation *parentOp = loopOp->getParentOp(); 108 GenericLoopCombinedInfo result = GenericLoopCombinedInfo::Standalone; 109 110 if (auto teamsOp = mlir::dyn_cast_if_present<mlir::omp::TeamsOp>(parentOp)) 111 result = GenericLoopCombinedInfo::TeamsLoop; 112 113 if (auto parallelOp = 114 mlir::dyn_cast_if_present<mlir::omp::ParallelOp>(parentOp)) 115 result = GenericLoopCombinedInfo::ParallelLoop; 116 117 return result; 118 } 119 120 void rewriteStandaloneLoop(mlir::omp::LoopOp loopOp, 121 mlir::ConversionPatternRewriter &rewriter) const { 122 using namespace mlir::omp; 123 std::optional<ClauseBindKind> bindKind = loopOp.getBindKind(); 124 125 if (!bindKind.has_value()) 126 return rewriteToSimdLoop(loopOp, rewriter); 127 128 switch (*loopOp.getBindKind()) { 129 case ClauseBindKind::Parallel: 130 return rewriteToWsloop(loopOp, rewriter); 131 case ClauseBindKind::Teams: 132 return rewriteToDistrbute(loopOp, rewriter); 133 case ClauseBindKind::Thread: 134 return rewriteToSimdLoop(loopOp, rewriter); 135 } 136 } 137 138 /// Rewrites standalone `loop` (without `bind` clause or with 139 /// `bind(parallel)`) directives to equivalent `simd` constructs. 140 /// 141 /// The reasoning behind this decision is that according to the spec (version 142 /// 5.2, section 11.7.1): 143 /// 144 /// "If the bind clause is not specified on a construct for which it may be 145 /// specified and the construct is closely nested inside a teams or parallel 146 /// construct, the effect is as if binding is teams or parallel. If none of 147 /// those conditions hold, the binding region is not defined." 148 /// 149 /// which means that standalone `loop` directives have undefined binding 150 /// region. Moreover, the spec says (in the next paragraph): 151 /// 152 /// "The specified binding region determines the binding thread set. 153 /// Specifically, if the binding region is a teams region, then the binding 154 /// thread set is the set of initial threads that are executing that region 155 /// while if the binding region is a parallel region, then the binding thread 156 /// set is the team of threads that are executing that region. If the binding 157 /// region is not defined, then the binding thread set is the encountering 158 /// thread." 159 /// 160 /// which means that the binding thread set for a standalone `loop` directive 161 /// is only the encountering thread. 162 /// 163 /// Since the encountering thread is the binding thread (set) for a 164 /// standalone `loop` directive, the best we can do in such case is to "simd" 165 /// the directive. 166 void rewriteToSimdLoop(mlir::omp::LoopOp loopOp, 167 mlir::ConversionPatternRewriter &rewriter) const { 168 loopOp.emitWarning( 169 "Detected standalone OpenMP `loop` directive with thread binding, " 170 "the associated loop will be rewritten to `simd`."); 171 rewriteToSingleWrapperOp<mlir::omp::SimdOp, mlir::omp::SimdOperands>( 172 loopOp, rewriter); 173 } 174 175 void rewriteToDistrbute(mlir::omp::LoopOp loopOp, 176 mlir::ConversionPatternRewriter &rewriter) const { 177 rewriteToSingleWrapperOp<mlir::omp::DistributeOp, 178 mlir::omp::DistributeOperands>(loopOp, rewriter); 179 } 180 181 void rewriteToWsloop(mlir::omp::LoopOp loopOp, 182 mlir::ConversionPatternRewriter &rewriter) const { 183 rewriteToSingleWrapperOp<mlir::omp::WsloopOp, mlir::omp::WsloopOperands>( 184 loopOp, rewriter); 185 } 186 187 // TODO Suggestion by Sergio: tag auto-generated operations for constructs 188 // that weren't part of the original program, that would be useful 189 // information for debugging purposes later on. This new attribute could be 190 // used for `omp.loop`, but also for `do concurrent` transformations, 191 // `workshare`, `workdistribute`, etc. The tag could be used for all kinds of 192 // auto-generated operations using a dialect attribute (named something like 193 // `omp.origin` or `omp.derived`) and perhaps hold the name of the operation 194 // it was derived from, the reason it was transformed or something like that 195 // we could use when emitting any messages related to it later on. 196 template <typename OpTy, typename OpOperandsTy> 197 void 198 rewriteToSingleWrapperOp(mlir::omp::LoopOp loopOp, 199 mlir::ConversionPatternRewriter &rewriter) const { 200 OpOperandsTy clauseOps; 201 clauseOps.privateVars = loopOp.getPrivateVars(); 202 203 auto privateSyms = loopOp.getPrivateSyms(); 204 if (privateSyms) 205 clauseOps.privateSyms.assign(privateSyms->begin(), privateSyms->end()); 206 207 Fortran::common::openmp::EntryBlockArgs args; 208 args.priv.vars = clauseOps.privateVars; 209 210 auto wrapperOp = rewriter.create<OpTy>(loopOp.getLoc(), clauseOps); 211 mlir::Block *opBlock = genEntryBlock(rewriter, args, wrapperOp.getRegion()); 212 213 mlir::IRMapping mapper; 214 mlir::Block &loopBlock = *loopOp.getRegion().begin(); 215 216 for (auto [loopOpArg, opArg] : 217 llvm::zip_equal(loopBlock.getArguments(), opBlock->getArguments())) 218 mapper.map(loopOpArg, opArg); 219 220 rewriter.clone(*loopOp.begin(), mapper); 221 } 222 223 void rewriteToDistributeParallelDo( 224 mlir::omp::LoopOp loopOp, 225 mlir::ConversionPatternRewriter &rewriter) const { 226 mlir::omp::ParallelOperands parallelClauseOps; 227 parallelClauseOps.privateVars = loopOp.getPrivateVars(); 228 229 auto privateSyms = loopOp.getPrivateSyms(); 230 if (privateSyms) 231 parallelClauseOps.privateSyms.assign(privateSyms->begin(), 232 privateSyms->end()); 233 234 Fortran::common::openmp::EntryBlockArgs parallelArgs; 235 parallelArgs.priv.vars = parallelClauseOps.privateVars; 236 237 auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loopOp.getLoc(), 238 parallelClauseOps); 239 mlir::Block *parallelBlock = 240 genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion()); 241 parallelOp.setComposite(true); 242 rewriter.setInsertionPoint( 243 rewriter.create<mlir::omp::TerminatorOp>(loopOp.getLoc())); 244 245 mlir::omp::DistributeOperands distributeClauseOps; 246 auto distributeOp = rewriter.create<mlir::omp::DistributeOp>( 247 loopOp.getLoc(), distributeClauseOps); 248 distributeOp.setComposite(true); 249 rewriter.createBlock(&distributeOp.getRegion()); 250 251 mlir::omp::WsloopOperands wsloopClauseOps; 252 auto wsloopOp = 253 rewriter.create<mlir::omp::WsloopOp>(loopOp.getLoc(), wsloopClauseOps); 254 wsloopOp.setComposite(true); 255 rewriter.createBlock(&wsloopOp.getRegion()); 256 257 mlir::IRMapping mapper; 258 mlir::Block &loopBlock = *loopOp.getRegion().begin(); 259 260 for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal( 261 loopBlock.getArguments(), parallelBlock->getArguments())) 262 mapper.map(loopOpArg, parallelOpArg); 263 264 rewriter.clone(*loopOp.begin(), mapper); 265 } 266 }; 267 268 class GenericLoopConversionPass 269 : public flangomp::impl::GenericLoopConversionPassBase< 270 GenericLoopConversionPass> { 271 public: 272 GenericLoopConversionPass() = default; 273 274 void runOnOperation() override { 275 mlir::func::FuncOp func = getOperation(); 276 277 if (func.isDeclaration()) 278 return; 279 280 mlir::MLIRContext *context = &getContext(); 281 mlir::RewritePatternSet patterns(context); 282 patterns.insert<GenericLoopConversionPattern>(context); 283 mlir::ConversionTarget target(*context); 284 285 target.markUnknownOpDynamicallyLegal( 286 [](mlir::Operation *) { return true; }); 287 target.addDynamicallyLegalOp<mlir::omp::LoopOp>( 288 [](mlir::omp::LoopOp loopOp) { 289 return mlir::failed( 290 GenericLoopConversionPattern::checkLoopConversionSupportStatus( 291 loopOp)); 292 }); 293 294 if (mlir::failed(mlir::applyFullConversion(getOperation(), target, 295 std::move(patterns)))) { 296 mlir::emitError(func.getLoc(), "error in converting `omp.loop` op"); 297 signalPassFailure(); 298 } 299 } 300 }; 301 } // namespace 302