1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ReductionProcessor.h" 14 15 #include "PrivateReductionUtils.h" 16 #include "flang/Lower/AbstractConverter.h" 17 #include "flang/Lower/ConvertType.h" 18 #include "flang/Lower/SymbolMap.h" 19 #include "flang/Optimizer/Builder/Complex.h" 20 #include "flang/Optimizer/Builder/HLFIRTools.h" 21 #include "flang/Optimizer/Builder/Todo.h" 22 #include "flang/Optimizer/Dialect/FIRType.h" 23 #include "flang/Optimizer/HLFIR/HLFIROps.h" 24 #include "flang/Optimizer/Support/FatalError.h" 25 #include "flang/Parser/tools.h" 26 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 27 #include "llvm/Support/CommandLine.h" 28 29 static llvm::cl::opt<bool> forceByrefReduction( 30 "force-byref-reduction", 31 llvm::cl::desc("Pass all reduction arguments by reference"), 32 llvm::cl::Hidden); 33 34 namespace Fortran { 35 namespace lower { 36 namespace omp { 37 38 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 39 const omp::clause::ProcedureDesignator &pd) { 40 auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>( 41 getRealName(pd.v.sym()).ToString()) 42 .Case("max", ReductionIdentifier::MAX) 43 .Case("min", ReductionIdentifier::MIN) 44 .Case("iand", ReductionIdentifier::IAND) 45 .Case("ior", ReductionIdentifier::IOR) 46 .Case("ieor", ReductionIdentifier::IEOR) 47 .Default(std::nullopt); 48 assert(redType && "Invalid Reduction"); 49 return *redType; 50 } 51 52 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 53 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) { 54 switch (intrinsicOp) { 55 case omp::clause::DefinedOperator::IntrinsicOperator::Add: 56 return ReductionIdentifier::ADD; 57 case omp::clause::DefinedOperator::IntrinsicOperator::Subtract: 58 return ReductionIdentifier::SUBTRACT; 59 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: 60 return ReductionIdentifier::MULTIPLY; 61 case omp::clause::DefinedOperator::IntrinsicOperator::AND: 62 return ReductionIdentifier::AND; 63 case omp::clause::DefinedOperator::IntrinsicOperator::EQV: 64 return ReductionIdentifier::EQV; 65 case omp::clause::DefinedOperator::IntrinsicOperator::OR: 66 return ReductionIdentifier::OR; 67 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: 68 return ReductionIdentifier::NEQV; 69 default: 70 llvm_unreachable("unexpected intrinsic operator in reduction"); 71 } 72 } 73 74 bool ReductionProcessor::supportedIntrinsicProcReduction( 75 const omp::clause::ProcedureDesignator &pd) { 76 semantics::Symbol *sym = pd.v.sym(); 77 if (!sym->GetUltimate().attrs().test(semantics::Attr::INTRINSIC)) 78 return false; 79 auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString()) 80 .Case("max", true) 81 .Case("min", true) 82 .Case("iand", true) 83 .Case("ior", true) 84 .Case("ieor", true) 85 .Default(false); 86 return redType; 87 } 88 89 std::string 90 ReductionProcessor::getReductionName(llvm::StringRef name, 91 const fir::KindMapping &kindMap, 92 mlir::Type ty, bool isByRef) { 93 ty = fir::unwrapRefType(ty); 94 95 // extra string to distinguish reduction functions for variables passed by 96 // reference 97 llvm::StringRef byrefAddition{""}; 98 if (isByRef) 99 byrefAddition = "_byref"; 100 101 return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str()); 102 } 103 104 std::string ReductionProcessor::getReductionName( 105 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, 106 const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) { 107 std::string reductionName; 108 109 switch (intrinsicOp) { 110 case omp::clause::DefinedOperator::IntrinsicOperator::Add: 111 reductionName = "add_reduction"; 112 break; 113 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: 114 reductionName = "multiply_reduction"; 115 break; 116 case omp::clause::DefinedOperator::IntrinsicOperator::AND: 117 return "and_reduction"; 118 case omp::clause::DefinedOperator::IntrinsicOperator::EQV: 119 return "eqv_reduction"; 120 case omp::clause::DefinedOperator::IntrinsicOperator::OR: 121 return "or_reduction"; 122 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: 123 return "neqv_reduction"; 124 default: 125 reductionName = "other_reduction"; 126 break; 127 } 128 129 return getReductionName(reductionName, kindMap, ty, isByRef); 130 } 131 132 mlir::Value 133 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, 134 ReductionIdentifier redId, 135 fir::FirOpBuilder &builder) { 136 type = fir::unwrapRefType(type); 137 if (!fir::isa_integer(type) && !fir::isa_real(type) && 138 !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type)) 139 TODO(loc, "Reduction of some types is not supported"); 140 switch (redId) { 141 case ReductionIdentifier::MAX: { 142 if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) { 143 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 144 return builder.createRealConstant( 145 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); 146 } 147 unsigned bits = type.getIntOrFloatBitWidth(); 148 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); 149 return builder.createIntegerConstant(loc, type, minInt); 150 } 151 case ReductionIdentifier::MIN: { 152 if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) { 153 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 154 return builder.createRealConstant( 155 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); 156 } 157 unsigned bits = type.getIntOrFloatBitWidth(); 158 int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); 159 return builder.createIntegerConstant(loc, type, maxInt); 160 } 161 case ReductionIdentifier::IOR: { 162 unsigned bits = type.getIntOrFloatBitWidth(); 163 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 164 return builder.createIntegerConstant(loc, type, zeroInt); 165 } 166 case ReductionIdentifier::IEOR: { 167 unsigned bits = type.getIntOrFloatBitWidth(); 168 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 169 return builder.createIntegerConstant(loc, type, zeroInt); 170 } 171 case ReductionIdentifier::IAND: { 172 unsigned bits = type.getIntOrFloatBitWidth(); 173 int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); 174 return builder.createIntegerConstant(loc, type, allOnInt); 175 } 176 case ReductionIdentifier::ADD: 177 case ReductionIdentifier::MULTIPLY: 178 case ReductionIdentifier::AND: 179 case ReductionIdentifier::OR: 180 case ReductionIdentifier::EQV: 181 case ReductionIdentifier::NEQV: 182 if (auto cplxTy = mlir::dyn_cast<mlir::ComplexType>(type)) { 183 mlir::Type realTy = cplxTy.getElementType(); 184 mlir::Value initRe = builder.createRealConstant( 185 loc, realTy, getOperationIdentity(redId, loc)); 186 mlir::Value initIm = builder.createRealConstant(loc, realTy, 0); 187 188 return fir::factory::Complex{builder, loc}.createComplex(type, initRe, 189 initIm); 190 } 191 if (mlir::isa<mlir::FloatType>(type)) 192 return builder.create<mlir::arith::ConstantOp>( 193 loc, type, 194 builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc))); 195 196 if (mlir::isa<fir::LogicalType>(type)) { 197 mlir::Value intConst = builder.create<mlir::arith::ConstantOp>( 198 loc, builder.getI1Type(), 199 builder.getIntegerAttr(builder.getI1Type(), 200 getOperationIdentity(redId, loc))); 201 return builder.createConvert(loc, type, intConst); 202 } 203 204 return builder.create<mlir::arith::ConstantOp>( 205 loc, type, 206 builder.getIntegerAttr(type, getOperationIdentity(redId, loc))); 207 case ReductionIdentifier::ID: 208 case ReductionIdentifier::USER_DEF_OP: 209 case ReductionIdentifier::SUBTRACT: 210 TODO(loc, "Reduction of some identifier types is not supported"); 211 } 212 llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue"); 213 } 214 215 mlir::Value ReductionProcessor::createScalarCombiner( 216 fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId, 217 mlir::Type type, mlir::Value op1, mlir::Value op2) { 218 mlir::Value reductionOp; 219 type = fir::unwrapRefType(type); 220 switch (redId) { 221 case ReductionIdentifier::MAX: 222 reductionOp = 223 getReductionOperation<mlir::arith::MaxNumFOp, mlir::arith::MaxSIOp>( 224 builder, type, loc, op1, op2); 225 break; 226 case ReductionIdentifier::MIN: 227 reductionOp = 228 getReductionOperation<mlir::arith::MinNumFOp, mlir::arith::MinSIOp>( 229 builder, type, loc, op1, op2); 230 break; 231 case ReductionIdentifier::IOR: 232 assert((type.isIntOrIndex()) && "only integer is expected"); 233 reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2); 234 break; 235 case ReductionIdentifier::IEOR: 236 assert((type.isIntOrIndex()) && "only integer is expected"); 237 reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2); 238 break; 239 case ReductionIdentifier::IAND: 240 assert((type.isIntOrIndex()) && "only integer is expected"); 241 reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2); 242 break; 243 case ReductionIdentifier::ADD: 244 reductionOp = 245 getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp, 246 fir::AddcOp>(builder, type, loc, op1, op2); 247 break; 248 case ReductionIdentifier::MULTIPLY: 249 reductionOp = 250 getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp, 251 fir::MulcOp>(builder, type, loc, op1, op2); 252 break; 253 case ReductionIdentifier::AND: { 254 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 255 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 256 257 mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1); 258 259 reductionOp = builder.createConvert(loc, type, andiOp); 260 break; 261 } 262 case ReductionIdentifier::OR: { 263 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 264 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 265 266 mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1); 267 268 reductionOp = builder.createConvert(loc, type, oriOp); 269 break; 270 } 271 case ReductionIdentifier::EQV: { 272 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 273 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 274 275 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 276 loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); 277 278 reductionOp = builder.createConvert(loc, type, cmpiOp); 279 break; 280 } 281 case ReductionIdentifier::NEQV: { 282 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 283 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 284 285 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 286 loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); 287 288 reductionOp = builder.createConvert(loc, type, cmpiOp); 289 break; 290 } 291 default: 292 TODO(loc, "Reduction of some intrinsic operators is not supported"); 293 } 294 295 return reductionOp; 296 } 297 298 /// Create reduction combiner region for reduction variables which are boxed 299 /// arrays 300 static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc, 301 ReductionProcessor::ReductionIdentifier redId, 302 fir::BaseBoxType boxTy, mlir::Value lhs, 303 mlir::Value rhs) { 304 fir::SequenceType seqTy = mlir::dyn_cast_or_null<fir::SequenceType>( 305 fir::unwrapRefType(boxTy.getEleTy())); 306 fir::HeapType heapTy = 307 mlir::dyn_cast_or_null<fir::HeapType>(boxTy.getEleTy()); 308 fir::PointerType ptrTy = 309 mlir::dyn_cast_or_null<fir::PointerType>(boxTy.getEleTy()); 310 if ((!seqTy || seqTy.hasUnknownShape()) && !heapTy && !ptrTy) 311 TODO(loc, "Unsupported boxed type in OpenMP reduction"); 312 313 // load fir.ref<fir.box<...>> 314 mlir::Value lhsAddr = lhs; 315 lhs = builder.create<fir::LoadOp>(loc, lhs); 316 rhs = builder.create<fir::LoadOp>(loc, rhs); 317 318 if ((heapTy || ptrTy) && !seqTy) { 319 // get box contents (heap pointers) 320 lhs = builder.create<fir::BoxAddrOp>(loc, lhs); 321 rhs = builder.create<fir::BoxAddrOp>(loc, rhs); 322 mlir::Value lhsValAddr = lhs; 323 324 // load heap pointers 325 lhs = builder.create<fir::LoadOp>(loc, lhs); 326 rhs = builder.create<fir::LoadOp>(loc, rhs); 327 328 mlir::Type eleTy = heapTy ? heapTy.getEleTy() : ptrTy.getEleTy(); 329 330 mlir::Value result = ReductionProcessor::createScalarCombiner( 331 builder, loc, redId, eleTy, lhs, rhs); 332 builder.create<fir::StoreOp>(loc, result, lhsValAddr); 333 builder.create<mlir::omp::YieldOp>(loc, lhsAddr); 334 return; 335 } 336 337 fir::ShapeShiftOp shapeShift = getShapeShift(builder, loc, lhs); 338 339 // Iterate over array elements, applying the equivalent scalar reduction: 340 341 // F2018 5.4.10.2: Unallocated allocatable variables may not be referenced 342 // and so no null check is needed here before indexing into the (possibly 343 // allocatable) arrays. 344 345 // A hlfir::elemental here gets inlined with a temporary so create the 346 // loop nest directly. 347 // This function already controls all of the code in this region so we 348 // know this won't miss any opportuinties for clever elemental inlining 349 hlfir::LoopNest nest = hlfir::genLoopNest( 350 loc, builder, shapeShift.getExtents(), /*isUnordered=*/true); 351 builder.setInsertionPointToStart(nest.body); 352 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy()); 353 auto lhsEleAddr = builder.create<fir::ArrayCoorOp>( 354 loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{}, 355 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); 356 auto rhsEleAddr = builder.create<fir::ArrayCoorOp>( 357 loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{}, 358 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); 359 auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr); 360 auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr); 361 mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner( 362 builder, loc, redId, refTy, lhsEle, rhsEle); 363 builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr); 364 365 builder.setInsertionPointAfter(nest.outerOp); 366 builder.create<mlir::omp::YieldOp>(loc, lhsAddr); 367 } 368 369 // generate combiner region for reduction operations 370 static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc, 371 ReductionProcessor::ReductionIdentifier redId, 372 mlir::Type ty, mlir::Value lhs, mlir::Value rhs, 373 bool isByRef) { 374 ty = fir::unwrapRefType(ty); 375 376 if (fir::isa_trivial(ty)) { 377 mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs); 378 mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs); 379 380 mlir::Value result = ReductionProcessor::createScalarCombiner( 381 builder, loc, redId, ty, lhsLoaded, rhsLoaded); 382 if (isByRef) { 383 builder.create<fir::StoreOp>(loc, result, lhs); 384 builder.create<mlir::omp::YieldOp>(loc, lhs); 385 } else { 386 builder.create<mlir::omp::YieldOp>(loc, result); 387 } 388 return; 389 } 390 // all arrays should have been boxed 391 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { 392 genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs); 393 return; 394 } 395 396 TODO(loc, "OpenMP genCombiner for unsupported reduction variable type"); 397 } 398 399 // like fir::unwrapSeqOrBoxedSeqType except it also works for non-sequence boxes 400 static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) { 401 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) 402 return seqTy.getEleTy(); 403 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { 404 auto eleTy = fir::unwrapRefType(boxTy.getEleTy()); 405 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) 406 return seqTy.getEleTy(); 407 return eleTy; 408 } 409 return ty; 410 } 411 412 static void createReductionAllocAndInitRegions( 413 fir::FirOpBuilder &builder, mlir::Location loc, 414 mlir::omp::DeclareReductionOp &reductionDecl, 415 const ReductionProcessor::ReductionIdentifier redId, mlir::Type type, 416 bool isByRef) { 417 auto yield = [&](mlir::Value ret) { 418 builder.create<mlir::omp::YieldOp>(loc, ret); 419 }; 420 421 mlir::Block *allocBlock = nullptr; 422 mlir::Block *initBlock = nullptr; 423 if (isByRef) { 424 allocBlock = 425 builder.createBlock(&reductionDecl.getAllocRegion(), 426 reductionDecl.getAllocRegion().end(), {}, {}); 427 initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(), 428 reductionDecl.getInitializerRegion().end(), 429 {type, type}, {loc, loc}); 430 } else { 431 initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(), 432 reductionDecl.getInitializerRegion().end(), 433 {type}, {loc}); 434 } 435 436 mlir::Type ty = fir::unwrapRefType(type); 437 builder.setInsertionPointToEnd(initBlock); 438 mlir::Value initValue = ReductionProcessor::getReductionInitValue( 439 loc, unwrapSeqOrBoxedType(ty), redId, builder); 440 441 if (isByRef) { 442 populateByRefInitAndCleanupRegions(builder, loc, type, initValue, initBlock, 443 reductionDecl.getInitializerAllocArg(), 444 reductionDecl.getInitializerMoldArg(), 445 reductionDecl.getCleanupRegion()); 446 } 447 448 if (fir::isa_trivial(ty)) { 449 if (isByRef) { 450 // alloc region 451 builder.setInsertionPointToEnd(allocBlock); 452 mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty); 453 yield(alloca); 454 return; 455 } 456 // by val 457 yield(initValue); 458 return; 459 } 460 assert(isByRef && "passing non-trivial types by val is unsupported"); 461 462 // alloc region 463 builder.setInsertionPointToEnd(allocBlock); 464 mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty); 465 yield(boxAlloca); 466 } 467 468 mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction( 469 fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, 470 const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, 471 bool isByRef) { 472 mlir::OpBuilder::InsertionGuard guard(builder); 473 mlir::ModuleOp module = builder.getModule(); 474 475 assert(!reductionOpName.empty()); 476 477 auto decl = 478 module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName); 479 if (decl) 480 return decl; 481 482 mlir::OpBuilder modBuilder(module.getBodyRegion()); 483 mlir::Type valTy = fir::unwrapRefType(type); 484 if (!isByRef) 485 type = valTy; 486 487 decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName, 488 type); 489 createReductionAllocAndInitRegions(builder, loc, decl, redId, type, isByRef); 490 491 builder.createBlock(&decl.getReductionRegion(), 492 decl.getReductionRegion().end(), {type, type}, 493 {loc, loc}); 494 495 builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); 496 mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); 497 mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); 498 genCombiner(builder, loc, redId, type, op1, op2, isByRef); 499 500 return decl; 501 } 502 503 static bool doReductionByRef(mlir::Value reductionVar) { 504 if (forceByrefReduction) 505 return true; 506 507 if (auto declare = 508 mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp())) 509 reductionVar = declare.getMemref(); 510 511 if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType()))) 512 return true; 513 514 return false; 515 } 516 517 void ReductionProcessor::addDeclareReduction( 518 mlir::Location currentLocation, lower::AbstractConverter &converter, 519 const omp::clause::Reduction &reduction, 520 llvm::SmallVectorImpl<mlir::Value> &reductionVars, 521 llvm::SmallVectorImpl<bool> &reduceVarByRef, 522 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, 523 llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) { 524 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 525 526 if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>( 527 reduction.t)) 528 TODO(currentLocation, "Reduction modifiers are not supported"); 529 530 mlir::omp::DeclareReductionOp decl; 531 const auto &redOperatorList{ 532 std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)}; 533 assert(redOperatorList.size() == 1 && "Expecting single operator"); 534 const auto &redOperator = redOperatorList.front(); 535 const auto &objectList{std::get<omp::ObjectList>(reduction.t)}; 536 537 if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) { 538 if (const auto *reductionIntrinsic = 539 std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) { 540 if (!ReductionProcessor::supportedIntrinsicProcReduction( 541 *reductionIntrinsic)) { 542 return; 543 } 544 } else { 545 return; 546 } 547 } 548 549 // Reduction variable processing common to both intrinsic operators and 550 // procedure designators 551 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 552 for (const Object &object : objectList) { 553 const semantics::Symbol *symbol = object.sym(); 554 reductionSymbols.push_back(symbol); 555 mlir::Value symVal = converter.getSymbolAddress(*symbol); 556 mlir::Type eleType; 557 auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType()); 558 if (refType) 559 eleType = refType.getEleTy(); 560 else 561 eleType = symVal.getType(); 562 563 // all arrays must be boxed so that we have convenient access to all the 564 // information needed to iterate over the array 565 if (mlir::isa<fir::SequenceType>(eleType)) { 566 // For Host associated symbols, use `SymbolBox` instead 567 lower::SymbolBox symBox = converter.lookupOneLevelUpSymbol(*symbol); 568 hlfir::Entity entity{symBox.getAddr()}; 569 entity = genVariableBox(currentLocation, builder, entity); 570 mlir::Value box = entity.getBase(); 571 572 // Always pass the box by reference so that the OpenMP dialect 573 // verifiers don't need to know anything about fir.box 574 auto alloca = 575 builder.create<fir::AllocaOp>(currentLocation, box.getType()); 576 builder.create<fir::StoreOp>(currentLocation, box, alloca); 577 578 symVal = alloca; 579 } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) { 580 // boxed arrays are passed as values not by reference. Unfortunately, 581 // we can't pass a box by value to omp.redution_declare, so turn it 582 // into a reference 583 584 auto alloca = 585 builder.create<fir::AllocaOp>(currentLocation, symVal.getType()); 586 builder.create<fir::StoreOp>(currentLocation, symVal, alloca); 587 symVal = alloca; 588 } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) { 589 symVal = declOp.getBase(); 590 } 591 592 // this isn't the same as the by-val and by-ref passing later in the 593 // pipeline. Both styles assume that the variable is a reference at 594 // this point 595 assert(mlir::isa<fir::ReferenceType>(symVal.getType()) && 596 "reduction input var is a reference"); 597 598 reductionVars.push_back(symVal); 599 reduceVarByRef.push_back(doReductionByRef(symVal)); 600 } 601 602 for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) { 603 auto redType = mlir::cast<fir::ReferenceType>(symVal.getType()); 604 const auto &kindMap = firOpBuilder.getKindMap(); 605 std::string reductionName; 606 ReductionIdentifier redId; 607 mlir::Type redNameTy = redType; 608 if (mlir::isa<fir::LogicalType>(redType.getEleTy())) 609 redNameTy = builder.getI1Type(); 610 611 if (const auto &redDefinedOp = 612 std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) { 613 const auto &intrinsicOp{ 614 std::get<omp::clause::DefinedOperator::IntrinsicOperator>( 615 redDefinedOp->u)}; 616 redId = getReductionType(intrinsicOp); 617 switch (redId) { 618 case ReductionIdentifier::ADD: 619 case ReductionIdentifier::MULTIPLY: 620 case ReductionIdentifier::AND: 621 case ReductionIdentifier::EQV: 622 case ReductionIdentifier::OR: 623 case ReductionIdentifier::NEQV: 624 break; 625 default: 626 TODO(currentLocation, 627 "Reduction of some intrinsic operators is not supported"); 628 break; 629 } 630 631 reductionName = 632 getReductionName(intrinsicOp, kindMap, redNameTy, isByRef); 633 } else if (const auto *reductionIntrinsic = 634 std::get_if<omp::clause::ProcedureDesignator>( 635 &redOperator.u)) { 636 if (!ReductionProcessor::supportedIntrinsicProcReduction( 637 *reductionIntrinsic)) { 638 TODO(currentLocation, "Unsupported intrinsic proc reduction"); 639 } 640 redId = getReductionType(*reductionIntrinsic); 641 reductionName = 642 getReductionName(getRealName(*reductionIntrinsic).ToString(), kindMap, 643 redNameTy, isByRef); 644 } else { 645 TODO(currentLocation, "Unexpected reduction type"); 646 } 647 648 decl = createDeclareReduction(firOpBuilder, reductionName, redId, redType, 649 currentLocation, isByRef); 650 reductionDeclSymbols.push_back( 651 mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName())); 652 } 653 } 654 655 const semantics::SourceName 656 ReductionProcessor::getRealName(const semantics::Symbol *symbol) { 657 return symbol->GetUltimate().name(); 658 } 659 660 const semantics::SourceName 661 ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) { 662 return getRealName(pd.v.sym()); 663 } 664 665 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId, 666 mlir::Location loc) { 667 switch (redId) { 668 case ReductionIdentifier::ADD: 669 case ReductionIdentifier::OR: 670 case ReductionIdentifier::NEQV: 671 return 0; 672 case ReductionIdentifier::MULTIPLY: 673 case ReductionIdentifier::AND: 674 case ReductionIdentifier::EQV: 675 return 1; 676 default: 677 TODO(loc, "Reduction of some intrinsic operators is not supported"); 678 } 679 } 680 681 } // namespace omp 682 } // namespace lower 683 } // namespace Fortran 684