1 //===-- lib/Semantics/check-case.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "check-case.h" 10 #include "flang/Common/idioms.h" 11 #include "flang/Common/reference.h" 12 #include "flang/Common/template.h" 13 #include "flang/Evaluate/fold.h" 14 #include "flang/Evaluate/type.h" 15 #include "flang/Parser/parse-tree.h" 16 #include "flang/Semantics/semantics.h" 17 #include "flang/Semantics/tools.h" 18 #include <tuple> 19 20 namespace Fortran::semantics { 21 22 template <typename T> class CaseValues { 23 public: 24 CaseValues(SemanticsContext &c, const evaluate::DynamicType &t) 25 : context_{c}, caseExprType_{t} {} 26 27 void Check(const std::list<parser::CaseConstruct::Case> &cases) { 28 for (const parser::CaseConstruct::Case &c : cases) { 29 AddCase(c); 30 } 31 if (!hasErrors_) { 32 cases_.sort(Comparator{}); 33 if (!AreCasesDisjoint()) { // C1149 34 ReportConflictingCases(); 35 } 36 } 37 } 38 39 private: 40 using Value = evaluate::Scalar<T>; 41 42 void AddCase(const parser::CaseConstruct::Case &c) { 43 const auto &stmt{std::get<parser::Statement<parser::CaseStmt>>(c.t)}; 44 const parser::CaseStmt &caseStmt{stmt.statement}; 45 const auto &selector{std::get<parser::CaseSelector>(caseStmt.t)}; 46 common::visit( 47 common::visitors{ 48 [&](const std::list<parser::CaseValueRange> &ranges) { 49 for (const auto &range : ranges) { 50 auto pair{ComputeBounds(range)}; 51 if (pair.first && pair.second && *pair.first > *pair.second) { 52 context_.Warn(common::UsageWarning::EmptyCase, stmt.source, 53 "CASE has lower bound greater than upper bound"_warn_en_US); 54 } else { 55 if constexpr (T::category == TypeCategory::Logical) { // C1148 56 if ((pair.first || pair.second) && 57 (!pair.first || !pair.second || 58 *pair.first != *pair.second)) { 59 context_.Say(stmt.source, 60 "CASE range is not allowed for LOGICAL"_err_en_US); 61 } 62 } 63 cases_.emplace_back(stmt); 64 cases_.back().lower = std::move(pair.first); 65 cases_.back().upper = std::move(pair.second); 66 } 67 } 68 }, 69 [&](const parser::Default &) { cases_.emplace_front(stmt); }, 70 }, 71 selector.u); 72 } 73 74 std::optional<Value> GetValue(const parser::CaseValue &caseValue) { 75 const parser::Expr &expr{caseValue.thing.thing.value()}; 76 auto *x{expr.typedExpr.get()}; 77 if (x && x->v) { // C1147 78 auto type{x->v->GetType()}; 79 if (type && type->category() == caseExprType_.category() && 80 (type->category() != TypeCategory::Character || 81 type->kind() == caseExprType_.kind())) { 82 parser::Messages buffer; // discarded folding messages 83 parser::ContextualMessages foldingMessages{expr.source, &buffer}; 84 evaluate::FoldingContext foldingContext{ 85 context_.foldingContext(), foldingMessages}; 86 auto folded{evaluate::Fold(foldingContext, SomeExpr{*x->v})}; 87 if (auto converted{evaluate::Fold(foldingContext, 88 evaluate::ConvertToType(T::GetType(), SomeExpr{folded}))}) { 89 if (auto value{evaluate::GetScalarConstantValue<T>(*converted)}) { 90 auto back{evaluate::Fold(foldingContext, 91 evaluate::ConvertToType(*type, SomeExpr{*converted}))}; 92 if (back == folded) { 93 x->v = converted; 94 return value; 95 } else { 96 context_.Warn(common::UsageWarning::CaseOverflow, expr.source, 97 "CASE value (%s) overflows type (%s) of SELECT CASE expression"_warn_en_US, 98 folded.AsFortran(), caseExprType_.AsFortran()); 99 hasErrors_ = true; 100 return std::nullopt; 101 } 102 } 103 } 104 context_.Say(expr.source, 105 "CASE value (%s) must be a constant scalar"_err_en_US, 106 x->v->AsFortran()); 107 } else { 108 std::string typeStr{type ? type->AsFortran() : "typeless"s}; 109 context_.Say(expr.source, 110 "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US, 111 typeStr, caseExprType_.AsFortran()); 112 } 113 hasErrors_ = true; 114 } 115 return std::nullopt; 116 } 117 118 using PairOfValues = std::pair<std::optional<Value>, std::optional<Value>>; 119 PairOfValues ComputeBounds(const parser::CaseValueRange &range) { 120 return common::visit( 121 common::visitors{ 122 [&](const parser::CaseValue &x) { 123 auto value{GetValue(x)}; 124 return PairOfValues{value, value}; 125 }, 126 [&](const parser::CaseValueRange::Range &x) { 127 std::optional<Value> lo, hi; 128 if (x.lower) { 129 lo = GetValue(*x.lower); 130 } 131 if (x.upper) { 132 hi = GetValue(*x.upper); 133 } 134 if ((x.lower && !lo) || (x.upper && !hi)) { 135 return PairOfValues{}; // error case 136 } 137 return PairOfValues{std::move(lo), std::move(hi)}; 138 }, 139 }, 140 range.u); 141 } 142 143 struct Case { 144 explicit Case(const parser::Statement<parser::CaseStmt> &s) : stmt{s} {} 145 bool IsDefault() const { return !lower && !upper; } 146 std::string AsFortran() const { 147 std::string result; 148 { 149 llvm::raw_string_ostream bs{result}; 150 if (lower) { 151 evaluate::Constant<T>{*lower}.AsFortran(bs << '('); 152 if (!upper) { 153 bs << ':'; 154 } else if (*lower != *upper) { 155 evaluate::Constant<T>{*upper}.AsFortran(bs << ':'); 156 } 157 bs << ')'; 158 } else if (upper) { 159 evaluate::Constant<T>{*upper}.AsFortran(bs << "(:") << ')'; 160 } else { 161 bs << "DEFAULT"; 162 } 163 } 164 return result; 165 } 166 167 const parser::Statement<parser::CaseStmt> &stmt; 168 std::optional<Value> lower, upper; 169 }; 170 171 // Defines a comparator for use with std::list<>::sort(). 172 // Returns true if and only if the highest value in range x is less 173 // than the least value in range y. The DEFAULT case is arbitrarily 174 // defined to be less than all others. When two ranges overlap, 175 // neither is less than the other. 176 struct Comparator { 177 bool operator()(const Case &x, const Case &y) const { 178 if (x.IsDefault()) { 179 return !y.IsDefault(); 180 } else { 181 return x.upper && y.lower && *x.upper < *y.lower; 182 } 183 } 184 }; 185 186 bool AreCasesDisjoint() const { 187 auto endIter{cases_.end()}; 188 for (auto iter{cases_.begin()}; iter != endIter; ++iter) { 189 auto next{iter}; 190 if (++next != endIter && !Comparator{}(*iter, *next)) { 191 return false; 192 } 193 } 194 return true; 195 } 196 197 // This has quadratic time, but only runs in error cases 198 void ReportConflictingCases() { 199 for (auto iter{cases_.begin()}; iter != cases_.end(); ++iter) { 200 parser::Message *msg{nullptr}; 201 for (auto p{cases_.begin()}; p != cases_.end(); ++p) { 202 if (p->stmt.source.begin() < iter->stmt.source.begin() && 203 !Comparator{}(*p, *iter) && !Comparator{}(*iter, *p)) { 204 if (!msg) { 205 msg = &context_.Say(iter->stmt.source, 206 "CASE %s conflicts with previous cases"_err_en_US, 207 iter->AsFortran()); 208 } 209 msg->Attach( 210 p->stmt.source, "Conflicting CASE %s"_en_US, p->AsFortran()); 211 } 212 } 213 } 214 } 215 216 SemanticsContext &context_; 217 const evaluate::DynamicType &caseExprType_; 218 std::list<Case> cases_; 219 bool hasErrors_{false}; 220 }; 221 222 template <TypeCategory CAT> struct TypeVisitor { 223 using Result = bool; 224 using Types = evaluate::CategoryTypes<CAT>; 225 template <typename T> Result Test() { 226 if (T::kind == exprType.kind()) { 227 CaseValues<T>(context, exprType).Check(caseList); 228 return true; 229 } else { 230 return false; 231 } 232 } 233 SemanticsContext &context; 234 const evaluate::DynamicType &exprType; 235 const std::list<parser::CaseConstruct::Case> &caseList; 236 }; 237 238 void CaseChecker::Enter(const parser::CaseConstruct &construct) { 239 const auto &selectCaseStmt{ 240 std::get<parser::Statement<parser::SelectCaseStmt>>(construct.t)}; 241 const auto &selectCase{selectCaseStmt.statement}; 242 const auto &selectExpr{ 243 std::get<parser::Scalar<parser::Expr>>(selectCase.t).thing}; 244 const auto *x{GetExpr(context_, selectExpr)}; 245 if (!x) { 246 return; // expression semantics failed 247 } 248 if (auto exprType{x->GetType()}) { 249 const auto &caseList{ 250 std::get<std::list<parser::CaseConstruct::Case>>(construct.t)}; 251 switch (exprType->category()) { 252 case TypeCategory::Integer: 253 common::SearchTypes( 254 TypeVisitor<TypeCategory::Integer>{context_, *exprType, caseList}); 255 return; 256 case TypeCategory::Unsigned: 257 common::SearchTypes( 258 TypeVisitor<TypeCategory::Unsigned>{context_, *exprType, caseList}); 259 return; 260 case TypeCategory::Logical: 261 CaseValues<evaluate::Type<TypeCategory::Logical, 1>>{context_, *exprType} 262 .Check(caseList); 263 return; 264 case TypeCategory::Character: 265 common::SearchTypes( 266 TypeVisitor<TypeCategory::Character>{context_, *exprType, caseList}); 267 return; 268 default: 269 break; 270 } 271 } 272 context_.Say(selectExpr.source, 273 context_.IsEnabled(common::LanguageFeature::Unsigned) 274 ? "SELECT CASE expression must be integer, unsigned, logical, or character"_err_en_US 275 : "SELECT CASE expression must be integer, logical, or character"_err_en_US); 276 } 277 } // namespace Fortran::semantics 278