1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ReductionProcessor.h" 14 15 #include "flang/Lower/AbstractConverter.h" 16 #include "flang/Lower/SymbolMap.h" 17 #include "flang/Optimizer/Builder/HLFIRTools.h" 18 #include "flang/Optimizer/Builder/Todo.h" 19 #include "flang/Optimizer/Dialect/FIRType.h" 20 #include "flang/Optimizer/HLFIR/HLFIROps.h" 21 #include "flang/Parser/tools.h" 22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 23 #include "llvm/Support/CommandLine.h" 24 25 static llvm::cl::opt<bool> forceByrefReduction( 26 "force-byref-reduction", 27 llvm::cl::desc("Pass all reduction arguments by reference"), 28 llvm::cl::Hidden); 29 30 namespace Fortran { 31 namespace lower { 32 namespace omp { 33 34 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 35 const omp::clause::ProcedureDesignator &pd) { 36 auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>( 37 getRealName(pd.v.id()).ToString()) 38 .Case("max", ReductionIdentifier::MAX) 39 .Case("min", ReductionIdentifier::MIN) 40 .Case("iand", ReductionIdentifier::IAND) 41 .Case("ior", ReductionIdentifier::IOR) 42 .Case("ieor", ReductionIdentifier::IEOR) 43 .Default(std::nullopt); 44 assert(redType && "Invalid Reduction"); 45 return *redType; 46 } 47 48 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 49 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) { 50 switch (intrinsicOp) { 51 case omp::clause::DefinedOperator::IntrinsicOperator::Add: 52 return ReductionIdentifier::ADD; 53 case omp::clause::DefinedOperator::IntrinsicOperator::Subtract: 54 return ReductionIdentifier::SUBTRACT; 55 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: 56 return ReductionIdentifier::MULTIPLY; 57 case omp::clause::DefinedOperator::IntrinsicOperator::AND: 58 return ReductionIdentifier::AND; 59 case omp::clause::DefinedOperator::IntrinsicOperator::EQV: 60 return ReductionIdentifier::EQV; 61 case omp::clause::DefinedOperator::IntrinsicOperator::OR: 62 return ReductionIdentifier::OR; 63 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: 64 return ReductionIdentifier::NEQV; 65 default: 66 llvm_unreachable("unexpected intrinsic operator in reduction"); 67 } 68 } 69 70 bool ReductionProcessor::supportedIntrinsicProcReduction( 71 const omp::clause::ProcedureDesignator &pd) { 72 Fortran::semantics::Symbol *sym = pd.v.id(); 73 if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC)) 74 return false; 75 auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString()) 76 .Case("max", true) 77 .Case("min", true) 78 .Case("iand", true) 79 .Case("ior", true) 80 .Case("ieor", true) 81 .Default(false); 82 return redType; 83 } 84 85 std::string 86 ReductionProcessor::getReductionName(llvm::StringRef name, 87 const fir::KindMapping &kindMap, 88 mlir::Type ty, bool isByRef) { 89 ty = fir::unwrapRefType(ty); 90 91 // extra string to distinguish reduction functions for variables passed by 92 // reference 93 llvm::StringRef byrefAddition{""}; 94 if (isByRef) 95 byrefAddition = "_byref"; 96 97 return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str()); 98 } 99 100 std::string ReductionProcessor::getReductionName( 101 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, 102 const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) { 103 std::string reductionName; 104 105 switch (intrinsicOp) { 106 case omp::clause::DefinedOperator::IntrinsicOperator::Add: 107 reductionName = "add_reduction"; 108 break; 109 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: 110 reductionName = "multiply_reduction"; 111 break; 112 case omp::clause::DefinedOperator::IntrinsicOperator::AND: 113 return "and_reduction"; 114 case omp::clause::DefinedOperator::IntrinsicOperator::EQV: 115 return "eqv_reduction"; 116 case omp::clause::DefinedOperator::IntrinsicOperator::OR: 117 return "or_reduction"; 118 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: 119 return "neqv_reduction"; 120 default: 121 reductionName = "other_reduction"; 122 break; 123 } 124 125 return getReductionName(reductionName, kindMap, ty, isByRef); 126 } 127 128 mlir::Value 129 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, 130 ReductionIdentifier redId, 131 fir::FirOpBuilder &builder) { 132 type = fir::unwrapRefType(type); 133 if (!fir::isa_integer(type) && !fir::isa_real(type) && 134 !mlir::isa<fir::LogicalType>(type)) 135 TODO(loc, "Reduction of some types is not supported"); 136 switch (redId) { 137 case ReductionIdentifier::MAX: { 138 if (auto ty = type.dyn_cast<mlir::FloatType>()) { 139 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 140 return builder.createRealConstant( 141 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); 142 } 143 unsigned bits = type.getIntOrFloatBitWidth(); 144 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); 145 return builder.createIntegerConstant(loc, type, minInt); 146 } 147 case ReductionIdentifier::MIN: { 148 if (auto ty = type.dyn_cast<mlir::FloatType>()) { 149 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 150 return builder.createRealConstant( 151 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); 152 } 153 unsigned bits = type.getIntOrFloatBitWidth(); 154 int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); 155 return builder.createIntegerConstant(loc, type, maxInt); 156 } 157 case ReductionIdentifier::IOR: { 158 unsigned bits = type.getIntOrFloatBitWidth(); 159 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 160 return builder.createIntegerConstant(loc, type, zeroInt); 161 } 162 case ReductionIdentifier::IEOR: { 163 unsigned bits = type.getIntOrFloatBitWidth(); 164 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 165 return builder.createIntegerConstant(loc, type, zeroInt); 166 } 167 case ReductionIdentifier::IAND: { 168 unsigned bits = type.getIntOrFloatBitWidth(); 169 int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); 170 return builder.createIntegerConstant(loc, type, allOnInt); 171 } 172 case ReductionIdentifier::ADD: 173 case ReductionIdentifier::MULTIPLY: 174 case ReductionIdentifier::AND: 175 case ReductionIdentifier::OR: 176 case ReductionIdentifier::EQV: 177 case ReductionIdentifier::NEQV: 178 if (type.isa<mlir::FloatType>()) 179 return builder.create<mlir::arith::ConstantOp>( 180 loc, type, 181 builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc))); 182 183 if (type.isa<fir::LogicalType>()) { 184 mlir::Value intConst = builder.create<mlir::arith::ConstantOp>( 185 loc, builder.getI1Type(), 186 builder.getIntegerAttr(builder.getI1Type(), 187 getOperationIdentity(redId, loc))); 188 return builder.createConvert(loc, type, intConst); 189 } 190 191 return builder.create<mlir::arith::ConstantOp>( 192 loc, type, 193 builder.getIntegerAttr(type, getOperationIdentity(redId, loc))); 194 case ReductionIdentifier::ID: 195 case ReductionIdentifier::USER_DEF_OP: 196 case ReductionIdentifier::SUBTRACT: 197 TODO(loc, "Reduction of some identifier types is not supported"); 198 } 199 llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue"); 200 } 201 202 mlir::Value ReductionProcessor::createScalarCombiner( 203 fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId, 204 mlir::Type type, mlir::Value op1, mlir::Value op2) { 205 mlir::Value reductionOp; 206 type = fir::unwrapRefType(type); 207 switch (redId) { 208 case ReductionIdentifier::MAX: 209 reductionOp = 210 getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>( 211 builder, type, loc, op1, op2); 212 break; 213 case ReductionIdentifier::MIN: 214 reductionOp = 215 getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>( 216 builder, type, loc, op1, op2); 217 break; 218 case ReductionIdentifier::IOR: 219 assert((type.isIntOrIndex()) && "only integer is expected"); 220 reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2); 221 break; 222 case ReductionIdentifier::IEOR: 223 assert((type.isIntOrIndex()) && "only integer is expected"); 224 reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2); 225 break; 226 case ReductionIdentifier::IAND: 227 assert((type.isIntOrIndex()) && "only integer is expected"); 228 reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2); 229 break; 230 case ReductionIdentifier::ADD: 231 reductionOp = 232 getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>( 233 builder, type, loc, op1, op2); 234 break; 235 case ReductionIdentifier::MULTIPLY: 236 reductionOp = 237 getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>( 238 builder, type, loc, op1, op2); 239 break; 240 case ReductionIdentifier::AND: { 241 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 242 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 243 244 mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1); 245 246 reductionOp = builder.createConvert(loc, type, andiOp); 247 break; 248 } 249 case ReductionIdentifier::OR: { 250 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 251 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 252 253 mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1); 254 255 reductionOp = builder.createConvert(loc, type, oriOp); 256 break; 257 } 258 case ReductionIdentifier::EQV: { 259 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 260 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 261 262 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 263 loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); 264 265 reductionOp = builder.createConvert(loc, type, cmpiOp); 266 break; 267 } 268 case ReductionIdentifier::NEQV: { 269 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 270 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 271 272 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 273 loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); 274 275 reductionOp = builder.createConvert(loc, type, cmpiOp); 276 break; 277 } 278 default: 279 TODO(loc, "Reduction of some intrinsic operators is not supported"); 280 } 281 282 return reductionOp; 283 } 284 285 /// Create reduction combiner region for reduction variables which are boxed 286 /// arrays 287 static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc, 288 ReductionProcessor::ReductionIdentifier redId, 289 fir::BaseBoxType boxTy, mlir::Value lhs, 290 mlir::Value rhs) { 291 fir::SequenceType seqTy = 292 mlir::dyn_cast_or_null<fir::SequenceType>(boxTy.getEleTy()); 293 // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>> 294 if (!seqTy || seqTy.hasUnknownShape()) 295 TODO(loc, "Unsupported boxed type in OpenMP reduction"); 296 297 // load fir.ref<fir.box<...>> 298 mlir::Value lhsAddr = lhs; 299 lhs = builder.create<fir::LoadOp>(loc, lhs); 300 rhs = builder.create<fir::LoadOp>(loc, rhs); 301 302 const unsigned rank = seqTy.getDimension(); 303 llvm::SmallVector<mlir::Value> extents; 304 extents.reserve(rank); 305 llvm::SmallVector<mlir::Value> lbAndExtents; 306 lbAndExtents.reserve(rank * 2); 307 308 // Get box lowerbounds and extents: 309 mlir::Type idxTy = builder.getIndexType(); 310 for (unsigned i = 0; i < rank; ++i) { 311 // TODO: ideally we want to hoist box reads out of the critical section. 312 // We could do this by having box dimensions in block arguments like 313 // OpenACC does 314 mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i); 315 auto dimInfo = 316 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, lhs, dim); 317 extents.push_back(dimInfo.getExtent()); 318 lbAndExtents.push_back(dimInfo.getLowerBound()); 319 lbAndExtents.push_back(dimInfo.getExtent()); 320 } 321 322 auto shapeShiftTy = fir::ShapeShiftType::get(builder.getContext(), rank); 323 auto shapeShift = 324 builder.create<fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents); 325 326 // Iterate over array elements, applying the equivalent scalar reduction: 327 328 // A hlfir::elemental here gets inlined with a temporary so create the 329 // loop nest directly. 330 // This function already controls all of the code in this region so we 331 // know this won't miss any opportuinties for clever elemental inlining 332 hlfir::LoopNest nest = 333 hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true); 334 builder.setInsertionPointToStart(nest.innerLoop.getBody()); 335 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy()); 336 auto lhsEleAddr = builder.create<fir::ArrayCoorOp>( 337 loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{}, 338 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); 339 auto rhsEleAddr = builder.create<fir::ArrayCoorOp>( 340 loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{}, 341 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); 342 auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr); 343 auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr); 344 mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner( 345 builder, loc, redId, refTy, lhsEle, rhsEle); 346 builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr); 347 348 builder.setInsertionPointAfter(nest.outerLoop); 349 builder.create<mlir::omp::YieldOp>(loc, lhsAddr); 350 } 351 352 // generate combiner region for reduction operations 353 static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc, 354 ReductionProcessor::ReductionIdentifier redId, 355 mlir::Type ty, mlir::Value lhs, mlir::Value rhs, 356 bool isByRef) { 357 ty = fir::unwrapRefType(ty); 358 359 if (fir::isa_trivial(ty)) { 360 mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs); 361 mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs); 362 363 mlir::Value result = ReductionProcessor::createScalarCombiner( 364 builder, loc, redId, ty, lhsLoaded, rhsLoaded); 365 if (isByRef) { 366 builder.create<fir::StoreOp>(loc, result, lhs); 367 builder.create<mlir::omp::YieldOp>(loc, lhs); 368 } else { 369 builder.create<mlir::omp::YieldOp>(loc, result); 370 } 371 return; 372 } 373 // all arrays should have been boxed 374 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { 375 genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs); 376 return; 377 } 378 379 TODO(loc, "OpenMP genCombiner for unsupported reduction variable type"); 380 } 381 382 static mlir::Value 383 createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc, 384 const ReductionProcessor::ReductionIdentifier redId, 385 mlir::Type type, bool isByRef) { 386 mlir::Type ty = fir::unwrapRefType(type); 387 mlir::Value initValue = ReductionProcessor::getReductionInitValue( 388 loc, fir::unwrapSeqOrBoxedSeqType(ty), redId, builder); 389 390 if (fir::isa_trivial(ty)) { 391 if (isByRef) { 392 mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty); 393 builder.createStoreWithConvert(loc, initValue, alloca); 394 return alloca; 395 } 396 // by val 397 return initValue; 398 } 399 400 // all arrays are boxed 401 if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) { 402 assert(isByRef && "passing arrays by value is unsupported"); 403 // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>> 404 mlir::Type innerTy = fir::extractSequenceType(boxTy); 405 if (!mlir::isa<fir::SequenceType>(innerTy)) 406 TODO(loc, "Unsupported boxed type for reduction"); 407 // Create the private copy from the initial fir.box: 408 hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)}; 409 410 // TODO: if the whole reduction is nested inside of a loop, this alloca 411 // could lead to a stack overflow (the memory is only freed at the end of 412 // the stack frame). The reduction declare operation needs a deallocation 413 // region to undo the init region. 414 hlfir::Entity temp = createStackTempFromMold(loc, builder, source); 415 416 // Put the temporary inside of a box: 417 hlfir::Entity box = hlfir::genVariableBox(loc, builder, temp); 418 builder.create<hlfir::AssignOp>(loc, initValue, box); 419 mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty); 420 builder.create<fir::StoreOp>(loc, box, boxAlloca); 421 return boxAlloca; 422 } 423 424 TODO(loc, "createReductionInitRegion for unsupported type"); 425 } 426 427 mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction( 428 fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, 429 const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, 430 bool isByRef) { 431 mlir::OpBuilder::InsertionGuard guard(builder); 432 mlir::ModuleOp module = builder.getModule(); 433 434 assert(!reductionOpName.empty()); 435 436 auto decl = 437 module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName); 438 if (decl) 439 return decl; 440 441 mlir::OpBuilder modBuilder(module.getBodyRegion()); 442 mlir::Type valTy = fir::unwrapRefType(type); 443 if (!isByRef) 444 type = valTy; 445 446 decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName, 447 type); 448 builder.createBlock(&decl.getInitializerRegion(), 449 decl.getInitializerRegion().end(), {type}, {loc}); 450 builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); 451 452 mlir::Value init = 453 createReductionInitRegion(builder, loc, redId, type, isByRef); 454 builder.create<mlir::omp::YieldOp>(loc, init); 455 456 builder.createBlock(&decl.getReductionRegion(), 457 decl.getReductionRegion().end(), {type, type}, 458 {loc, loc}); 459 460 builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); 461 mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); 462 mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); 463 genCombiner(builder, loc, redId, type, op1, op2, isByRef); 464 465 return decl; 466 } 467 468 // TODO: By-ref vs by-val reductions are currently toggled for the whole 469 // operation (possibly effecting multiple reduction variables). 470 // This could cause a problem with openmp target reductions because 471 // by-ref trivial types may not be supported. 472 bool ReductionProcessor::doReductionByRef( 473 const llvm::SmallVectorImpl<mlir::Value> &reductionVars) { 474 if (reductionVars.empty()) 475 return false; 476 if (forceByrefReduction) 477 return true; 478 479 for (mlir::Value reductionVar : reductionVars) { 480 if (auto declare = 481 mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp())) 482 reductionVar = declare.getMemref(); 483 484 if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType()))) 485 return true; 486 } 487 return false; 488 } 489 490 void ReductionProcessor::addDeclareReduction( 491 mlir::Location currentLocation, 492 Fortran::lower::AbstractConverter &converter, 493 const omp::clause::Reduction &reduction, 494 llvm::SmallVectorImpl<mlir::Value> &reductionVars, 495 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, 496 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> 497 *reductionSymbols) { 498 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 499 mlir::omp::DeclareReductionOp decl; 500 const auto &redOperatorList{ 501 std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)}; 502 assert(redOperatorList.size() == 1 && "Expecting single operator"); 503 const auto &redOperator = redOperatorList.front(); 504 const auto &objectList{std::get<omp::ObjectList>(reduction.t)}; 505 506 if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) { 507 if (const auto *reductionIntrinsic = 508 std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) { 509 if (!ReductionProcessor::supportedIntrinsicProcReduction( 510 *reductionIntrinsic)) { 511 return; 512 } 513 } else { 514 return; 515 } 516 } 517 518 // initial pass to collect all reduction vars so we can figure out if this 519 // should happen byref 520 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 521 for (const Object &object : objectList) { 522 const Fortran::semantics::Symbol *symbol = object.id(); 523 if (reductionSymbols) 524 reductionSymbols->push_back(symbol); 525 mlir::Value symVal = converter.getSymbolAddress(*symbol); 526 auto redType = mlir::cast<fir::ReferenceType>(symVal.getType()); 527 528 // all arrays must be boxed so that we have convenient access to all the 529 // information needed to iterate over the array 530 if (mlir::isa<fir::SequenceType>(redType.getEleTy())) { 531 // For Host associated symbols, use `SymbolBox` instead 532 Fortran::lower::SymbolBox symBox = 533 converter.lookupOneLevelUpSymbol(*symbol); 534 hlfir::Entity entity{symBox.getAddr()}; 535 entity = genVariableBox(currentLocation, builder, entity); 536 mlir::Value box = entity.getBase(); 537 538 // Always pass the box by reference so that the OpenMP dialect 539 // verifiers don't need to know anything about fir.box 540 auto alloca = 541 builder.create<fir::AllocaOp>(currentLocation, box.getType()); 542 builder.create<fir::StoreOp>(currentLocation, box, alloca); 543 544 symVal = alloca; 545 redType = mlir::cast<fir::ReferenceType>(symVal.getType()); 546 } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) { 547 symVal = declOp.getBase(); 548 } 549 550 reductionVars.push_back(symVal); 551 } 552 const bool isByRef = doReductionByRef(reductionVars); 553 554 if (const auto &redDefinedOp = 555 std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) { 556 const auto &intrinsicOp{ 557 std::get<omp::clause::DefinedOperator::IntrinsicOperator>( 558 redDefinedOp->u)}; 559 ReductionIdentifier redId = getReductionType(intrinsicOp); 560 switch (redId) { 561 case ReductionIdentifier::ADD: 562 case ReductionIdentifier::MULTIPLY: 563 case ReductionIdentifier::AND: 564 case ReductionIdentifier::EQV: 565 case ReductionIdentifier::OR: 566 case ReductionIdentifier::NEQV: 567 break; 568 default: 569 TODO(currentLocation, 570 "Reduction of some intrinsic operators is not supported"); 571 break; 572 } 573 574 for (mlir::Value symVal : reductionVars) { 575 auto redType = mlir::cast<fir::ReferenceType>(symVal.getType()); 576 const auto &kindMap = firOpBuilder.getKindMap(); 577 if (redType.getEleTy().isa<fir::LogicalType>()) 578 decl = createDeclareReduction(firOpBuilder, 579 getReductionName(intrinsicOp, kindMap, 580 firOpBuilder.getI1Type(), 581 isByRef), 582 redId, redType, currentLocation, isByRef); 583 else 584 decl = createDeclareReduction( 585 firOpBuilder, 586 getReductionName(intrinsicOp, kindMap, redType, isByRef), redId, 587 redType, currentLocation, isByRef); 588 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( 589 firOpBuilder.getContext(), decl.getSymName())); 590 } 591 } else if (const auto *reductionIntrinsic = 592 std::get_if<omp::clause::ProcedureDesignator>( 593 &redOperator.u)) { 594 if (ReductionProcessor::supportedIntrinsicProcReduction( 595 *reductionIntrinsic)) { 596 ReductionProcessor::ReductionIdentifier redId = 597 ReductionProcessor::getReductionType(*reductionIntrinsic); 598 for (const Object &object : objectList) { 599 const Fortran::semantics::Symbol *symbol = object.id(); 600 mlir::Value symVal = converter.getSymbolAddress(*symbol); 601 if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) 602 symVal = declOp.getBase(); 603 auto redType = symVal.getType().cast<fir::ReferenceType>(); 604 if (!redType.getEleTy().isIntOrIndexOrFloat()) 605 TODO(currentLocation, "User Defined Reduction on non-trivial type"); 606 decl = createDeclareReduction( 607 firOpBuilder, 608 getReductionName(getRealName(*reductionIntrinsic).ToString(), 609 firOpBuilder.getKindMap(), redType, isByRef), 610 redId, redType, currentLocation, isByRef); 611 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( 612 firOpBuilder.getContext(), decl.getSymName())); 613 } 614 } 615 } 616 } 617 618 const Fortran::semantics::SourceName 619 ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) { 620 return symbol->GetUltimate().name(); 621 } 622 623 const Fortran::semantics::SourceName 624 ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) { 625 return getRealName(pd.v.id()); 626 } 627 628 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId, 629 mlir::Location loc) { 630 switch (redId) { 631 case ReductionIdentifier::ADD: 632 case ReductionIdentifier::OR: 633 case ReductionIdentifier::NEQV: 634 return 0; 635 case ReductionIdentifier::MULTIPLY: 636 case ReductionIdentifier::AND: 637 case ReductionIdentifier::EQV: 638 return 1; 639 default: 640 TODO(loc, "Reduction of some intrinsic operators is not supported"); 641 } 642 } 643 644 } // namespace omp 645 } // namespace lower 646 } // namespace Fortran 647