xref: /llvm-project/flang/lib/Lower/OpenMP/ReductionProcessor.cpp (revision 4d4af15c3fb671ed9f7eef9f29ebd6fde15618df)
1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ReductionProcessor.h"
14 
15 #include "flang/Lower/AbstractConverter.h"
16 #include "flang/Optimizer/Builder/Todo.h"
17 #include "flang/Optimizer/HLFIR/HLFIROps.h"
18 #include "flang/Parser/tools.h"
19 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
20 
21 namespace Fortran {
22 namespace lower {
23 namespace omp {
24 
25 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
26     const Fortran::parser::ProcedureDesignator &pd) {
27   auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
28                      ReductionProcessor::getRealName(pd).ToString())
29                      .Case("max", ReductionIdentifier::MAX)
30                      .Case("min", ReductionIdentifier::MIN)
31                      .Case("iand", ReductionIdentifier::IAND)
32                      .Case("ior", ReductionIdentifier::IOR)
33                      .Case("ieor", ReductionIdentifier::IEOR)
34                      .Default(std::nullopt);
35   assert(redType && "Invalid Reduction");
36   return *redType;
37 }
38 
39 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
40     Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) {
41   switch (intrinsicOp) {
42   case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
43     return ReductionIdentifier::ADD;
44   case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract:
45     return ReductionIdentifier::SUBTRACT;
46   case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
47     return ReductionIdentifier::MULTIPLY;
48   case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
49     return ReductionIdentifier::AND;
50   case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
51     return ReductionIdentifier::EQV;
52   case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
53     return ReductionIdentifier::OR;
54   case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
55     return ReductionIdentifier::NEQV;
56   default:
57     llvm_unreachable("unexpected intrinsic operator in reduction");
58   }
59 }
60 
61 bool ReductionProcessor::supportedIntrinsicProcReduction(
62     const Fortran::parser::ProcedureDesignator &pd) {
63   const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
64   assert(name && "Invalid Reduction Intrinsic.");
65   if (!name->symbol->GetUltimate().attrs().test(
66           Fortran::semantics::Attr::INTRINSIC))
67     return false;
68   auto redType = llvm::StringSwitch<bool>(getRealName(name).ToString())
69                      .Case("max", true)
70                      .Case("min", true)
71                      .Case("iand", true)
72                      .Case("ior", true)
73                      .Case("ieor", true)
74                      .Default(false);
75   return redType;
76 }
77 
78 std::string ReductionProcessor::getReductionName(llvm::StringRef name,
79                                                  mlir::Type ty) {
80   return (llvm::Twine(name) +
81           (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
82           llvm::Twine(ty.getIntOrFloatBitWidth()))
83       .str();
84 }
85 
86 std::string ReductionProcessor::getReductionName(
87     Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
88     mlir::Type ty) {
89   std::string reductionName;
90 
91   switch (intrinsicOp) {
92   case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
93     reductionName = "add_reduction";
94     break;
95   case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
96     reductionName = "multiply_reduction";
97     break;
98   case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
99     return "and_reduction";
100   case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
101     return "eqv_reduction";
102   case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
103     return "or_reduction";
104   case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
105     return "neqv_reduction";
106   default:
107     reductionName = "other_reduction";
108     break;
109   }
110 
111   return getReductionName(reductionName, ty);
112 }
113 
114 mlir::Value
115 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
116                                           ReductionIdentifier redId,
117                                           fir::FirOpBuilder &builder) {
118   assert((fir::isa_integer(type) || fir::isa_real(type) ||
119           type.isa<fir::LogicalType>()) &&
120          "only integer, logical and real types are currently supported");
121   switch (redId) {
122   case ReductionIdentifier::MAX: {
123     if (auto ty = type.dyn_cast<mlir::FloatType>()) {
124       const llvm::fltSemantics &sem = ty.getFloatSemantics();
125       return builder.createRealConstant(
126           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
127     }
128     unsigned bits = type.getIntOrFloatBitWidth();
129     int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
130     return builder.createIntegerConstant(loc, type, minInt);
131   }
132   case ReductionIdentifier::MIN: {
133     if (auto ty = type.dyn_cast<mlir::FloatType>()) {
134       const llvm::fltSemantics &sem = ty.getFloatSemantics();
135       return builder.createRealConstant(
136           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
137     }
138     unsigned bits = type.getIntOrFloatBitWidth();
139     int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
140     return builder.createIntegerConstant(loc, type, maxInt);
141   }
142   case ReductionIdentifier::IOR: {
143     unsigned bits = type.getIntOrFloatBitWidth();
144     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
145     return builder.createIntegerConstant(loc, type, zeroInt);
146   }
147   case ReductionIdentifier::IEOR: {
148     unsigned bits = type.getIntOrFloatBitWidth();
149     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
150     return builder.createIntegerConstant(loc, type, zeroInt);
151   }
152   case ReductionIdentifier::IAND: {
153     unsigned bits = type.getIntOrFloatBitWidth();
154     int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
155     return builder.createIntegerConstant(loc, type, allOnInt);
156   }
157   case ReductionIdentifier::ADD:
158   case ReductionIdentifier::MULTIPLY:
159   case ReductionIdentifier::AND:
160   case ReductionIdentifier::OR:
161   case ReductionIdentifier::EQV:
162   case ReductionIdentifier::NEQV:
163     if (type.isa<mlir::FloatType>())
164       return builder.create<mlir::arith::ConstantOp>(
165           loc, type,
166           builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc)));
167 
168     if (type.isa<fir::LogicalType>()) {
169       mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
170           loc, builder.getI1Type(),
171           builder.getIntegerAttr(builder.getI1Type(),
172                                  getOperationIdentity(redId, loc)));
173       return builder.createConvert(loc, type, intConst);
174     }
175 
176     return builder.create<mlir::arith::ConstantOp>(
177         loc, type,
178         builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
179   case ReductionIdentifier::ID:
180   case ReductionIdentifier::USER_DEF_OP:
181   case ReductionIdentifier::SUBTRACT:
182     TODO(loc, "Reduction of some identifier types is not supported");
183   }
184   llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
185 }
186 
187 mlir::Value ReductionProcessor::createScalarCombiner(
188     fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
189     mlir::Type type, mlir::Value op1, mlir::Value op2) {
190   mlir::Value reductionOp;
191   switch (redId) {
192   case ReductionIdentifier::MAX:
193     reductionOp =
194         getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
195             builder, type, loc, op1, op2);
196     break;
197   case ReductionIdentifier::MIN:
198     reductionOp =
199         getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
200             builder, type, loc, op1, op2);
201     break;
202   case ReductionIdentifier::IOR:
203     assert((type.isIntOrIndex()) && "only integer is expected");
204     reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
205     break;
206   case ReductionIdentifier::IEOR:
207     assert((type.isIntOrIndex()) && "only integer is expected");
208     reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
209     break;
210   case ReductionIdentifier::IAND:
211     assert((type.isIntOrIndex()) && "only integer is expected");
212     reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
213     break;
214   case ReductionIdentifier::ADD:
215     reductionOp =
216         getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
217             builder, type, loc, op1, op2);
218     break;
219   case ReductionIdentifier::MULTIPLY:
220     reductionOp =
221         getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
222             builder, type, loc, op1, op2);
223     break;
224   case ReductionIdentifier::AND: {
225     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
226     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
227 
228     mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
229 
230     reductionOp = builder.createConvert(loc, type, andiOp);
231     break;
232   }
233   case ReductionIdentifier::OR: {
234     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
235     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
236 
237     mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
238 
239     reductionOp = builder.createConvert(loc, type, oriOp);
240     break;
241   }
242   case ReductionIdentifier::EQV: {
243     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
244     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
245 
246     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
247         loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
248 
249     reductionOp = builder.createConvert(loc, type, cmpiOp);
250     break;
251   }
252   case ReductionIdentifier::NEQV: {
253     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
254     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
255 
256     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
257         loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
258 
259     reductionOp = builder.createConvert(loc, type, cmpiOp);
260     break;
261   }
262   default:
263     TODO(loc, "Reduction of some intrinsic operators is not supported");
264   }
265 
266   return reductionOp;
267 }
268 
269 mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
270     fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
271     const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) {
272   mlir::OpBuilder::InsertionGuard guard(builder);
273   mlir::ModuleOp module = builder.getModule();
274 
275   auto decl =
276       module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
277   if (decl)
278     return decl;
279 
280   mlir::OpBuilder modBuilder(module.getBodyRegion());
281 
282   decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
283                                                           type);
284   builder.createBlock(&decl.getInitializerRegion(),
285                       decl.getInitializerRegion().end(), {type}, {loc});
286   builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
287   mlir::Value init = getReductionInitValue(loc, type, redId, builder);
288   builder.create<mlir::omp::YieldOp>(loc, init);
289 
290   builder.createBlock(&decl.getReductionRegion(),
291                       decl.getReductionRegion().end(), {type, type},
292                       {loc, loc});
293 
294   builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
295   mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
296   mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
297 
298   mlir::Value reductionOp =
299       createScalarCombiner(builder, loc, redId, type, op1, op2);
300   builder.create<mlir::omp::YieldOp>(loc, reductionOp);
301 
302   return decl;
303 }
304 
305 void ReductionProcessor::addReductionDecl(
306     mlir::Location currentLocation,
307     Fortran::lower::AbstractConverter &converter,
308     const Fortran::parser::OmpReductionClause &reduction,
309     llvm::SmallVectorImpl<mlir::Value> &reductionVars,
310     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
311     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
312         *reductionSymbols) {
313   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
314   mlir::omp::ReductionDeclareOp decl;
315   const auto &redOperator{
316       std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
317   const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
318   if (const auto &redDefinedOp =
319           std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
320     const auto &intrinsicOp{
321         std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
322             redDefinedOp->u)};
323     ReductionIdentifier redId = getReductionType(intrinsicOp);
324     switch (redId) {
325     case ReductionIdentifier::ADD:
326     case ReductionIdentifier::MULTIPLY:
327     case ReductionIdentifier::AND:
328     case ReductionIdentifier::EQV:
329     case ReductionIdentifier::OR:
330     case ReductionIdentifier::NEQV:
331       break;
332     default:
333       TODO(currentLocation,
334            "Reduction of some intrinsic operators is not supported");
335       break;
336     }
337     for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
338       if (const auto *name{
339               Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
340         if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
341           if (reductionSymbols)
342             reductionSymbols->push_back(symbol);
343           mlir::Value symVal = converter.getSymbolAddress(*symbol);
344           if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
345             symVal = declOp.getBase();
346           mlir::Type redType =
347               symVal.getType().cast<fir::ReferenceType>().getEleTy();
348           reductionVars.push_back(symVal);
349           if (redType.isa<fir::LogicalType>())
350             decl = createReductionDecl(
351                 firOpBuilder,
352                 getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
353                 redType, currentLocation);
354           else if (redType.isIntOrIndexOrFloat()) {
355             decl = createReductionDecl(firOpBuilder,
356                                        getReductionName(intrinsicOp, redType),
357                                        redId, redType, currentLocation);
358           } else {
359             TODO(currentLocation, "Reduction of some types is not supported");
360           }
361           reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
362               firOpBuilder.getContext(), decl.getSymName()));
363         }
364       }
365     }
366   } else if (const auto *reductionIntrinsic =
367                  std::get_if<Fortran::parser::ProcedureDesignator>(
368                      &redOperator.u)) {
369     if (ReductionProcessor::supportedIntrinsicProcReduction(
370             *reductionIntrinsic)) {
371       ReductionProcessor::ReductionIdentifier redId =
372           ReductionProcessor::getReductionType(*reductionIntrinsic);
373       for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
374         if (const auto *name{
375                 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
376           if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
377             if (reductionSymbols)
378               reductionSymbols->push_back(symbol);
379             mlir::Value symVal = converter.getSymbolAddress(*symbol);
380             if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
381               symVal = declOp.getBase();
382             mlir::Type redType =
383                 symVal.getType().cast<fir::ReferenceType>().getEleTy();
384             reductionVars.push_back(symVal);
385             assert(redType.isIntOrIndexOrFloat() &&
386                    "Unsupported reduction type");
387             decl = createReductionDecl(
388                 firOpBuilder,
389                 getReductionName(getRealName(*reductionIntrinsic).ToString(),
390                                  redType),
391                 redId, redType, currentLocation);
392             reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
393                 firOpBuilder.getContext(), decl.getSymName()));
394           }
395         }
396       }
397     }
398   }
399 }
400 
401 const Fortran::semantics::SourceName
402 ReductionProcessor::getRealName(const Fortran::parser::Name *name) {
403   return name->symbol->GetUltimate().name();
404 }
405 
406 const Fortran::semantics::SourceName ReductionProcessor::getRealName(
407     const Fortran::parser::ProcedureDesignator &pd) {
408   const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
409   assert(name && "Invalid Reduction Intrinsic.");
410   return getRealName(name);
411 }
412 
413 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
414                                              mlir::Location loc) {
415   switch (redId) {
416   case ReductionIdentifier::ADD:
417   case ReductionIdentifier::OR:
418   case ReductionIdentifier::NEQV:
419     return 0;
420   case ReductionIdentifier::MULTIPLY:
421   case ReductionIdentifier::AND:
422   case ReductionIdentifier::EQV:
423     return 1;
424   default:
425     TODO(loc, "Reduction of some intrinsic operators is not supported");
426   }
427 }
428 
429 } // namespace omp
430 } // namespace lower
431 } // namespace Fortran
432