xref: /llvm-project/flang/lib/Evaluate/fold-logical.cpp (revision 67f15e76ebd23f659168f845f9660d3e5e9c606c)
1 //===-- lib/Evaluate/fold-logical.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "fold-implementation.h"
10 #include "fold-reduction.h"
11 #include "flang/Evaluate/check-expression.h"
12 #include "flang/Runtime/magic-numbers.h"
13 
14 namespace Fortran::evaluate {
15 
16 template <typename T>
17 static std::optional<Expr<SomeType>> ZeroExtend(const Constant<T> &c) {
18   std::vector<Scalar<LargestInt>> exts;
19   for (const auto &v : c.values()) {
20     exts.push_back(Scalar<LargestInt>::ConvertUnsigned(v).value);
21   }
22   return AsGenericExpr(
23       Constant<LargestInt>(std::move(exts), ConstantSubscripts(c.shape())));
24 }
25 
26 // for ALL, ANY & PARITY
27 template <typename T>
28 static Expr<T> FoldAllAnyParity(FoldingContext &context, FunctionRef<T> &&ref,
29     Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
30     Scalar<T> identity) {
31   static_assert(T::category == TypeCategory::Logical);
32   std::optional<int> dim;
33   if (std::optional<Constant<T>> array{
34           ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
35               /*ARRAY(MASK)=*/0, /*DIM=*/1)}) {
36     OperationAccumulator accumulator{*array, operation};
37     return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
38   }
39   return Expr<T>{std::move(ref)};
40 }
41 
42 template <int KIND>
43 Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
44     FoldingContext &context,
45     FunctionRef<Type<TypeCategory::Logical, KIND>> &&funcRef) {
46   using T = Type<TypeCategory::Logical, KIND>;
47   ActualArguments &args{funcRef.arguments()};
48   auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)};
49   CHECK(intrinsic);
50   std::string name{intrinsic->name};
51   using SameInt = Type<TypeCategory::Integer, KIND>;
52   if (name == "all") {
53     return FoldAllAnyParity(
54         context, std::move(funcRef), &Scalar<T>::AND, Scalar<T>{true});
55   } else if (name == "any") {
56     return FoldAllAnyParity(
57         context, std::move(funcRef), &Scalar<T>::OR, Scalar<T>{false});
58   } else if (name == "associated") {
59     bool gotConstant{true};
60     const Expr<SomeType> *firstArgExpr{args[0]->UnwrapExpr()};
61     if (!firstArgExpr || !IsNullPointer(*firstArgExpr)) {
62       gotConstant = false;
63     } else if (args[1]) { // There's a second argument
64       const Expr<SomeType> *secondArgExpr{args[1]->UnwrapExpr()};
65       if (!secondArgExpr || !IsNullPointer(*secondArgExpr)) {
66         gotConstant = false;
67       }
68     }
69     return gotConstant ? Expr<T>{false} : Expr<T>{std::move(funcRef)};
70   } else if (name == "bge" || name == "bgt" || name == "ble" || name == "blt") {
71     static_assert(std::is_same_v<Scalar<LargestInt>, BOZLiteralConstant>);
72 
73     // The arguments to these intrinsics can be of different types. In that
74     // case, the shorter of the two would need to be zero-extended to match
75     // the size of the other. If at least one of the operands is not a constant,
76     // the zero-extending will be done during lowering. Otherwise, the folding
77     // must be done here.
78     std::optional<Expr<SomeType>> constArgs[2];
79     for (int i{0}; i <= 1; i++) {
80       if (BOZLiteralConstant * x{UnwrapExpr<BOZLiteralConstant>(args[i])}) {
81         constArgs[i] = AsGenericExpr(Constant<LargestInt>{std::move(*x)});
82       } else if (auto *x{UnwrapExpr<Expr<SomeInteger>>(args[i])}) {
83         common::visit(
84             [&](const auto &ix) {
85               using IntT = typename std::decay_t<decltype(ix)>::Result;
86               if (auto *c{UnwrapConstantValue<IntT>(ix)}) {
87                 constArgs[i] = ZeroExtend(*c);
88               }
89             },
90             x->u);
91       }
92     }
93 
94     if (constArgs[0] && constArgs[1]) {
95       auto fptr{&Scalar<LargestInt>::BGE};
96       if (name == "bge") { // done in fptr declaration
97       } else if (name == "bgt") {
98         fptr = &Scalar<LargestInt>::BGT;
99       } else if (name == "ble") {
100         fptr = &Scalar<LargestInt>::BLE;
101       } else if (name == "blt") {
102         fptr = &Scalar<LargestInt>::BLT;
103       } else {
104         common::die("missing case to fold intrinsic function %s", name.c_str());
105       }
106 
107       for (int i{0}; i <= 1; i++) {
108         *args[i] = std::move(constArgs[i].value());
109       }
110 
111       return FoldElementalIntrinsic<T, LargestInt, LargestInt>(context,
112           std::move(funcRef),
113           ScalarFunc<T, LargestInt, LargestInt>(
114               [&fptr](
115                   const Scalar<LargestInt> &i, const Scalar<LargestInt> &j) {
116                 return Scalar<T>{std::invoke(fptr, i, j)};
117               }));
118     } else {
119       return Expr<T>{std::move(funcRef)};
120     }
121   } else if (name == "btest") {
122     if (const auto *ix{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
123       return common::visit(
124           [&](const auto &x) {
125             using IT = ResultType<decltype(x)>;
126             return FoldElementalIntrinsic<T, IT, SameInt>(context,
127                 std::move(funcRef),
128                 ScalarFunc<T, IT, SameInt>(
129                     [&](const Scalar<IT> &x, const Scalar<SameInt> &pos) {
130                       auto posVal{pos.ToInt64()};
131                       if (posVal < 0 || posVal >= x.bits) {
132                         context.messages().Say(
133                             "POS=%jd out of range for BTEST"_err_en_US,
134                             static_cast<std::intmax_t>(posVal));
135                       }
136                       return Scalar<T>{x.BTEST(posVal)};
137                     }));
138           },
139           ix->u);
140     }
141   } else if (name == "dot_product") {
142     return FoldDotProduct<T>(context, std::move(funcRef));
143   } else if (name == "extends_type_of") {
144     // Type extension testing with EXTENDS_TYPE_OF() ignores any type
145     // parameters. Returns a constant truth value when the result is known now.
146     if (args[0] && args[1]) {
147       auto t0{args[0]->GetType()};
148       auto t1{args[1]->GetType()};
149       if (t0 && t1) {
150         if (auto result{t0->ExtendsTypeOf(*t1)}) {
151           return Expr<T>{*result};
152         }
153       }
154     }
155   } else if (name == "isnan" || name == "__builtin_ieee_is_nan") {
156     // Only replace the type of the function if we can do the fold
157     if (args[0] && args[0]->UnwrapExpr() &&
158         IsActuallyConstant(*args[0]->UnwrapExpr())) {
159       auto restorer{context.messages().DiscardMessages()};
160       using DefaultReal = Type<TypeCategory::Real, 4>;
161       return FoldElementalIntrinsic<T, DefaultReal>(context, std::move(funcRef),
162           ScalarFunc<T, DefaultReal>([](const Scalar<DefaultReal> &x) {
163             return Scalar<T>{x.IsNotANumber()};
164           }));
165     }
166   } else if (name == "__builtin_ieee_is_negative") {
167     auto restorer{context.messages().DiscardMessages()};
168     using DefaultReal = Type<TypeCategory::Real, 4>;
169     if (args[0] && args[0]->UnwrapExpr() &&
170         IsActuallyConstant(*args[0]->UnwrapExpr())) {
171       return FoldElementalIntrinsic<T, DefaultReal>(context, std::move(funcRef),
172           ScalarFunc<T, DefaultReal>([](const Scalar<DefaultReal> &x) {
173             return Scalar<T>{x.IsNegative()};
174           }));
175     }
176   } else if (name == "__builtin_ieee_is_normal") {
177     auto restorer{context.messages().DiscardMessages()};
178     using DefaultReal = Type<TypeCategory::Real, 4>;
179     if (args[0] && args[0]->UnwrapExpr() &&
180         IsActuallyConstant(*args[0]->UnwrapExpr())) {
181       return FoldElementalIntrinsic<T, DefaultReal>(context, std::move(funcRef),
182           ScalarFunc<T, DefaultReal>([](const Scalar<DefaultReal> &x) {
183             return Scalar<T>{x.IsNormal()};
184           }));
185     }
186   } else if (name == "is_contiguous") {
187     if (args.at(0)) {
188       if (auto *expr{args[0]->UnwrapExpr()}) {
189         if (auto contiguous{IsContiguous(*expr, context)}) {
190           return Expr<T>{*contiguous};
191         }
192       } else if (auto *assumedType{args[0]->GetAssumedTypeDummy()}) {
193         if (auto contiguous{IsContiguous(*assumedType, context)}) {
194           return Expr<T>{*contiguous};
195         }
196       }
197     }
198   } else if (name == "is_iostat_end") {
199     if (args[0] && args[0]->UnwrapExpr() &&
200         IsActuallyConstant(*args[0]->UnwrapExpr())) {
201       using Int64 = Type<TypeCategory::Integer, 8>;
202       return FoldElementalIntrinsic<T, Int64>(context, std::move(funcRef),
203           ScalarFunc<T, Int64>([](const Scalar<Int64> &x) {
204             return Scalar<T>{x.ToInt64() == FORTRAN_RUNTIME_IOSTAT_END};
205           }));
206     }
207   } else if (name == "is_iostat_eor") {
208     if (args[0] && args[0]->UnwrapExpr() &&
209         IsActuallyConstant(*args[0]->UnwrapExpr())) {
210       using Int64 = Type<TypeCategory::Integer, 8>;
211       return FoldElementalIntrinsic<T, Int64>(context, std::move(funcRef),
212           ScalarFunc<T, Int64>([](const Scalar<Int64> &x) {
213             return Scalar<T>{x.ToInt64() == FORTRAN_RUNTIME_IOSTAT_EOR};
214           }));
215     }
216   } else if (name == "lge" || name == "lgt" || name == "lle" || name == "llt") {
217     // Rewrite LGE/LGT/LLE/LLT into ASCII character relations
218     auto *cx0{UnwrapExpr<Expr<SomeCharacter>>(args[0])};
219     auto *cx1{UnwrapExpr<Expr<SomeCharacter>>(args[1])};
220     if (cx0 && cx1) {
221       return Fold(context,
222           ConvertToType<T>(
223               PackageRelation(name == "lge" ? RelationalOperator::GE
224                       : name == "lgt"       ? RelationalOperator::GT
225                       : name == "lle"       ? RelationalOperator::LE
226                                             : RelationalOperator::LT,
227                   ConvertToType<Ascii>(std::move(*cx0)),
228                   ConvertToType<Ascii>(std::move(*cx1)))));
229     }
230   } else if (name == "logical") {
231     if (auto *expr{UnwrapExpr<Expr<SomeLogical>>(args[0])}) {
232       return Fold(context, ConvertToType<T>(std::move(*expr)));
233     }
234   } else if (name == "out_of_range") {
235     if (Expr<SomeType> * cx{UnwrapExpr<Expr<SomeType>>(args[0])}) {
236       auto restorer{context.messages().DiscardMessages()};
237       *args[0] = Fold(context, std::move(*cx));
238       if (Expr<SomeType> & folded{DEREF(args[0].value().UnwrapExpr())};
239           IsActuallyConstant(folded)) {
240         std::optional<std::vector<typename T::Scalar>> result;
241         if (Expr<SomeReal> * realMold{UnwrapExpr<Expr<SomeReal>>(args[1])}) {
242           if (const auto *xInt{UnwrapExpr<Expr<SomeInteger>>(folded)}) {
243             result.emplace();
244             std::visit(
245                 [&](const auto &mold, const auto &x) {
246                   using RealType =
247                       typename std::decay_t<decltype(mold)>::Result;
248                   static_assert(RealType::category == TypeCategory::Real);
249                   using Scalar = typename RealType::Scalar;
250                   using xType = typename std::decay_t<decltype(x)>::Result;
251                   const auto &xConst{DEREF(UnwrapExpr<Constant<xType>>(x))};
252                   for (const auto &elt : xConst.values()) {
253                     result->emplace_back(
254                         Scalar::template FromInteger(elt).flags.test(
255                             RealFlag::Overflow));
256                   }
257                 },
258                 realMold->u, xInt->u);
259           } else if (const auto *xReal{UnwrapExpr<Expr<SomeReal>>(folded)}) {
260             result.emplace();
261             std::visit(
262                 [&](const auto &mold, const auto &x) {
263                   using RealType =
264                       typename std::decay_t<decltype(mold)>::Result;
265                   static_assert(RealType::category == TypeCategory::Real);
266                   using Scalar = typename RealType::Scalar;
267                   using xType = typename std::decay_t<decltype(x)>::Result;
268                   const auto &xConst{DEREF(UnwrapExpr<Constant<xType>>(x))};
269                   for (const auto &elt : xConst.values()) {
270                     result->emplace_back(elt.IsFinite() &&
271                         Scalar::template Convert(elt).flags.test(
272                             RealFlag::Overflow));
273                   }
274                 },
275                 realMold->u, xReal->u);
276           }
277         } else if (Expr<SomeInteger> *
278             intMold{UnwrapExpr<Expr<SomeInteger>>(args[1])}) {
279           if (const auto *xInt{UnwrapExpr<Expr<SomeInteger>>(folded)}) {
280             result.emplace();
281             std::visit(
282                 [&](const auto &mold, const auto &x) {
283                   using IntType = typename std::decay_t<decltype(mold)>::Result;
284                   static_assert(IntType::category == TypeCategory::Integer);
285                   using Scalar = typename IntType::Scalar;
286                   using xType = typename std::decay_t<decltype(x)>::Result;
287                   const auto &xConst{DEREF(UnwrapExpr<Constant<xType>>(x))};
288                   for (const auto &elt : xConst.values()) {
289                     result->emplace_back(
290                         Scalar::template ConvertSigned(elt).overflow);
291                   }
292                 },
293                 intMold->u, xInt->u);
294           } else if (Expr<SomeLogical> *
295                          cRound{args.size() >= 3
296                                  ? UnwrapExpr<Expr<SomeLogical>>(args[2])
297                                  : nullptr};
298                      !cRound || IsActuallyConstant(*args[2]->UnwrapExpr())) {
299             if (const auto *xReal{UnwrapExpr<Expr<SomeReal>>(folded)}) {
300               common::RoundingMode roundingMode{common::RoundingMode::ToZero};
301               if (cRound &&
302                   common::visit(
303                       [](const auto &x) {
304                         using xType =
305                             typename std::decay_t<decltype(x)>::Result;
306                         return GetScalarConstantValue<xType>(x)
307                             .value()
308                             .IsTrue();
309                       },
310                       cRound->u)) {
311                 // ROUND=.TRUE. - convert with NINT()
312                 roundingMode = common::RoundingMode::TiesAwayFromZero;
313               }
314               result.emplace();
315               std::visit(
316                   [&](const auto &mold, const auto &x) {
317                     using IntType =
318                         typename std::decay_t<decltype(mold)>::Result;
319                     static_assert(IntType::category == TypeCategory::Integer);
320                     using Scalar = typename IntType::Scalar;
321                     using xType = typename std::decay_t<decltype(x)>::Result;
322                     const auto &xConst{DEREF(UnwrapExpr<Constant<xType>>(x))};
323                     for (const auto &elt : xConst.values()) {
324                       // Note that OUT_OF_RANGE(Inf/NaN) is .TRUE. for the
325                       // real->integer case, but not for  real->real.
326                       result->emplace_back(!elt.IsFinite() ||
327                           elt.template ToInteger<Scalar>(roundingMode)
328                               .flags.test(RealFlag::Overflow));
329                     }
330                   },
331                   intMold->u, xReal->u);
332             }
333           }
334         }
335         if (result) {
336           if (auto extents{GetConstantExtents(context, folded)}) {
337             return Expr<T>{
338                 Constant<T>{std::move(*result), std::move(*extents)}};
339           }
340         }
341       }
342     }
343   } else if (name == "parity") {
344     return FoldAllAnyParity(
345         context, std::move(funcRef), &Scalar<T>::NEQV, Scalar<T>{false});
346   } else if (name == "same_type_as") {
347     // Type equality testing with SAME_TYPE_AS() ignores any type parameters.
348     // Returns a constant truth value when the result is known now.
349     if (args[0] && args[1]) {
350       auto t0{args[0]->GetType()};
351       auto t1{args[1]->GetType()};
352       if (t0 && t1) {
353         if (auto result{t0->SameTypeAs(*t1)}) {
354           return Expr<T>{*result};
355         }
356       }
357     }
358   } else if (name == "__builtin_ieee_support_datatype" ||
359       name == "__builtin_ieee_support_denormal" ||
360       name == "__builtin_ieee_support_divide" ||
361       name == "__builtin_ieee_support_inf" ||
362       name == "__builtin_ieee_support_io" ||
363       name == "__builtin_ieee_support_nan" ||
364       name == "__builtin_ieee_support_sqrt" ||
365       name == "__builtin_ieee_support_standard" ||
366       name == "__builtin_ieee_support_subnormal" ||
367       name == "__builtin_ieee_support_underflow_control") {
368     return Expr<T>{true};
369   }
370   // TODO: logical, matmul, parity
371   return Expr<T>{std::move(funcRef)};
372 }
373 
374 template <typename T>
375 Expr<LogicalResult> FoldOperation(
376     FoldingContext &context, Relational<T> &&relation) {
377   if (auto array{ApplyElementwise(context, relation,
378           std::function<Expr<LogicalResult>(Expr<T> &&, Expr<T> &&)>{
379               [=](Expr<T> &&x, Expr<T> &&y) {
380                 return Expr<LogicalResult>{Relational<SomeType>{
381                     Relational<T>{relation.opr, std::move(x), std::move(y)}}};
382               }})}) {
383     return *array;
384   }
385   if (auto folded{OperandsAreConstants(relation)}) {
386     bool result{};
387     if constexpr (T::category == TypeCategory::Integer) {
388       result =
389           Satisfies(relation.opr, folded->first.CompareSigned(folded->second));
390     } else if constexpr (T::category == TypeCategory::Real) {
391       result = Satisfies(relation.opr, folded->first.Compare(folded->second));
392     } else if constexpr (T::category == TypeCategory::Complex) {
393       result = (relation.opr == RelationalOperator::EQ) ==
394           folded->first.Equals(folded->second);
395     } else if constexpr (T::category == TypeCategory::Character) {
396       result = Satisfies(relation.opr, Compare(folded->first, folded->second));
397     } else {
398       static_assert(T::category != TypeCategory::Logical);
399     }
400     return Expr<LogicalResult>{Constant<LogicalResult>{result}};
401   }
402   return Expr<LogicalResult>{Relational<SomeType>{std::move(relation)}};
403 }
404 
405 Expr<LogicalResult> FoldOperation(
406     FoldingContext &context, Relational<SomeType> &&relation) {
407   return common::visit(
408       [&](auto &&x) {
409         return Expr<LogicalResult>{FoldOperation(context, std::move(x))};
410       },
411       std::move(relation.u));
412 }
413 
414 template <int KIND>
415 Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(
416     FoldingContext &context, Not<KIND> &&x) {
417   if (auto array{ApplyElementwise(context, x)}) {
418     return *array;
419   }
420   using Ty = Type<TypeCategory::Logical, KIND>;
421   auto &operand{x.left()};
422   if (auto value{GetScalarConstantValue<Ty>(operand)}) {
423     return Expr<Ty>{Constant<Ty>{!value->IsTrue()}};
424   }
425   return Expr<Ty>{x};
426 }
427 
428 template <int KIND>
429 Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(
430     FoldingContext &context, LogicalOperation<KIND> &&operation) {
431   using LOGICAL = Type<TypeCategory::Logical, KIND>;
432   if (auto array{ApplyElementwise(context, operation,
433           std::function<Expr<LOGICAL>(Expr<LOGICAL> &&, Expr<LOGICAL> &&)>{
434               [=](Expr<LOGICAL> &&x, Expr<LOGICAL> &&y) {
435                 return Expr<LOGICAL>{LogicalOperation<KIND>{
436                     operation.logicalOperator, std::move(x), std::move(y)}};
437               }})}) {
438     return *array;
439   }
440   if (auto folded{OperandsAreConstants(operation)}) {
441     bool xt{folded->first.IsTrue()}, yt{folded->second.IsTrue()}, result{};
442     switch (operation.logicalOperator) {
443     case LogicalOperator::And:
444       result = xt && yt;
445       break;
446     case LogicalOperator::Or:
447       result = xt || yt;
448       break;
449     case LogicalOperator::Eqv:
450       result = xt == yt;
451       break;
452     case LogicalOperator::Neqv:
453       result = xt != yt;
454       break;
455     case LogicalOperator::Not:
456       DIE("not a binary operator");
457     }
458     return Expr<LOGICAL>{Constant<LOGICAL>{result}};
459   }
460   return Expr<LOGICAL>{std::move(operation)};
461 }
462 
463 #ifdef _MSC_VER // disable bogus warning about missing definitions
464 #pragma warning(disable : 4661)
465 #endif
466 FOR_EACH_LOGICAL_KIND(template class ExpressionBase, )
467 template class ExpressionBase<SomeLogical>;
468 } // namespace Fortran::evaluate
469