1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ReductionProcessor.h" 14 15 #include "flang/Lower/AbstractConverter.h" 16 #include "flang/Lower/ConvertType.h" 17 #include "flang/Lower/SymbolMap.h" 18 #include "flang/Optimizer/Builder/Complex.h" 19 #include "flang/Optimizer/Builder/HLFIRTools.h" 20 #include "flang/Optimizer/Builder/Todo.h" 21 #include "flang/Optimizer/Dialect/FIRType.h" 22 #include "flang/Optimizer/HLFIR/HLFIROps.h" 23 #include "flang/Optimizer/Support/FatalError.h" 24 #include "flang/Parser/tools.h" 25 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 26 #include "llvm/Support/CommandLine.h" 27 28 static llvm::cl::opt<bool> forceByrefReduction( 29 "force-byref-reduction", 30 llvm::cl::desc("Pass all reduction arguments by reference"), 31 llvm::cl::Hidden); 32 33 namespace Fortran { 34 namespace lower { 35 namespace omp { 36 37 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 38 const omp::clause::ProcedureDesignator &pd) { 39 auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>( 40 getRealName(pd.v.id()).ToString()) 41 .Case("max", ReductionIdentifier::MAX) 42 .Case("min", ReductionIdentifier::MIN) 43 .Case("iand", ReductionIdentifier::IAND) 44 .Case("ior", ReductionIdentifier::IOR) 45 .Case("ieor", ReductionIdentifier::IEOR) 46 .Default(std::nullopt); 47 assert(redType && "Invalid Reduction"); 48 return *redType; 49 } 50 51 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 52 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) { 53 switch (intrinsicOp) { 54 case omp::clause::DefinedOperator::IntrinsicOperator::Add: 55 return ReductionIdentifier::ADD; 56 case omp::clause::DefinedOperator::IntrinsicOperator::Subtract: 57 return ReductionIdentifier::SUBTRACT; 58 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: 59 return ReductionIdentifier::MULTIPLY; 60 case omp::clause::DefinedOperator::IntrinsicOperator::AND: 61 return ReductionIdentifier::AND; 62 case omp::clause::DefinedOperator::IntrinsicOperator::EQV: 63 return ReductionIdentifier::EQV; 64 case omp::clause::DefinedOperator::IntrinsicOperator::OR: 65 return ReductionIdentifier::OR; 66 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: 67 return ReductionIdentifier::NEQV; 68 default: 69 llvm_unreachable("unexpected intrinsic operator in reduction"); 70 } 71 } 72 73 bool ReductionProcessor::supportedIntrinsicProcReduction( 74 const omp::clause::ProcedureDesignator &pd) { 75 Fortran::semantics::Symbol *sym = pd.v.id(); 76 if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC)) 77 return false; 78 auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString()) 79 .Case("max", true) 80 .Case("min", true) 81 .Case("iand", true) 82 .Case("ior", true) 83 .Case("ieor", true) 84 .Default(false); 85 return redType; 86 } 87 88 std::string 89 ReductionProcessor::getReductionName(llvm::StringRef name, 90 const fir::KindMapping &kindMap, 91 mlir::Type ty, bool isByRef) { 92 ty = fir::unwrapRefType(ty); 93 94 // extra string to distinguish reduction functions for variables passed by 95 // reference 96 llvm::StringRef byrefAddition{""}; 97 if (isByRef) 98 byrefAddition = "_byref"; 99 100 return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str()); 101 } 102 103 std::string ReductionProcessor::getReductionName( 104 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, 105 const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) { 106 std::string reductionName; 107 108 switch (intrinsicOp) { 109 case omp::clause::DefinedOperator::IntrinsicOperator::Add: 110 reductionName = "add_reduction"; 111 break; 112 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: 113 reductionName = "multiply_reduction"; 114 break; 115 case omp::clause::DefinedOperator::IntrinsicOperator::AND: 116 return "and_reduction"; 117 case omp::clause::DefinedOperator::IntrinsicOperator::EQV: 118 return "eqv_reduction"; 119 case omp::clause::DefinedOperator::IntrinsicOperator::OR: 120 return "or_reduction"; 121 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: 122 return "neqv_reduction"; 123 default: 124 reductionName = "other_reduction"; 125 break; 126 } 127 128 return getReductionName(reductionName, kindMap, ty, isByRef); 129 } 130 131 mlir::Value 132 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, 133 ReductionIdentifier redId, 134 fir::FirOpBuilder &builder) { 135 type = fir::unwrapRefType(type); 136 if (!fir::isa_integer(type) && !fir::isa_real(type) && 137 !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type)) 138 TODO(loc, "Reduction of some types is not supported"); 139 switch (redId) { 140 case ReductionIdentifier::MAX: { 141 if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) { 142 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 143 return builder.createRealConstant( 144 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); 145 } 146 unsigned bits = type.getIntOrFloatBitWidth(); 147 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); 148 return builder.createIntegerConstant(loc, type, minInt); 149 } 150 case ReductionIdentifier::MIN: { 151 if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) { 152 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 153 return builder.createRealConstant( 154 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); 155 } 156 unsigned bits = type.getIntOrFloatBitWidth(); 157 int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); 158 return builder.createIntegerConstant(loc, type, maxInt); 159 } 160 case ReductionIdentifier::IOR: { 161 unsigned bits = type.getIntOrFloatBitWidth(); 162 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 163 return builder.createIntegerConstant(loc, type, zeroInt); 164 } 165 case ReductionIdentifier::IEOR: { 166 unsigned bits = type.getIntOrFloatBitWidth(); 167 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 168 return builder.createIntegerConstant(loc, type, zeroInt); 169 } 170 case ReductionIdentifier::IAND: { 171 unsigned bits = type.getIntOrFloatBitWidth(); 172 int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); 173 return builder.createIntegerConstant(loc, type, allOnInt); 174 } 175 case ReductionIdentifier::ADD: 176 case ReductionIdentifier::MULTIPLY: 177 case ReductionIdentifier::AND: 178 case ReductionIdentifier::OR: 179 case ReductionIdentifier::EQV: 180 case ReductionIdentifier::NEQV: 181 if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) { 182 mlir::Type realTy = 183 Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind()); 184 mlir::Value initRe = builder.createRealConstant( 185 loc, realTy, getOperationIdentity(redId, loc)); 186 mlir::Value initIm = builder.createRealConstant(loc, realTy, 0); 187 188 return fir::factory::Complex{builder, loc}.createComplex(type, initRe, 189 initIm); 190 } 191 if (mlir::isa<mlir::FloatType>(type)) 192 return builder.create<mlir::arith::ConstantOp>( 193 loc, type, 194 builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc))); 195 196 if (mlir::isa<fir::LogicalType>(type)) { 197 mlir::Value intConst = builder.create<mlir::arith::ConstantOp>( 198 loc, builder.getI1Type(), 199 builder.getIntegerAttr(builder.getI1Type(), 200 getOperationIdentity(redId, loc))); 201 return builder.createConvert(loc, type, intConst); 202 } 203 204 return builder.create<mlir::arith::ConstantOp>( 205 loc, type, 206 builder.getIntegerAttr(type, getOperationIdentity(redId, loc))); 207 case ReductionIdentifier::ID: 208 case ReductionIdentifier::USER_DEF_OP: 209 case ReductionIdentifier::SUBTRACT: 210 TODO(loc, "Reduction of some identifier types is not supported"); 211 } 212 llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue"); 213 } 214 215 mlir::Value ReductionProcessor::createScalarCombiner( 216 fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId, 217 mlir::Type type, mlir::Value op1, mlir::Value op2) { 218 mlir::Value reductionOp; 219 type = fir::unwrapRefType(type); 220 switch (redId) { 221 case ReductionIdentifier::MAX: 222 reductionOp = 223 getReductionOperation<mlir::arith::MaxNumFOp, mlir::arith::MaxSIOp>( 224 builder, type, loc, op1, op2); 225 break; 226 case ReductionIdentifier::MIN: 227 reductionOp = 228 getReductionOperation<mlir::arith::MinNumFOp, mlir::arith::MinSIOp>( 229 builder, type, loc, op1, op2); 230 break; 231 case ReductionIdentifier::IOR: 232 assert((type.isIntOrIndex()) && "only integer is expected"); 233 reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2); 234 break; 235 case ReductionIdentifier::IEOR: 236 assert((type.isIntOrIndex()) && "only integer is expected"); 237 reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2); 238 break; 239 case ReductionIdentifier::IAND: 240 assert((type.isIntOrIndex()) && "only integer is expected"); 241 reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2); 242 break; 243 case ReductionIdentifier::ADD: 244 reductionOp = 245 getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp, 246 fir::AddcOp>(builder, type, loc, op1, op2); 247 break; 248 case ReductionIdentifier::MULTIPLY: 249 reductionOp = 250 getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp, 251 fir::MulcOp>(builder, type, loc, op1, op2); 252 break; 253 case ReductionIdentifier::AND: { 254 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 255 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 256 257 mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1); 258 259 reductionOp = builder.createConvert(loc, type, andiOp); 260 break; 261 } 262 case ReductionIdentifier::OR: { 263 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 264 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 265 266 mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1); 267 268 reductionOp = builder.createConvert(loc, type, oriOp); 269 break; 270 } 271 case ReductionIdentifier::EQV: { 272 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 273 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 274 275 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 276 loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); 277 278 reductionOp = builder.createConvert(loc, type, cmpiOp); 279 break; 280 } 281 case ReductionIdentifier::NEQV: { 282 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 283 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 284 285 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 286 loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); 287 288 reductionOp = builder.createConvert(loc, type, cmpiOp); 289 break; 290 } 291 default: 292 TODO(loc, "Reduction of some intrinsic operators is not supported"); 293 } 294 295 return reductionOp; 296 } 297 298 /// Generate a fir::ShapeShift op describing the provided boxed array. 299 static fir::ShapeShiftOp getShapeShift(fir::FirOpBuilder &builder, 300 mlir::Location loc, mlir::Value box) { 301 fir::SequenceType sequenceType = mlir::cast<fir::SequenceType>( 302 hlfir::getFortranElementOrSequenceType(box.getType())); 303 const unsigned rank = sequenceType.getDimension(); 304 llvm::SmallVector<mlir::Value> lbAndExtents; 305 lbAndExtents.reserve(rank * 2); 306 307 mlir::Type idxTy = builder.getIndexType(); 308 for (unsigned i = 0; i < rank; ++i) { 309 // TODO: ideally we want to hoist box reads out of the critical section. 310 // We could do this by having box dimensions in block arguments like 311 // OpenACC does 312 mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i); 313 auto dimInfo = 314 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, box, dim); 315 lbAndExtents.push_back(dimInfo.getLowerBound()); 316 lbAndExtents.push_back(dimInfo.getExtent()); 317 } 318 319 auto shapeShiftTy = fir::ShapeShiftType::get(builder.getContext(), rank); 320 auto shapeShift = 321 builder.create<fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents); 322 return shapeShift; 323 } 324 325 /// Create reduction combiner region for reduction variables which are boxed 326 /// arrays 327 static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc, 328 ReductionProcessor::ReductionIdentifier redId, 329 fir::BaseBoxType boxTy, mlir::Value lhs, 330 mlir::Value rhs) { 331 fir::SequenceType seqTy = mlir::dyn_cast_or_null<fir::SequenceType>( 332 fir::unwrapRefType(boxTy.getEleTy())); 333 fir::HeapType heapTy = 334 mlir::dyn_cast_or_null<fir::HeapType>(boxTy.getEleTy()); 335 if ((!seqTy || seqTy.hasUnknownShape()) && !heapTy) 336 TODO(loc, "Unsupported boxed type in OpenMP reduction"); 337 338 // load fir.ref<fir.box<...>> 339 mlir::Value lhsAddr = lhs; 340 lhs = builder.create<fir::LoadOp>(loc, lhs); 341 rhs = builder.create<fir::LoadOp>(loc, rhs); 342 343 if (heapTy && !seqTy) { 344 // get box contents (heap pointers) 345 lhs = builder.create<fir::BoxAddrOp>(loc, lhs); 346 rhs = builder.create<fir::BoxAddrOp>(loc, rhs); 347 mlir::Value lhsValAddr = lhs; 348 349 // load heap pointers 350 lhs = builder.create<fir::LoadOp>(loc, lhs); 351 rhs = builder.create<fir::LoadOp>(loc, rhs); 352 353 mlir::Value result = ReductionProcessor::createScalarCombiner( 354 builder, loc, redId, heapTy.getEleTy(), lhs, rhs); 355 builder.create<fir::StoreOp>(loc, result, lhsValAddr); 356 builder.create<mlir::omp::YieldOp>(loc, lhsAddr); 357 return; 358 } 359 360 fir::ShapeShiftOp shapeShift = getShapeShift(builder, loc, lhs); 361 362 // Iterate over array elements, applying the equivalent scalar reduction: 363 364 // F2018 5.4.10.2: Unallocated allocatable variables may not be referenced 365 // and so no null check is needed here before indexing into the (possibly 366 // allocatable) arrays. 367 368 // A hlfir::elemental here gets inlined with a temporary so create the 369 // loop nest directly. 370 // This function already controls all of the code in this region so we 371 // know this won't miss any opportuinties for clever elemental inlining 372 hlfir::LoopNest nest = hlfir::genLoopNest( 373 loc, builder, shapeShift.getExtents(), /*isUnordered=*/true); 374 builder.setInsertionPointToStart(nest.innerLoop.getBody()); 375 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy()); 376 auto lhsEleAddr = builder.create<fir::ArrayCoorOp>( 377 loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{}, 378 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); 379 auto rhsEleAddr = builder.create<fir::ArrayCoorOp>( 380 loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{}, 381 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); 382 auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr); 383 auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr); 384 mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner( 385 builder, loc, redId, refTy, lhsEle, rhsEle); 386 builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr); 387 388 builder.setInsertionPointAfter(nest.outerLoop); 389 builder.create<mlir::omp::YieldOp>(loc, lhsAddr); 390 } 391 392 // generate combiner region for reduction operations 393 static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc, 394 ReductionProcessor::ReductionIdentifier redId, 395 mlir::Type ty, mlir::Value lhs, mlir::Value rhs, 396 bool isByRef) { 397 ty = fir::unwrapRefType(ty); 398 399 if (fir::isa_trivial(ty)) { 400 mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs); 401 mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs); 402 403 mlir::Value result = ReductionProcessor::createScalarCombiner( 404 builder, loc, redId, ty, lhsLoaded, rhsLoaded); 405 if (isByRef) { 406 builder.create<fir::StoreOp>(loc, result, lhs); 407 builder.create<mlir::omp::YieldOp>(loc, lhs); 408 } else { 409 builder.create<mlir::omp::YieldOp>(loc, result); 410 } 411 return; 412 } 413 // all arrays should have been boxed 414 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { 415 genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs); 416 return; 417 } 418 419 TODO(loc, "OpenMP genCombiner for unsupported reduction variable type"); 420 } 421 422 static void 423 createReductionCleanupRegion(fir::FirOpBuilder &builder, mlir::Location loc, 424 mlir::omp::DeclareReductionOp &reductionDecl) { 425 mlir::Type redTy = reductionDecl.getType(); 426 427 mlir::Region &cleanupRegion = reductionDecl.getCleanupRegion(); 428 assert(cleanupRegion.empty()); 429 mlir::Block *block = 430 builder.createBlock(&cleanupRegion, cleanupRegion.end(), {redTy}, {loc}); 431 builder.setInsertionPointToEnd(block); 432 433 auto typeError = [loc]() { 434 fir::emitFatalError(loc, 435 "Attempt to create an omp reduction cleanup region " 436 "for a type that wasn't allocated", 437 /*genCrashDiag=*/true); 438 }; 439 440 mlir::Type valTy = fir::unwrapRefType(redTy); 441 if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(valTy)) { 442 if (!mlir::isa<fir::HeapType>(boxTy.getEleTy())) { 443 mlir::Type innerTy = fir::extractSequenceType(boxTy); 444 if (!mlir::isa<fir::SequenceType>(innerTy)) 445 typeError(); 446 } 447 448 mlir::Value arg = block->getArgument(0); 449 arg = builder.loadIfRef(loc, arg); 450 assert(mlir::isa<fir::BaseBoxType>(arg.getType())); 451 452 // Deallocate box 453 // The FIR type system doesn't nesecarrily know that this is a mutable box 454 // if we allocated the thread local array on the heap to avoid looped stack 455 // allocations. 456 mlir::Value addr = 457 hlfir::genVariableRawAddress(loc, builder, hlfir::Entity{arg}); 458 mlir::Value isAllocated = builder.genIsNotNullAddr(loc, addr); 459 fir::IfOp ifOp = 460 builder.create<fir::IfOp>(loc, isAllocated, /*withElseRegion=*/false); 461 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 462 463 mlir::Value cast = builder.createConvert( 464 loc, fir::HeapType::get(fir::dyn_cast_ptrEleTy(addr.getType())), addr); 465 builder.create<fir::FreeMemOp>(loc, cast); 466 467 builder.setInsertionPointAfter(ifOp); 468 builder.create<mlir::omp::YieldOp>(loc); 469 return; 470 } 471 472 typeError(); 473 } 474 475 // like fir::unwrapSeqOrBoxedSeqType except it also works for non-sequence boxes 476 static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) { 477 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) 478 return seqTy.getEleTy(); 479 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { 480 auto eleTy = fir::unwrapRefType(boxTy.getEleTy()); 481 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) 482 return seqTy.getEleTy(); 483 return eleTy; 484 } 485 return ty; 486 } 487 488 static mlir::Value 489 createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc, 490 mlir::omp::DeclareReductionOp &reductionDecl, 491 const ReductionProcessor::ReductionIdentifier redId, 492 mlir::Type type, bool isByRef) { 493 mlir::Type ty = fir::unwrapRefType(type); 494 mlir::Value initValue = ReductionProcessor::getReductionInitValue( 495 loc, unwrapSeqOrBoxedType(ty), redId, builder); 496 497 if (fir::isa_trivial(ty)) { 498 if (isByRef) { 499 mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty); 500 builder.createStoreWithConvert(loc, initValue, alloca); 501 return alloca; 502 } 503 // by val 504 return initValue; 505 } 506 507 // check if an allocatable box is unallocated. If so, initialize the boxAlloca 508 // to be unallocated e.g. 509 // %box_alloca = fir.alloca !fir.box<!fir.heap<...>> 510 // %addr = fir.box_addr %box 511 // if (%addr == 0) { 512 // %nullbox = fir.embox %addr 513 // fir.store %nullbox to %box_alloca 514 // } else { 515 // // ... 516 // fir.store %something to %box_alloca 517 // } 518 // omp.yield %box_alloca 519 mlir::Value blockArg = 520 builder.loadIfRef(loc, builder.getBlock()->getArgument(0)); 521 auto handleNullAllocatable = [&](mlir::Value boxAlloca) -> fir::IfOp { 522 mlir::Value addr = builder.create<fir::BoxAddrOp>(loc, blockArg); 523 mlir::Value isNotAllocated = builder.genIsNullAddr(loc, addr); 524 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, isNotAllocated, 525 /*withElseRegion=*/true); 526 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 527 // just embox the null address and return 528 mlir::Value nullBox = builder.create<fir::EmboxOp>(loc, ty, addr); 529 builder.create<fir::StoreOp>(loc, nullBox, boxAlloca); 530 return ifOp; 531 }; 532 533 // all arrays are boxed 534 if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) { 535 assert(isByRef && "passing boxes by value is unsupported"); 536 bool isAllocatable = mlir::isa<fir::HeapType>(boxTy.getEleTy()); 537 mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty); 538 mlir::Type innerTy = fir::unwrapRefType(boxTy.getEleTy()); 539 if (fir::isa_trivial(innerTy)) { 540 // boxed non-sequence value e.g. !fir.box<!fir.heap<i32>> 541 if (!isAllocatable) 542 TODO(loc, "Reduction of non-allocatable trivial typed box"); 543 544 fir::IfOp ifUnallocated = handleNullAllocatable(boxAlloca); 545 546 builder.setInsertionPointToStart(&ifUnallocated.getElseRegion().front()); 547 mlir::Value valAlloc = builder.create<fir::AllocMemOp>(loc, innerTy); 548 builder.createStoreWithConvert(loc, initValue, valAlloc); 549 mlir::Value box = builder.create<fir::EmboxOp>(loc, ty, valAlloc); 550 builder.create<fir::StoreOp>(loc, box, boxAlloca); 551 552 auto insPt = builder.saveInsertionPoint(); 553 createReductionCleanupRegion(builder, loc, reductionDecl); 554 builder.restoreInsertionPoint(insPt); 555 builder.setInsertionPointAfter(ifUnallocated); 556 return boxAlloca; 557 } 558 innerTy = fir::extractSequenceType(boxTy); 559 if (!mlir::isa<fir::SequenceType>(innerTy)) 560 TODO(loc, "Unsupported boxed type for reduction"); 561 562 fir::IfOp ifUnallocated{nullptr}; 563 if (isAllocatable) { 564 ifUnallocated = handleNullAllocatable(boxAlloca); 565 builder.setInsertionPointToStart(&ifUnallocated.getElseRegion().front()); 566 } 567 568 // Create the private copy from the initial fir.box: 569 mlir::Value loadedBox = builder.loadIfRef(loc, blockArg); 570 hlfir::Entity source = hlfir::Entity{loadedBox}; 571 572 // Allocating on the heap in case the whole reduction is nested inside of a 573 // loop 574 // TODO: compare performance here to using allocas - this could be made to 575 // work by inserting stacksave/stackrestore around the reduction in 576 // openmpirbuilder 577 auto [temp, needsDealloc] = createTempFromMold(loc, builder, source); 578 // if needsDealloc isn't statically false, add cleanup region. Always 579 // do this for allocatable boxes because they might have been re-allocated 580 // in the body of the loop/parallel region 581 582 std::optional<int64_t> cstNeedsDealloc = 583 fir::getIntIfConstant(needsDealloc); 584 assert(cstNeedsDealloc.has_value() && 585 "createTempFromMold decides this statically"); 586 if (cstNeedsDealloc.has_value() && *cstNeedsDealloc != false) { 587 mlir::OpBuilder::InsertionGuard guard(builder); 588 createReductionCleanupRegion(builder, loc, reductionDecl); 589 } else { 590 assert(!isAllocatable && "Allocatable arrays must be heap allocated"); 591 } 592 593 // Put the temporary inside of a box: 594 // hlfir::genVariableBox doesn't handle non-default lower bounds 595 mlir::Value box; 596 fir::ShapeShiftOp shapeShift = getShapeShift(builder, loc, loadedBox); 597 mlir::Type boxType = loadedBox.getType(); 598 if (mlir::isa<fir::BaseBoxType>(temp.getType())) 599 // the box created by the declare form createTempFromMold is missing lower 600 // bounds info 601 box = builder.create<fir::ReboxOp>(loc, boxType, temp, shapeShift, 602 /*shift=*/mlir::Value{}); 603 else 604 box = builder.create<fir::EmboxOp>( 605 loc, boxType, temp, shapeShift, 606 /*slice=*/mlir::Value{}, 607 /*typeParams=*/llvm::ArrayRef<mlir::Value>{}); 608 609 builder.create<hlfir::AssignOp>(loc, initValue, box); 610 builder.create<fir::StoreOp>(loc, box, boxAlloca); 611 if (ifUnallocated) 612 builder.setInsertionPointAfter(ifUnallocated); 613 return boxAlloca; 614 } 615 616 TODO(loc, "createReductionInitRegion for unsupported type"); 617 } 618 619 mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction( 620 fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, 621 const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, 622 bool isByRef) { 623 mlir::OpBuilder::InsertionGuard guard(builder); 624 mlir::ModuleOp module = builder.getModule(); 625 626 assert(!reductionOpName.empty()); 627 628 auto decl = 629 module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName); 630 if (decl) 631 return decl; 632 633 mlir::OpBuilder modBuilder(module.getBodyRegion()); 634 mlir::Type valTy = fir::unwrapRefType(type); 635 if (!isByRef) 636 type = valTy; 637 638 decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName, 639 type); 640 builder.createBlock(&decl.getInitializerRegion(), 641 decl.getInitializerRegion().end(), {type}, {loc}); 642 builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); 643 644 mlir::Value init = 645 createReductionInitRegion(builder, loc, decl, redId, type, isByRef); 646 builder.create<mlir::omp::YieldOp>(loc, init); 647 648 builder.createBlock(&decl.getReductionRegion(), 649 decl.getReductionRegion().end(), {type, type}, 650 {loc, loc}); 651 652 builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); 653 mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); 654 mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); 655 genCombiner(builder, loc, redId, type, op1, op2, isByRef); 656 657 return decl; 658 } 659 660 // TODO: By-ref vs by-val reductions are currently toggled for the whole 661 // operation (possibly effecting multiple reduction variables). 662 // This could cause a problem with openmp target reductions because 663 // by-ref trivial types may not be supported. 664 bool ReductionProcessor::doReductionByRef( 665 const llvm::SmallVectorImpl<mlir::Value> &reductionVars) { 666 if (reductionVars.empty()) 667 return false; 668 if (forceByrefReduction) 669 return true; 670 671 for (mlir::Value reductionVar : reductionVars) { 672 if (auto declare = 673 mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp())) 674 reductionVar = declare.getMemref(); 675 676 if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType()))) 677 return true; 678 } 679 return false; 680 } 681 682 void ReductionProcessor::addDeclareReduction( 683 mlir::Location currentLocation, 684 Fortran::lower::AbstractConverter &converter, 685 const omp::clause::Reduction &reduction, 686 llvm::SmallVectorImpl<mlir::Value> &reductionVars, 687 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, 688 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> 689 *reductionSymbols) { 690 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 691 692 if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>( 693 reduction.t)) 694 TODO(currentLocation, "Reduction modifiers are not supported"); 695 696 mlir::omp::DeclareReductionOp decl; 697 const auto &redOperatorList{ 698 std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)}; 699 assert(redOperatorList.size() == 1 && "Expecting single operator"); 700 const auto &redOperator = redOperatorList.front(); 701 const auto &objectList{std::get<omp::ObjectList>(reduction.t)}; 702 703 if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) { 704 if (const auto *reductionIntrinsic = 705 std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) { 706 if (!ReductionProcessor::supportedIntrinsicProcReduction( 707 *reductionIntrinsic)) { 708 return; 709 } 710 } else { 711 return; 712 } 713 } 714 715 // initial pass to collect all reduction vars so we can figure out if this 716 // should happen byref 717 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 718 for (const Object &object : objectList) { 719 const Fortran::semantics::Symbol *symbol = object.id(); 720 if (reductionSymbols) 721 reductionSymbols->push_back(symbol); 722 mlir::Value symVal = converter.getSymbolAddress(*symbol); 723 mlir::Type eleType; 724 auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType()); 725 if (refType) 726 eleType = refType.getEleTy(); 727 else 728 eleType = symVal.getType(); 729 730 // all arrays must be boxed so that we have convenient access to all the 731 // information needed to iterate over the array 732 if (mlir::isa<fir::SequenceType>(eleType)) { 733 // For Host associated symbols, use `SymbolBox` instead 734 Fortran::lower::SymbolBox symBox = 735 converter.lookupOneLevelUpSymbol(*symbol); 736 hlfir::Entity entity{symBox.getAddr()}; 737 entity = genVariableBox(currentLocation, builder, entity); 738 mlir::Value box = entity.getBase(); 739 740 // Always pass the box by reference so that the OpenMP dialect 741 // verifiers don't need to know anything about fir.box 742 auto alloca = 743 builder.create<fir::AllocaOp>(currentLocation, box.getType()); 744 builder.create<fir::StoreOp>(currentLocation, box, alloca); 745 746 symVal = alloca; 747 } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) { 748 // boxed arrays are passed as values not by reference. Unfortunately, 749 // we can't pass a box by value to omp.redution_declare, so turn it 750 // into a reference 751 752 auto alloca = 753 builder.create<fir::AllocaOp>(currentLocation, symVal.getType()); 754 builder.create<fir::StoreOp>(currentLocation, symVal, alloca); 755 symVal = alloca; 756 } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) { 757 symVal = declOp.getBase(); 758 } 759 760 // this isn't the same as the by-val and by-ref passing later in the 761 // pipeline. Both styles assume that the variable is a reference at 762 // this point 763 assert(mlir::isa<fir::ReferenceType>(symVal.getType()) && 764 "reduction input var is a reference"); 765 766 reductionVars.push_back(symVal); 767 } 768 const bool isByRef = doReductionByRef(reductionVars); 769 770 if (const auto &redDefinedOp = 771 std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) { 772 const auto &intrinsicOp{ 773 std::get<omp::clause::DefinedOperator::IntrinsicOperator>( 774 redDefinedOp->u)}; 775 ReductionIdentifier redId = getReductionType(intrinsicOp); 776 switch (redId) { 777 case ReductionIdentifier::ADD: 778 case ReductionIdentifier::MULTIPLY: 779 case ReductionIdentifier::AND: 780 case ReductionIdentifier::EQV: 781 case ReductionIdentifier::OR: 782 case ReductionIdentifier::NEQV: 783 break; 784 default: 785 TODO(currentLocation, 786 "Reduction of some intrinsic operators is not supported"); 787 break; 788 } 789 790 for (mlir::Value symVal : reductionVars) { 791 auto redType = mlir::cast<fir::ReferenceType>(symVal.getType()); 792 const auto &kindMap = firOpBuilder.getKindMap(); 793 if (mlir::isa<fir::LogicalType>(redType.getEleTy())) 794 decl = createDeclareReduction(firOpBuilder, 795 getReductionName(intrinsicOp, kindMap, 796 firOpBuilder.getI1Type(), 797 isByRef), 798 redId, redType, currentLocation, isByRef); 799 else 800 decl = createDeclareReduction( 801 firOpBuilder, 802 getReductionName(intrinsicOp, kindMap, redType, isByRef), redId, 803 redType, currentLocation, isByRef); 804 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( 805 firOpBuilder.getContext(), decl.getSymName())); 806 } 807 } else if (const auto *reductionIntrinsic = 808 std::get_if<omp::clause::ProcedureDesignator>( 809 &redOperator.u)) { 810 if (ReductionProcessor::supportedIntrinsicProcReduction( 811 *reductionIntrinsic)) { 812 ReductionProcessor::ReductionIdentifier redId = 813 ReductionProcessor::getReductionType(*reductionIntrinsic); 814 for (const Object &object : objectList) { 815 const Fortran::semantics::Symbol *symbol = object.id(); 816 mlir::Value symVal = converter.getSymbolAddress(*symbol); 817 if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) 818 symVal = declOp.getBase(); 819 auto redType = mlir::cast<fir::ReferenceType>(symVal.getType()); 820 if (!redType.getEleTy().isIntOrIndexOrFloat()) 821 TODO(currentLocation, "User Defined Reduction on non-trivial type"); 822 decl = createDeclareReduction( 823 firOpBuilder, 824 getReductionName(getRealName(*reductionIntrinsic).ToString(), 825 firOpBuilder.getKindMap(), redType, isByRef), 826 redId, redType, currentLocation, isByRef); 827 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( 828 firOpBuilder.getContext(), decl.getSymName())); 829 } 830 } 831 } 832 } 833 834 const Fortran::semantics::SourceName 835 ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) { 836 return symbol->GetUltimate().name(); 837 } 838 839 const Fortran::semantics::SourceName 840 ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) { 841 return getRealName(pd.v.id()); 842 } 843 844 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId, 845 mlir::Location loc) { 846 switch (redId) { 847 case ReductionIdentifier::ADD: 848 case ReductionIdentifier::OR: 849 case ReductionIdentifier::NEQV: 850 return 0; 851 case ReductionIdentifier::MULTIPLY: 852 case ReductionIdentifier::AND: 853 case ReductionIdentifier::EQV: 854 return 1; 855 default: 856 TODO(loc, "Reduction of some intrinsic operators is not supported"); 857 } 858 } 859 860 } // namespace omp 861 } // namespace lower 862 } // namespace Fortran 863