16e193b5cSMats Petersson //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===// 26e193b5cSMats Petersson // 36e193b5cSMats Petersson // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 46e193b5cSMats Petersson // See https://llvm.org/LICENSE.txt for license information. 56e193b5cSMats Petersson // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 66e193b5cSMats Petersson // 76e193b5cSMats Petersson //===----------------------------------------------------------------------===// 86e193b5cSMats Petersson 96e193b5cSMats Petersson //===----------------------------------------------------------------------===// 106e193b5cSMats Petersson /// \file 116e193b5cSMats Petersson /// This pass looks for suitable calls to runtime library for intrinsics that 126e193b5cSMats Petersson /// can be simplified/specialized and replaces with a specialized function. 136e193b5cSMats Petersson /// 146e193b5cSMats Petersson /// For example, SUM(arr) can be specialized as a simple function with one loop, 156e193b5cSMats Petersson /// compared to the three arguments (plus file & line info) that the runtime 166e193b5cSMats Petersson /// call has - when the argument is a 1D-array (multiple loops may be needed 176e193b5cSMats Petersson // for higher dimension arrays, of course) 186e193b5cSMats Petersson /// 196e193b5cSMats Petersson /// The general idea is that besides making the call simpler, it can also be 206e193b5cSMats Petersson /// inlined by other passes that run after this pass, which further improves 216e193b5cSMats Petersson /// performance, particularly when the work done in the function is trivial 226e193b5cSMats Petersson /// and small in size. 236e193b5cSMats Petersson //===----------------------------------------------------------------------===// 246e193b5cSMats Petersson 25614cd721SSacha Ballantyne #include "flang/Common/Fortran.h" 266e193b5cSMats Petersson #include "flang/Optimizer/Builder/BoxValue.h" 27*4b17a8b1SValentin Clement (バレンタイン クレメン) #include "flang/Optimizer/Builder/CUFCommon.h" 286e193b5cSMats Petersson #include "flang/Optimizer/Builder/FIRBuilder.h" 29614cd721SSacha Ballantyne #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" 306e193b5cSMats Petersson #include "flang/Optimizer/Builder/Todo.h" 316e193b5cSMats Petersson #include "flang/Optimizer/Dialect/FIROps.h" 326e193b5cSMats Petersson #include "flang/Optimizer/Dialect/FIRType.h" 33b07ef9e7SRenaud-K #include "flang/Optimizer/Dialect/Support/FIRContext.h" 3420fba03fSSacha Ballantyne #include "flang/Optimizer/HLFIR/HLFIRDialect.h" 356e193b5cSMats Petersson #include "flang/Optimizer/Transforms/Passes.h" 36815a8465SDavid Green #include "flang/Optimizer/Transforms/Utils.h" 37aa94eb38SMats Petersson #include "flang/Runtime/entry-names.h" 3867d0d7acSMichele Scuttari #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 396e193b5cSMats Petersson #include "mlir/IR/Matchers.h" 40614cd721SSacha Ballantyne #include "mlir/IR/Operation.h" 416e193b5cSMats Petersson #include "mlir/Pass/Pass.h" 426e193b5cSMats Petersson #include "mlir/Transforms/DialectConversion.h" 436e193b5cSMats Petersson #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 446e193b5cSMats Petersson #include "mlir/Transforms/RegionUtils.h" 451d5e7a49SSlava Zakharin #include "llvm/Support/Debug.h" 4656eda98fSSlava Zakharin #include "llvm/Support/raw_ostream.h" 47614cd721SSacha Ballantyne #include <llvm/Support/ErrorHandling.h> 48614cd721SSacha Ballantyne #include <mlir/Dialect/Arith/IR/Arith.h> 49614cd721SSacha Ballantyne #include <mlir/IR/BuiltinTypes.h> 5020fba03fSSacha Ballantyne #include <mlir/IR/Location.h> 5120fba03fSSacha Ballantyne #include <mlir/IR/MLIRContext.h> 5220fba03fSSacha Ballantyne #include <mlir/IR/Value.h> 5320fba03fSSacha Ballantyne #include <mlir/Support/LLVM.h> 544d4d4785SKazu Hirata #include <optional> 551d5e7a49SSlava Zakharin 5667d0d7acSMichele Scuttari namespace fir { 5767d0d7acSMichele Scuttari #define GEN_PASS_DEF_SIMPLIFYINTRINSICS 5867d0d7acSMichele Scuttari #include "flang/Optimizer/Transforms/Passes.h.inc" 5967d0d7acSMichele Scuttari } // namespace fir 6067d0d7acSMichele Scuttari 611d5e7a49SSlava Zakharin #define DEBUG_TYPE "flang-simplify-intrinsics" 626e193b5cSMats Petersson 636e193b5cSMats Petersson namespace { 646e193b5cSMats Petersson 656e193b5cSMats Petersson class SimplifyIntrinsicsPass 6667d0d7acSMichele Scuttari : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> { 6780dcc907SSlava Zakharin using FunctionTypeGeneratorTy = 6843159b58SMats Petersson llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>; 6980dcc907SSlava Zakharin using FunctionBodyGeneratorTy = 7043159b58SMats Petersson llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>; 7143159b58SMats Petersson using GenReductionBodyTy = llvm::function_ref<void( 7220fba03fSSacha Ballantyne fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank, 7320fba03fSSacha Ballantyne mlir::Type elementType)>; 7480dcc907SSlava Zakharin 756e193b5cSMats Petersson public: 7681442f8dSTom Eccles using fir::impl::SimplifyIntrinsicsBase< 7781442f8dSTom Eccles SimplifyIntrinsicsPass>::SimplifyIntrinsicsBase; 7881442f8dSTom Eccles 7980dcc907SSlava Zakharin /// Generate a new function implementing a simplified version 8080dcc907SSlava Zakharin /// of a Fortran runtime function defined by \p basename name. 8180dcc907SSlava Zakharin /// \p typeGenerator is a callback that generates the new function's type. 8280dcc907SSlava Zakharin /// \p bodyGenerator is a callback that generates the new function's body. 8380dcc907SSlava Zakharin /// The new function is created in the \p builder's Module. 8480dcc907SSlava Zakharin mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder, 8580dcc907SSlava Zakharin const mlir::StringRef &basename, 8680dcc907SSlava Zakharin FunctionTypeGeneratorTy typeGenerator, 8780dcc907SSlava Zakharin FunctionBodyGeneratorTy bodyGenerator); 886e193b5cSMats Petersson void runOnOperation() override; 891d5e7a49SSlava Zakharin void getDependentDialects(mlir::DialectRegistry ®istry) const override; 9043159b58SMats Petersson 9143159b58SMats Petersson private: 927d2e1987SSacha Ballantyne /// Helper functions to replace a reduction type of call with its 9343159b58SMats Petersson /// simplified form. The actual function is generated using a callback 9443159b58SMats Petersson /// function. 9543159b58SMats Petersson /// \p call is the call to be replaced 9643159b58SMats Petersson /// \p kindMap is used to create FIROpBuilder 9743159b58SMats Petersson /// \p genBodyFunc is the callback that builds the replacement function 987d2e1987SSacha Ballantyne void simplifyIntOrFloatReduction(fir::CallOp call, 997d2e1987SSacha Ballantyne const fir::KindMapping &kindMap, 10043159b58SMats Petersson GenReductionBodyTy genBodyFunc); 10120fba03fSSacha Ballantyne void simplifyLogicalDim0Reduction(fir::CallOp call, 10220fba03fSSacha Ballantyne const fir::KindMapping &kindMap, 10320fba03fSSacha Ballantyne GenReductionBodyTy genBodyFunc); 10420fba03fSSacha Ballantyne void simplifyLogicalDim1Reduction(fir::CallOp call, 1057d2e1987SSacha Ballantyne const fir::KindMapping &kindMap, 1067d2e1987SSacha Ballantyne GenReductionBodyTy genBodyFunc); 1079bb47f7fSDavid Green void simplifyMinMaxlocReduction(fir::CallOp call, 1089bb47f7fSDavid Green const fir::KindMapping &kindMap, bool isMax); 1097d2e1987SSacha Ballantyne void simplifyReductionBody(fir::CallOp call, const fir::KindMapping &kindMap, 1107d2e1987SSacha Ballantyne GenReductionBodyTy genBodyFunc, 1117d2e1987SSacha Ballantyne fir::FirOpBuilder &builder, 11220fba03fSSacha Ballantyne const mlir::StringRef &basename, 11320fba03fSSacha Ballantyne mlir::Type elementType); 1146e193b5cSMats Petersson }; 1156e193b5cSMats Petersson 1166e193b5cSMats Petersson } // namespace 1176e193b5cSMats Petersson 118ffe1661fSSlava Zakharin /// Create FirOpBuilder with the provided \p op insertion point 119ffe1661fSSlava Zakharin /// and \p kindMap additionally inheriting FastMathFlags from \p op. 120ffe1661fSSlava Zakharin static fir::FirOpBuilder 121ffe1661fSSlava Zakharin getSimplificationBuilder(mlir::Operation *op, const fir::KindMapping &kindMap) { 122ffe1661fSSlava Zakharin fir::FirOpBuilder builder{op, kindMap}; 123ffe1661fSSlava Zakharin auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op); 124ffe1661fSSlava Zakharin if (!fmi) 125ffe1661fSSlava Zakharin return builder; 126ffe1661fSSlava Zakharin 127ffe1661fSSlava Zakharin // Regardless of what default FastMathFlags are used by FirOpBuilder, 128ffe1661fSSlava Zakharin // override them with FastMathFlags attached to the operation. 129ffe1661fSSlava Zakharin builder.setFastMathFlags(fmi.getFastMathFlagsAttr().getValue()); 130ffe1661fSSlava Zakharin return builder; 131ffe1661fSSlava Zakharin } 132ffe1661fSSlava Zakharin 133aa94eb38SMats Petersson /// Generate function type for the simplified version of RTNAME(Sum) and 134afa520abSMats Petersson /// similar functions with a fir.box<none> type returning \p elementType. 135afa520abSMats Petersson static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder, 13680dcc907SSlava Zakharin const mlir::Type &elementType) { 13780dcc907SSlava Zakharin mlir::Type boxType = fir::BoxType::get(builder.getNoneType()); 13880dcc907SSlava Zakharin return mlir::FunctionType::get(builder.getContext(), {boxType}, 13980dcc907SSlava Zakharin {elementType}); 14080dcc907SSlava Zakharin } 1416e193b5cSMats Petersson 142614cd721SSacha Ballantyne template <typename Op> 143614cd721SSacha Ballantyne Op expectOp(mlir::Value val) { 144614cd721SSacha Ballantyne if (Op op = mlir::dyn_cast_or_null<Op>(val.getDefiningOp())) 145614cd721SSacha Ballantyne return op; 146614cd721SSacha Ballantyne LLVM_DEBUG(llvm::dbgs() << "Didn't find expected " << Op::getOperationName() 147614cd721SSacha Ballantyne << '\n'); 148614cd721SSacha Ballantyne return nullptr; 149614cd721SSacha Ballantyne } 150614cd721SSacha Ballantyne 151614cd721SSacha Ballantyne template <typename Op> 152614cd721SSacha Ballantyne static mlir::Value findDefSingle(fir::ConvertOp op) { 153614cd721SSacha Ballantyne if (auto defOp = expectOp<Op>(op->getOperand(0))) { 154614cd721SSacha Ballantyne return defOp.getResult(); 155614cd721SSacha Ballantyne } 156614cd721SSacha Ballantyne return {}; 157614cd721SSacha Ballantyne } 158614cd721SSacha Ballantyne 159614cd721SSacha Ballantyne template <typename... Ops> 160614cd721SSacha Ballantyne static mlir::Value findDef(fir::ConvertOp op) { 161614cd721SSacha Ballantyne mlir::Value defOp; 162614cd721SSacha Ballantyne // Loop over the operation types given to see if any match, exiting once 163614cd721SSacha Ballantyne // a match is found. Cast to void is needed to avoid compiler complaining 164614cd721SSacha Ballantyne // that the result of expression is unused 165614cd721SSacha Ballantyne (void)((defOp = findDefSingle<Ops>(op), (defOp)) || ...); 166614cd721SSacha Ballantyne return defOp; 167614cd721SSacha Ballantyne } 168614cd721SSacha Ballantyne 169614cd721SSacha Ballantyne static bool isOperandAbsent(mlir::Value val) { 170614cd721SSacha Ballantyne if (auto op = expectOp<fir::ConvertOp>(val)) { 171614cd721SSacha Ballantyne assert(op->getOperands().size() != 0); 172614cd721SSacha Ballantyne return mlir::isa_and_nonnull<fir::AbsentOp>( 173614cd721SSacha Ballantyne op->getOperand(0).getDefiningOp()); 174614cd721SSacha Ballantyne } 175614cd721SSacha Ballantyne return false; 176614cd721SSacha Ballantyne } 177614cd721SSacha Ballantyne 178614cd721SSacha Ballantyne static bool isTrueOrNotConstant(mlir::Value val) { 179614cd721SSacha Ballantyne if (auto op = expectOp<mlir::arith::ConstantOp>(val)) { 180614cd721SSacha Ballantyne return !mlir::matchPattern(val, mlir::m_Zero()); 181614cd721SSacha Ballantyne } 182614cd721SSacha Ballantyne return true; 183614cd721SSacha Ballantyne } 184614cd721SSacha Ballantyne 185614cd721SSacha Ballantyne static bool isZero(mlir::Value val) { 186614cd721SSacha Ballantyne if (auto op = expectOp<fir::ConvertOp>(val)) { 187614cd721SSacha Ballantyne assert(op->getOperands().size() != 0); 188614cd721SSacha Ballantyne if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp()) 189614cd721SSacha Ballantyne return mlir::matchPattern(defOp, mlir::m_Zero()); 190614cd721SSacha Ballantyne } 191614cd721SSacha Ballantyne return false; 192614cd721SSacha Ballantyne } 193614cd721SSacha Ballantyne 194614cd721SSacha Ballantyne static mlir::Value findBoxDef(mlir::Value val) { 195614cd721SSacha Ballantyne if (auto op = expectOp<fir::ConvertOp>(val)) { 196614cd721SSacha Ballantyne assert(op->getOperands().size() != 0); 197614cd721SSacha Ballantyne return findDef<fir::EmboxOp, fir::ReboxOp>(op); 198614cd721SSacha Ballantyne } 199614cd721SSacha Ballantyne return {}; 200614cd721SSacha Ballantyne } 201614cd721SSacha Ballantyne 202614cd721SSacha Ballantyne static mlir::Value findMaskDef(mlir::Value val) { 203614cd721SSacha Ballantyne if (auto op = expectOp<fir::ConvertOp>(val)) { 204614cd721SSacha Ballantyne assert(op->getOperands().size() != 0); 205614cd721SSacha Ballantyne return findDef<fir::EmboxOp, fir::ReboxOp, fir::AbsentOp>(op); 206614cd721SSacha Ballantyne } 207614cd721SSacha Ballantyne return {}; 208614cd721SSacha Ballantyne } 209614cd721SSacha Ballantyne 210614cd721SSacha Ballantyne static unsigned getDimCount(mlir::Value val) { 211614cd721SSacha Ballantyne // In order to find the dimensions count, we look for EmboxOp/ReboxOp 212614cd721SSacha Ballantyne // and take the count from its *result* type. Note that in case 213614cd721SSacha Ballantyne // of sliced emboxing the operand and the result of EmboxOp/ReboxOp 214614cd721SSacha Ballantyne // have different types. 215614cd721SSacha Ballantyne // Actually, we can take the box type from the operand of 216614cd721SSacha Ballantyne // the first ConvertOp that has non-opaque box type that we meet 217614cd721SSacha Ballantyne // going through the ConvertOp chain. 218614cd721SSacha Ballantyne if (mlir::Value emboxVal = findBoxDef(val)) 219fac349a1SChristian Sigg if (auto boxTy = mlir::dyn_cast<fir::BoxType>(emboxVal.getType())) 220fac349a1SChristian Sigg if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy())) 221614cd721SSacha Ballantyne return seqTy.getDimension(); 222614cd721SSacha Ballantyne return 0; 223614cd721SSacha Ballantyne } 224614cd721SSacha Ballantyne 225614cd721SSacha Ballantyne /// Given the call operation's box argument \p val, discover 226614cd721SSacha Ballantyne /// the element type of the underlying array object. 227614cd721SSacha Ballantyne /// \returns the element type or std::nullopt if the type cannot 228614cd721SSacha Ballantyne /// be reliably found. 229614cd721SSacha Ballantyne /// We expect that the argument is a result of fir.convert 230614cd721SSacha Ballantyne /// with the destination type of !fir.box<none>. 231614cd721SSacha Ballantyne static std::optional<mlir::Type> getArgElementType(mlir::Value val) { 232614cd721SSacha Ballantyne mlir::Operation *defOp; 233614cd721SSacha Ballantyne do { 234614cd721SSacha Ballantyne defOp = val.getDefiningOp(); 235614cd721SSacha Ballantyne // Analyze only sequences of convert operations. 236614cd721SSacha Ballantyne if (!mlir::isa<fir::ConvertOp>(defOp)) 237614cd721SSacha Ballantyne return std::nullopt; 238614cd721SSacha Ballantyne val = defOp->getOperand(0); 239614cd721SSacha Ballantyne // The convert operation is expected to convert from one 240614cd721SSacha Ballantyne // box type to another box type. 241fac349a1SChristian Sigg auto boxType = mlir::cast<fir::BoxType>(val.getType()); 242614cd721SSacha Ballantyne auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType); 243fac349a1SChristian Sigg if (!mlir::isa<mlir::NoneType>(elementType)) 244614cd721SSacha Ballantyne return elementType; 245614cd721SSacha Ballantyne } while (true); 246614cd721SSacha Ballantyne } 247614cd721SSacha Ballantyne 24843159b58SMats Petersson using BodyOpGeneratorTy = llvm::function_ref<mlir::Value( 24943159b58SMats Petersson fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value, 25043159b58SMats Petersson mlir::Value)>; 25120fba03fSSacha Ballantyne using ContinueLoopGenTy = llvm::function_ref<llvm::SmallVector<mlir::Value>( 25220fba03fSSacha Ballantyne fir::FirOpBuilder &, mlir::Location, mlir::Value)>; 253afa520abSMats Petersson 254afa520abSMats Petersson /// Generate the reduction loop into \p funcOp. 255afa520abSMats Petersson /// 256afa520abSMats Petersson /// \p initVal is a function, called to get the initial value for 257afa520abSMats Petersson /// the reduction value 258afa520abSMats Petersson /// \p genBody is called to fill in the actual reduciton operation 259afa520abSMats Petersson /// for example add for SUM, MAX for MAXVAL, etc. 2608bd76ac1SSlava Zakharin /// \p rank is the rank of the input argument. 26120fba03fSSacha Ballantyne /// \p elementType is the type of the elements in the input array, 26220fba03fSSacha Ballantyne /// which may be different to the return type. 26320fba03fSSacha Ballantyne /// \p loopCond is called to generate the condition to continue or 26420fba03fSSacha Ballantyne /// not for IterWhile loops 26520fba03fSSacha Ballantyne /// \p unorderedOrInitalLoopCond contains either a boolean or bool 26620fba03fSSacha Ballantyne /// mlir constant, and controls the inital value for while loops 26720fba03fSSacha Ballantyne /// or if DoLoop is ordered/unordered. 26820fba03fSSacha Ballantyne 26920fba03fSSacha Ballantyne template <typename OP, typename T, int resultIndex> 27020fba03fSSacha Ballantyne static void 27120fba03fSSacha Ballantyne genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, 272223d3dabSDavid Green fir::InitValGeneratorTy initVal, ContinueLoopGenTy loopCond, 27320fba03fSSacha Ballantyne T unorderedOrInitialLoopCond, BodyOpGeneratorTy genBody, 27420fba03fSSacha Ballantyne unsigned rank, mlir::Type elementType, mlir::Location loc) { 2756e193b5cSMats Petersson 2766e193b5cSMats Petersson mlir::IndexType idxTy = builder.getIndexType(); 2776e193b5cSMats Petersson 27880dcc907SSlava Zakharin mlir::Block::BlockArgListType args = funcOp.front().getArguments(); 2796e193b5cSMats Petersson mlir::Value arg = args[0]; 2806e193b5cSMats Petersson 2816e193b5cSMats Petersson mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); 2826e193b5cSMats Petersson 2838bd76ac1SSlava Zakharin fir::SequenceType::Shape flatShape(rank, 2848bd76ac1SSlava Zakharin fir::SequenceType::getUnknownExtent()); 28580dcc907SSlava Zakharin mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType); 2866e193b5cSMats Petersson mlir::Type boxArrTy = fir::BoxType::get(arrTy); 2876e193b5cSMats Petersson mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg); 2887d2e1987SSacha Ballantyne mlir::Type resultType = funcOp.getResultTypes()[0]; 2897d2e1987SSacha Ballantyne mlir::Value init = initVal(builder, loc, resultType); 2906e193b5cSMats Petersson 291614cd721SSacha Ballantyne llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds; 2928bd76ac1SSlava Zakharin 2938bd76ac1SSlava Zakharin assert(rank > 0 && "rank cannot be zero"); 2948bd76ac1SSlava Zakharin mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); 2958bd76ac1SSlava Zakharin 2968bd76ac1SSlava Zakharin // Compute all the upper bounds before the loop nest. 2978bd76ac1SSlava Zakharin // It is not strictly necessary for performance, since the loop nest 2988bd76ac1SSlava Zakharin // does not have any store operations and any LICM optimization 2998bd76ac1SSlava Zakharin // should be able to optimize the redundancy. 3008bd76ac1SSlava Zakharin for (unsigned i = 0; i < rank; ++i) { 3018bd76ac1SSlava Zakharin mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i); 3028bd76ac1SSlava Zakharin auto dims = 3038bd76ac1SSlava Zakharin builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx); 3048bd76ac1SSlava Zakharin mlir::Value len = dims.getResult(1); 3056e193b5cSMats Petersson // We use C indexing here, so len-1 as loopcount 3066e193b5cSMats Petersson mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one); 3078bd76ac1SSlava Zakharin bounds.push_back(loopCount); 3088bd76ac1SSlava Zakharin } 30920fba03fSSacha Ballantyne // Create a loop nest consisting of OP operations. 3108bd76ac1SSlava Zakharin // Collect the loops' induction variables into indices array, 3118bd76ac1SSlava Zakharin // which will be used in the innermost loop to load the input 3128bd76ac1SSlava Zakharin // array's element. 3138bd76ac1SSlava Zakharin // The loops are generated such that the innermost loop processes 3148bd76ac1SSlava Zakharin // the 0 dimension. 315614cd721SSacha Ballantyne llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices; 3168bd76ac1SSlava Zakharin for (unsigned i = rank; 0 < i; --i) { 3178bd76ac1SSlava Zakharin mlir::Value step = one; 3188bd76ac1SSlava Zakharin mlir::Value loopCount = bounds[i - 1]; 31920fba03fSSacha Ballantyne auto loop = builder.create<OP>(loc, zeroIdx, loopCount, step, 32020fba03fSSacha Ballantyne unorderedOrInitialLoopCond, 321afa520abSMats Petersson /*finalCountValue=*/false, init); 32298ecc3acSSacha Ballantyne init = loop.getRegionIterArgs()[resultIndex]; 3238bd76ac1SSlava Zakharin indices.push_back(loop.getInductionVar()); 3248bd76ac1SSlava Zakharin // Set insertion point to the loop body so that the next loop 3258bd76ac1SSlava Zakharin // is inserted inside the current one. 3266e193b5cSMats Petersson builder.setInsertionPointToStart(loop.getBody()); 3278bd76ac1SSlava Zakharin } 3286e193b5cSMats Petersson 3298bd76ac1SSlava Zakharin // Reverse the indices such that they are ordered as: 3308bd76ac1SSlava Zakharin // <dim-0-idx, dim-1-idx, ...> 3318bd76ac1SSlava Zakharin std::reverse(indices.begin(), indices.end()); 3328bd76ac1SSlava Zakharin // We are in the innermost loop: generate the reduction body. 33380dcc907SSlava Zakharin mlir::Type eleRefTy = builder.getRefType(elementType); 3346e193b5cSMats Petersson mlir::Value addr = 3358bd76ac1SSlava Zakharin builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices); 3366e193b5cSMats Petersson mlir::Value elem = builder.create<fir::LoadOp>(loc, addr); 3378bd76ac1SSlava Zakharin mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init); 33820fba03fSSacha Ballantyne // Generate vector with condition to continue while loop at [0] and result 33920fba03fSSacha Ballantyne // from current loop at [1] for IterWhileOp loops, just result at [0] for 34020fba03fSSacha Ballantyne // DoLoopOp loops. 34120fba03fSSacha Ballantyne llvm::SmallVector<mlir::Value> results = loopCond(builder, loc, reductionVal); 3426e193b5cSMats Petersson 3438bd76ac1SSlava Zakharin // Unwind the loop nest and insert ResultOp on each level 3448bd76ac1SSlava Zakharin // to return the updated value of the reduction to the enclosing 3458bd76ac1SSlava Zakharin // loops. 3468bd76ac1SSlava Zakharin for (unsigned i = 0; i < rank; ++i) { 34720fba03fSSacha Ballantyne auto result = builder.create<fir::ResultOp>(loc, results); 3488bd76ac1SSlava Zakharin // Proceed to the outer loop. 34920fba03fSSacha Ballantyne auto loop = mlir::cast<OP>(result->getParentOp()); 35020fba03fSSacha Ballantyne results = loop.getResults(); 3518bd76ac1SSlava Zakharin // Set insertion point after the loop operation that we have 3528bd76ac1SSlava Zakharin // just processed. 3538bd76ac1SSlava Zakharin builder.setInsertionPointAfter(loop.getOperation()); 3548bd76ac1SSlava Zakharin } 3558bd76ac1SSlava Zakharin // End of loop nest. The insertion point is after the outermost loop. 3568bd76ac1SSlava Zakharin // Return the reduction value from the function. 35720fba03fSSacha Ballantyne builder.create<mlir::func::ReturnOp>(loc, results[resultIndex]); 35820fba03fSSacha Ballantyne } 359614cd721SSacha Ballantyne 36020fba03fSSacha Ballantyne static llvm::SmallVector<mlir::Value> nopLoopCond(fir::FirOpBuilder &builder, 361614cd721SSacha Ballantyne mlir::Location loc, 36220fba03fSSacha Ballantyne mlir::Value reductionVal) { 36320fba03fSSacha Ballantyne return {reductionVal}; 36480dcc907SSlava Zakharin } 36580dcc907SSlava Zakharin 366aa94eb38SMats Petersson /// Generate function body of the simplified version of RTNAME(Sum) 367afa520abSMats Petersson /// with signature provided by \p funcOp. The caller is responsible 368afa520abSMats Petersson /// for saving/restoring the original insertion point of \p builder. 369afa520abSMats Petersson /// \p funcOp is expected to be empty on entry to this function. 3708bd76ac1SSlava Zakharin /// \p rank specifies the rank of the input argument. 371aa94eb38SMats Petersson static void genRuntimeSumBody(fir::FirOpBuilder &builder, 37220fba03fSSacha Ballantyne mlir::func::FuncOp &funcOp, unsigned rank, 37320fba03fSSacha Ballantyne mlir::Type elementType) { 3748bd76ac1SSlava Zakharin // function RTNAME(Sum)<T>x<rank>_simplified(arr) 375afa520abSMats Petersson // T, dimension(:) :: arr 376afa520abSMats Petersson // T sum = 0 377afa520abSMats Petersson // integer iter 378afa520abSMats Petersson // do iter = 0, extent(arr) 379afa520abSMats Petersson // sum = sum + arr[iter] 380afa520abSMats Petersson // end do 3818bd76ac1SSlava Zakharin // RTNAME(Sum)<T>x<rank>_simplified = sum 3828bd76ac1SSlava Zakharin // end function RTNAME(Sum)<T>x<rank>_simplified 383afa520abSMats Petersson auto zero = [](fir::FirOpBuilder builder, mlir::Location loc, 384afa520abSMats Petersson mlir::Type elementType) { 385fac349a1SChristian Sigg if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) { 3862b138567SSlava Zakharin const llvm::fltSemantics &sem = ty.getFloatSemantics(); 3872b138567SSlava Zakharin return builder.createRealConstant(loc, elementType, 3882b138567SSlava Zakharin llvm::APFloat::getZero(sem)); 3892b138567SSlava Zakharin } 3902b138567SSlava Zakharin return builder.createIntegerConstant(loc, elementType, 0); 391afa520abSMats Petersson }; 392afa520abSMats Petersson 393afa520abSMats Petersson auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 394afa520abSMats Petersson mlir::Type elementType, mlir::Value elem1, 395afa520abSMats Petersson mlir::Value elem2) -> mlir::Value { 396fac349a1SChristian Sigg if (mlir::isa<mlir::FloatType>(elementType)) 397afa520abSMats Petersson return builder.create<mlir::arith::AddFOp>(loc, elem1, elem2); 398fac349a1SChristian Sigg if (mlir::isa<mlir::IntegerType>(elementType)) 399afa520abSMats Petersson return builder.create<mlir::arith::AddIOp>(loc, elem1, elem2); 400afa520abSMats Petersson 401afa520abSMats Petersson llvm_unreachable("unsupported type"); 402afa520abSMats Petersson return {}; 403afa520abSMats Petersson }; 404afa520abSMats Petersson 40520fba03fSSacha Ballantyne mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 40620fba03fSSacha Ballantyne builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 4077d2e1987SSacha Ballantyne 40820fba03fSSacha Ballantyne genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond, 40920fba03fSSacha Ballantyne false, genBodyOp, rank, elementType, 41020fba03fSSacha Ballantyne loc); 411afa520abSMats Petersson } 412afa520abSMats Petersson 413aa94eb38SMats Petersson static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder, 41420fba03fSSacha Ballantyne mlir::func::FuncOp &funcOp, unsigned rank, 41520fba03fSSacha Ballantyne mlir::Type elementType) { 416afa520abSMats Petersson auto init = [](fir::FirOpBuilder builder, mlir::Location loc, 417afa520abSMats Petersson mlir::Type elementType) { 418fac349a1SChristian Sigg if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) { 419afa520abSMats Petersson const llvm::fltSemantics &sem = ty.getFloatSemantics(); 420afa520abSMats Petersson return builder.createRealConstant( 421afa520abSMats Petersson loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/true)); 422afa520abSMats Petersson } 423afa520abSMats Petersson unsigned bits = elementType.getIntOrFloatBitWidth(); 424afa520abSMats Petersson int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); 425afa520abSMats Petersson return builder.createIntegerConstant(loc, elementType, minInt); 426afa520abSMats Petersson }; 427afa520abSMats Petersson 428afa520abSMats Petersson auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 429afa520abSMats Petersson mlir::Type elementType, mlir::Value elem1, 430afa520abSMats Petersson mlir::Value elem2) -> mlir::Value { 431fac349a1SChristian Sigg if (mlir::isa<mlir::FloatType>(elementType)) { 43289b98c13SSlava Zakharin // arith.maxf later converted to llvm.intr.maxnum does not work 43389b98c13SSlava Zakharin // correctly for NaNs and -0.0 (see maxnum/minnum pattern matching 43489b98c13SSlava Zakharin // in LLVM's InstCombine pass). Moreover, llvm.intr.maxnum 43589b98c13SSlava Zakharin // for F128 operands is lowered into fmaxl call by LLVM. 43689b98c13SSlava Zakharin // This libm function may not work properly for F128 arguments 43789b98c13SSlava Zakharin // on targets where long double is not F128. It is an LLVM issue, 43889b98c13SSlava Zakharin // but we just use normal select here to resolve all the cases. 43989b98c13SSlava Zakharin auto compare = builder.create<mlir::arith::CmpFOp>( 44089b98c13SSlava Zakharin loc, mlir::arith::CmpFPredicate::OGT, elem1, elem2); 44189b98c13SSlava Zakharin return builder.create<mlir::arith::SelectOp>(loc, compare, elem1, elem2); 44289b98c13SSlava Zakharin } 443fac349a1SChristian Sigg if (mlir::isa<mlir::IntegerType>(elementType)) 444afa520abSMats Petersson return builder.create<mlir::arith::MaxSIOp>(loc, elem1, elem2); 445afa520abSMats Petersson 446afa520abSMats Petersson llvm_unreachable("unsupported type"); 447afa520abSMats Petersson return {}; 448afa520abSMats Petersson }; 4497d2e1987SSacha Ballantyne 45020fba03fSSacha Ballantyne mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 45120fba03fSSacha Ballantyne builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 4527d2e1987SSacha Ballantyne 45320fba03fSSacha Ballantyne genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, init, nopLoopCond, 45420fba03fSSacha Ballantyne false, genBodyOp, rank, elementType, 45520fba03fSSacha Ballantyne loc); 4567d2e1987SSacha Ballantyne } 4577d2e1987SSacha Ballantyne 4587d2e1987SSacha Ballantyne static void genRuntimeCountBody(fir::FirOpBuilder &builder, 45920fba03fSSacha Ballantyne mlir::func::FuncOp &funcOp, unsigned rank, 46020fba03fSSacha Ballantyne mlir::Type elementType) { 4617d2e1987SSacha Ballantyne auto zero = [](fir::FirOpBuilder builder, mlir::Location loc, 4627d2e1987SSacha Ballantyne mlir::Type elementType) { 4637d2e1987SSacha Ballantyne unsigned bits = elementType.getIntOrFloatBitWidth(); 4647d2e1987SSacha Ballantyne int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 4657d2e1987SSacha Ballantyne return builder.createIntegerConstant(loc, elementType, zeroInt); 4667d2e1987SSacha Ballantyne }; 4677d2e1987SSacha Ballantyne 4687d2e1987SSacha Ballantyne auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 4697d2e1987SSacha Ballantyne mlir::Type elementType, mlir::Value elem1, 4707d2e1987SSacha Ballantyne mlir::Value elem2) -> mlir::Value { 47179dccdedSSacha Ballantyne auto zero32 = builder.createIntegerConstant(loc, elementType, 0); 4727d2e1987SSacha Ballantyne auto zero64 = builder.createIntegerConstant(loc, builder.getI64Type(), 0); 4737d2e1987SSacha Ballantyne auto one64 = builder.createIntegerConstant(loc, builder.getI64Type(), 1); 4747d2e1987SSacha Ballantyne 4757d2e1987SSacha Ballantyne auto compare = builder.create<mlir::arith::CmpIOp>( 4767d2e1987SSacha Ballantyne loc, mlir::arith::CmpIPredicate::eq, elem1, zero32); 4777d2e1987SSacha Ballantyne auto select = 4787d2e1987SSacha Ballantyne builder.create<mlir::arith::SelectOp>(loc, compare, zero64, one64); 4797d2e1987SSacha Ballantyne return builder.create<mlir::arith::AddIOp>(loc, select, elem2); 4807d2e1987SSacha Ballantyne }; 4817d2e1987SSacha Ballantyne 48220fba03fSSacha Ballantyne // Count always gets I32 for elementType as it converts logical input to 48320fba03fSSacha Ballantyne // logical<4> before passing to the function. 48420fba03fSSacha Ballantyne mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 48520fba03fSSacha Ballantyne builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 4867d2e1987SSacha Ballantyne 48720fba03fSSacha Ballantyne genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond, 48820fba03fSSacha Ballantyne false, genBodyOp, rank, elementType, 48920fba03fSSacha Ballantyne loc); 49020fba03fSSacha Ballantyne } 49120fba03fSSacha Ballantyne 49220fba03fSSacha Ballantyne static void genRuntimeAnyBody(fir::FirOpBuilder &builder, 49320fba03fSSacha Ballantyne mlir::func::FuncOp &funcOp, unsigned rank, 49420fba03fSSacha Ballantyne mlir::Type elementType) { 49520fba03fSSacha Ballantyne auto zero = [](fir::FirOpBuilder builder, mlir::Location loc, 49620fba03fSSacha Ballantyne mlir::Type elementType) { 49720fba03fSSacha Ballantyne return builder.createIntegerConstant(loc, elementType, 0); 49820fba03fSSacha Ballantyne }; 49920fba03fSSacha Ballantyne 50020fba03fSSacha Ballantyne auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 50120fba03fSSacha Ballantyne mlir::Type elementType, mlir::Value elem1, 50220fba03fSSacha Ballantyne mlir::Value elem2) -> mlir::Value { 50320fba03fSSacha Ballantyne auto zero = builder.createIntegerConstant(loc, elementType, 0); 50420fba03fSSacha Ballantyne return builder.create<mlir::arith::CmpIOp>( 50520fba03fSSacha Ballantyne loc, mlir::arith::CmpIPredicate::ne, elem1, zero); 50620fba03fSSacha Ballantyne }; 50720fba03fSSacha Ballantyne 50820fba03fSSacha Ballantyne auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc, 50920fba03fSSacha Ballantyne mlir::Value reductionVal) { 51020fba03fSSacha Ballantyne auto one1 = builder.createIntegerConstant(loc, builder.getI1Type(), 1); 51120fba03fSSacha Ballantyne auto eor = builder.create<mlir::arith::XOrIOp>(loc, reductionVal, one1); 51220fba03fSSacha Ballantyne llvm::SmallVector<mlir::Value> results = {eor, reductionVal}; 51320fba03fSSacha Ballantyne return results; 51420fba03fSSacha Ballantyne }; 51520fba03fSSacha Ballantyne 51620fba03fSSacha Ballantyne mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 51720fba03fSSacha Ballantyne builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 51820fba03fSSacha Ballantyne mlir::Value ok = builder.createBool(loc, true); 51920fba03fSSacha Ballantyne 52020fba03fSSacha Ballantyne genReductionLoop<fir::IterWhileOp, mlir::Value, 1>( 52120fba03fSSacha Ballantyne builder, funcOp, zero, continueCond, ok, genBodyOp, rank, elementType, 52220fba03fSSacha Ballantyne loc); 52320fba03fSSacha Ballantyne } 52420fba03fSSacha Ballantyne 52520fba03fSSacha Ballantyne static void genRuntimeAllBody(fir::FirOpBuilder &builder, 52620fba03fSSacha Ballantyne mlir::func::FuncOp &funcOp, unsigned rank, 52720fba03fSSacha Ballantyne mlir::Type elementType) { 52820fba03fSSacha Ballantyne auto one = [](fir::FirOpBuilder builder, mlir::Location loc, 52920fba03fSSacha Ballantyne mlir::Type elementType) { 53020fba03fSSacha Ballantyne return builder.createIntegerConstant(loc, elementType, 1); 53120fba03fSSacha Ballantyne }; 53220fba03fSSacha Ballantyne 53320fba03fSSacha Ballantyne auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 53420fba03fSSacha Ballantyne mlir::Type elementType, mlir::Value elem1, 53520fba03fSSacha Ballantyne mlir::Value elem2) -> mlir::Value { 53620fba03fSSacha Ballantyne auto zero = builder.createIntegerConstant(loc, elementType, 0); 53720fba03fSSacha Ballantyne return builder.create<mlir::arith::CmpIOp>( 53820fba03fSSacha Ballantyne loc, mlir::arith::CmpIPredicate::ne, elem1, zero); 53920fba03fSSacha Ballantyne }; 54020fba03fSSacha Ballantyne 54120fba03fSSacha Ballantyne auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc, 54220fba03fSSacha Ballantyne mlir::Value reductionVal) { 54320fba03fSSacha Ballantyne llvm::SmallVector<mlir::Value> results = {reductionVal, reductionVal}; 54420fba03fSSacha Ballantyne return results; 54520fba03fSSacha Ballantyne }; 54620fba03fSSacha Ballantyne 54720fba03fSSacha Ballantyne mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 54820fba03fSSacha Ballantyne builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 54920fba03fSSacha Ballantyne mlir::Value ok = builder.createBool(loc, true); 55020fba03fSSacha Ballantyne 55120fba03fSSacha Ballantyne genReductionLoop<fir::IterWhileOp, mlir::Value, 1>( 55220fba03fSSacha Ballantyne builder, funcOp, one, continueCond, ok, genBodyOp, rank, elementType, 55320fba03fSSacha Ballantyne loc); 554afa520abSMats Petersson } 555afa520abSMats Petersson 556614cd721SSacha Ballantyne static mlir::FunctionType genRuntimeMinlocType(fir::FirOpBuilder &builder, 557614cd721SSacha Ballantyne unsigned int rank) { 558614cd721SSacha Ballantyne mlir::Type boxType = fir::BoxType::get(builder.getNoneType()); 559614cd721SSacha Ballantyne mlir::Type boxRefType = builder.getRefType(boxType); 560614cd721SSacha Ballantyne 561614cd721SSacha Ballantyne return mlir::FunctionType::get(builder.getContext(), 562614cd721SSacha Ballantyne {boxRefType, boxType, boxType}, {}); 563614cd721SSacha Ballantyne } 564614cd721SSacha Ballantyne 565815a8465SDavid Green // Produces a loop nest for a Minloc intrinsic. 566815a8465SDavid Green void fir::genMinMaxlocReductionLoop( 567815a8465SDavid Green fir::FirOpBuilder &builder, mlir::Value array, 568815a8465SDavid Green fir::InitValGeneratorTy initVal, fir::MinlocBodyOpGeneratorTy genBody, 569815a8465SDavid Green fir::AddrGeneratorTy getAddrFn, unsigned rank, mlir::Type elementType, 570815a8465SDavid Green mlir::Location loc, mlir::Type maskElemType, mlir::Value resultArr, 571815a8465SDavid Green bool maskMayBeLogicalScalar) { 572815a8465SDavid Green mlir::IndexType idxTy = builder.getIndexType(); 573815a8465SDavid Green 574815a8465SDavid Green mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); 575815a8465SDavid Green 576815a8465SDavid Green fir::SequenceType::Shape flatShape(rank, 577815a8465SDavid Green fir::SequenceType::getUnknownExtent()); 578815a8465SDavid Green mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType); 579815a8465SDavid Green mlir::Type boxArrTy = fir::BoxType::get(arrTy); 580815a8465SDavid Green array = builder.create<fir::ConvertOp>(loc, boxArrTy, array); 581815a8465SDavid Green 582815a8465SDavid Green mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType()); 583815a8465SDavid Green mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1); 584815a8465SDavid Green mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0); 585815a8465SDavid Green mlir::Value flagRef = builder.createTemporary(loc, resultElemType); 586815a8465SDavid Green builder.create<fir::StoreOp>(loc, zero, flagRef); 587815a8465SDavid Green 588815a8465SDavid Green mlir::Value init = initVal(builder, loc, elementType); 589815a8465SDavid Green llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds; 590815a8465SDavid Green 591815a8465SDavid Green assert(rank > 0 && "rank cannot be zero"); 592815a8465SDavid Green mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); 593815a8465SDavid Green 594815a8465SDavid Green // Compute all the upper bounds before the loop nest. 595815a8465SDavid Green // It is not strictly necessary for performance, since the loop nest 596815a8465SDavid Green // does not have any store operations and any LICM optimization 597815a8465SDavid Green // should be able to optimize the redundancy. 598815a8465SDavid Green for (unsigned i = 0; i < rank; ++i) { 599815a8465SDavid Green mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i); 600815a8465SDavid Green auto dims = 601815a8465SDavid Green builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx); 602815a8465SDavid Green mlir::Value len = dims.getResult(1); 603815a8465SDavid Green // We use C indexing here, so len-1 as loopcount 604815a8465SDavid Green mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one); 605815a8465SDavid Green bounds.push_back(loopCount); 606815a8465SDavid Green } 607815a8465SDavid Green // Create a loop nest consisting of OP operations. 608815a8465SDavid Green // Collect the loops' induction variables into indices array, 609815a8465SDavid Green // which will be used in the innermost loop to load the input 610815a8465SDavid Green // array's element. 611815a8465SDavid Green // The loops are generated such that the innermost loop processes 612815a8465SDavid Green // the 0 dimension. 613815a8465SDavid Green llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices; 614815a8465SDavid Green for (unsigned i = rank; 0 < i; --i) { 615815a8465SDavid Green mlir::Value step = one; 616815a8465SDavid Green mlir::Value loopCount = bounds[i - 1]; 617815a8465SDavid Green auto loop = 618815a8465SDavid Green builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false, 619815a8465SDavid Green /*finalCountValue=*/false, init); 620815a8465SDavid Green init = loop.getRegionIterArgs()[0]; 621815a8465SDavid Green indices.push_back(loop.getInductionVar()); 622815a8465SDavid Green // Set insertion point to the loop body so that the next loop 623815a8465SDavid Green // is inserted inside the current one. 624815a8465SDavid Green builder.setInsertionPointToStart(loop.getBody()); 625815a8465SDavid Green } 626815a8465SDavid Green 627815a8465SDavid Green // Reverse the indices such that they are ordered as: 628815a8465SDavid Green // <dim-0-idx, dim-1-idx, ...> 629815a8465SDavid Green std::reverse(indices.begin(), indices.end()); 630815a8465SDavid Green mlir::Value reductionVal = 631815a8465SDavid Green genBody(builder, loc, elementType, array, flagRef, init, indices); 632815a8465SDavid Green 633815a8465SDavid Green // Unwind the loop nest and insert ResultOp on each level 634815a8465SDavid Green // to return the updated value of the reduction to the enclosing 635815a8465SDavid Green // loops. 636815a8465SDavid Green for (unsigned i = 0; i < rank; ++i) { 637815a8465SDavid Green auto result = builder.create<fir::ResultOp>(loc, reductionVal); 638815a8465SDavid Green // Proceed to the outer loop. 639815a8465SDavid Green auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp()); 640815a8465SDavid Green reductionVal = loop.getResult(0); 641815a8465SDavid Green // Set insertion point after the loop operation that we have 642815a8465SDavid Green // just processed. 643815a8465SDavid Green builder.setInsertionPointAfter(loop.getOperation()); 644815a8465SDavid Green } 645815a8465SDavid Green // End of loop nest. The insertion point is after the outermost loop. 646815a8465SDavid Green if (maskMayBeLogicalScalar) { 647815a8465SDavid Green if (fir::IfOp ifOp = 648815a8465SDavid Green mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) { 649815a8465SDavid Green builder.create<fir::ResultOp>(loc, reductionVal); 650815a8465SDavid Green builder.setInsertionPointAfter(ifOp); 651815a8465SDavid Green // Redefine flagSet to escape scope of ifOp 652815a8465SDavid Green flagSet = builder.createIntegerConstant(loc, resultElemType, 1); 653815a8465SDavid Green reductionVal = ifOp.getResult(0); 654815a8465SDavid Green } 655815a8465SDavid Green } 656815a8465SDavid Green } 657815a8465SDavid Green 6589bb47f7fSDavid Green static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder, 6599bb47f7fSDavid Green mlir::func::FuncOp &funcOp, bool isMax, 6609bb47f7fSDavid Green unsigned rank, int maskRank, 6619bb47f7fSDavid Green mlir::Type elementType, 662614cd721SSacha Ballantyne mlir::Type maskElemType, 6632a95fe48SDavid Green mlir::Type resultElemTy, bool isDim) { 6649bb47f7fSDavid Green auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc, 665614cd721SSacha Ballantyne mlir::Type elementType) { 666fac349a1SChristian Sigg if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) { 667614cd721SSacha Ballantyne const llvm::fltSemantics &sem = ty.getFloatSemantics(); 66872428962SDavid Green llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax); 66972428962SDavid Green return builder.createRealConstant(loc, elementType, limit); 670614cd721SSacha Ballantyne } 671614cd721SSacha Ballantyne unsigned bits = elementType.getIntOrFloatBitWidth(); 6729bb47f7fSDavid Green int64_t initValue = (isMax ? llvm::APInt::getSignedMinValue(bits) 6739bb47f7fSDavid Green : llvm::APInt::getSignedMaxValue(bits)) 6749bb47f7fSDavid Green .getSExtValue(); 6759bb47f7fSDavid Green return builder.createIntegerConstant(loc, elementType, initValue); 676614cd721SSacha Ballantyne }; 677614cd721SSacha Ballantyne 678614cd721SSacha Ballantyne mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 679614cd721SSacha Ballantyne builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 680614cd721SSacha Ballantyne 681614cd721SSacha Ballantyne mlir::Value mask = funcOp.front().getArgument(2); 682614cd721SSacha Ballantyne 683614cd721SSacha Ballantyne // Set up result array in case of early exit / 0 length array 684614cd721SSacha Ballantyne mlir::IndexType idxTy = builder.getIndexType(); 685614cd721SSacha Ballantyne mlir::Type resultTy = fir::SequenceType::get(rank, resultElemTy); 686614cd721SSacha Ballantyne mlir::Type resultHeapTy = fir::HeapType::get(resultTy); 687614cd721SSacha Ballantyne mlir::Type resultBoxTy = fir::BoxType::get(resultHeapTy); 688614cd721SSacha Ballantyne 689614cd721SSacha Ballantyne mlir::Value returnValue = builder.createIntegerConstant(loc, resultElemTy, 0); 690614cd721SSacha Ballantyne mlir::Value resultArrSize = builder.createIntegerConstant(loc, idxTy, rank); 691614cd721SSacha Ballantyne 692614cd721SSacha Ballantyne mlir::Value resultArrInit = builder.create<fir::AllocMemOp>(loc, resultTy); 693614cd721SSacha Ballantyne mlir::Value resultArrShape = builder.create<fir::ShapeOp>(loc, resultArrSize); 694614cd721SSacha Ballantyne mlir::Value resultArr = builder.create<fir::EmboxOp>( 695614cd721SSacha Ballantyne loc, resultBoxTy, resultArrInit, resultArrShape); 696614cd721SSacha Ballantyne 697614cd721SSacha Ballantyne mlir::Type resultRefTy = builder.getRefType(resultElemTy); 698614cd721SSacha Ballantyne 699223d3dabSDavid Green if (maskRank > 0) { 700223d3dabSDavid Green fir::SequenceType::Shape flatShape(rank, 701223d3dabSDavid Green fir::SequenceType::getUnknownExtent()); 702223d3dabSDavid Green mlir::Type maskTy = fir::SequenceType::get(flatShape, maskElemType); 703223d3dabSDavid Green mlir::Type boxMaskTy = fir::BoxType::get(maskTy); 704223d3dabSDavid Green mask = builder.create<fir::ConvertOp>(loc, boxMaskTy, mask); 705223d3dabSDavid Green } 706223d3dabSDavid Green 707614cd721SSacha Ballantyne for (unsigned int i = 0; i < rank; ++i) { 708614cd721SSacha Ballantyne mlir::Value index = builder.createIntegerConstant(loc, idxTy, i); 709614cd721SSacha Ballantyne mlir::Value resultElemAddr = 710614cd721SSacha Ballantyne builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr, index); 711614cd721SSacha Ballantyne builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr); 712614cd721SSacha Ballantyne } 713614cd721SSacha Ballantyne 714614cd721SSacha Ballantyne auto genBodyOp = 715223d3dabSDavid Green [&rank, &resultArr, isMax, &mask, &maskElemType, &maskRank]( 716223d3dabSDavid Green fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType, 717223d3dabSDavid Green mlir::Value array, mlir::Value flagRef, mlir::Value reduction, 718223d3dabSDavid Green const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value { 719223d3dabSDavid Green // We are in the innermost loop: generate the reduction body. 720223d3dabSDavid Green if (maskRank > 0) { 721223d3dabSDavid Green mlir::Type logicalRef = builder.getRefType(maskElemType); 722223d3dabSDavid Green mlir::Value maskAddr = 723223d3dabSDavid Green builder.create<fir::CoordinateOp>(loc, logicalRef, mask, indices); 724223d3dabSDavid Green mlir::Value maskElem = builder.create<fir::LoadOp>(loc, maskAddr); 725223d3dabSDavid Green 726223d3dabSDavid Green // fir::IfOp requires argument to be I1 - won't accept logical or any 727223d3dabSDavid Green // other Integer. 728223d3dabSDavid Green mlir::Type ifCompatType = builder.getI1Type(); 729223d3dabSDavid Green mlir::Value ifCompatElem = 730223d3dabSDavid Green builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem); 731223d3dabSDavid Green 732223d3dabSDavid Green llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType}; 733223d3dabSDavid Green fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, ifCompatElem, 734223d3dabSDavid Green /*withElseRegion=*/true); 735223d3dabSDavid Green builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 736223d3dabSDavid Green } 737223d3dabSDavid Green 738223d3dabSDavid Green // Set flag that mask was true at some point 739223d3dabSDavid Green mlir::Value flagSet = builder.createIntegerConstant( 740223d3dabSDavid Green loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1); 74172428962SDavid Green mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef); 742223d3dabSDavid Green mlir::Type eleRefTy = builder.getRefType(elementType); 743223d3dabSDavid Green mlir::Value addr = 744223d3dabSDavid Green builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices); 745223d3dabSDavid Green mlir::Value elem = builder.create<fir::LoadOp>(loc, addr); 746223d3dabSDavid Green 747614cd721SSacha Ballantyne mlir::Value cmp; 748fac349a1SChristian Sigg if (mlir::isa<mlir::FloatType>(elementType)) { 74972428962SDavid Green // For FP reductions we want the first smallest value to be used, that 75072428962SDavid Green // is not NaN. A OGL/OLT condition will usually work for this unless all 75172428962SDavid Green // the values are Nan or Inf. This follows the same logic as 75272428962SDavid Green // NumericCompare for Minloc/Maxlox in extrema.cpp. 753614cd721SSacha Ballantyne cmp = builder.create<mlir::arith::CmpFOp>( 7549bb47f7fSDavid Green loc, 7559bb47f7fSDavid Green isMax ? mlir::arith::CmpFPredicate::OGT 7569bb47f7fSDavid Green : mlir::arith::CmpFPredicate::OLT, 757223d3dabSDavid Green elem, reduction); 75872428962SDavid Green 75972428962SDavid Green mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>( 76072428962SDavid Green loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction); 76172428962SDavid Green mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>( 76272428962SDavid Green loc, mlir::arith::CmpFPredicate::OEQ, elem, elem); 76372428962SDavid Green cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2); 76472428962SDavid Green cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan); 765fac349a1SChristian Sigg } else if (mlir::isa<mlir::IntegerType>(elementType)) { 766614cd721SSacha Ballantyne cmp = builder.create<mlir::arith::CmpIOp>( 7679bb47f7fSDavid Green loc, 7689bb47f7fSDavid Green isMax ? mlir::arith::CmpIPredicate::sgt 7699bb47f7fSDavid Green : mlir::arith::CmpIPredicate::slt, 770223d3dabSDavid Green elem, reduction); 771614cd721SSacha Ballantyne } else { 772614cd721SSacha Ballantyne llvm_unreachable("unsupported type"); 773614cd721SSacha Ballantyne } 774614cd721SSacha Ballantyne 77572428962SDavid Green // The condition used for the loop is isFirst || <the condition above>. 77672428962SDavid Green isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst); 77772428962SDavid Green isFirst = builder.create<mlir::arith::XOrIOp>( 77872428962SDavid Green loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1)); 77972428962SDavid Green cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst); 780614cd721SSacha Ballantyne fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp, 781614cd721SSacha Ballantyne /*withElseRegion*/ true); 782614cd721SSacha Ballantyne 783614cd721SSacha Ballantyne builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 78472428962SDavid Green builder.create<fir::StoreOp>(loc, flagSet, flagRef); 785614cd721SSacha Ballantyne mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType()); 786614cd721SSacha Ballantyne mlir::Type returnRefTy = builder.getRefType(resultElemTy); 787614cd721SSacha Ballantyne mlir::IndexType idxTy = builder.getIndexType(); 788614cd721SSacha Ballantyne 789614cd721SSacha Ballantyne mlir::Value one = builder.createIntegerConstant(loc, resultElemTy, 1); 790614cd721SSacha Ballantyne 791614cd721SSacha Ballantyne for (unsigned int i = 0; i < rank; ++i) { 792614cd721SSacha Ballantyne mlir::Value index = builder.createIntegerConstant(loc, idxTy, i); 793614cd721SSacha Ballantyne mlir::Value resultElemAddr = 794614cd721SSacha Ballantyne builder.create<fir::CoordinateOp>(loc, returnRefTy, resultArr, index); 795614cd721SSacha Ballantyne mlir::Value convert = 796614cd721SSacha Ballantyne builder.create<fir::ConvertOp>(loc, resultElemTy, indices[i]); 797614cd721SSacha Ballantyne mlir::Value fortranIndex = 798614cd721SSacha Ballantyne builder.create<mlir::arith::AddIOp>(loc, convert, one); 799614cd721SSacha Ballantyne builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr); 800614cd721SSacha Ballantyne } 801223d3dabSDavid Green builder.create<fir::ResultOp>(loc, elem); 802614cd721SSacha Ballantyne builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 803223d3dabSDavid Green builder.create<fir::ResultOp>(loc, reduction); 804614cd721SSacha Ballantyne builder.setInsertionPointAfter(ifOp); 805223d3dabSDavid Green mlir::Value reductionVal = ifOp.getResult(0); 806223d3dabSDavid Green 807223d3dabSDavid Green // Close the mask if needed 808223d3dabSDavid Green if (maskRank > 0) { 809223d3dabSDavid Green fir::IfOp ifOp = 810223d3dabSDavid Green mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp()); 811223d3dabSDavid Green builder.create<fir::ResultOp>(loc, reductionVal); 812223d3dabSDavid Green builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 813223d3dabSDavid Green builder.create<fir::ResultOp>(loc, reduction); 814223d3dabSDavid Green reductionVal = ifOp.getResult(0); 815223d3dabSDavid Green builder.setInsertionPointAfter(ifOp); 816223d3dabSDavid Green } 817223d3dabSDavid Green 818223d3dabSDavid Green return reductionVal; 819614cd721SSacha Ballantyne }; 820614cd721SSacha Ballantyne 821614cd721SSacha Ballantyne // if mask is a logical scalar, we can check its value before the main loop 822614cd721SSacha Ballantyne // and either ignore the fact it is there or exit early. 823614cd721SSacha Ballantyne if (maskRank == 0) { 824614cd721SSacha Ballantyne mlir::Type logical = builder.getI1Type(); 825614cd721SSacha Ballantyne mlir::IndexType idxTy = builder.getIndexType(); 826614cd721SSacha Ballantyne 827614cd721SSacha Ballantyne fir::SequenceType::Shape singleElement(1, 1); 828614cd721SSacha Ballantyne mlir::Type arrTy = fir::SequenceType::get(singleElement, logical); 829614cd721SSacha Ballantyne mlir::Type boxArrTy = fir::BoxType::get(arrTy); 830614cd721SSacha Ballantyne mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, mask); 831614cd721SSacha Ballantyne 832614cd721SSacha Ballantyne mlir::Value indx = builder.createIntegerConstant(loc, idxTy, 0); 833614cd721SSacha Ballantyne mlir::Type logicalRefTy = builder.getRefType(logical); 834614cd721SSacha Ballantyne mlir::Value condAddr = 835614cd721SSacha Ballantyne builder.create<fir::CoordinateOp>(loc, logicalRefTy, array, indx); 836614cd721SSacha Ballantyne mlir::Value cond = builder.create<fir::LoadOp>(loc, condAddr); 837614cd721SSacha Ballantyne 838614cd721SSacha Ballantyne fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cond, 839614cd721SSacha Ballantyne /*withElseRegion=*/true); 840614cd721SSacha Ballantyne 841614cd721SSacha Ballantyne builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 842614cd721SSacha Ballantyne mlir::Value basicValue; 843fac349a1SChristian Sigg if (mlir::isa<mlir::IntegerType>(elementType)) { 844614cd721SSacha Ballantyne basicValue = builder.createIntegerConstant(loc, elementType, 0); 845614cd721SSacha Ballantyne } else { 846614cd721SSacha Ballantyne basicValue = builder.createRealConstant(loc, elementType, 0); 847614cd721SSacha Ballantyne } 848614cd721SSacha Ballantyne builder.create<fir::ResultOp>(loc, basicValue); 849614cd721SSacha Ballantyne 850614cd721SSacha Ballantyne builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 851614cd721SSacha Ballantyne } 852223d3dabSDavid Green auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc, 853223d3dabSDavid Green const mlir::Type &resultElemType, mlir::Value resultArr, 854223d3dabSDavid Green mlir::Value index) { 855223d3dabSDavid Green mlir::Type resultRefTy = builder.getRefType(resultElemType); 856223d3dabSDavid Green return builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr, 857223d3dabSDavid Green index); 858223d3dabSDavid Green }; 859614cd721SSacha Ballantyne 860223d3dabSDavid Green genMinMaxlocReductionLoop(builder, funcOp.front().getArgument(1), init, 861223d3dabSDavid Green genBodyOp, getAddrFn, rank, elementType, loc, 862223d3dabSDavid Green maskElemType, resultArr, maskRank == 0); 863223d3dabSDavid Green 864223d3dabSDavid Green // Store newly created output array to the reference passed in 8652a95fe48SDavid Green if (isDim) { 8662a95fe48SDavid Green mlir::Type resultBoxTy = 8672a95fe48SDavid Green fir::BoxType::get(fir::HeapType::get(resultElemTy)); 8682a95fe48SDavid Green mlir::Value outputArr = builder.create<fir::ConvertOp>( 8692a95fe48SDavid Green loc, builder.getRefType(resultBoxTy), funcOp.front().getArgument(0)); 8702a95fe48SDavid Green mlir::Value resultArrScalar = builder.create<fir::ConvertOp>( 8712a95fe48SDavid Green loc, fir::HeapType::get(resultElemTy), resultArrInit); 8722a95fe48SDavid Green mlir::Value resultBox = 8732a95fe48SDavid Green builder.create<fir::EmboxOp>(loc, resultBoxTy, resultArrScalar); 8742a95fe48SDavid Green builder.create<fir::StoreOp>(loc, resultBox, outputArr); 8752a95fe48SDavid Green } else { 876223d3dabSDavid Green fir::SequenceType::Shape resultShape(1, rank); 877223d3dabSDavid Green mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemTy); 878223d3dabSDavid Green mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy); 879223d3dabSDavid Green mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy); 880223d3dabSDavid Green mlir::Type outputRefTy = builder.getRefType(outputBoxTy); 881223d3dabSDavid Green mlir::Value outputArr = builder.create<fir::ConvertOp>( 882223d3dabSDavid Green loc, outputRefTy, funcOp.front().getArgument(0)); 883223d3dabSDavid Green builder.create<fir::StoreOp>(loc, resultArr, outputArr); 8842a95fe48SDavid Green } 8852a95fe48SDavid Green 886223d3dabSDavid Green builder.create<mlir::func::ReturnOp>(loc); 887614cd721SSacha Ballantyne } 888614cd721SSacha Ballantyne 889aa94eb38SMats Petersson /// Generate function type for the simplified version of RTNAME(DotProduct) 8901d5e7a49SSlava Zakharin /// operating on the given \p elementType. 891aa94eb38SMats Petersson static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder, 8921d5e7a49SSlava Zakharin const mlir::Type &elementType) { 8931d5e7a49SSlava Zakharin mlir::Type boxType = fir::BoxType::get(builder.getNoneType()); 8941d5e7a49SSlava Zakharin return mlir::FunctionType::get(builder.getContext(), {boxType, boxType}, 8951d5e7a49SSlava Zakharin {elementType}); 8961d5e7a49SSlava Zakharin } 8971d5e7a49SSlava Zakharin 898aa94eb38SMats Petersson /// Generate function body of the simplified version of RTNAME(DotProduct) 8991d5e7a49SSlava Zakharin /// with signature provided by \p funcOp. The caller is responsible 9001d5e7a49SSlava Zakharin /// for saving/restoring the original insertion point of \p builder. 9011d5e7a49SSlava Zakharin /// \p funcOp is expected to be empty on entry to this function. 90256eda98fSSlava Zakharin /// \p arg1ElementTy and \p arg2ElementTy specify elements types 90356eda98fSSlava Zakharin /// of the underlying array objects - they are used to generate proper 90456eda98fSSlava Zakharin /// element accesses. 905aa94eb38SMats Petersson static void genRuntimeDotBody(fir::FirOpBuilder &builder, 90656eda98fSSlava Zakharin mlir::func::FuncOp &funcOp, 90756eda98fSSlava Zakharin mlir::Type arg1ElementTy, 90856eda98fSSlava Zakharin mlir::Type arg2ElementTy) { 909aa94eb38SMats Petersson // function RTNAME(DotProduct)<T>_simplified(arr1, arr2) 9101d5e7a49SSlava Zakharin // T, dimension(:) :: arr1, arr2 9111d5e7a49SSlava Zakharin // T product = 0 9121d5e7a49SSlava Zakharin // integer iter 9131d5e7a49SSlava Zakharin // do iter = 0, extent(arr1) 9141d5e7a49SSlava Zakharin // product = product + arr1[iter] * arr2[iter] 9151d5e7a49SSlava Zakharin // end do 916aa94eb38SMats Petersson // RTNAME(ADotProduct)<T>_simplified = product 917aa94eb38SMats Petersson // end function RTNAME(DotProduct)<T>_simplified 9181d5e7a49SSlava Zakharin auto loc = mlir::UnknownLoc::get(builder.getContext()); 91956eda98fSSlava Zakharin mlir::Type resultElementType = funcOp.getResultTypes()[0]; 9201d5e7a49SSlava Zakharin builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 9211d5e7a49SSlava Zakharin 9221d5e7a49SSlava Zakharin mlir::IndexType idxTy = builder.getIndexType(); 9231d5e7a49SSlava Zakharin 92456eda98fSSlava Zakharin mlir::Value zero = 925fac349a1SChristian Sigg mlir::isa<mlir::FloatType>(resultElementType) 92656eda98fSSlava Zakharin ? builder.createRealConstant(loc, resultElementType, 0.0) 92756eda98fSSlava Zakharin : builder.createIntegerConstant(loc, resultElementType, 0); 9281d5e7a49SSlava Zakharin 9291d5e7a49SSlava Zakharin mlir::Block::BlockArgListType args = funcOp.front().getArguments(); 9301d5e7a49SSlava Zakharin mlir::Value arg1 = args[0]; 9311d5e7a49SSlava Zakharin mlir::Value arg2 = args[1]; 9321d5e7a49SSlava Zakharin 9331d5e7a49SSlava Zakharin mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); 9341d5e7a49SSlava Zakharin 9351d5e7a49SSlava Zakharin fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()}; 93656eda98fSSlava Zakharin mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy); 93756eda98fSSlava Zakharin mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1); 93856eda98fSSlava Zakharin mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1); 93956eda98fSSlava Zakharin mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy); 94056eda98fSSlava Zakharin mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2); 94156eda98fSSlava Zakharin mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2); 9421d5e7a49SSlava Zakharin // This version takes the loop trip count from the first argument. 9431d5e7a49SSlava Zakharin // If the first argument's box has unknown (at compilation time) 9441d5e7a49SSlava Zakharin // extent, then it may be better to take the extent from the second 9451d5e7a49SSlava Zakharin // argument - so that after inlining the loop may be better optimized, e.g. 9461d5e7a49SSlava Zakharin // fully unrolled. This requires generating two versions of the simplified 9471d5e7a49SSlava Zakharin // function and some analysis at the call site to choose which version 9481d5e7a49SSlava Zakharin // is more profitable to call. 9491d5e7a49SSlava Zakharin // Note that we can assume that both arguments have the same extent. 9501d5e7a49SSlava Zakharin auto dims = 9511d5e7a49SSlava Zakharin builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx); 9521d5e7a49SSlava Zakharin mlir::Value len = dims.getResult(1); 9531d5e7a49SSlava Zakharin mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); 9541d5e7a49SSlava Zakharin mlir::Value step = one; 9551d5e7a49SSlava Zakharin 9561d5e7a49SSlava Zakharin // We use C indexing here, so len-1 as loopcount 9571d5e7a49SSlava Zakharin mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one); 9581d5e7a49SSlava Zakharin auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, 9591d5e7a49SSlava Zakharin /*unordered=*/false, 9601d5e7a49SSlava Zakharin /*finalCountValue=*/false, zero); 9611d5e7a49SSlava Zakharin mlir::Value sumVal = loop.getRegionIterArgs()[0]; 9621d5e7a49SSlava Zakharin 9631d5e7a49SSlava Zakharin // Begin loop code 9641d5e7a49SSlava Zakharin mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint(); 9651d5e7a49SSlava Zakharin builder.setInsertionPointToStart(loop.getBody()); 9661d5e7a49SSlava Zakharin 96756eda98fSSlava Zakharin mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy); 9681d5e7a49SSlava Zakharin mlir::Value index = loop.getInductionVar(); 9691d5e7a49SSlava Zakharin mlir::Value addr1 = 97056eda98fSSlava Zakharin builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index); 9711d5e7a49SSlava Zakharin mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1); 97256eda98fSSlava Zakharin // Convert to the result type. 97356eda98fSSlava Zakharin elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1); 9741d5e7a49SSlava Zakharin 97556eda98fSSlava Zakharin mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy); 97656eda98fSSlava Zakharin mlir::Value addr2 = 97756eda98fSSlava Zakharin builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index); 97856eda98fSSlava Zakharin mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2); 97956eda98fSSlava Zakharin // Convert to the result type. 98056eda98fSSlava Zakharin elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2); 98156eda98fSSlava Zakharin 982fac349a1SChristian Sigg if (mlir::isa<mlir::FloatType>(resultElementType)) 9831d5e7a49SSlava Zakharin sumVal = builder.create<mlir::arith::AddFOp>( 9841d5e7a49SSlava Zakharin loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal); 985fac349a1SChristian Sigg else if (mlir::isa<mlir::IntegerType>(resultElementType)) 9861d5e7a49SSlava Zakharin sumVal = builder.create<mlir::arith::AddIOp>( 9871d5e7a49SSlava Zakharin loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal); 9881d5e7a49SSlava Zakharin else 9891d5e7a49SSlava Zakharin llvm_unreachable("unsupported type"); 9901d5e7a49SSlava Zakharin 9911d5e7a49SSlava Zakharin builder.create<fir::ResultOp>(loc, sumVal); 9921d5e7a49SSlava Zakharin // End of loop. 9931d5e7a49SSlava Zakharin builder.restoreInsertionPoint(loopEndPt); 9941d5e7a49SSlava Zakharin 9951d5e7a49SSlava Zakharin mlir::Value resultVal = loop.getResult(0); 9961d5e7a49SSlava Zakharin builder.create<mlir::func::ReturnOp>(loc, resultVal); 9971d5e7a49SSlava Zakharin } 9981d5e7a49SSlava Zakharin 99980dcc907SSlava Zakharin mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction( 100080dcc907SSlava Zakharin fir::FirOpBuilder &builder, const mlir::StringRef &baseName, 100180dcc907SSlava Zakharin FunctionTypeGeneratorTy typeGenerator, 100280dcc907SSlava Zakharin FunctionBodyGeneratorTy bodyGenerator) { 100380dcc907SSlava Zakharin // WARNING: if the function generated here changes its signature 100480dcc907SSlava Zakharin // or behavior (the body code), we should probably embed some 100580dcc907SSlava Zakharin // versioning information into its name, otherwise libraries 100680dcc907SSlava Zakharin // statically linked with older versions of Flang may stop 100780dcc907SSlava Zakharin // working with object files created with newer Flang. 100880dcc907SSlava Zakharin // We can also avoid this by using internal linkage, but 100980dcc907SSlava Zakharin // this may increase the size of final executable/shared library. 101080dcc907SSlava Zakharin std::string replacementName = mlir::Twine{baseName, "_simplified"}.str(); 101180dcc907SSlava Zakharin // If we already have a function, just return it. 1012a4798bb0SjeanPerier mlir::func::FuncOp newFunc = builder.getNamedFunction(replacementName); 101380dcc907SSlava Zakharin mlir::FunctionType fType = typeGenerator(builder); 101480dcc907SSlava Zakharin if (newFunc) { 101580dcc907SSlava Zakharin assert(newFunc.getFunctionType() == fType && 101680dcc907SSlava Zakharin "type mismatch for simplified function"); 101780dcc907SSlava Zakharin return newFunc; 101880dcc907SSlava Zakharin } 101980dcc907SSlava Zakharin 102080dcc907SSlava Zakharin // Need to build the function! 102180dcc907SSlava Zakharin auto loc = mlir::UnknownLoc::get(builder.getContext()); 1022a4798bb0SjeanPerier newFunc = builder.createFunction(loc, replacementName, fType); 102380dcc907SSlava Zakharin auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR; 102480dcc907SSlava Zakharin auto linkage = 102580dcc907SSlava Zakharin mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); 102680dcc907SSlava Zakharin newFunc->setAttr("llvm.linkage", linkage); 102780dcc907SSlava Zakharin 102880dcc907SSlava Zakharin // Save the position of the original call. 102980dcc907SSlava Zakharin mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint(); 103080dcc907SSlava Zakharin 103180dcc907SSlava Zakharin bodyGenerator(builder, newFunc); 10326e193b5cSMats Petersson 10336e193b5cSMats Petersson // Now back to where we were adding code earlier... 10346e193b5cSMats Petersson builder.restoreInsertionPoint(insertPt); 10356e193b5cSMats Petersson 10366e193b5cSMats Petersson return newFunc; 10376e193b5cSMats Petersson } 10386e193b5cSMats Petersson 10397d2e1987SSacha Ballantyne void SimplifyIntrinsicsPass::simplifyIntOrFloatReduction( 10407d2e1987SSacha Ballantyne fir::CallOp call, const fir::KindMapping &kindMap, 104143159b58SMats Petersson GenReductionBodyTy genBodyFunc) { 104243159b58SMats Petersson // args[1] and args[2] are source filename and line number, ignored. 10437d2e1987SSacha Ballantyne mlir::Operation::operand_range args = call.getArgs(); 10447d2e1987SSacha Ballantyne 104543159b58SMats Petersson const mlir::Value &dim = args[3]; 104643159b58SMats Petersson const mlir::Value &mask = args[4]; 104743159b58SMats Petersson // dim is zero when it is absent, which is an implementation 104843159b58SMats Petersson // detail in the runtime library. 10497d2e1987SSacha Ballantyne 105043159b58SMats Petersson bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask); 105143159b58SMats Petersson unsigned rank = getDimCount(args[0]); 10522b138567SSlava Zakharin 1053bb94d33aSSacha Ballantyne // Rank is set to 0 for assumed shape arrays, don't simplify 1054bb94d33aSSacha Ballantyne // in these cases 10557d2e1987SSacha Ballantyne if (!(dimAndMaskAbsent && rank > 0)) 10567d2e1987SSacha Ballantyne return; 10577d2e1987SSacha Ballantyne 10582b138567SSlava Zakharin mlir::Type resultType = call.getResult(0).getType(); 10597d2e1987SSacha Ballantyne 1060fac349a1SChristian Sigg if (!mlir::isa<mlir::FloatType>(resultType) && 1061fac349a1SChristian Sigg !mlir::isa<mlir::IntegerType>(resultType)) 106243159b58SMats Petersson return; 10632b138567SSlava Zakharin 10642b138567SSlava Zakharin auto argType = getArgElementType(args[0]); 10652b138567SSlava Zakharin if (!argType) 10662b138567SSlava Zakharin return; 10672b138567SSlava Zakharin assert(*argType == resultType && 10682b138567SSlava Zakharin "Argument/result types mismatch in reduction"); 10692b138567SSlava Zakharin 10707d2e1987SSacha Ballantyne mlir::SymbolRefAttr callee = call.getCalleeAttr(); 10717d2e1987SSacha Ballantyne 10727d2e1987SSacha Ballantyne fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; 1073f52c64b1SDavid Truby std::string fmfString{builder.getFastMathFlagsString()}; 10747d2e1987SSacha Ballantyne std::string funcName = 10757d2e1987SSacha Ballantyne (mlir::Twine{callee.getLeafReference().getValue(), "x"} + 10767d2e1987SSacha Ballantyne mlir::Twine{rank} + 10777d2e1987SSacha Ballantyne // We must mangle the generated function name with FastMathFlags 10787d2e1987SSacha Ballantyne // value. 10797d2e1987SSacha Ballantyne (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString})) 10807d2e1987SSacha Ballantyne .str(); 10817d2e1987SSacha Ballantyne 108220fba03fSSacha Ballantyne simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName, 108320fba03fSSacha Ballantyne resultType); 10847d2e1987SSacha Ballantyne } 10857d2e1987SSacha Ballantyne 108620fba03fSSacha Ballantyne void SimplifyIntrinsicsPass::simplifyLogicalDim0Reduction( 10877d2e1987SSacha Ballantyne fir::CallOp call, const fir::KindMapping &kindMap, 10887d2e1987SSacha Ballantyne GenReductionBodyTy genBodyFunc) { 10897d2e1987SSacha Ballantyne 10907d2e1987SSacha Ballantyne mlir::Operation::operand_range args = call.getArgs(); 10917d2e1987SSacha Ballantyne const mlir::Value &dim = args[3]; 1092bb94d33aSSacha Ballantyne unsigned rank = getDimCount(args[0]); 10937d2e1987SSacha Ballantyne 109420fba03fSSacha Ballantyne // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in 109520fba03fSSacha Ballantyne // these cases. 1096bb94d33aSSacha Ballantyne if (!(isZero(dim) && rank > 0)) 10977d2e1987SSacha Ballantyne return; 10987d2e1987SSacha Ballantyne 109920fba03fSSacha Ballantyne mlir::Value inputBox = findBoxDef(args[0]); 110020fba03fSSacha Ballantyne 110120fba03fSSacha Ballantyne mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType()); 11027d2e1987SSacha Ballantyne mlir::SymbolRefAttr callee = call.getCalleeAttr(); 11037d2e1987SSacha Ballantyne 11047d2e1987SSacha Ballantyne fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; 110520fba03fSSacha Ballantyne 110620fba03fSSacha Ballantyne // Treating logicals as integers makes things a lot easier 1107fac349a1SChristian Sigg fir::LogicalType logicalType = { 1108fac349a1SChristian Sigg mlir::dyn_cast<fir::LogicalType>(elementType)}; 110920fba03fSSacha Ballantyne fir::KindTy kind = logicalType.getFKind(); 1110614cd721SSacha Ballantyne mlir::Type intElementType = builder.getIntegerType(kind * 8); 111120fba03fSSacha Ballantyne 111220fba03fSSacha Ballantyne // Mangle kind into function name as it is not done by default 11137d2e1987SSacha Ballantyne std::string funcName = 111420fba03fSSacha Ballantyne (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} + 111520fba03fSSacha Ballantyne mlir::Twine{kind} + "x" + mlir::Twine{rank}) 11167d2e1987SSacha Ballantyne .str(); 11177d2e1987SSacha Ballantyne 111820fba03fSSacha Ballantyne simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName, 111920fba03fSSacha Ballantyne intElementType); 112020fba03fSSacha Ballantyne } 112120fba03fSSacha Ballantyne 112220fba03fSSacha Ballantyne void SimplifyIntrinsicsPass::simplifyLogicalDim1Reduction( 112320fba03fSSacha Ballantyne fir::CallOp call, const fir::KindMapping &kindMap, 112420fba03fSSacha Ballantyne GenReductionBodyTy genBodyFunc) { 112520fba03fSSacha Ballantyne 112620fba03fSSacha Ballantyne mlir::Operation::operand_range args = call.getArgs(); 112720fba03fSSacha Ballantyne mlir::SymbolRefAttr callee = call.getCalleeAttr(); 112820fba03fSSacha Ballantyne mlir::StringRef funcNameBase = callee.getLeafReference().getValue(); 112920fba03fSSacha Ballantyne unsigned rank = getDimCount(args[0]); 113020fba03fSSacha Ballantyne 113120fba03fSSacha Ballantyne // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in 113220fba03fSSacha Ballantyne // these cases. We check for Dim at the end as some logical functions (Any, 113320fba03fSSacha Ballantyne // All) set dim to 1 instead of 0 when the argument is not present. 113420fba03fSSacha Ballantyne if (funcNameBase.ends_with("Dim") || !(rank > 0)) 113520fba03fSSacha Ballantyne return; 113620fba03fSSacha Ballantyne 113720fba03fSSacha Ballantyne mlir::Value inputBox = findBoxDef(args[0]); 113820fba03fSSacha Ballantyne mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType()); 113920fba03fSSacha Ballantyne 114020fba03fSSacha Ballantyne fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; 114120fba03fSSacha Ballantyne 114220fba03fSSacha Ballantyne // Treating logicals as integers makes things a lot easier 1143fac349a1SChristian Sigg fir::LogicalType logicalType = { 1144fac349a1SChristian Sigg mlir::dyn_cast<fir::LogicalType>(elementType)}; 114520fba03fSSacha Ballantyne fir::KindTy kind = logicalType.getFKind(); 1146614cd721SSacha Ballantyne mlir::Type intElementType = builder.getIntegerType(kind * 8); 114720fba03fSSacha Ballantyne 114820fba03fSSacha Ballantyne // Mangle kind into function name as it is not done by default 114920fba03fSSacha Ballantyne std::string funcName = 115020fba03fSSacha Ballantyne (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} + 115120fba03fSSacha Ballantyne mlir::Twine{kind} + "x" + mlir::Twine{rank}) 115220fba03fSSacha Ballantyne .str(); 115320fba03fSSacha Ballantyne 115420fba03fSSacha Ballantyne simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName, 115520fba03fSSacha Ballantyne intElementType); 11567d2e1987SSacha Ballantyne } 11577d2e1987SSacha Ballantyne 11589bb47f7fSDavid Green void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction( 11599bb47f7fSDavid Green fir::CallOp call, const fir::KindMapping &kindMap, bool isMax) { 1160614cd721SSacha Ballantyne 1161614cd721SSacha Ballantyne mlir::Operation::operand_range args = call.getArgs(); 1162614cd721SSacha Ballantyne 11632a95fe48SDavid Green mlir::SymbolRefAttr callee = call.getCalleeAttr(); 11642a95fe48SDavid Green mlir::StringRef funcNameBase = callee.getLeafReference().getValue(); 11652a95fe48SDavid Green bool isDim = funcNameBase.ends_with("Dim"); 11662a95fe48SDavid Green mlir::Value back = args[isDim ? 7 : 6]; 1167614cd721SSacha Ballantyne if (isTrueOrNotConstant(back)) 1168614cd721SSacha Ballantyne return; 1169614cd721SSacha Ballantyne 11702a95fe48SDavid Green mlir::Value mask = args[isDim ? 6 : 5]; 1171614cd721SSacha Ballantyne mlir::Value maskDef = findMaskDef(mask); 1172614cd721SSacha Ballantyne 1173614cd721SSacha Ballantyne // maskDef is set to NULL when the defining op is not one we accept. 1174614cd721SSacha Ballantyne // This tends to be because it is a selectOp, in which case let the 1175614cd721SSacha Ballantyne // runtime deal with it. 1176614cd721SSacha Ballantyne if (maskDef == NULL) 1177614cd721SSacha Ballantyne return; 1178614cd721SSacha Ballantyne 1179614cd721SSacha Ballantyne unsigned rank = getDimCount(args[1]); 11802a95fe48SDavid Green if ((isDim && rank != 1) || !(rank > 0)) 1181614cd721SSacha Ballantyne return; 1182614cd721SSacha Ballantyne 1183614cd721SSacha Ballantyne fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; 1184614cd721SSacha Ballantyne mlir::Location loc = call.getLoc(); 1185614cd721SSacha Ballantyne auto inputBox = findBoxDef(args[1]); 1186614cd721SSacha Ballantyne mlir::Type inputType = hlfir::getFortranElementType(inputBox.getType()); 1187614cd721SSacha Ballantyne 1188fac349a1SChristian Sigg if (mlir::isa<fir::CharacterType>(inputType)) 1189614cd721SSacha Ballantyne return; 1190614cd721SSacha Ballantyne 1191614cd721SSacha Ballantyne int maskRank; 1192614cd721SSacha Ballantyne fir::KindTy kind = 0; 1193242bb0b6SSacha Ballantyne mlir::Type logicalElemType = builder.getI1Type(); 1194614cd721SSacha Ballantyne if (isOperandAbsent(mask)) { 1195614cd721SSacha Ballantyne maskRank = -1; 1196614cd721SSacha Ballantyne } else { 1197614cd721SSacha Ballantyne maskRank = getDimCount(mask); 1198614cd721SSacha Ballantyne mlir::Type maskElemTy = hlfir::getFortranElementType(maskDef.getType()); 1199fac349a1SChristian Sigg fir::LogicalType logicalFirType = { 1200fac349a1SChristian Sigg mlir::dyn_cast<fir::LogicalType>(maskElemTy)}; 1201242bb0b6SSacha Ballantyne kind = logicalFirType.getFKind(); 1202242bb0b6SSacha Ballantyne // Convert fir::LogicalType to mlir::Type 1203242bb0b6SSacha Ballantyne logicalElemType = logicalFirType; 1204614cd721SSacha Ballantyne } 1205614cd721SSacha Ballantyne 1206614cd721SSacha Ballantyne mlir::Operation *outputDef = args[0].getDefiningOp(); 1207614cd721SSacha Ballantyne mlir::Value outputAlloc = outputDef->getOperand(0); 1208614cd721SSacha Ballantyne mlir::Type outType = hlfir::getFortranElementType(outputAlloc.getType()); 1209614cd721SSacha Ballantyne 1210f52c64b1SDavid Truby std::string fmfString{builder.getFastMathFlagsString()}; 1211614cd721SSacha Ballantyne std::string funcName = 1212614cd721SSacha Ballantyne (mlir::Twine{callee.getLeafReference().getValue(), "x"} + 1213614cd721SSacha Ballantyne mlir::Twine{rank} + 1214614cd721SSacha Ballantyne (maskRank >= 0 1215614cd721SSacha Ballantyne ? "_Logical" + mlir::Twine{kind} + "x" + mlir::Twine{maskRank} 1216614cd721SSacha Ballantyne : "") + 1217614cd721SSacha Ballantyne "_") 1218614cd721SSacha Ballantyne .str(); 1219614cd721SSacha Ballantyne 1220614cd721SSacha Ballantyne llvm::raw_string_ostream nameOS(funcName); 1221614cd721SSacha Ballantyne outType.print(nameOS); 12222a95fe48SDavid Green if (isDim) 12232a95fe48SDavid Green nameOS << '_' << inputType; 1224614cd721SSacha Ballantyne nameOS << '_' << fmfString; 1225614cd721SSacha Ballantyne 1226614cd721SSacha Ballantyne auto typeGenerator = [rank](fir::FirOpBuilder &builder) { 1227614cd721SSacha Ballantyne return genRuntimeMinlocType(builder, rank); 1228614cd721SSacha Ballantyne }; 12299bb47f7fSDavid Green auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType, 12302a95fe48SDavid Green isMax, isDim](fir::FirOpBuilder &builder, 1231614cd721SSacha Ballantyne mlir::func::FuncOp &funcOp) { 12329bb47f7fSDavid Green genRuntimeMinMaxlocBody(builder, funcOp, isMax, rank, maskRank, inputType, 12332a95fe48SDavid Green logicalElemType, outType, isDim); 1234614cd721SSacha Ballantyne }; 1235614cd721SSacha Ballantyne 1236614cd721SSacha Ballantyne mlir::func::FuncOp newFunc = 1237614cd721SSacha Ballantyne getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator); 1238614cd721SSacha Ballantyne builder.create<fir::CallOp>(loc, newFunc, 12392a95fe48SDavid Green mlir::ValueRange{args[0], args[1], mask}); 1240614cd721SSacha Ballantyne call->dropAllReferences(); 1241614cd721SSacha Ballantyne call->erase(); 1242614cd721SSacha Ballantyne } 1243614cd721SSacha Ballantyne 12447d2e1987SSacha Ballantyne void SimplifyIntrinsicsPass::simplifyReductionBody( 12457d2e1987SSacha Ballantyne fir::CallOp call, const fir::KindMapping &kindMap, 12467d2e1987SSacha Ballantyne GenReductionBodyTy genBodyFunc, fir::FirOpBuilder &builder, 124720fba03fSSacha Ballantyne const mlir::StringRef &funcName, mlir::Type elementType) { 12487d2e1987SSacha Ballantyne 12497d2e1987SSacha Ballantyne mlir::Operation::operand_range args = call.getArgs(); 12507d2e1987SSacha Ballantyne 12517d2e1987SSacha Ballantyne mlir::Type resultType = call.getResult(0).getType(); 12527d2e1987SSacha Ballantyne unsigned rank = getDimCount(args[0]); 12537d2e1987SSacha Ballantyne 12547d2e1987SSacha Ballantyne mlir::Location loc = call.getLoc(); 12557d2e1987SSacha Ballantyne 12562b138567SSlava Zakharin auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) { 12572b138567SSlava Zakharin return genNoneBoxType(builder, resultType); 125843159b58SMats Petersson }; 125920fba03fSSacha Ballantyne auto bodyGenerator = [&rank, &genBodyFunc, 126020fba03fSSacha Ballantyne &elementType](fir::FirOpBuilder &builder, 12618bd76ac1SSlava Zakharin mlir::func::FuncOp &funcOp) { 126220fba03fSSacha Ballantyne genBodyFunc(builder, funcOp, rank, elementType); 12638bd76ac1SSlava Zakharin }; 12648bd76ac1SSlava Zakharin // Mangle the function name with the rank value as "x<rank>". 126543159b58SMats Petersson mlir::func::FuncOp newFunc = 12668bd76ac1SSlava Zakharin getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator); 126743159b58SMats Petersson auto newCall = 126843159b58SMats Petersson builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]}); 126943159b58SMats Petersson call->replaceAllUsesWith(newCall.getResults()); 127043159b58SMats Petersson call->dropAllReferences(); 127143159b58SMats Petersson call->erase(); 127243159b58SMats Petersson } 127343159b58SMats Petersson 12746e193b5cSMats Petersson void SimplifyIntrinsicsPass::runOnOperation() { 12751d5e7a49SSlava Zakharin LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n"); 12766e193b5cSMats Petersson mlir::ModuleOp module = getOperation(); 12776e193b5cSMats Petersson fir::KindMapping kindMap = fir::getKindMapping(module); 12786e193b5cSMats Petersson module.walk([&](mlir::Operation *op) { 12796e193b5cSMats Petersson if (auto call = mlir::dyn_cast<fir::CallOp>(op)) { 1280a76609ddSValentin Clement (バレンタイン クレメン) if (cuf::isInCUDADeviceContext(op)) 1281a76609ddSValentin Clement (バレンタイン クレメン) return; 12826e193b5cSMats Petersson if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) { 12836e193b5cSMats Petersson mlir::StringRef funcName = callee.getLeafReference().getValue(); 12846e193b5cSMats Petersson // Replace call to runtime function for SUM when it has single 12856e193b5cSMats Petersson // argument (no dim or mask argument) for 1D arrays with either 12866e193b5cSMats Petersson // Integer4 or Real8 types. Other forms are ignored. 12876e193b5cSMats Petersson // The new function is added to the module. 12886e193b5cSMats Petersson // 12896e193b5cSMats Petersson // Prototype for runtime call (from sum.cpp): 12906e193b5cSMats Petersson // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line, 12916e193b5cSMats Petersson // int dim, const Descriptor *mask) 129211db65baSSlava Zakharin // 129311efcceaSKazu Hirata if (funcName.starts_with(RTNAME_STRING(Sum))) { 12947d2e1987SSacha Ballantyne simplifyIntOrFloatReduction(call, kindMap, genRuntimeSumBody); 12951d5e7a49SSlava Zakharin return; 12961d5e7a49SSlava Zakharin } 129711efcceaSKazu Hirata if (funcName.starts_with(RTNAME_STRING(DotProduct))) { 12981d5e7a49SSlava Zakharin LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n"); 12991d5e7a49SSlava Zakharin LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump(); 13001d5e7a49SSlava Zakharin llvm::dbgs() << "\n"); 13011d5e7a49SSlava Zakharin mlir::Operation::operand_range args = call.getArgs(); 13021d5e7a49SSlava Zakharin const mlir::Value &v1 = args[0]; 13031d5e7a49SSlava Zakharin const mlir::Value &v2 = args[1]; 13041d5e7a49SSlava Zakharin mlir::Location loc = call.getLoc(); 1305ffe1661fSSlava Zakharin fir::FirOpBuilder builder{getSimplificationBuilder(op, kindMap)}; 1306ffe1661fSSlava Zakharin // Stringize the builder's FastMathFlags flags for mangling 1307ffe1661fSSlava Zakharin // the generated function name. 1308f52c64b1SDavid Truby std::string fmfString{builder.getFastMathFlagsString()}; 1309afa520abSMats Petersson 13101d5e7a49SSlava Zakharin mlir::Type type = call.getResult(0).getType(); 1311fac349a1SChristian Sigg if (!mlir::isa<mlir::FloatType>(type) && 1312fac349a1SChristian Sigg !mlir::isa<mlir::IntegerType>(type)) 13131d5e7a49SSlava Zakharin return; 13141d5e7a49SSlava Zakharin 131556eda98fSSlava Zakharin // Try to find the element types of the boxed arguments. 131656eda98fSSlava Zakharin auto arg1Type = getArgElementType(v1); 131756eda98fSSlava Zakharin auto arg2Type = getArgElementType(v2); 131856eda98fSSlava Zakharin 131956eda98fSSlava Zakharin if (!arg1Type || !arg2Type) 132056eda98fSSlava Zakharin return; 132156eda98fSSlava Zakharin 132256eda98fSSlava Zakharin // Support only floating point and integer arguments 132356eda98fSSlava Zakharin // now (e.g. logical is skipped here). 1324bd9fdce6SChristian Sigg if (!mlir::isa<mlir::FloatType, mlir::IntegerType>(*arg1Type)) 132556eda98fSSlava Zakharin return; 1326bd9fdce6SChristian Sigg if (!mlir::isa<mlir::FloatType, mlir::IntegerType>(*arg2Type)) 132756eda98fSSlava Zakharin return; 132856eda98fSSlava Zakharin 13291d5e7a49SSlava Zakharin auto typeGenerator = [&type](fir::FirOpBuilder &builder) { 1330aa94eb38SMats Petersson return genRuntimeDotType(builder, type); 13311d5e7a49SSlava Zakharin }; 133256eda98fSSlava Zakharin auto bodyGenerator = [&arg1Type, 133356eda98fSSlava Zakharin &arg2Type](fir::FirOpBuilder &builder, 133456eda98fSSlava Zakharin mlir::func::FuncOp &funcOp) { 1335aa94eb38SMats Petersson genRuntimeDotBody(builder, funcOp, *arg1Type, *arg2Type); 133656eda98fSSlava Zakharin }; 133756eda98fSSlava Zakharin 133856eda98fSSlava Zakharin // Suffix the function name with the element types 133956eda98fSSlava Zakharin // of the arguments. 134056eda98fSSlava Zakharin std::string typedFuncName(funcName); 134156eda98fSSlava Zakharin llvm::raw_string_ostream nameOS(typedFuncName); 1342ffe1661fSSlava Zakharin // We must mangle the generated function name with FastMathFlags 1343ffe1661fSSlava Zakharin // value. 1344ffe1661fSSlava Zakharin if (!fmfString.empty()) 1345ffe1661fSSlava Zakharin nameOS << '_' << fmfString; 1346ffe1661fSSlava Zakharin nameOS << '_'; 134756eda98fSSlava Zakharin arg1Type->print(nameOS); 1348ffe1661fSSlava Zakharin nameOS << '_'; 134956eda98fSSlava Zakharin arg2Type->print(nameOS); 135056eda98fSSlava Zakharin 13511d5e7a49SSlava Zakharin mlir::func::FuncOp newFunc = getOrCreateFunction( 135256eda98fSSlava Zakharin builder, typedFuncName, typeGenerator, bodyGenerator); 13531d5e7a49SSlava Zakharin auto newCall = builder.create<fir::CallOp>(loc, newFunc, 13541d5e7a49SSlava Zakharin mlir::ValueRange{v1, v2}); 13551d5e7a49SSlava Zakharin call->replaceAllUsesWith(newCall.getResults()); 13561d5e7a49SSlava Zakharin call->dropAllReferences(); 13571d5e7a49SSlava Zakharin call->erase(); 13581d5e7a49SSlava Zakharin 13591d5e7a49SSlava Zakharin LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump(); 13601d5e7a49SSlava Zakharin llvm::dbgs() << "\n"); 13611d5e7a49SSlava Zakharin return; 13626e193b5cSMats Petersson } 136311efcceaSKazu Hirata if (funcName.starts_with(RTNAME_STRING(Maxval))) { 13647d2e1987SSacha Ballantyne simplifyIntOrFloatReduction(call, kindMap, genRuntimeMaxvalBody); 13657d2e1987SSacha Ballantyne return; 13667d2e1987SSacha Ballantyne } 136711efcceaSKazu Hirata if (funcName.starts_with(RTNAME_STRING(Count))) { 136820fba03fSSacha Ballantyne simplifyLogicalDim0Reduction(call, kindMap, genRuntimeCountBody); 136920fba03fSSacha Ballantyne return; 137020fba03fSSacha Ballantyne } 137111efcceaSKazu Hirata if (funcName.starts_with(RTNAME_STRING(Any))) { 137220fba03fSSacha Ballantyne simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAnyBody); 137320fba03fSSacha Ballantyne return; 137420fba03fSSacha Ballantyne } 137511efcceaSKazu Hirata if (funcName.ends_with(RTNAME_STRING(All))) { 137620fba03fSSacha Ballantyne simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAllBody); 1377afa520abSMats Petersson return; 1378afa520abSMats Petersson } 137911efcceaSKazu Hirata if (funcName.starts_with(RTNAME_STRING(Minloc))) { 13809bb47f7fSDavid Green simplifyMinMaxlocReduction(call, kindMap, false); 13819bb47f7fSDavid Green return; 13829bb47f7fSDavid Green } 13839bb47f7fSDavid Green if (funcName.starts_with(RTNAME_STRING(Maxloc))) { 13849bb47f7fSDavid Green simplifyMinMaxlocReduction(call, kindMap, true); 1385614cd721SSacha Ballantyne return; 1386614cd721SSacha Ballantyne } 13876e193b5cSMats Petersson } 13886e193b5cSMats Petersson } 13896e193b5cSMats Petersson }); 13901d5e7a49SSlava Zakharin LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n"); 13916e193b5cSMats Petersson } 13926e193b5cSMats Petersson 13931d5e7a49SSlava Zakharin void SimplifyIntrinsicsPass::getDependentDialects( 13941d5e7a49SSlava Zakharin mlir::DialectRegistry ®istry) const { 13951d5e7a49SSlava Zakharin // LLVM::LinkageAttr creation requires that LLVM dialect is loaded. 13961d5e7a49SSlava Zakharin registry.insert<mlir::LLVM::LLVMDialect>(); 13971d5e7a49SSlava Zakharin } 1398