xref: /llvm-project/flang/lib/Evaluate/shape.cpp (revision 4fed5959974e4a85504667ce47ef03234dd9aec6)
1 //===-- lib/Evaluate/shape.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Evaluate/shape.h"
10 #include "flang/Common/idioms.h"
11 #include "flang/Common/template.h"
12 #include "flang/Evaluate/characteristics.h"
13 #include "flang/Evaluate/check-expression.h"
14 #include "flang/Evaluate/fold.h"
15 #include "flang/Evaluate/intrinsics.h"
16 #include "flang/Evaluate/tools.h"
17 #include "flang/Evaluate/type.h"
18 #include "flang/Parser/message.h"
19 #include "flang/Semantics/symbol.h"
20 #include <functional>
21 
22 using namespace std::placeholders; // _1, _2, &c. for std::bind()
23 
24 namespace Fortran::evaluate {
25 
26 bool IsImpliedShape(const Symbol &original) {
27   const Symbol &symbol{ResolveAssociations(original)};
28   const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()};
29   return details && symbol.attrs().test(semantics::Attr::PARAMETER) &&
30       details->shape().CanBeImpliedShape();
31 }
32 
33 bool IsExplicitShape(const Symbol &original) {
34   const Symbol &symbol{ResolveAssociations(original)};
35   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
36     const auto &shape{details->shape()};
37     return shape.Rank() == 0 ||
38         shape.IsExplicitShape(); // true when scalar, too
39   } else {
40     return symbol
41         .has<semantics::AssocEntityDetails>(); // exprs have explicit shape
42   }
43 }
44 
45 Shape GetShapeHelper::ConstantShape(const Constant<ExtentType> &arrayConstant) {
46   CHECK(arrayConstant.Rank() == 1);
47   Shape result;
48   std::size_t dimensions{arrayConstant.size()};
49   for (std::size_t j{0}; j < dimensions; ++j) {
50     Scalar<ExtentType> extent{arrayConstant.values().at(j)};
51     result.emplace_back(MaybeExtentExpr{ExtentExpr{std::move(extent)}});
52   }
53   return result;
54 }
55 
56 auto GetShapeHelper::AsShapeResult(ExtentExpr &&arrayExpr) const -> Result {
57   if (context_) {
58     arrayExpr = Fold(*context_, std::move(arrayExpr));
59   }
60   if (const auto *constArray{UnwrapConstantValue<ExtentType>(arrayExpr)}) {
61     return ConstantShape(*constArray);
62   }
63   if (auto *constructor{UnwrapExpr<ArrayConstructor<ExtentType>>(arrayExpr)}) {
64     Shape result;
65     for (auto &value : *constructor) {
66       auto *expr{std::get_if<ExtentExpr>(&value.u)};
67       if (expr && expr->Rank() == 0) {
68         result.emplace_back(std::move(*expr));
69       } else {
70         return std::nullopt;
71       }
72     }
73     return result;
74   } else {
75     return std::nullopt;
76   }
77 }
78 
79 Shape GetShapeHelper::CreateShape(int rank, NamedEntity &base) const {
80   Shape shape;
81   for (int dimension{0}; dimension < rank; ++dimension) {
82     shape.emplace_back(GetExtent(base, dimension, invariantOnly_));
83   }
84   return shape;
85 }
86 
87 std::optional<ExtentExpr> AsExtentArrayExpr(const Shape &shape) {
88   ArrayConstructorValues<ExtentType> values;
89   for (const auto &dim : shape) {
90     if (dim) {
91       values.Push(common::Clone(*dim));
92     } else {
93       return std::nullopt;
94     }
95   }
96   return ExtentExpr{ArrayConstructor<ExtentType>{std::move(values)}};
97 }
98 
99 std::optional<Constant<ExtentType>> AsConstantShape(
100     FoldingContext &context, const Shape &shape) {
101   if (auto shapeArray{AsExtentArrayExpr(shape)}) {
102     auto folded{Fold(context, std::move(*shapeArray))};
103     if (auto *p{UnwrapConstantValue<ExtentType>(folded)}) {
104       return std::move(*p);
105     }
106   }
107   return std::nullopt;
108 }
109 
110 Constant<SubscriptInteger> AsConstantShape(const ConstantSubscripts &shape) {
111   using IntType = Scalar<SubscriptInteger>;
112   std::vector<IntType> result;
113   for (auto dim : shape) {
114     result.emplace_back(dim);
115   }
116   return {std::move(result), ConstantSubscripts{GetRank(shape)}};
117 }
118 
119 ConstantSubscripts AsConstantExtents(const Constant<ExtentType> &shape) {
120   ConstantSubscripts result;
121   for (const auto &extent : shape.values()) {
122     result.push_back(extent.ToInt64());
123   }
124   return result;
125 }
126 
127 std::optional<ConstantSubscripts> AsConstantExtents(
128     FoldingContext &context, const Shape &shape) {
129   if (auto shapeConstant{AsConstantShape(context, shape)}) {
130     return AsConstantExtents(*shapeConstant);
131   } else {
132     return std::nullopt;
133   }
134 }
135 
136 Shape AsShape(const ConstantSubscripts &shape) {
137   Shape result;
138   for (const auto &extent : shape) {
139     result.emplace_back(ExtentExpr{extent});
140   }
141   return result;
142 }
143 
144 std::optional<Shape> AsShape(const std::optional<ConstantSubscripts> &shape) {
145   if (shape) {
146     return AsShape(*shape);
147   } else {
148     return std::nullopt;
149   }
150 }
151 
152 Shape Fold(FoldingContext &context, Shape &&shape) {
153   for (auto &dim : shape) {
154     dim = Fold(context, std::move(dim));
155   }
156   return std::move(shape);
157 }
158 
159 std::optional<Shape> Fold(
160     FoldingContext &context, std::optional<Shape> &&shape) {
161   if (shape) {
162     return Fold(context, std::move(*shape));
163   } else {
164     return std::nullopt;
165   }
166 }
167 
168 static ExtentExpr ComputeTripCount(
169     ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
170   ExtentExpr strideCopy{common::Clone(stride)};
171   ExtentExpr span{
172       (std::move(upper) - std::move(lower) + std::move(strideCopy)) /
173       std::move(stride)};
174   return ExtentExpr{
175       Extremum<ExtentType>{Ordering::Greater, std::move(span), ExtentExpr{0}}};
176 }
177 
178 ExtentExpr CountTrips(
179     ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
180   return ComputeTripCount(
181       std::move(lower), std::move(upper), std::move(stride));
182 }
183 
184 ExtentExpr CountTrips(const ExtentExpr &lower, const ExtentExpr &upper,
185     const ExtentExpr &stride) {
186   return ComputeTripCount(
187       common::Clone(lower), common::Clone(upper), common::Clone(stride));
188 }
189 
190 MaybeExtentExpr CountTrips(MaybeExtentExpr &&lower, MaybeExtentExpr &&upper,
191     MaybeExtentExpr &&stride) {
192   std::function<ExtentExpr(ExtentExpr &&, ExtentExpr &&, ExtentExpr &&)> bound{
193       std::bind(ComputeTripCount, _1, _2, _3)};
194   return common::MapOptional(
195       std::move(bound), std::move(lower), std::move(upper), std::move(stride));
196 }
197 
198 MaybeExtentExpr GetSize(Shape &&shape) {
199   ExtentExpr extent{1};
200   for (auto &&dim : std::move(shape)) {
201     if (dim) {
202       extent = std::move(extent) * std::move(*dim);
203     } else {
204       return std::nullopt;
205     }
206   }
207   return extent;
208 }
209 
210 ConstantSubscript GetSize(const ConstantSubscripts &shape) {
211   ConstantSubscript size{1};
212   for (auto dim : shape) {
213     CHECK(dim >= 0);
214     size *= dim;
215   }
216   return size;
217 }
218 
219 bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
220   struct MyVisitor : public AnyTraverse<MyVisitor> {
221     using Base = AnyTraverse<MyVisitor>;
222     MyVisitor() : Base{*this} {}
223     using Base::operator();
224     bool operator()(const ImpliedDoIndex &) { return true; }
225   };
226   return MyVisitor{}(expr);
227 }
228 
229 // Determines lower bound on a dimension.  This can be other than 1 only
230 // for a reference to a whole array object or component. (See LBOUND, 16.9.109).
231 // ASSOCIATE construct entities may require traversal of their referents.
232 template <typename RESULT, bool LBOUND_SEMANTICS>
233 class GetLowerBoundHelper
234     : public Traverse<GetLowerBoundHelper<RESULT, LBOUND_SEMANTICS>, RESULT> {
235 public:
236   using Result = RESULT;
237   using Base = Traverse<GetLowerBoundHelper, RESULT>;
238   using Base::operator();
239   explicit GetLowerBoundHelper(
240       int d, FoldingContext *context, bool invariantOnly)
241       : Base{*this}, dimension_{d}, context_{context},
242         invariantOnly_{invariantOnly} {}
243   static Result Default() { return Result{1}; }
244   static Result Combine(Result &&, Result &&) {
245     // Operator results and array references always have lower bounds == 1
246     return Result{1};
247   }
248 
249   Result GetLowerBound(const Symbol &symbol0, NamedEntity &&base) const {
250     const Symbol &symbol{symbol0.GetUltimate()};
251     if (const auto *object{
252             symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
253       int rank{object->shape().Rank()};
254       if (dimension_ < rank) {
255         const semantics::ShapeSpec &shapeSpec{object->shape()[dimension_]};
256         if (shapeSpec.lbound().isExplicit()) {
257           if (const auto &lbound{shapeSpec.lbound().GetExplicit()}) {
258             if constexpr (LBOUND_SEMANTICS) {
259               bool ok{false};
260               auto lbValue{ToInt64(*lbound)};
261               if (dimension_ == rank - 1 && object->IsAssumedSize()) {
262                 // last dimension of assumed-size dummy array: don't worry
263                 // about handling an empty dimension
264                 ok = !invariantOnly_ || IsScopeInvariantExpr(*lbound);
265               } else if (lbValue.value_or(0) == 1) {
266                 // Lower bound is 1, regardless of extent
267                 ok = true;
268               } else if (const auto &ubound{shapeSpec.ubound().GetExplicit()}) {
269                 // If we can't prove that the dimension is nonempty,
270                 // we must be conservative.
271                 // TODO: simple symbolic math in expression rewriting to
272                 // cope with cases like A(J:J)
273                 if (context_) {
274                   auto extent{ToInt64(Fold(*context_,
275                       ExtentExpr{*ubound} - ExtentExpr{*lbound} +
276                           ExtentExpr{1}))};
277                   if (extent) {
278                     if (extent <= 0) {
279                       return Result{1};
280                     }
281                     ok = true;
282                   } else {
283                     ok = false;
284                   }
285                 } else {
286                   auto ubValue{ToInt64(*ubound)};
287                   if (lbValue && ubValue) {
288                     if (*lbValue > *ubValue) {
289                       return Result{1};
290                     }
291                     ok = true;
292                   } else {
293                     ok = false;
294                   }
295                 }
296               }
297               return ok ? *lbound : Result{};
298             } else {
299               return *lbound;
300             }
301           } else {
302             return Result{1};
303           }
304         }
305         if (IsDescriptor(symbol)) {
306           return ExtentExpr{DescriptorInquiry{std::move(base),
307               DescriptorInquiry::Field::LowerBound, dimension_}};
308         }
309       }
310     } else if (const auto *assoc{
311                    symbol.detailsIf<semantics::AssocEntityDetails>()}) {
312       if (assoc->IsAssumedSize()) { // RANK(*)
313         return Result{1};
314       } else if (assoc->IsAssumedRank()) { // RANK DEFAULT
315       } else if (assoc->rank()) { // RANK(n)
316         const Symbol &resolved{ResolveAssociations(symbol)};
317         if (IsDescriptor(resolved) && dimension_ < *assoc->rank()) {
318           return ExtentExpr{DescriptorInquiry{std::move(base),
319               DescriptorInquiry::Field::LowerBound, dimension_}};
320         }
321       } else {
322         Result exprLowerBound{((*this)(assoc->expr()))};
323         if (IsActuallyConstant(exprLowerBound)) {
324           return std::move(exprLowerBound);
325         } else {
326           // If the lower bound of the associated entity is not resolved to
327           // constant expression at the time of the association, it is unsafe
328           // to re-evaluate it later in the associate construct. Statements
329           // in-between may have modified its operands value.
330           return ExtentExpr{DescriptorInquiry{std::move(base),
331               DescriptorInquiry::Field::LowerBound, dimension_}};
332         }
333       }
334     }
335     if constexpr (LBOUND_SEMANTICS) {
336       return Result{};
337     } else {
338       return Result{1};
339     }
340   }
341 
342   Result operator()(const Symbol &symbol) const {
343     return GetLowerBound(symbol, NamedEntity{symbol});
344   }
345 
346   Result operator()(const Component &component) const {
347     if (component.base().Rank() == 0) {
348       return GetLowerBound(
349           component.GetLastSymbol(), NamedEntity{common::Clone(component)});
350     }
351     return Result{1};
352   }
353 
354   template <typename T> Result operator()(const Expr<T> &expr) const {
355     if (const Symbol * whole{UnwrapWholeSymbolOrComponentDataRef(expr)}) {
356       return (*this)(*whole);
357     } else if constexpr (common::HasMember<Constant<T>, decltype(expr.u)>) {
358       if (const auto *con{std::get_if<Constant<T>>(&expr.u)}) {
359         ConstantSubscripts lb{con->lbounds()};
360         if (dimension_ < GetRank(lb)) {
361           return Result{lb[dimension_]};
362         }
363       } else { // operation
364         return Result{1};
365       }
366     } else {
367       return (*this)(expr.u);
368     }
369     if constexpr (LBOUND_SEMANTICS) {
370       return Result{};
371     } else {
372       return Result{1};
373     }
374   }
375 
376 private:
377   int dimension_; // zero-based
378   FoldingContext *context_{nullptr};
379   bool invariantOnly_{false};
380 };
381 
382 ExtentExpr GetRawLowerBound(
383     const NamedEntity &base, int dimension, bool invariantOnly) {
384   return GetLowerBoundHelper<ExtentExpr, false>{
385       dimension, nullptr, invariantOnly}(base);
386 }
387 
388 ExtentExpr GetRawLowerBound(FoldingContext &context, const NamedEntity &base,
389     int dimension, bool invariantOnly) {
390   return Fold(context,
391       GetLowerBoundHelper<ExtentExpr, false>{
392           dimension, &context, invariantOnly}(base));
393 }
394 
395 MaybeExtentExpr GetLBOUND(
396     const NamedEntity &base, int dimension, bool invariantOnly) {
397   return GetLowerBoundHelper<MaybeExtentExpr, true>{
398       dimension, nullptr, invariantOnly}(base);
399 }
400 
401 MaybeExtentExpr GetLBOUND(FoldingContext &context, const NamedEntity &base,
402     int dimension, bool invariantOnly) {
403   return Fold(context,
404       GetLowerBoundHelper<MaybeExtentExpr, true>{
405           dimension, &context, invariantOnly}(base));
406 }
407 
408 Shape GetRawLowerBounds(const NamedEntity &base, bool invariantOnly) {
409   Shape result;
410   int rank{base.Rank()};
411   for (int dim{0}; dim < rank; ++dim) {
412     result.emplace_back(GetRawLowerBound(base, dim, invariantOnly));
413   }
414   return result;
415 }
416 
417 Shape GetRawLowerBounds(
418     FoldingContext &context, const NamedEntity &base, bool invariantOnly) {
419   Shape result;
420   int rank{base.Rank()};
421   for (int dim{0}; dim < rank; ++dim) {
422     result.emplace_back(GetRawLowerBound(context, base, dim, invariantOnly));
423   }
424   return result;
425 }
426 
427 Shape GetLBOUNDs(const NamedEntity &base, bool invariantOnly) {
428   Shape result;
429   int rank{base.Rank()};
430   for (int dim{0}; dim < rank; ++dim) {
431     result.emplace_back(GetLBOUND(base, dim, invariantOnly));
432   }
433   return result;
434 }
435 
436 Shape GetLBOUNDs(
437     FoldingContext &context, const NamedEntity &base, bool invariantOnly) {
438   Shape result;
439   int rank{base.Rank()};
440   for (int dim{0}; dim < rank; ++dim) {
441     result.emplace_back(GetLBOUND(context, base, dim, invariantOnly));
442   }
443   return result;
444 }
445 
446 // If the upper and lower bounds are constant, return a constant expression for
447 // the extent.  In particular, if the upper bound is less than the lower bound,
448 // return zero.
449 static MaybeExtentExpr GetNonNegativeExtent(
450     const semantics::ShapeSpec &shapeSpec, bool invariantOnly) {
451   const auto &ubound{shapeSpec.ubound().GetExplicit()};
452   const auto &lbound{shapeSpec.lbound().GetExplicit()};
453   std::optional<ConstantSubscript> uval{ToInt64(ubound)};
454   std::optional<ConstantSubscript> lval{ToInt64(lbound)};
455   if (uval && lval) {
456     if (*uval < *lval) {
457       return ExtentExpr{0};
458     } else {
459       return ExtentExpr{*uval - *lval + 1};
460     }
461   } else if (lbound && ubound &&
462       (!invariantOnly ||
463           (IsScopeInvariantExpr(*lbound) && IsScopeInvariantExpr(*ubound)))) {
464     // Apply effective IDIM (MAX calculation with 0) so thet the
465     // result is never negative
466     if (lval.value_or(0) == 1) {
467       return ExtentExpr{Extremum<SubscriptInteger>{
468           Ordering::Greater, ExtentExpr{0}, common::Clone(*ubound)}};
469     } else {
470       return ExtentExpr{
471           Extremum<SubscriptInteger>{Ordering::Greater, ExtentExpr{0},
472               common::Clone(*ubound) - common::Clone(*lbound) + ExtentExpr{1}}};
473     }
474   } else {
475     return std::nullopt;
476   }
477 }
478 
479 MaybeExtentExpr GetAssociatedExtent(const NamedEntity &base,
480     const semantics::AssocEntityDetails &assoc, int dimension) {
481   if (auto shape{GetShape(assoc.expr())}) {
482     if (dimension < static_cast<int>(shape->size())) {
483       auto &extent{shape->at(dimension)};
484       if (extent && IsActuallyConstant(*extent)) {
485         return std::move(extent);
486       } else {
487         // Otherwise, evaluating the associated expression extent expression
488         // after the associate statement is unsafe given statements inside the
489         // associate may have modified the associated expression operands
490         // values.
491         return ExtentExpr{DescriptorInquiry{
492             NamedEntity{base}, DescriptorInquiry::Field::Extent, dimension}};
493       }
494     }
495   }
496   return std::nullopt;
497 }
498 
499 MaybeExtentExpr GetExtent(
500     const NamedEntity &base, int dimension, bool invariantOnly) {
501   CHECK(dimension >= 0);
502   const Symbol &last{base.GetLastSymbol()};
503   const Symbol &symbol{ResolveAssociations(last)};
504   if (const auto *assoc{last.detailsIf<semantics::AssocEntityDetails>()}) {
505     if (assoc->IsAssumedSize() || assoc->IsAssumedRank()) { // RANK(*)/DEFAULT
506       return std::nullopt;
507     } else if (assoc->rank()) { // RANK(n)
508       if (semantics::IsDescriptor(symbol) && dimension < *assoc->rank()) {
509         return ExtentExpr{DescriptorInquiry{
510             NamedEntity{base}, DescriptorInquiry::Field::Extent, dimension}};
511       }
512     } else {
513       return GetAssociatedExtent(base, *assoc, dimension);
514     }
515   }
516   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
517     if (IsImpliedShape(symbol) && details->init()) {
518       if (auto shape{GetShape(symbol)}) {
519         if (dimension < static_cast<int>(shape->size())) {
520           return std::move(shape->at(dimension));
521         }
522       }
523     } else {
524       int j{0};
525       for (const auto &shapeSpec : details->shape()) {
526         if (j++ == dimension) {
527           if (auto extent{GetNonNegativeExtent(shapeSpec, invariantOnly)}) {
528             return extent;
529           } else if (details->IsAssumedSize() && j == symbol.Rank()) {
530             return std::nullopt;
531           } else if (semantics::IsDescriptor(symbol)) {
532             return ExtentExpr{DescriptorInquiry{NamedEntity{base},
533                 DescriptorInquiry::Field::Extent, dimension}};
534           } else {
535             break;
536           }
537         }
538       }
539     }
540   }
541   return std::nullopt;
542 }
543 
544 MaybeExtentExpr GetExtent(FoldingContext &context, const NamedEntity &base,
545     int dimension, bool invariantOnly) {
546   return Fold(context, GetExtent(base, dimension, invariantOnly));
547 }
548 
549 MaybeExtentExpr GetExtent(const Subscript &subscript, const NamedEntity &base,
550     int dimension, bool invariantOnly) {
551   return common::visit(
552       common::visitors{
553           [&](const Triplet &triplet) -> MaybeExtentExpr {
554             MaybeExtentExpr upper{triplet.upper()};
555             if (!upper) {
556               upper = GetUBOUND(base, dimension, invariantOnly);
557             }
558             MaybeExtentExpr lower{triplet.lower()};
559             if (!lower) {
560               lower = GetLBOUND(base, dimension, invariantOnly);
561             }
562             return CountTrips(std::move(lower), std::move(upper),
563                 MaybeExtentExpr{triplet.stride()});
564           },
565           [&](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtentExpr {
566             if (auto shape{GetShape(subs.value())}) {
567               if (GetRank(*shape) > 0) {
568                 CHECK(GetRank(*shape) == 1); // vector-valued subscript
569                 return std::move(shape->at(0));
570               }
571             }
572             return std::nullopt;
573           },
574       },
575       subscript.u);
576 }
577 
578 MaybeExtentExpr GetExtent(FoldingContext &context, const Subscript &subscript,
579     const NamedEntity &base, int dimension, bool invariantOnly) {
580   return Fold(context, GetExtent(subscript, base, dimension, invariantOnly));
581 }
582 
583 MaybeExtentExpr ComputeUpperBound(
584     ExtentExpr &&lower, MaybeExtentExpr &&extent) {
585   if (extent) {
586     if (ToInt64(lower).value_or(0) == 1) {
587       return std::move(*extent);
588     } else {
589       return std::move(*extent) + std::move(lower) - ExtentExpr{1};
590     }
591   } else {
592     return std::nullopt;
593   }
594 }
595 
596 MaybeExtentExpr ComputeUpperBound(
597     FoldingContext &context, ExtentExpr &&lower, MaybeExtentExpr &&extent) {
598   return Fold(context, ComputeUpperBound(std::move(lower), std::move(extent)));
599 }
600 
601 MaybeExtentExpr GetRawUpperBound(
602     const NamedEntity &base, int dimension, bool invariantOnly) {
603   const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
604   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
605     int rank{details->shape().Rank()};
606     if (dimension < rank) {
607       const auto &bound{details->shape()[dimension].ubound().GetExplicit()};
608       if (bound && (!invariantOnly || IsScopeInvariantExpr(*bound))) {
609         return *bound;
610       } else if (details->IsAssumedSize() && dimension + 1 == symbol.Rank()) {
611         return std::nullopt;
612       } else {
613         return ComputeUpperBound(
614             GetRawLowerBound(base, dimension), GetExtent(base, dimension));
615       }
616     }
617   } else if (const auto *assoc{
618                  symbol.detailsIf<semantics::AssocEntityDetails>()}) {
619     if (assoc->IsAssumedSize() || assoc->IsAssumedRank()) {
620       return std::nullopt;
621     } else if (assoc->rank() && dimension >= *assoc->rank()) {
622       return std::nullopt;
623     } else if (auto extent{GetAssociatedExtent(base, *assoc, dimension)}) {
624       return ComputeUpperBound(
625           GetRawLowerBound(base, dimension), std::move(extent));
626     }
627   }
628   return std::nullopt;
629 }
630 
631 MaybeExtentExpr GetRawUpperBound(FoldingContext &context,
632     const NamedEntity &base, int dimension, bool invariantOnly) {
633   return Fold(context, GetRawUpperBound(base, dimension, invariantOnly));
634 }
635 
636 static MaybeExtentExpr GetExplicitUBOUND(FoldingContext *context,
637     const semantics::ShapeSpec &shapeSpec, bool invariantOnly) {
638   const auto &ubound{shapeSpec.ubound().GetExplicit()};
639   if (ubound && (!invariantOnly || IsScopeInvariantExpr(*ubound))) {
640     if (auto extent{GetNonNegativeExtent(shapeSpec, invariantOnly)}) {
641       if (auto cstExtent{ToInt64(
642               context ? Fold(*context, std::move(*extent)) : *extent)}) {
643         if (cstExtent > 0) {
644           return *ubound;
645         } else if (cstExtent == 0) {
646           return ExtentExpr{0};
647         }
648       }
649     }
650   }
651   return std::nullopt;
652 }
653 
654 static MaybeExtentExpr GetUBOUND(FoldingContext *context,
655     const NamedEntity &base, int dimension, bool invariantOnly) {
656   const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
657   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
658     int rank{details->shape().Rank()};
659     if (dimension < rank) {
660       const semantics::ShapeSpec &shapeSpec{details->shape()[dimension]};
661       if (auto ubound{GetExplicitUBOUND(context, shapeSpec, invariantOnly)}) {
662         return *ubound;
663       } else if (details->IsAssumedSize() && dimension + 1 == symbol.Rank()) {
664         return std::nullopt; // UBOUND() folding replaces with -1
665       } else if (auto lb{GetLBOUND(base, dimension, invariantOnly)}) {
666         return ComputeUpperBound(
667             std::move(*lb), GetExtent(base, dimension, invariantOnly));
668       }
669     }
670   } else if (const auto *assoc{
671                  symbol.detailsIf<semantics::AssocEntityDetails>()}) {
672     if (assoc->IsAssumedSize() || assoc->IsAssumedRank()) {
673       return std::nullopt;
674     } else if (assoc->rank()) { // RANK (n)
675       const Symbol &resolved{ResolveAssociations(symbol)};
676       if (IsDescriptor(resolved) && dimension < *assoc->rank()) {
677         ExtentExpr lb{DescriptorInquiry{NamedEntity{base},
678             DescriptorInquiry::Field::LowerBound, dimension}};
679         ExtentExpr extent{DescriptorInquiry{
680             std::move(base), DescriptorInquiry::Field::Extent, dimension}};
681         return ComputeUpperBound(std::move(lb), std::move(extent));
682       }
683     } else if (assoc->expr()) {
684       if (auto extent{GetAssociatedExtent(base, *assoc, dimension)}) {
685         if (auto lb{GetLBOUND(base, dimension, invariantOnly)}) {
686           return ComputeUpperBound(std::move(*lb), std::move(extent));
687         }
688       }
689     }
690   }
691   return std::nullopt;
692 }
693 
694 MaybeExtentExpr GetUBOUND(
695     const NamedEntity &base, int dimension, bool invariantOnly) {
696   return GetUBOUND(nullptr, base, dimension, invariantOnly);
697 }
698 
699 MaybeExtentExpr GetUBOUND(FoldingContext &context, const NamedEntity &base,
700     int dimension, bool invariantOnly) {
701   return Fold(context, GetUBOUND(&context, base, dimension, invariantOnly));
702 }
703 
704 static Shape GetUBOUNDs(
705     FoldingContext *context, const NamedEntity &base, bool invariantOnly) {
706   Shape result;
707   int rank{base.Rank()};
708   for (int dim{0}; dim < rank; ++dim) {
709     result.emplace_back(GetUBOUND(context, base, dim, invariantOnly));
710   }
711   return result;
712 }
713 
714 Shape GetUBOUNDs(
715     FoldingContext &context, const NamedEntity &base, bool invariantOnly) {
716   return Fold(context, GetUBOUNDs(&context, base, invariantOnly));
717 }
718 
719 Shape GetUBOUNDs(const NamedEntity &base, bool invariantOnly) {
720   return GetUBOUNDs(nullptr, base, invariantOnly);
721 }
722 
723 auto GetShapeHelper::operator()(const Symbol &symbol) const -> Result {
724   return common::visit(
725       common::visitors{
726           [&](const semantics::ObjectEntityDetails &object) {
727             if (IsImpliedShape(symbol) && object.init()) {
728               return (*this)(object.init());
729             } else if (IsAssumedRank(symbol)) {
730               return Result{};
731             } else {
732               int n{object.shape().Rank()};
733               NamedEntity base{symbol};
734               return Result{CreateShape(n, base)};
735             }
736           },
737           [](const semantics::EntityDetails &) {
738             return ScalarShape(); // no dimensions seen
739           },
740           [&](const semantics::ProcEntityDetails &proc) {
741             if (const Symbol * interface{proc.procInterface()}) {
742               return (*this)(*interface);
743             } else {
744               return ScalarShape();
745             }
746           },
747           [&](const semantics::AssocEntityDetails &assoc) {
748             NamedEntity base{symbol};
749             if (assoc.rank()) { // SELECT RANK case
750               int n{assoc.rank().value()};
751               return Result{CreateShape(n, base)};
752             } else {
753               auto exprShape{((*this)(assoc.expr()))};
754               if (exprShape) {
755                 int rank{static_cast<int>(exprShape->size())};
756                 for (int dimension{0}; dimension < rank; ++dimension) {
757                   auto &extent{(*exprShape)[dimension]};
758                   if (extent && !IsActuallyConstant(*extent)) {
759                     extent = GetExtent(base, dimension);
760                   }
761                 }
762               }
763               return exprShape;
764             }
765           },
766           [&](const semantics::SubprogramDetails &subp) -> Result {
767             if (subp.isFunction()) {
768               auto resultShape{(*this)(subp.result())};
769               if (resultShape && !useResultSymbolShape_) {
770                 // Ensure the shape is constant. Otherwise, it may be referring
771                 // to symbols that belong to the subroutine scope and are
772                 // meaningless on the caller side without the related call
773                 // expression.
774                 for (auto &extent : *resultShape) {
775                   if (extent && !IsActuallyConstant(*extent)) {
776                     extent.reset();
777                   }
778                 }
779               }
780               return resultShape;
781             } else {
782               return Result{};
783             }
784           },
785           [&](const semantics::ProcBindingDetails &binding) {
786             return (*this)(binding.symbol());
787           },
788           [](const semantics::TypeParamDetails &) { return ScalarShape(); },
789           [](const auto &) { return Result{}; },
790       },
791       symbol.GetUltimate().details());
792 }
793 
794 auto GetShapeHelper::operator()(const Component &component) const -> Result {
795   const Symbol &symbol{component.GetLastSymbol()};
796   int rank{symbol.Rank()};
797   if (rank == 0) {
798     return (*this)(component.base());
799   } else if (symbol.has<semantics::ObjectEntityDetails>()) {
800     NamedEntity base{Component{component}};
801     return CreateShape(rank, base);
802   } else if (symbol.has<semantics::AssocEntityDetails>()) {
803     NamedEntity base{Component{component}};
804     return Result{CreateShape(rank, base)};
805   } else {
806     return (*this)(symbol);
807   }
808 }
809 
810 auto GetShapeHelper::operator()(const ArrayRef &arrayRef) const -> Result {
811   Shape shape;
812   int dimension{0};
813   const NamedEntity &base{arrayRef.base()};
814   for (const Subscript &ss : arrayRef.subscript()) {
815     if (ss.Rank() > 0) {
816       shape.emplace_back(GetExtent(ss, base, dimension));
817     }
818     ++dimension;
819   }
820   if (shape.empty()) {
821     if (const Component * component{base.UnwrapComponent()}) {
822       return (*this)(component->base());
823     }
824   }
825   return shape;
826 }
827 
828 auto GetShapeHelper::operator()(const CoarrayRef &coarrayRef) const -> Result {
829   NamedEntity base{coarrayRef.GetBase()};
830   if (coarrayRef.subscript().empty()) {
831     return (*this)(base);
832   } else {
833     Shape shape;
834     int dimension{0};
835     for (const Subscript &ss : coarrayRef.subscript()) {
836       if (ss.Rank() > 0) {
837         shape.emplace_back(GetExtent(ss, base, dimension));
838       }
839       ++dimension;
840     }
841     return shape;
842   }
843 }
844 
845 auto GetShapeHelper::operator()(const Substring &substring) const -> Result {
846   return (*this)(substring.parent());
847 }
848 
849 auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result {
850   if (call.Rank() == 0) {
851     return ScalarShape();
852   } else if (call.IsElemental()) {
853     // Use the shape of an actual array argument associated with a
854     // non-OPTIONAL dummy object argument.
855     if (context_) {
856       if (auto chars{characteristics::Procedure::FromActuals(
857               call.proc(), call.arguments(), *context_)}) {
858         std::size_t j{0};
859         std::size_t anyArrayArgRank{0};
860         for (const auto &arg : call.arguments()) {
861           if (arg && arg->Rank() > 0 && j < chars->dummyArguments.size()) {
862             anyArrayArgRank = arg->Rank();
863             if (!chars->dummyArguments[j].IsOptional()) {
864               return (*this)(*arg);
865             }
866           }
867           ++j;
868         }
869         if (anyArrayArgRank) {
870           // All dummy array arguments of the procedure are OPTIONAL.
871           // We cannot take the shape from just any array argument,
872           // because all of them might be OPTIONAL dummy arguments
873           // of the caller. Return unknown shape ranked according
874           // to the last actual array argument.
875           return Shape(anyArrayArgRank, MaybeExtentExpr{});
876         }
877       }
878     }
879     return ScalarShape();
880   } else if (const Symbol * symbol{call.proc().GetSymbol()}) {
881     return (*this)(*symbol);
882   } else if (const auto *intrinsic{call.proc().GetSpecificIntrinsic()}) {
883     if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
884         intrinsic->name == "ubound") {
885       // For LBOUND/UBOUND, these are the array-valued cases (no DIM=)
886       if (!call.arguments().empty() && call.arguments().front()) {
887         return Shape{
888             MaybeExtentExpr{ExtentExpr{call.arguments().front()->Rank()}}};
889       }
890     } else if (intrinsic->name == "all" || intrinsic->name == "any" ||
891         intrinsic->name == "count" || intrinsic->name == "iall" ||
892         intrinsic->name == "iany" || intrinsic->name == "iparity" ||
893         intrinsic->name == "maxval" || intrinsic->name == "minval" ||
894         intrinsic->name == "norm2" || intrinsic->name == "parity" ||
895         intrinsic->name == "product" || intrinsic->name == "sum") {
896       // Reduction with DIM=
897       if (call.arguments().size() >= 2) {
898         auto arrayShape{
899             (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
900         const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
901         if (arrayShape && dimArg) {
902           if (auto dim{ToInt64(*dimArg)}) {
903             if (*dim >= 1 &&
904                 static_cast<std::size_t>(*dim) <= arrayShape->size()) {
905               arrayShape->erase(arrayShape->begin() + (*dim - 1));
906               return std::move(*arrayShape);
907             }
908           }
909         }
910       }
911     } else if (intrinsic->name == "findloc" || intrinsic->name == "maxloc" ||
912         intrinsic->name == "minloc") {
913       std::size_t dimIndex{intrinsic->name == "findloc" ? 2u : 1u};
914       if (call.arguments().size() > dimIndex) {
915         if (auto arrayShape{
916                 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))}) {
917           auto rank{static_cast<int>(arrayShape->size())};
918           if (const auto *dimArg{
919                   UnwrapExpr<Expr<SomeType>>(call.arguments()[dimIndex])}) {
920             auto dim{ToInt64(*dimArg)};
921             if (dim && *dim >= 1 && *dim <= rank) {
922               arrayShape->erase(arrayShape->begin() + (*dim - 1));
923               return std::move(*arrayShape);
924             }
925           } else {
926             // xxxLOC(no DIM=) result is vector(1:RANK(ARRAY=))
927             return Shape{ExtentExpr{rank}};
928           }
929         }
930       }
931     } else if (intrinsic->name == "cshift" || intrinsic->name == "eoshift") {
932       if (!call.arguments().empty()) {
933         return (*this)(call.arguments()[0]);
934       }
935     } else if (intrinsic->name == "matmul") {
936       if (call.arguments().size() == 2) {
937         if (auto ashape{(*this)(call.arguments()[0])}) {
938           if (auto bshape{(*this)(call.arguments()[1])}) {
939             if (ashape->size() == 1 && bshape->size() == 2) {
940               bshape->erase(bshape->begin());
941               return std::move(*bshape); // matmul(vector, matrix)
942             } else if (ashape->size() == 2 && bshape->size() == 1) {
943               ashape->pop_back();
944               return std::move(*ashape); // matmul(matrix, vector)
945             } else if (ashape->size() == 2 && bshape->size() == 2) {
946               (*ashape)[1] = std::move((*bshape)[1]);
947               return std::move(*ashape); // matmul(matrix, matrix)
948             }
949           }
950         }
951       }
952     } else if (intrinsic->name == "pack") {
953       if (call.arguments().size() >= 3 && call.arguments().at(2)) {
954         // SHAPE(PACK(,,VECTOR=v)) -> SHAPE(v)
955         return (*this)(call.arguments().at(2));
956       } else if (call.arguments().size() >= 2 && context_) {
957         if (auto maskShape{(*this)(call.arguments().at(1))}) {
958           if (maskShape->size() == 0) {
959             // Scalar MASK= -> [MERGE(SIZE(ARRAY=), 0, mask)]
960             if (auto arrayShape{(*this)(call.arguments().at(0))}) {
961               if (auto arraySize{GetSize(std::move(*arrayShape))}) {
962                 ActualArguments toMerge{
963                     ActualArgument{AsGenericExpr(std::move(*arraySize))},
964                     ActualArgument{AsGenericExpr(ExtentExpr{0})},
965                     common::Clone(call.arguments().at(1))};
966                 auto specific{context_->intrinsics().Probe(
967                     CallCharacteristics{"merge"}, toMerge, *context_)};
968                 CHECK(specific);
969                 return Shape{ExtentExpr{FunctionRef<ExtentType>{
970                     ProcedureDesignator{std::move(specific->specificIntrinsic)},
971                     std::move(specific->arguments)}}};
972               }
973             }
974           } else {
975             // Non-scalar MASK= -> [COUNT(mask)]
976             ActualArguments toCount{ActualArgument{common::Clone(
977                 DEREF(call.arguments().at(1).value().UnwrapExpr()))}};
978             auto specific{context_->intrinsics().Probe(
979                 CallCharacteristics{"count"}, toCount, *context_)};
980             CHECK(specific);
981             return Shape{ExtentExpr{FunctionRef<ExtentType>{
982                 ProcedureDesignator{std::move(specific->specificIntrinsic)},
983                 std::move(specific->arguments)}}};
984           }
985         }
986       }
987     } else if (intrinsic->name == "reshape") {
988       if (call.arguments().size() >= 2 && call.arguments().at(1)) {
989         // SHAPE(RESHAPE(array,shape)) -> shape
990         if (const auto *shapeExpr{
991                 call.arguments().at(1).value().UnwrapExpr()}) {
992           auto shapeArg{std::get<Expr<SomeInteger>>(shapeExpr->u)};
993           if (auto result{AsShapeResult(
994                   ConvertToType<ExtentType>(std::move(shapeArg)))}) {
995             return result;
996           }
997         }
998       }
999     } else if (intrinsic->name == "spread") {
1000       // SHAPE(SPREAD(ARRAY,DIM,NCOPIES)) = SHAPE(ARRAY) with NCOPIES inserted
1001       // at position DIM.
1002       if (call.arguments().size() == 3) {
1003         auto arrayShape{
1004             (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
1005         const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
1006         const auto *nCopies{
1007             UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))};
1008         if (arrayShape && dimArg && nCopies) {
1009           if (auto dim{ToInt64(*dimArg)}) {
1010             if (*dim >= 1 &&
1011                 static_cast<std::size_t>(*dim) <= arrayShape->size() + 1) {
1012               arrayShape->emplace(arrayShape->begin() + *dim - 1,
1013                   ConvertToType<ExtentType>(common::Clone(*nCopies)));
1014               return std::move(*arrayShape);
1015             }
1016           }
1017         }
1018       }
1019     } else if (intrinsic->name == "transfer") {
1020       if (call.arguments().size() == 3 && call.arguments().at(2)) {
1021         // SIZE= is present; shape is vector [SIZE=]
1022         if (const auto *size{
1023                 UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))}) {
1024           return Shape{
1025               MaybeExtentExpr{ConvertToType<ExtentType>(common::Clone(*size))}};
1026         }
1027       } else if (context_) {
1028         if (auto moldTypeAndShape{characteristics::TypeAndShape::Characterize(
1029                 call.arguments().at(1), *context_)}) {
1030           if (GetRank(moldTypeAndShape->shape()) == 0) {
1031             // SIZE= is absent and MOLD= is scalar: result is scalar
1032             return ScalarShape();
1033           } else {
1034             // SIZE= is absent and MOLD= is array: result is vector whose
1035             // length is determined by sizes of types.  See 16.9.193p4 case(ii).
1036             // Note that if sourceBytes is not known to be empty, we
1037             // can fold only when moldElementBytes is known to not be zero;
1038             // the most general case risks a division by zero otherwise.
1039             if (auto sourceTypeAndShape{
1040                     characteristics::TypeAndShape::Characterize(
1041                         call.arguments().at(0), *context_)}) {
1042               if (auto sourceBytes{
1043                       sourceTypeAndShape->MeasureSizeInBytes(*context_)}) {
1044                 *sourceBytes = Fold(*context_, std::move(*sourceBytes));
1045                 if (auto sourceBytesConst{ToInt64(*sourceBytes)}) {
1046                   if (*sourceBytesConst == 0) {
1047                     return Shape{ExtentExpr{0}};
1048                   }
1049                 }
1050                 if (auto moldElementBytes{
1051                         moldTypeAndShape->MeasureElementSizeInBytes(
1052                             *context_, true)}) {
1053                   *moldElementBytes =
1054                       Fold(*context_, std::move(*moldElementBytes));
1055                   auto moldElementBytesConst{ToInt64(*moldElementBytes)};
1056                   if (moldElementBytesConst && *moldElementBytesConst != 0) {
1057                     ExtentExpr extent{Fold(*context_,
1058                         (std::move(*sourceBytes) +
1059                             common::Clone(*moldElementBytes) - ExtentExpr{1}) /
1060                             common::Clone(*moldElementBytes))};
1061                     return Shape{MaybeExtentExpr{std::move(extent)}};
1062                   }
1063                 }
1064               }
1065             }
1066           }
1067         }
1068       }
1069     } else if (intrinsic->name == "transpose") {
1070       if (call.arguments().size() >= 1) {
1071         if (auto shape{(*this)(call.arguments().at(0))}) {
1072           if (shape->size() == 2) {
1073             std::swap((*shape)[0], (*shape)[1]);
1074             return shape;
1075           }
1076         }
1077       }
1078     } else if (intrinsic->name == "unpack") {
1079       if (call.arguments().size() >= 2) {
1080         return (*this)(call.arguments()[1]); // MASK=
1081       }
1082     } else if (intrinsic->characteristics.value().attrs.test(characteristics::
1083                        Procedure::Attr::NullPointer)) { // NULL(MOLD=)
1084       return (*this)(call.arguments());
1085     } else {
1086       // TODO: shapes of other non-elemental intrinsic results
1087     }
1088   }
1089   // The rank is always known even if the extents are not.
1090   return Shape(static_cast<std::size_t>(call.Rank()), MaybeExtentExpr{});
1091 }
1092 
1093 void GetShapeHelper::AccumulateExtent(
1094     ExtentExpr &result, ExtentExpr &&n) const {
1095   result = std::move(result) + std::move(n);
1096   if (context_) {
1097     // Fold during expression creation to avoid creating an expression so
1098     // large we can't evaluate it without overflowing the stack.
1099     result = Fold(*context_, std::move(result));
1100   }
1101 }
1102 
1103 // Check conformance of the passed shapes.
1104 std::optional<bool> CheckConformance(parser::ContextualMessages &messages,
1105     const Shape &left, const Shape &right, CheckConformanceFlags::Flags flags,
1106     const char *leftIs, const char *rightIs) {
1107   int n{GetRank(left)};
1108   if (n == 0 && (flags & CheckConformanceFlags::LeftScalarExpandable)) {
1109     return true;
1110   }
1111   int rn{GetRank(right)};
1112   if (rn == 0 && (flags & CheckConformanceFlags::RightScalarExpandable)) {
1113     return true;
1114   }
1115   if (n != rn) {
1116     messages.Say("Rank of %1$s is %2$d, but %3$s has rank %4$d"_err_en_US,
1117         leftIs, n, rightIs, rn);
1118     return false;
1119   }
1120   for (int j{0}; j < n; ++j) {
1121     if (auto leftDim{ToInt64(left[j])}) {
1122       if (auto rightDim{ToInt64(right[j])}) {
1123         if (*leftDim != *rightDim) {
1124           messages.Say("Dimension %1$d of %2$s has extent %3$jd, "
1125                        "but %4$s has extent %5$jd"_err_en_US,
1126               j + 1, leftIs, *leftDim, rightIs, *rightDim);
1127           return false;
1128         }
1129       } else if (!(flags & CheckConformanceFlags::RightIsDeferredShape)) {
1130         return std::nullopt;
1131       }
1132     } else if (!(flags & CheckConformanceFlags::LeftIsDeferredShape)) {
1133       return std::nullopt;
1134     }
1135   }
1136   return true;
1137 }
1138 
1139 bool IncrementSubscripts(
1140     ConstantSubscripts &indices, const ConstantSubscripts &extents) {
1141   std::size_t rank(indices.size());
1142   CHECK(rank <= extents.size());
1143   for (std::size_t j{0}; j < rank; ++j) {
1144     if (extents[j] < 1) {
1145       return false;
1146     }
1147   }
1148   for (std::size_t j{0}; j < rank; ++j) {
1149     if (indices[j]++ < extents[j]) {
1150       return true;
1151     }
1152     indices[j] = 1;
1153   }
1154   return false;
1155 }
1156 
1157 } // namespace Fortran::evaluate
1158