xref: /llvm-project/flang/lib/Semantics/canonicalize-omp.cpp (revision adeff9f63a24f60b0bf240bf13e40bbf7c1dd0e8)
1 //===-- lib/Semantics/canonicalize-omp.cpp --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "canonicalize-omp.h"
10 #include "flang/Parser/parse-tree-visitor.h"
11 
12 // After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
13 // Constructs more structured which provide explicit scopes for later
14 // structural checks and semantic analysis.
15 //   1. move structured DoConstruct and OmpEndLoopDirective into
16 //      OpenMPLoopConstruct. Compilation will not proceed in case of errors
17 //      after this pass.
18 //   2. Associate declarative OMP allocation directives with their
19 //      respective executable allocation directive
20 //   3. TBD
21 namespace Fortran::semantics {
22 
23 using namespace parser::literals;
24 
25 class CanonicalizationOfOmp {
26 public:
27   template <typename T> bool Pre(T &) { return true; }
28   template <typename T> void Post(T &) {}
29   CanonicalizationOfOmp(parser::Messages &messages) : messages_{messages} {}
30 
31   void Post(parser::Block &block) {
32     for (auto it{block.begin()}; it != block.end(); ++it) {
33       if (auto *ompCons{GetConstructIf<parser::OpenMPConstruct>(*it)}) {
34         // OpenMPLoopConstruct
35         if (auto *ompLoop{
36                 std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
37           RewriteOpenMPLoopConstruct(*ompLoop, block, it);
38         }
39       } else if (auto *endDir{
40                      GetConstructIf<parser::OmpEndLoopDirective>(*it)}) {
41         // Unmatched OmpEndLoopDirective
42         auto &dir{std::get<parser::OmpLoopDirective>(endDir->t)};
43         messages_.Say(dir.source,
44             "The %s directive must follow the DO loop associated with the "
45             "loop construct"_err_en_US,
46             parser::ToUpperCaseLetters(dir.source.ToString()));
47       }
48     } // Block list
49   }
50 
51   void Post(parser::ExecutionPart &body) { RewriteOmpAllocations(body); }
52 
53   // Pre-visit all constructs that have both a specification part and
54   // an execution part, and store the connection between the two.
55   bool Pre(parser::BlockConstruct &x) {
56     auto *spec = &std::get<parser::BlockSpecificationPart>(x.t).v;
57     auto *block = &std::get<parser::Block>(x.t);
58     blockForSpec_.insert(std::make_pair(spec, block));
59     return true;
60   }
61   bool Pre(parser::MainProgram &x) {
62     auto *spec = &std::get<parser::SpecificationPart>(x.t);
63     auto *block = &std::get<parser::ExecutionPart>(x.t).v;
64     blockForSpec_.insert(std::make_pair(spec, block));
65     return true;
66   }
67   bool Pre(parser::FunctionSubprogram &x) {
68     auto *spec = &std::get<parser::SpecificationPart>(x.t);
69     auto *block = &std::get<parser::ExecutionPart>(x.t).v;
70     blockForSpec_.insert(std::make_pair(spec, block));
71     return true;
72   }
73   bool Pre(parser::SubroutineSubprogram &x) {
74     auto *spec = &std::get<parser::SpecificationPart>(x.t);
75     auto *block = &std::get<parser::ExecutionPart>(x.t).v;
76     blockForSpec_.insert(std::make_pair(spec, block));
77     return true;
78   }
79   bool Pre(parser::SeparateModuleSubprogram &x) {
80     auto *spec = &std::get<parser::SpecificationPart>(x.t);
81     auto *block = &std::get<parser::ExecutionPart>(x.t).v;
82     blockForSpec_.insert(std::make_pair(spec, block));
83     return true;
84   }
85 
86   void Post(parser::SpecificationPart &spec) {
87     CanonicalizeUtilityConstructs(spec);
88   }
89 
90 private:
91   template <typename T> T *GetConstructIf(parser::ExecutionPartConstruct &x) {
92     if (auto *y{std::get_if<parser::ExecutableConstruct>(&x.u)}) {
93       if (auto *z{std::get_if<common::Indirection<T>>(&y->u)}) {
94         return &z->value();
95       }
96     }
97     return nullptr;
98   }
99 
100   template <typename T> T *GetOmpIf(parser::ExecutionPartConstruct &x) {
101     if (auto *construct{GetConstructIf<parser::OpenMPConstruct>(x)}) {
102       if (auto *omp{std::get_if<T>(&construct->u)}) {
103         return omp;
104       }
105     }
106     return nullptr;
107   }
108 
109   void RewriteOpenMPLoopConstruct(parser::OpenMPLoopConstruct &x,
110       parser::Block &block, parser::Block::iterator it) {
111     // Check the sequence of DoConstruct and OmpEndLoopDirective
112     // in the same iteration
113     //
114     // Original:
115     //   ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
116     //     OmpBeginLoopDirective
117     //   ExecutableConstruct -> DoConstruct
118     //   ExecutableConstruct -> OmpEndLoopDirective (if available)
119     //
120     // After rewriting:
121     //   ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
122     //     OmpBeginLoopDirective
123     //     DoConstruct
124     //     OmpEndLoopDirective (if available)
125     parser::Block::iterator nextIt;
126     auto &beginDir{std::get<parser::OmpBeginLoopDirective>(x.t)};
127     auto &dir{std::get<parser::OmpLoopDirective>(beginDir.t)};
128 
129     nextIt = it;
130     while (++nextIt != block.end()) {
131       // Ignore compiler directives.
132       if (GetConstructIf<parser::CompilerDirective>(*nextIt))
133         continue;
134 
135       if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
136         if (doCons->GetLoopControl()) {
137           // move DoConstruct
138           std::get<std::optional<parser::DoConstruct>>(x.t) =
139               std::move(*doCons);
140           nextIt = block.erase(nextIt);
141           // try to match OmpEndLoopDirective
142           if (nextIt != block.end()) {
143             if (auto *endDir{
144                     GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
145               std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
146                   std::move(*endDir);
147               block.erase(nextIt);
148             }
149           }
150         } else {
151           messages_.Say(dir.source,
152               "DO loop after the %s directive must have loop control"_err_en_US,
153               parser::ToUpperCaseLetters(dir.source.ToString()));
154         }
155       } else {
156         messages_.Say(dir.source,
157             "A DO loop must follow the %s directive"_err_en_US,
158             parser::ToUpperCaseLetters(dir.source.ToString()));
159       }
160       // If we get here, we either found a loop, or issued an error message.
161       return;
162     }
163   }
164 
165   void RewriteOmpAllocations(parser::ExecutionPart &body) {
166     // Rewrite leading declarative allocations so they are nested
167     // within their respective executable allocate directive
168     //
169     // Original:
170     //   ExecutionPartConstruct -> OpenMPDeclarativeAllocate
171     //   ExecutionPartConstruct -> OpenMPDeclarativeAllocate
172     //   ExecutionPartConstruct -> OpenMPExecutableAllocate
173     //
174     // After rewriting:
175     //   ExecutionPartConstruct -> OpenMPExecutableAllocate
176     //     ExecutionPartConstruct -> OpenMPDeclarativeAllocate
177     //     ExecutionPartConstruct -> OpenMPDeclarativeAllocate
178     for (auto it = body.v.rbegin(); it != body.v.rend();) {
179       if (auto *exec = GetOmpIf<parser::OpenMPExecutableAllocate>(*(it++))) {
180         parser::OpenMPDeclarativeAllocate *decl;
181         std::list<parser::OpenMPDeclarativeAllocate> subAllocates;
182         while (it != body.v.rend() &&
183             (decl = GetOmpIf<parser::OpenMPDeclarativeAllocate>(*it))) {
184           subAllocates.push_front(std::move(*decl));
185           it = decltype(it)(body.v.erase(std::next(it).base()));
186         }
187         if (!subAllocates.empty()) {
188           std::get<std::optional<std::list<parser::OpenMPDeclarativeAllocate>>>(
189               exec->t) = {std::move(subAllocates)};
190         }
191       }
192     }
193   }
194 
195   // Canonicalization of utility constructs.
196   //
197   // This addresses the issue of utility constructs that appear at the
198   // boundary between the specification and the execution parts, e.g.
199   //   subroutine foo
200   //     integer :: x     ! Specification
201   //     !$omp nothing
202   //     x = 1            ! Execution
203   //     ...
204   //   end
205   //
206   // Utility constructs (error and nothing) can appear in both the
207   // specification part and the execution part, except "error at(execution)",
208   // which cannot be present in the specification part (whereas any utility
209   // construct can be in the execution part).
210   // When a utility construct is at the boundary, it should preferably be
211   // parsed as an element of the execution part, but since the specification
212   // part is parsed first, the utility construct ends up belonging to the
213   // specification part.
214   //
215   // To allow the likes of the following code to compile, move all utility
216   // construct that are at the end of the specification part to the beginning
217   // of the execution part.
218   //
219   // subroutine foo
220   //   !$omp error at(execution)  ! Initially parsed as declarative construct.
221   //                              ! Move it to the execution part.
222   // end
223 
224   void CanonicalizeUtilityConstructs(parser::SpecificationPart &spec) {
225     auto found = blockForSpec_.find(&spec);
226     if (found == blockForSpec_.end()) {
227       // There is no corresponding execution part, so there is nothing to do.
228       return;
229     }
230     parser::Block &block = *found->second;
231 
232     // There are two places where an OpenMP declarative construct can
233     // show up in the tuple in specification part:
234     // (1) in std::list<OpenMPDeclarativeConstruct>, or
235     // (2) in std::list<DeclarationConstruct>.
236     // The case (1) is only possible is the list (2) is empty.
237 
238     auto &omps =
239         std::get<std::list<parser::OpenMPDeclarativeConstruct>>(spec.t);
240     auto &decls = std::get<std::list<parser::DeclarationConstruct>>(spec.t);
241 
242     if (!decls.empty()) {
243       MoveUtilityConstructsFromDecls(decls, block);
244     } else {
245       MoveUtilityConstructsFromOmps(omps, block);
246     }
247   }
248 
249   void MoveUtilityConstructsFromDecls(
250       std::list<parser::DeclarationConstruct> &decls, parser::Block &block) {
251     // Find the trailing range of DeclarationConstructs that are OpenMP
252     // utility construct, that are to be moved to the execution part.
253     std::list<parser::DeclarationConstruct>::reverse_iterator rlast = [&]() {
254       for (auto rit = decls.rbegin(), rend = decls.rend(); rit != rend; ++rit) {
255         parser::DeclarationConstruct &dc = *rit;
256         if (!std::holds_alternative<parser::SpecificationConstruct>(dc.u)) {
257           return rit;
258         }
259         auto &sc = std::get<parser::SpecificationConstruct>(dc.u);
260         using OpenMPDeclarativeConstruct =
261             common::Indirection<parser::OpenMPDeclarativeConstruct>;
262         if (!std::holds_alternative<OpenMPDeclarativeConstruct>(sc.u)) {
263           return rit;
264         }
265         // Got OpenMPDeclarativeConstruct. If it's not a utility construct
266         // then stop.
267         auto &odc = std::get<OpenMPDeclarativeConstruct>(sc.u).value();
268         if (!std::holds_alternative<parser::OpenMPUtilityConstruct>(odc.u)) {
269           return rit;
270         }
271       }
272       return decls.rend();
273     }();
274 
275     std::transform(decls.rbegin(), rlast, std::front_inserter(block),
276         [](parser::DeclarationConstruct &dc) {
277           auto &sc = std::get<parser::SpecificationConstruct>(dc.u);
278           using OpenMPDeclarativeConstruct =
279               common::Indirection<parser::OpenMPDeclarativeConstruct>;
280           auto &oc = std::get<OpenMPDeclarativeConstruct>(sc.u).value();
281           auto &ut = std::get<parser::OpenMPUtilityConstruct>(oc.u);
282 
283           return parser::ExecutionPartConstruct(parser::ExecutableConstruct(
284               common::Indirection(parser::OpenMPConstruct(std::move(ut)))));
285         });
286 
287     decls.erase(rlast.base(), decls.end());
288   }
289 
290   void MoveUtilityConstructsFromOmps(
291       std::list<parser::OpenMPDeclarativeConstruct> &omps,
292       parser::Block &block) {
293     using OpenMPDeclarativeConstruct = parser::OpenMPDeclarativeConstruct;
294     // Find the trailing range of OpenMPDeclarativeConstruct that are OpenMP
295     // utility construct, that are to be moved to the execution part.
296     std::list<OpenMPDeclarativeConstruct>::reverse_iterator rlast = [&]() {
297       for (auto rit = omps.rbegin(), rend = omps.rend(); rit != rend; ++rit) {
298         OpenMPDeclarativeConstruct &dc = *rit;
299         if (!std::holds_alternative<parser::OpenMPUtilityConstruct>(dc.u)) {
300           return rit;
301         }
302       }
303       return omps.rend();
304     }();
305 
306     std::transform(omps.rbegin(), rlast, std::front_inserter(block),
307         [](parser::OpenMPDeclarativeConstruct &dc) {
308           auto &ut = std::get<parser::OpenMPUtilityConstruct>(dc.u);
309           return parser::ExecutionPartConstruct(parser::ExecutableConstruct(
310               common::Indirection(parser::OpenMPConstruct(std::move(ut)))));
311         });
312 
313     omps.erase(rlast.base(), omps.end());
314   }
315 
316   // Mapping from the specification parts to the blocks that follow in the
317   // same construct. This is for converting utility constructs to executable
318   // constructs.
319   std::map<parser::SpecificationPart *, parser::Block *> blockForSpec_;
320   parser::Messages &messages_;
321 };
322 
323 bool CanonicalizeOmp(parser::Messages &messages, parser::Program &program) {
324   CanonicalizationOfOmp omp{messages};
325   Walk(program, omp);
326   return !messages.AnyFatalError();
327 }
328 } // namespace Fortran::semantics
329