xref: /llvm-project/flang/runtime/transformational.cpp (revision 478e0b58605c4be16f1590f9b67889290ab45dab)
1 //===-- runtime/transformational.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // Implements the transformational intrinsic functions of Fortran 2018 that
10 // rearrange or duplicate data without (much) regard to type.  These are
11 // CSHIFT, EOSHIFT, PACK, RESHAPE, SPREAD, TRANSPOSE, and UNPACK.
12 //
13 // Many of these are defined in the 2018 standard with text that makes sense
14 // only if argument arrays have lower bounds of one.  Rather than interpret
15 // these cases as implying a hidden constraint, these implementations
16 // work with arbitrary lower bounds.  This may be technically an extension
17 // of the standard but it more likely to conform with its intent.
18 
19 #include "flang/Runtime/transformational.h"
20 #include "copy.h"
21 #include "terminator.h"
22 #include "tools.h"
23 #include "flang/Common/float128.h"
24 #include "flang/Runtime/descriptor.h"
25 
26 namespace Fortran::runtime {
27 
28 // Utility for CSHIFT & EOSHIFT rank > 1 cases that determines the shift count
29 // for each of the vector sections of the result.
30 class ShiftControl {
31 public:
32   RT_API_ATTRS ShiftControl(const Descriptor &s, Terminator &t, int dim)
33       : shift_{s}, terminator_{t}, shiftRank_{s.rank()}, dim_{dim} {}
34   RT_API_ATTRS void Init(const Descriptor &source, const char *which) {
35     int rank{source.rank()};
36     RUNTIME_CHECK(terminator_, shiftRank_ == 0 || shiftRank_ == rank - 1);
37     auto catAndKind{shift_.type().GetCategoryAndKind()};
38     RUNTIME_CHECK(
39         terminator_, catAndKind && catAndKind->first == TypeCategory::Integer);
40     shiftElemLen_ = catAndKind->second;
41     if (shiftRank_ > 0) {
42       int k{0};
43       for (int j{0}; j < rank; ++j) {
44         if (j + 1 != dim_) {
45           const Dimension &shiftDim{shift_.GetDimension(k)};
46           lb_[k++] = shiftDim.LowerBound();
47           if (shiftDim.Extent() != source.GetDimension(j).Extent()) {
48             terminator_.Crash("%s: on dimension %d, SHIFT= has extent %jd but "
49                               "SOURCE= has extent %jd",
50                 which, k, static_cast<std::intmax_t>(shiftDim.Extent()),
51                 static_cast<std::intmax_t>(source.GetDimension(j).Extent()));
52           }
53         }
54       }
55     } else {
56       shiftCount_ =
57           GetInt64(shift_.OffsetElement<char>(), shiftElemLen_, terminator_);
58     }
59   }
60   RT_API_ATTRS SubscriptValue GetShift(const SubscriptValue resultAt[]) const {
61     if (shiftRank_ > 0) {
62       SubscriptValue shiftAt[maxRank];
63       int k{0};
64       for (int j{0}; j < shiftRank_ + 1; ++j) {
65         if (j + 1 != dim_) {
66           shiftAt[k] = lb_[k] + resultAt[j] - 1;
67           ++k;
68         }
69       }
70       return GetInt64(
71           shift_.Element<char>(shiftAt), shiftElemLen_, terminator_);
72     } else {
73       return shiftCount_; // invariant count extracted in Init()
74     }
75   }
76 
77 private:
78   const Descriptor &shift_;
79   Terminator &terminator_;
80   int shiftRank_;
81   int dim_;
82   SubscriptValue lb_[maxRank];
83   std::size_t shiftElemLen_;
84   SubscriptValue shiftCount_{};
85 };
86 
87 // Fill an EOSHIFT result with default boundary values
88 static RT_API_ATTRS void DefaultInitialize(
89     const Descriptor &result, Terminator &terminator) {
90   auto catAndKind{result.type().GetCategoryAndKind()};
91   RUNTIME_CHECK(
92       terminator, catAndKind && catAndKind->first != TypeCategory::Derived);
93   std::size_t elementLen{result.ElementBytes()};
94   std::size_t bytes{result.Elements() * elementLen};
95   if (catAndKind->first == TypeCategory::Character) {
96     switch (int kind{catAndKind->second}) {
97     case 1:
98       Fortran::runtime::fill_n(result.OffsetElement<char>(), bytes, ' ');
99       break;
100     case 2:
101       Fortran::runtime::fill_n(result.OffsetElement<char16_t>(), bytes / 2,
102           static_cast<char16_t>(' '));
103       break;
104     case 4:
105       Fortran::runtime::fill_n(result.OffsetElement<char32_t>(), bytes / 4,
106           static_cast<char32_t>(' '));
107       break;
108     default:
109       terminator.Crash("not yet implemented: EOSHIFT: CHARACTER kind %d", kind);
110     }
111   } else {
112     std::memset(result.raw().base_addr, 0, bytes);
113   }
114 }
115 
116 static inline RT_API_ATTRS std::size_t AllocateResult(Descriptor &result,
117     const Descriptor &source, int rank, const SubscriptValue extent[],
118     Terminator &terminator, const char *function) {
119   std::size_t elementLen{source.ElementBytes()};
120   const DescriptorAddendum *sourceAddendum{source.Addendum()};
121   result.Establish(source.type(), elementLen, nullptr, rank, extent,
122       CFI_attribute_allocatable, sourceAddendum != nullptr);
123   if (sourceAddendum) {
124     *result.Addendum() = *sourceAddendum;
125   }
126   for (int j{0}; j < rank; ++j) {
127     result.GetDimension(j).SetBounds(1, extent[j]);
128   }
129   if (int stat{result.Allocate()}) {
130     terminator.Crash(
131         "%s: Could not allocate memory for result (stat=%d)", function, stat);
132   }
133   return elementLen;
134 }
135 
136 template <TypeCategory CAT, int KIND>
137 static inline RT_API_ATTRS std::size_t AllocateBesselResult(Descriptor &result,
138     int32_t n1, int32_t n2, Terminator &terminator, const char *function) {
139   int rank{1};
140   SubscriptValue extent[maxRank];
141   for (int j{0}; j < maxRank; j++) {
142     extent[j] = 0;
143   }
144   if (n1 <= n2) {
145     extent[0] = n2 - n1 + 1;
146   }
147 
148   std::size_t elementLen{Descriptor::BytesFor(CAT, KIND)};
149   result.Establish(TypeCode{CAT, KIND}, elementLen, nullptr, rank, extent,
150       CFI_attribute_allocatable, false);
151   for (int j{0}; j < rank; ++j) {
152     result.GetDimension(j).SetBounds(1, extent[j]);
153   }
154   if (int stat{result.Allocate()}) {
155     terminator.Crash(
156         "%s: Could not allocate memory for result (stat=%d)", function, stat);
157   }
158   return elementLen;
159 }
160 
161 template <TypeCategory CAT, int KIND>
162 static inline RT_API_ATTRS void DoBesselJn(Descriptor &result, int32_t n1,
163     int32_t n2, CppTypeFor<CAT, KIND> x, CppTypeFor<CAT, KIND> bn2,
164     CppTypeFor<CAT, KIND> bn2_1, const char *sourceFile, int line) {
165   Terminator terminator{sourceFile, line};
166   AllocateBesselResult<CAT, KIND>(result, n1, n2, terminator, "BESSEL_JN");
167 
168   // The standard requires that n1 and n2 be non-negative. However, some other
169   // compilers generate results even when n1 and/or n2 are negative. For now,
170   // we also do not enforce the non-negativity constraint.
171   if (n2 < n1) {
172     return;
173   }
174 
175   SubscriptValue at[maxRank];
176   for (int j{0}; j < maxRank; ++j) {
177     at[j] = 0;
178   }
179 
180   // if n2 >= n1, there will be at least one element in the result.
181   at[0] = n2 - n1 + 1;
182   *result.Element<CppTypeFor<CAT, KIND>>(at) = bn2;
183 
184   if (n2 == n1) {
185     return;
186   }
187 
188   at[0] = n2 - n1;
189   *result.Element<CppTypeFor<CAT, KIND>>(at) = bn2_1;
190 
191   // Bessel functions of the first kind are stable for a backward recursion
192   // (see https://dlmf.nist.gov/10.74.iv and https://dlmf.nist.gov/10.6.E1).
193   //
194   //     J(n-1, x) = (2.0 / x) * n * J(n, x) - J(n+1, x)
195   //
196   // which is equivalent to
197   //
198   //     J(n, x) = (2.0 / x) * (n + 1) * J(n+1, x) - J(n+2, x)
199   //
200   CppTypeFor<CAT, KIND> bn_2 = bn2;
201   CppTypeFor<CAT, KIND> bn_1 = bn2_1;
202   CppTypeFor<CAT, KIND> twoOverX = 2.0 / x;
203   for (int n{n2 - 2}; n >= n1; --n) {
204     auto bn = twoOverX * (n + 1) * bn_1 - bn_2;
205 
206     at[0] = n - n1 + 1;
207     *result.Element<CppTypeFor<CAT, KIND>>(at) = bn;
208 
209     bn_2 = bn_1;
210     bn_1 = bn;
211   }
212 }
213 
214 template <TypeCategory CAT, int KIND>
215 static inline RT_API_ATTRS void DoBesselJnX0(Descriptor &result, int32_t n1,
216     int32_t n2, const char *sourceFile, int line) {
217   Terminator terminator{sourceFile, line};
218   AllocateBesselResult<CAT, KIND>(result, n1, n2, terminator, "BESSEL_JN");
219 
220   // The standard requires that n1 and n2 be non-negative. However, some other
221   // compilers generate results even when n1 and/or n2 are negative. For now,
222   // we also do not enforce the non-negativity constraint.
223   if (n2 < n1) {
224     return;
225   }
226 
227   SubscriptValue at[maxRank];
228   for (int j{0}; j < maxRank; ++j) {
229     at[j] = 0;
230   }
231 
232   // J(0, 0.0) = 1.0, when n == 0.
233   // J(n, 0.0) = 0.0, when n > 0.
234   at[0] = 1;
235   *result.Element<CppTypeFor<CAT, KIND>>(at) = (n1 == 0) ? 1.0 : 0.0;
236   for (int j{2}; j <= n2 - n1 + 1; ++j) {
237     at[0] = j;
238     *result.Element<CppTypeFor<CAT, KIND>>(at) = 0.0;
239   }
240 }
241 
242 template <TypeCategory CAT, int KIND>
243 static inline RT_API_ATTRS void DoBesselYn(Descriptor &result, int32_t n1,
244     int32_t n2, CppTypeFor<CAT, KIND> x, CppTypeFor<CAT, KIND> bn1,
245     CppTypeFor<CAT, KIND> bn1_1, const char *sourceFile, int line) {
246   Terminator terminator{sourceFile, line};
247   AllocateBesselResult<CAT, KIND>(result, n1, n2, terminator, "BESSEL_YN");
248 
249   // The standard requires that n1 and n2 be non-negative. However, some other
250   // compilers generate results even when n1 and/or n2 are negative. For now,
251   // we also do not enforce the non-negativity constraint.
252   if (n2 < n1) {
253     return;
254   }
255 
256   SubscriptValue at[maxRank];
257   for (int j{0}; j < maxRank; ++j) {
258     at[j] = 0;
259   }
260 
261   // if n2 >= n1, there will be at least one element in the result.
262   at[0] = 1;
263   *result.Element<CppTypeFor<CAT, KIND>>(at) = bn1;
264 
265   if (n2 == n1) {
266     return;
267   }
268 
269   at[0] = 2;
270   *result.Element<CppTypeFor<CAT, KIND>>(at) = bn1_1;
271 
272   // Bessel functions of the second kind are stable for a forward recursion
273   // (see https://dlmf.nist.gov/10.74.iv and https://dlmf.nist.gov/10.6.E1).
274   //
275   //     Y(n+1, x) = (2.0 / x) * n * Y(n, x) - Y(n-1, x)
276   //
277   // which is equivalent to
278   //
279   //     Y(n, x) = (2.0 / x) * (n - 1) * Y(n-1, x) - Y(n-2, x)
280   //
281   CppTypeFor<CAT, KIND> bn_2 = bn1;
282   CppTypeFor<CAT, KIND> bn_1 = bn1_1;
283   CppTypeFor<CAT, KIND> twoOverX = 2.0 / x;
284   for (int n{n1 + 2}; n <= n2; ++n) {
285     auto bn = twoOverX * (n - 1) * bn_1 - bn_2;
286 
287     at[0] = n - n1 + 1;
288     *result.Element<CppTypeFor<CAT, KIND>>(at) = bn;
289 
290     bn_2 = bn_1;
291     bn_1 = bn;
292   }
293 }
294 
295 template <TypeCategory CAT, int KIND>
296 static inline RT_API_ATTRS void DoBesselYnX0(Descriptor &result, int32_t n1,
297     int32_t n2, const char *sourceFile, int line) {
298   Terminator terminator{sourceFile, line};
299   AllocateBesselResult<CAT, KIND>(result, n1, n2, terminator, "BESSEL_YN");
300 
301   // The standard requires that n1 and n2 be non-negative. However, some other
302   // compilers generate results even when n1 and/or n2 are negative. For now,
303   // we also do not enforce the non-negativity constraint.
304   if (n2 < n1) {
305     return;
306   }
307 
308   SubscriptValue at[maxRank];
309   for (int j{0}; j < maxRank; ++j) {
310     at[j] = 0;
311   }
312 
313   // Y(n, 0.0) = -Inf, when n >= 0
314   for (int j{1}; j <= n2 - n1 + 1; ++j) {
315     at[0] = j;
316     *result.Element<CppTypeFor<CAT, KIND>>(at) =
317         -std::numeric_limits<CppTypeFor<CAT, KIND>>::infinity();
318   }
319 }
320 
321 extern "C" {
322 RT_EXT_API_GROUP_BEGIN
323 
324 // BESSEL_JN
325 // TODO: REAL(2 & 3)
326 void RTDEF(BesselJn_4)(Descriptor &result, int32_t n1, int32_t n2,
327     CppTypeFor<TypeCategory::Real, 4> x, CppTypeFor<TypeCategory::Real, 4> bn2,
328     CppTypeFor<TypeCategory::Real, 4> bn2_1, const char *sourceFile, int line) {
329   DoBesselJn<TypeCategory::Real, 4>(
330       result, n1, n2, x, bn2, bn2_1, sourceFile, line);
331 }
332 
333 void RTDEF(BesselJn_8)(Descriptor &result, int32_t n1, int32_t n2,
334     CppTypeFor<TypeCategory::Real, 8> x, CppTypeFor<TypeCategory::Real, 8> bn2,
335     CppTypeFor<TypeCategory::Real, 8> bn2_1, const char *sourceFile, int line) {
336   DoBesselJn<TypeCategory::Real, 8>(
337       result, n1, n2, x, bn2, bn2_1, sourceFile, line);
338 }
339 
340 #if LDBL_MANT_DIG == 64
341 void RTDEF(BesselJn_10)(Descriptor &result, int32_t n1, int32_t n2,
342     CppTypeFor<TypeCategory::Real, 10> x,
343     CppTypeFor<TypeCategory::Real, 10> bn2,
344     CppTypeFor<TypeCategory::Real, 10> bn2_1, const char *sourceFile,
345     int line) {
346   DoBesselJn<TypeCategory::Real, 10>(
347       result, n1, n2, x, bn2, bn2_1, sourceFile, line);
348 }
349 #endif
350 
351 #if LDBL_MANT_DIG == 113 || HAS_FLOAT128
352 void RTDEF(BesselJn_16)(Descriptor &result, int32_t n1, int32_t n2,
353     CppTypeFor<TypeCategory::Real, 16> x,
354     CppTypeFor<TypeCategory::Real, 16> bn2,
355     CppTypeFor<TypeCategory::Real, 16> bn2_1, const char *sourceFile,
356     int line) {
357   DoBesselJn<TypeCategory::Real, 16>(
358       result, n1, n2, x, bn2, bn2_1, sourceFile, line);
359 }
360 #endif
361 
362 // TODO: REAL(2 & 3)
363 void RTDEF(BesselJnX0_4)(Descriptor &result, int32_t n1, int32_t n2,
364     const char *sourceFile, int line) {
365   DoBesselJnX0<TypeCategory::Real, 4>(result, n1, n2, sourceFile, line);
366 }
367 
368 void RTDEF(BesselJnX0_8)(Descriptor &result, int32_t n1, int32_t n2,
369     const char *sourceFile, int line) {
370   DoBesselJnX0<TypeCategory::Real, 8>(result, n1, n2, sourceFile, line);
371 }
372 
373 #if LDBL_MANT_DIG == 64
374 void RTDEF(BesselJnX0_10)(Descriptor &result, int32_t n1, int32_t n2,
375     const char *sourceFile, int line) {
376   DoBesselJnX0<TypeCategory::Real, 10>(result, n1, n2, sourceFile, line);
377 }
378 #endif
379 
380 #if LDBL_MANT_DIG == 113 || HAS_FLOAT128
381 void RTDEF(BesselJnX0_16)(Descriptor &result, int32_t n1, int32_t n2,
382     const char *sourceFile, int line) {
383   DoBesselJnX0<TypeCategory::Real, 16>(result, n1, n2, sourceFile, line);
384 }
385 #endif
386 
387 // BESSEL_YN
388 // TODO: REAL(2 & 3)
389 void RTDEF(BesselYn_4)(Descriptor &result, int32_t n1, int32_t n2,
390     CppTypeFor<TypeCategory::Real, 4> x, CppTypeFor<TypeCategory::Real, 4> bn1,
391     CppTypeFor<TypeCategory::Real, 4> bn1_1, const char *sourceFile, int line) {
392   DoBesselYn<TypeCategory::Real, 4>(
393       result, n1, n2, x, bn1, bn1_1, sourceFile, line);
394 }
395 
396 void RTDEF(BesselYn_8)(Descriptor &result, int32_t n1, int32_t n2,
397     CppTypeFor<TypeCategory::Real, 8> x, CppTypeFor<TypeCategory::Real, 8> bn1,
398     CppTypeFor<TypeCategory::Real, 8> bn1_1, const char *sourceFile, int line) {
399   DoBesselYn<TypeCategory::Real, 8>(
400       result, n1, n2, x, bn1, bn1_1, sourceFile, line);
401 }
402 
403 #if LDBL_MANT_DIG == 64
404 void RTDEF(BesselYn_10)(Descriptor &result, int32_t n1, int32_t n2,
405     CppTypeFor<TypeCategory::Real, 10> x,
406     CppTypeFor<TypeCategory::Real, 10> bn1,
407     CppTypeFor<TypeCategory::Real, 10> bn1_1, const char *sourceFile,
408     int line) {
409   DoBesselYn<TypeCategory::Real, 10>(
410       result, n1, n2, x, bn1, bn1_1, sourceFile, line);
411 }
412 #endif
413 
414 #if LDBL_MANT_DIG == 113 || HAS_FLOAT128
415 void RTDEF(BesselYn_16)(Descriptor &result, int32_t n1, int32_t n2,
416     CppTypeFor<TypeCategory::Real, 16> x,
417     CppTypeFor<TypeCategory::Real, 16> bn1,
418     CppTypeFor<TypeCategory::Real, 16> bn1_1, const char *sourceFile,
419     int line) {
420   DoBesselYn<TypeCategory::Real, 16>(
421       result, n1, n2, x, bn1, bn1_1, sourceFile, line);
422 }
423 #endif
424 
425 // TODO: REAL(2 & 3)
426 void RTDEF(BesselYnX0_4)(Descriptor &result, int32_t n1, int32_t n2,
427     const char *sourceFile, int line) {
428   DoBesselYnX0<TypeCategory::Real, 4>(result, n1, n2, sourceFile, line);
429 }
430 
431 void RTDEF(BesselYnX0_8)(Descriptor &result, int32_t n1, int32_t n2,
432     const char *sourceFile, int line) {
433   DoBesselYnX0<TypeCategory::Real, 8>(result, n1, n2, sourceFile, line);
434 }
435 
436 #if LDBL_MANT_DIG == 64
437 void RTDEF(BesselYnX0_10)(Descriptor &result, int32_t n1, int32_t n2,
438     const char *sourceFile, int line) {
439   DoBesselYnX0<TypeCategory::Real, 10>(result, n1, n2, sourceFile, line);
440 }
441 #endif
442 
443 #if LDBL_MANT_DIG == 113 || HAS_FLOAT128
444 void RTDEF(BesselYnX0_16)(Descriptor &result, int32_t n1, int32_t n2,
445     const char *sourceFile, int line) {
446   DoBesselYnX0<TypeCategory::Real, 16>(result, n1, n2, sourceFile, line);
447 }
448 #endif
449 
450 // CSHIFT where rank of ARRAY argument > 1
451 void RTDEF(Cshift)(Descriptor &result, const Descriptor &source,
452     const Descriptor &shift, int dim, const char *sourceFile, int line) {
453   Terminator terminator{sourceFile, line};
454   int rank{source.rank()};
455   RUNTIME_CHECK(terminator, rank > 1);
456   if (dim < 1 || dim > rank) {
457     terminator.Crash(
458         "CSHIFT: DIM=%d must be >= 1 and <= SOURCE= rank %d", dim, rank);
459   }
460   ShiftControl shiftControl{shift, terminator, dim};
461   shiftControl.Init(source, "CSHIFT");
462   SubscriptValue extent[maxRank];
463   source.GetShape(extent);
464   AllocateResult(result, source, rank, extent, terminator, "CSHIFT");
465   SubscriptValue resultAt[maxRank];
466   for (int j{0}; j < rank; ++j) {
467     resultAt[j] = 1;
468   }
469   SubscriptValue sourceLB[maxRank];
470   source.GetLowerBounds(sourceLB);
471   SubscriptValue dimExtent{extent[dim - 1]};
472   SubscriptValue dimLB{sourceLB[dim - 1]};
473   SubscriptValue &resDim{resultAt[dim - 1]};
474   for (std::size_t n{result.Elements()}; n > 0; n -= dimExtent) {
475     SubscriptValue shiftCount{shiftControl.GetShift(resultAt)};
476     SubscriptValue sourceAt[maxRank];
477     for (int j{0}; j < rank; ++j) {
478       sourceAt[j] = sourceLB[j] + resultAt[j] - 1;
479     }
480     SubscriptValue &sourceDim{sourceAt[dim - 1]};
481     sourceDim = dimLB + shiftCount % dimExtent;
482     if (sourceDim < dimLB) {
483       sourceDim += dimExtent;
484     }
485     for (resDim = 1; resDim <= dimExtent; ++resDim) {
486       CopyElement(result, resultAt, source, sourceAt, terminator);
487       if (++sourceDim == dimLB + dimExtent) {
488         sourceDim = dimLB;
489       }
490     }
491     result.IncrementSubscripts(resultAt);
492   }
493 }
494 
495 // CSHIFT where rank of ARRAY argument == 1
496 void RTDEF(CshiftVector)(Descriptor &result, const Descriptor &source,
497     std::int64_t shift, const char *sourceFile, int line) {
498   Terminator terminator{sourceFile, line};
499   RUNTIME_CHECK(terminator, source.rank() == 1);
500   const Dimension &sourceDim{source.GetDimension(0)};
501   SubscriptValue extent{sourceDim.Extent()};
502   AllocateResult(result, source, 1, &extent, terminator, "CSHIFT");
503   SubscriptValue lb{sourceDim.LowerBound()};
504   for (SubscriptValue j{0}; j < extent; ++j) {
505     SubscriptValue resultAt{1 + j};
506     SubscriptValue sourceAt{lb + (j + shift) % extent};
507     if (sourceAt < lb) {
508       sourceAt += extent;
509     }
510     CopyElement(result, &resultAt, source, &sourceAt, terminator);
511   }
512 }
513 
514 // EOSHIFT of rank > 1
515 void RTDEF(Eoshift)(Descriptor &result, const Descriptor &source,
516     const Descriptor &shift, const Descriptor *boundary, int dim,
517     const char *sourceFile, int line) {
518   Terminator terminator{sourceFile, line};
519   SubscriptValue extent[maxRank];
520   int rank{source.GetShape(extent)};
521   RUNTIME_CHECK(terminator, rank > 1);
522   if (dim < 1 || dim > rank) {
523     terminator.Crash(
524         "EOSHIFT: DIM=%d must be >= 1 and <= SOURCE= rank %d", dim, rank);
525   }
526   std::size_t elementLen{
527       AllocateResult(result, source, rank, extent, terminator, "EOSHIFT")};
528   int boundaryRank{-1};
529   if (boundary) {
530     boundaryRank = boundary->rank();
531     RUNTIME_CHECK(terminator, boundaryRank == 0 || boundaryRank == rank - 1);
532     RUNTIME_CHECK(terminator, boundary->type() == source.type());
533     if (boundary->ElementBytes() != elementLen) {
534       terminator.Crash("EOSHIFT: BOUNDARY= has element byte length %zd, but "
535                        "SOURCE= has length %zd",
536           boundary->ElementBytes(), elementLen);
537     }
538     if (boundaryRank > 0) {
539       int k{0};
540       for (int j{0}; j < rank; ++j) {
541         if (j != dim - 1) {
542           if (boundary->GetDimension(k).Extent() != extent[j]) {
543             terminator.Crash("EOSHIFT: BOUNDARY= has extent %jd on dimension "
544                              "%d but must conform with extent %jd of SOURCE=",
545                 static_cast<std::intmax_t>(boundary->GetDimension(k).Extent()),
546                 k + 1, static_cast<std::intmax_t>(extent[j]));
547           }
548           ++k;
549         }
550       }
551     }
552   }
553   ShiftControl shiftControl{shift, terminator, dim};
554   shiftControl.Init(source, "EOSHIFT");
555   SubscriptValue resultAt[maxRank];
556   for (int j{0}; j < rank; ++j) {
557     resultAt[j] = 1;
558   }
559   if (!boundary) {
560     DefaultInitialize(result, terminator);
561   }
562   SubscriptValue sourceLB[maxRank];
563   source.GetLowerBounds(sourceLB);
564   SubscriptValue boundaryAt[maxRank];
565   if (boundaryRank > 0) {
566     boundary->GetLowerBounds(boundaryAt);
567   }
568   SubscriptValue dimExtent{extent[dim - 1]};
569   SubscriptValue dimLB{sourceLB[dim - 1]};
570   SubscriptValue &resDim{resultAt[dim - 1]};
571   for (std::size_t n{result.Elements()}; n > 0; n -= dimExtent) {
572     SubscriptValue shiftCount{shiftControl.GetShift(resultAt)};
573     SubscriptValue sourceAt[maxRank];
574     for (int j{0}; j < rank; ++j) {
575       sourceAt[j] = sourceLB[j] + resultAt[j] - 1;
576     }
577     SubscriptValue &sourceDim{sourceAt[dim - 1]};
578     sourceDim = dimLB + shiftCount;
579     for (resDim = 1; resDim <= dimExtent; ++resDim) {
580       if (sourceDim >= dimLB && sourceDim < dimLB + dimExtent) {
581         CopyElement(result, resultAt, source, sourceAt, terminator);
582       } else if (boundary) {
583         CopyElement(result, resultAt, *boundary, boundaryAt, terminator);
584       }
585       ++sourceDim;
586     }
587     result.IncrementSubscripts(resultAt);
588     if (boundaryRank > 0) {
589       boundary->IncrementSubscripts(boundaryAt);
590     }
591   }
592 }
593 
594 // EOSHIFT of vector
595 void RTDEF(EoshiftVector)(Descriptor &result, const Descriptor &source,
596     std::int64_t shift, const Descriptor *boundary, const char *sourceFile,
597     int line) {
598   Terminator terminator{sourceFile, line};
599   RUNTIME_CHECK(terminator, source.rank() == 1);
600   SubscriptValue extent{source.GetDimension(0).Extent()};
601   std::size_t elementLen{
602       AllocateResult(result, source, 1, &extent, terminator, "EOSHIFT")};
603   if (boundary) {
604     RUNTIME_CHECK(terminator, boundary->rank() == 0);
605     RUNTIME_CHECK(terminator, boundary->type() == source.type());
606     if (boundary->ElementBytes() != elementLen) {
607       terminator.Crash("EOSHIFT: BOUNDARY= has element byte length %zd but "
608                        "SOURCE= has length %zd",
609           boundary->ElementBytes(), elementLen);
610     }
611   }
612   if (!boundary) {
613     DefaultInitialize(result, terminator);
614   }
615   SubscriptValue lb{source.GetDimension(0).LowerBound()};
616   for (SubscriptValue j{1}; j <= extent; ++j) {
617     SubscriptValue sourceAt{lb + j - 1 + shift};
618     if (sourceAt >= lb && sourceAt < lb + extent) {
619       CopyElement(result, &j, source, &sourceAt, terminator);
620     } else if (boundary) {
621       CopyElement(result, &j, *boundary, 0, terminator);
622     }
623   }
624 }
625 
626 // PACK
627 void RTDEF(Pack)(Descriptor &result, const Descriptor &source,
628     const Descriptor &mask, const Descriptor *vector, const char *sourceFile,
629     int line) {
630   Terminator terminator{sourceFile, line};
631   CheckConformability(source, mask, terminator, "PACK", "ARRAY=", "MASK=");
632   auto maskType{mask.type().GetCategoryAndKind()};
633   RUNTIME_CHECK(
634       terminator, maskType && maskType->first == TypeCategory::Logical);
635   SubscriptValue trues{0};
636   if (mask.rank() == 0) {
637     if (IsLogicalElementTrue(mask, nullptr)) {
638       trues = source.Elements();
639     }
640   } else {
641     SubscriptValue maskAt[maxRank];
642     mask.GetLowerBounds(maskAt);
643     for (std::size_t n{mask.Elements()}; n > 0; --n) {
644       if (IsLogicalElementTrue(mask, maskAt)) {
645         ++trues;
646       }
647       mask.IncrementSubscripts(maskAt);
648     }
649   }
650   SubscriptValue extent{trues};
651   if (vector) {
652     RUNTIME_CHECK(terminator, vector->rank() == 1);
653     RUNTIME_CHECK(terminator, source.type() == vector->type());
654     if (source.ElementBytes() != vector->ElementBytes()) {
655       terminator.Crash("PACK: SOURCE= has element byte length %zd, but VECTOR= "
656                        "has length %zd",
657           source.ElementBytes(), vector->ElementBytes());
658     }
659     extent = vector->GetDimension(0).Extent();
660     if (extent < trues) {
661       terminator.Crash("PACK: VECTOR= has extent %jd but there are %jd MASK= "
662                        "elements that are .TRUE.",
663           static_cast<std::intmax_t>(extent),
664           static_cast<std::intmax_t>(trues));
665     }
666   }
667   AllocateResult(result, source, 1, &extent, terminator, "PACK");
668   SubscriptValue sourceAt[maxRank], resultAt{1};
669   source.GetLowerBounds(sourceAt);
670   if (mask.rank() == 0) {
671     if (IsLogicalElementTrue(mask, nullptr)) {
672       for (SubscriptValue n{trues}; n > 0; --n) {
673         CopyElement(result, &resultAt, source, sourceAt, terminator);
674         ++resultAt;
675         source.IncrementSubscripts(sourceAt);
676       }
677     }
678   } else {
679     SubscriptValue maskAt[maxRank];
680     mask.GetLowerBounds(maskAt);
681     for (std::size_t n{source.Elements()}; n > 0; --n) {
682       if (IsLogicalElementTrue(mask, maskAt)) {
683         CopyElement(result, &resultAt, source, sourceAt, terminator);
684         ++resultAt;
685       }
686       source.IncrementSubscripts(sourceAt);
687       mask.IncrementSubscripts(maskAt);
688     }
689   }
690   if (vector) {
691     SubscriptValue vectorAt{
692         vector->GetDimension(0).LowerBound() + resultAt - 1};
693     for (; resultAt <= extent; ++resultAt, ++vectorAt) {
694       CopyElement(result, &resultAt, *vector, &vectorAt, terminator);
695     }
696   }
697 }
698 
699 // RESHAPE
700 // F2018 16.9.163
701 void RTDEF(Reshape)(Descriptor &result, const Descriptor &source,
702     const Descriptor &shape, const Descriptor *pad, const Descriptor *order,
703     const char *sourceFile, int line) {
704   // Compute and check the rank of the result.
705   Terminator terminator{sourceFile, line};
706   RUNTIME_CHECK(terminator, shape.rank() == 1);
707   RUNTIME_CHECK(terminator, shape.type().IsInteger());
708   SubscriptValue resultRank{shape.GetDimension(0).Extent()};
709   if (resultRank < 0 || resultRank > static_cast<SubscriptValue>(maxRank)) {
710     terminator.Crash(
711         "RESHAPE: SHAPE= vector length %jd implies a bad result rank",
712         static_cast<std::intmax_t>(resultRank));
713   }
714 
715   // Extract and check the shape of the result; compute its element count.
716   SubscriptValue resultExtent[maxRank];
717   std::size_t shapeElementBytes{shape.ElementBytes()};
718   std::size_t resultElements{1};
719   SubscriptValue shapeSubscript{shape.GetDimension(0).LowerBound()};
720   for (int j{0}; j < resultRank; ++j, ++shapeSubscript) {
721     resultExtent[j] = GetInt64(
722         shape.Element<char>(&shapeSubscript), shapeElementBytes, terminator);
723     if (resultExtent[j] < 0) {
724       terminator.Crash("RESHAPE: bad value for SHAPE(%d)=%jd", j + 1,
725           static_cast<std::intmax_t>(resultExtent[j]));
726     }
727     resultElements *= resultExtent[j];
728   }
729 
730   // Check that there are sufficient elements in the SOURCE=, or that
731   // the optional PAD= argument is present and nonempty.
732   std::size_t elementBytes{source.ElementBytes()};
733   std::size_t sourceElements{source.Elements()};
734   std::size_t padElements{pad ? pad->Elements() : 0};
735   if (resultElements > sourceElements) {
736     if (padElements <= 0) {
737       terminator.Crash(
738           "RESHAPE: not enough elements, need %zd but only have %zd",
739           resultElements, sourceElements);
740     }
741     if (pad->ElementBytes() != elementBytes) {
742       terminator.Crash("RESHAPE: PAD= has element byte length %zd but SOURCE= "
743                        "has length %zd",
744           pad->ElementBytes(), elementBytes);
745     }
746   }
747 
748   // Extract and check the optional ORDER= argument, which must be a
749   // permutation of [1..resultRank].
750   int dimOrder[maxRank];
751   if (order) {
752     RUNTIME_CHECK(terminator, order->rank() == 1);
753     RUNTIME_CHECK(terminator, order->type().IsInteger());
754     if (order->GetDimension(0).Extent() != resultRank) {
755       terminator.Crash("RESHAPE: the extent of ORDER (%jd) must match the rank"
756                        " of the SHAPE (%d)",
757           static_cast<std::intmax_t>(order->GetDimension(0).Extent()),
758           resultRank);
759     }
760     std::uint64_t values{0};
761     SubscriptValue orderSubscript{order->GetDimension(0).LowerBound()};
762     std::size_t orderElementBytes{order->ElementBytes()};
763     for (SubscriptValue j{0}; j < resultRank; ++j, ++orderSubscript) {
764       auto k{GetInt64(order->Element<char>(&orderSubscript), orderElementBytes,
765           terminator)};
766       if (k < 1 || k > resultRank || ((values >> k) & 1)) {
767         terminator.Crash("RESHAPE: bad value for ORDER element (%jd)",
768             static_cast<std::intmax_t>(k));
769       }
770       values |= std::uint64_t{1} << k;
771       dimOrder[j] = k - 1;
772     }
773   } else {
774     for (int j{0}; j < resultRank; ++j) {
775       dimOrder[j] = j;
776     }
777   }
778 
779   // Allocate result descriptor
780   AllocateResult(
781       result, source, resultRank, resultExtent, terminator, "RESHAPE");
782 
783   // Populate the result's elements.
784   SubscriptValue resultSubscript[maxRank];
785   result.GetLowerBounds(resultSubscript);
786   SubscriptValue sourceSubscript[maxRank];
787   source.GetLowerBounds(sourceSubscript);
788   std::size_t resultElement{0};
789   std::size_t elementsFromSource{std::min(resultElements, sourceElements)};
790   for (; resultElement < elementsFromSource; ++resultElement) {
791     CopyElement(result, resultSubscript, source, sourceSubscript, terminator);
792     source.IncrementSubscripts(sourceSubscript);
793     result.IncrementSubscripts(resultSubscript, dimOrder);
794   }
795   if (resultElement < resultElements) {
796     // Remaining elements come from the optional PAD= argument.
797     SubscriptValue padSubscript[maxRank];
798     pad->GetLowerBounds(padSubscript);
799     for (; resultElement < resultElements; ++resultElement) {
800       CopyElement(result, resultSubscript, *pad, padSubscript, terminator);
801       pad->IncrementSubscripts(padSubscript);
802       result.IncrementSubscripts(resultSubscript, dimOrder);
803     }
804   }
805 }
806 
807 // SPREAD
808 void RTDEF(Spread)(Descriptor &result, const Descriptor &source, int dim,
809     std::int64_t ncopies, const char *sourceFile, int line) {
810   Terminator terminator{sourceFile, line};
811   int rank{source.rank() + 1};
812   RUNTIME_CHECK(terminator, rank <= maxRank);
813   if (dim < 1 || dim > rank) {
814     terminator.Crash("SPREAD: DIM=%d argument for rank-%d source array "
815                      "must be greater than 1 and less than or equal to %d",
816         dim, rank - 1, rank);
817   }
818   ncopies = std::max<std::int64_t>(ncopies, 0);
819   SubscriptValue extent[maxRank];
820   int k{0};
821   for (int j{0}; j < rank; ++j) {
822     extent[j] = j == dim - 1 ? ncopies : source.GetDimension(k++).Extent();
823   }
824   AllocateResult(result, source, rank, extent, terminator, "SPREAD");
825   SubscriptValue resultAt[maxRank];
826   for (int j{0}; j < rank; ++j) {
827     resultAt[j] = 1;
828   }
829   SubscriptValue &resultDim{resultAt[dim - 1]};
830   SubscriptValue sourceAt[maxRank];
831   source.GetLowerBounds(sourceAt);
832   for (std::size_t n{result.Elements()}; n > 0; n -= ncopies) {
833     for (resultDim = 1; resultDim <= ncopies; ++resultDim) {
834       CopyElement(result, resultAt, source, sourceAt, terminator);
835     }
836     result.IncrementSubscripts(resultAt);
837     source.IncrementSubscripts(sourceAt);
838   }
839 }
840 
841 // TRANSPOSE
842 void RTDEF(Transpose)(Descriptor &result, const Descriptor &matrix,
843     const char *sourceFile, int line) {
844   Terminator terminator{sourceFile, line};
845   RUNTIME_CHECK(terminator, matrix.rank() == 2);
846   SubscriptValue extent[2]{
847       matrix.GetDimension(1).Extent(), matrix.GetDimension(0).Extent()};
848   AllocateResult(result, matrix, 2, extent, terminator, "TRANSPOSE");
849   SubscriptValue resultAt[2]{1, 1};
850   SubscriptValue matrixLB[2];
851   matrix.GetLowerBounds(matrixLB);
852   for (std::size_t n{result.Elements()}; n-- > 0;
853        result.IncrementSubscripts(resultAt)) {
854     SubscriptValue matrixAt[2]{
855         matrixLB[0] + resultAt[1] - 1, matrixLB[1] + resultAt[0] - 1};
856     CopyElement(result, resultAt, matrix, matrixAt, terminator);
857   }
858 }
859 
860 // UNPACK
861 void RTDEF(Unpack)(Descriptor &result, const Descriptor &vector,
862     const Descriptor &mask, const Descriptor &field, const char *sourceFile,
863     int line) {
864   Terminator terminator{sourceFile, line};
865   RUNTIME_CHECK(terminator, vector.rank() == 1);
866   int rank{mask.rank()};
867   RUNTIME_CHECK(terminator, rank > 0);
868   SubscriptValue extent[maxRank];
869   mask.GetShape(extent);
870   CheckConformability(mask, field, terminator, "UNPACK", "MASK=", "FIELD=");
871   std::size_t elementLen{
872       AllocateResult(result, field, rank, extent, terminator, "UNPACK")};
873   RUNTIME_CHECK(terminator, vector.type() == field.type());
874   if (vector.ElementBytes() != elementLen) {
875     terminator.Crash(
876         "UNPACK: VECTOR= has element byte length %zd but FIELD= has length %zd",
877         vector.ElementBytes(), elementLen);
878   }
879   SubscriptValue resultAt[maxRank], maskAt[maxRank], fieldAt[maxRank],
880       vectorAt{vector.GetDimension(0).LowerBound()};
881   for (int j{0}; j < rank; ++j) {
882     resultAt[j] = 1;
883   }
884   mask.GetLowerBounds(maskAt);
885   field.GetLowerBounds(fieldAt);
886   SubscriptValue vectorElements{vector.GetDimension(0).Extent()};
887   SubscriptValue vectorLeft{vectorElements};
888   for (std::size_t n{result.Elements()}; n-- > 0;) {
889     if (IsLogicalElementTrue(mask, maskAt)) {
890       if (vectorLeft-- == 0) {
891         terminator.Crash(
892             "UNPACK: VECTOR= argument has fewer elements (%d) than "
893             "MASK= has .TRUE. entries",
894             vectorElements);
895       }
896       CopyElement(result, resultAt, vector, &vectorAt, terminator);
897       ++vectorAt;
898     } else {
899       CopyElement(result, resultAt, field, fieldAt, terminator);
900     }
901     result.IncrementSubscripts(resultAt);
902     mask.IncrementSubscripts(maskAt);
903     field.IncrementSubscripts(fieldAt);
904   }
905 }
906 
907 RT_EXT_API_GROUP_END
908 } // extern "C"
909 } // namespace Fortran::runtime
910