1 //===- SimplifyHLFIRIntrinsics.cpp - Simplify HLFIR Intrinsics ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // Normally transformational intrinsics are lowered to calls to runtime 9 // functions. However, some cases of the intrinsics are faster when inlined 10 // into the calling function. 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Optimizer/Builder/Complex.h" 14 #include "flang/Optimizer/Builder/FIRBuilder.h" 15 #include "flang/Optimizer/Builder/HLFIRTools.h" 16 #include "flang/Optimizer/Builder/IntrinsicCall.h" 17 #include "flang/Optimizer/Dialect/FIRDialect.h" 18 #include "flang/Optimizer/HLFIR/HLFIRDialect.h" 19 #include "flang/Optimizer/HLFIR/HLFIROps.h" 20 #include "flang/Optimizer/HLFIR/Passes.h" 21 #include "mlir/Dialect/Arith/IR/Arith.h" 22 #include "mlir/IR/Location.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 26 namespace hlfir { 27 #define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS 28 #include "flang/Optimizer/HLFIR/Passes.h.inc" 29 } // namespace hlfir 30 31 #define DEBUG_TYPE "simplify-hlfir-intrinsics" 32 33 static llvm::cl::opt<bool> forceMatmulAsElemental( 34 "flang-inline-matmul-as-elemental", 35 llvm::cl::desc("Expand hlfir.matmul as elemental operation"), 36 llvm::cl::init(false)); 37 38 namespace { 39 40 // Helper class to generate operations related to computing 41 // product of values. 42 class ProductFactory { 43 public: 44 ProductFactory(mlir::Location loc, fir::FirOpBuilder &builder) 45 : loc(loc), builder(builder) {} 46 47 // Generate an update of the inner product value: 48 // acc += v1 * v2, OR 49 // acc += CONJ(v1) * v2, OR 50 // acc ||= v1 && v2 51 // 52 // CONJ parameter specifies whether the first complex product argument 53 // needs to be conjugated. 54 template <bool CONJ = false> 55 mlir::Value genAccumulateProduct(mlir::Value acc, mlir::Value v1, 56 mlir::Value v2) { 57 mlir::Type resultType = acc.getType(); 58 acc = castToProductType(acc, resultType); 59 v1 = castToProductType(v1, resultType); 60 v2 = castToProductType(v2, resultType); 61 mlir::Value result; 62 if (mlir::isa<mlir::FloatType>(resultType)) { 63 result = builder.create<mlir::arith::AddFOp>( 64 loc, acc, builder.create<mlir::arith::MulFOp>(loc, v1, v2)); 65 } else if (mlir::isa<mlir::ComplexType>(resultType)) { 66 if constexpr (CONJ) 67 result = fir::IntrinsicLibrary{builder, loc}.genConjg(resultType, v1); 68 else 69 result = v1; 70 71 result = builder.create<fir::AddcOp>( 72 loc, acc, builder.create<fir::MulcOp>(loc, result, v2)); 73 } else if (mlir::isa<mlir::IntegerType>(resultType)) { 74 result = builder.create<mlir::arith::AddIOp>( 75 loc, acc, builder.create<mlir::arith::MulIOp>(loc, v1, v2)); 76 } else if (mlir::isa<fir::LogicalType>(resultType)) { 77 result = builder.create<mlir::arith::OrIOp>( 78 loc, acc, builder.create<mlir::arith::AndIOp>(loc, v1, v2)); 79 } else { 80 llvm_unreachable("unsupported type"); 81 } 82 83 return builder.createConvert(loc, resultType, result); 84 } 85 86 private: 87 mlir::Location loc; 88 fir::FirOpBuilder &builder; 89 90 mlir::Value castToProductType(mlir::Value value, mlir::Type type) { 91 if (mlir::isa<fir::LogicalType>(type)) 92 return builder.createConvert(loc, builder.getIntegerType(1), value); 93 94 // TODO: the multiplications/additions by/of zero resulting from 95 // complex * real are optimized by LLVM under -fno-signed-zeros 96 // -fno-honor-nans. 97 // We can make them disappear by default if we: 98 // * either expand the complex multiplication into real 99 // operations, OR 100 // * set nnan nsz fast-math flags to the complex operations. 101 if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) { 102 mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type); 103 fir::factory::Complex helper(builder, loc); 104 mlir::Type partType = helper.getComplexPartType(type); 105 return helper.insertComplexPart(zeroCmplx, 106 castToProductType(value, partType), 107 /*isImagPart=*/false); 108 } 109 return builder.createConvert(loc, type, value); 110 } 111 }; 112 113 class TransposeAsElementalConversion 114 : public mlir::OpRewritePattern<hlfir::TransposeOp> { 115 public: 116 using mlir::OpRewritePattern<hlfir::TransposeOp>::OpRewritePattern; 117 118 llvm::LogicalResult 119 matchAndRewrite(hlfir::TransposeOp transpose, 120 mlir::PatternRewriter &rewriter) const override { 121 hlfir::ExprType expr = transpose.getType(); 122 // TODO: hlfir.elemental supports polymorphic data types now, 123 // so this can be supported. 124 if (expr.isPolymorphic()) 125 return rewriter.notifyMatchFailure(transpose, 126 "TRANSPOSE of polymorphic type"); 127 128 mlir::Location loc = transpose.getLoc(); 129 fir::FirOpBuilder builder{rewriter, transpose.getOperation()}; 130 mlir::Type elementType = expr.getElementType(); 131 hlfir::Entity array = hlfir::Entity{transpose.getArray()}; 132 mlir::Value resultShape = genResultShape(loc, builder, array); 133 llvm::SmallVector<mlir::Value, 1> typeParams; 134 hlfir::genLengthParameters(loc, builder, array, typeParams); 135 136 auto genKernel = [&array](mlir::Location loc, fir::FirOpBuilder &builder, 137 mlir::ValueRange inputIndices) -> hlfir::Entity { 138 assert(inputIndices.size() == 2 && "checked in TransposeOp::validate"); 139 const std::initializer_list<mlir::Value> initList = {inputIndices[1], 140 inputIndices[0]}; 141 mlir::ValueRange transposedIndices(initList); 142 hlfir::Entity element = 143 hlfir::getElementAt(loc, builder, array, transposedIndices); 144 hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, element); 145 return val; 146 }; 147 hlfir::ElementalOp elementalOp = hlfir::genElementalOp( 148 loc, builder, elementType, resultShape, typeParams, genKernel, 149 /*isUnordered=*/true, /*polymorphicMold=*/nullptr, 150 transpose.getResult().getType()); 151 152 // it wouldn't be safe to replace block arguments with a different 153 // hlfir.expr type. Types can differ due to differing amounts of shape 154 // information 155 assert(elementalOp.getResult().getType() == 156 transpose.getResult().getType()); 157 158 rewriter.replaceOp(transpose, elementalOp); 159 return mlir::success(); 160 } 161 162 private: 163 static mlir::Value genResultShape(mlir::Location loc, 164 fir::FirOpBuilder &builder, 165 hlfir::Entity array) { 166 llvm::SmallVector<mlir::Value, 2> inExtents = 167 hlfir::genExtentsVector(loc, builder, array); 168 169 // transpose indices 170 assert(inExtents.size() == 2 && "checked in TransposeOp::validate"); 171 return builder.create<fir::ShapeOp>( 172 loc, mlir::ValueRange{inExtents[1], inExtents[0]}); 173 } 174 }; 175 176 // Expand the SUM(DIM=CONSTANT) operation into . 177 class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> { 178 public: 179 using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern; 180 181 llvm::LogicalResult 182 matchAndRewrite(hlfir::SumOp sum, 183 mlir::PatternRewriter &rewriter) const override { 184 hlfir::Entity array = hlfir::Entity{sum.getArray()}; 185 bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0; 186 mlir::Value dim = sum.getDim(); 187 int64_t dimVal = 0; 188 if (!isTotalReduction) { 189 // In case of partial reduction we should ignore the operations 190 // with invalid DIM values. They may appear in dead code 191 // after constant propagation. 192 auto constDim = fir::getIntIfConstant(dim); 193 if (!constDim) 194 return rewriter.notifyMatchFailure(sum, "Nonconstant DIM for SUM"); 195 dimVal = *constDim; 196 197 if ((dimVal <= 0 || dimVal > array.getRank())) 198 return rewriter.notifyMatchFailure( 199 sum, "Invalid DIM for partial SUM reduction"); 200 } 201 202 mlir::Location loc = sum.getLoc(); 203 fir::FirOpBuilder builder{rewriter, sum.getOperation()}; 204 mlir::Type elementType = hlfir::getFortranElementType(sum.getType()); 205 mlir::Value mask = sum.getMask(); 206 207 mlir::Value resultShape, dimExtent; 208 llvm::SmallVector<mlir::Value> arrayExtents; 209 if (isTotalReduction) 210 arrayExtents = hlfir::genExtentsVector(loc, builder, array); 211 else 212 std::tie(resultShape, dimExtent) = 213 genResultShapeForPartialReduction(loc, builder, array, dimVal); 214 215 // If the mask is present and is a scalar, then we'd better load its value 216 // outside of the reduction loop making the loop unswitching easier. 217 mlir::Value isPresentPred, maskValue; 218 if (mask) { 219 if (mlir::isa<fir::BaseBoxType>(mask.getType())) { 220 // MASK represented by a box might be dynamically optional, 221 // so we have to check for its presence before accessing it. 222 isPresentPred = 223 builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask); 224 } 225 226 if (hlfir::Entity{mask}.isScalar()) 227 maskValue = genMaskValue(loc, builder, mask, isPresentPred, {}); 228 } 229 230 auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder, 231 mlir::ValueRange inputIndices) -> hlfir::Entity { 232 // Loop over all indices in the DIM dimension, and reduce all values. 233 // If DIM is not present, do total reduction. 234 235 // Initial value for the reduction. 236 mlir::Value reductionInitValue = 237 fir::factory::createZeroValue(builder, loc, elementType); 238 239 // The reduction loop may be unordered if FastMathFlags::reassoc 240 // transformations are allowed. The integer reduction is always 241 // unordered. 242 bool isUnordered = mlir::isa<mlir::IntegerType>(elementType) || 243 static_cast<bool>(sum.getFastmath() & 244 mlir::arith::FastMathFlags::reassoc); 245 246 llvm::SmallVector<mlir::Value> extents; 247 if (isTotalReduction) 248 extents = arrayExtents; 249 else 250 extents.push_back( 251 builder.createConvert(loc, builder.getIndexType(), dimExtent)); 252 253 auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder, 254 mlir::ValueRange oneBasedIndices, 255 mlir::ValueRange reductionArgs) 256 -> llvm::SmallVector<mlir::Value, 1> { 257 // Generate the reduction loop-nest body. 258 // The initial reduction value in the innermost loop 259 // is passed via reductionArgs[0]. 260 llvm::SmallVector<mlir::Value> indices; 261 if (isTotalReduction) { 262 indices = oneBasedIndices; 263 } else { 264 indices = inputIndices; 265 indices.insert(indices.begin() + dimVal - 1, oneBasedIndices[0]); 266 } 267 268 mlir::Value reductionValue = reductionArgs[0]; 269 fir::IfOp ifOp; 270 if (mask) { 271 // Make the reduction value update conditional on the value 272 // of the mask. 273 if (!maskValue) { 274 // If the mask is an array, use the elemental and the loop indices 275 // to address the proper mask element. 276 maskValue = 277 genMaskValue(loc, builder, mask, isPresentPred, indices); 278 } 279 mlir::Value isUnmasked = builder.create<fir::ConvertOp>( 280 loc, builder.getI1Type(), maskValue); 281 ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked, 282 /*withElseRegion=*/true); 283 // In the 'else' block return the current reduction value. 284 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 285 builder.create<fir::ResultOp>(loc, reductionValue); 286 287 // In the 'then' block do the actual addition. 288 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 289 } 290 291 hlfir::Entity element = 292 hlfir::getElementAt(loc, builder, array, indices); 293 hlfir::Entity elementValue = 294 hlfir::loadTrivialScalar(loc, builder, element); 295 // NOTE: we can use "Kahan summation" same way as the runtime 296 // (e.g. when fast-math is not allowed), but let's start with 297 // the simple version. 298 reductionValue = 299 genScalarAdd(loc, builder, reductionValue, elementValue); 300 301 if (ifOp) { 302 builder.create<fir::ResultOp>(loc, reductionValue); 303 builder.setInsertionPointAfter(ifOp); 304 reductionValue = ifOp.getResult(0); 305 } 306 307 return {reductionValue}; 308 }; 309 310 llvm::SmallVector<mlir::Value, 1> reductionFinalValues = 311 hlfir::genLoopNestWithReductions(loc, builder, extents, 312 {reductionInitValue}, genBody, 313 isUnordered); 314 return hlfir::Entity{reductionFinalValues[0]}; 315 }; 316 317 if (isTotalReduction) { 318 hlfir::Entity result = genKernel(loc, builder, mlir::ValueRange{}); 319 rewriter.replaceOp(sum, result); 320 return mlir::success(); 321 } 322 323 hlfir::ElementalOp elementalOp = hlfir::genElementalOp( 324 loc, builder, elementType, resultShape, {}, genKernel, 325 /*isUnordered=*/true, /*polymorphicMold=*/nullptr, 326 sum.getResult().getType()); 327 328 // it wouldn't be safe to replace block arguments with a different 329 // hlfir.expr type. Types can differ due to differing amounts of shape 330 // information 331 assert(elementalOp.getResult().getType() == sum.getResult().getType()); 332 333 rewriter.replaceOp(sum, elementalOp); 334 return mlir::success(); 335 } 336 337 private: 338 // Return fir.shape specifying the shape of the result 339 // of a SUM reduction with DIM=dimVal. The second return value 340 // is the extent of the DIM dimension. 341 static std::tuple<mlir::Value, mlir::Value> 342 genResultShapeForPartialReduction(mlir::Location loc, 343 fir::FirOpBuilder &builder, 344 hlfir::Entity array, int64_t dimVal) { 345 llvm::SmallVector<mlir::Value> inExtents = 346 hlfir::genExtentsVector(loc, builder, array); 347 assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) && 348 "DIM must be present and a positive constant not exceeding " 349 "the array's rank"); 350 351 mlir::Value dimExtent = inExtents[dimVal - 1]; 352 inExtents.erase(inExtents.begin() + dimVal - 1); 353 return {builder.create<fir::ShapeOp>(loc, inExtents), dimExtent}; 354 } 355 356 // Generate scalar addition of the two values (of the same data type). 357 static mlir::Value genScalarAdd(mlir::Location loc, 358 fir::FirOpBuilder &builder, 359 mlir::Value value1, mlir::Value value2) { 360 mlir::Type ty = value1.getType(); 361 assert(ty == value2.getType() && "reduction values' types do not match"); 362 if (mlir::isa<mlir::FloatType>(ty)) 363 return builder.create<mlir::arith::AddFOp>(loc, value1, value2); 364 else if (mlir::isa<mlir::ComplexType>(ty)) 365 return builder.create<fir::AddcOp>(loc, value1, value2); 366 else if (mlir::isa<mlir::IntegerType>(ty)) 367 return builder.create<mlir::arith::AddIOp>(loc, value1, value2); 368 369 llvm_unreachable("unsupported SUM reduction type"); 370 } 371 372 static mlir::Value genMaskValue(mlir::Location loc, 373 fir::FirOpBuilder &builder, mlir::Value mask, 374 mlir::Value isPresentPred, 375 mlir::ValueRange indices) { 376 mlir::OpBuilder::InsertionGuard guard(builder); 377 fir::IfOp ifOp; 378 mlir::Type maskType = 379 hlfir::getFortranElementType(fir::unwrapPassByRefType(mask.getType())); 380 if (isPresentPred) { 381 ifOp = builder.create<fir::IfOp>(loc, maskType, isPresentPred, 382 /*withElseRegion=*/true); 383 384 // Use 'true', if the mask is not present. 385 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 386 mlir::Value trueValue = builder.createBool(loc, true); 387 trueValue = builder.createConvert(loc, maskType, trueValue); 388 builder.create<fir::ResultOp>(loc, trueValue); 389 390 // Load the mask value, if the mask is present. 391 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 392 } 393 394 hlfir::Entity maskVar{mask}; 395 if (maskVar.isScalar()) { 396 if (mlir::isa<fir::BaseBoxType>(mask.getType())) { 397 // MASK may be a boxed scalar. 398 mlir::Value addr = hlfir::genVariableRawAddress(loc, builder, maskVar); 399 mask = builder.create<fir::LoadOp>(loc, hlfir::Entity{addr}); 400 } else { 401 mask = hlfir::loadTrivialScalar(loc, builder, maskVar); 402 } 403 } else { 404 // Load from the mask array. 405 assert(!indices.empty() && "no indices for addressing the mask array"); 406 maskVar = hlfir::getElementAt(loc, builder, maskVar, indices); 407 mask = hlfir::loadTrivialScalar(loc, builder, maskVar); 408 } 409 410 if (!isPresentPred) 411 return mask; 412 413 builder.create<fir::ResultOp>(loc, mask); 414 return ifOp.getResult(0); 415 } 416 }; 417 418 class CShiftAsElementalConversion 419 : public mlir::OpRewritePattern<hlfir::CShiftOp> { 420 public: 421 using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern; 422 423 llvm::LogicalResult 424 matchAndRewrite(hlfir::CShiftOp cshift, 425 mlir::PatternRewriter &rewriter) const override { 426 using Fortran::common::maxRank; 427 428 hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType()); 429 assert(expr && 430 "expected an expression type for the result of hlfir.cshift"); 431 unsigned arrayRank = expr.getRank(); 432 // When it is a 1D CSHIFT, we may assume that the DIM argument 433 // (whether it is present or absent) is equal to 1, otherwise, 434 // the program is illegal. 435 int64_t dimVal = 1; 436 if (arrayRank != 1) 437 if (mlir::Value dim = cshift.getDim()) { 438 auto constDim = fir::getIntIfConstant(dim); 439 if (!constDim) 440 return rewriter.notifyMatchFailure(cshift, 441 "Nonconstant DIM for CSHIFT"); 442 dimVal = *constDim; 443 } 444 445 if (dimVal <= 0 || dimVal > arrayRank) 446 return rewriter.notifyMatchFailure(cshift, "Invalid DIM for CSHIFT"); 447 448 mlir::Location loc = cshift.getLoc(); 449 fir::FirOpBuilder builder{rewriter, cshift.getOperation()}; 450 mlir::Type elementType = expr.getElementType(); 451 hlfir::Entity array = hlfir::Entity{cshift.getArray()}; 452 mlir::Value arrayShape = hlfir::genShape(loc, builder, array); 453 llvm::SmallVector<mlir::Value> arrayExtents = 454 hlfir::getExplicitExtentsFromShape(arrayShape, builder); 455 llvm::SmallVector<mlir::Value, 1> typeParams; 456 hlfir::genLengthParameters(loc, builder, array, typeParams); 457 hlfir::Entity shift = hlfir::Entity{cshift.getShift()}; 458 // The new index computation involves MODULO, which is not implemented 459 // for IndexType, so use I64 instead. 460 mlir::Type calcType = builder.getI64Type(); 461 462 mlir::Value one = builder.createIntegerConstant(loc, calcType, 1); 463 mlir::Value shiftVal; 464 if (shift.isScalar()) { 465 shiftVal = hlfir::loadTrivialScalar(loc, builder, shift); 466 shiftVal = builder.createConvert(loc, calcType, shiftVal); 467 } 468 469 auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder, 470 mlir::ValueRange inputIndices) -> hlfir::Entity { 471 llvm::SmallVector<mlir::Value, maxRank> indices{inputIndices}; 472 if (!shift.isScalar()) { 473 // When the array is not a vector, section 474 // (s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n) 475 // of the result has a value equal to: 476 // CSHIFT(ARRAY(s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)), 477 // SH, 1), 478 // where SH is either SHIFT (if scalar) or 479 // SHIFT(s(1), s(2), ..., s(dim-1), s(dim+1), ..., s(n)). 480 llvm::SmallVector<mlir::Value, maxRank> shiftIndices{indices}; 481 shiftIndices.erase(shiftIndices.begin() + dimVal - 1); 482 hlfir::Entity shiftElement = 483 hlfir::getElementAt(loc, builder, shift, shiftIndices); 484 shiftVal = hlfir::loadTrivialScalar(loc, builder, shiftElement); 485 shiftVal = builder.createConvert(loc, calcType, shiftVal); 486 } 487 488 // Element i of the result (1-based) is element 489 // 'MODULO(i + SH - 1, SIZE(ARRAY)) + 1' (1-based) of the original 490 // ARRAY (or its section, when ARRAY is not a vector). 491 mlir::Value index = 492 builder.createConvert(loc, calcType, inputIndices[dimVal - 1]); 493 mlir::Value extent = arrayExtents[dimVal - 1]; 494 mlir::Value newIndex = 495 builder.create<mlir::arith::AddIOp>(loc, index, shiftVal); 496 newIndex = builder.create<mlir::arith::SubIOp>(loc, newIndex, one); 497 newIndex = fir::IntrinsicLibrary{builder, loc}.genModulo( 498 calcType, {newIndex, builder.createConvert(loc, calcType, extent)}); 499 newIndex = builder.create<mlir::arith::AddIOp>(loc, newIndex, one); 500 newIndex = builder.createConvert(loc, builder.getIndexType(), newIndex); 501 502 indices[dimVal - 1] = newIndex; 503 hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices); 504 return hlfir::loadTrivialScalar(loc, builder, element); 505 }; 506 507 hlfir::ElementalOp elementalOp = hlfir::genElementalOp( 508 loc, builder, elementType, arrayShape, typeParams, genKernel, 509 /*isUnordered=*/true, 510 array.isPolymorphic() ? static_cast<mlir::Value>(array) : nullptr, 511 cshift.getResult().getType()); 512 rewriter.replaceOp(cshift, elementalOp); 513 return mlir::success(); 514 } 515 }; 516 517 template <typename Op> 518 class MatmulConversion : public mlir::OpRewritePattern<Op> { 519 public: 520 using mlir::OpRewritePattern<Op>::OpRewritePattern; 521 522 llvm::LogicalResult 523 matchAndRewrite(Op matmul, mlir::PatternRewriter &rewriter) const override { 524 mlir::Location loc = matmul.getLoc(); 525 fir::FirOpBuilder builder{rewriter, matmul.getOperation()}; 526 hlfir::Entity lhs = hlfir::Entity{matmul.getLhs()}; 527 hlfir::Entity rhs = hlfir::Entity{matmul.getRhs()}; 528 mlir::Value resultShape, innerProductExtent; 529 std::tie(resultShape, innerProductExtent) = 530 genResultShape(loc, builder, lhs, rhs); 531 532 if (forceMatmulAsElemental || isMatmulTranspose) { 533 // Generate hlfir.elemental that produces the result of 534 // MATMUL/MATMUL(TRANSPOSE). 535 // Note that this implementation is very suboptimal for MATMUL, 536 // but is quite good for MATMUL(TRANSPOSE), e.g.: 537 // R(1:N) = R(1:N) + MATMUL(TRANSPOSE(X(1:N,1:N)), Y(1:N)) 538 // Inlining MATMUL(TRANSPOSE) as hlfir.elemental may result 539 // in merging the inner product computation with the elemental 540 // addition. Note that the inner product computation will 541 // benefit from processing the lowermost dimensions of X and Y, 542 // which may be the best when they are contiguous. 543 // 544 // This is why we always inline MATMUL(TRANSPOSE) as an elemental. 545 // MATMUL is inlined below by default unless forceMatmulAsElemental. 546 hlfir::ExprType resultType = 547 mlir::cast<hlfir::ExprType>(matmul.getType()); 548 hlfir::ElementalOp newOp = genElementalMatmul( 549 loc, builder, resultType, resultShape, lhs, rhs, innerProductExtent); 550 rewriter.replaceOp(matmul, newOp); 551 return mlir::success(); 552 } 553 554 // Generate hlfir.eval_in_mem to mimic the MATMUL implementation 555 // from Fortran runtime. The implementation needs to operate 556 // with the result array as an in-memory object. 557 hlfir::EvaluateInMemoryOp evalOp = 558 builder.create<hlfir::EvaluateInMemoryOp>( 559 loc, mlir::cast<hlfir::ExprType>(matmul.getType()), resultShape); 560 builder.setInsertionPointToStart(&evalOp.getBody().front()); 561 562 // Embox the raw array pointer to simplify designating it. 563 // TODO: this currently results in redundant lower bounds 564 // addition for the designator, but this should be fixed in 565 // hlfir::Entity::mayHaveNonDefaultLowerBounds(). 566 mlir::Value resultArray = evalOp.getMemory(); 567 mlir::Type arrayType = fir::dyn_cast_ptrEleTy(resultArray.getType()); 568 resultArray = builder.createBox(loc, fir::BoxType::get(arrayType), 569 resultArray, resultShape, /*slice=*/nullptr, 570 /*lengths=*/{}, /*tdesc=*/nullptr); 571 572 // The contiguous MATMUL version is best for the cases 573 // where the input arrays and (maybe) the result are contiguous 574 // in their lowermost dimensions. 575 // Especially, when LLVM can recognize the continuity 576 // and vectorize the loops properly. 577 // Note that the contiguous MATMUL inlining is correct 578 // even when the input arrays are not contiguous. 579 // TODO: we can try to recognize the cases when the continuity 580 // is not statically obvious and try to generate an explicitly 581 // continuous version under a dynamic check. This should allow 582 // LLVM to vectorize the loops better. Note that this can 583 // also be postponed up to the LoopVersioning pass. 584 // The fallback implementation may use genElementalMatmul() with 585 // an hlfir.assign into the result of eval_in_mem. 586 mlir::LogicalResult rewriteResult = 587 genContiguousMatmul(loc, builder, hlfir::Entity{resultArray}, 588 resultShape, lhs, rhs, innerProductExtent); 589 590 if (mlir::failed(rewriteResult)) { 591 // Erase the unclaimed eval_in_mem op. 592 rewriter.eraseOp(evalOp); 593 return rewriter.notifyMatchFailure(matmul, 594 "genContiguousMatmul() failed"); 595 } 596 597 rewriter.replaceOp(matmul, evalOp); 598 return mlir::success(); 599 } 600 601 private: 602 static constexpr bool isMatmulTranspose = 603 std::is_same_v<Op, hlfir::MatmulTransposeOp>; 604 605 // Return a tuple of: 606 // * A fir.shape operation representing the shape of the result 607 // of a MATMUL/MATMUL(TRANSPOSE). 608 // * An extent of the dimensions of the input array 609 // that are processed during the inner product computation. 610 static std::tuple<mlir::Value, mlir::Value> 611 genResultShape(mlir::Location loc, fir::FirOpBuilder &builder, 612 hlfir::Entity input1, hlfir::Entity input2) { 613 llvm::SmallVector<mlir::Value, 2> input1Extents = 614 hlfir::genExtentsVector(loc, builder, input1); 615 llvm::SmallVector<mlir::Value, 2> input2Extents = 616 hlfir::genExtentsVector(loc, builder, input2); 617 618 llvm::SmallVector<mlir::Value, 2> newExtents; 619 mlir::Value innerProduct1Extent, innerProduct2Extent; 620 if (input1Extents.size() == 1) { 621 assert(!isMatmulTranspose && 622 "hlfir.matmul_transpose's first operand must be rank-2 array"); 623 assert(input2Extents.size() == 2 && 624 "hlfir.matmul second argument must be rank-2 array"); 625 newExtents.push_back(input2Extents[1]); 626 innerProduct1Extent = input1Extents[0]; 627 innerProduct2Extent = input2Extents[0]; 628 } else { 629 if (input2Extents.size() == 1) { 630 assert(input1Extents.size() == 2 && 631 "hlfir.matmul first argument must be rank-2 array"); 632 if constexpr (isMatmulTranspose) 633 newExtents.push_back(input1Extents[1]); 634 else 635 newExtents.push_back(input1Extents[0]); 636 } else { 637 assert(input1Extents.size() == 2 && input2Extents.size() == 2 && 638 "hlfir.matmul arguments must be rank-2 arrays"); 639 if constexpr (isMatmulTranspose) 640 newExtents.push_back(input1Extents[1]); 641 else 642 newExtents.push_back(input1Extents[0]); 643 644 newExtents.push_back(input2Extents[1]); 645 } 646 if constexpr (isMatmulTranspose) 647 innerProduct1Extent = input1Extents[0]; 648 else 649 innerProduct1Extent = input1Extents[1]; 650 651 innerProduct2Extent = input2Extents[0]; 652 } 653 // The inner product dimensions of the input arrays 654 // must match. Pick the best (e.g. constant) out of them 655 // so that the inner product loop bound can be used in 656 // optimizations. 657 llvm::SmallVector<mlir::Value> innerProductExtent = 658 fir::factory::deduceOptimalExtents({innerProduct1Extent}, 659 {innerProduct2Extent}); 660 return {builder.create<fir::ShapeOp>(loc, newExtents), 661 innerProductExtent[0]}; 662 } 663 664 static mlir::LogicalResult 665 genContiguousMatmul(mlir::Location loc, fir::FirOpBuilder &builder, 666 hlfir::Entity result, mlir::Value resultShape, 667 hlfir::Entity lhs, hlfir::Entity rhs, 668 mlir::Value innerProductExtent) { 669 // This code does not support MATMUL(TRANSPOSE), and it is supposed 670 // to be inlined as hlfir.elemental. 671 if constexpr (isMatmulTranspose) 672 return mlir::failure(); 673 674 mlir::OpBuilder::InsertionGuard guard(builder); 675 mlir::Type resultElementType = result.getFortranElementType(); 676 llvm::SmallVector<mlir::Value, 2> resultExtents = 677 mlir::cast<fir::ShapeOp>(resultShape.getDefiningOp()).getExtents(); 678 679 // The inner product loop may be unordered if FastMathFlags::reassoc 680 // transformations are allowed. The integer/logical inner product is 681 // always unordered. 682 // Note that isUnordered is currently applied to all loops 683 // in the loop nests generated below, while it has to be applied 684 // only to one. 685 bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) || 686 mlir::isa<fir::LogicalType>(resultElementType) || 687 static_cast<bool>(builder.getFastMathFlags() & 688 mlir::arith::FastMathFlags::reassoc); 689 690 // Insert the initialization loop nest that fills the whole result with 691 // zeroes. 692 mlir::Value initValue = 693 fir::factory::createZeroValue(builder, loc, resultElementType); 694 auto genInitBody = [&](mlir::Location loc, fir::FirOpBuilder &builder, 695 mlir::ValueRange oneBasedIndices, 696 mlir::ValueRange reductionArgs) 697 -> llvm::SmallVector<mlir::Value, 0> { 698 hlfir::Entity resultElement = 699 hlfir::getElementAt(loc, builder, result, oneBasedIndices); 700 builder.create<hlfir::AssignOp>(loc, initValue, resultElement); 701 return {}; 702 }; 703 704 hlfir::genLoopNestWithReductions(loc, builder, resultExtents, 705 /*reductionInits=*/{}, genInitBody, 706 /*isUnordered=*/true); 707 708 if (lhs.getRank() == 2 && rhs.getRank() == 2) { 709 // LHS(NROWS,N) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS) 710 // 711 // Insert the computation loop nest: 712 // DO 2 K = 1, N 713 // DO 2 J = 1, NCOLS 714 // DO 2 I = 1, NROWS 715 // 2 RESULT(I,J) = RESULT(I,J) + LHS(I,K)*RHS(K,J) 716 auto genMatrixMatrix = [&](mlir::Location loc, fir::FirOpBuilder &builder, 717 mlir::ValueRange oneBasedIndices, 718 mlir::ValueRange reductionArgs) 719 -> llvm::SmallVector<mlir::Value, 0> { 720 mlir::Value I = oneBasedIndices[0]; 721 mlir::Value J = oneBasedIndices[1]; 722 mlir::Value K = oneBasedIndices[2]; 723 hlfir::Entity resultElement = 724 hlfir::getElementAt(loc, builder, result, {I, J}); 725 hlfir::Entity resultElementValue = 726 hlfir::loadTrivialScalar(loc, builder, resultElement); 727 hlfir::Entity lhsElementValue = 728 hlfir::loadElementAt(loc, builder, lhs, {I, K}); 729 hlfir::Entity rhsElementValue = 730 hlfir::loadElementAt(loc, builder, rhs, {K, J}); 731 mlir::Value productValue = 732 ProductFactory{loc, builder}.genAccumulateProduct( 733 resultElementValue, lhsElementValue, rhsElementValue); 734 builder.create<hlfir::AssignOp>(loc, productValue, resultElement); 735 return {}; 736 }; 737 738 // Note that the loops are inserted in reverse order, 739 // so innerProductExtent should be passed as the last extent. 740 hlfir::genLoopNestWithReductions( 741 loc, builder, 742 {resultExtents[0], resultExtents[1], innerProductExtent}, 743 /*reductionInits=*/{}, genMatrixMatrix, isUnordered); 744 return mlir::success(); 745 } 746 747 if (lhs.getRank() == 2 && rhs.getRank() == 1) { 748 // LHS(NROWS,N) * RHS(N) -> RESULT(NROWS) 749 // 750 // Insert the computation loop nest: 751 // DO 2 K = 1, N 752 // DO 2 J = 1, NROWS 753 // 2 RES(J) = RES(J) + LHS(J,K)*RHS(K) 754 auto genMatrixVector = [&](mlir::Location loc, fir::FirOpBuilder &builder, 755 mlir::ValueRange oneBasedIndices, 756 mlir::ValueRange reductionArgs) 757 -> llvm::SmallVector<mlir::Value, 0> { 758 mlir::Value J = oneBasedIndices[0]; 759 mlir::Value K = oneBasedIndices[1]; 760 hlfir::Entity resultElement = 761 hlfir::getElementAt(loc, builder, result, {J}); 762 hlfir::Entity resultElementValue = 763 hlfir::loadTrivialScalar(loc, builder, resultElement); 764 hlfir::Entity lhsElementValue = 765 hlfir::loadElementAt(loc, builder, lhs, {J, K}); 766 hlfir::Entity rhsElementValue = 767 hlfir::loadElementAt(loc, builder, rhs, {K}); 768 mlir::Value productValue = 769 ProductFactory{loc, builder}.genAccumulateProduct( 770 resultElementValue, lhsElementValue, rhsElementValue); 771 builder.create<hlfir::AssignOp>(loc, productValue, resultElement); 772 return {}; 773 }; 774 hlfir::genLoopNestWithReductions( 775 loc, builder, {resultExtents[0], innerProductExtent}, 776 /*reductionInits=*/{}, genMatrixVector, isUnordered); 777 return mlir::success(); 778 } 779 if (lhs.getRank() == 1 && rhs.getRank() == 2) { 780 // LHS(N) * RHS(N,NCOLS) -> RESULT(NCOLS) 781 // 782 // Insert the computation loop nest: 783 // DO 2 K = 1, N 784 // DO 2 J = 1, NCOLS 785 // 2 RES(J) = RES(J) + LHS(K)*RHS(K,J) 786 auto genVectorMatrix = [&](mlir::Location loc, fir::FirOpBuilder &builder, 787 mlir::ValueRange oneBasedIndices, 788 mlir::ValueRange reductionArgs) 789 -> llvm::SmallVector<mlir::Value, 0> { 790 mlir::Value J = oneBasedIndices[0]; 791 mlir::Value K = oneBasedIndices[1]; 792 hlfir::Entity resultElement = 793 hlfir::getElementAt(loc, builder, result, {J}); 794 hlfir::Entity resultElementValue = 795 hlfir::loadTrivialScalar(loc, builder, resultElement); 796 hlfir::Entity lhsElementValue = 797 hlfir::loadElementAt(loc, builder, lhs, {K}); 798 hlfir::Entity rhsElementValue = 799 hlfir::loadElementAt(loc, builder, rhs, {K, J}); 800 mlir::Value productValue = 801 ProductFactory{loc, builder}.genAccumulateProduct( 802 resultElementValue, lhsElementValue, rhsElementValue); 803 builder.create<hlfir::AssignOp>(loc, productValue, resultElement); 804 return {}; 805 }; 806 hlfir::genLoopNestWithReductions( 807 loc, builder, {resultExtents[0], innerProductExtent}, 808 /*reductionInits=*/{}, genVectorMatrix, isUnordered); 809 return mlir::success(); 810 } 811 812 llvm_unreachable("unsupported MATMUL arguments' ranks"); 813 } 814 815 static hlfir::ElementalOp 816 genElementalMatmul(mlir::Location loc, fir::FirOpBuilder &builder, 817 hlfir::ExprType resultType, mlir::Value resultShape, 818 hlfir::Entity lhs, hlfir::Entity rhs, 819 mlir::Value innerProductExtent) { 820 mlir::OpBuilder::InsertionGuard guard(builder); 821 mlir::Type resultElementType = resultType.getElementType(); 822 auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder, 823 mlir::ValueRange resultIndices) -> hlfir::Entity { 824 mlir::Value initValue = 825 fir::factory::createZeroValue(builder, loc, resultElementType); 826 // The inner product loop may be unordered if FastMathFlags::reassoc 827 // transformations are allowed. The integer/logical inner product is 828 // always unordered. 829 bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) || 830 mlir::isa<fir::LogicalType>(resultElementType) || 831 static_cast<bool>(builder.getFastMathFlags() & 832 mlir::arith::FastMathFlags::reassoc); 833 834 auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder, 835 mlir::ValueRange oneBasedIndices, 836 mlir::ValueRange reductionArgs) 837 -> llvm::SmallVector<mlir::Value, 1> { 838 llvm::SmallVector<mlir::Value, 2> lhsIndices; 839 llvm::SmallVector<mlir::Value, 2> rhsIndices; 840 // MATMUL: 841 // LHS(NROWS,N) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS) 842 // LHS(NROWS,N) * RHS(N) -> RESULT(NROWS) 843 // LHS(N) * RHS(N,NCOLS) -> RESULT(NCOLS) 844 // 845 // MATMUL(TRANSPOSE): 846 // TRANSPOSE(LHS(N,NROWS)) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS) 847 // TRANSPOSE(LHS(N,NROWS)) * RHS(N) -> RESULT(NROWS) 848 // 849 // The resultIndices iterate over (NROWS[,NCOLS]). 850 // The oneBasedIndices iterate over (N). 851 if (lhs.getRank() > 1) 852 lhsIndices.push_back(resultIndices[0]); 853 lhsIndices.push_back(oneBasedIndices[0]); 854 855 if constexpr (isMatmulTranspose) { 856 // Swap the LHS indices for TRANSPOSE. 857 std::swap(lhsIndices[0], lhsIndices[1]); 858 } 859 860 rhsIndices.push_back(oneBasedIndices[0]); 861 if (rhs.getRank() > 1) 862 rhsIndices.push_back(resultIndices.back()); 863 864 hlfir::Entity lhsElementValue = 865 hlfir::loadElementAt(loc, builder, lhs, lhsIndices); 866 hlfir::Entity rhsElementValue = 867 hlfir::loadElementAt(loc, builder, rhs, rhsIndices); 868 mlir::Value productValue = 869 ProductFactory{loc, builder}.genAccumulateProduct( 870 reductionArgs[0], lhsElementValue, rhsElementValue); 871 return {productValue}; 872 }; 873 llvm::SmallVector<mlir::Value, 1> innerProductValue = 874 hlfir::genLoopNestWithReductions(loc, builder, {innerProductExtent}, 875 {initValue}, genBody, isUnordered); 876 return hlfir::Entity{innerProductValue[0]}; 877 }; 878 hlfir::ElementalOp elementalOp = hlfir::genElementalOp( 879 loc, builder, resultElementType, resultShape, /*typeParams=*/{}, 880 genKernel, 881 /*isUnordered=*/true, /*polymorphicMold=*/nullptr, resultType); 882 883 return elementalOp; 884 } 885 }; 886 887 class DotProductConversion 888 : public mlir::OpRewritePattern<hlfir::DotProductOp> { 889 public: 890 using mlir::OpRewritePattern<hlfir::DotProductOp>::OpRewritePattern; 891 892 llvm::LogicalResult 893 matchAndRewrite(hlfir::DotProductOp product, 894 mlir::PatternRewriter &rewriter) const override { 895 hlfir::Entity op = hlfir::Entity{product}; 896 if (!op.isScalar()) 897 return rewriter.notifyMatchFailure(product, "produces non-scalar result"); 898 899 mlir::Location loc = product.getLoc(); 900 fir::FirOpBuilder builder{rewriter, product.getOperation()}; 901 hlfir::Entity lhs = hlfir::Entity{product.getLhs()}; 902 hlfir::Entity rhs = hlfir::Entity{product.getRhs()}; 903 mlir::Type resultElementType = product.getType(); 904 bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) || 905 mlir::isa<fir::LogicalType>(resultElementType) || 906 static_cast<bool>(builder.getFastMathFlags() & 907 mlir::arith::FastMathFlags::reassoc); 908 909 mlir::Value extent = genProductExtent(loc, builder, lhs, rhs); 910 911 auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder, 912 mlir::ValueRange oneBasedIndices, 913 mlir::ValueRange reductionArgs) 914 -> llvm::SmallVector<mlir::Value, 1> { 915 hlfir::Entity lhsElementValue = 916 hlfir::loadElementAt(loc, builder, lhs, oneBasedIndices); 917 hlfir::Entity rhsElementValue = 918 hlfir::loadElementAt(loc, builder, rhs, oneBasedIndices); 919 mlir::Value productValue = 920 ProductFactory{loc, builder}.genAccumulateProduct</*CONJ=*/true>( 921 reductionArgs[0], lhsElementValue, rhsElementValue); 922 return {productValue}; 923 }; 924 925 mlir::Value initValue = 926 fir::factory::createZeroValue(builder, loc, resultElementType); 927 928 llvm::SmallVector<mlir::Value, 1> result = hlfir::genLoopNestWithReductions( 929 loc, builder, {extent}, 930 /*reductionInits=*/{initValue}, genBody, isUnordered); 931 932 rewriter.replaceOp(product, result[0]); 933 return mlir::success(); 934 } 935 936 private: 937 static mlir::Value genProductExtent(mlir::Location loc, 938 fir::FirOpBuilder &builder, 939 hlfir::Entity input1, 940 hlfir::Entity input2) { 941 llvm::SmallVector<mlir::Value, 1> input1Extents = 942 hlfir::genExtentsVector(loc, builder, input1); 943 llvm::SmallVector<mlir::Value, 1> input2Extents = 944 hlfir::genExtentsVector(loc, builder, input2); 945 946 assert(input1Extents.size() == 1 && input2Extents.size() == 1 && 947 "hlfir.dot_product arguments must be vectors"); 948 llvm::SmallVector<mlir::Value, 1> extent = 949 fir::factory::deduceOptimalExtents(input1Extents, input2Extents); 950 return extent[0]; 951 } 952 }; 953 954 class SimplifyHLFIRIntrinsics 955 : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> { 956 public: 957 using SimplifyHLFIRIntrinsicsBase< 958 SimplifyHLFIRIntrinsics>::SimplifyHLFIRIntrinsicsBase; 959 960 void runOnOperation() override { 961 mlir::MLIRContext *context = &getContext(); 962 963 mlir::GreedyRewriteConfig config; 964 // Prevent the pattern driver from merging blocks 965 config.enableRegionSimplification = 966 mlir::GreedySimplifyRegionLevel::Disabled; 967 968 mlir::RewritePatternSet patterns(context); 969 patterns.insert<TransposeAsElementalConversion>(context); 970 patterns.insert<SumAsElementalConversion>(context); 971 patterns.insert<CShiftAsElementalConversion>(context); 972 patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context); 973 974 // If forceMatmulAsElemental is false, then hlfir.matmul inlining 975 // will introduce hlfir.eval_in_mem operation with new memory side 976 // effects. This conflicts with CSE and optimized bufferization, e.g.: 977 // A(1:N,1:N) = A(1:N,1:N) - MATMUL(...) 978 // If we introduce hlfir.eval_in_mem before CSE, then the current 979 // MLIR CSE won't be able to optimize the trivial loads of 'N' value 980 // that happen before and after hlfir.matmul. 981 // If 'N' loads are not optimized, then the optimized bufferization 982 // won't be able to prove that the slices of A are identical 983 // on both sides of the assignment. 984 // This is actually the CSE problem, but we can work it around 985 // for the time being. 986 if (forceMatmulAsElemental || this->allowNewSideEffects) 987 patterns.insert<MatmulConversion<hlfir::MatmulOp>>(context); 988 989 patterns.insert<DotProductConversion>(context); 990 991 if (mlir::failed(mlir::applyPatternsGreedily( 992 getOperation(), std::move(patterns), config))) { 993 mlir::emitError(getOperation()->getLoc(), 994 "failure in HLFIR intrinsic simplification"); 995 signalPassFailure(); 996 } 997 } 998 }; 999 } // namespace 1000