xref: /llvm-project/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp (revision 4b17a8b10ebb69d3bd30ee7714b5ca24f7e944dc)
16e193b5cSMats Petersson //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
26e193b5cSMats Petersson //
36e193b5cSMats Petersson // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
46e193b5cSMats Petersson // See https://llvm.org/LICENSE.txt for license information.
56e193b5cSMats Petersson // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66e193b5cSMats Petersson //
76e193b5cSMats Petersson //===----------------------------------------------------------------------===//
86e193b5cSMats Petersson 
96e193b5cSMats Petersson //===----------------------------------------------------------------------===//
106e193b5cSMats Petersson /// \file
116e193b5cSMats Petersson /// This pass looks for suitable calls to runtime library for intrinsics that
126e193b5cSMats Petersson /// can be simplified/specialized and replaces with a specialized function.
136e193b5cSMats Petersson ///
146e193b5cSMats Petersson /// For example, SUM(arr) can be specialized as a simple function with one loop,
156e193b5cSMats Petersson /// compared to the three arguments (plus file & line info) that the runtime
166e193b5cSMats Petersson /// call has - when the argument is a 1D-array (multiple loops may be needed
176e193b5cSMats Petersson //  for higher dimension arrays, of course)
186e193b5cSMats Petersson ///
196e193b5cSMats Petersson /// The general idea is that besides making the call simpler, it can also be
206e193b5cSMats Petersson /// inlined by other passes that run after this pass, which further improves
216e193b5cSMats Petersson /// performance, particularly when the work done in the function is trivial
226e193b5cSMats Petersson /// and small in size.
236e193b5cSMats Petersson //===----------------------------------------------------------------------===//
246e193b5cSMats Petersson 
25614cd721SSacha Ballantyne #include "flang/Common/Fortran.h"
266e193b5cSMats Petersson #include "flang/Optimizer/Builder/BoxValue.h"
27*4b17a8b1SValentin Clement (バレンタイン クレメン) #include "flang/Optimizer/Builder/CUFCommon.h"
286e193b5cSMats Petersson #include "flang/Optimizer/Builder/FIRBuilder.h"
29614cd721SSacha Ballantyne #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
306e193b5cSMats Petersson #include "flang/Optimizer/Builder/Todo.h"
316e193b5cSMats Petersson #include "flang/Optimizer/Dialect/FIROps.h"
326e193b5cSMats Petersson #include "flang/Optimizer/Dialect/FIRType.h"
33b07ef9e7SRenaud-K #include "flang/Optimizer/Dialect/Support/FIRContext.h"
3420fba03fSSacha Ballantyne #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
356e193b5cSMats Petersson #include "flang/Optimizer/Transforms/Passes.h"
36815a8465SDavid Green #include "flang/Optimizer/Transforms/Utils.h"
37aa94eb38SMats Petersson #include "flang/Runtime/entry-names.h"
3867d0d7acSMichele Scuttari #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
396e193b5cSMats Petersson #include "mlir/IR/Matchers.h"
40614cd721SSacha Ballantyne #include "mlir/IR/Operation.h"
416e193b5cSMats Petersson #include "mlir/Pass/Pass.h"
426e193b5cSMats Petersson #include "mlir/Transforms/DialectConversion.h"
436e193b5cSMats Petersson #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
446e193b5cSMats Petersson #include "mlir/Transforms/RegionUtils.h"
451d5e7a49SSlava Zakharin #include "llvm/Support/Debug.h"
4656eda98fSSlava Zakharin #include "llvm/Support/raw_ostream.h"
47614cd721SSacha Ballantyne #include <llvm/Support/ErrorHandling.h>
48614cd721SSacha Ballantyne #include <mlir/Dialect/Arith/IR/Arith.h>
49614cd721SSacha Ballantyne #include <mlir/IR/BuiltinTypes.h>
5020fba03fSSacha Ballantyne #include <mlir/IR/Location.h>
5120fba03fSSacha Ballantyne #include <mlir/IR/MLIRContext.h>
5220fba03fSSacha Ballantyne #include <mlir/IR/Value.h>
5320fba03fSSacha Ballantyne #include <mlir/Support/LLVM.h>
544d4d4785SKazu Hirata #include <optional>
551d5e7a49SSlava Zakharin 
5667d0d7acSMichele Scuttari namespace fir {
5767d0d7acSMichele Scuttari #define GEN_PASS_DEF_SIMPLIFYINTRINSICS
5867d0d7acSMichele Scuttari #include "flang/Optimizer/Transforms/Passes.h.inc"
5967d0d7acSMichele Scuttari } // namespace fir
6067d0d7acSMichele Scuttari 
611d5e7a49SSlava Zakharin #define DEBUG_TYPE "flang-simplify-intrinsics"
626e193b5cSMats Petersson 
636e193b5cSMats Petersson namespace {
646e193b5cSMats Petersson 
656e193b5cSMats Petersson class SimplifyIntrinsicsPass
6667d0d7acSMichele Scuttari     : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
6780dcc907SSlava Zakharin   using FunctionTypeGeneratorTy =
6843159b58SMats Petersson       llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
6980dcc907SSlava Zakharin   using FunctionBodyGeneratorTy =
7043159b58SMats Petersson       llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
7143159b58SMats Petersson   using GenReductionBodyTy = llvm::function_ref<void(
7220fba03fSSacha Ballantyne       fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank,
7320fba03fSSacha Ballantyne       mlir::Type elementType)>;
7480dcc907SSlava Zakharin 
756e193b5cSMats Petersson public:
7681442f8dSTom Eccles   using fir::impl::SimplifyIntrinsicsBase<
7781442f8dSTom Eccles       SimplifyIntrinsicsPass>::SimplifyIntrinsicsBase;
7881442f8dSTom Eccles 
7980dcc907SSlava Zakharin   /// Generate a new function implementing a simplified version
8080dcc907SSlava Zakharin   /// of a Fortran runtime function defined by \p basename name.
8180dcc907SSlava Zakharin   /// \p typeGenerator is a callback that generates the new function's type.
8280dcc907SSlava Zakharin   /// \p bodyGenerator is a callback that generates the new function's body.
8380dcc907SSlava Zakharin   /// The new function is created in the \p builder's Module.
8480dcc907SSlava Zakharin   mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
8580dcc907SSlava Zakharin                                          const mlir::StringRef &basename,
8680dcc907SSlava Zakharin                                          FunctionTypeGeneratorTy typeGenerator,
8780dcc907SSlava Zakharin                                          FunctionBodyGeneratorTy bodyGenerator);
886e193b5cSMats Petersson   void runOnOperation() override;
891d5e7a49SSlava Zakharin   void getDependentDialects(mlir::DialectRegistry &registry) const override;
9043159b58SMats Petersson 
9143159b58SMats Petersson private:
927d2e1987SSacha Ballantyne   /// Helper functions to replace a reduction type of call with its
9343159b58SMats Petersson   /// simplified form. The actual function is generated using a callback
9443159b58SMats Petersson   /// function.
9543159b58SMats Petersson   /// \p call is the call to be replaced
9643159b58SMats Petersson   /// \p kindMap is used to create FIROpBuilder
9743159b58SMats Petersson   /// \p genBodyFunc is the callback that builds the replacement function
987d2e1987SSacha Ballantyne   void simplifyIntOrFloatReduction(fir::CallOp call,
997d2e1987SSacha Ballantyne                                    const fir::KindMapping &kindMap,
10043159b58SMats Petersson                                    GenReductionBodyTy genBodyFunc);
10120fba03fSSacha Ballantyne   void simplifyLogicalDim0Reduction(fir::CallOp call,
10220fba03fSSacha Ballantyne                                     const fir::KindMapping &kindMap,
10320fba03fSSacha Ballantyne                                     GenReductionBodyTy genBodyFunc);
10420fba03fSSacha Ballantyne   void simplifyLogicalDim1Reduction(fir::CallOp call,
1057d2e1987SSacha Ballantyne                                     const fir::KindMapping &kindMap,
1067d2e1987SSacha Ballantyne                                     GenReductionBodyTy genBodyFunc);
1079bb47f7fSDavid Green   void simplifyMinMaxlocReduction(fir::CallOp call,
1089bb47f7fSDavid Green                                   const fir::KindMapping &kindMap, bool isMax);
1097d2e1987SSacha Ballantyne   void simplifyReductionBody(fir::CallOp call, const fir::KindMapping &kindMap,
1107d2e1987SSacha Ballantyne                              GenReductionBodyTy genBodyFunc,
1117d2e1987SSacha Ballantyne                              fir::FirOpBuilder &builder,
11220fba03fSSacha Ballantyne                              const mlir::StringRef &basename,
11320fba03fSSacha Ballantyne                              mlir::Type elementType);
1146e193b5cSMats Petersson };
1156e193b5cSMats Petersson 
1166e193b5cSMats Petersson } // namespace
1176e193b5cSMats Petersson 
118ffe1661fSSlava Zakharin /// Create FirOpBuilder with the provided \p op insertion point
119ffe1661fSSlava Zakharin /// and \p kindMap additionally inheriting FastMathFlags from \p op.
120ffe1661fSSlava Zakharin static fir::FirOpBuilder
121ffe1661fSSlava Zakharin getSimplificationBuilder(mlir::Operation *op, const fir::KindMapping &kindMap) {
122ffe1661fSSlava Zakharin   fir::FirOpBuilder builder{op, kindMap};
123ffe1661fSSlava Zakharin   auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
124ffe1661fSSlava Zakharin   if (!fmi)
125ffe1661fSSlava Zakharin     return builder;
126ffe1661fSSlava Zakharin 
127ffe1661fSSlava Zakharin   // Regardless of what default FastMathFlags are used by FirOpBuilder,
128ffe1661fSSlava Zakharin   // override them with FastMathFlags attached to the operation.
129ffe1661fSSlava Zakharin   builder.setFastMathFlags(fmi.getFastMathFlagsAttr().getValue());
130ffe1661fSSlava Zakharin   return builder;
131ffe1661fSSlava Zakharin }
132ffe1661fSSlava Zakharin 
133aa94eb38SMats Petersson /// Generate function type for the simplified version of RTNAME(Sum) and
134afa520abSMats Petersson /// similar functions with a fir.box<none> type returning \p elementType.
135afa520abSMats Petersson static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
13680dcc907SSlava Zakharin                                          const mlir::Type &elementType) {
13780dcc907SSlava Zakharin   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
13880dcc907SSlava Zakharin   return mlir::FunctionType::get(builder.getContext(), {boxType},
13980dcc907SSlava Zakharin                                  {elementType});
14080dcc907SSlava Zakharin }
1416e193b5cSMats Petersson 
142614cd721SSacha Ballantyne template <typename Op>
143614cd721SSacha Ballantyne Op expectOp(mlir::Value val) {
144614cd721SSacha Ballantyne   if (Op op = mlir::dyn_cast_or_null<Op>(val.getDefiningOp()))
145614cd721SSacha Ballantyne     return op;
146614cd721SSacha Ballantyne   LLVM_DEBUG(llvm::dbgs() << "Didn't find expected " << Op::getOperationName()
147614cd721SSacha Ballantyne                           << '\n');
148614cd721SSacha Ballantyne   return nullptr;
149614cd721SSacha Ballantyne }
150614cd721SSacha Ballantyne 
151614cd721SSacha Ballantyne template <typename Op>
152614cd721SSacha Ballantyne static mlir::Value findDefSingle(fir::ConvertOp op) {
153614cd721SSacha Ballantyne   if (auto defOp = expectOp<Op>(op->getOperand(0))) {
154614cd721SSacha Ballantyne     return defOp.getResult();
155614cd721SSacha Ballantyne   }
156614cd721SSacha Ballantyne   return {};
157614cd721SSacha Ballantyne }
158614cd721SSacha Ballantyne 
159614cd721SSacha Ballantyne template <typename... Ops>
160614cd721SSacha Ballantyne static mlir::Value findDef(fir::ConvertOp op) {
161614cd721SSacha Ballantyne   mlir::Value defOp;
162614cd721SSacha Ballantyne   // Loop over the operation types given to see if any match, exiting once
163614cd721SSacha Ballantyne   // a match is found. Cast to void is needed to avoid compiler complaining
164614cd721SSacha Ballantyne   // that the result of expression is unused
165614cd721SSacha Ballantyne   (void)((defOp = findDefSingle<Ops>(op), (defOp)) || ...);
166614cd721SSacha Ballantyne   return defOp;
167614cd721SSacha Ballantyne }
168614cd721SSacha Ballantyne 
169614cd721SSacha Ballantyne static bool isOperandAbsent(mlir::Value val) {
170614cd721SSacha Ballantyne   if (auto op = expectOp<fir::ConvertOp>(val)) {
171614cd721SSacha Ballantyne     assert(op->getOperands().size() != 0);
172614cd721SSacha Ballantyne     return mlir::isa_and_nonnull<fir::AbsentOp>(
173614cd721SSacha Ballantyne         op->getOperand(0).getDefiningOp());
174614cd721SSacha Ballantyne   }
175614cd721SSacha Ballantyne   return false;
176614cd721SSacha Ballantyne }
177614cd721SSacha Ballantyne 
178614cd721SSacha Ballantyne static bool isTrueOrNotConstant(mlir::Value val) {
179614cd721SSacha Ballantyne   if (auto op = expectOp<mlir::arith::ConstantOp>(val)) {
180614cd721SSacha Ballantyne     return !mlir::matchPattern(val, mlir::m_Zero());
181614cd721SSacha Ballantyne   }
182614cd721SSacha Ballantyne   return true;
183614cd721SSacha Ballantyne }
184614cd721SSacha Ballantyne 
185614cd721SSacha Ballantyne static bool isZero(mlir::Value val) {
186614cd721SSacha Ballantyne   if (auto op = expectOp<fir::ConvertOp>(val)) {
187614cd721SSacha Ballantyne     assert(op->getOperands().size() != 0);
188614cd721SSacha Ballantyne     if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
189614cd721SSacha Ballantyne       return mlir::matchPattern(defOp, mlir::m_Zero());
190614cd721SSacha Ballantyne   }
191614cd721SSacha Ballantyne   return false;
192614cd721SSacha Ballantyne }
193614cd721SSacha Ballantyne 
194614cd721SSacha Ballantyne static mlir::Value findBoxDef(mlir::Value val) {
195614cd721SSacha Ballantyne   if (auto op = expectOp<fir::ConvertOp>(val)) {
196614cd721SSacha Ballantyne     assert(op->getOperands().size() != 0);
197614cd721SSacha Ballantyne     return findDef<fir::EmboxOp, fir::ReboxOp>(op);
198614cd721SSacha Ballantyne   }
199614cd721SSacha Ballantyne   return {};
200614cd721SSacha Ballantyne }
201614cd721SSacha Ballantyne 
202614cd721SSacha Ballantyne static mlir::Value findMaskDef(mlir::Value val) {
203614cd721SSacha Ballantyne   if (auto op = expectOp<fir::ConvertOp>(val)) {
204614cd721SSacha Ballantyne     assert(op->getOperands().size() != 0);
205614cd721SSacha Ballantyne     return findDef<fir::EmboxOp, fir::ReboxOp, fir::AbsentOp>(op);
206614cd721SSacha Ballantyne   }
207614cd721SSacha Ballantyne   return {};
208614cd721SSacha Ballantyne }
209614cd721SSacha Ballantyne 
210614cd721SSacha Ballantyne static unsigned getDimCount(mlir::Value val) {
211614cd721SSacha Ballantyne   // In order to find the dimensions count, we look for EmboxOp/ReboxOp
212614cd721SSacha Ballantyne   // and take the count from its *result* type. Note that in case
213614cd721SSacha Ballantyne   // of sliced emboxing the operand and the result of EmboxOp/ReboxOp
214614cd721SSacha Ballantyne   // have different types.
215614cd721SSacha Ballantyne   // Actually, we can take the box type from the operand of
216614cd721SSacha Ballantyne   // the first ConvertOp that has non-opaque box type that we meet
217614cd721SSacha Ballantyne   // going through the ConvertOp chain.
218614cd721SSacha Ballantyne   if (mlir::Value emboxVal = findBoxDef(val))
219fac349a1SChristian Sigg     if (auto boxTy = mlir::dyn_cast<fir::BoxType>(emboxVal.getType()))
220fac349a1SChristian Sigg       if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy()))
221614cd721SSacha Ballantyne         return seqTy.getDimension();
222614cd721SSacha Ballantyne   return 0;
223614cd721SSacha Ballantyne }
224614cd721SSacha Ballantyne 
225614cd721SSacha Ballantyne /// Given the call operation's box argument \p val, discover
226614cd721SSacha Ballantyne /// the element type of the underlying array object.
227614cd721SSacha Ballantyne /// \returns the element type or std::nullopt if the type cannot
228614cd721SSacha Ballantyne /// be reliably found.
229614cd721SSacha Ballantyne /// We expect that the argument is a result of fir.convert
230614cd721SSacha Ballantyne /// with the destination type of !fir.box<none>.
231614cd721SSacha Ballantyne static std::optional<mlir::Type> getArgElementType(mlir::Value val) {
232614cd721SSacha Ballantyne   mlir::Operation *defOp;
233614cd721SSacha Ballantyne   do {
234614cd721SSacha Ballantyne     defOp = val.getDefiningOp();
235614cd721SSacha Ballantyne     // Analyze only sequences of convert operations.
236614cd721SSacha Ballantyne     if (!mlir::isa<fir::ConvertOp>(defOp))
237614cd721SSacha Ballantyne       return std::nullopt;
238614cd721SSacha Ballantyne     val = defOp->getOperand(0);
239614cd721SSacha Ballantyne     // The convert operation is expected to convert from one
240614cd721SSacha Ballantyne     // box type to another box type.
241fac349a1SChristian Sigg     auto boxType = mlir::cast<fir::BoxType>(val.getType());
242614cd721SSacha Ballantyne     auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType);
243fac349a1SChristian Sigg     if (!mlir::isa<mlir::NoneType>(elementType))
244614cd721SSacha Ballantyne       return elementType;
245614cd721SSacha Ballantyne   } while (true);
246614cd721SSacha Ballantyne }
247614cd721SSacha Ballantyne 
24843159b58SMats Petersson using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
24943159b58SMats Petersson     fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
25043159b58SMats Petersson     mlir::Value)>;
25120fba03fSSacha Ballantyne using ContinueLoopGenTy = llvm::function_ref<llvm::SmallVector<mlir::Value>(
25220fba03fSSacha Ballantyne     fir::FirOpBuilder &, mlir::Location, mlir::Value)>;
253afa520abSMats Petersson 
254afa520abSMats Petersson /// Generate the reduction loop into \p funcOp.
255afa520abSMats Petersson ///
256afa520abSMats Petersson /// \p initVal is a function, called to get the initial value for
257afa520abSMats Petersson ///    the reduction value
258afa520abSMats Petersson /// \p genBody is called to fill in the actual reduciton operation
259afa520abSMats Petersson ///    for example add for SUM, MAX for MAXVAL, etc.
2608bd76ac1SSlava Zakharin /// \p rank is the rank of the input argument.
26120fba03fSSacha Ballantyne /// \p elementType is the type of the elements in the input array,
26220fba03fSSacha Ballantyne ///    which may be different to the return type.
26320fba03fSSacha Ballantyne /// \p loopCond is called to generate the condition to continue or
26420fba03fSSacha Ballantyne ///    not for IterWhile loops
26520fba03fSSacha Ballantyne /// \p unorderedOrInitalLoopCond contains either a boolean or bool
26620fba03fSSacha Ballantyne ///    mlir constant, and controls the inital value for while loops
26720fba03fSSacha Ballantyne ///    or if DoLoop is ordered/unordered.
26820fba03fSSacha Ballantyne 
26920fba03fSSacha Ballantyne template <typename OP, typename T, int resultIndex>
27020fba03fSSacha Ballantyne static void
27120fba03fSSacha Ballantyne genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
272223d3dabSDavid Green                  fir::InitValGeneratorTy initVal, ContinueLoopGenTy loopCond,
27320fba03fSSacha Ballantyne                  T unorderedOrInitialLoopCond, BodyOpGeneratorTy genBody,
27420fba03fSSacha Ballantyne                  unsigned rank, mlir::Type elementType, mlir::Location loc) {
2756e193b5cSMats Petersson 
2766e193b5cSMats Petersson   mlir::IndexType idxTy = builder.getIndexType();
2776e193b5cSMats Petersson 
27880dcc907SSlava Zakharin   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
2796e193b5cSMats Petersson   mlir::Value arg = args[0];
2806e193b5cSMats Petersson 
2816e193b5cSMats Petersson   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
2826e193b5cSMats Petersson 
2838bd76ac1SSlava Zakharin   fir::SequenceType::Shape flatShape(rank,
2848bd76ac1SSlava Zakharin                                      fir::SequenceType::getUnknownExtent());
28580dcc907SSlava Zakharin   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
2866e193b5cSMats Petersson   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
2876e193b5cSMats Petersson   mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
2887d2e1987SSacha Ballantyne   mlir::Type resultType = funcOp.getResultTypes()[0];
2897d2e1987SSacha Ballantyne   mlir::Value init = initVal(builder, loc, resultType);
2906e193b5cSMats Petersson 
291614cd721SSacha Ballantyne   llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
2928bd76ac1SSlava Zakharin 
2938bd76ac1SSlava Zakharin   assert(rank > 0 && "rank cannot be zero");
2948bd76ac1SSlava Zakharin   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
2958bd76ac1SSlava Zakharin 
2968bd76ac1SSlava Zakharin   // Compute all the upper bounds before the loop nest.
2978bd76ac1SSlava Zakharin   // It is not strictly necessary for performance, since the loop nest
2988bd76ac1SSlava Zakharin   // does not have any store operations and any LICM optimization
2998bd76ac1SSlava Zakharin   // should be able to optimize the redundancy.
3008bd76ac1SSlava Zakharin   for (unsigned i = 0; i < rank; ++i) {
3018bd76ac1SSlava Zakharin     mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
3028bd76ac1SSlava Zakharin     auto dims =
3038bd76ac1SSlava Zakharin         builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
3048bd76ac1SSlava Zakharin     mlir::Value len = dims.getResult(1);
3056e193b5cSMats Petersson     // We use C indexing here, so len-1 as loopcount
3066e193b5cSMats Petersson     mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
3078bd76ac1SSlava Zakharin     bounds.push_back(loopCount);
3088bd76ac1SSlava Zakharin   }
30920fba03fSSacha Ballantyne   // Create a loop nest consisting of OP operations.
3108bd76ac1SSlava Zakharin   // Collect the loops' induction variables into indices array,
3118bd76ac1SSlava Zakharin   // which will be used in the innermost loop to load the input
3128bd76ac1SSlava Zakharin   // array's element.
3138bd76ac1SSlava Zakharin   // The loops are generated such that the innermost loop processes
3148bd76ac1SSlava Zakharin   // the 0 dimension.
315614cd721SSacha Ballantyne   llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
3168bd76ac1SSlava Zakharin   for (unsigned i = rank; 0 < i; --i) {
3178bd76ac1SSlava Zakharin     mlir::Value step = one;
3188bd76ac1SSlava Zakharin     mlir::Value loopCount = bounds[i - 1];
31920fba03fSSacha Ballantyne     auto loop = builder.create<OP>(loc, zeroIdx, loopCount, step,
32020fba03fSSacha Ballantyne                                    unorderedOrInitialLoopCond,
321afa520abSMats Petersson                                    /*finalCountValue=*/false, init);
32298ecc3acSSacha Ballantyne     init = loop.getRegionIterArgs()[resultIndex];
3238bd76ac1SSlava Zakharin     indices.push_back(loop.getInductionVar());
3248bd76ac1SSlava Zakharin     // Set insertion point to the loop body so that the next loop
3258bd76ac1SSlava Zakharin     // is inserted inside the current one.
3266e193b5cSMats Petersson     builder.setInsertionPointToStart(loop.getBody());
3278bd76ac1SSlava Zakharin   }
3286e193b5cSMats Petersson 
3298bd76ac1SSlava Zakharin   // Reverse the indices such that they are ordered as:
3308bd76ac1SSlava Zakharin   //   <dim-0-idx, dim-1-idx, ...>
3318bd76ac1SSlava Zakharin   std::reverse(indices.begin(), indices.end());
3328bd76ac1SSlava Zakharin   // We are in the innermost loop: generate the reduction body.
33380dcc907SSlava Zakharin   mlir::Type eleRefTy = builder.getRefType(elementType);
3346e193b5cSMats Petersson   mlir::Value addr =
3358bd76ac1SSlava Zakharin       builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
3366e193b5cSMats Petersson   mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
3378bd76ac1SSlava Zakharin   mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init);
33820fba03fSSacha Ballantyne   // Generate vector with condition to continue while loop at [0] and result
33920fba03fSSacha Ballantyne   // from current loop at [1] for IterWhileOp loops, just result at [0] for
34020fba03fSSacha Ballantyne   // DoLoopOp loops.
34120fba03fSSacha Ballantyne   llvm::SmallVector<mlir::Value> results = loopCond(builder, loc, reductionVal);
3426e193b5cSMats Petersson 
3438bd76ac1SSlava Zakharin   // Unwind the loop nest and insert ResultOp on each level
3448bd76ac1SSlava Zakharin   // to return the updated value of the reduction to the enclosing
3458bd76ac1SSlava Zakharin   // loops.
3468bd76ac1SSlava Zakharin   for (unsigned i = 0; i < rank; ++i) {
34720fba03fSSacha Ballantyne     auto result = builder.create<fir::ResultOp>(loc, results);
3488bd76ac1SSlava Zakharin     // Proceed to the outer loop.
34920fba03fSSacha Ballantyne     auto loop = mlir::cast<OP>(result->getParentOp());
35020fba03fSSacha Ballantyne     results = loop.getResults();
3518bd76ac1SSlava Zakharin     // Set insertion point after the loop operation that we have
3528bd76ac1SSlava Zakharin     // just processed.
3538bd76ac1SSlava Zakharin     builder.setInsertionPointAfter(loop.getOperation());
3548bd76ac1SSlava Zakharin   }
3558bd76ac1SSlava Zakharin   // End of loop nest. The insertion point is after the outermost loop.
3568bd76ac1SSlava Zakharin   // Return the reduction value from the function.
35720fba03fSSacha Ballantyne   builder.create<mlir::func::ReturnOp>(loc, results[resultIndex]);
35820fba03fSSacha Ballantyne }
359614cd721SSacha Ballantyne 
36020fba03fSSacha Ballantyne static llvm::SmallVector<mlir::Value> nopLoopCond(fir::FirOpBuilder &builder,
361614cd721SSacha Ballantyne                                                   mlir::Location loc,
36220fba03fSSacha Ballantyne                                                   mlir::Value reductionVal) {
36320fba03fSSacha Ballantyne   return {reductionVal};
36480dcc907SSlava Zakharin }
36580dcc907SSlava Zakharin 
366aa94eb38SMats Petersson /// Generate function body of the simplified version of RTNAME(Sum)
367afa520abSMats Petersson /// with signature provided by \p funcOp. The caller is responsible
368afa520abSMats Petersson /// for saving/restoring the original insertion point of \p builder.
369afa520abSMats Petersson /// \p funcOp is expected to be empty on entry to this function.
3708bd76ac1SSlava Zakharin /// \p rank specifies the rank of the input argument.
371aa94eb38SMats Petersson static void genRuntimeSumBody(fir::FirOpBuilder &builder,
37220fba03fSSacha Ballantyne                               mlir::func::FuncOp &funcOp, unsigned rank,
37320fba03fSSacha Ballantyne                               mlir::Type elementType) {
3748bd76ac1SSlava Zakharin   // function RTNAME(Sum)<T>x<rank>_simplified(arr)
375afa520abSMats Petersson   //   T, dimension(:) :: arr
376afa520abSMats Petersson   //   T sum = 0
377afa520abSMats Petersson   //   integer iter
378afa520abSMats Petersson   //   do iter = 0, extent(arr)
379afa520abSMats Petersson   //     sum = sum + arr[iter]
380afa520abSMats Petersson   //   end do
3818bd76ac1SSlava Zakharin   //   RTNAME(Sum)<T>x<rank>_simplified = sum
3828bd76ac1SSlava Zakharin   // end function RTNAME(Sum)<T>x<rank>_simplified
383afa520abSMats Petersson   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
384afa520abSMats Petersson                  mlir::Type elementType) {
385fac349a1SChristian Sigg     if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
3862b138567SSlava Zakharin       const llvm::fltSemantics &sem = ty.getFloatSemantics();
3872b138567SSlava Zakharin       return builder.createRealConstant(loc, elementType,
3882b138567SSlava Zakharin                                         llvm::APFloat::getZero(sem));
3892b138567SSlava Zakharin     }
3902b138567SSlava Zakharin     return builder.createIntegerConstant(loc, elementType, 0);
391afa520abSMats Petersson   };
392afa520abSMats Petersson 
393afa520abSMats Petersson   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
394afa520abSMats Petersson                       mlir::Type elementType, mlir::Value elem1,
395afa520abSMats Petersson                       mlir::Value elem2) -> mlir::Value {
396fac349a1SChristian Sigg     if (mlir::isa<mlir::FloatType>(elementType))
397afa520abSMats Petersson       return builder.create<mlir::arith::AddFOp>(loc, elem1, elem2);
398fac349a1SChristian Sigg     if (mlir::isa<mlir::IntegerType>(elementType))
399afa520abSMats Petersson       return builder.create<mlir::arith::AddIOp>(loc, elem1, elem2);
400afa520abSMats Petersson 
401afa520abSMats Petersson     llvm_unreachable("unsupported type");
402afa520abSMats Petersson     return {};
403afa520abSMats Petersson   };
404afa520abSMats Petersson 
40520fba03fSSacha Ballantyne   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
40620fba03fSSacha Ballantyne   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
4077d2e1987SSacha Ballantyne 
40820fba03fSSacha Ballantyne   genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond,
40920fba03fSSacha Ballantyne                                            false, genBodyOp, rank, elementType,
41020fba03fSSacha Ballantyne                                            loc);
411afa520abSMats Petersson }
412afa520abSMats Petersson 
413aa94eb38SMats Petersson static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
41420fba03fSSacha Ballantyne                                  mlir::func::FuncOp &funcOp, unsigned rank,
41520fba03fSSacha Ballantyne                                  mlir::Type elementType) {
416afa520abSMats Petersson   auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
417afa520abSMats Petersson                  mlir::Type elementType) {
418fac349a1SChristian Sigg     if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
419afa520abSMats Petersson       const llvm::fltSemantics &sem = ty.getFloatSemantics();
420afa520abSMats Petersson       return builder.createRealConstant(
421afa520abSMats Petersson           loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/true));
422afa520abSMats Petersson     }
423afa520abSMats Petersson     unsigned bits = elementType.getIntOrFloatBitWidth();
424afa520abSMats Petersson     int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
425afa520abSMats Petersson     return builder.createIntegerConstant(loc, elementType, minInt);
426afa520abSMats Petersson   };
427afa520abSMats Petersson 
428afa520abSMats Petersson   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
429afa520abSMats Petersson                       mlir::Type elementType, mlir::Value elem1,
430afa520abSMats Petersson                       mlir::Value elem2) -> mlir::Value {
431fac349a1SChristian Sigg     if (mlir::isa<mlir::FloatType>(elementType)) {
43289b98c13SSlava Zakharin       // arith.maxf later converted to llvm.intr.maxnum does not work
43389b98c13SSlava Zakharin       // correctly for NaNs and -0.0 (see maxnum/minnum pattern matching
43489b98c13SSlava Zakharin       // in LLVM's InstCombine pass). Moreover, llvm.intr.maxnum
43589b98c13SSlava Zakharin       // for F128 operands is lowered into fmaxl call by LLVM.
43689b98c13SSlava Zakharin       // This libm function may not work properly for F128 arguments
43789b98c13SSlava Zakharin       // on targets where long double is not F128. It is an LLVM issue,
43889b98c13SSlava Zakharin       // but we just use normal select here to resolve all the cases.
43989b98c13SSlava Zakharin       auto compare = builder.create<mlir::arith::CmpFOp>(
44089b98c13SSlava Zakharin           loc, mlir::arith::CmpFPredicate::OGT, elem1, elem2);
44189b98c13SSlava Zakharin       return builder.create<mlir::arith::SelectOp>(loc, compare, elem1, elem2);
44289b98c13SSlava Zakharin     }
443fac349a1SChristian Sigg     if (mlir::isa<mlir::IntegerType>(elementType))
444afa520abSMats Petersson       return builder.create<mlir::arith::MaxSIOp>(loc, elem1, elem2);
445afa520abSMats Petersson 
446afa520abSMats Petersson     llvm_unreachable("unsupported type");
447afa520abSMats Petersson     return {};
448afa520abSMats Petersson   };
4497d2e1987SSacha Ballantyne 
45020fba03fSSacha Ballantyne   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
45120fba03fSSacha Ballantyne   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
4527d2e1987SSacha Ballantyne 
45320fba03fSSacha Ballantyne   genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, init, nopLoopCond,
45420fba03fSSacha Ballantyne                                            false, genBodyOp, rank, elementType,
45520fba03fSSacha Ballantyne                                            loc);
4567d2e1987SSacha Ballantyne }
4577d2e1987SSacha Ballantyne 
4587d2e1987SSacha Ballantyne static void genRuntimeCountBody(fir::FirOpBuilder &builder,
45920fba03fSSacha Ballantyne                                 mlir::func::FuncOp &funcOp, unsigned rank,
46020fba03fSSacha Ballantyne                                 mlir::Type elementType) {
4617d2e1987SSacha Ballantyne   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
4627d2e1987SSacha Ballantyne                  mlir::Type elementType) {
4637d2e1987SSacha Ballantyne     unsigned bits = elementType.getIntOrFloatBitWidth();
4647d2e1987SSacha Ballantyne     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
4657d2e1987SSacha Ballantyne     return builder.createIntegerConstant(loc, elementType, zeroInt);
4667d2e1987SSacha Ballantyne   };
4677d2e1987SSacha Ballantyne 
4687d2e1987SSacha Ballantyne   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
4697d2e1987SSacha Ballantyne                       mlir::Type elementType, mlir::Value elem1,
4707d2e1987SSacha Ballantyne                       mlir::Value elem2) -> mlir::Value {
47179dccdedSSacha Ballantyne     auto zero32 = builder.createIntegerConstant(loc, elementType, 0);
4727d2e1987SSacha Ballantyne     auto zero64 = builder.createIntegerConstant(loc, builder.getI64Type(), 0);
4737d2e1987SSacha Ballantyne     auto one64 = builder.createIntegerConstant(loc, builder.getI64Type(), 1);
4747d2e1987SSacha Ballantyne 
4757d2e1987SSacha Ballantyne     auto compare = builder.create<mlir::arith::CmpIOp>(
4767d2e1987SSacha Ballantyne         loc, mlir::arith::CmpIPredicate::eq, elem1, zero32);
4777d2e1987SSacha Ballantyne     auto select =
4787d2e1987SSacha Ballantyne         builder.create<mlir::arith::SelectOp>(loc, compare, zero64, one64);
4797d2e1987SSacha Ballantyne     return builder.create<mlir::arith::AddIOp>(loc, select, elem2);
4807d2e1987SSacha Ballantyne   };
4817d2e1987SSacha Ballantyne 
48220fba03fSSacha Ballantyne   // Count always gets I32 for elementType as it converts logical input to
48320fba03fSSacha Ballantyne   // logical<4> before passing to the function.
48420fba03fSSacha Ballantyne   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
48520fba03fSSacha Ballantyne   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
4867d2e1987SSacha Ballantyne 
48720fba03fSSacha Ballantyne   genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond,
48820fba03fSSacha Ballantyne                                            false, genBodyOp, rank, elementType,
48920fba03fSSacha Ballantyne                                            loc);
49020fba03fSSacha Ballantyne }
49120fba03fSSacha Ballantyne 
49220fba03fSSacha Ballantyne static void genRuntimeAnyBody(fir::FirOpBuilder &builder,
49320fba03fSSacha Ballantyne                               mlir::func::FuncOp &funcOp, unsigned rank,
49420fba03fSSacha Ballantyne                               mlir::Type elementType) {
49520fba03fSSacha Ballantyne   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
49620fba03fSSacha Ballantyne                  mlir::Type elementType) {
49720fba03fSSacha Ballantyne     return builder.createIntegerConstant(loc, elementType, 0);
49820fba03fSSacha Ballantyne   };
49920fba03fSSacha Ballantyne 
50020fba03fSSacha Ballantyne   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
50120fba03fSSacha Ballantyne                       mlir::Type elementType, mlir::Value elem1,
50220fba03fSSacha Ballantyne                       mlir::Value elem2) -> mlir::Value {
50320fba03fSSacha Ballantyne     auto zero = builder.createIntegerConstant(loc, elementType, 0);
50420fba03fSSacha Ballantyne     return builder.create<mlir::arith::CmpIOp>(
50520fba03fSSacha Ballantyne         loc, mlir::arith::CmpIPredicate::ne, elem1, zero);
50620fba03fSSacha Ballantyne   };
50720fba03fSSacha Ballantyne 
50820fba03fSSacha Ballantyne   auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc,
50920fba03fSSacha Ballantyne                          mlir::Value reductionVal) {
51020fba03fSSacha Ballantyne     auto one1 = builder.createIntegerConstant(loc, builder.getI1Type(), 1);
51120fba03fSSacha Ballantyne     auto eor = builder.create<mlir::arith::XOrIOp>(loc, reductionVal, one1);
51220fba03fSSacha Ballantyne     llvm::SmallVector<mlir::Value> results = {eor, reductionVal};
51320fba03fSSacha Ballantyne     return results;
51420fba03fSSacha Ballantyne   };
51520fba03fSSacha Ballantyne 
51620fba03fSSacha Ballantyne   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
51720fba03fSSacha Ballantyne   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
51820fba03fSSacha Ballantyne   mlir::Value ok = builder.createBool(loc, true);
51920fba03fSSacha Ballantyne 
52020fba03fSSacha Ballantyne   genReductionLoop<fir::IterWhileOp, mlir::Value, 1>(
52120fba03fSSacha Ballantyne       builder, funcOp, zero, continueCond, ok, genBodyOp, rank, elementType,
52220fba03fSSacha Ballantyne       loc);
52320fba03fSSacha Ballantyne }
52420fba03fSSacha Ballantyne 
52520fba03fSSacha Ballantyne static void genRuntimeAllBody(fir::FirOpBuilder &builder,
52620fba03fSSacha Ballantyne                               mlir::func::FuncOp &funcOp, unsigned rank,
52720fba03fSSacha Ballantyne                               mlir::Type elementType) {
52820fba03fSSacha Ballantyne   auto one = [](fir::FirOpBuilder builder, mlir::Location loc,
52920fba03fSSacha Ballantyne                 mlir::Type elementType) {
53020fba03fSSacha Ballantyne     return builder.createIntegerConstant(loc, elementType, 1);
53120fba03fSSacha Ballantyne   };
53220fba03fSSacha Ballantyne 
53320fba03fSSacha Ballantyne   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
53420fba03fSSacha Ballantyne                       mlir::Type elementType, mlir::Value elem1,
53520fba03fSSacha Ballantyne                       mlir::Value elem2) -> mlir::Value {
53620fba03fSSacha Ballantyne     auto zero = builder.createIntegerConstant(loc, elementType, 0);
53720fba03fSSacha Ballantyne     return builder.create<mlir::arith::CmpIOp>(
53820fba03fSSacha Ballantyne         loc, mlir::arith::CmpIPredicate::ne, elem1, zero);
53920fba03fSSacha Ballantyne   };
54020fba03fSSacha Ballantyne 
54120fba03fSSacha Ballantyne   auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc,
54220fba03fSSacha Ballantyne                          mlir::Value reductionVal) {
54320fba03fSSacha Ballantyne     llvm::SmallVector<mlir::Value> results = {reductionVal, reductionVal};
54420fba03fSSacha Ballantyne     return results;
54520fba03fSSacha Ballantyne   };
54620fba03fSSacha Ballantyne 
54720fba03fSSacha Ballantyne   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
54820fba03fSSacha Ballantyne   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
54920fba03fSSacha Ballantyne   mlir::Value ok = builder.createBool(loc, true);
55020fba03fSSacha Ballantyne 
55120fba03fSSacha Ballantyne   genReductionLoop<fir::IterWhileOp, mlir::Value, 1>(
55220fba03fSSacha Ballantyne       builder, funcOp, one, continueCond, ok, genBodyOp, rank, elementType,
55320fba03fSSacha Ballantyne       loc);
554afa520abSMats Petersson }
555afa520abSMats Petersson 
556614cd721SSacha Ballantyne static mlir::FunctionType genRuntimeMinlocType(fir::FirOpBuilder &builder,
557614cd721SSacha Ballantyne                                                unsigned int rank) {
558614cd721SSacha Ballantyne   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
559614cd721SSacha Ballantyne   mlir::Type boxRefType = builder.getRefType(boxType);
560614cd721SSacha Ballantyne 
561614cd721SSacha Ballantyne   return mlir::FunctionType::get(builder.getContext(),
562614cd721SSacha Ballantyne                                  {boxRefType, boxType, boxType}, {});
563614cd721SSacha Ballantyne }
564614cd721SSacha Ballantyne 
565815a8465SDavid Green // Produces a loop nest for a Minloc intrinsic.
566815a8465SDavid Green void fir::genMinMaxlocReductionLoop(
567815a8465SDavid Green     fir::FirOpBuilder &builder, mlir::Value array,
568815a8465SDavid Green     fir::InitValGeneratorTy initVal, fir::MinlocBodyOpGeneratorTy genBody,
569815a8465SDavid Green     fir::AddrGeneratorTy getAddrFn, unsigned rank, mlir::Type elementType,
570815a8465SDavid Green     mlir::Location loc, mlir::Type maskElemType, mlir::Value resultArr,
571815a8465SDavid Green     bool maskMayBeLogicalScalar) {
572815a8465SDavid Green   mlir::IndexType idxTy = builder.getIndexType();
573815a8465SDavid Green 
574815a8465SDavid Green   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
575815a8465SDavid Green 
576815a8465SDavid Green   fir::SequenceType::Shape flatShape(rank,
577815a8465SDavid Green                                      fir::SequenceType::getUnknownExtent());
578815a8465SDavid Green   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
579815a8465SDavid Green   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
580815a8465SDavid Green   array = builder.create<fir::ConvertOp>(loc, boxArrTy, array);
581815a8465SDavid Green 
582815a8465SDavid Green   mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType());
583815a8465SDavid Green   mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
584815a8465SDavid Green   mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0);
585815a8465SDavid Green   mlir::Value flagRef = builder.createTemporary(loc, resultElemType);
586815a8465SDavid Green   builder.create<fir::StoreOp>(loc, zero, flagRef);
587815a8465SDavid Green 
588815a8465SDavid Green   mlir::Value init = initVal(builder, loc, elementType);
589815a8465SDavid Green   llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
590815a8465SDavid Green 
591815a8465SDavid Green   assert(rank > 0 && "rank cannot be zero");
592815a8465SDavid Green   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
593815a8465SDavid Green 
594815a8465SDavid Green   // Compute all the upper bounds before the loop nest.
595815a8465SDavid Green   // It is not strictly necessary for performance, since the loop nest
596815a8465SDavid Green   // does not have any store operations and any LICM optimization
597815a8465SDavid Green   // should be able to optimize the redundancy.
598815a8465SDavid Green   for (unsigned i = 0; i < rank; ++i) {
599815a8465SDavid Green     mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
600815a8465SDavid Green     auto dims =
601815a8465SDavid Green         builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
602815a8465SDavid Green     mlir::Value len = dims.getResult(1);
603815a8465SDavid Green     // We use C indexing here, so len-1 as loopcount
604815a8465SDavid Green     mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
605815a8465SDavid Green     bounds.push_back(loopCount);
606815a8465SDavid Green   }
607815a8465SDavid Green   // Create a loop nest consisting of OP operations.
608815a8465SDavid Green   // Collect the loops' induction variables into indices array,
609815a8465SDavid Green   // which will be used in the innermost loop to load the input
610815a8465SDavid Green   // array's element.
611815a8465SDavid Green   // The loops are generated such that the innermost loop processes
612815a8465SDavid Green   // the 0 dimension.
613815a8465SDavid Green   llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
614815a8465SDavid Green   for (unsigned i = rank; 0 < i; --i) {
615815a8465SDavid Green     mlir::Value step = one;
616815a8465SDavid Green     mlir::Value loopCount = bounds[i - 1];
617815a8465SDavid Green     auto loop =
618815a8465SDavid Green         builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false,
619815a8465SDavid Green                                       /*finalCountValue=*/false, init);
620815a8465SDavid Green     init = loop.getRegionIterArgs()[0];
621815a8465SDavid Green     indices.push_back(loop.getInductionVar());
622815a8465SDavid Green     // Set insertion point to the loop body so that the next loop
623815a8465SDavid Green     // is inserted inside the current one.
624815a8465SDavid Green     builder.setInsertionPointToStart(loop.getBody());
625815a8465SDavid Green   }
626815a8465SDavid Green 
627815a8465SDavid Green   // Reverse the indices such that they are ordered as:
628815a8465SDavid Green   //   <dim-0-idx, dim-1-idx, ...>
629815a8465SDavid Green   std::reverse(indices.begin(), indices.end());
630815a8465SDavid Green   mlir::Value reductionVal =
631815a8465SDavid Green       genBody(builder, loc, elementType, array, flagRef, init, indices);
632815a8465SDavid Green 
633815a8465SDavid Green   // Unwind the loop nest and insert ResultOp on each level
634815a8465SDavid Green   // to return the updated value of the reduction to the enclosing
635815a8465SDavid Green   // loops.
636815a8465SDavid Green   for (unsigned i = 0; i < rank; ++i) {
637815a8465SDavid Green     auto result = builder.create<fir::ResultOp>(loc, reductionVal);
638815a8465SDavid Green     // Proceed to the outer loop.
639815a8465SDavid Green     auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
640815a8465SDavid Green     reductionVal = loop.getResult(0);
641815a8465SDavid Green     // Set insertion point after the loop operation that we have
642815a8465SDavid Green     // just processed.
643815a8465SDavid Green     builder.setInsertionPointAfter(loop.getOperation());
644815a8465SDavid Green   }
645815a8465SDavid Green   // End of loop nest. The insertion point is after the outermost loop.
646815a8465SDavid Green   if (maskMayBeLogicalScalar) {
647815a8465SDavid Green     if (fir::IfOp ifOp =
648815a8465SDavid Green             mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) {
649815a8465SDavid Green       builder.create<fir::ResultOp>(loc, reductionVal);
650815a8465SDavid Green       builder.setInsertionPointAfter(ifOp);
651815a8465SDavid Green       // Redefine flagSet to escape scope of ifOp
652815a8465SDavid Green       flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
653815a8465SDavid Green       reductionVal = ifOp.getResult(0);
654815a8465SDavid Green     }
655815a8465SDavid Green   }
656815a8465SDavid Green }
657815a8465SDavid Green 
6589bb47f7fSDavid Green static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
6599bb47f7fSDavid Green                                     mlir::func::FuncOp &funcOp, bool isMax,
6609bb47f7fSDavid Green                                     unsigned rank, int maskRank,
6619bb47f7fSDavid Green                                     mlir::Type elementType,
662614cd721SSacha Ballantyne                                     mlir::Type maskElemType,
6632a95fe48SDavid Green                                     mlir::Type resultElemTy, bool isDim) {
6649bb47f7fSDavid Green   auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
665614cd721SSacha Ballantyne                       mlir::Type elementType) {
666fac349a1SChristian Sigg     if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
667614cd721SSacha Ballantyne       const llvm::fltSemantics &sem = ty.getFloatSemantics();
66872428962SDavid Green       llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
66972428962SDavid Green       return builder.createRealConstant(loc, elementType, limit);
670614cd721SSacha Ballantyne     }
671614cd721SSacha Ballantyne     unsigned bits = elementType.getIntOrFloatBitWidth();
6729bb47f7fSDavid Green     int64_t initValue = (isMax ? llvm::APInt::getSignedMinValue(bits)
6739bb47f7fSDavid Green                                : llvm::APInt::getSignedMaxValue(bits))
6749bb47f7fSDavid Green                             .getSExtValue();
6759bb47f7fSDavid Green     return builder.createIntegerConstant(loc, elementType, initValue);
676614cd721SSacha Ballantyne   };
677614cd721SSacha Ballantyne 
678614cd721SSacha Ballantyne   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
679614cd721SSacha Ballantyne   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
680614cd721SSacha Ballantyne 
681614cd721SSacha Ballantyne   mlir::Value mask = funcOp.front().getArgument(2);
682614cd721SSacha Ballantyne 
683614cd721SSacha Ballantyne   // Set up result array in case of early exit / 0 length array
684614cd721SSacha Ballantyne   mlir::IndexType idxTy = builder.getIndexType();
685614cd721SSacha Ballantyne   mlir::Type resultTy = fir::SequenceType::get(rank, resultElemTy);
686614cd721SSacha Ballantyne   mlir::Type resultHeapTy = fir::HeapType::get(resultTy);
687614cd721SSacha Ballantyne   mlir::Type resultBoxTy = fir::BoxType::get(resultHeapTy);
688614cd721SSacha Ballantyne 
689614cd721SSacha Ballantyne   mlir::Value returnValue = builder.createIntegerConstant(loc, resultElemTy, 0);
690614cd721SSacha Ballantyne   mlir::Value resultArrSize = builder.createIntegerConstant(loc, idxTy, rank);
691614cd721SSacha Ballantyne 
692614cd721SSacha Ballantyne   mlir::Value resultArrInit = builder.create<fir::AllocMemOp>(loc, resultTy);
693614cd721SSacha Ballantyne   mlir::Value resultArrShape = builder.create<fir::ShapeOp>(loc, resultArrSize);
694614cd721SSacha Ballantyne   mlir::Value resultArr = builder.create<fir::EmboxOp>(
695614cd721SSacha Ballantyne       loc, resultBoxTy, resultArrInit, resultArrShape);
696614cd721SSacha Ballantyne 
697614cd721SSacha Ballantyne   mlir::Type resultRefTy = builder.getRefType(resultElemTy);
698614cd721SSacha Ballantyne 
699223d3dabSDavid Green   if (maskRank > 0) {
700223d3dabSDavid Green     fir::SequenceType::Shape flatShape(rank,
701223d3dabSDavid Green                                        fir::SequenceType::getUnknownExtent());
702223d3dabSDavid Green     mlir::Type maskTy = fir::SequenceType::get(flatShape, maskElemType);
703223d3dabSDavid Green     mlir::Type boxMaskTy = fir::BoxType::get(maskTy);
704223d3dabSDavid Green     mask = builder.create<fir::ConvertOp>(loc, boxMaskTy, mask);
705223d3dabSDavid Green   }
706223d3dabSDavid Green 
707614cd721SSacha Ballantyne   for (unsigned int i = 0; i < rank; ++i) {
708614cd721SSacha Ballantyne     mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
709614cd721SSacha Ballantyne     mlir::Value resultElemAddr =
710614cd721SSacha Ballantyne         builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr, index);
711614cd721SSacha Ballantyne     builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr);
712614cd721SSacha Ballantyne   }
713614cd721SSacha Ballantyne 
714614cd721SSacha Ballantyne   auto genBodyOp =
715223d3dabSDavid Green       [&rank, &resultArr, isMax, &mask, &maskElemType, &maskRank](
716223d3dabSDavid Green           fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType,
717223d3dabSDavid Green           mlir::Value array, mlir::Value flagRef, mlir::Value reduction,
718223d3dabSDavid Green           const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
719223d3dabSDavid Green     // We are in the innermost loop: generate the reduction body.
720223d3dabSDavid Green     if (maskRank > 0) {
721223d3dabSDavid Green       mlir::Type logicalRef = builder.getRefType(maskElemType);
722223d3dabSDavid Green       mlir::Value maskAddr =
723223d3dabSDavid Green           builder.create<fir::CoordinateOp>(loc, logicalRef, mask, indices);
724223d3dabSDavid Green       mlir::Value maskElem = builder.create<fir::LoadOp>(loc, maskAddr);
725223d3dabSDavid Green 
726223d3dabSDavid Green       // fir::IfOp requires argument to be I1 - won't accept logical or any
727223d3dabSDavid Green       // other Integer.
728223d3dabSDavid Green       mlir::Type ifCompatType = builder.getI1Type();
729223d3dabSDavid Green       mlir::Value ifCompatElem =
730223d3dabSDavid Green           builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem);
731223d3dabSDavid Green 
732223d3dabSDavid Green       llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
733223d3dabSDavid Green       fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, ifCompatElem,
734223d3dabSDavid Green                                                  /*withElseRegion=*/true);
735223d3dabSDavid Green       builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
736223d3dabSDavid Green     }
737223d3dabSDavid Green 
738223d3dabSDavid Green     // Set flag that mask was true at some point
739223d3dabSDavid Green     mlir::Value flagSet = builder.createIntegerConstant(
740223d3dabSDavid Green         loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
74172428962SDavid Green     mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef);
742223d3dabSDavid Green     mlir::Type eleRefTy = builder.getRefType(elementType);
743223d3dabSDavid Green     mlir::Value addr =
744223d3dabSDavid Green         builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
745223d3dabSDavid Green     mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
746223d3dabSDavid Green 
747614cd721SSacha Ballantyne     mlir::Value cmp;
748fac349a1SChristian Sigg     if (mlir::isa<mlir::FloatType>(elementType)) {
74972428962SDavid Green       // For FP reductions we want the first smallest value to be used, that
75072428962SDavid Green       // is not NaN. A OGL/OLT condition will usually work for this unless all
75172428962SDavid Green       // the values are Nan or Inf. This follows the same logic as
75272428962SDavid Green       // NumericCompare for Minloc/Maxlox in extrema.cpp.
753614cd721SSacha Ballantyne       cmp = builder.create<mlir::arith::CmpFOp>(
7549bb47f7fSDavid Green           loc,
7559bb47f7fSDavid Green           isMax ? mlir::arith::CmpFPredicate::OGT
7569bb47f7fSDavid Green                 : mlir::arith::CmpFPredicate::OLT,
757223d3dabSDavid Green           elem, reduction);
75872428962SDavid Green 
75972428962SDavid Green       mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
76072428962SDavid Green           loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
76172428962SDavid Green       mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
76272428962SDavid Green           loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
76372428962SDavid Green       cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
76472428962SDavid Green       cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
765fac349a1SChristian Sigg     } else if (mlir::isa<mlir::IntegerType>(elementType)) {
766614cd721SSacha Ballantyne       cmp = builder.create<mlir::arith::CmpIOp>(
7679bb47f7fSDavid Green           loc,
7689bb47f7fSDavid Green           isMax ? mlir::arith::CmpIPredicate::sgt
7699bb47f7fSDavid Green                 : mlir::arith::CmpIPredicate::slt,
770223d3dabSDavid Green           elem, reduction);
771614cd721SSacha Ballantyne     } else {
772614cd721SSacha Ballantyne       llvm_unreachable("unsupported type");
773614cd721SSacha Ballantyne     }
774614cd721SSacha Ballantyne 
77572428962SDavid Green     // The condition used for the loop is isFirst || <the condition above>.
77672428962SDavid Green     isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst);
77772428962SDavid Green     isFirst = builder.create<mlir::arith::XOrIOp>(
77872428962SDavid Green         loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1));
77972428962SDavid Green     cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst);
780614cd721SSacha Ballantyne     fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
781614cd721SSacha Ballantyne                                                /*withElseRegion*/ true);
782614cd721SSacha Ballantyne 
783614cd721SSacha Ballantyne     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
78472428962SDavid Green     builder.create<fir::StoreOp>(loc, flagSet, flagRef);
785614cd721SSacha Ballantyne     mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
786614cd721SSacha Ballantyne     mlir::Type returnRefTy = builder.getRefType(resultElemTy);
787614cd721SSacha Ballantyne     mlir::IndexType idxTy = builder.getIndexType();
788614cd721SSacha Ballantyne 
789614cd721SSacha Ballantyne     mlir::Value one = builder.createIntegerConstant(loc, resultElemTy, 1);
790614cd721SSacha Ballantyne 
791614cd721SSacha Ballantyne     for (unsigned int i = 0; i < rank; ++i) {
792614cd721SSacha Ballantyne       mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
793614cd721SSacha Ballantyne       mlir::Value resultElemAddr =
794614cd721SSacha Ballantyne           builder.create<fir::CoordinateOp>(loc, returnRefTy, resultArr, index);
795614cd721SSacha Ballantyne       mlir::Value convert =
796614cd721SSacha Ballantyne           builder.create<fir::ConvertOp>(loc, resultElemTy, indices[i]);
797614cd721SSacha Ballantyne       mlir::Value fortranIndex =
798614cd721SSacha Ballantyne           builder.create<mlir::arith::AddIOp>(loc, convert, one);
799614cd721SSacha Ballantyne       builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr);
800614cd721SSacha Ballantyne     }
801223d3dabSDavid Green     builder.create<fir::ResultOp>(loc, elem);
802614cd721SSacha Ballantyne     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
803223d3dabSDavid Green     builder.create<fir::ResultOp>(loc, reduction);
804614cd721SSacha Ballantyne     builder.setInsertionPointAfter(ifOp);
805223d3dabSDavid Green     mlir::Value reductionVal = ifOp.getResult(0);
806223d3dabSDavid Green 
807223d3dabSDavid Green     // Close the mask if needed
808223d3dabSDavid Green     if (maskRank > 0) {
809223d3dabSDavid Green       fir::IfOp ifOp =
810223d3dabSDavid Green           mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp());
811223d3dabSDavid Green       builder.create<fir::ResultOp>(loc, reductionVal);
812223d3dabSDavid Green       builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
813223d3dabSDavid Green       builder.create<fir::ResultOp>(loc, reduction);
814223d3dabSDavid Green       reductionVal = ifOp.getResult(0);
815223d3dabSDavid Green       builder.setInsertionPointAfter(ifOp);
816223d3dabSDavid Green     }
817223d3dabSDavid Green 
818223d3dabSDavid Green     return reductionVal;
819614cd721SSacha Ballantyne   };
820614cd721SSacha Ballantyne 
821614cd721SSacha Ballantyne   // if mask is a logical scalar, we can check its value before the main loop
822614cd721SSacha Ballantyne   // and either ignore the fact it is there or exit early.
823614cd721SSacha Ballantyne   if (maskRank == 0) {
824614cd721SSacha Ballantyne     mlir::Type logical = builder.getI1Type();
825614cd721SSacha Ballantyne     mlir::IndexType idxTy = builder.getIndexType();
826614cd721SSacha Ballantyne 
827614cd721SSacha Ballantyne     fir::SequenceType::Shape singleElement(1, 1);
828614cd721SSacha Ballantyne     mlir::Type arrTy = fir::SequenceType::get(singleElement, logical);
829614cd721SSacha Ballantyne     mlir::Type boxArrTy = fir::BoxType::get(arrTy);
830614cd721SSacha Ballantyne     mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, mask);
831614cd721SSacha Ballantyne 
832614cd721SSacha Ballantyne     mlir::Value indx = builder.createIntegerConstant(loc, idxTy, 0);
833614cd721SSacha Ballantyne     mlir::Type logicalRefTy = builder.getRefType(logical);
834614cd721SSacha Ballantyne     mlir::Value condAddr =
835614cd721SSacha Ballantyne         builder.create<fir::CoordinateOp>(loc, logicalRefTy, array, indx);
836614cd721SSacha Ballantyne     mlir::Value cond = builder.create<fir::LoadOp>(loc, condAddr);
837614cd721SSacha Ballantyne 
838614cd721SSacha Ballantyne     fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cond,
839614cd721SSacha Ballantyne                                                /*withElseRegion=*/true);
840614cd721SSacha Ballantyne 
841614cd721SSacha Ballantyne     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
842614cd721SSacha Ballantyne     mlir::Value basicValue;
843fac349a1SChristian Sigg     if (mlir::isa<mlir::IntegerType>(elementType)) {
844614cd721SSacha Ballantyne       basicValue = builder.createIntegerConstant(loc, elementType, 0);
845614cd721SSacha Ballantyne     } else {
846614cd721SSacha Ballantyne       basicValue = builder.createRealConstant(loc, elementType, 0);
847614cd721SSacha Ballantyne     }
848614cd721SSacha Ballantyne     builder.create<fir::ResultOp>(loc, basicValue);
849614cd721SSacha Ballantyne 
850614cd721SSacha Ballantyne     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
851614cd721SSacha Ballantyne   }
852223d3dabSDavid Green   auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
853223d3dabSDavid Green                       const mlir::Type &resultElemType, mlir::Value resultArr,
854223d3dabSDavid Green                       mlir::Value index) {
855223d3dabSDavid Green     mlir::Type resultRefTy = builder.getRefType(resultElemType);
856223d3dabSDavid Green     return builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr,
857223d3dabSDavid Green                                              index);
858223d3dabSDavid Green   };
859614cd721SSacha Ballantyne 
860223d3dabSDavid Green   genMinMaxlocReductionLoop(builder, funcOp.front().getArgument(1), init,
861223d3dabSDavid Green                             genBodyOp, getAddrFn, rank, elementType, loc,
862223d3dabSDavid Green                             maskElemType, resultArr, maskRank == 0);
863223d3dabSDavid Green 
864223d3dabSDavid Green   // Store newly created output array to the reference passed in
8652a95fe48SDavid Green   if (isDim) {
8662a95fe48SDavid Green     mlir::Type resultBoxTy =
8672a95fe48SDavid Green         fir::BoxType::get(fir::HeapType::get(resultElemTy));
8682a95fe48SDavid Green     mlir::Value outputArr = builder.create<fir::ConvertOp>(
8692a95fe48SDavid Green         loc, builder.getRefType(resultBoxTy), funcOp.front().getArgument(0));
8702a95fe48SDavid Green     mlir::Value resultArrScalar = builder.create<fir::ConvertOp>(
8712a95fe48SDavid Green         loc, fir::HeapType::get(resultElemTy), resultArrInit);
8722a95fe48SDavid Green     mlir::Value resultBox =
8732a95fe48SDavid Green         builder.create<fir::EmboxOp>(loc, resultBoxTy, resultArrScalar);
8742a95fe48SDavid Green     builder.create<fir::StoreOp>(loc, resultBox, outputArr);
8752a95fe48SDavid Green   } else {
876223d3dabSDavid Green     fir::SequenceType::Shape resultShape(1, rank);
877223d3dabSDavid Green     mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemTy);
878223d3dabSDavid Green     mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy);
879223d3dabSDavid Green     mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy);
880223d3dabSDavid Green     mlir::Type outputRefTy = builder.getRefType(outputBoxTy);
881223d3dabSDavid Green     mlir::Value outputArr = builder.create<fir::ConvertOp>(
882223d3dabSDavid Green         loc, outputRefTy, funcOp.front().getArgument(0));
883223d3dabSDavid Green     builder.create<fir::StoreOp>(loc, resultArr, outputArr);
8842a95fe48SDavid Green   }
8852a95fe48SDavid Green 
886223d3dabSDavid Green   builder.create<mlir::func::ReturnOp>(loc);
887614cd721SSacha Ballantyne }
888614cd721SSacha Ballantyne 
889aa94eb38SMats Petersson /// Generate function type for the simplified version of RTNAME(DotProduct)
8901d5e7a49SSlava Zakharin /// operating on the given \p elementType.
891aa94eb38SMats Petersson static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder,
8921d5e7a49SSlava Zakharin                                             const mlir::Type &elementType) {
8931d5e7a49SSlava Zakharin   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
8941d5e7a49SSlava Zakharin   return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
8951d5e7a49SSlava Zakharin                                  {elementType});
8961d5e7a49SSlava Zakharin }
8971d5e7a49SSlava Zakharin 
898aa94eb38SMats Petersson /// Generate function body of the simplified version of RTNAME(DotProduct)
8991d5e7a49SSlava Zakharin /// with signature provided by \p funcOp. The caller is responsible
9001d5e7a49SSlava Zakharin /// for saving/restoring the original insertion point of \p builder.
9011d5e7a49SSlava Zakharin /// \p funcOp is expected to be empty on entry to this function.
90256eda98fSSlava Zakharin /// \p arg1ElementTy and \p arg2ElementTy specify elements types
90356eda98fSSlava Zakharin /// of the underlying array objects - they are used to generate proper
90456eda98fSSlava Zakharin /// element accesses.
905aa94eb38SMats Petersson static void genRuntimeDotBody(fir::FirOpBuilder &builder,
90656eda98fSSlava Zakharin                               mlir::func::FuncOp &funcOp,
90756eda98fSSlava Zakharin                               mlir::Type arg1ElementTy,
90856eda98fSSlava Zakharin                               mlir::Type arg2ElementTy) {
909aa94eb38SMats Petersson   // function RTNAME(DotProduct)<T>_simplified(arr1, arr2)
9101d5e7a49SSlava Zakharin   //   T, dimension(:) :: arr1, arr2
9111d5e7a49SSlava Zakharin   //   T product = 0
9121d5e7a49SSlava Zakharin   //   integer iter
9131d5e7a49SSlava Zakharin   //   do iter = 0, extent(arr1)
9141d5e7a49SSlava Zakharin   //     product = product + arr1[iter] * arr2[iter]
9151d5e7a49SSlava Zakharin   //   end do
916aa94eb38SMats Petersson   //   RTNAME(ADotProduct)<T>_simplified = product
917aa94eb38SMats Petersson   // end function RTNAME(DotProduct)<T>_simplified
9181d5e7a49SSlava Zakharin   auto loc = mlir::UnknownLoc::get(builder.getContext());
91956eda98fSSlava Zakharin   mlir::Type resultElementType = funcOp.getResultTypes()[0];
9201d5e7a49SSlava Zakharin   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
9211d5e7a49SSlava Zakharin 
9221d5e7a49SSlava Zakharin   mlir::IndexType idxTy = builder.getIndexType();
9231d5e7a49SSlava Zakharin 
92456eda98fSSlava Zakharin   mlir::Value zero =
925fac349a1SChristian Sigg       mlir::isa<mlir::FloatType>(resultElementType)
92656eda98fSSlava Zakharin           ? builder.createRealConstant(loc, resultElementType, 0.0)
92756eda98fSSlava Zakharin           : builder.createIntegerConstant(loc, resultElementType, 0);
9281d5e7a49SSlava Zakharin 
9291d5e7a49SSlava Zakharin   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
9301d5e7a49SSlava Zakharin   mlir::Value arg1 = args[0];
9311d5e7a49SSlava Zakharin   mlir::Value arg2 = args[1];
9321d5e7a49SSlava Zakharin 
9331d5e7a49SSlava Zakharin   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
9341d5e7a49SSlava Zakharin 
9351d5e7a49SSlava Zakharin   fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
93656eda98fSSlava Zakharin   mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy);
93756eda98fSSlava Zakharin   mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1);
93856eda98fSSlava Zakharin   mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1);
93956eda98fSSlava Zakharin   mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy);
94056eda98fSSlava Zakharin   mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2);
94156eda98fSSlava Zakharin   mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2);
9421d5e7a49SSlava Zakharin   // This version takes the loop trip count from the first argument.
9431d5e7a49SSlava Zakharin   // If the first argument's box has unknown (at compilation time)
9441d5e7a49SSlava Zakharin   // extent, then it may be better to take the extent from the second
9451d5e7a49SSlava Zakharin   // argument - so that after inlining the loop may be better optimized, e.g.
9461d5e7a49SSlava Zakharin   // fully unrolled. This requires generating two versions of the simplified
9471d5e7a49SSlava Zakharin   // function and some analysis at the call site to choose which version
9481d5e7a49SSlava Zakharin   // is more profitable to call.
9491d5e7a49SSlava Zakharin   // Note that we can assume that both arguments have the same extent.
9501d5e7a49SSlava Zakharin   auto dims =
9511d5e7a49SSlava Zakharin       builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx);
9521d5e7a49SSlava Zakharin   mlir::Value len = dims.getResult(1);
9531d5e7a49SSlava Zakharin   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
9541d5e7a49SSlava Zakharin   mlir::Value step = one;
9551d5e7a49SSlava Zakharin 
9561d5e7a49SSlava Zakharin   // We use C indexing here, so len-1 as loopcount
9571d5e7a49SSlava Zakharin   mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
9581d5e7a49SSlava Zakharin   auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
9591d5e7a49SSlava Zakharin                                             /*unordered=*/false,
9601d5e7a49SSlava Zakharin                                             /*finalCountValue=*/false, zero);
9611d5e7a49SSlava Zakharin   mlir::Value sumVal = loop.getRegionIterArgs()[0];
9621d5e7a49SSlava Zakharin 
9631d5e7a49SSlava Zakharin   // Begin loop code
9641d5e7a49SSlava Zakharin   mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
9651d5e7a49SSlava Zakharin   builder.setInsertionPointToStart(loop.getBody());
9661d5e7a49SSlava Zakharin 
96756eda98fSSlava Zakharin   mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy);
9681d5e7a49SSlava Zakharin   mlir::Value index = loop.getInductionVar();
9691d5e7a49SSlava Zakharin   mlir::Value addr1 =
97056eda98fSSlava Zakharin       builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index);
9711d5e7a49SSlava Zakharin   mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1);
97256eda98fSSlava Zakharin   // Convert to the result type.
97356eda98fSSlava Zakharin   elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1);
9741d5e7a49SSlava Zakharin 
97556eda98fSSlava Zakharin   mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy);
97656eda98fSSlava Zakharin   mlir::Value addr2 =
97756eda98fSSlava Zakharin       builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index);
97856eda98fSSlava Zakharin   mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2);
97956eda98fSSlava Zakharin   // Convert to the result type.
98056eda98fSSlava Zakharin   elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2);
98156eda98fSSlava Zakharin 
982fac349a1SChristian Sigg   if (mlir::isa<mlir::FloatType>(resultElementType))
9831d5e7a49SSlava Zakharin     sumVal = builder.create<mlir::arith::AddFOp>(
9841d5e7a49SSlava Zakharin         loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal);
985fac349a1SChristian Sigg   else if (mlir::isa<mlir::IntegerType>(resultElementType))
9861d5e7a49SSlava Zakharin     sumVal = builder.create<mlir::arith::AddIOp>(
9871d5e7a49SSlava Zakharin         loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal);
9881d5e7a49SSlava Zakharin   else
9891d5e7a49SSlava Zakharin     llvm_unreachable("unsupported type");
9901d5e7a49SSlava Zakharin 
9911d5e7a49SSlava Zakharin   builder.create<fir::ResultOp>(loc, sumVal);
9921d5e7a49SSlava Zakharin   // End of loop.
9931d5e7a49SSlava Zakharin   builder.restoreInsertionPoint(loopEndPt);
9941d5e7a49SSlava Zakharin 
9951d5e7a49SSlava Zakharin   mlir::Value resultVal = loop.getResult(0);
9961d5e7a49SSlava Zakharin   builder.create<mlir::func::ReturnOp>(loc, resultVal);
9971d5e7a49SSlava Zakharin }
9981d5e7a49SSlava Zakharin 
99980dcc907SSlava Zakharin mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
100080dcc907SSlava Zakharin     fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
100180dcc907SSlava Zakharin     FunctionTypeGeneratorTy typeGenerator,
100280dcc907SSlava Zakharin     FunctionBodyGeneratorTy bodyGenerator) {
100380dcc907SSlava Zakharin   // WARNING: if the function generated here changes its signature
100480dcc907SSlava Zakharin   //          or behavior (the body code), we should probably embed some
100580dcc907SSlava Zakharin   //          versioning information into its name, otherwise libraries
100680dcc907SSlava Zakharin   //          statically linked with older versions of Flang may stop
100780dcc907SSlava Zakharin   //          working with object files created with newer Flang.
100880dcc907SSlava Zakharin   //          We can also avoid this by using internal linkage, but
100980dcc907SSlava Zakharin   //          this may increase the size of final executable/shared library.
101080dcc907SSlava Zakharin   std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
101180dcc907SSlava Zakharin   // If we already have a function, just return it.
1012a4798bb0SjeanPerier   mlir::func::FuncOp newFunc = builder.getNamedFunction(replacementName);
101380dcc907SSlava Zakharin   mlir::FunctionType fType = typeGenerator(builder);
101480dcc907SSlava Zakharin   if (newFunc) {
101580dcc907SSlava Zakharin     assert(newFunc.getFunctionType() == fType &&
101680dcc907SSlava Zakharin            "type mismatch for simplified function");
101780dcc907SSlava Zakharin     return newFunc;
101880dcc907SSlava Zakharin   }
101980dcc907SSlava Zakharin 
102080dcc907SSlava Zakharin   // Need to build the function!
102180dcc907SSlava Zakharin   auto loc = mlir::UnknownLoc::get(builder.getContext());
1022a4798bb0SjeanPerier   newFunc = builder.createFunction(loc, replacementName, fType);
102380dcc907SSlava Zakharin   auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
102480dcc907SSlava Zakharin   auto linkage =
102580dcc907SSlava Zakharin       mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
102680dcc907SSlava Zakharin   newFunc->setAttr("llvm.linkage", linkage);
102780dcc907SSlava Zakharin 
102880dcc907SSlava Zakharin   // Save the position of the original call.
102980dcc907SSlava Zakharin   mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
103080dcc907SSlava Zakharin 
103180dcc907SSlava Zakharin   bodyGenerator(builder, newFunc);
10326e193b5cSMats Petersson 
10336e193b5cSMats Petersson   // Now back to where we were adding code earlier...
10346e193b5cSMats Petersson   builder.restoreInsertionPoint(insertPt);
10356e193b5cSMats Petersson 
10366e193b5cSMats Petersson   return newFunc;
10376e193b5cSMats Petersson }
10386e193b5cSMats Petersson 
10397d2e1987SSacha Ballantyne void SimplifyIntrinsicsPass::simplifyIntOrFloatReduction(
10407d2e1987SSacha Ballantyne     fir::CallOp call, const fir::KindMapping &kindMap,
104143159b58SMats Petersson     GenReductionBodyTy genBodyFunc) {
104243159b58SMats Petersson   // args[1] and args[2] are source filename and line number, ignored.
10437d2e1987SSacha Ballantyne   mlir::Operation::operand_range args = call.getArgs();
10447d2e1987SSacha Ballantyne 
104543159b58SMats Petersson   const mlir::Value &dim = args[3];
104643159b58SMats Petersson   const mlir::Value &mask = args[4];
104743159b58SMats Petersson   // dim is zero when it is absent, which is an implementation
104843159b58SMats Petersson   // detail in the runtime library.
10497d2e1987SSacha Ballantyne 
105043159b58SMats Petersson   bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
105143159b58SMats Petersson   unsigned rank = getDimCount(args[0]);
10522b138567SSlava Zakharin 
1053bb94d33aSSacha Ballantyne   // Rank is set to 0 for assumed shape arrays, don't simplify
1054bb94d33aSSacha Ballantyne   // in these cases
10557d2e1987SSacha Ballantyne   if (!(dimAndMaskAbsent && rank > 0))
10567d2e1987SSacha Ballantyne     return;
10577d2e1987SSacha Ballantyne 
10582b138567SSlava Zakharin   mlir::Type resultType = call.getResult(0).getType();
10597d2e1987SSacha Ballantyne 
1060fac349a1SChristian Sigg   if (!mlir::isa<mlir::FloatType>(resultType) &&
1061fac349a1SChristian Sigg       !mlir::isa<mlir::IntegerType>(resultType))
106243159b58SMats Petersson     return;
10632b138567SSlava Zakharin 
10642b138567SSlava Zakharin   auto argType = getArgElementType(args[0]);
10652b138567SSlava Zakharin   if (!argType)
10662b138567SSlava Zakharin     return;
10672b138567SSlava Zakharin   assert(*argType == resultType &&
10682b138567SSlava Zakharin          "Argument/result types mismatch in reduction");
10692b138567SSlava Zakharin 
10707d2e1987SSacha Ballantyne   mlir::SymbolRefAttr callee = call.getCalleeAttr();
10717d2e1987SSacha Ballantyne 
10727d2e1987SSacha Ballantyne   fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1073f52c64b1SDavid Truby   std::string fmfString{builder.getFastMathFlagsString()};
10747d2e1987SSacha Ballantyne   std::string funcName =
10757d2e1987SSacha Ballantyne       (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
10767d2e1987SSacha Ballantyne        mlir::Twine{rank} +
10777d2e1987SSacha Ballantyne        // We must mangle the generated function name with FastMathFlags
10787d2e1987SSacha Ballantyne        // value.
10797d2e1987SSacha Ballantyne        (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString}))
10807d2e1987SSacha Ballantyne           .str();
10817d2e1987SSacha Ballantyne 
108220fba03fSSacha Ballantyne   simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
108320fba03fSSacha Ballantyne                         resultType);
10847d2e1987SSacha Ballantyne }
10857d2e1987SSacha Ballantyne 
108620fba03fSSacha Ballantyne void SimplifyIntrinsicsPass::simplifyLogicalDim0Reduction(
10877d2e1987SSacha Ballantyne     fir::CallOp call, const fir::KindMapping &kindMap,
10887d2e1987SSacha Ballantyne     GenReductionBodyTy genBodyFunc) {
10897d2e1987SSacha Ballantyne 
10907d2e1987SSacha Ballantyne   mlir::Operation::operand_range args = call.getArgs();
10917d2e1987SSacha Ballantyne   const mlir::Value &dim = args[3];
1092bb94d33aSSacha Ballantyne   unsigned rank = getDimCount(args[0]);
10937d2e1987SSacha Ballantyne 
109420fba03fSSacha Ballantyne   // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
109520fba03fSSacha Ballantyne   // these cases.
1096bb94d33aSSacha Ballantyne   if (!(isZero(dim) && rank > 0))
10977d2e1987SSacha Ballantyne     return;
10987d2e1987SSacha Ballantyne 
109920fba03fSSacha Ballantyne   mlir::Value inputBox = findBoxDef(args[0]);
110020fba03fSSacha Ballantyne 
110120fba03fSSacha Ballantyne   mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType());
11027d2e1987SSacha Ballantyne   mlir::SymbolRefAttr callee = call.getCalleeAttr();
11037d2e1987SSacha Ballantyne 
11047d2e1987SSacha Ballantyne   fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
110520fba03fSSacha Ballantyne 
110620fba03fSSacha Ballantyne   // Treating logicals as integers makes things a lot easier
1107fac349a1SChristian Sigg   fir::LogicalType logicalType = {
1108fac349a1SChristian Sigg       mlir::dyn_cast<fir::LogicalType>(elementType)};
110920fba03fSSacha Ballantyne   fir::KindTy kind = logicalType.getFKind();
1110614cd721SSacha Ballantyne   mlir::Type intElementType = builder.getIntegerType(kind * 8);
111120fba03fSSacha Ballantyne 
111220fba03fSSacha Ballantyne   // Mangle kind into function name as it is not done by default
11137d2e1987SSacha Ballantyne   std::string funcName =
111420fba03fSSacha Ballantyne       (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} +
111520fba03fSSacha Ballantyne        mlir::Twine{kind} + "x" + mlir::Twine{rank})
11167d2e1987SSacha Ballantyne           .str();
11177d2e1987SSacha Ballantyne 
111820fba03fSSacha Ballantyne   simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
111920fba03fSSacha Ballantyne                         intElementType);
112020fba03fSSacha Ballantyne }
112120fba03fSSacha Ballantyne 
112220fba03fSSacha Ballantyne void SimplifyIntrinsicsPass::simplifyLogicalDim1Reduction(
112320fba03fSSacha Ballantyne     fir::CallOp call, const fir::KindMapping &kindMap,
112420fba03fSSacha Ballantyne     GenReductionBodyTy genBodyFunc) {
112520fba03fSSacha Ballantyne 
112620fba03fSSacha Ballantyne   mlir::Operation::operand_range args = call.getArgs();
112720fba03fSSacha Ballantyne   mlir::SymbolRefAttr callee = call.getCalleeAttr();
112820fba03fSSacha Ballantyne   mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
112920fba03fSSacha Ballantyne   unsigned rank = getDimCount(args[0]);
113020fba03fSSacha Ballantyne 
113120fba03fSSacha Ballantyne   // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
113220fba03fSSacha Ballantyne   // these cases. We check for Dim at the end as some logical functions (Any,
113320fba03fSSacha Ballantyne   // All) set dim to 1 instead of 0 when the argument is not present.
113420fba03fSSacha Ballantyne   if (funcNameBase.ends_with("Dim") || !(rank > 0))
113520fba03fSSacha Ballantyne     return;
113620fba03fSSacha Ballantyne 
113720fba03fSSacha Ballantyne   mlir::Value inputBox = findBoxDef(args[0]);
113820fba03fSSacha Ballantyne   mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType());
113920fba03fSSacha Ballantyne 
114020fba03fSSacha Ballantyne   fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
114120fba03fSSacha Ballantyne 
114220fba03fSSacha Ballantyne   // Treating logicals as integers makes things a lot easier
1143fac349a1SChristian Sigg   fir::LogicalType logicalType = {
1144fac349a1SChristian Sigg       mlir::dyn_cast<fir::LogicalType>(elementType)};
114520fba03fSSacha Ballantyne   fir::KindTy kind = logicalType.getFKind();
1146614cd721SSacha Ballantyne   mlir::Type intElementType = builder.getIntegerType(kind * 8);
114720fba03fSSacha Ballantyne 
114820fba03fSSacha Ballantyne   // Mangle kind into function name as it is not done by default
114920fba03fSSacha Ballantyne   std::string funcName =
115020fba03fSSacha Ballantyne       (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} +
115120fba03fSSacha Ballantyne        mlir::Twine{kind} + "x" + mlir::Twine{rank})
115220fba03fSSacha Ballantyne           .str();
115320fba03fSSacha Ballantyne 
115420fba03fSSacha Ballantyne   simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
115520fba03fSSacha Ballantyne                         intElementType);
11567d2e1987SSacha Ballantyne }
11577d2e1987SSacha Ballantyne 
11589bb47f7fSDavid Green void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
11599bb47f7fSDavid Green     fir::CallOp call, const fir::KindMapping &kindMap, bool isMax) {
1160614cd721SSacha Ballantyne 
1161614cd721SSacha Ballantyne   mlir::Operation::operand_range args = call.getArgs();
1162614cd721SSacha Ballantyne 
11632a95fe48SDavid Green   mlir::SymbolRefAttr callee = call.getCalleeAttr();
11642a95fe48SDavid Green   mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
11652a95fe48SDavid Green   bool isDim = funcNameBase.ends_with("Dim");
11662a95fe48SDavid Green   mlir::Value back = args[isDim ? 7 : 6];
1167614cd721SSacha Ballantyne   if (isTrueOrNotConstant(back))
1168614cd721SSacha Ballantyne     return;
1169614cd721SSacha Ballantyne 
11702a95fe48SDavid Green   mlir::Value mask = args[isDim ? 6 : 5];
1171614cd721SSacha Ballantyne   mlir::Value maskDef = findMaskDef(mask);
1172614cd721SSacha Ballantyne 
1173614cd721SSacha Ballantyne   // maskDef is set to NULL when the defining op is not one we accept.
1174614cd721SSacha Ballantyne   // This tends to be because it is a selectOp, in which case let the
1175614cd721SSacha Ballantyne   // runtime deal with it.
1176614cd721SSacha Ballantyne   if (maskDef == NULL)
1177614cd721SSacha Ballantyne     return;
1178614cd721SSacha Ballantyne 
1179614cd721SSacha Ballantyne   unsigned rank = getDimCount(args[1]);
11802a95fe48SDavid Green   if ((isDim && rank != 1) || !(rank > 0))
1181614cd721SSacha Ballantyne     return;
1182614cd721SSacha Ballantyne 
1183614cd721SSacha Ballantyne   fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1184614cd721SSacha Ballantyne   mlir::Location loc = call.getLoc();
1185614cd721SSacha Ballantyne   auto inputBox = findBoxDef(args[1]);
1186614cd721SSacha Ballantyne   mlir::Type inputType = hlfir::getFortranElementType(inputBox.getType());
1187614cd721SSacha Ballantyne 
1188fac349a1SChristian Sigg   if (mlir::isa<fir::CharacterType>(inputType))
1189614cd721SSacha Ballantyne     return;
1190614cd721SSacha Ballantyne 
1191614cd721SSacha Ballantyne   int maskRank;
1192614cd721SSacha Ballantyne   fir::KindTy kind = 0;
1193242bb0b6SSacha Ballantyne   mlir::Type logicalElemType = builder.getI1Type();
1194614cd721SSacha Ballantyne   if (isOperandAbsent(mask)) {
1195614cd721SSacha Ballantyne     maskRank = -1;
1196614cd721SSacha Ballantyne   } else {
1197614cd721SSacha Ballantyne     maskRank = getDimCount(mask);
1198614cd721SSacha Ballantyne     mlir::Type maskElemTy = hlfir::getFortranElementType(maskDef.getType());
1199fac349a1SChristian Sigg     fir::LogicalType logicalFirType = {
1200fac349a1SChristian Sigg         mlir::dyn_cast<fir::LogicalType>(maskElemTy)};
1201242bb0b6SSacha Ballantyne     kind = logicalFirType.getFKind();
1202242bb0b6SSacha Ballantyne     // Convert fir::LogicalType to mlir::Type
1203242bb0b6SSacha Ballantyne     logicalElemType = logicalFirType;
1204614cd721SSacha Ballantyne   }
1205614cd721SSacha Ballantyne 
1206614cd721SSacha Ballantyne   mlir::Operation *outputDef = args[0].getDefiningOp();
1207614cd721SSacha Ballantyne   mlir::Value outputAlloc = outputDef->getOperand(0);
1208614cd721SSacha Ballantyne   mlir::Type outType = hlfir::getFortranElementType(outputAlloc.getType());
1209614cd721SSacha Ballantyne 
1210f52c64b1SDavid Truby   std::string fmfString{builder.getFastMathFlagsString()};
1211614cd721SSacha Ballantyne   std::string funcName =
1212614cd721SSacha Ballantyne       (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
1213614cd721SSacha Ballantyne        mlir::Twine{rank} +
1214614cd721SSacha Ballantyne        (maskRank >= 0
1215614cd721SSacha Ballantyne             ? "_Logical" + mlir::Twine{kind} + "x" + mlir::Twine{maskRank}
1216614cd721SSacha Ballantyne             : "") +
1217614cd721SSacha Ballantyne        "_")
1218614cd721SSacha Ballantyne           .str();
1219614cd721SSacha Ballantyne 
1220614cd721SSacha Ballantyne   llvm::raw_string_ostream nameOS(funcName);
1221614cd721SSacha Ballantyne   outType.print(nameOS);
12222a95fe48SDavid Green   if (isDim)
12232a95fe48SDavid Green     nameOS << '_' << inputType;
1224614cd721SSacha Ballantyne   nameOS << '_' << fmfString;
1225614cd721SSacha Ballantyne 
1226614cd721SSacha Ballantyne   auto typeGenerator = [rank](fir::FirOpBuilder &builder) {
1227614cd721SSacha Ballantyne     return genRuntimeMinlocType(builder, rank);
1228614cd721SSacha Ballantyne   };
12299bb47f7fSDavid Green   auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType,
12302a95fe48SDavid Green                         isMax, isDim](fir::FirOpBuilder &builder,
1231614cd721SSacha Ballantyne                                       mlir::func::FuncOp &funcOp) {
12329bb47f7fSDavid Green     genRuntimeMinMaxlocBody(builder, funcOp, isMax, rank, maskRank, inputType,
12332a95fe48SDavid Green                             logicalElemType, outType, isDim);
1234614cd721SSacha Ballantyne   };
1235614cd721SSacha Ballantyne 
1236614cd721SSacha Ballantyne   mlir::func::FuncOp newFunc =
1237614cd721SSacha Ballantyne       getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
1238614cd721SSacha Ballantyne   builder.create<fir::CallOp>(loc, newFunc,
12392a95fe48SDavid Green                               mlir::ValueRange{args[0], args[1], mask});
1240614cd721SSacha Ballantyne   call->dropAllReferences();
1241614cd721SSacha Ballantyne   call->erase();
1242614cd721SSacha Ballantyne }
1243614cd721SSacha Ballantyne 
12447d2e1987SSacha Ballantyne void SimplifyIntrinsicsPass::simplifyReductionBody(
12457d2e1987SSacha Ballantyne     fir::CallOp call, const fir::KindMapping &kindMap,
12467d2e1987SSacha Ballantyne     GenReductionBodyTy genBodyFunc, fir::FirOpBuilder &builder,
124720fba03fSSacha Ballantyne     const mlir::StringRef &funcName, mlir::Type elementType) {
12487d2e1987SSacha Ballantyne 
12497d2e1987SSacha Ballantyne   mlir::Operation::operand_range args = call.getArgs();
12507d2e1987SSacha Ballantyne 
12517d2e1987SSacha Ballantyne   mlir::Type resultType = call.getResult(0).getType();
12527d2e1987SSacha Ballantyne   unsigned rank = getDimCount(args[0]);
12537d2e1987SSacha Ballantyne 
12547d2e1987SSacha Ballantyne   mlir::Location loc = call.getLoc();
12557d2e1987SSacha Ballantyne 
12562b138567SSlava Zakharin   auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
12572b138567SSlava Zakharin     return genNoneBoxType(builder, resultType);
125843159b58SMats Petersson   };
125920fba03fSSacha Ballantyne   auto bodyGenerator = [&rank, &genBodyFunc,
126020fba03fSSacha Ballantyne                         &elementType](fir::FirOpBuilder &builder,
12618bd76ac1SSlava Zakharin                                       mlir::func::FuncOp &funcOp) {
126220fba03fSSacha Ballantyne     genBodyFunc(builder, funcOp, rank, elementType);
12638bd76ac1SSlava Zakharin   };
12648bd76ac1SSlava Zakharin   // Mangle the function name with the rank value as "x<rank>".
126543159b58SMats Petersson   mlir::func::FuncOp newFunc =
12668bd76ac1SSlava Zakharin       getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
126743159b58SMats Petersson   auto newCall =
126843159b58SMats Petersson       builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
126943159b58SMats Petersson   call->replaceAllUsesWith(newCall.getResults());
127043159b58SMats Petersson   call->dropAllReferences();
127143159b58SMats Petersson   call->erase();
127243159b58SMats Petersson }
127343159b58SMats Petersson 
12746e193b5cSMats Petersson void SimplifyIntrinsicsPass::runOnOperation() {
12751d5e7a49SSlava Zakharin   LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
12766e193b5cSMats Petersson   mlir::ModuleOp module = getOperation();
12776e193b5cSMats Petersson   fir::KindMapping kindMap = fir::getKindMapping(module);
12786e193b5cSMats Petersson   module.walk([&](mlir::Operation *op) {
12796e193b5cSMats Petersson     if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
1280a76609ddSValentin Clement (バレンタイン クレメン)       if (cuf::isInCUDADeviceContext(op))
1281a76609ddSValentin Clement (バレンタイン クレメン)         return;
12826e193b5cSMats Petersson       if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
12836e193b5cSMats Petersson         mlir::StringRef funcName = callee.getLeafReference().getValue();
12846e193b5cSMats Petersson         // Replace call to runtime function for SUM when it has single
12856e193b5cSMats Petersson         // argument (no dim or mask argument) for 1D arrays with either
12866e193b5cSMats Petersson         // Integer4 or Real8 types. Other forms are ignored.
12876e193b5cSMats Petersson         // The new function is added to the module.
12886e193b5cSMats Petersson         //
12896e193b5cSMats Petersson         // Prototype for runtime call (from sum.cpp):
12906e193b5cSMats Petersson         // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
12916e193b5cSMats Petersson         //                int dim, const Descriptor *mask)
129211db65baSSlava Zakharin         //
129311efcceaSKazu Hirata         if (funcName.starts_with(RTNAME_STRING(Sum))) {
12947d2e1987SSacha Ballantyne           simplifyIntOrFloatReduction(call, kindMap, genRuntimeSumBody);
12951d5e7a49SSlava Zakharin           return;
12961d5e7a49SSlava Zakharin         }
129711efcceaSKazu Hirata         if (funcName.starts_with(RTNAME_STRING(DotProduct))) {
12981d5e7a49SSlava Zakharin           LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n");
12991d5e7a49SSlava Zakharin           LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump();
13001d5e7a49SSlava Zakharin                      llvm::dbgs() << "\n");
13011d5e7a49SSlava Zakharin           mlir::Operation::operand_range args = call.getArgs();
13021d5e7a49SSlava Zakharin           const mlir::Value &v1 = args[0];
13031d5e7a49SSlava Zakharin           const mlir::Value &v2 = args[1];
13041d5e7a49SSlava Zakharin           mlir::Location loc = call.getLoc();
1305ffe1661fSSlava Zakharin           fir::FirOpBuilder builder{getSimplificationBuilder(op, kindMap)};
1306ffe1661fSSlava Zakharin           // Stringize the builder's FastMathFlags flags for mangling
1307ffe1661fSSlava Zakharin           // the generated function name.
1308f52c64b1SDavid Truby           std::string fmfString{builder.getFastMathFlagsString()};
1309afa520abSMats Petersson 
13101d5e7a49SSlava Zakharin           mlir::Type type = call.getResult(0).getType();
1311fac349a1SChristian Sigg           if (!mlir::isa<mlir::FloatType>(type) &&
1312fac349a1SChristian Sigg               !mlir::isa<mlir::IntegerType>(type))
13131d5e7a49SSlava Zakharin             return;
13141d5e7a49SSlava Zakharin 
131556eda98fSSlava Zakharin           // Try to find the element types of the boxed arguments.
131656eda98fSSlava Zakharin           auto arg1Type = getArgElementType(v1);
131756eda98fSSlava Zakharin           auto arg2Type = getArgElementType(v2);
131856eda98fSSlava Zakharin 
131956eda98fSSlava Zakharin           if (!arg1Type || !arg2Type)
132056eda98fSSlava Zakharin             return;
132156eda98fSSlava Zakharin 
132256eda98fSSlava Zakharin           // Support only floating point and integer arguments
132356eda98fSSlava Zakharin           // now (e.g. logical is skipped here).
1324bd9fdce6SChristian Sigg           if (!mlir::isa<mlir::FloatType, mlir::IntegerType>(*arg1Type))
132556eda98fSSlava Zakharin             return;
1326bd9fdce6SChristian Sigg           if (!mlir::isa<mlir::FloatType, mlir::IntegerType>(*arg2Type))
132756eda98fSSlava Zakharin             return;
132856eda98fSSlava Zakharin 
13291d5e7a49SSlava Zakharin           auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
1330aa94eb38SMats Petersson             return genRuntimeDotType(builder, type);
13311d5e7a49SSlava Zakharin           };
133256eda98fSSlava Zakharin           auto bodyGenerator = [&arg1Type,
133356eda98fSSlava Zakharin                                 &arg2Type](fir::FirOpBuilder &builder,
133456eda98fSSlava Zakharin                                            mlir::func::FuncOp &funcOp) {
1335aa94eb38SMats Petersson             genRuntimeDotBody(builder, funcOp, *arg1Type, *arg2Type);
133656eda98fSSlava Zakharin           };
133756eda98fSSlava Zakharin 
133856eda98fSSlava Zakharin           // Suffix the function name with the element types
133956eda98fSSlava Zakharin           // of the arguments.
134056eda98fSSlava Zakharin           std::string typedFuncName(funcName);
134156eda98fSSlava Zakharin           llvm::raw_string_ostream nameOS(typedFuncName);
1342ffe1661fSSlava Zakharin           // We must mangle the generated function name with FastMathFlags
1343ffe1661fSSlava Zakharin           // value.
1344ffe1661fSSlava Zakharin           if (!fmfString.empty())
1345ffe1661fSSlava Zakharin             nameOS << '_' << fmfString;
1346ffe1661fSSlava Zakharin           nameOS << '_';
134756eda98fSSlava Zakharin           arg1Type->print(nameOS);
1348ffe1661fSSlava Zakharin           nameOS << '_';
134956eda98fSSlava Zakharin           arg2Type->print(nameOS);
135056eda98fSSlava Zakharin 
13511d5e7a49SSlava Zakharin           mlir::func::FuncOp newFunc = getOrCreateFunction(
135256eda98fSSlava Zakharin               builder, typedFuncName, typeGenerator, bodyGenerator);
13531d5e7a49SSlava Zakharin           auto newCall = builder.create<fir::CallOp>(loc, newFunc,
13541d5e7a49SSlava Zakharin                                                      mlir::ValueRange{v1, v2});
13551d5e7a49SSlava Zakharin           call->replaceAllUsesWith(newCall.getResults());
13561d5e7a49SSlava Zakharin           call->dropAllReferences();
13571d5e7a49SSlava Zakharin           call->erase();
13581d5e7a49SSlava Zakharin 
13591d5e7a49SSlava Zakharin           LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump();
13601d5e7a49SSlava Zakharin                      llvm::dbgs() << "\n");
13611d5e7a49SSlava Zakharin           return;
13626e193b5cSMats Petersson         }
136311efcceaSKazu Hirata         if (funcName.starts_with(RTNAME_STRING(Maxval))) {
13647d2e1987SSacha Ballantyne           simplifyIntOrFloatReduction(call, kindMap, genRuntimeMaxvalBody);
13657d2e1987SSacha Ballantyne           return;
13667d2e1987SSacha Ballantyne         }
136711efcceaSKazu Hirata         if (funcName.starts_with(RTNAME_STRING(Count))) {
136820fba03fSSacha Ballantyne           simplifyLogicalDim0Reduction(call, kindMap, genRuntimeCountBody);
136920fba03fSSacha Ballantyne           return;
137020fba03fSSacha Ballantyne         }
137111efcceaSKazu Hirata         if (funcName.starts_with(RTNAME_STRING(Any))) {
137220fba03fSSacha Ballantyne           simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAnyBody);
137320fba03fSSacha Ballantyne           return;
137420fba03fSSacha Ballantyne         }
137511efcceaSKazu Hirata         if (funcName.ends_with(RTNAME_STRING(All))) {
137620fba03fSSacha Ballantyne           simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAllBody);
1377afa520abSMats Petersson           return;
1378afa520abSMats Petersson         }
137911efcceaSKazu Hirata         if (funcName.starts_with(RTNAME_STRING(Minloc))) {
13809bb47f7fSDavid Green           simplifyMinMaxlocReduction(call, kindMap, false);
13819bb47f7fSDavid Green           return;
13829bb47f7fSDavid Green         }
13839bb47f7fSDavid Green         if (funcName.starts_with(RTNAME_STRING(Maxloc))) {
13849bb47f7fSDavid Green           simplifyMinMaxlocReduction(call, kindMap, true);
1385614cd721SSacha Ballantyne           return;
1386614cd721SSacha Ballantyne         }
13876e193b5cSMats Petersson       }
13886e193b5cSMats Petersson     }
13896e193b5cSMats Petersson   });
13901d5e7a49SSlava Zakharin   LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
13916e193b5cSMats Petersson }
13926e193b5cSMats Petersson 
13931d5e7a49SSlava Zakharin void SimplifyIntrinsicsPass::getDependentDialects(
13941d5e7a49SSlava Zakharin     mlir::DialectRegistry &registry) const {
13951d5e7a49SSlava Zakharin   // LLVM::LinkageAttr creation requires that LLVM dialect is loaded.
13961d5e7a49SSlava Zakharin   registry.insert<mlir::LLVM::LLVMDialect>();
13971d5e7a49SSlava Zakharin }
1398