1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ReductionProcessor.h" 14 15 #include "flang/Lower/AbstractConverter.h" 16 #include "flang/Optimizer/Builder/Todo.h" 17 #include "flang/Optimizer/Dialect/FIRType.h" 18 #include "flang/Optimizer/HLFIR/HLFIROps.h" 19 #include "flang/Parser/tools.h" 20 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 21 #include "llvm/Support/CommandLine.h" 22 23 static llvm::cl::opt<bool> forceByrefReduction( 24 "force-byref-reduction", 25 llvm::cl::desc("Pass all reduction arguments by reference"), 26 llvm::cl::Hidden); 27 28 namespace Fortran { 29 namespace lower { 30 namespace omp { 31 32 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 33 const omp::clause::ProcedureDesignator &pd) { 34 auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>( 35 getRealName(pd.v.id()).ToString()) 36 .Case("max", ReductionIdentifier::MAX) 37 .Case("min", ReductionIdentifier::MIN) 38 .Case("iand", ReductionIdentifier::IAND) 39 .Case("ior", ReductionIdentifier::IOR) 40 .Case("ieor", ReductionIdentifier::IEOR) 41 .Default(std::nullopt); 42 assert(redType && "Invalid Reduction"); 43 return *redType; 44 } 45 46 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 47 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) { 48 switch (intrinsicOp) { 49 case omp::clause::DefinedOperator::IntrinsicOperator::Add: 50 return ReductionIdentifier::ADD; 51 case omp::clause::DefinedOperator::IntrinsicOperator::Subtract: 52 return ReductionIdentifier::SUBTRACT; 53 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: 54 return ReductionIdentifier::MULTIPLY; 55 case omp::clause::DefinedOperator::IntrinsicOperator::AND: 56 return ReductionIdentifier::AND; 57 case omp::clause::DefinedOperator::IntrinsicOperator::EQV: 58 return ReductionIdentifier::EQV; 59 case omp::clause::DefinedOperator::IntrinsicOperator::OR: 60 return ReductionIdentifier::OR; 61 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: 62 return ReductionIdentifier::NEQV; 63 default: 64 llvm_unreachable("unexpected intrinsic operator in reduction"); 65 } 66 } 67 68 bool ReductionProcessor::supportedIntrinsicProcReduction( 69 const omp::clause::ProcedureDesignator &pd) { 70 Fortran::semantics::Symbol *sym = pd.v.id(); 71 if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC)) 72 return false; 73 auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString()) 74 .Case("max", true) 75 .Case("min", true) 76 .Case("iand", true) 77 .Case("ior", true) 78 .Case("ieor", true) 79 .Default(false); 80 return redType; 81 } 82 83 std::string ReductionProcessor::getReductionName(llvm::StringRef name, 84 mlir::Type ty, bool isByRef) { 85 ty = fir::unwrapRefType(ty); 86 87 // extra string to distinguish reduction functions for variables passed by 88 // reference 89 llvm::StringRef byrefAddition{""}; 90 if (isByRef) 91 byrefAddition = "_byref"; 92 93 return (llvm::Twine(name) + 94 (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + 95 llvm::Twine(ty.getIntOrFloatBitWidth()) + byrefAddition) 96 .str(); 97 } 98 99 std::string ReductionProcessor::getReductionName( 100 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, mlir::Type ty, 101 bool isByRef) { 102 std::string reductionName; 103 104 switch (intrinsicOp) { 105 case omp::clause::DefinedOperator::IntrinsicOperator::Add: 106 reductionName = "add_reduction"; 107 break; 108 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: 109 reductionName = "multiply_reduction"; 110 break; 111 case omp::clause::DefinedOperator::IntrinsicOperator::AND: 112 return "and_reduction"; 113 case omp::clause::DefinedOperator::IntrinsicOperator::EQV: 114 return "eqv_reduction"; 115 case omp::clause::DefinedOperator::IntrinsicOperator::OR: 116 return "or_reduction"; 117 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: 118 return "neqv_reduction"; 119 default: 120 reductionName = "other_reduction"; 121 break; 122 } 123 124 return getReductionName(reductionName, ty, isByRef); 125 } 126 127 mlir::Value 128 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, 129 ReductionIdentifier redId, 130 fir::FirOpBuilder &builder) { 131 type = fir::unwrapRefType(type); 132 assert((fir::isa_integer(type) || fir::isa_real(type) || 133 type.isa<fir::LogicalType>()) && 134 "only integer, logical and real types are currently supported"); 135 switch (redId) { 136 case ReductionIdentifier::MAX: { 137 if (auto ty = type.dyn_cast<mlir::FloatType>()) { 138 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 139 return builder.createRealConstant( 140 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); 141 } 142 unsigned bits = type.getIntOrFloatBitWidth(); 143 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); 144 return builder.createIntegerConstant(loc, type, minInt); 145 } 146 case ReductionIdentifier::MIN: { 147 if (auto ty = type.dyn_cast<mlir::FloatType>()) { 148 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 149 return builder.createRealConstant( 150 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); 151 } 152 unsigned bits = type.getIntOrFloatBitWidth(); 153 int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); 154 return builder.createIntegerConstant(loc, type, maxInt); 155 } 156 case ReductionIdentifier::IOR: { 157 unsigned bits = type.getIntOrFloatBitWidth(); 158 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 159 return builder.createIntegerConstant(loc, type, zeroInt); 160 } 161 case ReductionIdentifier::IEOR: { 162 unsigned bits = type.getIntOrFloatBitWidth(); 163 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 164 return builder.createIntegerConstant(loc, type, zeroInt); 165 } 166 case ReductionIdentifier::IAND: { 167 unsigned bits = type.getIntOrFloatBitWidth(); 168 int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); 169 return builder.createIntegerConstant(loc, type, allOnInt); 170 } 171 case ReductionIdentifier::ADD: 172 case ReductionIdentifier::MULTIPLY: 173 case ReductionIdentifier::AND: 174 case ReductionIdentifier::OR: 175 case ReductionIdentifier::EQV: 176 case ReductionIdentifier::NEQV: 177 if (type.isa<mlir::FloatType>()) 178 return builder.create<mlir::arith::ConstantOp>( 179 loc, type, 180 builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc))); 181 182 if (type.isa<fir::LogicalType>()) { 183 mlir::Value intConst = builder.create<mlir::arith::ConstantOp>( 184 loc, builder.getI1Type(), 185 builder.getIntegerAttr(builder.getI1Type(), 186 getOperationIdentity(redId, loc))); 187 return builder.createConvert(loc, type, intConst); 188 } 189 190 return builder.create<mlir::arith::ConstantOp>( 191 loc, type, 192 builder.getIntegerAttr(type, getOperationIdentity(redId, loc))); 193 case ReductionIdentifier::ID: 194 case ReductionIdentifier::USER_DEF_OP: 195 case ReductionIdentifier::SUBTRACT: 196 TODO(loc, "Reduction of some identifier types is not supported"); 197 } 198 llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue"); 199 } 200 201 mlir::Value ReductionProcessor::createScalarCombiner( 202 fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId, 203 mlir::Type type, mlir::Value op1, mlir::Value op2) { 204 mlir::Value reductionOp; 205 type = fir::unwrapRefType(type); 206 switch (redId) { 207 case ReductionIdentifier::MAX: 208 reductionOp = 209 getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>( 210 builder, type, loc, op1, op2); 211 break; 212 case ReductionIdentifier::MIN: 213 reductionOp = 214 getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>( 215 builder, type, loc, op1, op2); 216 break; 217 case ReductionIdentifier::IOR: 218 assert((type.isIntOrIndex()) && "only integer is expected"); 219 reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2); 220 break; 221 case ReductionIdentifier::IEOR: 222 assert((type.isIntOrIndex()) && "only integer is expected"); 223 reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2); 224 break; 225 case ReductionIdentifier::IAND: 226 assert((type.isIntOrIndex()) && "only integer is expected"); 227 reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2); 228 break; 229 case ReductionIdentifier::ADD: 230 reductionOp = 231 getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>( 232 builder, type, loc, op1, op2); 233 break; 234 case ReductionIdentifier::MULTIPLY: 235 reductionOp = 236 getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>( 237 builder, type, loc, op1, op2); 238 break; 239 case ReductionIdentifier::AND: { 240 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 241 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 242 243 mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1); 244 245 reductionOp = builder.createConvert(loc, type, andiOp); 246 break; 247 } 248 case ReductionIdentifier::OR: { 249 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 250 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 251 252 mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1); 253 254 reductionOp = builder.createConvert(loc, type, oriOp); 255 break; 256 } 257 case ReductionIdentifier::EQV: { 258 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 259 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 260 261 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 262 loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); 263 264 reductionOp = builder.createConvert(loc, type, cmpiOp); 265 break; 266 } 267 case ReductionIdentifier::NEQV: { 268 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 269 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 270 271 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 272 loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); 273 274 reductionOp = builder.createConvert(loc, type, cmpiOp); 275 break; 276 } 277 default: 278 TODO(loc, "Reduction of some intrinsic operators is not supported"); 279 } 280 281 return reductionOp; 282 } 283 284 mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl( 285 fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, 286 const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, 287 bool isByRef) { 288 mlir::OpBuilder::InsertionGuard guard(builder); 289 mlir::ModuleOp module = builder.getModule(); 290 291 auto decl = 292 module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName); 293 if (decl) 294 return decl; 295 296 mlir::OpBuilder modBuilder(module.getBodyRegion()); 297 mlir::Type valTy = fir::unwrapRefType(type); 298 if (!isByRef) 299 type = valTy; 300 301 decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName, 302 type); 303 builder.createBlock(&decl.getInitializerRegion(), 304 decl.getInitializerRegion().end(), {type}, {loc}); 305 builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); 306 307 mlir::Value init = getReductionInitValue(loc, type, redId, builder); 308 if (isByRef) { 309 mlir::Value alloca = builder.create<fir::AllocaOp>(loc, valTy); 310 builder.createStoreWithConvert(loc, init, alloca); 311 builder.create<mlir::omp::YieldOp>(loc, alloca); 312 } else { 313 builder.create<mlir::omp::YieldOp>(loc, init); 314 } 315 316 builder.createBlock(&decl.getReductionRegion(), 317 decl.getReductionRegion().end(), {type, type}, 318 {loc, loc}); 319 320 builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); 321 mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); 322 mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); 323 mlir::Value outAddr = op1; 324 325 op1 = builder.loadIfRef(loc, op1); 326 op2 = builder.loadIfRef(loc, op2); 327 328 mlir::Value reductionOp = 329 createScalarCombiner(builder, loc, redId, type, op1, op2); 330 if (isByRef) { 331 builder.create<fir::StoreOp>(loc, reductionOp, outAddr); 332 builder.create<mlir::omp::YieldOp>(loc, outAddr); 333 } else { 334 builder.create<mlir::omp::YieldOp>(loc, reductionOp); 335 } 336 337 return decl; 338 } 339 340 // TODO: By-ref vs by-val reductions are currently toggled for the whole 341 // operation (possibly effecting multiple reduction variables). 342 // This could cause a problem with openmp target reductions because 343 // by-ref trivial types may not be supported. 344 bool ReductionProcessor::doReductionByRef( 345 const llvm::SmallVectorImpl<mlir::Value> &reductionVars) { 346 if (reductionVars.empty()) 347 return false; 348 if (forceByrefReduction) 349 return true; 350 351 for (mlir::Value reductionVar : reductionVars) { 352 if (auto declare = 353 mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp())) 354 reductionVar = declare.getMemref(); 355 356 if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType()))) 357 return true; 358 } 359 return false; 360 } 361 362 void ReductionProcessor::addReductionDecl( 363 mlir::Location currentLocation, 364 Fortran::lower::AbstractConverter &converter, 365 const omp::clause::Reduction &reduction, 366 llvm::SmallVectorImpl<mlir::Value> &reductionVars, 367 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, 368 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> 369 *reductionSymbols) { 370 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 371 mlir::omp::ReductionDeclareOp decl; 372 const auto &redOperator{ 373 std::get<omp::clause::ReductionOperator>(reduction.t)}; 374 const auto &objectList{std::get<omp::ObjectList>(reduction.t)}; 375 376 if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) { 377 if (const auto *reductionIntrinsic = 378 std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) { 379 if (!ReductionProcessor::supportedIntrinsicProcReduction( 380 *reductionIntrinsic)) { 381 return; 382 } 383 } else { 384 return; 385 } 386 } 387 388 // initial pass to collect all reduction vars so we can figure out if this 389 // should happen byref 390 for (const Object &object : objectList) { 391 const Fortran::semantics::Symbol *symbol = object.id(); 392 if (reductionSymbols) 393 reductionSymbols->push_back(symbol); 394 mlir::Value symVal = converter.getSymbolAddress(*symbol); 395 if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) 396 symVal = declOp.getBase(); 397 reductionVars.push_back(symVal); 398 } 399 const bool isByRef = doReductionByRef(reductionVars); 400 401 if (const auto &redDefinedOp = 402 std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) { 403 const auto &intrinsicOp{ 404 std::get<omp::clause::DefinedOperator::IntrinsicOperator>( 405 redDefinedOp->u)}; 406 ReductionIdentifier redId = getReductionType(intrinsicOp); 407 switch (redId) { 408 case ReductionIdentifier::ADD: 409 case ReductionIdentifier::MULTIPLY: 410 case ReductionIdentifier::AND: 411 case ReductionIdentifier::EQV: 412 case ReductionIdentifier::OR: 413 case ReductionIdentifier::NEQV: 414 break; 415 default: 416 TODO(currentLocation, 417 "Reduction of some intrinsic operators is not supported"); 418 break; 419 } 420 421 for (const Object &object : objectList) { 422 const Fortran::semantics::Symbol *symbol = object.id(); 423 mlir::Value symVal = converter.getSymbolAddress(*symbol); 424 if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) 425 symVal = declOp.getBase(); 426 auto redType = symVal.getType().cast<fir::ReferenceType>(); 427 if (redType.getEleTy().isa<fir::LogicalType>()) 428 decl = createReductionDecl( 429 firOpBuilder, 430 getReductionName(intrinsicOp, firOpBuilder.getI1Type(), isByRef), 431 redId, redType, currentLocation, isByRef); 432 else if (redType.getEleTy().isIntOrIndexOrFloat()) { 433 decl = createReductionDecl( 434 firOpBuilder, getReductionName(intrinsicOp, redType, isByRef), 435 redId, redType, currentLocation, isByRef); 436 } else { 437 TODO(currentLocation, "Reduction of some types is not supported"); 438 } 439 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( 440 firOpBuilder.getContext(), decl.getSymName())); 441 } 442 } else if (const auto *reductionIntrinsic = 443 std::get_if<omp::clause::ProcedureDesignator>( 444 &redOperator.u)) { 445 if (ReductionProcessor::supportedIntrinsicProcReduction( 446 *reductionIntrinsic)) { 447 ReductionProcessor::ReductionIdentifier redId = 448 ReductionProcessor::getReductionType(*reductionIntrinsic); 449 for (const Object &object : objectList) { 450 const Fortran::semantics::Symbol *symbol = object.id(); 451 mlir::Value symVal = converter.getSymbolAddress(*symbol); 452 if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) 453 symVal = declOp.getBase(); 454 auto redType = symVal.getType().cast<fir::ReferenceType>(); 455 assert(redType.getEleTy().isIntOrIndexOrFloat() && 456 "Unsupported reduction type"); 457 decl = createReductionDecl( 458 firOpBuilder, 459 getReductionName(getRealName(*reductionIntrinsic).ToString(), 460 redType, isByRef), 461 redId, redType, currentLocation, isByRef); 462 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( 463 firOpBuilder.getContext(), decl.getSymName())); 464 } 465 } 466 } 467 } 468 469 const Fortran::semantics::SourceName 470 ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) { 471 return symbol->GetUltimate().name(); 472 } 473 474 const Fortran::semantics::SourceName 475 ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) { 476 return getRealName(pd.v.id()); 477 } 478 479 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId, 480 mlir::Location loc) { 481 switch (redId) { 482 case ReductionIdentifier::ADD: 483 case ReductionIdentifier::OR: 484 case ReductionIdentifier::NEQV: 485 return 0; 486 case ReductionIdentifier::MULTIPLY: 487 case ReductionIdentifier::AND: 488 case ReductionIdentifier::EQV: 489 return 1; 490 default: 491 TODO(loc, "Reduction of some intrinsic operators is not supported"); 492 } 493 } 494 495 } // namespace omp 496 } // namespace lower 497 } // namespace Fortran 498