xref: /llvm-project/flang/lib/Evaluate/fold-logical.cpp (revision 54784b1831651a5f47ebafdb1f79a3380ac4a46b)
1 //===-- lib/Evaluate/fold-logical.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "fold-implementation.h"
10 #include "fold-reduction.h"
11 #include "flang/Evaluate/check-expression.h"
12 
13 namespace Fortran::evaluate {
14 
15 template <typename T>
16 static std::optional<Expr<SomeType>> ZeroExtend(const Constant<T> &c) {
17   std::vector<Scalar<LargestInt>> exts;
18   for (const auto &v : c.values()) {
19     exts.push_back(Scalar<LargestInt>::ConvertUnsigned(v).value);
20   }
21   return AsGenericExpr(
22       Constant<LargestInt>(std::move(exts), ConstantSubscripts(c.shape())));
23 }
24 
25 // for ALL, ANY & PARITY
26 template <typename T>
27 static Expr<T> FoldAllAnyParity(FoldingContext &context, FunctionRef<T> &&ref,
28     Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
29     Scalar<T> identity) {
30   static_assert(T::category == TypeCategory::Logical);
31   using Element = Scalar<T>;
32   std::optional<int> dim;
33   if (std::optional<Constant<T>> array{
34           ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
35               /*ARRAY(MASK)=*/0, /*DIM=*/1)}) {
36     auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
37       element = (element.*operation)(array->At(at));
38     }};
39     return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
40   }
41   return Expr<T>{std::move(ref)};
42 }
43 
44 template <int KIND>
45 Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
46     FoldingContext &context,
47     FunctionRef<Type<TypeCategory::Logical, KIND>> &&funcRef) {
48   using T = Type<TypeCategory::Logical, KIND>;
49   ActualArguments &args{funcRef.arguments()};
50   auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)};
51   CHECK(intrinsic);
52   std::string name{intrinsic->name};
53   using SameInt = Type<TypeCategory::Integer, KIND>;
54   if (name == "all") {
55     return FoldAllAnyParity(
56         context, std::move(funcRef), &Scalar<T>::AND, Scalar<T>{true});
57   } else if (name == "any") {
58     return FoldAllAnyParity(
59         context, std::move(funcRef), &Scalar<T>::OR, Scalar<T>{false});
60   } else if (name == "associated") {
61     bool gotConstant{true};
62     const Expr<SomeType> *firstArgExpr{args[0]->UnwrapExpr()};
63     if (!firstArgExpr || !IsNullPointer(*firstArgExpr)) {
64       gotConstant = false;
65     } else if (args[1]) { // There's a second argument
66       const Expr<SomeType> *secondArgExpr{args[1]->UnwrapExpr()};
67       if (!secondArgExpr || !IsNullPointer(*secondArgExpr)) {
68         gotConstant = false;
69       }
70     }
71     return gotConstant ? Expr<T>{false} : Expr<T>{std::move(funcRef)};
72   } else if (name == "bge" || name == "bgt" || name == "ble" || name == "blt") {
73     static_assert(std::is_same_v<Scalar<LargestInt>, BOZLiteralConstant>);
74 
75     // The arguments to these intrinsics can be of different types. In that
76     // case, the shorter of the two would need to be zero-extended to match
77     // the size of the other. If at least one of the operands is not a constant,
78     // the zero-extending will be done during lowering. Otherwise, the folding
79     // must be done here.
80     std::optional<Expr<SomeType>> constArgs[2];
81     for (int i{0}; i <= 1; i++) {
82       if (BOZLiteralConstant * x{UnwrapExpr<BOZLiteralConstant>(args[i])}) {
83         constArgs[i] = AsGenericExpr(Constant<LargestInt>{std::move(*x)});
84       } else if (auto *x{UnwrapExpr<Expr<SomeInteger>>(args[i])}) {
85         common::visit(
86             [&](const auto &ix) {
87               using IntT = typename std::decay_t<decltype(ix)>::Result;
88               if (auto *c{UnwrapConstantValue<IntT>(ix)}) {
89                 constArgs[i] = ZeroExtend(*c);
90               }
91             },
92             x->u);
93       }
94     }
95 
96     if (constArgs[0] && constArgs[1]) {
97       auto fptr{&Scalar<LargestInt>::BGE};
98       if (name == "bge") { // done in fptr declaration
99       } else if (name == "bgt") {
100         fptr = &Scalar<LargestInt>::BGT;
101       } else if (name == "ble") {
102         fptr = &Scalar<LargestInt>::BLE;
103       } else if (name == "blt") {
104         fptr = &Scalar<LargestInt>::BLT;
105       } else {
106         common::die("missing case to fold intrinsic function %s", name.c_str());
107       }
108 
109       for (int i{0}; i <= 1; i++) {
110         *args[i] = std::move(constArgs[i].value());
111       }
112 
113       return FoldElementalIntrinsic<T, LargestInt, LargestInt>(context,
114           std::move(funcRef),
115           ScalarFunc<T, LargestInt, LargestInt>(
116               [&fptr](
117                   const Scalar<LargestInt> &i, const Scalar<LargestInt> &j) {
118                 return Scalar<T>{std::invoke(fptr, i, j)};
119               }));
120     } else {
121       return Expr<T>{std::move(funcRef)};
122     }
123   } else if (name == "btest") {
124     if (const auto *ix{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
125       return common::visit(
126           [&](const auto &x) {
127             using IT = ResultType<decltype(x)>;
128             return FoldElementalIntrinsic<T, IT, SameInt>(context,
129                 std::move(funcRef),
130                 ScalarFunc<T, IT, SameInt>(
131                     [&](const Scalar<IT> &x, const Scalar<SameInt> &pos) {
132                       auto posVal{pos.ToInt64()};
133                       if (posVal < 0 || posVal >= x.bits) {
134                         context.messages().Say(
135                             "POS=%jd out of range for BTEST"_err_en_US,
136                             static_cast<std::intmax_t>(posVal));
137                       }
138                       return Scalar<T>{x.BTEST(posVal)};
139                     }));
140           },
141           ix->u);
142     }
143   } else if (name == "dot_product") {
144     return FoldDotProduct<T>(context, std::move(funcRef));
145   } else if (name == "extends_type_of") {
146     // Type extension testing with EXTENDS_TYPE_OF() ignores any type
147     // parameters. Returns a constant truth value when the result is known now.
148     if (args[0] && args[1]) {
149       auto t0{args[0]->GetType()};
150       auto t1{args[1]->GetType()};
151       if (t0 && t1) {
152         if (auto result{t0->ExtendsTypeOf(*t1)}) {
153           return Expr<T>{*result};
154         }
155       }
156     }
157   } else if (name == "isnan" || name == "__builtin_ieee_is_nan") {
158     // Only replace the type of the function if we can do the fold
159     if (args[0] && args[0]->UnwrapExpr() &&
160         IsActuallyConstant(*args[0]->UnwrapExpr())) {
161       auto restorer{context.messages().DiscardMessages()};
162       using DefaultReal = Type<TypeCategory::Real, 4>;
163       return FoldElementalIntrinsic<T, DefaultReal>(context, std::move(funcRef),
164           ScalarFunc<T, DefaultReal>([](const Scalar<DefaultReal> &x) {
165             return Scalar<T>{x.IsNotANumber()};
166           }));
167     }
168   } else if (name == "__builtin_ieee_is_negative") {
169     auto restorer{context.messages().DiscardMessages()};
170     using DefaultReal = Type<TypeCategory::Real, 4>;
171     if (args[0] && args[0]->UnwrapExpr() &&
172         IsActuallyConstant(*args[0]->UnwrapExpr())) {
173       return FoldElementalIntrinsic<T, DefaultReal>(context, std::move(funcRef),
174           ScalarFunc<T, DefaultReal>([](const Scalar<DefaultReal> &x) {
175             return Scalar<T>{x.IsNegative()};
176           }));
177     }
178   } else if (name == "__builtin_ieee_is_normal") {
179     auto restorer{context.messages().DiscardMessages()};
180     using DefaultReal = Type<TypeCategory::Real, 4>;
181     if (args[0] && args[0]->UnwrapExpr() &&
182         IsActuallyConstant(*args[0]->UnwrapExpr())) {
183       return FoldElementalIntrinsic<T, DefaultReal>(context, std::move(funcRef),
184           ScalarFunc<T, DefaultReal>([](const Scalar<DefaultReal> &x) {
185             return Scalar<T>{x.IsNormal()};
186           }));
187     }
188   } else if (name == "is_contiguous") {
189     if (args.at(0)) {
190       if (auto *expr{args[0]->UnwrapExpr()}) {
191         if (auto contiguous{IsContiguous(*expr, context)}) {
192           return Expr<T>{*contiguous};
193         }
194       } else if (auto *assumedType{args[0]->GetAssumedTypeDummy()}) {
195         if (auto contiguous{IsContiguous(*assumedType, context)}) {
196           return Expr<T>{*contiguous};
197         }
198       }
199     }
200   } else if (name == "lge" || name == "lgt" || name == "lle" || name == "llt") {
201     // Rewrite LGE/LGT/LLE/LLT into ASCII character relations
202     auto *cx0{UnwrapExpr<Expr<SomeCharacter>>(args[0])};
203     auto *cx1{UnwrapExpr<Expr<SomeCharacter>>(args[1])};
204     if (cx0 && cx1) {
205       return Fold(context,
206           ConvertToType<T>(
207               PackageRelation(name == "lge" ? RelationalOperator::GE
208                       : name == "lgt"       ? RelationalOperator::GT
209                       : name == "lle"       ? RelationalOperator::LE
210                                             : RelationalOperator::LT,
211                   ConvertToType<Ascii>(std::move(*cx0)),
212                   ConvertToType<Ascii>(std::move(*cx1)))));
213     }
214   } else if (name == "logical") {
215     if (auto *expr{UnwrapExpr<Expr<SomeLogical>>(args[0])}) {
216       return Fold(context, ConvertToType<T>(std::move(*expr)));
217     }
218   } else if (name == "out_of_range") {
219     if (Expr<SomeType> * cx{UnwrapExpr<Expr<SomeType>>(args[0])}) {
220       auto restorer{context.messages().DiscardMessages()};
221       *args[0] = Fold(context, std::move(*cx));
222       if (Expr<SomeType> & folded{DEREF(args[0].value().UnwrapExpr())};
223           IsActuallyConstant(folded)) {
224         std::optional<std::vector<typename T::Scalar>> result;
225         if (Expr<SomeReal> * realMold{UnwrapExpr<Expr<SomeReal>>(args[1])}) {
226           if (const auto *xInt{UnwrapExpr<Expr<SomeInteger>>(folded)}) {
227             result.emplace();
228             std::visit(
229                 [&](const auto &mold, const auto &x) {
230                   using RealType =
231                       typename std::decay_t<decltype(mold)>::Result;
232                   static_assert(RealType::category == TypeCategory::Real);
233                   using Scalar = typename RealType::Scalar;
234                   using xType = typename std::decay_t<decltype(x)>::Result;
235                   const auto &xConst{DEREF(UnwrapExpr<Constant<xType>>(x))};
236                   for (const auto &elt : xConst.values()) {
237                     result->emplace_back(
238                         Scalar::template FromInteger(elt).flags.test(
239                             RealFlag::Overflow));
240                   }
241                 },
242                 realMold->u, xInt->u);
243           } else if (const auto *xReal{UnwrapExpr<Expr<SomeReal>>(folded)}) {
244             result.emplace();
245             std::visit(
246                 [&](const auto &mold, const auto &x) {
247                   using RealType =
248                       typename std::decay_t<decltype(mold)>::Result;
249                   static_assert(RealType::category == TypeCategory::Real);
250                   using Scalar = typename RealType::Scalar;
251                   using xType = typename std::decay_t<decltype(x)>::Result;
252                   const auto &xConst{DEREF(UnwrapExpr<Constant<xType>>(x))};
253                   for (const auto &elt : xConst.values()) {
254                     result->emplace_back(elt.IsFinite() &&
255                         Scalar::template Convert(elt).flags.test(
256                             RealFlag::Overflow));
257                   }
258                 },
259                 realMold->u, xReal->u);
260           }
261         } else if (Expr<SomeInteger> *
262             intMold{UnwrapExpr<Expr<SomeInteger>>(args[1])}) {
263           if (const auto *xInt{UnwrapExpr<Expr<SomeInteger>>(folded)}) {
264             result.emplace();
265             std::visit(
266                 [&](const auto &mold, const auto &x) {
267                   using IntType = typename std::decay_t<decltype(mold)>::Result;
268                   static_assert(IntType::category == TypeCategory::Integer);
269                   using Scalar = typename IntType::Scalar;
270                   using xType = typename std::decay_t<decltype(x)>::Result;
271                   const auto &xConst{DEREF(UnwrapExpr<Constant<xType>>(x))};
272                   for (const auto &elt : xConst.values()) {
273                     result->emplace_back(
274                         Scalar::template ConvertSigned(elt).overflow);
275                   }
276                 },
277                 intMold->u, xInt->u);
278           } else if (Expr<SomeLogical> *
279                          cRound{args.size() >= 3
280                                  ? UnwrapExpr<Expr<SomeLogical>>(args[2])
281                                  : nullptr};
282                      !cRound || IsActuallyConstant(*args[2]->UnwrapExpr())) {
283             if (const auto *xReal{UnwrapExpr<Expr<SomeReal>>(folded)}) {
284               common::RoundingMode roundingMode{common::RoundingMode::ToZero};
285               if (cRound &&
286                   common::visit(
287                       [](const auto &x) {
288                         using xType =
289                             typename std::decay_t<decltype(x)>::Result;
290                         return GetScalarConstantValue<xType>(x)
291                             .value()
292                             .IsTrue();
293                       },
294                       cRound->u)) {
295                 // ROUND=.TRUE. - convert with NINT()
296                 roundingMode = common::RoundingMode::TiesAwayFromZero;
297               }
298               result.emplace();
299               std::visit(
300                   [&](const auto &mold, const auto &x) {
301                     using IntType =
302                         typename std::decay_t<decltype(mold)>::Result;
303                     static_assert(IntType::category == TypeCategory::Integer);
304                     using Scalar = typename IntType::Scalar;
305                     using xType = typename std::decay_t<decltype(x)>::Result;
306                     const auto &xConst{DEREF(UnwrapExpr<Constant<xType>>(x))};
307                     for (const auto &elt : xConst.values()) {
308                       // Note that OUT_OF_RANGE(Inf/NaN) is .TRUE. for the
309                       // real->integer case, but not for  real->real.
310                       result->emplace_back(!elt.IsFinite() ||
311                           elt.template ToInteger<Scalar>(roundingMode)
312                               .flags.test(RealFlag::Overflow));
313                     }
314                   },
315                   intMold->u, xReal->u);
316             }
317           }
318         }
319         if (result) {
320           if (auto extents{GetConstantExtents(context, folded)}) {
321             return Expr<T>{
322                 Constant<T>{std::move(*result), std::move(*extents)}};
323           }
324         }
325       }
326     }
327   } else if (name == "parity") {
328     return FoldAllAnyParity(
329         context, std::move(funcRef), &Scalar<T>::NEQV, Scalar<T>{false});
330   } else if (name == "same_type_as") {
331     // Type equality testing with SAME_TYPE_AS() ignores any type parameters.
332     // Returns a constant truth value when the result is known now.
333     if (args[0] && args[1]) {
334       auto t0{args[0]->GetType()};
335       auto t1{args[1]->GetType()};
336       if (t0 && t1) {
337         if (auto result{t0->SameTypeAs(*t1)}) {
338           return Expr<T>{*result};
339         }
340       }
341     }
342   } else if (name == "__builtin_ieee_support_datatype" ||
343       name == "__builtin_ieee_support_denormal" ||
344       name == "__builtin_ieee_support_divide" ||
345       name == "__builtin_ieee_support_inf" ||
346       name == "__builtin_ieee_support_io" ||
347       name == "__builtin_ieee_support_nan" ||
348       name == "__builtin_ieee_support_sqrt" ||
349       name == "__builtin_ieee_support_standard" ||
350       name == "__builtin_ieee_support_subnormal" ||
351       name == "__builtin_ieee_support_underflow_control") {
352     return Expr<T>{true};
353   }
354   // TODO: is_iostat_end, is_iostat_eor, logical, matmul, parity
355   return Expr<T>{std::move(funcRef)};
356 }
357 
358 template <typename T>
359 Expr<LogicalResult> FoldOperation(
360     FoldingContext &context, Relational<T> &&relation) {
361   if (auto array{ApplyElementwise(context, relation,
362           std::function<Expr<LogicalResult>(Expr<T> &&, Expr<T> &&)>{
363               [=](Expr<T> &&x, Expr<T> &&y) {
364                 return Expr<LogicalResult>{Relational<SomeType>{
365                     Relational<T>{relation.opr, std::move(x), std::move(y)}}};
366               }})}) {
367     return *array;
368   }
369   if (auto folded{OperandsAreConstants(relation)}) {
370     bool result{};
371     if constexpr (T::category == TypeCategory::Integer) {
372       result =
373           Satisfies(relation.opr, folded->first.CompareSigned(folded->second));
374     } else if constexpr (T::category == TypeCategory::Real) {
375       result = Satisfies(relation.opr, folded->first.Compare(folded->second));
376     } else if constexpr (T::category == TypeCategory::Complex) {
377       result = (relation.opr == RelationalOperator::EQ) ==
378           folded->first.Equals(folded->second);
379     } else if constexpr (T::category == TypeCategory::Character) {
380       result = Satisfies(relation.opr, Compare(folded->first, folded->second));
381     } else {
382       static_assert(T::category != TypeCategory::Logical);
383     }
384     return Expr<LogicalResult>{Constant<LogicalResult>{result}};
385   }
386   return Expr<LogicalResult>{Relational<SomeType>{std::move(relation)}};
387 }
388 
389 Expr<LogicalResult> FoldOperation(
390     FoldingContext &context, Relational<SomeType> &&relation) {
391   return common::visit(
392       [&](auto &&x) {
393         return Expr<LogicalResult>{FoldOperation(context, std::move(x))};
394       },
395       std::move(relation.u));
396 }
397 
398 template <int KIND>
399 Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(
400     FoldingContext &context, Not<KIND> &&x) {
401   if (auto array{ApplyElementwise(context, x)}) {
402     return *array;
403   }
404   using Ty = Type<TypeCategory::Logical, KIND>;
405   auto &operand{x.left()};
406   if (auto value{GetScalarConstantValue<Ty>(operand)}) {
407     return Expr<Ty>{Constant<Ty>{!value->IsTrue()}};
408   }
409   return Expr<Ty>{x};
410 }
411 
412 template <int KIND>
413 Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(
414     FoldingContext &context, LogicalOperation<KIND> &&operation) {
415   using LOGICAL = Type<TypeCategory::Logical, KIND>;
416   if (auto array{ApplyElementwise(context, operation,
417           std::function<Expr<LOGICAL>(Expr<LOGICAL> &&, Expr<LOGICAL> &&)>{
418               [=](Expr<LOGICAL> &&x, Expr<LOGICAL> &&y) {
419                 return Expr<LOGICAL>{LogicalOperation<KIND>{
420                     operation.logicalOperator, std::move(x), std::move(y)}};
421               }})}) {
422     return *array;
423   }
424   if (auto folded{OperandsAreConstants(operation)}) {
425     bool xt{folded->first.IsTrue()}, yt{folded->second.IsTrue()}, result{};
426     switch (operation.logicalOperator) {
427     case LogicalOperator::And:
428       result = xt && yt;
429       break;
430     case LogicalOperator::Or:
431       result = xt || yt;
432       break;
433     case LogicalOperator::Eqv:
434       result = xt == yt;
435       break;
436     case LogicalOperator::Neqv:
437       result = xt != yt;
438       break;
439     case LogicalOperator::Not:
440       DIE("not a binary operator");
441     }
442     return Expr<LOGICAL>{Constant<LOGICAL>{result}};
443   }
444   return Expr<LOGICAL>{std::move(operation)};
445 }
446 
447 #ifdef _MSC_VER // disable bogus warning about missing definitions
448 #pragma warning(disable : 4661)
449 #endif
450 FOR_EACH_LOGICAL_KIND(template class ExpressionBase, )
451 template class ExpressionBase<SomeLogical>;
452 } // namespace Fortran::evaluate
453