xref: /llvm-project/flang/lib/Evaluate/fold-logical.cpp (revision c757418869c01f5ee08f05661debabbba92edcf9)
1 //===-- lib/Evaluate/fold-logical.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "fold-implementation.h"
10 #include "flang/Evaluate/check-expression.h"
11 
12 namespace Fortran::evaluate {
13 
14 template <int KIND>
15 Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
16     FoldingContext &context,
17     FunctionRef<Type<TypeCategory::Logical, KIND>> &&funcRef) {
18   using T = Type<TypeCategory::Logical, KIND>;
19   ActualArguments &args{funcRef.arguments()};
20   auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)};
21   CHECK(intrinsic);
22   std::string name{intrinsic->name};
23   if (name == "all") {
24     if (!args[1]) { // TODO: ALL(x,DIM=d)
25       if (const auto *constant{UnwrapConstantValue<T>(args[0])}) {
26         bool result{true};
27         for (const auto &element : constant->values()) {
28           if (!element.IsTrue()) {
29             result = false;
30             break;
31           }
32         }
33         return Expr<T>{result};
34       }
35     }
36   } else if (name == "any") {
37     if (!args[1]) { // TODO: ANY(x,DIM=d)
38       if (const auto *constant{UnwrapConstantValue<T>(args[0])}) {
39         bool result{false};
40         for (const auto &element : constant->values()) {
41           if (element.IsTrue()) {
42             result = true;
43             break;
44           }
45         }
46         return Expr<T>{result};
47       }
48     }
49   } else if (name == "associated") {
50     bool gotConstant{true};
51     const Expr<SomeType> *firstArgExpr{args[0]->UnwrapExpr()};
52     if (!firstArgExpr || !IsNullPointer(*firstArgExpr)) {
53       gotConstant = false;
54     } else if (args[1]) { // There's a second argument
55       const Expr<SomeType> *secondArgExpr{args[1]->UnwrapExpr()};
56       if (!secondArgExpr || !IsNullPointer(*secondArgExpr)) {
57         gotConstant = false;
58       }
59     }
60     return gotConstant ? Expr<T>{false} : Expr<T>{std::move(funcRef)};
61   } else if (name == "bge" || name == "bgt" || name == "ble" || name == "blt") {
62     using LargestInt = Type<TypeCategory::Integer, 16>;
63     static_assert(std::is_same_v<Scalar<LargestInt>, BOZLiteralConstant>);
64     // Arguments do not have to be of the same integer type. Convert all
65     // arguments to the biggest integer type before comparing them to
66     // simplify.
67     for (int i{0}; i <= 1; ++i) {
68       if (auto *x{UnwrapExpr<Expr<SomeInteger>>(args[i])}) {
69         *args[i] = AsGenericExpr(
70             Fold(context, ConvertToType<LargestInt>(std::move(*x))));
71       } else if (auto *x{UnwrapExpr<BOZLiteralConstant>(args[i])}) {
72         *args[i] = AsGenericExpr(Constant<LargestInt>{std::move(*x)});
73       }
74     }
75     auto fptr{&Scalar<LargestInt>::BGE};
76     if (name == "bge") { // done in fptr declaration
77     } else if (name == "bgt") {
78       fptr = &Scalar<LargestInt>::BGT;
79     } else if (name == "ble") {
80       fptr = &Scalar<LargestInt>::BLE;
81     } else if (name == "blt") {
82       fptr = &Scalar<LargestInt>::BLT;
83     } else {
84       common::die("missing case to fold intrinsic function %s", name.c_str());
85     }
86     return FoldElementalIntrinsic<T, LargestInt, LargestInt>(context,
87         std::move(funcRef),
88         ScalarFunc<T, LargestInt, LargestInt>(
89             [&fptr](const Scalar<LargestInt> &i, const Scalar<LargestInt> &j) {
90               return Scalar<T>{std::invoke(fptr, i, j)};
91             }));
92   } else if (name == "is_contiguous") {
93     if (args.at(0)) {
94       if (auto *expr{args[0]->UnwrapExpr()}) {
95         if (IsSimplyContiguous(*expr, context.intrinsics())) {
96           return Expr<T>{true};
97         }
98       }
99     }
100   } else if (name == "merge") {
101     return FoldMerge<T>(context, std::move(funcRef));
102   }
103   // TODO: btest, cshift, dot_product, eoshift, is_iostat_end,
104   // is_iostat_eor, lge, lgt, lle, llt, logical, matmul, out_of_range,
105   // pack, parity, reduce, spread, transfer, transpose, unpack,
106   // extends_type_of, same_type_as
107   return Expr<T>{std::move(funcRef)};
108 }
109 
110 template <typename T>
111 Expr<LogicalResult> FoldOperation(
112     FoldingContext &context, Relational<T> &&relation) {
113   if (auto array{ApplyElementwise(context, relation,
114           std::function<Expr<LogicalResult>(Expr<T> &&, Expr<T> &&)>{
115               [=](Expr<T> &&x, Expr<T> &&y) {
116                 return Expr<LogicalResult>{Relational<SomeType>{
117                     Relational<T>{relation.opr, std::move(x), std::move(y)}}};
118               }})}) {
119     return *array;
120   }
121   if (auto folded{OperandsAreConstants(relation)}) {
122     bool result{};
123     if constexpr (T::category == TypeCategory::Integer) {
124       result =
125           Satisfies(relation.opr, folded->first.CompareSigned(folded->second));
126     } else if constexpr (T::category == TypeCategory::Real) {
127       result = Satisfies(relation.opr, folded->first.Compare(folded->second));
128     } else if constexpr (T::category == TypeCategory::Character) {
129       result = Satisfies(relation.opr, Compare(folded->first, folded->second));
130     } else {
131       static_assert(T::category != TypeCategory::Complex &&
132           T::category != TypeCategory::Logical);
133     }
134     return Expr<LogicalResult>{Constant<LogicalResult>{result}};
135   }
136   return Expr<LogicalResult>{Relational<SomeType>{std::move(relation)}};
137 }
138 
139 Expr<LogicalResult> FoldOperation(
140     FoldingContext &context, Relational<SomeType> &&relation) {
141   return std::visit(
142       [&](auto &&x) {
143         return Expr<LogicalResult>{FoldOperation(context, std::move(x))};
144       },
145       std::move(relation.u));
146 }
147 
148 template <int KIND>
149 Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(
150     FoldingContext &context, Not<KIND> &&x) {
151   if (auto array{ApplyElementwise(context, x)}) {
152     return *array;
153   }
154   using Ty = Type<TypeCategory::Logical, KIND>;
155   auto &operand{x.left()};
156   if (auto value{GetScalarConstantValue<Ty>(operand)}) {
157     return Expr<Ty>{Constant<Ty>{!value->IsTrue()}};
158   }
159   return Expr<Ty>{x};
160 }
161 
162 template <int KIND>
163 Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(
164     FoldingContext &context, LogicalOperation<KIND> &&operation) {
165   using LOGICAL = Type<TypeCategory::Logical, KIND>;
166   if (auto array{ApplyElementwise(context, operation,
167           std::function<Expr<LOGICAL>(Expr<LOGICAL> &&, Expr<LOGICAL> &&)>{
168               [=](Expr<LOGICAL> &&x, Expr<LOGICAL> &&y) {
169                 return Expr<LOGICAL>{LogicalOperation<KIND>{
170                     operation.logicalOperator, std::move(x), std::move(y)}};
171               }})}) {
172     return *array;
173   }
174   if (auto folded{OperandsAreConstants(operation)}) {
175     bool xt{folded->first.IsTrue()}, yt{folded->second.IsTrue()}, result{};
176     switch (operation.logicalOperator) {
177     case LogicalOperator::And:
178       result = xt && yt;
179       break;
180     case LogicalOperator::Or:
181       result = xt || yt;
182       break;
183     case LogicalOperator::Eqv:
184       result = xt == yt;
185       break;
186     case LogicalOperator::Neqv:
187       result = xt != yt;
188       break;
189     case LogicalOperator::Not:
190       DIE("not a binary operator");
191     }
192     return Expr<LOGICAL>{Constant<LOGICAL>{result}};
193   }
194   return Expr<LOGICAL>{std::move(operation)};
195 }
196 
197 FOR_EACH_LOGICAL_KIND(template class ExpressionBase, )
198 template class ExpressionBase<SomeLogical>;
199 } // namespace Fortran::evaluate
200