xref: /llvm-project/flang/lib/Lower/OpenMP/ReductionProcessor.cpp (revision 8557a57c4b1a228ce63f2409dd5cc4c70a25e6fc)
1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ReductionProcessor.h"
14 
15 #include "PrivateReductionUtils.h"
16 #include "flang/Lower/AbstractConverter.h"
17 #include "flang/Lower/ConvertType.h"
18 #include "flang/Lower/SymbolMap.h"
19 #include "flang/Optimizer/Builder/Complex.h"
20 #include "flang/Optimizer/Builder/HLFIRTools.h"
21 #include "flang/Optimizer/Builder/Todo.h"
22 #include "flang/Optimizer/Dialect/FIRType.h"
23 #include "flang/Optimizer/HLFIR/HLFIROps.h"
24 #include "flang/Optimizer/Support/FatalError.h"
25 #include "flang/Parser/tools.h"
26 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
27 #include "llvm/Support/CommandLine.h"
28 
29 static llvm::cl::opt<bool> forceByrefReduction(
30     "force-byref-reduction",
31     llvm::cl::desc("Pass all reduction arguments by reference"),
32     llvm::cl::Hidden);
33 
34 namespace Fortran {
35 namespace lower {
36 namespace omp {
37 
38 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
39     const omp::clause::ProcedureDesignator &pd) {
40   auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
41                      getRealName(pd.v.sym()).ToString())
42                      .Case("max", ReductionIdentifier::MAX)
43                      .Case("min", ReductionIdentifier::MIN)
44                      .Case("iand", ReductionIdentifier::IAND)
45                      .Case("ior", ReductionIdentifier::IOR)
46                      .Case("ieor", ReductionIdentifier::IEOR)
47                      .Default(std::nullopt);
48   assert(redType && "Invalid Reduction");
49   return *redType;
50 }
51 
52 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
53     omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
54   switch (intrinsicOp) {
55   case omp::clause::DefinedOperator::IntrinsicOperator::Add:
56     return ReductionIdentifier::ADD;
57   case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
58     return ReductionIdentifier::SUBTRACT;
59   case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
60     return ReductionIdentifier::MULTIPLY;
61   case omp::clause::DefinedOperator::IntrinsicOperator::AND:
62     return ReductionIdentifier::AND;
63   case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
64     return ReductionIdentifier::EQV;
65   case omp::clause::DefinedOperator::IntrinsicOperator::OR:
66     return ReductionIdentifier::OR;
67   case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
68     return ReductionIdentifier::NEQV;
69   default:
70     llvm_unreachable("unexpected intrinsic operator in reduction");
71   }
72 }
73 
74 bool ReductionProcessor::supportedIntrinsicProcReduction(
75     const omp::clause::ProcedureDesignator &pd) {
76   semantics::Symbol *sym = pd.v.sym();
77   if (!sym->GetUltimate().attrs().test(semantics::Attr::INTRINSIC))
78     return false;
79   auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
80                      .Case("max", true)
81                      .Case("min", true)
82                      .Case("iand", true)
83                      .Case("ior", true)
84                      .Case("ieor", true)
85                      .Default(false);
86   return redType;
87 }
88 
89 std::string
90 ReductionProcessor::getReductionName(llvm::StringRef name,
91                                      const fir::KindMapping &kindMap,
92                                      mlir::Type ty, bool isByRef) {
93   ty = fir::unwrapRefType(ty);
94 
95   // extra string to distinguish reduction functions for variables passed by
96   // reference
97   llvm::StringRef byrefAddition{""};
98   if (isByRef)
99     byrefAddition = "_byref";
100 
101   return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str());
102 }
103 
104 std::string ReductionProcessor::getReductionName(
105     omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
106     const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) {
107   std::string reductionName;
108 
109   switch (intrinsicOp) {
110   case omp::clause::DefinedOperator::IntrinsicOperator::Add:
111     reductionName = "add_reduction";
112     break;
113   case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
114     reductionName = "multiply_reduction";
115     break;
116   case omp::clause::DefinedOperator::IntrinsicOperator::AND:
117     return "and_reduction";
118   case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
119     return "eqv_reduction";
120   case omp::clause::DefinedOperator::IntrinsicOperator::OR:
121     return "or_reduction";
122   case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
123     return "neqv_reduction";
124   default:
125     reductionName = "other_reduction";
126     break;
127   }
128 
129   return getReductionName(reductionName, kindMap, ty, isByRef);
130 }
131 
132 mlir::Value
133 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
134                                           ReductionIdentifier redId,
135                                           fir::FirOpBuilder &builder) {
136   type = fir::unwrapRefType(type);
137   if (!fir::isa_integer(type) && !fir::isa_real(type) &&
138       !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
139     TODO(loc, "Reduction of some types is not supported");
140   switch (redId) {
141   case ReductionIdentifier::MAX: {
142     if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) {
143       const llvm::fltSemantics &sem = ty.getFloatSemantics();
144       return builder.createRealConstant(
145           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
146     }
147     unsigned bits = type.getIntOrFloatBitWidth();
148     int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
149     return builder.createIntegerConstant(loc, type, minInt);
150   }
151   case ReductionIdentifier::MIN: {
152     if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) {
153       const llvm::fltSemantics &sem = ty.getFloatSemantics();
154       return builder.createRealConstant(
155           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
156     }
157     unsigned bits = type.getIntOrFloatBitWidth();
158     int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
159     return builder.createIntegerConstant(loc, type, maxInt);
160   }
161   case ReductionIdentifier::IOR: {
162     unsigned bits = type.getIntOrFloatBitWidth();
163     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
164     return builder.createIntegerConstant(loc, type, zeroInt);
165   }
166   case ReductionIdentifier::IEOR: {
167     unsigned bits = type.getIntOrFloatBitWidth();
168     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
169     return builder.createIntegerConstant(loc, type, zeroInt);
170   }
171   case ReductionIdentifier::IAND: {
172     unsigned bits = type.getIntOrFloatBitWidth();
173     int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
174     return builder.createIntegerConstant(loc, type, allOnInt);
175   }
176   case ReductionIdentifier::ADD:
177   case ReductionIdentifier::MULTIPLY:
178   case ReductionIdentifier::AND:
179   case ReductionIdentifier::OR:
180   case ReductionIdentifier::EQV:
181   case ReductionIdentifier::NEQV:
182     if (auto cplxTy = mlir::dyn_cast<mlir::ComplexType>(type)) {
183       mlir::Type realTy = cplxTy.getElementType();
184       mlir::Value initRe = builder.createRealConstant(
185           loc, realTy, getOperationIdentity(redId, loc));
186       mlir::Value initIm = builder.createRealConstant(loc, realTy, 0);
187 
188       return fir::factory::Complex{builder, loc}.createComplex(type, initRe,
189                                                                initIm);
190     }
191     if (mlir::isa<mlir::FloatType>(type))
192       return builder.create<mlir::arith::ConstantOp>(
193           loc, type,
194           builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc)));
195 
196     if (mlir::isa<fir::LogicalType>(type)) {
197       mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
198           loc, builder.getI1Type(),
199           builder.getIntegerAttr(builder.getI1Type(),
200                                  getOperationIdentity(redId, loc)));
201       return builder.createConvert(loc, type, intConst);
202     }
203 
204     return builder.create<mlir::arith::ConstantOp>(
205         loc, type,
206         builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
207   case ReductionIdentifier::ID:
208   case ReductionIdentifier::USER_DEF_OP:
209   case ReductionIdentifier::SUBTRACT:
210     TODO(loc, "Reduction of some identifier types is not supported");
211   }
212   llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
213 }
214 
215 mlir::Value ReductionProcessor::createScalarCombiner(
216     fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
217     mlir::Type type, mlir::Value op1, mlir::Value op2) {
218   mlir::Value reductionOp;
219   type = fir::unwrapRefType(type);
220   switch (redId) {
221   case ReductionIdentifier::MAX:
222     reductionOp =
223         getReductionOperation<mlir::arith::MaxNumFOp, mlir::arith::MaxSIOp>(
224             builder, type, loc, op1, op2);
225     break;
226   case ReductionIdentifier::MIN:
227     reductionOp =
228         getReductionOperation<mlir::arith::MinNumFOp, mlir::arith::MinSIOp>(
229             builder, type, loc, op1, op2);
230     break;
231   case ReductionIdentifier::IOR:
232     assert((type.isIntOrIndex()) && "only integer is expected");
233     reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
234     break;
235   case ReductionIdentifier::IEOR:
236     assert((type.isIntOrIndex()) && "only integer is expected");
237     reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
238     break;
239   case ReductionIdentifier::IAND:
240     assert((type.isIntOrIndex()) && "only integer is expected");
241     reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
242     break;
243   case ReductionIdentifier::ADD:
244     reductionOp =
245         getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
246                               fir::AddcOp>(builder, type, loc, op1, op2);
247     break;
248   case ReductionIdentifier::MULTIPLY:
249     reductionOp =
250         getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
251                               fir::MulcOp>(builder, type, loc, op1, op2);
252     break;
253   case ReductionIdentifier::AND: {
254     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
255     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
256 
257     mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
258 
259     reductionOp = builder.createConvert(loc, type, andiOp);
260     break;
261   }
262   case ReductionIdentifier::OR: {
263     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
264     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
265 
266     mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
267 
268     reductionOp = builder.createConvert(loc, type, oriOp);
269     break;
270   }
271   case ReductionIdentifier::EQV: {
272     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
273     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
274 
275     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
276         loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
277 
278     reductionOp = builder.createConvert(loc, type, cmpiOp);
279     break;
280   }
281   case ReductionIdentifier::NEQV: {
282     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
283     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
284 
285     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
286         loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
287 
288     reductionOp = builder.createConvert(loc, type, cmpiOp);
289     break;
290   }
291   default:
292     TODO(loc, "Reduction of some intrinsic operators is not supported");
293   }
294 
295   return reductionOp;
296 }
297 
298 /// Create reduction combiner region for reduction variables which are boxed
299 /// arrays
300 static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
301                            ReductionProcessor::ReductionIdentifier redId,
302                            fir::BaseBoxType boxTy, mlir::Value lhs,
303                            mlir::Value rhs) {
304   fir::SequenceType seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
305       fir::unwrapRefType(boxTy.getEleTy()));
306   fir::HeapType heapTy =
307       mlir::dyn_cast_or_null<fir::HeapType>(boxTy.getEleTy());
308   fir::PointerType ptrTy =
309       mlir::dyn_cast_or_null<fir::PointerType>(boxTy.getEleTy());
310   if ((!seqTy || seqTy.hasUnknownShape()) && !heapTy && !ptrTy)
311     TODO(loc, "Unsupported boxed type in OpenMP reduction");
312 
313   // load fir.ref<fir.box<...>>
314   mlir::Value lhsAddr = lhs;
315   lhs = builder.create<fir::LoadOp>(loc, lhs);
316   rhs = builder.create<fir::LoadOp>(loc, rhs);
317 
318   if ((heapTy || ptrTy) && !seqTy) {
319     // get box contents (heap pointers)
320     lhs = builder.create<fir::BoxAddrOp>(loc, lhs);
321     rhs = builder.create<fir::BoxAddrOp>(loc, rhs);
322     mlir::Value lhsValAddr = lhs;
323 
324     // load heap pointers
325     lhs = builder.create<fir::LoadOp>(loc, lhs);
326     rhs = builder.create<fir::LoadOp>(loc, rhs);
327 
328     mlir::Type eleTy = heapTy ? heapTy.getEleTy() : ptrTy.getEleTy();
329 
330     mlir::Value result = ReductionProcessor::createScalarCombiner(
331         builder, loc, redId, eleTy, lhs, rhs);
332     builder.create<fir::StoreOp>(loc, result, lhsValAddr);
333     builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
334     return;
335   }
336 
337   fir::ShapeShiftOp shapeShift = getShapeShift(builder, loc, lhs);
338 
339   // Iterate over array elements, applying the equivalent scalar reduction:
340 
341   // F2018 5.4.10.2: Unallocated allocatable variables may not be referenced
342   // and so no null check is needed here before indexing into the (possibly
343   // allocatable) arrays.
344 
345   // A hlfir::elemental here gets inlined with a temporary so create the
346   // loop nest directly.
347   // This function already controls all of the code in this region so we
348   // know this won't miss any opportuinties for clever elemental inlining
349   hlfir::LoopNest nest = hlfir::genLoopNest(
350       loc, builder, shapeShift.getExtents(), /*isUnordered=*/true);
351   builder.setInsertionPointToStart(nest.body);
352   mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
353   auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
354       loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
355       nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
356   auto rhsEleAddr = builder.create<fir::ArrayCoorOp>(
357       loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{},
358       nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
359   auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr);
360   auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr);
361   mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner(
362       builder, loc, redId, refTy, lhsEle, rhsEle);
363   builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
364 
365   builder.setInsertionPointAfter(nest.outerOp);
366   builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
367 }
368 
369 // generate combiner region for reduction operations
370 static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
371                         ReductionProcessor::ReductionIdentifier redId,
372                         mlir::Type ty, mlir::Value lhs, mlir::Value rhs,
373                         bool isByRef) {
374   ty = fir::unwrapRefType(ty);
375 
376   if (fir::isa_trivial(ty)) {
377     mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs);
378     mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs);
379 
380     mlir::Value result = ReductionProcessor::createScalarCombiner(
381         builder, loc, redId, ty, lhsLoaded, rhsLoaded);
382     if (isByRef) {
383       builder.create<fir::StoreOp>(loc, result, lhs);
384       builder.create<mlir::omp::YieldOp>(loc, lhs);
385     } else {
386       builder.create<mlir::omp::YieldOp>(loc, result);
387     }
388     return;
389   }
390   // all arrays should have been boxed
391   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
392     genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs);
393     return;
394   }
395 
396   TODO(loc, "OpenMP genCombiner for unsupported reduction variable type");
397 }
398 
399 // like fir::unwrapSeqOrBoxedSeqType except it also works for non-sequence boxes
400 static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) {
401   if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
402     return seqTy.getEleTy();
403   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
404     auto eleTy = fir::unwrapRefType(boxTy.getEleTy());
405     if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy))
406       return seqTy.getEleTy();
407     return eleTy;
408   }
409   return ty;
410 }
411 
412 static void createReductionAllocAndInitRegions(
413     fir::FirOpBuilder &builder, mlir::Location loc,
414     mlir::omp::DeclareReductionOp &reductionDecl,
415     const ReductionProcessor::ReductionIdentifier redId, mlir::Type type,
416     bool isByRef) {
417   auto yield = [&](mlir::Value ret) {
418     builder.create<mlir::omp::YieldOp>(loc, ret);
419   };
420 
421   mlir::Block *allocBlock = nullptr;
422   mlir::Block *initBlock = nullptr;
423   if (isByRef) {
424     allocBlock =
425         builder.createBlock(&reductionDecl.getAllocRegion(),
426                             reductionDecl.getAllocRegion().end(), {}, {});
427     initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(),
428                                     reductionDecl.getInitializerRegion().end(),
429                                     {type, type}, {loc, loc});
430   } else {
431     initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(),
432                                     reductionDecl.getInitializerRegion().end(),
433                                     {type}, {loc});
434   }
435 
436   mlir::Type ty = fir::unwrapRefType(type);
437   builder.setInsertionPointToEnd(initBlock);
438   mlir::Value initValue = ReductionProcessor::getReductionInitValue(
439       loc, unwrapSeqOrBoxedType(ty), redId, builder);
440 
441   if (isByRef) {
442     populateByRefInitAndCleanupRegions(builder, loc, type, initValue, initBlock,
443                                        reductionDecl.getInitializerAllocArg(),
444                                        reductionDecl.getInitializerMoldArg(),
445                                        reductionDecl.getCleanupRegion());
446   }
447 
448   if (fir::isa_trivial(ty)) {
449     if (isByRef) {
450       // alloc region
451       builder.setInsertionPointToEnd(allocBlock);
452       mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty);
453       yield(alloca);
454       return;
455     }
456     // by val
457     yield(initValue);
458     return;
459   }
460   assert(isByRef && "passing non-trivial types by val is unsupported");
461 
462   // alloc region
463   builder.setInsertionPointToEnd(allocBlock);
464   mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty);
465   yield(boxAlloca);
466 }
467 
468 mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
469     fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
470     const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
471     bool isByRef) {
472   mlir::OpBuilder::InsertionGuard guard(builder);
473   mlir::ModuleOp module = builder.getModule();
474 
475   assert(!reductionOpName.empty());
476 
477   auto decl =
478       module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName);
479   if (decl)
480     return decl;
481 
482   mlir::OpBuilder modBuilder(module.getBodyRegion());
483   mlir::Type valTy = fir::unwrapRefType(type);
484   if (!isByRef)
485     type = valTy;
486 
487   decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName,
488                                                           type);
489   createReductionAllocAndInitRegions(builder, loc, decl, redId, type, isByRef);
490 
491   builder.createBlock(&decl.getReductionRegion(),
492                       decl.getReductionRegion().end(), {type, type},
493                       {loc, loc});
494 
495   builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
496   mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
497   mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
498   genCombiner(builder, loc, redId, type, op1, op2, isByRef);
499 
500   return decl;
501 }
502 
503 static bool doReductionByRef(mlir::Value reductionVar) {
504   if (forceByrefReduction)
505     return true;
506 
507   if (auto declare =
508           mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
509     reductionVar = declare.getMemref();
510 
511   if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
512     return true;
513 
514   return false;
515 }
516 
517 void ReductionProcessor::addDeclareReduction(
518     mlir::Location currentLocation, lower::AbstractConverter &converter,
519     const omp::clause::Reduction &reduction,
520     llvm::SmallVectorImpl<mlir::Value> &reductionVars,
521     llvm::SmallVectorImpl<bool> &reduceVarByRef,
522     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
523     llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
524   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
525 
526   if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
527           reduction.t))
528     TODO(currentLocation, "Reduction modifiers are not supported");
529 
530   mlir::omp::DeclareReductionOp decl;
531   const auto &redOperatorList{
532       std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)};
533   assert(redOperatorList.size() == 1 && "Expecting single operator");
534   const auto &redOperator = redOperatorList.front();
535   const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
536 
537   if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) {
538     if (const auto *reductionIntrinsic =
539             std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
540       if (!ReductionProcessor::supportedIntrinsicProcReduction(
541               *reductionIntrinsic)) {
542         return;
543       }
544     } else {
545       return;
546     }
547   }
548 
549   // Reduction variable processing common to both intrinsic operators and
550   // procedure designators
551   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
552   for (const Object &object : objectList) {
553     const semantics::Symbol *symbol = object.sym();
554     reductionSymbols.push_back(symbol);
555     mlir::Value symVal = converter.getSymbolAddress(*symbol);
556     mlir::Type eleType;
557     auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType());
558     if (refType)
559       eleType = refType.getEleTy();
560     else
561       eleType = symVal.getType();
562 
563     // all arrays must be boxed so that we have convenient access to all the
564     // information needed to iterate over the array
565     if (mlir::isa<fir::SequenceType>(eleType)) {
566       // For Host associated symbols, use `SymbolBox` instead
567       lower::SymbolBox symBox = converter.lookupOneLevelUpSymbol(*symbol);
568       hlfir::Entity entity{symBox.getAddr()};
569       entity = genVariableBox(currentLocation, builder, entity);
570       mlir::Value box = entity.getBase();
571 
572       // Always pass the box by reference so that the OpenMP dialect
573       // verifiers don't need to know anything about fir.box
574       auto alloca =
575           builder.create<fir::AllocaOp>(currentLocation, box.getType());
576       builder.create<fir::StoreOp>(currentLocation, box, alloca);
577 
578       symVal = alloca;
579     } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) {
580       // boxed arrays are passed as values not by reference. Unfortunately,
581       // we can't pass a box by value to omp.redution_declare, so turn it
582       // into a reference
583 
584       auto alloca =
585           builder.create<fir::AllocaOp>(currentLocation, symVal.getType());
586       builder.create<fir::StoreOp>(currentLocation, symVal, alloca);
587       symVal = alloca;
588     } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
589       symVal = declOp.getBase();
590     }
591 
592     // this isn't the same as the by-val and by-ref passing later in the
593     // pipeline. Both styles assume that the variable is a reference at
594     // this point
595     assert(mlir::isa<fir::ReferenceType>(symVal.getType()) &&
596            "reduction input var is a reference");
597 
598     reductionVars.push_back(symVal);
599     reduceVarByRef.push_back(doReductionByRef(symVal));
600   }
601 
602   for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
603     auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
604     const auto &kindMap = firOpBuilder.getKindMap();
605     std::string reductionName;
606     ReductionIdentifier redId;
607     mlir::Type redNameTy = redType;
608     if (mlir::isa<fir::LogicalType>(redType.getEleTy()))
609       redNameTy = builder.getI1Type();
610 
611     if (const auto &redDefinedOp =
612             std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
613       const auto &intrinsicOp{
614           std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
615               redDefinedOp->u)};
616       redId = getReductionType(intrinsicOp);
617       switch (redId) {
618       case ReductionIdentifier::ADD:
619       case ReductionIdentifier::MULTIPLY:
620       case ReductionIdentifier::AND:
621       case ReductionIdentifier::EQV:
622       case ReductionIdentifier::OR:
623       case ReductionIdentifier::NEQV:
624         break;
625       default:
626         TODO(currentLocation,
627              "Reduction of some intrinsic operators is not supported");
628         break;
629       }
630 
631       reductionName =
632           getReductionName(intrinsicOp, kindMap, redNameTy, isByRef);
633     } else if (const auto *reductionIntrinsic =
634                    std::get_if<omp::clause::ProcedureDesignator>(
635                        &redOperator.u)) {
636       if (!ReductionProcessor::supportedIntrinsicProcReduction(
637               *reductionIntrinsic)) {
638         TODO(currentLocation, "Unsupported intrinsic proc reduction");
639       }
640       redId = getReductionType(*reductionIntrinsic);
641       reductionName =
642           getReductionName(getRealName(*reductionIntrinsic).ToString(), kindMap,
643                            redNameTy, isByRef);
644     } else {
645       TODO(currentLocation, "Unexpected reduction type");
646     }
647 
648     decl = createDeclareReduction(firOpBuilder, reductionName, redId, redType,
649                                   currentLocation, isByRef);
650     reductionDeclSymbols.push_back(
651         mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName()));
652   }
653 }
654 
655 const semantics::SourceName
656 ReductionProcessor::getRealName(const semantics::Symbol *symbol) {
657   return symbol->GetUltimate().name();
658 }
659 
660 const semantics::SourceName
661 ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) {
662   return getRealName(pd.v.sym());
663 }
664 
665 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
666                                              mlir::Location loc) {
667   switch (redId) {
668   case ReductionIdentifier::ADD:
669   case ReductionIdentifier::OR:
670   case ReductionIdentifier::NEQV:
671     return 0;
672   case ReductionIdentifier::MULTIPLY:
673   case ReductionIdentifier::AND:
674   case ReductionIdentifier::EQV:
675     return 1;
676   default:
677     TODO(loc, "Reduction of some intrinsic operators is not supported");
678   }
679 }
680 
681 } // namespace omp
682 } // namespace lower
683 } // namespace Fortran
684