xref: /llvm-project/flang/lib/Semantics/check-case.cpp (revision fc97d2e68b03bc2979395e84b645e5b3ba35aecd)
17a77c20dSpeter klausler //===-- lib/Semantics/check-case.cpp --------------------------------------===//
27a77c20dSpeter klausler //
37a77c20dSpeter klausler // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47a77c20dSpeter klausler // See https://llvm.org/LICENSE.txt for license information.
57a77c20dSpeter klausler // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67a77c20dSpeter klausler //
77a77c20dSpeter klausler //===----------------------------------------------------------------------===//
87a77c20dSpeter klausler 
97a77c20dSpeter klausler #include "check-case.h"
107a77c20dSpeter klausler #include "flang/Common/idioms.h"
117a77c20dSpeter klausler #include "flang/Common/reference.h"
1211ddb84bSpeter klausler #include "flang/Common/template.h"
137a77c20dSpeter klausler #include "flang/Evaluate/fold.h"
147a77c20dSpeter klausler #include "flang/Evaluate/type.h"
157a77c20dSpeter klausler #include "flang/Parser/parse-tree.h"
167a77c20dSpeter klausler #include "flang/Semantics/semantics.h"
177a77c20dSpeter klausler #include "flang/Semantics/tools.h"
187a77c20dSpeter klausler #include <tuple>
197a77c20dSpeter klausler 
207a77c20dSpeter klausler namespace Fortran::semantics {
217a77c20dSpeter klausler 
227a77c20dSpeter klausler template <typename T> class CaseValues {
237a77c20dSpeter klausler public:
247a77c20dSpeter klausler   CaseValues(SemanticsContext &c, const evaluate::DynamicType &t)
257a77c20dSpeter klausler       : context_{c}, caseExprType_{t} {}
267a77c20dSpeter klausler 
277a77c20dSpeter klausler   void Check(const std::list<parser::CaseConstruct::Case> &cases) {
287a77c20dSpeter klausler     for (const parser::CaseConstruct::Case &c : cases) {
297a77c20dSpeter klausler       AddCase(c);
307a77c20dSpeter klausler     }
317a77c20dSpeter klausler     if (!hasErrors_) {
327a77c20dSpeter klausler       cases_.sort(Comparator{});
337a77c20dSpeter klausler       if (!AreCasesDisjoint()) { // C1149
347a77c20dSpeter klausler         ReportConflictingCases();
357a77c20dSpeter klausler       }
367a77c20dSpeter klausler     }
377a77c20dSpeter klausler   }
387a77c20dSpeter klausler 
397a77c20dSpeter klausler private:
407a77c20dSpeter klausler   using Value = evaluate::Scalar<T>;
417a77c20dSpeter klausler 
427a77c20dSpeter klausler   void AddCase(const parser::CaseConstruct::Case &c) {
437a77c20dSpeter klausler     const auto &stmt{std::get<parser::Statement<parser::CaseStmt>>(c.t)};
447a77c20dSpeter klausler     const parser::CaseStmt &caseStmt{stmt.statement};
457a77c20dSpeter klausler     const auto &selector{std::get<parser::CaseSelector>(caseStmt.t)};
46cd03e96fSPeter Klausler     common::visit(
477a77c20dSpeter klausler         common::visitors{
487a77c20dSpeter klausler             [&](const std::list<parser::CaseValueRange> &ranges) {
497a77c20dSpeter klausler               for (const auto &range : ranges) {
507a77c20dSpeter klausler                 auto pair{ComputeBounds(range)};
517a77c20dSpeter klausler                 if (pair.first && pair.second && *pair.first > *pair.second) {
520f973ac7SPeter Klausler                   context_.Warn(common::UsageWarning::EmptyCase, stmt.source,
53a53967cdSPeter Klausler                       "CASE has lower bound greater than upper bound"_warn_en_US);
547a77c20dSpeter klausler                 } else {
557a77c20dSpeter klausler                   if constexpr (T::category == TypeCategory::Logical) { // C1148
567a77c20dSpeter klausler                     if ((pair.first || pair.second) &&
577a77c20dSpeter klausler                         (!pair.first || !pair.second ||
587a77c20dSpeter klausler                             *pair.first != *pair.second)) {
597a77c20dSpeter klausler                       context_.Say(stmt.source,
607a77c20dSpeter klausler                           "CASE range is not allowed for LOGICAL"_err_en_US);
617a77c20dSpeter klausler                     }
627a77c20dSpeter klausler                   }
637a77c20dSpeter klausler                   cases_.emplace_back(stmt);
647a77c20dSpeter klausler                   cases_.back().lower = std::move(pair.first);
657a77c20dSpeter klausler                   cases_.back().upper = std::move(pair.second);
667a77c20dSpeter klausler                 }
677a77c20dSpeter klausler               }
687a77c20dSpeter klausler             },
697a77c20dSpeter klausler             [&](const parser::Default &) { cases_.emplace_front(stmt); },
707a77c20dSpeter klausler         },
717a77c20dSpeter klausler         selector.u);
727a77c20dSpeter klausler   }
737a77c20dSpeter klausler 
747a77c20dSpeter klausler   std::optional<Value> GetValue(const parser::CaseValue &caseValue) {
757a77c20dSpeter klausler     const parser::Expr &expr{caseValue.thing.thing.value()};
767a77c20dSpeter klausler     auto *x{expr.typedExpr.get()};
777a77c20dSpeter klausler     if (x && x->v) { // C1147
787a77c20dSpeter klausler       auto type{x->v->GetType()};
797a77c20dSpeter klausler       if (type && type->category() == caseExprType_.category() &&
807a77c20dSpeter klausler           (type->category() != TypeCategory::Character ||
817a77c20dSpeter klausler               type->kind() == caseExprType_.kind())) {
82a73f7abaSPeter Klausler         parser::Messages buffer; // discarded folding messages
83a73f7abaSPeter Klausler         parser::ContextualMessages foldingMessages{expr.source, &buffer};
84a73f7abaSPeter Klausler         evaluate::FoldingContext foldingContext{
85a73f7abaSPeter Klausler             context_.foldingContext(), foldingMessages};
86a73f7abaSPeter Klausler         auto folded{evaluate::Fold(foldingContext, SomeExpr{*x->v})};
87a73f7abaSPeter Klausler         if (auto converted{evaluate::Fold(foldingContext,
88a73f7abaSPeter Klausler                 evaluate::ConvertToType(T::GetType(), SomeExpr{folded}))}) {
89a73f7abaSPeter Klausler           if (auto value{evaluate::GetScalarConstantValue<T>(*converted)}) {
90a73f7abaSPeter Klausler             auto back{evaluate::Fold(foldingContext,
91a73f7abaSPeter Klausler                 evaluate::ConvertToType(*type, SomeExpr{*converted}))};
92a73f7abaSPeter Klausler             if (back == folded) {
93a73f7abaSPeter Klausler               x->v = converted;
94a73f7abaSPeter Klausler               return value;
95a73f7abaSPeter Klausler             } else {
960f973ac7SPeter Klausler               context_.Warn(common::UsageWarning::CaseOverflow, expr.source,
9784c6dc96SPeter Klausler                   "CASE value (%s) overflows type (%s) of SELECT CASE expression"_warn_en_US,
98a73f7abaSPeter Klausler                   folded.AsFortran(), caseExprType_.AsFortran());
99a73f7abaSPeter Klausler               hasErrors_ = true;
100a73f7abaSPeter Klausler               return std::nullopt;
1017a77c20dSpeter klausler             }
1027a77c20dSpeter klausler           }
103a73f7abaSPeter Klausler         }
104a73f7abaSPeter Klausler         context_.Say(expr.source,
105a73f7abaSPeter Klausler             "CASE value (%s) must be a constant scalar"_err_en_US,
106a73f7abaSPeter Klausler             x->v->AsFortran());
1077a77c20dSpeter klausler       } else {
1087a77c20dSpeter klausler         std::string typeStr{type ? type->AsFortran() : "typeless"s};
1097a77c20dSpeter klausler         context_.Say(expr.source,
1107a77c20dSpeter klausler             "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US,
1117a77c20dSpeter klausler             typeStr, caseExprType_.AsFortran());
1127a77c20dSpeter klausler       }
1137a77c20dSpeter klausler       hasErrors_ = true;
1147a77c20dSpeter klausler     }
1157a77c20dSpeter klausler     return std::nullopt;
1167a77c20dSpeter klausler   }
1177a77c20dSpeter klausler 
1187a77c20dSpeter klausler   using PairOfValues = std::pair<std::optional<Value>, std::optional<Value>>;
1197a77c20dSpeter klausler   PairOfValues ComputeBounds(const parser::CaseValueRange &range) {
120cd03e96fSPeter Klausler     return common::visit(
121cd03e96fSPeter Klausler         common::visitors{
1227a77c20dSpeter klausler             [&](const parser::CaseValue &x) {
1237a77c20dSpeter klausler               auto value{GetValue(x)};
1247a77c20dSpeter klausler               return PairOfValues{value, value};
1257a77c20dSpeter klausler             },
1267a77c20dSpeter klausler             [&](const parser::CaseValueRange::Range &x) {
1277a77c20dSpeter klausler               std::optional<Value> lo, hi;
1287a77c20dSpeter klausler               if (x.lower) {
1297a77c20dSpeter klausler                 lo = GetValue(*x.lower);
1307a77c20dSpeter klausler               }
1317a77c20dSpeter klausler               if (x.upper) {
1327a77c20dSpeter klausler                 hi = GetValue(*x.upper);
1337a77c20dSpeter klausler               }
1347a77c20dSpeter klausler               if ((x.lower && !lo) || (x.upper && !hi)) {
1357a77c20dSpeter klausler                 return PairOfValues{}; // error case
1367a77c20dSpeter klausler               }
1377a77c20dSpeter klausler               return PairOfValues{std::move(lo), std::move(hi)};
1387a77c20dSpeter klausler             },
1397a77c20dSpeter klausler         },
1407a77c20dSpeter klausler         range.u);
1417a77c20dSpeter klausler   }
1427a77c20dSpeter klausler 
1437a77c20dSpeter klausler   struct Case {
1447a77c20dSpeter klausler     explicit Case(const parser::Statement<parser::CaseStmt> &s) : stmt{s} {}
1457a77c20dSpeter klausler     bool IsDefault() const { return !lower && !upper; }
1467a77c20dSpeter klausler     std::string AsFortran() const {
1477a77c20dSpeter klausler       std::string result;
1487a77c20dSpeter klausler       {
1497a77c20dSpeter klausler         llvm::raw_string_ostream bs{result};
1507a77c20dSpeter klausler         if (lower) {
1517a77c20dSpeter klausler           evaluate::Constant<T>{*lower}.AsFortran(bs << '(');
1527a77c20dSpeter klausler           if (!upper) {
1537a77c20dSpeter klausler             bs << ':';
1547a77c20dSpeter klausler           } else if (*lower != *upper) {
1557a77c20dSpeter klausler             evaluate::Constant<T>{*upper}.AsFortran(bs << ':');
1567a77c20dSpeter klausler           }
1577a77c20dSpeter klausler           bs << ')';
1587a77c20dSpeter klausler         } else if (upper) {
1597a77c20dSpeter klausler           evaluate::Constant<T>{*upper}.AsFortran(bs << "(:") << ')';
1607a77c20dSpeter klausler         } else {
1617a77c20dSpeter klausler           bs << "DEFAULT";
1627a77c20dSpeter klausler         }
1637a77c20dSpeter klausler       }
1647a77c20dSpeter klausler       return result;
1657a77c20dSpeter klausler     }
1667a77c20dSpeter klausler 
1677a77c20dSpeter klausler     const parser::Statement<parser::CaseStmt> &stmt;
1687a77c20dSpeter klausler     std::optional<Value> lower, upper;
1697a77c20dSpeter klausler   };
1707a77c20dSpeter klausler 
1717a77c20dSpeter klausler   // Defines a comparator for use with std::list<>::sort().
1727a77c20dSpeter klausler   // Returns true if and only if the highest value in range x is less
1737a77c20dSpeter klausler   // than the least value in range y.  The DEFAULT case is arbitrarily
1747a77c20dSpeter klausler   // defined to be less than all others.  When two ranges overlap,
1757a77c20dSpeter klausler   // neither is less than the other.
1767a77c20dSpeter klausler   struct Comparator {
1777a77c20dSpeter klausler     bool operator()(const Case &x, const Case &y) const {
1787a77c20dSpeter klausler       if (x.IsDefault()) {
1797a77c20dSpeter klausler         return !y.IsDefault();
1807a77c20dSpeter klausler       } else {
1817a77c20dSpeter klausler         return x.upper && y.lower && *x.upper < *y.lower;
1827a77c20dSpeter klausler       }
1837a77c20dSpeter klausler     }
1847a77c20dSpeter klausler   };
1857a77c20dSpeter klausler 
1867a77c20dSpeter klausler   bool AreCasesDisjoint() const {
1877a77c20dSpeter klausler     auto endIter{cases_.end()};
1887a77c20dSpeter klausler     for (auto iter{cases_.begin()}; iter != endIter; ++iter) {
1897a77c20dSpeter klausler       auto next{iter};
1907a77c20dSpeter klausler       if (++next != endIter && !Comparator{}(*iter, *next)) {
1917a77c20dSpeter klausler         return false;
1927a77c20dSpeter klausler       }
1937a77c20dSpeter klausler     }
1947a77c20dSpeter klausler     return true;
1957a77c20dSpeter klausler   }
1967a77c20dSpeter klausler 
1977a77c20dSpeter klausler   // This has quadratic time, but only runs in error cases
1987a77c20dSpeter klausler   void ReportConflictingCases() {
1997a77c20dSpeter klausler     for (auto iter{cases_.begin()}; iter != cases_.end(); ++iter) {
2007a77c20dSpeter klausler       parser::Message *msg{nullptr};
2017a77c20dSpeter klausler       for (auto p{cases_.begin()}; p != cases_.end(); ++p) {
2027a77c20dSpeter klausler         if (p->stmt.source.begin() < iter->stmt.source.begin() &&
2037a77c20dSpeter klausler             !Comparator{}(*p, *iter) && !Comparator{}(*iter, *p)) {
2047a77c20dSpeter klausler           if (!msg) {
2057a77c20dSpeter klausler             msg = &context_.Say(iter->stmt.source,
2067a77c20dSpeter klausler                 "CASE %s conflicts with previous cases"_err_en_US,
2077a77c20dSpeter klausler                 iter->AsFortran());
2087a77c20dSpeter klausler           }
2097a77c20dSpeter klausler           msg->Attach(
2107a77c20dSpeter klausler               p->stmt.source, "Conflicting CASE %s"_en_US, p->AsFortran());
2117a77c20dSpeter klausler         }
2127a77c20dSpeter klausler       }
2137a77c20dSpeter klausler     }
2147a77c20dSpeter klausler   }
2157a77c20dSpeter klausler 
2167a77c20dSpeter klausler   SemanticsContext &context_;
2177a77c20dSpeter klausler   const evaluate::DynamicType &caseExprType_;
2187a77c20dSpeter klausler   std::list<Case> cases_;
2197a77c20dSpeter klausler   bool hasErrors_{false};
2207a77c20dSpeter klausler };
2217a77c20dSpeter klausler 
22211ddb84bSpeter klausler template <TypeCategory CAT> struct TypeVisitor {
22311ddb84bSpeter klausler   using Result = bool;
22411ddb84bSpeter klausler   using Types = evaluate::CategoryTypes<CAT>;
22511ddb84bSpeter klausler   template <typename T> Result Test() {
22611ddb84bSpeter klausler     if (T::kind == exprType.kind()) {
22711ddb84bSpeter klausler       CaseValues<T>(context, exprType).Check(caseList);
22811ddb84bSpeter klausler       return true;
22911ddb84bSpeter klausler     } else {
23011ddb84bSpeter klausler       return false;
23111ddb84bSpeter klausler     }
23211ddb84bSpeter klausler   }
23311ddb84bSpeter klausler   SemanticsContext &context;
23411ddb84bSpeter klausler   const evaluate::DynamicType &exprType;
23511ddb84bSpeter klausler   const std::list<parser::CaseConstruct::Case> &caseList;
23611ddb84bSpeter klausler };
23711ddb84bSpeter klausler 
2387a77c20dSpeter klausler void CaseChecker::Enter(const parser::CaseConstruct &construct) {
2397a77c20dSpeter klausler   const auto &selectCaseStmt{
2407a77c20dSpeter klausler       std::get<parser::Statement<parser::SelectCaseStmt>>(construct.t)};
2417a77c20dSpeter klausler   const auto &selectCase{selectCaseStmt.statement};
2427a77c20dSpeter klausler   const auto &selectExpr{
2437a77c20dSpeter klausler       std::get<parser::Scalar<parser::Expr>>(selectCase.t).thing};
2447e225423SPeter Klausler   const auto *x{GetExpr(context_, selectExpr)};
2457a77c20dSpeter klausler   if (!x) {
2467a77c20dSpeter klausler     return; // expression semantics failed
2477a77c20dSpeter klausler   }
2487a77c20dSpeter klausler   if (auto exprType{x->GetType()}) {
2497a77c20dSpeter klausler     const auto &caseList{
2507a77c20dSpeter klausler         std::get<std::list<parser::CaseConstruct::Case>>(construct.t)};
2517a77c20dSpeter klausler     switch (exprType->category()) {
2527a77c20dSpeter klausler     case TypeCategory::Integer:
25311ddb84bSpeter klausler       common::SearchTypes(
25411ddb84bSpeter klausler           TypeVisitor<TypeCategory::Integer>{context_, *exprType, caseList});
2557a77c20dSpeter klausler       return;
256*fc97d2e6SPeter Klausler     case TypeCategory::Unsigned:
257*fc97d2e6SPeter Klausler       common::SearchTypes(
258*fc97d2e6SPeter Klausler           TypeVisitor<TypeCategory::Unsigned>{context_, *exprType, caseList});
259*fc97d2e6SPeter Klausler       return;
2607a77c20dSpeter klausler     case TypeCategory::Logical:
2617a77c20dSpeter klausler       CaseValues<evaluate::Type<TypeCategory::Logical, 1>>{context_, *exprType}
2627a77c20dSpeter klausler           .Check(caseList);
2637a77c20dSpeter klausler       return;
2647a77c20dSpeter klausler     case TypeCategory::Character:
26511ddb84bSpeter klausler       common::SearchTypes(
26611ddb84bSpeter klausler           TypeVisitor<TypeCategory::Character>{context_, *exprType, caseList});
2677a77c20dSpeter klausler       return;
2687a77c20dSpeter klausler     default:
2697a77c20dSpeter klausler       break;
2707a77c20dSpeter klausler     }
2717a77c20dSpeter klausler   }
2727a77c20dSpeter klausler   context_.Say(selectExpr.source,
273*fc97d2e6SPeter Klausler       context_.IsEnabled(common::LanguageFeature::Unsigned)
274*fc97d2e6SPeter Klausler           ? "SELECT CASE expression must be integer, unsigned, logical, or character"_err_en_US
275*fc97d2e6SPeter Klausler           : "SELECT CASE expression must be integer, logical, or character"_err_en_US);
2767a77c20dSpeter klausler }
2777a77c20dSpeter klausler } // namespace Fortran::semantics
278