xref: /llvm-project/flang/lib/Lower/OpenMP/ReductionProcessor.h (revision 88478a89cd85adcc32f2a321ef9e9906c5fdbe26)
1 //===-- Lower/OpenMP/ReductionProcessor.h -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H
14 #define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
15 
16 #include "Clauses.h"
17 #include "flang/Optimizer/Builder/FIRBuilder.h"
18 #include "flang/Optimizer/Dialect/FIRType.h"
19 #include "flang/Parser/parse-tree.h"
20 #include "flang/Semantics/symbol.h"
21 #include "flang/Semantics/type.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/IR/Types.h"
24 
25 namespace mlir {
26 namespace omp {
27 class DeclareReductionOp;
28 } // namespace omp
29 } // namespace mlir
30 
31 namespace Fortran {
32 namespace lower {
33 class AbstractConverter;
34 } // namespace lower
35 } // namespace Fortran
36 
37 namespace Fortran {
38 namespace lower {
39 namespace omp {
40 
41 class ReductionProcessor {
42 public:
43   // TODO: Move this enumeration to the OpenMP dialect
44   enum ReductionIdentifier {
45     ID,
46     USER_DEF_OP,
47     ADD,
48     SUBTRACT,
49     MULTIPLY,
50     AND,
51     OR,
52     EQV,
53     NEQV,
54     MAX,
55     MIN,
56     IAND,
57     IOR,
58     IEOR
59   };
60 
61   static ReductionIdentifier
62   getReductionType(const omp::clause::ProcedureDesignator &pd);
63 
64   static ReductionIdentifier
65   getReductionType(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp);
66 
67   static bool
68   supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd);
69 
70   static const semantics::SourceName
71   getRealName(const semantics::Symbol *symbol);
72 
73   static const semantics::SourceName
74   getRealName(const omp::clause::ProcedureDesignator &pd);
75 
76   static std::string getReductionName(llvm::StringRef name,
77                                       const fir::KindMapping &kindMap,
78                                       mlir::Type ty, bool isByRef);
79 
80   static std::string
81   getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
82                    const fir::KindMapping &kindMap, mlir::Type ty,
83                    bool isByRef);
84 
85   /// This function returns the identity value of the operator \p
86   /// reductionOpName. For example:
87   ///    0 + x = x,
88   ///    1 * x = x
89   static int getOperationIdentity(ReductionIdentifier redId,
90                                   mlir::Location loc);
91 
92   static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
93                                            ReductionIdentifier redId,
94                                            fir::FirOpBuilder &builder);
95 
96   template <typename FloatOp, typename IntegerOp>
97   static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
98                                            mlir::Type type, mlir::Location loc,
99                                            mlir::Value op1, mlir::Value op2);
100   template <typename FloatOp, typename IntegerOp, typename ComplexOp>
101   static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
102                                            mlir::Type type, mlir::Location loc,
103                                            mlir::Value op1, mlir::Value op2);
104 
105   static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder,
106                                           mlir::Location loc,
107                                           ReductionIdentifier redId,
108                                           mlir::Type type, mlir::Value op1,
109                                           mlir::Value op2);
110 
111   /// Creates an OpenMP reduction declaration and inserts it into the provided
112   /// symbol table. The declaration has a constant initializer with the neutral
113   /// value `initValue`, and the reduction combiner carried over from `reduce`.
114   /// TODO: add atomic region.
115   static mlir::omp::DeclareReductionOp
116   createDeclareReduction(fir::FirOpBuilder &builder,
117                          llvm::StringRef reductionOpName,
118                          const ReductionIdentifier redId, mlir::Type type,
119                          mlir::Location loc, bool isByRef);
120 
121   /// Creates a reduction declaration and associates it with an OpenMP block
122   /// directive.
123   static void addDeclareReduction(
124       mlir::Location currentLocation, lower::AbstractConverter &converter,
125       const omp::clause::Reduction &reduction,
126       llvm::SmallVectorImpl<mlir::Value> &reductionVars,
127       llvm::SmallVectorImpl<bool> &reduceVarByRef,
128       llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
129       llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
130 };
131 
132 template <typename FloatOp, typename IntegerOp>
133 mlir::Value
134 ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
135                                           mlir::Type type, mlir::Location loc,
136                                           mlir::Value op1, mlir::Value op2) {
137   type = fir::unwrapRefType(type);
138   assert(type.isIntOrIndexOrFloat() &&
139          "only integer, float and complex types are currently supported");
140   if (type.isIntOrIndex())
141     return builder.create<IntegerOp>(loc, op1, op2);
142   return builder.create<FloatOp>(loc, op1, op2);
143 }
144 
145 template <typename FloatOp, typename IntegerOp, typename ComplexOp>
146 mlir::Value
147 ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
148                                           mlir::Type type, mlir::Location loc,
149                                           mlir::Value op1, mlir::Value op2) {
150   assert((type.isIntOrIndexOrFloat() || fir::isa_complex(type)) &&
151          "only integer, float and complex types are currently supported");
152   if (type.isIntOrIndex())
153     return builder.create<IntegerOp>(loc, op1, op2);
154   if (fir::isa_real(type))
155     return builder.create<FloatOp>(loc, op1, op2);
156   return builder.create<ComplexOp>(loc, op1, op2);
157 }
158 
159 } // namespace omp
160 } // namespace lower
161 } // namespace Fortran
162 
163 #endif // FORTRAN_LOWER_REDUCTIONPROCESSOR_H
164