xref: /llvm-project/flang/lib/Lower/OpenMP/ReductionProcessor.cpp (revision d84252e064b3f35aa879c10e207f77e931f351d9)
1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ReductionProcessor.h"
14 
15 #include "flang/Lower/AbstractConverter.h"
16 #include "flang/Optimizer/Builder/HLFIRTools.h"
17 #include "flang/Optimizer/Builder/Todo.h"
18 #include "flang/Optimizer/Dialect/FIRType.h"
19 #include "flang/Optimizer/HLFIR/HLFIROps.h"
20 #include "flang/Parser/tools.h"
21 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
22 #include "llvm/Support/CommandLine.h"
23 
24 static llvm::cl::opt<bool> forceByrefReduction(
25     "force-byref-reduction",
26     llvm::cl::desc("Pass all reduction arguments by reference"),
27     llvm::cl::Hidden);
28 
29 namespace Fortran {
30 namespace lower {
31 namespace omp {
32 
33 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
34     const omp::clause::ProcedureDesignator &pd) {
35   auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
36                      getRealName(pd.v.id()).ToString())
37                      .Case("max", ReductionIdentifier::MAX)
38                      .Case("min", ReductionIdentifier::MIN)
39                      .Case("iand", ReductionIdentifier::IAND)
40                      .Case("ior", ReductionIdentifier::IOR)
41                      .Case("ieor", ReductionIdentifier::IEOR)
42                      .Default(std::nullopt);
43   assert(redType && "Invalid Reduction");
44   return *redType;
45 }
46 
47 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
48     omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
49   switch (intrinsicOp) {
50   case omp::clause::DefinedOperator::IntrinsicOperator::Add:
51     return ReductionIdentifier::ADD;
52   case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
53     return ReductionIdentifier::SUBTRACT;
54   case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
55     return ReductionIdentifier::MULTIPLY;
56   case omp::clause::DefinedOperator::IntrinsicOperator::AND:
57     return ReductionIdentifier::AND;
58   case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
59     return ReductionIdentifier::EQV;
60   case omp::clause::DefinedOperator::IntrinsicOperator::OR:
61     return ReductionIdentifier::OR;
62   case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
63     return ReductionIdentifier::NEQV;
64   default:
65     llvm_unreachable("unexpected intrinsic operator in reduction");
66   }
67 }
68 
69 bool ReductionProcessor::supportedIntrinsicProcReduction(
70     const omp::clause::ProcedureDesignator &pd) {
71   Fortran::semantics::Symbol *sym = pd.v.id();
72   if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC))
73     return false;
74   auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
75                      .Case("max", true)
76                      .Case("min", true)
77                      .Case("iand", true)
78                      .Case("ior", true)
79                      .Case("ieor", true)
80                      .Default(false);
81   return redType;
82 }
83 
84 std::string ReductionProcessor::getReductionName(llvm::StringRef name,
85                                                  mlir::Type ty, bool isByRef) {
86   ty = fir::unwrapRefType(ty);
87 
88   // extra string to distinguish reduction functions for variables passed by
89   // reference
90   llvm::StringRef byrefAddition{""};
91   if (isByRef)
92     byrefAddition = "_byref";
93 
94   if (fir::isa_trivial(ty))
95     return (llvm::Twine(name) +
96             (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
97             llvm::Twine(ty.getIntOrFloatBitWidth()) + byrefAddition)
98         .str();
99 
100   // creates a name like reduction_i_64_box_ux4x3
101   if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
102     // TODO: support for allocatable boxes:
103     // !fir.box<!fir.heap<!fir.array<...>>>
104     fir::SequenceType seqTy = fir::unwrapRefType(boxTy.getEleTy())
105                                   .dyn_cast_or_null<fir::SequenceType>();
106     if (!seqTy)
107       return {};
108 
109     std::string prefix = getReductionName(
110         name, fir::unwrapSeqOrBoxedSeqType(ty), /*isByRef=*/false);
111     if (prefix.empty())
112       return {};
113     std::stringstream tyStr;
114     tyStr << prefix << "_box_";
115     bool first = true;
116     for (std::int64_t extent : seqTy.getShape()) {
117       if (first)
118         first = false;
119       else
120         tyStr << "x";
121       if (extent == seqTy.getUnknownExtent())
122         tyStr << 'u'; // I'm not sure that '?' is safe in symbol names
123       else
124         tyStr << extent;
125     }
126     return (tyStr.str() + byrefAddition).str();
127   }
128 
129   return {};
130 }
131 
132 std::string ReductionProcessor::getReductionName(
133     omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, mlir::Type ty,
134     bool isByRef) {
135   std::string reductionName;
136 
137   switch (intrinsicOp) {
138   case omp::clause::DefinedOperator::IntrinsicOperator::Add:
139     reductionName = "add_reduction";
140     break;
141   case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
142     reductionName = "multiply_reduction";
143     break;
144   case omp::clause::DefinedOperator::IntrinsicOperator::AND:
145     return "and_reduction";
146   case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
147     return "eqv_reduction";
148   case omp::clause::DefinedOperator::IntrinsicOperator::OR:
149     return "or_reduction";
150   case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
151     return "neqv_reduction";
152   default:
153     reductionName = "other_reduction";
154     break;
155   }
156 
157   return getReductionName(reductionName, ty, isByRef);
158 }
159 
160 mlir::Value
161 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
162                                           ReductionIdentifier redId,
163                                           fir::FirOpBuilder &builder) {
164   type = fir::unwrapRefType(type);
165   assert((fir::isa_integer(type) || fir::isa_real(type) ||
166           type.isa<fir::LogicalType>()) &&
167          "only integer, logical and real types are currently supported");
168   switch (redId) {
169   case ReductionIdentifier::MAX: {
170     if (auto ty = type.dyn_cast<mlir::FloatType>()) {
171       const llvm::fltSemantics &sem = ty.getFloatSemantics();
172       return builder.createRealConstant(
173           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
174     }
175     unsigned bits = type.getIntOrFloatBitWidth();
176     int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
177     return builder.createIntegerConstant(loc, type, minInt);
178   }
179   case ReductionIdentifier::MIN: {
180     if (auto ty = type.dyn_cast<mlir::FloatType>()) {
181       const llvm::fltSemantics &sem = ty.getFloatSemantics();
182       return builder.createRealConstant(
183           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
184     }
185     unsigned bits = type.getIntOrFloatBitWidth();
186     int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
187     return builder.createIntegerConstant(loc, type, maxInt);
188   }
189   case ReductionIdentifier::IOR: {
190     unsigned bits = type.getIntOrFloatBitWidth();
191     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
192     return builder.createIntegerConstant(loc, type, zeroInt);
193   }
194   case ReductionIdentifier::IEOR: {
195     unsigned bits = type.getIntOrFloatBitWidth();
196     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
197     return builder.createIntegerConstant(loc, type, zeroInt);
198   }
199   case ReductionIdentifier::IAND: {
200     unsigned bits = type.getIntOrFloatBitWidth();
201     int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
202     return builder.createIntegerConstant(loc, type, allOnInt);
203   }
204   case ReductionIdentifier::ADD:
205   case ReductionIdentifier::MULTIPLY:
206   case ReductionIdentifier::AND:
207   case ReductionIdentifier::OR:
208   case ReductionIdentifier::EQV:
209   case ReductionIdentifier::NEQV:
210     if (type.isa<mlir::FloatType>())
211       return builder.create<mlir::arith::ConstantOp>(
212           loc, type,
213           builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc)));
214 
215     if (type.isa<fir::LogicalType>()) {
216       mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
217           loc, builder.getI1Type(),
218           builder.getIntegerAttr(builder.getI1Type(),
219                                  getOperationIdentity(redId, loc)));
220       return builder.createConvert(loc, type, intConst);
221     }
222 
223     return builder.create<mlir::arith::ConstantOp>(
224         loc, type,
225         builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
226   case ReductionIdentifier::ID:
227   case ReductionIdentifier::USER_DEF_OP:
228   case ReductionIdentifier::SUBTRACT:
229     TODO(loc, "Reduction of some identifier types is not supported");
230   }
231   llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
232 }
233 
234 mlir::Value ReductionProcessor::createScalarCombiner(
235     fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
236     mlir::Type type, mlir::Value op1, mlir::Value op2) {
237   mlir::Value reductionOp;
238   type = fir::unwrapRefType(type);
239   switch (redId) {
240   case ReductionIdentifier::MAX:
241     reductionOp =
242         getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
243             builder, type, loc, op1, op2);
244     break;
245   case ReductionIdentifier::MIN:
246     reductionOp =
247         getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
248             builder, type, loc, op1, op2);
249     break;
250   case ReductionIdentifier::IOR:
251     assert((type.isIntOrIndex()) && "only integer is expected");
252     reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
253     break;
254   case ReductionIdentifier::IEOR:
255     assert((type.isIntOrIndex()) && "only integer is expected");
256     reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
257     break;
258   case ReductionIdentifier::IAND:
259     assert((type.isIntOrIndex()) && "only integer is expected");
260     reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
261     break;
262   case ReductionIdentifier::ADD:
263     reductionOp =
264         getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
265             builder, type, loc, op1, op2);
266     break;
267   case ReductionIdentifier::MULTIPLY:
268     reductionOp =
269         getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
270             builder, type, loc, op1, op2);
271     break;
272   case ReductionIdentifier::AND: {
273     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
274     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
275 
276     mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
277 
278     reductionOp = builder.createConvert(loc, type, andiOp);
279     break;
280   }
281   case ReductionIdentifier::OR: {
282     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
283     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
284 
285     mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
286 
287     reductionOp = builder.createConvert(loc, type, oriOp);
288     break;
289   }
290   case ReductionIdentifier::EQV: {
291     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
292     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
293 
294     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
295         loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
296 
297     reductionOp = builder.createConvert(loc, type, cmpiOp);
298     break;
299   }
300   case ReductionIdentifier::NEQV: {
301     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
302     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
303 
304     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
305         loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
306 
307     reductionOp = builder.createConvert(loc, type, cmpiOp);
308     break;
309   }
310   default:
311     TODO(loc, "Reduction of some intrinsic operators is not supported");
312   }
313 
314   return reductionOp;
315 }
316 
317 /// Create reduction combiner region for reduction variables which are boxed
318 /// arrays
319 static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
320                            ReductionProcessor::ReductionIdentifier redId,
321                            fir::BaseBoxType boxTy, mlir::Value lhs,
322                            mlir::Value rhs) {
323   fir::SequenceType seqTy =
324       mlir::dyn_cast_or_null<fir::SequenceType>(boxTy.getEleTy());
325   // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
326   if (!seqTy || seqTy.hasUnknownShape())
327     TODO(loc, "Unsupported boxed type in OpenMP reduction");
328 
329   // load fir.ref<fir.box<...>>
330   mlir::Value lhsAddr = lhs;
331   lhs = builder.create<fir::LoadOp>(loc, lhs);
332   rhs = builder.create<fir::LoadOp>(loc, rhs);
333 
334   const unsigned rank = seqTy.getDimension();
335   llvm::SmallVector<mlir::Value> extents;
336   extents.reserve(rank);
337   llvm::SmallVector<mlir::Value> lbAndExtents;
338   lbAndExtents.reserve(rank * 2);
339 
340   // Get box lowerbounds and extents:
341   mlir::Type idxTy = builder.getIndexType();
342   for (unsigned i = 0; i < rank; ++i) {
343     // TODO: ideally we want to hoist box reads out of the critical section.
344     // We could do this by having box dimensions in block arguments like
345     // OpenACC does
346     mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i);
347     auto dimInfo =
348         builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, lhs, dim);
349     extents.push_back(dimInfo.getExtent());
350     lbAndExtents.push_back(dimInfo.getLowerBound());
351     lbAndExtents.push_back(dimInfo.getExtent());
352   }
353 
354   auto shapeShiftTy = fir::ShapeShiftType::get(builder.getContext(), rank);
355   auto shapeShift =
356       builder.create<fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents);
357 
358   // Iterate over array elements, applying the equivalent scalar reduction:
359 
360   // A hlfir::elemental here gets inlined with a temporary so create the
361   // loop nest directly.
362   // This function already controls all of the code in this region so we
363   // know this won't miss any opportuinties for clever elemental inlining
364   hlfir::LoopNest nest =
365       hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
366   builder.setInsertionPointToStart(nest.innerLoop.getBody());
367   mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
368   auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
369       loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
370       nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
371   auto rhsEleAddr = builder.create<fir::ArrayCoorOp>(
372       loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{},
373       nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
374   auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr);
375   auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr);
376   mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner(
377       builder, loc, redId, refTy, lhsEle, rhsEle);
378   builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
379 
380   builder.setInsertionPointAfter(nest.outerLoop);
381   builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
382 }
383 
384 // generate combiner region for reduction operations
385 static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
386                         ReductionProcessor::ReductionIdentifier redId,
387                         mlir::Type ty, mlir::Value lhs, mlir::Value rhs,
388                         bool isByRef) {
389   ty = fir::unwrapRefType(ty);
390 
391   if (fir::isa_trivial(ty)) {
392     mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs);
393     mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs);
394 
395     mlir::Value result = ReductionProcessor::createScalarCombiner(
396         builder, loc, redId, ty, lhsLoaded, rhsLoaded);
397     if (isByRef) {
398       builder.create<fir::StoreOp>(loc, result, lhs);
399       builder.create<mlir::omp::YieldOp>(loc, lhs);
400     } else {
401       builder.create<mlir::omp::YieldOp>(loc, result);
402     }
403     return;
404   }
405   // all arrays should have been boxed
406   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
407     genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs);
408     return;
409   }
410 
411   TODO(loc, "OpenMP genCombiner for unsupported reduction variable type");
412 }
413 
414 static mlir::Value
415 createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc,
416                           const ReductionProcessor::ReductionIdentifier redId,
417                           mlir::Type type, bool isByRef) {
418   mlir::Type ty = fir::unwrapRefType(type);
419   mlir::Value initValue = ReductionProcessor::getReductionInitValue(
420       loc, fir::unwrapSeqOrBoxedSeqType(ty), redId, builder);
421 
422   if (fir::isa_trivial(ty)) {
423     if (isByRef) {
424       mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty);
425       builder.createStoreWithConvert(loc, initValue, alloca);
426       return alloca;
427     }
428     // by val
429     return initValue;
430   }
431 
432   // all arrays are boxed
433   if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
434     assert(isByRef && "passing arrays by value is unsupported");
435     // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
436     mlir::Type innerTy = fir::extractSequenceType(boxTy);
437     if (!mlir::isa<fir::SequenceType>(innerTy))
438       TODO(loc, "Unsupported boxed type for reduction");
439     // Create the private copy from the initial fir.box:
440     hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)};
441 
442     // TODO: if the whole reduction is nested inside of a loop, this alloca
443     // could lead to a stack overflow (the memory is only freed at the end of
444     // the stack frame). The reduction declare operation needs a deallocation
445     // region to undo the init region.
446     hlfir::Entity temp = createStackTempFromMold(loc, builder, source);
447 
448     // Put the temporary inside of a box:
449     hlfir::Entity box = hlfir::genVariableBox(loc, builder, temp);
450     builder.create<hlfir::AssignOp>(loc, initValue, box);
451     mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty);
452     builder.create<fir::StoreOp>(loc, box, boxAlloca);
453     return boxAlloca;
454   }
455 
456   TODO(loc, "createReductionInitRegion for unsupported type");
457 }
458 
459 mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
460     fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
461     const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
462     bool isByRef) {
463   mlir::OpBuilder::InsertionGuard guard(builder);
464   mlir::ModuleOp module = builder.getModule();
465 
466   if (reductionOpName.empty())
467     TODO(loc, "Reduction of some types is not supported");
468 
469   auto decl =
470       module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName);
471   if (decl)
472     return decl;
473 
474   mlir::OpBuilder modBuilder(module.getBodyRegion());
475   mlir::Type valTy = fir::unwrapRefType(type);
476   if (!isByRef)
477     type = valTy;
478 
479   decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName,
480                                                           type);
481   builder.createBlock(&decl.getInitializerRegion(),
482                       decl.getInitializerRegion().end(), {type}, {loc});
483   builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
484 
485   mlir::Value init =
486       createReductionInitRegion(builder, loc, redId, type, isByRef);
487   builder.create<mlir::omp::YieldOp>(loc, init);
488 
489   builder.createBlock(&decl.getReductionRegion(),
490                       decl.getReductionRegion().end(), {type, type},
491                       {loc, loc});
492 
493   builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
494   mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
495   mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
496   genCombiner(builder, loc, redId, type, op1, op2, isByRef);
497 
498   return decl;
499 }
500 
501 // TODO: By-ref vs by-val reductions are currently toggled for the whole
502 //       operation (possibly effecting multiple reduction variables).
503 //       This could cause a problem with openmp target reductions because
504 //       by-ref trivial types may not be supported.
505 bool ReductionProcessor::doReductionByRef(
506     const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
507   if (reductionVars.empty())
508     return false;
509   if (forceByrefReduction)
510     return true;
511 
512   for (mlir::Value reductionVar : reductionVars) {
513     if (auto declare =
514             mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
515       reductionVar = declare.getMemref();
516 
517     if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
518       return true;
519   }
520   return false;
521 }
522 
523 void ReductionProcessor::addDeclareReduction(
524     mlir::Location currentLocation,
525     Fortran::lower::AbstractConverter &converter,
526     const omp::clause::Reduction &reduction,
527     llvm::SmallVectorImpl<mlir::Value> &reductionVars,
528     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
529     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
530         *reductionSymbols) {
531   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
532   mlir::omp::DeclareReductionOp decl;
533   const auto &redOperator{
534       std::get<omp::clause::ReductionOperator>(reduction.t)};
535   const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
536 
537   if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) {
538     if (const auto *reductionIntrinsic =
539             std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
540       if (!ReductionProcessor::supportedIntrinsicProcReduction(
541               *reductionIntrinsic)) {
542         return;
543       }
544     } else {
545       return;
546     }
547   }
548 
549   // initial pass to collect all reduction vars so we can figure out if this
550   // should happen byref
551   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
552   for (const Object &object : objectList) {
553     const Fortran::semantics::Symbol *symbol = object.id();
554     if (reductionSymbols)
555       reductionSymbols->push_back(symbol);
556     mlir::Value symVal = converter.getSymbolAddress(*symbol);
557     auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
558 
559     // all arrays must be boxed so that we have convenient access to all the
560     // information needed to iterate over the array
561     if (mlir::isa<fir::SequenceType>(redType.getEleTy())) {
562       hlfir::Entity entity{symVal};
563       entity = genVariableBox(currentLocation, builder, entity);
564       mlir::Value box = entity.getBase();
565 
566       // Always pass the box by reference so that the OpenMP dialect
567       // verifiers don't need to know anything about fir.box
568       auto alloca =
569           builder.create<fir::AllocaOp>(currentLocation, box.getType());
570       builder.create<fir::StoreOp>(currentLocation, box, alloca);
571 
572       symVal = alloca;
573       redType = mlir::cast<fir::ReferenceType>(symVal.getType());
574     } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
575       symVal = declOp.getBase();
576     }
577 
578     reductionVars.push_back(symVal);
579   }
580   const bool isByRef = doReductionByRef(reductionVars);
581 
582   if (const auto &redDefinedOp =
583           std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
584     const auto &intrinsicOp{
585         std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
586             redDefinedOp->u)};
587     ReductionIdentifier redId = getReductionType(intrinsicOp);
588     switch (redId) {
589     case ReductionIdentifier::ADD:
590     case ReductionIdentifier::MULTIPLY:
591     case ReductionIdentifier::AND:
592     case ReductionIdentifier::EQV:
593     case ReductionIdentifier::OR:
594     case ReductionIdentifier::NEQV:
595       break;
596     default:
597       TODO(currentLocation,
598            "Reduction of some intrinsic operators is not supported");
599       break;
600     }
601 
602     for (mlir::Value symVal : reductionVars) {
603       auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
604       if (redType.getEleTy().isa<fir::LogicalType>())
605         decl = createDeclareReduction(
606             firOpBuilder,
607             getReductionName(intrinsicOp, firOpBuilder.getI1Type(), isByRef),
608             redId, redType, currentLocation, isByRef);
609       else
610         decl = createDeclareReduction(
611             firOpBuilder, getReductionName(intrinsicOp, redType, isByRef),
612             redId, redType, currentLocation, isByRef);
613       reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
614           firOpBuilder.getContext(), decl.getSymName()));
615     }
616   } else if (const auto *reductionIntrinsic =
617                  std::get_if<omp::clause::ProcedureDesignator>(
618                      &redOperator.u)) {
619     if (ReductionProcessor::supportedIntrinsicProcReduction(
620             *reductionIntrinsic)) {
621       ReductionProcessor::ReductionIdentifier redId =
622           ReductionProcessor::getReductionType(*reductionIntrinsic);
623       for (const Object &object : objectList) {
624         const Fortran::semantics::Symbol *symbol = object.id();
625         mlir::Value symVal = converter.getSymbolAddress(*symbol);
626         if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
627           symVal = declOp.getBase();
628         auto redType = symVal.getType().cast<fir::ReferenceType>();
629         if (!redType.getEleTy().isIntOrIndexOrFloat())
630           TODO(currentLocation, "User Defined Reduction on non-trivial type");
631         decl = createDeclareReduction(
632             firOpBuilder,
633             getReductionName(getRealName(*reductionIntrinsic).ToString(),
634                              redType, isByRef),
635             redId, redType, currentLocation, isByRef);
636         reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
637             firOpBuilder.getContext(), decl.getSymName()));
638       }
639     }
640   }
641 }
642 
643 const Fortran::semantics::SourceName
644 ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) {
645   return symbol->GetUltimate().name();
646 }
647 
648 const Fortran::semantics::SourceName
649 ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) {
650   return getRealName(pd.v.id());
651 }
652 
653 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
654                                              mlir::Location loc) {
655   switch (redId) {
656   case ReductionIdentifier::ADD:
657   case ReductionIdentifier::OR:
658   case ReductionIdentifier::NEQV:
659     return 0;
660   case ReductionIdentifier::MULTIPLY:
661   case ReductionIdentifier::AND:
662   case ReductionIdentifier::EQV:
663     return 1;
664   default:
665     TODO(loc, "Reduction of some intrinsic operators is not supported");
666   }
667 }
668 
669 } // namespace omp
670 } // namespace lower
671 } // namespace Fortran
672