xref: /llvm-project/flang/lib/Lower/OpenMP/ReductionProcessor.cpp (revision 63e70c05537c54edae975c8b5449ff87444abec2)
1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ReductionProcessor.h"
14 
15 #include "flang/Lower/AbstractConverter.h"
16 #include "flang/Optimizer/Builder/Todo.h"
17 #include "flang/Optimizer/Dialect/FIRType.h"
18 #include "flang/Optimizer/HLFIR/HLFIROps.h"
19 #include "flang/Parser/tools.h"
20 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
21 #include "llvm/Support/CommandLine.h"
22 
23 static llvm::cl::opt<bool> forceByrefReduction(
24     "force-byref-reduction",
25     llvm::cl::desc("Pass all reduction arguments by reference"),
26     llvm::cl::Hidden);
27 
28 namespace Fortran {
29 namespace lower {
30 namespace omp {
31 
32 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
33     const omp::clause::ProcedureDesignator &pd) {
34   auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
35                      getRealName(pd.v.id()).ToString())
36                      .Case("max", ReductionIdentifier::MAX)
37                      .Case("min", ReductionIdentifier::MIN)
38                      .Case("iand", ReductionIdentifier::IAND)
39                      .Case("ior", ReductionIdentifier::IOR)
40                      .Case("ieor", ReductionIdentifier::IEOR)
41                      .Default(std::nullopt);
42   assert(redType && "Invalid Reduction");
43   return *redType;
44 }
45 
46 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
47     omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
48   switch (intrinsicOp) {
49   case omp::clause::DefinedOperator::IntrinsicOperator::Add:
50     return ReductionIdentifier::ADD;
51   case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
52     return ReductionIdentifier::SUBTRACT;
53   case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
54     return ReductionIdentifier::MULTIPLY;
55   case omp::clause::DefinedOperator::IntrinsicOperator::AND:
56     return ReductionIdentifier::AND;
57   case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
58     return ReductionIdentifier::EQV;
59   case omp::clause::DefinedOperator::IntrinsicOperator::OR:
60     return ReductionIdentifier::OR;
61   case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
62     return ReductionIdentifier::NEQV;
63   default:
64     llvm_unreachable("unexpected intrinsic operator in reduction");
65   }
66 }
67 
68 bool ReductionProcessor::supportedIntrinsicProcReduction(
69     const omp::clause::ProcedureDesignator &pd) {
70   Fortran::semantics::Symbol *sym = pd.v.id();
71   if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC))
72     return false;
73   auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
74                      .Case("max", true)
75                      .Case("min", true)
76                      .Case("iand", true)
77                      .Case("ior", true)
78                      .Case("ieor", true)
79                      .Default(false);
80   return redType;
81 }
82 
83 std::string ReductionProcessor::getReductionName(llvm::StringRef name,
84                                                  mlir::Type ty, bool isByRef) {
85   ty = fir::unwrapRefType(ty);
86 
87   // extra string to distinguish reduction functions for variables passed by
88   // reference
89   llvm::StringRef byrefAddition{""};
90   if (isByRef)
91     byrefAddition = "_byref";
92 
93   return (llvm::Twine(name) +
94           (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
95           llvm::Twine(ty.getIntOrFloatBitWidth()) + byrefAddition)
96       .str();
97 }
98 
99 std::string ReductionProcessor::getReductionName(
100     omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, mlir::Type ty,
101     bool isByRef) {
102   std::string reductionName;
103 
104   switch (intrinsicOp) {
105   case omp::clause::DefinedOperator::IntrinsicOperator::Add:
106     reductionName = "add_reduction";
107     break;
108   case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
109     reductionName = "multiply_reduction";
110     break;
111   case omp::clause::DefinedOperator::IntrinsicOperator::AND:
112     return "and_reduction";
113   case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
114     return "eqv_reduction";
115   case omp::clause::DefinedOperator::IntrinsicOperator::OR:
116     return "or_reduction";
117   case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
118     return "neqv_reduction";
119   default:
120     reductionName = "other_reduction";
121     break;
122   }
123 
124   return getReductionName(reductionName, ty, isByRef);
125 }
126 
127 mlir::Value
128 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
129                                           ReductionIdentifier redId,
130                                           fir::FirOpBuilder &builder) {
131   type = fir::unwrapRefType(type);
132   assert((fir::isa_integer(type) || fir::isa_real(type) ||
133           type.isa<fir::LogicalType>()) &&
134          "only integer, logical and real types are currently supported");
135   switch (redId) {
136   case ReductionIdentifier::MAX: {
137     if (auto ty = type.dyn_cast<mlir::FloatType>()) {
138       const llvm::fltSemantics &sem = ty.getFloatSemantics();
139       return builder.createRealConstant(
140           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
141     }
142     unsigned bits = type.getIntOrFloatBitWidth();
143     int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
144     return builder.createIntegerConstant(loc, type, minInt);
145   }
146   case ReductionIdentifier::MIN: {
147     if (auto ty = type.dyn_cast<mlir::FloatType>()) {
148       const llvm::fltSemantics &sem = ty.getFloatSemantics();
149       return builder.createRealConstant(
150           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
151     }
152     unsigned bits = type.getIntOrFloatBitWidth();
153     int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
154     return builder.createIntegerConstant(loc, type, maxInt);
155   }
156   case ReductionIdentifier::IOR: {
157     unsigned bits = type.getIntOrFloatBitWidth();
158     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
159     return builder.createIntegerConstant(loc, type, zeroInt);
160   }
161   case ReductionIdentifier::IEOR: {
162     unsigned bits = type.getIntOrFloatBitWidth();
163     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
164     return builder.createIntegerConstant(loc, type, zeroInt);
165   }
166   case ReductionIdentifier::IAND: {
167     unsigned bits = type.getIntOrFloatBitWidth();
168     int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
169     return builder.createIntegerConstant(loc, type, allOnInt);
170   }
171   case ReductionIdentifier::ADD:
172   case ReductionIdentifier::MULTIPLY:
173   case ReductionIdentifier::AND:
174   case ReductionIdentifier::OR:
175   case ReductionIdentifier::EQV:
176   case ReductionIdentifier::NEQV:
177     if (type.isa<mlir::FloatType>())
178       return builder.create<mlir::arith::ConstantOp>(
179           loc, type,
180           builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc)));
181 
182     if (type.isa<fir::LogicalType>()) {
183       mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
184           loc, builder.getI1Type(),
185           builder.getIntegerAttr(builder.getI1Type(),
186                                  getOperationIdentity(redId, loc)));
187       return builder.createConvert(loc, type, intConst);
188     }
189 
190     return builder.create<mlir::arith::ConstantOp>(
191         loc, type,
192         builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
193   case ReductionIdentifier::ID:
194   case ReductionIdentifier::USER_DEF_OP:
195   case ReductionIdentifier::SUBTRACT:
196     TODO(loc, "Reduction of some identifier types is not supported");
197   }
198   llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
199 }
200 
201 mlir::Value ReductionProcessor::createScalarCombiner(
202     fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
203     mlir::Type type, mlir::Value op1, mlir::Value op2) {
204   mlir::Value reductionOp;
205   type = fir::unwrapRefType(type);
206   switch (redId) {
207   case ReductionIdentifier::MAX:
208     reductionOp =
209         getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
210             builder, type, loc, op1, op2);
211     break;
212   case ReductionIdentifier::MIN:
213     reductionOp =
214         getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
215             builder, type, loc, op1, op2);
216     break;
217   case ReductionIdentifier::IOR:
218     assert((type.isIntOrIndex()) && "only integer is expected");
219     reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
220     break;
221   case ReductionIdentifier::IEOR:
222     assert((type.isIntOrIndex()) && "only integer is expected");
223     reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
224     break;
225   case ReductionIdentifier::IAND:
226     assert((type.isIntOrIndex()) && "only integer is expected");
227     reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
228     break;
229   case ReductionIdentifier::ADD:
230     reductionOp =
231         getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
232             builder, type, loc, op1, op2);
233     break;
234   case ReductionIdentifier::MULTIPLY:
235     reductionOp =
236         getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
237             builder, type, loc, op1, op2);
238     break;
239   case ReductionIdentifier::AND: {
240     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
241     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
242 
243     mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
244 
245     reductionOp = builder.createConvert(loc, type, andiOp);
246     break;
247   }
248   case ReductionIdentifier::OR: {
249     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
250     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
251 
252     mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
253 
254     reductionOp = builder.createConvert(loc, type, oriOp);
255     break;
256   }
257   case ReductionIdentifier::EQV: {
258     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
259     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
260 
261     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
262         loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
263 
264     reductionOp = builder.createConvert(loc, type, cmpiOp);
265     break;
266   }
267   case ReductionIdentifier::NEQV: {
268     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
269     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
270 
271     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
272         loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
273 
274     reductionOp = builder.createConvert(loc, type, cmpiOp);
275     break;
276   }
277   default:
278     TODO(loc, "Reduction of some intrinsic operators is not supported");
279   }
280 
281   return reductionOp;
282 }
283 
284 mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
285     fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
286     const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
287     bool isByRef) {
288   mlir::OpBuilder::InsertionGuard guard(builder);
289   mlir::ModuleOp module = builder.getModule();
290 
291   auto decl =
292       module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
293   if (decl)
294     return decl;
295 
296   mlir::OpBuilder modBuilder(module.getBodyRegion());
297   mlir::Type valTy = fir::unwrapRefType(type);
298   if (!isByRef)
299     type = valTy;
300 
301   decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
302                                                           type);
303   builder.createBlock(&decl.getInitializerRegion(),
304                       decl.getInitializerRegion().end(), {type}, {loc});
305   builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
306 
307   mlir::Value init = getReductionInitValue(loc, type, redId, builder);
308   if (isByRef) {
309     mlir::Value alloca = builder.create<fir::AllocaOp>(loc, valTy);
310     builder.createStoreWithConvert(loc, init, alloca);
311     builder.create<mlir::omp::YieldOp>(loc, alloca);
312   } else {
313     builder.create<mlir::omp::YieldOp>(loc, init);
314   }
315 
316   builder.createBlock(&decl.getReductionRegion(),
317                       decl.getReductionRegion().end(), {type, type},
318                       {loc, loc});
319 
320   builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
321   mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
322   mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
323   mlir::Value outAddr = op1;
324 
325   op1 = builder.loadIfRef(loc, op1);
326   op2 = builder.loadIfRef(loc, op2);
327 
328   mlir::Value reductionOp =
329       createScalarCombiner(builder, loc, redId, type, op1, op2);
330   if (isByRef) {
331     builder.create<fir::StoreOp>(loc, reductionOp, outAddr);
332     builder.create<mlir::omp::YieldOp>(loc, outAddr);
333   } else {
334     builder.create<mlir::omp::YieldOp>(loc, reductionOp);
335   }
336 
337   return decl;
338 }
339 
340 // TODO: By-ref vs by-val reductions are currently toggled for the whole
341 //       operation (possibly effecting multiple reduction variables).
342 //       This could cause a problem with openmp target reductions because
343 //       by-ref trivial types may not be supported.
344 bool ReductionProcessor::doReductionByRef(
345     const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
346   if (reductionVars.empty())
347     return false;
348   if (forceByrefReduction)
349     return true;
350 
351   for (mlir::Value reductionVar : reductionVars) {
352     if (auto declare =
353             mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
354       reductionVar = declare.getMemref();
355 
356     if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
357       return true;
358   }
359   return false;
360 }
361 
362 void ReductionProcessor::addReductionDecl(
363     mlir::Location currentLocation,
364     Fortran::lower::AbstractConverter &converter,
365     const omp::clause::Reduction &reduction,
366     llvm::SmallVectorImpl<mlir::Value> &reductionVars,
367     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
368     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
369         *reductionSymbols) {
370   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
371   mlir::omp::ReductionDeclareOp decl;
372   const auto &redOperator{
373       std::get<omp::clause::ReductionOperator>(reduction.t)};
374   const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
375 
376   if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) {
377     if (const auto *reductionIntrinsic =
378             std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
379       if (!ReductionProcessor::supportedIntrinsicProcReduction(
380               *reductionIntrinsic)) {
381         return;
382       }
383     } else {
384       return;
385     }
386   }
387 
388   // initial pass to collect all reduction vars so we can figure out if this
389   // should happen byref
390   for (const Object &object : objectList) {
391     const Fortran::semantics::Symbol *symbol = object.id();
392     if (reductionSymbols)
393       reductionSymbols->push_back(symbol);
394     mlir::Value symVal = converter.getSymbolAddress(*symbol);
395     if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
396       symVal = declOp.getBase();
397     reductionVars.push_back(symVal);
398   }
399   const bool isByRef = doReductionByRef(reductionVars);
400 
401   if (const auto &redDefinedOp =
402           std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
403     const auto &intrinsicOp{
404         std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
405             redDefinedOp->u)};
406     ReductionIdentifier redId = getReductionType(intrinsicOp);
407     switch (redId) {
408     case ReductionIdentifier::ADD:
409     case ReductionIdentifier::MULTIPLY:
410     case ReductionIdentifier::AND:
411     case ReductionIdentifier::EQV:
412     case ReductionIdentifier::OR:
413     case ReductionIdentifier::NEQV:
414       break;
415     default:
416       TODO(currentLocation,
417            "Reduction of some intrinsic operators is not supported");
418       break;
419     }
420 
421     for (const Object &object : objectList) {
422       const Fortran::semantics::Symbol *symbol = object.id();
423       mlir::Value symVal = converter.getSymbolAddress(*symbol);
424       if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
425         symVal = declOp.getBase();
426       auto redType = symVal.getType().cast<fir::ReferenceType>();
427       if (redType.getEleTy().isa<fir::LogicalType>())
428         decl = createReductionDecl(
429             firOpBuilder,
430             getReductionName(intrinsicOp, firOpBuilder.getI1Type(), isByRef),
431             redId, redType, currentLocation, isByRef);
432       else if (redType.getEleTy().isIntOrIndexOrFloat()) {
433         decl = createReductionDecl(
434             firOpBuilder, getReductionName(intrinsicOp, redType, isByRef),
435             redId, redType, currentLocation, isByRef);
436       } else {
437         TODO(currentLocation, "Reduction of some types is not supported");
438       }
439       reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
440           firOpBuilder.getContext(), decl.getSymName()));
441     }
442   } else if (const auto *reductionIntrinsic =
443                  std::get_if<omp::clause::ProcedureDesignator>(
444                      &redOperator.u)) {
445     if (ReductionProcessor::supportedIntrinsicProcReduction(
446             *reductionIntrinsic)) {
447       ReductionProcessor::ReductionIdentifier redId =
448           ReductionProcessor::getReductionType(*reductionIntrinsic);
449       for (const Object &object : objectList) {
450         const Fortran::semantics::Symbol *symbol = object.id();
451         mlir::Value symVal = converter.getSymbolAddress(*symbol);
452         if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
453           symVal = declOp.getBase();
454         auto redType = symVal.getType().cast<fir::ReferenceType>();
455         assert(redType.getEleTy().isIntOrIndexOrFloat() &&
456                "Unsupported reduction type");
457         decl = createReductionDecl(
458             firOpBuilder,
459             getReductionName(getRealName(*reductionIntrinsic).ToString(),
460                              redType, isByRef),
461             redId, redType, currentLocation, isByRef);
462         reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
463             firOpBuilder.getContext(), decl.getSymName()));
464       }
465     }
466   }
467 }
468 
469 const Fortran::semantics::SourceName
470 ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) {
471   return symbol->GetUltimate().name();
472 }
473 
474 const Fortran::semantics::SourceName
475 ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) {
476   return getRealName(pd.v.id());
477 }
478 
479 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
480                                              mlir::Location loc) {
481   switch (redId) {
482   case ReductionIdentifier::ADD:
483   case ReductionIdentifier::OR:
484   case ReductionIdentifier::NEQV:
485     return 0;
486   case ReductionIdentifier::MULTIPLY:
487   case ReductionIdentifier::AND:
488   case ReductionIdentifier::EQV:
489     return 1;
490   default:
491     TODO(loc, "Reduction of some intrinsic operators is not supported");
492   }
493 }
494 
495 } // namespace omp
496 } // namespace lower
497 } // namespace Fortran
498