xref: /llvm-project/flang/lib/Lower/OpenMP/Decomposer.cpp (revision aa875cfe11ddec239934e37ce07c1cf7804bb73b)
1 //===-- Decomposer.cpp -- Compound directive decomposition ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Decomposer.h"
14 
15 #include "Clauses.h"
16 #include "Utils.h"
17 #include "flang/Lower/PFTBuilder.h"
18 #include "flang/Semantics/semantics.h"
19 #include "flang/Tools/CrossToolHelpers.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Frontend/OpenMP/ClauseT.h"
25 #include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
26 #include "llvm/Frontend/OpenMP/OMP.h"
27 #include "llvm/Support/raw_ostream.h"
28 
29 #include <optional>
30 #include <utility>
31 #include <variant>
32 
33 using namespace Fortran;
34 
35 namespace {
36 using namespace Fortran::lower::omp;
37 
38 struct ConstructDecomposition {
39   ConstructDecomposition(mlir::ModuleOp modOp,
40                          semantics::SemanticsContext &semaCtx,
41                          lower::pft::Evaluation &ev,
42                          llvm::omp::Directive compound,
43                          const List<Clause> &clauses)
44       : semaCtx(semaCtx), mod(modOp), eval(ev) {
45     tomp::ConstructDecompositionT decompose(getOpenMPVersionAttribute(modOp),
46                                             *this, compound,
47                                             llvm::ArrayRef(clauses));
48     output = std::move(decompose.output);
49   }
50 
51   // Given an object, return its base object if one exists.
52   std::optional<Object> getBaseObject(const Object &object) {
53     return lower::omp::getBaseObject(object, semaCtx);
54   }
55 
56   // Return the iteration variable of the associated loop if any.
57   std::optional<Object> getLoopIterVar() {
58     if (semantics::Symbol *symbol = getIterationVariableSymbol(eval))
59       return Object{symbol, /*designator=*/{}};
60     return std::nullopt;
61   }
62 
63   semantics::SemanticsContext &semaCtx;
64   mlir::ModuleOp mod;
65   lower::pft::Evaluation &eval;
66   List<UnitConstruct> output;
67 };
68 } // namespace
69 
70 namespace Fortran::lower::omp {
71 LLVM_DUMP_METHOD llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
72                                                const UnitConstruct &uc) {
73   os << llvm::omp::getOpenMPDirectiveName(uc.id);
74   for (auto [index, clause] : llvm::enumerate(uc.clauses)) {
75     os << (index == 0 ? '\t' : ' ');
76     os << llvm::omp::getOpenMPClauseName(clause.id);
77   }
78   return os;
79 }
80 
81 ConstructQueue buildConstructQueue(
82     mlir::ModuleOp modOp, Fortran::semantics::SemanticsContext &semaCtx,
83     Fortran::lower::pft::Evaluation &eval, const parser::CharBlock &source,
84     llvm::omp::Directive compound, const List<Clause> &clauses) {
85 
86   ConstructDecomposition decompose(modOp, semaCtx, eval, compound, clauses);
87   assert(!decompose.output.empty() && "Construct decomposition failed");
88 
89   for (UnitConstruct &uc : decompose.output) {
90     assert(getLeafConstructs(uc.id).empty() && "unexpected compound directive");
91     //  If some clauses are left without source information, use the directive's
92     //  source.
93     for (auto &clause : uc.clauses)
94       if (clause.source.empty())
95         clause.source = source;
96   }
97 
98   return decompose.output;
99 }
100 
101 bool matchLeafSequence(ConstructQueue::const_iterator item,
102                        const ConstructQueue &queue,
103                        llvm::omp::Directive directive) {
104   llvm::ArrayRef<llvm::omp::Directive> leafDirs =
105       llvm::omp::getLeafConstructsOrSelf(directive);
106 
107   for (auto [dir, leaf] :
108        llvm::zip_longest(leafDirs, llvm::make_range(item, queue.end()))) {
109     if (!dir.has_value() || !leaf.has_value())
110       return false;
111 
112     if (*dir != leaf->id)
113       return false;
114   }
115 
116   return true;
117 }
118 
119 bool isLastItemInQueue(ConstructQueue::const_iterator item,
120                        const ConstructQueue &queue) {
121   return std::next(item) == queue.end();
122 }
123 } // namespace Fortran::lower::omp
124