1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 //===----------------------------------------------------------------------===// 10 /// \file 11 /// This pass looks for suitable calls to runtime library for intrinsics that 12 /// can be simplified/specialized and replaces with a specialized function. 13 /// 14 /// For example, SUM(arr) can be specialized as a simple function with one loop, 15 /// compared to the three arguments (plus file & line info) that the runtime 16 /// call has - when the argument is a 1D-array (multiple loops may be needed 17 // for higher dimension arrays, of course) 18 /// 19 /// The general idea is that besides making the call simpler, it can also be 20 /// inlined by other passes that run after this pass, which further improves 21 /// performance, particularly when the work done in the function is trivial 22 /// and small in size. 23 //===----------------------------------------------------------------------===// 24 25 #include "flang/Common/Fortran.h" 26 #include "flang/Optimizer/Builder/BoxValue.h" 27 #include "flang/Optimizer/Builder/FIRBuilder.h" 28 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" 29 #include "flang/Optimizer/Builder/Todo.h" 30 #include "flang/Optimizer/Dialect/FIROps.h" 31 #include "flang/Optimizer/Dialect/FIRType.h" 32 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 33 #include "flang/Optimizer/HLFIR/HLFIRDialect.h" 34 #include "flang/Optimizer/Transforms/Passes.h" 35 #include "flang/Optimizer/Transforms/Utils.h" 36 #include "flang/Runtime/entry-names.h" 37 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 38 #include "mlir/IR/Matchers.h" 39 #include "mlir/IR/Operation.h" 40 #include "mlir/Pass/Pass.h" 41 #include "mlir/Transforms/DialectConversion.h" 42 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 43 #include "mlir/Transforms/RegionUtils.h" 44 #include "llvm/Support/Debug.h" 45 #include "llvm/Support/raw_ostream.h" 46 #include <llvm/Support/ErrorHandling.h> 47 #include <mlir/Dialect/Arith/IR/Arith.h> 48 #include <mlir/IR/BuiltinTypes.h> 49 #include <mlir/IR/Location.h> 50 #include <mlir/IR/MLIRContext.h> 51 #include <mlir/IR/Value.h> 52 #include <mlir/Support/LLVM.h> 53 #include <optional> 54 55 namespace fir { 56 #define GEN_PASS_DEF_SIMPLIFYINTRINSICS 57 #include "flang/Optimizer/Transforms/Passes.h.inc" 58 } // namespace fir 59 60 #define DEBUG_TYPE "flang-simplify-intrinsics" 61 62 namespace { 63 64 class SimplifyIntrinsicsPass 65 : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> { 66 using FunctionTypeGeneratorTy = 67 llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>; 68 using FunctionBodyGeneratorTy = 69 llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>; 70 using GenReductionBodyTy = llvm::function_ref<void( 71 fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank, 72 mlir::Type elementType)>; 73 74 public: 75 using fir::impl::SimplifyIntrinsicsBase< 76 SimplifyIntrinsicsPass>::SimplifyIntrinsicsBase; 77 78 /// Generate a new function implementing a simplified version 79 /// of a Fortran runtime function defined by \p basename name. 80 /// \p typeGenerator is a callback that generates the new function's type. 81 /// \p bodyGenerator is a callback that generates the new function's body. 82 /// The new function is created in the \p builder's Module. 83 mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder, 84 const mlir::StringRef &basename, 85 FunctionTypeGeneratorTy typeGenerator, 86 FunctionBodyGeneratorTy bodyGenerator); 87 void runOnOperation() override; 88 void getDependentDialects(mlir::DialectRegistry ®istry) const override; 89 90 private: 91 /// Helper functions to replace a reduction type of call with its 92 /// simplified form. The actual function is generated using a callback 93 /// function. 94 /// \p call is the call to be replaced 95 /// \p kindMap is used to create FIROpBuilder 96 /// \p genBodyFunc is the callback that builds the replacement function 97 void simplifyIntOrFloatReduction(fir::CallOp call, 98 const fir::KindMapping &kindMap, 99 GenReductionBodyTy genBodyFunc); 100 void simplifyLogicalDim0Reduction(fir::CallOp call, 101 const fir::KindMapping &kindMap, 102 GenReductionBodyTy genBodyFunc); 103 void simplifyLogicalDim1Reduction(fir::CallOp call, 104 const fir::KindMapping &kindMap, 105 GenReductionBodyTy genBodyFunc); 106 void simplifyMinMaxlocReduction(fir::CallOp call, 107 const fir::KindMapping &kindMap, bool isMax); 108 void simplifyReductionBody(fir::CallOp call, const fir::KindMapping &kindMap, 109 GenReductionBodyTy genBodyFunc, 110 fir::FirOpBuilder &builder, 111 const mlir::StringRef &basename, 112 mlir::Type elementType); 113 }; 114 115 } // namespace 116 117 /// Create FirOpBuilder with the provided \p op insertion point 118 /// and \p kindMap additionally inheriting FastMathFlags from \p op. 119 static fir::FirOpBuilder 120 getSimplificationBuilder(mlir::Operation *op, const fir::KindMapping &kindMap) { 121 fir::FirOpBuilder builder{op, kindMap}; 122 auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op); 123 if (!fmi) 124 return builder; 125 126 // Regardless of what default FastMathFlags are used by FirOpBuilder, 127 // override them with FastMathFlags attached to the operation. 128 builder.setFastMathFlags(fmi.getFastMathFlagsAttr().getValue()); 129 return builder; 130 } 131 132 /// Generate function type for the simplified version of RTNAME(Sum) and 133 /// similar functions with a fir.box<none> type returning \p elementType. 134 static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder, 135 const mlir::Type &elementType) { 136 mlir::Type boxType = fir::BoxType::get(builder.getNoneType()); 137 return mlir::FunctionType::get(builder.getContext(), {boxType}, 138 {elementType}); 139 } 140 141 template <typename Op> 142 Op expectOp(mlir::Value val) { 143 if (Op op = mlir::dyn_cast_or_null<Op>(val.getDefiningOp())) 144 return op; 145 LLVM_DEBUG(llvm::dbgs() << "Didn't find expected " << Op::getOperationName() 146 << '\n'); 147 return nullptr; 148 } 149 150 template <typename Op> 151 static mlir::Value findDefSingle(fir::ConvertOp op) { 152 if (auto defOp = expectOp<Op>(op->getOperand(0))) { 153 return defOp.getResult(); 154 } 155 return {}; 156 } 157 158 template <typename... Ops> 159 static mlir::Value findDef(fir::ConvertOp op) { 160 mlir::Value defOp; 161 // Loop over the operation types given to see if any match, exiting once 162 // a match is found. Cast to void is needed to avoid compiler complaining 163 // that the result of expression is unused 164 (void)((defOp = findDefSingle<Ops>(op), (defOp)) || ...); 165 return defOp; 166 } 167 168 static bool isOperandAbsent(mlir::Value val) { 169 if (auto op = expectOp<fir::ConvertOp>(val)) { 170 assert(op->getOperands().size() != 0); 171 return mlir::isa_and_nonnull<fir::AbsentOp>( 172 op->getOperand(0).getDefiningOp()); 173 } 174 return false; 175 } 176 177 static bool isTrueOrNotConstant(mlir::Value val) { 178 if (auto op = expectOp<mlir::arith::ConstantOp>(val)) { 179 return !mlir::matchPattern(val, mlir::m_Zero()); 180 } 181 return true; 182 } 183 184 static bool isZero(mlir::Value val) { 185 if (auto op = expectOp<fir::ConvertOp>(val)) { 186 assert(op->getOperands().size() != 0); 187 if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp()) 188 return mlir::matchPattern(defOp, mlir::m_Zero()); 189 } 190 return false; 191 } 192 193 static mlir::Value findBoxDef(mlir::Value val) { 194 if (auto op = expectOp<fir::ConvertOp>(val)) { 195 assert(op->getOperands().size() != 0); 196 return findDef<fir::EmboxOp, fir::ReboxOp>(op); 197 } 198 return {}; 199 } 200 201 static mlir::Value findMaskDef(mlir::Value val) { 202 if (auto op = expectOp<fir::ConvertOp>(val)) { 203 assert(op->getOperands().size() != 0); 204 return findDef<fir::EmboxOp, fir::ReboxOp, fir::AbsentOp>(op); 205 } 206 return {}; 207 } 208 209 static unsigned getDimCount(mlir::Value val) { 210 // In order to find the dimensions count, we look for EmboxOp/ReboxOp 211 // and take the count from its *result* type. Note that in case 212 // of sliced emboxing the operand and the result of EmboxOp/ReboxOp 213 // have different types. 214 // Actually, we can take the box type from the operand of 215 // the first ConvertOp that has non-opaque box type that we meet 216 // going through the ConvertOp chain. 217 if (mlir::Value emboxVal = findBoxDef(val)) 218 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(emboxVal.getType())) 219 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy())) 220 return seqTy.getDimension(); 221 return 0; 222 } 223 224 /// Given the call operation's box argument \p val, discover 225 /// the element type of the underlying array object. 226 /// \returns the element type or std::nullopt if the type cannot 227 /// be reliably found. 228 /// We expect that the argument is a result of fir.convert 229 /// with the destination type of !fir.box<none>. 230 static std::optional<mlir::Type> getArgElementType(mlir::Value val) { 231 mlir::Operation *defOp; 232 do { 233 defOp = val.getDefiningOp(); 234 // Analyze only sequences of convert operations. 235 if (!mlir::isa<fir::ConvertOp>(defOp)) 236 return std::nullopt; 237 val = defOp->getOperand(0); 238 // The convert operation is expected to convert from one 239 // box type to another box type. 240 auto boxType = mlir::cast<fir::BoxType>(val.getType()); 241 auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType); 242 if (!mlir::isa<mlir::NoneType>(elementType)) 243 return elementType; 244 } while (true); 245 } 246 247 using BodyOpGeneratorTy = llvm::function_ref<mlir::Value( 248 fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value, 249 mlir::Value)>; 250 using ContinueLoopGenTy = llvm::function_ref<llvm::SmallVector<mlir::Value>( 251 fir::FirOpBuilder &, mlir::Location, mlir::Value)>; 252 253 /// Generate the reduction loop into \p funcOp. 254 /// 255 /// \p initVal is a function, called to get the initial value for 256 /// the reduction value 257 /// \p genBody is called to fill in the actual reduciton operation 258 /// for example add for SUM, MAX for MAXVAL, etc. 259 /// \p rank is the rank of the input argument. 260 /// \p elementType is the type of the elements in the input array, 261 /// which may be different to the return type. 262 /// \p loopCond is called to generate the condition to continue or 263 /// not for IterWhile loops 264 /// \p unorderedOrInitalLoopCond contains either a boolean or bool 265 /// mlir constant, and controls the inital value for while loops 266 /// or if DoLoop is ordered/unordered. 267 268 template <typename OP, typename T, int resultIndex> 269 static void 270 genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, 271 fir::InitValGeneratorTy initVal, ContinueLoopGenTy loopCond, 272 T unorderedOrInitialLoopCond, BodyOpGeneratorTy genBody, 273 unsigned rank, mlir::Type elementType, mlir::Location loc) { 274 275 mlir::IndexType idxTy = builder.getIndexType(); 276 277 mlir::Block::BlockArgListType args = funcOp.front().getArguments(); 278 mlir::Value arg = args[0]; 279 280 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); 281 282 fir::SequenceType::Shape flatShape(rank, 283 fir::SequenceType::getUnknownExtent()); 284 mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType); 285 mlir::Type boxArrTy = fir::BoxType::get(arrTy); 286 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg); 287 mlir::Type resultType = funcOp.getResultTypes()[0]; 288 mlir::Value init = initVal(builder, loc, resultType); 289 290 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds; 291 292 assert(rank > 0 && "rank cannot be zero"); 293 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); 294 295 // Compute all the upper bounds before the loop nest. 296 // It is not strictly necessary for performance, since the loop nest 297 // does not have any store operations and any LICM optimization 298 // should be able to optimize the redundancy. 299 for (unsigned i = 0; i < rank; ++i) { 300 mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i); 301 auto dims = 302 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx); 303 mlir::Value len = dims.getResult(1); 304 // We use C indexing here, so len-1 as loopcount 305 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one); 306 bounds.push_back(loopCount); 307 } 308 // Create a loop nest consisting of OP operations. 309 // Collect the loops' induction variables into indices array, 310 // which will be used in the innermost loop to load the input 311 // array's element. 312 // The loops are generated such that the innermost loop processes 313 // the 0 dimension. 314 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices; 315 for (unsigned i = rank; 0 < i; --i) { 316 mlir::Value step = one; 317 mlir::Value loopCount = bounds[i - 1]; 318 auto loop = builder.create<OP>(loc, zeroIdx, loopCount, step, 319 unorderedOrInitialLoopCond, 320 /*finalCountValue=*/false, init); 321 init = loop.getRegionIterArgs()[resultIndex]; 322 indices.push_back(loop.getInductionVar()); 323 // Set insertion point to the loop body so that the next loop 324 // is inserted inside the current one. 325 builder.setInsertionPointToStart(loop.getBody()); 326 } 327 328 // Reverse the indices such that they are ordered as: 329 // <dim-0-idx, dim-1-idx, ...> 330 std::reverse(indices.begin(), indices.end()); 331 // We are in the innermost loop: generate the reduction body. 332 mlir::Type eleRefTy = builder.getRefType(elementType); 333 mlir::Value addr = 334 builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices); 335 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr); 336 mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init); 337 // Generate vector with condition to continue while loop at [0] and result 338 // from current loop at [1] for IterWhileOp loops, just result at [0] for 339 // DoLoopOp loops. 340 llvm::SmallVector<mlir::Value> results = loopCond(builder, loc, reductionVal); 341 342 // Unwind the loop nest and insert ResultOp on each level 343 // to return the updated value of the reduction to the enclosing 344 // loops. 345 for (unsigned i = 0; i < rank; ++i) { 346 auto result = builder.create<fir::ResultOp>(loc, results); 347 // Proceed to the outer loop. 348 auto loop = mlir::cast<OP>(result->getParentOp()); 349 results = loop.getResults(); 350 // Set insertion point after the loop operation that we have 351 // just processed. 352 builder.setInsertionPointAfter(loop.getOperation()); 353 } 354 // End of loop nest. The insertion point is after the outermost loop. 355 // Return the reduction value from the function. 356 builder.create<mlir::func::ReturnOp>(loc, results[resultIndex]); 357 } 358 359 static llvm::SmallVector<mlir::Value> nopLoopCond(fir::FirOpBuilder &builder, 360 mlir::Location loc, 361 mlir::Value reductionVal) { 362 return {reductionVal}; 363 } 364 365 /// Generate function body of the simplified version of RTNAME(Sum) 366 /// with signature provided by \p funcOp. The caller is responsible 367 /// for saving/restoring the original insertion point of \p builder. 368 /// \p funcOp is expected to be empty on entry to this function. 369 /// \p rank specifies the rank of the input argument. 370 static void genRuntimeSumBody(fir::FirOpBuilder &builder, 371 mlir::func::FuncOp &funcOp, unsigned rank, 372 mlir::Type elementType) { 373 // function RTNAME(Sum)<T>x<rank>_simplified(arr) 374 // T, dimension(:) :: arr 375 // T sum = 0 376 // integer iter 377 // do iter = 0, extent(arr) 378 // sum = sum + arr[iter] 379 // end do 380 // RTNAME(Sum)<T>x<rank>_simplified = sum 381 // end function RTNAME(Sum)<T>x<rank>_simplified 382 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc, 383 mlir::Type elementType) { 384 if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) { 385 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 386 return builder.createRealConstant(loc, elementType, 387 llvm::APFloat::getZero(sem)); 388 } 389 return builder.createIntegerConstant(loc, elementType, 0); 390 }; 391 392 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 393 mlir::Type elementType, mlir::Value elem1, 394 mlir::Value elem2) -> mlir::Value { 395 if (mlir::isa<mlir::FloatType>(elementType)) 396 return builder.create<mlir::arith::AddFOp>(loc, elem1, elem2); 397 if (mlir::isa<mlir::IntegerType>(elementType)) 398 return builder.create<mlir::arith::AddIOp>(loc, elem1, elem2); 399 400 llvm_unreachable("unsupported type"); 401 return {}; 402 }; 403 404 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 405 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 406 407 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond, 408 false, genBodyOp, rank, elementType, 409 loc); 410 } 411 412 static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder, 413 mlir::func::FuncOp &funcOp, unsigned rank, 414 mlir::Type elementType) { 415 auto init = [](fir::FirOpBuilder builder, mlir::Location loc, 416 mlir::Type elementType) { 417 if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) { 418 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 419 return builder.createRealConstant( 420 loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/true)); 421 } 422 unsigned bits = elementType.getIntOrFloatBitWidth(); 423 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); 424 return builder.createIntegerConstant(loc, elementType, minInt); 425 }; 426 427 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 428 mlir::Type elementType, mlir::Value elem1, 429 mlir::Value elem2) -> mlir::Value { 430 if (mlir::isa<mlir::FloatType>(elementType)) { 431 // arith.maxf later converted to llvm.intr.maxnum does not work 432 // correctly for NaNs and -0.0 (see maxnum/minnum pattern matching 433 // in LLVM's InstCombine pass). Moreover, llvm.intr.maxnum 434 // for F128 operands is lowered into fmaxl call by LLVM. 435 // This libm function may not work properly for F128 arguments 436 // on targets where long double is not F128. It is an LLVM issue, 437 // but we just use normal select here to resolve all the cases. 438 auto compare = builder.create<mlir::arith::CmpFOp>( 439 loc, mlir::arith::CmpFPredicate::OGT, elem1, elem2); 440 return builder.create<mlir::arith::SelectOp>(loc, compare, elem1, elem2); 441 } 442 if (mlir::isa<mlir::IntegerType>(elementType)) 443 return builder.create<mlir::arith::MaxSIOp>(loc, elem1, elem2); 444 445 llvm_unreachable("unsupported type"); 446 return {}; 447 }; 448 449 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 450 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 451 452 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, init, nopLoopCond, 453 false, genBodyOp, rank, elementType, 454 loc); 455 } 456 457 static void genRuntimeCountBody(fir::FirOpBuilder &builder, 458 mlir::func::FuncOp &funcOp, unsigned rank, 459 mlir::Type elementType) { 460 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc, 461 mlir::Type elementType) { 462 unsigned bits = elementType.getIntOrFloatBitWidth(); 463 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 464 return builder.createIntegerConstant(loc, elementType, zeroInt); 465 }; 466 467 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 468 mlir::Type elementType, mlir::Value elem1, 469 mlir::Value elem2) -> mlir::Value { 470 auto zero32 = builder.createIntegerConstant(loc, elementType, 0); 471 auto zero64 = builder.createIntegerConstant(loc, builder.getI64Type(), 0); 472 auto one64 = builder.createIntegerConstant(loc, builder.getI64Type(), 1); 473 474 auto compare = builder.create<mlir::arith::CmpIOp>( 475 loc, mlir::arith::CmpIPredicate::eq, elem1, zero32); 476 auto select = 477 builder.create<mlir::arith::SelectOp>(loc, compare, zero64, one64); 478 return builder.create<mlir::arith::AddIOp>(loc, select, elem2); 479 }; 480 481 // Count always gets I32 for elementType as it converts logical input to 482 // logical<4> before passing to the function. 483 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 484 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 485 486 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond, 487 false, genBodyOp, rank, elementType, 488 loc); 489 } 490 491 static void genRuntimeAnyBody(fir::FirOpBuilder &builder, 492 mlir::func::FuncOp &funcOp, unsigned rank, 493 mlir::Type elementType) { 494 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc, 495 mlir::Type elementType) { 496 return builder.createIntegerConstant(loc, elementType, 0); 497 }; 498 499 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 500 mlir::Type elementType, mlir::Value elem1, 501 mlir::Value elem2) -> mlir::Value { 502 auto zero = builder.createIntegerConstant(loc, elementType, 0); 503 return builder.create<mlir::arith::CmpIOp>( 504 loc, mlir::arith::CmpIPredicate::ne, elem1, zero); 505 }; 506 507 auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc, 508 mlir::Value reductionVal) { 509 auto one1 = builder.createIntegerConstant(loc, builder.getI1Type(), 1); 510 auto eor = builder.create<mlir::arith::XOrIOp>(loc, reductionVal, one1); 511 llvm::SmallVector<mlir::Value> results = {eor, reductionVal}; 512 return results; 513 }; 514 515 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 516 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 517 mlir::Value ok = builder.createBool(loc, true); 518 519 genReductionLoop<fir::IterWhileOp, mlir::Value, 1>( 520 builder, funcOp, zero, continueCond, ok, genBodyOp, rank, elementType, 521 loc); 522 } 523 524 static void genRuntimeAllBody(fir::FirOpBuilder &builder, 525 mlir::func::FuncOp &funcOp, unsigned rank, 526 mlir::Type elementType) { 527 auto one = [](fir::FirOpBuilder builder, mlir::Location loc, 528 mlir::Type elementType) { 529 return builder.createIntegerConstant(loc, elementType, 1); 530 }; 531 532 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 533 mlir::Type elementType, mlir::Value elem1, 534 mlir::Value elem2) -> mlir::Value { 535 auto zero = builder.createIntegerConstant(loc, elementType, 0); 536 return builder.create<mlir::arith::CmpIOp>( 537 loc, mlir::arith::CmpIPredicate::ne, elem1, zero); 538 }; 539 540 auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc, 541 mlir::Value reductionVal) { 542 llvm::SmallVector<mlir::Value> results = {reductionVal, reductionVal}; 543 return results; 544 }; 545 546 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 547 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 548 mlir::Value ok = builder.createBool(loc, true); 549 550 genReductionLoop<fir::IterWhileOp, mlir::Value, 1>( 551 builder, funcOp, one, continueCond, ok, genBodyOp, rank, elementType, 552 loc); 553 } 554 555 static mlir::FunctionType genRuntimeMinlocType(fir::FirOpBuilder &builder, 556 unsigned int rank) { 557 mlir::Type boxType = fir::BoxType::get(builder.getNoneType()); 558 mlir::Type boxRefType = builder.getRefType(boxType); 559 560 return mlir::FunctionType::get(builder.getContext(), 561 {boxRefType, boxType, boxType}, {}); 562 } 563 564 // Produces a loop nest for a Minloc intrinsic. 565 void fir::genMinMaxlocReductionLoop( 566 fir::FirOpBuilder &builder, mlir::Value array, 567 fir::InitValGeneratorTy initVal, fir::MinlocBodyOpGeneratorTy genBody, 568 fir::AddrGeneratorTy getAddrFn, unsigned rank, mlir::Type elementType, 569 mlir::Location loc, mlir::Type maskElemType, mlir::Value resultArr, 570 bool maskMayBeLogicalScalar) { 571 mlir::IndexType idxTy = builder.getIndexType(); 572 573 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); 574 575 fir::SequenceType::Shape flatShape(rank, 576 fir::SequenceType::getUnknownExtent()); 577 mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType); 578 mlir::Type boxArrTy = fir::BoxType::get(arrTy); 579 array = builder.create<fir::ConvertOp>(loc, boxArrTy, array); 580 581 mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType()); 582 mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1); 583 mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0); 584 mlir::Value flagRef = builder.createTemporary(loc, resultElemType); 585 builder.create<fir::StoreOp>(loc, zero, flagRef); 586 587 mlir::Value init = initVal(builder, loc, elementType); 588 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds; 589 590 assert(rank > 0 && "rank cannot be zero"); 591 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); 592 593 // Compute all the upper bounds before the loop nest. 594 // It is not strictly necessary for performance, since the loop nest 595 // does not have any store operations and any LICM optimization 596 // should be able to optimize the redundancy. 597 for (unsigned i = 0; i < rank; ++i) { 598 mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i); 599 auto dims = 600 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx); 601 mlir::Value len = dims.getResult(1); 602 // We use C indexing here, so len-1 as loopcount 603 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one); 604 bounds.push_back(loopCount); 605 } 606 // Create a loop nest consisting of OP operations. 607 // Collect the loops' induction variables into indices array, 608 // which will be used in the innermost loop to load the input 609 // array's element. 610 // The loops are generated such that the innermost loop processes 611 // the 0 dimension. 612 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices; 613 for (unsigned i = rank; 0 < i; --i) { 614 mlir::Value step = one; 615 mlir::Value loopCount = bounds[i - 1]; 616 auto loop = 617 builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false, 618 /*finalCountValue=*/false, init); 619 init = loop.getRegionIterArgs()[0]; 620 indices.push_back(loop.getInductionVar()); 621 // Set insertion point to the loop body so that the next loop 622 // is inserted inside the current one. 623 builder.setInsertionPointToStart(loop.getBody()); 624 } 625 626 // Reverse the indices such that they are ordered as: 627 // <dim-0-idx, dim-1-idx, ...> 628 std::reverse(indices.begin(), indices.end()); 629 mlir::Value reductionVal = 630 genBody(builder, loc, elementType, array, flagRef, init, indices); 631 632 // Unwind the loop nest and insert ResultOp on each level 633 // to return the updated value of the reduction to the enclosing 634 // loops. 635 for (unsigned i = 0; i < rank; ++i) { 636 auto result = builder.create<fir::ResultOp>(loc, reductionVal); 637 // Proceed to the outer loop. 638 auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp()); 639 reductionVal = loop.getResult(0); 640 // Set insertion point after the loop operation that we have 641 // just processed. 642 builder.setInsertionPointAfter(loop.getOperation()); 643 } 644 // End of loop nest. The insertion point is after the outermost loop. 645 if (maskMayBeLogicalScalar) { 646 if (fir::IfOp ifOp = 647 mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) { 648 builder.create<fir::ResultOp>(loc, reductionVal); 649 builder.setInsertionPointAfter(ifOp); 650 // Redefine flagSet to escape scope of ifOp 651 flagSet = builder.createIntegerConstant(loc, resultElemType, 1); 652 reductionVal = ifOp.getResult(0); 653 } 654 } 655 } 656 657 static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder, 658 mlir::func::FuncOp &funcOp, bool isMax, 659 unsigned rank, int maskRank, 660 mlir::Type elementType, 661 mlir::Type maskElemType, 662 mlir::Type resultElemTy, bool isDim) { 663 auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc, 664 mlir::Type elementType) { 665 if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) { 666 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 667 llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax); 668 return builder.createRealConstant(loc, elementType, limit); 669 } 670 unsigned bits = elementType.getIntOrFloatBitWidth(); 671 int64_t initValue = (isMax ? llvm::APInt::getSignedMinValue(bits) 672 : llvm::APInt::getSignedMaxValue(bits)) 673 .getSExtValue(); 674 return builder.createIntegerConstant(loc, elementType, initValue); 675 }; 676 677 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 678 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 679 680 mlir::Value mask = funcOp.front().getArgument(2); 681 682 // Set up result array in case of early exit / 0 length array 683 mlir::IndexType idxTy = builder.getIndexType(); 684 mlir::Type resultTy = fir::SequenceType::get(rank, resultElemTy); 685 mlir::Type resultHeapTy = fir::HeapType::get(resultTy); 686 mlir::Type resultBoxTy = fir::BoxType::get(resultHeapTy); 687 688 mlir::Value returnValue = builder.createIntegerConstant(loc, resultElemTy, 0); 689 mlir::Value resultArrSize = builder.createIntegerConstant(loc, idxTy, rank); 690 691 mlir::Value resultArrInit = builder.create<fir::AllocMemOp>(loc, resultTy); 692 mlir::Value resultArrShape = builder.create<fir::ShapeOp>(loc, resultArrSize); 693 mlir::Value resultArr = builder.create<fir::EmboxOp>( 694 loc, resultBoxTy, resultArrInit, resultArrShape); 695 696 mlir::Type resultRefTy = builder.getRefType(resultElemTy); 697 698 if (maskRank > 0) { 699 fir::SequenceType::Shape flatShape(rank, 700 fir::SequenceType::getUnknownExtent()); 701 mlir::Type maskTy = fir::SequenceType::get(flatShape, maskElemType); 702 mlir::Type boxMaskTy = fir::BoxType::get(maskTy); 703 mask = builder.create<fir::ConvertOp>(loc, boxMaskTy, mask); 704 } 705 706 for (unsigned int i = 0; i < rank; ++i) { 707 mlir::Value index = builder.createIntegerConstant(loc, idxTy, i); 708 mlir::Value resultElemAddr = 709 builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr, index); 710 builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr); 711 } 712 713 auto genBodyOp = 714 [&rank, &resultArr, isMax, &mask, &maskElemType, &maskRank]( 715 fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType, 716 mlir::Value array, mlir::Value flagRef, mlir::Value reduction, 717 const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value { 718 // We are in the innermost loop: generate the reduction body. 719 if (maskRank > 0) { 720 mlir::Type logicalRef = builder.getRefType(maskElemType); 721 mlir::Value maskAddr = 722 builder.create<fir::CoordinateOp>(loc, logicalRef, mask, indices); 723 mlir::Value maskElem = builder.create<fir::LoadOp>(loc, maskAddr); 724 725 // fir::IfOp requires argument to be I1 - won't accept logical or any 726 // other Integer. 727 mlir::Type ifCompatType = builder.getI1Type(); 728 mlir::Value ifCompatElem = 729 builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem); 730 731 llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType}; 732 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, ifCompatElem, 733 /*withElseRegion=*/true); 734 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 735 } 736 737 // Set flag that mask was true at some point 738 mlir::Value flagSet = builder.createIntegerConstant( 739 loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1); 740 mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef); 741 mlir::Type eleRefTy = builder.getRefType(elementType); 742 mlir::Value addr = 743 builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices); 744 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr); 745 746 mlir::Value cmp; 747 if (mlir::isa<mlir::FloatType>(elementType)) { 748 // For FP reductions we want the first smallest value to be used, that 749 // is not NaN. A OGL/OLT condition will usually work for this unless all 750 // the values are Nan or Inf. This follows the same logic as 751 // NumericCompare for Minloc/Maxlox in extrema.cpp. 752 cmp = builder.create<mlir::arith::CmpFOp>( 753 loc, 754 isMax ? mlir::arith::CmpFPredicate::OGT 755 : mlir::arith::CmpFPredicate::OLT, 756 elem, reduction); 757 758 mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>( 759 loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction); 760 mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>( 761 loc, mlir::arith::CmpFPredicate::OEQ, elem, elem); 762 cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2); 763 cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan); 764 } else if (mlir::isa<mlir::IntegerType>(elementType)) { 765 cmp = builder.create<mlir::arith::CmpIOp>( 766 loc, 767 isMax ? mlir::arith::CmpIPredicate::sgt 768 : mlir::arith::CmpIPredicate::slt, 769 elem, reduction); 770 } else { 771 llvm_unreachable("unsupported type"); 772 } 773 774 // The condition used for the loop is isFirst || <the condition above>. 775 isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst); 776 isFirst = builder.create<mlir::arith::XOrIOp>( 777 loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1)); 778 cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst); 779 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp, 780 /*withElseRegion*/ true); 781 782 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 783 builder.create<fir::StoreOp>(loc, flagSet, flagRef); 784 mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType()); 785 mlir::Type returnRefTy = builder.getRefType(resultElemTy); 786 mlir::IndexType idxTy = builder.getIndexType(); 787 788 mlir::Value one = builder.createIntegerConstant(loc, resultElemTy, 1); 789 790 for (unsigned int i = 0; i < rank; ++i) { 791 mlir::Value index = builder.createIntegerConstant(loc, idxTy, i); 792 mlir::Value resultElemAddr = 793 builder.create<fir::CoordinateOp>(loc, returnRefTy, resultArr, index); 794 mlir::Value convert = 795 builder.create<fir::ConvertOp>(loc, resultElemTy, indices[i]); 796 mlir::Value fortranIndex = 797 builder.create<mlir::arith::AddIOp>(loc, convert, one); 798 builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr); 799 } 800 builder.create<fir::ResultOp>(loc, elem); 801 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 802 builder.create<fir::ResultOp>(loc, reduction); 803 builder.setInsertionPointAfter(ifOp); 804 mlir::Value reductionVal = ifOp.getResult(0); 805 806 // Close the mask if needed 807 if (maskRank > 0) { 808 fir::IfOp ifOp = 809 mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp()); 810 builder.create<fir::ResultOp>(loc, reductionVal); 811 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 812 builder.create<fir::ResultOp>(loc, reduction); 813 reductionVal = ifOp.getResult(0); 814 builder.setInsertionPointAfter(ifOp); 815 } 816 817 return reductionVal; 818 }; 819 820 // if mask is a logical scalar, we can check its value before the main loop 821 // and either ignore the fact it is there or exit early. 822 if (maskRank == 0) { 823 mlir::Type logical = builder.getI1Type(); 824 mlir::IndexType idxTy = builder.getIndexType(); 825 826 fir::SequenceType::Shape singleElement(1, 1); 827 mlir::Type arrTy = fir::SequenceType::get(singleElement, logical); 828 mlir::Type boxArrTy = fir::BoxType::get(arrTy); 829 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, mask); 830 831 mlir::Value indx = builder.createIntegerConstant(loc, idxTy, 0); 832 mlir::Type logicalRefTy = builder.getRefType(logical); 833 mlir::Value condAddr = 834 builder.create<fir::CoordinateOp>(loc, logicalRefTy, array, indx); 835 mlir::Value cond = builder.create<fir::LoadOp>(loc, condAddr); 836 837 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cond, 838 /*withElseRegion=*/true); 839 840 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 841 mlir::Value basicValue; 842 if (mlir::isa<mlir::IntegerType>(elementType)) { 843 basicValue = builder.createIntegerConstant(loc, elementType, 0); 844 } else { 845 basicValue = builder.createRealConstant(loc, elementType, 0); 846 } 847 builder.create<fir::ResultOp>(loc, basicValue); 848 849 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 850 } 851 auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc, 852 const mlir::Type &resultElemType, mlir::Value resultArr, 853 mlir::Value index) { 854 mlir::Type resultRefTy = builder.getRefType(resultElemType); 855 return builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr, 856 index); 857 }; 858 859 genMinMaxlocReductionLoop(builder, funcOp.front().getArgument(1), init, 860 genBodyOp, getAddrFn, rank, elementType, loc, 861 maskElemType, resultArr, maskRank == 0); 862 863 // Store newly created output array to the reference passed in 864 if (isDim) { 865 mlir::Type resultBoxTy = 866 fir::BoxType::get(fir::HeapType::get(resultElemTy)); 867 mlir::Value outputArr = builder.create<fir::ConvertOp>( 868 loc, builder.getRefType(resultBoxTy), funcOp.front().getArgument(0)); 869 mlir::Value resultArrScalar = builder.create<fir::ConvertOp>( 870 loc, fir::HeapType::get(resultElemTy), resultArrInit); 871 mlir::Value resultBox = 872 builder.create<fir::EmboxOp>(loc, resultBoxTy, resultArrScalar); 873 builder.create<fir::StoreOp>(loc, resultBox, outputArr); 874 } else { 875 fir::SequenceType::Shape resultShape(1, rank); 876 mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemTy); 877 mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy); 878 mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy); 879 mlir::Type outputRefTy = builder.getRefType(outputBoxTy); 880 mlir::Value outputArr = builder.create<fir::ConvertOp>( 881 loc, outputRefTy, funcOp.front().getArgument(0)); 882 builder.create<fir::StoreOp>(loc, resultArr, outputArr); 883 } 884 885 builder.create<mlir::func::ReturnOp>(loc); 886 } 887 888 /// Generate function type for the simplified version of RTNAME(DotProduct) 889 /// operating on the given \p elementType. 890 static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder, 891 const mlir::Type &elementType) { 892 mlir::Type boxType = fir::BoxType::get(builder.getNoneType()); 893 return mlir::FunctionType::get(builder.getContext(), {boxType, boxType}, 894 {elementType}); 895 } 896 897 /// Generate function body of the simplified version of RTNAME(DotProduct) 898 /// with signature provided by \p funcOp. The caller is responsible 899 /// for saving/restoring the original insertion point of \p builder. 900 /// \p funcOp is expected to be empty on entry to this function. 901 /// \p arg1ElementTy and \p arg2ElementTy specify elements types 902 /// of the underlying array objects - they are used to generate proper 903 /// element accesses. 904 static void genRuntimeDotBody(fir::FirOpBuilder &builder, 905 mlir::func::FuncOp &funcOp, 906 mlir::Type arg1ElementTy, 907 mlir::Type arg2ElementTy) { 908 // function RTNAME(DotProduct)<T>_simplified(arr1, arr2) 909 // T, dimension(:) :: arr1, arr2 910 // T product = 0 911 // integer iter 912 // do iter = 0, extent(arr1) 913 // product = product + arr1[iter] * arr2[iter] 914 // end do 915 // RTNAME(ADotProduct)<T>_simplified = product 916 // end function RTNAME(DotProduct)<T>_simplified 917 auto loc = mlir::UnknownLoc::get(builder.getContext()); 918 mlir::Type resultElementType = funcOp.getResultTypes()[0]; 919 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 920 921 mlir::IndexType idxTy = builder.getIndexType(); 922 923 mlir::Value zero = 924 mlir::isa<mlir::FloatType>(resultElementType) 925 ? builder.createRealConstant(loc, resultElementType, 0.0) 926 : builder.createIntegerConstant(loc, resultElementType, 0); 927 928 mlir::Block::BlockArgListType args = funcOp.front().getArguments(); 929 mlir::Value arg1 = args[0]; 930 mlir::Value arg2 = args[1]; 931 932 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); 933 934 fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()}; 935 mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy); 936 mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1); 937 mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1); 938 mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy); 939 mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2); 940 mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2); 941 // This version takes the loop trip count from the first argument. 942 // If the first argument's box has unknown (at compilation time) 943 // extent, then it may be better to take the extent from the second 944 // argument - so that after inlining the loop may be better optimized, e.g. 945 // fully unrolled. This requires generating two versions of the simplified 946 // function and some analysis at the call site to choose which version 947 // is more profitable to call. 948 // Note that we can assume that both arguments have the same extent. 949 auto dims = 950 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx); 951 mlir::Value len = dims.getResult(1); 952 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); 953 mlir::Value step = one; 954 955 // We use C indexing here, so len-1 as loopcount 956 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one); 957 auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, 958 /*unordered=*/false, 959 /*finalCountValue=*/false, zero); 960 mlir::Value sumVal = loop.getRegionIterArgs()[0]; 961 962 // Begin loop code 963 mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint(); 964 builder.setInsertionPointToStart(loop.getBody()); 965 966 mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy); 967 mlir::Value index = loop.getInductionVar(); 968 mlir::Value addr1 = 969 builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index); 970 mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1); 971 // Convert to the result type. 972 elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1); 973 974 mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy); 975 mlir::Value addr2 = 976 builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index); 977 mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2); 978 // Convert to the result type. 979 elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2); 980 981 if (mlir::isa<mlir::FloatType>(resultElementType)) 982 sumVal = builder.create<mlir::arith::AddFOp>( 983 loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal); 984 else if (mlir::isa<mlir::IntegerType>(resultElementType)) 985 sumVal = builder.create<mlir::arith::AddIOp>( 986 loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal); 987 else 988 llvm_unreachable("unsupported type"); 989 990 builder.create<fir::ResultOp>(loc, sumVal); 991 // End of loop. 992 builder.restoreInsertionPoint(loopEndPt); 993 994 mlir::Value resultVal = loop.getResult(0); 995 builder.create<mlir::func::ReturnOp>(loc, resultVal); 996 } 997 998 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction( 999 fir::FirOpBuilder &builder, const mlir::StringRef &baseName, 1000 FunctionTypeGeneratorTy typeGenerator, 1001 FunctionBodyGeneratorTy bodyGenerator) { 1002 // WARNING: if the function generated here changes its signature 1003 // or behavior (the body code), we should probably embed some 1004 // versioning information into its name, otherwise libraries 1005 // statically linked with older versions of Flang may stop 1006 // working with object files created with newer Flang. 1007 // We can also avoid this by using internal linkage, but 1008 // this may increase the size of final executable/shared library. 1009 std::string replacementName = mlir::Twine{baseName, "_simplified"}.str(); 1010 // If we already have a function, just return it. 1011 mlir::func::FuncOp newFunc = builder.getNamedFunction(replacementName); 1012 mlir::FunctionType fType = typeGenerator(builder); 1013 if (newFunc) { 1014 assert(newFunc.getFunctionType() == fType && 1015 "type mismatch for simplified function"); 1016 return newFunc; 1017 } 1018 1019 // Need to build the function! 1020 auto loc = mlir::UnknownLoc::get(builder.getContext()); 1021 newFunc = builder.createFunction(loc, replacementName, fType); 1022 auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR; 1023 auto linkage = 1024 mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); 1025 newFunc->setAttr("llvm.linkage", linkage); 1026 1027 // Save the position of the original call. 1028 mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint(); 1029 1030 bodyGenerator(builder, newFunc); 1031 1032 // Now back to where we were adding code earlier... 1033 builder.restoreInsertionPoint(insertPt); 1034 1035 return newFunc; 1036 } 1037 1038 void SimplifyIntrinsicsPass::simplifyIntOrFloatReduction( 1039 fir::CallOp call, const fir::KindMapping &kindMap, 1040 GenReductionBodyTy genBodyFunc) { 1041 // args[1] and args[2] are source filename and line number, ignored. 1042 mlir::Operation::operand_range args = call.getArgs(); 1043 1044 const mlir::Value &dim = args[3]; 1045 const mlir::Value &mask = args[4]; 1046 // dim is zero when it is absent, which is an implementation 1047 // detail in the runtime library. 1048 1049 bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask); 1050 unsigned rank = getDimCount(args[0]); 1051 1052 // Rank is set to 0 for assumed shape arrays, don't simplify 1053 // in these cases 1054 if (!(dimAndMaskAbsent && rank > 0)) 1055 return; 1056 1057 mlir::Type resultType = call.getResult(0).getType(); 1058 1059 if (!mlir::isa<mlir::FloatType>(resultType) && 1060 !mlir::isa<mlir::IntegerType>(resultType)) 1061 return; 1062 1063 auto argType = getArgElementType(args[0]); 1064 if (!argType) 1065 return; 1066 assert(*argType == resultType && 1067 "Argument/result types mismatch in reduction"); 1068 1069 mlir::SymbolRefAttr callee = call.getCalleeAttr(); 1070 1071 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; 1072 std::string fmfString{builder.getFastMathFlagsString()}; 1073 std::string funcName = 1074 (mlir::Twine{callee.getLeafReference().getValue(), "x"} + 1075 mlir::Twine{rank} + 1076 // We must mangle the generated function name with FastMathFlags 1077 // value. 1078 (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString})) 1079 .str(); 1080 1081 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName, 1082 resultType); 1083 } 1084 1085 void SimplifyIntrinsicsPass::simplifyLogicalDim0Reduction( 1086 fir::CallOp call, const fir::KindMapping &kindMap, 1087 GenReductionBodyTy genBodyFunc) { 1088 1089 mlir::Operation::operand_range args = call.getArgs(); 1090 const mlir::Value &dim = args[3]; 1091 unsigned rank = getDimCount(args[0]); 1092 1093 // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in 1094 // these cases. 1095 if (!(isZero(dim) && rank > 0)) 1096 return; 1097 1098 mlir::Value inputBox = findBoxDef(args[0]); 1099 1100 mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType()); 1101 mlir::SymbolRefAttr callee = call.getCalleeAttr(); 1102 1103 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; 1104 1105 // Treating logicals as integers makes things a lot easier 1106 fir::LogicalType logicalType = { 1107 mlir::dyn_cast<fir::LogicalType>(elementType)}; 1108 fir::KindTy kind = logicalType.getFKind(); 1109 mlir::Type intElementType = builder.getIntegerType(kind * 8); 1110 1111 // Mangle kind into function name as it is not done by default 1112 std::string funcName = 1113 (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} + 1114 mlir::Twine{kind} + "x" + mlir::Twine{rank}) 1115 .str(); 1116 1117 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName, 1118 intElementType); 1119 } 1120 1121 void SimplifyIntrinsicsPass::simplifyLogicalDim1Reduction( 1122 fir::CallOp call, const fir::KindMapping &kindMap, 1123 GenReductionBodyTy genBodyFunc) { 1124 1125 mlir::Operation::operand_range args = call.getArgs(); 1126 mlir::SymbolRefAttr callee = call.getCalleeAttr(); 1127 mlir::StringRef funcNameBase = callee.getLeafReference().getValue(); 1128 unsigned rank = getDimCount(args[0]); 1129 1130 // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in 1131 // these cases. We check for Dim at the end as some logical functions (Any, 1132 // All) set dim to 1 instead of 0 when the argument is not present. 1133 if (funcNameBase.ends_with("Dim") || !(rank > 0)) 1134 return; 1135 1136 mlir::Value inputBox = findBoxDef(args[0]); 1137 mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType()); 1138 1139 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; 1140 1141 // Treating logicals as integers makes things a lot easier 1142 fir::LogicalType logicalType = { 1143 mlir::dyn_cast<fir::LogicalType>(elementType)}; 1144 fir::KindTy kind = logicalType.getFKind(); 1145 mlir::Type intElementType = builder.getIntegerType(kind * 8); 1146 1147 // Mangle kind into function name as it is not done by default 1148 std::string funcName = 1149 (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} + 1150 mlir::Twine{kind} + "x" + mlir::Twine{rank}) 1151 .str(); 1152 1153 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName, 1154 intElementType); 1155 } 1156 1157 void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction( 1158 fir::CallOp call, const fir::KindMapping &kindMap, bool isMax) { 1159 1160 mlir::Operation::operand_range args = call.getArgs(); 1161 1162 mlir::SymbolRefAttr callee = call.getCalleeAttr(); 1163 mlir::StringRef funcNameBase = callee.getLeafReference().getValue(); 1164 bool isDim = funcNameBase.ends_with("Dim"); 1165 mlir::Value back = args[isDim ? 7 : 6]; 1166 if (isTrueOrNotConstant(back)) 1167 return; 1168 1169 mlir::Value mask = args[isDim ? 6 : 5]; 1170 mlir::Value maskDef = findMaskDef(mask); 1171 1172 // maskDef is set to NULL when the defining op is not one we accept. 1173 // This tends to be because it is a selectOp, in which case let the 1174 // runtime deal with it. 1175 if (maskDef == NULL) 1176 return; 1177 1178 unsigned rank = getDimCount(args[1]); 1179 if ((isDim && rank != 1) || !(rank > 0)) 1180 return; 1181 1182 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; 1183 mlir::Location loc = call.getLoc(); 1184 auto inputBox = findBoxDef(args[1]); 1185 mlir::Type inputType = hlfir::getFortranElementType(inputBox.getType()); 1186 1187 if (mlir::isa<fir::CharacterType>(inputType)) 1188 return; 1189 1190 int maskRank; 1191 fir::KindTy kind = 0; 1192 mlir::Type logicalElemType = builder.getI1Type(); 1193 if (isOperandAbsent(mask)) { 1194 maskRank = -1; 1195 } else { 1196 maskRank = getDimCount(mask); 1197 mlir::Type maskElemTy = hlfir::getFortranElementType(maskDef.getType()); 1198 fir::LogicalType logicalFirType = { 1199 mlir::dyn_cast<fir::LogicalType>(maskElemTy)}; 1200 kind = logicalFirType.getFKind(); 1201 // Convert fir::LogicalType to mlir::Type 1202 logicalElemType = logicalFirType; 1203 } 1204 1205 mlir::Operation *outputDef = args[0].getDefiningOp(); 1206 mlir::Value outputAlloc = outputDef->getOperand(0); 1207 mlir::Type outType = hlfir::getFortranElementType(outputAlloc.getType()); 1208 1209 std::string fmfString{builder.getFastMathFlagsString()}; 1210 std::string funcName = 1211 (mlir::Twine{callee.getLeafReference().getValue(), "x"} + 1212 mlir::Twine{rank} + 1213 (maskRank >= 0 1214 ? "_Logical" + mlir::Twine{kind} + "x" + mlir::Twine{maskRank} 1215 : "") + 1216 "_") 1217 .str(); 1218 1219 llvm::raw_string_ostream nameOS(funcName); 1220 outType.print(nameOS); 1221 if (isDim) 1222 nameOS << '_' << inputType; 1223 nameOS << '_' << fmfString; 1224 1225 auto typeGenerator = [rank](fir::FirOpBuilder &builder) { 1226 return genRuntimeMinlocType(builder, rank); 1227 }; 1228 auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType, 1229 isMax, isDim](fir::FirOpBuilder &builder, 1230 mlir::func::FuncOp &funcOp) { 1231 genRuntimeMinMaxlocBody(builder, funcOp, isMax, rank, maskRank, inputType, 1232 logicalElemType, outType, isDim); 1233 }; 1234 1235 mlir::func::FuncOp newFunc = 1236 getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator); 1237 builder.create<fir::CallOp>(loc, newFunc, 1238 mlir::ValueRange{args[0], args[1], mask}); 1239 call->dropAllReferences(); 1240 call->erase(); 1241 } 1242 1243 void SimplifyIntrinsicsPass::simplifyReductionBody( 1244 fir::CallOp call, const fir::KindMapping &kindMap, 1245 GenReductionBodyTy genBodyFunc, fir::FirOpBuilder &builder, 1246 const mlir::StringRef &funcName, mlir::Type elementType) { 1247 1248 mlir::Operation::operand_range args = call.getArgs(); 1249 1250 mlir::Type resultType = call.getResult(0).getType(); 1251 unsigned rank = getDimCount(args[0]); 1252 1253 mlir::Location loc = call.getLoc(); 1254 1255 auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) { 1256 return genNoneBoxType(builder, resultType); 1257 }; 1258 auto bodyGenerator = [&rank, &genBodyFunc, 1259 &elementType](fir::FirOpBuilder &builder, 1260 mlir::func::FuncOp &funcOp) { 1261 genBodyFunc(builder, funcOp, rank, elementType); 1262 }; 1263 // Mangle the function name with the rank value as "x<rank>". 1264 mlir::func::FuncOp newFunc = 1265 getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator); 1266 auto newCall = 1267 builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]}); 1268 call->replaceAllUsesWith(newCall.getResults()); 1269 call->dropAllReferences(); 1270 call->erase(); 1271 } 1272 1273 void SimplifyIntrinsicsPass::runOnOperation() { 1274 LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n"); 1275 mlir::ModuleOp module = getOperation(); 1276 fir::KindMapping kindMap = fir::getKindMapping(module); 1277 module.walk([&](mlir::Operation *op) { 1278 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) { 1279 if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) { 1280 mlir::StringRef funcName = callee.getLeafReference().getValue(); 1281 // Replace call to runtime function for SUM when it has single 1282 // argument (no dim or mask argument) for 1D arrays with either 1283 // Integer4 or Real8 types. Other forms are ignored. 1284 // The new function is added to the module. 1285 // 1286 // Prototype for runtime call (from sum.cpp): 1287 // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line, 1288 // int dim, const Descriptor *mask) 1289 // 1290 if (funcName.starts_with(RTNAME_STRING(Sum))) { 1291 simplifyIntOrFloatReduction(call, kindMap, genRuntimeSumBody); 1292 return; 1293 } 1294 if (funcName.starts_with(RTNAME_STRING(DotProduct))) { 1295 LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n"); 1296 LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump(); 1297 llvm::dbgs() << "\n"); 1298 mlir::Operation::operand_range args = call.getArgs(); 1299 const mlir::Value &v1 = args[0]; 1300 const mlir::Value &v2 = args[1]; 1301 mlir::Location loc = call.getLoc(); 1302 fir::FirOpBuilder builder{getSimplificationBuilder(op, kindMap)}; 1303 // Stringize the builder's FastMathFlags flags for mangling 1304 // the generated function name. 1305 std::string fmfString{builder.getFastMathFlagsString()}; 1306 1307 mlir::Type type = call.getResult(0).getType(); 1308 if (!mlir::isa<mlir::FloatType>(type) && 1309 !mlir::isa<mlir::IntegerType>(type)) 1310 return; 1311 1312 // Try to find the element types of the boxed arguments. 1313 auto arg1Type = getArgElementType(v1); 1314 auto arg2Type = getArgElementType(v2); 1315 1316 if (!arg1Type || !arg2Type) 1317 return; 1318 1319 // Support only floating point and integer arguments 1320 // now (e.g. logical is skipped here). 1321 if (!arg1Type->isa<mlir::FloatType>() && 1322 !arg1Type->isa<mlir::IntegerType>()) 1323 return; 1324 if (!arg2Type->isa<mlir::FloatType>() && 1325 !arg2Type->isa<mlir::IntegerType>()) 1326 return; 1327 1328 auto typeGenerator = [&type](fir::FirOpBuilder &builder) { 1329 return genRuntimeDotType(builder, type); 1330 }; 1331 auto bodyGenerator = [&arg1Type, 1332 &arg2Type](fir::FirOpBuilder &builder, 1333 mlir::func::FuncOp &funcOp) { 1334 genRuntimeDotBody(builder, funcOp, *arg1Type, *arg2Type); 1335 }; 1336 1337 // Suffix the function name with the element types 1338 // of the arguments. 1339 std::string typedFuncName(funcName); 1340 llvm::raw_string_ostream nameOS(typedFuncName); 1341 // We must mangle the generated function name with FastMathFlags 1342 // value. 1343 if (!fmfString.empty()) 1344 nameOS << '_' << fmfString; 1345 nameOS << '_'; 1346 arg1Type->print(nameOS); 1347 nameOS << '_'; 1348 arg2Type->print(nameOS); 1349 1350 mlir::func::FuncOp newFunc = getOrCreateFunction( 1351 builder, typedFuncName, typeGenerator, bodyGenerator); 1352 auto newCall = builder.create<fir::CallOp>(loc, newFunc, 1353 mlir::ValueRange{v1, v2}); 1354 call->replaceAllUsesWith(newCall.getResults()); 1355 call->dropAllReferences(); 1356 call->erase(); 1357 1358 LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump(); 1359 llvm::dbgs() << "\n"); 1360 return; 1361 } 1362 if (funcName.starts_with(RTNAME_STRING(Maxval))) { 1363 simplifyIntOrFloatReduction(call, kindMap, genRuntimeMaxvalBody); 1364 return; 1365 } 1366 if (funcName.starts_with(RTNAME_STRING(Count))) { 1367 simplifyLogicalDim0Reduction(call, kindMap, genRuntimeCountBody); 1368 return; 1369 } 1370 if (funcName.starts_with(RTNAME_STRING(Any))) { 1371 simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAnyBody); 1372 return; 1373 } 1374 if (funcName.ends_with(RTNAME_STRING(All))) { 1375 simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAllBody); 1376 return; 1377 } 1378 if (funcName.starts_with(RTNAME_STRING(Minloc))) { 1379 simplifyMinMaxlocReduction(call, kindMap, false); 1380 return; 1381 } 1382 if (funcName.starts_with(RTNAME_STRING(Maxloc))) { 1383 simplifyMinMaxlocReduction(call, kindMap, true); 1384 return; 1385 } 1386 } 1387 } 1388 }); 1389 LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n"); 1390 } 1391 1392 void SimplifyIntrinsicsPass::getDependentDialects( 1393 mlir::DialectRegistry ®istry) const { 1394 // LLVM::LinkageAttr creation requires that LLVM dialect is loaded. 1395 registry.insert<mlir::LLVM::LLVMDialect>(); 1396 } 1397