xref: /llvm-project/flang/lib/Semantics/check-cuda.cpp (revision 3d59e30cbcfea475594aaf1c69388c0503f846ef)
1 //===-- lib/Semantics/check-cuda.cpp ----------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "check-cuda.h"
10 #include "flang/Common/template.h"
11 #include "flang/Evaluate/fold.h"
12 #include "flang/Evaluate/tools.h"
13 #include "flang/Evaluate/traverse.h"
14 #include "flang/Parser/parse-tree-visitor.h"
15 #include "flang/Parser/parse-tree.h"
16 #include "flang/Parser/tools.h"
17 #include "flang/Semantics/expression.h"
18 #include "flang/Semantics/symbol.h"
19 #include "flang/Semantics/tools.h"
20 
21 // Once labeled DO constructs have been canonicalized and their parse subtrees
22 // transformed into parser::DoConstructs, scan the parser::Blocks of the program
23 // and merge adjacent CUFKernelDoConstructs and DoConstructs whenever the
24 // CUFKernelDoConstruct doesn't already have an embedded DoConstruct.  Also
25 // emit errors about improper or missing DoConstructs.
26 
27 namespace Fortran::parser {
28 struct Mutator {
29   template <typename A> bool Pre(A &) { return true; }
30   template <typename A> void Post(A &) {}
31   bool Pre(Block &);
32 };
33 
34 bool Mutator::Pre(Block &block) {
35   for (auto iter{block.begin()}; iter != block.end(); ++iter) {
36     if (auto *kernel{Unwrap<CUFKernelDoConstruct>(*iter)}) {
37       auto &nested{std::get<std::optional<DoConstruct>>(kernel->t)};
38       if (!nested) {
39         if (auto next{iter}; ++next != block.end()) {
40           if (auto *doConstruct{Unwrap<DoConstruct>(*next)}) {
41             nested = std::move(*doConstruct);
42             block.erase(next);
43           }
44         }
45       }
46     } else {
47       Walk(*iter, *this);
48     }
49   }
50   return false;
51 }
52 } // namespace Fortran::parser
53 
54 namespace Fortran::semantics {
55 
56 bool CanonicalizeCUDA(parser::Program &program) {
57   parser::Mutator mutator;
58   parser::Walk(program, mutator);
59   return true;
60 }
61 
62 using MaybeMsg = std::optional<parser::MessageFormattedText>;
63 
64 // Traverses an evaluate::Expr<> in search of unsupported operations
65 // on the device.
66 
67 struct DeviceExprChecker
68     : public evaluate::AnyTraverse<DeviceExprChecker, MaybeMsg> {
69   using Result = MaybeMsg;
70   using Base = evaluate::AnyTraverse<DeviceExprChecker, Result>;
71   DeviceExprChecker() : Base(*this) {}
72   using Base::operator();
73   Result operator()(const evaluate::ProcedureDesignator &x) const {
74     if (const Symbol * sym{x.GetInterfaceSymbol()}) {
75       const auto *subp{
76           sym->GetUltimate().detailsIf<semantics::SubprogramDetails>()};
77       if (subp) {
78         if (auto attrs{subp->cudaSubprogramAttrs()}) {
79           if (*attrs == common::CUDASubprogramAttrs::HostDevice ||
80               *attrs == common::CUDASubprogramAttrs::Device) {
81             return {};
82           }
83         }
84       }
85     } else if (x.GetSpecificIntrinsic()) {
86       // TODO(CUDA): Check for unsupported intrinsics here
87       return {};
88     }
89     return parser::MessageFormattedText(
90         "'%s' may not be called in device code"_err_en_US, x.GetName());
91   }
92 };
93 
94 struct FindHostArray
95     : public evaluate::AnyTraverse<FindHostArray, const Symbol *> {
96   using Result = const Symbol *;
97   using Base = evaluate::AnyTraverse<FindHostArray, Result>;
98   FindHostArray() : Base(*this) {}
99   using Base::operator();
100   Result operator()(const evaluate::Component &x) const {
101     const Symbol &symbol{x.GetLastSymbol()};
102     if (IsAllocatableOrPointer(symbol)) {
103       if (Result hostArray{(*this)(symbol)}) {
104         return hostArray;
105       }
106     }
107     return (*this)(x.base());
108   }
109   Result operator()(const Symbol &symbol) const {
110     if (const auto *details{
111             symbol.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()}) {
112       if (details->IsArray() &&
113           !symbol.attrs().test(Fortran::semantics::Attr::PARAMETER) &&
114           (!details->cudaDataAttr() ||
115               (details->cudaDataAttr() &&
116                   *details->cudaDataAttr() != common::CUDADataAttr::Device &&
117                   *details->cudaDataAttr() != common::CUDADataAttr::Constant &&
118                   *details->cudaDataAttr() != common::CUDADataAttr::Managed &&
119                   *details->cudaDataAttr() != common::CUDADataAttr::Shared &&
120                   *details->cudaDataAttr() != common::CUDADataAttr::Unified))) {
121         return &symbol;
122       }
123     }
124     return nullptr;
125   }
126 };
127 
128 template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) {
129   if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
130     return DeviceExprChecker{}(expr->typedExpr);
131   }
132   return {};
133 }
134 
135 template <typename A>
136 static void CheckUnwrappedExpr(
137     SemanticsContext &context, SourceName at, const A &x) {
138   if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
139     if (auto msg{DeviceExprChecker{}(expr->typedExpr)}) {
140       context.Say(at, std::move(*msg));
141     }
142   }
143 }
144 
145 template <bool CUF_KERNEL> struct ActionStmtChecker {
146   template <typename A> static MaybeMsg WhyNotOk(const A &x) {
147     if constexpr (ConstraintTrait<A>) {
148       return WhyNotOk(x.thing);
149     } else if constexpr (WrapperTrait<A>) {
150       return WhyNotOk(x.v);
151     } else if constexpr (UnionTrait<A>) {
152       return WhyNotOk(x.u);
153     } else if constexpr (TupleTrait<A>) {
154       return WhyNotOk(x.t);
155     } else {
156       return parser::MessageFormattedText{
157           "Statement may not appear in device code"_err_en_US};
158     }
159   }
160   template <typename A>
161   static MaybeMsg WhyNotOk(const common::Indirection<A> &x) {
162     return WhyNotOk(x.value());
163   }
164   template <typename... As>
165   static MaybeMsg WhyNotOk(const std::variant<As...> &x) {
166     return common::visit([](const auto &x) { return WhyNotOk(x); }, x);
167   }
168   template <std::size_t J = 0, typename... As>
169   static MaybeMsg WhyNotOk(const std::tuple<As...> &x) {
170     if constexpr (J == sizeof...(As)) {
171       return {};
172     } else if (auto msg{WhyNotOk(std::get<J>(x))}) {
173       return msg;
174     } else {
175       return WhyNotOk<(J + 1)>(x);
176     }
177   }
178   template <typename A> static MaybeMsg WhyNotOk(const std::list<A> &x) {
179     for (const auto &y : x) {
180       if (MaybeMsg result{WhyNotOk(y)}) {
181         return result;
182       }
183     }
184     return {};
185   }
186   template <typename A> static MaybeMsg WhyNotOk(const std::optional<A> &x) {
187     if (x) {
188       return WhyNotOk(*x);
189     } else {
190       return {};
191     }
192   }
193   template <typename A>
194   static MaybeMsg WhyNotOk(const parser::UnlabeledStatement<A> &x) {
195     return WhyNotOk(x.statement);
196   }
197   template <typename A>
198   static MaybeMsg WhyNotOk(const parser::Statement<A> &x) {
199     return WhyNotOk(x.statement);
200   }
201   static MaybeMsg WhyNotOk(const parser::AllocateStmt &) {
202     return {}; // AllocateObjects are checked elsewhere
203   }
204   static MaybeMsg WhyNotOk(const parser::AllocateCoarraySpec &) {
205     return parser::MessageFormattedText(
206         "A coarray may not be allocated on the device"_err_en_US);
207   }
208   static MaybeMsg WhyNotOk(const parser::DeallocateStmt &) {
209     return {}; // AllocateObjects are checked elsewhere
210   }
211   static MaybeMsg WhyNotOk(const parser::AssignmentStmt &x) {
212     return DeviceExprChecker{}(x.typedAssignment);
213   }
214   static MaybeMsg WhyNotOk(const parser::CallStmt &x) {
215     return DeviceExprChecker{}(x.typedCall);
216   }
217   static MaybeMsg WhyNotOk(const parser::ContinueStmt &) { return {}; }
218   static MaybeMsg WhyNotOk(const parser::IfStmt &x) {
219     if (auto result{
220             CheckUnwrappedExpr(std::get<parser::ScalarLogicalExpr>(x.t))}) {
221       return result;
222     }
223     return WhyNotOk(
224         std::get<parser::UnlabeledStatement<parser::ActionStmt>>(x.t)
225             .statement);
226   }
227   static MaybeMsg WhyNotOk(const parser::NullifyStmt &x) {
228     for (const auto &y : x.v) {
229       if (MaybeMsg result{DeviceExprChecker{}(y.typedExpr)}) {
230         return result;
231       }
232     }
233     return {};
234   }
235   static MaybeMsg WhyNotOk(const parser::PointerAssignmentStmt &x) {
236     return DeviceExprChecker{}(x.typedAssignment);
237   }
238 };
239 
240 template <bool IsCUFKernelDo> class DeviceContextChecker {
241 public:
242   explicit DeviceContextChecker(SemanticsContext &c) : context_{c} {}
243   void CheckSubprogram(const parser::Name &name, const parser::Block &body) {
244     if (name.symbol) {
245       const auto *subp{
246           name.symbol->GetUltimate().detailsIf<SubprogramDetails>()};
247       if (subp && subp->moduleInterface()) {
248         subp = subp->moduleInterface()
249                    ->GetUltimate()
250                    .detailsIf<SubprogramDetails>();
251       }
252       if (subp &&
253           subp->cudaSubprogramAttrs().value_or(
254               common::CUDASubprogramAttrs::Host) !=
255               common::CUDASubprogramAttrs::Host) {
256         Check(body);
257       }
258     }
259   }
260   void Check(const parser::Block &block) {
261     for (const auto &epc : block) {
262       Check(epc);
263     }
264   }
265 
266 private:
267   void Check(const parser::ExecutionPartConstruct &epc) {
268     common::visit(
269         common::visitors{
270             [&](const parser::ExecutableConstruct &x) { Check(x); },
271             [&](const parser::Statement<common::Indirection<parser::EntryStmt>>
272                     &x) {
273               context_.Say(x.source,
274                   "Device code may not contain an ENTRY statement"_err_en_US);
275             },
276             [](const parser::Statement<common::Indirection<parser::FormatStmt>>
277                     &) {},
278             [](const parser::Statement<common::Indirection<parser::DataStmt>>
279                     &) {},
280             [](const parser::Statement<
281                 common::Indirection<parser::NamelistStmt>> &) {},
282             [](const parser::ErrorRecovery &) {},
283         },
284         epc.u);
285   }
286   void Check(const parser::ExecutableConstruct &ec) {
287     common::visit(
288         common::visitors{
289             [&](const parser::Statement<parser::ActionStmt> &stmt) {
290               Check(stmt.statement, stmt.source);
291             },
292             [&](const common::Indirection<parser::DoConstruct> &x) {
293               if (const std::optional<parser::LoopControl> &control{
294                       x.value().GetLoopControl()}) {
295                 common::visit([&](const auto &y) { Check(y); }, control->u);
296               }
297               Check(std::get<parser::Block>(x.value().t));
298             },
299             [&](const common::Indirection<parser::BlockConstruct> &x) {
300               Check(std::get<parser::Block>(x.value().t));
301             },
302             [&](const common::Indirection<parser::IfConstruct> &x) {
303               Check(x.value());
304             },
305             [&](const common::Indirection<parser::CaseConstruct> &x) {
306               const auto &caseList{
307                   std::get<std::list<parser::CaseConstruct::Case>>(
308                       x.value().t)};
309               for (const parser::CaseConstruct::Case &c : caseList) {
310                 Check(std::get<parser::Block>(c.t));
311               }
312             },
313             [&](const auto &x) {
314               if (auto source{parser::GetSource(x)}) {
315                 context_.Say(*source,
316                     "Statement may not appear in device code"_err_en_US);
317               }
318             },
319         },
320         ec.u);
321   }
322   template <typename SEEK, typename A>
323   static const SEEK *GetIOControl(const A &stmt) {
324     for (const auto &spec : stmt.controls) {
325       if (const auto *result{std::get_if<SEEK>(&spec.u)}) {
326         return result;
327       }
328     }
329     return nullptr;
330   }
331   template <typename A> static bool IsInternalIO(const A &stmt) {
332     if (stmt.iounit.has_value()) {
333       return std::holds_alternative<Fortran::parser::Variable>(stmt.iounit->u);
334     }
335     if (auto *unit{GetIOControl<Fortran::parser::IoUnit>(stmt)}) {
336       return std::holds_alternative<Fortran::parser::Variable>(unit->u);
337     }
338     return false;
339   }
340   void WarnOnIoStmt(const parser::CharBlock &source) {
341     context_.Warn(common::UsageWarning::CUDAUsage, source,
342         "I/O statement might not be supported on device"_warn_en_US);
343   }
344   template <typename A>
345   void WarnIfNotInternal(const A &stmt, const parser::CharBlock &source) {
346     if (!IsInternalIO(stmt)) {
347       WarnOnIoStmt(source);
348     }
349   }
350   template <typename A>
351   void ErrorIfHostSymbol(const A &expr, parser::CharBlock source) {
352     if (const Symbol * hostArray{FindHostArray{}(expr)}) {
353       context_.Say(source,
354           "Host array '%s' cannot be present in device context"_err_en_US,
355           hostArray->name());
356     }
357   }
358   void ErrorInCUFKernel(parser::CharBlock source) {
359     if (IsCUFKernelDo) {
360       context_.Say(
361           source, "Statement may not appear in cuf kernel code"_err_en_US);
362     }
363   }
364   void Check(const parser::ActionStmt &stmt, const parser::CharBlock &source) {
365     common::visit(
366         common::visitors{
367             [&](const common::Indirection<parser::CycleStmt> &) {
368               ErrorInCUFKernel(source);
369             },
370             [&](const common::Indirection<parser::ExitStmt> &) {
371               ErrorInCUFKernel(source);
372             },
373             [&](const common::Indirection<parser::GotoStmt> &) {
374               ErrorInCUFKernel(source);
375             },
376             [&](const common::Indirection<parser::StopStmt> &) { return; },
377             [&](const common::Indirection<parser::PrintStmt> &) {},
378             [&](const common::Indirection<parser::WriteStmt> &x) {
379               if (x.value().format) { // Formatted write to '*' or '6'
380                 if (std::holds_alternative<Fortran::parser::Star>(
381                         x.value().format->u)) {
382                   if (x.value().iounit) {
383                     if (std::holds_alternative<Fortran::parser::Star>(
384                             x.value().iounit->u)) {
385                       return;
386                     }
387                   }
388                 }
389               }
390               WarnIfNotInternal(x.value(), source);
391             },
392             [&](const common::Indirection<parser::CloseStmt> &x) {
393               WarnOnIoStmt(source);
394             },
395             [&](const common::Indirection<parser::EndfileStmt> &x) {
396               WarnOnIoStmt(source);
397             },
398             [&](const common::Indirection<parser::OpenStmt> &x) {
399               WarnOnIoStmt(source);
400             },
401             [&](const common::Indirection<parser::ReadStmt> &x) {
402               WarnIfNotInternal(x.value(), source);
403             },
404             [&](const common::Indirection<parser::InquireStmt> &x) {
405               WarnOnIoStmt(source);
406             },
407             [&](const common::Indirection<parser::RewindStmt> &x) {
408               WarnOnIoStmt(source);
409             },
410             [&](const common::Indirection<parser::BackspaceStmt> &x) {
411               WarnOnIoStmt(source);
412             },
413             [&](const common::Indirection<parser::IfStmt> &x) {
414               Check(x.value());
415             },
416             [&](const common::Indirection<parser::AssignmentStmt> &x) {
417               if (const evaluate::Assignment *
418                   assign{semantics::GetAssignment(x.value())}) {
419                 ErrorIfHostSymbol(assign->lhs, source);
420                 ErrorIfHostSymbol(assign->rhs, source);
421               }
422               if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
423                 context_.Say(source, std::move(*msg));
424               }
425             },
426             [&](const auto &x) {
427               if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
428                 context_.Say(source, std::move(*msg));
429               }
430             },
431         },
432         stmt.u);
433   }
434   void Check(const parser::IfConstruct &ic) {
435     const auto &ifS{std::get<parser::Statement<parser::IfThenStmt>>(ic.t)};
436     CheckUnwrappedExpr(context_, ifS.source,
437         std::get<parser::ScalarLogicalExpr>(ifS.statement.t));
438     Check(std::get<parser::Block>(ic.t));
439     for (const auto &eib :
440         std::get<std::list<parser::IfConstruct::ElseIfBlock>>(ic.t)) {
441       const auto &eIfS{std::get<parser::Statement<parser::ElseIfStmt>>(eib.t)};
442       CheckUnwrappedExpr(context_, eIfS.source,
443           std::get<parser::ScalarLogicalExpr>(eIfS.statement.t));
444       Check(std::get<parser::Block>(eib.t));
445     }
446     if (const auto &eb{
447             std::get<std::optional<parser::IfConstruct::ElseBlock>>(ic.t)}) {
448       Check(std::get<parser::Block>(eb->t));
449     }
450   }
451   void Check(const parser::IfStmt &is) {
452     const auto &uS{
453         std::get<parser::UnlabeledStatement<parser::ActionStmt>>(is.t)};
454     CheckUnwrappedExpr(
455         context_, uS.source, std::get<parser::ScalarLogicalExpr>(is.t));
456     Check(uS.statement, uS.source);
457   }
458   void Check(const parser::LoopControl::Bounds &bounds) {
459     Check(bounds.lower);
460     Check(bounds.upper);
461     if (bounds.step) {
462       Check(*bounds.step);
463     }
464   }
465   void Check(const parser::LoopControl::Concurrent &x) {
466     const auto &header{std::get<parser::ConcurrentHeader>(x.t)};
467     for (const auto &cc :
468         std::get<std::list<parser::ConcurrentControl>>(header.t)) {
469       Check(std::get<1>(cc.t));
470       Check(std::get<2>(cc.t));
471       if (const auto &step{
472               std::get<std::optional<parser::ScalarIntExpr>>(cc.t)}) {
473         Check(*step);
474       }
475     }
476     if (const auto &mask{
477             std::get<std::optional<parser::ScalarLogicalExpr>>(header.t)}) {
478       Check(*mask);
479     }
480   }
481   void Check(const parser::ScalarLogicalExpr &x) {
482     Check(DEREF(parser::Unwrap<parser::Expr>(x)));
483   }
484   void Check(const parser::ScalarIntExpr &x) {
485     Check(DEREF(parser::Unwrap<parser::Expr>(x)));
486   }
487   void Check(const parser::ScalarExpr &x) {
488     Check(DEREF(parser::Unwrap<parser::Expr>(x)));
489   }
490   void Check(const parser::Expr &expr) {
491     if (MaybeMsg msg{DeviceExprChecker{}(expr.typedExpr)}) {
492       context_.Say(expr.source, std::move(*msg));
493     }
494   }
495 
496   SemanticsContext &context_;
497 };
498 
499 void CUDAChecker::Enter(const parser::SubroutineSubprogram &x) {
500   DeviceContextChecker<false>{context_}.CheckSubprogram(
501       std::get<parser::Name>(
502           std::get<parser::Statement<parser::SubroutineStmt>>(x.t).statement.t),
503       std::get<parser::ExecutionPart>(x.t).v);
504 }
505 
506 void CUDAChecker::Enter(const parser::FunctionSubprogram &x) {
507   DeviceContextChecker<false>{context_}.CheckSubprogram(
508       std::get<parser::Name>(
509           std::get<parser::Statement<parser::FunctionStmt>>(x.t).statement.t),
510       std::get<parser::ExecutionPart>(x.t).v);
511 }
512 
513 void CUDAChecker::Enter(const parser::SeparateModuleSubprogram &x) {
514   DeviceContextChecker<false>{context_}.CheckSubprogram(
515       std::get<parser::Statement<parser::MpSubprogramStmt>>(x.t).statement.v,
516       std::get<parser::ExecutionPart>(x.t).v);
517 }
518 
519 // !$CUF KERNEL DO semantic checks
520 
521 static int DoConstructTightNesting(
522     const parser::DoConstruct *doConstruct, const parser::Block *&innerBlock) {
523   if (!doConstruct ||
524       (!doConstruct->IsDoNormal() && !doConstruct->IsDoConcurrent())) {
525     return 0;
526   }
527   innerBlock = &std::get<parser::Block>(doConstruct->t);
528   if (innerBlock->size() == 1) {
529     if (const auto *execConstruct{
530             std::get_if<parser::ExecutableConstruct>(&innerBlock->front().u)}) {
531       if (const auto *next{
532               std::get_if<common::Indirection<parser::DoConstruct>>(
533                   &execConstruct->u)}) {
534         return 1 + DoConstructTightNesting(&next->value(), innerBlock);
535       }
536     }
537   }
538   return 1;
539 }
540 
541 static void CheckReduce(
542     SemanticsContext &context, const parser::CUFReduction &reduce) {
543   auto op{std::get<parser::CUFReduction::Operator>(reduce.t).v};
544   for (const auto &var :
545       std::get<std::list<parser::Scalar<parser::Variable>>>(reduce.t)) {
546     if (const auto &typedExprPtr{var.thing.typedExpr};
547         typedExprPtr && typedExprPtr->v) {
548       const auto &expr{*typedExprPtr->v};
549       if (auto type{expr.GetType()}) {
550         auto cat{type->category()};
551         bool isOk{false};
552         switch (op) {
553         case parser::ReductionOperator::Operator::Plus:
554         case parser::ReductionOperator::Operator::Multiply:
555         case parser::ReductionOperator::Operator::Max:
556         case parser::ReductionOperator::Operator::Min:
557           isOk = cat == TypeCategory::Integer || cat == TypeCategory::Real ||
558               cat == TypeCategory::Complex;
559           break;
560         case parser::ReductionOperator::Operator::Iand:
561         case parser::ReductionOperator::Operator::Ior:
562         case parser::ReductionOperator::Operator::Ieor:
563           isOk = cat == TypeCategory::Integer;
564           break;
565         case parser::ReductionOperator::Operator::And:
566         case parser::ReductionOperator::Operator::Or:
567         case parser::ReductionOperator::Operator::Eqv:
568         case parser::ReductionOperator::Operator::Neqv:
569           isOk = cat == TypeCategory::Logical;
570           break;
571         }
572         if (!isOk) {
573           context.Say(var.thing.GetSource(),
574               "!$CUF KERNEL DO REDUCE operation is not acceptable for a variable with type %s"_err_en_US,
575               type->AsFortran());
576         }
577       }
578     }
579   }
580 }
581 
582 void CUDAChecker::Enter(const parser::CUFKernelDoConstruct &x) {
583   auto source{std::get<parser::CUFKernelDoConstruct::Directive>(x.t).source};
584   const auto &directive{std::get<parser::CUFKernelDoConstruct::Directive>(x.t)};
585   std::int64_t depth{1};
586   if (auto expr{AnalyzeExpr(context_,
587           std::get<std::optional<parser::ScalarIntConstantExpr>>(
588               directive.t))}) {
589     depth = evaluate::ToInt64(expr).value_or(0);
590     if (depth <= 0) {
591       context_.Say(source,
592           "!$CUF KERNEL DO (%jd): loop nesting depth must be positive"_err_en_US,
593           std::intmax_t{depth});
594       depth = 1;
595     }
596   }
597   const parser::DoConstruct *doConstruct{common::GetPtrFromOptional(
598       std::get<std::optional<parser::DoConstruct>>(x.t))};
599   const parser::Block *innerBlock{nullptr};
600   if (DoConstructTightNesting(doConstruct, innerBlock) < depth) {
601     context_.Say(source,
602         "!$CUF KERNEL DO (%jd) must be followed by a DO construct with tightly nested outer levels of counted DO loops"_err_en_US,
603         std::intmax_t{depth});
604   }
605   if (innerBlock) {
606     DeviceContextChecker<true>{context_}.Check(*innerBlock);
607   }
608   for (const auto &reduce :
609       std::get<std::list<parser::CUFReduction>>(directive.t)) {
610     CheckReduce(context_, reduce);
611   }
612   inCUFKernelDoConstruct_ = true;
613 }
614 
615 void CUDAChecker::Leave(const parser::CUFKernelDoConstruct &) {
616   inCUFKernelDoConstruct_ = false;
617 }
618 
619 void CUDAChecker::Enter(const parser::AssignmentStmt &x) {
620   auto lhsLoc{std::get<parser::Variable>(x.t).GetSource()};
621   const auto &scope{context_.FindScope(lhsLoc)};
622   const Scope &progUnit{GetProgramUnitContaining(scope)};
623   if (IsCUDADeviceContext(&progUnit) || inCUFKernelDoConstruct_) {
624     return; // Data transfer with assignment is only perform on host.
625   }
626 
627   const evaluate::Assignment *assign{semantics::GetAssignment(x)};
628   if (!assign) {
629     return;
630   }
631 
632   int nbLhs{evaluate::GetNbOfCUDADeviceSymbols(assign->lhs)};
633   int nbRhs{evaluate::GetNbOfCUDADeviceSymbols(assign->rhs)};
634 
635   // device to host transfer with more than one device object on the rhs is not
636   // legal.
637   if (nbLhs == 0 && nbRhs > 1) {
638     context_.Say(lhsLoc,
639         "More than one reference to a CUDA object on the right hand side of the assigment"_err_en_US);
640   }
641 }
642 
643 } // namespace Fortran::semantics
644