1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ReductionProcessor.h" 14 15 #include "flang/Lower/AbstractConverter.h" 16 #include "flang/Optimizer/Builder/HLFIRTools.h" 17 #include "flang/Optimizer/Builder/Todo.h" 18 #include "flang/Optimizer/Dialect/FIRType.h" 19 #include "flang/Optimizer/HLFIR/HLFIROps.h" 20 #include "flang/Parser/tools.h" 21 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 22 #include "llvm/Support/CommandLine.h" 23 24 static llvm::cl::opt<bool> forceByrefReduction( 25 "force-byref-reduction", 26 llvm::cl::desc("Pass all reduction arguments by reference"), 27 llvm::cl::Hidden); 28 29 namespace Fortran { 30 namespace lower { 31 namespace omp { 32 33 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 34 const omp::clause::ProcedureDesignator &pd) { 35 auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>( 36 getRealName(pd.v.id()).ToString()) 37 .Case("max", ReductionIdentifier::MAX) 38 .Case("min", ReductionIdentifier::MIN) 39 .Case("iand", ReductionIdentifier::IAND) 40 .Case("ior", ReductionIdentifier::IOR) 41 .Case("ieor", ReductionIdentifier::IEOR) 42 .Default(std::nullopt); 43 assert(redType && "Invalid Reduction"); 44 return *redType; 45 } 46 47 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 48 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) { 49 switch (intrinsicOp) { 50 case omp::clause::DefinedOperator::IntrinsicOperator::Add: 51 return ReductionIdentifier::ADD; 52 case omp::clause::DefinedOperator::IntrinsicOperator::Subtract: 53 return ReductionIdentifier::SUBTRACT; 54 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: 55 return ReductionIdentifier::MULTIPLY; 56 case omp::clause::DefinedOperator::IntrinsicOperator::AND: 57 return ReductionIdentifier::AND; 58 case omp::clause::DefinedOperator::IntrinsicOperator::EQV: 59 return ReductionIdentifier::EQV; 60 case omp::clause::DefinedOperator::IntrinsicOperator::OR: 61 return ReductionIdentifier::OR; 62 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: 63 return ReductionIdentifier::NEQV; 64 default: 65 llvm_unreachable("unexpected intrinsic operator in reduction"); 66 } 67 } 68 69 bool ReductionProcessor::supportedIntrinsicProcReduction( 70 const omp::clause::ProcedureDesignator &pd) { 71 Fortran::semantics::Symbol *sym = pd.v.id(); 72 if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC)) 73 return false; 74 auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString()) 75 .Case("max", true) 76 .Case("min", true) 77 .Case("iand", true) 78 .Case("ior", true) 79 .Case("ieor", true) 80 .Default(false); 81 return redType; 82 } 83 84 std::string 85 ReductionProcessor::getReductionName(llvm::StringRef name, 86 const fir::KindMapping &kindMap, 87 mlir::Type ty, bool isByRef) { 88 ty = fir::unwrapRefType(ty); 89 90 // extra string to distinguish reduction functions for variables passed by 91 // reference 92 llvm::StringRef byrefAddition{""}; 93 if (isByRef) 94 byrefAddition = "_byref"; 95 96 return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str()); 97 } 98 99 std::string ReductionProcessor::getReductionName( 100 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, 101 const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) { 102 std::string reductionName; 103 104 switch (intrinsicOp) { 105 case omp::clause::DefinedOperator::IntrinsicOperator::Add: 106 reductionName = "add_reduction"; 107 break; 108 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: 109 reductionName = "multiply_reduction"; 110 break; 111 case omp::clause::DefinedOperator::IntrinsicOperator::AND: 112 return "and_reduction"; 113 case omp::clause::DefinedOperator::IntrinsicOperator::EQV: 114 return "eqv_reduction"; 115 case omp::clause::DefinedOperator::IntrinsicOperator::OR: 116 return "or_reduction"; 117 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: 118 return "neqv_reduction"; 119 default: 120 reductionName = "other_reduction"; 121 break; 122 } 123 124 return getReductionName(reductionName, kindMap, ty, isByRef); 125 } 126 127 mlir::Value 128 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, 129 ReductionIdentifier redId, 130 fir::FirOpBuilder &builder) { 131 type = fir::unwrapRefType(type); 132 if (!fir::isa_integer(type) && !fir::isa_real(type) && 133 !mlir::isa<fir::LogicalType>(type)) 134 TODO(loc, "Reduction of some types is not supported"); 135 switch (redId) { 136 case ReductionIdentifier::MAX: { 137 if (auto ty = type.dyn_cast<mlir::FloatType>()) { 138 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 139 return builder.createRealConstant( 140 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); 141 } 142 unsigned bits = type.getIntOrFloatBitWidth(); 143 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); 144 return builder.createIntegerConstant(loc, type, minInt); 145 } 146 case ReductionIdentifier::MIN: { 147 if (auto ty = type.dyn_cast<mlir::FloatType>()) { 148 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 149 return builder.createRealConstant( 150 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); 151 } 152 unsigned bits = type.getIntOrFloatBitWidth(); 153 int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); 154 return builder.createIntegerConstant(loc, type, maxInt); 155 } 156 case ReductionIdentifier::IOR: { 157 unsigned bits = type.getIntOrFloatBitWidth(); 158 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 159 return builder.createIntegerConstant(loc, type, zeroInt); 160 } 161 case ReductionIdentifier::IEOR: { 162 unsigned bits = type.getIntOrFloatBitWidth(); 163 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 164 return builder.createIntegerConstant(loc, type, zeroInt); 165 } 166 case ReductionIdentifier::IAND: { 167 unsigned bits = type.getIntOrFloatBitWidth(); 168 int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); 169 return builder.createIntegerConstant(loc, type, allOnInt); 170 } 171 case ReductionIdentifier::ADD: 172 case ReductionIdentifier::MULTIPLY: 173 case ReductionIdentifier::AND: 174 case ReductionIdentifier::OR: 175 case ReductionIdentifier::EQV: 176 case ReductionIdentifier::NEQV: 177 if (type.isa<mlir::FloatType>()) 178 return builder.create<mlir::arith::ConstantOp>( 179 loc, type, 180 builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc))); 181 182 if (type.isa<fir::LogicalType>()) { 183 mlir::Value intConst = builder.create<mlir::arith::ConstantOp>( 184 loc, builder.getI1Type(), 185 builder.getIntegerAttr(builder.getI1Type(), 186 getOperationIdentity(redId, loc))); 187 return builder.createConvert(loc, type, intConst); 188 } 189 190 return builder.create<mlir::arith::ConstantOp>( 191 loc, type, 192 builder.getIntegerAttr(type, getOperationIdentity(redId, loc))); 193 case ReductionIdentifier::ID: 194 case ReductionIdentifier::USER_DEF_OP: 195 case ReductionIdentifier::SUBTRACT: 196 TODO(loc, "Reduction of some identifier types is not supported"); 197 } 198 llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue"); 199 } 200 201 mlir::Value ReductionProcessor::createScalarCombiner( 202 fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId, 203 mlir::Type type, mlir::Value op1, mlir::Value op2) { 204 mlir::Value reductionOp; 205 type = fir::unwrapRefType(type); 206 switch (redId) { 207 case ReductionIdentifier::MAX: 208 reductionOp = 209 getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>( 210 builder, type, loc, op1, op2); 211 break; 212 case ReductionIdentifier::MIN: 213 reductionOp = 214 getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>( 215 builder, type, loc, op1, op2); 216 break; 217 case ReductionIdentifier::IOR: 218 assert((type.isIntOrIndex()) && "only integer is expected"); 219 reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2); 220 break; 221 case ReductionIdentifier::IEOR: 222 assert((type.isIntOrIndex()) && "only integer is expected"); 223 reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2); 224 break; 225 case ReductionIdentifier::IAND: 226 assert((type.isIntOrIndex()) && "only integer is expected"); 227 reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2); 228 break; 229 case ReductionIdentifier::ADD: 230 reductionOp = 231 getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>( 232 builder, type, loc, op1, op2); 233 break; 234 case ReductionIdentifier::MULTIPLY: 235 reductionOp = 236 getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>( 237 builder, type, loc, op1, op2); 238 break; 239 case ReductionIdentifier::AND: { 240 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 241 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 242 243 mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1); 244 245 reductionOp = builder.createConvert(loc, type, andiOp); 246 break; 247 } 248 case ReductionIdentifier::OR: { 249 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 250 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 251 252 mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1); 253 254 reductionOp = builder.createConvert(loc, type, oriOp); 255 break; 256 } 257 case ReductionIdentifier::EQV: { 258 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 259 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 260 261 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 262 loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); 263 264 reductionOp = builder.createConvert(loc, type, cmpiOp); 265 break; 266 } 267 case ReductionIdentifier::NEQV: { 268 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 269 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 270 271 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 272 loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); 273 274 reductionOp = builder.createConvert(loc, type, cmpiOp); 275 break; 276 } 277 default: 278 TODO(loc, "Reduction of some intrinsic operators is not supported"); 279 } 280 281 return reductionOp; 282 } 283 284 /// Create reduction combiner region for reduction variables which are boxed 285 /// arrays 286 static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc, 287 ReductionProcessor::ReductionIdentifier redId, 288 fir::BaseBoxType boxTy, mlir::Value lhs, 289 mlir::Value rhs) { 290 fir::SequenceType seqTy = 291 mlir::dyn_cast_or_null<fir::SequenceType>(boxTy.getEleTy()); 292 // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>> 293 if (!seqTy || seqTy.hasUnknownShape()) 294 TODO(loc, "Unsupported boxed type in OpenMP reduction"); 295 296 // load fir.ref<fir.box<...>> 297 mlir::Value lhsAddr = lhs; 298 lhs = builder.create<fir::LoadOp>(loc, lhs); 299 rhs = builder.create<fir::LoadOp>(loc, rhs); 300 301 const unsigned rank = seqTy.getDimension(); 302 llvm::SmallVector<mlir::Value> extents; 303 extents.reserve(rank); 304 llvm::SmallVector<mlir::Value> lbAndExtents; 305 lbAndExtents.reserve(rank * 2); 306 307 // Get box lowerbounds and extents: 308 mlir::Type idxTy = builder.getIndexType(); 309 for (unsigned i = 0; i < rank; ++i) { 310 // TODO: ideally we want to hoist box reads out of the critical section. 311 // We could do this by having box dimensions in block arguments like 312 // OpenACC does 313 mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i); 314 auto dimInfo = 315 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, lhs, dim); 316 extents.push_back(dimInfo.getExtent()); 317 lbAndExtents.push_back(dimInfo.getLowerBound()); 318 lbAndExtents.push_back(dimInfo.getExtent()); 319 } 320 321 auto shapeShiftTy = fir::ShapeShiftType::get(builder.getContext(), rank); 322 auto shapeShift = 323 builder.create<fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents); 324 325 // Iterate over array elements, applying the equivalent scalar reduction: 326 327 // A hlfir::elemental here gets inlined with a temporary so create the 328 // loop nest directly. 329 // This function already controls all of the code in this region so we 330 // know this won't miss any opportuinties for clever elemental inlining 331 hlfir::LoopNest nest = 332 hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true); 333 builder.setInsertionPointToStart(nest.innerLoop.getBody()); 334 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy()); 335 auto lhsEleAddr = builder.create<fir::ArrayCoorOp>( 336 loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{}, 337 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); 338 auto rhsEleAddr = builder.create<fir::ArrayCoorOp>( 339 loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{}, 340 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); 341 auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr); 342 auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr); 343 mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner( 344 builder, loc, redId, refTy, lhsEle, rhsEle); 345 builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr); 346 347 builder.setInsertionPointAfter(nest.outerLoop); 348 builder.create<mlir::omp::YieldOp>(loc, lhsAddr); 349 } 350 351 // generate combiner region for reduction operations 352 static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc, 353 ReductionProcessor::ReductionIdentifier redId, 354 mlir::Type ty, mlir::Value lhs, mlir::Value rhs, 355 bool isByRef) { 356 ty = fir::unwrapRefType(ty); 357 358 if (fir::isa_trivial(ty)) { 359 mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs); 360 mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs); 361 362 mlir::Value result = ReductionProcessor::createScalarCombiner( 363 builder, loc, redId, ty, lhsLoaded, rhsLoaded); 364 if (isByRef) { 365 builder.create<fir::StoreOp>(loc, result, lhs); 366 builder.create<mlir::omp::YieldOp>(loc, lhs); 367 } else { 368 builder.create<mlir::omp::YieldOp>(loc, result); 369 } 370 return; 371 } 372 // all arrays should have been boxed 373 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { 374 genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs); 375 return; 376 } 377 378 TODO(loc, "OpenMP genCombiner for unsupported reduction variable type"); 379 } 380 381 static mlir::Value 382 createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc, 383 const ReductionProcessor::ReductionIdentifier redId, 384 mlir::Type type, bool isByRef) { 385 mlir::Type ty = fir::unwrapRefType(type); 386 mlir::Value initValue = ReductionProcessor::getReductionInitValue( 387 loc, fir::unwrapSeqOrBoxedSeqType(ty), redId, builder); 388 389 if (fir::isa_trivial(ty)) { 390 if (isByRef) { 391 mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty); 392 builder.createStoreWithConvert(loc, initValue, alloca); 393 return alloca; 394 } 395 // by val 396 return initValue; 397 } 398 399 // all arrays are boxed 400 if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) { 401 assert(isByRef && "passing arrays by value is unsupported"); 402 // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>> 403 mlir::Type innerTy = fir::extractSequenceType(boxTy); 404 if (!mlir::isa<fir::SequenceType>(innerTy)) 405 TODO(loc, "Unsupported boxed type for reduction"); 406 // Create the private copy from the initial fir.box: 407 hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)}; 408 409 // TODO: if the whole reduction is nested inside of a loop, this alloca 410 // could lead to a stack overflow (the memory is only freed at the end of 411 // the stack frame). The reduction declare operation needs a deallocation 412 // region to undo the init region. 413 hlfir::Entity temp = createStackTempFromMold(loc, builder, source); 414 415 // Put the temporary inside of a box: 416 hlfir::Entity box = hlfir::genVariableBox(loc, builder, temp); 417 builder.create<hlfir::AssignOp>(loc, initValue, box); 418 mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty); 419 builder.create<fir::StoreOp>(loc, box, boxAlloca); 420 return boxAlloca; 421 } 422 423 TODO(loc, "createReductionInitRegion for unsupported type"); 424 } 425 426 mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction( 427 fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, 428 const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, 429 bool isByRef) { 430 mlir::OpBuilder::InsertionGuard guard(builder); 431 mlir::ModuleOp module = builder.getModule(); 432 433 assert(!reductionOpName.empty()); 434 435 auto decl = 436 module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName); 437 if (decl) 438 return decl; 439 440 mlir::OpBuilder modBuilder(module.getBodyRegion()); 441 mlir::Type valTy = fir::unwrapRefType(type); 442 if (!isByRef) 443 type = valTy; 444 445 decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName, 446 type); 447 builder.createBlock(&decl.getInitializerRegion(), 448 decl.getInitializerRegion().end(), {type}, {loc}); 449 builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); 450 451 mlir::Value init = 452 createReductionInitRegion(builder, loc, redId, type, isByRef); 453 builder.create<mlir::omp::YieldOp>(loc, init); 454 455 builder.createBlock(&decl.getReductionRegion(), 456 decl.getReductionRegion().end(), {type, type}, 457 {loc, loc}); 458 459 builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); 460 mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); 461 mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); 462 genCombiner(builder, loc, redId, type, op1, op2, isByRef); 463 464 return decl; 465 } 466 467 // TODO: By-ref vs by-val reductions are currently toggled for the whole 468 // operation (possibly effecting multiple reduction variables). 469 // This could cause a problem with openmp target reductions because 470 // by-ref trivial types may not be supported. 471 bool ReductionProcessor::doReductionByRef( 472 const llvm::SmallVectorImpl<mlir::Value> &reductionVars) { 473 if (reductionVars.empty()) 474 return false; 475 if (forceByrefReduction) 476 return true; 477 478 for (mlir::Value reductionVar : reductionVars) { 479 if (auto declare = 480 mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp())) 481 reductionVar = declare.getMemref(); 482 483 if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType()))) 484 return true; 485 } 486 return false; 487 } 488 489 void ReductionProcessor::addDeclareReduction( 490 mlir::Location currentLocation, 491 Fortran::lower::AbstractConverter &converter, 492 const omp::clause::Reduction &reduction, 493 llvm::SmallVectorImpl<mlir::Value> &reductionVars, 494 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, 495 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> 496 *reductionSymbols) { 497 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 498 mlir::omp::DeclareReductionOp decl; 499 const auto &redOperator{ 500 std::get<omp::clause::ReductionOperator>(reduction.t)}; 501 const auto &objectList{std::get<omp::ObjectList>(reduction.t)}; 502 503 if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) { 504 if (const auto *reductionIntrinsic = 505 std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) { 506 if (!ReductionProcessor::supportedIntrinsicProcReduction( 507 *reductionIntrinsic)) { 508 return; 509 } 510 } else { 511 return; 512 } 513 } 514 515 // initial pass to collect all reduction vars so we can figure out if this 516 // should happen byref 517 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 518 for (const Object &object : objectList) { 519 const Fortran::semantics::Symbol *symbol = object.id(); 520 if (reductionSymbols) 521 reductionSymbols->push_back(symbol); 522 mlir::Value symVal = converter.getSymbolAddress(*symbol); 523 auto redType = mlir::cast<fir::ReferenceType>(symVal.getType()); 524 525 // all arrays must be boxed so that we have convenient access to all the 526 // information needed to iterate over the array 527 if (mlir::isa<fir::SequenceType>(redType.getEleTy())) { 528 hlfir::Entity entity{symVal}; 529 entity = genVariableBox(currentLocation, builder, entity); 530 mlir::Value box = entity.getBase(); 531 532 // Always pass the box by reference so that the OpenMP dialect 533 // verifiers don't need to know anything about fir.box 534 auto alloca = 535 builder.create<fir::AllocaOp>(currentLocation, box.getType()); 536 builder.create<fir::StoreOp>(currentLocation, box, alloca); 537 538 symVal = alloca; 539 redType = mlir::cast<fir::ReferenceType>(symVal.getType()); 540 } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) { 541 symVal = declOp.getBase(); 542 } 543 544 reductionVars.push_back(symVal); 545 } 546 const bool isByRef = doReductionByRef(reductionVars); 547 548 if (const auto &redDefinedOp = 549 std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) { 550 const auto &intrinsicOp{ 551 std::get<omp::clause::DefinedOperator::IntrinsicOperator>( 552 redDefinedOp->u)}; 553 ReductionIdentifier redId = getReductionType(intrinsicOp); 554 switch (redId) { 555 case ReductionIdentifier::ADD: 556 case ReductionIdentifier::MULTIPLY: 557 case ReductionIdentifier::AND: 558 case ReductionIdentifier::EQV: 559 case ReductionIdentifier::OR: 560 case ReductionIdentifier::NEQV: 561 break; 562 default: 563 TODO(currentLocation, 564 "Reduction of some intrinsic operators is not supported"); 565 break; 566 } 567 568 for (mlir::Value symVal : reductionVars) { 569 auto redType = mlir::cast<fir::ReferenceType>(symVal.getType()); 570 const auto &kindMap = firOpBuilder.getKindMap(); 571 if (redType.getEleTy().isa<fir::LogicalType>()) 572 decl = createDeclareReduction(firOpBuilder, 573 getReductionName(intrinsicOp, kindMap, 574 firOpBuilder.getI1Type(), 575 isByRef), 576 redId, redType, currentLocation, isByRef); 577 else 578 decl = createDeclareReduction( 579 firOpBuilder, 580 getReductionName(intrinsicOp, kindMap, redType, isByRef), redId, 581 redType, currentLocation, isByRef); 582 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( 583 firOpBuilder.getContext(), decl.getSymName())); 584 } 585 } else if (const auto *reductionIntrinsic = 586 std::get_if<omp::clause::ProcedureDesignator>( 587 &redOperator.u)) { 588 if (ReductionProcessor::supportedIntrinsicProcReduction( 589 *reductionIntrinsic)) { 590 ReductionProcessor::ReductionIdentifier redId = 591 ReductionProcessor::getReductionType(*reductionIntrinsic); 592 for (const Object &object : objectList) { 593 const Fortran::semantics::Symbol *symbol = object.id(); 594 mlir::Value symVal = converter.getSymbolAddress(*symbol); 595 if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) 596 symVal = declOp.getBase(); 597 auto redType = symVal.getType().cast<fir::ReferenceType>(); 598 if (!redType.getEleTy().isIntOrIndexOrFloat()) 599 TODO(currentLocation, "User Defined Reduction on non-trivial type"); 600 decl = createDeclareReduction( 601 firOpBuilder, 602 getReductionName(getRealName(*reductionIntrinsic).ToString(), 603 firOpBuilder.getKindMap(), redType, isByRef), 604 redId, redType, currentLocation, isByRef); 605 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( 606 firOpBuilder.getContext(), decl.getSymName())); 607 } 608 } 609 } 610 } 611 612 const Fortran::semantics::SourceName 613 ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) { 614 return symbol->GetUltimate().name(); 615 } 616 617 const Fortran::semantics::SourceName 618 ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) { 619 return getRealName(pd.v.id()); 620 } 621 622 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId, 623 mlir::Location loc) { 624 switch (redId) { 625 case ReductionIdentifier::ADD: 626 case ReductionIdentifier::OR: 627 case ReductionIdentifier::NEQV: 628 return 0; 629 case ReductionIdentifier::MULTIPLY: 630 case ReductionIdentifier::AND: 631 case ReductionIdentifier::EQV: 632 return 1; 633 default: 634 TODO(loc, "Reduction of some intrinsic operators is not supported"); 635 } 636 } 637 638 } // namespace omp 639 } // namespace lower 640 } // namespace Fortran 641