1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ReductionProcessor.h" 14 15 #include "flang/Lower/AbstractConverter.h" 16 #include "flang/Lower/ConvertType.h" 17 #include "flang/Lower/SymbolMap.h" 18 #include "flang/Optimizer/Builder/Complex.h" 19 #include "flang/Optimizer/Builder/HLFIRTools.h" 20 #include "flang/Optimizer/Builder/Todo.h" 21 #include "flang/Optimizer/Dialect/FIRType.h" 22 #include "flang/Optimizer/HLFIR/HLFIROps.h" 23 #include "flang/Optimizer/Support/FatalError.h" 24 #include "flang/Parser/tools.h" 25 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 26 #include "llvm/Support/CommandLine.h" 27 28 static llvm::cl::opt<bool> forceByrefReduction( 29 "force-byref-reduction", 30 llvm::cl::desc("Pass all reduction arguments by reference"), 31 llvm::cl::Hidden); 32 33 namespace Fortran { 34 namespace lower { 35 namespace omp { 36 37 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 38 const omp::clause::ProcedureDesignator &pd) { 39 auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>( 40 getRealName(pd.v.sym()).ToString()) 41 .Case("max", ReductionIdentifier::MAX) 42 .Case("min", ReductionIdentifier::MIN) 43 .Case("iand", ReductionIdentifier::IAND) 44 .Case("ior", ReductionIdentifier::IOR) 45 .Case("ieor", ReductionIdentifier::IEOR) 46 .Default(std::nullopt); 47 assert(redType && "Invalid Reduction"); 48 return *redType; 49 } 50 51 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 52 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) { 53 switch (intrinsicOp) { 54 case omp::clause::DefinedOperator::IntrinsicOperator::Add: 55 return ReductionIdentifier::ADD; 56 case omp::clause::DefinedOperator::IntrinsicOperator::Subtract: 57 return ReductionIdentifier::SUBTRACT; 58 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: 59 return ReductionIdentifier::MULTIPLY; 60 case omp::clause::DefinedOperator::IntrinsicOperator::AND: 61 return ReductionIdentifier::AND; 62 case omp::clause::DefinedOperator::IntrinsicOperator::EQV: 63 return ReductionIdentifier::EQV; 64 case omp::clause::DefinedOperator::IntrinsicOperator::OR: 65 return ReductionIdentifier::OR; 66 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: 67 return ReductionIdentifier::NEQV; 68 default: 69 llvm_unreachable("unexpected intrinsic operator in reduction"); 70 } 71 } 72 73 bool ReductionProcessor::supportedIntrinsicProcReduction( 74 const omp::clause::ProcedureDesignator &pd) { 75 semantics::Symbol *sym = pd.v.sym(); 76 if (!sym->GetUltimate().attrs().test(semantics::Attr::INTRINSIC)) 77 return false; 78 auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString()) 79 .Case("max", true) 80 .Case("min", true) 81 .Case("iand", true) 82 .Case("ior", true) 83 .Case("ieor", true) 84 .Default(false); 85 return redType; 86 } 87 88 std::string 89 ReductionProcessor::getReductionName(llvm::StringRef name, 90 const fir::KindMapping &kindMap, 91 mlir::Type ty, bool isByRef) { 92 ty = fir::unwrapRefType(ty); 93 94 // extra string to distinguish reduction functions for variables passed by 95 // reference 96 llvm::StringRef byrefAddition{""}; 97 if (isByRef) 98 byrefAddition = "_byref"; 99 100 return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str()); 101 } 102 103 std::string ReductionProcessor::getReductionName( 104 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, 105 const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) { 106 std::string reductionName; 107 108 switch (intrinsicOp) { 109 case omp::clause::DefinedOperator::IntrinsicOperator::Add: 110 reductionName = "add_reduction"; 111 break; 112 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: 113 reductionName = "multiply_reduction"; 114 break; 115 case omp::clause::DefinedOperator::IntrinsicOperator::AND: 116 return "and_reduction"; 117 case omp::clause::DefinedOperator::IntrinsicOperator::EQV: 118 return "eqv_reduction"; 119 case omp::clause::DefinedOperator::IntrinsicOperator::OR: 120 return "or_reduction"; 121 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: 122 return "neqv_reduction"; 123 default: 124 reductionName = "other_reduction"; 125 break; 126 } 127 128 return getReductionName(reductionName, kindMap, ty, isByRef); 129 } 130 131 mlir::Value 132 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, 133 ReductionIdentifier redId, 134 fir::FirOpBuilder &builder) { 135 type = fir::unwrapRefType(type); 136 if (!fir::isa_integer(type) && !fir::isa_real(type) && 137 !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type)) 138 TODO(loc, "Reduction of some types is not supported"); 139 switch (redId) { 140 case ReductionIdentifier::MAX: { 141 if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) { 142 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 143 return builder.createRealConstant( 144 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); 145 } 146 unsigned bits = type.getIntOrFloatBitWidth(); 147 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); 148 return builder.createIntegerConstant(loc, type, minInt); 149 } 150 case ReductionIdentifier::MIN: { 151 if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) { 152 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 153 return builder.createRealConstant( 154 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); 155 } 156 unsigned bits = type.getIntOrFloatBitWidth(); 157 int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); 158 return builder.createIntegerConstant(loc, type, maxInt); 159 } 160 case ReductionIdentifier::IOR: { 161 unsigned bits = type.getIntOrFloatBitWidth(); 162 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 163 return builder.createIntegerConstant(loc, type, zeroInt); 164 } 165 case ReductionIdentifier::IEOR: { 166 unsigned bits = type.getIntOrFloatBitWidth(); 167 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 168 return builder.createIntegerConstant(loc, type, zeroInt); 169 } 170 case ReductionIdentifier::IAND: { 171 unsigned bits = type.getIntOrFloatBitWidth(); 172 int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); 173 return builder.createIntegerConstant(loc, type, allOnInt); 174 } 175 case ReductionIdentifier::ADD: 176 case ReductionIdentifier::MULTIPLY: 177 case ReductionIdentifier::AND: 178 case ReductionIdentifier::OR: 179 case ReductionIdentifier::EQV: 180 case ReductionIdentifier::NEQV: 181 if (auto cplxTy = mlir::dyn_cast<mlir::ComplexType>(type)) { 182 mlir::Type realTy = cplxTy.getElementType(); 183 mlir::Value initRe = builder.createRealConstant( 184 loc, realTy, getOperationIdentity(redId, loc)); 185 mlir::Value initIm = builder.createRealConstant(loc, realTy, 0); 186 187 return fir::factory::Complex{builder, loc}.createComplex(type, initRe, 188 initIm); 189 } 190 if (mlir::isa<mlir::FloatType>(type)) 191 return builder.create<mlir::arith::ConstantOp>( 192 loc, type, 193 builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc))); 194 195 if (mlir::isa<fir::LogicalType>(type)) { 196 mlir::Value intConst = builder.create<mlir::arith::ConstantOp>( 197 loc, builder.getI1Type(), 198 builder.getIntegerAttr(builder.getI1Type(), 199 getOperationIdentity(redId, loc))); 200 return builder.createConvert(loc, type, intConst); 201 } 202 203 return builder.create<mlir::arith::ConstantOp>( 204 loc, type, 205 builder.getIntegerAttr(type, getOperationIdentity(redId, loc))); 206 case ReductionIdentifier::ID: 207 case ReductionIdentifier::USER_DEF_OP: 208 case ReductionIdentifier::SUBTRACT: 209 TODO(loc, "Reduction of some identifier types is not supported"); 210 } 211 llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue"); 212 } 213 214 mlir::Value ReductionProcessor::createScalarCombiner( 215 fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId, 216 mlir::Type type, mlir::Value op1, mlir::Value op2) { 217 mlir::Value reductionOp; 218 type = fir::unwrapRefType(type); 219 switch (redId) { 220 case ReductionIdentifier::MAX: 221 reductionOp = 222 getReductionOperation<mlir::arith::MaxNumFOp, mlir::arith::MaxSIOp>( 223 builder, type, loc, op1, op2); 224 break; 225 case ReductionIdentifier::MIN: 226 reductionOp = 227 getReductionOperation<mlir::arith::MinNumFOp, mlir::arith::MinSIOp>( 228 builder, type, loc, op1, op2); 229 break; 230 case ReductionIdentifier::IOR: 231 assert((type.isIntOrIndex()) && "only integer is expected"); 232 reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2); 233 break; 234 case ReductionIdentifier::IEOR: 235 assert((type.isIntOrIndex()) && "only integer is expected"); 236 reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2); 237 break; 238 case ReductionIdentifier::IAND: 239 assert((type.isIntOrIndex()) && "only integer is expected"); 240 reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2); 241 break; 242 case ReductionIdentifier::ADD: 243 reductionOp = 244 getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp, 245 fir::AddcOp>(builder, type, loc, op1, op2); 246 break; 247 case ReductionIdentifier::MULTIPLY: 248 reductionOp = 249 getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp, 250 fir::MulcOp>(builder, type, loc, op1, op2); 251 break; 252 case ReductionIdentifier::AND: { 253 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 254 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 255 256 mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1); 257 258 reductionOp = builder.createConvert(loc, type, andiOp); 259 break; 260 } 261 case ReductionIdentifier::OR: { 262 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 263 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 264 265 mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1); 266 267 reductionOp = builder.createConvert(loc, type, oriOp); 268 break; 269 } 270 case ReductionIdentifier::EQV: { 271 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 272 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 273 274 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 275 loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); 276 277 reductionOp = builder.createConvert(loc, type, cmpiOp); 278 break; 279 } 280 case ReductionIdentifier::NEQV: { 281 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 282 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 283 284 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 285 loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); 286 287 reductionOp = builder.createConvert(loc, type, cmpiOp); 288 break; 289 } 290 default: 291 TODO(loc, "Reduction of some intrinsic operators is not supported"); 292 } 293 294 return reductionOp; 295 } 296 297 /// Generate a fir::ShapeShift op describing the provided boxed array. 298 static fir::ShapeShiftOp getShapeShift(fir::FirOpBuilder &builder, 299 mlir::Location loc, mlir::Value box) { 300 fir::SequenceType sequenceType = mlir::cast<fir::SequenceType>( 301 hlfir::getFortranElementOrSequenceType(box.getType())); 302 const unsigned rank = sequenceType.getDimension(); 303 llvm::SmallVector<mlir::Value> lbAndExtents; 304 lbAndExtents.reserve(rank * 2); 305 306 mlir::Type idxTy = builder.getIndexType(); 307 for (unsigned i = 0; i < rank; ++i) { 308 // TODO: ideally we want to hoist box reads out of the critical section. 309 // We could do this by having box dimensions in block arguments like 310 // OpenACC does 311 mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i); 312 auto dimInfo = 313 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, box, dim); 314 lbAndExtents.push_back(dimInfo.getLowerBound()); 315 lbAndExtents.push_back(dimInfo.getExtent()); 316 } 317 318 auto shapeShiftTy = fir::ShapeShiftType::get(builder.getContext(), rank); 319 auto shapeShift = 320 builder.create<fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents); 321 return shapeShift; 322 } 323 324 /// Create reduction combiner region for reduction variables which are boxed 325 /// arrays 326 static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc, 327 ReductionProcessor::ReductionIdentifier redId, 328 fir::BaseBoxType boxTy, mlir::Value lhs, 329 mlir::Value rhs) { 330 fir::SequenceType seqTy = mlir::dyn_cast_or_null<fir::SequenceType>( 331 fir::unwrapRefType(boxTy.getEleTy())); 332 fir::HeapType heapTy = 333 mlir::dyn_cast_or_null<fir::HeapType>(boxTy.getEleTy()); 334 fir::PointerType ptrTy = 335 mlir::dyn_cast_or_null<fir::PointerType>(boxTy.getEleTy()); 336 if ((!seqTy || seqTy.hasUnknownShape()) && !heapTy && !ptrTy) 337 TODO(loc, "Unsupported boxed type in OpenMP reduction"); 338 339 // load fir.ref<fir.box<...>> 340 mlir::Value lhsAddr = lhs; 341 lhs = builder.create<fir::LoadOp>(loc, lhs); 342 rhs = builder.create<fir::LoadOp>(loc, rhs); 343 344 if ((heapTy || ptrTy) && !seqTy) { 345 // get box contents (heap pointers) 346 lhs = builder.create<fir::BoxAddrOp>(loc, lhs); 347 rhs = builder.create<fir::BoxAddrOp>(loc, rhs); 348 mlir::Value lhsValAddr = lhs; 349 350 // load heap pointers 351 lhs = builder.create<fir::LoadOp>(loc, lhs); 352 rhs = builder.create<fir::LoadOp>(loc, rhs); 353 354 mlir::Type eleTy = heapTy ? heapTy.getEleTy() : ptrTy.getEleTy(); 355 356 mlir::Value result = ReductionProcessor::createScalarCombiner( 357 builder, loc, redId, eleTy, lhs, rhs); 358 builder.create<fir::StoreOp>(loc, result, lhsValAddr); 359 builder.create<mlir::omp::YieldOp>(loc, lhsAddr); 360 return; 361 } 362 363 fir::ShapeShiftOp shapeShift = getShapeShift(builder, loc, lhs); 364 365 // Iterate over array elements, applying the equivalent scalar reduction: 366 367 // F2018 5.4.10.2: Unallocated allocatable variables may not be referenced 368 // and so no null check is needed here before indexing into the (possibly 369 // allocatable) arrays. 370 371 // A hlfir::elemental here gets inlined with a temporary so create the 372 // loop nest directly. 373 // This function already controls all of the code in this region so we 374 // know this won't miss any opportuinties for clever elemental inlining 375 hlfir::LoopNest nest = hlfir::genLoopNest( 376 loc, builder, shapeShift.getExtents(), /*isUnordered=*/true); 377 builder.setInsertionPointToStart(nest.innerLoop.getBody()); 378 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy()); 379 auto lhsEleAddr = builder.create<fir::ArrayCoorOp>( 380 loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{}, 381 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); 382 auto rhsEleAddr = builder.create<fir::ArrayCoorOp>( 383 loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{}, 384 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); 385 auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr); 386 auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr); 387 mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner( 388 builder, loc, redId, refTy, lhsEle, rhsEle); 389 builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr); 390 391 builder.setInsertionPointAfter(nest.outerLoop); 392 builder.create<mlir::omp::YieldOp>(loc, lhsAddr); 393 } 394 395 // generate combiner region for reduction operations 396 static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc, 397 ReductionProcessor::ReductionIdentifier redId, 398 mlir::Type ty, mlir::Value lhs, mlir::Value rhs, 399 bool isByRef) { 400 ty = fir::unwrapRefType(ty); 401 402 if (fir::isa_trivial(ty)) { 403 mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs); 404 mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs); 405 406 mlir::Value result = ReductionProcessor::createScalarCombiner( 407 builder, loc, redId, ty, lhsLoaded, rhsLoaded); 408 if (isByRef) { 409 builder.create<fir::StoreOp>(loc, result, lhs); 410 builder.create<mlir::omp::YieldOp>(loc, lhs); 411 } else { 412 builder.create<mlir::omp::YieldOp>(loc, result); 413 } 414 return; 415 } 416 // all arrays should have been boxed 417 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { 418 genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs); 419 return; 420 } 421 422 TODO(loc, "OpenMP genCombiner for unsupported reduction variable type"); 423 } 424 425 static void 426 createReductionCleanupRegion(fir::FirOpBuilder &builder, mlir::Location loc, 427 mlir::omp::DeclareReductionOp &reductionDecl) { 428 mlir::Type redTy = reductionDecl.getType(); 429 430 mlir::Region &cleanupRegion = reductionDecl.getCleanupRegion(); 431 assert(cleanupRegion.empty()); 432 mlir::Block *block = 433 builder.createBlock(&cleanupRegion, cleanupRegion.end(), {redTy}, {loc}); 434 builder.setInsertionPointToEnd(block); 435 436 auto typeError = [loc]() { 437 fir::emitFatalError(loc, 438 "Attempt to create an omp reduction cleanup region " 439 "for a type that wasn't allocated", 440 /*genCrashDiag=*/true); 441 }; 442 443 mlir::Type valTy = fir::unwrapRefType(redTy); 444 if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(valTy)) { 445 if (!mlir::isa<fir::HeapType, fir::PointerType>(boxTy.getEleTy())) { 446 mlir::Type innerTy = fir::extractSequenceType(boxTy); 447 if (!mlir::isa<fir::SequenceType>(innerTy)) 448 typeError(); 449 } 450 451 mlir::Value arg = block->getArgument(0); 452 arg = builder.loadIfRef(loc, arg); 453 assert(mlir::isa<fir::BaseBoxType>(arg.getType())); 454 455 // Deallocate box 456 // The FIR type system doesn't nesecarrily know that this is a mutable box 457 // if we allocated the thread local array on the heap to avoid looped stack 458 // allocations. 459 mlir::Value addr = 460 hlfir::genVariableRawAddress(loc, builder, hlfir::Entity{arg}); 461 mlir::Value isAllocated = builder.genIsNotNullAddr(loc, addr); 462 fir::IfOp ifOp = 463 builder.create<fir::IfOp>(loc, isAllocated, /*withElseRegion=*/false); 464 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 465 466 mlir::Value cast = builder.createConvert( 467 loc, fir::HeapType::get(fir::dyn_cast_ptrEleTy(addr.getType())), addr); 468 builder.create<fir::FreeMemOp>(loc, cast); 469 470 builder.setInsertionPointAfter(ifOp); 471 builder.create<mlir::omp::YieldOp>(loc); 472 return; 473 } 474 475 typeError(); 476 } 477 478 // like fir::unwrapSeqOrBoxedSeqType except it also works for non-sequence boxes 479 static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) { 480 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) 481 return seqTy.getEleTy(); 482 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { 483 auto eleTy = fir::unwrapRefType(boxTy.getEleTy()); 484 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) 485 return seqTy.getEleTy(); 486 return eleTy; 487 } 488 return ty; 489 } 490 491 static void createReductionAllocAndInitRegions( 492 fir::FirOpBuilder &builder, mlir::Location loc, 493 mlir::omp::DeclareReductionOp &reductionDecl, 494 const ReductionProcessor::ReductionIdentifier redId, mlir::Type type, 495 bool isByRef) { 496 auto yield = [&](mlir::Value ret) { 497 builder.create<mlir::omp::YieldOp>(loc, ret); 498 }; 499 500 mlir::Block *allocBlock = nullptr; 501 mlir::Block *initBlock = nullptr; 502 if (isByRef) { 503 allocBlock = 504 builder.createBlock(&reductionDecl.getAllocRegion(), 505 reductionDecl.getAllocRegion().end(), {}, {}); 506 initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(), 507 reductionDecl.getInitializerRegion().end(), 508 {type, type}, {loc, loc}); 509 } else { 510 initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(), 511 reductionDecl.getInitializerRegion().end(), 512 {type}, {loc}); 513 } 514 515 mlir::Type ty = fir::unwrapRefType(type); 516 builder.setInsertionPointToEnd(initBlock); 517 mlir::Value initValue = ReductionProcessor::getReductionInitValue( 518 loc, unwrapSeqOrBoxedType(ty), redId, builder); 519 520 if (fir::isa_trivial(ty)) { 521 if (isByRef) { 522 // alloc region 523 { 524 builder.setInsertionPointToEnd(allocBlock); 525 mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty); 526 yield(alloca); 527 } 528 529 // init region 530 { 531 builder.setInsertionPointToEnd(initBlock); 532 // block arg is mapped to the alloca yielded from the alloc region 533 mlir::Value alloc = reductionDecl.getInitializerAllocArg(); 534 builder.createStoreWithConvert(loc, initValue, alloc); 535 yield(alloc); 536 } 537 return; 538 } 539 // by val 540 yield(initValue); 541 return; 542 } 543 544 // check if an allocatable box is unallocated. If so, initialize the boxAlloca 545 // to be unallocated e.g. 546 // %box_alloca = fir.alloca !fir.box<!fir.heap<...>> 547 // %addr = fir.box_addr %box 548 // if (%addr == 0) { 549 // %nullbox = fir.embox %addr 550 // fir.store %nullbox to %box_alloca 551 // } else { 552 // // ... 553 // fir.store %something to %box_alloca 554 // } 555 // omp.yield %box_alloca 556 mlir::Value moldArg = 557 builder.loadIfRef(loc, reductionDecl.getInitializerMoldArg()); 558 auto handleNullAllocatable = [&](mlir::Value boxAlloca) -> fir::IfOp { 559 mlir::Value addr = builder.create<fir::BoxAddrOp>(loc, moldArg); 560 mlir::Value isNotAllocated = builder.genIsNullAddr(loc, addr); 561 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, isNotAllocated, 562 /*withElseRegion=*/true); 563 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 564 // just embox the null address and return 565 mlir::Value nullBox = builder.create<fir::EmboxOp>(loc, ty, addr); 566 builder.create<fir::StoreOp>(loc, nullBox, boxAlloca); 567 return ifOp; 568 }; 569 570 // all arrays are boxed 571 if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) { 572 assert(isByRef && "passing boxes by value is unsupported"); 573 bool isAllocatableOrPointer = 574 mlir::isa<fir::HeapType, fir::PointerType>(boxTy.getEleTy()); 575 576 // alloc region 577 { 578 builder.setInsertionPointToEnd(allocBlock); 579 mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty); 580 yield(boxAlloca); 581 } 582 583 // init region 584 builder.setInsertionPointToEnd(initBlock); 585 mlir::Value boxAlloca = reductionDecl.getInitializerAllocArg(); 586 mlir::Type innerTy = fir::unwrapRefType(boxTy.getEleTy()); 587 if (fir::isa_trivial(innerTy)) { 588 // boxed non-sequence value e.g. !fir.box<!fir.heap<i32>> 589 if (!isAllocatableOrPointer) 590 TODO(loc, "Reduction of non-allocatable trivial typed box"); 591 592 fir::IfOp ifUnallocated = handleNullAllocatable(boxAlloca); 593 594 builder.setInsertionPointToStart(&ifUnallocated.getElseRegion().front()); 595 mlir::Value valAlloc = builder.create<fir::AllocMemOp>(loc, innerTy); 596 builder.createStoreWithConvert(loc, initValue, valAlloc); 597 mlir::Value box = builder.create<fir::EmboxOp>(loc, ty, valAlloc); 598 builder.create<fir::StoreOp>(loc, box, boxAlloca); 599 600 auto insPt = builder.saveInsertionPoint(); 601 createReductionCleanupRegion(builder, loc, reductionDecl); 602 builder.restoreInsertionPoint(insPt); 603 builder.setInsertionPointAfter(ifUnallocated); 604 yield(boxAlloca); 605 return; 606 } 607 innerTy = fir::extractSequenceType(boxTy); 608 if (!mlir::isa<fir::SequenceType>(innerTy)) 609 TODO(loc, "Unsupported boxed type for reduction"); 610 611 fir::IfOp ifUnallocated{nullptr}; 612 if (isAllocatableOrPointer) { 613 ifUnallocated = handleNullAllocatable(boxAlloca); 614 builder.setInsertionPointToStart(&ifUnallocated.getElseRegion().front()); 615 } 616 617 // Create the private copy from the initial fir.box: 618 mlir::Value loadedBox = builder.loadIfRef(loc, moldArg); 619 hlfir::Entity source = hlfir::Entity{loadedBox}; 620 621 // Allocating on the heap in case the whole reduction is nested inside of a 622 // loop 623 // TODO: compare performance here to using allocas - this could be made to 624 // work by inserting stacksave/stackrestore around the reduction in 625 // openmpirbuilder 626 auto [temp, needsDealloc] = createTempFromMold(loc, builder, source); 627 // if needsDealloc isn't statically false, add cleanup region. Always 628 // do this for allocatable boxes because they might have been re-allocated 629 // in the body of the loop/parallel region 630 631 std::optional<int64_t> cstNeedsDealloc = 632 fir::getIntIfConstant(needsDealloc); 633 assert(cstNeedsDealloc.has_value() && 634 "createTempFromMold decides this statically"); 635 if (cstNeedsDealloc.has_value() && *cstNeedsDealloc != false) { 636 mlir::OpBuilder::InsertionGuard guard(builder); 637 createReductionCleanupRegion(builder, loc, reductionDecl); 638 } else { 639 assert(!isAllocatableOrPointer && 640 "Pointer-like arrays must be heap allocated"); 641 } 642 643 // Put the temporary inside of a box: 644 // hlfir::genVariableBox doesn't handle non-default lower bounds 645 mlir::Value box; 646 fir::ShapeShiftOp shapeShift = getShapeShift(builder, loc, loadedBox); 647 mlir::Type boxType = loadedBox.getType(); 648 if (mlir::isa<fir::BaseBoxType>(temp.getType())) 649 // the box created by the declare form createTempFromMold is missing lower 650 // bounds info 651 box = builder.create<fir::ReboxOp>(loc, boxType, temp, shapeShift, 652 /*shift=*/mlir::Value{}); 653 else 654 box = builder.create<fir::EmboxOp>( 655 loc, boxType, temp, shapeShift, 656 /*slice=*/mlir::Value{}, 657 /*typeParams=*/llvm::ArrayRef<mlir::Value>{}); 658 659 builder.create<hlfir::AssignOp>(loc, initValue, box); 660 builder.create<fir::StoreOp>(loc, box, boxAlloca); 661 if (ifUnallocated) 662 builder.setInsertionPointAfter(ifUnallocated); 663 yield(boxAlloca); 664 return; 665 } 666 667 TODO(loc, "createReductionInitRegion for unsupported type"); 668 } 669 670 mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction( 671 fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, 672 const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, 673 bool isByRef) { 674 mlir::OpBuilder::InsertionGuard guard(builder); 675 mlir::ModuleOp module = builder.getModule(); 676 677 assert(!reductionOpName.empty()); 678 679 auto decl = 680 module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName); 681 if (decl) 682 return decl; 683 684 mlir::OpBuilder modBuilder(module.getBodyRegion()); 685 mlir::Type valTy = fir::unwrapRefType(type); 686 if (!isByRef) 687 type = valTy; 688 689 decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName, 690 type); 691 createReductionAllocAndInitRegions(builder, loc, decl, redId, type, isByRef); 692 693 builder.createBlock(&decl.getReductionRegion(), 694 decl.getReductionRegion().end(), {type, type}, 695 {loc, loc}); 696 697 builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); 698 mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); 699 mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); 700 genCombiner(builder, loc, redId, type, op1, op2, isByRef); 701 702 return decl; 703 } 704 705 static bool doReductionByRef(mlir::Value reductionVar) { 706 if (forceByrefReduction) 707 return true; 708 709 if (auto declare = 710 mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp())) 711 reductionVar = declare.getMemref(); 712 713 if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType()))) 714 return true; 715 716 return false; 717 } 718 719 void ReductionProcessor::addDeclareReduction( 720 mlir::Location currentLocation, lower::AbstractConverter &converter, 721 const omp::clause::Reduction &reduction, 722 llvm::SmallVectorImpl<mlir::Value> &reductionVars, 723 llvm::SmallVectorImpl<bool> &reduceVarByRef, 724 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, 725 llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) { 726 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 727 728 if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>( 729 reduction.t)) 730 TODO(currentLocation, "Reduction modifiers are not supported"); 731 732 mlir::omp::DeclareReductionOp decl; 733 const auto &redOperatorList{ 734 std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)}; 735 assert(redOperatorList.size() == 1 && "Expecting single operator"); 736 const auto &redOperator = redOperatorList.front(); 737 const auto &objectList{std::get<omp::ObjectList>(reduction.t)}; 738 739 if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) { 740 if (const auto *reductionIntrinsic = 741 std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) { 742 if (!ReductionProcessor::supportedIntrinsicProcReduction( 743 *reductionIntrinsic)) { 744 return; 745 } 746 } else { 747 return; 748 } 749 } 750 751 // Reduction variable processing common to both intrinsic operators and 752 // procedure designators 753 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 754 for (const Object &object : objectList) { 755 const semantics::Symbol *symbol = object.sym(); 756 reductionSymbols.push_back(symbol); 757 mlir::Value symVal = converter.getSymbolAddress(*symbol); 758 mlir::Type eleType; 759 auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType()); 760 if (refType) 761 eleType = refType.getEleTy(); 762 else 763 eleType = symVal.getType(); 764 765 // all arrays must be boxed so that we have convenient access to all the 766 // information needed to iterate over the array 767 if (mlir::isa<fir::SequenceType>(eleType)) { 768 // For Host associated symbols, use `SymbolBox` instead 769 lower::SymbolBox symBox = converter.lookupOneLevelUpSymbol(*symbol); 770 hlfir::Entity entity{symBox.getAddr()}; 771 entity = genVariableBox(currentLocation, builder, entity); 772 mlir::Value box = entity.getBase(); 773 774 // Always pass the box by reference so that the OpenMP dialect 775 // verifiers don't need to know anything about fir.box 776 auto alloca = 777 builder.create<fir::AllocaOp>(currentLocation, box.getType()); 778 builder.create<fir::StoreOp>(currentLocation, box, alloca); 779 780 symVal = alloca; 781 } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) { 782 // boxed arrays are passed as values not by reference. Unfortunately, 783 // we can't pass a box by value to omp.redution_declare, so turn it 784 // into a reference 785 786 auto alloca = 787 builder.create<fir::AllocaOp>(currentLocation, symVal.getType()); 788 builder.create<fir::StoreOp>(currentLocation, symVal, alloca); 789 symVal = alloca; 790 } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) { 791 symVal = declOp.getBase(); 792 } 793 794 // this isn't the same as the by-val and by-ref passing later in the 795 // pipeline. Both styles assume that the variable is a reference at 796 // this point 797 assert(mlir::isa<fir::ReferenceType>(symVal.getType()) && 798 "reduction input var is a reference"); 799 800 reductionVars.push_back(symVal); 801 reduceVarByRef.push_back(doReductionByRef(symVal)); 802 } 803 804 for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) { 805 auto redType = mlir::cast<fir::ReferenceType>(symVal.getType()); 806 const auto &kindMap = firOpBuilder.getKindMap(); 807 std::string reductionName; 808 ReductionIdentifier redId; 809 mlir::Type redNameTy = redType; 810 if (mlir::isa<fir::LogicalType>(redType.getEleTy())) 811 redNameTy = builder.getI1Type(); 812 813 if (const auto &redDefinedOp = 814 std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) { 815 const auto &intrinsicOp{ 816 std::get<omp::clause::DefinedOperator::IntrinsicOperator>( 817 redDefinedOp->u)}; 818 redId = getReductionType(intrinsicOp); 819 switch (redId) { 820 case ReductionIdentifier::ADD: 821 case ReductionIdentifier::MULTIPLY: 822 case ReductionIdentifier::AND: 823 case ReductionIdentifier::EQV: 824 case ReductionIdentifier::OR: 825 case ReductionIdentifier::NEQV: 826 break; 827 default: 828 TODO(currentLocation, 829 "Reduction of some intrinsic operators is not supported"); 830 break; 831 } 832 833 reductionName = 834 getReductionName(intrinsicOp, kindMap, redNameTy, isByRef); 835 } else if (const auto *reductionIntrinsic = 836 std::get_if<omp::clause::ProcedureDesignator>( 837 &redOperator.u)) { 838 if (!ReductionProcessor::supportedIntrinsicProcReduction( 839 *reductionIntrinsic)) { 840 TODO(currentLocation, "Unsupported intrinsic proc reduction"); 841 } 842 redId = getReductionType(*reductionIntrinsic); 843 reductionName = 844 getReductionName(getRealName(*reductionIntrinsic).ToString(), kindMap, 845 redNameTy, isByRef); 846 } else { 847 TODO(currentLocation, "Unexpected reduction type"); 848 } 849 850 decl = createDeclareReduction(firOpBuilder, reductionName, redId, redType, 851 currentLocation, isByRef); 852 reductionDeclSymbols.push_back( 853 mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName())); 854 } 855 } 856 857 const semantics::SourceName 858 ReductionProcessor::getRealName(const semantics::Symbol *symbol) { 859 return symbol->GetUltimate().name(); 860 } 861 862 const semantics::SourceName 863 ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) { 864 return getRealName(pd.v.sym()); 865 } 866 867 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId, 868 mlir::Location loc) { 869 switch (redId) { 870 case ReductionIdentifier::ADD: 871 case ReductionIdentifier::OR: 872 case ReductionIdentifier::NEQV: 873 return 0; 874 case ReductionIdentifier::MULTIPLY: 875 case ReductionIdentifier::AND: 876 case ReductionIdentifier::EQV: 877 return 1; 878 default: 879 TODO(loc, "Reduction of some intrinsic operators is not supported"); 880 } 881 } 882 883 } // namespace omp 884 } // namespace lower 885 } // namespace Fortran 886