181f544d4SKareem Ergawy //===- GenericLoopConversion.cpp ------------------------------------------===// 281f544d4SKareem Ergawy // 381f544d4SKareem Ergawy // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 481f544d4SKareem Ergawy // See https://llvm.org/LICENSE.txt for license information. 581f544d4SKareem Ergawy // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 681f544d4SKareem Ergawy // 781f544d4SKareem Ergawy //===----------------------------------------------------------------------===// 881f544d4SKareem Ergawy 981f544d4SKareem Ergawy #include "flang/Common/OpenMP-utils.h" 1081f544d4SKareem Ergawy 1181f544d4SKareem Ergawy #include "mlir/Dialect/Func/IR/FuncOps.h" 1281f544d4SKareem Ergawy #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 1381f544d4SKareem Ergawy #include "mlir/IR/IRMapping.h" 1481f544d4SKareem Ergawy #include "mlir/Pass/Pass.h" 1581f544d4SKareem Ergawy #include "mlir/Transforms/DialectConversion.h" 1681f544d4SKareem Ergawy 1781f544d4SKareem Ergawy #include <memory> 1881f544d4SKareem Ergawy 1981f544d4SKareem Ergawy namespace flangomp { 2081f544d4SKareem Ergawy #define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS 2181f544d4SKareem Ergawy #include "flang/Optimizer/OpenMP/Passes.h.inc" 2281f544d4SKareem Ergawy } // namespace flangomp 2381f544d4SKareem Ergawy 2481f544d4SKareem Ergawy namespace { 2581f544d4SKareem Ergawy 2681f544d4SKareem Ergawy /// A conversion pattern to handle various combined forms of `omp.loop`. For how 2781f544d4SKareem Ergawy /// combined/composite directive are handled see: 2881f544d4SKareem Ergawy /// https://discourse.llvm.org/t/rfc-representing-combined-composite-constructs-in-the-openmp-dialect/76986. 2981f544d4SKareem Ergawy class GenericLoopConversionPattern 3081f544d4SKareem Ergawy : public mlir::OpConversionPattern<mlir::omp::LoopOp> { 3181f544d4SKareem Ergawy public: 32*1e2d5f79SKareem Ergawy enum class GenericLoopCombinedInfo { Standalone, TeamsLoop, ParallelLoop }; 3381f544d4SKareem Ergawy 3481f544d4SKareem Ergawy using mlir::OpConversionPattern<mlir::omp::LoopOp>::OpConversionPattern; 3581f544d4SKareem Ergawy 3629f7392cSKareem Ergawy explicit GenericLoopConversionPattern(mlir::MLIRContext *ctx) 3729f7392cSKareem Ergawy : mlir::OpConversionPattern<mlir::omp::LoopOp>{ctx} { 3829f7392cSKareem Ergawy // Enable rewrite recursion to make sure nested `loop` directives are 3929f7392cSKareem Ergawy // handled. 4029f7392cSKareem Ergawy this->setHasBoundedRewriteRecursion(true); 4129f7392cSKareem Ergawy } 4229f7392cSKareem Ergawy 4381f544d4SKareem Ergawy mlir::LogicalResult 4481f544d4SKareem Ergawy matchAndRewrite(mlir::omp::LoopOp loopOp, OpAdaptor adaptor, 4581f544d4SKareem Ergawy mlir::ConversionPatternRewriter &rewriter) const override { 4681f544d4SKareem Ergawy assert(mlir::succeeded(checkLoopConversionSupportStatus(loopOp))); 4781f544d4SKareem Ergawy 4829f7392cSKareem Ergawy GenericLoopCombinedInfo combinedInfo = findGenericLoopCombineInfo(loopOp); 4929f7392cSKareem Ergawy 5029f7392cSKareem Ergawy switch (combinedInfo) { 5129f7392cSKareem Ergawy case GenericLoopCombinedInfo::Standalone: 52d7e561b9SKareem Ergawy rewriteStandaloneLoop(loopOp, rewriter); 5329f7392cSKareem Ergawy break; 54*1e2d5f79SKareem Ergawy case GenericLoopCombinedInfo::ParallelLoop: 55*1e2d5f79SKareem Ergawy llvm_unreachable( 56*1e2d5f79SKareem Ergawy "not yet implemented: Combined `parallel loop` directive"); 5729f7392cSKareem Ergawy break; 58*1e2d5f79SKareem Ergawy case GenericLoopCombinedInfo::TeamsLoop: 5981f544d4SKareem Ergawy rewriteToDistributeParallelDo(loopOp, rewriter); 6029f7392cSKareem Ergawy break; 6129f7392cSKareem Ergawy } 6229f7392cSKareem Ergawy 6381f544d4SKareem Ergawy rewriter.eraseOp(loopOp); 6481f544d4SKareem Ergawy return mlir::success(); 6581f544d4SKareem Ergawy } 6681f544d4SKareem Ergawy 6781f544d4SKareem Ergawy static mlir::LogicalResult 6881f544d4SKareem Ergawy checkLoopConversionSupportStatus(mlir::omp::LoopOp loopOp) { 6981f544d4SKareem Ergawy GenericLoopCombinedInfo combinedInfo = findGenericLoopCombineInfo(loopOp); 7081f544d4SKareem Ergawy 7181f544d4SKareem Ergawy switch (combinedInfo) { 7229f7392cSKareem Ergawy case GenericLoopCombinedInfo::Standalone: 7329f7392cSKareem Ergawy break; 74*1e2d5f79SKareem Ergawy case GenericLoopCombinedInfo::ParallelLoop: 7581f544d4SKareem Ergawy return loopOp.emitError( 76*1e2d5f79SKareem Ergawy "not yet implemented: Combined `parallel loop` directive"); 77*1e2d5f79SKareem Ergawy case GenericLoopCombinedInfo::TeamsLoop: 7881f544d4SKareem Ergawy break; 7981f544d4SKareem Ergawy } 8081f544d4SKareem Ergawy 8181f544d4SKareem Ergawy auto todo = [&loopOp](mlir::StringRef clauseName) { 8281f544d4SKareem Ergawy return loopOp.emitError() 8381f544d4SKareem Ergawy << "not yet implemented: Unhandled clause " << clauseName << " in " 8481f544d4SKareem Ergawy << loopOp->getName() << " operation"; 8581f544d4SKareem Ergawy }; 8681f544d4SKareem Ergawy 87d7e561b9SKareem Ergawy // For standalone directives, `bind` is already supported. Other combined 88d7e561b9SKareem Ergawy // forms will be supported in a follow-up PR. 89d7e561b9SKareem Ergawy if (combinedInfo != GenericLoopCombinedInfo::Standalone && 90d7e561b9SKareem Ergawy loopOp.getBindKind()) 9181f544d4SKareem Ergawy return todo("bind"); 9281f544d4SKareem Ergawy 9381f544d4SKareem Ergawy if (loopOp.getOrder()) 9481f544d4SKareem Ergawy return todo("order"); 9581f544d4SKareem Ergawy 9681f544d4SKareem Ergawy if (!loopOp.getReductionVars().empty()) 9781f544d4SKareem Ergawy return todo("reduction"); 9881f544d4SKareem Ergawy 99*1e2d5f79SKareem Ergawy // TODO For `teams loop`, check similar constrains to what is checked 10081f544d4SKareem Ergawy // by `TeamsLoopChecker` in SemaOpenMP.cpp. 10181f544d4SKareem Ergawy return mlir::success(); 10281f544d4SKareem Ergawy } 10381f544d4SKareem Ergawy 10481f544d4SKareem Ergawy private: 10581f544d4SKareem Ergawy static GenericLoopCombinedInfo 10681f544d4SKareem Ergawy findGenericLoopCombineInfo(mlir::omp::LoopOp loopOp) { 10781f544d4SKareem Ergawy mlir::Operation *parentOp = loopOp->getParentOp(); 10829f7392cSKareem Ergawy GenericLoopCombinedInfo result = GenericLoopCombinedInfo::Standalone; 10981f544d4SKareem Ergawy 11081f544d4SKareem Ergawy if (auto teamsOp = mlir::dyn_cast_if_present<mlir::omp::TeamsOp>(parentOp)) 111*1e2d5f79SKareem Ergawy result = GenericLoopCombinedInfo::TeamsLoop; 11281f544d4SKareem Ergawy 11381f544d4SKareem Ergawy if (auto parallelOp = 11481f544d4SKareem Ergawy mlir::dyn_cast_if_present<mlir::omp::ParallelOp>(parentOp)) 115*1e2d5f79SKareem Ergawy result = GenericLoopCombinedInfo::ParallelLoop; 11681f544d4SKareem Ergawy 11781f544d4SKareem Ergawy return result; 11881f544d4SKareem Ergawy } 11981f544d4SKareem Ergawy 120d7e561b9SKareem Ergawy void rewriteStandaloneLoop(mlir::omp::LoopOp loopOp, 121d7e561b9SKareem Ergawy mlir::ConversionPatternRewriter &rewriter) const { 122d7e561b9SKareem Ergawy using namespace mlir::omp; 123d7e561b9SKareem Ergawy std::optional<ClauseBindKind> bindKind = loopOp.getBindKind(); 124d7e561b9SKareem Ergawy 125d7e561b9SKareem Ergawy if (!bindKind.has_value()) 126d7e561b9SKareem Ergawy return rewriteToSimdLoop(loopOp, rewriter); 127d7e561b9SKareem Ergawy 128d7e561b9SKareem Ergawy switch (*loopOp.getBindKind()) { 129d7e561b9SKareem Ergawy case ClauseBindKind::Parallel: 130d7e561b9SKareem Ergawy return rewriteToWsloop(loopOp, rewriter); 131d7e561b9SKareem Ergawy case ClauseBindKind::Teams: 132d7e561b9SKareem Ergawy return rewriteToDistrbute(loopOp, rewriter); 133d7e561b9SKareem Ergawy case ClauseBindKind::Thread: 134d7e561b9SKareem Ergawy return rewriteToSimdLoop(loopOp, rewriter); 135d7e561b9SKareem Ergawy } 136d7e561b9SKareem Ergawy } 137d7e561b9SKareem Ergawy 138d7e561b9SKareem Ergawy /// Rewrites standalone `loop` (without `bind` clause or with 139d7e561b9SKareem Ergawy /// `bind(parallel)`) directives to equivalent `simd` constructs. 140d7e561b9SKareem Ergawy /// 14129f7392cSKareem Ergawy /// The reasoning behind this decision is that according to the spec (version 14229f7392cSKareem Ergawy /// 5.2, section 11.7.1): 14329f7392cSKareem Ergawy /// 14429f7392cSKareem Ergawy /// "If the bind clause is not specified on a construct for which it may be 14529f7392cSKareem Ergawy /// specified and the construct is closely nested inside a teams or parallel 14629f7392cSKareem Ergawy /// construct, the effect is as if binding is teams or parallel. If none of 14729f7392cSKareem Ergawy /// those conditions hold, the binding region is not defined." 14829f7392cSKareem Ergawy /// 14929f7392cSKareem Ergawy /// which means that standalone `loop` directives have undefined binding 15029f7392cSKareem Ergawy /// region. Moreover, the spec says (in the next paragraph): 15129f7392cSKareem Ergawy /// 15229f7392cSKareem Ergawy /// "The specified binding region determines the binding thread set. 15329f7392cSKareem Ergawy /// Specifically, if the binding region is a teams region, then the binding 15429f7392cSKareem Ergawy /// thread set is the set of initial threads that are executing that region 15529f7392cSKareem Ergawy /// while if the binding region is a parallel region, then the binding thread 15629f7392cSKareem Ergawy /// set is the team of threads that are executing that region. If the binding 15729f7392cSKareem Ergawy /// region is not defined, then the binding thread set is the encountering 15829f7392cSKareem Ergawy /// thread." 15929f7392cSKareem Ergawy /// 16029f7392cSKareem Ergawy /// which means that the binding thread set for a standalone `loop` directive 16129f7392cSKareem Ergawy /// is only the encountering thread. 16229f7392cSKareem Ergawy /// 16329f7392cSKareem Ergawy /// Since the encountering thread is the binding thread (set) for a 16429f7392cSKareem Ergawy /// standalone `loop` directive, the best we can do in such case is to "simd" 16529f7392cSKareem Ergawy /// the directive. 16629f7392cSKareem Ergawy void rewriteToSimdLoop(mlir::omp::LoopOp loopOp, 16729f7392cSKareem Ergawy mlir::ConversionPatternRewriter &rewriter) const { 168d7e561b9SKareem Ergawy loopOp.emitWarning( 169d7e561b9SKareem Ergawy "Detected standalone OpenMP `loop` directive with thread binding, " 170d7e561b9SKareem Ergawy "the associated loop will be rewritten to `simd`."); 171d7e561b9SKareem Ergawy rewriteToSingleWrapperOp<mlir::omp::SimdOp, mlir::omp::SimdOperands>( 172d7e561b9SKareem Ergawy loopOp, rewriter); 173d7e561b9SKareem Ergawy } 174d7e561b9SKareem Ergawy 175d7e561b9SKareem Ergawy void rewriteToDistrbute(mlir::omp::LoopOp loopOp, 176d7e561b9SKareem Ergawy mlir::ConversionPatternRewriter &rewriter) const { 177d7e561b9SKareem Ergawy rewriteToSingleWrapperOp<mlir::omp::DistributeOp, 178d7e561b9SKareem Ergawy mlir::omp::DistributeOperands>(loopOp, rewriter); 179d7e561b9SKareem Ergawy } 180d7e561b9SKareem Ergawy 181d7e561b9SKareem Ergawy void rewriteToWsloop(mlir::omp::LoopOp loopOp, 182d7e561b9SKareem Ergawy mlir::ConversionPatternRewriter &rewriter) const { 183d7e561b9SKareem Ergawy rewriteToSingleWrapperOp<mlir::omp::WsloopOp, mlir::omp::WsloopOperands>( 184d7e561b9SKareem Ergawy loopOp, rewriter); 185d7e561b9SKareem Ergawy } 186d7e561b9SKareem Ergawy 187d7e561b9SKareem Ergawy // TODO Suggestion by Sergio: tag auto-generated operations for constructs 188d7e561b9SKareem Ergawy // that weren't part of the original program, that would be useful 189d7e561b9SKareem Ergawy // information for debugging purposes later on. This new attribute could be 190d7e561b9SKareem Ergawy // used for `omp.loop`, but also for `do concurrent` transformations, 191d7e561b9SKareem Ergawy // `workshare`, `workdistribute`, etc. The tag could be used for all kinds of 192d7e561b9SKareem Ergawy // auto-generated operations using a dialect attribute (named something like 193d7e561b9SKareem Ergawy // `omp.origin` or `omp.derived`) and perhaps hold the name of the operation 194d7e561b9SKareem Ergawy // it was derived from, the reason it was transformed or something like that 195d7e561b9SKareem Ergawy // we could use when emitting any messages related to it later on. 196d7e561b9SKareem Ergawy template <typename OpTy, typename OpOperandsTy> 197d7e561b9SKareem Ergawy void 198d7e561b9SKareem Ergawy rewriteToSingleWrapperOp(mlir::omp::LoopOp loopOp, 199d7e561b9SKareem Ergawy mlir::ConversionPatternRewriter &rewriter) const { 200d7e561b9SKareem Ergawy OpOperandsTy clauseOps; 201d7e561b9SKareem Ergawy clauseOps.privateVars = loopOp.getPrivateVars(); 20229f7392cSKareem Ergawy 20329f7392cSKareem Ergawy auto privateSyms = loopOp.getPrivateSyms(); 20429f7392cSKareem Ergawy if (privateSyms) 205d7e561b9SKareem Ergawy clauseOps.privateSyms.assign(privateSyms->begin(), privateSyms->end()); 20629f7392cSKareem Ergawy 207d7e561b9SKareem Ergawy Fortran::common::openmp::EntryBlockArgs args; 208d7e561b9SKareem Ergawy args.priv.vars = clauseOps.privateVars; 20929f7392cSKareem Ergawy 210d7e561b9SKareem Ergawy auto wrapperOp = rewriter.create<OpTy>(loopOp.getLoc(), clauseOps); 211d7e561b9SKareem Ergawy mlir::Block *opBlock = genEntryBlock(rewriter, args, wrapperOp.getRegion()); 21229f7392cSKareem Ergawy 21329f7392cSKareem Ergawy mlir::IRMapping mapper; 21429f7392cSKareem Ergawy mlir::Block &loopBlock = *loopOp.getRegion().begin(); 21529f7392cSKareem Ergawy 216d7e561b9SKareem Ergawy for (auto [loopOpArg, opArg] : 217d7e561b9SKareem Ergawy llvm::zip_equal(loopBlock.getArguments(), opBlock->getArguments())) 218d7e561b9SKareem Ergawy mapper.map(loopOpArg, opArg); 21929f7392cSKareem Ergawy 22029f7392cSKareem Ergawy rewriter.clone(*loopOp.begin(), mapper); 22129f7392cSKareem Ergawy } 22229f7392cSKareem Ergawy 22381f544d4SKareem Ergawy void rewriteToDistributeParallelDo( 22481f544d4SKareem Ergawy mlir::omp::LoopOp loopOp, 22581f544d4SKareem Ergawy mlir::ConversionPatternRewriter &rewriter) const { 22681f544d4SKareem Ergawy mlir::omp::ParallelOperands parallelClauseOps; 22781f544d4SKareem Ergawy parallelClauseOps.privateVars = loopOp.getPrivateVars(); 22881f544d4SKareem Ergawy 22981f544d4SKareem Ergawy auto privateSyms = loopOp.getPrivateSyms(); 23081f544d4SKareem Ergawy if (privateSyms) 23181f544d4SKareem Ergawy parallelClauseOps.privateSyms.assign(privateSyms->begin(), 23281f544d4SKareem Ergawy privateSyms->end()); 23381f544d4SKareem Ergawy 23481f544d4SKareem Ergawy Fortran::common::openmp::EntryBlockArgs parallelArgs; 23581f544d4SKareem Ergawy parallelArgs.priv.vars = parallelClauseOps.privateVars; 23681f544d4SKareem Ergawy 23781f544d4SKareem Ergawy auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loopOp.getLoc(), 23881f544d4SKareem Ergawy parallelClauseOps); 23981f544d4SKareem Ergawy mlir::Block *parallelBlock = 24081f544d4SKareem Ergawy genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion()); 24181f544d4SKareem Ergawy parallelOp.setComposite(true); 24281f544d4SKareem Ergawy rewriter.setInsertionPoint( 24381f544d4SKareem Ergawy rewriter.create<mlir::omp::TerminatorOp>(loopOp.getLoc())); 24481f544d4SKareem Ergawy 24581f544d4SKareem Ergawy mlir::omp::DistributeOperands distributeClauseOps; 24681f544d4SKareem Ergawy auto distributeOp = rewriter.create<mlir::omp::DistributeOp>( 24781f544d4SKareem Ergawy loopOp.getLoc(), distributeClauseOps); 24881f544d4SKareem Ergawy distributeOp.setComposite(true); 24981f544d4SKareem Ergawy rewriter.createBlock(&distributeOp.getRegion()); 25081f544d4SKareem Ergawy 25181f544d4SKareem Ergawy mlir::omp::WsloopOperands wsloopClauseOps; 25281f544d4SKareem Ergawy auto wsloopOp = 25381f544d4SKareem Ergawy rewriter.create<mlir::omp::WsloopOp>(loopOp.getLoc(), wsloopClauseOps); 25481f544d4SKareem Ergawy wsloopOp.setComposite(true); 25581f544d4SKareem Ergawy rewriter.createBlock(&wsloopOp.getRegion()); 25681f544d4SKareem Ergawy 25781f544d4SKareem Ergawy mlir::IRMapping mapper; 25881f544d4SKareem Ergawy mlir::Block &loopBlock = *loopOp.getRegion().begin(); 25981f544d4SKareem Ergawy 26081f544d4SKareem Ergawy for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal( 26181f544d4SKareem Ergawy loopBlock.getArguments(), parallelBlock->getArguments())) 26281f544d4SKareem Ergawy mapper.map(loopOpArg, parallelOpArg); 26381f544d4SKareem Ergawy 26481f544d4SKareem Ergawy rewriter.clone(*loopOp.begin(), mapper); 26581f544d4SKareem Ergawy } 26681f544d4SKareem Ergawy }; 26781f544d4SKareem Ergawy 26881f544d4SKareem Ergawy class GenericLoopConversionPass 26981f544d4SKareem Ergawy : public flangomp::impl::GenericLoopConversionPassBase< 27081f544d4SKareem Ergawy GenericLoopConversionPass> { 27181f544d4SKareem Ergawy public: 27281f544d4SKareem Ergawy GenericLoopConversionPass() = default; 27381f544d4SKareem Ergawy 27481f544d4SKareem Ergawy void runOnOperation() override { 27581f544d4SKareem Ergawy mlir::func::FuncOp func = getOperation(); 27681f544d4SKareem Ergawy 27781f544d4SKareem Ergawy if (func.isDeclaration()) 27881f544d4SKareem Ergawy return; 27981f544d4SKareem Ergawy 28081f544d4SKareem Ergawy mlir::MLIRContext *context = &getContext(); 28181f544d4SKareem Ergawy mlir::RewritePatternSet patterns(context); 28281f544d4SKareem Ergawy patterns.insert<GenericLoopConversionPattern>(context); 28381f544d4SKareem Ergawy mlir::ConversionTarget target(*context); 28481f544d4SKareem Ergawy 28581f544d4SKareem Ergawy target.markUnknownOpDynamicallyLegal( 28681f544d4SKareem Ergawy [](mlir::Operation *) { return true; }); 28781f544d4SKareem Ergawy target.addDynamicallyLegalOp<mlir::omp::LoopOp>( 28881f544d4SKareem Ergawy [](mlir::omp::LoopOp loopOp) { 28981f544d4SKareem Ergawy return mlir::failed( 29081f544d4SKareem Ergawy GenericLoopConversionPattern::checkLoopConversionSupportStatus( 29181f544d4SKareem Ergawy loopOp)); 29281f544d4SKareem Ergawy }); 29381f544d4SKareem Ergawy 29481f544d4SKareem Ergawy if (mlir::failed(mlir::applyFullConversion(getOperation(), target, 29581f544d4SKareem Ergawy std::move(patterns)))) { 29681f544d4SKareem Ergawy mlir::emitError(func.getLoc(), "error in converting `omp.loop` op"); 29781f544d4SKareem Ergawy signalPassFailure(); 29881f544d4SKareem Ergawy } 29981f544d4SKareem Ergawy } 30081f544d4SKareem Ergawy }; 30181f544d4SKareem Ergawy } // namespace 302