xref: /llvm-project/flang/lib/Lower/OpenMP/ReductionProcessor.cpp (revision 221f438af1c1292d787b58da99a5a7b371888456)
1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ReductionProcessor.h"
14 
15 #include "flang/Lower/AbstractConverter.h"
16 #include "flang/Lower/ConvertType.h"
17 #include "flang/Lower/SymbolMap.h"
18 #include "flang/Optimizer/Builder/Complex.h"
19 #include "flang/Optimizer/Builder/HLFIRTools.h"
20 #include "flang/Optimizer/Builder/Todo.h"
21 #include "flang/Optimizer/Dialect/FIRType.h"
22 #include "flang/Optimizer/HLFIR/HLFIROps.h"
23 #include "flang/Parser/tools.h"
24 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
25 #include "llvm/Support/CommandLine.h"
26 
27 static llvm::cl::opt<bool> forceByrefReduction(
28     "force-byref-reduction",
29     llvm::cl::desc("Pass all reduction arguments by reference"),
30     llvm::cl::Hidden);
31 
32 namespace Fortran {
33 namespace lower {
34 namespace omp {
35 
36 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
37     const omp::clause::ProcedureDesignator &pd) {
38   auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
39                      getRealName(pd.v.id()).ToString())
40                      .Case("max", ReductionIdentifier::MAX)
41                      .Case("min", ReductionIdentifier::MIN)
42                      .Case("iand", ReductionIdentifier::IAND)
43                      .Case("ior", ReductionIdentifier::IOR)
44                      .Case("ieor", ReductionIdentifier::IEOR)
45                      .Default(std::nullopt);
46   assert(redType && "Invalid Reduction");
47   return *redType;
48 }
49 
50 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
51     omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
52   switch (intrinsicOp) {
53   case omp::clause::DefinedOperator::IntrinsicOperator::Add:
54     return ReductionIdentifier::ADD;
55   case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
56     return ReductionIdentifier::SUBTRACT;
57   case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
58     return ReductionIdentifier::MULTIPLY;
59   case omp::clause::DefinedOperator::IntrinsicOperator::AND:
60     return ReductionIdentifier::AND;
61   case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
62     return ReductionIdentifier::EQV;
63   case omp::clause::DefinedOperator::IntrinsicOperator::OR:
64     return ReductionIdentifier::OR;
65   case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
66     return ReductionIdentifier::NEQV;
67   default:
68     llvm_unreachable("unexpected intrinsic operator in reduction");
69   }
70 }
71 
72 bool ReductionProcessor::supportedIntrinsicProcReduction(
73     const omp::clause::ProcedureDesignator &pd) {
74   Fortran::semantics::Symbol *sym = pd.v.id();
75   if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC))
76     return false;
77   auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
78                      .Case("max", true)
79                      .Case("min", true)
80                      .Case("iand", true)
81                      .Case("ior", true)
82                      .Case("ieor", true)
83                      .Default(false);
84   return redType;
85 }
86 
87 std::string
88 ReductionProcessor::getReductionName(llvm::StringRef name,
89                                      const fir::KindMapping &kindMap,
90                                      mlir::Type ty, bool isByRef) {
91   ty = fir::unwrapRefType(ty);
92 
93   // extra string to distinguish reduction functions for variables passed by
94   // reference
95   llvm::StringRef byrefAddition{""};
96   if (isByRef)
97     byrefAddition = "_byref";
98 
99   return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str());
100 }
101 
102 std::string ReductionProcessor::getReductionName(
103     omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
104     const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) {
105   std::string reductionName;
106 
107   switch (intrinsicOp) {
108   case omp::clause::DefinedOperator::IntrinsicOperator::Add:
109     reductionName = "add_reduction";
110     break;
111   case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
112     reductionName = "multiply_reduction";
113     break;
114   case omp::clause::DefinedOperator::IntrinsicOperator::AND:
115     return "and_reduction";
116   case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
117     return "eqv_reduction";
118   case omp::clause::DefinedOperator::IntrinsicOperator::OR:
119     return "or_reduction";
120   case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
121     return "neqv_reduction";
122   default:
123     reductionName = "other_reduction";
124     break;
125   }
126 
127   return getReductionName(reductionName, kindMap, ty, isByRef);
128 }
129 
130 mlir::Value
131 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
132                                           ReductionIdentifier redId,
133                                           fir::FirOpBuilder &builder) {
134   type = fir::unwrapRefType(type);
135   if (!fir::isa_integer(type) && !fir::isa_real(type) &&
136       !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
137     TODO(loc, "Reduction of some types is not supported");
138   switch (redId) {
139   case ReductionIdentifier::MAX: {
140     if (auto ty = type.dyn_cast<mlir::FloatType>()) {
141       const llvm::fltSemantics &sem = ty.getFloatSemantics();
142       return builder.createRealConstant(
143           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
144     }
145     unsigned bits = type.getIntOrFloatBitWidth();
146     int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
147     return builder.createIntegerConstant(loc, type, minInt);
148   }
149   case ReductionIdentifier::MIN: {
150     if (auto ty = type.dyn_cast<mlir::FloatType>()) {
151       const llvm::fltSemantics &sem = ty.getFloatSemantics();
152       return builder.createRealConstant(
153           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
154     }
155     unsigned bits = type.getIntOrFloatBitWidth();
156     int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
157     return builder.createIntegerConstant(loc, type, maxInt);
158   }
159   case ReductionIdentifier::IOR: {
160     unsigned bits = type.getIntOrFloatBitWidth();
161     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
162     return builder.createIntegerConstant(loc, type, zeroInt);
163   }
164   case ReductionIdentifier::IEOR: {
165     unsigned bits = type.getIntOrFloatBitWidth();
166     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
167     return builder.createIntegerConstant(loc, type, zeroInt);
168   }
169   case ReductionIdentifier::IAND: {
170     unsigned bits = type.getIntOrFloatBitWidth();
171     int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
172     return builder.createIntegerConstant(loc, type, allOnInt);
173   }
174   case ReductionIdentifier::ADD:
175   case ReductionIdentifier::MULTIPLY:
176   case ReductionIdentifier::AND:
177   case ReductionIdentifier::OR:
178   case ReductionIdentifier::EQV:
179   case ReductionIdentifier::NEQV:
180     if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) {
181       mlir::Type realTy =
182           Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind());
183       mlir::Value initRe = builder.createRealConstant(
184           loc, realTy, getOperationIdentity(redId, loc));
185       mlir::Value initIm = builder.createRealConstant(loc, realTy, 0);
186 
187       return fir::factory::Complex{builder, loc}.createComplex(type, initRe,
188                                                                initIm);
189     }
190     if (type.isa<mlir::FloatType>())
191       return builder.create<mlir::arith::ConstantOp>(
192           loc, type,
193           builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc)));
194 
195     if (type.isa<fir::LogicalType>()) {
196       mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
197           loc, builder.getI1Type(),
198           builder.getIntegerAttr(builder.getI1Type(),
199                                  getOperationIdentity(redId, loc)));
200       return builder.createConvert(loc, type, intConst);
201     }
202 
203     return builder.create<mlir::arith::ConstantOp>(
204         loc, type,
205         builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
206   case ReductionIdentifier::ID:
207   case ReductionIdentifier::USER_DEF_OP:
208   case ReductionIdentifier::SUBTRACT:
209     TODO(loc, "Reduction of some identifier types is not supported");
210   }
211   llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
212 }
213 
214 mlir::Value ReductionProcessor::createScalarCombiner(
215     fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
216     mlir::Type type, mlir::Value op1, mlir::Value op2) {
217   mlir::Value reductionOp;
218   type = fir::unwrapRefType(type);
219   switch (redId) {
220   case ReductionIdentifier::MAX:
221     reductionOp =
222         getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
223             builder, type, loc, op1, op2);
224     break;
225   case ReductionIdentifier::MIN:
226     reductionOp =
227         getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
228             builder, type, loc, op1, op2);
229     break;
230   case ReductionIdentifier::IOR:
231     assert((type.isIntOrIndex()) && "only integer is expected");
232     reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
233     break;
234   case ReductionIdentifier::IEOR:
235     assert((type.isIntOrIndex()) && "only integer is expected");
236     reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
237     break;
238   case ReductionIdentifier::IAND:
239     assert((type.isIntOrIndex()) && "only integer is expected");
240     reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
241     break;
242   case ReductionIdentifier::ADD:
243     reductionOp =
244         getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
245                               fir::AddcOp>(builder, type, loc, op1, op2);
246     break;
247   case ReductionIdentifier::MULTIPLY:
248     reductionOp =
249         getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
250                               fir::MulcOp>(builder, type, loc, op1, op2);
251     break;
252   case ReductionIdentifier::AND: {
253     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
254     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
255 
256     mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
257 
258     reductionOp = builder.createConvert(loc, type, andiOp);
259     break;
260   }
261   case ReductionIdentifier::OR: {
262     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
263     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
264 
265     mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
266 
267     reductionOp = builder.createConvert(loc, type, oriOp);
268     break;
269   }
270   case ReductionIdentifier::EQV: {
271     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
272     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
273 
274     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
275         loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
276 
277     reductionOp = builder.createConvert(loc, type, cmpiOp);
278     break;
279   }
280   case ReductionIdentifier::NEQV: {
281     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
282     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
283 
284     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
285         loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
286 
287     reductionOp = builder.createConvert(loc, type, cmpiOp);
288     break;
289   }
290   default:
291     TODO(loc, "Reduction of some intrinsic operators is not supported");
292   }
293 
294   return reductionOp;
295 }
296 
297 /// Create reduction combiner region for reduction variables which are boxed
298 /// arrays
299 static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
300                            ReductionProcessor::ReductionIdentifier redId,
301                            fir::BaseBoxType boxTy, mlir::Value lhs,
302                            mlir::Value rhs) {
303   fir::SequenceType seqTy =
304       mlir::dyn_cast_or_null<fir::SequenceType>(boxTy.getEleTy());
305   // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
306   if (!seqTy || seqTy.hasUnknownShape())
307     TODO(loc, "Unsupported boxed type in OpenMP reduction");
308 
309   // load fir.ref<fir.box<...>>
310   mlir::Value lhsAddr = lhs;
311   lhs = builder.create<fir::LoadOp>(loc, lhs);
312   rhs = builder.create<fir::LoadOp>(loc, rhs);
313 
314   const unsigned rank = seqTy.getDimension();
315   llvm::SmallVector<mlir::Value> extents;
316   extents.reserve(rank);
317   llvm::SmallVector<mlir::Value> lbAndExtents;
318   lbAndExtents.reserve(rank * 2);
319 
320   // Get box lowerbounds and extents:
321   mlir::Type idxTy = builder.getIndexType();
322   for (unsigned i = 0; i < rank; ++i) {
323     // TODO: ideally we want to hoist box reads out of the critical section.
324     // We could do this by having box dimensions in block arguments like
325     // OpenACC does
326     mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i);
327     auto dimInfo =
328         builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, lhs, dim);
329     extents.push_back(dimInfo.getExtent());
330     lbAndExtents.push_back(dimInfo.getLowerBound());
331     lbAndExtents.push_back(dimInfo.getExtent());
332   }
333 
334   auto shapeShiftTy = fir::ShapeShiftType::get(builder.getContext(), rank);
335   auto shapeShift =
336       builder.create<fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents);
337 
338   // Iterate over array elements, applying the equivalent scalar reduction:
339 
340   // A hlfir::elemental here gets inlined with a temporary so create the
341   // loop nest directly.
342   // This function already controls all of the code in this region so we
343   // know this won't miss any opportuinties for clever elemental inlining
344   hlfir::LoopNest nest =
345       hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
346   builder.setInsertionPointToStart(nest.innerLoop.getBody());
347   mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
348   auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
349       loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
350       nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
351   auto rhsEleAddr = builder.create<fir::ArrayCoorOp>(
352       loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{},
353       nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
354   auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr);
355   auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr);
356   mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner(
357       builder, loc, redId, refTy, lhsEle, rhsEle);
358   builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
359 
360   builder.setInsertionPointAfter(nest.outerLoop);
361   builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
362 }
363 
364 // generate combiner region for reduction operations
365 static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
366                         ReductionProcessor::ReductionIdentifier redId,
367                         mlir::Type ty, mlir::Value lhs, mlir::Value rhs,
368                         bool isByRef) {
369   ty = fir::unwrapRefType(ty);
370 
371   if (fir::isa_trivial(ty)) {
372     mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs);
373     mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs);
374 
375     mlir::Value result = ReductionProcessor::createScalarCombiner(
376         builder, loc, redId, ty, lhsLoaded, rhsLoaded);
377     if (isByRef) {
378       builder.create<fir::StoreOp>(loc, result, lhs);
379       builder.create<mlir::omp::YieldOp>(loc, lhs);
380     } else {
381       builder.create<mlir::omp::YieldOp>(loc, result);
382     }
383     return;
384   }
385   // all arrays should have been boxed
386   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
387     genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs);
388     return;
389   }
390 
391   TODO(loc, "OpenMP genCombiner for unsupported reduction variable type");
392 }
393 
394 static mlir::Value
395 createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc,
396                           const ReductionProcessor::ReductionIdentifier redId,
397                           mlir::Type type, bool isByRef) {
398   mlir::Type ty = fir::unwrapRefType(type);
399   mlir::Value initValue = ReductionProcessor::getReductionInitValue(
400       loc, fir::unwrapSeqOrBoxedSeqType(ty), redId, builder);
401 
402   if (fir::isa_trivial(ty)) {
403     if (isByRef) {
404       mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty);
405       builder.createStoreWithConvert(loc, initValue, alloca);
406       return alloca;
407     }
408     // by val
409     return initValue;
410   }
411 
412   // all arrays are boxed
413   if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
414     assert(isByRef && "passing arrays by value is unsupported");
415     // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
416     mlir::Type innerTy = fir::extractSequenceType(boxTy);
417     if (!mlir::isa<fir::SequenceType>(innerTy))
418       TODO(loc, "Unsupported boxed type for reduction");
419     // Create the private copy from the initial fir.box:
420     hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)};
421 
422     // TODO: if the whole reduction is nested inside of a loop, this alloca
423     // could lead to a stack overflow (the memory is only freed at the end of
424     // the stack frame). The reduction declare operation needs a deallocation
425     // region to undo the init region.
426     hlfir::Entity temp = createStackTempFromMold(loc, builder, source);
427 
428     // Put the temporary inside of a box:
429     hlfir::Entity box = hlfir::genVariableBox(loc, builder, temp);
430     builder.create<hlfir::AssignOp>(loc, initValue, box);
431     mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty);
432     builder.create<fir::StoreOp>(loc, box, boxAlloca);
433     return boxAlloca;
434   }
435 
436   TODO(loc, "createReductionInitRegion for unsupported type");
437 }
438 
439 mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
440     fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
441     const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
442     bool isByRef) {
443   mlir::OpBuilder::InsertionGuard guard(builder);
444   mlir::ModuleOp module = builder.getModule();
445 
446   assert(!reductionOpName.empty());
447 
448   auto decl =
449       module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName);
450   if (decl)
451     return decl;
452 
453   mlir::OpBuilder modBuilder(module.getBodyRegion());
454   mlir::Type valTy = fir::unwrapRefType(type);
455   if (!isByRef)
456     type = valTy;
457 
458   decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName,
459                                                           type);
460   builder.createBlock(&decl.getInitializerRegion(),
461                       decl.getInitializerRegion().end(), {type}, {loc});
462   builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
463 
464   mlir::Value init =
465       createReductionInitRegion(builder, loc, redId, type, isByRef);
466   builder.create<mlir::omp::YieldOp>(loc, init);
467 
468   builder.createBlock(&decl.getReductionRegion(),
469                       decl.getReductionRegion().end(), {type, type},
470                       {loc, loc});
471 
472   builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
473   mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
474   mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
475   genCombiner(builder, loc, redId, type, op1, op2, isByRef);
476 
477   return decl;
478 }
479 
480 // TODO: By-ref vs by-val reductions are currently toggled for the whole
481 //       operation (possibly effecting multiple reduction variables).
482 //       This could cause a problem with openmp target reductions because
483 //       by-ref trivial types may not be supported.
484 bool ReductionProcessor::doReductionByRef(
485     const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
486   if (reductionVars.empty())
487     return false;
488   if (forceByrefReduction)
489     return true;
490 
491   for (mlir::Value reductionVar : reductionVars) {
492     if (auto declare =
493             mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
494       reductionVar = declare.getMemref();
495 
496     if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
497       return true;
498   }
499   return false;
500 }
501 
502 void ReductionProcessor::addDeclareReduction(
503     mlir::Location currentLocation,
504     Fortran::lower::AbstractConverter &converter,
505     const omp::clause::Reduction &reduction,
506     llvm::SmallVectorImpl<mlir::Value> &reductionVars,
507     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
508     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
509         *reductionSymbols) {
510   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
511   mlir::omp::DeclareReductionOp decl;
512   const auto &redOperatorList{
513       std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)};
514   assert(redOperatorList.size() == 1 && "Expecting single operator");
515   const auto &redOperator = redOperatorList.front();
516   const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
517 
518   if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) {
519     if (const auto *reductionIntrinsic =
520             std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
521       if (!ReductionProcessor::supportedIntrinsicProcReduction(
522               *reductionIntrinsic)) {
523         return;
524       }
525     } else {
526       return;
527     }
528   }
529 
530   // initial pass to collect all reduction vars so we can figure out if this
531   // should happen byref
532   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
533   for (const Object &object : objectList) {
534     const Fortran::semantics::Symbol *symbol = object.id();
535     if (reductionSymbols)
536       reductionSymbols->push_back(symbol);
537     mlir::Value symVal = converter.getSymbolAddress(*symbol);
538     mlir::Type eleType;
539     auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType());
540     if (refType)
541       eleType = refType.getEleTy();
542     else
543       eleType = symVal.getType();
544 
545     // all arrays must be boxed so that we have convenient access to all the
546     // information needed to iterate over the array
547     if (mlir::isa<fir::SequenceType>(eleType)) {
548       // For Host associated symbols, use `SymbolBox` instead
549       Fortran::lower::SymbolBox symBox =
550           converter.lookupOneLevelUpSymbol(*symbol);
551       hlfir::Entity entity{symBox.getAddr()};
552       entity = genVariableBox(currentLocation, builder, entity);
553       mlir::Value box = entity.getBase();
554 
555       // Always pass the box by reference so that the OpenMP dialect
556       // verifiers don't need to know anything about fir.box
557       auto alloca =
558           builder.create<fir::AllocaOp>(currentLocation, box.getType());
559       builder.create<fir::StoreOp>(currentLocation, box, alloca);
560 
561       symVal = alloca;
562     } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) {
563       // boxed arrays are passed as values not by reference. Unfortunately,
564       // we can't pass a box by value to omp.redution_declare, so turn it
565       // into a reference
566 
567       auto alloca =
568           builder.create<fir::AllocaOp>(currentLocation, symVal.getType());
569       builder.create<fir::StoreOp>(currentLocation, symVal, alloca);
570       symVal = alloca;
571     } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
572       symVal = declOp.getBase();
573     }
574 
575     // this isn't the same as the by-val and by-ref passing later in the
576     // pipeline. Both styles assume that the variable is a reference at
577     // this point
578     assert(mlir::isa<fir::ReferenceType>(symVal.getType()) &&
579            "reduction input var is a reference");
580 
581     reductionVars.push_back(symVal);
582   }
583   const bool isByRef = doReductionByRef(reductionVars);
584 
585   if (const auto &redDefinedOp =
586           std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
587     const auto &intrinsicOp{
588         std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
589             redDefinedOp->u)};
590     ReductionIdentifier redId = getReductionType(intrinsicOp);
591     switch (redId) {
592     case ReductionIdentifier::ADD:
593     case ReductionIdentifier::MULTIPLY:
594     case ReductionIdentifier::AND:
595     case ReductionIdentifier::EQV:
596     case ReductionIdentifier::OR:
597     case ReductionIdentifier::NEQV:
598       break;
599     default:
600       TODO(currentLocation,
601            "Reduction of some intrinsic operators is not supported");
602       break;
603     }
604 
605     for (mlir::Value symVal : reductionVars) {
606       auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
607       const auto &kindMap = firOpBuilder.getKindMap();
608       if (redType.getEleTy().isa<fir::LogicalType>())
609         decl = createDeclareReduction(firOpBuilder,
610                                       getReductionName(intrinsicOp, kindMap,
611                                                        firOpBuilder.getI1Type(),
612                                                        isByRef),
613                                       redId, redType, currentLocation, isByRef);
614       else
615         decl = createDeclareReduction(
616             firOpBuilder,
617             getReductionName(intrinsicOp, kindMap, redType, isByRef), redId,
618             redType, currentLocation, isByRef);
619       reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
620           firOpBuilder.getContext(), decl.getSymName()));
621     }
622   } else if (const auto *reductionIntrinsic =
623                  std::get_if<omp::clause::ProcedureDesignator>(
624                      &redOperator.u)) {
625     if (ReductionProcessor::supportedIntrinsicProcReduction(
626             *reductionIntrinsic)) {
627       ReductionProcessor::ReductionIdentifier redId =
628           ReductionProcessor::getReductionType(*reductionIntrinsic);
629       for (const Object &object : objectList) {
630         const Fortran::semantics::Symbol *symbol = object.id();
631         mlir::Value symVal = converter.getSymbolAddress(*symbol);
632         if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
633           symVal = declOp.getBase();
634         auto redType = symVal.getType().cast<fir::ReferenceType>();
635         if (!redType.getEleTy().isIntOrIndexOrFloat())
636           TODO(currentLocation, "User Defined Reduction on non-trivial type");
637         decl = createDeclareReduction(
638             firOpBuilder,
639             getReductionName(getRealName(*reductionIntrinsic).ToString(),
640                              firOpBuilder.getKindMap(), redType, isByRef),
641             redId, redType, currentLocation, isByRef);
642         reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
643             firOpBuilder.getContext(), decl.getSymName()));
644       }
645     }
646   }
647 }
648 
649 const Fortran::semantics::SourceName
650 ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) {
651   return symbol->GetUltimate().name();
652 }
653 
654 const Fortran::semantics::SourceName
655 ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) {
656   return getRealName(pd.v.id());
657 }
658 
659 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
660                                              mlir::Location loc) {
661   switch (redId) {
662   case ReductionIdentifier::ADD:
663   case ReductionIdentifier::OR:
664   case ReductionIdentifier::NEQV:
665     return 0;
666   case ReductionIdentifier::MULTIPLY:
667   case ReductionIdentifier::AND:
668   case ReductionIdentifier::EQV:
669     return 1;
670   default:
671     TODO(loc, "Reduction of some intrinsic operators is not supported");
672   }
673 }
674 
675 } // namespace omp
676 } // namespace lower
677 } // namespace Fortran
678