1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ReductionProcessor.h" 14 15 #include "flang/Lower/AbstractConverter.h" 16 #include "flang/Lower/ConvertType.h" 17 #include "flang/Lower/SymbolMap.h" 18 #include "flang/Optimizer/Builder/Complex.h" 19 #include "flang/Optimizer/Builder/HLFIRTools.h" 20 #include "flang/Optimizer/Builder/Todo.h" 21 #include "flang/Optimizer/Dialect/FIRType.h" 22 #include "flang/Optimizer/HLFIR/HLFIROps.h" 23 #include "flang/Parser/tools.h" 24 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 25 #include "llvm/Support/CommandLine.h" 26 27 static llvm::cl::opt<bool> forceByrefReduction( 28 "force-byref-reduction", 29 llvm::cl::desc("Pass all reduction arguments by reference"), 30 llvm::cl::Hidden); 31 32 namespace Fortran { 33 namespace lower { 34 namespace omp { 35 36 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 37 const omp::clause::ProcedureDesignator &pd) { 38 auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>( 39 getRealName(pd.v.id()).ToString()) 40 .Case("max", ReductionIdentifier::MAX) 41 .Case("min", ReductionIdentifier::MIN) 42 .Case("iand", ReductionIdentifier::IAND) 43 .Case("ior", ReductionIdentifier::IOR) 44 .Case("ieor", ReductionIdentifier::IEOR) 45 .Default(std::nullopt); 46 assert(redType && "Invalid Reduction"); 47 return *redType; 48 } 49 50 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 51 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) { 52 switch (intrinsicOp) { 53 case omp::clause::DefinedOperator::IntrinsicOperator::Add: 54 return ReductionIdentifier::ADD; 55 case omp::clause::DefinedOperator::IntrinsicOperator::Subtract: 56 return ReductionIdentifier::SUBTRACT; 57 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: 58 return ReductionIdentifier::MULTIPLY; 59 case omp::clause::DefinedOperator::IntrinsicOperator::AND: 60 return ReductionIdentifier::AND; 61 case omp::clause::DefinedOperator::IntrinsicOperator::EQV: 62 return ReductionIdentifier::EQV; 63 case omp::clause::DefinedOperator::IntrinsicOperator::OR: 64 return ReductionIdentifier::OR; 65 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: 66 return ReductionIdentifier::NEQV; 67 default: 68 llvm_unreachable("unexpected intrinsic operator in reduction"); 69 } 70 } 71 72 bool ReductionProcessor::supportedIntrinsicProcReduction( 73 const omp::clause::ProcedureDesignator &pd) { 74 Fortran::semantics::Symbol *sym = pd.v.id(); 75 if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC)) 76 return false; 77 auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString()) 78 .Case("max", true) 79 .Case("min", true) 80 .Case("iand", true) 81 .Case("ior", true) 82 .Case("ieor", true) 83 .Default(false); 84 return redType; 85 } 86 87 std::string 88 ReductionProcessor::getReductionName(llvm::StringRef name, 89 const fir::KindMapping &kindMap, 90 mlir::Type ty, bool isByRef) { 91 ty = fir::unwrapRefType(ty); 92 93 // extra string to distinguish reduction functions for variables passed by 94 // reference 95 llvm::StringRef byrefAddition{""}; 96 if (isByRef) 97 byrefAddition = "_byref"; 98 99 return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str()); 100 } 101 102 std::string ReductionProcessor::getReductionName( 103 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, 104 const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) { 105 std::string reductionName; 106 107 switch (intrinsicOp) { 108 case omp::clause::DefinedOperator::IntrinsicOperator::Add: 109 reductionName = "add_reduction"; 110 break; 111 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: 112 reductionName = "multiply_reduction"; 113 break; 114 case omp::clause::DefinedOperator::IntrinsicOperator::AND: 115 return "and_reduction"; 116 case omp::clause::DefinedOperator::IntrinsicOperator::EQV: 117 return "eqv_reduction"; 118 case omp::clause::DefinedOperator::IntrinsicOperator::OR: 119 return "or_reduction"; 120 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: 121 return "neqv_reduction"; 122 default: 123 reductionName = "other_reduction"; 124 break; 125 } 126 127 return getReductionName(reductionName, kindMap, ty, isByRef); 128 } 129 130 mlir::Value 131 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, 132 ReductionIdentifier redId, 133 fir::FirOpBuilder &builder) { 134 type = fir::unwrapRefType(type); 135 if (!fir::isa_integer(type) && !fir::isa_real(type) && 136 !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type)) 137 TODO(loc, "Reduction of some types is not supported"); 138 switch (redId) { 139 case ReductionIdentifier::MAX: { 140 if (auto ty = type.dyn_cast<mlir::FloatType>()) { 141 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 142 return builder.createRealConstant( 143 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); 144 } 145 unsigned bits = type.getIntOrFloatBitWidth(); 146 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); 147 return builder.createIntegerConstant(loc, type, minInt); 148 } 149 case ReductionIdentifier::MIN: { 150 if (auto ty = type.dyn_cast<mlir::FloatType>()) { 151 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 152 return builder.createRealConstant( 153 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); 154 } 155 unsigned bits = type.getIntOrFloatBitWidth(); 156 int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); 157 return builder.createIntegerConstant(loc, type, maxInt); 158 } 159 case ReductionIdentifier::IOR: { 160 unsigned bits = type.getIntOrFloatBitWidth(); 161 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 162 return builder.createIntegerConstant(loc, type, zeroInt); 163 } 164 case ReductionIdentifier::IEOR: { 165 unsigned bits = type.getIntOrFloatBitWidth(); 166 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 167 return builder.createIntegerConstant(loc, type, zeroInt); 168 } 169 case ReductionIdentifier::IAND: { 170 unsigned bits = type.getIntOrFloatBitWidth(); 171 int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); 172 return builder.createIntegerConstant(loc, type, allOnInt); 173 } 174 case ReductionIdentifier::ADD: 175 case ReductionIdentifier::MULTIPLY: 176 case ReductionIdentifier::AND: 177 case ReductionIdentifier::OR: 178 case ReductionIdentifier::EQV: 179 case ReductionIdentifier::NEQV: 180 if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) { 181 mlir::Type realTy = 182 Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind()); 183 mlir::Value initRe = builder.createRealConstant( 184 loc, realTy, getOperationIdentity(redId, loc)); 185 mlir::Value initIm = builder.createRealConstant(loc, realTy, 0); 186 187 return fir::factory::Complex{builder, loc}.createComplex(type, initRe, 188 initIm); 189 } 190 if (type.isa<mlir::FloatType>()) 191 return builder.create<mlir::arith::ConstantOp>( 192 loc, type, 193 builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc))); 194 195 if (type.isa<fir::LogicalType>()) { 196 mlir::Value intConst = builder.create<mlir::arith::ConstantOp>( 197 loc, builder.getI1Type(), 198 builder.getIntegerAttr(builder.getI1Type(), 199 getOperationIdentity(redId, loc))); 200 return builder.createConvert(loc, type, intConst); 201 } 202 203 return builder.create<mlir::arith::ConstantOp>( 204 loc, type, 205 builder.getIntegerAttr(type, getOperationIdentity(redId, loc))); 206 case ReductionIdentifier::ID: 207 case ReductionIdentifier::USER_DEF_OP: 208 case ReductionIdentifier::SUBTRACT: 209 TODO(loc, "Reduction of some identifier types is not supported"); 210 } 211 llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue"); 212 } 213 214 mlir::Value ReductionProcessor::createScalarCombiner( 215 fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId, 216 mlir::Type type, mlir::Value op1, mlir::Value op2) { 217 mlir::Value reductionOp; 218 type = fir::unwrapRefType(type); 219 switch (redId) { 220 case ReductionIdentifier::MAX: 221 reductionOp = 222 getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>( 223 builder, type, loc, op1, op2); 224 break; 225 case ReductionIdentifier::MIN: 226 reductionOp = 227 getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>( 228 builder, type, loc, op1, op2); 229 break; 230 case ReductionIdentifier::IOR: 231 assert((type.isIntOrIndex()) && "only integer is expected"); 232 reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2); 233 break; 234 case ReductionIdentifier::IEOR: 235 assert((type.isIntOrIndex()) && "only integer is expected"); 236 reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2); 237 break; 238 case ReductionIdentifier::IAND: 239 assert((type.isIntOrIndex()) && "only integer is expected"); 240 reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2); 241 break; 242 case ReductionIdentifier::ADD: 243 reductionOp = 244 getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp, 245 fir::AddcOp>(builder, type, loc, op1, op2); 246 break; 247 case ReductionIdentifier::MULTIPLY: 248 reductionOp = 249 getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp, 250 fir::MulcOp>(builder, type, loc, op1, op2); 251 break; 252 case ReductionIdentifier::AND: { 253 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 254 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 255 256 mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1); 257 258 reductionOp = builder.createConvert(loc, type, andiOp); 259 break; 260 } 261 case ReductionIdentifier::OR: { 262 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 263 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 264 265 mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1); 266 267 reductionOp = builder.createConvert(loc, type, oriOp); 268 break; 269 } 270 case ReductionIdentifier::EQV: { 271 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 272 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 273 274 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 275 loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); 276 277 reductionOp = builder.createConvert(loc, type, cmpiOp); 278 break; 279 } 280 case ReductionIdentifier::NEQV: { 281 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 282 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 283 284 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 285 loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); 286 287 reductionOp = builder.createConvert(loc, type, cmpiOp); 288 break; 289 } 290 default: 291 TODO(loc, "Reduction of some intrinsic operators is not supported"); 292 } 293 294 return reductionOp; 295 } 296 297 /// Create reduction combiner region for reduction variables which are boxed 298 /// arrays 299 static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc, 300 ReductionProcessor::ReductionIdentifier redId, 301 fir::BaseBoxType boxTy, mlir::Value lhs, 302 mlir::Value rhs) { 303 fir::SequenceType seqTy = 304 mlir::dyn_cast_or_null<fir::SequenceType>(boxTy.getEleTy()); 305 // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>> 306 if (!seqTy || seqTy.hasUnknownShape()) 307 TODO(loc, "Unsupported boxed type in OpenMP reduction"); 308 309 // load fir.ref<fir.box<...>> 310 mlir::Value lhsAddr = lhs; 311 lhs = builder.create<fir::LoadOp>(loc, lhs); 312 rhs = builder.create<fir::LoadOp>(loc, rhs); 313 314 const unsigned rank = seqTy.getDimension(); 315 llvm::SmallVector<mlir::Value> extents; 316 extents.reserve(rank); 317 llvm::SmallVector<mlir::Value> lbAndExtents; 318 lbAndExtents.reserve(rank * 2); 319 320 // Get box lowerbounds and extents: 321 mlir::Type idxTy = builder.getIndexType(); 322 for (unsigned i = 0; i < rank; ++i) { 323 // TODO: ideally we want to hoist box reads out of the critical section. 324 // We could do this by having box dimensions in block arguments like 325 // OpenACC does 326 mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i); 327 auto dimInfo = 328 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, lhs, dim); 329 extents.push_back(dimInfo.getExtent()); 330 lbAndExtents.push_back(dimInfo.getLowerBound()); 331 lbAndExtents.push_back(dimInfo.getExtent()); 332 } 333 334 auto shapeShiftTy = fir::ShapeShiftType::get(builder.getContext(), rank); 335 auto shapeShift = 336 builder.create<fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents); 337 338 // Iterate over array elements, applying the equivalent scalar reduction: 339 340 // A hlfir::elemental here gets inlined with a temporary so create the 341 // loop nest directly. 342 // This function already controls all of the code in this region so we 343 // know this won't miss any opportuinties for clever elemental inlining 344 hlfir::LoopNest nest = 345 hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true); 346 builder.setInsertionPointToStart(nest.innerLoop.getBody()); 347 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy()); 348 auto lhsEleAddr = builder.create<fir::ArrayCoorOp>( 349 loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{}, 350 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); 351 auto rhsEleAddr = builder.create<fir::ArrayCoorOp>( 352 loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{}, 353 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); 354 auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr); 355 auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr); 356 mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner( 357 builder, loc, redId, refTy, lhsEle, rhsEle); 358 builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr); 359 360 builder.setInsertionPointAfter(nest.outerLoop); 361 builder.create<mlir::omp::YieldOp>(loc, lhsAddr); 362 } 363 364 // generate combiner region for reduction operations 365 static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc, 366 ReductionProcessor::ReductionIdentifier redId, 367 mlir::Type ty, mlir::Value lhs, mlir::Value rhs, 368 bool isByRef) { 369 ty = fir::unwrapRefType(ty); 370 371 if (fir::isa_trivial(ty)) { 372 mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs); 373 mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs); 374 375 mlir::Value result = ReductionProcessor::createScalarCombiner( 376 builder, loc, redId, ty, lhsLoaded, rhsLoaded); 377 if (isByRef) { 378 builder.create<fir::StoreOp>(loc, result, lhs); 379 builder.create<mlir::omp::YieldOp>(loc, lhs); 380 } else { 381 builder.create<mlir::omp::YieldOp>(loc, result); 382 } 383 return; 384 } 385 // all arrays should have been boxed 386 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { 387 genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs); 388 return; 389 } 390 391 TODO(loc, "OpenMP genCombiner for unsupported reduction variable type"); 392 } 393 394 static mlir::Value 395 createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc, 396 const ReductionProcessor::ReductionIdentifier redId, 397 mlir::Type type, bool isByRef) { 398 mlir::Type ty = fir::unwrapRefType(type); 399 mlir::Value initValue = ReductionProcessor::getReductionInitValue( 400 loc, fir::unwrapSeqOrBoxedSeqType(ty), redId, builder); 401 402 if (fir::isa_trivial(ty)) { 403 if (isByRef) { 404 mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty); 405 builder.createStoreWithConvert(loc, initValue, alloca); 406 return alloca; 407 } 408 // by val 409 return initValue; 410 } 411 412 // all arrays are boxed 413 if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) { 414 assert(isByRef && "passing arrays by value is unsupported"); 415 // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>> 416 mlir::Type innerTy = fir::extractSequenceType(boxTy); 417 if (!mlir::isa<fir::SequenceType>(innerTy)) 418 TODO(loc, "Unsupported boxed type for reduction"); 419 // Create the private copy from the initial fir.box: 420 hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)}; 421 422 // TODO: if the whole reduction is nested inside of a loop, this alloca 423 // could lead to a stack overflow (the memory is only freed at the end of 424 // the stack frame). The reduction declare operation needs a deallocation 425 // region to undo the init region. 426 hlfir::Entity temp = createStackTempFromMold(loc, builder, source); 427 428 // Put the temporary inside of a box: 429 hlfir::Entity box = hlfir::genVariableBox(loc, builder, temp); 430 builder.create<hlfir::AssignOp>(loc, initValue, box); 431 mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty); 432 builder.create<fir::StoreOp>(loc, box, boxAlloca); 433 return boxAlloca; 434 } 435 436 TODO(loc, "createReductionInitRegion for unsupported type"); 437 } 438 439 mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction( 440 fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, 441 const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, 442 bool isByRef) { 443 mlir::OpBuilder::InsertionGuard guard(builder); 444 mlir::ModuleOp module = builder.getModule(); 445 446 assert(!reductionOpName.empty()); 447 448 auto decl = 449 module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName); 450 if (decl) 451 return decl; 452 453 mlir::OpBuilder modBuilder(module.getBodyRegion()); 454 mlir::Type valTy = fir::unwrapRefType(type); 455 if (!isByRef) 456 type = valTy; 457 458 decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName, 459 type); 460 builder.createBlock(&decl.getInitializerRegion(), 461 decl.getInitializerRegion().end(), {type}, {loc}); 462 builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); 463 464 mlir::Value init = 465 createReductionInitRegion(builder, loc, redId, type, isByRef); 466 builder.create<mlir::omp::YieldOp>(loc, init); 467 468 builder.createBlock(&decl.getReductionRegion(), 469 decl.getReductionRegion().end(), {type, type}, 470 {loc, loc}); 471 472 builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); 473 mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); 474 mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); 475 genCombiner(builder, loc, redId, type, op1, op2, isByRef); 476 477 return decl; 478 } 479 480 // TODO: By-ref vs by-val reductions are currently toggled for the whole 481 // operation (possibly effecting multiple reduction variables). 482 // This could cause a problem with openmp target reductions because 483 // by-ref trivial types may not be supported. 484 bool ReductionProcessor::doReductionByRef( 485 const llvm::SmallVectorImpl<mlir::Value> &reductionVars) { 486 if (reductionVars.empty()) 487 return false; 488 if (forceByrefReduction) 489 return true; 490 491 for (mlir::Value reductionVar : reductionVars) { 492 if (auto declare = 493 mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp())) 494 reductionVar = declare.getMemref(); 495 496 if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType()))) 497 return true; 498 } 499 return false; 500 } 501 502 void ReductionProcessor::addDeclareReduction( 503 mlir::Location currentLocation, 504 Fortran::lower::AbstractConverter &converter, 505 const omp::clause::Reduction &reduction, 506 llvm::SmallVectorImpl<mlir::Value> &reductionVars, 507 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, 508 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> 509 *reductionSymbols) { 510 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 511 mlir::omp::DeclareReductionOp decl; 512 const auto &redOperatorList{ 513 std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)}; 514 assert(redOperatorList.size() == 1 && "Expecting single operator"); 515 const auto &redOperator = redOperatorList.front(); 516 const auto &objectList{std::get<omp::ObjectList>(reduction.t)}; 517 518 if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) { 519 if (const auto *reductionIntrinsic = 520 std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) { 521 if (!ReductionProcessor::supportedIntrinsicProcReduction( 522 *reductionIntrinsic)) { 523 return; 524 } 525 } else { 526 return; 527 } 528 } 529 530 // initial pass to collect all reduction vars so we can figure out if this 531 // should happen byref 532 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 533 for (const Object &object : objectList) { 534 const Fortran::semantics::Symbol *symbol = object.id(); 535 if (reductionSymbols) 536 reductionSymbols->push_back(symbol); 537 mlir::Value symVal = converter.getSymbolAddress(*symbol); 538 mlir::Type eleType; 539 auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType()); 540 if (refType) 541 eleType = refType.getEleTy(); 542 else 543 eleType = symVal.getType(); 544 545 // all arrays must be boxed so that we have convenient access to all the 546 // information needed to iterate over the array 547 if (mlir::isa<fir::SequenceType>(eleType)) { 548 // For Host associated symbols, use `SymbolBox` instead 549 Fortran::lower::SymbolBox symBox = 550 converter.lookupOneLevelUpSymbol(*symbol); 551 hlfir::Entity entity{symBox.getAddr()}; 552 entity = genVariableBox(currentLocation, builder, entity); 553 mlir::Value box = entity.getBase(); 554 555 // Always pass the box by reference so that the OpenMP dialect 556 // verifiers don't need to know anything about fir.box 557 auto alloca = 558 builder.create<fir::AllocaOp>(currentLocation, box.getType()); 559 builder.create<fir::StoreOp>(currentLocation, box, alloca); 560 561 symVal = alloca; 562 } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) { 563 // boxed arrays are passed as values not by reference. Unfortunately, 564 // we can't pass a box by value to omp.redution_declare, so turn it 565 // into a reference 566 567 auto alloca = 568 builder.create<fir::AllocaOp>(currentLocation, symVal.getType()); 569 builder.create<fir::StoreOp>(currentLocation, symVal, alloca); 570 symVal = alloca; 571 } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) { 572 symVal = declOp.getBase(); 573 } 574 575 // this isn't the same as the by-val and by-ref passing later in the 576 // pipeline. Both styles assume that the variable is a reference at 577 // this point 578 assert(mlir::isa<fir::ReferenceType>(symVal.getType()) && 579 "reduction input var is a reference"); 580 581 reductionVars.push_back(symVal); 582 } 583 const bool isByRef = doReductionByRef(reductionVars); 584 585 if (const auto &redDefinedOp = 586 std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) { 587 const auto &intrinsicOp{ 588 std::get<omp::clause::DefinedOperator::IntrinsicOperator>( 589 redDefinedOp->u)}; 590 ReductionIdentifier redId = getReductionType(intrinsicOp); 591 switch (redId) { 592 case ReductionIdentifier::ADD: 593 case ReductionIdentifier::MULTIPLY: 594 case ReductionIdentifier::AND: 595 case ReductionIdentifier::EQV: 596 case ReductionIdentifier::OR: 597 case ReductionIdentifier::NEQV: 598 break; 599 default: 600 TODO(currentLocation, 601 "Reduction of some intrinsic operators is not supported"); 602 break; 603 } 604 605 for (mlir::Value symVal : reductionVars) { 606 auto redType = mlir::cast<fir::ReferenceType>(symVal.getType()); 607 const auto &kindMap = firOpBuilder.getKindMap(); 608 if (redType.getEleTy().isa<fir::LogicalType>()) 609 decl = createDeclareReduction(firOpBuilder, 610 getReductionName(intrinsicOp, kindMap, 611 firOpBuilder.getI1Type(), 612 isByRef), 613 redId, redType, currentLocation, isByRef); 614 else 615 decl = createDeclareReduction( 616 firOpBuilder, 617 getReductionName(intrinsicOp, kindMap, redType, isByRef), redId, 618 redType, currentLocation, isByRef); 619 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( 620 firOpBuilder.getContext(), decl.getSymName())); 621 } 622 } else if (const auto *reductionIntrinsic = 623 std::get_if<omp::clause::ProcedureDesignator>( 624 &redOperator.u)) { 625 if (ReductionProcessor::supportedIntrinsicProcReduction( 626 *reductionIntrinsic)) { 627 ReductionProcessor::ReductionIdentifier redId = 628 ReductionProcessor::getReductionType(*reductionIntrinsic); 629 for (const Object &object : objectList) { 630 const Fortran::semantics::Symbol *symbol = object.id(); 631 mlir::Value symVal = converter.getSymbolAddress(*symbol); 632 if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) 633 symVal = declOp.getBase(); 634 auto redType = symVal.getType().cast<fir::ReferenceType>(); 635 if (!redType.getEleTy().isIntOrIndexOrFloat()) 636 TODO(currentLocation, "User Defined Reduction on non-trivial type"); 637 decl = createDeclareReduction( 638 firOpBuilder, 639 getReductionName(getRealName(*reductionIntrinsic).ToString(), 640 firOpBuilder.getKindMap(), redType, isByRef), 641 redId, redType, currentLocation, isByRef); 642 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( 643 firOpBuilder.getContext(), decl.getSymName())); 644 } 645 } 646 } 647 } 648 649 const Fortran::semantics::SourceName 650 ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) { 651 return symbol->GetUltimate().name(); 652 } 653 654 const Fortran::semantics::SourceName 655 ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) { 656 return getRealName(pd.v.id()); 657 } 658 659 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId, 660 mlir::Location loc) { 661 switch (redId) { 662 case ReductionIdentifier::ADD: 663 case ReductionIdentifier::OR: 664 case ReductionIdentifier::NEQV: 665 return 0; 666 case ReductionIdentifier::MULTIPLY: 667 case ReductionIdentifier::AND: 668 case ReductionIdentifier::EQV: 669 return 1; 670 default: 671 TODO(loc, "Reduction of some intrinsic operators is not supported"); 672 } 673 } 674 675 } // namespace omp 676 } // namespace lower 677 } // namespace Fortran 678