1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 //===----------------------------------------------------------------------===// 10 /// \file 11 /// This pass looks for suitable calls to runtime library for intrinsics that 12 /// can be simplified/specialized and replaces with a specialized function. 13 /// 14 /// For example, SUM(arr) can be specialized as a simple function with one loop, 15 /// compared to the three arguments (plus file & line info) that the runtime 16 /// call has - when the argument is a 1D-array (multiple loops may be needed 17 // for higher dimension arrays, of course) 18 /// 19 /// The general idea is that besides making the call simpler, it can also be 20 /// inlined by other passes that run after this pass, which further improves 21 /// performance, particularly when the work done in the function is trivial 22 /// and small in size. 23 //===----------------------------------------------------------------------===// 24 25 #include "flang/Common/Fortran.h" 26 #include "flang/Optimizer/Builder/BoxValue.h" 27 #include "flang/Optimizer/Builder/FIRBuilder.h" 28 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" 29 #include "flang/Optimizer/Builder/Todo.h" 30 #include "flang/Optimizer/Dialect/FIROps.h" 31 #include "flang/Optimizer/Dialect/FIRType.h" 32 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 33 #include "flang/Optimizer/HLFIR/HLFIRDialect.h" 34 #include "flang/Optimizer/Transforms/Passes.h" 35 #include "flang/Runtime/entry-names.h" 36 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 37 #include "mlir/IR/Matchers.h" 38 #include "mlir/IR/Operation.h" 39 #include "mlir/Pass/Pass.h" 40 #include "mlir/Transforms/DialectConversion.h" 41 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 42 #include "mlir/Transforms/RegionUtils.h" 43 #include "llvm/Support/Debug.h" 44 #include "llvm/Support/raw_ostream.h" 45 #include <llvm/Support/ErrorHandling.h> 46 #include <mlir/Dialect/Arith/IR/Arith.h> 47 #include <mlir/IR/BuiltinTypes.h> 48 #include <mlir/IR/Location.h> 49 #include <mlir/IR/MLIRContext.h> 50 #include <mlir/IR/Value.h> 51 #include <mlir/Support/LLVM.h> 52 #include <optional> 53 54 namespace fir { 55 #define GEN_PASS_DEF_SIMPLIFYINTRINSICS 56 #include "flang/Optimizer/Transforms/Passes.h.inc" 57 } // namespace fir 58 59 #define DEBUG_TYPE "flang-simplify-intrinsics" 60 61 namespace { 62 63 class SimplifyIntrinsicsPass 64 : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> { 65 using FunctionTypeGeneratorTy = 66 llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>; 67 using FunctionBodyGeneratorTy = 68 llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>; 69 using GenReductionBodyTy = llvm::function_ref<void( 70 fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank, 71 mlir::Type elementType)>; 72 73 public: 74 /// Generate a new function implementing a simplified version 75 /// of a Fortran runtime function defined by \p basename name. 76 /// \p typeGenerator is a callback that generates the new function's type. 77 /// \p bodyGenerator is a callback that generates the new function's body. 78 /// The new function is created in the \p builder's Module. 79 mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder, 80 const mlir::StringRef &basename, 81 FunctionTypeGeneratorTy typeGenerator, 82 FunctionBodyGeneratorTy bodyGenerator); 83 void runOnOperation() override; 84 void getDependentDialects(mlir::DialectRegistry ®istry) const override; 85 86 private: 87 /// Helper functions to replace a reduction type of call with its 88 /// simplified form. The actual function is generated using a callback 89 /// function. 90 /// \p call is the call to be replaced 91 /// \p kindMap is used to create FIROpBuilder 92 /// \p genBodyFunc is the callback that builds the replacement function 93 void simplifyIntOrFloatReduction(fir::CallOp call, 94 const fir::KindMapping &kindMap, 95 GenReductionBodyTy genBodyFunc); 96 void simplifyLogicalDim0Reduction(fir::CallOp call, 97 const fir::KindMapping &kindMap, 98 GenReductionBodyTy genBodyFunc); 99 void simplifyLogicalDim1Reduction(fir::CallOp call, 100 const fir::KindMapping &kindMap, 101 GenReductionBodyTy genBodyFunc); 102 void simplifyMinMaxlocReduction(fir::CallOp call, 103 const fir::KindMapping &kindMap, bool isMax); 104 void simplifyReductionBody(fir::CallOp call, const fir::KindMapping &kindMap, 105 GenReductionBodyTy genBodyFunc, 106 fir::FirOpBuilder &builder, 107 const mlir::StringRef &basename, 108 mlir::Type elementType); 109 }; 110 111 } // namespace 112 113 /// Create FirOpBuilder with the provided \p op insertion point 114 /// and \p kindMap additionally inheriting FastMathFlags from \p op. 115 static fir::FirOpBuilder 116 getSimplificationBuilder(mlir::Operation *op, const fir::KindMapping &kindMap) { 117 fir::FirOpBuilder builder{op, kindMap}; 118 auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op); 119 if (!fmi) 120 return builder; 121 122 // Regardless of what default FastMathFlags are used by FirOpBuilder, 123 // override them with FastMathFlags attached to the operation. 124 builder.setFastMathFlags(fmi.getFastMathFlagsAttr().getValue()); 125 return builder; 126 } 127 128 /// Generate function type for the simplified version of RTNAME(Sum) and 129 /// similar functions with a fir.box<none> type returning \p elementType. 130 static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder, 131 const mlir::Type &elementType) { 132 mlir::Type boxType = fir::BoxType::get(builder.getNoneType()); 133 return mlir::FunctionType::get(builder.getContext(), {boxType}, 134 {elementType}); 135 } 136 137 template <typename Op> 138 Op expectOp(mlir::Value val) { 139 if (Op op = mlir::dyn_cast_or_null<Op>(val.getDefiningOp())) 140 return op; 141 LLVM_DEBUG(llvm::dbgs() << "Didn't find expected " << Op::getOperationName() 142 << '\n'); 143 return nullptr; 144 } 145 146 template <typename Op> 147 static mlir::Value findDefSingle(fir::ConvertOp op) { 148 if (auto defOp = expectOp<Op>(op->getOperand(0))) { 149 return defOp.getResult(); 150 } 151 return {}; 152 } 153 154 template <typename... Ops> 155 static mlir::Value findDef(fir::ConvertOp op) { 156 mlir::Value defOp; 157 // Loop over the operation types given to see if any match, exiting once 158 // a match is found. Cast to void is needed to avoid compiler complaining 159 // that the result of expression is unused 160 (void)((defOp = findDefSingle<Ops>(op), (defOp)) || ...); 161 return defOp; 162 } 163 164 static bool isOperandAbsent(mlir::Value val) { 165 if (auto op = expectOp<fir::ConvertOp>(val)) { 166 assert(op->getOperands().size() != 0); 167 return mlir::isa_and_nonnull<fir::AbsentOp>( 168 op->getOperand(0).getDefiningOp()); 169 } 170 return false; 171 } 172 173 static bool isTrueOrNotConstant(mlir::Value val) { 174 if (auto op = expectOp<mlir::arith::ConstantOp>(val)) { 175 return !mlir::matchPattern(val, mlir::m_Zero()); 176 } 177 return true; 178 } 179 180 static bool isZero(mlir::Value val) { 181 if (auto op = expectOp<fir::ConvertOp>(val)) { 182 assert(op->getOperands().size() != 0); 183 if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp()) 184 return mlir::matchPattern(defOp, mlir::m_Zero()); 185 } 186 return false; 187 } 188 189 static mlir::Value findBoxDef(mlir::Value val) { 190 if (auto op = expectOp<fir::ConvertOp>(val)) { 191 assert(op->getOperands().size() != 0); 192 return findDef<fir::EmboxOp, fir::ReboxOp>(op); 193 } 194 return {}; 195 } 196 197 static mlir::Value findMaskDef(mlir::Value val) { 198 if (auto op = expectOp<fir::ConvertOp>(val)) { 199 assert(op->getOperands().size() != 0); 200 return findDef<fir::EmboxOp, fir::ReboxOp, fir::AbsentOp>(op); 201 } 202 return {}; 203 } 204 205 static unsigned getDimCount(mlir::Value val) { 206 // In order to find the dimensions count, we look for EmboxOp/ReboxOp 207 // and take the count from its *result* type. Note that in case 208 // of sliced emboxing the operand and the result of EmboxOp/ReboxOp 209 // have different types. 210 // Actually, we can take the box type from the operand of 211 // the first ConvertOp that has non-opaque box type that we meet 212 // going through the ConvertOp chain. 213 if (mlir::Value emboxVal = findBoxDef(val)) 214 if (auto boxTy = emboxVal.getType().dyn_cast<fir::BoxType>()) 215 if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>()) 216 return seqTy.getDimension(); 217 return 0; 218 } 219 220 /// Given the call operation's box argument \p val, discover 221 /// the element type of the underlying array object. 222 /// \returns the element type or std::nullopt if the type cannot 223 /// be reliably found. 224 /// We expect that the argument is a result of fir.convert 225 /// with the destination type of !fir.box<none>. 226 static std::optional<mlir::Type> getArgElementType(mlir::Value val) { 227 mlir::Operation *defOp; 228 do { 229 defOp = val.getDefiningOp(); 230 // Analyze only sequences of convert operations. 231 if (!mlir::isa<fir::ConvertOp>(defOp)) 232 return std::nullopt; 233 val = defOp->getOperand(0); 234 // The convert operation is expected to convert from one 235 // box type to another box type. 236 auto boxType = val.getType().cast<fir::BoxType>(); 237 auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType); 238 if (!elementType.isa<mlir::NoneType>()) 239 return elementType; 240 } while (true); 241 } 242 243 using BodyOpGeneratorTy = llvm::function_ref<mlir::Value( 244 fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value, 245 mlir::Value)>; 246 using InitValGeneratorTy = llvm::function_ref<mlir::Value( 247 fir::FirOpBuilder &, mlir::Location, const mlir::Type &)>; 248 using ContinueLoopGenTy = llvm::function_ref<llvm::SmallVector<mlir::Value>( 249 fir::FirOpBuilder &, mlir::Location, mlir::Value)>; 250 251 /// Generate the reduction loop into \p funcOp. 252 /// 253 /// \p initVal is a function, called to get the initial value for 254 /// the reduction value 255 /// \p genBody is called to fill in the actual reduciton operation 256 /// for example add for SUM, MAX for MAXVAL, etc. 257 /// \p rank is the rank of the input argument. 258 /// \p elementType is the type of the elements in the input array, 259 /// which may be different to the return type. 260 /// \p loopCond is called to generate the condition to continue or 261 /// not for IterWhile loops 262 /// \p unorderedOrInitalLoopCond contains either a boolean or bool 263 /// mlir constant, and controls the inital value for while loops 264 /// or if DoLoop is ordered/unordered. 265 266 template <typename OP, typename T, int resultIndex> 267 static void 268 genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, 269 InitValGeneratorTy initVal, ContinueLoopGenTy loopCond, 270 T unorderedOrInitialLoopCond, BodyOpGeneratorTy genBody, 271 unsigned rank, mlir::Type elementType, mlir::Location loc) { 272 273 mlir::IndexType idxTy = builder.getIndexType(); 274 275 mlir::Block::BlockArgListType args = funcOp.front().getArguments(); 276 mlir::Value arg = args[0]; 277 278 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); 279 280 fir::SequenceType::Shape flatShape(rank, 281 fir::SequenceType::getUnknownExtent()); 282 mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType); 283 mlir::Type boxArrTy = fir::BoxType::get(arrTy); 284 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg); 285 mlir::Type resultType = funcOp.getResultTypes()[0]; 286 mlir::Value init = initVal(builder, loc, resultType); 287 288 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds; 289 290 assert(rank > 0 && "rank cannot be zero"); 291 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); 292 293 // Compute all the upper bounds before the loop nest. 294 // It is not strictly necessary for performance, since the loop nest 295 // does not have any store operations and any LICM optimization 296 // should be able to optimize the redundancy. 297 for (unsigned i = 0; i < rank; ++i) { 298 mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i); 299 auto dims = 300 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx); 301 mlir::Value len = dims.getResult(1); 302 // We use C indexing here, so len-1 as loopcount 303 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one); 304 bounds.push_back(loopCount); 305 } 306 // Create a loop nest consisting of OP operations. 307 // Collect the loops' induction variables into indices array, 308 // which will be used in the innermost loop to load the input 309 // array's element. 310 // The loops are generated such that the innermost loop processes 311 // the 0 dimension. 312 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices; 313 for (unsigned i = rank; 0 < i; --i) { 314 mlir::Value step = one; 315 mlir::Value loopCount = bounds[i - 1]; 316 auto loop = builder.create<OP>(loc, zeroIdx, loopCount, step, 317 unorderedOrInitialLoopCond, 318 /*finalCountValue=*/false, init); 319 init = loop.getRegionIterArgs()[resultIndex]; 320 indices.push_back(loop.getInductionVar()); 321 // Set insertion point to the loop body so that the next loop 322 // is inserted inside the current one. 323 builder.setInsertionPointToStart(loop.getBody()); 324 } 325 326 // Reverse the indices such that they are ordered as: 327 // <dim-0-idx, dim-1-idx, ...> 328 std::reverse(indices.begin(), indices.end()); 329 // We are in the innermost loop: generate the reduction body. 330 mlir::Type eleRefTy = builder.getRefType(elementType); 331 mlir::Value addr = 332 builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices); 333 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr); 334 mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init); 335 // Generate vector with condition to continue while loop at [0] and result 336 // from current loop at [1] for IterWhileOp loops, just result at [0] for 337 // DoLoopOp loops. 338 llvm::SmallVector<mlir::Value> results = loopCond(builder, loc, reductionVal); 339 340 // Unwind the loop nest and insert ResultOp on each level 341 // to return the updated value of the reduction to the enclosing 342 // loops. 343 for (unsigned i = 0; i < rank; ++i) { 344 auto result = builder.create<fir::ResultOp>(loc, results); 345 // Proceed to the outer loop. 346 auto loop = mlir::cast<OP>(result->getParentOp()); 347 results = loop.getResults(); 348 // Set insertion point after the loop operation that we have 349 // just processed. 350 builder.setInsertionPointAfter(loop.getOperation()); 351 } 352 // End of loop nest. The insertion point is after the outermost loop. 353 // Return the reduction value from the function. 354 builder.create<mlir::func::ReturnOp>(loc, results[resultIndex]); 355 } 356 using MinMaxlocBodyOpGeneratorTy = llvm::function_ref<mlir::Value( 357 fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value, 358 mlir::Value, llvm::SmallVector<mlir::Value, Fortran::common::maxRank> &)>; 359 360 static void genMinMaxlocReductionLoop( 361 fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, 362 InitValGeneratorTy initVal, MinMaxlocBodyOpGeneratorTy genBody, 363 unsigned rank, mlir::Type elementType, mlir::Location loc, bool hasMask, 364 mlir::Type maskElemType, mlir::Value resultArr) { 365 366 mlir::IndexType idxTy = builder.getIndexType(); 367 368 mlir::Block::BlockArgListType args = funcOp.front().getArguments(); 369 mlir::Value arg = args[1]; 370 371 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); 372 373 fir::SequenceType::Shape flatShape(rank, 374 fir::SequenceType::getUnknownExtent()); 375 mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType); 376 mlir::Type boxArrTy = fir::BoxType::get(arrTy); 377 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg); 378 379 mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType()); 380 mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1); 381 mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0); 382 mlir::Value flagRef = builder.createTemporary(loc, resultElemType); 383 builder.create<fir::StoreOp>(loc, zero, flagRef); 384 385 mlir::Value mask; 386 if (hasMask) { 387 mlir::Type maskTy = fir::SequenceType::get(flatShape, maskElemType); 388 mlir::Type boxMaskTy = fir::BoxType::get(maskTy); 389 mask = builder.create<fir::ConvertOp>(loc, boxMaskTy, args[2]); 390 } 391 392 mlir::Value init = initVal(builder, loc, elementType); 393 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds; 394 395 assert(rank > 0 && "rank cannot be zero"); 396 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); 397 398 // Compute all the upper bounds before the loop nest. 399 // It is not strictly necessary for performance, since the loop nest 400 // does not have any store operations and any LICM optimization 401 // should be able to optimize the redundancy. 402 for (unsigned i = 0; i < rank; ++i) { 403 mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i); 404 auto dims = 405 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx); 406 mlir::Value len = dims.getResult(1); 407 // We use C indexing here, so len-1 as loopcount 408 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one); 409 bounds.push_back(loopCount); 410 } 411 // Create a loop nest consisting of OP operations. 412 // Collect the loops' induction variables into indices array, 413 // which will be used in the innermost loop to load the input 414 // array's element. 415 // The loops are generated such that the innermost loop processes 416 // the 0 dimension. 417 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices; 418 for (unsigned i = rank; 0 < i; --i) { 419 mlir::Value step = one; 420 mlir::Value loopCount = bounds[i - 1]; 421 auto loop = 422 builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false, 423 /*finalCountValue=*/false, init); 424 init = loop.getRegionIterArgs()[0]; 425 indices.push_back(loop.getInductionVar()); 426 // Set insertion point to the loop body so that the next loop 427 // is inserted inside the current one. 428 builder.setInsertionPointToStart(loop.getBody()); 429 } 430 431 // Reverse the indices such that they are ordered as: 432 // <dim-0-idx, dim-1-idx, ...> 433 std::reverse(indices.begin(), indices.end()); 434 // We are in the innermost loop: generate the reduction body. 435 if (hasMask) { 436 mlir::Type logicalRef = builder.getRefType(maskElemType); 437 mlir::Value maskAddr = 438 builder.create<fir::CoordinateOp>(loc, logicalRef, mask, indices); 439 mlir::Value maskElem = builder.create<fir::LoadOp>(loc, maskAddr); 440 441 // fir::IfOp requires argument to be I1 - won't accept logical or any other 442 // Integer. 443 mlir::Type ifCompatType = builder.getI1Type(); 444 mlir::Value ifCompatElem = 445 builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem); 446 447 llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType}; 448 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, ifCompatElem, 449 /*withElseRegion=*/true); 450 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 451 } 452 453 // Set flag that mask was true at some point 454 builder.create<fir::StoreOp>(loc, flagSet, flagRef); 455 mlir::Type eleRefTy = builder.getRefType(elementType); 456 mlir::Value addr = 457 builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices); 458 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr); 459 460 mlir::Value reductionVal = 461 genBody(builder, loc, elementType, elem, init, indices); 462 463 if (hasMask) { 464 fir::IfOp ifOp = 465 mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp()); 466 builder.create<fir::ResultOp>(loc, reductionVal); 467 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 468 builder.create<fir::ResultOp>(loc, init); 469 reductionVal = ifOp.getResult(0); 470 builder.setInsertionPointAfter(ifOp); 471 } 472 473 // Unwind the loop nest and insert ResultOp on each level 474 // to return the updated value of the reduction to the enclosing 475 // loops. 476 for (unsigned i = 0; i < rank; ++i) { 477 auto result = builder.create<fir::ResultOp>(loc, reductionVal); 478 // Proceed to the outer loop. 479 auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp()); 480 reductionVal = loop.getResult(0); 481 // Set insertion point after the loop operation that we have 482 // just processed. 483 builder.setInsertionPointAfter(loop.getOperation()); 484 } 485 // End of loop nest. The insertion point is after the outermost loop. 486 if (fir::IfOp ifOp = 487 mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) { 488 builder.create<fir::ResultOp>(loc, reductionVal); 489 builder.setInsertionPointAfter(ifOp); 490 // Redefine flagSet to escape scope of ifOp 491 flagSet = builder.createIntegerConstant(loc, resultElemType, 1); 492 reductionVal = ifOp.getResult(0); 493 } 494 495 // Check for case where array was full of max values. 496 // flag will be 0 if mask was never true, 1 if mask was true as some point, 497 // this is needed to avoid catching cases where we didn't access any elements 498 // e.g. mask=.FALSE. 499 mlir::Value flagValue = 500 builder.create<fir::LoadOp>(loc, resultElemType, flagRef); 501 mlir::Value flagCmp = builder.create<mlir::arith::CmpIOp>( 502 loc, mlir::arith::CmpIPredicate::eq, flagValue, flagSet); 503 fir::IfOp ifMaskTrueOp = 504 builder.create<fir::IfOp>(loc, flagCmp, /*withElseRegion=*/false); 505 builder.setInsertionPointToStart(&ifMaskTrueOp.getThenRegion().front()); 506 507 mlir::Value testInit = initVal(builder, loc, elementType); 508 fir::IfOp ifMinSetOp; 509 if (elementType.isa<mlir::FloatType>()) { 510 mlir::Value cmp = builder.create<mlir::arith::CmpFOp>( 511 loc, mlir::arith::CmpFPredicate::OEQ, testInit, reductionVal); 512 ifMinSetOp = builder.create<fir::IfOp>(loc, cmp, 513 /*withElseRegion*/ false); 514 } else { 515 mlir::Value cmp = builder.create<mlir::arith::CmpIOp>( 516 loc, mlir::arith::CmpIPredicate::eq, testInit, reductionVal); 517 ifMinSetOp = builder.create<fir::IfOp>(loc, cmp, 518 /*withElseRegion*/ false); 519 } 520 builder.setInsertionPointToStart(&ifMinSetOp.getThenRegion().front()); 521 522 // Load output array with 1s instead of 0s 523 for (unsigned int i = 0; i < rank; ++i) { 524 mlir::Type resultRefTy = builder.getRefType(resultElemType); 525 // mlir::Value one = builder.createIntegerConstant(loc, resultElemType, 1); 526 mlir::Value index = builder.createIntegerConstant(loc, idxTy, i); 527 mlir::Value resultElemAddr = 528 builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr, index); 529 builder.create<fir::StoreOp>(loc, flagSet, resultElemAddr); 530 } 531 builder.setInsertionPointAfter(ifMaskTrueOp); 532 // Store newly created output array to the reference passed in 533 fir::SequenceType::Shape resultShape(1, rank); 534 mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemType); 535 mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy); 536 mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy); 537 mlir::Type outputRefTy = builder.getRefType(outputBoxTy); 538 539 mlir::Value outputArrNone = args[0]; 540 mlir::Value outputArr = 541 builder.create<fir::ConvertOp>(loc, outputRefTy, outputArrNone); 542 543 // Store nearly created array to output array 544 builder.create<fir::StoreOp>(loc, resultArr, outputArr); 545 builder.create<mlir::func::ReturnOp>(loc); 546 } 547 548 static llvm::SmallVector<mlir::Value> nopLoopCond(fir::FirOpBuilder &builder, 549 mlir::Location loc, 550 mlir::Value reductionVal) { 551 return {reductionVal}; 552 } 553 554 /// Generate function body of the simplified version of RTNAME(Sum) 555 /// with signature provided by \p funcOp. The caller is responsible 556 /// for saving/restoring the original insertion point of \p builder. 557 /// \p funcOp is expected to be empty on entry to this function. 558 /// \p rank specifies the rank of the input argument. 559 static void genRuntimeSumBody(fir::FirOpBuilder &builder, 560 mlir::func::FuncOp &funcOp, unsigned rank, 561 mlir::Type elementType) { 562 // function RTNAME(Sum)<T>x<rank>_simplified(arr) 563 // T, dimension(:) :: arr 564 // T sum = 0 565 // integer iter 566 // do iter = 0, extent(arr) 567 // sum = sum + arr[iter] 568 // end do 569 // RTNAME(Sum)<T>x<rank>_simplified = sum 570 // end function RTNAME(Sum)<T>x<rank>_simplified 571 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc, 572 mlir::Type elementType) { 573 if (auto ty = elementType.dyn_cast<mlir::FloatType>()) { 574 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 575 return builder.createRealConstant(loc, elementType, 576 llvm::APFloat::getZero(sem)); 577 } 578 return builder.createIntegerConstant(loc, elementType, 0); 579 }; 580 581 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 582 mlir::Type elementType, mlir::Value elem1, 583 mlir::Value elem2) -> mlir::Value { 584 if (elementType.isa<mlir::FloatType>()) 585 return builder.create<mlir::arith::AddFOp>(loc, elem1, elem2); 586 if (elementType.isa<mlir::IntegerType>()) 587 return builder.create<mlir::arith::AddIOp>(loc, elem1, elem2); 588 589 llvm_unreachable("unsupported type"); 590 return {}; 591 }; 592 593 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 594 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 595 596 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond, 597 false, genBodyOp, rank, elementType, 598 loc); 599 } 600 601 static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder, 602 mlir::func::FuncOp &funcOp, unsigned rank, 603 mlir::Type elementType) { 604 auto init = [](fir::FirOpBuilder builder, mlir::Location loc, 605 mlir::Type elementType) { 606 if (auto ty = elementType.dyn_cast<mlir::FloatType>()) { 607 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 608 return builder.createRealConstant( 609 loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/true)); 610 } 611 unsigned bits = elementType.getIntOrFloatBitWidth(); 612 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); 613 return builder.createIntegerConstant(loc, elementType, minInt); 614 }; 615 616 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 617 mlir::Type elementType, mlir::Value elem1, 618 mlir::Value elem2) -> mlir::Value { 619 if (elementType.isa<mlir::FloatType>()) { 620 // arith.maxf later converted to llvm.intr.maxnum does not work 621 // correctly for NaNs and -0.0 (see maxnum/minnum pattern matching 622 // in LLVM's InstCombine pass). Moreover, llvm.intr.maxnum 623 // for F128 operands is lowered into fmaxl call by LLVM. 624 // This libm function may not work properly for F128 arguments 625 // on targets where long double is not F128. It is an LLVM issue, 626 // but we just use normal select here to resolve all the cases. 627 auto compare = builder.create<mlir::arith::CmpFOp>( 628 loc, mlir::arith::CmpFPredicate::OGT, elem1, elem2); 629 return builder.create<mlir::arith::SelectOp>(loc, compare, elem1, elem2); 630 } 631 if (elementType.isa<mlir::IntegerType>()) 632 return builder.create<mlir::arith::MaxSIOp>(loc, elem1, elem2); 633 634 llvm_unreachable("unsupported type"); 635 return {}; 636 }; 637 638 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 639 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 640 641 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, init, nopLoopCond, 642 false, genBodyOp, rank, elementType, 643 loc); 644 } 645 646 static void genRuntimeCountBody(fir::FirOpBuilder &builder, 647 mlir::func::FuncOp &funcOp, unsigned rank, 648 mlir::Type elementType) { 649 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc, 650 mlir::Type elementType) { 651 unsigned bits = elementType.getIntOrFloatBitWidth(); 652 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 653 return builder.createIntegerConstant(loc, elementType, zeroInt); 654 }; 655 656 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 657 mlir::Type elementType, mlir::Value elem1, 658 mlir::Value elem2) -> mlir::Value { 659 auto zero32 = builder.createIntegerConstant(loc, elementType, 0); 660 auto zero64 = builder.createIntegerConstant(loc, builder.getI64Type(), 0); 661 auto one64 = builder.createIntegerConstant(loc, builder.getI64Type(), 1); 662 663 auto compare = builder.create<mlir::arith::CmpIOp>( 664 loc, mlir::arith::CmpIPredicate::eq, elem1, zero32); 665 auto select = 666 builder.create<mlir::arith::SelectOp>(loc, compare, zero64, one64); 667 return builder.create<mlir::arith::AddIOp>(loc, select, elem2); 668 }; 669 670 // Count always gets I32 for elementType as it converts logical input to 671 // logical<4> before passing to the function. 672 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 673 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 674 675 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond, 676 false, genBodyOp, rank, elementType, 677 loc); 678 } 679 680 static void genRuntimeAnyBody(fir::FirOpBuilder &builder, 681 mlir::func::FuncOp &funcOp, unsigned rank, 682 mlir::Type elementType) { 683 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc, 684 mlir::Type elementType) { 685 return builder.createIntegerConstant(loc, elementType, 0); 686 }; 687 688 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 689 mlir::Type elementType, mlir::Value elem1, 690 mlir::Value elem2) -> mlir::Value { 691 auto zero = builder.createIntegerConstant(loc, elementType, 0); 692 return builder.create<mlir::arith::CmpIOp>( 693 loc, mlir::arith::CmpIPredicate::ne, elem1, zero); 694 }; 695 696 auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc, 697 mlir::Value reductionVal) { 698 auto one1 = builder.createIntegerConstant(loc, builder.getI1Type(), 1); 699 auto eor = builder.create<mlir::arith::XOrIOp>(loc, reductionVal, one1); 700 llvm::SmallVector<mlir::Value> results = {eor, reductionVal}; 701 return results; 702 }; 703 704 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 705 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 706 mlir::Value ok = builder.createBool(loc, true); 707 708 genReductionLoop<fir::IterWhileOp, mlir::Value, 1>( 709 builder, funcOp, zero, continueCond, ok, genBodyOp, rank, elementType, 710 loc); 711 } 712 713 static void genRuntimeAllBody(fir::FirOpBuilder &builder, 714 mlir::func::FuncOp &funcOp, unsigned rank, 715 mlir::Type elementType) { 716 auto one = [](fir::FirOpBuilder builder, mlir::Location loc, 717 mlir::Type elementType) { 718 return builder.createIntegerConstant(loc, elementType, 1); 719 }; 720 721 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, 722 mlir::Type elementType, mlir::Value elem1, 723 mlir::Value elem2) -> mlir::Value { 724 auto zero = builder.createIntegerConstant(loc, elementType, 0); 725 return builder.create<mlir::arith::CmpIOp>( 726 loc, mlir::arith::CmpIPredicate::ne, elem1, zero); 727 }; 728 729 auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc, 730 mlir::Value reductionVal) { 731 llvm::SmallVector<mlir::Value> results = {reductionVal, reductionVal}; 732 return results; 733 }; 734 735 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 736 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 737 mlir::Value ok = builder.createBool(loc, true); 738 739 genReductionLoop<fir::IterWhileOp, mlir::Value, 1>( 740 builder, funcOp, one, continueCond, ok, genBodyOp, rank, elementType, 741 loc); 742 } 743 744 static mlir::FunctionType genRuntimeMinlocType(fir::FirOpBuilder &builder, 745 unsigned int rank) { 746 mlir::Type boxType = fir::BoxType::get(builder.getNoneType()); 747 mlir::Type boxRefType = builder.getRefType(boxType); 748 749 return mlir::FunctionType::get(builder.getContext(), 750 {boxRefType, boxType, boxType}, {}); 751 } 752 753 static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder, 754 mlir::func::FuncOp &funcOp, bool isMax, 755 unsigned rank, int maskRank, 756 mlir::Type elementType, 757 mlir::Type maskElemType, 758 mlir::Type resultElemTy) { 759 auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc, 760 mlir::Type elementType) { 761 if (auto ty = elementType.dyn_cast<mlir::FloatType>()) { 762 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 763 return builder.createRealConstant( 764 loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/isMax)); 765 } 766 unsigned bits = elementType.getIntOrFloatBitWidth(); 767 int64_t initValue = (isMax ? llvm::APInt::getSignedMinValue(bits) 768 : llvm::APInt::getSignedMaxValue(bits)) 769 .getSExtValue(); 770 return builder.createIntegerConstant(loc, elementType, initValue); 771 }; 772 773 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext()); 774 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 775 776 mlir::Value mask = funcOp.front().getArgument(2); 777 778 // Set up result array in case of early exit / 0 length array 779 mlir::IndexType idxTy = builder.getIndexType(); 780 mlir::Type resultTy = fir::SequenceType::get(rank, resultElemTy); 781 mlir::Type resultHeapTy = fir::HeapType::get(resultTy); 782 mlir::Type resultBoxTy = fir::BoxType::get(resultHeapTy); 783 784 mlir::Value returnValue = builder.createIntegerConstant(loc, resultElemTy, 0); 785 mlir::Value resultArrSize = builder.createIntegerConstant(loc, idxTy, rank); 786 787 mlir::Value resultArrInit = builder.create<fir::AllocMemOp>(loc, resultTy); 788 mlir::Value resultArrShape = builder.create<fir::ShapeOp>(loc, resultArrSize); 789 mlir::Value resultArr = builder.create<fir::EmboxOp>( 790 loc, resultBoxTy, resultArrInit, resultArrShape); 791 792 mlir::Type resultRefTy = builder.getRefType(resultElemTy); 793 794 for (unsigned int i = 0; i < rank; ++i) { 795 mlir::Value index = builder.createIntegerConstant(loc, idxTy, i); 796 mlir::Value resultElemAddr = 797 builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr, index); 798 builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr); 799 } 800 801 auto genBodyOp = 802 [&rank, &resultArr, 803 isMax](fir::FirOpBuilder builder, mlir::Location loc, 804 mlir::Type elementType, mlir::Value elem1, mlir::Value elem2, 805 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices) 806 -> mlir::Value { 807 mlir::Value cmp; 808 if (elementType.isa<mlir::FloatType>()) { 809 cmp = builder.create<mlir::arith::CmpFOp>( 810 loc, 811 isMax ? mlir::arith::CmpFPredicate::OGT 812 : mlir::arith::CmpFPredicate::OLT, 813 elem1, elem2); 814 } else if (elementType.isa<mlir::IntegerType>()) { 815 cmp = builder.create<mlir::arith::CmpIOp>( 816 loc, 817 isMax ? mlir::arith::CmpIPredicate::sgt 818 : mlir::arith::CmpIPredicate::slt, 819 elem1, elem2); 820 } else { 821 llvm_unreachable("unsupported type"); 822 } 823 824 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp, 825 /*withElseRegion*/ true); 826 827 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 828 mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType()); 829 mlir::Type returnRefTy = builder.getRefType(resultElemTy); 830 mlir::IndexType idxTy = builder.getIndexType(); 831 832 mlir::Value one = builder.createIntegerConstant(loc, resultElemTy, 1); 833 834 for (unsigned int i = 0; i < rank; ++i) { 835 mlir::Value index = builder.createIntegerConstant(loc, idxTy, i); 836 mlir::Value resultElemAddr = 837 builder.create<fir::CoordinateOp>(loc, returnRefTy, resultArr, index); 838 mlir::Value convert = 839 builder.create<fir::ConvertOp>(loc, resultElemTy, indices[i]); 840 mlir::Value fortranIndex = 841 builder.create<mlir::arith::AddIOp>(loc, convert, one); 842 builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr); 843 } 844 builder.create<fir::ResultOp>(loc, elem1); 845 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 846 builder.create<fir::ResultOp>(loc, elem2); 847 builder.setInsertionPointAfter(ifOp); 848 return ifOp.getResult(0); 849 }; 850 851 // if mask is a logical scalar, we can check its value before the main loop 852 // and either ignore the fact it is there or exit early. 853 if (maskRank == 0) { 854 mlir::Type logical = builder.getI1Type(); 855 mlir::IndexType idxTy = builder.getIndexType(); 856 857 fir::SequenceType::Shape singleElement(1, 1); 858 mlir::Type arrTy = fir::SequenceType::get(singleElement, logical); 859 mlir::Type boxArrTy = fir::BoxType::get(arrTy); 860 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, mask); 861 862 mlir::Value indx = builder.createIntegerConstant(loc, idxTy, 0); 863 mlir::Type logicalRefTy = builder.getRefType(logical); 864 mlir::Value condAddr = 865 builder.create<fir::CoordinateOp>(loc, logicalRefTy, array, indx); 866 mlir::Value cond = builder.create<fir::LoadOp>(loc, condAddr); 867 868 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cond, 869 /*withElseRegion=*/true); 870 871 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 872 mlir::Value basicValue; 873 if (elementType.isa<mlir::IntegerType>()) { 874 basicValue = builder.createIntegerConstant(loc, elementType, 0); 875 } else { 876 basicValue = builder.createRealConstant(loc, elementType, 0); 877 } 878 builder.create<fir::ResultOp>(loc, basicValue); 879 880 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 881 } 882 883 // bit of a hack - maskRank is set to -1 for absent mask arg, so don't 884 // generate high level mask or element by element mask. 885 bool hasMask = maskRank > 0; 886 genMinMaxlocReductionLoop(builder, funcOp, init, genBodyOp, rank, elementType, 887 loc, hasMask, maskElemType, resultArr); 888 } 889 890 /// Generate function type for the simplified version of RTNAME(DotProduct) 891 /// operating on the given \p elementType. 892 static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder, 893 const mlir::Type &elementType) { 894 mlir::Type boxType = fir::BoxType::get(builder.getNoneType()); 895 return mlir::FunctionType::get(builder.getContext(), {boxType, boxType}, 896 {elementType}); 897 } 898 899 /// Generate function body of the simplified version of RTNAME(DotProduct) 900 /// with signature provided by \p funcOp. The caller is responsible 901 /// for saving/restoring the original insertion point of \p builder. 902 /// \p funcOp is expected to be empty on entry to this function. 903 /// \p arg1ElementTy and \p arg2ElementTy specify elements types 904 /// of the underlying array objects - they are used to generate proper 905 /// element accesses. 906 static void genRuntimeDotBody(fir::FirOpBuilder &builder, 907 mlir::func::FuncOp &funcOp, 908 mlir::Type arg1ElementTy, 909 mlir::Type arg2ElementTy) { 910 // function RTNAME(DotProduct)<T>_simplified(arr1, arr2) 911 // T, dimension(:) :: arr1, arr2 912 // T product = 0 913 // integer iter 914 // do iter = 0, extent(arr1) 915 // product = product + arr1[iter] * arr2[iter] 916 // end do 917 // RTNAME(ADotProduct)<T>_simplified = product 918 // end function RTNAME(DotProduct)<T>_simplified 919 auto loc = mlir::UnknownLoc::get(builder.getContext()); 920 mlir::Type resultElementType = funcOp.getResultTypes()[0]; 921 builder.setInsertionPointToEnd(funcOp.addEntryBlock()); 922 923 mlir::IndexType idxTy = builder.getIndexType(); 924 925 mlir::Value zero = 926 resultElementType.isa<mlir::FloatType>() 927 ? builder.createRealConstant(loc, resultElementType, 0.0) 928 : builder.createIntegerConstant(loc, resultElementType, 0); 929 930 mlir::Block::BlockArgListType args = funcOp.front().getArguments(); 931 mlir::Value arg1 = args[0]; 932 mlir::Value arg2 = args[1]; 933 934 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); 935 936 fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()}; 937 mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy); 938 mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1); 939 mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1); 940 mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy); 941 mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2); 942 mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2); 943 // This version takes the loop trip count from the first argument. 944 // If the first argument's box has unknown (at compilation time) 945 // extent, then it may be better to take the extent from the second 946 // argument - so that after inlining the loop may be better optimized, e.g. 947 // fully unrolled. This requires generating two versions of the simplified 948 // function and some analysis at the call site to choose which version 949 // is more profitable to call. 950 // Note that we can assume that both arguments have the same extent. 951 auto dims = 952 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx); 953 mlir::Value len = dims.getResult(1); 954 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); 955 mlir::Value step = one; 956 957 // We use C indexing here, so len-1 as loopcount 958 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one); 959 auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, 960 /*unordered=*/false, 961 /*finalCountValue=*/false, zero); 962 mlir::Value sumVal = loop.getRegionIterArgs()[0]; 963 964 // Begin loop code 965 mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint(); 966 builder.setInsertionPointToStart(loop.getBody()); 967 968 mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy); 969 mlir::Value index = loop.getInductionVar(); 970 mlir::Value addr1 = 971 builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index); 972 mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1); 973 // Convert to the result type. 974 elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1); 975 976 mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy); 977 mlir::Value addr2 = 978 builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index); 979 mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2); 980 // Convert to the result type. 981 elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2); 982 983 if (resultElementType.isa<mlir::FloatType>()) 984 sumVal = builder.create<mlir::arith::AddFOp>( 985 loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal); 986 else if (resultElementType.isa<mlir::IntegerType>()) 987 sumVal = builder.create<mlir::arith::AddIOp>( 988 loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal); 989 else 990 llvm_unreachable("unsupported type"); 991 992 builder.create<fir::ResultOp>(loc, sumVal); 993 // End of loop. 994 builder.restoreInsertionPoint(loopEndPt); 995 996 mlir::Value resultVal = loop.getResult(0); 997 builder.create<mlir::func::ReturnOp>(loc, resultVal); 998 } 999 1000 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction( 1001 fir::FirOpBuilder &builder, const mlir::StringRef &baseName, 1002 FunctionTypeGeneratorTy typeGenerator, 1003 FunctionBodyGeneratorTy bodyGenerator) { 1004 // WARNING: if the function generated here changes its signature 1005 // or behavior (the body code), we should probably embed some 1006 // versioning information into its name, otherwise libraries 1007 // statically linked with older versions of Flang may stop 1008 // working with object files created with newer Flang. 1009 // We can also avoid this by using internal linkage, but 1010 // this may increase the size of final executable/shared library. 1011 std::string replacementName = mlir::Twine{baseName, "_simplified"}.str(); 1012 mlir::ModuleOp module = builder.getModule(); 1013 // If we already have a function, just return it. 1014 mlir::func::FuncOp newFunc = 1015 fir::FirOpBuilder::getNamedFunction(module, replacementName); 1016 mlir::FunctionType fType = typeGenerator(builder); 1017 if (newFunc) { 1018 assert(newFunc.getFunctionType() == fType && 1019 "type mismatch for simplified function"); 1020 return newFunc; 1021 } 1022 1023 // Need to build the function! 1024 auto loc = mlir::UnknownLoc::get(builder.getContext()); 1025 newFunc = 1026 fir::FirOpBuilder::createFunction(loc, module, replacementName, fType); 1027 auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR; 1028 auto linkage = 1029 mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); 1030 newFunc->setAttr("llvm.linkage", linkage); 1031 1032 // Save the position of the original call. 1033 mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint(); 1034 1035 bodyGenerator(builder, newFunc); 1036 1037 // Now back to where we were adding code earlier... 1038 builder.restoreInsertionPoint(insertPt); 1039 1040 return newFunc; 1041 } 1042 1043 void SimplifyIntrinsicsPass::simplifyIntOrFloatReduction( 1044 fir::CallOp call, const fir::KindMapping &kindMap, 1045 GenReductionBodyTy genBodyFunc) { 1046 // args[1] and args[2] are source filename and line number, ignored. 1047 mlir::Operation::operand_range args = call.getArgs(); 1048 1049 const mlir::Value &dim = args[3]; 1050 const mlir::Value &mask = args[4]; 1051 // dim is zero when it is absent, which is an implementation 1052 // detail in the runtime library. 1053 1054 bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask); 1055 unsigned rank = getDimCount(args[0]); 1056 1057 // Rank is set to 0 for assumed shape arrays, don't simplify 1058 // in these cases 1059 if (!(dimAndMaskAbsent && rank > 0)) 1060 return; 1061 1062 mlir::Type resultType = call.getResult(0).getType(); 1063 1064 if (!resultType.isa<mlir::FloatType>() && 1065 !resultType.isa<mlir::IntegerType>()) 1066 return; 1067 1068 auto argType = getArgElementType(args[0]); 1069 if (!argType) 1070 return; 1071 assert(*argType == resultType && 1072 "Argument/result types mismatch in reduction"); 1073 1074 mlir::SymbolRefAttr callee = call.getCalleeAttr(); 1075 1076 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; 1077 std::string fmfString{builder.getFastMathFlagsString()}; 1078 std::string funcName = 1079 (mlir::Twine{callee.getLeafReference().getValue(), "x"} + 1080 mlir::Twine{rank} + 1081 // We must mangle the generated function name with FastMathFlags 1082 // value. 1083 (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString})) 1084 .str(); 1085 1086 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName, 1087 resultType); 1088 } 1089 1090 void SimplifyIntrinsicsPass::simplifyLogicalDim0Reduction( 1091 fir::CallOp call, const fir::KindMapping &kindMap, 1092 GenReductionBodyTy genBodyFunc) { 1093 1094 mlir::Operation::operand_range args = call.getArgs(); 1095 const mlir::Value &dim = args[3]; 1096 unsigned rank = getDimCount(args[0]); 1097 1098 // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in 1099 // these cases. 1100 if (!(isZero(dim) && rank > 0)) 1101 return; 1102 1103 mlir::Value inputBox = findBoxDef(args[0]); 1104 1105 mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType()); 1106 mlir::SymbolRefAttr callee = call.getCalleeAttr(); 1107 1108 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; 1109 1110 // Treating logicals as integers makes things a lot easier 1111 fir::LogicalType logicalType = {elementType.dyn_cast<fir::LogicalType>()}; 1112 fir::KindTy kind = logicalType.getFKind(); 1113 mlir::Type intElementType = builder.getIntegerType(kind * 8); 1114 1115 // Mangle kind into function name as it is not done by default 1116 std::string funcName = 1117 (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} + 1118 mlir::Twine{kind} + "x" + mlir::Twine{rank}) 1119 .str(); 1120 1121 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName, 1122 intElementType); 1123 } 1124 1125 void SimplifyIntrinsicsPass::simplifyLogicalDim1Reduction( 1126 fir::CallOp call, const fir::KindMapping &kindMap, 1127 GenReductionBodyTy genBodyFunc) { 1128 1129 mlir::Operation::operand_range args = call.getArgs(); 1130 mlir::SymbolRefAttr callee = call.getCalleeAttr(); 1131 mlir::StringRef funcNameBase = callee.getLeafReference().getValue(); 1132 unsigned rank = getDimCount(args[0]); 1133 1134 // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in 1135 // these cases. We check for Dim at the end as some logical functions (Any, 1136 // All) set dim to 1 instead of 0 when the argument is not present. 1137 if (funcNameBase.ends_with("Dim") || !(rank > 0)) 1138 return; 1139 1140 mlir::Value inputBox = findBoxDef(args[0]); 1141 mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType()); 1142 1143 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; 1144 1145 // Treating logicals as integers makes things a lot easier 1146 fir::LogicalType logicalType = {elementType.dyn_cast<fir::LogicalType>()}; 1147 fir::KindTy kind = logicalType.getFKind(); 1148 mlir::Type intElementType = builder.getIntegerType(kind * 8); 1149 1150 // Mangle kind into function name as it is not done by default 1151 std::string funcName = 1152 (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} + 1153 mlir::Twine{kind} + "x" + mlir::Twine{rank}) 1154 .str(); 1155 1156 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName, 1157 intElementType); 1158 } 1159 1160 void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction( 1161 fir::CallOp call, const fir::KindMapping &kindMap, bool isMax) { 1162 1163 mlir::Operation::operand_range args = call.getArgs(); 1164 1165 mlir::Value back = args[6]; 1166 if (isTrueOrNotConstant(back)) 1167 return; 1168 1169 mlir::Value mask = args[5]; 1170 mlir::Value maskDef = findMaskDef(mask); 1171 1172 // maskDef is set to NULL when the defining op is not one we accept. 1173 // This tends to be because it is a selectOp, in which case let the 1174 // runtime deal with it. 1175 if (maskDef == NULL) 1176 return; 1177 1178 mlir::SymbolRefAttr callee = call.getCalleeAttr(); 1179 mlir::StringRef funcNameBase = callee.getLeafReference().getValue(); 1180 unsigned rank = getDimCount(args[1]); 1181 if (funcNameBase.ends_with("Dim") || !(rank > 0)) 1182 return; 1183 1184 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; 1185 mlir::Location loc = call.getLoc(); 1186 auto inputBox = findBoxDef(args[1]); 1187 mlir::Type inputType = hlfir::getFortranElementType(inputBox.getType()); 1188 1189 if (inputType.isa<fir::CharacterType>()) 1190 return; 1191 1192 int maskRank; 1193 fir::KindTy kind = 0; 1194 mlir::Type logicalElemType = builder.getI1Type(); 1195 if (isOperandAbsent(mask)) { 1196 maskRank = -1; 1197 } else { 1198 maskRank = getDimCount(mask); 1199 mlir::Type maskElemTy = hlfir::getFortranElementType(maskDef.getType()); 1200 fir::LogicalType logicalFirType = {maskElemTy.dyn_cast<fir::LogicalType>()}; 1201 kind = logicalFirType.getFKind(); 1202 // Convert fir::LogicalType to mlir::Type 1203 logicalElemType = logicalFirType; 1204 } 1205 1206 mlir::Operation *outputDef = args[0].getDefiningOp(); 1207 mlir::Value outputAlloc = outputDef->getOperand(0); 1208 mlir::Type outType = hlfir::getFortranElementType(outputAlloc.getType()); 1209 1210 std::string fmfString{builder.getFastMathFlagsString()}; 1211 std::string funcName = 1212 (mlir::Twine{callee.getLeafReference().getValue(), "x"} + 1213 mlir::Twine{rank} + 1214 (maskRank >= 0 1215 ? "_Logical" + mlir::Twine{kind} + "x" + mlir::Twine{maskRank} 1216 : "") + 1217 "_") 1218 .str(); 1219 1220 llvm::raw_string_ostream nameOS(funcName); 1221 outType.print(nameOS); 1222 nameOS << '_' << fmfString; 1223 1224 auto typeGenerator = [rank](fir::FirOpBuilder &builder) { 1225 return genRuntimeMinlocType(builder, rank); 1226 }; 1227 auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType, 1228 isMax](fir::FirOpBuilder &builder, 1229 mlir::func::FuncOp &funcOp) { 1230 genRuntimeMinMaxlocBody(builder, funcOp, isMax, rank, maskRank, inputType, 1231 logicalElemType, outType); 1232 }; 1233 1234 mlir::func::FuncOp newFunc = 1235 getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator); 1236 builder.create<fir::CallOp>(loc, newFunc, 1237 mlir::ValueRange{args[0], args[1], args[5]}); 1238 call->dropAllReferences(); 1239 call->erase(); 1240 } 1241 1242 void SimplifyIntrinsicsPass::simplifyReductionBody( 1243 fir::CallOp call, const fir::KindMapping &kindMap, 1244 GenReductionBodyTy genBodyFunc, fir::FirOpBuilder &builder, 1245 const mlir::StringRef &funcName, mlir::Type elementType) { 1246 1247 mlir::Operation::operand_range args = call.getArgs(); 1248 1249 mlir::Type resultType = call.getResult(0).getType(); 1250 unsigned rank = getDimCount(args[0]); 1251 1252 mlir::Location loc = call.getLoc(); 1253 1254 auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) { 1255 return genNoneBoxType(builder, resultType); 1256 }; 1257 auto bodyGenerator = [&rank, &genBodyFunc, 1258 &elementType](fir::FirOpBuilder &builder, 1259 mlir::func::FuncOp &funcOp) { 1260 genBodyFunc(builder, funcOp, rank, elementType); 1261 }; 1262 // Mangle the function name with the rank value as "x<rank>". 1263 mlir::func::FuncOp newFunc = 1264 getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator); 1265 auto newCall = 1266 builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]}); 1267 call->replaceAllUsesWith(newCall.getResults()); 1268 call->dropAllReferences(); 1269 call->erase(); 1270 } 1271 1272 void SimplifyIntrinsicsPass::runOnOperation() { 1273 LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n"); 1274 mlir::ModuleOp module = getOperation(); 1275 fir::KindMapping kindMap = fir::getKindMapping(module); 1276 module.walk([&](mlir::Operation *op) { 1277 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) { 1278 if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) { 1279 mlir::StringRef funcName = callee.getLeafReference().getValue(); 1280 // Replace call to runtime function for SUM when it has single 1281 // argument (no dim or mask argument) for 1D arrays with either 1282 // Integer4 or Real8 types. Other forms are ignored. 1283 // The new function is added to the module. 1284 // 1285 // Prototype for runtime call (from sum.cpp): 1286 // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line, 1287 // int dim, const Descriptor *mask) 1288 // 1289 if (funcName.starts_with(RTNAME_STRING(Sum))) { 1290 simplifyIntOrFloatReduction(call, kindMap, genRuntimeSumBody); 1291 return; 1292 } 1293 if (funcName.starts_with(RTNAME_STRING(DotProduct))) { 1294 LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n"); 1295 LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump(); 1296 llvm::dbgs() << "\n"); 1297 mlir::Operation::operand_range args = call.getArgs(); 1298 const mlir::Value &v1 = args[0]; 1299 const mlir::Value &v2 = args[1]; 1300 mlir::Location loc = call.getLoc(); 1301 fir::FirOpBuilder builder{getSimplificationBuilder(op, kindMap)}; 1302 // Stringize the builder's FastMathFlags flags for mangling 1303 // the generated function name. 1304 std::string fmfString{builder.getFastMathFlagsString()}; 1305 1306 mlir::Type type = call.getResult(0).getType(); 1307 if (!type.isa<mlir::FloatType>() && !type.isa<mlir::IntegerType>()) 1308 return; 1309 1310 // Try to find the element types of the boxed arguments. 1311 auto arg1Type = getArgElementType(v1); 1312 auto arg2Type = getArgElementType(v2); 1313 1314 if (!arg1Type || !arg2Type) 1315 return; 1316 1317 // Support only floating point and integer arguments 1318 // now (e.g. logical is skipped here). 1319 if (!arg1Type->isa<mlir::FloatType>() && 1320 !arg1Type->isa<mlir::IntegerType>()) 1321 return; 1322 if (!arg2Type->isa<mlir::FloatType>() && 1323 !arg2Type->isa<mlir::IntegerType>()) 1324 return; 1325 1326 auto typeGenerator = [&type](fir::FirOpBuilder &builder) { 1327 return genRuntimeDotType(builder, type); 1328 }; 1329 auto bodyGenerator = [&arg1Type, 1330 &arg2Type](fir::FirOpBuilder &builder, 1331 mlir::func::FuncOp &funcOp) { 1332 genRuntimeDotBody(builder, funcOp, *arg1Type, *arg2Type); 1333 }; 1334 1335 // Suffix the function name with the element types 1336 // of the arguments. 1337 std::string typedFuncName(funcName); 1338 llvm::raw_string_ostream nameOS(typedFuncName); 1339 // We must mangle the generated function name with FastMathFlags 1340 // value. 1341 if (!fmfString.empty()) 1342 nameOS << '_' << fmfString; 1343 nameOS << '_'; 1344 arg1Type->print(nameOS); 1345 nameOS << '_'; 1346 arg2Type->print(nameOS); 1347 1348 mlir::func::FuncOp newFunc = getOrCreateFunction( 1349 builder, typedFuncName, typeGenerator, bodyGenerator); 1350 auto newCall = builder.create<fir::CallOp>(loc, newFunc, 1351 mlir::ValueRange{v1, v2}); 1352 call->replaceAllUsesWith(newCall.getResults()); 1353 call->dropAllReferences(); 1354 call->erase(); 1355 1356 LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump(); 1357 llvm::dbgs() << "\n"); 1358 return; 1359 } 1360 if (funcName.starts_with(RTNAME_STRING(Maxval))) { 1361 simplifyIntOrFloatReduction(call, kindMap, genRuntimeMaxvalBody); 1362 return; 1363 } 1364 if (funcName.starts_with(RTNAME_STRING(Count))) { 1365 simplifyLogicalDim0Reduction(call, kindMap, genRuntimeCountBody); 1366 return; 1367 } 1368 if (funcName.starts_with(RTNAME_STRING(Any))) { 1369 simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAnyBody); 1370 return; 1371 } 1372 if (funcName.ends_with(RTNAME_STRING(All))) { 1373 simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAllBody); 1374 return; 1375 } 1376 if (funcName.starts_with(RTNAME_STRING(Minloc))) { 1377 simplifyMinMaxlocReduction(call, kindMap, false); 1378 return; 1379 } 1380 if (funcName.starts_with(RTNAME_STRING(Maxloc))) { 1381 simplifyMinMaxlocReduction(call, kindMap, true); 1382 return; 1383 } 1384 } 1385 } 1386 }); 1387 LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n"); 1388 } 1389 1390 void SimplifyIntrinsicsPass::getDependentDialects( 1391 mlir::DialectRegistry ®istry) const { 1392 // LLVM::LinkageAttr creation requires that LLVM dialect is loaded. 1393 registry.insert<mlir::LLVM::LLVMDialect>(); 1394 } 1395 std::unique_ptr<mlir::Pass> fir::createSimplifyIntrinsicsPass() { 1396 return std::make_unique<SimplifyIntrinsicsPass>(); 1397 } 1398