xref: /llvm-project/flang/include/flang/Lower/Support/Utils.h (revision 77d8cfb3c50e3341d65af1f9e442004bbd77af9b)
1 //===-- Lower/Support/Utils.h -- utilities ----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef FORTRAN_LOWER_SUPPORT_UTILS_H
14 #define FORTRAN_LOWER_SUPPORT_UTILS_H
15 
16 #include "flang/Common/indirection.h"
17 #include "flang/Parser/char-block.h"
18 #include "flang/Semantics/tools.h"
19 #include "mlir/Dialect/Arith/IR/Arith.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/IR/BuiltinAttributes.h"
22 #include "llvm/ADT/StringRef.h"
23 #include <cstdint>
24 
25 namespace Fortran::lower {
26 using SomeExpr = Fortran::evaluate::Expr<Fortran::evaluate::SomeType>;
27 } // end namespace Fortran::lower
28 
29 //===----------------------------------------------------------------------===//
30 // Small inline helper functions to deal with repetitive, clumsy conversions.
31 //===----------------------------------------------------------------------===//
32 
33 /// Convert an F18 CharBlock to an LLVM StringRef.
toStringRef(const Fortran::parser::CharBlock & cb)34 inline llvm::StringRef toStringRef(const Fortran::parser::CharBlock &cb) {
35   return {cb.begin(), cb.size()};
36 }
37 
38 /// Template helper to remove Fortran::common::Indirection wrappers.
39 template <typename A>
removeIndirection(const A & a)40 const A &removeIndirection(const A &a) {
41   return a;
42 }
43 template <typename A>
removeIndirection(const Fortran::common::Indirection<A> & a)44 const A &removeIndirection(const Fortran::common::Indirection<A> &a) {
45   return a.value();
46 }
47 
48 /// Clone subexpression and wrap it as a generic `Fortran::evaluate::Expr`.
49 template <typename A>
toEvExpr(const A & x)50 static Fortran::lower::SomeExpr toEvExpr(const A &x) {
51   return Fortran::evaluate::AsGenericExpr(Fortran::common::Clone(x));
52 }
53 
54 template <Fortran::common::TypeCategory FROM>
ignoreEvConvert(const Fortran::evaluate::Convert<Fortran::evaluate::Type<Fortran::common::TypeCategory::Integer,8>,FROM> & x)55 static Fortran::lower::SomeExpr ignoreEvConvert(
56     const Fortran::evaluate::Convert<
57         Fortran::evaluate::Type<Fortran::common::TypeCategory::Integer, 8>,
58         FROM> &x) {
59   return toEvExpr(x.left());
60 }
61 template <typename A>
ignoreEvConvert(const A & x)62 static Fortran::lower::SomeExpr ignoreEvConvert(const A &x) {
63   return toEvExpr(x);
64 }
65 
66 /// A vector subscript expression may be wrapped with a cast to INTEGER*8.
67 /// Get rid of it here so the vector can be loaded. Add it back when
68 /// generating the elemental evaluation (inside the loop nest).
69 inline Fortran::lower::SomeExpr
ignoreEvConvert(const Fortran::evaluate::Expr<Fortran::evaluate::Type<Fortran::common::TypeCategory::Integer,8>> & x)70 ignoreEvConvert(const Fortran::evaluate::Expr<Fortran::evaluate::Type<
71                     Fortran::common::TypeCategory::Integer, 8>> &x) {
72   return Fortran::common::visit(
73       [](const auto &v) { return ignoreEvConvert(v); }, x.u);
74 }
75 
76 /// Zip two containers of the same size together and flatten the pairs. `flatZip
77 /// [1;2] [3;4]` yields `[1;3;2;4]`.
78 template <typename A>
flatZip(const A & container1,const A & container2)79 A flatZip(const A &container1, const A &container2) {
80   assert(container1.size() == container2.size());
81   A result;
82   for (auto [e1, e2] : llvm::zip(container1, container2)) {
83     result.emplace_back(e1);
84     result.emplace_back(e2);
85   }
86   return result;
87 }
88 
89 namespace Fortran::lower {
90 // Fortran::evaluate::Expr are functional values organized like an AST. A
91 // Fortran::evaluate::Expr is meant to be moved and cloned. Using the front end
92 // tools can often cause copies and extra wrapper classes to be added to any
93 // Fortran::evaluate::Expr. These values should not be assumed or relied upon to
94 // have an *object* identity. They are deeply recursive, irregular structures
95 // built from a large number of classes which do not use inheritance and
96 // necessitate a large volume of boilerplate code as a result.
97 //
98 // Contrastingly, LLVM data structures make ubiquitous assumptions about an
99 // object's identity via pointers to the object. An object's location in memory
100 // is thus very often an identifying relation.
101 
102 // This class defines a hash computation of a Fortran::evaluate::Expr tree value
103 // so it can be used with llvm::DenseMap. The Fortran::evaluate::Expr need not
104 // have the same address.
105 class HashEvaluateExpr {
106 public:
107   // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an
108   // identity property.
getHashValue(const Fortran::semantics::Symbol & x)109   static unsigned getHashValue(const Fortran::semantics::Symbol &x) {
110     return static_cast<unsigned>(reinterpret_cast<std::intptr_t>(&x));
111   }
112   template <typename A, bool COPY>
getHashValue(const Fortran::common::Indirection<A,COPY> & x)113   static unsigned getHashValue(const Fortran::common::Indirection<A, COPY> &x) {
114     return getHashValue(x.value());
115   }
116   template <typename A>
getHashValue(const std::optional<A> & x)117   static unsigned getHashValue(const std::optional<A> &x) {
118     if (x.has_value())
119       return getHashValue(x.value());
120     return 0u;
121   }
getHashValue(const Fortran::evaluate::Subscript & x)122   static unsigned getHashValue(const Fortran::evaluate::Subscript &x) {
123     return Fortran::common::visit(
124         [&](const auto &v) { return getHashValue(v); }, x.u);
125   }
getHashValue(const Fortran::evaluate::Triplet & x)126   static unsigned getHashValue(const Fortran::evaluate::Triplet &x) {
127     return getHashValue(x.lower()) - getHashValue(x.upper()) * 5u -
128            getHashValue(x.stride()) * 11u;
129   }
getHashValue(const Fortran::evaluate::Component & x)130   static unsigned getHashValue(const Fortran::evaluate::Component &x) {
131     return getHashValue(x.base()) * 83u - getHashValue(x.GetLastSymbol());
132   }
getHashValue(const Fortran::evaluate::ArrayRef & x)133   static unsigned getHashValue(const Fortran::evaluate::ArrayRef &x) {
134     unsigned subs = 1u;
135     for (const Fortran::evaluate::Subscript &v : x.subscript())
136       subs -= getHashValue(v);
137     return getHashValue(x.base()) * 89u - subs;
138   }
getHashValue(const Fortran::evaluate::CoarrayRef & x)139   static unsigned getHashValue(const Fortran::evaluate::CoarrayRef &x) {
140     unsigned subs = 1u;
141     for (const Fortran::evaluate::Subscript &v : x.subscript())
142       subs -= getHashValue(v);
143     unsigned cosubs = 3u;
144     for (const Fortran::evaluate::Expr<Fortran::evaluate::SubscriptInteger> &v :
145          x.cosubscript())
146       cosubs -= getHashValue(v);
147     unsigned syms = 7u;
148     for (const Fortran::evaluate::SymbolRef &v : x.base())
149       syms += getHashValue(v);
150     return syms * 97u - subs - cosubs + getHashValue(x.stat()) + 257u +
151            getHashValue(x.team());
152   }
getHashValue(const Fortran::evaluate::NamedEntity & x)153   static unsigned getHashValue(const Fortran::evaluate::NamedEntity &x) {
154     if (x.IsSymbol())
155       return getHashValue(x.GetFirstSymbol()) * 11u;
156     return getHashValue(x.GetComponent()) * 13u;
157   }
getHashValue(const Fortran::evaluate::DataRef & x)158   static unsigned getHashValue(const Fortran::evaluate::DataRef &x) {
159     return Fortran::common::visit(
160         [&](const auto &v) { return getHashValue(v); }, x.u);
161   }
getHashValue(const Fortran::evaluate::ComplexPart & x)162   static unsigned getHashValue(const Fortran::evaluate::ComplexPart &x) {
163     return getHashValue(x.complex()) - static_cast<unsigned>(x.part());
164   }
165   template <Fortran::common::TypeCategory TC1, int KIND,
166             Fortran::common::TypeCategory TC2>
getHashValue(const Fortran::evaluate::Convert<Fortran::evaluate::Type<TC1,KIND>,TC2> & x)167   static unsigned getHashValue(
168       const Fortran::evaluate::Convert<Fortran::evaluate::Type<TC1, KIND>, TC2>
169           &x) {
170     return getHashValue(x.left()) - (static_cast<unsigned>(TC1) + 2u) -
171            (static_cast<unsigned>(KIND) + 5u);
172   }
173   template <int KIND>
174   static unsigned
getHashValue(const Fortran::evaluate::ComplexComponent<KIND> & x)175   getHashValue(const Fortran::evaluate::ComplexComponent<KIND> &x) {
176     return getHashValue(x.left()) -
177            (static_cast<unsigned>(x.isImaginaryPart) + 1u) * 3u;
178   }
179   template <typename T>
getHashValue(const Fortran::evaluate::Parentheses<T> & x)180   static unsigned getHashValue(const Fortran::evaluate::Parentheses<T> &x) {
181     return getHashValue(x.left()) * 17u;
182   }
183   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Negate<Fortran::evaluate::Type<TC,KIND>> & x)184   static unsigned getHashValue(
185       const Fortran::evaluate::Negate<Fortran::evaluate::Type<TC, KIND>> &x) {
186     return getHashValue(x.left()) - (static_cast<unsigned>(TC) + 5u) -
187            (static_cast<unsigned>(KIND) + 7u);
188   }
189   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Add<Fortran::evaluate::Type<TC,KIND>> & x)190   static unsigned getHashValue(
191       const Fortran::evaluate::Add<Fortran::evaluate::Type<TC, KIND>> &x) {
192     return (getHashValue(x.left()) + getHashValue(x.right())) * 23u +
193            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
194   }
195   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Subtract<Fortran::evaluate::Type<TC,KIND>> & x)196   static unsigned getHashValue(
197       const Fortran::evaluate::Subtract<Fortran::evaluate::Type<TC, KIND>> &x) {
198     return (getHashValue(x.left()) - getHashValue(x.right())) * 19u +
199            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
200   }
201   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Multiply<Fortran::evaluate::Type<TC,KIND>> & x)202   static unsigned getHashValue(
203       const Fortran::evaluate::Multiply<Fortran::evaluate::Type<TC, KIND>> &x) {
204     return (getHashValue(x.left()) + getHashValue(x.right())) * 29u +
205            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
206   }
207   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Divide<Fortran::evaluate::Type<TC,KIND>> & x)208   static unsigned getHashValue(
209       const Fortran::evaluate::Divide<Fortran::evaluate::Type<TC, KIND>> &x) {
210     return (getHashValue(x.left()) - getHashValue(x.right())) * 31u +
211            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
212   }
213   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Power<Fortran::evaluate::Type<TC,KIND>> & x)214   static unsigned getHashValue(
215       const Fortran::evaluate::Power<Fortran::evaluate::Type<TC, KIND>> &x) {
216     return (getHashValue(x.left()) - getHashValue(x.right())) * 37u +
217            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
218   }
219   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Extremum<Fortran::evaluate::Type<TC,KIND>> & x)220   static unsigned getHashValue(
221       const Fortran::evaluate::Extremum<Fortran::evaluate::Type<TC, KIND>> &x) {
222     return (getHashValue(x.left()) + getHashValue(x.right())) * 41u +
223            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND) +
224            static_cast<unsigned>(x.ordering) * 7u;
225   }
226   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::RealToIntPower<Fortran::evaluate::Type<TC,KIND>> & x)227   static unsigned getHashValue(
228       const Fortran::evaluate::RealToIntPower<Fortran::evaluate::Type<TC, KIND>>
229           &x) {
230     return (getHashValue(x.left()) - getHashValue(x.right())) * 43u +
231            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
232   }
233   template <int KIND>
234   static unsigned
getHashValue(const Fortran::evaluate::ComplexConstructor<KIND> & x)235   getHashValue(const Fortran::evaluate::ComplexConstructor<KIND> &x) {
236     return (getHashValue(x.left()) - getHashValue(x.right())) * 47u +
237            static_cast<unsigned>(KIND);
238   }
239   template <int KIND>
getHashValue(const Fortran::evaluate::Concat<KIND> & x)240   static unsigned getHashValue(const Fortran::evaluate::Concat<KIND> &x) {
241     return (getHashValue(x.left()) - getHashValue(x.right())) * 53u +
242            static_cast<unsigned>(KIND);
243   }
244   template <int KIND>
getHashValue(const Fortran::evaluate::SetLength<KIND> & x)245   static unsigned getHashValue(const Fortran::evaluate::SetLength<KIND> &x) {
246     return (getHashValue(x.left()) - getHashValue(x.right())) * 59u +
247            static_cast<unsigned>(KIND);
248   }
getHashValue(const Fortran::semantics::SymbolRef & sym)249   static unsigned getHashValue(const Fortran::semantics::SymbolRef &sym) {
250     return getHashValue(sym.get());
251   }
getHashValue(const Fortran::evaluate::Substring & x)252   static unsigned getHashValue(const Fortran::evaluate::Substring &x) {
253     return 61u *
254                Fortran::common::visit(
255                    [&](const auto &p) { return getHashValue(p); }, x.parent()) -
256            getHashValue(x.lower()) - (getHashValue(x.lower()) + 1u);
257   }
258   static unsigned
getHashValue(const Fortran::evaluate::StaticDataObject::Pointer & x)259   getHashValue(const Fortran::evaluate::StaticDataObject::Pointer &x) {
260     return llvm::hash_value(x->name());
261   }
getHashValue(const Fortran::evaluate::SpecificIntrinsic & x)262   static unsigned getHashValue(const Fortran::evaluate::SpecificIntrinsic &x) {
263     return llvm::hash_value(x.name);
264   }
265   template <typename A>
getHashValue(const Fortran::evaluate::Constant<A> & x)266   static unsigned getHashValue(const Fortran::evaluate::Constant<A> &x) {
267     // FIXME: Should hash the content.
268     return 103u;
269   }
getHashValue(const Fortran::evaluate::ActualArgument & x)270   static unsigned getHashValue(const Fortran::evaluate::ActualArgument &x) {
271     if (const Fortran::evaluate::Symbol *sym = x.GetAssumedTypeDummy())
272       return getHashValue(*sym);
273     return getHashValue(*x.UnwrapExpr());
274   }
275   static unsigned
getHashValue(const Fortran::evaluate::ProcedureDesignator & x)276   getHashValue(const Fortran::evaluate::ProcedureDesignator &x) {
277     return Fortran::common::visit(
278         [&](const auto &v) { return getHashValue(v); }, x.u);
279   }
getHashValue(const Fortran::evaluate::ProcedureRef & x)280   static unsigned getHashValue(const Fortran::evaluate::ProcedureRef &x) {
281     unsigned args = 13u;
282     for (const std::optional<Fortran::evaluate::ActualArgument> &v :
283          x.arguments())
284       args -= getHashValue(v);
285     return getHashValue(x.proc()) * 101u - args;
286   }
287   template <typename A>
288   static unsigned
getHashValue(const Fortran::evaluate::ArrayConstructor<A> & x)289   getHashValue(const Fortran::evaluate::ArrayConstructor<A> &x) {
290     // FIXME: hash the contents.
291     return 127u;
292   }
getHashValue(const Fortran::evaluate::ImpliedDoIndex & x)293   static unsigned getHashValue(const Fortran::evaluate::ImpliedDoIndex &x) {
294     return llvm::hash_value(toStringRef(x.name).str()) * 131u;
295   }
getHashValue(const Fortran::evaluate::TypeParamInquiry & x)296   static unsigned getHashValue(const Fortran::evaluate::TypeParamInquiry &x) {
297     return getHashValue(x.base()) * 137u - getHashValue(x.parameter()) * 3u;
298   }
getHashValue(const Fortran::evaluate::DescriptorInquiry & x)299   static unsigned getHashValue(const Fortran::evaluate::DescriptorInquiry &x) {
300     return getHashValue(x.base()) * 139u -
301            static_cast<unsigned>(x.field()) * 13u +
302            static_cast<unsigned>(x.dimension());
303   }
304   static unsigned
getHashValue(const Fortran::evaluate::StructureConstructor & x)305   getHashValue(const Fortran::evaluate::StructureConstructor &x) {
306     // FIXME: hash the contents.
307     return 149u;
308   }
309   template <int KIND>
getHashValue(const Fortran::evaluate::Not<KIND> & x)310   static unsigned getHashValue(const Fortran::evaluate::Not<KIND> &x) {
311     return getHashValue(x.left()) * 61u + static_cast<unsigned>(KIND);
312   }
313   template <int KIND>
314   static unsigned
getHashValue(const Fortran::evaluate::LogicalOperation<KIND> & x)315   getHashValue(const Fortran::evaluate::LogicalOperation<KIND> &x) {
316     unsigned result = getHashValue(x.left()) + getHashValue(x.right());
317     return result * 67u + static_cast<unsigned>(x.logicalOperator) * 5u;
318   }
319   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Relational<Fortran::evaluate::Type<TC,KIND>> & x)320   static unsigned getHashValue(
321       const Fortran::evaluate::Relational<Fortran::evaluate::Type<TC, KIND>>
322           &x) {
323     return (getHashValue(x.left()) + getHashValue(x.right())) * 71u +
324            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND) +
325            static_cast<unsigned>(x.opr) * 11u;
326   }
327   template <typename A>
getHashValue(const Fortran::evaluate::Expr<A> & x)328   static unsigned getHashValue(const Fortran::evaluate::Expr<A> &x) {
329     return Fortran::common::visit(
330         [&](const auto &v) { return getHashValue(v); }, x.u);
331   }
getHashValue(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> & x)332   static unsigned getHashValue(
333       const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) {
334     return Fortran::common::visit(
335         [&](const auto &v) { return getHashValue(v); }, x.u);
336   }
337   template <typename A>
getHashValue(const Fortran::evaluate::Designator<A> & x)338   static unsigned getHashValue(const Fortran::evaluate::Designator<A> &x) {
339     return Fortran::common::visit(
340         [&](const auto &v) { return getHashValue(v); }, x.u);
341   }
342   template <int BITS>
343   static unsigned
getHashValue(const Fortran::evaluate::value::Integer<BITS> & x)344   getHashValue(const Fortran::evaluate::value::Integer<BITS> &x) {
345     return static_cast<unsigned>(x.ToSInt());
346   }
getHashValue(const Fortran::evaluate::NullPointer & x)347   static unsigned getHashValue(const Fortran::evaluate::NullPointer &x) {
348     return ~179u;
349   }
350 };
351 
352 // Define the is equals test for using Fortran::evaluate::Expr values with
353 // llvm::DenseMap.
354 class IsEqualEvaluateExpr {
355 public:
356   // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an
357   // identity property.
isEqual(const Fortran::semantics::Symbol & x,const Fortran::semantics::Symbol & y)358   static bool isEqual(const Fortran::semantics::Symbol &x,
359                       const Fortran::semantics::Symbol &y) {
360     return isEqual(&x, &y);
361   }
isEqual(const Fortran::semantics::Symbol * x,const Fortran::semantics::Symbol * y)362   static bool isEqual(const Fortran::semantics::Symbol *x,
363                       const Fortran::semantics::Symbol *y) {
364     return x == y;
365   }
366   template <typename A, bool COPY>
isEqual(const Fortran::common::Indirection<A,COPY> & x,const Fortran::common::Indirection<A,COPY> & y)367   static bool isEqual(const Fortran::common::Indirection<A, COPY> &x,
368                       const Fortran::common::Indirection<A, COPY> &y) {
369     return isEqual(x.value(), y.value());
370   }
371   template <typename A>
isEqual(const std::optional<A> & x,const std::optional<A> & y)372   static bool isEqual(const std::optional<A> &x, const std::optional<A> &y) {
373     if (x.has_value() && y.has_value())
374       return isEqual(x.value(), y.value());
375     return !x.has_value() && !y.has_value();
376   }
377   template <typename A>
isEqual(const std::vector<A> & x,const std::vector<A> & y)378   static bool isEqual(const std::vector<A> &x, const std::vector<A> &y) {
379     if (x.size() != y.size())
380       return false;
381     const std::size_t size = x.size();
382     for (std::remove_const_t<decltype(size)> i = 0; i < size; ++i)
383       if (!isEqual(x[i], y[i]))
384         return false;
385     return true;
386   }
isEqual(const Fortran::evaluate::Subscript & x,const Fortran::evaluate::Subscript & y)387   static bool isEqual(const Fortran::evaluate::Subscript &x,
388                       const Fortran::evaluate::Subscript &y) {
389     return Fortran::common::visit(
390         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
391   }
isEqual(const Fortran::evaluate::Triplet & x,const Fortran::evaluate::Triplet & y)392   static bool isEqual(const Fortran::evaluate::Triplet &x,
393                       const Fortran::evaluate::Triplet &y) {
394     return isEqual(x.lower(), y.lower()) && isEqual(x.upper(), y.upper()) &&
395            isEqual(x.stride(), y.stride());
396   }
isEqual(const Fortran::evaluate::Component & x,const Fortran::evaluate::Component & y)397   static bool isEqual(const Fortran::evaluate::Component &x,
398                       const Fortran::evaluate::Component &y) {
399     return isEqual(x.base(), y.base()) &&
400            isEqual(x.GetLastSymbol(), y.GetLastSymbol());
401   }
isEqual(const Fortran::evaluate::ArrayRef & x,const Fortran::evaluate::ArrayRef & y)402   static bool isEqual(const Fortran::evaluate::ArrayRef &x,
403                       const Fortran::evaluate::ArrayRef &y) {
404     return isEqual(x.base(), y.base()) && isEqual(x.subscript(), y.subscript());
405   }
isEqual(const Fortran::evaluate::CoarrayRef & x,const Fortran::evaluate::CoarrayRef & y)406   static bool isEqual(const Fortran::evaluate::CoarrayRef &x,
407                       const Fortran::evaluate::CoarrayRef &y) {
408     return isEqual(x.base(), y.base()) &&
409            isEqual(x.subscript(), y.subscript()) &&
410            isEqual(x.cosubscript(), y.cosubscript()) &&
411            isEqual(x.stat(), y.stat()) && isEqual(x.team(), y.team());
412   }
isEqual(const Fortran::evaluate::NamedEntity & x,const Fortran::evaluate::NamedEntity & y)413   static bool isEqual(const Fortran::evaluate::NamedEntity &x,
414                       const Fortran::evaluate::NamedEntity &y) {
415     if (x.IsSymbol() && y.IsSymbol())
416       return isEqual(x.GetFirstSymbol(), y.GetFirstSymbol());
417     return !x.IsSymbol() && !y.IsSymbol() &&
418            isEqual(x.GetComponent(), y.GetComponent());
419   }
isEqual(const Fortran::evaluate::DataRef & x,const Fortran::evaluate::DataRef & y)420   static bool isEqual(const Fortran::evaluate::DataRef &x,
421                       const Fortran::evaluate::DataRef &y) {
422     return Fortran::common::visit(
423         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
424   }
isEqual(const Fortran::evaluate::ComplexPart & x,const Fortran::evaluate::ComplexPart & y)425   static bool isEqual(const Fortran::evaluate::ComplexPart &x,
426                       const Fortran::evaluate::ComplexPart &y) {
427     return isEqual(x.complex(), y.complex()) && x.part() == y.part();
428   }
429   template <typename A, Fortran::common::TypeCategory TC2>
isEqual(const Fortran::evaluate::Convert<A,TC2> & x,const Fortran::evaluate::Convert<A,TC2> & y)430   static bool isEqual(const Fortran::evaluate::Convert<A, TC2> &x,
431                       const Fortran::evaluate::Convert<A, TC2> &y) {
432     return isEqual(x.left(), y.left());
433   }
434   template <int KIND>
isEqual(const Fortran::evaluate::ComplexComponent<KIND> & x,const Fortran::evaluate::ComplexComponent<KIND> & y)435   static bool isEqual(const Fortran::evaluate::ComplexComponent<KIND> &x,
436                       const Fortran::evaluate::ComplexComponent<KIND> &y) {
437     return isEqual(x.left(), y.left()) &&
438            x.isImaginaryPart == y.isImaginaryPart;
439   }
440   template <typename T>
isEqual(const Fortran::evaluate::Parentheses<T> & x,const Fortran::evaluate::Parentheses<T> & y)441   static bool isEqual(const Fortran::evaluate::Parentheses<T> &x,
442                       const Fortran::evaluate::Parentheses<T> &y) {
443     return isEqual(x.left(), y.left());
444   }
445   template <typename A>
isEqual(const Fortran::evaluate::Negate<A> & x,const Fortran::evaluate::Negate<A> & y)446   static bool isEqual(const Fortran::evaluate::Negate<A> &x,
447                       const Fortran::evaluate::Negate<A> &y) {
448     return isEqual(x.left(), y.left());
449   }
450   template <typename A>
isBinaryEqual(const A & x,const A & y)451   static bool isBinaryEqual(const A &x, const A &y) {
452     return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right());
453   }
454   template <typename A>
isEqual(const Fortran::evaluate::Add<A> & x,const Fortran::evaluate::Add<A> & y)455   static bool isEqual(const Fortran::evaluate::Add<A> &x,
456                       const Fortran::evaluate::Add<A> &y) {
457     return isBinaryEqual(x, y);
458   }
459   template <typename A>
isEqual(const Fortran::evaluate::Subtract<A> & x,const Fortran::evaluate::Subtract<A> & y)460   static bool isEqual(const Fortran::evaluate::Subtract<A> &x,
461                       const Fortran::evaluate::Subtract<A> &y) {
462     return isBinaryEqual(x, y);
463   }
464   template <typename A>
isEqual(const Fortran::evaluate::Multiply<A> & x,const Fortran::evaluate::Multiply<A> & y)465   static bool isEqual(const Fortran::evaluate::Multiply<A> &x,
466                       const Fortran::evaluate::Multiply<A> &y) {
467     return isBinaryEqual(x, y);
468   }
469   template <typename A>
isEqual(const Fortran::evaluate::Divide<A> & x,const Fortran::evaluate::Divide<A> & y)470   static bool isEqual(const Fortran::evaluate::Divide<A> &x,
471                       const Fortran::evaluate::Divide<A> &y) {
472     return isBinaryEqual(x, y);
473   }
474   template <typename A>
isEqual(const Fortran::evaluate::Power<A> & x,const Fortran::evaluate::Power<A> & y)475   static bool isEqual(const Fortran::evaluate::Power<A> &x,
476                       const Fortran::evaluate::Power<A> &y) {
477     return isBinaryEqual(x, y);
478   }
479   template <typename A>
isEqual(const Fortran::evaluate::Extremum<A> & x,const Fortran::evaluate::Extremum<A> & y)480   static bool isEqual(const Fortran::evaluate::Extremum<A> &x,
481                       const Fortran::evaluate::Extremum<A> &y) {
482     return isBinaryEqual(x, y);
483   }
484   template <typename A>
isEqual(const Fortran::evaluate::RealToIntPower<A> & x,const Fortran::evaluate::RealToIntPower<A> & y)485   static bool isEqual(const Fortran::evaluate::RealToIntPower<A> &x,
486                       const Fortran::evaluate::RealToIntPower<A> &y) {
487     return isBinaryEqual(x, y);
488   }
489   template <int KIND>
isEqual(const Fortran::evaluate::ComplexConstructor<KIND> & x,const Fortran::evaluate::ComplexConstructor<KIND> & y)490   static bool isEqual(const Fortran::evaluate::ComplexConstructor<KIND> &x,
491                       const Fortran::evaluate::ComplexConstructor<KIND> &y) {
492     return isBinaryEqual(x, y);
493   }
494   template <int KIND>
isEqual(const Fortran::evaluate::Concat<KIND> & x,const Fortran::evaluate::Concat<KIND> & y)495   static bool isEqual(const Fortran::evaluate::Concat<KIND> &x,
496                       const Fortran::evaluate::Concat<KIND> &y) {
497     return isBinaryEqual(x, y);
498   }
499   template <int KIND>
isEqual(const Fortran::evaluate::SetLength<KIND> & x,const Fortran::evaluate::SetLength<KIND> & y)500   static bool isEqual(const Fortran::evaluate::SetLength<KIND> &x,
501                       const Fortran::evaluate::SetLength<KIND> &y) {
502     return isBinaryEqual(x, y);
503   }
isEqual(const Fortran::semantics::SymbolRef & x,const Fortran::semantics::SymbolRef & y)504   static bool isEqual(const Fortran::semantics::SymbolRef &x,
505                       const Fortran::semantics::SymbolRef &y) {
506     return isEqual(x.get(), y.get());
507   }
isEqual(const Fortran::evaluate::Substring & x,const Fortran::evaluate::Substring & y)508   static bool isEqual(const Fortran::evaluate::Substring &x,
509                       const Fortran::evaluate::Substring &y) {
510     return Fortran::common::visit(
511                [&](const auto &p, const auto &q) { return isEqual(p, q); },
512                x.parent(), y.parent()) &&
513            isEqual(x.lower(), y.lower()) && isEqual(x.upper(), y.upper());
514   }
isEqual(const Fortran::evaluate::StaticDataObject::Pointer & x,const Fortran::evaluate::StaticDataObject::Pointer & y)515   static bool isEqual(const Fortran::evaluate::StaticDataObject::Pointer &x,
516                       const Fortran::evaluate::StaticDataObject::Pointer &y) {
517     return x->name() == y->name();
518   }
isEqual(const Fortran::evaluate::SpecificIntrinsic & x,const Fortran::evaluate::SpecificIntrinsic & y)519   static bool isEqual(const Fortran::evaluate::SpecificIntrinsic &x,
520                       const Fortran::evaluate::SpecificIntrinsic &y) {
521     return x.name == y.name;
522   }
523   template <typename A>
isEqual(const Fortran::evaluate::Constant<A> & x,const Fortran::evaluate::Constant<A> & y)524   static bool isEqual(const Fortran::evaluate::Constant<A> &x,
525                       const Fortran::evaluate::Constant<A> &y) {
526     return x == y;
527   }
isEqual(const Fortran::evaluate::ActualArgument & x,const Fortran::evaluate::ActualArgument & y)528   static bool isEqual(const Fortran::evaluate::ActualArgument &x,
529                       const Fortran::evaluate::ActualArgument &y) {
530     if (const Fortran::evaluate::Symbol *xs = x.GetAssumedTypeDummy()) {
531       if (const Fortran::evaluate::Symbol *ys = y.GetAssumedTypeDummy())
532         return isEqual(*xs, *ys);
533       return false;
534     }
535     return !y.GetAssumedTypeDummy() &&
536            isEqual(*x.UnwrapExpr(), *y.UnwrapExpr());
537   }
isEqual(const Fortran::evaluate::ProcedureDesignator & x,const Fortran::evaluate::ProcedureDesignator & y)538   static bool isEqual(const Fortran::evaluate::ProcedureDesignator &x,
539                       const Fortran::evaluate::ProcedureDesignator &y) {
540     return Fortran::common::visit(
541         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
542   }
isEqual(const Fortran::evaluate::ProcedureRef & x,const Fortran::evaluate::ProcedureRef & y)543   static bool isEqual(const Fortran::evaluate::ProcedureRef &x,
544                       const Fortran::evaluate::ProcedureRef &y) {
545     return isEqual(x.proc(), y.proc()) && isEqual(x.arguments(), y.arguments());
546   }
547   template <typename A>
isEqual(const Fortran::evaluate::ArrayConstructor<A> & x,const Fortran::evaluate::ArrayConstructor<A> & y)548   static bool isEqual(const Fortran::evaluate::ArrayConstructor<A> &x,
549                       const Fortran::evaluate::ArrayConstructor<A> &y) {
550     llvm::report_fatal_error("not implemented");
551   }
isEqual(const Fortran::evaluate::ImpliedDoIndex & x,const Fortran::evaluate::ImpliedDoIndex & y)552   static bool isEqual(const Fortran::evaluate::ImpliedDoIndex &x,
553                       const Fortran::evaluate::ImpliedDoIndex &y) {
554     return toStringRef(x.name) == toStringRef(y.name);
555   }
isEqual(const Fortran::evaluate::TypeParamInquiry & x,const Fortran::evaluate::TypeParamInquiry & y)556   static bool isEqual(const Fortran::evaluate::TypeParamInquiry &x,
557                       const Fortran::evaluate::TypeParamInquiry &y) {
558     return isEqual(x.base(), y.base()) && isEqual(x.parameter(), y.parameter());
559   }
isEqual(const Fortran::evaluate::DescriptorInquiry & x,const Fortran::evaluate::DescriptorInquiry & y)560   static bool isEqual(const Fortran::evaluate::DescriptorInquiry &x,
561                       const Fortran::evaluate::DescriptorInquiry &y) {
562     return isEqual(x.base(), y.base()) && x.field() == y.field() &&
563            x.dimension() == y.dimension();
564   }
isEqual(const Fortran::evaluate::StructureConstructor & x,const Fortran::evaluate::StructureConstructor & y)565   static bool isEqual(const Fortran::evaluate::StructureConstructor &x,
566                       const Fortran::evaluate::StructureConstructor &y) {
567     const auto &xValues = x.values();
568     const auto &yValues = y.values();
569     if (xValues.size() != yValues.size())
570       return false;
571     if (x.derivedTypeSpec() != y.derivedTypeSpec())
572       return false;
573     for (const auto &[xSymbol, xValue] : xValues) {
574       auto yIt = yValues.find(xSymbol);
575       // This should probably never happen, since the derived type
576       // should be the same.
577       if (yIt == yValues.end())
578         return false;
579       if (!isEqual(xValue, yIt->second))
580         return false;
581     }
582     return true;
583   }
584   template <int KIND>
isEqual(const Fortran::evaluate::Not<KIND> & x,const Fortran::evaluate::Not<KIND> & y)585   static bool isEqual(const Fortran::evaluate::Not<KIND> &x,
586                       const Fortran::evaluate::Not<KIND> &y) {
587     return isEqual(x.left(), y.left());
588   }
589   template <int KIND>
isEqual(const Fortran::evaluate::LogicalOperation<KIND> & x,const Fortran::evaluate::LogicalOperation<KIND> & y)590   static bool isEqual(const Fortran::evaluate::LogicalOperation<KIND> &x,
591                       const Fortran::evaluate::LogicalOperation<KIND> &y) {
592     return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right());
593   }
594   template <typename A>
isEqual(const Fortran::evaluate::Relational<A> & x,const Fortran::evaluate::Relational<A> & y)595   static bool isEqual(const Fortran::evaluate::Relational<A> &x,
596                       const Fortran::evaluate::Relational<A> &y) {
597     return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right());
598   }
599   template <typename A>
isEqual(const Fortran::evaluate::Expr<A> & x,const Fortran::evaluate::Expr<A> & y)600   static bool isEqual(const Fortran::evaluate::Expr<A> &x,
601                       const Fortran::evaluate::Expr<A> &y) {
602     return Fortran::common::visit(
603         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
604   }
605   static bool
isEqual(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> & x,const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> & y)606   isEqual(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x,
607           const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &y) {
608     return Fortran::common::visit(
609         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
610   }
611   template <typename A>
isEqual(const Fortran::evaluate::Designator<A> & x,const Fortran::evaluate::Designator<A> & y)612   static bool isEqual(const Fortran::evaluate::Designator<A> &x,
613                       const Fortran::evaluate::Designator<A> &y) {
614     return Fortran::common::visit(
615         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
616   }
617   template <int BITS>
isEqual(const Fortran::evaluate::value::Integer<BITS> & x,const Fortran::evaluate::value::Integer<BITS> & y)618   static bool isEqual(const Fortran::evaluate::value::Integer<BITS> &x,
619                       const Fortran::evaluate::value::Integer<BITS> &y) {
620     return x == y;
621   }
isEqual(const Fortran::evaluate::NullPointer & x,const Fortran::evaluate::NullPointer & y)622   static bool isEqual(const Fortran::evaluate::NullPointer &x,
623                       const Fortran::evaluate::NullPointer &y) {
624     return true;
625   }
626   template <typename A, typename B,
627             std::enable_if_t<!std::is_same_v<A, B>, bool> = true>
isEqual(const A &,const B &)628   static bool isEqual(const A &, const B &) {
629     return false;
630   }
631 };
632 
getHashValue(const Fortran::lower::SomeExpr * x)633 static inline unsigned getHashValue(const Fortran::lower::SomeExpr *x) {
634   return HashEvaluateExpr::getHashValue(*x);
635 }
636 
637 static bool isEqual(const Fortran::lower::SomeExpr *x,
638                     const Fortran::lower::SomeExpr *y);
639 } // end namespace Fortran::lower
640 
641 // DenseMapInfo for pointers to Fortran::lower::SomeExpr.
642 namespace llvm {
643 template <>
644 struct DenseMapInfo<const Fortran::lower::SomeExpr *> {
645   static inline const Fortran::lower::SomeExpr *getEmptyKey() {
646     return reinterpret_cast<Fortran::lower::SomeExpr *>(~0);
647   }
648   static inline const Fortran::lower::SomeExpr *getTombstoneKey() {
649     return reinterpret_cast<Fortran::lower::SomeExpr *>(~0 - 1);
650   }
651   static unsigned getHashValue(const Fortran::lower::SomeExpr *v) {
652     return Fortran::lower::getHashValue(v);
653   }
654   static bool isEqual(const Fortran::lower::SomeExpr *lhs,
655                       const Fortran::lower::SomeExpr *rhs) {
656     return Fortran::lower::isEqual(lhs, rhs);
657   }
658 };
659 } // namespace llvm
660 
661 namespace Fortran::lower {
662 static inline bool isEqual(const Fortran::lower::SomeExpr *x,
663                            const Fortran::lower::SomeExpr *y) {
664   const auto *empty =
665       llvm::DenseMapInfo<const Fortran::lower::SomeExpr *>::getEmptyKey();
666   const auto *tombstone =
667       llvm::DenseMapInfo<const Fortran::lower::SomeExpr *>::getTombstoneKey();
668   if (x == empty || y == empty || x == tombstone || y == tombstone)
669     return x == y;
670   return x == y || IsEqualEvaluateExpr::isEqual(*x, *y);
671 }
672 } // end namespace Fortran::lower
673 
674 #endif // FORTRAN_LOWER_SUPPORT_UTILS_H
675