xref: /llvm-project/flang/lib/Semantics/check-case.cpp (revision fc97d2e68b03bc2979395e84b645e5b3ba35aecd)
1 //===-- lib/Semantics/check-case.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "check-case.h"
10 #include "flang/Common/idioms.h"
11 #include "flang/Common/reference.h"
12 #include "flang/Common/template.h"
13 #include "flang/Evaluate/fold.h"
14 #include "flang/Evaluate/type.h"
15 #include "flang/Parser/parse-tree.h"
16 #include "flang/Semantics/semantics.h"
17 #include "flang/Semantics/tools.h"
18 #include <tuple>
19 
20 namespace Fortran::semantics {
21 
22 template <typename T> class CaseValues {
23 public:
24   CaseValues(SemanticsContext &c, const evaluate::DynamicType &t)
25       : context_{c}, caseExprType_{t} {}
26 
27   void Check(const std::list<parser::CaseConstruct::Case> &cases) {
28     for (const parser::CaseConstruct::Case &c : cases) {
29       AddCase(c);
30     }
31     if (!hasErrors_) {
32       cases_.sort(Comparator{});
33       if (!AreCasesDisjoint()) { // C1149
34         ReportConflictingCases();
35       }
36     }
37   }
38 
39 private:
40   using Value = evaluate::Scalar<T>;
41 
42   void AddCase(const parser::CaseConstruct::Case &c) {
43     const auto &stmt{std::get<parser::Statement<parser::CaseStmt>>(c.t)};
44     const parser::CaseStmt &caseStmt{stmt.statement};
45     const auto &selector{std::get<parser::CaseSelector>(caseStmt.t)};
46     common::visit(
47         common::visitors{
48             [&](const std::list<parser::CaseValueRange> &ranges) {
49               for (const auto &range : ranges) {
50                 auto pair{ComputeBounds(range)};
51                 if (pair.first && pair.second && *pair.first > *pair.second) {
52                   context_.Warn(common::UsageWarning::EmptyCase, stmt.source,
53                       "CASE has lower bound greater than upper bound"_warn_en_US);
54                 } else {
55                   if constexpr (T::category == TypeCategory::Logical) { // C1148
56                     if ((pair.first || pair.second) &&
57                         (!pair.first || !pair.second ||
58                             *pair.first != *pair.second)) {
59                       context_.Say(stmt.source,
60                           "CASE range is not allowed for LOGICAL"_err_en_US);
61                     }
62                   }
63                   cases_.emplace_back(stmt);
64                   cases_.back().lower = std::move(pair.first);
65                   cases_.back().upper = std::move(pair.second);
66                 }
67               }
68             },
69             [&](const parser::Default &) { cases_.emplace_front(stmt); },
70         },
71         selector.u);
72   }
73 
74   std::optional<Value> GetValue(const parser::CaseValue &caseValue) {
75     const parser::Expr &expr{caseValue.thing.thing.value()};
76     auto *x{expr.typedExpr.get()};
77     if (x && x->v) { // C1147
78       auto type{x->v->GetType()};
79       if (type && type->category() == caseExprType_.category() &&
80           (type->category() != TypeCategory::Character ||
81               type->kind() == caseExprType_.kind())) {
82         parser::Messages buffer; // discarded folding messages
83         parser::ContextualMessages foldingMessages{expr.source, &buffer};
84         evaluate::FoldingContext foldingContext{
85             context_.foldingContext(), foldingMessages};
86         auto folded{evaluate::Fold(foldingContext, SomeExpr{*x->v})};
87         if (auto converted{evaluate::Fold(foldingContext,
88                 evaluate::ConvertToType(T::GetType(), SomeExpr{folded}))}) {
89           if (auto value{evaluate::GetScalarConstantValue<T>(*converted)}) {
90             auto back{evaluate::Fold(foldingContext,
91                 evaluate::ConvertToType(*type, SomeExpr{*converted}))};
92             if (back == folded) {
93               x->v = converted;
94               return value;
95             } else {
96               context_.Warn(common::UsageWarning::CaseOverflow, expr.source,
97                   "CASE value (%s) overflows type (%s) of SELECT CASE expression"_warn_en_US,
98                   folded.AsFortran(), caseExprType_.AsFortran());
99               hasErrors_ = true;
100               return std::nullopt;
101             }
102           }
103         }
104         context_.Say(expr.source,
105             "CASE value (%s) must be a constant scalar"_err_en_US,
106             x->v->AsFortran());
107       } else {
108         std::string typeStr{type ? type->AsFortran() : "typeless"s};
109         context_.Say(expr.source,
110             "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US,
111             typeStr, caseExprType_.AsFortran());
112       }
113       hasErrors_ = true;
114     }
115     return std::nullopt;
116   }
117 
118   using PairOfValues = std::pair<std::optional<Value>, std::optional<Value>>;
119   PairOfValues ComputeBounds(const parser::CaseValueRange &range) {
120     return common::visit(
121         common::visitors{
122             [&](const parser::CaseValue &x) {
123               auto value{GetValue(x)};
124               return PairOfValues{value, value};
125             },
126             [&](const parser::CaseValueRange::Range &x) {
127               std::optional<Value> lo, hi;
128               if (x.lower) {
129                 lo = GetValue(*x.lower);
130               }
131               if (x.upper) {
132                 hi = GetValue(*x.upper);
133               }
134               if ((x.lower && !lo) || (x.upper && !hi)) {
135                 return PairOfValues{}; // error case
136               }
137               return PairOfValues{std::move(lo), std::move(hi)};
138             },
139         },
140         range.u);
141   }
142 
143   struct Case {
144     explicit Case(const parser::Statement<parser::CaseStmt> &s) : stmt{s} {}
145     bool IsDefault() const { return !lower && !upper; }
146     std::string AsFortran() const {
147       std::string result;
148       {
149         llvm::raw_string_ostream bs{result};
150         if (lower) {
151           evaluate::Constant<T>{*lower}.AsFortran(bs << '(');
152           if (!upper) {
153             bs << ':';
154           } else if (*lower != *upper) {
155             evaluate::Constant<T>{*upper}.AsFortran(bs << ':');
156           }
157           bs << ')';
158         } else if (upper) {
159           evaluate::Constant<T>{*upper}.AsFortran(bs << "(:") << ')';
160         } else {
161           bs << "DEFAULT";
162         }
163       }
164       return result;
165     }
166 
167     const parser::Statement<parser::CaseStmt> &stmt;
168     std::optional<Value> lower, upper;
169   };
170 
171   // Defines a comparator for use with std::list<>::sort().
172   // Returns true if and only if the highest value in range x is less
173   // than the least value in range y.  The DEFAULT case is arbitrarily
174   // defined to be less than all others.  When two ranges overlap,
175   // neither is less than the other.
176   struct Comparator {
177     bool operator()(const Case &x, const Case &y) const {
178       if (x.IsDefault()) {
179         return !y.IsDefault();
180       } else {
181         return x.upper && y.lower && *x.upper < *y.lower;
182       }
183     }
184   };
185 
186   bool AreCasesDisjoint() const {
187     auto endIter{cases_.end()};
188     for (auto iter{cases_.begin()}; iter != endIter; ++iter) {
189       auto next{iter};
190       if (++next != endIter && !Comparator{}(*iter, *next)) {
191         return false;
192       }
193     }
194     return true;
195   }
196 
197   // This has quadratic time, but only runs in error cases
198   void ReportConflictingCases() {
199     for (auto iter{cases_.begin()}; iter != cases_.end(); ++iter) {
200       parser::Message *msg{nullptr};
201       for (auto p{cases_.begin()}; p != cases_.end(); ++p) {
202         if (p->stmt.source.begin() < iter->stmt.source.begin() &&
203             !Comparator{}(*p, *iter) && !Comparator{}(*iter, *p)) {
204           if (!msg) {
205             msg = &context_.Say(iter->stmt.source,
206                 "CASE %s conflicts with previous cases"_err_en_US,
207                 iter->AsFortran());
208           }
209           msg->Attach(
210               p->stmt.source, "Conflicting CASE %s"_en_US, p->AsFortran());
211         }
212       }
213     }
214   }
215 
216   SemanticsContext &context_;
217   const evaluate::DynamicType &caseExprType_;
218   std::list<Case> cases_;
219   bool hasErrors_{false};
220 };
221 
222 template <TypeCategory CAT> struct TypeVisitor {
223   using Result = bool;
224   using Types = evaluate::CategoryTypes<CAT>;
225   template <typename T> Result Test() {
226     if (T::kind == exprType.kind()) {
227       CaseValues<T>(context, exprType).Check(caseList);
228       return true;
229     } else {
230       return false;
231     }
232   }
233   SemanticsContext &context;
234   const evaluate::DynamicType &exprType;
235   const std::list<parser::CaseConstruct::Case> &caseList;
236 };
237 
238 void CaseChecker::Enter(const parser::CaseConstruct &construct) {
239   const auto &selectCaseStmt{
240       std::get<parser::Statement<parser::SelectCaseStmt>>(construct.t)};
241   const auto &selectCase{selectCaseStmt.statement};
242   const auto &selectExpr{
243       std::get<parser::Scalar<parser::Expr>>(selectCase.t).thing};
244   const auto *x{GetExpr(context_, selectExpr)};
245   if (!x) {
246     return; // expression semantics failed
247   }
248   if (auto exprType{x->GetType()}) {
249     const auto &caseList{
250         std::get<std::list<parser::CaseConstruct::Case>>(construct.t)};
251     switch (exprType->category()) {
252     case TypeCategory::Integer:
253       common::SearchTypes(
254           TypeVisitor<TypeCategory::Integer>{context_, *exprType, caseList});
255       return;
256     case TypeCategory::Unsigned:
257       common::SearchTypes(
258           TypeVisitor<TypeCategory::Unsigned>{context_, *exprType, caseList});
259       return;
260     case TypeCategory::Logical:
261       CaseValues<evaluate::Type<TypeCategory::Logical, 1>>{context_, *exprType}
262           .Check(caseList);
263       return;
264     case TypeCategory::Character:
265       common::SearchTypes(
266           TypeVisitor<TypeCategory::Character>{context_, *exprType, caseList});
267       return;
268     default:
269       break;
270     }
271   }
272   context_.Say(selectExpr.source,
273       context_.IsEnabled(common::LanguageFeature::Unsigned)
274           ? "SELECT CASE expression must be integer, unsigned, logical, or character"_err_en_US
275           : "SELECT CASE expression must be integer, logical, or character"_err_en_US);
276 }
277 } // namespace Fortran::semantics
278