xref: /llvm-project/flang/lib/Evaluate/fold-logical.cpp (revision 460fc79a080ba5733c30610cceb6ddced37afdd4)
1 //===-- lib/Evaluate/fold-logical.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "fold-implementation.h"
10 #include "fold-reduction.h"
11 #include "flang/Evaluate/check-expression.h"
12 
13 namespace Fortran::evaluate {
14 
15 // for ALL & ANY
16 template <typename T>
17 static Expr<T> FoldAllAny(FoldingContext &context, FunctionRef<T> &&ref,
18     Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
19     Scalar<T> identity) {
20   static_assert(T::category == TypeCategory::Logical);
21   using Element = Scalar<T>;
22   std::optional<int> dim;
23   if (std::optional<Constant<T>> array{
24           ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
25               /*ARRAY(MASK)=*/0, /*DIM=*/1)}) {
26     auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
27       element = (element.*operation)(array->At(at));
28     }};
29     return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
30   }
31   return Expr<T>{std::move(ref)};
32 }
33 
34 template <int KIND>
35 Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
36     FoldingContext &context,
37     FunctionRef<Type<TypeCategory::Logical, KIND>> &&funcRef) {
38   using T = Type<TypeCategory::Logical, KIND>;
39   ActualArguments &args{funcRef.arguments()};
40   auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)};
41   CHECK(intrinsic);
42   std::string name{intrinsic->name};
43   using SameInt = Type<TypeCategory::Integer, KIND>;
44   if (name == "all") {
45     return FoldAllAny(
46         context, std::move(funcRef), &Scalar<T>::AND, Scalar<T>{true});
47   } else if (name == "any") {
48     return FoldAllAny(
49         context, std::move(funcRef), &Scalar<T>::OR, Scalar<T>{false});
50   } else if (name == "associated") {
51     bool gotConstant{true};
52     const Expr<SomeType> *firstArgExpr{args[0]->UnwrapExpr()};
53     if (!firstArgExpr || !IsNullPointer(*firstArgExpr)) {
54       gotConstant = false;
55     } else if (args[1]) { // There's a second argument
56       const Expr<SomeType> *secondArgExpr{args[1]->UnwrapExpr()};
57       if (!secondArgExpr || !IsNullPointer(*secondArgExpr)) {
58         gotConstant = false;
59       }
60     }
61     return gotConstant ? Expr<T>{false} : Expr<T>{std::move(funcRef)};
62   } else if (name == "bge" || name == "bgt" || name == "ble" || name == "blt") {
63     static_assert(std::is_same_v<Scalar<LargestInt>, BOZLiteralConstant>);
64     // Arguments do not have to be of the same integer type. Convert all
65     // arguments to the biggest integer type before comparing them to
66     // simplify.
67     for (int i{0}; i <= 1; ++i) {
68       if (auto *x{UnwrapExpr<Expr<SomeInteger>>(args[i])}) {
69         *args[i] = AsGenericExpr(
70             Fold(context, ConvertToType<LargestInt>(std::move(*x))));
71       } else if (auto *x{UnwrapExpr<BOZLiteralConstant>(args[i])}) {
72         *args[i] = AsGenericExpr(Constant<LargestInt>{std::move(*x)});
73       }
74     }
75     auto fptr{&Scalar<LargestInt>::BGE};
76     if (name == "bge") { // done in fptr declaration
77     } else if (name == "bgt") {
78       fptr = &Scalar<LargestInt>::BGT;
79     } else if (name == "ble") {
80       fptr = &Scalar<LargestInt>::BLE;
81     } else if (name == "blt") {
82       fptr = &Scalar<LargestInt>::BLT;
83     } else {
84       common::die("missing case to fold intrinsic function %s", name.c_str());
85     }
86     return FoldElementalIntrinsic<T, LargestInt, LargestInt>(context,
87         std::move(funcRef),
88         ScalarFunc<T, LargestInt, LargestInt>(
89             [&fptr](const Scalar<LargestInt> &i, const Scalar<LargestInt> &j) {
90               return Scalar<T>{std::invoke(fptr, i, j)};
91             }));
92   } else if (name == "btest") {
93     if (const auto *ix{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
94       return common::visit(
95           [&](const auto &x) {
96             using IT = ResultType<decltype(x)>;
97             return FoldElementalIntrinsic<T, IT, SameInt>(context,
98                 std::move(funcRef),
99                 ScalarFunc<T, IT, SameInt>(
100                     [&](const Scalar<IT> &x, const Scalar<SameInt> &pos) {
101                       auto posVal{pos.ToInt64()};
102                       if (posVal < 0 || posVal >= x.bits) {
103                         context.messages().Say(
104                             "POS=%jd out of range for BTEST"_err_en_US,
105                             static_cast<std::intmax_t>(posVal));
106                       }
107                       return Scalar<T>{x.BTEST(posVal)};
108                     }));
109           },
110           ix->u);
111     }
112   } else if (name == "extends_type_of") {
113     // Type extension testing with EXTENDS_TYPE_OF() ignores any type
114     // parameters. Returns a constant truth value when the result is known now.
115     if (args[0] && args[1]) {
116       auto t0{args[0]->GetType()};
117       auto t1{args[1]->GetType()};
118       if (t0 && t1) {
119         if (auto result{t0->ExtendsTypeOf(*t1)}) {
120           return Expr<T>{*result};
121         }
122       }
123     }
124   } else if (name == "isnan" || name == "__builtin_ieee_is_nan") {
125     // A warning about an invalid argument is discarded from converting
126     // the argument of isnan() / IEEE_IS_NAN().
127     auto restorer{context.messages().DiscardMessages()};
128     using DefaultReal = Type<TypeCategory::Real, 4>;
129     return FoldElementalIntrinsic<T, DefaultReal>(context, std::move(funcRef),
130         ScalarFunc<T, DefaultReal>([](const Scalar<DefaultReal> &x) {
131           return Scalar<T>{x.IsNotANumber()};
132         }));
133   } else if (name == "__builtin_ieee_is_negative") {
134     auto restorer{context.messages().DiscardMessages()};
135     using DefaultReal = Type<TypeCategory::Real, 4>;
136     return FoldElementalIntrinsic<T, DefaultReal>(context, std::move(funcRef),
137         ScalarFunc<T, DefaultReal>([](const Scalar<DefaultReal> &x) {
138           return Scalar<T>{x.IsNegative()};
139         }));
140   } else if (name == "__builtin_ieee_is_normal") {
141     auto restorer{context.messages().DiscardMessages()};
142     using DefaultReal = Type<TypeCategory::Real, 4>;
143     return FoldElementalIntrinsic<T, DefaultReal>(context, std::move(funcRef),
144         ScalarFunc<T, DefaultReal>([](const Scalar<DefaultReal> &x) {
145           return Scalar<T>{x.IsNormal()};
146         }));
147   } else if (name == "is_contiguous") {
148     if (args.at(0)) {
149       if (auto *expr{args[0]->UnwrapExpr()}) {
150         if (IsSimplyContiguous(*expr, context)) {
151           return Expr<T>{true};
152         }
153       }
154     }
155   } else if (name == "lge" || name == "lgt" || name == "lle" || name == "llt") {
156     // Rewrite LGE/LGT/LLE/LLT into ASCII character relations
157     auto *cx0{UnwrapExpr<Expr<SomeCharacter>>(args[0])};
158     auto *cx1{UnwrapExpr<Expr<SomeCharacter>>(args[1])};
159     if (cx0 && cx1) {
160       return Fold(context,
161           ConvertToType<T>(
162               PackageRelation(name == "lge" ? RelationalOperator::GE
163                       : name == "lgt"       ? RelationalOperator::GT
164                       : name == "lle"       ? RelationalOperator::LE
165                                             : RelationalOperator::LT,
166                   ConvertToType<Ascii>(std::move(*cx0)),
167                   ConvertToType<Ascii>(std::move(*cx1)))));
168     }
169   } else if (name == "logical") {
170     if (auto *expr{UnwrapExpr<Expr<SomeLogical>>(args[0])}) {
171       return Fold(context, ConvertToType<T>(std::move(*expr)));
172     }
173   } else if (name == "merge") {
174     return FoldMerge<T>(context, std::move(funcRef));
175   } else if (name == "same_type_as") {
176     // Type equality testing with SAME_TYPE_AS() ignores any type parameters.
177     // Returns a constant truth value when the result is known now.
178     if (args[0] && args[1]) {
179       auto t0{args[0]->GetType()};
180       auto t1{args[1]->GetType()};
181       if (t0 && t1) {
182         if (auto result{t0->SameTypeAs(*t1)}) {
183           return Expr<T>{*result};
184         }
185       }
186     }
187   } else if (name == "__builtin_ieee_support_datatype" ||
188       name == "__builtin_ieee_support_denormal" ||
189       name == "__builtin_ieee_support_divide" ||
190       name == "__builtin_ieee_support_divide" ||
191       name == "__builtin_ieee_support_inf" ||
192       name == "__builtin_ieee_support_io" ||
193       name == "__builtin_ieee_support_nan" ||
194       name == "__builtin_ieee_support_sqrt" ||
195       name == "__builtin_ieee_support_standard" ||
196       name == "__builtin_ieee_support_subnormal" ||
197       name == "__builtin_ieee_support_underflow_control") {
198     return Expr<T>{true};
199   }
200   // TODO: dot_product, is_iostat_end,
201   // is_iostat_eor, logical, matmul, out_of_range,
202   // parity, transfer
203   return Expr<T>{std::move(funcRef)};
204 }
205 
206 template <typename T>
207 Expr<LogicalResult> FoldOperation(
208     FoldingContext &context, Relational<T> &&relation) {
209   if (auto array{ApplyElementwise(context, relation,
210           std::function<Expr<LogicalResult>(Expr<T> &&, Expr<T> &&)>{
211               [=](Expr<T> &&x, Expr<T> &&y) {
212                 return Expr<LogicalResult>{Relational<SomeType>{
213                     Relational<T>{relation.opr, std::move(x), std::move(y)}}};
214               }})}) {
215     return *array;
216   }
217   if (auto folded{OperandsAreConstants(relation)}) {
218     bool result{};
219     if constexpr (T::category == TypeCategory::Integer) {
220       result =
221           Satisfies(relation.opr, folded->first.CompareSigned(folded->second));
222     } else if constexpr (T::category == TypeCategory::Real) {
223       result = Satisfies(relation.opr, folded->first.Compare(folded->second));
224     } else if constexpr (T::category == TypeCategory::Complex) {
225       result = (relation.opr == RelationalOperator::EQ) ==
226           folded->first.Equals(folded->second);
227     } else if constexpr (T::category == TypeCategory::Character) {
228       result = Satisfies(relation.opr, Compare(folded->first, folded->second));
229     } else {
230       static_assert(T::category != TypeCategory::Logical);
231     }
232     return Expr<LogicalResult>{Constant<LogicalResult>{result}};
233   }
234   return Expr<LogicalResult>{Relational<SomeType>{std::move(relation)}};
235 }
236 
237 Expr<LogicalResult> FoldOperation(
238     FoldingContext &context, Relational<SomeType> &&relation) {
239   return common::visit(
240       [&](auto &&x) {
241         return Expr<LogicalResult>{FoldOperation(context, std::move(x))};
242       },
243       std::move(relation.u));
244 }
245 
246 template <int KIND>
247 Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(
248     FoldingContext &context, Not<KIND> &&x) {
249   if (auto array{ApplyElementwise(context, x)}) {
250     return *array;
251   }
252   using Ty = Type<TypeCategory::Logical, KIND>;
253   auto &operand{x.left()};
254   if (auto value{GetScalarConstantValue<Ty>(operand)}) {
255     return Expr<Ty>{Constant<Ty>{!value->IsTrue()}};
256   }
257   return Expr<Ty>{x};
258 }
259 
260 template <int KIND>
261 Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(
262     FoldingContext &context, LogicalOperation<KIND> &&operation) {
263   using LOGICAL = Type<TypeCategory::Logical, KIND>;
264   if (auto array{ApplyElementwise(context, operation,
265           std::function<Expr<LOGICAL>(Expr<LOGICAL> &&, Expr<LOGICAL> &&)>{
266               [=](Expr<LOGICAL> &&x, Expr<LOGICAL> &&y) {
267                 return Expr<LOGICAL>{LogicalOperation<KIND>{
268                     operation.logicalOperator, std::move(x), std::move(y)}};
269               }})}) {
270     return *array;
271   }
272   if (auto folded{OperandsAreConstants(operation)}) {
273     bool xt{folded->first.IsTrue()}, yt{folded->second.IsTrue()}, result{};
274     switch (operation.logicalOperator) {
275     case LogicalOperator::And:
276       result = xt && yt;
277       break;
278     case LogicalOperator::Or:
279       result = xt || yt;
280       break;
281     case LogicalOperator::Eqv:
282       result = xt == yt;
283       break;
284     case LogicalOperator::Neqv:
285       result = xt != yt;
286       break;
287     case LogicalOperator::Not:
288       DIE("not a binary operator");
289     }
290     return Expr<LOGICAL>{Constant<LOGICAL>{result}};
291   }
292   return Expr<LOGICAL>{std::move(operation)};
293 }
294 
295 #ifdef _MSC_VER // disable bogus warning about missing definitions
296 #pragma warning(disable : 4661)
297 #endif
298 FOR_EACH_LOGICAL_KIND(template class ExpressionBase, )
299 template class ExpressionBase<SomeLogical>;
300 } // namespace Fortran::evaluate
301