1 //===-- HlfirIntrinsics.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Lower/HlfirIntrinsics.h" 14 15 #include "flang/Optimizer/Builder/BoxValue.h" 16 #include "flang/Optimizer/Builder/FIRBuilder.h" 17 #include "flang/Optimizer/Builder/HLFIRTools.h" 18 #include "flang/Optimizer/Builder/IntrinsicCall.h" 19 #include "flang/Optimizer/Builder/MutableBox.h" 20 #include "flang/Optimizer/Builder/Todo.h" 21 #include "flang/Optimizer/Dialect/FIRType.h" 22 #include "flang/Optimizer/HLFIR/HLFIRDialect.h" 23 #include "flang/Optimizer/HLFIR/HLFIROps.h" 24 #include "mlir/IR/Value.h" 25 #include "llvm/ADT/SmallVector.h" 26 #include <mlir/IR/ValueRange.h> 27 28 namespace { 29 30 class HlfirTransformationalIntrinsic { 31 public: 32 explicit HlfirTransformationalIntrinsic(fir::FirOpBuilder &builder, 33 mlir::Location loc) 34 : builder(builder), loc(loc) {} 35 36 virtual ~HlfirTransformationalIntrinsic() = default; 37 38 hlfir::EntityWithAttributes 39 lower(const Fortran::lower::PreparedActualArguments &loweredActuals, 40 const fir::IntrinsicArgumentLoweringRules *argLowering, 41 mlir::Type stmtResultType) { 42 mlir::Value res = lowerImpl(loweredActuals, argLowering, stmtResultType); 43 for (const hlfir::CleanupFunction &fn : cleanupFns) 44 fn(); 45 return {hlfir::EntityWithAttributes{res}}; 46 } 47 48 protected: 49 fir::FirOpBuilder &builder; 50 mlir::Location loc; 51 llvm::SmallVector<hlfir::CleanupFunction, 3> cleanupFns; 52 53 virtual mlir::Value 54 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, 55 const fir::IntrinsicArgumentLoweringRules *argLowering, 56 mlir::Type stmtResultType) = 0; 57 58 llvm::SmallVector<mlir::Value> getOperandVector( 59 const Fortran::lower::PreparedActualArguments &loweredActuals, 60 const fir::IntrinsicArgumentLoweringRules *argLowering); 61 62 mlir::Type computeResultType(mlir::Value argArray, mlir::Type stmtResultType); 63 64 template <typename OP, typename... BUILD_ARGS> 65 inline OP createOp(BUILD_ARGS... args) { 66 return builder.create<OP>(loc, args...); 67 } 68 69 mlir::Value loadBoxAddress( 70 const std::optional<Fortran::lower::PreparedActualArgument> &arg); 71 72 void addCleanup(std::optional<hlfir::CleanupFunction> cleanup) { 73 if (cleanup) 74 cleanupFns.emplace_back(std::move(*cleanup)); 75 } 76 }; 77 78 template <typename OP, bool HAS_MASK> 79 class HlfirReductionIntrinsic : public HlfirTransformationalIntrinsic { 80 public: 81 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; 82 83 protected: 84 mlir::Value 85 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, 86 const fir::IntrinsicArgumentLoweringRules *argLowering, 87 mlir::Type stmtResultType) override; 88 }; 89 using HlfirSumLowering = HlfirReductionIntrinsic<hlfir::SumOp, true>; 90 using HlfirProductLowering = HlfirReductionIntrinsic<hlfir::ProductOp, true>; 91 using HlfirMaxvalLowering = HlfirReductionIntrinsic<hlfir::MaxvalOp, true>; 92 using HlfirMinvalLowering = HlfirReductionIntrinsic<hlfir::MinvalOp, true>; 93 using HlfirAnyLowering = HlfirReductionIntrinsic<hlfir::AnyOp, false>; 94 using HlfirAllLowering = HlfirReductionIntrinsic<hlfir::AllOp, false>; 95 96 template <typename OP> 97 class HlfirMinMaxLocIntrinsic : public HlfirTransformationalIntrinsic { 98 public: 99 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; 100 101 protected: 102 mlir::Value 103 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, 104 const fir::IntrinsicArgumentLoweringRules *argLowering, 105 mlir::Type stmtResultType) override; 106 }; 107 using HlfirMinlocLowering = HlfirMinMaxLocIntrinsic<hlfir::MinlocOp>; 108 using HlfirMaxlocLowering = HlfirMinMaxLocIntrinsic<hlfir::MaxlocOp>; 109 110 template <typename OP> 111 class HlfirProductIntrinsic : public HlfirTransformationalIntrinsic { 112 public: 113 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; 114 115 protected: 116 mlir::Value 117 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, 118 const fir::IntrinsicArgumentLoweringRules *argLowering, 119 mlir::Type stmtResultType) override; 120 }; 121 using HlfirMatmulLowering = HlfirProductIntrinsic<hlfir::MatmulOp>; 122 using HlfirDotProductLowering = HlfirProductIntrinsic<hlfir::DotProductOp>; 123 124 class HlfirTransposeLowering : public HlfirTransformationalIntrinsic { 125 public: 126 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; 127 128 protected: 129 mlir::Value 130 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, 131 const fir::IntrinsicArgumentLoweringRules *argLowering, 132 mlir::Type stmtResultType) override; 133 }; 134 135 class HlfirCountLowering : public HlfirTransformationalIntrinsic { 136 public: 137 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; 138 139 protected: 140 mlir::Value 141 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, 142 const fir::IntrinsicArgumentLoweringRules *argLowering, 143 mlir::Type stmtResultType) override; 144 }; 145 146 class HlfirCharExtremumLowering : public HlfirTransformationalIntrinsic { 147 public: 148 HlfirCharExtremumLowering(fir::FirOpBuilder &builder, mlir::Location loc, 149 hlfir::CharExtremumPredicate pred) 150 : HlfirTransformationalIntrinsic(builder, loc), pred{pred} {} 151 152 protected: 153 mlir::Value 154 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, 155 const fir::IntrinsicArgumentLoweringRules *argLowering, 156 mlir::Type stmtResultType) override; 157 158 protected: 159 hlfir::CharExtremumPredicate pred; 160 }; 161 162 class HlfirCShiftLowering : public HlfirTransformationalIntrinsic { 163 public: 164 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; 165 166 protected: 167 mlir::Value 168 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, 169 const fir::IntrinsicArgumentLoweringRules *argLowering, 170 mlir::Type stmtResultType) override; 171 }; 172 173 class HlfirReshapeLowering : public HlfirTransformationalIntrinsic { 174 public: 175 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; 176 177 protected: 178 mlir::Value 179 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, 180 const fir::IntrinsicArgumentLoweringRules *argLowering, 181 mlir::Type stmtResultType) override; 182 }; 183 184 } // namespace 185 186 mlir::Value HlfirTransformationalIntrinsic::loadBoxAddress( 187 const std::optional<Fortran::lower::PreparedActualArgument> &arg) { 188 if (!arg) 189 return mlir::Value{}; 190 191 hlfir::Entity actual = arg->getActual(loc, builder); 192 193 if (!arg->handleDynamicOptional()) { 194 if (actual.isMutableBox()) { 195 // this is a box address type but is not dynamically optional. Just load 196 // the box, assuming it is well formed (!fir.ref<!fir.box<...>> -> 197 // !fir.box<...>) 198 return builder.create<fir::LoadOp>(loc, actual.getBase()); 199 } 200 return actual; 201 } 202 203 auto [exv, cleanup] = hlfir::translateToExtendedValue(loc, builder, actual); 204 addCleanup(cleanup); 205 206 mlir::Value isPresent = arg->getIsPresent(); 207 // createBox will not do create any invalid memory dereferences if exv is 208 // absent. The created fir.box will not be usable, but the SelectOp below 209 // ensures it won't be. 210 mlir::Value box = builder.createBox(loc, exv); 211 mlir::Type boxType = box.getType(); 212 auto absent = builder.create<fir::AbsentOp>(loc, boxType); 213 auto boxOrAbsent = builder.create<mlir::arith::SelectOp>( 214 loc, boxType, isPresent, box, absent); 215 216 return boxOrAbsent; 217 } 218 219 static mlir::Value loadOptionalValue( 220 mlir::Location loc, fir::FirOpBuilder &builder, 221 const std::optional<Fortran::lower::PreparedActualArgument> &arg, 222 hlfir::Entity actual) { 223 if (!arg->handleDynamicOptional()) 224 return hlfir::loadTrivialScalar(loc, builder, actual); 225 226 mlir::Value isPresent = arg->getIsPresent(); 227 mlir::Type eleType = hlfir::getFortranElementType(actual.getType()); 228 return builder 229 .genIfOp(loc, {eleType}, isPresent, 230 /*withElseRegion=*/true) 231 .genThen([&]() { 232 assert(actual.isScalar() && fir::isa_trivial(eleType) && 233 "must be a numerical or logical scalar"); 234 hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, actual); 235 builder.create<fir::ResultOp>(loc, val); 236 }) 237 .genElse([&]() { 238 mlir::Value zero = fir::factory::createZeroValue(builder, loc, eleType); 239 builder.create<fir::ResultOp>(loc, zero); 240 }) 241 .getResults()[0]; 242 } 243 244 llvm::SmallVector<mlir::Value> HlfirTransformationalIntrinsic::getOperandVector( 245 const Fortran::lower::PreparedActualArguments &loweredActuals, 246 const fir::IntrinsicArgumentLoweringRules *argLowering) { 247 llvm::SmallVector<mlir::Value> operands; 248 operands.reserve(loweredActuals.size()); 249 250 for (size_t i = 0; i < loweredActuals.size(); ++i) { 251 std::optional<Fortran::lower::PreparedActualArgument> arg = 252 loweredActuals[i]; 253 if (!arg) { 254 operands.emplace_back(); 255 continue; 256 } 257 hlfir::Entity actual = arg->getActual(loc, builder); 258 mlir::Value valArg; 259 260 if (!argLowering) { 261 valArg = hlfir::loadTrivialScalar(loc, builder, actual); 262 } else { 263 fir::ArgLoweringRule argRules = 264 fir::lowerIntrinsicArgumentAs(*argLowering, i); 265 if (argRules.lowerAs == fir::LowerIntrinsicArgAs::Box) 266 valArg = loadBoxAddress(arg); 267 else if (!argRules.handleDynamicOptional && 268 argRules.lowerAs != fir::LowerIntrinsicArgAs::Inquired) 269 valArg = hlfir::derefPointersAndAllocatables(loc, builder, actual); 270 else if (argRules.handleDynamicOptional && 271 argRules.lowerAs == fir::LowerIntrinsicArgAs::Value) 272 valArg = loadOptionalValue(loc, builder, arg, actual); 273 else if (argRules.handleDynamicOptional) 274 TODO(loc, "hlfir transformational intrinsic dynamically optional " 275 "argument without box lowering"); 276 else 277 valArg = actual.getBase(); 278 } 279 280 operands.emplace_back(valArg); 281 } 282 return operands; 283 } 284 285 mlir::Type 286 HlfirTransformationalIntrinsic::computeResultType(mlir::Value argArray, 287 mlir::Type stmtResultType) { 288 mlir::Type normalisedResult = 289 hlfir::getFortranElementOrSequenceType(stmtResultType); 290 if (auto array = mlir::dyn_cast<fir::SequenceType>(normalisedResult)) { 291 hlfir::ExprType::Shape resultShape = 292 hlfir::ExprType::Shape{array.getShape()}; 293 mlir::Type elementType = array.getEleTy(); 294 return hlfir::ExprType::get(builder.getContext(), resultShape, elementType, 295 fir::isPolymorphicType(stmtResultType)); 296 } else if (auto resCharType = 297 mlir::dyn_cast<fir::CharacterType>(stmtResultType)) { 298 normalisedResult = hlfir::ExprType::get( 299 builder.getContext(), hlfir::ExprType::Shape{}, resCharType, 300 /*polymorphic=*/false); 301 } 302 return normalisedResult; 303 } 304 305 template <typename OP, bool HAS_MASK> 306 mlir::Value HlfirReductionIntrinsic<OP, HAS_MASK>::lowerImpl( 307 const Fortran::lower::PreparedActualArguments &loweredActuals, 308 const fir::IntrinsicArgumentLoweringRules *argLowering, 309 mlir::Type stmtResultType) { 310 auto operands = getOperandVector(loweredActuals, argLowering); 311 mlir::Value array = operands[0]; 312 mlir::Value dim = operands[1]; 313 // dim, mask can be NULL if these arguments are not given 314 if (dim) 315 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim}); 316 317 mlir::Type resultTy = computeResultType(array, stmtResultType); 318 319 OP op; 320 if constexpr (HAS_MASK) 321 op = createOp<OP>(resultTy, array, dim, 322 /*mask=*/operands[2]); 323 else 324 op = createOp<OP>(resultTy, array, dim); 325 return op; 326 } 327 328 template <typename OP> 329 mlir::Value HlfirMinMaxLocIntrinsic<OP>::lowerImpl( 330 const Fortran::lower::PreparedActualArguments &loweredActuals, 331 const fir::IntrinsicArgumentLoweringRules *argLowering, 332 mlir::Type stmtResultType) { 333 auto operands = getOperandVector(loweredActuals, argLowering); 334 mlir::Value array = operands[0]; 335 mlir::Value dim = operands[1]; 336 mlir::Value mask = operands[2]; 337 mlir::Value back = operands[4]; 338 // dim, mask and back can be NULL if these arguments are not given. 339 if (dim) 340 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim}); 341 if (back) 342 back = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{back}); 343 344 mlir::Type resultTy = computeResultType(array, stmtResultType); 345 346 return createOp<OP>(resultTy, array, dim, mask, back); 347 } 348 349 template <typename OP> 350 mlir::Value HlfirProductIntrinsic<OP>::lowerImpl( 351 const Fortran::lower::PreparedActualArguments &loweredActuals, 352 const fir::IntrinsicArgumentLoweringRules *argLowering, 353 mlir::Type stmtResultType) { 354 auto operands = getOperandVector(loweredActuals, argLowering); 355 mlir::Type resultType = computeResultType(operands[0], stmtResultType); 356 return createOp<OP>(resultType, operands[0], operands[1]); 357 } 358 359 mlir::Value HlfirTransposeLowering::lowerImpl( 360 const Fortran::lower::PreparedActualArguments &loweredActuals, 361 const fir::IntrinsicArgumentLoweringRules *argLowering, 362 mlir::Type stmtResultType) { 363 auto operands = getOperandVector(loweredActuals, argLowering); 364 hlfir::ExprType::Shape resultShape; 365 mlir::Type normalisedResult = 366 hlfir::getFortranElementOrSequenceType(stmtResultType); 367 auto array = mlir::cast<fir::SequenceType>(normalisedResult); 368 llvm::ArrayRef<int64_t> arrayShape = array.getShape(); 369 assert(arrayShape.size() == 2 && "arguments to transpose have a rank of 2"); 370 mlir::Type elementType = array.getEleTy(); 371 resultShape.push_back(arrayShape[0]); 372 resultShape.push_back(arrayShape[1]); 373 if (auto resCharType = mlir::dyn_cast<fir::CharacterType>(elementType)) 374 if (!resCharType.hasConstantLen()) { 375 // The FunctionRef expression might have imprecise character 376 // type at this point, and we can improve it by propagating 377 // the constant length from the argument. 378 auto argCharType = mlir::dyn_cast<fir::CharacterType>( 379 hlfir::getFortranElementType(operands[0].getType())); 380 if (argCharType && argCharType.hasConstantLen()) 381 elementType = fir::CharacterType::get( 382 builder.getContext(), resCharType.getFKind(), argCharType.getLen()); 383 } 384 385 mlir::Type resultTy = 386 hlfir::ExprType::get(builder.getContext(), resultShape, elementType, 387 fir::isPolymorphicType(stmtResultType)); 388 return createOp<hlfir::TransposeOp>(resultTy, operands[0]); 389 } 390 391 mlir::Value HlfirCountLowering::lowerImpl( 392 const Fortran::lower::PreparedActualArguments &loweredActuals, 393 const fir::IntrinsicArgumentLoweringRules *argLowering, 394 mlir::Type stmtResultType) { 395 auto operands = getOperandVector(loweredActuals, argLowering); 396 mlir::Value array = operands[0]; 397 mlir::Value dim = operands[1]; 398 if (dim) 399 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim}); 400 mlir::Type resultType = computeResultType(array, stmtResultType); 401 return createOp<hlfir::CountOp>(resultType, array, dim); 402 } 403 404 mlir::Value HlfirCharExtremumLowering::lowerImpl( 405 const Fortran::lower::PreparedActualArguments &loweredActuals, 406 const fir::IntrinsicArgumentLoweringRules *argLowering, 407 mlir::Type stmtResultType) { 408 auto operands = getOperandVector(loweredActuals, argLowering); 409 assert(operands.size() >= 2); 410 return createOp<hlfir::CharExtremumOp>(pred, mlir::ValueRange{operands}); 411 } 412 413 mlir::Value HlfirCShiftLowering::lowerImpl( 414 const Fortran::lower::PreparedActualArguments &loweredActuals, 415 const fir::IntrinsicArgumentLoweringRules *argLowering, 416 mlir::Type stmtResultType) { 417 auto operands = getOperandVector(loweredActuals, argLowering); 418 assert(operands.size() == 3); 419 mlir::Value dim = operands[2]; 420 if (!dim) { 421 // If DIM is not present, drop the last element which is a null Value. 422 operands.truncate(2); 423 } else { 424 // If DIM is present, then dereference it if it is a ref. 425 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim}); 426 operands[2] = dim; 427 } 428 429 mlir::Type resultType = computeResultType(operands[0], stmtResultType); 430 return createOp<hlfir::CShiftOp>(resultType, operands); 431 } 432 433 mlir::Value HlfirReshapeLowering::lowerImpl( 434 const Fortran::lower::PreparedActualArguments &loweredActuals, 435 const fir::IntrinsicArgumentLoweringRules *argLowering, 436 mlir::Type stmtResultType) { 437 auto operands = getOperandVector(loweredActuals, argLowering); 438 assert(operands.size() == 4); 439 mlir::Type resultType = computeResultType(operands[0], stmtResultType); 440 return createOp<hlfir::ReshapeOp>(resultType, operands[0], operands[1], 441 operands[2], operands[3]); 442 } 443 444 std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic( 445 fir::FirOpBuilder &builder, mlir::Location loc, const std::string &name, 446 const Fortran::lower::PreparedActualArguments &loweredActuals, 447 const fir::IntrinsicArgumentLoweringRules *argLowering, 448 mlir::Type stmtResultType) { 449 // If the result is of a derived type that may need finalization, 450 // we have to use DestroyOp with 'finalize' attribute for the result 451 // of the intrinsic operation. 452 if (name == "sum") 453 return HlfirSumLowering{builder, loc}.lower(loweredActuals, argLowering, 454 stmtResultType); 455 if (name == "product") 456 return HlfirProductLowering{builder, loc}.lower(loweredActuals, argLowering, 457 stmtResultType); 458 if (name == "any") 459 return HlfirAnyLowering{builder, loc}.lower(loweredActuals, argLowering, 460 stmtResultType); 461 if (name == "all") 462 return HlfirAllLowering{builder, loc}.lower(loweredActuals, argLowering, 463 stmtResultType); 464 if (name == "matmul") 465 return HlfirMatmulLowering{builder, loc}.lower(loweredActuals, argLowering, 466 stmtResultType); 467 if (name == "dot_product") 468 return HlfirDotProductLowering{builder, loc}.lower( 469 loweredActuals, argLowering, stmtResultType); 470 // FIXME: the result may need finalization. 471 if (name == "transpose") 472 return HlfirTransposeLowering{builder, loc}.lower( 473 loweredActuals, argLowering, stmtResultType); 474 if (name == "count") 475 return HlfirCountLowering{builder, loc}.lower(loweredActuals, argLowering, 476 stmtResultType); 477 if (name == "maxval") 478 return HlfirMaxvalLowering{builder, loc}.lower(loweredActuals, argLowering, 479 stmtResultType); 480 if (name == "minval") 481 return HlfirMinvalLowering{builder, loc}.lower(loweredActuals, argLowering, 482 stmtResultType); 483 if (name == "minloc") 484 return HlfirMinlocLowering{builder, loc}.lower(loweredActuals, argLowering, 485 stmtResultType); 486 if (name == "maxloc") 487 return HlfirMaxlocLowering{builder, loc}.lower(loweredActuals, argLowering, 488 stmtResultType); 489 if (name == "cshift") 490 return HlfirCShiftLowering{builder, loc}.lower(loweredActuals, argLowering, 491 stmtResultType); 492 if (name == "reshape") 493 return HlfirReshapeLowering{builder, loc}.lower(loweredActuals, argLowering, 494 stmtResultType); 495 if (mlir::isa<fir::CharacterType>(stmtResultType)) { 496 if (name == "min") 497 return HlfirCharExtremumLowering{builder, loc, 498 hlfir::CharExtremumPredicate::min} 499 .lower(loweredActuals, argLowering, stmtResultType); 500 if (name == "max") 501 return HlfirCharExtremumLowering{builder, loc, 502 hlfir::CharExtremumPredicate::max} 503 .lower(loweredActuals, argLowering, stmtResultType); 504 } 505 return std::nullopt; 506 } 507