xref: /llvm-project/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp (revision 9bb47f7f8bcc17d90763d201f383d28489b9b071)
1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 //===----------------------------------------------------------------------===//
10 /// \file
11 /// This pass looks for suitable calls to runtime library for intrinsics that
12 /// can be simplified/specialized and replaces with a specialized function.
13 ///
14 /// For example, SUM(arr) can be specialized as a simple function with one loop,
15 /// compared to the three arguments (plus file & line info) that the runtime
16 /// call has - when the argument is a 1D-array (multiple loops may be needed
17 //  for higher dimension arrays, of course)
18 ///
19 /// The general idea is that besides making the call simpler, it can also be
20 /// inlined by other passes that run after this pass, which further improves
21 /// performance, particularly when the work done in the function is trivial
22 /// and small in size.
23 //===----------------------------------------------------------------------===//
24 
25 #include "flang/Common/Fortran.h"
26 #include "flang/Optimizer/Builder/BoxValue.h"
27 #include "flang/Optimizer/Builder/FIRBuilder.h"
28 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
29 #include "flang/Optimizer/Builder/Todo.h"
30 #include "flang/Optimizer/Dialect/FIROps.h"
31 #include "flang/Optimizer/Dialect/FIRType.h"
32 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
33 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
34 #include "flang/Optimizer/Transforms/Passes.h"
35 #include "flang/Runtime/entry-names.h"
36 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
37 #include "mlir/IR/Matchers.h"
38 #include "mlir/IR/Operation.h"
39 #include "mlir/Pass/Pass.h"
40 #include "mlir/Transforms/DialectConversion.h"
41 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
42 #include "mlir/Transforms/RegionUtils.h"
43 #include "llvm/Support/Debug.h"
44 #include "llvm/Support/raw_ostream.h"
45 #include <llvm/Support/ErrorHandling.h>
46 #include <mlir/Dialect/Arith/IR/Arith.h>
47 #include <mlir/IR/BuiltinTypes.h>
48 #include <mlir/IR/Location.h>
49 #include <mlir/IR/MLIRContext.h>
50 #include <mlir/IR/Value.h>
51 #include <mlir/Support/LLVM.h>
52 #include <optional>
53 
54 namespace fir {
55 #define GEN_PASS_DEF_SIMPLIFYINTRINSICS
56 #include "flang/Optimizer/Transforms/Passes.h.inc"
57 } // namespace fir
58 
59 #define DEBUG_TYPE "flang-simplify-intrinsics"
60 
61 namespace {
62 
63 class SimplifyIntrinsicsPass
64     : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
65   using FunctionTypeGeneratorTy =
66       llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
67   using FunctionBodyGeneratorTy =
68       llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
69   using GenReductionBodyTy = llvm::function_ref<void(
70       fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank,
71       mlir::Type elementType)>;
72 
73 public:
74   /// Generate a new function implementing a simplified version
75   /// of a Fortran runtime function defined by \p basename name.
76   /// \p typeGenerator is a callback that generates the new function's type.
77   /// \p bodyGenerator is a callback that generates the new function's body.
78   /// The new function is created in the \p builder's Module.
79   mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
80                                          const mlir::StringRef &basename,
81                                          FunctionTypeGeneratorTy typeGenerator,
82                                          FunctionBodyGeneratorTy bodyGenerator);
83   void runOnOperation() override;
84   void getDependentDialects(mlir::DialectRegistry &registry) const override;
85 
86 private:
87   /// Helper functions to replace a reduction type of call with its
88   /// simplified form. The actual function is generated using a callback
89   /// function.
90   /// \p call is the call to be replaced
91   /// \p kindMap is used to create FIROpBuilder
92   /// \p genBodyFunc is the callback that builds the replacement function
93   void simplifyIntOrFloatReduction(fir::CallOp call,
94                                    const fir::KindMapping &kindMap,
95                                    GenReductionBodyTy genBodyFunc);
96   void simplifyLogicalDim0Reduction(fir::CallOp call,
97                                     const fir::KindMapping &kindMap,
98                                     GenReductionBodyTy genBodyFunc);
99   void simplifyLogicalDim1Reduction(fir::CallOp call,
100                                     const fir::KindMapping &kindMap,
101                                     GenReductionBodyTy genBodyFunc);
102   void simplifyMinMaxlocReduction(fir::CallOp call,
103                                   const fir::KindMapping &kindMap, bool isMax);
104   void simplifyReductionBody(fir::CallOp call, const fir::KindMapping &kindMap,
105                              GenReductionBodyTy genBodyFunc,
106                              fir::FirOpBuilder &builder,
107                              const mlir::StringRef &basename,
108                              mlir::Type elementType);
109 };
110 
111 } // namespace
112 
113 /// Create FirOpBuilder with the provided \p op insertion point
114 /// and \p kindMap additionally inheriting FastMathFlags from \p op.
115 static fir::FirOpBuilder
116 getSimplificationBuilder(mlir::Operation *op, const fir::KindMapping &kindMap) {
117   fir::FirOpBuilder builder{op, kindMap};
118   auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
119   if (!fmi)
120     return builder;
121 
122   // Regardless of what default FastMathFlags are used by FirOpBuilder,
123   // override them with FastMathFlags attached to the operation.
124   builder.setFastMathFlags(fmi.getFastMathFlagsAttr().getValue());
125   return builder;
126 }
127 
128 /// Generate function type for the simplified version of RTNAME(Sum) and
129 /// similar functions with a fir.box<none> type returning \p elementType.
130 static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
131                                          const mlir::Type &elementType) {
132   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
133   return mlir::FunctionType::get(builder.getContext(), {boxType},
134                                  {elementType});
135 }
136 
137 template <typename Op>
138 Op expectOp(mlir::Value val) {
139   if (Op op = mlir::dyn_cast_or_null<Op>(val.getDefiningOp()))
140     return op;
141   LLVM_DEBUG(llvm::dbgs() << "Didn't find expected " << Op::getOperationName()
142                           << '\n');
143   return nullptr;
144 }
145 
146 template <typename Op>
147 static mlir::Value findDefSingle(fir::ConvertOp op) {
148   if (auto defOp = expectOp<Op>(op->getOperand(0))) {
149     return defOp.getResult();
150   }
151   return {};
152 }
153 
154 template <typename... Ops>
155 static mlir::Value findDef(fir::ConvertOp op) {
156   mlir::Value defOp;
157   // Loop over the operation types given to see if any match, exiting once
158   // a match is found. Cast to void is needed to avoid compiler complaining
159   // that the result of expression is unused
160   (void)((defOp = findDefSingle<Ops>(op), (defOp)) || ...);
161   return defOp;
162 }
163 
164 static bool isOperandAbsent(mlir::Value val) {
165   if (auto op = expectOp<fir::ConvertOp>(val)) {
166     assert(op->getOperands().size() != 0);
167     return mlir::isa_and_nonnull<fir::AbsentOp>(
168         op->getOperand(0).getDefiningOp());
169   }
170   return false;
171 }
172 
173 static bool isTrueOrNotConstant(mlir::Value val) {
174   if (auto op = expectOp<mlir::arith::ConstantOp>(val)) {
175     return !mlir::matchPattern(val, mlir::m_Zero());
176   }
177   return true;
178 }
179 
180 static bool isZero(mlir::Value val) {
181   if (auto op = expectOp<fir::ConvertOp>(val)) {
182     assert(op->getOperands().size() != 0);
183     if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
184       return mlir::matchPattern(defOp, mlir::m_Zero());
185   }
186   return false;
187 }
188 
189 static mlir::Value findBoxDef(mlir::Value val) {
190   if (auto op = expectOp<fir::ConvertOp>(val)) {
191     assert(op->getOperands().size() != 0);
192     return findDef<fir::EmboxOp, fir::ReboxOp>(op);
193   }
194   return {};
195 }
196 
197 static mlir::Value findMaskDef(mlir::Value val) {
198   if (auto op = expectOp<fir::ConvertOp>(val)) {
199     assert(op->getOperands().size() != 0);
200     return findDef<fir::EmboxOp, fir::ReboxOp, fir::AbsentOp>(op);
201   }
202   return {};
203 }
204 
205 static unsigned getDimCount(mlir::Value val) {
206   // In order to find the dimensions count, we look for EmboxOp/ReboxOp
207   // and take the count from its *result* type. Note that in case
208   // of sliced emboxing the operand and the result of EmboxOp/ReboxOp
209   // have different types.
210   // Actually, we can take the box type from the operand of
211   // the first ConvertOp that has non-opaque box type that we meet
212   // going through the ConvertOp chain.
213   if (mlir::Value emboxVal = findBoxDef(val))
214     if (auto boxTy = emboxVal.getType().dyn_cast<fir::BoxType>())
215       if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>())
216         return seqTy.getDimension();
217   return 0;
218 }
219 
220 /// Given the call operation's box argument \p val, discover
221 /// the element type of the underlying array object.
222 /// \returns the element type or std::nullopt if the type cannot
223 /// be reliably found.
224 /// We expect that the argument is a result of fir.convert
225 /// with the destination type of !fir.box<none>.
226 static std::optional<mlir::Type> getArgElementType(mlir::Value val) {
227   mlir::Operation *defOp;
228   do {
229     defOp = val.getDefiningOp();
230     // Analyze only sequences of convert operations.
231     if (!mlir::isa<fir::ConvertOp>(defOp))
232       return std::nullopt;
233     val = defOp->getOperand(0);
234     // The convert operation is expected to convert from one
235     // box type to another box type.
236     auto boxType = val.getType().cast<fir::BoxType>();
237     auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType);
238     if (!elementType.isa<mlir::NoneType>())
239       return elementType;
240   } while (true);
241 }
242 
243 using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
244     fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
245     mlir::Value)>;
246 using InitValGeneratorTy = llvm::function_ref<mlir::Value(
247     fir::FirOpBuilder &, mlir::Location, const mlir::Type &)>;
248 using ContinueLoopGenTy = llvm::function_ref<llvm::SmallVector<mlir::Value>(
249     fir::FirOpBuilder &, mlir::Location, mlir::Value)>;
250 
251 /// Generate the reduction loop into \p funcOp.
252 ///
253 /// \p initVal is a function, called to get the initial value for
254 ///    the reduction value
255 /// \p genBody is called to fill in the actual reduciton operation
256 ///    for example add for SUM, MAX for MAXVAL, etc.
257 /// \p rank is the rank of the input argument.
258 /// \p elementType is the type of the elements in the input array,
259 ///    which may be different to the return type.
260 /// \p loopCond is called to generate the condition to continue or
261 ///    not for IterWhile loops
262 /// \p unorderedOrInitalLoopCond contains either a boolean or bool
263 ///    mlir constant, and controls the inital value for while loops
264 ///    or if DoLoop is ordered/unordered.
265 
266 template <typename OP, typename T, int resultIndex>
267 static void
268 genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
269                  InitValGeneratorTy initVal, ContinueLoopGenTy loopCond,
270                  T unorderedOrInitialLoopCond, BodyOpGeneratorTy genBody,
271                  unsigned rank, mlir::Type elementType, mlir::Location loc) {
272 
273   mlir::IndexType idxTy = builder.getIndexType();
274 
275   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
276   mlir::Value arg = args[0];
277 
278   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
279 
280   fir::SequenceType::Shape flatShape(rank,
281                                      fir::SequenceType::getUnknownExtent());
282   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
283   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
284   mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
285   mlir::Type resultType = funcOp.getResultTypes()[0];
286   mlir::Value init = initVal(builder, loc, resultType);
287 
288   llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
289 
290   assert(rank > 0 && "rank cannot be zero");
291   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
292 
293   // Compute all the upper bounds before the loop nest.
294   // It is not strictly necessary for performance, since the loop nest
295   // does not have any store operations and any LICM optimization
296   // should be able to optimize the redundancy.
297   for (unsigned i = 0; i < rank; ++i) {
298     mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
299     auto dims =
300         builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
301     mlir::Value len = dims.getResult(1);
302     // We use C indexing here, so len-1 as loopcount
303     mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
304     bounds.push_back(loopCount);
305   }
306   // Create a loop nest consisting of OP operations.
307   // Collect the loops' induction variables into indices array,
308   // which will be used in the innermost loop to load the input
309   // array's element.
310   // The loops are generated such that the innermost loop processes
311   // the 0 dimension.
312   llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
313   for (unsigned i = rank; 0 < i; --i) {
314     mlir::Value step = one;
315     mlir::Value loopCount = bounds[i - 1];
316     auto loop = builder.create<OP>(loc, zeroIdx, loopCount, step,
317                                    unorderedOrInitialLoopCond,
318                                    /*finalCountValue=*/false, init);
319     init = loop.getRegionIterArgs()[resultIndex];
320     indices.push_back(loop.getInductionVar());
321     // Set insertion point to the loop body so that the next loop
322     // is inserted inside the current one.
323     builder.setInsertionPointToStart(loop.getBody());
324   }
325 
326   // Reverse the indices such that they are ordered as:
327   //   <dim-0-idx, dim-1-idx, ...>
328   std::reverse(indices.begin(), indices.end());
329   // We are in the innermost loop: generate the reduction body.
330   mlir::Type eleRefTy = builder.getRefType(elementType);
331   mlir::Value addr =
332       builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
333   mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
334   mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init);
335   // Generate vector with condition to continue while loop at [0] and result
336   // from current loop at [1] for IterWhileOp loops, just result at [0] for
337   // DoLoopOp loops.
338   llvm::SmallVector<mlir::Value> results = loopCond(builder, loc, reductionVal);
339 
340   // Unwind the loop nest and insert ResultOp on each level
341   // to return the updated value of the reduction to the enclosing
342   // loops.
343   for (unsigned i = 0; i < rank; ++i) {
344     auto result = builder.create<fir::ResultOp>(loc, results);
345     // Proceed to the outer loop.
346     auto loop = mlir::cast<OP>(result->getParentOp());
347     results = loop.getResults();
348     // Set insertion point after the loop operation that we have
349     // just processed.
350     builder.setInsertionPointAfter(loop.getOperation());
351   }
352   // End of loop nest. The insertion point is after the outermost loop.
353   // Return the reduction value from the function.
354   builder.create<mlir::func::ReturnOp>(loc, results[resultIndex]);
355 }
356 using MinMaxlocBodyOpGeneratorTy = llvm::function_ref<mlir::Value(
357     fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
358     mlir::Value, llvm::SmallVector<mlir::Value, Fortran::common::maxRank> &)>;
359 
360 static void genMinMaxlocReductionLoop(
361     fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
362     InitValGeneratorTy initVal, MinMaxlocBodyOpGeneratorTy genBody,
363     unsigned rank, mlir::Type elementType, mlir::Location loc, bool hasMask,
364     mlir::Type maskElemType, mlir::Value resultArr) {
365 
366   mlir::IndexType idxTy = builder.getIndexType();
367 
368   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
369   mlir::Value arg = args[1];
370 
371   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
372 
373   fir::SequenceType::Shape flatShape(rank,
374                                      fir::SequenceType::getUnknownExtent());
375   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
376   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
377   mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
378 
379   mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType());
380   mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
381   mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0);
382   mlir::Value flagRef = builder.createTemporary(loc, resultElemType);
383   builder.create<fir::StoreOp>(loc, zero, flagRef);
384 
385   mlir::Value mask;
386   if (hasMask) {
387     mlir::Type maskTy = fir::SequenceType::get(flatShape, maskElemType);
388     mlir::Type boxMaskTy = fir::BoxType::get(maskTy);
389     mask = builder.create<fir::ConvertOp>(loc, boxMaskTy, args[2]);
390   }
391 
392   mlir::Value init = initVal(builder, loc, elementType);
393   llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
394 
395   assert(rank > 0 && "rank cannot be zero");
396   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
397 
398   // Compute all the upper bounds before the loop nest.
399   // It is not strictly necessary for performance, since the loop nest
400   // does not have any store operations and any LICM optimization
401   // should be able to optimize the redundancy.
402   for (unsigned i = 0; i < rank; ++i) {
403     mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
404     auto dims =
405         builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
406     mlir::Value len = dims.getResult(1);
407     // We use C indexing here, so len-1 as loopcount
408     mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
409     bounds.push_back(loopCount);
410   }
411   // Create a loop nest consisting of OP operations.
412   // Collect the loops' induction variables into indices array,
413   // which will be used in the innermost loop to load the input
414   // array's element.
415   // The loops are generated such that the innermost loop processes
416   // the 0 dimension.
417   llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
418   for (unsigned i = rank; 0 < i; --i) {
419     mlir::Value step = one;
420     mlir::Value loopCount = bounds[i - 1];
421     auto loop =
422         builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false,
423                                       /*finalCountValue=*/false, init);
424     init = loop.getRegionIterArgs()[0];
425     indices.push_back(loop.getInductionVar());
426     // Set insertion point to the loop body so that the next loop
427     // is inserted inside the current one.
428     builder.setInsertionPointToStart(loop.getBody());
429   }
430 
431   // Reverse the indices such that they are ordered as:
432   //   <dim-0-idx, dim-1-idx, ...>
433   std::reverse(indices.begin(), indices.end());
434   // We are in the innermost loop: generate the reduction body.
435   if (hasMask) {
436     mlir::Type logicalRef = builder.getRefType(maskElemType);
437     mlir::Value maskAddr =
438         builder.create<fir::CoordinateOp>(loc, logicalRef, mask, indices);
439     mlir::Value maskElem = builder.create<fir::LoadOp>(loc, maskAddr);
440 
441     // fir::IfOp requires argument to be I1 - won't accept logical or any other
442     // Integer.
443     mlir::Type ifCompatType = builder.getI1Type();
444     mlir::Value ifCompatElem =
445         builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem);
446 
447     llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
448     fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, ifCompatElem,
449                                                /*withElseRegion=*/true);
450     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
451   }
452 
453   // Set flag that mask was true at some point
454   builder.create<fir::StoreOp>(loc, flagSet, flagRef);
455   mlir::Type eleRefTy = builder.getRefType(elementType);
456   mlir::Value addr =
457       builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
458   mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
459 
460   mlir::Value reductionVal =
461       genBody(builder, loc, elementType, elem, init, indices);
462 
463   if (hasMask) {
464     fir::IfOp ifOp =
465         mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp());
466     builder.create<fir::ResultOp>(loc, reductionVal);
467     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
468     builder.create<fir::ResultOp>(loc, init);
469     reductionVal = ifOp.getResult(0);
470     builder.setInsertionPointAfter(ifOp);
471   }
472 
473   // Unwind the loop nest and insert ResultOp on each level
474   // to return the updated value of the reduction to the enclosing
475   // loops.
476   for (unsigned i = 0; i < rank; ++i) {
477     auto result = builder.create<fir::ResultOp>(loc, reductionVal);
478     // Proceed to the outer loop.
479     auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
480     reductionVal = loop.getResult(0);
481     // Set insertion point after the loop operation that we have
482     // just processed.
483     builder.setInsertionPointAfter(loop.getOperation());
484   }
485   // End of loop nest. The insertion point is after the outermost loop.
486   if (fir::IfOp ifOp =
487           mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) {
488     builder.create<fir::ResultOp>(loc, reductionVal);
489     builder.setInsertionPointAfter(ifOp);
490     // Redefine flagSet to escape scope of ifOp
491     flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
492     reductionVal = ifOp.getResult(0);
493   }
494 
495   // Check for case where array was full of max values.
496   // flag will be 0 if mask was never true, 1 if mask was true as some point,
497   // this is needed to avoid catching cases where we didn't access any elements
498   // e.g. mask=.FALSE.
499   mlir::Value flagValue =
500       builder.create<fir::LoadOp>(loc, resultElemType, flagRef);
501   mlir::Value flagCmp = builder.create<mlir::arith::CmpIOp>(
502       loc, mlir::arith::CmpIPredicate::eq, flagValue, flagSet);
503   fir::IfOp ifMaskTrueOp =
504       builder.create<fir::IfOp>(loc, flagCmp, /*withElseRegion=*/false);
505   builder.setInsertionPointToStart(&ifMaskTrueOp.getThenRegion().front());
506 
507   mlir::Value testInit = initVal(builder, loc, elementType);
508   fir::IfOp ifMinSetOp;
509   if (elementType.isa<mlir::FloatType>()) {
510     mlir::Value cmp = builder.create<mlir::arith::CmpFOp>(
511         loc, mlir::arith::CmpFPredicate::OEQ, testInit, reductionVal);
512     ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
513                                            /*withElseRegion*/ false);
514   } else {
515     mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
516         loc, mlir::arith::CmpIPredicate::eq, testInit, reductionVal);
517     ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
518                                            /*withElseRegion*/ false);
519   }
520   builder.setInsertionPointToStart(&ifMinSetOp.getThenRegion().front());
521 
522   // Load output array with 1s instead of 0s
523   for (unsigned int i = 0; i < rank; ++i) {
524     mlir::Type resultRefTy = builder.getRefType(resultElemType);
525     // mlir::Value one = builder.createIntegerConstant(loc, resultElemType, 1);
526     mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
527     mlir::Value resultElemAddr =
528         builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr, index);
529     builder.create<fir::StoreOp>(loc, flagSet, resultElemAddr);
530   }
531   builder.setInsertionPointAfter(ifMaskTrueOp);
532   // Store newly created output array to the reference passed in
533   fir::SequenceType::Shape resultShape(1, rank);
534   mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemType);
535   mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy);
536   mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy);
537   mlir::Type outputRefTy = builder.getRefType(outputBoxTy);
538 
539   mlir::Value outputArrNone = args[0];
540   mlir::Value outputArr =
541       builder.create<fir::ConvertOp>(loc, outputRefTy, outputArrNone);
542 
543   // Store nearly created array to output array
544   builder.create<fir::StoreOp>(loc, resultArr, outputArr);
545   builder.create<mlir::func::ReturnOp>(loc);
546 }
547 
548 static llvm::SmallVector<mlir::Value> nopLoopCond(fir::FirOpBuilder &builder,
549                                                   mlir::Location loc,
550                                                   mlir::Value reductionVal) {
551   return {reductionVal};
552 }
553 
554 /// Generate function body of the simplified version of RTNAME(Sum)
555 /// with signature provided by \p funcOp. The caller is responsible
556 /// for saving/restoring the original insertion point of \p builder.
557 /// \p funcOp is expected to be empty on entry to this function.
558 /// \p rank specifies the rank of the input argument.
559 static void genRuntimeSumBody(fir::FirOpBuilder &builder,
560                               mlir::func::FuncOp &funcOp, unsigned rank,
561                               mlir::Type elementType) {
562   // function RTNAME(Sum)<T>x<rank>_simplified(arr)
563   //   T, dimension(:) :: arr
564   //   T sum = 0
565   //   integer iter
566   //   do iter = 0, extent(arr)
567   //     sum = sum + arr[iter]
568   //   end do
569   //   RTNAME(Sum)<T>x<rank>_simplified = sum
570   // end function RTNAME(Sum)<T>x<rank>_simplified
571   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
572                  mlir::Type elementType) {
573     if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
574       const llvm::fltSemantics &sem = ty.getFloatSemantics();
575       return builder.createRealConstant(loc, elementType,
576                                         llvm::APFloat::getZero(sem));
577     }
578     return builder.createIntegerConstant(loc, elementType, 0);
579   };
580 
581   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
582                       mlir::Type elementType, mlir::Value elem1,
583                       mlir::Value elem2) -> mlir::Value {
584     if (elementType.isa<mlir::FloatType>())
585       return builder.create<mlir::arith::AddFOp>(loc, elem1, elem2);
586     if (elementType.isa<mlir::IntegerType>())
587       return builder.create<mlir::arith::AddIOp>(loc, elem1, elem2);
588 
589     llvm_unreachable("unsupported type");
590     return {};
591   };
592 
593   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
594   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
595 
596   genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond,
597                                            false, genBodyOp, rank, elementType,
598                                            loc);
599 }
600 
601 static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
602                                  mlir::func::FuncOp &funcOp, unsigned rank,
603                                  mlir::Type elementType) {
604   auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
605                  mlir::Type elementType) {
606     if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
607       const llvm::fltSemantics &sem = ty.getFloatSemantics();
608       return builder.createRealConstant(
609           loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/true));
610     }
611     unsigned bits = elementType.getIntOrFloatBitWidth();
612     int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
613     return builder.createIntegerConstant(loc, elementType, minInt);
614   };
615 
616   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
617                       mlir::Type elementType, mlir::Value elem1,
618                       mlir::Value elem2) -> mlir::Value {
619     if (elementType.isa<mlir::FloatType>()) {
620       // arith.maxf later converted to llvm.intr.maxnum does not work
621       // correctly for NaNs and -0.0 (see maxnum/minnum pattern matching
622       // in LLVM's InstCombine pass). Moreover, llvm.intr.maxnum
623       // for F128 operands is lowered into fmaxl call by LLVM.
624       // This libm function may not work properly for F128 arguments
625       // on targets where long double is not F128. It is an LLVM issue,
626       // but we just use normal select here to resolve all the cases.
627       auto compare = builder.create<mlir::arith::CmpFOp>(
628           loc, mlir::arith::CmpFPredicate::OGT, elem1, elem2);
629       return builder.create<mlir::arith::SelectOp>(loc, compare, elem1, elem2);
630     }
631     if (elementType.isa<mlir::IntegerType>())
632       return builder.create<mlir::arith::MaxSIOp>(loc, elem1, elem2);
633 
634     llvm_unreachable("unsupported type");
635     return {};
636   };
637 
638   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
639   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
640 
641   genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, init, nopLoopCond,
642                                            false, genBodyOp, rank, elementType,
643                                            loc);
644 }
645 
646 static void genRuntimeCountBody(fir::FirOpBuilder &builder,
647                                 mlir::func::FuncOp &funcOp, unsigned rank,
648                                 mlir::Type elementType) {
649   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
650                  mlir::Type elementType) {
651     unsigned bits = elementType.getIntOrFloatBitWidth();
652     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
653     return builder.createIntegerConstant(loc, elementType, zeroInt);
654   };
655 
656   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
657                       mlir::Type elementType, mlir::Value elem1,
658                       mlir::Value elem2) -> mlir::Value {
659     auto zero32 = builder.createIntegerConstant(loc, elementType, 0);
660     auto zero64 = builder.createIntegerConstant(loc, builder.getI64Type(), 0);
661     auto one64 = builder.createIntegerConstant(loc, builder.getI64Type(), 1);
662 
663     auto compare = builder.create<mlir::arith::CmpIOp>(
664         loc, mlir::arith::CmpIPredicate::eq, elem1, zero32);
665     auto select =
666         builder.create<mlir::arith::SelectOp>(loc, compare, zero64, one64);
667     return builder.create<mlir::arith::AddIOp>(loc, select, elem2);
668   };
669 
670   // Count always gets I32 for elementType as it converts logical input to
671   // logical<4> before passing to the function.
672   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
673   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
674 
675   genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond,
676                                            false, genBodyOp, rank, elementType,
677                                            loc);
678 }
679 
680 static void genRuntimeAnyBody(fir::FirOpBuilder &builder,
681                               mlir::func::FuncOp &funcOp, unsigned rank,
682                               mlir::Type elementType) {
683   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
684                  mlir::Type elementType) {
685     return builder.createIntegerConstant(loc, elementType, 0);
686   };
687 
688   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
689                       mlir::Type elementType, mlir::Value elem1,
690                       mlir::Value elem2) -> mlir::Value {
691     auto zero = builder.createIntegerConstant(loc, elementType, 0);
692     return builder.create<mlir::arith::CmpIOp>(
693         loc, mlir::arith::CmpIPredicate::ne, elem1, zero);
694   };
695 
696   auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc,
697                          mlir::Value reductionVal) {
698     auto one1 = builder.createIntegerConstant(loc, builder.getI1Type(), 1);
699     auto eor = builder.create<mlir::arith::XOrIOp>(loc, reductionVal, one1);
700     llvm::SmallVector<mlir::Value> results = {eor, reductionVal};
701     return results;
702   };
703 
704   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
705   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
706   mlir::Value ok = builder.createBool(loc, true);
707 
708   genReductionLoop<fir::IterWhileOp, mlir::Value, 1>(
709       builder, funcOp, zero, continueCond, ok, genBodyOp, rank, elementType,
710       loc);
711 }
712 
713 static void genRuntimeAllBody(fir::FirOpBuilder &builder,
714                               mlir::func::FuncOp &funcOp, unsigned rank,
715                               mlir::Type elementType) {
716   auto one = [](fir::FirOpBuilder builder, mlir::Location loc,
717                 mlir::Type elementType) {
718     return builder.createIntegerConstant(loc, elementType, 1);
719   };
720 
721   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
722                       mlir::Type elementType, mlir::Value elem1,
723                       mlir::Value elem2) -> mlir::Value {
724     auto zero = builder.createIntegerConstant(loc, elementType, 0);
725     return builder.create<mlir::arith::CmpIOp>(
726         loc, mlir::arith::CmpIPredicate::ne, elem1, zero);
727   };
728 
729   auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc,
730                          mlir::Value reductionVal) {
731     llvm::SmallVector<mlir::Value> results = {reductionVal, reductionVal};
732     return results;
733   };
734 
735   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
736   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
737   mlir::Value ok = builder.createBool(loc, true);
738 
739   genReductionLoop<fir::IterWhileOp, mlir::Value, 1>(
740       builder, funcOp, one, continueCond, ok, genBodyOp, rank, elementType,
741       loc);
742 }
743 
744 static mlir::FunctionType genRuntimeMinlocType(fir::FirOpBuilder &builder,
745                                                unsigned int rank) {
746   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
747   mlir::Type boxRefType = builder.getRefType(boxType);
748 
749   return mlir::FunctionType::get(builder.getContext(),
750                                  {boxRefType, boxType, boxType}, {});
751 }
752 
753 static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
754                                     mlir::func::FuncOp &funcOp, bool isMax,
755                                     unsigned rank, int maskRank,
756                                     mlir::Type elementType,
757                                     mlir::Type maskElemType,
758                                     mlir::Type resultElemTy) {
759   auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
760                       mlir::Type elementType) {
761     if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
762       const llvm::fltSemantics &sem = ty.getFloatSemantics();
763       return builder.createRealConstant(
764           loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/isMax));
765     }
766     unsigned bits = elementType.getIntOrFloatBitWidth();
767     int64_t initValue = (isMax ? llvm::APInt::getSignedMinValue(bits)
768                                : llvm::APInt::getSignedMaxValue(bits))
769                             .getSExtValue();
770     return builder.createIntegerConstant(loc, elementType, initValue);
771   };
772 
773   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
774   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
775 
776   mlir::Value mask = funcOp.front().getArgument(2);
777 
778   // Set up result array in case of early exit / 0 length array
779   mlir::IndexType idxTy = builder.getIndexType();
780   mlir::Type resultTy = fir::SequenceType::get(rank, resultElemTy);
781   mlir::Type resultHeapTy = fir::HeapType::get(resultTy);
782   mlir::Type resultBoxTy = fir::BoxType::get(resultHeapTy);
783 
784   mlir::Value returnValue = builder.createIntegerConstant(loc, resultElemTy, 0);
785   mlir::Value resultArrSize = builder.createIntegerConstant(loc, idxTy, rank);
786 
787   mlir::Value resultArrInit = builder.create<fir::AllocMemOp>(loc, resultTy);
788   mlir::Value resultArrShape = builder.create<fir::ShapeOp>(loc, resultArrSize);
789   mlir::Value resultArr = builder.create<fir::EmboxOp>(
790       loc, resultBoxTy, resultArrInit, resultArrShape);
791 
792   mlir::Type resultRefTy = builder.getRefType(resultElemTy);
793 
794   for (unsigned int i = 0; i < rank; ++i) {
795     mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
796     mlir::Value resultElemAddr =
797         builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr, index);
798     builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr);
799   }
800 
801   auto genBodyOp =
802       [&rank, &resultArr,
803        isMax](fir::FirOpBuilder builder, mlir::Location loc,
804               mlir::Type elementType, mlir::Value elem1, mlir::Value elem2,
805               llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices)
806       -> mlir::Value {
807     mlir::Value cmp;
808     if (elementType.isa<mlir::FloatType>()) {
809       cmp = builder.create<mlir::arith::CmpFOp>(
810           loc,
811           isMax ? mlir::arith::CmpFPredicate::OGT
812                 : mlir::arith::CmpFPredicate::OLT,
813           elem1, elem2);
814     } else if (elementType.isa<mlir::IntegerType>()) {
815       cmp = builder.create<mlir::arith::CmpIOp>(
816           loc,
817           isMax ? mlir::arith::CmpIPredicate::sgt
818                 : mlir::arith::CmpIPredicate::slt,
819           elem1, elem2);
820     } else {
821       llvm_unreachable("unsupported type");
822     }
823 
824     fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
825                                                /*withElseRegion*/ true);
826 
827     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
828     mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
829     mlir::Type returnRefTy = builder.getRefType(resultElemTy);
830     mlir::IndexType idxTy = builder.getIndexType();
831 
832     mlir::Value one = builder.createIntegerConstant(loc, resultElemTy, 1);
833 
834     for (unsigned int i = 0; i < rank; ++i) {
835       mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
836       mlir::Value resultElemAddr =
837           builder.create<fir::CoordinateOp>(loc, returnRefTy, resultArr, index);
838       mlir::Value convert =
839           builder.create<fir::ConvertOp>(loc, resultElemTy, indices[i]);
840       mlir::Value fortranIndex =
841           builder.create<mlir::arith::AddIOp>(loc, convert, one);
842       builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr);
843     }
844     builder.create<fir::ResultOp>(loc, elem1);
845     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
846     builder.create<fir::ResultOp>(loc, elem2);
847     builder.setInsertionPointAfter(ifOp);
848     return ifOp.getResult(0);
849   };
850 
851   // if mask is a logical scalar, we can check its value before the main loop
852   // and either ignore the fact it is there or exit early.
853   if (maskRank == 0) {
854     mlir::Type logical = builder.getI1Type();
855     mlir::IndexType idxTy = builder.getIndexType();
856 
857     fir::SequenceType::Shape singleElement(1, 1);
858     mlir::Type arrTy = fir::SequenceType::get(singleElement, logical);
859     mlir::Type boxArrTy = fir::BoxType::get(arrTy);
860     mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, mask);
861 
862     mlir::Value indx = builder.createIntegerConstant(loc, idxTy, 0);
863     mlir::Type logicalRefTy = builder.getRefType(logical);
864     mlir::Value condAddr =
865         builder.create<fir::CoordinateOp>(loc, logicalRefTy, array, indx);
866     mlir::Value cond = builder.create<fir::LoadOp>(loc, condAddr);
867 
868     fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cond,
869                                                /*withElseRegion=*/true);
870 
871     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
872     mlir::Value basicValue;
873     if (elementType.isa<mlir::IntegerType>()) {
874       basicValue = builder.createIntegerConstant(loc, elementType, 0);
875     } else {
876       basicValue = builder.createRealConstant(loc, elementType, 0);
877     }
878     builder.create<fir::ResultOp>(loc, basicValue);
879 
880     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
881   }
882 
883   // bit of a hack - maskRank is set to -1 for absent mask arg, so don't
884   // generate high level mask or element by element mask.
885   bool hasMask = maskRank > 0;
886   genMinMaxlocReductionLoop(builder, funcOp, init, genBodyOp, rank, elementType,
887                             loc, hasMask, maskElemType, resultArr);
888 }
889 
890 /// Generate function type for the simplified version of RTNAME(DotProduct)
891 /// operating on the given \p elementType.
892 static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder,
893                                             const mlir::Type &elementType) {
894   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
895   return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
896                                  {elementType});
897 }
898 
899 /// Generate function body of the simplified version of RTNAME(DotProduct)
900 /// with signature provided by \p funcOp. The caller is responsible
901 /// for saving/restoring the original insertion point of \p builder.
902 /// \p funcOp is expected to be empty on entry to this function.
903 /// \p arg1ElementTy and \p arg2ElementTy specify elements types
904 /// of the underlying array objects - they are used to generate proper
905 /// element accesses.
906 static void genRuntimeDotBody(fir::FirOpBuilder &builder,
907                               mlir::func::FuncOp &funcOp,
908                               mlir::Type arg1ElementTy,
909                               mlir::Type arg2ElementTy) {
910   // function RTNAME(DotProduct)<T>_simplified(arr1, arr2)
911   //   T, dimension(:) :: arr1, arr2
912   //   T product = 0
913   //   integer iter
914   //   do iter = 0, extent(arr1)
915   //     product = product + arr1[iter] * arr2[iter]
916   //   end do
917   //   RTNAME(ADotProduct)<T>_simplified = product
918   // end function RTNAME(DotProduct)<T>_simplified
919   auto loc = mlir::UnknownLoc::get(builder.getContext());
920   mlir::Type resultElementType = funcOp.getResultTypes()[0];
921   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
922 
923   mlir::IndexType idxTy = builder.getIndexType();
924 
925   mlir::Value zero =
926       resultElementType.isa<mlir::FloatType>()
927           ? builder.createRealConstant(loc, resultElementType, 0.0)
928           : builder.createIntegerConstant(loc, resultElementType, 0);
929 
930   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
931   mlir::Value arg1 = args[0];
932   mlir::Value arg2 = args[1];
933 
934   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
935 
936   fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
937   mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy);
938   mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1);
939   mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1);
940   mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy);
941   mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2);
942   mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2);
943   // This version takes the loop trip count from the first argument.
944   // If the first argument's box has unknown (at compilation time)
945   // extent, then it may be better to take the extent from the second
946   // argument - so that after inlining the loop may be better optimized, e.g.
947   // fully unrolled. This requires generating two versions of the simplified
948   // function and some analysis at the call site to choose which version
949   // is more profitable to call.
950   // Note that we can assume that both arguments have the same extent.
951   auto dims =
952       builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx);
953   mlir::Value len = dims.getResult(1);
954   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
955   mlir::Value step = one;
956 
957   // We use C indexing here, so len-1 as loopcount
958   mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
959   auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
960                                             /*unordered=*/false,
961                                             /*finalCountValue=*/false, zero);
962   mlir::Value sumVal = loop.getRegionIterArgs()[0];
963 
964   // Begin loop code
965   mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
966   builder.setInsertionPointToStart(loop.getBody());
967 
968   mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy);
969   mlir::Value index = loop.getInductionVar();
970   mlir::Value addr1 =
971       builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index);
972   mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1);
973   // Convert to the result type.
974   elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1);
975 
976   mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy);
977   mlir::Value addr2 =
978       builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index);
979   mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2);
980   // Convert to the result type.
981   elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2);
982 
983   if (resultElementType.isa<mlir::FloatType>())
984     sumVal = builder.create<mlir::arith::AddFOp>(
985         loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal);
986   else if (resultElementType.isa<mlir::IntegerType>())
987     sumVal = builder.create<mlir::arith::AddIOp>(
988         loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal);
989   else
990     llvm_unreachable("unsupported type");
991 
992   builder.create<fir::ResultOp>(loc, sumVal);
993   // End of loop.
994   builder.restoreInsertionPoint(loopEndPt);
995 
996   mlir::Value resultVal = loop.getResult(0);
997   builder.create<mlir::func::ReturnOp>(loc, resultVal);
998 }
999 
1000 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
1001     fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
1002     FunctionTypeGeneratorTy typeGenerator,
1003     FunctionBodyGeneratorTy bodyGenerator) {
1004   // WARNING: if the function generated here changes its signature
1005   //          or behavior (the body code), we should probably embed some
1006   //          versioning information into its name, otherwise libraries
1007   //          statically linked with older versions of Flang may stop
1008   //          working with object files created with newer Flang.
1009   //          We can also avoid this by using internal linkage, but
1010   //          this may increase the size of final executable/shared library.
1011   std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
1012   mlir::ModuleOp module = builder.getModule();
1013   // If we already have a function, just return it.
1014   mlir::func::FuncOp newFunc =
1015       fir::FirOpBuilder::getNamedFunction(module, replacementName);
1016   mlir::FunctionType fType = typeGenerator(builder);
1017   if (newFunc) {
1018     assert(newFunc.getFunctionType() == fType &&
1019            "type mismatch for simplified function");
1020     return newFunc;
1021   }
1022 
1023   // Need to build the function!
1024   auto loc = mlir::UnknownLoc::get(builder.getContext());
1025   newFunc =
1026       fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
1027   auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
1028   auto linkage =
1029       mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
1030   newFunc->setAttr("llvm.linkage", linkage);
1031 
1032   // Save the position of the original call.
1033   mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
1034 
1035   bodyGenerator(builder, newFunc);
1036 
1037   // Now back to where we were adding code earlier...
1038   builder.restoreInsertionPoint(insertPt);
1039 
1040   return newFunc;
1041 }
1042 
1043 void SimplifyIntrinsicsPass::simplifyIntOrFloatReduction(
1044     fir::CallOp call, const fir::KindMapping &kindMap,
1045     GenReductionBodyTy genBodyFunc) {
1046   // args[1] and args[2] are source filename and line number, ignored.
1047   mlir::Operation::operand_range args = call.getArgs();
1048 
1049   const mlir::Value &dim = args[3];
1050   const mlir::Value &mask = args[4];
1051   // dim is zero when it is absent, which is an implementation
1052   // detail in the runtime library.
1053 
1054   bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
1055   unsigned rank = getDimCount(args[0]);
1056 
1057   // Rank is set to 0 for assumed shape arrays, don't simplify
1058   // in these cases
1059   if (!(dimAndMaskAbsent && rank > 0))
1060     return;
1061 
1062   mlir::Type resultType = call.getResult(0).getType();
1063 
1064   if (!resultType.isa<mlir::FloatType>() &&
1065       !resultType.isa<mlir::IntegerType>())
1066     return;
1067 
1068   auto argType = getArgElementType(args[0]);
1069   if (!argType)
1070     return;
1071   assert(*argType == resultType &&
1072          "Argument/result types mismatch in reduction");
1073 
1074   mlir::SymbolRefAttr callee = call.getCalleeAttr();
1075 
1076   fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1077   std::string fmfString{builder.getFastMathFlagsString()};
1078   std::string funcName =
1079       (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
1080        mlir::Twine{rank} +
1081        // We must mangle the generated function name with FastMathFlags
1082        // value.
1083        (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString}))
1084           .str();
1085 
1086   simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1087                         resultType);
1088 }
1089 
1090 void SimplifyIntrinsicsPass::simplifyLogicalDim0Reduction(
1091     fir::CallOp call, const fir::KindMapping &kindMap,
1092     GenReductionBodyTy genBodyFunc) {
1093 
1094   mlir::Operation::operand_range args = call.getArgs();
1095   const mlir::Value &dim = args[3];
1096   unsigned rank = getDimCount(args[0]);
1097 
1098   // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
1099   // these cases.
1100   if (!(isZero(dim) && rank > 0))
1101     return;
1102 
1103   mlir::Value inputBox = findBoxDef(args[0]);
1104 
1105   mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType());
1106   mlir::SymbolRefAttr callee = call.getCalleeAttr();
1107 
1108   fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1109 
1110   // Treating logicals as integers makes things a lot easier
1111   fir::LogicalType logicalType = {elementType.dyn_cast<fir::LogicalType>()};
1112   fir::KindTy kind = logicalType.getFKind();
1113   mlir::Type intElementType = builder.getIntegerType(kind * 8);
1114 
1115   // Mangle kind into function name as it is not done by default
1116   std::string funcName =
1117       (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} +
1118        mlir::Twine{kind} + "x" + mlir::Twine{rank})
1119           .str();
1120 
1121   simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1122                         intElementType);
1123 }
1124 
1125 void SimplifyIntrinsicsPass::simplifyLogicalDim1Reduction(
1126     fir::CallOp call, const fir::KindMapping &kindMap,
1127     GenReductionBodyTy genBodyFunc) {
1128 
1129   mlir::Operation::operand_range args = call.getArgs();
1130   mlir::SymbolRefAttr callee = call.getCalleeAttr();
1131   mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
1132   unsigned rank = getDimCount(args[0]);
1133 
1134   // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
1135   // these cases. We check for Dim at the end as some logical functions (Any,
1136   // All) set dim to 1 instead of 0 when the argument is not present.
1137   if (funcNameBase.ends_with("Dim") || !(rank > 0))
1138     return;
1139 
1140   mlir::Value inputBox = findBoxDef(args[0]);
1141   mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType());
1142 
1143   fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1144 
1145   // Treating logicals as integers makes things a lot easier
1146   fir::LogicalType logicalType = {elementType.dyn_cast<fir::LogicalType>()};
1147   fir::KindTy kind = logicalType.getFKind();
1148   mlir::Type intElementType = builder.getIntegerType(kind * 8);
1149 
1150   // Mangle kind into function name as it is not done by default
1151   std::string funcName =
1152       (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} +
1153        mlir::Twine{kind} + "x" + mlir::Twine{rank})
1154           .str();
1155 
1156   simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1157                         intElementType);
1158 }
1159 
1160 void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
1161     fir::CallOp call, const fir::KindMapping &kindMap, bool isMax) {
1162 
1163   mlir::Operation::operand_range args = call.getArgs();
1164 
1165   mlir::Value back = args[6];
1166   if (isTrueOrNotConstant(back))
1167     return;
1168 
1169   mlir::Value mask = args[5];
1170   mlir::Value maskDef = findMaskDef(mask);
1171 
1172   // maskDef is set to NULL when the defining op is not one we accept.
1173   // This tends to be because it is a selectOp, in which case let the
1174   // runtime deal with it.
1175   if (maskDef == NULL)
1176     return;
1177 
1178   mlir::SymbolRefAttr callee = call.getCalleeAttr();
1179   mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
1180   unsigned rank = getDimCount(args[1]);
1181   if (funcNameBase.ends_with("Dim") || !(rank > 0))
1182     return;
1183 
1184   fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1185   mlir::Location loc = call.getLoc();
1186   auto inputBox = findBoxDef(args[1]);
1187   mlir::Type inputType = hlfir::getFortranElementType(inputBox.getType());
1188 
1189   if (inputType.isa<fir::CharacterType>())
1190     return;
1191 
1192   int maskRank;
1193   fir::KindTy kind = 0;
1194   mlir::Type logicalElemType = builder.getI1Type();
1195   if (isOperandAbsent(mask)) {
1196     maskRank = -1;
1197   } else {
1198     maskRank = getDimCount(mask);
1199     mlir::Type maskElemTy = hlfir::getFortranElementType(maskDef.getType());
1200     fir::LogicalType logicalFirType = {maskElemTy.dyn_cast<fir::LogicalType>()};
1201     kind = logicalFirType.getFKind();
1202     // Convert fir::LogicalType to mlir::Type
1203     logicalElemType = logicalFirType;
1204   }
1205 
1206   mlir::Operation *outputDef = args[0].getDefiningOp();
1207   mlir::Value outputAlloc = outputDef->getOperand(0);
1208   mlir::Type outType = hlfir::getFortranElementType(outputAlloc.getType());
1209 
1210   std::string fmfString{builder.getFastMathFlagsString()};
1211   std::string funcName =
1212       (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
1213        mlir::Twine{rank} +
1214        (maskRank >= 0
1215             ? "_Logical" + mlir::Twine{kind} + "x" + mlir::Twine{maskRank}
1216             : "") +
1217        "_")
1218           .str();
1219 
1220   llvm::raw_string_ostream nameOS(funcName);
1221   outType.print(nameOS);
1222   nameOS << '_' << fmfString;
1223 
1224   auto typeGenerator = [rank](fir::FirOpBuilder &builder) {
1225     return genRuntimeMinlocType(builder, rank);
1226   };
1227   auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType,
1228                         isMax](fir::FirOpBuilder &builder,
1229                                mlir::func::FuncOp &funcOp) {
1230     genRuntimeMinMaxlocBody(builder, funcOp, isMax, rank, maskRank, inputType,
1231                             logicalElemType, outType);
1232   };
1233 
1234   mlir::func::FuncOp newFunc =
1235       getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
1236   builder.create<fir::CallOp>(loc, newFunc,
1237                               mlir::ValueRange{args[0], args[1], args[5]});
1238   call->dropAllReferences();
1239   call->erase();
1240 }
1241 
1242 void SimplifyIntrinsicsPass::simplifyReductionBody(
1243     fir::CallOp call, const fir::KindMapping &kindMap,
1244     GenReductionBodyTy genBodyFunc, fir::FirOpBuilder &builder,
1245     const mlir::StringRef &funcName, mlir::Type elementType) {
1246 
1247   mlir::Operation::operand_range args = call.getArgs();
1248 
1249   mlir::Type resultType = call.getResult(0).getType();
1250   unsigned rank = getDimCount(args[0]);
1251 
1252   mlir::Location loc = call.getLoc();
1253 
1254   auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
1255     return genNoneBoxType(builder, resultType);
1256   };
1257   auto bodyGenerator = [&rank, &genBodyFunc,
1258                         &elementType](fir::FirOpBuilder &builder,
1259                                       mlir::func::FuncOp &funcOp) {
1260     genBodyFunc(builder, funcOp, rank, elementType);
1261   };
1262   // Mangle the function name with the rank value as "x<rank>".
1263   mlir::func::FuncOp newFunc =
1264       getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
1265   auto newCall =
1266       builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
1267   call->replaceAllUsesWith(newCall.getResults());
1268   call->dropAllReferences();
1269   call->erase();
1270 }
1271 
1272 void SimplifyIntrinsicsPass::runOnOperation() {
1273   LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
1274   mlir::ModuleOp module = getOperation();
1275   fir::KindMapping kindMap = fir::getKindMapping(module);
1276   module.walk([&](mlir::Operation *op) {
1277     if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
1278       if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
1279         mlir::StringRef funcName = callee.getLeafReference().getValue();
1280         // Replace call to runtime function for SUM when it has single
1281         // argument (no dim or mask argument) for 1D arrays with either
1282         // Integer4 or Real8 types. Other forms are ignored.
1283         // The new function is added to the module.
1284         //
1285         // Prototype for runtime call (from sum.cpp):
1286         // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
1287         //                int dim, const Descriptor *mask)
1288         //
1289         if (funcName.starts_with(RTNAME_STRING(Sum))) {
1290           simplifyIntOrFloatReduction(call, kindMap, genRuntimeSumBody);
1291           return;
1292         }
1293         if (funcName.starts_with(RTNAME_STRING(DotProduct))) {
1294           LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n");
1295           LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump();
1296                      llvm::dbgs() << "\n");
1297           mlir::Operation::operand_range args = call.getArgs();
1298           const mlir::Value &v1 = args[0];
1299           const mlir::Value &v2 = args[1];
1300           mlir::Location loc = call.getLoc();
1301           fir::FirOpBuilder builder{getSimplificationBuilder(op, kindMap)};
1302           // Stringize the builder's FastMathFlags flags for mangling
1303           // the generated function name.
1304           std::string fmfString{builder.getFastMathFlagsString()};
1305 
1306           mlir::Type type = call.getResult(0).getType();
1307           if (!type.isa<mlir::FloatType>() && !type.isa<mlir::IntegerType>())
1308             return;
1309 
1310           // Try to find the element types of the boxed arguments.
1311           auto arg1Type = getArgElementType(v1);
1312           auto arg2Type = getArgElementType(v2);
1313 
1314           if (!arg1Type || !arg2Type)
1315             return;
1316 
1317           // Support only floating point and integer arguments
1318           // now (e.g. logical is skipped here).
1319           if (!arg1Type->isa<mlir::FloatType>() &&
1320               !arg1Type->isa<mlir::IntegerType>())
1321             return;
1322           if (!arg2Type->isa<mlir::FloatType>() &&
1323               !arg2Type->isa<mlir::IntegerType>())
1324             return;
1325 
1326           auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
1327             return genRuntimeDotType(builder, type);
1328           };
1329           auto bodyGenerator = [&arg1Type,
1330                                 &arg2Type](fir::FirOpBuilder &builder,
1331                                            mlir::func::FuncOp &funcOp) {
1332             genRuntimeDotBody(builder, funcOp, *arg1Type, *arg2Type);
1333           };
1334 
1335           // Suffix the function name with the element types
1336           // of the arguments.
1337           std::string typedFuncName(funcName);
1338           llvm::raw_string_ostream nameOS(typedFuncName);
1339           // We must mangle the generated function name with FastMathFlags
1340           // value.
1341           if (!fmfString.empty())
1342             nameOS << '_' << fmfString;
1343           nameOS << '_';
1344           arg1Type->print(nameOS);
1345           nameOS << '_';
1346           arg2Type->print(nameOS);
1347 
1348           mlir::func::FuncOp newFunc = getOrCreateFunction(
1349               builder, typedFuncName, typeGenerator, bodyGenerator);
1350           auto newCall = builder.create<fir::CallOp>(loc, newFunc,
1351                                                      mlir::ValueRange{v1, v2});
1352           call->replaceAllUsesWith(newCall.getResults());
1353           call->dropAllReferences();
1354           call->erase();
1355 
1356           LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump();
1357                      llvm::dbgs() << "\n");
1358           return;
1359         }
1360         if (funcName.starts_with(RTNAME_STRING(Maxval))) {
1361           simplifyIntOrFloatReduction(call, kindMap, genRuntimeMaxvalBody);
1362           return;
1363         }
1364         if (funcName.starts_with(RTNAME_STRING(Count))) {
1365           simplifyLogicalDim0Reduction(call, kindMap, genRuntimeCountBody);
1366           return;
1367         }
1368         if (funcName.starts_with(RTNAME_STRING(Any))) {
1369           simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAnyBody);
1370           return;
1371         }
1372         if (funcName.ends_with(RTNAME_STRING(All))) {
1373           simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAllBody);
1374           return;
1375         }
1376         if (funcName.starts_with(RTNAME_STRING(Minloc))) {
1377           simplifyMinMaxlocReduction(call, kindMap, false);
1378           return;
1379         }
1380         if (funcName.starts_with(RTNAME_STRING(Maxloc))) {
1381           simplifyMinMaxlocReduction(call, kindMap, true);
1382           return;
1383         }
1384       }
1385     }
1386   });
1387   LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
1388 }
1389 
1390 void SimplifyIntrinsicsPass::getDependentDialects(
1391     mlir::DialectRegistry &registry) const {
1392   // LLVM::LinkageAttr creation requires that LLVM dialect is loaded.
1393   registry.insert<mlir::LLVM::LLVMDialect>();
1394 }
1395 std::unique_ptr<mlir::Pass> fir::createSimplifyIntrinsicsPass() {
1396   return std::make_unique<SimplifyIntrinsicsPass>();
1397 }
1398