xref: /llvm-project/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp (revision 1e2d5f7943d09d658a5fbacf661d2c6c361f857c)
181f544d4SKareem Ergawy //===- GenericLoopConversion.cpp ------------------------------------------===//
281f544d4SKareem Ergawy //
381f544d4SKareem Ergawy // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
481f544d4SKareem Ergawy // See https://llvm.org/LICENSE.txt for license information.
581f544d4SKareem Ergawy // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
681f544d4SKareem Ergawy //
781f544d4SKareem Ergawy //===----------------------------------------------------------------------===//
881f544d4SKareem Ergawy 
981f544d4SKareem Ergawy #include "flang/Common/OpenMP-utils.h"
1081f544d4SKareem Ergawy 
1181f544d4SKareem Ergawy #include "mlir/Dialect/Func/IR/FuncOps.h"
1281f544d4SKareem Ergawy #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
1381f544d4SKareem Ergawy #include "mlir/IR/IRMapping.h"
1481f544d4SKareem Ergawy #include "mlir/Pass/Pass.h"
1581f544d4SKareem Ergawy #include "mlir/Transforms/DialectConversion.h"
1681f544d4SKareem Ergawy 
1781f544d4SKareem Ergawy #include <memory>
1881f544d4SKareem Ergawy 
1981f544d4SKareem Ergawy namespace flangomp {
2081f544d4SKareem Ergawy #define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS
2181f544d4SKareem Ergawy #include "flang/Optimizer/OpenMP/Passes.h.inc"
2281f544d4SKareem Ergawy } // namespace flangomp
2381f544d4SKareem Ergawy 
2481f544d4SKareem Ergawy namespace {
2581f544d4SKareem Ergawy 
2681f544d4SKareem Ergawy /// A conversion pattern to handle various combined forms of `omp.loop`. For how
2781f544d4SKareem Ergawy /// combined/composite directive are handled see:
2881f544d4SKareem Ergawy /// https://discourse.llvm.org/t/rfc-representing-combined-composite-constructs-in-the-openmp-dialect/76986.
2981f544d4SKareem Ergawy class GenericLoopConversionPattern
3081f544d4SKareem Ergawy     : public mlir::OpConversionPattern<mlir::omp::LoopOp> {
3181f544d4SKareem Ergawy public:
32*1e2d5f79SKareem Ergawy   enum class GenericLoopCombinedInfo { Standalone, TeamsLoop, ParallelLoop };
3381f544d4SKareem Ergawy 
3481f544d4SKareem Ergawy   using mlir::OpConversionPattern<mlir::omp::LoopOp>::OpConversionPattern;
3581f544d4SKareem Ergawy 
3629f7392cSKareem Ergawy   explicit GenericLoopConversionPattern(mlir::MLIRContext *ctx)
3729f7392cSKareem Ergawy       : mlir::OpConversionPattern<mlir::omp::LoopOp>{ctx} {
3829f7392cSKareem Ergawy     // Enable rewrite recursion to make sure nested `loop` directives are
3929f7392cSKareem Ergawy     // handled.
4029f7392cSKareem Ergawy     this->setHasBoundedRewriteRecursion(true);
4129f7392cSKareem Ergawy   }
4229f7392cSKareem Ergawy 
4381f544d4SKareem Ergawy   mlir::LogicalResult
4481f544d4SKareem Ergawy   matchAndRewrite(mlir::omp::LoopOp loopOp, OpAdaptor adaptor,
4581f544d4SKareem Ergawy                   mlir::ConversionPatternRewriter &rewriter) const override {
4681f544d4SKareem Ergawy     assert(mlir::succeeded(checkLoopConversionSupportStatus(loopOp)));
4781f544d4SKareem Ergawy 
4829f7392cSKareem Ergawy     GenericLoopCombinedInfo combinedInfo = findGenericLoopCombineInfo(loopOp);
4929f7392cSKareem Ergawy 
5029f7392cSKareem Ergawy     switch (combinedInfo) {
5129f7392cSKareem Ergawy     case GenericLoopCombinedInfo::Standalone:
52d7e561b9SKareem Ergawy       rewriteStandaloneLoop(loopOp, rewriter);
5329f7392cSKareem Ergawy       break;
54*1e2d5f79SKareem Ergawy     case GenericLoopCombinedInfo::ParallelLoop:
55*1e2d5f79SKareem Ergawy       llvm_unreachable(
56*1e2d5f79SKareem Ergawy           "not yet implemented: Combined `parallel loop` directive");
5729f7392cSKareem Ergawy       break;
58*1e2d5f79SKareem Ergawy     case GenericLoopCombinedInfo::TeamsLoop:
5981f544d4SKareem Ergawy       rewriteToDistributeParallelDo(loopOp, rewriter);
6029f7392cSKareem Ergawy       break;
6129f7392cSKareem Ergawy     }
6229f7392cSKareem Ergawy 
6381f544d4SKareem Ergawy     rewriter.eraseOp(loopOp);
6481f544d4SKareem Ergawy     return mlir::success();
6581f544d4SKareem Ergawy   }
6681f544d4SKareem Ergawy 
6781f544d4SKareem Ergawy   static mlir::LogicalResult
6881f544d4SKareem Ergawy   checkLoopConversionSupportStatus(mlir::omp::LoopOp loopOp) {
6981f544d4SKareem Ergawy     GenericLoopCombinedInfo combinedInfo = findGenericLoopCombineInfo(loopOp);
7081f544d4SKareem Ergawy 
7181f544d4SKareem Ergawy     switch (combinedInfo) {
7229f7392cSKareem Ergawy     case GenericLoopCombinedInfo::Standalone:
7329f7392cSKareem Ergawy       break;
74*1e2d5f79SKareem Ergawy     case GenericLoopCombinedInfo::ParallelLoop:
7581f544d4SKareem Ergawy       return loopOp.emitError(
76*1e2d5f79SKareem Ergawy           "not yet implemented: Combined `parallel loop` directive");
77*1e2d5f79SKareem Ergawy     case GenericLoopCombinedInfo::TeamsLoop:
7881f544d4SKareem Ergawy       break;
7981f544d4SKareem Ergawy     }
8081f544d4SKareem Ergawy 
8181f544d4SKareem Ergawy     auto todo = [&loopOp](mlir::StringRef clauseName) {
8281f544d4SKareem Ergawy       return loopOp.emitError()
8381f544d4SKareem Ergawy              << "not yet implemented: Unhandled clause " << clauseName << " in "
8481f544d4SKareem Ergawy              << loopOp->getName() << " operation";
8581f544d4SKareem Ergawy     };
8681f544d4SKareem Ergawy 
87d7e561b9SKareem Ergawy     // For standalone directives, `bind` is already supported. Other combined
88d7e561b9SKareem Ergawy     // forms will be supported in a follow-up PR.
89d7e561b9SKareem Ergawy     if (combinedInfo != GenericLoopCombinedInfo::Standalone &&
90d7e561b9SKareem Ergawy         loopOp.getBindKind())
9181f544d4SKareem Ergawy       return todo("bind");
9281f544d4SKareem Ergawy 
9381f544d4SKareem Ergawy     if (loopOp.getOrder())
9481f544d4SKareem Ergawy       return todo("order");
9581f544d4SKareem Ergawy 
9681f544d4SKareem Ergawy     if (!loopOp.getReductionVars().empty())
9781f544d4SKareem Ergawy       return todo("reduction");
9881f544d4SKareem Ergawy 
99*1e2d5f79SKareem Ergawy     // TODO For `teams loop`, check similar constrains to what is checked
10081f544d4SKareem Ergawy     // by `TeamsLoopChecker` in SemaOpenMP.cpp.
10181f544d4SKareem Ergawy     return mlir::success();
10281f544d4SKareem Ergawy   }
10381f544d4SKareem Ergawy 
10481f544d4SKareem Ergawy private:
10581f544d4SKareem Ergawy   static GenericLoopCombinedInfo
10681f544d4SKareem Ergawy   findGenericLoopCombineInfo(mlir::omp::LoopOp loopOp) {
10781f544d4SKareem Ergawy     mlir::Operation *parentOp = loopOp->getParentOp();
10829f7392cSKareem Ergawy     GenericLoopCombinedInfo result = GenericLoopCombinedInfo::Standalone;
10981f544d4SKareem Ergawy 
11081f544d4SKareem Ergawy     if (auto teamsOp = mlir::dyn_cast_if_present<mlir::omp::TeamsOp>(parentOp))
111*1e2d5f79SKareem Ergawy       result = GenericLoopCombinedInfo::TeamsLoop;
11281f544d4SKareem Ergawy 
11381f544d4SKareem Ergawy     if (auto parallelOp =
11481f544d4SKareem Ergawy             mlir::dyn_cast_if_present<mlir::omp::ParallelOp>(parentOp))
115*1e2d5f79SKareem Ergawy       result = GenericLoopCombinedInfo::ParallelLoop;
11681f544d4SKareem Ergawy 
11781f544d4SKareem Ergawy     return result;
11881f544d4SKareem Ergawy   }
11981f544d4SKareem Ergawy 
120d7e561b9SKareem Ergawy   void rewriteStandaloneLoop(mlir::omp::LoopOp loopOp,
121d7e561b9SKareem Ergawy                              mlir::ConversionPatternRewriter &rewriter) const {
122d7e561b9SKareem Ergawy     using namespace mlir::omp;
123d7e561b9SKareem Ergawy     std::optional<ClauseBindKind> bindKind = loopOp.getBindKind();
124d7e561b9SKareem Ergawy 
125d7e561b9SKareem Ergawy     if (!bindKind.has_value())
126d7e561b9SKareem Ergawy       return rewriteToSimdLoop(loopOp, rewriter);
127d7e561b9SKareem Ergawy 
128d7e561b9SKareem Ergawy     switch (*loopOp.getBindKind()) {
129d7e561b9SKareem Ergawy     case ClauseBindKind::Parallel:
130d7e561b9SKareem Ergawy       return rewriteToWsloop(loopOp, rewriter);
131d7e561b9SKareem Ergawy     case ClauseBindKind::Teams:
132d7e561b9SKareem Ergawy       return rewriteToDistrbute(loopOp, rewriter);
133d7e561b9SKareem Ergawy     case ClauseBindKind::Thread:
134d7e561b9SKareem Ergawy       return rewriteToSimdLoop(loopOp, rewriter);
135d7e561b9SKareem Ergawy     }
136d7e561b9SKareem Ergawy   }
137d7e561b9SKareem Ergawy 
138d7e561b9SKareem Ergawy   /// Rewrites standalone `loop` (without `bind` clause or with
139d7e561b9SKareem Ergawy   /// `bind(parallel)`) directives to equivalent `simd` constructs.
140d7e561b9SKareem Ergawy   ///
14129f7392cSKareem Ergawy   /// The reasoning behind this decision is that according to the spec (version
14229f7392cSKareem Ergawy   /// 5.2, section 11.7.1):
14329f7392cSKareem Ergawy   ///
14429f7392cSKareem Ergawy   /// "If the bind clause is not specified on a construct for which it may be
14529f7392cSKareem Ergawy   /// specified and the construct is closely nested inside a teams or parallel
14629f7392cSKareem Ergawy   /// construct, the effect is as if binding is teams or parallel. If none of
14729f7392cSKareem Ergawy   /// those conditions hold, the binding region is not defined."
14829f7392cSKareem Ergawy   ///
14929f7392cSKareem Ergawy   /// which means that standalone `loop` directives have undefined binding
15029f7392cSKareem Ergawy   /// region. Moreover, the spec says (in the next paragraph):
15129f7392cSKareem Ergawy   ///
15229f7392cSKareem Ergawy   /// "The specified binding region determines the binding thread set.
15329f7392cSKareem Ergawy   /// Specifically, if the binding region is a teams region, then the binding
15429f7392cSKareem Ergawy   /// thread set is the set of initial threads that are executing that region
15529f7392cSKareem Ergawy   /// while if the binding region is a parallel region, then the binding thread
15629f7392cSKareem Ergawy   /// set is the team of threads that are executing that region. If the binding
15729f7392cSKareem Ergawy   /// region is not defined, then the binding thread set is the encountering
15829f7392cSKareem Ergawy   /// thread."
15929f7392cSKareem Ergawy   ///
16029f7392cSKareem Ergawy   /// which means that the binding thread set for a standalone `loop` directive
16129f7392cSKareem Ergawy   /// is only the encountering thread.
16229f7392cSKareem Ergawy   ///
16329f7392cSKareem Ergawy   /// Since the encountering thread is the binding thread (set) for a
16429f7392cSKareem Ergawy   /// standalone `loop` directive, the best we can do in such case is to "simd"
16529f7392cSKareem Ergawy   /// the directive.
16629f7392cSKareem Ergawy   void rewriteToSimdLoop(mlir::omp::LoopOp loopOp,
16729f7392cSKareem Ergawy                          mlir::ConversionPatternRewriter &rewriter) const {
168d7e561b9SKareem Ergawy     loopOp.emitWarning(
169d7e561b9SKareem Ergawy         "Detected standalone OpenMP `loop` directive with thread binding, "
170d7e561b9SKareem Ergawy         "the associated loop will be rewritten to `simd`.");
171d7e561b9SKareem Ergawy     rewriteToSingleWrapperOp<mlir::omp::SimdOp, mlir::omp::SimdOperands>(
172d7e561b9SKareem Ergawy         loopOp, rewriter);
173d7e561b9SKareem Ergawy   }
174d7e561b9SKareem Ergawy 
175d7e561b9SKareem Ergawy   void rewriteToDistrbute(mlir::omp::LoopOp loopOp,
176d7e561b9SKareem Ergawy                           mlir::ConversionPatternRewriter &rewriter) const {
177d7e561b9SKareem Ergawy     rewriteToSingleWrapperOp<mlir::omp::DistributeOp,
178d7e561b9SKareem Ergawy                              mlir::omp::DistributeOperands>(loopOp, rewriter);
179d7e561b9SKareem Ergawy   }
180d7e561b9SKareem Ergawy 
181d7e561b9SKareem Ergawy   void rewriteToWsloop(mlir::omp::LoopOp loopOp,
182d7e561b9SKareem Ergawy                        mlir::ConversionPatternRewriter &rewriter) const {
183d7e561b9SKareem Ergawy     rewriteToSingleWrapperOp<mlir::omp::WsloopOp, mlir::omp::WsloopOperands>(
184d7e561b9SKareem Ergawy         loopOp, rewriter);
185d7e561b9SKareem Ergawy   }
186d7e561b9SKareem Ergawy 
187d7e561b9SKareem Ergawy   // TODO Suggestion by Sergio: tag auto-generated operations for constructs
188d7e561b9SKareem Ergawy   // that weren't part of the original program, that would be useful
189d7e561b9SKareem Ergawy   // information for debugging purposes later on. This new attribute could be
190d7e561b9SKareem Ergawy   // used for `omp.loop`, but also for `do concurrent` transformations,
191d7e561b9SKareem Ergawy   // `workshare`, `workdistribute`, etc. The tag could be used for all kinds of
192d7e561b9SKareem Ergawy   // auto-generated operations using a dialect attribute (named something like
193d7e561b9SKareem Ergawy   // `omp.origin` or `omp.derived`) and perhaps hold the name of the operation
194d7e561b9SKareem Ergawy   // it was derived from, the reason it was transformed or something like that
195d7e561b9SKareem Ergawy   // we could use when emitting any messages related to it later on.
196d7e561b9SKareem Ergawy   template <typename OpTy, typename OpOperandsTy>
197d7e561b9SKareem Ergawy   void
198d7e561b9SKareem Ergawy   rewriteToSingleWrapperOp(mlir::omp::LoopOp loopOp,
199d7e561b9SKareem Ergawy                            mlir::ConversionPatternRewriter &rewriter) const {
200d7e561b9SKareem Ergawy     OpOperandsTy clauseOps;
201d7e561b9SKareem Ergawy     clauseOps.privateVars = loopOp.getPrivateVars();
20229f7392cSKareem Ergawy 
20329f7392cSKareem Ergawy     auto privateSyms = loopOp.getPrivateSyms();
20429f7392cSKareem Ergawy     if (privateSyms)
205d7e561b9SKareem Ergawy       clauseOps.privateSyms.assign(privateSyms->begin(), privateSyms->end());
20629f7392cSKareem Ergawy 
207d7e561b9SKareem Ergawy     Fortran::common::openmp::EntryBlockArgs args;
208d7e561b9SKareem Ergawy     args.priv.vars = clauseOps.privateVars;
20929f7392cSKareem Ergawy 
210d7e561b9SKareem Ergawy     auto wrapperOp = rewriter.create<OpTy>(loopOp.getLoc(), clauseOps);
211d7e561b9SKareem Ergawy     mlir::Block *opBlock = genEntryBlock(rewriter, args, wrapperOp.getRegion());
21229f7392cSKareem Ergawy 
21329f7392cSKareem Ergawy     mlir::IRMapping mapper;
21429f7392cSKareem Ergawy     mlir::Block &loopBlock = *loopOp.getRegion().begin();
21529f7392cSKareem Ergawy 
216d7e561b9SKareem Ergawy     for (auto [loopOpArg, opArg] :
217d7e561b9SKareem Ergawy          llvm::zip_equal(loopBlock.getArguments(), opBlock->getArguments()))
218d7e561b9SKareem Ergawy       mapper.map(loopOpArg, opArg);
21929f7392cSKareem Ergawy 
22029f7392cSKareem Ergawy     rewriter.clone(*loopOp.begin(), mapper);
22129f7392cSKareem Ergawy   }
22229f7392cSKareem Ergawy 
22381f544d4SKareem Ergawy   void rewriteToDistributeParallelDo(
22481f544d4SKareem Ergawy       mlir::omp::LoopOp loopOp,
22581f544d4SKareem Ergawy       mlir::ConversionPatternRewriter &rewriter) const {
22681f544d4SKareem Ergawy     mlir::omp::ParallelOperands parallelClauseOps;
22781f544d4SKareem Ergawy     parallelClauseOps.privateVars = loopOp.getPrivateVars();
22881f544d4SKareem Ergawy 
22981f544d4SKareem Ergawy     auto privateSyms = loopOp.getPrivateSyms();
23081f544d4SKareem Ergawy     if (privateSyms)
23181f544d4SKareem Ergawy       parallelClauseOps.privateSyms.assign(privateSyms->begin(),
23281f544d4SKareem Ergawy                                            privateSyms->end());
23381f544d4SKareem Ergawy 
23481f544d4SKareem Ergawy     Fortran::common::openmp::EntryBlockArgs parallelArgs;
23581f544d4SKareem Ergawy     parallelArgs.priv.vars = parallelClauseOps.privateVars;
23681f544d4SKareem Ergawy 
23781f544d4SKareem Ergawy     auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loopOp.getLoc(),
23881f544d4SKareem Ergawy                                                              parallelClauseOps);
23981f544d4SKareem Ergawy     mlir::Block *parallelBlock =
24081f544d4SKareem Ergawy         genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
24181f544d4SKareem Ergawy     parallelOp.setComposite(true);
24281f544d4SKareem Ergawy     rewriter.setInsertionPoint(
24381f544d4SKareem Ergawy         rewriter.create<mlir::omp::TerminatorOp>(loopOp.getLoc()));
24481f544d4SKareem Ergawy 
24581f544d4SKareem Ergawy     mlir::omp::DistributeOperands distributeClauseOps;
24681f544d4SKareem Ergawy     auto distributeOp = rewriter.create<mlir::omp::DistributeOp>(
24781f544d4SKareem Ergawy         loopOp.getLoc(), distributeClauseOps);
24881f544d4SKareem Ergawy     distributeOp.setComposite(true);
24981f544d4SKareem Ergawy     rewriter.createBlock(&distributeOp.getRegion());
25081f544d4SKareem Ergawy 
25181f544d4SKareem Ergawy     mlir::omp::WsloopOperands wsloopClauseOps;
25281f544d4SKareem Ergawy     auto wsloopOp =
25381f544d4SKareem Ergawy         rewriter.create<mlir::omp::WsloopOp>(loopOp.getLoc(), wsloopClauseOps);
25481f544d4SKareem Ergawy     wsloopOp.setComposite(true);
25581f544d4SKareem Ergawy     rewriter.createBlock(&wsloopOp.getRegion());
25681f544d4SKareem Ergawy 
25781f544d4SKareem Ergawy     mlir::IRMapping mapper;
25881f544d4SKareem Ergawy     mlir::Block &loopBlock = *loopOp.getRegion().begin();
25981f544d4SKareem Ergawy 
26081f544d4SKareem Ergawy     for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal(
26181f544d4SKareem Ergawy              loopBlock.getArguments(), parallelBlock->getArguments()))
26281f544d4SKareem Ergawy       mapper.map(loopOpArg, parallelOpArg);
26381f544d4SKareem Ergawy 
26481f544d4SKareem Ergawy     rewriter.clone(*loopOp.begin(), mapper);
26581f544d4SKareem Ergawy   }
26681f544d4SKareem Ergawy };
26781f544d4SKareem Ergawy 
26881f544d4SKareem Ergawy class GenericLoopConversionPass
26981f544d4SKareem Ergawy     : public flangomp::impl::GenericLoopConversionPassBase<
27081f544d4SKareem Ergawy           GenericLoopConversionPass> {
27181f544d4SKareem Ergawy public:
27281f544d4SKareem Ergawy   GenericLoopConversionPass() = default;
27381f544d4SKareem Ergawy 
27481f544d4SKareem Ergawy   void runOnOperation() override {
27581f544d4SKareem Ergawy     mlir::func::FuncOp func = getOperation();
27681f544d4SKareem Ergawy 
27781f544d4SKareem Ergawy     if (func.isDeclaration())
27881f544d4SKareem Ergawy       return;
27981f544d4SKareem Ergawy 
28081f544d4SKareem Ergawy     mlir::MLIRContext *context = &getContext();
28181f544d4SKareem Ergawy     mlir::RewritePatternSet patterns(context);
28281f544d4SKareem Ergawy     patterns.insert<GenericLoopConversionPattern>(context);
28381f544d4SKareem Ergawy     mlir::ConversionTarget target(*context);
28481f544d4SKareem Ergawy 
28581f544d4SKareem Ergawy     target.markUnknownOpDynamicallyLegal(
28681f544d4SKareem Ergawy         [](mlir::Operation *) { return true; });
28781f544d4SKareem Ergawy     target.addDynamicallyLegalOp<mlir::omp::LoopOp>(
28881f544d4SKareem Ergawy         [](mlir::omp::LoopOp loopOp) {
28981f544d4SKareem Ergawy           return mlir::failed(
29081f544d4SKareem Ergawy               GenericLoopConversionPattern::checkLoopConversionSupportStatus(
29181f544d4SKareem Ergawy                   loopOp));
29281f544d4SKareem Ergawy         });
29381f544d4SKareem Ergawy 
29481f544d4SKareem Ergawy     if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
29581f544d4SKareem Ergawy                                                std::move(patterns)))) {
29681f544d4SKareem Ergawy       mlir::emitError(func.getLoc(), "error in converting `omp.loop` op");
29781f544d4SKareem Ergawy       signalPassFailure();
29881f544d4SKareem Ergawy     }
29981f544d4SKareem Ergawy   }
30081f544d4SKareem Ergawy };
30181f544d4SKareem Ergawy } // namespace
302