1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ReductionProcessor.h" 14 15 #include "flang/Lower/AbstractConverter.h" 16 #include "flang/Optimizer/Builder/Todo.h" 17 #include "flang/Optimizer/HLFIR/HLFIROps.h" 18 #include "flang/Parser/tools.h" 19 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 20 21 namespace Fortran { 22 namespace lower { 23 namespace omp { 24 25 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 26 const Fortran::parser::ProcedureDesignator &pd) { 27 auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>( 28 ReductionProcessor::getRealName(pd).ToString()) 29 .Case("max", ReductionIdentifier::MAX) 30 .Case("min", ReductionIdentifier::MIN) 31 .Case("iand", ReductionIdentifier::IAND) 32 .Case("ior", ReductionIdentifier::IOR) 33 .Case("ieor", ReductionIdentifier::IEOR) 34 .Default(std::nullopt); 35 assert(redType && "Invalid Reduction"); 36 return *redType; 37 } 38 39 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( 40 Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) { 41 switch (intrinsicOp) { 42 case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: 43 return ReductionIdentifier::ADD; 44 case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract: 45 return ReductionIdentifier::SUBTRACT; 46 case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: 47 return ReductionIdentifier::MULTIPLY; 48 case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: 49 return ReductionIdentifier::AND; 50 case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: 51 return ReductionIdentifier::EQV; 52 case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: 53 return ReductionIdentifier::OR; 54 case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: 55 return ReductionIdentifier::NEQV; 56 default: 57 llvm_unreachable("unexpected intrinsic operator in reduction"); 58 } 59 } 60 61 bool ReductionProcessor::supportedIntrinsicProcReduction( 62 const Fortran::parser::ProcedureDesignator &pd) { 63 const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)}; 64 assert(name && "Invalid Reduction Intrinsic."); 65 if (!name->symbol->GetUltimate().attrs().test( 66 Fortran::semantics::Attr::INTRINSIC)) 67 return false; 68 auto redType = llvm::StringSwitch<bool>(getRealName(name).ToString()) 69 .Case("max", true) 70 .Case("min", true) 71 .Case("iand", true) 72 .Case("ior", true) 73 .Case("ieor", true) 74 .Default(false); 75 return redType; 76 } 77 78 std::string ReductionProcessor::getReductionName(llvm::StringRef name, 79 mlir::Type ty) { 80 return (llvm::Twine(name) + 81 (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + 82 llvm::Twine(ty.getIntOrFloatBitWidth())) 83 .str(); 84 } 85 86 std::string ReductionProcessor::getReductionName( 87 Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, 88 mlir::Type ty) { 89 std::string reductionName; 90 91 switch (intrinsicOp) { 92 case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: 93 reductionName = "add_reduction"; 94 break; 95 case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: 96 reductionName = "multiply_reduction"; 97 break; 98 case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: 99 return "and_reduction"; 100 case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: 101 return "eqv_reduction"; 102 case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: 103 return "or_reduction"; 104 case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: 105 return "neqv_reduction"; 106 default: 107 reductionName = "other_reduction"; 108 break; 109 } 110 111 return getReductionName(reductionName, ty); 112 } 113 114 mlir::Value 115 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, 116 ReductionIdentifier redId, 117 fir::FirOpBuilder &builder) { 118 assert((fir::isa_integer(type) || fir::isa_real(type) || 119 type.isa<fir::LogicalType>()) && 120 "only integer, logical and real types are currently supported"); 121 switch (redId) { 122 case ReductionIdentifier::MAX: { 123 if (auto ty = type.dyn_cast<mlir::FloatType>()) { 124 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 125 return builder.createRealConstant( 126 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); 127 } 128 unsigned bits = type.getIntOrFloatBitWidth(); 129 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); 130 return builder.createIntegerConstant(loc, type, minInt); 131 } 132 case ReductionIdentifier::MIN: { 133 if (auto ty = type.dyn_cast<mlir::FloatType>()) { 134 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 135 return builder.createRealConstant( 136 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); 137 } 138 unsigned bits = type.getIntOrFloatBitWidth(); 139 int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); 140 return builder.createIntegerConstant(loc, type, maxInt); 141 } 142 case ReductionIdentifier::IOR: { 143 unsigned bits = type.getIntOrFloatBitWidth(); 144 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 145 return builder.createIntegerConstant(loc, type, zeroInt); 146 } 147 case ReductionIdentifier::IEOR: { 148 unsigned bits = type.getIntOrFloatBitWidth(); 149 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); 150 return builder.createIntegerConstant(loc, type, zeroInt); 151 } 152 case ReductionIdentifier::IAND: { 153 unsigned bits = type.getIntOrFloatBitWidth(); 154 int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); 155 return builder.createIntegerConstant(loc, type, allOnInt); 156 } 157 case ReductionIdentifier::ADD: 158 case ReductionIdentifier::MULTIPLY: 159 case ReductionIdentifier::AND: 160 case ReductionIdentifier::OR: 161 case ReductionIdentifier::EQV: 162 case ReductionIdentifier::NEQV: 163 if (type.isa<mlir::FloatType>()) 164 return builder.create<mlir::arith::ConstantOp>( 165 loc, type, 166 builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc))); 167 168 if (type.isa<fir::LogicalType>()) { 169 mlir::Value intConst = builder.create<mlir::arith::ConstantOp>( 170 loc, builder.getI1Type(), 171 builder.getIntegerAttr(builder.getI1Type(), 172 getOperationIdentity(redId, loc))); 173 return builder.createConvert(loc, type, intConst); 174 } 175 176 return builder.create<mlir::arith::ConstantOp>( 177 loc, type, 178 builder.getIntegerAttr(type, getOperationIdentity(redId, loc))); 179 case ReductionIdentifier::ID: 180 case ReductionIdentifier::USER_DEF_OP: 181 case ReductionIdentifier::SUBTRACT: 182 TODO(loc, "Reduction of some identifier types is not supported"); 183 } 184 llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue"); 185 } 186 187 mlir::Value ReductionProcessor::createScalarCombiner( 188 fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId, 189 mlir::Type type, mlir::Value op1, mlir::Value op2) { 190 mlir::Value reductionOp; 191 switch (redId) { 192 case ReductionIdentifier::MAX: 193 reductionOp = 194 getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>( 195 builder, type, loc, op1, op2); 196 break; 197 case ReductionIdentifier::MIN: 198 reductionOp = 199 getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>( 200 builder, type, loc, op1, op2); 201 break; 202 case ReductionIdentifier::IOR: 203 assert((type.isIntOrIndex()) && "only integer is expected"); 204 reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2); 205 break; 206 case ReductionIdentifier::IEOR: 207 assert((type.isIntOrIndex()) && "only integer is expected"); 208 reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2); 209 break; 210 case ReductionIdentifier::IAND: 211 assert((type.isIntOrIndex()) && "only integer is expected"); 212 reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2); 213 break; 214 case ReductionIdentifier::ADD: 215 reductionOp = 216 getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>( 217 builder, type, loc, op1, op2); 218 break; 219 case ReductionIdentifier::MULTIPLY: 220 reductionOp = 221 getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>( 222 builder, type, loc, op1, op2); 223 break; 224 case ReductionIdentifier::AND: { 225 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 226 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 227 228 mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1); 229 230 reductionOp = builder.createConvert(loc, type, andiOp); 231 break; 232 } 233 case ReductionIdentifier::OR: { 234 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 235 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 236 237 mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1); 238 239 reductionOp = builder.createConvert(loc, type, oriOp); 240 break; 241 } 242 case ReductionIdentifier::EQV: { 243 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 244 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 245 246 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 247 loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); 248 249 reductionOp = builder.createConvert(loc, type, cmpiOp); 250 break; 251 } 252 case ReductionIdentifier::NEQV: { 253 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); 254 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); 255 256 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( 257 loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); 258 259 reductionOp = builder.createConvert(loc, type, cmpiOp); 260 break; 261 } 262 default: 263 TODO(loc, "Reduction of some intrinsic operators is not supported"); 264 } 265 266 return reductionOp; 267 } 268 269 mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl( 270 fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, 271 const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) { 272 mlir::OpBuilder::InsertionGuard guard(builder); 273 mlir::ModuleOp module = builder.getModule(); 274 275 auto decl = 276 module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName); 277 if (decl) 278 return decl; 279 280 mlir::OpBuilder modBuilder(module.getBodyRegion()); 281 282 decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName, 283 type); 284 builder.createBlock(&decl.getInitializerRegion(), 285 decl.getInitializerRegion().end(), {type}, {loc}); 286 builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); 287 mlir::Value init = getReductionInitValue(loc, type, redId, builder); 288 builder.create<mlir::omp::YieldOp>(loc, init); 289 290 builder.createBlock(&decl.getReductionRegion(), 291 decl.getReductionRegion().end(), {type, type}, 292 {loc, loc}); 293 294 builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); 295 mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); 296 mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); 297 298 mlir::Value reductionOp = 299 createScalarCombiner(builder, loc, redId, type, op1, op2); 300 builder.create<mlir::omp::YieldOp>(loc, reductionOp); 301 302 return decl; 303 } 304 305 void ReductionProcessor::addReductionDecl( 306 mlir::Location currentLocation, 307 Fortran::lower::AbstractConverter &converter, 308 const Fortran::parser::OmpReductionClause &reduction, 309 llvm::SmallVectorImpl<mlir::Value> &reductionVars, 310 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, 311 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> 312 *reductionSymbols) { 313 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 314 mlir::omp::ReductionDeclareOp decl; 315 const auto &redOperator{ 316 std::get<Fortran::parser::OmpReductionOperator>(reduction.t)}; 317 const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)}; 318 if (const auto &redDefinedOp = 319 std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) { 320 const auto &intrinsicOp{ 321 std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>( 322 redDefinedOp->u)}; 323 ReductionIdentifier redId = getReductionType(intrinsicOp); 324 switch (redId) { 325 case ReductionIdentifier::ADD: 326 case ReductionIdentifier::MULTIPLY: 327 case ReductionIdentifier::AND: 328 case ReductionIdentifier::EQV: 329 case ReductionIdentifier::OR: 330 case ReductionIdentifier::NEQV: 331 break; 332 default: 333 TODO(currentLocation, 334 "Reduction of some intrinsic operators is not supported"); 335 break; 336 } 337 for (const Fortran::parser::OmpObject &ompObject : objectList.v) { 338 if (const auto *name{ 339 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { 340 if (const Fortran::semantics::Symbol * symbol{name->symbol}) { 341 if (reductionSymbols) 342 reductionSymbols->push_back(symbol); 343 mlir::Value symVal = converter.getSymbolAddress(*symbol); 344 if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) 345 symVal = declOp.getBase(); 346 mlir::Type redType = 347 symVal.getType().cast<fir::ReferenceType>().getEleTy(); 348 reductionVars.push_back(symVal); 349 if (redType.isa<fir::LogicalType>()) 350 decl = createReductionDecl( 351 firOpBuilder, 352 getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId, 353 redType, currentLocation); 354 else if (redType.isIntOrIndexOrFloat()) { 355 decl = createReductionDecl(firOpBuilder, 356 getReductionName(intrinsicOp, redType), 357 redId, redType, currentLocation); 358 } else { 359 TODO(currentLocation, "Reduction of some types is not supported"); 360 } 361 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( 362 firOpBuilder.getContext(), decl.getSymName())); 363 } 364 } 365 } 366 } else if (const auto *reductionIntrinsic = 367 std::get_if<Fortran::parser::ProcedureDesignator>( 368 &redOperator.u)) { 369 if (ReductionProcessor::supportedIntrinsicProcReduction( 370 *reductionIntrinsic)) { 371 ReductionProcessor::ReductionIdentifier redId = 372 ReductionProcessor::getReductionType(*reductionIntrinsic); 373 for (const Fortran::parser::OmpObject &ompObject : objectList.v) { 374 if (const auto *name{ 375 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { 376 if (const Fortran::semantics::Symbol * symbol{name->symbol}) { 377 if (reductionSymbols) 378 reductionSymbols->push_back(symbol); 379 mlir::Value symVal = converter.getSymbolAddress(*symbol); 380 if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) 381 symVal = declOp.getBase(); 382 mlir::Type redType = 383 symVal.getType().cast<fir::ReferenceType>().getEleTy(); 384 reductionVars.push_back(symVal); 385 assert(redType.isIntOrIndexOrFloat() && 386 "Unsupported reduction type"); 387 decl = createReductionDecl( 388 firOpBuilder, 389 getReductionName(getRealName(*reductionIntrinsic).ToString(), 390 redType), 391 redId, redType, currentLocation); 392 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( 393 firOpBuilder.getContext(), decl.getSymName())); 394 } 395 } 396 } 397 } 398 } 399 } 400 401 const Fortran::semantics::SourceName 402 ReductionProcessor::getRealName(const Fortran::parser::Name *name) { 403 return name->symbol->GetUltimate().name(); 404 } 405 406 const Fortran::semantics::SourceName ReductionProcessor::getRealName( 407 const Fortran::parser::ProcedureDesignator &pd) { 408 const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)}; 409 assert(name && "Invalid Reduction Intrinsic."); 410 return getRealName(name); 411 } 412 413 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId, 414 mlir::Location loc) { 415 switch (redId) { 416 case ReductionIdentifier::ADD: 417 case ReductionIdentifier::OR: 418 case ReductionIdentifier::NEQV: 419 return 0; 420 case ReductionIdentifier::MULTIPLY: 421 case ReductionIdentifier::AND: 422 case ReductionIdentifier::EQV: 423 return 1; 424 default: 425 TODO(loc, "Reduction of some intrinsic operators is not supported"); 426 } 427 } 428 429 } // namespace omp 430 } // namespace lower 431 } // namespace Fortran 432