xref: /llvm-project/flang/lib/Semantics/canonicalize-omp.cpp (revision 5faf45a3d24e603cbc8fe4eb45da386653dae5e5)
1 //===-- lib/Semantics/canonicalize-omp.cpp --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "canonicalize-omp.h"
10 #include "flang/Parser/parse-tree-visitor.h"
11 
12 // After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
13 // Constructs more structured which provide explicit scopes for later
14 // structural checks and semantic analysis.
15 //   1. move structured DoConstruct and OmpEndLoopDirective into
16 //      OpenMPLoopConstruct. Compilation will not proceed in case of errors
17 //      after this pass.
18 //   2. Associate declarative OMP allocation directives with their
19 //      respective executable allocation directive
20 //   3. TBD
21 namespace Fortran::semantics {
22 
23 using namespace parser::literals;
24 
25 class CanonicalizationOfOmp {
26 public:
27   template <typename T> bool Pre(T &) { return true; }
28   template <typename T> void Post(T &) {}
29   CanonicalizationOfOmp(parser::Messages &messages) : messages_{messages} {}
30 
31   void Post(parser::Block &block) {
32     for (auto it{block.begin()}; it != block.end(); ++it) {
33       if (auto *ompCons{GetConstructIf<parser::OpenMPConstruct>(*it)}) {
34         // OpenMPLoopConstruct
35         if (auto *ompLoop{
36                 std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
37           RewriteOpenMPLoopConstruct(*ompLoop, block, it);
38         }
39       } else if (auto *endDir{
40                      GetConstructIf<parser::OmpEndLoopDirective>(*it)}) {
41         // Unmatched OmpEndLoopDirective
42         auto &dir{std::get<parser::OmpLoopDirective>(endDir->t)};
43         messages_.Say(dir.source,
44             "The %s directive must follow the DO loop associated with the "
45             "loop construct"_err_en_US,
46             parser::ToUpperCaseLetters(dir.source.ToString()));
47       }
48     } // Block list
49   }
50 
51   void Post(parser::ExecutionPart &body) { RewriteOmpAllocations(body); }
52 
53 private:
54   template <typename T> T *GetConstructIf(parser::ExecutionPartConstruct &x) {
55     if (auto *y{std::get_if<parser::ExecutableConstruct>(&x.u)}) {
56       if (auto *z{std::get_if<common::Indirection<T>>(&y->u)}) {
57         return &z->value();
58       }
59     }
60     return nullptr;
61   }
62 
63   template <typename T> T *GetOmpIf(parser::ExecutionPartConstruct &x) {
64     if (auto *construct{GetConstructIf<parser::OpenMPConstruct>(x)}) {
65       if (auto *omp{std::get_if<T>(&construct->u)}) {
66         return omp;
67       }
68     }
69     return nullptr;
70   }
71 
72   void RewriteOpenMPLoopConstruct(parser::OpenMPLoopConstruct &x,
73       parser::Block &block, parser::Block::iterator it) {
74     // Check the sequence of DoConstruct and OmpEndLoopDirective
75     // in the same iteration
76     //
77     // Original:
78     //   ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
79     //     OmpBeginLoopDirective
80     //   ExecutableConstruct -> DoConstruct
81     //   ExecutableConstruct -> OmpEndLoopDirective (if available)
82     //
83     // After rewriting:
84     //   ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
85     //     OmpBeginLoopDirective
86     //     DoConstruct
87     //     OmpEndLoopDirective (if available)
88     parser::Block::iterator nextIt;
89     auto &beginDir{std::get<parser::OmpBeginLoopDirective>(x.t)};
90     auto &dir{std::get<parser::OmpLoopDirective>(beginDir.t)};
91 
92     nextIt = it;
93     if (++nextIt != block.end()) {
94       if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
95         if (doCons->GetLoopControl()) {
96           // move DoConstruct
97           std::get<std::optional<parser::DoConstruct>>(x.t) =
98               std::move(*doCons);
99           nextIt = block.erase(nextIt);
100           // try to match OmpEndLoopDirective
101           if (nextIt != block.end()) {
102             if (auto *endDir{
103                     GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
104               std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
105                   std::move(*endDir);
106               block.erase(nextIt);
107             }
108           }
109         } else {
110           messages_.Say(dir.source,
111               "DO loop after the %s directive must have loop control"_err_en_US,
112               parser::ToUpperCaseLetters(dir.source.ToString()));
113         }
114         return; // found do-loop
115       }
116     }
117     messages_.Say(dir.source,
118         "A DO loop must follow the %s directive"_err_en_US,
119         parser::ToUpperCaseLetters(dir.source.ToString()));
120   }
121 
122   void RewriteOmpAllocations(parser::ExecutionPart &body) {
123     // Rewrite leading declarative allocations so they are nested
124     // within their respective executable allocate directive
125     //
126     // Original:
127     //   ExecutionPartConstruct -> OpenMPDeclarativeAllocate
128     //   ExecutionPartConstruct -> OpenMPDeclarativeAllocate
129     //   ExecutionPartConstruct -> OpenMPExecutableAllocate
130     //
131     // After rewriting:
132     //   ExecutionPartConstruct -> OpenMPExecutableAllocate
133     //     ExecutionPartConstruct -> OpenMPDeclarativeAllocate
134     //     ExecutionPartConstruct -> OpenMPDeclarativeAllocate
135     for (auto it = body.v.rbegin(); it != body.v.rend();) {
136       if (auto *exec = GetOmpIf<parser::OpenMPExecutableAllocate>(*(it++))) {
137         parser::OpenMPDeclarativeAllocate *decl;
138         std::list<parser::OpenMPDeclarativeAllocate> subAllocates;
139         while (it != body.v.rend() &&
140             (decl = GetOmpIf<parser::OpenMPDeclarativeAllocate>(*it))) {
141           subAllocates.push_front(std::move(*decl));
142           it = decltype(it)(body.v.erase(std::next(it).base()));
143         }
144         if (!subAllocates.empty()) {
145           std::get<std::optional<std::list<parser::OpenMPDeclarativeAllocate>>>(
146               exec->t) = {std::move(subAllocates)};
147         }
148       }
149     }
150   }
151 
152   parser::Messages &messages_;
153 };
154 
155 bool CanonicalizeOmp(parser::Messages &messages, parser::Program &program) {
156   CanonicalizationOfOmp omp{messages};
157   Walk(program, omp);
158   return !messages.AnyFatalError();
159 }
160 } // namespace Fortran::semantics
161