1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ReductionProcessor.h" 14 15 #include "flang/Lower/AbstractConverter.h" 16 #include "flang/Optimizer/Builder/HLFIRTools.h" 17 #include "flang/Optimizer/Builder/Todo.h" 18 #include "flang/Optimizer/Dialect/FIRType.h" 19 #include "flang/Optimizer/HLFIR/HLFIROps.h" 20 #include "flang/Parser/tools.h" 21 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 22 #include "llvm/Support/CommandLine.h" 23 24 static llvm::cl::opt<bool> forceByrefReduction( 25 "force-byref-reduction", 26 llvm::cl::desc("Pass all reduction arguments by reference"), 27 llvm::cl::Hidden); 28 29 namespace Fortran { 30 namespace lower { 31 namespace omp { 32 33 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 34 const omp::clause::ProcedureDesignator &pd) { 35 auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>( 36 getRealName(pd.v.id()).ToString()) 37 .Case("max", ReductionIdentifier::MAX) 38 .Case("min", ReductionIdentifier::MIN) 39 .Case("iand", ReductionIdentifier::IAND) 40 .Case("ior", ReductionIdentifier::IOR) 41 .Case("ieor", ReductionIdentifier::IEOR) 42 .Default(std::nullopt); 43 assert(redType && "Invalid Reduction"); 44 return *redType; 45 } 46 47 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 48 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) { 49 switch (intrinsicOp) { 50 case omp::clause::DefinedOperator::IntrinsicOperator::Add: 51 return ReductionIdentifier::ADD; 52 case omp::clause::DefinedOperator::IntrinsicOperator::Subtract: 53 return ReductionIdentifier::SUBTRACT; 54 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: 55 return ReductionIdentifier::MULTIPLY; 56 case omp::clause::DefinedOperator::IntrinsicOperator::AND: 57 return ReductionIdentifier::AND; 58 case omp::clause::DefinedOperator::IntrinsicOperator::EQV: 59 return ReductionIdentifier::EQV; 60 case omp::clause::DefinedOperator::IntrinsicOperator::OR: 61 return ReductionIdentifier::OR; 62 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: 63 return ReductionIdentifier::NEQV; 64 default: 65 llvm_unreachable("unexpected intrinsic operator in reduction"); 66 } 67 } 68 69 bool ReductionProcessor::supportedIntrinsicProcReduction( 70 const omp::clause::ProcedureDesignator &pd) { 71 Fortran::semantics::Symbol *sym = pd.v.id(); 72 if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC)) 73 return false; 74 auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString()) 75 .Case("max", true) 76 .Case("min", true) 77 .Case("iand", true) 78 .Case("ior", true) 79 .Case("ieor", true) 80 .Default(false); 81 return redType; 82 } 83 84 std::string ReductionProcessor::getReductionName(llvm::StringRef name, 85 mlir::Type ty, bool isByRef) { 86 ty = fir::unwrapRefType(ty); 87 88 // extra string to distinguish reduction functions for variables passed by 89 // reference 90 llvm::StringRef byrefAddition{""}; 91 if (isByRef) 92 byrefAddition = "_byref"; 93 94 if (fir::isa_trivial(ty)) 95 return (llvm::Twine(name) + 96 (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + 97 llvm::Twine(ty.getIntOrFloatBitWidth()) + byrefAddition) 98 .str(); 99 100 // creates a name like reduction_i_64_box_ux4x3 101 if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) { 102 // TODO: support for allocatable boxes: 103 // !fir.box<!fir.heap<!fir.array<...>>> 104 fir::SequenceType seqTy = fir::unwrapRefType(boxTy.getEleTy()) 105 .dyn_cast_or_null<fir::SequenceType>(); 106 if (!seqTy) 107 return {}; 108 109 std::string prefix = getReductionName( 110 name, fir::unwrapSeqOrBoxedSeqType(ty), /*isByRef=*/false); 111 if (prefix.empty()) 112 return {}; 113 std::stringstream tyStr; 114 tyStr << prefix << "_box_"; 115 bool first = true; 116 for (std::int64_t extent : seqTy.getShape()) { 117 if (first) 118 first = false; 119 else 120 tyStr << "x"; 121 if (extent == seqTy.getUnknownExtent()) 122 tyStr << 'u'; // I'm not sure that '?' is safe in symbol names 123 else 124 tyStr << extent; 125 } 126 return (tyStr.str() + byrefAddition).str(); 127 } 128 129 return {}; 130 } 131 132 std::string ReductionProcessor::getReductionName( 133 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, mlir::Type ty, 134 bool isByRef) { 135 std::string reductionName; 136 137 switch (intrinsicOp) { 138 case omp::clause::DefinedOperator::IntrinsicOperator::Add: 139 reductionName = "add_reduction"; 140 break; 141 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: 142 reductionName = "multiply_reduction"; 143 break; 144 case omp::clause::DefinedOperator::IntrinsicOperator::AND: 145 return "and_reduction"; 146 case omp::clause::DefinedOperator::IntrinsicOperator::EQV: 147 return "eqv_reduction"; 148 case omp::clause::DefinedOperator::IntrinsicOperator::OR: 149 return "or_reduction"; 150 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: 151 return "neqv_reduction"; 152 default: 153 reductionName = "other_reduction"; 154 break; 155 } 156 157 return getReductionName(reductionName, ty, isByRef); 158 } 159 160 mlir::Value 161 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, 162 ReductionIdentifier redId, 163 fir::FirOpBuilder &builder) { 164 type = fir::unwrapRefType(type); 165 assert((fir::isa_integer(type) || fir::isa_real(type) || 166 type.isa<fir::LogicalType>()) && 167 "only integer, logical and real types are currently supported"); 168 switch (redId) { 169 case ReductionIdentifier::MAX: { 170 if (auto ty = type.dyn_cast<mlir::FloatType>()) { 171 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 172 return builder.createRealConstant( 173 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); 174 } 175 unsigned bits = type.getIntOrFloatBitWidth(); 176 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); 177 return builder.createIntegerConstant(loc, type, minInt); 178 } 179 case ReductionIdentifier::MIN: { 180 if (auto ty = type.dyn_cast<mlir::FloatType>()) { 181 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 182 return builder.createRealConstant( 183 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); 184 } 185 unsigned bits = type.getIntOrFloatBitWidth(); 186 int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); 187 return builder.createIntegerConstant(loc, type, maxInt); 188 } 189 case ReductionIdentifier::IOR: { 190 unsigned bits = type.getIntOrFloatBitWidth(); 191 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 192 return builder.createIntegerConstant(loc, type, zeroInt); 193 } 194 case ReductionIdentifier::IEOR: { 195 unsigned bits = type.getIntOrFloatBitWidth(); 196 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 197 return builder.createIntegerConstant(loc, type, zeroInt); 198 } 199 case ReductionIdentifier::IAND: { 200 unsigned bits = type.getIntOrFloatBitWidth(); 201 int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); 202 return builder.createIntegerConstant(loc, type, allOnInt); 203 } 204 case ReductionIdentifier::ADD: 205 case ReductionIdentifier::MULTIPLY: 206 case ReductionIdentifier::AND: 207 case ReductionIdentifier::OR: 208 case ReductionIdentifier::EQV: 209 case ReductionIdentifier::NEQV: 210 if (type.isa<mlir::FloatType>()) 211 return builder.create<mlir::arith::ConstantOp>( 212 loc, type, 213 builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc))); 214 215 if (type.isa<fir::LogicalType>()) { 216 mlir::Value intConst = builder.create<mlir::arith::ConstantOp>( 217 loc, builder.getI1Type(), 218 builder.getIntegerAttr(builder.getI1Type(), 219 getOperationIdentity(redId, loc))); 220 return builder.createConvert(loc, type, intConst); 221 } 222 223 return builder.create<mlir::arith::ConstantOp>( 224 loc, type, 225 builder.getIntegerAttr(type, getOperationIdentity(redId, loc))); 226 case ReductionIdentifier::ID: 227 case ReductionIdentifier::USER_DEF_OP: 228 case ReductionIdentifier::SUBTRACT: 229 TODO(loc, "Reduction of some identifier types is not supported"); 230 } 231 llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue"); 232 } 233 234 mlir::Value ReductionProcessor::createScalarCombiner( 235 fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId, 236 mlir::Type type, mlir::Value op1, mlir::Value op2) { 237 mlir::Value reductionOp; 238 type = fir::unwrapRefType(type); 239 switch (redId) { 240 case ReductionIdentifier::MAX: 241 reductionOp = 242 getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>( 243 builder, type, loc, op1, op2); 244 break; 245 case ReductionIdentifier::MIN: 246 reductionOp = 247 getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>( 248 builder, type, loc, op1, op2); 249 break; 250 case ReductionIdentifier::IOR: 251 assert((type.isIntOrIndex()) && "only integer is expected"); 252 reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2); 253 break; 254 case ReductionIdentifier::IEOR: 255 assert((type.isIntOrIndex()) && "only integer is expected"); 256 reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2); 257 break; 258 case ReductionIdentifier::IAND: 259 assert((type.isIntOrIndex()) && "only integer is expected"); 260 reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2); 261 break; 262 case ReductionIdentifier::ADD: 263 reductionOp = 264 getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>( 265 builder, type, loc, op1, op2); 266 break; 267 case ReductionIdentifier::MULTIPLY: 268 reductionOp = 269 getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>( 270 builder, type, loc, op1, op2); 271 break; 272 case ReductionIdentifier::AND: { 273 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 274 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 275 276 mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1); 277 278 reductionOp = builder.createConvert(loc, type, andiOp); 279 break; 280 } 281 case ReductionIdentifier::OR: { 282 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 283 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 284 285 mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1); 286 287 reductionOp = builder.createConvert(loc, type, oriOp); 288 break; 289 } 290 case ReductionIdentifier::EQV: { 291 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 292 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 293 294 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 295 loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); 296 297 reductionOp = builder.createConvert(loc, type, cmpiOp); 298 break; 299 } 300 case ReductionIdentifier::NEQV: { 301 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 302 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 303 304 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 305 loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); 306 307 reductionOp = builder.createConvert(loc, type, cmpiOp); 308 break; 309 } 310 default: 311 TODO(loc, "Reduction of some intrinsic operators is not supported"); 312 } 313 314 return reductionOp; 315 } 316 317 /// Create reduction combiner region for reduction variables which are boxed 318 /// arrays 319 static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc, 320 ReductionProcessor::ReductionIdentifier redId, 321 fir::BaseBoxType boxTy, mlir::Value lhs, 322 mlir::Value rhs) { 323 fir::SequenceType seqTy = 324 mlir::dyn_cast_or_null<fir::SequenceType>(boxTy.getEleTy()); 325 // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>> 326 if (!seqTy || seqTy.hasUnknownShape()) 327 TODO(loc, "Unsupported boxed type in OpenMP reduction"); 328 329 // load fir.ref<fir.box<...>> 330 mlir::Value lhsAddr = lhs; 331 lhs = builder.create<fir::LoadOp>(loc, lhs); 332 rhs = builder.create<fir::LoadOp>(loc, rhs); 333 334 const unsigned rank = seqTy.getDimension(); 335 llvm::SmallVector<mlir::Value> extents; 336 extents.reserve(rank); 337 llvm::SmallVector<mlir::Value> lbAndExtents; 338 lbAndExtents.reserve(rank * 2); 339 340 // Get box lowerbounds and extents: 341 mlir::Type idxTy = builder.getIndexType(); 342 for (unsigned i = 0; i < rank; ++i) { 343 // TODO: ideally we want to hoist box reads out of the critical section. 344 // We could do this by having box dimensions in block arguments like 345 // OpenACC does 346 mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i); 347 auto dimInfo = 348 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, lhs, dim); 349 extents.push_back(dimInfo.getExtent()); 350 lbAndExtents.push_back(dimInfo.getLowerBound()); 351 lbAndExtents.push_back(dimInfo.getExtent()); 352 } 353 354 auto shapeShiftTy = fir::ShapeShiftType::get(builder.getContext(), rank); 355 auto shapeShift = 356 builder.create<fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents); 357 358 // Iterate over array elements, applying the equivalent scalar reduction: 359 360 // A hlfir::elemental here gets inlined with a temporary so create the 361 // loop nest directly. 362 // This function already controls all of the code in this region so we 363 // know this won't miss any opportuinties for clever elemental inlining 364 hlfir::LoopNest nest = 365 hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true); 366 builder.setInsertionPointToStart(nest.innerLoop.getBody()); 367 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy()); 368 auto lhsEleAddr = builder.create<fir::ArrayCoorOp>( 369 loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{}, 370 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); 371 auto rhsEleAddr = builder.create<fir::ArrayCoorOp>( 372 loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{}, 373 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); 374 auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr); 375 auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr); 376 mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner( 377 builder, loc, redId, refTy, lhsEle, rhsEle); 378 builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr); 379 380 builder.setInsertionPointAfter(nest.outerLoop); 381 builder.create<mlir::omp::YieldOp>(loc, lhsAddr); 382 } 383 384 // generate combiner region for reduction operations 385 static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc, 386 ReductionProcessor::ReductionIdentifier redId, 387 mlir::Type ty, mlir::Value lhs, mlir::Value rhs, 388 bool isByRef) { 389 ty = fir::unwrapRefType(ty); 390 391 if (fir::isa_trivial(ty)) { 392 mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs); 393 mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs); 394 395 mlir::Value result = ReductionProcessor::createScalarCombiner( 396 builder, loc, redId, ty, lhsLoaded, rhsLoaded); 397 if (isByRef) { 398 builder.create<fir::StoreOp>(loc, result, lhs); 399 builder.create<mlir::omp::YieldOp>(loc, lhs); 400 } else { 401 builder.create<mlir::omp::YieldOp>(loc, result); 402 } 403 return; 404 } 405 // all arrays should have been boxed 406 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { 407 genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs); 408 return; 409 } 410 411 TODO(loc, "OpenMP genCombiner for unsupported reduction variable type"); 412 } 413 414 static mlir::Value 415 createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc, 416 const ReductionProcessor::ReductionIdentifier redId, 417 mlir::Type type, bool isByRef) { 418 mlir::Type ty = fir::unwrapRefType(type); 419 mlir::Value initValue = ReductionProcessor::getReductionInitValue( 420 loc, fir::unwrapSeqOrBoxedSeqType(ty), redId, builder); 421 422 if (fir::isa_trivial(ty)) { 423 if (isByRef) { 424 mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty); 425 builder.createStoreWithConvert(loc, initValue, alloca); 426 return alloca; 427 } 428 // by val 429 return initValue; 430 } 431 432 // all arrays are boxed 433 if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) { 434 assert(isByRef && "passing arrays by value is unsupported"); 435 // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>> 436 mlir::Type innerTy = fir::extractSequenceType(boxTy); 437 if (!mlir::isa<fir::SequenceType>(innerTy)) 438 TODO(loc, "Unsupported boxed type for reduction"); 439 // Create the private copy from the initial fir.box: 440 hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)}; 441 442 // TODO: if the whole reduction is nested inside of a loop, this alloca 443 // could lead to a stack overflow (the memory is only freed at the end of 444 // the stack frame). The reduction declare operation needs a deallocation 445 // region to undo the init region. 446 hlfir::Entity temp = createStackTempFromMold(loc, builder, source); 447 448 // Put the temporary inside of a box: 449 hlfir::Entity box = hlfir::genVariableBox(loc, builder, temp); 450 builder.create<hlfir::AssignOp>(loc, initValue, box); 451 mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty); 452 builder.create<fir::StoreOp>(loc, box, boxAlloca); 453 return boxAlloca; 454 } 455 456 TODO(loc, "createReductionInitRegion for unsupported type"); 457 } 458 459 mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction( 460 fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, 461 const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, 462 bool isByRef) { 463 mlir::OpBuilder::InsertionGuard guard(builder); 464 mlir::ModuleOp module = builder.getModule(); 465 466 if (reductionOpName.empty()) 467 TODO(loc, "Reduction of some types is not supported"); 468 469 auto decl = 470 module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName); 471 if (decl) 472 return decl; 473 474 mlir::OpBuilder modBuilder(module.getBodyRegion()); 475 mlir::Type valTy = fir::unwrapRefType(type); 476 if (!isByRef) 477 type = valTy; 478 479 decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName, 480 type); 481 builder.createBlock(&decl.getInitializerRegion(), 482 decl.getInitializerRegion().end(), {type}, {loc}); 483 builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); 484 485 mlir::Value init = 486 createReductionInitRegion(builder, loc, redId, type, isByRef); 487 builder.create<mlir::omp::YieldOp>(loc, init); 488 489 builder.createBlock(&decl.getReductionRegion(), 490 decl.getReductionRegion().end(), {type, type}, 491 {loc, loc}); 492 493 builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); 494 mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); 495 mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); 496 genCombiner(builder, loc, redId, type, op1, op2, isByRef); 497 498 return decl; 499 } 500 501 // TODO: By-ref vs by-val reductions are currently toggled for the whole 502 // operation (possibly effecting multiple reduction variables). 503 // This could cause a problem with openmp target reductions because 504 // by-ref trivial types may not be supported. 505 bool ReductionProcessor::doReductionByRef( 506 const llvm::SmallVectorImpl<mlir::Value> &reductionVars) { 507 if (reductionVars.empty()) 508 return false; 509 if (forceByrefReduction) 510 return true; 511 512 for (mlir::Value reductionVar : reductionVars) { 513 if (auto declare = 514 mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp())) 515 reductionVar = declare.getMemref(); 516 517 if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType()))) 518 return true; 519 } 520 return false; 521 } 522 523 void ReductionProcessor::addDeclareReduction( 524 mlir::Location currentLocation, 525 Fortran::lower::AbstractConverter &converter, 526 const omp::clause::Reduction &reduction, 527 llvm::SmallVectorImpl<mlir::Value> &reductionVars, 528 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, 529 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> 530 *reductionSymbols) { 531 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 532 mlir::omp::DeclareReductionOp decl; 533 const auto &redOperator{ 534 std::get<omp::clause::ReductionOperator>(reduction.t)}; 535 const auto &objectList{std::get<omp::ObjectList>(reduction.t)}; 536 537 if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) { 538 if (const auto *reductionIntrinsic = 539 std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) { 540 if (!ReductionProcessor::supportedIntrinsicProcReduction( 541 *reductionIntrinsic)) { 542 return; 543 } 544 } else { 545 return; 546 } 547 } 548 549 // initial pass to collect all reduction vars so we can figure out if this 550 // should happen byref 551 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 552 for (const Object &object : objectList) { 553 const Fortran::semantics::Symbol *symbol = object.id(); 554 if (reductionSymbols) 555 reductionSymbols->push_back(symbol); 556 mlir::Value symVal = converter.getSymbolAddress(*symbol); 557 auto redType = mlir::cast<fir::ReferenceType>(symVal.getType()); 558 559 // all arrays must be boxed so that we have convenient access to all the 560 // information needed to iterate over the array 561 if (mlir::isa<fir::SequenceType>(redType.getEleTy())) { 562 hlfir::Entity entity{symVal}; 563 entity = genVariableBox(currentLocation, builder, entity); 564 mlir::Value box = entity.getBase(); 565 566 // Always pass the box by reference so that the OpenMP dialect 567 // verifiers don't need to know anything about fir.box 568 auto alloca = 569 builder.create<fir::AllocaOp>(currentLocation, box.getType()); 570 builder.create<fir::StoreOp>(currentLocation, box, alloca); 571 572 symVal = alloca; 573 redType = mlir::cast<fir::ReferenceType>(symVal.getType()); 574 } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) { 575 symVal = declOp.getBase(); 576 } 577 578 reductionVars.push_back(symVal); 579 } 580 const bool isByRef = doReductionByRef(reductionVars); 581 582 if (const auto &redDefinedOp = 583 std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) { 584 const auto &intrinsicOp{ 585 std::get<omp::clause::DefinedOperator::IntrinsicOperator>( 586 redDefinedOp->u)}; 587 ReductionIdentifier redId = getReductionType(intrinsicOp); 588 switch (redId) { 589 case ReductionIdentifier::ADD: 590 case ReductionIdentifier::MULTIPLY: 591 case ReductionIdentifier::AND: 592 case ReductionIdentifier::EQV: 593 case ReductionIdentifier::OR: 594 case ReductionIdentifier::NEQV: 595 break; 596 default: 597 TODO(currentLocation, 598 "Reduction of some intrinsic operators is not supported"); 599 break; 600 } 601 602 for (mlir::Value symVal : reductionVars) { 603 auto redType = mlir::cast<fir::ReferenceType>(symVal.getType()); 604 if (redType.getEleTy().isa<fir::LogicalType>()) 605 decl = createDeclareReduction( 606 firOpBuilder, 607 getReductionName(intrinsicOp, firOpBuilder.getI1Type(), isByRef), 608 redId, redType, currentLocation, isByRef); 609 else 610 decl = createDeclareReduction( 611 firOpBuilder, getReductionName(intrinsicOp, redType, isByRef), 612 redId, redType, currentLocation, isByRef); 613 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( 614 firOpBuilder.getContext(), decl.getSymName())); 615 } 616 } else if (const auto *reductionIntrinsic = 617 std::get_if<omp::clause::ProcedureDesignator>( 618 &redOperator.u)) { 619 if (ReductionProcessor::supportedIntrinsicProcReduction( 620 *reductionIntrinsic)) { 621 ReductionProcessor::ReductionIdentifier redId = 622 ReductionProcessor::getReductionType(*reductionIntrinsic); 623 for (const Object &object : objectList) { 624 const Fortran::semantics::Symbol *symbol = object.id(); 625 mlir::Value symVal = converter.getSymbolAddress(*symbol); 626 if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) 627 symVal = declOp.getBase(); 628 auto redType = symVal.getType().cast<fir::ReferenceType>(); 629 if (!redType.getEleTy().isIntOrIndexOrFloat()) 630 TODO(currentLocation, "User Defined Reduction on non-trivial type"); 631 decl = createDeclareReduction( 632 firOpBuilder, 633 getReductionName(getRealName(*reductionIntrinsic).ToString(), 634 redType, isByRef), 635 redId, redType, currentLocation, isByRef); 636 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( 637 firOpBuilder.getContext(), decl.getSymName())); 638 } 639 } 640 } 641 } 642 643 const Fortran::semantics::SourceName 644 ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) { 645 return symbol->GetUltimate().name(); 646 } 647 648 const Fortran::semantics::SourceName 649 ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) { 650 return getRealName(pd.v.id()); 651 } 652 653 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId, 654 mlir::Location loc) { 655 switch (redId) { 656 case ReductionIdentifier::ADD: 657 case ReductionIdentifier::OR: 658 case ReductionIdentifier::NEQV: 659 return 0; 660 case ReductionIdentifier::MULTIPLY: 661 case ReductionIdentifier::AND: 662 case ReductionIdentifier::EQV: 663 return 1; 664 default: 665 TODO(loc, "Reduction of some intrinsic operators is not supported"); 666 } 667 } 668 669 } // namespace omp 670 } // namespace lower 671 } // namespace Fortran 672