xref: /llvm-project/flang/lib/Evaluate/fold-reduction.cpp (revision 9d084982a38526fb8ba15867ecceec3e9b82a6f8)
1 //===-- lib/Evaluate/fold-reduction.cpp -----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "fold-reduction.h"
10 
11 namespace Fortran::evaluate {
CheckReductionDIM(std::optional<int> & dim,FoldingContext & context,ActualArguments & arg,std::optional<int> dimIndex,int rank)12 bool CheckReductionDIM(std::optional<int> &dim, FoldingContext &context,
13     ActualArguments &arg, std::optional<int> dimIndex, int rank) {
14   if (!dimIndex || static_cast<std::size_t>(*dimIndex) >= arg.size() ||
15       !arg[*dimIndex]) {
16     dim.reset();
17     return true; // no DIM= argument
18   }
19   if (auto *dimConst{
20           Folder<SubscriptInteger>{context}.Folding(arg[*dimIndex])}) {
21     if (auto dimScalar{dimConst->GetScalarValue()}) {
22       auto dimVal{dimScalar->ToInt64()};
23       if (dimVal >= 1 && dimVal <= rank) {
24         dim = dimVal;
25         return true; // DIM= exists and is a valid constant
26       } else {
27         context.messages().Say(
28             "DIM=%jd is not valid for an array of rank %d"_err_en_US,
29             static_cast<std::intmax_t>(dimVal), rank);
30       }
31     }
32   }
33   return false; // DIM= bad or not scalar constant
34 }
35 
GetReductionMASK(std::optional<ActualArgument> & maskArg,const ConstantSubscripts & shape,FoldingContext & context)36 Constant<LogicalResult> *GetReductionMASK(
37     std::optional<ActualArgument> &maskArg, const ConstantSubscripts &shape,
38     FoldingContext &context) {
39   Constant<LogicalResult> *mask{
40       Folder<LogicalResult>{context}.Folding(maskArg)};
41   if (mask &&
42       !CheckConformance(context.messages(), AsShape(shape),
43           AsShape(mask->shape()), CheckConformanceFlags::RightScalarExpandable,
44           "ARRAY=", "MASK=")
45            .value_or(false)) {
46     mask = nullptr;
47   }
48   return mask;
49 }
50 } // namespace Fortran::evaluate
51