xref: /llvm-project/flang/lib/Evaluate/fold-designator.cpp (revision 9696355484152eda5684e0ec6249f4c423f08e42)
1 //===-- lib/Evaluate/designate.cpp ------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Evaluate/fold-designator.h"
10 #include "flang/Semantics/tools.h"
11 
12 namespace Fortran::evaluate {
13 
14 DEFINE_DEFAULT_CONSTRUCTORS_AND_ASSIGNMENTS(OffsetSymbol)
15 
16 std::optional<OffsetSymbol> DesignatorFolder::FoldDesignator(
17     const Symbol &symbol, ConstantSubscript which) {
18   if (!getLastComponent_ && IsAllocatableOrPointer(symbol)) {
19     // A pointer may appear as a DATA statement object if it is the
20     // rightmost symbol in a designator and has no subscripts.
21     // An allocatable may appear if its initializer is NULL().
22     if (which > 0) {
23       isEmpty_ = true;
24     } else {
25       return OffsetSymbol{symbol, symbol.size()};
26     }
27   } else if (symbol.has<semantics::ObjectEntityDetails>() &&
28       !IsNamedConstant(symbol)) {
29     if (auto type{DynamicType::From(symbol)}) {
30       if (auto extents{GetConstantExtents(context_, symbol)}) {
31         if (auto bytes{ToInt64(
32                 type->MeasureSizeInBytes(context_, GetRank(*extents) > 0))}) {
33           OffsetSymbol result{symbol, static_cast<std::size_t>(*bytes)};
34           if (which < GetSize(*extents)) {
35             result.Augment(*bytes * which);
36             return result;
37           } else {
38             isEmpty_ = true;
39           }
40         }
41       }
42     }
43   }
44   return std::nullopt;
45 }
46 
47 std::optional<OffsetSymbol> DesignatorFolder::FoldDesignator(
48     const ArrayRef &x, ConstantSubscript which) {
49   const Symbol &array{x.base().GetLastSymbol()};
50   if (auto type{DynamicType::From(array)}) {
51     if (auto extents{GetConstantExtents(context_, array)}) {
52       if (auto bytes{ToInt64(type->MeasureSizeInBytes(context_, true))}) {
53         Shape lbs{GetLBOUNDs(context_, x.base())};
54         if (auto lowerBounds{AsConstantExtents(context_, lbs)}) {
55           std::optional<OffsetSymbol> result;
56           if (!x.base().IsSymbol() &&
57               x.base().GetComponent().base().Rank() > 0) {
58             // A(:)%B(1) - apply elementNumber_ to base
59             result = FoldDesignator(x.base(), which);
60             which = 0;
61           } else { // A(1)%B(:) - apply elementNumber_ to subscripts
62             result = FoldDesignator(x.base(), 0);
63           }
64           if (!result) {
65             return std::nullopt;
66           }
67           auto stride{*bytes};
68           int dim{0};
69           for (const Subscript &subscript : x.subscript()) {
70             ConstantSubscript lower{lowerBounds->at(dim)};
71             ConstantSubscript extent{extents->at(dim)};
72             ConstantSubscript upper{lower + extent - 1};
73             if (!common::visit(
74                     common::visitors{
75                         [&](const IndirectSubscriptIntegerExpr &expr) {
76                           auto folded{
77                               Fold(context_, common::Clone(expr.value()))};
78                           if (auto value{UnwrapConstantValue<SubscriptInteger>(
79                                   folded)}) {
80                             CHECK(value->Rank() <= 1);
81                             if (value->size() != 0) {
82                               // Apply subscript, possibly vector-valued
83                               auto quotient{which / value->size()};
84                               auto remainder{which - value->size() * quotient};
85                               ConstantSubscript at{
86                                   value->values().at(remainder).ToInt64()};
87                               if (at < lower || at > upper) {
88                                 isOutOfRange_ = true;
89                               }
90                               result->Augment((at - lower) * stride);
91                               which = quotient;
92                               return true;
93                             } else {
94                               isEmpty_ = true;
95                             }
96                           }
97                           return false;
98                         },
99                         [&](const Triplet &triplet) {
100                           auto start{ToInt64(Fold(context_,
101                               triplet.lower().value_or(ExtentExpr{lower})))};
102                           auto end{ToInt64(Fold(context_,
103                               triplet.upper().value_or(ExtentExpr{upper})))};
104                           auto step{ToInt64(Fold(context_, triplet.stride()))};
105                           if (start && end && step) {
106                             if (*step != 0) {
107                               ConstantSubscript range{
108                                   (*end - *start + *step) / *step};
109                               if (range > 0) {
110                                 auto quotient{which / range};
111                                 auto remainder{which - range * quotient};
112                                 auto j{*start + remainder * *step};
113                                 result->Augment((j - lower) * stride);
114                                 which = quotient;
115                                 return true;
116                               } else {
117                                 isEmpty_ = true;
118                               }
119                             }
120                           }
121                           return false;
122                         },
123                     },
124                     subscript.u)) {
125               return std::nullopt;
126             }
127             ++dim;
128             stride *= extent;
129           }
130           if (which > 0) {
131             isEmpty_ = true;
132           } else {
133             return result;
134           }
135         }
136       }
137     }
138   }
139   return std::nullopt;
140 }
141 
142 std::optional<OffsetSymbol> DesignatorFolder::FoldDesignator(
143     const Component &component, ConstantSubscript which) {
144   const Symbol &comp{component.GetLastSymbol()};
145   if (getLastComponent_) {
146     return FoldDesignator(comp, which);
147   } else {
148     const DataRef &base{component.base()};
149     std::optional<OffsetSymbol> baseResult, compResult;
150     if (base.Rank() == 0) { // A%X(:) - apply "which" to component
151       baseResult = FoldDesignator(base, 0);
152       compResult = FoldDesignator(comp, which);
153     } else { // A(:)%X - apply "which" to base
154       baseResult = FoldDesignator(base, which);
155       compResult = FoldDesignator(comp, 0);
156     }
157     if (baseResult && compResult) {
158       OffsetSymbol result{baseResult->symbol(), compResult->size()};
159       result.Augment(
160           baseResult->offset() + compResult->offset() + comp.offset());
161       return {std::move(result)};
162     } else {
163       return std::nullopt;
164     }
165   }
166 }
167 
168 std::optional<OffsetSymbol> DesignatorFolder::FoldDesignator(
169     const ComplexPart &z, ConstantSubscript which) {
170   if (auto result{FoldDesignator(z.complex(), which)}) {
171     result->set_size(result->size() >> 1);
172     if (z.part() == ComplexPart::Part::IM) {
173       result->Augment(result->size());
174     }
175     return result;
176   } else {
177     return std::nullopt;
178   }
179 }
180 
181 std::optional<OffsetSymbol> DesignatorFolder::FoldDesignator(
182     const DataRef &dataRef, ConstantSubscript which) {
183   return common::visit(
184       [&](const auto &x) { return FoldDesignator(x, which); }, dataRef.u);
185 }
186 
187 std::optional<OffsetSymbol> DesignatorFolder::FoldDesignator(
188     const NamedEntity &entity, ConstantSubscript which) {
189   return entity.IsSymbol() ? FoldDesignator(entity.GetLastSymbol(), which)
190                            : FoldDesignator(entity.GetComponent(), which);
191 }
192 
193 std::optional<OffsetSymbol> DesignatorFolder::FoldDesignator(
194     const CoarrayRef &, ConstantSubscript) {
195   return std::nullopt;
196 }
197 
198 std::optional<OffsetSymbol> DesignatorFolder::FoldDesignator(
199     const ProcedureDesignator &proc, ConstantSubscript which) {
200   if (const Symbol * symbol{proc.GetSymbol()}) {
201     if (const Component * component{proc.GetComponent()}) {
202       return FoldDesignator(*component, which);
203     } else if (which > 0) {
204       isEmpty_ = true;
205     } else {
206       return FoldDesignator(*symbol, 0);
207     }
208   }
209   return std::nullopt;
210 }
211 
212 // Conversions of offset symbols (back) to Designators
213 
214 // Reconstructs subscripts.
215 // "offset" is decremented in place to hold remaining component offset.
216 static std::optional<ArrayRef> OffsetToArrayRef(FoldingContext &context,
217     NamedEntity &&entity, const Shape &shape, const DynamicType &elementType,
218     ConstantSubscript &offset) {
219   auto extents{AsConstantExtents(context, shape)};
220   Shape lbs{GetRawLowerBounds(context, entity)};
221   auto lower{AsConstantExtents(context, lbs)};
222   auto elementBytes{ToInt64(elementType.MeasureSizeInBytes(context, true))};
223   if (!extents || HasNegativeExtent(*extents) || !lower || !elementBytes ||
224       *elementBytes <= 0) {
225     return std::nullopt;
226   }
227   int rank{GetRank(shape)};
228   CHECK(extents->size() == static_cast<std::size_t>(rank) &&
229       lower->size() == extents->size());
230   auto element{offset / static_cast<std::size_t>(*elementBytes)};
231   std::vector<Subscript> subscripts;
232   auto at{element};
233   for (int dim{0}; dim + 1 < rank; ++dim) {
234     auto extent{(*extents)[dim]};
235     if (extent <= 0) {
236       return std::nullopt;
237     }
238     auto quotient{at / extent};
239     auto remainder{at - quotient * extent};
240     subscripts.emplace_back(ExtentExpr{(*lower)[dim] + remainder});
241     at = quotient;
242   }
243   // This final subscript might be out of range for use in error reporting.
244   subscripts.emplace_back(ExtentExpr{(*lower)[rank - 1] + at});
245   offset -= element * static_cast<std::size_t>(*elementBytes);
246   return ArrayRef{std::move(entity), std::move(subscripts)};
247 }
248 
249 // Maps an offset back to a component, when unambiguous.
250 static const Symbol *OffsetToUniqueComponent(
251     const semantics::DerivedTypeSpec &spec, ConstantSubscript offset) {
252   const Symbol *result{nullptr};
253   if (const semantics::Scope * scope{spec.scope()}) {
254     for (const auto &pair : *scope) {
255       const Symbol &component{*pair.second};
256       if (offset >= static_cast<ConstantSubscript>(component.offset()) &&
257           offset < static_cast<ConstantSubscript>(
258                        component.offset() + component.size())) {
259         if (result) {
260           return nullptr; // MAP overlap or error recovery
261         }
262         result = &component;
263       }
264     }
265   }
266   return result;
267 }
268 
269 // Converts an offset into subscripts &/or component references.  Recursive.
270 // Any remaining offset is left in place in the "offset" reference argument.
271 static std::optional<DataRef> OffsetToDataRef(FoldingContext &context,
272     NamedEntity &&entity, ConstantSubscript &offset, std::size_t size) {
273   const Symbol &symbol{entity.GetLastSymbol()};
274   if (IsAllocatableOrPointer(symbol)) {
275     return entity.IsSymbol() ? DataRef{symbol}
276                              : DataRef{std::move(entity.GetComponent())};
277   } else if (std::optional<DynamicType> type{DynamicType::From(symbol)}) {
278     std::optional<DataRef> result;
279     if (!type->IsUnlimitedPolymorphic()) {
280       if (std::optional<Shape> shape{GetShape(context, symbol)}) {
281         if (GetRank(*shape) > 0) {
282           if (auto aref{OffsetToArrayRef(
283                   context, std::move(entity), *shape, *type, offset)}) {
284             result = DataRef{std::move(*aref)};
285           }
286         } else {
287           result = entity.IsSymbol()
288               ? DataRef{symbol}
289               : DataRef{std::move(entity.GetComponent())};
290         }
291         if (result && type->category() == TypeCategory::Derived &&
292             size <= result->GetLastSymbol().size()) {
293           if (const Symbol *
294               component{OffsetToUniqueComponent(
295                   type->GetDerivedTypeSpec(), offset)}) {
296             offset -= component->offset();
297             return OffsetToDataRef(context,
298                 NamedEntity{Component{std::move(*result), *component}}, offset,
299                 size);
300           }
301         }
302       }
303     }
304     return result;
305   } else {
306     return std::nullopt;
307   }
308 }
309 
310 // Reconstructs a Designator from a symbol, an offset, and a size.
311 // Returns a ProcedureDesignator in the case of a whole procedure pointer.
312 std::optional<Expr<SomeType>> OffsetToDesignator(FoldingContext &context,
313     const Symbol &baseSymbol, ConstantSubscript offset, std::size_t size) {
314   if (offset < 0) {
315     return std::nullopt;
316   } else if (std::optional<DataRef> dataRef{OffsetToDataRef(
317                  context, NamedEntity{baseSymbol}, offset, size)}) {
318     const Symbol &symbol{dataRef->GetLastSymbol()};
319     if (IsProcedurePointer(symbol)) {
320       if (std::holds_alternative<SymbolRef>(dataRef->u)) {
321         return Expr<SomeType>{ProcedureDesignator{symbol}};
322       } else if (auto *component{std::get_if<Component>(&dataRef->u)}) {
323         return Expr<SomeType>{ProcedureDesignator{std::move(*component)}};
324       }
325     } else if (std::optional<Expr<SomeType>> result{
326                    AsGenericExpr(std::move(*dataRef))}) {
327       if (IsAllocatableOrPointer(symbol)) {
328       } else if (auto type{DynamicType::From(symbol)}) {
329         if (auto elementBytes{
330                 ToInt64(type->MeasureSizeInBytes(context, true))}) {
331           if (auto *zExpr{std::get_if<Expr<SomeComplex>>(&result->u)}) {
332             if (size * 2 > static_cast<std::size_t>(*elementBytes)) {
333               return result;
334             } else if (offset == 0 || offset * 2 == *elementBytes) {
335               // Pick a COMPLEX component
336               auto part{
337                   offset == 0 ? ComplexPart::Part::RE : ComplexPart::Part::IM};
338               return common::visit(
339                   [&](const auto &z) -> std::optional<Expr<SomeType>> {
340                     using PartType = typename ResultType<decltype(z)>::Part;
341                     return AsGenericExpr(Designator<PartType>{ComplexPart{
342                         ExtractDataRef(std::move(*zExpr)).value(), part}});
343                   },
344                   zExpr->u);
345             }
346           } else if (auto *cExpr{
347                          std::get_if<Expr<SomeCharacter>>(&result->u)}) {
348             if (offset > 0 || size != static_cast<std::size_t>(*elementBytes)) {
349               // Select a substring
350               return common::visit(
351                   [&](const auto &x) -> std::optional<Expr<SomeType>> {
352                     using T = typename std::decay_t<decltype(x)>::Result;
353                     return AsGenericExpr(Designator<T>{
354                         Substring{ExtractDataRef(std::move(*cExpr)).value(),
355                             std::optional<Expr<SubscriptInteger>>{
356                                 1 + (offset / T::kind)},
357                             std::optional<Expr<SubscriptInteger>>{
358                                 1 + ((offset + size - 1) / T::kind)}}});
359                   },
360                   cExpr->u);
361             }
362           }
363         }
364       }
365       if (offset == 0) {
366         return result;
367       }
368     }
369   }
370   return std::nullopt;
371 }
372 
373 std::optional<Expr<SomeType>> OffsetToDesignator(
374     FoldingContext &context, const OffsetSymbol &offsetSymbol) {
375   return OffsetToDesignator(context, offsetSymbol.symbol(),
376       offsetSymbol.offset(), offsetSymbol.size());
377 }
378 
379 ConstantObjectPointer ConstantObjectPointer::From(
380     FoldingContext &context, const Expr<SomeType> &expr) {
381   auto extents{GetConstantExtents(context, expr)};
382   CHECK(extents);
383   std::optional<uint64_t> optElements{TotalElementCount(*extents)};
384   CHECK(optElements);
385   uint64_t elements{*optElements};
386   CHECK(elements > 0);
387   int rank{GetRank(*extents)};
388   ConstantSubscripts at(rank, 1);
389   ConstantObjectPointer::Dimensions dimensions(rank);
390   for (int j{0}; j < rank; ++j) {
391     dimensions[j].extent = (*extents)[j];
392   }
393   DesignatorFolder designatorFolder{context};
394   const Symbol *symbol{nullptr};
395   ConstantSubscript baseOffset{0};
396   std::size_t elementSize{0};
397   for (std::size_t j{0}; j < elements; ++j) {
398     auto folded{designatorFolder.FoldDesignator(expr)};
399     CHECK(folded);
400     if (j == 0) {
401       symbol = &folded->symbol();
402       baseOffset = folded->offset();
403       elementSize = folded->size();
404     } else {
405       CHECK(symbol == &folded->symbol());
406       CHECK(elementSize == folded->size());
407     }
408     int twoDim{-1};
409     for (int k{0}; k < rank; ++k) {
410       if (at[k] == 2 && twoDim == -1) {
411         twoDim = k;
412       } else if (at[k] != 1) {
413         twoDim = -2;
414       }
415     }
416     if (twoDim >= 0) {
417       // Exactly one subscript is a 2 and the rest are 1.
418       dimensions[twoDim].byteStride = folded->offset() - baseOffset;
419     }
420     ConstantSubscript checkOffset{baseOffset};
421     for (int k{0}; k < rank; ++k) {
422       checkOffset += (at[k] - 1) * dimensions[twoDim].byteStride;
423     }
424     CHECK(checkOffset == folded->offset());
425     CHECK(IncrementSubscripts(at, *extents) == (j + 1 < elements));
426   }
427   CHECK(!designatorFolder.FoldDesignator(expr));
428   return ConstantObjectPointer{
429       DEREF(symbol), elementSize, std::move(dimensions)};
430 }
431 } // namespace Fortran::evaluate
432