//===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ // //===----------------------------------------------------------------------===// #include "ReductionProcessor.h" #include "flang/Lower/AbstractConverter.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Parser/tools.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" namespace Fortran { namespace lower { namespace omp { ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( const Fortran::parser::ProcedureDesignator &pd) { auto redType = llvm::StringSwitch>( ReductionProcessor::getRealName(pd).ToString()) .Case("max", ReductionIdentifier::MAX) .Case("min", ReductionIdentifier::MIN) .Case("iand", ReductionIdentifier::IAND) .Case("ior", ReductionIdentifier::IOR) .Case("ieor", ReductionIdentifier::IEOR) .Default(std::nullopt); assert(redType && "Invalid Reduction"); return *redType; } ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) { switch (intrinsicOp) { case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: return ReductionIdentifier::ADD; case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract: return ReductionIdentifier::SUBTRACT; case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: return ReductionIdentifier::MULTIPLY; case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: return ReductionIdentifier::AND; case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: return ReductionIdentifier::EQV; case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: return ReductionIdentifier::OR; case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: return ReductionIdentifier::NEQV; default: llvm_unreachable("unexpected intrinsic operator in reduction"); } } bool ReductionProcessor::supportedIntrinsicProcReduction( const Fortran::parser::ProcedureDesignator &pd) { const auto *name{Fortran::parser::Unwrap(pd)}; assert(name && "Invalid Reduction Intrinsic."); if (!name->symbol->GetUltimate().attrs().test( Fortran::semantics::Attr::INTRINSIC)) return false; auto redType = llvm::StringSwitch(getRealName(name).ToString()) .Case("max", true) .Case("min", true) .Case("iand", true) .Case("ior", true) .Case("ieor", true) .Default(false); return redType; } std::string ReductionProcessor::getReductionName(llvm::StringRef name, mlir::Type ty) { return (llvm::Twine(name) + (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + llvm::Twine(ty.getIntOrFloatBitWidth())) .str(); } std::string ReductionProcessor::getReductionName( Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, mlir::Type ty) { std::string reductionName; switch (intrinsicOp) { case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: reductionName = "add_reduction"; break; case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: reductionName = "multiply_reduction"; break; case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: return "and_reduction"; case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: return "eqv_reduction"; case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: return "or_reduction"; case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: return "neqv_reduction"; default: reductionName = "other_reduction"; break; } return getReductionName(reductionName, ty); } mlir::Value ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, ReductionIdentifier redId, fir::FirOpBuilder &builder) { assert((fir::isa_integer(type) || fir::isa_real(type) || type.isa()) && "only integer, logical and real types are currently supported"); switch (redId) { case ReductionIdentifier::MAX: { if (auto ty = type.dyn_cast()) { const llvm::fltSemantics &sem = ty.getFloatSemantics(); return builder.createRealConstant( loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); } unsigned bits = type.getIntOrFloatBitWidth(); int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); return builder.createIntegerConstant(loc, type, minInt); } case ReductionIdentifier::MIN: { if (auto ty = type.dyn_cast()) { const llvm::fltSemantics &sem = ty.getFloatSemantics(); return builder.createRealConstant( loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); } unsigned bits = type.getIntOrFloatBitWidth(); int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); return builder.createIntegerConstant(loc, type, maxInt); } case ReductionIdentifier::IOR: { unsigned bits = type.getIntOrFloatBitWidth(); int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); return builder.createIntegerConstant(loc, type, zeroInt); } case ReductionIdentifier::IEOR: { unsigned bits = type.getIntOrFloatBitWidth(); int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); return builder.createIntegerConstant(loc, type, zeroInt); } case ReductionIdentifier::IAND: { unsigned bits = type.getIntOrFloatBitWidth(); int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); return builder.createIntegerConstant(loc, type, allOnInt); } case ReductionIdentifier::ADD: case ReductionIdentifier::MULTIPLY: case ReductionIdentifier::AND: case ReductionIdentifier::OR: case ReductionIdentifier::EQV: case ReductionIdentifier::NEQV: if (type.isa()) return builder.create( loc, type, builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc))); if (type.isa()) { mlir::Value intConst = builder.create( loc, builder.getI1Type(), builder.getIntegerAttr(builder.getI1Type(), getOperationIdentity(redId, loc))); return builder.createConvert(loc, type, intConst); } return builder.create( loc, type, builder.getIntegerAttr(type, getOperationIdentity(redId, loc))); case ReductionIdentifier::ID: case ReductionIdentifier::USER_DEF_OP: case ReductionIdentifier::SUBTRACT: TODO(loc, "Reduction of some identifier types is not supported"); } llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue"); } mlir::Value ReductionProcessor::createScalarCombiner( fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId, mlir::Type type, mlir::Value op1, mlir::Value op2) { mlir::Value reductionOp; switch (redId) { case ReductionIdentifier::MAX: reductionOp = getReductionOperation( builder, type, loc, op1, op2); break; case ReductionIdentifier::MIN: reductionOp = getReductionOperation( builder, type, loc, op1, op2); break; case ReductionIdentifier::IOR: assert((type.isIntOrIndex()) && "only integer is expected"); reductionOp = builder.create(loc, op1, op2); break; case ReductionIdentifier::IEOR: assert((type.isIntOrIndex()) && "only integer is expected"); reductionOp = builder.create(loc, op1, op2); break; case ReductionIdentifier::IAND: assert((type.isIntOrIndex()) && "only integer is expected"); reductionOp = builder.create(loc, op1, op2); break; case ReductionIdentifier::ADD: reductionOp = getReductionOperation( builder, type, loc, op1, op2); break; case ReductionIdentifier::MULTIPLY: reductionOp = getReductionOperation( builder, type, loc, op1, op2); break; case ReductionIdentifier::AND: { mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); mlir::Value andiOp = builder.create(loc, op1I1, op2I1); reductionOp = builder.createConvert(loc, type, andiOp); break; } case ReductionIdentifier::OR: { mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); mlir::Value oriOp = builder.create(loc, op1I1, op2I1); reductionOp = builder.createConvert(loc, type, oriOp); break; } case ReductionIdentifier::EQV: { mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); mlir::Value cmpiOp = builder.create( loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); reductionOp = builder.createConvert(loc, type, cmpiOp); break; } case ReductionIdentifier::NEQV: { mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); mlir::Value cmpiOp = builder.create( loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); reductionOp = builder.createConvert(loc, type, cmpiOp); break; } default: TODO(loc, "Reduction of some intrinsic operators is not supported"); } return reductionOp; } mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl( fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) { mlir::OpBuilder::InsertionGuard guard(builder); mlir::ModuleOp module = builder.getModule(); auto decl = module.lookupSymbol(reductionOpName); if (decl) return decl; mlir::OpBuilder modBuilder(module.getBodyRegion()); decl = modBuilder.create(loc, reductionOpName, type); builder.createBlock(&decl.getInitializerRegion(), decl.getInitializerRegion().end(), {type}, {loc}); builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); mlir::Value init = getReductionInitValue(loc, type, redId, builder); builder.create(loc, init); builder.createBlock(&decl.getReductionRegion(), decl.getReductionRegion().end(), {type, type}, {loc, loc}); builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); mlir::Value reductionOp = createScalarCombiner(builder, loc, redId, type, op1, op2); builder.create(loc, reductionOp); return decl; } void ReductionProcessor::addReductionDecl( mlir::Location currentLocation, Fortran::lower::AbstractConverter &converter, const Fortran::parser::OmpReductionClause &reduction, llvm::SmallVectorImpl &reductionVars, llvm::SmallVectorImpl &reductionDeclSymbols, llvm::SmallVectorImpl *reductionSymbols) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::omp::ReductionDeclareOp decl; const auto &redOperator{ std::get(reduction.t)}; const auto &objectList{std::get(reduction.t)}; if (const auto &redDefinedOp = std::get_if(&redOperator.u)) { const auto &intrinsicOp{ std::get( redDefinedOp->u)}; ReductionIdentifier redId = getReductionType(intrinsicOp); switch (redId) { case ReductionIdentifier::ADD: case ReductionIdentifier::MULTIPLY: case ReductionIdentifier::AND: case ReductionIdentifier::EQV: case ReductionIdentifier::OR: case ReductionIdentifier::NEQV: break; default: TODO(currentLocation, "Reduction of some intrinsic operators is not supported"); break; } for (const Fortran::parser::OmpObject &ompObject : objectList.v) { if (const auto *name{ Fortran::parser::Unwrap(ompObject)}) { if (const Fortran::semantics::Symbol * symbol{name->symbol}) { if (reductionSymbols) reductionSymbols->push_back(symbol); mlir::Value symVal = converter.getSymbolAddress(*symbol); if (auto declOp = symVal.getDefiningOp()) symVal = declOp.getBase(); mlir::Type redType = symVal.getType().cast().getEleTy(); reductionVars.push_back(symVal); if (redType.isa()) decl = createReductionDecl( firOpBuilder, getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId, redType, currentLocation); else if (redType.isIntOrIndexOrFloat()) { decl = createReductionDecl(firOpBuilder, getReductionName(intrinsicOp, redType), redId, redType, currentLocation); } else { TODO(currentLocation, "Reduction of some types is not supported"); } reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( firOpBuilder.getContext(), decl.getSymName())); } } } } else if (const auto *reductionIntrinsic = std::get_if( &redOperator.u)) { if (ReductionProcessor::supportedIntrinsicProcReduction( *reductionIntrinsic)) { ReductionProcessor::ReductionIdentifier redId = ReductionProcessor::getReductionType(*reductionIntrinsic); for (const Fortran::parser::OmpObject &ompObject : objectList.v) { if (const auto *name{ Fortran::parser::Unwrap(ompObject)}) { if (const Fortran::semantics::Symbol * symbol{name->symbol}) { if (reductionSymbols) reductionSymbols->push_back(symbol); mlir::Value symVal = converter.getSymbolAddress(*symbol); if (auto declOp = symVal.getDefiningOp()) symVal = declOp.getBase(); mlir::Type redType = symVal.getType().cast().getEleTy(); reductionVars.push_back(symVal); assert(redType.isIntOrIndexOrFloat() && "Unsupported reduction type"); decl = createReductionDecl( firOpBuilder, getReductionName(getRealName(*reductionIntrinsic).ToString(), redType), redId, redType, currentLocation); reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( firOpBuilder.getContext(), decl.getSymName())); } } } } } } const Fortran::semantics::SourceName ReductionProcessor::getRealName(const Fortran::parser::Name *name) { return name->symbol->GetUltimate().name(); } const Fortran::semantics::SourceName ReductionProcessor::getRealName( const Fortran::parser::ProcedureDesignator &pd) { const auto *name{Fortran::parser::Unwrap(pd)}; assert(name && "Invalid Reduction Intrinsic."); return getRealName(name); } int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId, mlir::Location loc) { switch (redId) { case ReductionIdentifier::ADD: case ReductionIdentifier::OR: case ReductionIdentifier::NEQV: return 0; case ReductionIdentifier::MULTIPLY: case ReductionIdentifier::AND: case ReductionIdentifier::EQV: return 1; default: TODO(loc, "Reduction of some intrinsic operators is not supported"); } } } // namespace omp } // namespace lower } // namespace Fortran