xref: /llvm-project/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp (revision 1e2d5f7943d09d658a5fbacf661d2c6c361f857c)
1 //===- GenericLoopConversion.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Common/OpenMP-utils.h"
10 
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
13 #include "mlir/IR/IRMapping.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 #include <memory>
18 
19 namespace flangomp {
20 #define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS
21 #include "flang/Optimizer/OpenMP/Passes.h.inc"
22 } // namespace flangomp
23 
24 namespace {
25 
26 /// A conversion pattern to handle various combined forms of `omp.loop`. For how
27 /// combined/composite directive are handled see:
28 /// https://discourse.llvm.org/t/rfc-representing-combined-composite-constructs-in-the-openmp-dialect/76986.
29 class GenericLoopConversionPattern
30     : public mlir::OpConversionPattern<mlir::omp::LoopOp> {
31 public:
32   enum class GenericLoopCombinedInfo { Standalone, TeamsLoop, ParallelLoop };
33 
34   using mlir::OpConversionPattern<mlir::omp::LoopOp>::OpConversionPattern;
35 
36   explicit GenericLoopConversionPattern(mlir::MLIRContext *ctx)
37       : mlir::OpConversionPattern<mlir::omp::LoopOp>{ctx} {
38     // Enable rewrite recursion to make sure nested `loop` directives are
39     // handled.
40     this->setHasBoundedRewriteRecursion(true);
41   }
42 
43   mlir::LogicalResult
44   matchAndRewrite(mlir::omp::LoopOp loopOp, OpAdaptor adaptor,
45                   mlir::ConversionPatternRewriter &rewriter) const override {
46     assert(mlir::succeeded(checkLoopConversionSupportStatus(loopOp)));
47 
48     GenericLoopCombinedInfo combinedInfo = findGenericLoopCombineInfo(loopOp);
49 
50     switch (combinedInfo) {
51     case GenericLoopCombinedInfo::Standalone:
52       rewriteStandaloneLoop(loopOp, rewriter);
53       break;
54     case GenericLoopCombinedInfo::ParallelLoop:
55       llvm_unreachable(
56           "not yet implemented: Combined `parallel loop` directive");
57       break;
58     case GenericLoopCombinedInfo::TeamsLoop:
59       rewriteToDistributeParallelDo(loopOp, rewriter);
60       break;
61     }
62 
63     rewriter.eraseOp(loopOp);
64     return mlir::success();
65   }
66 
67   static mlir::LogicalResult
68   checkLoopConversionSupportStatus(mlir::omp::LoopOp loopOp) {
69     GenericLoopCombinedInfo combinedInfo = findGenericLoopCombineInfo(loopOp);
70 
71     switch (combinedInfo) {
72     case GenericLoopCombinedInfo::Standalone:
73       break;
74     case GenericLoopCombinedInfo::ParallelLoop:
75       return loopOp.emitError(
76           "not yet implemented: Combined `parallel loop` directive");
77     case GenericLoopCombinedInfo::TeamsLoop:
78       break;
79     }
80 
81     auto todo = [&loopOp](mlir::StringRef clauseName) {
82       return loopOp.emitError()
83              << "not yet implemented: Unhandled clause " << clauseName << " in "
84              << loopOp->getName() << " operation";
85     };
86 
87     // For standalone directives, `bind` is already supported. Other combined
88     // forms will be supported in a follow-up PR.
89     if (combinedInfo != GenericLoopCombinedInfo::Standalone &&
90         loopOp.getBindKind())
91       return todo("bind");
92 
93     if (loopOp.getOrder())
94       return todo("order");
95 
96     if (!loopOp.getReductionVars().empty())
97       return todo("reduction");
98 
99     // TODO For `teams loop`, check similar constrains to what is checked
100     // by `TeamsLoopChecker` in SemaOpenMP.cpp.
101     return mlir::success();
102   }
103 
104 private:
105   static GenericLoopCombinedInfo
106   findGenericLoopCombineInfo(mlir::omp::LoopOp loopOp) {
107     mlir::Operation *parentOp = loopOp->getParentOp();
108     GenericLoopCombinedInfo result = GenericLoopCombinedInfo::Standalone;
109 
110     if (auto teamsOp = mlir::dyn_cast_if_present<mlir::omp::TeamsOp>(parentOp))
111       result = GenericLoopCombinedInfo::TeamsLoop;
112 
113     if (auto parallelOp =
114             mlir::dyn_cast_if_present<mlir::omp::ParallelOp>(parentOp))
115       result = GenericLoopCombinedInfo::ParallelLoop;
116 
117     return result;
118   }
119 
120   void rewriteStandaloneLoop(mlir::omp::LoopOp loopOp,
121                              mlir::ConversionPatternRewriter &rewriter) const {
122     using namespace mlir::omp;
123     std::optional<ClauseBindKind> bindKind = loopOp.getBindKind();
124 
125     if (!bindKind.has_value())
126       return rewriteToSimdLoop(loopOp, rewriter);
127 
128     switch (*loopOp.getBindKind()) {
129     case ClauseBindKind::Parallel:
130       return rewriteToWsloop(loopOp, rewriter);
131     case ClauseBindKind::Teams:
132       return rewriteToDistrbute(loopOp, rewriter);
133     case ClauseBindKind::Thread:
134       return rewriteToSimdLoop(loopOp, rewriter);
135     }
136   }
137 
138   /// Rewrites standalone `loop` (without `bind` clause or with
139   /// `bind(parallel)`) directives to equivalent `simd` constructs.
140   ///
141   /// The reasoning behind this decision is that according to the spec (version
142   /// 5.2, section 11.7.1):
143   ///
144   /// "If the bind clause is not specified on a construct for which it may be
145   /// specified and the construct is closely nested inside a teams or parallel
146   /// construct, the effect is as if binding is teams or parallel. If none of
147   /// those conditions hold, the binding region is not defined."
148   ///
149   /// which means that standalone `loop` directives have undefined binding
150   /// region. Moreover, the spec says (in the next paragraph):
151   ///
152   /// "The specified binding region determines the binding thread set.
153   /// Specifically, if the binding region is a teams region, then the binding
154   /// thread set is the set of initial threads that are executing that region
155   /// while if the binding region is a parallel region, then the binding thread
156   /// set is the team of threads that are executing that region. If the binding
157   /// region is not defined, then the binding thread set is the encountering
158   /// thread."
159   ///
160   /// which means that the binding thread set for a standalone `loop` directive
161   /// is only the encountering thread.
162   ///
163   /// Since the encountering thread is the binding thread (set) for a
164   /// standalone `loop` directive, the best we can do in such case is to "simd"
165   /// the directive.
166   void rewriteToSimdLoop(mlir::omp::LoopOp loopOp,
167                          mlir::ConversionPatternRewriter &rewriter) const {
168     loopOp.emitWarning(
169         "Detected standalone OpenMP `loop` directive with thread binding, "
170         "the associated loop will be rewritten to `simd`.");
171     rewriteToSingleWrapperOp<mlir::omp::SimdOp, mlir::omp::SimdOperands>(
172         loopOp, rewriter);
173   }
174 
175   void rewriteToDistrbute(mlir::omp::LoopOp loopOp,
176                           mlir::ConversionPatternRewriter &rewriter) const {
177     rewriteToSingleWrapperOp<mlir::omp::DistributeOp,
178                              mlir::omp::DistributeOperands>(loopOp, rewriter);
179   }
180 
181   void rewriteToWsloop(mlir::omp::LoopOp loopOp,
182                        mlir::ConversionPatternRewriter &rewriter) const {
183     rewriteToSingleWrapperOp<mlir::omp::WsloopOp, mlir::omp::WsloopOperands>(
184         loopOp, rewriter);
185   }
186 
187   // TODO Suggestion by Sergio: tag auto-generated operations for constructs
188   // that weren't part of the original program, that would be useful
189   // information for debugging purposes later on. This new attribute could be
190   // used for `omp.loop`, but also for `do concurrent` transformations,
191   // `workshare`, `workdistribute`, etc. The tag could be used for all kinds of
192   // auto-generated operations using a dialect attribute (named something like
193   // `omp.origin` or `omp.derived`) and perhaps hold the name of the operation
194   // it was derived from, the reason it was transformed or something like that
195   // we could use when emitting any messages related to it later on.
196   template <typename OpTy, typename OpOperandsTy>
197   void
198   rewriteToSingleWrapperOp(mlir::omp::LoopOp loopOp,
199                            mlir::ConversionPatternRewriter &rewriter) const {
200     OpOperandsTy clauseOps;
201     clauseOps.privateVars = loopOp.getPrivateVars();
202 
203     auto privateSyms = loopOp.getPrivateSyms();
204     if (privateSyms)
205       clauseOps.privateSyms.assign(privateSyms->begin(), privateSyms->end());
206 
207     Fortran::common::openmp::EntryBlockArgs args;
208     args.priv.vars = clauseOps.privateVars;
209 
210     auto wrapperOp = rewriter.create<OpTy>(loopOp.getLoc(), clauseOps);
211     mlir::Block *opBlock = genEntryBlock(rewriter, args, wrapperOp.getRegion());
212 
213     mlir::IRMapping mapper;
214     mlir::Block &loopBlock = *loopOp.getRegion().begin();
215 
216     for (auto [loopOpArg, opArg] :
217          llvm::zip_equal(loopBlock.getArguments(), opBlock->getArguments()))
218       mapper.map(loopOpArg, opArg);
219 
220     rewriter.clone(*loopOp.begin(), mapper);
221   }
222 
223   void rewriteToDistributeParallelDo(
224       mlir::omp::LoopOp loopOp,
225       mlir::ConversionPatternRewriter &rewriter) const {
226     mlir::omp::ParallelOperands parallelClauseOps;
227     parallelClauseOps.privateVars = loopOp.getPrivateVars();
228 
229     auto privateSyms = loopOp.getPrivateSyms();
230     if (privateSyms)
231       parallelClauseOps.privateSyms.assign(privateSyms->begin(),
232                                            privateSyms->end());
233 
234     Fortran::common::openmp::EntryBlockArgs parallelArgs;
235     parallelArgs.priv.vars = parallelClauseOps.privateVars;
236 
237     auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loopOp.getLoc(),
238                                                              parallelClauseOps);
239     mlir::Block *parallelBlock =
240         genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
241     parallelOp.setComposite(true);
242     rewriter.setInsertionPoint(
243         rewriter.create<mlir::omp::TerminatorOp>(loopOp.getLoc()));
244 
245     mlir::omp::DistributeOperands distributeClauseOps;
246     auto distributeOp = rewriter.create<mlir::omp::DistributeOp>(
247         loopOp.getLoc(), distributeClauseOps);
248     distributeOp.setComposite(true);
249     rewriter.createBlock(&distributeOp.getRegion());
250 
251     mlir::omp::WsloopOperands wsloopClauseOps;
252     auto wsloopOp =
253         rewriter.create<mlir::omp::WsloopOp>(loopOp.getLoc(), wsloopClauseOps);
254     wsloopOp.setComposite(true);
255     rewriter.createBlock(&wsloopOp.getRegion());
256 
257     mlir::IRMapping mapper;
258     mlir::Block &loopBlock = *loopOp.getRegion().begin();
259 
260     for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal(
261              loopBlock.getArguments(), parallelBlock->getArguments()))
262       mapper.map(loopOpArg, parallelOpArg);
263 
264     rewriter.clone(*loopOp.begin(), mapper);
265   }
266 };
267 
268 class GenericLoopConversionPass
269     : public flangomp::impl::GenericLoopConversionPassBase<
270           GenericLoopConversionPass> {
271 public:
272   GenericLoopConversionPass() = default;
273 
274   void runOnOperation() override {
275     mlir::func::FuncOp func = getOperation();
276 
277     if (func.isDeclaration())
278       return;
279 
280     mlir::MLIRContext *context = &getContext();
281     mlir::RewritePatternSet patterns(context);
282     patterns.insert<GenericLoopConversionPattern>(context);
283     mlir::ConversionTarget target(*context);
284 
285     target.markUnknownOpDynamicallyLegal(
286         [](mlir::Operation *) { return true; });
287     target.addDynamicallyLegalOp<mlir::omp::LoopOp>(
288         [](mlir::omp::LoopOp loopOp) {
289           return mlir::failed(
290               GenericLoopConversionPattern::checkLoopConversionSupportStatus(
291                   loopOp));
292         });
293 
294     if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
295                                                std::move(patterns)))) {
296       mlir::emitError(func.getLoc(), "error in converting `omp.loop` op");
297       signalPassFailure();
298     }
299   }
300 };
301 } // namespace
302