1 //===-- Lower/OpenMP/ReductionProcessor.h -----------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H 14 #define FORTRAN_LOWER_REDUCTIONPROCESSOR_H 15 16 #include "Clauses.h" 17 #include "flang/Optimizer/Builder/FIRBuilder.h" 18 #include "flang/Optimizer/Dialect/FIRType.h" 19 #include "flang/Parser/parse-tree.h" 20 #include "flang/Semantics/symbol.h" 21 #include "flang/Semantics/type.h" 22 #include "mlir/IR/Location.h" 23 #include "mlir/IR/Types.h" 24 25 namespace mlir { 26 namespace omp { 27 class DeclareReductionOp; 28 } // namespace omp 29 } // namespace mlir 30 31 namespace Fortran { 32 namespace lower { 33 class AbstractConverter; 34 } // namespace lower 35 } // namespace Fortran 36 37 namespace Fortran { 38 namespace lower { 39 namespace omp { 40 41 class ReductionProcessor { 42 public: 43 // TODO: Move this enumeration to the OpenMP dialect 44 enum ReductionIdentifier { 45 ID, 46 USER_DEF_OP, 47 ADD, 48 SUBTRACT, 49 MULTIPLY, 50 AND, 51 OR, 52 EQV, 53 NEQV, 54 MAX, 55 MIN, 56 IAND, 57 IOR, 58 IEOR 59 }; 60 61 static ReductionIdentifier 62 getReductionType(const omp::clause::ProcedureDesignator &pd); 63 64 static ReductionIdentifier 65 getReductionType(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp); 66 67 static bool 68 supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd); 69 70 static const semantics::SourceName 71 getRealName(const semantics::Symbol *symbol); 72 73 static const semantics::SourceName 74 getRealName(const omp::clause::ProcedureDesignator &pd); 75 76 static std::string getReductionName(llvm::StringRef name, 77 const fir::KindMapping &kindMap, 78 mlir::Type ty, bool isByRef); 79 80 static std::string 81 getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, 82 const fir::KindMapping &kindMap, mlir::Type ty, 83 bool isByRef); 84 85 /// This function returns the identity value of the operator \p 86 /// reductionOpName. For example: 87 /// 0 + x = x, 88 /// 1 * x = x 89 static int getOperationIdentity(ReductionIdentifier redId, 90 mlir::Location loc); 91 92 static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type, 93 ReductionIdentifier redId, 94 fir::FirOpBuilder &builder); 95 96 template <typename FloatOp, typename IntegerOp> 97 static mlir::Value getReductionOperation(fir::FirOpBuilder &builder, 98 mlir::Type type, mlir::Location loc, 99 mlir::Value op1, mlir::Value op2); 100 template <typename FloatOp, typename IntegerOp, typename ComplexOp> 101 static mlir::Value getReductionOperation(fir::FirOpBuilder &builder, 102 mlir::Type type, mlir::Location loc, 103 mlir::Value op1, mlir::Value op2); 104 105 static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder, 106 mlir::Location loc, 107 ReductionIdentifier redId, 108 mlir::Type type, mlir::Value op1, 109 mlir::Value op2); 110 111 /// Creates an OpenMP reduction declaration and inserts it into the provided 112 /// symbol table. The declaration has a constant initializer with the neutral 113 /// value `initValue`, and the reduction combiner carried over from `reduce`. 114 /// TODO: add atomic region. 115 static mlir::omp::DeclareReductionOp 116 createDeclareReduction(fir::FirOpBuilder &builder, 117 llvm::StringRef reductionOpName, 118 const ReductionIdentifier redId, mlir::Type type, 119 mlir::Location loc, bool isByRef); 120 121 /// Creates a reduction declaration and associates it with an OpenMP block 122 /// directive. 123 static void addDeclareReduction( 124 mlir::Location currentLocation, lower::AbstractConverter &converter, 125 const omp::clause::Reduction &reduction, 126 llvm::SmallVectorImpl<mlir::Value> &reductionVars, 127 llvm::SmallVectorImpl<bool> &reduceVarByRef, 128 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, 129 llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols); 130 }; 131 132 template <typename FloatOp, typename IntegerOp> 133 mlir::Value 134 ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder, 135 mlir::Type type, mlir::Location loc, 136 mlir::Value op1, mlir::Value op2) { 137 type = fir::unwrapRefType(type); 138 assert(type.isIntOrIndexOrFloat() && 139 "only integer, float and complex types are currently supported"); 140 if (type.isIntOrIndex()) 141 return builder.create<IntegerOp>(loc, op1, op2); 142 return builder.create<FloatOp>(loc, op1, op2); 143 } 144 145 template <typename FloatOp, typename IntegerOp, typename ComplexOp> 146 mlir::Value 147 ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder, 148 mlir::Type type, mlir::Location loc, 149 mlir::Value op1, mlir::Value op2) { 150 assert((type.isIntOrIndexOrFloat() || fir::isa_complex(type)) && 151 "only integer, float and complex types are currently supported"); 152 if (type.isIntOrIndex()) 153 return builder.create<IntegerOp>(loc, op1, op2); 154 if (fir::isa_real(type)) 155 return builder.create<FloatOp>(loc, op1, op2); 156 return builder.create<ComplexOp>(loc, op1, op2); 157 } 158 159 } // namespace omp 160 } // namespace lower 161 } // namespace Fortran 162 163 #endif // FORTRAN_LOWER_REDUCTIONPROCESSOR_H 164