xref: /llvm-project/flang/lib/Evaluate/shape.cpp (revision 16c4b320fe9544f9556c0d1d733f5c50f1ba0da3)
1 //===-- lib/Evaluate/shape.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Evaluate/shape.h"
10 #include "flang/Common/idioms.h"
11 #include "flang/Common/template.h"
12 #include "flang/Evaluate/characteristics.h"
13 #include "flang/Evaluate/check-expression.h"
14 #include "flang/Evaluate/fold.h"
15 #include "flang/Evaluate/intrinsics.h"
16 #include "flang/Evaluate/tools.h"
17 #include "flang/Evaluate/type.h"
18 #include "flang/Parser/message.h"
19 #include "flang/Semantics/symbol.h"
20 #include <functional>
21 
22 using namespace std::placeholders; // _1, _2, &c. for std::bind()
23 
24 namespace Fortran::evaluate {
25 
26 bool IsImpliedShape(const Symbol &original) {
27   const Symbol &symbol{ResolveAssociations(original)};
28   const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()};
29   return details && symbol.attrs().test(semantics::Attr::PARAMETER) &&
30       details->shape().CanBeImpliedShape();
31 }
32 
33 bool IsExplicitShape(const Symbol &original) {
34   const Symbol &symbol{ResolveAssociations(original)};
35   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
36     const auto &shape{details->shape()};
37     return shape.Rank() == 0 ||
38         shape.IsExplicitShape(); // true when scalar, too
39   } else {
40     return symbol
41         .has<semantics::AssocEntityDetails>(); // exprs have explicit shape
42   }
43 }
44 
45 Shape GetShapeHelper::ConstantShape(const Constant<ExtentType> &arrayConstant) {
46   CHECK(arrayConstant.Rank() == 1);
47   Shape result;
48   std::size_t dimensions{arrayConstant.size()};
49   for (std::size_t j{0}; j < dimensions; ++j) {
50     Scalar<ExtentType> extent{arrayConstant.values().at(j)};
51     result.emplace_back(MaybeExtentExpr{ExtentExpr{std::move(extent)}});
52   }
53   return result;
54 }
55 
56 auto GetShapeHelper::AsShapeResult(ExtentExpr &&arrayExpr) const -> Result {
57   if (context_) {
58     arrayExpr = Fold(*context_, std::move(arrayExpr));
59   }
60   if (const auto *constArray{UnwrapConstantValue<ExtentType>(arrayExpr)}) {
61     return ConstantShape(*constArray);
62   }
63   if (auto *constructor{UnwrapExpr<ArrayConstructor<ExtentType>>(arrayExpr)}) {
64     Shape result;
65     for (auto &value : *constructor) {
66       auto *expr{std::get_if<ExtentExpr>(&value.u)};
67       if (expr && expr->Rank() == 0) {
68         result.emplace_back(std::move(*expr));
69       } else {
70         return std::nullopt;
71       }
72     }
73     return result;
74   } else {
75     return std::nullopt;
76   }
77 }
78 
79 Shape GetShapeHelper::CreateShape(int rank, NamedEntity &base) {
80   Shape shape;
81   for (int dimension{0}; dimension < rank; ++dimension) {
82     shape.emplace_back(GetExtent(base, dimension));
83   }
84   return shape;
85 }
86 
87 std::optional<ExtentExpr> AsExtentArrayExpr(const Shape &shape) {
88   ArrayConstructorValues<ExtentType> values;
89   for (const auto &dim : shape) {
90     if (dim) {
91       values.Push(common::Clone(*dim));
92     } else {
93       return std::nullopt;
94     }
95   }
96   return ExtentExpr{ArrayConstructor<ExtentType>{std::move(values)}};
97 }
98 
99 std::optional<Constant<ExtentType>> AsConstantShape(
100     FoldingContext &context, const Shape &shape) {
101   if (auto shapeArray{AsExtentArrayExpr(shape)}) {
102     auto folded{Fold(context, std::move(*shapeArray))};
103     if (auto *p{UnwrapConstantValue<ExtentType>(folded)}) {
104       return std::move(*p);
105     }
106   }
107   return std::nullopt;
108 }
109 
110 Constant<SubscriptInteger> AsConstantShape(const ConstantSubscripts &shape) {
111   using IntType = Scalar<SubscriptInteger>;
112   std::vector<IntType> result;
113   for (auto dim : shape) {
114     result.emplace_back(dim);
115   }
116   return {std::move(result), ConstantSubscripts{GetRank(shape)}};
117 }
118 
119 ConstantSubscripts AsConstantExtents(const Constant<ExtentType> &shape) {
120   ConstantSubscripts result;
121   for (const auto &extent : shape.values()) {
122     result.push_back(extent.ToInt64());
123   }
124   return result;
125 }
126 
127 std::optional<ConstantSubscripts> AsConstantExtents(
128     FoldingContext &context, const Shape &shape) {
129   if (auto shapeConstant{AsConstantShape(context, shape)}) {
130     return AsConstantExtents(*shapeConstant);
131   } else {
132     return std::nullopt;
133   }
134 }
135 
136 Shape AsShape(const ConstantSubscripts &shape) {
137   Shape result;
138   for (const auto &extent : shape) {
139     result.emplace_back(ExtentExpr{extent});
140   }
141   return result;
142 }
143 
144 std::optional<Shape> AsShape(const std::optional<ConstantSubscripts> &shape) {
145   if (shape) {
146     return AsShape(*shape);
147   } else {
148     return std::nullopt;
149   }
150 }
151 
152 Shape Fold(FoldingContext &context, Shape &&shape) {
153   for (auto &dim : shape) {
154     dim = Fold(context, std::move(dim));
155   }
156   return std::move(shape);
157 }
158 
159 std::optional<Shape> Fold(
160     FoldingContext &context, std::optional<Shape> &&shape) {
161   if (shape) {
162     return Fold(context, std::move(*shape));
163   } else {
164     return std::nullopt;
165   }
166 }
167 
168 static ExtentExpr ComputeTripCount(
169     ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
170   ExtentExpr strideCopy{common::Clone(stride)};
171   ExtentExpr span{
172       (std::move(upper) - std::move(lower) + std::move(strideCopy)) /
173       std::move(stride)};
174   return ExtentExpr{
175       Extremum<ExtentType>{Ordering::Greater, std::move(span), ExtentExpr{0}}};
176 }
177 
178 ExtentExpr CountTrips(
179     ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
180   return ComputeTripCount(
181       std::move(lower), std::move(upper), std::move(stride));
182 }
183 
184 ExtentExpr CountTrips(const ExtentExpr &lower, const ExtentExpr &upper,
185     const ExtentExpr &stride) {
186   return ComputeTripCount(
187       common::Clone(lower), common::Clone(upper), common::Clone(stride));
188 }
189 
190 MaybeExtentExpr CountTrips(MaybeExtentExpr &&lower, MaybeExtentExpr &&upper,
191     MaybeExtentExpr &&stride) {
192   std::function<ExtentExpr(ExtentExpr &&, ExtentExpr &&, ExtentExpr &&)> bound{
193       std::bind(ComputeTripCount, _1, _2, _3)};
194   return common::MapOptional(
195       std::move(bound), std::move(lower), std::move(upper), std::move(stride));
196 }
197 
198 MaybeExtentExpr GetSize(Shape &&shape) {
199   ExtentExpr extent{1};
200   for (auto &&dim : std::move(shape)) {
201     if (dim) {
202       extent = std::move(extent) * std::move(*dim);
203     } else {
204       return std::nullopt;
205     }
206   }
207   return extent;
208 }
209 
210 ConstantSubscript GetSize(const ConstantSubscripts &shape) {
211   ConstantSubscript size{1};
212   for (auto dim : shape) {
213     CHECK(dim >= 0);
214     size *= dim;
215   }
216   return size;
217 }
218 
219 bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
220   struct MyVisitor : public AnyTraverse<MyVisitor> {
221     using Base = AnyTraverse<MyVisitor>;
222     MyVisitor() : Base{*this} {}
223     using Base::operator();
224     bool operator()(const ImpliedDoIndex &) { return true; }
225   };
226   return MyVisitor{}(expr);
227 }
228 
229 // Determines lower bound on a dimension.  This can be other than 1 only
230 // for a reference to a whole array object or component. (See LBOUND, 16.9.109).
231 // ASSOCIATE construct entities may require traversal of their referents.
232 template <typename RESULT, bool LBOUND_SEMANTICS>
233 class GetLowerBoundHelper
234     : public Traverse<GetLowerBoundHelper<RESULT, LBOUND_SEMANTICS>, RESULT> {
235 public:
236   using Result = RESULT;
237   using Base = Traverse<GetLowerBoundHelper, RESULT>;
238   using Base::operator();
239   explicit GetLowerBoundHelper(int d, FoldingContext *context)
240       : Base{*this}, dimension_{d}, context_{context} {}
241   static Result Default() { return Result{1}; }
242   static Result Combine(Result &&, Result &&) {
243     // Operator results and array references always have lower bounds == 1
244     return Result{1};
245   }
246 
247   Result GetLowerBound(const Symbol &symbol0, NamedEntity &&base) const {
248     const Symbol &symbol{symbol0.GetUltimate()};
249     if (const auto *details{
250             symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
251       int rank{details->shape().Rank()};
252       if (dimension_ < rank) {
253         const semantics::ShapeSpec &shapeSpec{details->shape()[dimension_]};
254         if (shapeSpec.lbound().isExplicit()) {
255           if (const auto &lbound{shapeSpec.lbound().GetExplicit()}) {
256             if constexpr (LBOUND_SEMANTICS) {
257               bool ok{false};
258               auto lbValue{ToInt64(*lbound)};
259               if (dimension_ == rank - 1 && details->IsAssumedSize()) {
260                 // last dimension of assumed-size dummy array: don't worry
261                 // about handling an empty dimension
262                 ok = IsScopeInvariantExpr(*lbound);
263               } else if (lbValue.value_or(0) == 1) {
264                 // Lower bound is 1, regardless of extent
265                 ok = true;
266               } else if (const auto &ubound{shapeSpec.ubound().GetExplicit()}) {
267                 // If we can't prove that the dimension is nonempty,
268                 // we must be conservative.
269                 // TODO: simple symbolic math in expression rewriting to
270                 // cope with cases like A(J:J)
271                 if (context_) {
272                   auto extent{ToInt64(Fold(*context_,
273                       ExtentExpr{*ubound} - ExtentExpr{*lbound} +
274                           ExtentExpr{1}))};
275                   if (extent) {
276                     if (extent <= 0) {
277                       return Result{1};
278                     }
279                     ok = true;
280                   } else {
281                     ok = false;
282                   }
283                 } else {
284                   auto ubValue{ToInt64(*ubound)};
285                   if (lbValue && ubValue) {
286                     if (*lbValue > *ubValue) {
287                       return Result{1};
288                     }
289                     ok = true;
290                   } else {
291                     ok = false;
292                   }
293                 }
294               }
295               return ok ? *lbound : Result{};
296             } else {
297               return *lbound;
298             }
299           } else {
300             return Result{1};
301           }
302         }
303         if (IsDescriptor(symbol)) {
304           return ExtentExpr{DescriptorInquiry{std::move(base),
305               DescriptorInquiry::Field::LowerBound, dimension_}};
306         }
307       }
308     } else if (const auto *assoc{
309                    symbol.detailsIf<semantics::AssocEntityDetails>()}) {
310       if (assoc->rank()) { // SELECT RANK case
311         const Symbol &resolved{ResolveAssociations(symbol)};
312         if (IsDescriptor(resolved) && dimension_ < *assoc->rank()) {
313           return ExtentExpr{DescriptorInquiry{std::move(base),
314               DescriptorInquiry::Field::LowerBound, dimension_}};
315         }
316       } else {
317         Result exprLowerBound{((*this)(assoc->expr()))};
318         if (IsActuallyConstant(exprLowerBound)) {
319           return std::move(exprLowerBound);
320         } else {
321           // If the lower bound of the associated entity is not resolved to
322           // constant expression at the time of the association, it is unsafe
323           // to re-evaluate it later in the associate construct. Statements
324           // in-between may have modified its operands value.
325           return ExtentExpr{DescriptorInquiry{std::move(base),
326               DescriptorInquiry::Field::LowerBound, dimension_}};
327         }
328       }
329     }
330     if constexpr (LBOUND_SEMANTICS) {
331       return Result{};
332     } else {
333       return Result{1};
334     }
335   }
336 
337   Result operator()(const Symbol &symbol) const {
338     return GetLowerBound(symbol, NamedEntity{symbol});
339   }
340 
341   Result operator()(const Component &component) const {
342     if (component.base().Rank() == 0) {
343       return GetLowerBound(
344           component.GetLastSymbol(), NamedEntity{common::Clone(component)});
345     }
346     return Result{1};
347   }
348 
349   template <typename T> Result operator()(const Expr<T> &expr) const {
350     if (const Symbol * whole{UnwrapWholeSymbolOrComponentDataRef(expr)}) {
351       return (*this)(*whole);
352     } else if constexpr (common::HasMember<Constant<T>, decltype(expr.u)>) {
353       if (const auto *con{std::get_if<Constant<T>>(&expr.u)}) {
354         ConstantSubscripts lb{con->lbounds()};
355         if (dimension_ < GetRank(lb)) {
356           return Result{lb[dimension_]};
357         }
358       } else { // operation
359         return Result{1};
360       }
361     } else {
362       return (*this)(expr.u);
363     }
364     if constexpr (LBOUND_SEMANTICS) {
365       return Result{};
366     } else {
367       return Result{1};
368     }
369   }
370 
371 private:
372   int dimension_; // zero-based
373   FoldingContext *context_{nullptr};
374 };
375 
376 ExtentExpr GetRawLowerBound(const NamedEntity &base, int dimension) {
377   return GetLowerBoundHelper<ExtentExpr, false>{dimension, nullptr}(base);
378 }
379 
380 ExtentExpr GetRawLowerBound(
381     FoldingContext &context, const NamedEntity &base, int dimension) {
382   return Fold(context,
383       GetLowerBoundHelper<ExtentExpr, false>{dimension, &context}(base));
384 }
385 
386 MaybeExtentExpr GetLBOUND(const NamedEntity &base, int dimension) {
387   return GetLowerBoundHelper<MaybeExtentExpr, true>{dimension, nullptr}(base);
388 }
389 
390 MaybeExtentExpr GetLBOUND(
391     FoldingContext &context, const NamedEntity &base, int dimension) {
392   return Fold(context,
393       GetLowerBoundHelper<MaybeExtentExpr, true>{dimension, &context}(base));
394 }
395 
396 Shape GetRawLowerBounds(const NamedEntity &base) {
397   Shape result;
398   int rank{base.Rank()};
399   for (int dim{0}; dim < rank; ++dim) {
400     result.emplace_back(GetRawLowerBound(base, dim));
401   }
402   return result;
403 }
404 
405 Shape GetRawLowerBounds(FoldingContext &context, const NamedEntity &base) {
406   Shape result;
407   int rank{base.Rank()};
408   for (int dim{0}; dim < rank; ++dim) {
409     result.emplace_back(GetRawLowerBound(context, base, dim));
410   }
411   return result;
412 }
413 
414 Shape GetLBOUNDs(const NamedEntity &base) {
415   Shape result;
416   int rank{base.Rank()};
417   for (int dim{0}; dim < rank; ++dim) {
418     result.emplace_back(GetLBOUND(base, dim));
419   }
420   return result;
421 }
422 
423 Shape GetLBOUNDs(FoldingContext &context, const NamedEntity &base) {
424   Shape result;
425   int rank{base.Rank()};
426   for (int dim{0}; dim < rank; ++dim) {
427     result.emplace_back(GetLBOUND(context, base, dim));
428   }
429   return result;
430 }
431 
432 // If the upper and lower bounds are constant, return a constant expression for
433 // the extent.  In particular, if the upper bound is less than the lower bound,
434 // return zero.
435 static MaybeExtentExpr GetNonNegativeExtent(
436     const semantics::ShapeSpec &shapeSpec) {
437   const auto &ubound{shapeSpec.ubound().GetExplicit()};
438   const auto &lbound{shapeSpec.lbound().GetExplicit()};
439   std::optional<ConstantSubscript> uval{ToInt64(ubound)};
440   std::optional<ConstantSubscript> lval{ToInt64(lbound)};
441   if (uval && lval) {
442     if (*uval < *lval) {
443       return ExtentExpr{0};
444     } else {
445       return ExtentExpr{*uval - *lval + 1};
446     }
447   } else if (lbound && ubound && IsScopeInvariantExpr(*lbound) &&
448       IsScopeInvariantExpr(*ubound)) {
449     // Apply effective IDIM (MAX calculation with 0) so thet the
450     // result is never negative
451     if (lval.value_or(0) == 1) {
452       return ExtentExpr{Extremum<SubscriptInteger>{
453           Ordering::Greater, ExtentExpr{0}, common::Clone(*ubound)}};
454     } else {
455       return ExtentExpr{
456           Extremum<SubscriptInteger>{Ordering::Greater, ExtentExpr{0},
457               common::Clone(*ubound) - common::Clone(*lbound) + ExtentExpr{1}}};
458     }
459   } else {
460     return std::nullopt;
461   }
462 }
463 
464 MaybeExtentExpr GetAssociatedExtent(const NamedEntity &base,
465     const semantics::AssocEntityDetails &assoc, int dimension) {
466   if (auto shape{GetShape(assoc.expr())}) {
467     if (dimension < static_cast<int>(shape->size())) {
468       auto &extent{shape->at(dimension)};
469       if (extent && IsActuallyConstant(*extent)) {
470         return std::move(extent);
471       } else {
472         // Otherwise, evaluating the associated expression extent expression
473         // after the associate statement is unsafe given statements inside the
474         // associate may have modified the associated expression operands
475         // values.
476         return ExtentExpr{DescriptorInquiry{
477             NamedEntity{base}, DescriptorInquiry::Field::Extent, dimension}};
478       }
479     }
480   }
481   return std::nullopt;
482 }
483 
484 MaybeExtentExpr GetExtent(const NamedEntity &base, int dimension) {
485   CHECK(dimension >= 0);
486   const Symbol &last{base.GetLastSymbol()};
487   const Symbol &symbol{ResolveAssociationsExceptSelectRank(last)};
488   if (const auto *assoc{last.detailsIf<semantics::AssocEntityDetails>()}) {
489     if (assoc->rank()) { // SELECT RANK case
490       if (semantics::IsDescriptor(symbol) && dimension < *assoc->rank()) {
491         return ExtentExpr{DescriptorInquiry{
492             NamedEntity{base}, DescriptorInquiry::Field::Extent, dimension}};
493       }
494     } else {
495       return GetAssociatedExtent(base, *assoc, dimension);
496     }
497   }
498   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
499     if (IsImpliedShape(symbol) && details->init()) {
500       if (auto shape{GetShape(symbol)}) {
501         if (dimension < static_cast<int>(shape->size())) {
502           return std::move(shape->at(dimension));
503         }
504       }
505     } else {
506       int j{0};
507       for (const auto &shapeSpec : details->shape()) {
508         if (j++ == dimension) {
509           if (auto extent{GetNonNegativeExtent(shapeSpec)}) {
510             return extent;
511           } else if (details->IsAssumedSize() && j == symbol.Rank()) {
512             return std::nullopt;
513           } else if (semantics::IsDescriptor(symbol)) {
514             return ExtentExpr{DescriptorInquiry{NamedEntity{base},
515                 DescriptorInquiry::Field::Extent, dimension}};
516           } else {
517             break;
518           }
519         }
520       }
521     }
522   }
523   return std::nullopt;
524 }
525 
526 MaybeExtentExpr GetExtent(
527     FoldingContext &context, const NamedEntity &base, int dimension) {
528   return Fold(context, GetExtent(base, dimension));
529 }
530 
531 MaybeExtentExpr GetExtent(
532     const Subscript &subscript, const NamedEntity &base, int dimension) {
533   return common::visit(
534       common::visitors{
535           [&](const Triplet &triplet) -> MaybeExtentExpr {
536             MaybeExtentExpr upper{triplet.upper()};
537             if (!upper) {
538               upper = GetUBOUND(base, dimension);
539             }
540             MaybeExtentExpr lower{triplet.lower()};
541             if (!lower) {
542               lower = GetLBOUND(base, dimension);
543             }
544             return CountTrips(std::move(lower), std::move(upper),
545                 MaybeExtentExpr{triplet.stride()});
546           },
547           [&](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtentExpr {
548             if (auto shape{GetShape(subs.value())}) {
549               if (GetRank(*shape) > 0) {
550                 CHECK(GetRank(*shape) == 1); // vector-valued subscript
551                 return std::move(shape->at(0));
552               }
553             }
554             return std::nullopt;
555           },
556       },
557       subscript.u);
558 }
559 
560 MaybeExtentExpr GetExtent(FoldingContext &context, const Subscript &subscript,
561     const NamedEntity &base, int dimension) {
562   return Fold(context, GetExtent(subscript, base, dimension));
563 }
564 
565 MaybeExtentExpr ComputeUpperBound(
566     ExtentExpr &&lower, MaybeExtentExpr &&extent) {
567   if (extent) {
568     if (ToInt64(lower).value_or(0) == 1) {
569       return std::move(*extent);
570     } else {
571       return std::move(*extent) + std::move(lower) - ExtentExpr{1};
572     }
573   } else {
574     return std::nullopt;
575   }
576 }
577 
578 MaybeExtentExpr ComputeUpperBound(
579     FoldingContext &context, ExtentExpr &&lower, MaybeExtentExpr &&extent) {
580   return Fold(context, ComputeUpperBound(std::move(lower), std::move(extent)));
581 }
582 
583 MaybeExtentExpr GetRawUpperBound(const NamedEntity &base, int dimension) {
584   const Symbol &symbol{
585       ResolveAssociationsExceptSelectRank(base.GetLastSymbol())};
586   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
587     int rank{details->shape().Rank()};
588     if (dimension < rank) {
589       const auto &bound{details->shape()[dimension].ubound().GetExplicit()};
590       if (bound && IsScopeInvariantExpr(*bound)) {
591         return *bound;
592       } else if (details->IsAssumedSize() && dimension + 1 == symbol.Rank()) {
593         return std::nullopt;
594       } else {
595         return ComputeUpperBound(
596             GetRawLowerBound(base, dimension), GetExtent(base, dimension));
597       }
598     }
599   } else if (const auto *assoc{
600                  symbol.detailsIf<semantics::AssocEntityDetails>()}) {
601     if (auto extent{GetAssociatedExtent(base, *assoc, dimension)}) {
602       return ComputeUpperBound(
603           GetRawLowerBound(base, dimension), std::move(extent));
604     }
605   }
606   return std::nullopt;
607 }
608 
609 MaybeExtentExpr GetRawUpperBound(
610     FoldingContext &context, const NamedEntity &base, int dimension) {
611   return Fold(context, GetRawUpperBound(base, dimension));
612 }
613 
614 static MaybeExtentExpr GetExplicitUBOUND(
615     FoldingContext *context, const semantics::ShapeSpec &shapeSpec) {
616   const auto &ubound{shapeSpec.ubound().GetExplicit()};
617   if (ubound && IsScopeInvariantExpr(*ubound)) {
618     if (auto extent{GetNonNegativeExtent(shapeSpec)}) {
619       if (auto cstExtent{ToInt64(
620               context ? Fold(*context, std::move(*extent)) : *extent)}) {
621         if (cstExtent > 0) {
622           return *ubound;
623         } else if (cstExtent == 0) {
624           return ExtentExpr{0};
625         }
626       }
627     }
628   }
629   return std::nullopt;
630 }
631 
632 static MaybeExtentExpr GetUBOUND(
633     FoldingContext *context, const NamedEntity &base, int dimension) {
634   const Symbol &symbol{
635       ResolveAssociationsExceptSelectRank(base.GetLastSymbol())};
636   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
637     int rank{details->shape().Rank()};
638     if (dimension < rank) {
639       const semantics::ShapeSpec &shapeSpec{details->shape()[dimension]};
640       if (auto ubound{GetExplicitUBOUND(context, shapeSpec)}) {
641         return *ubound;
642       } else if (details->IsAssumedSize() && dimension + 1 == symbol.Rank()) {
643         return std::nullopt; // UBOUND() folding replaces with -1
644       } else if (auto lb{GetLBOUND(base, dimension)}) {
645         return ComputeUpperBound(std::move(*lb), GetExtent(base, dimension));
646       }
647     }
648   } else if (const auto *assoc{
649                  symbol.detailsIf<semantics::AssocEntityDetails>()}) {
650     if (assoc->rank()) { // SELECT RANK case
651       const Symbol &resolved{ResolveAssociations(symbol)};
652       if (IsDescriptor(resolved) && dimension < *assoc->rank()) {
653         ExtentExpr lb{DescriptorInquiry{NamedEntity{base},
654             DescriptorInquiry::Field::LowerBound, dimension}};
655         ExtentExpr extent{DescriptorInquiry{
656             std::move(base), DescriptorInquiry::Field::Extent, dimension}};
657         return ComputeUpperBound(std::move(lb), std::move(extent));
658       }
659     } else if (assoc->expr()) {
660       if (auto extent{GetAssociatedExtent(base, *assoc, dimension)}) {
661         if (auto lb{GetLBOUND(base, dimension)}) {
662           return ComputeUpperBound(std::move(*lb), std::move(extent));
663         }
664       }
665     }
666   }
667   return std::nullopt;
668 }
669 
670 MaybeExtentExpr GetUBOUND(const NamedEntity &base, int dimension) {
671   return GetUBOUND(nullptr, base, dimension);
672 }
673 
674 MaybeExtentExpr GetUBOUND(
675     FoldingContext &context, const NamedEntity &base, int dimension) {
676   return Fold(context, GetUBOUND(&context, base, dimension));
677 }
678 
679 static Shape GetUBOUNDs(FoldingContext *context, const NamedEntity &base) {
680   Shape result;
681   int rank{base.Rank()};
682   for (int dim{0}; dim < rank; ++dim) {
683     result.emplace_back(GetUBOUND(context, base, dim));
684   }
685   return result;
686 }
687 
688 Shape GetUBOUNDs(FoldingContext &context, const NamedEntity &base) {
689   return Fold(context, GetUBOUNDs(&context, base));
690 }
691 
692 Shape GetUBOUNDs(const NamedEntity &base) { return GetUBOUNDs(nullptr, base); }
693 
694 auto GetShapeHelper::operator()(const Symbol &symbol) const -> Result {
695   return common::visit(
696       common::visitors{
697           [&](const semantics::ObjectEntityDetails &object) {
698             if (IsImpliedShape(symbol) && object.init()) {
699               return (*this)(object.init());
700             } else if (IsAssumedRank(symbol)) {
701               return Result{};
702             } else {
703               int n{object.shape().Rank()};
704               NamedEntity base{symbol};
705               return Result{CreateShape(n, base)};
706             }
707           },
708           [](const semantics::EntityDetails &) {
709             return ScalarShape(); // no dimensions seen
710           },
711           [&](const semantics::ProcEntityDetails &proc) {
712             if (const Symbol * interface{proc.procInterface()}) {
713               return (*this)(*interface);
714             } else {
715               return ScalarShape();
716             }
717           },
718           [&](const semantics::AssocEntityDetails &assoc) {
719             NamedEntity base{symbol};
720             if (assoc.rank()) { // SELECT RANK case
721               int n{assoc.rank().value()};
722               return Result{CreateShape(n, base)};
723             } else {
724               auto exprShape{((*this)(assoc.expr()))};
725               if (exprShape) {
726                 int rank{static_cast<int>(exprShape->size())};
727                 for (int dimension{0}; dimension < rank; ++dimension) {
728                   auto &extent{(*exprShape)[dimension]};
729                   if (extent && !IsActuallyConstant(*extent)) {
730                     extent = GetExtent(base, dimension);
731                   }
732                 }
733               }
734               return exprShape;
735             }
736           },
737           [&](const semantics::SubprogramDetails &subp) -> Result {
738             if (subp.isFunction()) {
739               auto resultShape{(*this)(subp.result())};
740               if (resultShape && !useResultSymbolShape_) {
741                 // Ensure the shape is constant. Otherwise, it may be referring
742                 // to symbols that belong to the subroutine scope and are
743                 // meaningless on the caller side without the related call
744                 // expression.
745                 for (auto &extent : *resultShape) {
746                   if (extent && !IsActuallyConstant(*extent)) {
747                     extent.reset();
748                   }
749                 }
750               }
751               return resultShape;
752             } else {
753               return Result{};
754             }
755           },
756           [&](const semantics::ProcBindingDetails &binding) {
757             return (*this)(binding.symbol());
758           },
759           [](const semantics::TypeParamDetails &) { return ScalarShape(); },
760           [](const auto &) { return Result{}; },
761       },
762       symbol.GetUltimate().details());
763 }
764 
765 auto GetShapeHelper::operator()(const Component &component) const -> Result {
766   const Symbol &symbol{component.GetLastSymbol()};
767   int rank{symbol.Rank()};
768   if (rank == 0) {
769     return (*this)(component.base());
770   } else if (symbol.has<semantics::ObjectEntityDetails>()) {
771     NamedEntity base{Component{component}};
772     return CreateShape(rank, base);
773   } else if (symbol.has<semantics::AssocEntityDetails>()) {
774     NamedEntity base{Component{component}};
775     return Result{CreateShape(rank, base)};
776   } else {
777     return (*this)(symbol);
778   }
779 }
780 
781 auto GetShapeHelper::operator()(const ArrayRef &arrayRef) const -> Result {
782   Shape shape;
783   int dimension{0};
784   const NamedEntity &base{arrayRef.base()};
785   for (const Subscript &ss : arrayRef.subscript()) {
786     if (ss.Rank() > 0) {
787       shape.emplace_back(GetExtent(ss, base, dimension));
788     }
789     ++dimension;
790   }
791   if (shape.empty()) {
792     if (const Component * component{base.UnwrapComponent()}) {
793       return (*this)(component->base());
794     }
795   }
796   return shape;
797 }
798 
799 auto GetShapeHelper::operator()(const CoarrayRef &coarrayRef) const -> Result {
800   NamedEntity base{coarrayRef.GetBase()};
801   if (coarrayRef.subscript().empty()) {
802     return (*this)(base);
803   } else {
804     Shape shape;
805     int dimension{0};
806     for (const Subscript &ss : coarrayRef.subscript()) {
807       if (ss.Rank() > 0) {
808         shape.emplace_back(GetExtent(ss, base, dimension));
809       }
810       ++dimension;
811     }
812     return shape;
813   }
814 }
815 
816 auto GetShapeHelper::operator()(const Substring &substring) const -> Result {
817   return (*this)(substring.parent());
818 }
819 
820 auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result {
821   if (call.Rank() == 0) {
822     return ScalarShape();
823   } else if (call.IsElemental()) {
824     // Use the shape of an actual array argument associated with a
825     // non-OPTIONAL dummy object argument.
826     if (context_) {
827       if (auto chars{characteristics::Procedure::FromActuals(
828               call.proc(), call.arguments(), *context_)}) {
829         std::size_t j{0};
830         std::size_t anyArrayArgRank{0};
831         for (const auto &arg : call.arguments()) {
832           if (arg && arg->Rank() > 0 && j < chars->dummyArguments.size()) {
833             anyArrayArgRank = arg->Rank();
834             if (!chars->dummyArguments[j].IsOptional()) {
835               return (*this)(*arg);
836             }
837           }
838           ++j;
839         }
840         if (anyArrayArgRank) {
841           // All dummy array arguments of the procedure are OPTIONAL.
842           // We cannot take the shape from just any array argument,
843           // because all of them might be OPTIONAL dummy arguments
844           // of the caller. Return unknown shape ranked according
845           // to the last actual array argument.
846           return Shape(anyArrayArgRank, MaybeExtentExpr{});
847         }
848       }
849     }
850     return ScalarShape();
851   } else if (const Symbol * symbol{call.proc().GetSymbol()}) {
852     return (*this)(*symbol);
853   } else if (const auto *intrinsic{call.proc().GetSpecificIntrinsic()}) {
854     if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
855         intrinsic->name == "ubound") {
856       // For LBOUND/UBOUND, these are the array-valued cases (no DIM=)
857       if (!call.arguments().empty() && call.arguments().front()) {
858         return Shape{
859             MaybeExtentExpr{ExtentExpr{call.arguments().front()->Rank()}}};
860       }
861     } else if (intrinsic->name == "all" || intrinsic->name == "any" ||
862         intrinsic->name == "count" || intrinsic->name == "iall" ||
863         intrinsic->name == "iany" || intrinsic->name == "iparity" ||
864         intrinsic->name == "maxval" || intrinsic->name == "minval" ||
865         intrinsic->name == "norm2" || intrinsic->name == "parity" ||
866         intrinsic->name == "product" || intrinsic->name == "sum") {
867       // Reduction with DIM=
868       if (call.arguments().size() >= 2) {
869         auto arrayShape{
870             (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
871         const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
872         if (arrayShape && dimArg) {
873           if (auto dim{ToInt64(*dimArg)}) {
874             if (*dim >= 1 &&
875                 static_cast<std::size_t>(*dim) <= arrayShape->size()) {
876               arrayShape->erase(arrayShape->begin() + (*dim - 1));
877               return std::move(*arrayShape);
878             }
879           }
880         }
881       }
882     } else if (intrinsic->name == "findloc" || intrinsic->name == "maxloc" ||
883         intrinsic->name == "minloc") {
884       std::size_t dimIndex{intrinsic->name == "findloc" ? 2u : 1u};
885       if (call.arguments().size() > dimIndex) {
886         if (auto arrayShape{
887                 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))}) {
888           auto rank{static_cast<int>(arrayShape->size())};
889           if (const auto *dimArg{
890                   UnwrapExpr<Expr<SomeType>>(call.arguments()[dimIndex])}) {
891             auto dim{ToInt64(*dimArg)};
892             if (dim && *dim >= 1 && *dim <= rank) {
893               arrayShape->erase(arrayShape->begin() + (*dim - 1));
894               return std::move(*arrayShape);
895             }
896           } else {
897             // xxxLOC(no DIM=) result is vector(1:RANK(ARRAY=))
898             return Shape{ExtentExpr{rank}};
899           }
900         }
901       }
902     } else if (intrinsic->name == "cshift" || intrinsic->name == "eoshift") {
903       if (!call.arguments().empty()) {
904         return (*this)(call.arguments()[0]);
905       }
906     } else if (intrinsic->name == "matmul") {
907       if (call.arguments().size() == 2) {
908         if (auto ashape{(*this)(call.arguments()[0])}) {
909           if (auto bshape{(*this)(call.arguments()[1])}) {
910             if (ashape->size() == 1 && bshape->size() == 2) {
911               bshape->erase(bshape->begin());
912               return std::move(*bshape); // matmul(vector, matrix)
913             } else if (ashape->size() == 2 && bshape->size() == 1) {
914               ashape->pop_back();
915               return std::move(*ashape); // matmul(matrix, vector)
916             } else if (ashape->size() == 2 && bshape->size() == 2) {
917               (*ashape)[1] = std::move((*bshape)[1]);
918               return std::move(*ashape); // matmul(matrix, matrix)
919             }
920           }
921         }
922       }
923     } else if (intrinsic->name == "pack") {
924       if (call.arguments().size() >= 3 && call.arguments().at(2)) {
925         // SHAPE(PACK(,,VECTOR=v)) -> SHAPE(v)
926         return (*this)(call.arguments().at(2));
927       } else if (call.arguments().size() >= 2 && context_) {
928         if (auto maskShape{(*this)(call.arguments().at(1))}) {
929           if (maskShape->size() == 0) {
930             // Scalar MASK= -> [MERGE(SIZE(ARRAY=), 0, mask)]
931             if (auto arrayShape{(*this)(call.arguments().at(0))}) {
932               if (auto arraySize{GetSize(std::move(*arrayShape))}) {
933                 ActualArguments toMerge{
934                     ActualArgument{AsGenericExpr(std::move(*arraySize))},
935                     ActualArgument{AsGenericExpr(ExtentExpr{0})},
936                     common::Clone(call.arguments().at(1))};
937                 auto specific{context_->intrinsics().Probe(
938                     CallCharacteristics{"merge"}, toMerge, *context_)};
939                 CHECK(specific);
940                 return Shape{ExtentExpr{FunctionRef<ExtentType>{
941                     ProcedureDesignator{std::move(specific->specificIntrinsic)},
942                     std::move(specific->arguments)}}};
943               }
944             }
945           } else {
946             // Non-scalar MASK= -> [COUNT(mask)]
947             ActualArguments toCount{ActualArgument{common::Clone(
948                 DEREF(call.arguments().at(1).value().UnwrapExpr()))}};
949             auto specific{context_->intrinsics().Probe(
950                 CallCharacteristics{"count"}, toCount, *context_)};
951             CHECK(specific);
952             return Shape{ExtentExpr{FunctionRef<ExtentType>{
953                 ProcedureDesignator{std::move(specific->specificIntrinsic)},
954                 std::move(specific->arguments)}}};
955           }
956         }
957       }
958     } else if (intrinsic->name == "reshape") {
959       if (call.arguments().size() >= 2 && call.arguments().at(1)) {
960         // SHAPE(RESHAPE(array,shape)) -> shape
961         if (const auto *shapeExpr{
962                 call.arguments().at(1).value().UnwrapExpr()}) {
963           auto shapeArg{std::get<Expr<SomeInteger>>(shapeExpr->u)};
964           if (auto result{AsShapeResult(
965                   ConvertToType<ExtentType>(std::move(shapeArg)))}) {
966             return result;
967           }
968         }
969       }
970     } else if (intrinsic->name == "spread") {
971       // SHAPE(SPREAD(ARRAY,DIM,NCOPIES)) = SHAPE(ARRAY) with NCOPIES inserted
972       // at position DIM.
973       if (call.arguments().size() == 3) {
974         auto arrayShape{
975             (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
976         const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
977         const auto *nCopies{
978             UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))};
979         if (arrayShape && dimArg && nCopies) {
980           if (auto dim{ToInt64(*dimArg)}) {
981             if (*dim >= 1 &&
982                 static_cast<std::size_t>(*dim) <= arrayShape->size() + 1) {
983               arrayShape->emplace(arrayShape->begin() + *dim - 1,
984                   ConvertToType<ExtentType>(common::Clone(*nCopies)));
985               return std::move(*arrayShape);
986             }
987           }
988         }
989       }
990     } else if (intrinsic->name == "transfer") {
991       if (call.arguments().size() == 3 && call.arguments().at(2)) {
992         // SIZE= is present; shape is vector [SIZE=]
993         if (const auto *size{
994                 UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))}) {
995           return Shape{
996               MaybeExtentExpr{ConvertToType<ExtentType>(common::Clone(*size))}};
997         }
998       } else if (context_) {
999         if (auto moldTypeAndShape{characteristics::TypeAndShape::Characterize(
1000                 call.arguments().at(1), *context_)}) {
1001           if (GetRank(moldTypeAndShape->shape()) == 0) {
1002             // SIZE= is absent and MOLD= is scalar: result is scalar
1003             return ScalarShape();
1004           } else {
1005             // SIZE= is absent and MOLD= is array: result is vector whose
1006             // length is determined by sizes of types.  See 16.9.193p4 case(ii).
1007             // Note that if sourceBytes is not known to be empty, we
1008             // can fold only when moldElementBytes is known to not be zero;
1009             // the most general case risks a division by zero otherwise.
1010             if (auto sourceTypeAndShape{
1011                     characteristics::TypeAndShape::Characterize(
1012                         call.arguments().at(0), *context_)}) {
1013               if (auto sourceBytes{
1014                       sourceTypeAndShape->MeasureSizeInBytes(*context_)}) {
1015                 *sourceBytes = Fold(*context_, std::move(*sourceBytes));
1016                 if (auto sourceBytesConst{ToInt64(*sourceBytes)}) {
1017                   if (*sourceBytesConst == 0) {
1018                     return Shape{ExtentExpr{0}};
1019                   }
1020                 }
1021                 if (auto moldElementBytes{
1022                         moldTypeAndShape->MeasureElementSizeInBytes(
1023                             *context_, true)}) {
1024                   *moldElementBytes =
1025                       Fold(*context_, std::move(*moldElementBytes));
1026                   auto moldElementBytesConst{ToInt64(*moldElementBytes)};
1027                   if (moldElementBytesConst && *moldElementBytesConst != 0) {
1028                     ExtentExpr extent{Fold(*context_,
1029                         (std::move(*sourceBytes) +
1030                             common::Clone(*moldElementBytes) - ExtentExpr{1}) /
1031                             common::Clone(*moldElementBytes))};
1032                     return Shape{MaybeExtentExpr{std::move(extent)}};
1033                   }
1034                 }
1035               }
1036             }
1037           }
1038         }
1039       }
1040     } else if (intrinsic->name == "transpose") {
1041       if (call.arguments().size() >= 1) {
1042         if (auto shape{(*this)(call.arguments().at(0))}) {
1043           if (shape->size() == 2) {
1044             std::swap((*shape)[0], (*shape)[1]);
1045             return shape;
1046           }
1047         }
1048       }
1049     } else if (intrinsic->name == "unpack") {
1050       if (call.arguments().size() >= 2) {
1051         return (*this)(call.arguments()[1]); // MASK=
1052       }
1053     } else if (intrinsic->characteristics.value().attrs.test(characteristics::
1054                        Procedure::Attr::NullPointer)) { // NULL(MOLD=)
1055       return (*this)(call.arguments());
1056     } else {
1057       // TODO: shapes of other non-elemental intrinsic results
1058     }
1059   }
1060   // The rank is always known even if the extents are not.
1061   return Shape(static_cast<std::size_t>(call.Rank()), MaybeExtentExpr{});
1062 }
1063 
1064 void GetShapeHelper::AccumulateExtent(
1065     ExtentExpr &result, ExtentExpr &&n) const {
1066   result = std::move(result) + std::move(n);
1067   if (context_) {
1068     // Fold during expression creation to avoid creating an expression so
1069     // large we can't evalute it without overflowing the stack.
1070     result = Fold(*context_, std::move(result));
1071   }
1072 }
1073 
1074 // Check conformance of the passed shapes.
1075 std::optional<bool> CheckConformance(parser::ContextualMessages &messages,
1076     const Shape &left, const Shape &right, CheckConformanceFlags::Flags flags,
1077     const char *leftIs, const char *rightIs) {
1078   int n{GetRank(left)};
1079   if (n == 0 && (flags & CheckConformanceFlags::LeftScalarExpandable)) {
1080     return true;
1081   }
1082   int rn{GetRank(right)};
1083   if (rn == 0 && (flags & CheckConformanceFlags::RightScalarExpandable)) {
1084     return true;
1085   }
1086   if (n != rn) {
1087     messages.Say("Rank of %1$s is %2$d, but %3$s has rank %4$d"_err_en_US,
1088         leftIs, n, rightIs, rn);
1089     return false;
1090   }
1091   for (int j{0}; j < n; ++j) {
1092     if (auto leftDim{ToInt64(left[j])}) {
1093       if (auto rightDim{ToInt64(right[j])}) {
1094         if (*leftDim != *rightDim) {
1095           messages.Say("Dimension %1$d of %2$s has extent %3$jd, "
1096                        "but %4$s has extent %5$jd"_err_en_US,
1097               j + 1, leftIs, *leftDim, rightIs, *rightDim);
1098           return false;
1099         }
1100       } else if (!(flags & CheckConformanceFlags::RightIsDeferredShape)) {
1101         return std::nullopt;
1102       }
1103     } else if (!(flags & CheckConformanceFlags::LeftIsDeferredShape)) {
1104       return std::nullopt;
1105     }
1106   }
1107   return true;
1108 }
1109 
1110 bool IncrementSubscripts(
1111     ConstantSubscripts &indices, const ConstantSubscripts &extents) {
1112   std::size_t rank(indices.size());
1113   CHECK(rank <= extents.size());
1114   for (std::size_t j{0}; j < rank; ++j) {
1115     if (extents[j] < 1) {
1116       return false;
1117     }
1118   }
1119   for (std::size_t j{0}; j < rank; ++j) {
1120     if (indices[j]++ < extents[j]) {
1121       return true;
1122     }
1123     indices[j] = 1;
1124   }
1125   return false;
1126 }
1127 
1128 } // namespace Fortran::evaluate
1129