xref: /llvm-project/flang/lib/Lower/OpenMP/ReductionProcessor.cpp (revision 698bf3dafcc0dfa15540ae7f1f9b72208a578bd2)
1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ReductionProcessor.h"
14 
15 #include "flang/Lower/AbstractConverter.h"
16 #include "flang/Lower/SymbolMap.h"
17 #include "flang/Optimizer/Builder/HLFIRTools.h"
18 #include "flang/Optimizer/Builder/Todo.h"
19 #include "flang/Optimizer/Dialect/FIRType.h"
20 #include "flang/Optimizer/HLFIR/HLFIROps.h"
21 #include "flang/Parser/tools.h"
22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
23 #include "llvm/Support/CommandLine.h"
24 
25 static llvm::cl::opt<bool> forceByrefReduction(
26     "force-byref-reduction",
27     llvm::cl::desc("Pass all reduction arguments by reference"),
28     llvm::cl::Hidden);
29 
30 namespace Fortran {
31 namespace lower {
32 namespace omp {
33 
34 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
35     const omp::clause::ProcedureDesignator &pd) {
36   auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
37                      getRealName(pd.v.id()).ToString())
38                      .Case("max", ReductionIdentifier::MAX)
39                      .Case("min", ReductionIdentifier::MIN)
40                      .Case("iand", ReductionIdentifier::IAND)
41                      .Case("ior", ReductionIdentifier::IOR)
42                      .Case("ieor", ReductionIdentifier::IEOR)
43                      .Default(std::nullopt);
44   assert(redType && "Invalid Reduction");
45   return *redType;
46 }
47 
48 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
49     omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
50   switch (intrinsicOp) {
51   case omp::clause::DefinedOperator::IntrinsicOperator::Add:
52     return ReductionIdentifier::ADD;
53   case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
54     return ReductionIdentifier::SUBTRACT;
55   case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
56     return ReductionIdentifier::MULTIPLY;
57   case omp::clause::DefinedOperator::IntrinsicOperator::AND:
58     return ReductionIdentifier::AND;
59   case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
60     return ReductionIdentifier::EQV;
61   case omp::clause::DefinedOperator::IntrinsicOperator::OR:
62     return ReductionIdentifier::OR;
63   case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
64     return ReductionIdentifier::NEQV;
65   default:
66     llvm_unreachable("unexpected intrinsic operator in reduction");
67   }
68 }
69 
70 bool ReductionProcessor::supportedIntrinsicProcReduction(
71     const omp::clause::ProcedureDesignator &pd) {
72   Fortran::semantics::Symbol *sym = pd.v.id();
73   if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC))
74     return false;
75   auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
76                      .Case("max", true)
77                      .Case("min", true)
78                      .Case("iand", true)
79                      .Case("ior", true)
80                      .Case("ieor", true)
81                      .Default(false);
82   return redType;
83 }
84 
85 std::string
86 ReductionProcessor::getReductionName(llvm::StringRef name,
87                                      const fir::KindMapping &kindMap,
88                                      mlir::Type ty, bool isByRef) {
89   ty = fir::unwrapRefType(ty);
90 
91   // extra string to distinguish reduction functions for variables passed by
92   // reference
93   llvm::StringRef byrefAddition{""};
94   if (isByRef)
95     byrefAddition = "_byref";
96 
97   return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str());
98 }
99 
100 std::string ReductionProcessor::getReductionName(
101     omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
102     const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) {
103   std::string reductionName;
104 
105   switch (intrinsicOp) {
106   case omp::clause::DefinedOperator::IntrinsicOperator::Add:
107     reductionName = "add_reduction";
108     break;
109   case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
110     reductionName = "multiply_reduction";
111     break;
112   case omp::clause::DefinedOperator::IntrinsicOperator::AND:
113     return "and_reduction";
114   case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
115     return "eqv_reduction";
116   case omp::clause::DefinedOperator::IntrinsicOperator::OR:
117     return "or_reduction";
118   case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
119     return "neqv_reduction";
120   default:
121     reductionName = "other_reduction";
122     break;
123   }
124 
125   return getReductionName(reductionName, kindMap, ty, isByRef);
126 }
127 
128 mlir::Value
129 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
130                                           ReductionIdentifier redId,
131                                           fir::FirOpBuilder &builder) {
132   type = fir::unwrapRefType(type);
133   if (!fir::isa_integer(type) && !fir::isa_real(type) &&
134       !mlir::isa<fir::LogicalType>(type))
135     TODO(loc, "Reduction of some types is not supported");
136   switch (redId) {
137   case ReductionIdentifier::MAX: {
138     if (auto ty = type.dyn_cast<mlir::FloatType>()) {
139       const llvm::fltSemantics &sem = ty.getFloatSemantics();
140       return builder.createRealConstant(
141           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
142     }
143     unsigned bits = type.getIntOrFloatBitWidth();
144     int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
145     return builder.createIntegerConstant(loc, type, minInt);
146   }
147   case ReductionIdentifier::MIN: {
148     if (auto ty = type.dyn_cast<mlir::FloatType>()) {
149       const llvm::fltSemantics &sem = ty.getFloatSemantics();
150       return builder.createRealConstant(
151           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
152     }
153     unsigned bits = type.getIntOrFloatBitWidth();
154     int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
155     return builder.createIntegerConstant(loc, type, maxInt);
156   }
157   case ReductionIdentifier::IOR: {
158     unsigned bits = type.getIntOrFloatBitWidth();
159     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
160     return builder.createIntegerConstant(loc, type, zeroInt);
161   }
162   case ReductionIdentifier::IEOR: {
163     unsigned bits = type.getIntOrFloatBitWidth();
164     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
165     return builder.createIntegerConstant(loc, type, zeroInt);
166   }
167   case ReductionIdentifier::IAND: {
168     unsigned bits = type.getIntOrFloatBitWidth();
169     int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
170     return builder.createIntegerConstant(loc, type, allOnInt);
171   }
172   case ReductionIdentifier::ADD:
173   case ReductionIdentifier::MULTIPLY:
174   case ReductionIdentifier::AND:
175   case ReductionIdentifier::OR:
176   case ReductionIdentifier::EQV:
177   case ReductionIdentifier::NEQV:
178     if (type.isa<mlir::FloatType>())
179       return builder.create<mlir::arith::ConstantOp>(
180           loc, type,
181           builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc)));
182 
183     if (type.isa<fir::LogicalType>()) {
184       mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
185           loc, builder.getI1Type(),
186           builder.getIntegerAttr(builder.getI1Type(),
187                                  getOperationIdentity(redId, loc)));
188       return builder.createConvert(loc, type, intConst);
189     }
190 
191     return builder.create<mlir::arith::ConstantOp>(
192         loc, type,
193         builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
194   case ReductionIdentifier::ID:
195   case ReductionIdentifier::USER_DEF_OP:
196   case ReductionIdentifier::SUBTRACT:
197     TODO(loc, "Reduction of some identifier types is not supported");
198   }
199   llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
200 }
201 
202 mlir::Value ReductionProcessor::createScalarCombiner(
203     fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
204     mlir::Type type, mlir::Value op1, mlir::Value op2) {
205   mlir::Value reductionOp;
206   type = fir::unwrapRefType(type);
207   switch (redId) {
208   case ReductionIdentifier::MAX:
209     reductionOp =
210         getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
211             builder, type, loc, op1, op2);
212     break;
213   case ReductionIdentifier::MIN:
214     reductionOp =
215         getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
216             builder, type, loc, op1, op2);
217     break;
218   case ReductionIdentifier::IOR:
219     assert((type.isIntOrIndex()) && "only integer is expected");
220     reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
221     break;
222   case ReductionIdentifier::IEOR:
223     assert((type.isIntOrIndex()) && "only integer is expected");
224     reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
225     break;
226   case ReductionIdentifier::IAND:
227     assert((type.isIntOrIndex()) && "only integer is expected");
228     reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
229     break;
230   case ReductionIdentifier::ADD:
231     reductionOp =
232         getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
233             builder, type, loc, op1, op2);
234     break;
235   case ReductionIdentifier::MULTIPLY:
236     reductionOp =
237         getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
238             builder, type, loc, op1, op2);
239     break;
240   case ReductionIdentifier::AND: {
241     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
242     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
243 
244     mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
245 
246     reductionOp = builder.createConvert(loc, type, andiOp);
247     break;
248   }
249   case ReductionIdentifier::OR: {
250     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
251     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
252 
253     mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
254 
255     reductionOp = builder.createConvert(loc, type, oriOp);
256     break;
257   }
258   case ReductionIdentifier::EQV: {
259     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
260     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
261 
262     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
263         loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
264 
265     reductionOp = builder.createConvert(loc, type, cmpiOp);
266     break;
267   }
268   case ReductionIdentifier::NEQV: {
269     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
270     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
271 
272     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
273         loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
274 
275     reductionOp = builder.createConvert(loc, type, cmpiOp);
276     break;
277   }
278   default:
279     TODO(loc, "Reduction of some intrinsic operators is not supported");
280   }
281 
282   return reductionOp;
283 }
284 
285 /// Create reduction combiner region for reduction variables which are boxed
286 /// arrays
287 static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
288                            ReductionProcessor::ReductionIdentifier redId,
289                            fir::BaseBoxType boxTy, mlir::Value lhs,
290                            mlir::Value rhs) {
291   fir::SequenceType seqTy =
292       mlir::dyn_cast_or_null<fir::SequenceType>(boxTy.getEleTy());
293   // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
294   if (!seqTy || seqTy.hasUnknownShape())
295     TODO(loc, "Unsupported boxed type in OpenMP reduction");
296 
297   // load fir.ref<fir.box<...>>
298   mlir::Value lhsAddr = lhs;
299   lhs = builder.create<fir::LoadOp>(loc, lhs);
300   rhs = builder.create<fir::LoadOp>(loc, rhs);
301 
302   const unsigned rank = seqTy.getDimension();
303   llvm::SmallVector<mlir::Value> extents;
304   extents.reserve(rank);
305   llvm::SmallVector<mlir::Value> lbAndExtents;
306   lbAndExtents.reserve(rank * 2);
307 
308   // Get box lowerbounds and extents:
309   mlir::Type idxTy = builder.getIndexType();
310   for (unsigned i = 0; i < rank; ++i) {
311     // TODO: ideally we want to hoist box reads out of the critical section.
312     // We could do this by having box dimensions in block arguments like
313     // OpenACC does
314     mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i);
315     auto dimInfo =
316         builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, lhs, dim);
317     extents.push_back(dimInfo.getExtent());
318     lbAndExtents.push_back(dimInfo.getLowerBound());
319     lbAndExtents.push_back(dimInfo.getExtent());
320   }
321 
322   auto shapeShiftTy = fir::ShapeShiftType::get(builder.getContext(), rank);
323   auto shapeShift =
324       builder.create<fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents);
325 
326   // Iterate over array elements, applying the equivalent scalar reduction:
327 
328   // A hlfir::elemental here gets inlined with a temporary so create the
329   // loop nest directly.
330   // This function already controls all of the code in this region so we
331   // know this won't miss any opportuinties for clever elemental inlining
332   hlfir::LoopNest nest =
333       hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
334   builder.setInsertionPointToStart(nest.innerLoop.getBody());
335   mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
336   auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
337       loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
338       nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
339   auto rhsEleAddr = builder.create<fir::ArrayCoorOp>(
340       loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{},
341       nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
342   auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr);
343   auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr);
344   mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner(
345       builder, loc, redId, refTy, lhsEle, rhsEle);
346   builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
347 
348   builder.setInsertionPointAfter(nest.outerLoop);
349   builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
350 }
351 
352 // generate combiner region for reduction operations
353 static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
354                         ReductionProcessor::ReductionIdentifier redId,
355                         mlir::Type ty, mlir::Value lhs, mlir::Value rhs,
356                         bool isByRef) {
357   ty = fir::unwrapRefType(ty);
358 
359   if (fir::isa_trivial(ty)) {
360     mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs);
361     mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs);
362 
363     mlir::Value result = ReductionProcessor::createScalarCombiner(
364         builder, loc, redId, ty, lhsLoaded, rhsLoaded);
365     if (isByRef) {
366       builder.create<fir::StoreOp>(loc, result, lhs);
367       builder.create<mlir::omp::YieldOp>(loc, lhs);
368     } else {
369       builder.create<mlir::omp::YieldOp>(loc, result);
370     }
371     return;
372   }
373   // all arrays should have been boxed
374   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
375     genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs);
376     return;
377   }
378 
379   TODO(loc, "OpenMP genCombiner for unsupported reduction variable type");
380 }
381 
382 static mlir::Value
383 createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc,
384                           const ReductionProcessor::ReductionIdentifier redId,
385                           mlir::Type type, bool isByRef) {
386   mlir::Type ty = fir::unwrapRefType(type);
387   mlir::Value initValue = ReductionProcessor::getReductionInitValue(
388       loc, fir::unwrapSeqOrBoxedSeqType(ty), redId, builder);
389 
390   if (fir::isa_trivial(ty)) {
391     if (isByRef) {
392       mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty);
393       builder.createStoreWithConvert(loc, initValue, alloca);
394       return alloca;
395     }
396     // by val
397     return initValue;
398   }
399 
400   // all arrays are boxed
401   if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
402     assert(isByRef && "passing arrays by value is unsupported");
403     // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
404     mlir::Type innerTy = fir::extractSequenceType(boxTy);
405     if (!mlir::isa<fir::SequenceType>(innerTy))
406       TODO(loc, "Unsupported boxed type for reduction");
407     // Create the private copy from the initial fir.box:
408     hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)};
409 
410     // TODO: if the whole reduction is nested inside of a loop, this alloca
411     // could lead to a stack overflow (the memory is only freed at the end of
412     // the stack frame). The reduction declare operation needs a deallocation
413     // region to undo the init region.
414     hlfir::Entity temp = createStackTempFromMold(loc, builder, source);
415 
416     // Put the temporary inside of a box:
417     hlfir::Entity box = hlfir::genVariableBox(loc, builder, temp);
418     builder.create<hlfir::AssignOp>(loc, initValue, box);
419     mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty);
420     builder.create<fir::StoreOp>(loc, box, boxAlloca);
421     return boxAlloca;
422   }
423 
424   TODO(loc, "createReductionInitRegion for unsupported type");
425 }
426 
427 mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
428     fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
429     const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
430     bool isByRef) {
431   mlir::OpBuilder::InsertionGuard guard(builder);
432   mlir::ModuleOp module = builder.getModule();
433 
434   assert(!reductionOpName.empty());
435 
436   auto decl =
437       module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName);
438   if (decl)
439     return decl;
440 
441   mlir::OpBuilder modBuilder(module.getBodyRegion());
442   mlir::Type valTy = fir::unwrapRefType(type);
443   if (!isByRef)
444     type = valTy;
445 
446   decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName,
447                                                           type);
448   builder.createBlock(&decl.getInitializerRegion(),
449                       decl.getInitializerRegion().end(), {type}, {loc});
450   builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
451 
452   mlir::Value init =
453       createReductionInitRegion(builder, loc, redId, type, isByRef);
454   builder.create<mlir::omp::YieldOp>(loc, init);
455 
456   builder.createBlock(&decl.getReductionRegion(),
457                       decl.getReductionRegion().end(), {type, type},
458                       {loc, loc});
459 
460   builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
461   mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
462   mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
463   genCombiner(builder, loc, redId, type, op1, op2, isByRef);
464 
465   return decl;
466 }
467 
468 // TODO: By-ref vs by-val reductions are currently toggled for the whole
469 //       operation (possibly effecting multiple reduction variables).
470 //       This could cause a problem with openmp target reductions because
471 //       by-ref trivial types may not be supported.
472 bool ReductionProcessor::doReductionByRef(
473     const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
474   if (reductionVars.empty())
475     return false;
476   if (forceByrefReduction)
477     return true;
478 
479   for (mlir::Value reductionVar : reductionVars) {
480     if (auto declare =
481             mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
482       reductionVar = declare.getMemref();
483 
484     if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
485       return true;
486   }
487   return false;
488 }
489 
490 void ReductionProcessor::addDeclareReduction(
491     mlir::Location currentLocation,
492     Fortran::lower::AbstractConverter &converter,
493     const omp::clause::Reduction &reduction,
494     llvm::SmallVectorImpl<mlir::Value> &reductionVars,
495     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
496     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
497         *reductionSymbols) {
498   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
499   mlir::omp::DeclareReductionOp decl;
500   const auto &redOperatorList{
501       std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)};
502   assert(redOperatorList.size() == 1 && "Expecting single operator");
503   const auto &redOperator = redOperatorList.front();
504   const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
505 
506   if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) {
507     if (const auto *reductionIntrinsic =
508             std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
509       if (!ReductionProcessor::supportedIntrinsicProcReduction(
510               *reductionIntrinsic)) {
511         return;
512       }
513     } else {
514       return;
515     }
516   }
517 
518   // initial pass to collect all reduction vars so we can figure out if this
519   // should happen byref
520   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
521   for (const Object &object : objectList) {
522     const Fortran::semantics::Symbol *symbol = object.id();
523     if (reductionSymbols)
524       reductionSymbols->push_back(symbol);
525     mlir::Value symVal = converter.getSymbolAddress(*symbol);
526     auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
527 
528     // all arrays must be boxed so that we have convenient access to all the
529     // information needed to iterate over the array
530     if (mlir::isa<fir::SequenceType>(redType.getEleTy())) {
531       // For Host associated symbols, use `SymbolBox` instead
532       Fortran::lower::SymbolBox symBox =
533           converter.lookupOneLevelUpSymbol(*symbol);
534       hlfir::Entity entity{symBox.getAddr()};
535       entity = genVariableBox(currentLocation, builder, entity);
536       mlir::Value box = entity.getBase();
537 
538       // Always pass the box by reference so that the OpenMP dialect
539       // verifiers don't need to know anything about fir.box
540       auto alloca =
541           builder.create<fir::AllocaOp>(currentLocation, box.getType());
542       builder.create<fir::StoreOp>(currentLocation, box, alloca);
543 
544       symVal = alloca;
545       redType = mlir::cast<fir::ReferenceType>(symVal.getType());
546     } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
547       symVal = declOp.getBase();
548     }
549 
550     reductionVars.push_back(symVal);
551   }
552   const bool isByRef = doReductionByRef(reductionVars);
553 
554   if (const auto &redDefinedOp =
555           std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
556     const auto &intrinsicOp{
557         std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
558             redDefinedOp->u)};
559     ReductionIdentifier redId = getReductionType(intrinsicOp);
560     switch (redId) {
561     case ReductionIdentifier::ADD:
562     case ReductionIdentifier::MULTIPLY:
563     case ReductionIdentifier::AND:
564     case ReductionIdentifier::EQV:
565     case ReductionIdentifier::OR:
566     case ReductionIdentifier::NEQV:
567       break;
568     default:
569       TODO(currentLocation,
570            "Reduction of some intrinsic operators is not supported");
571       break;
572     }
573 
574     for (mlir::Value symVal : reductionVars) {
575       auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
576       const auto &kindMap = firOpBuilder.getKindMap();
577       if (redType.getEleTy().isa<fir::LogicalType>())
578         decl = createDeclareReduction(firOpBuilder,
579                                       getReductionName(intrinsicOp, kindMap,
580                                                        firOpBuilder.getI1Type(),
581                                                        isByRef),
582                                       redId, redType, currentLocation, isByRef);
583       else
584         decl = createDeclareReduction(
585             firOpBuilder,
586             getReductionName(intrinsicOp, kindMap, redType, isByRef), redId,
587             redType, currentLocation, isByRef);
588       reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
589           firOpBuilder.getContext(), decl.getSymName()));
590     }
591   } else if (const auto *reductionIntrinsic =
592                  std::get_if<omp::clause::ProcedureDesignator>(
593                      &redOperator.u)) {
594     if (ReductionProcessor::supportedIntrinsicProcReduction(
595             *reductionIntrinsic)) {
596       ReductionProcessor::ReductionIdentifier redId =
597           ReductionProcessor::getReductionType(*reductionIntrinsic);
598       for (const Object &object : objectList) {
599         const Fortran::semantics::Symbol *symbol = object.id();
600         mlir::Value symVal = converter.getSymbolAddress(*symbol);
601         if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
602           symVal = declOp.getBase();
603         auto redType = symVal.getType().cast<fir::ReferenceType>();
604         if (!redType.getEleTy().isIntOrIndexOrFloat())
605           TODO(currentLocation, "User Defined Reduction on non-trivial type");
606         decl = createDeclareReduction(
607             firOpBuilder,
608             getReductionName(getRealName(*reductionIntrinsic).ToString(),
609                              firOpBuilder.getKindMap(), redType, isByRef),
610             redId, redType, currentLocation, isByRef);
611         reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
612             firOpBuilder.getContext(), decl.getSymName()));
613       }
614     }
615   }
616 }
617 
618 const Fortran::semantics::SourceName
619 ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) {
620   return symbol->GetUltimate().name();
621 }
622 
623 const Fortran::semantics::SourceName
624 ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) {
625   return getRealName(pd.v.id());
626 }
627 
628 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
629                                              mlir::Location loc) {
630   switch (redId) {
631   case ReductionIdentifier::ADD:
632   case ReductionIdentifier::OR:
633   case ReductionIdentifier::NEQV:
634     return 0;
635   case ReductionIdentifier::MULTIPLY:
636   case ReductionIdentifier::AND:
637   case ReductionIdentifier::EQV:
638     return 1;
639   default:
640     TODO(loc, "Reduction of some intrinsic operators is not supported");
641   }
642 }
643 
644 } // namespace omp
645 } // namespace lower
646 } // namespace Fortran
647