xref: /llvm-project/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp (revision fac349a169976f822fb27f03e623fa0d28aec1f3)
1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 //===----------------------------------------------------------------------===//
10 /// \file
11 /// This pass looks for suitable calls to runtime library for intrinsics that
12 /// can be simplified/specialized and replaces with a specialized function.
13 ///
14 /// For example, SUM(arr) can be specialized as a simple function with one loop,
15 /// compared to the three arguments (plus file & line info) that the runtime
16 /// call has - when the argument is a 1D-array (multiple loops may be needed
17 //  for higher dimension arrays, of course)
18 ///
19 /// The general idea is that besides making the call simpler, it can also be
20 /// inlined by other passes that run after this pass, which further improves
21 /// performance, particularly when the work done in the function is trivial
22 /// and small in size.
23 //===----------------------------------------------------------------------===//
24 
25 #include "flang/Common/Fortran.h"
26 #include "flang/Optimizer/Builder/BoxValue.h"
27 #include "flang/Optimizer/Builder/FIRBuilder.h"
28 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
29 #include "flang/Optimizer/Builder/Todo.h"
30 #include "flang/Optimizer/Dialect/FIROps.h"
31 #include "flang/Optimizer/Dialect/FIRType.h"
32 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
33 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
34 #include "flang/Optimizer/Transforms/Passes.h"
35 #include "flang/Optimizer/Transforms/Utils.h"
36 #include "flang/Runtime/entry-names.h"
37 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
38 #include "mlir/IR/Matchers.h"
39 #include "mlir/IR/Operation.h"
40 #include "mlir/Pass/Pass.h"
41 #include "mlir/Transforms/DialectConversion.h"
42 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
43 #include "mlir/Transforms/RegionUtils.h"
44 #include "llvm/Support/Debug.h"
45 #include "llvm/Support/raw_ostream.h"
46 #include <llvm/Support/ErrorHandling.h>
47 #include <mlir/Dialect/Arith/IR/Arith.h>
48 #include <mlir/IR/BuiltinTypes.h>
49 #include <mlir/IR/Location.h>
50 #include <mlir/IR/MLIRContext.h>
51 #include <mlir/IR/Value.h>
52 #include <mlir/Support/LLVM.h>
53 #include <optional>
54 
55 namespace fir {
56 #define GEN_PASS_DEF_SIMPLIFYINTRINSICS
57 #include "flang/Optimizer/Transforms/Passes.h.inc"
58 } // namespace fir
59 
60 #define DEBUG_TYPE "flang-simplify-intrinsics"
61 
62 namespace {
63 
64 class SimplifyIntrinsicsPass
65     : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
66   using FunctionTypeGeneratorTy =
67       llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
68   using FunctionBodyGeneratorTy =
69       llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
70   using GenReductionBodyTy = llvm::function_ref<void(
71       fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank,
72       mlir::Type elementType)>;
73 
74 public:
75   using fir::impl::SimplifyIntrinsicsBase<
76       SimplifyIntrinsicsPass>::SimplifyIntrinsicsBase;
77 
78   /// Generate a new function implementing a simplified version
79   /// of a Fortran runtime function defined by \p basename name.
80   /// \p typeGenerator is a callback that generates the new function's type.
81   /// \p bodyGenerator is a callback that generates the new function's body.
82   /// The new function is created in the \p builder's Module.
83   mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
84                                          const mlir::StringRef &basename,
85                                          FunctionTypeGeneratorTy typeGenerator,
86                                          FunctionBodyGeneratorTy bodyGenerator);
87   void runOnOperation() override;
88   void getDependentDialects(mlir::DialectRegistry &registry) const override;
89 
90 private:
91   /// Helper functions to replace a reduction type of call with its
92   /// simplified form. The actual function is generated using a callback
93   /// function.
94   /// \p call is the call to be replaced
95   /// \p kindMap is used to create FIROpBuilder
96   /// \p genBodyFunc is the callback that builds the replacement function
97   void simplifyIntOrFloatReduction(fir::CallOp call,
98                                    const fir::KindMapping &kindMap,
99                                    GenReductionBodyTy genBodyFunc);
100   void simplifyLogicalDim0Reduction(fir::CallOp call,
101                                     const fir::KindMapping &kindMap,
102                                     GenReductionBodyTy genBodyFunc);
103   void simplifyLogicalDim1Reduction(fir::CallOp call,
104                                     const fir::KindMapping &kindMap,
105                                     GenReductionBodyTy genBodyFunc);
106   void simplifyMinMaxlocReduction(fir::CallOp call,
107                                   const fir::KindMapping &kindMap, bool isMax);
108   void simplifyReductionBody(fir::CallOp call, const fir::KindMapping &kindMap,
109                              GenReductionBodyTy genBodyFunc,
110                              fir::FirOpBuilder &builder,
111                              const mlir::StringRef &basename,
112                              mlir::Type elementType);
113 };
114 
115 } // namespace
116 
117 /// Create FirOpBuilder with the provided \p op insertion point
118 /// and \p kindMap additionally inheriting FastMathFlags from \p op.
119 static fir::FirOpBuilder
120 getSimplificationBuilder(mlir::Operation *op, const fir::KindMapping &kindMap) {
121   fir::FirOpBuilder builder{op, kindMap};
122   auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
123   if (!fmi)
124     return builder;
125 
126   // Regardless of what default FastMathFlags are used by FirOpBuilder,
127   // override them with FastMathFlags attached to the operation.
128   builder.setFastMathFlags(fmi.getFastMathFlagsAttr().getValue());
129   return builder;
130 }
131 
132 /// Generate function type for the simplified version of RTNAME(Sum) and
133 /// similar functions with a fir.box<none> type returning \p elementType.
134 static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
135                                          const mlir::Type &elementType) {
136   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
137   return mlir::FunctionType::get(builder.getContext(), {boxType},
138                                  {elementType});
139 }
140 
141 template <typename Op>
142 Op expectOp(mlir::Value val) {
143   if (Op op = mlir::dyn_cast_or_null<Op>(val.getDefiningOp()))
144     return op;
145   LLVM_DEBUG(llvm::dbgs() << "Didn't find expected " << Op::getOperationName()
146                           << '\n');
147   return nullptr;
148 }
149 
150 template <typename Op>
151 static mlir::Value findDefSingle(fir::ConvertOp op) {
152   if (auto defOp = expectOp<Op>(op->getOperand(0))) {
153     return defOp.getResult();
154   }
155   return {};
156 }
157 
158 template <typename... Ops>
159 static mlir::Value findDef(fir::ConvertOp op) {
160   mlir::Value defOp;
161   // Loop over the operation types given to see if any match, exiting once
162   // a match is found. Cast to void is needed to avoid compiler complaining
163   // that the result of expression is unused
164   (void)((defOp = findDefSingle<Ops>(op), (defOp)) || ...);
165   return defOp;
166 }
167 
168 static bool isOperandAbsent(mlir::Value val) {
169   if (auto op = expectOp<fir::ConvertOp>(val)) {
170     assert(op->getOperands().size() != 0);
171     return mlir::isa_and_nonnull<fir::AbsentOp>(
172         op->getOperand(0).getDefiningOp());
173   }
174   return false;
175 }
176 
177 static bool isTrueOrNotConstant(mlir::Value val) {
178   if (auto op = expectOp<mlir::arith::ConstantOp>(val)) {
179     return !mlir::matchPattern(val, mlir::m_Zero());
180   }
181   return true;
182 }
183 
184 static bool isZero(mlir::Value val) {
185   if (auto op = expectOp<fir::ConvertOp>(val)) {
186     assert(op->getOperands().size() != 0);
187     if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
188       return mlir::matchPattern(defOp, mlir::m_Zero());
189   }
190   return false;
191 }
192 
193 static mlir::Value findBoxDef(mlir::Value val) {
194   if (auto op = expectOp<fir::ConvertOp>(val)) {
195     assert(op->getOperands().size() != 0);
196     return findDef<fir::EmboxOp, fir::ReboxOp>(op);
197   }
198   return {};
199 }
200 
201 static mlir::Value findMaskDef(mlir::Value val) {
202   if (auto op = expectOp<fir::ConvertOp>(val)) {
203     assert(op->getOperands().size() != 0);
204     return findDef<fir::EmboxOp, fir::ReboxOp, fir::AbsentOp>(op);
205   }
206   return {};
207 }
208 
209 static unsigned getDimCount(mlir::Value val) {
210   // In order to find the dimensions count, we look for EmboxOp/ReboxOp
211   // and take the count from its *result* type. Note that in case
212   // of sliced emboxing the operand and the result of EmboxOp/ReboxOp
213   // have different types.
214   // Actually, we can take the box type from the operand of
215   // the first ConvertOp that has non-opaque box type that we meet
216   // going through the ConvertOp chain.
217   if (mlir::Value emboxVal = findBoxDef(val))
218     if (auto boxTy = mlir::dyn_cast<fir::BoxType>(emboxVal.getType()))
219       if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy()))
220         return seqTy.getDimension();
221   return 0;
222 }
223 
224 /// Given the call operation's box argument \p val, discover
225 /// the element type of the underlying array object.
226 /// \returns the element type or std::nullopt if the type cannot
227 /// be reliably found.
228 /// We expect that the argument is a result of fir.convert
229 /// with the destination type of !fir.box<none>.
230 static std::optional<mlir::Type> getArgElementType(mlir::Value val) {
231   mlir::Operation *defOp;
232   do {
233     defOp = val.getDefiningOp();
234     // Analyze only sequences of convert operations.
235     if (!mlir::isa<fir::ConvertOp>(defOp))
236       return std::nullopt;
237     val = defOp->getOperand(0);
238     // The convert operation is expected to convert from one
239     // box type to another box type.
240     auto boxType = mlir::cast<fir::BoxType>(val.getType());
241     auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType);
242     if (!mlir::isa<mlir::NoneType>(elementType))
243       return elementType;
244   } while (true);
245 }
246 
247 using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
248     fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
249     mlir::Value)>;
250 using ContinueLoopGenTy = llvm::function_ref<llvm::SmallVector<mlir::Value>(
251     fir::FirOpBuilder &, mlir::Location, mlir::Value)>;
252 
253 /// Generate the reduction loop into \p funcOp.
254 ///
255 /// \p initVal is a function, called to get the initial value for
256 ///    the reduction value
257 /// \p genBody is called to fill in the actual reduciton operation
258 ///    for example add for SUM, MAX for MAXVAL, etc.
259 /// \p rank is the rank of the input argument.
260 /// \p elementType is the type of the elements in the input array,
261 ///    which may be different to the return type.
262 /// \p loopCond is called to generate the condition to continue or
263 ///    not for IterWhile loops
264 /// \p unorderedOrInitalLoopCond contains either a boolean or bool
265 ///    mlir constant, and controls the inital value for while loops
266 ///    or if DoLoop is ordered/unordered.
267 
268 template <typename OP, typename T, int resultIndex>
269 static void
270 genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
271                  fir::InitValGeneratorTy initVal, ContinueLoopGenTy loopCond,
272                  T unorderedOrInitialLoopCond, BodyOpGeneratorTy genBody,
273                  unsigned rank, mlir::Type elementType, mlir::Location loc) {
274 
275   mlir::IndexType idxTy = builder.getIndexType();
276 
277   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
278   mlir::Value arg = args[0];
279 
280   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
281 
282   fir::SequenceType::Shape flatShape(rank,
283                                      fir::SequenceType::getUnknownExtent());
284   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
285   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
286   mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
287   mlir::Type resultType = funcOp.getResultTypes()[0];
288   mlir::Value init = initVal(builder, loc, resultType);
289 
290   llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
291 
292   assert(rank > 0 && "rank cannot be zero");
293   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
294 
295   // Compute all the upper bounds before the loop nest.
296   // It is not strictly necessary for performance, since the loop nest
297   // does not have any store operations and any LICM optimization
298   // should be able to optimize the redundancy.
299   for (unsigned i = 0; i < rank; ++i) {
300     mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
301     auto dims =
302         builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
303     mlir::Value len = dims.getResult(1);
304     // We use C indexing here, so len-1 as loopcount
305     mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
306     bounds.push_back(loopCount);
307   }
308   // Create a loop nest consisting of OP operations.
309   // Collect the loops' induction variables into indices array,
310   // which will be used in the innermost loop to load the input
311   // array's element.
312   // The loops are generated such that the innermost loop processes
313   // the 0 dimension.
314   llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
315   for (unsigned i = rank; 0 < i; --i) {
316     mlir::Value step = one;
317     mlir::Value loopCount = bounds[i - 1];
318     auto loop = builder.create<OP>(loc, zeroIdx, loopCount, step,
319                                    unorderedOrInitialLoopCond,
320                                    /*finalCountValue=*/false, init);
321     init = loop.getRegionIterArgs()[resultIndex];
322     indices.push_back(loop.getInductionVar());
323     // Set insertion point to the loop body so that the next loop
324     // is inserted inside the current one.
325     builder.setInsertionPointToStart(loop.getBody());
326   }
327 
328   // Reverse the indices such that they are ordered as:
329   //   <dim-0-idx, dim-1-idx, ...>
330   std::reverse(indices.begin(), indices.end());
331   // We are in the innermost loop: generate the reduction body.
332   mlir::Type eleRefTy = builder.getRefType(elementType);
333   mlir::Value addr =
334       builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
335   mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
336   mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init);
337   // Generate vector with condition to continue while loop at [0] and result
338   // from current loop at [1] for IterWhileOp loops, just result at [0] for
339   // DoLoopOp loops.
340   llvm::SmallVector<mlir::Value> results = loopCond(builder, loc, reductionVal);
341 
342   // Unwind the loop nest and insert ResultOp on each level
343   // to return the updated value of the reduction to the enclosing
344   // loops.
345   for (unsigned i = 0; i < rank; ++i) {
346     auto result = builder.create<fir::ResultOp>(loc, results);
347     // Proceed to the outer loop.
348     auto loop = mlir::cast<OP>(result->getParentOp());
349     results = loop.getResults();
350     // Set insertion point after the loop operation that we have
351     // just processed.
352     builder.setInsertionPointAfter(loop.getOperation());
353   }
354   // End of loop nest. The insertion point is after the outermost loop.
355   // Return the reduction value from the function.
356   builder.create<mlir::func::ReturnOp>(loc, results[resultIndex]);
357 }
358 
359 static llvm::SmallVector<mlir::Value> nopLoopCond(fir::FirOpBuilder &builder,
360                                                   mlir::Location loc,
361                                                   mlir::Value reductionVal) {
362   return {reductionVal};
363 }
364 
365 /// Generate function body of the simplified version of RTNAME(Sum)
366 /// with signature provided by \p funcOp. The caller is responsible
367 /// for saving/restoring the original insertion point of \p builder.
368 /// \p funcOp is expected to be empty on entry to this function.
369 /// \p rank specifies the rank of the input argument.
370 static void genRuntimeSumBody(fir::FirOpBuilder &builder,
371                               mlir::func::FuncOp &funcOp, unsigned rank,
372                               mlir::Type elementType) {
373   // function RTNAME(Sum)<T>x<rank>_simplified(arr)
374   //   T, dimension(:) :: arr
375   //   T sum = 0
376   //   integer iter
377   //   do iter = 0, extent(arr)
378   //     sum = sum + arr[iter]
379   //   end do
380   //   RTNAME(Sum)<T>x<rank>_simplified = sum
381   // end function RTNAME(Sum)<T>x<rank>_simplified
382   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
383                  mlir::Type elementType) {
384     if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
385       const llvm::fltSemantics &sem = ty.getFloatSemantics();
386       return builder.createRealConstant(loc, elementType,
387                                         llvm::APFloat::getZero(sem));
388     }
389     return builder.createIntegerConstant(loc, elementType, 0);
390   };
391 
392   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
393                       mlir::Type elementType, mlir::Value elem1,
394                       mlir::Value elem2) -> mlir::Value {
395     if (mlir::isa<mlir::FloatType>(elementType))
396       return builder.create<mlir::arith::AddFOp>(loc, elem1, elem2);
397     if (mlir::isa<mlir::IntegerType>(elementType))
398       return builder.create<mlir::arith::AddIOp>(loc, elem1, elem2);
399 
400     llvm_unreachable("unsupported type");
401     return {};
402   };
403 
404   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
405   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
406 
407   genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond,
408                                            false, genBodyOp, rank, elementType,
409                                            loc);
410 }
411 
412 static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
413                                  mlir::func::FuncOp &funcOp, unsigned rank,
414                                  mlir::Type elementType) {
415   auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
416                  mlir::Type elementType) {
417     if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
418       const llvm::fltSemantics &sem = ty.getFloatSemantics();
419       return builder.createRealConstant(
420           loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/true));
421     }
422     unsigned bits = elementType.getIntOrFloatBitWidth();
423     int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
424     return builder.createIntegerConstant(loc, elementType, minInt);
425   };
426 
427   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
428                       mlir::Type elementType, mlir::Value elem1,
429                       mlir::Value elem2) -> mlir::Value {
430     if (mlir::isa<mlir::FloatType>(elementType)) {
431       // arith.maxf later converted to llvm.intr.maxnum does not work
432       // correctly for NaNs and -0.0 (see maxnum/minnum pattern matching
433       // in LLVM's InstCombine pass). Moreover, llvm.intr.maxnum
434       // for F128 operands is lowered into fmaxl call by LLVM.
435       // This libm function may not work properly for F128 arguments
436       // on targets where long double is not F128. It is an LLVM issue,
437       // but we just use normal select here to resolve all the cases.
438       auto compare = builder.create<mlir::arith::CmpFOp>(
439           loc, mlir::arith::CmpFPredicate::OGT, elem1, elem2);
440       return builder.create<mlir::arith::SelectOp>(loc, compare, elem1, elem2);
441     }
442     if (mlir::isa<mlir::IntegerType>(elementType))
443       return builder.create<mlir::arith::MaxSIOp>(loc, elem1, elem2);
444 
445     llvm_unreachable("unsupported type");
446     return {};
447   };
448 
449   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
450   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
451 
452   genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, init, nopLoopCond,
453                                            false, genBodyOp, rank, elementType,
454                                            loc);
455 }
456 
457 static void genRuntimeCountBody(fir::FirOpBuilder &builder,
458                                 mlir::func::FuncOp &funcOp, unsigned rank,
459                                 mlir::Type elementType) {
460   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
461                  mlir::Type elementType) {
462     unsigned bits = elementType.getIntOrFloatBitWidth();
463     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
464     return builder.createIntegerConstant(loc, elementType, zeroInt);
465   };
466 
467   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
468                       mlir::Type elementType, mlir::Value elem1,
469                       mlir::Value elem2) -> mlir::Value {
470     auto zero32 = builder.createIntegerConstant(loc, elementType, 0);
471     auto zero64 = builder.createIntegerConstant(loc, builder.getI64Type(), 0);
472     auto one64 = builder.createIntegerConstant(loc, builder.getI64Type(), 1);
473 
474     auto compare = builder.create<mlir::arith::CmpIOp>(
475         loc, mlir::arith::CmpIPredicate::eq, elem1, zero32);
476     auto select =
477         builder.create<mlir::arith::SelectOp>(loc, compare, zero64, one64);
478     return builder.create<mlir::arith::AddIOp>(loc, select, elem2);
479   };
480 
481   // Count always gets I32 for elementType as it converts logical input to
482   // logical<4> before passing to the function.
483   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
484   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
485 
486   genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond,
487                                            false, genBodyOp, rank, elementType,
488                                            loc);
489 }
490 
491 static void genRuntimeAnyBody(fir::FirOpBuilder &builder,
492                               mlir::func::FuncOp &funcOp, unsigned rank,
493                               mlir::Type elementType) {
494   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
495                  mlir::Type elementType) {
496     return builder.createIntegerConstant(loc, elementType, 0);
497   };
498 
499   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
500                       mlir::Type elementType, mlir::Value elem1,
501                       mlir::Value elem2) -> mlir::Value {
502     auto zero = builder.createIntegerConstant(loc, elementType, 0);
503     return builder.create<mlir::arith::CmpIOp>(
504         loc, mlir::arith::CmpIPredicate::ne, elem1, zero);
505   };
506 
507   auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc,
508                          mlir::Value reductionVal) {
509     auto one1 = builder.createIntegerConstant(loc, builder.getI1Type(), 1);
510     auto eor = builder.create<mlir::arith::XOrIOp>(loc, reductionVal, one1);
511     llvm::SmallVector<mlir::Value> results = {eor, reductionVal};
512     return results;
513   };
514 
515   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
516   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
517   mlir::Value ok = builder.createBool(loc, true);
518 
519   genReductionLoop<fir::IterWhileOp, mlir::Value, 1>(
520       builder, funcOp, zero, continueCond, ok, genBodyOp, rank, elementType,
521       loc);
522 }
523 
524 static void genRuntimeAllBody(fir::FirOpBuilder &builder,
525                               mlir::func::FuncOp &funcOp, unsigned rank,
526                               mlir::Type elementType) {
527   auto one = [](fir::FirOpBuilder builder, mlir::Location loc,
528                 mlir::Type elementType) {
529     return builder.createIntegerConstant(loc, elementType, 1);
530   };
531 
532   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
533                       mlir::Type elementType, mlir::Value elem1,
534                       mlir::Value elem2) -> mlir::Value {
535     auto zero = builder.createIntegerConstant(loc, elementType, 0);
536     return builder.create<mlir::arith::CmpIOp>(
537         loc, mlir::arith::CmpIPredicate::ne, elem1, zero);
538   };
539 
540   auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc,
541                          mlir::Value reductionVal) {
542     llvm::SmallVector<mlir::Value> results = {reductionVal, reductionVal};
543     return results;
544   };
545 
546   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
547   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
548   mlir::Value ok = builder.createBool(loc, true);
549 
550   genReductionLoop<fir::IterWhileOp, mlir::Value, 1>(
551       builder, funcOp, one, continueCond, ok, genBodyOp, rank, elementType,
552       loc);
553 }
554 
555 static mlir::FunctionType genRuntimeMinlocType(fir::FirOpBuilder &builder,
556                                                unsigned int rank) {
557   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
558   mlir::Type boxRefType = builder.getRefType(boxType);
559 
560   return mlir::FunctionType::get(builder.getContext(),
561                                  {boxRefType, boxType, boxType}, {});
562 }
563 
564 // Produces a loop nest for a Minloc intrinsic.
565 void fir::genMinMaxlocReductionLoop(
566     fir::FirOpBuilder &builder, mlir::Value array,
567     fir::InitValGeneratorTy initVal, fir::MinlocBodyOpGeneratorTy genBody,
568     fir::AddrGeneratorTy getAddrFn, unsigned rank, mlir::Type elementType,
569     mlir::Location loc, mlir::Type maskElemType, mlir::Value resultArr,
570     bool maskMayBeLogicalScalar) {
571   mlir::IndexType idxTy = builder.getIndexType();
572 
573   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
574 
575   fir::SequenceType::Shape flatShape(rank,
576                                      fir::SequenceType::getUnknownExtent());
577   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
578   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
579   array = builder.create<fir::ConvertOp>(loc, boxArrTy, array);
580 
581   mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType());
582   mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
583   mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0);
584   mlir::Value flagRef = builder.createTemporary(loc, resultElemType);
585   builder.create<fir::StoreOp>(loc, zero, flagRef);
586 
587   mlir::Value init = initVal(builder, loc, elementType);
588   llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
589 
590   assert(rank > 0 && "rank cannot be zero");
591   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
592 
593   // Compute all the upper bounds before the loop nest.
594   // It is not strictly necessary for performance, since the loop nest
595   // does not have any store operations and any LICM optimization
596   // should be able to optimize the redundancy.
597   for (unsigned i = 0; i < rank; ++i) {
598     mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
599     auto dims =
600         builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
601     mlir::Value len = dims.getResult(1);
602     // We use C indexing here, so len-1 as loopcount
603     mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
604     bounds.push_back(loopCount);
605   }
606   // Create a loop nest consisting of OP operations.
607   // Collect the loops' induction variables into indices array,
608   // which will be used in the innermost loop to load the input
609   // array's element.
610   // The loops are generated such that the innermost loop processes
611   // the 0 dimension.
612   llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
613   for (unsigned i = rank; 0 < i; --i) {
614     mlir::Value step = one;
615     mlir::Value loopCount = bounds[i - 1];
616     auto loop =
617         builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false,
618                                       /*finalCountValue=*/false, init);
619     init = loop.getRegionIterArgs()[0];
620     indices.push_back(loop.getInductionVar());
621     // Set insertion point to the loop body so that the next loop
622     // is inserted inside the current one.
623     builder.setInsertionPointToStart(loop.getBody());
624   }
625 
626   // Reverse the indices such that they are ordered as:
627   //   <dim-0-idx, dim-1-idx, ...>
628   std::reverse(indices.begin(), indices.end());
629   mlir::Value reductionVal =
630       genBody(builder, loc, elementType, array, flagRef, init, indices);
631 
632   // Unwind the loop nest and insert ResultOp on each level
633   // to return the updated value of the reduction to the enclosing
634   // loops.
635   for (unsigned i = 0; i < rank; ++i) {
636     auto result = builder.create<fir::ResultOp>(loc, reductionVal);
637     // Proceed to the outer loop.
638     auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
639     reductionVal = loop.getResult(0);
640     // Set insertion point after the loop operation that we have
641     // just processed.
642     builder.setInsertionPointAfter(loop.getOperation());
643   }
644   // End of loop nest. The insertion point is after the outermost loop.
645   if (maskMayBeLogicalScalar) {
646     if (fir::IfOp ifOp =
647             mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) {
648       builder.create<fir::ResultOp>(loc, reductionVal);
649       builder.setInsertionPointAfter(ifOp);
650       // Redefine flagSet to escape scope of ifOp
651       flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
652       reductionVal = ifOp.getResult(0);
653     }
654   }
655 }
656 
657 static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
658                                     mlir::func::FuncOp &funcOp, bool isMax,
659                                     unsigned rank, int maskRank,
660                                     mlir::Type elementType,
661                                     mlir::Type maskElemType,
662                                     mlir::Type resultElemTy, bool isDim) {
663   auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
664                       mlir::Type elementType) {
665     if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
666       const llvm::fltSemantics &sem = ty.getFloatSemantics();
667       llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
668       return builder.createRealConstant(loc, elementType, limit);
669     }
670     unsigned bits = elementType.getIntOrFloatBitWidth();
671     int64_t initValue = (isMax ? llvm::APInt::getSignedMinValue(bits)
672                                : llvm::APInt::getSignedMaxValue(bits))
673                             .getSExtValue();
674     return builder.createIntegerConstant(loc, elementType, initValue);
675   };
676 
677   mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
678   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
679 
680   mlir::Value mask = funcOp.front().getArgument(2);
681 
682   // Set up result array in case of early exit / 0 length array
683   mlir::IndexType idxTy = builder.getIndexType();
684   mlir::Type resultTy = fir::SequenceType::get(rank, resultElemTy);
685   mlir::Type resultHeapTy = fir::HeapType::get(resultTy);
686   mlir::Type resultBoxTy = fir::BoxType::get(resultHeapTy);
687 
688   mlir::Value returnValue = builder.createIntegerConstant(loc, resultElemTy, 0);
689   mlir::Value resultArrSize = builder.createIntegerConstant(loc, idxTy, rank);
690 
691   mlir::Value resultArrInit = builder.create<fir::AllocMemOp>(loc, resultTy);
692   mlir::Value resultArrShape = builder.create<fir::ShapeOp>(loc, resultArrSize);
693   mlir::Value resultArr = builder.create<fir::EmboxOp>(
694       loc, resultBoxTy, resultArrInit, resultArrShape);
695 
696   mlir::Type resultRefTy = builder.getRefType(resultElemTy);
697 
698   if (maskRank > 0) {
699     fir::SequenceType::Shape flatShape(rank,
700                                        fir::SequenceType::getUnknownExtent());
701     mlir::Type maskTy = fir::SequenceType::get(flatShape, maskElemType);
702     mlir::Type boxMaskTy = fir::BoxType::get(maskTy);
703     mask = builder.create<fir::ConvertOp>(loc, boxMaskTy, mask);
704   }
705 
706   for (unsigned int i = 0; i < rank; ++i) {
707     mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
708     mlir::Value resultElemAddr =
709         builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr, index);
710     builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr);
711   }
712 
713   auto genBodyOp =
714       [&rank, &resultArr, isMax, &mask, &maskElemType, &maskRank](
715           fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType,
716           mlir::Value array, mlir::Value flagRef, mlir::Value reduction,
717           const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
718     // We are in the innermost loop: generate the reduction body.
719     if (maskRank > 0) {
720       mlir::Type logicalRef = builder.getRefType(maskElemType);
721       mlir::Value maskAddr =
722           builder.create<fir::CoordinateOp>(loc, logicalRef, mask, indices);
723       mlir::Value maskElem = builder.create<fir::LoadOp>(loc, maskAddr);
724 
725       // fir::IfOp requires argument to be I1 - won't accept logical or any
726       // other Integer.
727       mlir::Type ifCompatType = builder.getI1Type();
728       mlir::Value ifCompatElem =
729           builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem);
730 
731       llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
732       fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, ifCompatElem,
733                                                  /*withElseRegion=*/true);
734       builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
735     }
736 
737     // Set flag that mask was true at some point
738     mlir::Value flagSet = builder.createIntegerConstant(
739         loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
740     mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef);
741     mlir::Type eleRefTy = builder.getRefType(elementType);
742     mlir::Value addr =
743         builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
744     mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
745 
746     mlir::Value cmp;
747     if (mlir::isa<mlir::FloatType>(elementType)) {
748       // For FP reductions we want the first smallest value to be used, that
749       // is not NaN. A OGL/OLT condition will usually work for this unless all
750       // the values are Nan or Inf. This follows the same logic as
751       // NumericCompare for Minloc/Maxlox in extrema.cpp.
752       cmp = builder.create<mlir::arith::CmpFOp>(
753           loc,
754           isMax ? mlir::arith::CmpFPredicate::OGT
755                 : mlir::arith::CmpFPredicate::OLT,
756           elem, reduction);
757 
758       mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
759           loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
760       mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
761           loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
762       cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
763       cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
764     } else if (mlir::isa<mlir::IntegerType>(elementType)) {
765       cmp = builder.create<mlir::arith::CmpIOp>(
766           loc,
767           isMax ? mlir::arith::CmpIPredicate::sgt
768                 : mlir::arith::CmpIPredicate::slt,
769           elem, reduction);
770     } else {
771       llvm_unreachable("unsupported type");
772     }
773 
774     // The condition used for the loop is isFirst || <the condition above>.
775     isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst);
776     isFirst = builder.create<mlir::arith::XOrIOp>(
777         loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1));
778     cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst);
779     fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
780                                                /*withElseRegion*/ true);
781 
782     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
783     builder.create<fir::StoreOp>(loc, flagSet, flagRef);
784     mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
785     mlir::Type returnRefTy = builder.getRefType(resultElemTy);
786     mlir::IndexType idxTy = builder.getIndexType();
787 
788     mlir::Value one = builder.createIntegerConstant(loc, resultElemTy, 1);
789 
790     for (unsigned int i = 0; i < rank; ++i) {
791       mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
792       mlir::Value resultElemAddr =
793           builder.create<fir::CoordinateOp>(loc, returnRefTy, resultArr, index);
794       mlir::Value convert =
795           builder.create<fir::ConvertOp>(loc, resultElemTy, indices[i]);
796       mlir::Value fortranIndex =
797           builder.create<mlir::arith::AddIOp>(loc, convert, one);
798       builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr);
799     }
800     builder.create<fir::ResultOp>(loc, elem);
801     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
802     builder.create<fir::ResultOp>(loc, reduction);
803     builder.setInsertionPointAfter(ifOp);
804     mlir::Value reductionVal = ifOp.getResult(0);
805 
806     // Close the mask if needed
807     if (maskRank > 0) {
808       fir::IfOp ifOp =
809           mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp());
810       builder.create<fir::ResultOp>(loc, reductionVal);
811       builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
812       builder.create<fir::ResultOp>(loc, reduction);
813       reductionVal = ifOp.getResult(0);
814       builder.setInsertionPointAfter(ifOp);
815     }
816 
817     return reductionVal;
818   };
819 
820   // if mask is a logical scalar, we can check its value before the main loop
821   // and either ignore the fact it is there or exit early.
822   if (maskRank == 0) {
823     mlir::Type logical = builder.getI1Type();
824     mlir::IndexType idxTy = builder.getIndexType();
825 
826     fir::SequenceType::Shape singleElement(1, 1);
827     mlir::Type arrTy = fir::SequenceType::get(singleElement, logical);
828     mlir::Type boxArrTy = fir::BoxType::get(arrTy);
829     mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, mask);
830 
831     mlir::Value indx = builder.createIntegerConstant(loc, idxTy, 0);
832     mlir::Type logicalRefTy = builder.getRefType(logical);
833     mlir::Value condAddr =
834         builder.create<fir::CoordinateOp>(loc, logicalRefTy, array, indx);
835     mlir::Value cond = builder.create<fir::LoadOp>(loc, condAddr);
836 
837     fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cond,
838                                                /*withElseRegion=*/true);
839 
840     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
841     mlir::Value basicValue;
842     if (mlir::isa<mlir::IntegerType>(elementType)) {
843       basicValue = builder.createIntegerConstant(loc, elementType, 0);
844     } else {
845       basicValue = builder.createRealConstant(loc, elementType, 0);
846     }
847     builder.create<fir::ResultOp>(loc, basicValue);
848 
849     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
850   }
851   auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
852                       const mlir::Type &resultElemType, mlir::Value resultArr,
853                       mlir::Value index) {
854     mlir::Type resultRefTy = builder.getRefType(resultElemType);
855     return builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr,
856                                              index);
857   };
858 
859   genMinMaxlocReductionLoop(builder, funcOp.front().getArgument(1), init,
860                             genBodyOp, getAddrFn, rank, elementType, loc,
861                             maskElemType, resultArr, maskRank == 0);
862 
863   // Store newly created output array to the reference passed in
864   if (isDim) {
865     mlir::Type resultBoxTy =
866         fir::BoxType::get(fir::HeapType::get(resultElemTy));
867     mlir::Value outputArr = builder.create<fir::ConvertOp>(
868         loc, builder.getRefType(resultBoxTy), funcOp.front().getArgument(0));
869     mlir::Value resultArrScalar = builder.create<fir::ConvertOp>(
870         loc, fir::HeapType::get(resultElemTy), resultArrInit);
871     mlir::Value resultBox =
872         builder.create<fir::EmboxOp>(loc, resultBoxTy, resultArrScalar);
873     builder.create<fir::StoreOp>(loc, resultBox, outputArr);
874   } else {
875     fir::SequenceType::Shape resultShape(1, rank);
876     mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemTy);
877     mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy);
878     mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy);
879     mlir::Type outputRefTy = builder.getRefType(outputBoxTy);
880     mlir::Value outputArr = builder.create<fir::ConvertOp>(
881         loc, outputRefTy, funcOp.front().getArgument(0));
882     builder.create<fir::StoreOp>(loc, resultArr, outputArr);
883   }
884 
885   builder.create<mlir::func::ReturnOp>(loc);
886 }
887 
888 /// Generate function type for the simplified version of RTNAME(DotProduct)
889 /// operating on the given \p elementType.
890 static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder,
891                                             const mlir::Type &elementType) {
892   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
893   return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
894                                  {elementType});
895 }
896 
897 /// Generate function body of the simplified version of RTNAME(DotProduct)
898 /// with signature provided by \p funcOp. The caller is responsible
899 /// for saving/restoring the original insertion point of \p builder.
900 /// \p funcOp is expected to be empty on entry to this function.
901 /// \p arg1ElementTy and \p arg2ElementTy specify elements types
902 /// of the underlying array objects - they are used to generate proper
903 /// element accesses.
904 static void genRuntimeDotBody(fir::FirOpBuilder &builder,
905                               mlir::func::FuncOp &funcOp,
906                               mlir::Type arg1ElementTy,
907                               mlir::Type arg2ElementTy) {
908   // function RTNAME(DotProduct)<T>_simplified(arr1, arr2)
909   //   T, dimension(:) :: arr1, arr2
910   //   T product = 0
911   //   integer iter
912   //   do iter = 0, extent(arr1)
913   //     product = product + arr1[iter] * arr2[iter]
914   //   end do
915   //   RTNAME(ADotProduct)<T>_simplified = product
916   // end function RTNAME(DotProduct)<T>_simplified
917   auto loc = mlir::UnknownLoc::get(builder.getContext());
918   mlir::Type resultElementType = funcOp.getResultTypes()[0];
919   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
920 
921   mlir::IndexType idxTy = builder.getIndexType();
922 
923   mlir::Value zero =
924       mlir::isa<mlir::FloatType>(resultElementType)
925           ? builder.createRealConstant(loc, resultElementType, 0.0)
926           : builder.createIntegerConstant(loc, resultElementType, 0);
927 
928   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
929   mlir::Value arg1 = args[0];
930   mlir::Value arg2 = args[1];
931 
932   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
933 
934   fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
935   mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy);
936   mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1);
937   mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1);
938   mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy);
939   mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2);
940   mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2);
941   // This version takes the loop trip count from the first argument.
942   // If the first argument's box has unknown (at compilation time)
943   // extent, then it may be better to take the extent from the second
944   // argument - so that after inlining the loop may be better optimized, e.g.
945   // fully unrolled. This requires generating two versions of the simplified
946   // function and some analysis at the call site to choose which version
947   // is more profitable to call.
948   // Note that we can assume that both arguments have the same extent.
949   auto dims =
950       builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx);
951   mlir::Value len = dims.getResult(1);
952   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
953   mlir::Value step = one;
954 
955   // We use C indexing here, so len-1 as loopcount
956   mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
957   auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
958                                             /*unordered=*/false,
959                                             /*finalCountValue=*/false, zero);
960   mlir::Value sumVal = loop.getRegionIterArgs()[0];
961 
962   // Begin loop code
963   mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
964   builder.setInsertionPointToStart(loop.getBody());
965 
966   mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy);
967   mlir::Value index = loop.getInductionVar();
968   mlir::Value addr1 =
969       builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index);
970   mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1);
971   // Convert to the result type.
972   elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1);
973 
974   mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy);
975   mlir::Value addr2 =
976       builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index);
977   mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2);
978   // Convert to the result type.
979   elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2);
980 
981   if (mlir::isa<mlir::FloatType>(resultElementType))
982     sumVal = builder.create<mlir::arith::AddFOp>(
983         loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal);
984   else if (mlir::isa<mlir::IntegerType>(resultElementType))
985     sumVal = builder.create<mlir::arith::AddIOp>(
986         loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal);
987   else
988     llvm_unreachable("unsupported type");
989 
990   builder.create<fir::ResultOp>(loc, sumVal);
991   // End of loop.
992   builder.restoreInsertionPoint(loopEndPt);
993 
994   mlir::Value resultVal = loop.getResult(0);
995   builder.create<mlir::func::ReturnOp>(loc, resultVal);
996 }
997 
998 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
999     fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
1000     FunctionTypeGeneratorTy typeGenerator,
1001     FunctionBodyGeneratorTy bodyGenerator) {
1002   // WARNING: if the function generated here changes its signature
1003   //          or behavior (the body code), we should probably embed some
1004   //          versioning information into its name, otherwise libraries
1005   //          statically linked with older versions of Flang may stop
1006   //          working with object files created with newer Flang.
1007   //          We can also avoid this by using internal linkage, but
1008   //          this may increase the size of final executable/shared library.
1009   std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
1010   // If we already have a function, just return it.
1011   mlir::func::FuncOp newFunc = builder.getNamedFunction(replacementName);
1012   mlir::FunctionType fType = typeGenerator(builder);
1013   if (newFunc) {
1014     assert(newFunc.getFunctionType() == fType &&
1015            "type mismatch for simplified function");
1016     return newFunc;
1017   }
1018 
1019   // Need to build the function!
1020   auto loc = mlir::UnknownLoc::get(builder.getContext());
1021   newFunc = builder.createFunction(loc, replacementName, fType);
1022   auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
1023   auto linkage =
1024       mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
1025   newFunc->setAttr("llvm.linkage", linkage);
1026 
1027   // Save the position of the original call.
1028   mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
1029 
1030   bodyGenerator(builder, newFunc);
1031 
1032   // Now back to where we were adding code earlier...
1033   builder.restoreInsertionPoint(insertPt);
1034 
1035   return newFunc;
1036 }
1037 
1038 void SimplifyIntrinsicsPass::simplifyIntOrFloatReduction(
1039     fir::CallOp call, const fir::KindMapping &kindMap,
1040     GenReductionBodyTy genBodyFunc) {
1041   // args[1] and args[2] are source filename and line number, ignored.
1042   mlir::Operation::operand_range args = call.getArgs();
1043 
1044   const mlir::Value &dim = args[3];
1045   const mlir::Value &mask = args[4];
1046   // dim is zero when it is absent, which is an implementation
1047   // detail in the runtime library.
1048 
1049   bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
1050   unsigned rank = getDimCount(args[0]);
1051 
1052   // Rank is set to 0 for assumed shape arrays, don't simplify
1053   // in these cases
1054   if (!(dimAndMaskAbsent && rank > 0))
1055     return;
1056 
1057   mlir::Type resultType = call.getResult(0).getType();
1058 
1059   if (!mlir::isa<mlir::FloatType>(resultType) &&
1060       !mlir::isa<mlir::IntegerType>(resultType))
1061     return;
1062 
1063   auto argType = getArgElementType(args[0]);
1064   if (!argType)
1065     return;
1066   assert(*argType == resultType &&
1067          "Argument/result types mismatch in reduction");
1068 
1069   mlir::SymbolRefAttr callee = call.getCalleeAttr();
1070 
1071   fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1072   std::string fmfString{builder.getFastMathFlagsString()};
1073   std::string funcName =
1074       (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
1075        mlir::Twine{rank} +
1076        // We must mangle the generated function name with FastMathFlags
1077        // value.
1078        (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString}))
1079           .str();
1080 
1081   simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1082                         resultType);
1083 }
1084 
1085 void SimplifyIntrinsicsPass::simplifyLogicalDim0Reduction(
1086     fir::CallOp call, const fir::KindMapping &kindMap,
1087     GenReductionBodyTy genBodyFunc) {
1088 
1089   mlir::Operation::operand_range args = call.getArgs();
1090   const mlir::Value &dim = args[3];
1091   unsigned rank = getDimCount(args[0]);
1092 
1093   // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
1094   // these cases.
1095   if (!(isZero(dim) && rank > 0))
1096     return;
1097 
1098   mlir::Value inputBox = findBoxDef(args[0]);
1099 
1100   mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType());
1101   mlir::SymbolRefAttr callee = call.getCalleeAttr();
1102 
1103   fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1104 
1105   // Treating logicals as integers makes things a lot easier
1106   fir::LogicalType logicalType = {
1107       mlir::dyn_cast<fir::LogicalType>(elementType)};
1108   fir::KindTy kind = logicalType.getFKind();
1109   mlir::Type intElementType = builder.getIntegerType(kind * 8);
1110 
1111   // Mangle kind into function name as it is not done by default
1112   std::string funcName =
1113       (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} +
1114        mlir::Twine{kind} + "x" + mlir::Twine{rank})
1115           .str();
1116 
1117   simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1118                         intElementType);
1119 }
1120 
1121 void SimplifyIntrinsicsPass::simplifyLogicalDim1Reduction(
1122     fir::CallOp call, const fir::KindMapping &kindMap,
1123     GenReductionBodyTy genBodyFunc) {
1124 
1125   mlir::Operation::operand_range args = call.getArgs();
1126   mlir::SymbolRefAttr callee = call.getCalleeAttr();
1127   mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
1128   unsigned rank = getDimCount(args[0]);
1129 
1130   // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
1131   // these cases. We check for Dim at the end as some logical functions (Any,
1132   // All) set dim to 1 instead of 0 when the argument is not present.
1133   if (funcNameBase.ends_with("Dim") || !(rank > 0))
1134     return;
1135 
1136   mlir::Value inputBox = findBoxDef(args[0]);
1137   mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType());
1138 
1139   fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1140 
1141   // Treating logicals as integers makes things a lot easier
1142   fir::LogicalType logicalType = {
1143       mlir::dyn_cast<fir::LogicalType>(elementType)};
1144   fir::KindTy kind = logicalType.getFKind();
1145   mlir::Type intElementType = builder.getIntegerType(kind * 8);
1146 
1147   // Mangle kind into function name as it is not done by default
1148   std::string funcName =
1149       (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} +
1150        mlir::Twine{kind} + "x" + mlir::Twine{rank})
1151           .str();
1152 
1153   simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1154                         intElementType);
1155 }
1156 
1157 void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
1158     fir::CallOp call, const fir::KindMapping &kindMap, bool isMax) {
1159 
1160   mlir::Operation::operand_range args = call.getArgs();
1161 
1162   mlir::SymbolRefAttr callee = call.getCalleeAttr();
1163   mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
1164   bool isDim = funcNameBase.ends_with("Dim");
1165   mlir::Value back = args[isDim ? 7 : 6];
1166   if (isTrueOrNotConstant(back))
1167     return;
1168 
1169   mlir::Value mask = args[isDim ? 6 : 5];
1170   mlir::Value maskDef = findMaskDef(mask);
1171 
1172   // maskDef is set to NULL when the defining op is not one we accept.
1173   // This tends to be because it is a selectOp, in which case let the
1174   // runtime deal with it.
1175   if (maskDef == NULL)
1176     return;
1177 
1178   unsigned rank = getDimCount(args[1]);
1179   if ((isDim && rank != 1) || !(rank > 0))
1180     return;
1181 
1182   fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1183   mlir::Location loc = call.getLoc();
1184   auto inputBox = findBoxDef(args[1]);
1185   mlir::Type inputType = hlfir::getFortranElementType(inputBox.getType());
1186 
1187   if (mlir::isa<fir::CharacterType>(inputType))
1188     return;
1189 
1190   int maskRank;
1191   fir::KindTy kind = 0;
1192   mlir::Type logicalElemType = builder.getI1Type();
1193   if (isOperandAbsent(mask)) {
1194     maskRank = -1;
1195   } else {
1196     maskRank = getDimCount(mask);
1197     mlir::Type maskElemTy = hlfir::getFortranElementType(maskDef.getType());
1198     fir::LogicalType logicalFirType = {
1199         mlir::dyn_cast<fir::LogicalType>(maskElemTy)};
1200     kind = logicalFirType.getFKind();
1201     // Convert fir::LogicalType to mlir::Type
1202     logicalElemType = logicalFirType;
1203   }
1204 
1205   mlir::Operation *outputDef = args[0].getDefiningOp();
1206   mlir::Value outputAlloc = outputDef->getOperand(0);
1207   mlir::Type outType = hlfir::getFortranElementType(outputAlloc.getType());
1208 
1209   std::string fmfString{builder.getFastMathFlagsString()};
1210   std::string funcName =
1211       (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
1212        mlir::Twine{rank} +
1213        (maskRank >= 0
1214             ? "_Logical" + mlir::Twine{kind} + "x" + mlir::Twine{maskRank}
1215             : "") +
1216        "_")
1217           .str();
1218 
1219   llvm::raw_string_ostream nameOS(funcName);
1220   outType.print(nameOS);
1221   if (isDim)
1222     nameOS << '_' << inputType;
1223   nameOS << '_' << fmfString;
1224 
1225   auto typeGenerator = [rank](fir::FirOpBuilder &builder) {
1226     return genRuntimeMinlocType(builder, rank);
1227   };
1228   auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType,
1229                         isMax, isDim](fir::FirOpBuilder &builder,
1230                                       mlir::func::FuncOp &funcOp) {
1231     genRuntimeMinMaxlocBody(builder, funcOp, isMax, rank, maskRank, inputType,
1232                             logicalElemType, outType, isDim);
1233   };
1234 
1235   mlir::func::FuncOp newFunc =
1236       getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
1237   builder.create<fir::CallOp>(loc, newFunc,
1238                               mlir::ValueRange{args[0], args[1], mask});
1239   call->dropAllReferences();
1240   call->erase();
1241 }
1242 
1243 void SimplifyIntrinsicsPass::simplifyReductionBody(
1244     fir::CallOp call, const fir::KindMapping &kindMap,
1245     GenReductionBodyTy genBodyFunc, fir::FirOpBuilder &builder,
1246     const mlir::StringRef &funcName, mlir::Type elementType) {
1247 
1248   mlir::Operation::operand_range args = call.getArgs();
1249 
1250   mlir::Type resultType = call.getResult(0).getType();
1251   unsigned rank = getDimCount(args[0]);
1252 
1253   mlir::Location loc = call.getLoc();
1254 
1255   auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
1256     return genNoneBoxType(builder, resultType);
1257   };
1258   auto bodyGenerator = [&rank, &genBodyFunc,
1259                         &elementType](fir::FirOpBuilder &builder,
1260                                       mlir::func::FuncOp &funcOp) {
1261     genBodyFunc(builder, funcOp, rank, elementType);
1262   };
1263   // Mangle the function name with the rank value as "x<rank>".
1264   mlir::func::FuncOp newFunc =
1265       getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
1266   auto newCall =
1267       builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
1268   call->replaceAllUsesWith(newCall.getResults());
1269   call->dropAllReferences();
1270   call->erase();
1271 }
1272 
1273 void SimplifyIntrinsicsPass::runOnOperation() {
1274   LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
1275   mlir::ModuleOp module = getOperation();
1276   fir::KindMapping kindMap = fir::getKindMapping(module);
1277   module.walk([&](mlir::Operation *op) {
1278     if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
1279       if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
1280         mlir::StringRef funcName = callee.getLeafReference().getValue();
1281         // Replace call to runtime function for SUM when it has single
1282         // argument (no dim or mask argument) for 1D arrays with either
1283         // Integer4 or Real8 types. Other forms are ignored.
1284         // The new function is added to the module.
1285         //
1286         // Prototype for runtime call (from sum.cpp):
1287         // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
1288         //                int dim, const Descriptor *mask)
1289         //
1290         if (funcName.starts_with(RTNAME_STRING(Sum))) {
1291           simplifyIntOrFloatReduction(call, kindMap, genRuntimeSumBody);
1292           return;
1293         }
1294         if (funcName.starts_with(RTNAME_STRING(DotProduct))) {
1295           LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n");
1296           LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump();
1297                      llvm::dbgs() << "\n");
1298           mlir::Operation::operand_range args = call.getArgs();
1299           const mlir::Value &v1 = args[0];
1300           const mlir::Value &v2 = args[1];
1301           mlir::Location loc = call.getLoc();
1302           fir::FirOpBuilder builder{getSimplificationBuilder(op, kindMap)};
1303           // Stringize the builder's FastMathFlags flags for mangling
1304           // the generated function name.
1305           std::string fmfString{builder.getFastMathFlagsString()};
1306 
1307           mlir::Type type = call.getResult(0).getType();
1308           if (!mlir::isa<mlir::FloatType>(type) &&
1309               !mlir::isa<mlir::IntegerType>(type))
1310             return;
1311 
1312           // Try to find the element types of the boxed arguments.
1313           auto arg1Type = getArgElementType(v1);
1314           auto arg2Type = getArgElementType(v2);
1315 
1316           if (!arg1Type || !arg2Type)
1317             return;
1318 
1319           // Support only floating point and integer arguments
1320           // now (e.g. logical is skipped here).
1321           if (!arg1Type->isa<mlir::FloatType>() &&
1322               !arg1Type->isa<mlir::IntegerType>())
1323             return;
1324           if (!arg2Type->isa<mlir::FloatType>() &&
1325               !arg2Type->isa<mlir::IntegerType>())
1326             return;
1327 
1328           auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
1329             return genRuntimeDotType(builder, type);
1330           };
1331           auto bodyGenerator = [&arg1Type,
1332                                 &arg2Type](fir::FirOpBuilder &builder,
1333                                            mlir::func::FuncOp &funcOp) {
1334             genRuntimeDotBody(builder, funcOp, *arg1Type, *arg2Type);
1335           };
1336 
1337           // Suffix the function name with the element types
1338           // of the arguments.
1339           std::string typedFuncName(funcName);
1340           llvm::raw_string_ostream nameOS(typedFuncName);
1341           // We must mangle the generated function name with FastMathFlags
1342           // value.
1343           if (!fmfString.empty())
1344             nameOS << '_' << fmfString;
1345           nameOS << '_';
1346           arg1Type->print(nameOS);
1347           nameOS << '_';
1348           arg2Type->print(nameOS);
1349 
1350           mlir::func::FuncOp newFunc = getOrCreateFunction(
1351               builder, typedFuncName, typeGenerator, bodyGenerator);
1352           auto newCall = builder.create<fir::CallOp>(loc, newFunc,
1353                                                      mlir::ValueRange{v1, v2});
1354           call->replaceAllUsesWith(newCall.getResults());
1355           call->dropAllReferences();
1356           call->erase();
1357 
1358           LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump();
1359                      llvm::dbgs() << "\n");
1360           return;
1361         }
1362         if (funcName.starts_with(RTNAME_STRING(Maxval))) {
1363           simplifyIntOrFloatReduction(call, kindMap, genRuntimeMaxvalBody);
1364           return;
1365         }
1366         if (funcName.starts_with(RTNAME_STRING(Count))) {
1367           simplifyLogicalDim0Reduction(call, kindMap, genRuntimeCountBody);
1368           return;
1369         }
1370         if (funcName.starts_with(RTNAME_STRING(Any))) {
1371           simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAnyBody);
1372           return;
1373         }
1374         if (funcName.ends_with(RTNAME_STRING(All))) {
1375           simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAllBody);
1376           return;
1377         }
1378         if (funcName.starts_with(RTNAME_STRING(Minloc))) {
1379           simplifyMinMaxlocReduction(call, kindMap, false);
1380           return;
1381         }
1382         if (funcName.starts_with(RTNAME_STRING(Maxloc))) {
1383           simplifyMinMaxlocReduction(call, kindMap, true);
1384           return;
1385         }
1386       }
1387     }
1388   });
1389   LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
1390 }
1391 
1392 void SimplifyIntrinsicsPass::getDependentDialects(
1393     mlir::DialectRegistry &registry) const {
1394   // LLVM::LinkageAttr creation requires that LLVM dialect is loaded.
1395   registry.insert<mlir::LLVM::LLVMDialect>();
1396 }
1397