xref: /llvm-project/flang/runtime/findloc.cpp (revision 9f3a6114807b66738585af060012927bd0f05b88)
1 //===-- runtime/findloc.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // Implements FINDLOC for all required operand types and shapes and result
10 // integer kinds.
11 
12 #include "reduction-templates.h"
13 #include "flang/Runtime/character.h"
14 #include "flang/Runtime/reduction.h"
15 #include <cinttypes>
16 #include <complex>
17 
18 namespace Fortran::runtime {
19 
20 template <TypeCategory CAT1, int KIND1, TypeCategory CAT2, int KIND2>
21 struct Equality {
22   using Type1 = CppTypeFor<CAT1, KIND1>;
23   using Type2 = CppTypeFor<CAT2, KIND2>;
24   RT_API_ATTRS bool operator()(const Descriptor &array,
25       const SubscriptValue at[], const Descriptor &target) const {
26     if constexpr (KIND1 >= KIND2) {
27       return *array.Element<Type1>(at) ==
28           static_cast<Type1>(*target.OffsetElement<Type2>());
29     } else {
30       return static_cast<Type2>(*array.Element<Type1>(at)) ==
31           *target.OffsetElement<Type2>();
32     }
33   }
34 };
35 
36 template <int KIND1, int KIND2>
37 struct Equality<TypeCategory::Complex, KIND1, TypeCategory::Complex, KIND2> {
38   using Type1 = CppTypeFor<TypeCategory::Complex, KIND1>;
39   using Type2 = CppTypeFor<TypeCategory::Complex, KIND2>;
40   RT_API_ATTRS bool operator()(const Descriptor &array,
41       const SubscriptValue at[], const Descriptor &target) const {
42     const Type1 &xz{*array.Element<Type1>(at)};
43     const Type2 &tz{*target.OffsetElement<Type2>()};
44     return xz.real() == tz.real() && xz.imag() == tz.imag();
45   }
46 };
47 
48 template <int KIND1, TypeCategory CAT2, int KIND2>
49 struct Equality<TypeCategory::Complex, KIND1, CAT2, KIND2> {
50   using Type1 = CppTypeFor<TypeCategory::Complex, KIND1>;
51   using Type2 = CppTypeFor<CAT2, KIND2>;
52   RT_API_ATTRS bool operator()(const Descriptor &array,
53       const SubscriptValue at[], const Descriptor &target) const {
54     const Type1 &z{*array.Element<Type1>(at)};
55     return z.imag() == 0 && z.real() == *target.OffsetElement<Type2>();
56   }
57 };
58 
59 template <TypeCategory CAT1, int KIND1, int KIND2>
60 struct Equality<CAT1, KIND1, TypeCategory::Complex, KIND2> {
61   using Type1 = CppTypeFor<CAT1, KIND1>;
62   using Type2 = CppTypeFor<TypeCategory::Complex, KIND2>;
63   RT_API_ATTRS bool operator()(const Descriptor &array,
64       const SubscriptValue at[], const Descriptor &target) const {
65     const Type2 &z{*target.OffsetElement<Type2>()};
66     return *array.Element<Type1>(at) == z.real() && z.imag() == 0;
67   }
68 };
69 
70 template <int KIND> struct CharacterEquality {
71   using Type = CppTypeFor<TypeCategory::Character, KIND>;
72   RT_API_ATTRS bool operator()(const Descriptor &array,
73       const SubscriptValue at[], const Descriptor &target) const {
74     return CharacterScalarCompare<Type>(array.Element<Type>(at),
75                target.OffsetElement<Type>(),
76                array.ElementBytes() / static_cast<unsigned>(KIND),
77                target.ElementBytes() / static_cast<unsigned>(KIND)) == 0;
78   }
79 };
80 
81 struct LogicalEquivalence {
82   RT_API_ATTRS bool operator()(const Descriptor &array,
83       const SubscriptValue at[], const Descriptor &target) const {
84     return IsLogicalElementTrue(array, at) ==
85         IsLogicalElementTrue(target, at /*ignored*/);
86   }
87 };
88 
89 template <typename EQUALITY> class LocationAccumulator {
90 public:
91   RT_API_ATTRS LocationAccumulator(
92       const Descriptor &array, const Descriptor &target, bool back)
93       : array_{array}, target_{target}, back_{back} {}
94   RT_API_ATTRS void Reinitialize() { gotAnything_ = false; }
95   template <typename A>
96   RT_API_ATTRS void GetResult(A *p, int zeroBasedDim = -1) {
97     if (zeroBasedDim >= 0) {
98       *p = gotAnything_ ? location_[zeroBasedDim] -
99               array_.GetDimension(zeroBasedDim).LowerBound() + 1
100                         : 0;
101     } else if (gotAnything_) {
102       for (int j{0}; j < rank_; ++j) {
103         p[j] = location_[j] - array_.GetDimension(j).LowerBound() + 1;
104       }
105     } else {
106       // no unmasked hits? result is all zeroes
107       for (int j{0}; j < rank_; ++j) {
108         p[j] = 0;
109       }
110     }
111   }
112   template <typename IGNORED>
113   RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) {
114     if (equality_(array_, at, target_)) {
115       gotAnything_ = true;
116       for (int j{0}; j < rank_; ++j) {
117         location_[j] = at[j];
118       }
119       return back_;
120     } else {
121       return true;
122     }
123   }
124 
125 private:
126   const Descriptor &array_;
127   const Descriptor &target_;
128   const bool back_{false};
129   const int rank_{array_.rank()};
130   bool gotAnything_{false};
131   SubscriptValue location_[maxRank];
132   const EQUALITY equality_{};
133 };
134 
135 template <TypeCategory XCAT, int XKIND, TypeCategory TARGET_CAT>
136 struct TotalNumericFindlocHelper {
137   template <int TARGET_KIND> struct Functor {
138     RT_API_ATTRS void operator()(Descriptor &result, const Descriptor &x,
139         const Descriptor &target, int kind, int dim, const Descriptor *mask,
140         bool back, Terminator &terminator) const {
141       using Eq = Equality<XCAT, XKIND, TARGET_CAT, TARGET_KIND>;
142       using Accumulator = LocationAccumulator<Eq>;
143       Accumulator accumulator{x, target, back};
144       DoTotalReduction<void>(x, dim, mask, accumulator, "FINDLOC", terminator);
145       ApplyIntegerKind<LocationResultHelper<Accumulator>::template Functor,
146           void>(kind, terminator, accumulator, result);
147     }
148   };
149 };
150 
151 template <TypeCategory CAT,
152     template <TypeCategory XCAT, int XKIND, TypeCategory TARGET_CAT>
153     class HELPER>
154 struct NumericFindlocHelper {
155   template <int KIND> struct Functor {
156     RT_API_ATTRS void operator()(TypeCategory targetCat, int targetKind,
157         Descriptor &result, const Descriptor &x, const Descriptor &target,
158         int kind, int dim, const Descriptor *mask, bool back,
159         Terminator &terminator) const {
160       switch (targetCat) {
161       case TypeCategory::Integer:
162       case TypeCategory::Unsigned:
163         ApplyIntegerKind<
164             HELPER<CAT, KIND, TypeCategory::Integer>::template Functor, void>(
165             targetKind, terminator, result, x, target, kind, dim, mask, back,
166             terminator);
167         break;
168       case TypeCategory::Real:
169         ApplyFloatingPointKind<
170             HELPER<CAT, KIND, TypeCategory::Real>::template Functor, void>(
171             targetKind, terminator, result, x, target, kind, dim, mask, back,
172             terminator);
173         break;
174       case TypeCategory::Complex:
175         ApplyFloatingPointKind<
176             HELPER<CAT, KIND, TypeCategory::Complex>::template Functor, void>(
177             targetKind, terminator, result, x, target, kind, dim, mask, back,
178             terminator);
179         break;
180       default:
181         terminator.Crash(
182             "FINDLOC: bad target category %d for array category %d",
183             static_cast<int>(targetCat), static_cast<int>(CAT));
184       }
185     }
186   };
187 };
188 
189 template <int KIND> struct CharacterFindlocHelper {
190   RT_API_ATTRS void operator()(Descriptor &result, const Descriptor &x,
191       const Descriptor &target, int kind, const Descriptor *mask, bool back,
192       Terminator &terminator) {
193     using Accumulator = LocationAccumulator<CharacterEquality<KIND>>;
194     Accumulator accumulator{x, target, back};
195     DoTotalReduction<void>(x, 0, mask, accumulator, "FINDLOC", terminator);
196     ApplyIntegerKind<LocationResultHelper<Accumulator>::template Functor, void>(
197         kind, terminator, accumulator, result);
198   }
199 };
200 
201 static RT_API_ATTRS void LogicalFindlocHelper(Descriptor &result,
202     const Descriptor &x, const Descriptor &target, int kind,
203     const Descriptor *mask, bool back, Terminator &terminator) {
204   using Accumulator = LocationAccumulator<LogicalEquivalence>;
205   Accumulator accumulator{x, target, back};
206   DoTotalReduction<void>(x, 0, mask, accumulator, "FINDLOC", terminator);
207   ApplyIntegerKind<LocationResultHelper<Accumulator>::template Functor, void>(
208       kind, terminator, accumulator, result);
209 }
210 
211 extern "C" {
212 RT_EXT_API_GROUP_BEGIN
213 
214 void RTDEF(Findloc)(Descriptor &result, const Descriptor &x,
215     const Descriptor &target, int kind, const char *source, int line,
216     const Descriptor *mask, bool back) {
217   int rank{x.rank()};
218   SubscriptValue extent[1]{rank};
219   result.Establish(TypeCategory::Integer, kind, nullptr, 1, extent,
220       CFI_attribute_allocatable);
221   result.GetDimension(0).SetBounds(1, extent[0]);
222   Terminator terminator{source, line};
223   if (int stat{result.Allocate()}) {
224     terminator.Crash(
225         "FINDLOC: could not allocate memory for result; STAT=%d", stat);
226   }
227   CheckIntegerKind(terminator, kind, "FINDLOC");
228   auto xType{x.type().GetCategoryAndKind()};
229   auto targetType{target.type().GetCategoryAndKind()};
230   RUNTIME_CHECK(terminator, xType.has_value() && targetType.has_value());
231   switch (xType->first) {
232   case TypeCategory::Integer:
233   case TypeCategory::Unsigned:
234     ApplyIntegerKind<NumericFindlocHelper<TypeCategory::Integer,
235                          TotalNumericFindlocHelper>::template Functor,
236         void>(xType->second, terminator, targetType->first, targetType->second,
237         result, x, target, kind, 0, mask, back, terminator);
238     break;
239   case TypeCategory::Real:
240     ApplyFloatingPointKind<NumericFindlocHelper<TypeCategory::Real,
241                                TotalNumericFindlocHelper>::template Functor,
242         void>(xType->second, terminator, targetType->first, targetType->second,
243         result, x, target, kind, 0, mask, back, terminator);
244     break;
245   case TypeCategory::Complex:
246     ApplyFloatingPointKind<NumericFindlocHelper<TypeCategory::Complex,
247                                TotalNumericFindlocHelper>::template Functor,
248         void>(xType->second, terminator, targetType->first, targetType->second,
249         result, x, target, kind, 0, mask, back, terminator);
250     break;
251   case TypeCategory::Character:
252     RUNTIME_CHECK(terminator,
253         targetType->first == TypeCategory::Character &&
254             targetType->second == xType->second);
255     ApplyCharacterKind<CharacterFindlocHelper, void>(xType->second, terminator,
256         result, x, target, kind, mask, back, terminator);
257     break;
258   case TypeCategory::Logical:
259     RUNTIME_CHECK(terminator, targetType->first == TypeCategory::Logical);
260     LogicalFindlocHelper(result, x, target, kind, mask, back, terminator);
261     break;
262   default:
263     terminator.Crash(
264         "FINDLOC: bad data type code (%d) for array", x.type().raw());
265   }
266 }
267 
268 RT_EXT_API_GROUP_END
269 } // extern "C"
270 
271 // FINDLOC with DIM=
272 
273 template <TypeCategory XCAT, int XKIND, TypeCategory TARGET_CAT>
274 struct PartialNumericFindlocHelper {
275   template <int TARGET_KIND> struct Functor {
276     RT_API_ATTRS void operator()(Descriptor &result, const Descriptor &x,
277         const Descriptor &target, int kind, int dim, const Descriptor *mask,
278         bool back, Terminator &terminator) const {
279       using Eq = Equality<XCAT, XKIND, TARGET_CAT, TARGET_KIND>;
280       using Accumulator = LocationAccumulator<Eq>;
281       Accumulator accumulator{x, target, back};
282       ApplyIntegerKind<PartialLocationHelper<Accumulator>::template Functor,
283           void>(kind, terminator, result, x, dim, mask, terminator, "FINDLOC",
284           accumulator);
285     }
286   };
287 };
288 
289 template <int KIND> struct PartialCharacterFindlocHelper {
290   RT_API_ATTRS void operator()(Descriptor &result, const Descriptor &x,
291       const Descriptor &target, int kind, int dim, const Descriptor *mask,
292       bool back, Terminator &terminator) {
293     using Accumulator = LocationAccumulator<CharacterEquality<KIND>>;
294     Accumulator accumulator{x, target, back};
295     ApplyIntegerKind<PartialLocationHelper<Accumulator>::template Functor,
296         void>(kind, terminator, result, x, dim, mask, terminator, "FINDLOC",
297         accumulator);
298   }
299 };
300 
301 static RT_API_ATTRS void PartialLogicalFindlocHelper(Descriptor &result,
302     const Descriptor &x, const Descriptor &target, int kind, int dim,
303     const Descriptor *mask, bool back, Terminator &terminator) {
304   using Accumulator = LocationAccumulator<LogicalEquivalence>;
305   Accumulator accumulator{x, target, back};
306   ApplyIntegerKind<PartialLocationHelper<Accumulator>::template Functor, void>(
307       kind, terminator, result, x, dim, mask, terminator, "FINDLOC",
308       accumulator);
309 }
310 
311 extern "C" {
312 RT_EXT_API_GROUP_BEGIN
313 
314 void RTDEF(FindlocDim)(Descriptor &result, const Descriptor &x,
315     const Descriptor &target, int kind, int dim, const char *source, int line,
316     const Descriptor *mask, bool back) {
317   Terminator terminator{source, line};
318   CheckIntegerKind(terminator, kind, "FINDLOC");
319   auto xType{x.type().GetCategoryAndKind()};
320   auto targetType{target.type().GetCategoryAndKind()};
321   RUNTIME_CHECK(terminator, xType.has_value() && targetType.has_value());
322   switch (xType->first) {
323   case TypeCategory::Integer:
324   case TypeCategory::Unsigned:
325     ApplyIntegerKind<NumericFindlocHelper<TypeCategory::Integer,
326                          PartialNumericFindlocHelper>::template Functor,
327         void>(xType->second, terminator, targetType->first, targetType->second,
328         result, x, target, kind, dim, mask, back, terminator);
329     break;
330   case TypeCategory::Real:
331     ApplyFloatingPointKind<NumericFindlocHelper<TypeCategory::Real,
332                                PartialNumericFindlocHelper>::template Functor,
333         void>(xType->second, terminator, targetType->first, targetType->second,
334         result, x, target, kind, dim, mask, back, terminator);
335     break;
336   case TypeCategory::Complex:
337     ApplyFloatingPointKind<NumericFindlocHelper<TypeCategory::Complex,
338                                PartialNumericFindlocHelper>::template Functor,
339         void>(xType->second, terminator, targetType->first, targetType->second,
340         result, x, target, kind, dim, mask, back, terminator);
341     break;
342   case TypeCategory::Character:
343     RUNTIME_CHECK(terminator,
344         targetType->first == TypeCategory::Character &&
345             targetType->second == xType->second);
346     ApplyCharacterKind<PartialCharacterFindlocHelper, void>(xType->second,
347         terminator, result, x, target, kind, dim, mask, back, terminator);
348     break;
349   case TypeCategory::Logical:
350     RUNTIME_CHECK(terminator, targetType->first == TypeCategory::Logical);
351     PartialLogicalFindlocHelper(
352         result, x, target, kind, dim, mask, back, terminator);
353     break;
354   default:
355     terminator.Crash(
356         "FINDLOC: bad data type code (%d) for array", x.type().raw());
357   }
358 }
359 
360 RT_EXT_API_GROUP_END
361 } // extern "C"
362 } // namespace Fortran::runtime
363