xref: /llvm-project/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp (revision 4b17a8b10ebb69d3bd30ee7714b5ca24f7e944dc)
1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 //===----------------------------------------------------------------------===//
10 /// \file
11 /// This pass looks for suitable calls to runtime library for intrinsics that
12 /// can be simplified/specialized and replaces with a specialized function.
13 ///
14 /// For example, SUM(arr) can be specialized as a simple function with one loop,
15 /// compared to the three arguments (plus file & line info) that the runtime
16 /// call has - when the argument is a 1D-array (multiple loops may be needed
17 //  for higher dimension arrays, of course)
18 ///
19 /// The general idea is that besides making the call simpler, it can also be
20 /// inlined by other passes that run after this pass, which further improves
21 /// performance, particularly when the work done in the function is trivial
22 /// and small in size.
23 //===----------------------------------------------------------------------===//
24 
25 #include "flang/Common/Fortran.h"
26 #include "flang/Optimizer/Builder/BoxValue.h"
27 #include "flang/Optimizer/Builder/CUFCommon.h"
28 #include "flang/Optimizer/Builder/FIRBuilder.h"
29 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
30 #include "flang/Optimizer/Builder/Todo.h"
31 #include "flang/Optimizer/Dialect/FIROps.h"
32 #include "flang/Optimizer/Dialect/FIRType.h"
33 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
34 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
35 #include "flang/Optimizer/Transforms/Passes.h"
36 #include "flang/Optimizer/Transforms/Utils.h"
37 #include "flang/Runtime/entry-names.h"
38 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
39 #include "mlir/IR/Matchers.h"
40 #include "mlir/IR/Operation.h"
41 #include "mlir/Pass/Pass.h"
42 #include "mlir/Transforms/DialectConversion.h"
43 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
44 #include "mlir/Transforms/RegionUtils.h"
45 #include "llvm/Support/Debug.h"
46 #include "llvm/Support/raw_ostream.h"
47 #include <llvm/Support/ErrorHandling.h>
48 #include <mlir/Dialect/Arith/IR/Arith.h>
49 #include <mlir/IR/BuiltinTypes.h>
50 #include <mlir/IR/Location.h>
51 #include <mlir/IR/MLIRContext.h>
52 #include <mlir/IR/Value.h>
53 #include <mlir/Support/LLVM.h>
54 #include <optional>
55 
56 namespace fir {
57 #define GEN_PASS_DEF_SIMPLIFYINTRINSICS
58 #include "flang/Optimizer/Transforms/Passes.h.inc"
59 } // namespace fir
60 
61 #define DEBUG_TYPE "flang-simplify-intrinsics"
62 
63 namespace {
64 
65 class SimplifyIntrinsicsPass
66     : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
67   using FunctionTypeGeneratorTy =
68       llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
69   using FunctionBodyGeneratorTy =
70       llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
71   using GenReductionBodyTy = llvm::function_ref<void(
72       fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank,
73       mlir::Type elementType)>;
74 
75 public:
76   using fir::impl::SimplifyIntrinsicsBase<
77       SimplifyIntrinsicsPass>::SimplifyIntrinsicsBase;
78 
79   /// Generate a new function implementing a simplified version
80   /// of a Fortran runtime function defined by \p basename name.
81   /// \p typeGenerator is a callback that generates the new function's type.
82   /// \p bodyGenerator is a callback that generates the new function's body.
83   /// The new function is created in the \p builder's Module.
84   mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
85                                          const mlir::StringRef &basename,
86                                          FunctionTypeGeneratorTy typeGenerator,
87                                          FunctionBodyGeneratorTy bodyGenerator);
88   void runOnOperation() override;
89   void getDependentDialects(mlir::DialectRegistry &registry) const override;
90 
91 private:
92   /// Helper functions to replace a reduction type of call with its
93   /// simplified form. The actual function is generated using a callback
94   /// function.
95   /// \p call is the call to be replaced
96   /// \p kindMap is used to create FIROpBuilder
97   /// \p genBodyFunc is the callback that builds the replacement function
98   void simplifyIntOrFloatReduction(fir::CallOp call,
99                                    const fir::KindMapping &kindMap,
100                                    GenReductionBodyTy genBodyFunc);
101   void simplifyLogicalDim0Reduction(fir::CallOp call,
102                                     const fir::KindMapping &kindMap,
103                                     GenReductionBodyTy genBodyFunc);
104   void simplifyLogicalDim1Reduction(fir::CallOp call,
105                                     const fir::KindMapping &kindMap,
106                                     GenReductionBodyTy genBodyFunc);
107   void simplifyMinMaxlocReduction(fir::CallOp call,
108                                   const fir::KindMapping &kindMap, bool isMax);
109   void simplifyReductionBody(fir::CallOp call, const fir::KindMapping &kindMap,
110                              GenReductionBodyTy genBodyFunc,
111                              fir::FirOpBuilder &builder,
112                              const mlir::StringRef &basename,
113                              mlir::Type elementType);
114 };
115 
116 } // namespace
117 
118 /// Create FirOpBuilder with the provided \p op insertion point
119 /// and \p kindMap additionally inheriting FastMathFlags from \p op.
120 static fir::FirOpBuilder
121 getSimplificationBuilder(mlir::Operation *op, const fir::KindMapping &kindMap) {
122   fir::FirOpBuilder builder{op, kindMap};
123   auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
124   if (!fmi)
125     return builder;
126 
127   // Regardless of what default FastMathFlags are used by FirOpBuilder,
128   // override them with FastMathFlags attached to the operation.
129   builder.setFastMathFlags(fmi.getFastMathFlagsAttr().getValue());
130   return builder;
131 }
132 
133 /// Generate function type for the simplified version of RTNAME(Sum) and
134 /// similar functions with a fir.box<none> type returning \p elementType.
135 static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
136                                          const mlir::Type &elementType) {
137   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
138   return mlir::FunctionType::get(builder.getContext(), {boxType},
139                                  {elementType});
140 }
141 
142 template <typename Op>
143 Op expectOp(mlir::Value val) {
144   if (Op op = mlir::dyn_cast_or_null<Op>(val.getDefiningOp()))
145     return op;
146   LLVM_DEBUG(llvm::dbgs() << "Didn't find expected " << Op::getOperationName()
147                           << '\n');
148   return nullptr;
149 }
150 
151 template <typename Op>
152 static mlir::Value findDefSingle(fir::ConvertOp op) {
153   if (auto defOp = expectOp<Op>(op->getOperand(0))) {
154     return defOp.getResult();
155   }
156   return {};
157 }
158 
159 template <typename... Ops>
160 static mlir::Value findDef(fir::ConvertOp op) {
161   mlir::Value defOp;
162   // Loop over the operation types given to see if any match, exiting once
163   // a match is found. Cast to void is needed to avoid compiler complaining
164   // that the result of expression is unused
165   (void)((defOp = findDefSingle<Ops>(op), (defOp)) || ...);
166   return defOp;
167 }
168 
169 static bool isOperandAbsent(mlir::Value val) {
170   if (auto op = expectOp<fir::ConvertOp>(val)) {
171     assert(op->getOperands().size() != 0);
172     return mlir::isa_and_nonnull<fir::AbsentOp>(
173         op->getOperand(0).getDefiningOp());
174   }
175   return false;
176 }
177 
178 static bool isTrueOrNotConstant(mlir::Value val) {
179   if (auto op = expectOp<mlir::arith::ConstantOp>(val)) {
180     return !mlir::matchPattern(val, mlir::m_Zero());
181   }
182   return true;
183 }
184 
185 static bool isZero(mlir::Value val) {
186   if (auto op = expectOp<fir::ConvertOp>(val)) {
187     assert(op->getOperands().size() != 0);
188     if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
189       return mlir::matchPattern(defOp, mlir::m_Zero());
190   }
191   return false;
192 }
193 
194 static mlir::Value findBoxDef(mlir::Value val) {
195   if (auto op = expectOp<fir::ConvertOp>(val)) {
196     assert(op->getOperands().size() != 0);
197     return findDef<fir::EmboxOp, fir::ReboxOp>(op);
198   }
199   return {};
200 }
201 
202 static mlir::Value findMaskDef(mlir::Value val) {
203   if (auto op = expectOp<fir::ConvertOp>(val)) {
204     assert(op->getOperands().size() != 0);
205     return findDef<fir::EmboxOp, fir::ReboxOp, fir::AbsentOp>(op);
206   }
207   return {};
208 }
209 
210 static unsigned getDimCount(mlir::Value val) {
211   // In order to find the dimensions count, we look for EmboxOp/ReboxOp
212   // and take the count from its *result* type. Note that in case
213   // of sliced emboxing the operand and the result of EmboxOp/ReboxOp
214   // have different types.
215   // Actually, we can take the box type from the operand of
216   // the first ConvertOp that has non-opaque box type that we meet
217   // going through the ConvertOp chain.
218   if (mlir::Value emboxVal = findBoxDef(val))
219     if (auto boxTy = mlir::dyn_cast<fir::BoxType>(emboxVal.getType()))
220       if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy()))
221         return seqTy.getDimension();
222   return 0;
223 }
224 
225 /// Given the call operation's box argument \p val, discover
226 /// the element type of the underlying array object.
227 /// \returns the element type or std::nullopt if the type cannot
228 /// be reliably found.
229 /// We expect that the argument is a result of fir.convert
230 /// with the destination type of !fir.box<none>.
231 static std::optional<mlir::Type> getArgElementType(mlir::Value val) {
232   mlir::Operation *defOp;
233   do {
234     defOp = val.getDefiningOp();
235     // Analyze only sequences of convert operations.
236     if (!mlir::isa<fir::ConvertOp>(defOp))
237       return std::nullopt;
238     val = defOp->getOperand(0);
239     // The convert operation is expected to convert from one
240     // box type to another box type.
241     auto boxType = mlir::cast<fir::BoxType>(val.getType());
242     auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType);
243     if (!mlir::isa<mlir::NoneType>(elementType))
244       return elementType;
245   } while (true);
246 }
247 
248 using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
249     fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
250     mlir::Value)>;
251 using ContinueLoopGenTy = llvm::function_ref<llvm::SmallVector<mlir::Value>(
252     fir::FirOpBuilder &, mlir::Location, mlir::Value)>;
253 
254 /// Generate the reduction loop into \p funcOp.
255 ///
256 /// \p initVal is a function, called to get the initial value for
257 ///    the reduction value
258 /// \p genBody is called to fill in the actual reduciton operation
259 ///    for example add for SUM, MAX for MAXVAL, etc.
260 /// \p rank is the rank of the input argument.
261 /// \p elementType is the type of the elements in the input array,
262 ///    which may be different to the return type.
263 /// \p loopCond is called to generate the condition to continue or
264 ///    not for IterWhile loops
265 /// \p unorderedOrInitalLoopCond contains either a boolean or bool
266 ///    mlir constant, and controls the inital value for while loops
267 ///    or if DoLoop is ordered/unordered.
268 
269 template <typename OP, typename T, int resultIndex>
270 static void
271 genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
272                  fir::InitValGeneratorTy initVal, ContinueLoopGenTy loopCond,
273                  T unorderedOrInitialLoopCond, BodyOpGeneratorTy genBody,
274                  unsigned rank, mlir::Type elementType, mlir::Location loc) {
275 
276   mlir::IndexType idxTy = builder.getIndexType();
277 
278   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
279   mlir::Value arg = args[0];
280 
281   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
282 
283   fir::SequenceType::Shape flatShape(rank,
284                                      fir::SequenceType::getUnknownExtent());
285   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
286   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
287   mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
288   mlir::Type resultType = funcOp.getResultTypes()[0];
289   mlir::Value init = initVal(builder, loc, resultType);
290 
291   llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
292 
293   assert(rank > 0 && "rank cannot be zero");
294   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
295 
296   // Compute all the upper bounds before the loop nest.
297   // It is not strictly necessary for performance, since the loop nest
298   // does not have any store operations and any LICM optimization
299   // should be able to optimize the redundancy.
300   for (unsigned i = 0; i < rank; ++i) {
301     mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
302     auto dims =
303         builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
304     mlir::Value len = dims.getResult(1);
305     // We use C indexing here, so len-1 as loopcount
306     mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
307     bounds.push_back(loopCount);
308   }
309   // Create a loop nest consisting of OP operations.
310   // Collect the loops' induction variables into indices array,
311   // which will be used in the innermost loop to load the input
312   // array's element.
313   // The loops are generated such that the innermost loop processes
314   // the 0 dimension.
315   llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
316   for (unsigned i = rank; 0 < i; --i) {
317     mlir::Value step = one;
318     mlir::Value loopCount = bounds[i - 1];
319     auto loop = builder.create<OP>(loc, zeroIdx, loopCount, step,
320                                    unorderedOrInitialLoopCond,
321                                    /*finalCountValue=*/false, init);
322     init = loop.getRegionIterArgs()[resultIndex];
323     indices.push_back(loop.getInductionVar());
324     // Set insertion point to the loop body so that the next loop
325     // is inserted inside the current one.
326     builder.setInsertionPointToStart(loop.getBody());
327   }
328 
329   // Reverse the indices such that they are ordered as:
330   //   <dim-0-idx, dim-1-idx, ...>
331   std::reverse(indices.begin(), indices.end());
332   // We are in the innermost loop: generate the reduction body.
333   mlir::Type eleRefTy = builder.getRefType(elementType);
334   mlir::Value addr =
335       builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
336   mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
337   mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init);
338   // Generate vector with condition to continue while loop at [0] and result
339   // from current loop at [1] for IterWhileOp loops, just result at [0] for
340   // DoLoopOp loops.
341   llvm::SmallVector<mlir::Value> results = loopCond(builder, loc, reductionVal);
342 
343   // Unwind the loop nest and insert ResultOp on each level
344   // to return the updated value of the reduction to the enclosing
345   // loops.
346   for (unsigned i = 0; i < rank; ++i) {
347     auto result = builder.create<fir::ResultOp>(loc, results);
348     // Proceed to the outer loop.
349     auto loop = mlir::cast<OP>(result->getParentOp());
350     results = loop.getResults();
351     // Set insertion point after the loop operation that we have
352     // just processed.
353     builder.setInsertionPointAfter(loop.getOperation());
354   }
355   // End of loop nest. The insertion point is after the outermost loop.
356   // Return the reduction value from the function.
357   builder.create<mlir::func::ReturnOp>(loc, results[resultIndex]);
358 }
359 
360 static llvm::SmallVector<mlir::Value> nopLoopCond(fir::FirOpBuilder &builder,
361                                                   mlir::Location loc,
362                                                   mlir::Value reductionVal) {
363   return {reductionVal};
364 }
365 
366 /// Generate function body of the simplified version of RTNAME(Sum)
367 /// with signature provided by \p funcOp. The caller is responsible
368 /// for saving/restoring the original insertion point of \p builder.
369 /// \p funcOp is expected to be empty on entry to this function.
370 /// \p rank specifies the rank of the input argument.
371 static void genRuntimeSumBody(fir::FirOpBuilder &builder,
372                               mlir::func::FuncOp &funcOp, unsigned rank,
373                               mlir::Type elementType) {
374   // function RTNAME(Sum)<T>x<rank>_simplified(arr)
375   //   T, dimension(:) :: arr
376   //   T sum = 0
377   //   integer iter
378   //   do iter = 0, extent(arr)
379   //     sum = sum + arr[iter]
380   //   end do
381   //   RTNAME(Sum)<T>x<rank>_simplified = sum
382   // end function RTNAME(Sum)<T>x<rank>_simplified
383   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
384                  mlir::Type elementType) {
385     if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
386       const llvm::fltSemantics &sem = ty.getFloatSemantics();
387       return builder.createRealConstant(loc, elementType,
388                                         llvm::APFloat::getZero(sem));
389     }
390     return builder.createIntegerConstant(loc, elementType, 0);
391   };
392 
393   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
394                       mlir::Type elementType, mlir::Value elem1,
395                       mlir::Value elem2) -> mlir::Value {
396     if (mlir::isa<mlir::FloatType>(elementType))
397       return builder.create<mlir::arith::AddFOp>(loc, elem1, elem2);
398     if (mlir::isa<mlir::IntegerType>(elementType))
399       return builder.create<mlir::arith::AddIOp>(loc, elem1, elem2);
400 
401     llvm_unreachable("unsupported type");
402     return {};
403   };
404 
405   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
406   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
407 
408   genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond,
409                                            false, genBodyOp, rank, elementType,
410                                            loc);
411 }
412 
413 static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
414                                  mlir::func::FuncOp &funcOp, unsigned rank,
415                                  mlir::Type elementType) {
416   auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
417                  mlir::Type elementType) {
418     if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
419       const llvm::fltSemantics &sem = ty.getFloatSemantics();
420       return builder.createRealConstant(
421           loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/true));
422     }
423     unsigned bits = elementType.getIntOrFloatBitWidth();
424     int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
425     return builder.createIntegerConstant(loc, elementType, minInt);
426   };
427 
428   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
429                       mlir::Type elementType, mlir::Value elem1,
430                       mlir::Value elem2) -> mlir::Value {
431     if (mlir::isa<mlir::FloatType>(elementType)) {
432       // arith.maxf later converted to llvm.intr.maxnum does not work
433       // correctly for NaNs and -0.0 (see maxnum/minnum pattern matching
434       // in LLVM's InstCombine pass). Moreover, llvm.intr.maxnum
435       // for F128 operands is lowered into fmaxl call by LLVM.
436       // This libm function may not work properly for F128 arguments
437       // on targets where long double is not F128. It is an LLVM issue,
438       // but we just use normal select here to resolve all the cases.
439       auto compare = builder.create<mlir::arith::CmpFOp>(
440           loc, mlir::arith::CmpFPredicate::OGT, elem1, elem2);
441       return builder.create<mlir::arith::SelectOp>(loc, compare, elem1, elem2);
442     }
443     if (mlir::isa<mlir::IntegerType>(elementType))
444       return builder.create<mlir::arith::MaxSIOp>(loc, elem1, elem2);
445 
446     llvm_unreachable("unsupported type");
447     return {};
448   };
449 
450   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
451   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
452 
453   genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, init, nopLoopCond,
454                                            false, genBodyOp, rank, elementType,
455                                            loc);
456 }
457 
458 static void genRuntimeCountBody(fir::FirOpBuilder &builder,
459                                 mlir::func::FuncOp &funcOp, unsigned rank,
460                                 mlir::Type elementType) {
461   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
462                  mlir::Type elementType) {
463     unsigned bits = elementType.getIntOrFloatBitWidth();
464     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
465     return builder.createIntegerConstant(loc, elementType, zeroInt);
466   };
467 
468   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
469                       mlir::Type elementType, mlir::Value elem1,
470                       mlir::Value elem2) -> mlir::Value {
471     auto zero32 = builder.createIntegerConstant(loc, elementType, 0);
472     auto zero64 = builder.createIntegerConstant(loc, builder.getI64Type(), 0);
473     auto one64 = builder.createIntegerConstant(loc, builder.getI64Type(), 1);
474 
475     auto compare = builder.create<mlir::arith::CmpIOp>(
476         loc, mlir::arith::CmpIPredicate::eq, elem1, zero32);
477     auto select =
478         builder.create<mlir::arith::SelectOp>(loc, compare, zero64, one64);
479     return builder.create<mlir::arith::AddIOp>(loc, select, elem2);
480   };
481 
482   // Count always gets I32 for elementType as it converts logical input to
483   // logical<4> before passing to the function.
484   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
485   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
486 
487   genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond,
488                                            false, genBodyOp, rank, elementType,
489                                            loc);
490 }
491 
492 static void genRuntimeAnyBody(fir::FirOpBuilder &builder,
493                               mlir::func::FuncOp &funcOp, unsigned rank,
494                               mlir::Type elementType) {
495   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
496                  mlir::Type elementType) {
497     return builder.createIntegerConstant(loc, elementType, 0);
498   };
499 
500   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
501                       mlir::Type elementType, mlir::Value elem1,
502                       mlir::Value elem2) -> mlir::Value {
503     auto zero = builder.createIntegerConstant(loc, elementType, 0);
504     return builder.create<mlir::arith::CmpIOp>(
505         loc, mlir::arith::CmpIPredicate::ne, elem1, zero);
506   };
507 
508   auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc,
509                          mlir::Value reductionVal) {
510     auto one1 = builder.createIntegerConstant(loc, builder.getI1Type(), 1);
511     auto eor = builder.create<mlir::arith::XOrIOp>(loc, reductionVal, one1);
512     llvm::SmallVector<mlir::Value> results = {eor, reductionVal};
513     return results;
514   };
515 
516   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
517   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
518   mlir::Value ok = builder.createBool(loc, true);
519 
520   genReductionLoop<fir::IterWhileOp, mlir::Value, 1>(
521       builder, funcOp, zero, continueCond, ok, genBodyOp, rank, elementType,
522       loc);
523 }
524 
525 static void genRuntimeAllBody(fir::FirOpBuilder &builder,
526                               mlir::func::FuncOp &funcOp, unsigned rank,
527                               mlir::Type elementType) {
528   auto one = [](fir::FirOpBuilder builder, mlir::Location loc,
529                 mlir::Type elementType) {
530     return builder.createIntegerConstant(loc, elementType, 1);
531   };
532 
533   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
534                       mlir::Type elementType, mlir::Value elem1,
535                       mlir::Value elem2) -> mlir::Value {
536     auto zero = builder.createIntegerConstant(loc, elementType, 0);
537     return builder.create<mlir::arith::CmpIOp>(
538         loc, mlir::arith::CmpIPredicate::ne, elem1, zero);
539   };
540 
541   auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc,
542                          mlir::Value reductionVal) {
543     llvm::SmallVector<mlir::Value> results = {reductionVal, reductionVal};
544     return results;
545   };
546 
547   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
548   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
549   mlir::Value ok = builder.createBool(loc, true);
550 
551   genReductionLoop<fir::IterWhileOp, mlir::Value, 1>(
552       builder, funcOp, one, continueCond, ok, genBodyOp, rank, elementType,
553       loc);
554 }
555 
556 static mlir::FunctionType genRuntimeMinlocType(fir::FirOpBuilder &builder,
557                                                unsigned int rank) {
558   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
559   mlir::Type boxRefType = builder.getRefType(boxType);
560 
561   return mlir::FunctionType::get(builder.getContext(),
562                                  {boxRefType, boxType, boxType}, {});
563 }
564 
565 // Produces a loop nest for a Minloc intrinsic.
566 void fir::genMinMaxlocReductionLoop(
567     fir::FirOpBuilder &builder, mlir::Value array,
568     fir::InitValGeneratorTy initVal, fir::MinlocBodyOpGeneratorTy genBody,
569     fir::AddrGeneratorTy getAddrFn, unsigned rank, mlir::Type elementType,
570     mlir::Location loc, mlir::Type maskElemType, mlir::Value resultArr,
571     bool maskMayBeLogicalScalar) {
572   mlir::IndexType idxTy = builder.getIndexType();
573 
574   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
575 
576   fir::SequenceType::Shape flatShape(rank,
577                                      fir::SequenceType::getUnknownExtent());
578   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
579   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
580   array = builder.create<fir::ConvertOp>(loc, boxArrTy, array);
581 
582   mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType());
583   mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
584   mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0);
585   mlir::Value flagRef = builder.createTemporary(loc, resultElemType);
586   builder.create<fir::StoreOp>(loc, zero, flagRef);
587 
588   mlir::Value init = initVal(builder, loc, elementType);
589   llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
590 
591   assert(rank > 0 && "rank cannot be zero");
592   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
593 
594   // Compute all the upper bounds before the loop nest.
595   // It is not strictly necessary for performance, since the loop nest
596   // does not have any store operations and any LICM optimization
597   // should be able to optimize the redundancy.
598   for (unsigned i = 0; i < rank; ++i) {
599     mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
600     auto dims =
601         builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
602     mlir::Value len = dims.getResult(1);
603     // We use C indexing here, so len-1 as loopcount
604     mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
605     bounds.push_back(loopCount);
606   }
607   // Create a loop nest consisting of OP operations.
608   // Collect the loops' induction variables into indices array,
609   // which will be used in the innermost loop to load the input
610   // array's element.
611   // The loops are generated such that the innermost loop processes
612   // the 0 dimension.
613   llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
614   for (unsigned i = rank; 0 < i; --i) {
615     mlir::Value step = one;
616     mlir::Value loopCount = bounds[i - 1];
617     auto loop =
618         builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false,
619                                       /*finalCountValue=*/false, init);
620     init = loop.getRegionIterArgs()[0];
621     indices.push_back(loop.getInductionVar());
622     // Set insertion point to the loop body so that the next loop
623     // is inserted inside the current one.
624     builder.setInsertionPointToStart(loop.getBody());
625   }
626 
627   // Reverse the indices such that they are ordered as:
628   //   <dim-0-idx, dim-1-idx, ...>
629   std::reverse(indices.begin(), indices.end());
630   mlir::Value reductionVal =
631       genBody(builder, loc, elementType, array, flagRef, init, indices);
632 
633   // Unwind the loop nest and insert ResultOp on each level
634   // to return the updated value of the reduction to the enclosing
635   // loops.
636   for (unsigned i = 0; i < rank; ++i) {
637     auto result = builder.create<fir::ResultOp>(loc, reductionVal);
638     // Proceed to the outer loop.
639     auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
640     reductionVal = loop.getResult(0);
641     // Set insertion point after the loop operation that we have
642     // just processed.
643     builder.setInsertionPointAfter(loop.getOperation());
644   }
645   // End of loop nest. The insertion point is after the outermost loop.
646   if (maskMayBeLogicalScalar) {
647     if (fir::IfOp ifOp =
648             mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) {
649       builder.create<fir::ResultOp>(loc, reductionVal);
650       builder.setInsertionPointAfter(ifOp);
651       // Redefine flagSet to escape scope of ifOp
652       flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
653       reductionVal = ifOp.getResult(0);
654     }
655   }
656 }
657 
658 static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
659                                     mlir::func::FuncOp &funcOp, bool isMax,
660                                     unsigned rank, int maskRank,
661                                     mlir::Type elementType,
662                                     mlir::Type maskElemType,
663                                     mlir::Type resultElemTy, bool isDim) {
664   auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
665                       mlir::Type elementType) {
666     if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
667       const llvm::fltSemantics &sem = ty.getFloatSemantics();
668       llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
669       return builder.createRealConstant(loc, elementType, limit);
670     }
671     unsigned bits = elementType.getIntOrFloatBitWidth();
672     int64_t initValue = (isMax ? llvm::APInt::getSignedMinValue(bits)
673                                : llvm::APInt::getSignedMaxValue(bits))
674                             .getSExtValue();
675     return builder.createIntegerConstant(loc, elementType, initValue);
676   };
677 
678   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
679   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
680 
681   mlir::Value mask = funcOp.front().getArgument(2);
682 
683   // Set up result array in case of early exit / 0 length array
684   mlir::IndexType idxTy = builder.getIndexType();
685   mlir::Type resultTy = fir::SequenceType::get(rank, resultElemTy);
686   mlir::Type resultHeapTy = fir::HeapType::get(resultTy);
687   mlir::Type resultBoxTy = fir::BoxType::get(resultHeapTy);
688 
689   mlir::Value returnValue = builder.createIntegerConstant(loc, resultElemTy, 0);
690   mlir::Value resultArrSize = builder.createIntegerConstant(loc, idxTy, rank);
691 
692   mlir::Value resultArrInit = builder.create<fir::AllocMemOp>(loc, resultTy);
693   mlir::Value resultArrShape = builder.create<fir::ShapeOp>(loc, resultArrSize);
694   mlir::Value resultArr = builder.create<fir::EmboxOp>(
695       loc, resultBoxTy, resultArrInit, resultArrShape);
696 
697   mlir::Type resultRefTy = builder.getRefType(resultElemTy);
698 
699   if (maskRank > 0) {
700     fir::SequenceType::Shape flatShape(rank,
701                                        fir::SequenceType::getUnknownExtent());
702     mlir::Type maskTy = fir::SequenceType::get(flatShape, maskElemType);
703     mlir::Type boxMaskTy = fir::BoxType::get(maskTy);
704     mask = builder.create<fir::ConvertOp>(loc, boxMaskTy, mask);
705   }
706 
707   for (unsigned int i = 0; i < rank; ++i) {
708     mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
709     mlir::Value resultElemAddr =
710         builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr, index);
711     builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr);
712   }
713 
714   auto genBodyOp =
715       [&rank, &resultArr, isMax, &mask, &maskElemType, &maskRank](
716           fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType,
717           mlir::Value array, mlir::Value flagRef, mlir::Value reduction,
718           const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
719     // We are in the innermost loop: generate the reduction body.
720     if (maskRank > 0) {
721       mlir::Type logicalRef = builder.getRefType(maskElemType);
722       mlir::Value maskAddr =
723           builder.create<fir::CoordinateOp>(loc, logicalRef, mask, indices);
724       mlir::Value maskElem = builder.create<fir::LoadOp>(loc, maskAddr);
725 
726       // fir::IfOp requires argument to be I1 - won't accept logical or any
727       // other Integer.
728       mlir::Type ifCompatType = builder.getI1Type();
729       mlir::Value ifCompatElem =
730           builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem);
731 
732       llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
733       fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, ifCompatElem,
734                                                  /*withElseRegion=*/true);
735       builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
736     }
737 
738     // Set flag that mask was true at some point
739     mlir::Value flagSet = builder.createIntegerConstant(
740         loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
741     mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef);
742     mlir::Type eleRefTy = builder.getRefType(elementType);
743     mlir::Value addr =
744         builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
745     mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
746 
747     mlir::Value cmp;
748     if (mlir::isa<mlir::FloatType>(elementType)) {
749       // For FP reductions we want the first smallest value to be used, that
750       // is not NaN. A OGL/OLT condition will usually work for this unless all
751       // the values are Nan or Inf. This follows the same logic as
752       // NumericCompare for Minloc/Maxlox in extrema.cpp.
753       cmp = builder.create<mlir::arith::CmpFOp>(
754           loc,
755           isMax ? mlir::arith::CmpFPredicate::OGT
756                 : mlir::arith::CmpFPredicate::OLT,
757           elem, reduction);
758 
759       mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
760           loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
761       mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
762           loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
763       cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
764       cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
765     } else if (mlir::isa<mlir::IntegerType>(elementType)) {
766       cmp = builder.create<mlir::arith::CmpIOp>(
767           loc,
768           isMax ? mlir::arith::CmpIPredicate::sgt
769                 : mlir::arith::CmpIPredicate::slt,
770           elem, reduction);
771     } else {
772       llvm_unreachable("unsupported type");
773     }
774 
775     // The condition used for the loop is isFirst || <the condition above>.
776     isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst);
777     isFirst = builder.create<mlir::arith::XOrIOp>(
778         loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1));
779     cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst);
780     fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
781                                                /*withElseRegion*/ true);
782 
783     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
784     builder.create<fir::StoreOp>(loc, flagSet, flagRef);
785     mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
786     mlir::Type returnRefTy = builder.getRefType(resultElemTy);
787     mlir::IndexType idxTy = builder.getIndexType();
788 
789     mlir::Value one = builder.createIntegerConstant(loc, resultElemTy, 1);
790 
791     for (unsigned int i = 0; i < rank; ++i) {
792       mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
793       mlir::Value resultElemAddr =
794           builder.create<fir::CoordinateOp>(loc, returnRefTy, resultArr, index);
795       mlir::Value convert =
796           builder.create<fir::ConvertOp>(loc, resultElemTy, indices[i]);
797       mlir::Value fortranIndex =
798           builder.create<mlir::arith::AddIOp>(loc, convert, one);
799       builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr);
800     }
801     builder.create<fir::ResultOp>(loc, elem);
802     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
803     builder.create<fir::ResultOp>(loc, reduction);
804     builder.setInsertionPointAfter(ifOp);
805     mlir::Value reductionVal = ifOp.getResult(0);
806 
807     // Close the mask if needed
808     if (maskRank > 0) {
809       fir::IfOp ifOp =
810           mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp());
811       builder.create<fir::ResultOp>(loc, reductionVal);
812       builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
813       builder.create<fir::ResultOp>(loc, reduction);
814       reductionVal = ifOp.getResult(0);
815       builder.setInsertionPointAfter(ifOp);
816     }
817 
818     return reductionVal;
819   };
820 
821   // if mask is a logical scalar, we can check its value before the main loop
822   // and either ignore the fact it is there or exit early.
823   if (maskRank == 0) {
824     mlir::Type logical = builder.getI1Type();
825     mlir::IndexType idxTy = builder.getIndexType();
826 
827     fir::SequenceType::Shape singleElement(1, 1);
828     mlir::Type arrTy = fir::SequenceType::get(singleElement, logical);
829     mlir::Type boxArrTy = fir::BoxType::get(arrTy);
830     mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, mask);
831 
832     mlir::Value indx = builder.createIntegerConstant(loc, idxTy, 0);
833     mlir::Type logicalRefTy = builder.getRefType(logical);
834     mlir::Value condAddr =
835         builder.create<fir::CoordinateOp>(loc, logicalRefTy, array, indx);
836     mlir::Value cond = builder.create<fir::LoadOp>(loc, condAddr);
837 
838     fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cond,
839                                                /*withElseRegion=*/true);
840 
841     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
842     mlir::Value basicValue;
843     if (mlir::isa<mlir::IntegerType>(elementType)) {
844       basicValue = builder.createIntegerConstant(loc, elementType, 0);
845     } else {
846       basicValue = builder.createRealConstant(loc, elementType, 0);
847     }
848     builder.create<fir::ResultOp>(loc, basicValue);
849 
850     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
851   }
852   auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
853                       const mlir::Type &resultElemType, mlir::Value resultArr,
854                       mlir::Value index) {
855     mlir::Type resultRefTy = builder.getRefType(resultElemType);
856     return builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr,
857                                              index);
858   };
859 
860   genMinMaxlocReductionLoop(builder, funcOp.front().getArgument(1), init,
861                             genBodyOp, getAddrFn, rank, elementType, loc,
862                             maskElemType, resultArr, maskRank == 0);
863 
864   // Store newly created output array to the reference passed in
865   if (isDim) {
866     mlir::Type resultBoxTy =
867         fir::BoxType::get(fir::HeapType::get(resultElemTy));
868     mlir::Value outputArr = builder.create<fir::ConvertOp>(
869         loc, builder.getRefType(resultBoxTy), funcOp.front().getArgument(0));
870     mlir::Value resultArrScalar = builder.create<fir::ConvertOp>(
871         loc, fir::HeapType::get(resultElemTy), resultArrInit);
872     mlir::Value resultBox =
873         builder.create<fir::EmboxOp>(loc, resultBoxTy, resultArrScalar);
874     builder.create<fir::StoreOp>(loc, resultBox, outputArr);
875   } else {
876     fir::SequenceType::Shape resultShape(1, rank);
877     mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemTy);
878     mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy);
879     mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy);
880     mlir::Type outputRefTy = builder.getRefType(outputBoxTy);
881     mlir::Value outputArr = builder.create<fir::ConvertOp>(
882         loc, outputRefTy, funcOp.front().getArgument(0));
883     builder.create<fir::StoreOp>(loc, resultArr, outputArr);
884   }
885 
886   builder.create<mlir::func::ReturnOp>(loc);
887 }
888 
889 /// Generate function type for the simplified version of RTNAME(DotProduct)
890 /// operating on the given \p elementType.
891 static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder,
892                                             const mlir::Type &elementType) {
893   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
894   return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
895                                  {elementType});
896 }
897 
898 /// Generate function body of the simplified version of RTNAME(DotProduct)
899 /// with signature provided by \p funcOp. The caller is responsible
900 /// for saving/restoring the original insertion point of \p builder.
901 /// \p funcOp is expected to be empty on entry to this function.
902 /// \p arg1ElementTy and \p arg2ElementTy specify elements types
903 /// of the underlying array objects - they are used to generate proper
904 /// element accesses.
905 static void genRuntimeDotBody(fir::FirOpBuilder &builder,
906                               mlir::func::FuncOp &funcOp,
907                               mlir::Type arg1ElementTy,
908                               mlir::Type arg2ElementTy) {
909   // function RTNAME(DotProduct)<T>_simplified(arr1, arr2)
910   //   T, dimension(:) :: arr1, arr2
911   //   T product = 0
912   //   integer iter
913   //   do iter = 0, extent(arr1)
914   //     product = product + arr1[iter] * arr2[iter]
915   //   end do
916   //   RTNAME(ADotProduct)<T>_simplified = product
917   // end function RTNAME(DotProduct)<T>_simplified
918   auto loc = mlir::UnknownLoc::get(builder.getContext());
919   mlir::Type resultElementType = funcOp.getResultTypes()[0];
920   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
921 
922   mlir::IndexType idxTy = builder.getIndexType();
923 
924   mlir::Value zero =
925       mlir::isa<mlir::FloatType>(resultElementType)
926           ? builder.createRealConstant(loc, resultElementType, 0.0)
927           : builder.createIntegerConstant(loc, resultElementType, 0);
928 
929   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
930   mlir::Value arg1 = args[0];
931   mlir::Value arg2 = args[1];
932 
933   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
934 
935   fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
936   mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy);
937   mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1);
938   mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1);
939   mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy);
940   mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2);
941   mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2);
942   // This version takes the loop trip count from the first argument.
943   // If the first argument's box has unknown (at compilation time)
944   // extent, then it may be better to take the extent from the second
945   // argument - so that after inlining the loop may be better optimized, e.g.
946   // fully unrolled. This requires generating two versions of the simplified
947   // function and some analysis at the call site to choose which version
948   // is more profitable to call.
949   // Note that we can assume that both arguments have the same extent.
950   auto dims =
951       builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx);
952   mlir::Value len = dims.getResult(1);
953   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
954   mlir::Value step = one;
955 
956   // We use C indexing here, so len-1 as loopcount
957   mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
958   auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
959                                             /*unordered=*/false,
960                                             /*finalCountValue=*/false, zero);
961   mlir::Value sumVal = loop.getRegionIterArgs()[0];
962 
963   // Begin loop code
964   mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
965   builder.setInsertionPointToStart(loop.getBody());
966 
967   mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy);
968   mlir::Value index = loop.getInductionVar();
969   mlir::Value addr1 =
970       builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index);
971   mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1);
972   // Convert to the result type.
973   elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1);
974 
975   mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy);
976   mlir::Value addr2 =
977       builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index);
978   mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2);
979   // Convert to the result type.
980   elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2);
981 
982   if (mlir::isa<mlir::FloatType>(resultElementType))
983     sumVal = builder.create<mlir::arith::AddFOp>(
984         loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal);
985   else if (mlir::isa<mlir::IntegerType>(resultElementType))
986     sumVal = builder.create<mlir::arith::AddIOp>(
987         loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal);
988   else
989     llvm_unreachable("unsupported type");
990 
991   builder.create<fir::ResultOp>(loc, sumVal);
992   // End of loop.
993   builder.restoreInsertionPoint(loopEndPt);
994 
995   mlir::Value resultVal = loop.getResult(0);
996   builder.create<mlir::func::ReturnOp>(loc, resultVal);
997 }
998 
999 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
1000     fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
1001     FunctionTypeGeneratorTy typeGenerator,
1002     FunctionBodyGeneratorTy bodyGenerator) {
1003   // WARNING: if the function generated here changes its signature
1004   //          or behavior (the body code), we should probably embed some
1005   //          versioning information into its name, otherwise libraries
1006   //          statically linked with older versions of Flang may stop
1007   //          working with object files created with newer Flang.
1008   //          We can also avoid this by using internal linkage, but
1009   //          this may increase the size of final executable/shared library.
1010   std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
1011   // If we already have a function, just return it.
1012   mlir::func::FuncOp newFunc = builder.getNamedFunction(replacementName);
1013   mlir::FunctionType fType = typeGenerator(builder);
1014   if (newFunc) {
1015     assert(newFunc.getFunctionType() == fType &&
1016            "type mismatch for simplified function");
1017     return newFunc;
1018   }
1019 
1020   // Need to build the function!
1021   auto loc = mlir::UnknownLoc::get(builder.getContext());
1022   newFunc = builder.createFunction(loc, replacementName, fType);
1023   auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
1024   auto linkage =
1025       mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
1026   newFunc->setAttr("llvm.linkage", linkage);
1027 
1028   // Save the position of the original call.
1029   mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
1030 
1031   bodyGenerator(builder, newFunc);
1032 
1033   // Now back to where we were adding code earlier...
1034   builder.restoreInsertionPoint(insertPt);
1035 
1036   return newFunc;
1037 }
1038 
1039 void SimplifyIntrinsicsPass::simplifyIntOrFloatReduction(
1040     fir::CallOp call, const fir::KindMapping &kindMap,
1041     GenReductionBodyTy genBodyFunc) {
1042   // args[1] and args[2] are source filename and line number, ignored.
1043   mlir::Operation::operand_range args = call.getArgs();
1044 
1045   const mlir::Value &dim = args[3];
1046   const mlir::Value &mask = args[4];
1047   // dim is zero when it is absent, which is an implementation
1048   // detail in the runtime library.
1049 
1050   bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
1051   unsigned rank = getDimCount(args[0]);
1052 
1053   // Rank is set to 0 for assumed shape arrays, don't simplify
1054   // in these cases
1055   if (!(dimAndMaskAbsent && rank > 0))
1056     return;
1057 
1058   mlir::Type resultType = call.getResult(0).getType();
1059 
1060   if (!mlir::isa<mlir::FloatType>(resultType) &&
1061       !mlir::isa<mlir::IntegerType>(resultType))
1062     return;
1063 
1064   auto argType = getArgElementType(args[0]);
1065   if (!argType)
1066     return;
1067   assert(*argType == resultType &&
1068          "Argument/result types mismatch in reduction");
1069 
1070   mlir::SymbolRefAttr callee = call.getCalleeAttr();
1071 
1072   fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1073   std::string fmfString{builder.getFastMathFlagsString()};
1074   std::string funcName =
1075       (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
1076        mlir::Twine{rank} +
1077        // We must mangle the generated function name with FastMathFlags
1078        // value.
1079        (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString}))
1080           .str();
1081 
1082   simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1083                         resultType);
1084 }
1085 
1086 void SimplifyIntrinsicsPass::simplifyLogicalDim0Reduction(
1087     fir::CallOp call, const fir::KindMapping &kindMap,
1088     GenReductionBodyTy genBodyFunc) {
1089 
1090   mlir::Operation::operand_range args = call.getArgs();
1091   const mlir::Value &dim = args[3];
1092   unsigned rank = getDimCount(args[0]);
1093 
1094   // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
1095   // these cases.
1096   if (!(isZero(dim) && rank > 0))
1097     return;
1098 
1099   mlir::Value inputBox = findBoxDef(args[0]);
1100 
1101   mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType());
1102   mlir::SymbolRefAttr callee = call.getCalleeAttr();
1103 
1104   fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1105 
1106   // Treating logicals as integers makes things a lot easier
1107   fir::LogicalType logicalType = {
1108       mlir::dyn_cast<fir::LogicalType>(elementType)};
1109   fir::KindTy kind = logicalType.getFKind();
1110   mlir::Type intElementType = builder.getIntegerType(kind * 8);
1111 
1112   // Mangle kind into function name as it is not done by default
1113   std::string funcName =
1114       (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} +
1115        mlir::Twine{kind} + "x" + mlir::Twine{rank})
1116           .str();
1117 
1118   simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1119                         intElementType);
1120 }
1121 
1122 void SimplifyIntrinsicsPass::simplifyLogicalDim1Reduction(
1123     fir::CallOp call, const fir::KindMapping &kindMap,
1124     GenReductionBodyTy genBodyFunc) {
1125 
1126   mlir::Operation::operand_range args = call.getArgs();
1127   mlir::SymbolRefAttr callee = call.getCalleeAttr();
1128   mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
1129   unsigned rank = getDimCount(args[0]);
1130 
1131   // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
1132   // these cases. We check for Dim at the end as some logical functions (Any,
1133   // All) set dim to 1 instead of 0 when the argument is not present.
1134   if (funcNameBase.ends_with("Dim") || !(rank > 0))
1135     return;
1136 
1137   mlir::Value inputBox = findBoxDef(args[0]);
1138   mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType());
1139 
1140   fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1141 
1142   // Treating logicals as integers makes things a lot easier
1143   fir::LogicalType logicalType = {
1144       mlir::dyn_cast<fir::LogicalType>(elementType)};
1145   fir::KindTy kind = logicalType.getFKind();
1146   mlir::Type intElementType = builder.getIntegerType(kind * 8);
1147 
1148   // Mangle kind into function name as it is not done by default
1149   std::string funcName =
1150       (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} +
1151        mlir::Twine{kind} + "x" + mlir::Twine{rank})
1152           .str();
1153 
1154   simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1155                         intElementType);
1156 }
1157 
1158 void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
1159     fir::CallOp call, const fir::KindMapping &kindMap, bool isMax) {
1160 
1161   mlir::Operation::operand_range args = call.getArgs();
1162 
1163   mlir::SymbolRefAttr callee = call.getCalleeAttr();
1164   mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
1165   bool isDim = funcNameBase.ends_with("Dim");
1166   mlir::Value back = args[isDim ? 7 : 6];
1167   if (isTrueOrNotConstant(back))
1168     return;
1169 
1170   mlir::Value mask = args[isDim ? 6 : 5];
1171   mlir::Value maskDef = findMaskDef(mask);
1172 
1173   // maskDef is set to NULL when the defining op is not one we accept.
1174   // This tends to be because it is a selectOp, in which case let the
1175   // runtime deal with it.
1176   if (maskDef == NULL)
1177     return;
1178 
1179   unsigned rank = getDimCount(args[1]);
1180   if ((isDim && rank != 1) || !(rank > 0))
1181     return;
1182 
1183   fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1184   mlir::Location loc = call.getLoc();
1185   auto inputBox = findBoxDef(args[1]);
1186   mlir::Type inputType = hlfir::getFortranElementType(inputBox.getType());
1187 
1188   if (mlir::isa<fir::CharacterType>(inputType))
1189     return;
1190 
1191   int maskRank;
1192   fir::KindTy kind = 0;
1193   mlir::Type logicalElemType = builder.getI1Type();
1194   if (isOperandAbsent(mask)) {
1195     maskRank = -1;
1196   } else {
1197     maskRank = getDimCount(mask);
1198     mlir::Type maskElemTy = hlfir::getFortranElementType(maskDef.getType());
1199     fir::LogicalType logicalFirType = {
1200         mlir::dyn_cast<fir::LogicalType>(maskElemTy)};
1201     kind = logicalFirType.getFKind();
1202     // Convert fir::LogicalType to mlir::Type
1203     logicalElemType = logicalFirType;
1204   }
1205 
1206   mlir::Operation *outputDef = args[0].getDefiningOp();
1207   mlir::Value outputAlloc = outputDef->getOperand(0);
1208   mlir::Type outType = hlfir::getFortranElementType(outputAlloc.getType());
1209 
1210   std::string fmfString{builder.getFastMathFlagsString()};
1211   std::string funcName =
1212       (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
1213        mlir::Twine{rank} +
1214        (maskRank >= 0
1215             ? "_Logical" + mlir::Twine{kind} + "x" + mlir::Twine{maskRank}
1216             : "") +
1217        "_")
1218           .str();
1219 
1220   llvm::raw_string_ostream nameOS(funcName);
1221   outType.print(nameOS);
1222   if (isDim)
1223     nameOS << '_' << inputType;
1224   nameOS << '_' << fmfString;
1225 
1226   auto typeGenerator = [rank](fir::FirOpBuilder &builder) {
1227     return genRuntimeMinlocType(builder, rank);
1228   };
1229   auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType,
1230                         isMax, isDim](fir::FirOpBuilder &builder,
1231                                       mlir::func::FuncOp &funcOp) {
1232     genRuntimeMinMaxlocBody(builder, funcOp, isMax, rank, maskRank, inputType,
1233                             logicalElemType, outType, isDim);
1234   };
1235 
1236   mlir::func::FuncOp newFunc =
1237       getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
1238   builder.create<fir::CallOp>(loc, newFunc,
1239                               mlir::ValueRange{args[0], args[1], mask});
1240   call->dropAllReferences();
1241   call->erase();
1242 }
1243 
1244 void SimplifyIntrinsicsPass::simplifyReductionBody(
1245     fir::CallOp call, const fir::KindMapping &kindMap,
1246     GenReductionBodyTy genBodyFunc, fir::FirOpBuilder &builder,
1247     const mlir::StringRef &funcName, mlir::Type elementType) {
1248 
1249   mlir::Operation::operand_range args = call.getArgs();
1250 
1251   mlir::Type resultType = call.getResult(0).getType();
1252   unsigned rank = getDimCount(args[0]);
1253 
1254   mlir::Location loc = call.getLoc();
1255 
1256   auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
1257     return genNoneBoxType(builder, resultType);
1258   };
1259   auto bodyGenerator = [&rank, &genBodyFunc,
1260                         &elementType](fir::FirOpBuilder &builder,
1261                                       mlir::func::FuncOp &funcOp) {
1262     genBodyFunc(builder, funcOp, rank, elementType);
1263   };
1264   // Mangle the function name with the rank value as "x<rank>".
1265   mlir::func::FuncOp newFunc =
1266       getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
1267   auto newCall =
1268       builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
1269   call->replaceAllUsesWith(newCall.getResults());
1270   call->dropAllReferences();
1271   call->erase();
1272 }
1273 
1274 void SimplifyIntrinsicsPass::runOnOperation() {
1275   LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
1276   mlir::ModuleOp module = getOperation();
1277   fir::KindMapping kindMap = fir::getKindMapping(module);
1278   module.walk([&](mlir::Operation *op) {
1279     if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
1280       if (cuf::isInCUDADeviceContext(op))
1281         return;
1282       if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
1283         mlir::StringRef funcName = callee.getLeafReference().getValue();
1284         // Replace call to runtime function for SUM when it has single
1285         // argument (no dim or mask argument) for 1D arrays with either
1286         // Integer4 or Real8 types. Other forms are ignored.
1287         // The new function is added to the module.
1288         //
1289         // Prototype for runtime call (from sum.cpp):
1290         // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
1291         //                int dim, const Descriptor *mask)
1292         //
1293         if (funcName.starts_with(RTNAME_STRING(Sum))) {
1294           simplifyIntOrFloatReduction(call, kindMap, genRuntimeSumBody);
1295           return;
1296         }
1297         if (funcName.starts_with(RTNAME_STRING(DotProduct))) {
1298           LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n");
1299           LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump();
1300                      llvm::dbgs() << "\n");
1301           mlir::Operation::operand_range args = call.getArgs();
1302           const mlir::Value &v1 = args[0];
1303           const mlir::Value &v2 = args[1];
1304           mlir::Location loc = call.getLoc();
1305           fir::FirOpBuilder builder{getSimplificationBuilder(op, kindMap)};
1306           // Stringize the builder's FastMathFlags flags for mangling
1307           // the generated function name.
1308           std::string fmfString{builder.getFastMathFlagsString()};
1309 
1310           mlir::Type type = call.getResult(0).getType();
1311           if (!mlir::isa<mlir::FloatType>(type) &&
1312               !mlir::isa<mlir::IntegerType>(type))
1313             return;
1314 
1315           // Try to find the element types of the boxed arguments.
1316           auto arg1Type = getArgElementType(v1);
1317           auto arg2Type = getArgElementType(v2);
1318 
1319           if (!arg1Type || !arg2Type)
1320             return;
1321 
1322           // Support only floating point and integer arguments
1323           // now (e.g. logical is skipped here).
1324           if (!mlir::isa<mlir::FloatType, mlir::IntegerType>(*arg1Type))
1325             return;
1326           if (!mlir::isa<mlir::FloatType, mlir::IntegerType>(*arg2Type))
1327             return;
1328 
1329           auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
1330             return genRuntimeDotType(builder, type);
1331           };
1332           auto bodyGenerator = [&arg1Type,
1333                                 &arg2Type](fir::FirOpBuilder &builder,
1334                                            mlir::func::FuncOp &funcOp) {
1335             genRuntimeDotBody(builder, funcOp, *arg1Type, *arg2Type);
1336           };
1337 
1338           // Suffix the function name with the element types
1339           // of the arguments.
1340           std::string typedFuncName(funcName);
1341           llvm::raw_string_ostream nameOS(typedFuncName);
1342           // We must mangle the generated function name with FastMathFlags
1343           // value.
1344           if (!fmfString.empty())
1345             nameOS << '_' << fmfString;
1346           nameOS << '_';
1347           arg1Type->print(nameOS);
1348           nameOS << '_';
1349           arg2Type->print(nameOS);
1350 
1351           mlir::func::FuncOp newFunc = getOrCreateFunction(
1352               builder, typedFuncName, typeGenerator, bodyGenerator);
1353           auto newCall = builder.create<fir::CallOp>(loc, newFunc,
1354                                                      mlir::ValueRange{v1, v2});
1355           call->replaceAllUsesWith(newCall.getResults());
1356           call->dropAllReferences();
1357           call->erase();
1358 
1359           LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump();
1360                      llvm::dbgs() << "\n");
1361           return;
1362         }
1363         if (funcName.starts_with(RTNAME_STRING(Maxval))) {
1364           simplifyIntOrFloatReduction(call, kindMap, genRuntimeMaxvalBody);
1365           return;
1366         }
1367         if (funcName.starts_with(RTNAME_STRING(Count))) {
1368           simplifyLogicalDim0Reduction(call, kindMap, genRuntimeCountBody);
1369           return;
1370         }
1371         if (funcName.starts_with(RTNAME_STRING(Any))) {
1372           simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAnyBody);
1373           return;
1374         }
1375         if (funcName.ends_with(RTNAME_STRING(All))) {
1376           simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAllBody);
1377           return;
1378         }
1379         if (funcName.starts_with(RTNAME_STRING(Minloc))) {
1380           simplifyMinMaxlocReduction(call, kindMap, false);
1381           return;
1382         }
1383         if (funcName.starts_with(RTNAME_STRING(Maxloc))) {
1384           simplifyMinMaxlocReduction(call, kindMap, true);
1385           return;
1386         }
1387       }
1388     }
1389   });
1390   LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
1391 }
1392 
1393 void SimplifyIntrinsicsPass::getDependentDialects(
1394     mlir::DialectRegistry &registry) const {
1395   // LLVM::LinkageAttr creation requires that LLVM dialect is loaded.
1396   registry.insert<mlir::LLVM::LLVMDialect>();
1397 }
1398