xref: /llvm-project/flang/lib/Semantics/canonicalize-omp.cpp (revision d1f510cca8e966bd1742bf17256bfec99dcdf229)
1 //===-- lib/Semantics/canonicalize-omp.cpp --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "canonicalize-omp.h"
10 #include "flang/Parser/parse-tree-visitor.h"
11 
12 // After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
13 // Constructs more structured which provide explicit scopes for later
14 // structural checks and semantic analysis.
15 //   1. move structured DoConstruct and OmpEndLoopDirective into
16 //      OpenMPLoopConstruct. Compilation will not proceed in case of errors
17 //      after this pass.
18 //   2. Associate declarative OMP allocation directives with their
19 //      respective executable allocation directive
20 //   3. TBD
21 namespace Fortran::semantics {
22 
23 using namespace parser::literals;
24 
25 class CanonicalizationOfOmp {
26 public:
27   template <typename T> bool Pre(T &) { return true; }
28   template <typename T> void Post(T &) {}
29   CanonicalizationOfOmp(parser::Messages &messages) : messages_{messages} {}
30 
31   void Post(parser::Block &block) {
32     for (auto it{block.begin()}; it != block.end(); ++it) {
33       if (auto *ompCons{GetConstructIf<parser::OpenMPConstruct>(*it)}) {
34         // OpenMPLoopConstruct
35         if (auto *ompLoop{
36                 std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
37           RewriteOpenMPLoopConstruct(*ompLoop, block, it);
38         }
39       } else if (auto *endDir{
40                      GetConstructIf<parser::OmpEndLoopDirective>(*it)}) {
41         // Unmatched OmpEndLoopDirective
42         auto &dir{std::get<parser::OmpLoopDirective>(endDir->t)};
43         messages_.Say(dir.source,
44             "The %s directive must follow the DO loop associated with the "
45             "loop construct"_err_en_US,
46             parser::ToUpperCaseLetters(dir.source.ToString()));
47       }
48     } // Block list
49   }
50 
51   void Post(parser::ExecutionPart &body) { RewriteOmpAllocations(body); }
52 
53 private:
54   template <typename T> T *GetConstructIf(parser::ExecutionPartConstruct &x) {
55     if (auto *y{std::get_if<parser::ExecutableConstruct>(&x.u)}) {
56       if (auto *z{std::get_if<common::Indirection<T>>(&y->u)}) {
57         return &z->value();
58       }
59     }
60     return nullptr;
61   }
62 
63   template <typename T> T *GetOmpIf(parser::ExecutionPartConstruct &x) {
64     if (auto *construct{GetConstructIf<parser::OpenMPConstruct>(x)}) {
65       if (auto *omp{std::get_if<T>(&construct->u)}) {
66         return omp;
67       }
68     }
69     return nullptr;
70   }
71 
72   void RewriteOpenMPLoopConstruct(parser::OpenMPLoopConstruct &x,
73       parser::Block &block, parser::Block::iterator it) {
74     // Check the sequence of DoConstruct and OmpEndLoopDirective
75     // in the same iteration
76     //
77     // Original:
78     //   ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
79     //     OmpBeginLoopDirective
80     //   ExecutableConstruct -> DoConstruct
81     //   ExecutableConstruct -> OmpEndLoopDirective (if available)
82     //
83     // After rewriting:
84     //   ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
85     //     OmpBeginLoopDirective
86     //     DoConstruct
87     //     OmpEndLoopDirective (if available)
88     parser::Block::iterator nextIt;
89     auto &beginDir{std::get<parser::OmpBeginLoopDirective>(x.t)};
90     auto &dir{std::get<parser::OmpLoopDirective>(beginDir.t)};
91 
92     nextIt = it;
93     while (++nextIt != block.end()) {
94       // Ignore compiler directives.
95       if (GetConstructIf<parser::CompilerDirective>(*nextIt))
96         continue;
97 
98       if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
99         if (doCons->GetLoopControl()) {
100           // move DoConstruct
101           std::get<std::optional<parser::DoConstruct>>(x.t) =
102               std::move(*doCons);
103           nextIt = block.erase(nextIt);
104           // try to match OmpEndLoopDirective
105           if (nextIt != block.end()) {
106             if (auto *endDir{
107                     GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
108               std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
109                   std::move(*endDir);
110               block.erase(nextIt);
111             }
112           }
113         } else {
114           messages_.Say(dir.source,
115               "DO loop after the %s directive must have loop control"_err_en_US,
116               parser::ToUpperCaseLetters(dir.source.ToString()));
117         }
118       } else {
119         messages_.Say(dir.source,
120             "A DO loop must follow the %s directive"_err_en_US,
121             parser::ToUpperCaseLetters(dir.source.ToString()));
122       }
123       // If we get here, we either found a loop, or issued an error message.
124       return;
125     }
126   }
127 
128   void RewriteOmpAllocations(parser::ExecutionPart &body) {
129     // Rewrite leading declarative allocations so they are nested
130     // within their respective executable allocate directive
131     //
132     // Original:
133     //   ExecutionPartConstruct -> OpenMPDeclarativeAllocate
134     //   ExecutionPartConstruct -> OpenMPDeclarativeAllocate
135     //   ExecutionPartConstruct -> OpenMPExecutableAllocate
136     //
137     // After rewriting:
138     //   ExecutionPartConstruct -> OpenMPExecutableAllocate
139     //     ExecutionPartConstruct -> OpenMPDeclarativeAllocate
140     //     ExecutionPartConstruct -> OpenMPDeclarativeAllocate
141     for (auto it = body.v.rbegin(); it != body.v.rend();) {
142       if (auto *exec = GetOmpIf<parser::OpenMPExecutableAllocate>(*(it++))) {
143         parser::OpenMPDeclarativeAllocate *decl;
144         std::list<parser::OpenMPDeclarativeAllocate> subAllocates;
145         while (it != body.v.rend() &&
146             (decl = GetOmpIf<parser::OpenMPDeclarativeAllocate>(*it))) {
147           subAllocates.push_front(std::move(*decl));
148           it = decltype(it)(body.v.erase(std::next(it).base()));
149         }
150         if (!subAllocates.empty()) {
151           std::get<std::optional<std::list<parser::OpenMPDeclarativeAllocate>>>(
152               exec->t) = {std::move(subAllocates)};
153         }
154       }
155     }
156   }
157 
158   parser::Messages &messages_;
159 };
160 
161 bool CanonicalizeOmp(parser::Messages &messages, parser::Program &program) {
162   CanonicalizationOfOmp omp{messages};
163   Walk(program, omp);
164   return !messages.AnyFatalError();
165 }
166 } // namespace Fortran::semantics
167