xref: /llvm-project/flang/lib/Lower/OpenMP/ReductionProcessor.cpp (revision 3deaa77f1a25f0cdfcf23c34fac0b51293f32f9c)
1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ReductionProcessor.h"
14 
15 #include "flang/Lower/AbstractConverter.h"
16 #include "flang/Optimizer/Builder/HLFIRTools.h"
17 #include "flang/Optimizer/Builder/Todo.h"
18 #include "flang/Optimizer/Dialect/FIRType.h"
19 #include "flang/Optimizer/HLFIR/HLFIROps.h"
20 #include "flang/Parser/tools.h"
21 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
22 #include "llvm/Support/CommandLine.h"
23 
24 static llvm::cl::opt<bool> forceByrefReduction(
25     "force-byref-reduction",
26     llvm::cl::desc("Pass all reduction arguments by reference"),
27     llvm::cl::Hidden);
28 
29 namespace Fortran {
30 namespace lower {
31 namespace omp {
32 
33 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
34     const omp::clause::ProcedureDesignator &pd) {
35   auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
36                      getRealName(pd.v.id()).ToString())
37                      .Case("max", ReductionIdentifier::MAX)
38                      .Case("min", ReductionIdentifier::MIN)
39                      .Case("iand", ReductionIdentifier::IAND)
40                      .Case("ior", ReductionIdentifier::IOR)
41                      .Case("ieor", ReductionIdentifier::IEOR)
42                      .Default(std::nullopt);
43   assert(redType && "Invalid Reduction");
44   return *redType;
45 }
46 
47 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
48     omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
49   switch (intrinsicOp) {
50   case omp::clause::DefinedOperator::IntrinsicOperator::Add:
51     return ReductionIdentifier::ADD;
52   case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
53     return ReductionIdentifier::SUBTRACT;
54   case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
55     return ReductionIdentifier::MULTIPLY;
56   case omp::clause::DefinedOperator::IntrinsicOperator::AND:
57     return ReductionIdentifier::AND;
58   case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
59     return ReductionIdentifier::EQV;
60   case omp::clause::DefinedOperator::IntrinsicOperator::OR:
61     return ReductionIdentifier::OR;
62   case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
63     return ReductionIdentifier::NEQV;
64   default:
65     llvm_unreachable("unexpected intrinsic operator in reduction");
66   }
67 }
68 
69 bool ReductionProcessor::supportedIntrinsicProcReduction(
70     const omp::clause::ProcedureDesignator &pd) {
71   Fortran::semantics::Symbol *sym = pd.v.id();
72   if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC))
73     return false;
74   auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
75                      .Case("max", true)
76                      .Case("min", true)
77                      .Case("iand", true)
78                      .Case("ior", true)
79                      .Case("ieor", true)
80                      .Default(false);
81   return redType;
82 }
83 
84 std::string
85 ReductionProcessor::getReductionName(llvm::StringRef name,
86                                      const fir::KindMapping &kindMap,
87                                      mlir::Type ty, bool isByRef) {
88   ty = fir::unwrapRefType(ty);
89 
90   // extra string to distinguish reduction functions for variables passed by
91   // reference
92   llvm::StringRef byrefAddition{""};
93   if (isByRef)
94     byrefAddition = "_byref";
95 
96   return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str());
97 }
98 
99 std::string ReductionProcessor::getReductionName(
100     omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
101     const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) {
102   std::string reductionName;
103 
104   switch (intrinsicOp) {
105   case omp::clause::DefinedOperator::IntrinsicOperator::Add:
106     reductionName = "add_reduction";
107     break;
108   case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
109     reductionName = "multiply_reduction";
110     break;
111   case omp::clause::DefinedOperator::IntrinsicOperator::AND:
112     return "and_reduction";
113   case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
114     return "eqv_reduction";
115   case omp::clause::DefinedOperator::IntrinsicOperator::OR:
116     return "or_reduction";
117   case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
118     return "neqv_reduction";
119   default:
120     reductionName = "other_reduction";
121     break;
122   }
123 
124   return getReductionName(reductionName, kindMap, ty, isByRef);
125 }
126 
127 mlir::Value
128 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
129                                           ReductionIdentifier redId,
130                                           fir::FirOpBuilder &builder) {
131   type = fir::unwrapRefType(type);
132   if (!fir::isa_integer(type) && !fir::isa_real(type) &&
133       !mlir::isa<fir::LogicalType>(type))
134     TODO(loc, "Reduction of some types is not supported");
135   switch (redId) {
136   case ReductionIdentifier::MAX: {
137     if (auto ty = type.dyn_cast<mlir::FloatType>()) {
138       const llvm::fltSemantics &sem = ty.getFloatSemantics();
139       return builder.createRealConstant(
140           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
141     }
142     unsigned bits = type.getIntOrFloatBitWidth();
143     int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
144     return builder.createIntegerConstant(loc, type, minInt);
145   }
146   case ReductionIdentifier::MIN: {
147     if (auto ty = type.dyn_cast<mlir::FloatType>()) {
148       const llvm::fltSemantics &sem = ty.getFloatSemantics();
149       return builder.createRealConstant(
150           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
151     }
152     unsigned bits = type.getIntOrFloatBitWidth();
153     int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
154     return builder.createIntegerConstant(loc, type, maxInt);
155   }
156   case ReductionIdentifier::IOR: {
157     unsigned bits = type.getIntOrFloatBitWidth();
158     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
159     return builder.createIntegerConstant(loc, type, zeroInt);
160   }
161   case ReductionIdentifier::IEOR: {
162     unsigned bits = type.getIntOrFloatBitWidth();
163     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
164     return builder.createIntegerConstant(loc, type, zeroInt);
165   }
166   case ReductionIdentifier::IAND: {
167     unsigned bits = type.getIntOrFloatBitWidth();
168     int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
169     return builder.createIntegerConstant(loc, type, allOnInt);
170   }
171   case ReductionIdentifier::ADD:
172   case ReductionIdentifier::MULTIPLY:
173   case ReductionIdentifier::AND:
174   case ReductionIdentifier::OR:
175   case ReductionIdentifier::EQV:
176   case ReductionIdentifier::NEQV:
177     if (type.isa<mlir::FloatType>())
178       return builder.create<mlir::arith::ConstantOp>(
179           loc, type,
180           builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc)));
181 
182     if (type.isa<fir::LogicalType>()) {
183       mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
184           loc, builder.getI1Type(),
185           builder.getIntegerAttr(builder.getI1Type(),
186                                  getOperationIdentity(redId, loc)));
187       return builder.createConvert(loc, type, intConst);
188     }
189 
190     return builder.create<mlir::arith::ConstantOp>(
191         loc, type,
192         builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
193   case ReductionIdentifier::ID:
194   case ReductionIdentifier::USER_DEF_OP:
195   case ReductionIdentifier::SUBTRACT:
196     TODO(loc, "Reduction of some identifier types is not supported");
197   }
198   llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
199 }
200 
201 mlir::Value ReductionProcessor::createScalarCombiner(
202     fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
203     mlir::Type type, mlir::Value op1, mlir::Value op2) {
204   mlir::Value reductionOp;
205   type = fir::unwrapRefType(type);
206   switch (redId) {
207   case ReductionIdentifier::MAX:
208     reductionOp =
209         getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
210             builder, type, loc, op1, op2);
211     break;
212   case ReductionIdentifier::MIN:
213     reductionOp =
214         getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
215             builder, type, loc, op1, op2);
216     break;
217   case ReductionIdentifier::IOR:
218     assert((type.isIntOrIndex()) && "only integer is expected");
219     reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
220     break;
221   case ReductionIdentifier::IEOR:
222     assert((type.isIntOrIndex()) && "only integer is expected");
223     reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
224     break;
225   case ReductionIdentifier::IAND:
226     assert((type.isIntOrIndex()) && "only integer is expected");
227     reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
228     break;
229   case ReductionIdentifier::ADD:
230     reductionOp =
231         getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
232             builder, type, loc, op1, op2);
233     break;
234   case ReductionIdentifier::MULTIPLY:
235     reductionOp =
236         getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
237             builder, type, loc, op1, op2);
238     break;
239   case ReductionIdentifier::AND: {
240     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
241     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
242 
243     mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
244 
245     reductionOp = builder.createConvert(loc, type, andiOp);
246     break;
247   }
248   case ReductionIdentifier::OR: {
249     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
250     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
251 
252     mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
253 
254     reductionOp = builder.createConvert(loc, type, oriOp);
255     break;
256   }
257   case ReductionIdentifier::EQV: {
258     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
259     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
260 
261     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
262         loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
263 
264     reductionOp = builder.createConvert(loc, type, cmpiOp);
265     break;
266   }
267   case ReductionIdentifier::NEQV: {
268     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
269     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
270 
271     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
272         loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
273 
274     reductionOp = builder.createConvert(loc, type, cmpiOp);
275     break;
276   }
277   default:
278     TODO(loc, "Reduction of some intrinsic operators is not supported");
279   }
280 
281   return reductionOp;
282 }
283 
284 /// Create reduction combiner region for reduction variables which are boxed
285 /// arrays
286 static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
287                            ReductionProcessor::ReductionIdentifier redId,
288                            fir::BaseBoxType boxTy, mlir::Value lhs,
289                            mlir::Value rhs) {
290   fir::SequenceType seqTy =
291       mlir::dyn_cast_or_null<fir::SequenceType>(boxTy.getEleTy());
292   // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
293   if (!seqTy || seqTy.hasUnknownShape())
294     TODO(loc, "Unsupported boxed type in OpenMP reduction");
295 
296   // load fir.ref<fir.box<...>>
297   mlir::Value lhsAddr = lhs;
298   lhs = builder.create<fir::LoadOp>(loc, lhs);
299   rhs = builder.create<fir::LoadOp>(loc, rhs);
300 
301   const unsigned rank = seqTy.getDimension();
302   llvm::SmallVector<mlir::Value> extents;
303   extents.reserve(rank);
304   llvm::SmallVector<mlir::Value> lbAndExtents;
305   lbAndExtents.reserve(rank * 2);
306 
307   // Get box lowerbounds and extents:
308   mlir::Type idxTy = builder.getIndexType();
309   for (unsigned i = 0; i < rank; ++i) {
310     // TODO: ideally we want to hoist box reads out of the critical section.
311     // We could do this by having box dimensions in block arguments like
312     // OpenACC does
313     mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i);
314     auto dimInfo =
315         builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, lhs, dim);
316     extents.push_back(dimInfo.getExtent());
317     lbAndExtents.push_back(dimInfo.getLowerBound());
318     lbAndExtents.push_back(dimInfo.getExtent());
319   }
320 
321   auto shapeShiftTy = fir::ShapeShiftType::get(builder.getContext(), rank);
322   auto shapeShift =
323       builder.create<fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents);
324 
325   // Iterate over array elements, applying the equivalent scalar reduction:
326 
327   // A hlfir::elemental here gets inlined with a temporary so create the
328   // loop nest directly.
329   // This function already controls all of the code in this region so we
330   // know this won't miss any opportuinties for clever elemental inlining
331   hlfir::LoopNest nest =
332       hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
333   builder.setInsertionPointToStart(nest.innerLoop.getBody());
334   mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
335   auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
336       loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
337       nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
338   auto rhsEleAddr = builder.create<fir::ArrayCoorOp>(
339       loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{},
340       nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
341   auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr);
342   auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr);
343   mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner(
344       builder, loc, redId, refTy, lhsEle, rhsEle);
345   builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
346 
347   builder.setInsertionPointAfter(nest.outerLoop);
348   builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
349 }
350 
351 // generate combiner region for reduction operations
352 static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
353                         ReductionProcessor::ReductionIdentifier redId,
354                         mlir::Type ty, mlir::Value lhs, mlir::Value rhs,
355                         bool isByRef) {
356   ty = fir::unwrapRefType(ty);
357 
358   if (fir::isa_trivial(ty)) {
359     mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs);
360     mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs);
361 
362     mlir::Value result = ReductionProcessor::createScalarCombiner(
363         builder, loc, redId, ty, lhsLoaded, rhsLoaded);
364     if (isByRef) {
365       builder.create<fir::StoreOp>(loc, result, lhs);
366       builder.create<mlir::omp::YieldOp>(loc, lhs);
367     } else {
368       builder.create<mlir::omp::YieldOp>(loc, result);
369     }
370     return;
371   }
372   // all arrays should have been boxed
373   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
374     genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs);
375     return;
376   }
377 
378   TODO(loc, "OpenMP genCombiner for unsupported reduction variable type");
379 }
380 
381 static mlir::Value
382 createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc,
383                           const ReductionProcessor::ReductionIdentifier redId,
384                           mlir::Type type, bool isByRef) {
385   mlir::Type ty = fir::unwrapRefType(type);
386   mlir::Value initValue = ReductionProcessor::getReductionInitValue(
387       loc, fir::unwrapSeqOrBoxedSeqType(ty), redId, builder);
388 
389   if (fir::isa_trivial(ty)) {
390     if (isByRef) {
391       mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty);
392       builder.createStoreWithConvert(loc, initValue, alloca);
393       return alloca;
394     }
395     // by val
396     return initValue;
397   }
398 
399   // all arrays are boxed
400   if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
401     assert(isByRef && "passing arrays by value is unsupported");
402     // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
403     mlir::Type innerTy = fir::extractSequenceType(boxTy);
404     if (!mlir::isa<fir::SequenceType>(innerTy))
405       TODO(loc, "Unsupported boxed type for reduction");
406     // Create the private copy from the initial fir.box:
407     hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)};
408 
409     // TODO: if the whole reduction is nested inside of a loop, this alloca
410     // could lead to a stack overflow (the memory is only freed at the end of
411     // the stack frame). The reduction declare operation needs a deallocation
412     // region to undo the init region.
413     hlfir::Entity temp = createStackTempFromMold(loc, builder, source);
414 
415     // Put the temporary inside of a box:
416     hlfir::Entity box = hlfir::genVariableBox(loc, builder, temp);
417     builder.create<hlfir::AssignOp>(loc, initValue, box);
418     mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty);
419     builder.create<fir::StoreOp>(loc, box, boxAlloca);
420     return boxAlloca;
421   }
422 
423   TODO(loc, "createReductionInitRegion for unsupported type");
424 }
425 
426 mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
427     fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
428     const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
429     bool isByRef) {
430   mlir::OpBuilder::InsertionGuard guard(builder);
431   mlir::ModuleOp module = builder.getModule();
432 
433   assert(!reductionOpName.empty());
434 
435   auto decl =
436       module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName);
437   if (decl)
438     return decl;
439 
440   mlir::OpBuilder modBuilder(module.getBodyRegion());
441   mlir::Type valTy = fir::unwrapRefType(type);
442   if (!isByRef)
443     type = valTy;
444 
445   decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName,
446                                                           type);
447   builder.createBlock(&decl.getInitializerRegion(),
448                       decl.getInitializerRegion().end(), {type}, {loc});
449   builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
450 
451   mlir::Value init =
452       createReductionInitRegion(builder, loc, redId, type, isByRef);
453   builder.create<mlir::omp::YieldOp>(loc, init);
454 
455   builder.createBlock(&decl.getReductionRegion(),
456                       decl.getReductionRegion().end(), {type, type},
457                       {loc, loc});
458 
459   builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
460   mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
461   mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
462   genCombiner(builder, loc, redId, type, op1, op2, isByRef);
463 
464   return decl;
465 }
466 
467 // TODO: By-ref vs by-val reductions are currently toggled for the whole
468 //       operation (possibly effecting multiple reduction variables).
469 //       This could cause a problem with openmp target reductions because
470 //       by-ref trivial types may not be supported.
471 bool ReductionProcessor::doReductionByRef(
472     const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
473   if (reductionVars.empty())
474     return false;
475   if (forceByrefReduction)
476     return true;
477 
478   for (mlir::Value reductionVar : reductionVars) {
479     if (auto declare =
480             mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
481       reductionVar = declare.getMemref();
482 
483     if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
484       return true;
485   }
486   return false;
487 }
488 
489 void ReductionProcessor::addDeclareReduction(
490     mlir::Location currentLocation,
491     Fortran::lower::AbstractConverter &converter,
492     const omp::clause::Reduction &reduction,
493     llvm::SmallVectorImpl<mlir::Value> &reductionVars,
494     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
495     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
496         *reductionSymbols) {
497   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
498   mlir::omp::DeclareReductionOp decl;
499   const auto &redOperator{
500       std::get<omp::clause::ReductionOperator>(reduction.t)};
501   const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
502 
503   if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) {
504     if (const auto *reductionIntrinsic =
505             std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
506       if (!ReductionProcessor::supportedIntrinsicProcReduction(
507               *reductionIntrinsic)) {
508         return;
509       }
510     } else {
511       return;
512     }
513   }
514 
515   // initial pass to collect all reduction vars so we can figure out if this
516   // should happen byref
517   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
518   for (const Object &object : objectList) {
519     const Fortran::semantics::Symbol *symbol = object.id();
520     if (reductionSymbols)
521       reductionSymbols->push_back(symbol);
522     mlir::Value symVal = converter.getSymbolAddress(*symbol);
523     auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
524 
525     // all arrays must be boxed so that we have convenient access to all the
526     // information needed to iterate over the array
527     if (mlir::isa<fir::SequenceType>(redType.getEleTy())) {
528       hlfir::Entity entity{symVal};
529       entity = genVariableBox(currentLocation, builder, entity);
530       mlir::Value box = entity.getBase();
531 
532       // Always pass the box by reference so that the OpenMP dialect
533       // verifiers don't need to know anything about fir.box
534       auto alloca =
535           builder.create<fir::AllocaOp>(currentLocation, box.getType());
536       builder.create<fir::StoreOp>(currentLocation, box, alloca);
537 
538       symVal = alloca;
539       redType = mlir::cast<fir::ReferenceType>(symVal.getType());
540     } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
541       symVal = declOp.getBase();
542     }
543 
544     reductionVars.push_back(symVal);
545   }
546   const bool isByRef = doReductionByRef(reductionVars);
547 
548   if (const auto &redDefinedOp =
549           std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
550     const auto &intrinsicOp{
551         std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
552             redDefinedOp->u)};
553     ReductionIdentifier redId = getReductionType(intrinsicOp);
554     switch (redId) {
555     case ReductionIdentifier::ADD:
556     case ReductionIdentifier::MULTIPLY:
557     case ReductionIdentifier::AND:
558     case ReductionIdentifier::EQV:
559     case ReductionIdentifier::OR:
560     case ReductionIdentifier::NEQV:
561       break;
562     default:
563       TODO(currentLocation,
564            "Reduction of some intrinsic operators is not supported");
565       break;
566     }
567 
568     for (mlir::Value symVal : reductionVars) {
569       auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
570       const auto &kindMap = firOpBuilder.getKindMap();
571       if (redType.getEleTy().isa<fir::LogicalType>())
572         decl = createDeclareReduction(firOpBuilder,
573                                       getReductionName(intrinsicOp, kindMap,
574                                                        firOpBuilder.getI1Type(),
575                                                        isByRef),
576                                       redId, redType, currentLocation, isByRef);
577       else
578         decl = createDeclareReduction(
579             firOpBuilder,
580             getReductionName(intrinsicOp, kindMap, redType, isByRef), redId,
581             redType, currentLocation, isByRef);
582       reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
583           firOpBuilder.getContext(), decl.getSymName()));
584     }
585   } else if (const auto *reductionIntrinsic =
586                  std::get_if<omp::clause::ProcedureDesignator>(
587                      &redOperator.u)) {
588     if (ReductionProcessor::supportedIntrinsicProcReduction(
589             *reductionIntrinsic)) {
590       ReductionProcessor::ReductionIdentifier redId =
591           ReductionProcessor::getReductionType(*reductionIntrinsic);
592       for (const Object &object : objectList) {
593         const Fortran::semantics::Symbol *symbol = object.id();
594         mlir::Value symVal = converter.getSymbolAddress(*symbol);
595         if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
596           symVal = declOp.getBase();
597         auto redType = symVal.getType().cast<fir::ReferenceType>();
598         if (!redType.getEleTy().isIntOrIndexOrFloat())
599           TODO(currentLocation, "User Defined Reduction on non-trivial type");
600         decl = createDeclareReduction(
601             firOpBuilder,
602             getReductionName(getRealName(*reductionIntrinsic).ToString(),
603                              firOpBuilder.getKindMap(), redType, isByRef),
604             redId, redType, currentLocation, isByRef);
605         reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
606             firOpBuilder.getContext(), decl.getSymName()));
607       }
608     }
609   }
610 }
611 
612 const Fortran::semantics::SourceName
613 ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) {
614   return symbol->GetUltimate().name();
615 }
616 
617 const Fortran::semantics::SourceName
618 ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) {
619   return getRealName(pd.v.id());
620 }
621 
622 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
623                                              mlir::Location loc) {
624   switch (redId) {
625   case ReductionIdentifier::ADD:
626   case ReductionIdentifier::OR:
627   case ReductionIdentifier::NEQV:
628     return 0;
629   case ReductionIdentifier::MULTIPLY:
630   case ReductionIdentifier::AND:
631   case ReductionIdentifier::EQV:
632     return 1;
633   default:
634     TODO(loc, "Reduction of some intrinsic operators is not supported");
635   }
636 }
637 
638 } // namespace omp
639 } // namespace lower
640 } // namespace Fortran
641