1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 //===----------------------------------------------------------------------===// 10 /// \file 11 /// This pass looks for suitable calls to runtime library for intrinsics that 12 /// can be simplified/specialized and replaces with a specialized function. 13 /// 14 /// For example, SUM(arr) can be specialized as a simple function with one loop, 15 /// compared to the three arguments (plus file & line info) that the runtime 16 /// call has - when the argument is a 1D-array (multiple loops may be needed 17 // for higher dimension arrays, of course) 18 /// 19 /// The general idea is that besides making the call simpler, it can also be 20 /// inlined by other passes that run after this pass, which further improves 21 /// performance, particularly when the work done in the function is trivial 22 /// and small in size. 23 //===----------------------------------------------------------------------===// 24 25 #include "flang/Common/Fortran.h" 26 #include "flang/Optimizer/Builder/BoxValue.h" 27 #include "flang/Optimizer/Builder/CUFCommon.h" 28 #include "flang/Optimizer/Builder/FIRBuilder.h" 29 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" 30 #include "flang/Optimizer/Builder/Todo.h" 31 #include "flang/Optimizer/Dialect/FIROps.h" 32 #include "flang/Optimizer/Dialect/FIRType.h" 33 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 34 #include "flang/Optimizer/HLFIR/HLFIRDialect.h" 35 #include "flang/Optimizer/Transforms/Passes.h" 36 #include "flang/Optimizer/Transforms/Utils.h" 37 #include "flang/Runtime/entry-names.h" 38 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 39 #include "mlir/IR/Matchers.h" 40 #include "mlir/IR/Operation.h" 41 #include "mlir/Pass/Pass.h" 42 #include "mlir/Transforms/DialectConversion.h" 43 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 44 #include "mlir/Transforms/RegionUtils.h" 45 #include "llvm/Support/Debug.h" 46 #include "llvm/Support/raw_ostream.h" 47 #include <llvm/Support/ErrorHandling.h> 48 #include <mlir/Dialect/Arith/IR/Arith.h> 49 #include <mlir/IR/BuiltinTypes.h> 50 #include <mlir/IR/Location.h> 51 #include <mlir/IR/MLIRContext.h> 52 #include <mlir/IR/Value.h> 53 #include <mlir/Support/LLVM.h> 54 #include <optional> 55 56 namespace fir { 57 #define GEN_PASS_DEF_SIMPLIFYINTRINSICS 58 #include "flang/Optimizer/Transforms/Passes.h.inc" 59 } // namespace fir 60 61 #define DEBUG_TYPE "flang-simplify-intrinsics" 62 63 namespace { 64 65 class SimplifyIntrinsicsPass 66 : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> { 67 using FunctionTypeGeneratorTy = 68 llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>; 69 using FunctionBodyGeneratorTy = 70 llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>; 71 using GenReductionBodyTy = llvm::function_ref<void( 72 fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank, 73 mlir::Type elementType)>; 74 75 public: 76 using fir::impl::SimplifyIntrinsicsBase< 77 SimplifyIntrinsicsPass>::SimplifyIntrinsicsBase; 78 79 /// Generate a new function implementing a simplified version 80 /// of a Fortran runtime function defined by \p basename name. 81 /// \p typeGenerator is a callback that generates the new function's type. 82 /// \p bodyGenerator is a callback that generates the new function's body. 83 /// The new function is created in the \p builder's Module. 84 mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder, 85 const mlir::StringRef &basename, 86 FunctionTypeGeneratorTy typeGenerator, 87 FunctionBodyGeneratorTy bodyGenerator); 88 void runOnOperation() override; 89 void getDependentDialects(mlir::DialectRegistry ®istry) const override; 90 91 private: 92 /// Helper functions to replace a reduction type of call with its 93 /// simplified form. The actual function is generated using a callback 94 /// function. 95 /// \p call is the call to be replaced 96 /// \p kindMap is used to create FIROpBuilder 97 /// \p genBodyFunc is the callback that builds the replacement function 98 void simplifyIntOrFloatReduction(fir::CallOp call, 99 const fir::KindMapping &kindMap, 100 GenReductionBodyTy genBodyFunc); 101 void simplifyLogicalDim0Reduction(fir::CallOp call, 102 const fir::KindMapping &kindMap, 103 GenReductionBodyTy genBodyFunc); 104 void simplifyLogicalDim1Reduction(fir::CallOp call, 105 const fir::KindMapping &kindMap, 106 GenReductionBodyTy genBodyFunc); 107 void simplifyMinMaxlocReduction(fir::CallOp call, 108 const fir::KindMapping &kindMap, bool isMax); 109 void simplifyReductionBody(fir::CallOp call, const fir::KindMapping &kindMap, 110 GenReductionBodyTy genBodyFunc, 111 fir::FirOpBuilder &builder, 112 const mlir::StringRef &basename, 113 mlir::Type elementType); 114 }; 115 116 } // namespace 117 118 /// Create FirOpBuilder with the provided \p op insertion point 119 /// and \p kindMap additionally inheriting FastMathFlags from \p op. 120 static fir::FirOpBuilder 121 getSimplificationBuilder(mlir::Operation *op, const fir::KindMapping &kindMap) { 122 fir::FirOpBuilder builder{op, kindMap}; 123 auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op); 124 if (!fmi) 125 return builder; 126 127 // Regardless of what default FastMathFlags are used by FirOpBuilder, 128 // override them with FastMathFlags attached to the operation. 129 builder.setFastMathFlags(fmi.getFastMathFlagsAttr().getValue()); 130 return builder; 131 } 132 133 /// Generate function type for the simplified version of RTNAME(Sum) and 134 /// similar functions with a fir.box<none> type returning \p elementType. 135 static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder, 136 const mlir::Type &elementType) { 137 mlir::Type boxType = fir::BoxType::get(builder.getNoneType()); 138 return mlir::FunctionType::get(builder.getContext(), {boxType}, 139 {elementType}); 140 } 141 142 template <typename Op> 143 Op expectOp(mlir::Value val) { 144 if (Op op = mlir::dyn_cast_or_null<Op>(val.getDefiningOp())) 145 return op; 146 LLVM_DEBUG(llvm::dbgs() << "Didn't find expected " << Op::getOperationName() 147 << '\n'); 148 return nullptr; 149 } 150 151 template <typename Op> 152 static mlir::Value findDefSingle(fir::ConvertOp op) { 153 if (auto defOp = expectOp<Op>(op->getOperand(0))) { 154 return defOp.getResult(); 155 } 156 return {}; 157 } 158 159 template <typename... Ops> 160 static mlir::Value findDef(fir::ConvertOp op) { 161 mlir::Value defOp; 162 // Loop over the operation types given to see if any match, exiting once 163 // a match is found. Cast to void is needed to avoid compiler complaining 164 // that the result of expression is unused 165 (void)((defOp = findDefSingle<Ops>(op), (defOp)) || ...); 166 return defOp; 167 } 168 169 static bool isOperandAbsent(mlir::Value val) { 170 if (auto op = expectOp<fir::ConvertOp>(val)) { 171 assert(op->getOperands().size() != 0); 172 return mlir::isa_and_nonnull<fir::AbsentOp>( 173 op->getOperand(0).getDefiningOp()); 174 } 175 return false; 176 } 177 178 static bool isTrueOrNotConstant(mlir::Value val) { 179 if (auto op = expectOp<mlir::arith::ConstantOp>(val)) { 180 return !mlir::matchPattern(val, mlir::m_Zero()); 181 } 182 return true; 183 } 184 185 static bool isZero(mlir::Value val) { 186 if (auto op = expectOp<fir::ConvertOp>(val)) { 187 assert(op->getOperands().size() != 0); 188 if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp()) 189 return mlir::matchPattern(defOp, mlir::m_Zero()); 190 } 191 return false; 192 } 193 194 static mlir::Value findBoxDef(mlir::Value val) { 195 if (auto op = expectOp<fir::ConvertOp>(val)) { 196 assert(op->getOperands().size() != 0); 197 return findDef<fir::EmboxOp, fir::ReboxOp>(op); 198 } 199 return {}; 200 } 201 202 static mlir::Value findMaskDef(mlir::Value val) { 203 if (auto op = expectOp<fir::ConvertOp>(val)) { 204 assert(op->getOperands().size() != 0); 205 return findDef<fir::EmboxOp, fir::ReboxOp, fir::AbsentOp>(op); 206 } 207 return {}; 208 } 209 210 static unsigned getDimCount(mlir::Value val) { 211 // In order to find the dimensions count, we look for EmboxOp/ReboxOp 212 // and take the count from its *result* type. Note that in case 213 // of sliced emboxing the operand and the result of EmboxOp/ReboxOp 214 // have different types. 215 // Actually, we can take the box type from the operand of 216 // the first ConvertOp that has non-opaque box type that we meet 217 // going through the ConvertOp chain. 218 if (mlir::Value emboxVal = findBoxDef(val)) 219 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(emboxVal.getType())) 220 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy())) 221 return seqTy.getDimension(); 222 return 0; 223 } 224 225 /// Given the call operation's box argument \p val, discover 226 /// the element type of the underlying array object. 227 /// \returns the element type or std::nullopt if the type cannot 228 /// be reliably found. 229 /// We expect that the argument is a result of fir.convert 230 /// with the destination type of !fir.box<none>. 231 static std::optional<mlir::Type> getArgElementType(mlir::Value val) { 232 mlir::Operation *defOp; 233 do { 234 defOp = val.getDefiningOp(); 235 // Analyze only sequences of convert operations. 236 if (!mlir::isa<fir::ConvertOp>(defOp)) 237 return std::nullopt; 238 val = defOp->getOperand(0); 239 // The convert operation is expected to convert from one 240 // box type to another box type. 241 auto boxType = mlir::cast<fir::BoxType>(val.getType()); 242 auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType); 243 if (!mlir::isa<mlir::NoneType>(elementType)) 244 return elementType; 245 } while (true); 246 } 247 248 using BodyOpGeneratorTy = llvm::function_ref<mlir::Value( 249 fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value, 250 mlir::Value)>; 251 using ContinueLoopGenTy = llvm::function_ref<llvm::SmallVector<mlir::Value>( 252 fir::FirOpBuilder &, mlir::Location, mlir::Value)>; 253 254 /// Generate the reduction loop into \p funcOp. 255 /// 256 /// \p initVal is a function, called to get the initial value for 257 /// the reduction value 258 /// \p genBody is called to fill in the actual reduciton operation 259 /// for example add for SUM, MAX for MAXVAL, etc. 260 /// \p rank is the rank of the input argument. 261 /// \p elementType is the type of the elements in the input array, 262 /// which may be different to the return type. 263 /// \p loopCond is called to generate the condition to continue or 264 /// not for IterWhile loops 265 /// \p unorderedOrInitalLoopCond contains either a boolean or bool 266 /// mlir constant, and controls the inital value for while loops 267 /// or if DoLoop is ordered/unordered. 268 269 template <typename OP, typename T, int resultIndex> 270 static void 271 genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, 272 fir::InitValGeneratorTy initVal, ContinueLoopGenTy loopCond, 273 T unorderedOrInitialLoopCond, BodyOpGeneratorTy genBody, 274 unsigned rank, mlir::Type elementType, mlir::Location loc) { 275 276 mlir::IndexType idxTy = builder.getIndexType(); 277 278 mlir::Block::BlockArgListType args = funcOp.front().getArguments(); 279 mlir::Value arg = args[0]; 280 281 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); 282 283 fir::SequenceType::Shape flatShape(rank, 284 fir::SequenceType::getUnknownExtent()); 285 mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType); 286 mlir::Type boxArrTy = fir::BoxType::get(arrTy); 287 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg); 288 mlir::Type resultType = funcOp.getResultTypes()[0]; 289 mlir::Value init = initVal(builder, loc, resultType); 290 291 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds; 292 293 assert(rank > 0 && "rank cannot be zero"); 294 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); 295 296 // Compute all the upper bounds before the loop nest. 297 // It is not strictly necessary for performance, since the loop nest 298 // does not have any store operations and any LICM optimization 299 // should be able to optimize the redundancy. 300 for (unsigned i = 0; i < rank; ++i) { 301 mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i); 302 auto dims = 303 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx); 304 mlir::Value len = dims.getResult(1); 305 // We use C indexing here, so len-1 as loopcount 306 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one); 307 bounds.push_back(loopCount); 308 } 309 // Create a loop nest consisting of OP operations. 310 // Collect the loops' induction variables into indices array, 311 // which will be used in the innermost loop to load the input 312 // array's element. 313 // The loops are generated such that the innermost loop processes 314 // the 0 dimension. 315 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices; 316 for (unsigned i = rank; 0 < i; --i) { 317 mlir::Value step = one; 318 mlir::Value loopCount = bounds[i - 1]; 319 auto loop = builder.create<OP>(loc, zeroIdx, loopCount, step, 320 unorderedOrInitialLoopCond, 321 /*finalCountValue=*/false, init); 322 init = loop.getRegionIterArgs()[resultIndex]; 323 indices.push_back(loop.getInductionVar()); 324 // Set insertion point to the loop body so that the next loop 325 // is inserted inside the current one. 326 builder.setInsertionPointToStart(loop.getBody()); 327 } 328 329 // Reverse the indices such that they are ordered as: 330 // <dim-0-idx, dim-1-idx, ...> 331 std::reverse(indices.begin(), indices.end()); 332 // We are in the innermost loop: generate the reduction body. 333 mlir::Type eleRefTy = builder.getRefType(elementType); 334 mlir::Value addr = 335 builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices); 336 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr); 337 mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init); 338 // Generate vector with condition to continue while loop at [0] and result 339 // from current loop at [1] for IterWhileOp loops, just result at [0] for 340 // DoLoopOp loops. 341 llvm::SmallVector<mlir::Value> results = loopCond(builder, loc, reductionVal); 342 343 // Unwind the loop nest and insert ResultOp on each level 344 // to return the updated value of the reduction to the enclosing 345 // loops. 346 for (unsigned i = 0; i < rank; ++i) { 347 auto result = builder.create<fir::ResultOp>(loc, results); 348 // Proceed to the outer loop. 349 auto loop = mlir::cast<OP>(result->getParentOp()); 350 results = loop.getResults(); 351 // Set insertion point after the loop operation that we have 352 // just processed. 353 builder.setInsertionPointAfter(loop.getOperation()); 354 } 355 // End of loop nest. The insertion point is after the outermost loop. 356 // Return the reduction value from the function. 357 builder.create<mlir::func::ReturnOp>(loc, results[resultIndex]); 358 } 359 360 static llvm::SmallVector<mlir::Value> nopLoopCond(fir::FirOpBuilder &builder, 361 mlir::Location loc, 362 mlir::Value reductionVal) { 363 return {reductionVal}; 364 } 365 366 /// Generate function body of the simplified version of RTNAME(Sum) 367 /// with signature provided by \p funcOp. The caller is responsible 368 /// for saving/restoring the original insertion point of \p builder. 369 /// \p funcOp is expected to be empty on entry to this function. 370 /// \p rank specifies the rank of the input argument. 371 static void genRuntimeSumBody(fir::FirOpBuilder &builder, 372 mlir::func::FuncOp &funcOp, unsigned rank, 373 mlir::Type elementType) { 374 // function RTNAME(Sum)<T>x<rank>_simplified(arr) 375 // T, dimension(:) :: arr 376 // T sum = 0 377 // integer iter 378 // do iter = 0, extent(arr) 379 // sum = sum + arr[iter] 380 // end do 381 // RTNAME(Sum)<T>x<rank>_simplified = sum 382 // end function RTNAME(Sum)<T>x<rank>_simplified 383 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc, 384 mlir::Type elementType) { 385 if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) { 386 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 387 return builder.createRealConstant(loc, elementType, 388 llvm::APFloat::getZero(sem)); 389 } 390 return builder.createIntegerConstant(loc, elementType, 0); 391 }; 392 393 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 394 mlir::Type elementType, mlir::Value elem1, 395 mlir::Value elem2) -> mlir::Value { 396 if (mlir::isa<mlir::FloatType>(elementType)) 397 return builder.create<mlir::arith::AddFOp>(loc, elem1, elem2); 398 if (mlir::isa<mlir::IntegerType>(elementType)) 399 return builder.create<mlir::arith::AddIOp>(loc, elem1, elem2); 400 401 llvm_unreachable("unsupported type"); 402 return {}; 403 }; 404 405 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 406 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 407 408 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond, 409 false, genBodyOp, rank, elementType, 410 loc); 411 } 412 413 static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder, 414 mlir::func::FuncOp &funcOp, unsigned rank, 415 mlir::Type elementType) { 416 auto init = [](fir::FirOpBuilder builder, mlir::Location loc, 417 mlir::Type elementType) { 418 if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) { 419 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 420 return builder.createRealConstant( 421 loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/true)); 422 } 423 unsigned bits = elementType.getIntOrFloatBitWidth(); 424 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); 425 return builder.createIntegerConstant(loc, elementType, minInt); 426 }; 427 428 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 429 mlir::Type elementType, mlir::Value elem1, 430 mlir::Value elem2) -> mlir::Value { 431 if (mlir::isa<mlir::FloatType>(elementType)) { 432 // arith.maxf later converted to llvm.intr.maxnum does not work 433 // correctly for NaNs and -0.0 (see maxnum/minnum pattern matching 434 // in LLVM's InstCombine pass). Moreover, llvm.intr.maxnum 435 // for F128 operands is lowered into fmaxl call by LLVM. 436 // This libm function may not work properly for F128 arguments 437 // on targets where long double is not F128. It is an LLVM issue, 438 // but we just use normal select here to resolve all the cases. 439 auto compare = builder.create<mlir::arith::CmpFOp>( 440 loc, mlir::arith::CmpFPredicate::OGT, elem1, elem2); 441 return builder.create<mlir::arith::SelectOp>(loc, compare, elem1, elem2); 442 } 443 if (mlir::isa<mlir::IntegerType>(elementType)) 444 return builder.create<mlir::arith::MaxSIOp>(loc, elem1, elem2); 445 446 llvm_unreachable("unsupported type"); 447 return {}; 448 }; 449 450 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 451 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 452 453 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, init, nopLoopCond, 454 false, genBodyOp, rank, elementType, 455 loc); 456 } 457 458 static void genRuntimeCountBody(fir::FirOpBuilder &builder, 459 mlir::func::FuncOp &funcOp, unsigned rank, 460 mlir::Type elementType) { 461 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc, 462 mlir::Type elementType) { 463 unsigned bits = elementType.getIntOrFloatBitWidth(); 464 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 465 return builder.createIntegerConstant(loc, elementType, zeroInt); 466 }; 467 468 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 469 mlir::Type elementType, mlir::Value elem1, 470 mlir::Value elem2) -> mlir::Value { 471 auto zero32 = builder.createIntegerConstant(loc, elementType, 0); 472 auto zero64 = builder.createIntegerConstant(loc, builder.getI64Type(), 0); 473 auto one64 = builder.createIntegerConstant(loc, builder.getI64Type(), 1); 474 475 auto compare = builder.create<mlir::arith::CmpIOp>( 476 loc, mlir::arith::CmpIPredicate::eq, elem1, zero32); 477 auto select = 478 builder.create<mlir::arith::SelectOp>(loc, compare, zero64, one64); 479 return builder.create<mlir::arith::AddIOp>(loc, select, elem2); 480 }; 481 482 // Count always gets I32 for elementType as it converts logical input to 483 // logical<4> before passing to the function. 484 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 485 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 486 487 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond, 488 false, genBodyOp, rank, elementType, 489 loc); 490 } 491 492 static void genRuntimeAnyBody(fir::FirOpBuilder &builder, 493 mlir::func::FuncOp &funcOp, unsigned rank, 494 mlir::Type elementType) { 495 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc, 496 mlir::Type elementType) { 497 return builder.createIntegerConstant(loc, elementType, 0); 498 }; 499 500 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 501 mlir::Type elementType, mlir::Value elem1, 502 mlir::Value elem2) -> mlir::Value { 503 auto zero = builder.createIntegerConstant(loc, elementType, 0); 504 return builder.create<mlir::arith::CmpIOp>( 505 loc, mlir::arith::CmpIPredicate::ne, elem1, zero); 506 }; 507 508 auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc, 509 mlir::Value reductionVal) { 510 auto one1 = builder.createIntegerConstant(loc, builder.getI1Type(), 1); 511 auto eor = builder.create<mlir::arith::XOrIOp>(loc, reductionVal, one1); 512 llvm::SmallVector<mlir::Value> results = {eor, reductionVal}; 513 return results; 514 }; 515 516 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 517 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 518 mlir::Value ok = builder.createBool(loc, true); 519 520 genReductionLoop<fir::IterWhileOp, mlir::Value, 1>( 521 builder, funcOp, zero, continueCond, ok, genBodyOp, rank, elementType, 522 loc); 523 } 524 525 static void genRuntimeAllBody(fir::FirOpBuilder &builder, 526 mlir::func::FuncOp &funcOp, unsigned rank, 527 mlir::Type elementType) { 528 auto one = [](fir::FirOpBuilder builder, mlir::Location loc, 529 mlir::Type elementType) { 530 return builder.createIntegerConstant(loc, elementType, 1); 531 }; 532 533 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 534 mlir::Type elementType, mlir::Value elem1, 535 mlir::Value elem2) -> mlir::Value { 536 auto zero = builder.createIntegerConstant(loc, elementType, 0); 537 return builder.create<mlir::arith::CmpIOp>( 538 loc, mlir::arith::CmpIPredicate::ne, elem1, zero); 539 }; 540 541 auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc, 542 mlir::Value reductionVal) { 543 llvm::SmallVector<mlir::Value> results = {reductionVal, reductionVal}; 544 return results; 545 }; 546 547 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 548 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 549 mlir::Value ok = builder.createBool(loc, true); 550 551 genReductionLoop<fir::IterWhileOp, mlir::Value, 1>( 552 builder, funcOp, one, continueCond, ok, genBodyOp, rank, elementType, 553 loc); 554 } 555 556 static mlir::FunctionType genRuntimeMinlocType(fir::FirOpBuilder &builder, 557 unsigned int rank) { 558 mlir::Type boxType = fir::BoxType::get(builder.getNoneType()); 559 mlir::Type boxRefType = builder.getRefType(boxType); 560 561 return mlir::FunctionType::get(builder.getContext(), 562 {boxRefType, boxType, boxType}, {}); 563 } 564 565 // Produces a loop nest for a Minloc intrinsic. 566 void fir::genMinMaxlocReductionLoop( 567 fir::FirOpBuilder &builder, mlir::Value array, 568 fir::InitValGeneratorTy initVal, fir::MinlocBodyOpGeneratorTy genBody, 569 fir::AddrGeneratorTy getAddrFn, unsigned rank, mlir::Type elementType, 570 mlir::Location loc, mlir::Type maskElemType, mlir::Value resultArr, 571 bool maskMayBeLogicalScalar) { 572 mlir::IndexType idxTy = builder.getIndexType(); 573 574 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); 575 576 fir::SequenceType::Shape flatShape(rank, 577 fir::SequenceType::getUnknownExtent()); 578 mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType); 579 mlir::Type boxArrTy = fir::BoxType::get(arrTy); 580 array = builder.create<fir::ConvertOp>(loc, boxArrTy, array); 581 582 mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType()); 583 mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1); 584 mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0); 585 mlir::Value flagRef = builder.createTemporary(loc, resultElemType); 586 builder.create<fir::StoreOp>(loc, zero, flagRef); 587 588 mlir::Value init = initVal(builder, loc, elementType); 589 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds; 590 591 assert(rank > 0 && "rank cannot be zero"); 592 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); 593 594 // Compute all the upper bounds before the loop nest. 595 // It is not strictly necessary for performance, since the loop nest 596 // does not have any store operations and any LICM optimization 597 // should be able to optimize the redundancy. 598 for (unsigned i = 0; i < rank; ++i) { 599 mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i); 600 auto dims = 601 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx); 602 mlir::Value len = dims.getResult(1); 603 // We use C indexing here, so len-1 as loopcount 604 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one); 605 bounds.push_back(loopCount); 606 } 607 // Create a loop nest consisting of OP operations. 608 // Collect the loops' induction variables into indices array, 609 // which will be used in the innermost loop to load the input 610 // array's element. 611 // The loops are generated such that the innermost loop processes 612 // the 0 dimension. 613 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices; 614 for (unsigned i = rank; 0 < i; --i) { 615 mlir::Value step = one; 616 mlir::Value loopCount = bounds[i - 1]; 617 auto loop = 618 builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false, 619 /*finalCountValue=*/false, init); 620 init = loop.getRegionIterArgs()[0]; 621 indices.push_back(loop.getInductionVar()); 622 // Set insertion point to the loop body so that the next loop 623 // is inserted inside the current one. 624 builder.setInsertionPointToStart(loop.getBody()); 625 } 626 627 // Reverse the indices such that they are ordered as: 628 // <dim-0-idx, dim-1-idx, ...> 629 std::reverse(indices.begin(), indices.end()); 630 mlir::Value reductionVal = 631 genBody(builder, loc, elementType, array, flagRef, init, indices); 632 633 // Unwind the loop nest and insert ResultOp on each level 634 // to return the updated value of the reduction to the enclosing 635 // loops. 636 for (unsigned i = 0; i < rank; ++i) { 637 auto result = builder.create<fir::ResultOp>(loc, reductionVal); 638 // Proceed to the outer loop. 639 auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp()); 640 reductionVal = loop.getResult(0); 641 // Set insertion point after the loop operation that we have 642 // just processed. 643 builder.setInsertionPointAfter(loop.getOperation()); 644 } 645 // End of loop nest. The insertion point is after the outermost loop. 646 if (maskMayBeLogicalScalar) { 647 if (fir::IfOp ifOp = 648 mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) { 649 builder.create<fir::ResultOp>(loc, reductionVal); 650 builder.setInsertionPointAfter(ifOp); 651 // Redefine flagSet to escape scope of ifOp 652 flagSet = builder.createIntegerConstant(loc, resultElemType, 1); 653 reductionVal = ifOp.getResult(0); 654 } 655 } 656 } 657 658 static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder, 659 mlir::func::FuncOp &funcOp, bool isMax, 660 unsigned rank, int maskRank, 661 mlir::Type elementType, 662 mlir::Type maskElemType, 663 mlir::Type resultElemTy, bool isDim) { 664 auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc, 665 mlir::Type elementType) { 666 if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) { 667 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 668 llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax); 669 return builder.createRealConstant(loc, elementType, limit); 670 } 671 unsigned bits = elementType.getIntOrFloatBitWidth(); 672 int64_t initValue = (isMax ? llvm::APInt::getSignedMinValue(bits) 673 : llvm::APInt::getSignedMaxValue(bits)) 674 .getSExtValue(); 675 return builder.createIntegerConstant(loc, elementType, initValue); 676 }; 677 678 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 679 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 680 681 mlir::Value mask = funcOp.front().getArgument(2); 682 683 // Set up result array in case of early exit / 0 length array 684 mlir::IndexType idxTy = builder.getIndexType(); 685 mlir::Type resultTy = fir::SequenceType::get(rank, resultElemTy); 686 mlir::Type resultHeapTy = fir::HeapType::get(resultTy); 687 mlir::Type resultBoxTy = fir::BoxType::get(resultHeapTy); 688 689 mlir::Value returnValue = builder.createIntegerConstant(loc, resultElemTy, 0); 690 mlir::Value resultArrSize = builder.createIntegerConstant(loc, idxTy, rank); 691 692 mlir::Value resultArrInit = builder.create<fir::AllocMemOp>(loc, resultTy); 693 mlir::Value resultArrShape = builder.create<fir::ShapeOp>(loc, resultArrSize); 694 mlir::Value resultArr = builder.create<fir::EmboxOp>( 695 loc, resultBoxTy, resultArrInit, resultArrShape); 696 697 mlir::Type resultRefTy = builder.getRefType(resultElemTy); 698 699 if (maskRank > 0) { 700 fir::SequenceType::Shape flatShape(rank, 701 fir::SequenceType::getUnknownExtent()); 702 mlir::Type maskTy = fir::SequenceType::get(flatShape, maskElemType); 703 mlir::Type boxMaskTy = fir::BoxType::get(maskTy); 704 mask = builder.create<fir::ConvertOp>(loc, boxMaskTy, mask); 705 } 706 707 for (unsigned int i = 0; i < rank; ++i) { 708 mlir::Value index = builder.createIntegerConstant(loc, idxTy, i); 709 mlir::Value resultElemAddr = 710 builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr, index); 711 builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr); 712 } 713 714 auto genBodyOp = 715 [&rank, &resultArr, isMax, &mask, &maskElemType, &maskRank]( 716 fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType, 717 mlir::Value array, mlir::Value flagRef, mlir::Value reduction, 718 const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value { 719 // We are in the innermost loop: generate the reduction body. 720 if (maskRank > 0) { 721 mlir::Type logicalRef = builder.getRefType(maskElemType); 722 mlir::Value maskAddr = 723 builder.create<fir::CoordinateOp>(loc, logicalRef, mask, indices); 724 mlir::Value maskElem = builder.create<fir::LoadOp>(loc, maskAddr); 725 726 // fir::IfOp requires argument to be I1 - won't accept logical or any 727 // other Integer. 728 mlir::Type ifCompatType = builder.getI1Type(); 729 mlir::Value ifCompatElem = 730 builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem); 731 732 llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType}; 733 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, ifCompatElem, 734 /*withElseRegion=*/true); 735 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 736 } 737 738 // Set flag that mask was true at some point 739 mlir::Value flagSet = builder.createIntegerConstant( 740 loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1); 741 mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef); 742 mlir::Type eleRefTy = builder.getRefType(elementType); 743 mlir::Value addr = 744 builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices); 745 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr); 746 747 mlir::Value cmp; 748 if (mlir::isa<mlir::FloatType>(elementType)) { 749 // For FP reductions we want the first smallest value to be used, that 750 // is not NaN. A OGL/OLT condition will usually work for this unless all 751 // the values are Nan or Inf. This follows the same logic as 752 // NumericCompare for Minloc/Maxlox in extrema.cpp. 753 cmp = builder.create<mlir::arith::CmpFOp>( 754 loc, 755 isMax ? mlir::arith::CmpFPredicate::OGT 756 : mlir::arith::CmpFPredicate::OLT, 757 elem, reduction); 758 759 mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>( 760 loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction); 761 mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>( 762 loc, mlir::arith::CmpFPredicate::OEQ, elem, elem); 763 cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2); 764 cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan); 765 } else if (mlir::isa<mlir::IntegerType>(elementType)) { 766 cmp = builder.create<mlir::arith::CmpIOp>( 767 loc, 768 isMax ? mlir::arith::CmpIPredicate::sgt 769 : mlir::arith::CmpIPredicate::slt, 770 elem, reduction); 771 } else { 772 llvm_unreachable("unsupported type"); 773 } 774 775 // The condition used for the loop is isFirst || <the condition above>. 776 isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst); 777 isFirst = builder.create<mlir::arith::XOrIOp>( 778 loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1)); 779 cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst); 780 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp, 781 /*withElseRegion*/ true); 782 783 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 784 builder.create<fir::StoreOp>(loc, flagSet, flagRef); 785 mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType()); 786 mlir::Type returnRefTy = builder.getRefType(resultElemTy); 787 mlir::IndexType idxTy = builder.getIndexType(); 788 789 mlir::Value one = builder.createIntegerConstant(loc, resultElemTy, 1); 790 791 for (unsigned int i = 0; i < rank; ++i) { 792 mlir::Value index = builder.createIntegerConstant(loc, idxTy, i); 793 mlir::Value resultElemAddr = 794 builder.create<fir::CoordinateOp>(loc, returnRefTy, resultArr, index); 795 mlir::Value convert = 796 builder.create<fir::ConvertOp>(loc, resultElemTy, indices[i]); 797 mlir::Value fortranIndex = 798 builder.create<mlir::arith::AddIOp>(loc, convert, one); 799 builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr); 800 } 801 builder.create<fir::ResultOp>(loc, elem); 802 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 803 builder.create<fir::ResultOp>(loc, reduction); 804 builder.setInsertionPointAfter(ifOp); 805 mlir::Value reductionVal = ifOp.getResult(0); 806 807 // Close the mask if needed 808 if (maskRank > 0) { 809 fir::IfOp ifOp = 810 mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp()); 811 builder.create<fir::ResultOp>(loc, reductionVal); 812 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 813 builder.create<fir::ResultOp>(loc, reduction); 814 reductionVal = ifOp.getResult(0); 815 builder.setInsertionPointAfter(ifOp); 816 } 817 818 return reductionVal; 819 }; 820 821 // if mask is a logical scalar, we can check its value before the main loop 822 // and either ignore the fact it is there or exit early. 823 if (maskRank == 0) { 824 mlir::Type logical = builder.getI1Type(); 825 mlir::IndexType idxTy = builder.getIndexType(); 826 827 fir::SequenceType::Shape singleElement(1, 1); 828 mlir::Type arrTy = fir::SequenceType::get(singleElement, logical); 829 mlir::Type boxArrTy = fir::BoxType::get(arrTy); 830 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, mask); 831 832 mlir::Value indx = builder.createIntegerConstant(loc, idxTy, 0); 833 mlir::Type logicalRefTy = builder.getRefType(logical); 834 mlir::Value condAddr = 835 builder.create<fir::CoordinateOp>(loc, logicalRefTy, array, indx); 836 mlir::Value cond = builder.create<fir::LoadOp>(loc, condAddr); 837 838 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cond, 839 /*withElseRegion=*/true); 840 841 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 842 mlir::Value basicValue; 843 if (mlir::isa<mlir::IntegerType>(elementType)) { 844 basicValue = builder.createIntegerConstant(loc, elementType, 0); 845 } else { 846 basicValue = builder.createRealConstant(loc, elementType, 0); 847 } 848 builder.create<fir::ResultOp>(loc, basicValue); 849 850 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 851 } 852 auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc, 853 const mlir::Type &resultElemType, mlir::Value resultArr, 854 mlir::Value index) { 855 mlir::Type resultRefTy = builder.getRefType(resultElemType); 856 return builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr, 857 index); 858 }; 859 860 genMinMaxlocReductionLoop(builder, funcOp.front().getArgument(1), init, 861 genBodyOp, getAddrFn, rank, elementType, loc, 862 maskElemType, resultArr, maskRank == 0); 863 864 // Store newly created output array to the reference passed in 865 if (isDim) { 866 mlir::Type resultBoxTy = 867 fir::BoxType::get(fir::HeapType::get(resultElemTy)); 868 mlir::Value outputArr = builder.create<fir::ConvertOp>( 869 loc, builder.getRefType(resultBoxTy), funcOp.front().getArgument(0)); 870 mlir::Value resultArrScalar = builder.create<fir::ConvertOp>( 871 loc, fir::HeapType::get(resultElemTy), resultArrInit); 872 mlir::Value resultBox = 873 builder.create<fir::EmboxOp>(loc, resultBoxTy, resultArrScalar); 874 builder.create<fir::StoreOp>(loc, resultBox, outputArr); 875 } else { 876 fir::SequenceType::Shape resultShape(1, rank); 877 mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemTy); 878 mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy); 879 mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy); 880 mlir::Type outputRefTy = builder.getRefType(outputBoxTy); 881 mlir::Value outputArr = builder.create<fir::ConvertOp>( 882 loc, outputRefTy, funcOp.front().getArgument(0)); 883 builder.create<fir::StoreOp>(loc, resultArr, outputArr); 884 } 885 886 builder.create<mlir::func::ReturnOp>(loc); 887 } 888 889 /// Generate function type for the simplified version of RTNAME(DotProduct) 890 /// operating on the given \p elementType. 891 static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder, 892 const mlir::Type &elementType) { 893 mlir::Type boxType = fir::BoxType::get(builder.getNoneType()); 894 return mlir::FunctionType::get(builder.getContext(), {boxType, boxType}, 895 {elementType}); 896 } 897 898 /// Generate function body of the simplified version of RTNAME(DotProduct) 899 /// with signature provided by \p funcOp. The caller is responsible 900 /// for saving/restoring the original insertion point of \p builder. 901 /// \p funcOp is expected to be empty on entry to this function. 902 /// \p arg1ElementTy and \p arg2ElementTy specify elements types 903 /// of the underlying array objects - they are used to generate proper 904 /// element accesses. 905 static void genRuntimeDotBody(fir::FirOpBuilder &builder, 906 mlir::func::FuncOp &funcOp, 907 mlir::Type arg1ElementTy, 908 mlir::Type arg2ElementTy) { 909 // function RTNAME(DotProduct)<T>_simplified(arr1, arr2) 910 // T, dimension(:) :: arr1, arr2 911 // T product = 0 912 // integer iter 913 // do iter = 0, extent(arr1) 914 // product = product + arr1[iter] * arr2[iter] 915 // end do 916 // RTNAME(ADotProduct)<T>_simplified = product 917 // end function RTNAME(DotProduct)<T>_simplified 918 auto loc = mlir::UnknownLoc::get(builder.getContext()); 919 mlir::Type resultElementType = funcOp.getResultTypes()[0]; 920 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 921 922 mlir::IndexType idxTy = builder.getIndexType(); 923 924 mlir::Value zero = 925 mlir::isa<mlir::FloatType>(resultElementType) 926 ? builder.createRealConstant(loc, resultElementType, 0.0) 927 : builder.createIntegerConstant(loc, resultElementType, 0); 928 929 mlir::Block::BlockArgListType args = funcOp.front().getArguments(); 930 mlir::Value arg1 = args[0]; 931 mlir::Value arg2 = args[1]; 932 933 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); 934 935 fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()}; 936 mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy); 937 mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1); 938 mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1); 939 mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy); 940 mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2); 941 mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2); 942 // This version takes the loop trip count from the first argument. 943 // If the first argument's box has unknown (at compilation time) 944 // extent, then it may be better to take the extent from the second 945 // argument - so that after inlining the loop may be better optimized, e.g. 946 // fully unrolled. This requires generating two versions of the simplified 947 // function and some analysis at the call site to choose which version 948 // is more profitable to call. 949 // Note that we can assume that both arguments have the same extent. 950 auto dims = 951 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx); 952 mlir::Value len = dims.getResult(1); 953 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); 954 mlir::Value step = one; 955 956 // We use C indexing here, so len-1 as loopcount 957 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one); 958 auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, 959 /*unordered=*/false, 960 /*finalCountValue=*/false, zero); 961 mlir::Value sumVal = loop.getRegionIterArgs()[0]; 962 963 // Begin loop code 964 mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint(); 965 builder.setInsertionPointToStart(loop.getBody()); 966 967 mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy); 968 mlir::Value index = loop.getInductionVar(); 969 mlir::Value addr1 = 970 builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index); 971 mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1); 972 // Convert to the result type. 973 elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1); 974 975 mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy); 976 mlir::Value addr2 = 977 builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index); 978 mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2); 979 // Convert to the result type. 980 elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2); 981 982 if (mlir::isa<mlir::FloatType>(resultElementType)) 983 sumVal = builder.create<mlir::arith::AddFOp>( 984 loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal); 985 else if (mlir::isa<mlir::IntegerType>(resultElementType)) 986 sumVal = builder.create<mlir::arith::AddIOp>( 987 loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal); 988 else 989 llvm_unreachable("unsupported type"); 990 991 builder.create<fir::ResultOp>(loc, sumVal); 992 // End of loop. 993 builder.restoreInsertionPoint(loopEndPt); 994 995 mlir::Value resultVal = loop.getResult(0); 996 builder.create<mlir::func::ReturnOp>(loc, resultVal); 997 } 998 999 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction( 1000 fir::FirOpBuilder &builder, const mlir::StringRef &baseName, 1001 FunctionTypeGeneratorTy typeGenerator, 1002 FunctionBodyGeneratorTy bodyGenerator) { 1003 // WARNING: if the function generated here changes its signature 1004 // or behavior (the body code), we should probably embed some 1005 // versioning information into its name, otherwise libraries 1006 // statically linked with older versions of Flang may stop 1007 // working with object files created with newer Flang. 1008 // We can also avoid this by using internal linkage, but 1009 // this may increase the size of final executable/shared library. 1010 std::string replacementName = mlir::Twine{baseName, "_simplified"}.str(); 1011 // If we already have a function, just return it. 1012 mlir::func::FuncOp newFunc = builder.getNamedFunction(replacementName); 1013 mlir::FunctionType fType = typeGenerator(builder); 1014 if (newFunc) { 1015 assert(newFunc.getFunctionType() == fType && 1016 "type mismatch for simplified function"); 1017 return newFunc; 1018 } 1019 1020 // Need to build the function! 1021 auto loc = mlir::UnknownLoc::get(builder.getContext()); 1022 newFunc = builder.createFunction(loc, replacementName, fType); 1023 auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR; 1024 auto linkage = 1025 mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); 1026 newFunc->setAttr("llvm.linkage", linkage); 1027 1028 // Save the position of the original call. 1029 mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint(); 1030 1031 bodyGenerator(builder, newFunc); 1032 1033 // Now back to where we were adding code earlier... 1034 builder.restoreInsertionPoint(insertPt); 1035 1036 return newFunc; 1037 } 1038 1039 void SimplifyIntrinsicsPass::simplifyIntOrFloatReduction( 1040 fir::CallOp call, const fir::KindMapping &kindMap, 1041 GenReductionBodyTy genBodyFunc) { 1042 // args[1] and args[2] are source filename and line number, ignored. 1043 mlir::Operation::operand_range args = call.getArgs(); 1044 1045 const mlir::Value &dim = args[3]; 1046 const mlir::Value &mask = args[4]; 1047 // dim is zero when it is absent, which is an implementation 1048 // detail in the runtime library. 1049 1050 bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask); 1051 unsigned rank = getDimCount(args[0]); 1052 1053 // Rank is set to 0 for assumed shape arrays, don't simplify 1054 // in these cases 1055 if (!(dimAndMaskAbsent && rank > 0)) 1056 return; 1057 1058 mlir::Type resultType = call.getResult(0).getType(); 1059 1060 if (!mlir::isa<mlir::FloatType>(resultType) && 1061 !mlir::isa<mlir::IntegerType>(resultType)) 1062 return; 1063 1064 auto argType = getArgElementType(args[0]); 1065 if (!argType) 1066 return; 1067 assert(*argType == resultType && 1068 "Argument/result types mismatch in reduction"); 1069 1070 mlir::SymbolRefAttr callee = call.getCalleeAttr(); 1071 1072 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; 1073 std::string fmfString{builder.getFastMathFlagsString()}; 1074 std::string funcName = 1075 (mlir::Twine{callee.getLeafReference().getValue(), "x"} + 1076 mlir::Twine{rank} + 1077 // We must mangle the generated function name with FastMathFlags 1078 // value. 1079 (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString})) 1080 .str(); 1081 1082 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName, 1083 resultType); 1084 } 1085 1086 void SimplifyIntrinsicsPass::simplifyLogicalDim0Reduction( 1087 fir::CallOp call, const fir::KindMapping &kindMap, 1088 GenReductionBodyTy genBodyFunc) { 1089 1090 mlir::Operation::operand_range args = call.getArgs(); 1091 const mlir::Value &dim = args[3]; 1092 unsigned rank = getDimCount(args[0]); 1093 1094 // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in 1095 // these cases. 1096 if (!(isZero(dim) && rank > 0)) 1097 return; 1098 1099 mlir::Value inputBox = findBoxDef(args[0]); 1100 1101 mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType()); 1102 mlir::SymbolRefAttr callee = call.getCalleeAttr(); 1103 1104 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; 1105 1106 // Treating logicals as integers makes things a lot easier 1107 fir::LogicalType logicalType = { 1108 mlir::dyn_cast<fir::LogicalType>(elementType)}; 1109 fir::KindTy kind = logicalType.getFKind(); 1110 mlir::Type intElementType = builder.getIntegerType(kind * 8); 1111 1112 // Mangle kind into function name as it is not done by default 1113 std::string funcName = 1114 (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} + 1115 mlir::Twine{kind} + "x" + mlir::Twine{rank}) 1116 .str(); 1117 1118 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName, 1119 intElementType); 1120 } 1121 1122 void SimplifyIntrinsicsPass::simplifyLogicalDim1Reduction( 1123 fir::CallOp call, const fir::KindMapping &kindMap, 1124 GenReductionBodyTy genBodyFunc) { 1125 1126 mlir::Operation::operand_range args = call.getArgs(); 1127 mlir::SymbolRefAttr callee = call.getCalleeAttr(); 1128 mlir::StringRef funcNameBase = callee.getLeafReference().getValue(); 1129 unsigned rank = getDimCount(args[0]); 1130 1131 // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in 1132 // these cases. We check for Dim at the end as some logical functions (Any, 1133 // All) set dim to 1 instead of 0 when the argument is not present. 1134 if (funcNameBase.ends_with("Dim") || !(rank > 0)) 1135 return; 1136 1137 mlir::Value inputBox = findBoxDef(args[0]); 1138 mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType()); 1139 1140 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; 1141 1142 // Treating logicals as integers makes things a lot easier 1143 fir::LogicalType logicalType = { 1144 mlir::dyn_cast<fir::LogicalType>(elementType)}; 1145 fir::KindTy kind = logicalType.getFKind(); 1146 mlir::Type intElementType = builder.getIntegerType(kind * 8); 1147 1148 // Mangle kind into function name as it is not done by default 1149 std::string funcName = 1150 (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} + 1151 mlir::Twine{kind} + "x" + mlir::Twine{rank}) 1152 .str(); 1153 1154 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName, 1155 intElementType); 1156 } 1157 1158 void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction( 1159 fir::CallOp call, const fir::KindMapping &kindMap, bool isMax) { 1160 1161 mlir::Operation::operand_range args = call.getArgs(); 1162 1163 mlir::SymbolRefAttr callee = call.getCalleeAttr(); 1164 mlir::StringRef funcNameBase = callee.getLeafReference().getValue(); 1165 bool isDim = funcNameBase.ends_with("Dim"); 1166 mlir::Value back = args[isDim ? 7 : 6]; 1167 if (isTrueOrNotConstant(back)) 1168 return; 1169 1170 mlir::Value mask = args[isDim ? 6 : 5]; 1171 mlir::Value maskDef = findMaskDef(mask); 1172 1173 // maskDef is set to NULL when the defining op is not one we accept. 1174 // This tends to be because it is a selectOp, in which case let the 1175 // runtime deal with it. 1176 if (maskDef == NULL) 1177 return; 1178 1179 unsigned rank = getDimCount(args[1]); 1180 if ((isDim && rank != 1) || !(rank > 0)) 1181 return; 1182 1183 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; 1184 mlir::Location loc = call.getLoc(); 1185 auto inputBox = findBoxDef(args[1]); 1186 mlir::Type inputType = hlfir::getFortranElementType(inputBox.getType()); 1187 1188 if (mlir::isa<fir::CharacterType>(inputType)) 1189 return; 1190 1191 int maskRank; 1192 fir::KindTy kind = 0; 1193 mlir::Type logicalElemType = builder.getI1Type(); 1194 if (isOperandAbsent(mask)) { 1195 maskRank = -1; 1196 } else { 1197 maskRank = getDimCount(mask); 1198 mlir::Type maskElemTy = hlfir::getFortranElementType(maskDef.getType()); 1199 fir::LogicalType logicalFirType = { 1200 mlir::dyn_cast<fir::LogicalType>(maskElemTy)}; 1201 kind = logicalFirType.getFKind(); 1202 // Convert fir::LogicalType to mlir::Type 1203 logicalElemType = logicalFirType; 1204 } 1205 1206 mlir::Operation *outputDef = args[0].getDefiningOp(); 1207 mlir::Value outputAlloc = outputDef->getOperand(0); 1208 mlir::Type outType = hlfir::getFortranElementType(outputAlloc.getType()); 1209 1210 std::string fmfString{builder.getFastMathFlagsString()}; 1211 std::string funcName = 1212 (mlir::Twine{callee.getLeafReference().getValue(), "x"} + 1213 mlir::Twine{rank} + 1214 (maskRank >= 0 1215 ? "_Logical" + mlir::Twine{kind} + "x" + mlir::Twine{maskRank} 1216 : "") + 1217 "_") 1218 .str(); 1219 1220 llvm::raw_string_ostream nameOS(funcName); 1221 outType.print(nameOS); 1222 if (isDim) 1223 nameOS << '_' << inputType; 1224 nameOS << '_' << fmfString; 1225 1226 auto typeGenerator = [rank](fir::FirOpBuilder &builder) { 1227 return genRuntimeMinlocType(builder, rank); 1228 }; 1229 auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType, 1230 isMax, isDim](fir::FirOpBuilder &builder, 1231 mlir::func::FuncOp &funcOp) { 1232 genRuntimeMinMaxlocBody(builder, funcOp, isMax, rank, maskRank, inputType, 1233 logicalElemType, outType, isDim); 1234 }; 1235 1236 mlir::func::FuncOp newFunc = 1237 getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator); 1238 builder.create<fir::CallOp>(loc, newFunc, 1239 mlir::ValueRange{args[0], args[1], mask}); 1240 call->dropAllReferences(); 1241 call->erase(); 1242 } 1243 1244 void SimplifyIntrinsicsPass::simplifyReductionBody( 1245 fir::CallOp call, const fir::KindMapping &kindMap, 1246 GenReductionBodyTy genBodyFunc, fir::FirOpBuilder &builder, 1247 const mlir::StringRef &funcName, mlir::Type elementType) { 1248 1249 mlir::Operation::operand_range args = call.getArgs(); 1250 1251 mlir::Type resultType = call.getResult(0).getType(); 1252 unsigned rank = getDimCount(args[0]); 1253 1254 mlir::Location loc = call.getLoc(); 1255 1256 auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) { 1257 return genNoneBoxType(builder, resultType); 1258 }; 1259 auto bodyGenerator = [&rank, &genBodyFunc, 1260 &elementType](fir::FirOpBuilder &builder, 1261 mlir::func::FuncOp &funcOp) { 1262 genBodyFunc(builder, funcOp, rank, elementType); 1263 }; 1264 // Mangle the function name with the rank value as "x<rank>". 1265 mlir::func::FuncOp newFunc = 1266 getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator); 1267 auto newCall = 1268 builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]}); 1269 call->replaceAllUsesWith(newCall.getResults()); 1270 call->dropAllReferences(); 1271 call->erase(); 1272 } 1273 1274 void SimplifyIntrinsicsPass::runOnOperation() { 1275 LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n"); 1276 mlir::ModuleOp module = getOperation(); 1277 fir::KindMapping kindMap = fir::getKindMapping(module); 1278 module.walk([&](mlir::Operation *op) { 1279 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) { 1280 if (cuf::isInCUDADeviceContext(op)) 1281 return; 1282 if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) { 1283 mlir::StringRef funcName = callee.getLeafReference().getValue(); 1284 // Replace call to runtime function for SUM when it has single 1285 // argument (no dim or mask argument) for 1D arrays with either 1286 // Integer4 or Real8 types. Other forms are ignored. 1287 // The new function is added to the module. 1288 // 1289 // Prototype for runtime call (from sum.cpp): 1290 // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line, 1291 // int dim, const Descriptor *mask) 1292 // 1293 if (funcName.starts_with(RTNAME_STRING(Sum))) { 1294 simplifyIntOrFloatReduction(call, kindMap, genRuntimeSumBody); 1295 return; 1296 } 1297 if (funcName.starts_with(RTNAME_STRING(DotProduct))) { 1298 LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n"); 1299 LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump(); 1300 llvm::dbgs() << "\n"); 1301 mlir::Operation::operand_range args = call.getArgs(); 1302 const mlir::Value &v1 = args[0]; 1303 const mlir::Value &v2 = args[1]; 1304 mlir::Location loc = call.getLoc(); 1305 fir::FirOpBuilder builder{getSimplificationBuilder(op, kindMap)}; 1306 // Stringize the builder's FastMathFlags flags for mangling 1307 // the generated function name. 1308 std::string fmfString{builder.getFastMathFlagsString()}; 1309 1310 mlir::Type type = call.getResult(0).getType(); 1311 if (!mlir::isa<mlir::FloatType>(type) && 1312 !mlir::isa<mlir::IntegerType>(type)) 1313 return; 1314 1315 // Try to find the element types of the boxed arguments. 1316 auto arg1Type = getArgElementType(v1); 1317 auto arg2Type = getArgElementType(v2); 1318 1319 if (!arg1Type || !arg2Type) 1320 return; 1321 1322 // Support only floating point and integer arguments 1323 // now (e.g. logical is skipped here). 1324 if (!mlir::isa<mlir::FloatType, mlir::IntegerType>(*arg1Type)) 1325 return; 1326 if (!mlir::isa<mlir::FloatType, mlir::IntegerType>(*arg2Type)) 1327 return; 1328 1329 auto typeGenerator = [&type](fir::FirOpBuilder &builder) { 1330 return genRuntimeDotType(builder, type); 1331 }; 1332 auto bodyGenerator = [&arg1Type, 1333 &arg2Type](fir::FirOpBuilder &builder, 1334 mlir::func::FuncOp &funcOp) { 1335 genRuntimeDotBody(builder, funcOp, *arg1Type, *arg2Type); 1336 }; 1337 1338 // Suffix the function name with the element types 1339 // of the arguments. 1340 std::string typedFuncName(funcName); 1341 llvm::raw_string_ostream nameOS(typedFuncName); 1342 // We must mangle the generated function name with FastMathFlags 1343 // value. 1344 if (!fmfString.empty()) 1345 nameOS << '_' << fmfString; 1346 nameOS << '_'; 1347 arg1Type->print(nameOS); 1348 nameOS << '_'; 1349 arg2Type->print(nameOS); 1350 1351 mlir::func::FuncOp newFunc = getOrCreateFunction( 1352 builder, typedFuncName, typeGenerator, bodyGenerator); 1353 auto newCall = builder.create<fir::CallOp>(loc, newFunc, 1354 mlir::ValueRange{v1, v2}); 1355 call->replaceAllUsesWith(newCall.getResults()); 1356 call->dropAllReferences(); 1357 call->erase(); 1358 1359 LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump(); 1360 llvm::dbgs() << "\n"); 1361 return; 1362 } 1363 if (funcName.starts_with(RTNAME_STRING(Maxval))) { 1364 simplifyIntOrFloatReduction(call, kindMap, genRuntimeMaxvalBody); 1365 return; 1366 } 1367 if (funcName.starts_with(RTNAME_STRING(Count))) { 1368 simplifyLogicalDim0Reduction(call, kindMap, genRuntimeCountBody); 1369 return; 1370 } 1371 if (funcName.starts_with(RTNAME_STRING(Any))) { 1372 simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAnyBody); 1373 return; 1374 } 1375 if (funcName.ends_with(RTNAME_STRING(All))) { 1376 simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAllBody); 1377 return; 1378 } 1379 if (funcName.starts_with(RTNAME_STRING(Minloc))) { 1380 simplifyMinMaxlocReduction(call, kindMap, false); 1381 return; 1382 } 1383 if (funcName.starts_with(RTNAME_STRING(Maxloc))) { 1384 simplifyMinMaxlocReduction(call, kindMap, true); 1385 return; 1386 } 1387 } 1388 } 1389 }); 1390 LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n"); 1391 } 1392 1393 void SimplifyIntrinsicsPass::getDependentDialects( 1394 mlir::DialectRegistry ®istry) const { 1395 // LLVM::LinkageAttr creation requires that LLVM dialect is loaded. 1396 registry.insert<mlir::LLVM::LLVMDialect>(); 1397 } 1398