xref: /llvm-project/flang/lib/Lower/OpenMP/ReductionProcessor.cpp (revision 88478a89cd85adcc32f2a321ef9e9906c5fdbe26)
1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ReductionProcessor.h"
14 
15 #include "flang/Lower/AbstractConverter.h"
16 #include "flang/Lower/ConvertType.h"
17 #include "flang/Lower/SymbolMap.h"
18 #include "flang/Optimizer/Builder/Complex.h"
19 #include "flang/Optimizer/Builder/HLFIRTools.h"
20 #include "flang/Optimizer/Builder/Todo.h"
21 #include "flang/Optimizer/Dialect/FIRType.h"
22 #include "flang/Optimizer/HLFIR/HLFIROps.h"
23 #include "flang/Optimizer/Support/FatalError.h"
24 #include "flang/Parser/tools.h"
25 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
26 #include "llvm/Support/CommandLine.h"
27 
28 static llvm::cl::opt<bool> forceByrefReduction(
29     "force-byref-reduction",
30     llvm::cl::desc("Pass all reduction arguments by reference"),
31     llvm::cl::Hidden);
32 
33 namespace Fortran {
34 namespace lower {
35 namespace omp {
36 
37 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
38     const omp::clause::ProcedureDesignator &pd) {
39   auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
40                      getRealName(pd.v.sym()).ToString())
41                      .Case("max", ReductionIdentifier::MAX)
42                      .Case("min", ReductionIdentifier::MIN)
43                      .Case("iand", ReductionIdentifier::IAND)
44                      .Case("ior", ReductionIdentifier::IOR)
45                      .Case("ieor", ReductionIdentifier::IEOR)
46                      .Default(std::nullopt);
47   assert(redType && "Invalid Reduction");
48   return *redType;
49 }
50 
51 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
52     omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
53   switch (intrinsicOp) {
54   case omp::clause::DefinedOperator::IntrinsicOperator::Add:
55     return ReductionIdentifier::ADD;
56   case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
57     return ReductionIdentifier::SUBTRACT;
58   case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
59     return ReductionIdentifier::MULTIPLY;
60   case omp::clause::DefinedOperator::IntrinsicOperator::AND:
61     return ReductionIdentifier::AND;
62   case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
63     return ReductionIdentifier::EQV;
64   case omp::clause::DefinedOperator::IntrinsicOperator::OR:
65     return ReductionIdentifier::OR;
66   case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
67     return ReductionIdentifier::NEQV;
68   default:
69     llvm_unreachable("unexpected intrinsic operator in reduction");
70   }
71 }
72 
73 bool ReductionProcessor::supportedIntrinsicProcReduction(
74     const omp::clause::ProcedureDesignator &pd) {
75   semantics::Symbol *sym = pd.v.sym();
76   if (!sym->GetUltimate().attrs().test(semantics::Attr::INTRINSIC))
77     return false;
78   auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
79                      .Case("max", true)
80                      .Case("min", true)
81                      .Case("iand", true)
82                      .Case("ior", true)
83                      .Case("ieor", true)
84                      .Default(false);
85   return redType;
86 }
87 
88 std::string
89 ReductionProcessor::getReductionName(llvm::StringRef name,
90                                      const fir::KindMapping &kindMap,
91                                      mlir::Type ty, bool isByRef) {
92   ty = fir::unwrapRefType(ty);
93 
94   // extra string to distinguish reduction functions for variables passed by
95   // reference
96   llvm::StringRef byrefAddition{""};
97   if (isByRef)
98     byrefAddition = "_byref";
99 
100   return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str());
101 }
102 
103 std::string ReductionProcessor::getReductionName(
104     omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
105     const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) {
106   std::string reductionName;
107 
108   switch (intrinsicOp) {
109   case omp::clause::DefinedOperator::IntrinsicOperator::Add:
110     reductionName = "add_reduction";
111     break;
112   case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
113     reductionName = "multiply_reduction";
114     break;
115   case omp::clause::DefinedOperator::IntrinsicOperator::AND:
116     return "and_reduction";
117   case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
118     return "eqv_reduction";
119   case omp::clause::DefinedOperator::IntrinsicOperator::OR:
120     return "or_reduction";
121   case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
122     return "neqv_reduction";
123   default:
124     reductionName = "other_reduction";
125     break;
126   }
127 
128   return getReductionName(reductionName, kindMap, ty, isByRef);
129 }
130 
131 mlir::Value
132 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
133                                           ReductionIdentifier redId,
134                                           fir::FirOpBuilder &builder) {
135   type = fir::unwrapRefType(type);
136   if (!fir::isa_integer(type) && !fir::isa_real(type) &&
137       !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
138     TODO(loc, "Reduction of some types is not supported");
139   switch (redId) {
140   case ReductionIdentifier::MAX: {
141     if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) {
142       const llvm::fltSemantics &sem = ty.getFloatSemantics();
143       return builder.createRealConstant(
144           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
145     }
146     unsigned bits = type.getIntOrFloatBitWidth();
147     int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
148     return builder.createIntegerConstant(loc, type, minInt);
149   }
150   case ReductionIdentifier::MIN: {
151     if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) {
152       const llvm::fltSemantics &sem = ty.getFloatSemantics();
153       return builder.createRealConstant(
154           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
155     }
156     unsigned bits = type.getIntOrFloatBitWidth();
157     int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
158     return builder.createIntegerConstant(loc, type, maxInt);
159   }
160   case ReductionIdentifier::IOR: {
161     unsigned bits = type.getIntOrFloatBitWidth();
162     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
163     return builder.createIntegerConstant(loc, type, zeroInt);
164   }
165   case ReductionIdentifier::IEOR: {
166     unsigned bits = type.getIntOrFloatBitWidth();
167     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
168     return builder.createIntegerConstant(loc, type, zeroInt);
169   }
170   case ReductionIdentifier::IAND: {
171     unsigned bits = type.getIntOrFloatBitWidth();
172     int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
173     return builder.createIntegerConstant(loc, type, allOnInt);
174   }
175   case ReductionIdentifier::ADD:
176   case ReductionIdentifier::MULTIPLY:
177   case ReductionIdentifier::AND:
178   case ReductionIdentifier::OR:
179   case ReductionIdentifier::EQV:
180   case ReductionIdentifier::NEQV:
181     if (auto cplxTy = mlir::dyn_cast<mlir::ComplexType>(type)) {
182       mlir::Type realTy = cplxTy.getElementType();
183       mlir::Value initRe = builder.createRealConstant(
184           loc, realTy, getOperationIdentity(redId, loc));
185       mlir::Value initIm = builder.createRealConstant(loc, realTy, 0);
186 
187       return fir::factory::Complex{builder, loc}.createComplex(type, initRe,
188                                                                initIm);
189     }
190     if (mlir::isa<mlir::FloatType>(type))
191       return builder.create<mlir::arith::ConstantOp>(
192           loc, type,
193           builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc)));
194 
195     if (mlir::isa<fir::LogicalType>(type)) {
196       mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
197           loc, builder.getI1Type(),
198           builder.getIntegerAttr(builder.getI1Type(),
199                                  getOperationIdentity(redId, loc)));
200       return builder.createConvert(loc, type, intConst);
201     }
202 
203     return builder.create<mlir::arith::ConstantOp>(
204         loc, type,
205         builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
206   case ReductionIdentifier::ID:
207   case ReductionIdentifier::USER_DEF_OP:
208   case ReductionIdentifier::SUBTRACT:
209     TODO(loc, "Reduction of some identifier types is not supported");
210   }
211   llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
212 }
213 
214 mlir::Value ReductionProcessor::createScalarCombiner(
215     fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
216     mlir::Type type, mlir::Value op1, mlir::Value op2) {
217   mlir::Value reductionOp;
218   type = fir::unwrapRefType(type);
219   switch (redId) {
220   case ReductionIdentifier::MAX:
221     reductionOp =
222         getReductionOperation<mlir::arith::MaxNumFOp, mlir::arith::MaxSIOp>(
223             builder, type, loc, op1, op2);
224     break;
225   case ReductionIdentifier::MIN:
226     reductionOp =
227         getReductionOperation<mlir::arith::MinNumFOp, mlir::arith::MinSIOp>(
228             builder, type, loc, op1, op2);
229     break;
230   case ReductionIdentifier::IOR:
231     assert((type.isIntOrIndex()) && "only integer is expected");
232     reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
233     break;
234   case ReductionIdentifier::IEOR:
235     assert((type.isIntOrIndex()) && "only integer is expected");
236     reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
237     break;
238   case ReductionIdentifier::IAND:
239     assert((type.isIntOrIndex()) && "only integer is expected");
240     reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
241     break;
242   case ReductionIdentifier::ADD:
243     reductionOp =
244         getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
245                               fir::AddcOp>(builder, type, loc, op1, op2);
246     break;
247   case ReductionIdentifier::MULTIPLY:
248     reductionOp =
249         getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
250                               fir::MulcOp>(builder, type, loc, op1, op2);
251     break;
252   case ReductionIdentifier::AND: {
253     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
254     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
255 
256     mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
257 
258     reductionOp = builder.createConvert(loc, type, andiOp);
259     break;
260   }
261   case ReductionIdentifier::OR: {
262     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
263     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
264 
265     mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
266 
267     reductionOp = builder.createConvert(loc, type, oriOp);
268     break;
269   }
270   case ReductionIdentifier::EQV: {
271     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
272     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
273 
274     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
275         loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
276 
277     reductionOp = builder.createConvert(loc, type, cmpiOp);
278     break;
279   }
280   case ReductionIdentifier::NEQV: {
281     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
282     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
283 
284     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
285         loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
286 
287     reductionOp = builder.createConvert(loc, type, cmpiOp);
288     break;
289   }
290   default:
291     TODO(loc, "Reduction of some intrinsic operators is not supported");
292   }
293 
294   return reductionOp;
295 }
296 
297 /// Generate a fir::ShapeShift op describing the provided boxed array.
298 static fir::ShapeShiftOp getShapeShift(fir::FirOpBuilder &builder,
299                                        mlir::Location loc, mlir::Value box) {
300   fir::SequenceType sequenceType = mlir::cast<fir::SequenceType>(
301       hlfir::getFortranElementOrSequenceType(box.getType()));
302   const unsigned rank = sequenceType.getDimension();
303   llvm::SmallVector<mlir::Value> lbAndExtents;
304   lbAndExtents.reserve(rank * 2);
305 
306   mlir::Type idxTy = builder.getIndexType();
307   for (unsigned i = 0; i < rank; ++i) {
308     // TODO: ideally we want to hoist box reads out of the critical section.
309     // We could do this by having box dimensions in block arguments like
310     // OpenACC does
311     mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i);
312     auto dimInfo =
313         builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, box, dim);
314     lbAndExtents.push_back(dimInfo.getLowerBound());
315     lbAndExtents.push_back(dimInfo.getExtent());
316   }
317 
318   auto shapeShiftTy = fir::ShapeShiftType::get(builder.getContext(), rank);
319   auto shapeShift =
320       builder.create<fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents);
321   return shapeShift;
322 }
323 
324 /// Create reduction combiner region for reduction variables which are boxed
325 /// arrays
326 static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
327                            ReductionProcessor::ReductionIdentifier redId,
328                            fir::BaseBoxType boxTy, mlir::Value lhs,
329                            mlir::Value rhs) {
330   fir::SequenceType seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
331       fir::unwrapRefType(boxTy.getEleTy()));
332   fir::HeapType heapTy =
333       mlir::dyn_cast_or_null<fir::HeapType>(boxTy.getEleTy());
334   fir::PointerType ptrTy =
335       mlir::dyn_cast_or_null<fir::PointerType>(boxTy.getEleTy());
336   if ((!seqTy || seqTy.hasUnknownShape()) && !heapTy && !ptrTy)
337     TODO(loc, "Unsupported boxed type in OpenMP reduction");
338 
339   // load fir.ref<fir.box<...>>
340   mlir::Value lhsAddr = lhs;
341   lhs = builder.create<fir::LoadOp>(loc, lhs);
342   rhs = builder.create<fir::LoadOp>(loc, rhs);
343 
344   if ((heapTy || ptrTy) && !seqTy) {
345     // get box contents (heap pointers)
346     lhs = builder.create<fir::BoxAddrOp>(loc, lhs);
347     rhs = builder.create<fir::BoxAddrOp>(loc, rhs);
348     mlir::Value lhsValAddr = lhs;
349 
350     // load heap pointers
351     lhs = builder.create<fir::LoadOp>(loc, lhs);
352     rhs = builder.create<fir::LoadOp>(loc, rhs);
353 
354     mlir::Type eleTy = heapTy ? heapTy.getEleTy() : ptrTy.getEleTy();
355 
356     mlir::Value result = ReductionProcessor::createScalarCombiner(
357         builder, loc, redId, eleTy, lhs, rhs);
358     builder.create<fir::StoreOp>(loc, result, lhsValAddr);
359     builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
360     return;
361   }
362 
363   fir::ShapeShiftOp shapeShift = getShapeShift(builder, loc, lhs);
364 
365   // Iterate over array elements, applying the equivalent scalar reduction:
366 
367   // F2018 5.4.10.2: Unallocated allocatable variables may not be referenced
368   // and so no null check is needed here before indexing into the (possibly
369   // allocatable) arrays.
370 
371   // A hlfir::elemental here gets inlined with a temporary so create the
372   // loop nest directly.
373   // This function already controls all of the code in this region so we
374   // know this won't miss any opportuinties for clever elemental inlining
375   hlfir::LoopNest nest = hlfir::genLoopNest(
376       loc, builder, shapeShift.getExtents(), /*isUnordered=*/true);
377   builder.setInsertionPointToStart(nest.innerLoop.getBody());
378   mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
379   auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
380       loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
381       nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
382   auto rhsEleAddr = builder.create<fir::ArrayCoorOp>(
383       loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{},
384       nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
385   auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr);
386   auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr);
387   mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner(
388       builder, loc, redId, refTy, lhsEle, rhsEle);
389   builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
390 
391   builder.setInsertionPointAfter(nest.outerLoop);
392   builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
393 }
394 
395 // generate combiner region for reduction operations
396 static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
397                         ReductionProcessor::ReductionIdentifier redId,
398                         mlir::Type ty, mlir::Value lhs, mlir::Value rhs,
399                         bool isByRef) {
400   ty = fir::unwrapRefType(ty);
401 
402   if (fir::isa_trivial(ty)) {
403     mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs);
404     mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs);
405 
406     mlir::Value result = ReductionProcessor::createScalarCombiner(
407         builder, loc, redId, ty, lhsLoaded, rhsLoaded);
408     if (isByRef) {
409       builder.create<fir::StoreOp>(loc, result, lhs);
410       builder.create<mlir::omp::YieldOp>(loc, lhs);
411     } else {
412       builder.create<mlir::omp::YieldOp>(loc, result);
413     }
414     return;
415   }
416   // all arrays should have been boxed
417   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
418     genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs);
419     return;
420   }
421 
422   TODO(loc, "OpenMP genCombiner for unsupported reduction variable type");
423 }
424 
425 static void
426 createReductionCleanupRegion(fir::FirOpBuilder &builder, mlir::Location loc,
427                              mlir::omp::DeclareReductionOp &reductionDecl) {
428   mlir::Type redTy = reductionDecl.getType();
429 
430   mlir::Region &cleanupRegion = reductionDecl.getCleanupRegion();
431   assert(cleanupRegion.empty());
432   mlir::Block *block =
433       builder.createBlock(&cleanupRegion, cleanupRegion.end(), {redTy}, {loc});
434   builder.setInsertionPointToEnd(block);
435 
436   auto typeError = [loc]() {
437     fir::emitFatalError(loc,
438                         "Attempt to create an omp reduction cleanup region "
439                         "for a type that wasn't allocated",
440                         /*genCrashDiag=*/true);
441   };
442 
443   mlir::Type valTy = fir::unwrapRefType(redTy);
444   if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(valTy)) {
445     if (!mlir::isa<fir::HeapType, fir::PointerType>(boxTy.getEleTy())) {
446       mlir::Type innerTy = fir::extractSequenceType(boxTy);
447       if (!mlir::isa<fir::SequenceType>(innerTy))
448         typeError();
449     }
450 
451     mlir::Value arg = block->getArgument(0);
452     arg = builder.loadIfRef(loc, arg);
453     assert(mlir::isa<fir::BaseBoxType>(arg.getType()));
454 
455     // Deallocate box
456     // The FIR type system doesn't nesecarrily know that this is a mutable box
457     // if we allocated the thread local array on the heap to avoid looped stack
458     // allocations.
459     mlir::Value addr =
460         hlfir::genVariableRawAddress(loc, builder, hlfir::Entity{arg});
461     mlir::Value isAllocated = builder.genIsNotNullAddr(loc, addr);
462     fir::IfOp ifOp =
463         builder.create<fir::IfOp>(loc, isAllocated, /*withElseRegion=*/false);
464     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
465 
466     mlir::Value cast = builder.createConvert(
467         loc, fir::HeapType::get(fir::dyn_cast_ptrEleTy(addr.getType())), addr);
468     builder.create<fir::FreeMemOp>(loc, cast);
469 
470     builder.setInsertionPointAfter(ifOp);
471     builder.create<mlir::omp::YieldOp>(loc);
472     return;
473   }
474 
475   typeError();
476 }
477 
478 // like fir::unwrapSeqOrBoxedSeqType except it also works for non-sequence boxes
479 static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) {
480   if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
481     return seqTy.getEleTy();
482   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
483     auto eleTy = fir::unwrapRefType(boxTy.getEleTy());
484     if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy))
485       return seqTy.getEleTy();
486     return eleTy;
487   }
488   return ty;
489 }
490 
491 static void createReductionAllocAndInitRegions(
492     fir::FirOpBuilder &builder, mlir::Location loc,
493     mlir::omp::DeclareReductionOp &reductionDecl,
494     const ReductionProcessor::ReductionIdentifier redId, mlir::Type type,
495     bool isByRef) {
496   auto yield = [&](mlir::Value ret) {
497     builder.create<mlir::omp::YieldOp>(loc, ret);
498   };
499 
500   mlir::Block *allocBlock = nullptr;
501   mlir::Block *initBlock = nullptr;
502   if (isByRef) {
503     allocBlock =
504         builder.createBlock(&reductionDecl.getAllocRegion(),
505                             reductionDecl.getAllocRegion().end(), {}, {});
506     initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(),
507                                     reductionDecl.getInitializerRegion().end(),
508                                     {type, type}, {loc, loc});
509   } else {
510     initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(),
511                                     reductionDecl.getInitializerRegion().end(),
512                                     {type}, {loc});
513   }
514 
515   mlir::Type ty = fir::unwrapRefType(type);
516   builder.setInsertionPointToEnd(initBlock);
517   mlir::Value initValue = ReductionProcessor::getReductionInitValue(
518       loc, unwrapSeqOrBoxedType(ty), redId, builder);
519 
520   if (fir::isa_trivial(ty)) {
521     if (isByRef) {
522       // alloc region
523       {
524         builder.setInsertionPointToEnd(allocBlock);
525         mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty);
526         yield(alloca);
527       }
528 
529       // init region
530       {
531         builder.setInsertionPointToEnd(initBlock);
532         // block arg is mapped to the alloca yielded from the alloc region
533         mlir::Value alloc = reductionDecl.getInitializerAllocArg();
534         builder.createStoreWithConvert(loc, initValue, alloc);
535         yield(alloc);
536       }
537       return;
538     }
539     // by val
540     yield(initValue);
541     return;
542   }
543 
544   // check if an allocatable box is unallocated. If so, initialize the boxAlloca
545   // to be unallocated e.g.
546   // %box_alloca = fir.alloca !fir.box<!fir.heap<...>>
547   // %addr = fir.box_addr %box
548   // if (%addr == 0) {
549   //   %nullbox = fir.embox %addr
550   //   fir.store %nullbox to %box_alloca
551   // } else {
552   //   // ...
553   //   fir.store %something to %box_alloca
554   // }
555   // omp.yield %box_alloca
556   mlir::Value moldArg =
557       builder.loadIfRef(loc, reductionDecl.getInitializerMoldArg());
558   auto handleNullAllocatable = [&](mlir::Value boxAlloca) -> fir::IfOp {
559     mlir::Value addr = builder.create<fir::BoxAddrOp>(loc, moldArg);
560     mlir::Value isNotAllocated = builder.genIsNullAddr(loc, addr);
561     fir::IfOp ifOp = builder.create<fir::IfOp>(loc, isNotAllocated,
562                                                /*withElseRegion=*/true);
563     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
564     // just embox the null address and return
565     mlir::Value nullBox = builder.create<fir::EmboxOp>(loc, ty, addr);
566     builder.create<fir::StoreOp>(loc, nullBox, boxAlloca);
567     return ifOp;
568   };
569 
570   // all arrays are boxed
571   if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
572     assert(isByRef && "passing boxes by value is unsupported");
573     bool isAllocatableOrPointer =
574         mlir::isa<fir::HeapType, fir::PointerType>(boxTy.getEleTy());
575 
576     // alloc region
577     {
578       builder.setInsertionPointToEnd(allocBlock);
579       mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty);
580       yield(boxAlloca);
581     }
582 
583     // init region
584     builder.setInsertionPointToEnd(initBlock);
585     mlir::Value boxAlloca = reductionDecl.getInitializerAllocArg();
586     mlir::Type innerTy = fir::unwrapRefType(boxTy.getEleTy());
587     if (fir::isa_trivial(innerTy)) {
588       // boxed non-sequence value e.g. !fir.box<!fir.heap<i32>>
589       if (!isAllocatableOrPointer)
590         TODO(loc, "Reduction of non-allocatable trivial typed box");
591 
592       fir::IfOp ifUnallocated = handleNullAllocatable(boxAlloca);
593 
594       builder.setInsertionPointToStart(&ifUnallocated.getElseRegion().front());
595       mlir::Value valAlloc = builder.create<fir::AllocMemOp>(loc, innerTy);
596       builder.createStoreWithConvert(loc, initValue, valAlloc);
597       mlir::Value box = builder.create<fir::EmboxOp>(loc, ty, valAlloc);
598       builder.create<fir::StoreOp>(loc, box, boxAlloca);
599 
600       auto insPt = builder.saveInsertionPoint();
601       createReductionCleanupRegion(builder, loc, reductionDecl);
602       builder.restoreInsertionPoint(insPt);
603       builder.setInsertionPointAfter(ifUnallocated);
604       yield(boxAlloca);
605       return;
606     }
607     innerTy = fir::extractSequenceType(boxTy);
608     if (!mlir::isa<fir::SequenceType>(innerTy))
609       TODO(loc, "Unsupported boxed type for reduction");
610 
611     fir::IfOp ifUnallocated{nullptr};
612     if (isAllocatableOrPointer) {
613       ifUnallocated = handleNullAllocatable(boxAlloca);
614       builder.setInsertionPointToStart(&ifUnallocated.getElseRegion().front());
615     }
616 
617     // Create the private copy from the initial fir.box:
618     mlir::Value loadedBox = builder.loadIfRef(loc, moldArg);
619     hlfir::Entity source = hlfir::Entity{loadedBox};
620 
621     // Allocating on the heap in case the whole reduction is nested inside of a
622     // loop
623     // TODO: compare performance here to using allocas - this could be made to
624     // work by inserting stacksave/stackrestore around the reduction in
625     // openmpirbuilder
626     auto [temp, needsDealloc] = createTempFromMold(loc, builder, source);
627     // if needsDealloc isn't statically false, add cleanup region. Always
628     // do this for allocatable boxes because they might have been re-allocated
629     // in the body of the loop/parallel region
630 
631     std::optional<int64_t> cstNeedsDealloc =
632         fir::getIntIfConstant(needsDealloc);
633     assert(cstNeedsDealloc.has_value() &&
634            "createTempFromMold decides this statically");
635     if (cstNeedsDealloc.has_value() && *cstNeedsDealloc != false) {
636       mlir::OpBuilder::InsertionGuard guard(builder);
637       createReductionCleanupRegion(builder, loc, reductionDecl);
638     } else {
639       assert(!isAllocatableOrPointer &&
640              "Pointer-like arrays must be heap allocated");
641     }
642 
643     // Put the temporary inside of a box:
644     // hlfir::genVariableBox doesn't handle non-default lower bounds
645     mlir::Value box;
646     fir::ShapeShiftOp shapeShift = getShapeShift(builder, loc, loadedBox);
647     mlir::Type boxType = loadedBox.getType();
648     if (mlir::isa<fir::BaseBoxType>(temp.getType()))
649       // the box created by the declare form createTempFromMold is missing lower
650       // bounds info
651       box = builder.create<fir::ReboxOp>(loc, boxType, temp, shapeShift,
652                                          /*shift=*/mlir::Value{});
653     else
654       box = builder.create<fir::EmboxOp>(
655           loc, boxType, temp, shapeShift,
656           /*slice=*/mlir::Value{},
657           /*typeParams=*/llvm::ArrayRef<mlir::Value>{});
658 
659     builder.create<hlfir::AssignOp>(loc, initValue, box);
660     builder.create<fir::StoreOp>(loc, box, boxAlloca);
661     if (ifUnallocated)
662       builder.setInsertionPointAfter(ifUnallocated);
663     yield(boxAlloca);
664     return;
665   }
666 
667   TODO(loc, "createReductionInitRegion for unsupported type");
668 }
669 
670 mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
671     fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
672     const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
673     bool isByRef) {
674   mlir::OpBuilder::InsertionGuard guard(builder);
675   mlir::ModuleOp module = builder.getModule();
676 
677   assert(!reductionOpName.empty());
678 
679   auto decl =
680       module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName);
681   if (decl)
682     return decl;
683 
684   mlir::OpBuilder modBuilder(module.getBodyRegion());
685   mlir::Type valTy = fir::unwrapRefType(type);
686   if (!isByRef)
687     type = valTy;
688 
689   decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName,
690                                                           type);
691   createReductionAllocAndInitRegions(builder, loc, decl, redId, type, isByRef);
692 
693   builder.createBlock(&decl.getReductionRegion(),
694                       decl.getReductionRegion().end(), {type, type},
695                       {loc, loc});
696 
697   builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
698   mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
699   mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
700   genCombiner(builder, loc, redId, type, op1, op2, isByRef);
701 
702   return decl;
703 }
704 
705 static bool doReductionByRef(mlir::Value reductionVar) {
706   if (forceByrefReduction)
707     return true;
708 
709   if (auto declare =
710           mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
711     reductionVar = declare.getMemref();
712 
713   if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
714     return true;
715 
716   return false;
717 }
718 
719 void ReductionProcessor::addDeclareReduction(
720     mlir::Location currentLocation, lower::AbstractConverter &converter,
721     const omp::clause::Reduction &reduction,
722     llvm::SmallVectorImpl<mlir::Value> &reductionVars,
723     llvm::SmallVectorImpl<bool> &reduceVarByRef,
724     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
725     llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
726   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
727 
728   if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
729           reduction.t))
730     TODO(currentLocation, "Reduction modifiers are not supported");
731 
732   mlir::omp::DeclareReductionOp decl;
733   const auto &redOperatorList{
734       std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)};
735   assert(redOperatorList.size() == 1 && "Expecting single operator");
736   const auto &redOperator = redOperatorList.front();
737   const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
738 
739   if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) {
740     if (const auto *reductionIntrinsic =
741             std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
742       if (!ReductionProcessor::supportedIntrinsicProcReduction(
743               *reductionIntrinsic)) {
744         return;
745       }
746     } else {
747       return;
748     }
749   }
750 
751   // Reduction variable processing common to both intrinsic operators and
752   // procedure designators
753   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
754   for (const Object &object : objectList) {
755     const semantics::Symbol *symbol = object.sym();
756     reductionSymbols.push_back(symbol);
757     mlir::Value symVal = converter.getSymbolAddress(*symbol);
758     mlir::Type eleType;
759     auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType());
760     if (refType)
761       eleType = refType.getEleTy();
762     else
763       eleType = symVal.getType();
764 
765     // all arrays must be boxed so that we have convenient access to all the
766     // information needed to iterate over the array
767     if (mlir::isa<fir::SequenceType>(eleType)) {
768       // For Host associated symbols, use `SymbolBox` instead
769       lower::SymbolBox symBox = converter.lookupOneLevelUpSymbol(*symbol);
770       hlfir::Entity entity{symBox.getAddr()};
771       entity = genVariableBox(currentLocation, builder, entity);
772       mlir::Value box = entity.getBase();
773 
774       // Always pass the box by reference so that the OpenMP dialect
775       // verifiers don't need to know anything about fir.box
776       auto alloca =
777           builder.create<fir::AllocaOp>(currentLocation, box.getType());
778       builder.create<fir::StoreOp>(currentLocation, box, alloca);
779 
780       symVal = alloca;
781     } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) {
782       // boxed arrays are passed as values not by reference. Unfortunately,
783       // we can't pass a box by value to omp.redution_declare, so turn it
784       // into a reference
785 
786       auto alloca =
787           builder.create<fir::AllocaOp>(currentLocation, symVal.getType());
788       builder.create<fir::StoreOp>(currentLocation, symVal, alloca);
789       symVal = alloca;
790     } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
791       symVal = declOp.getBase();
792     }
793 
794     // this isn't the same as the by-val and by-ref passing later in the
795     // pipeline. Both styles assume that the variable is a reference at
796     // this point
797     assert(mlir::isa<fir::ReferenceType>(symVal.getType()) &&
798            "reduction input var is a reference");
799 
800     reductionVars.push_back(symVal);
801     reduceVarByRef.push_back(doReductionByRef(symVal));
802   }
803 
804   for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
805     auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
806     const auto &kindMap = firOpBuilder.getKindMap();
807     std::string reductionName;
808     ReductionIdentifier redId;
809     mlir::Type redNameTy = redType;
810     if (mlir::isa<fir::LogicalType>(redType.getEleTy()))
811       redNameTy = builder.getI1Type();
812 
813     if (const auto &redDefinedOp =
814             std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
815       const auto &intrinsicOp{
816           std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
817               redDefinedOp->u)};
818       redId = getReductionType(intrinsicOp);
819       switch (redId) {
820       case ReductionIdentifier::ADD:
821       case ReductionIdentifier::MULTIPLY:
822       case ReductionIdentifier::AND:
823       case ReductionIdentifier::EQV:
824       case ReductionIdentifier::OR:
825       case ReductionIdentifier::NEQV:
826         break;
827       default:
828         TODO(currentLocation,
829              "Reduction of some intrinsic operators is not supported");
830         break;
831       }
832 
833       reductionName =
834           getReductionName(intrinsicOp, kindMap, redNameTy, isByRef);
835     } else if (const auto *reductionIntrinsic =
836                    std::get_if<omp::clause::ProcedureDesignator>(
837                        &redOperator.u)) {
838       if (!ReductionProcessor::supportedIntrinsicProcReduction(
839               *reductionIntrinsic)) {
840         TODO(currentLocation, "Unsupported intrinsic proc reduction");
841       }
842       redId = getReductionType(*reductionIntrinsic);
843       reductionName =
844           getReductionName(getRealName(*reductionIntrinsic).ToString(), kindMap,
845                            redNameTy, isByRef);
846     } else {
847       TODO(currentLocation, "Unexpected reduction type");
848     }
849 
850     decl = createDeclareReduction(firOpBuilder, reductionName, redId, redType,
851                                   currentLocation, isByRef);
852     reductionDeclSymbols.push_back(
853         mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName()));
854   }
855 }
856 
857 const semantics::SourceName
858 ReductionProcessor::getRealName(const semantics::Symbol *symbol) {
859   return symbol->GetUltimate().name();
860 }
861 
862 const semantics::SourceName
863 ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) {
864   return getRealName(pd.v.sym());
865 }
866 
867 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
868                                              mlir::Location loc) {
869   switch (redId) {
870   case ReductionIdentifier::ADD:
871   case ReductionIdentifier::OR:
872   case ReductionIdentifier::NEQV:
873     return 0;
874   case ReductionIdentifier::MULTIPLY:
875   case ReductionIdentifier::AND:
876   case ReductionIdentifier::EQV:
877     return 1;
878   default:
879     TODO(loc, "Reduction of some intrinsic operators is not supported");
880   }
881 }
882 
883 } // namespace omp
884 } // namespace lower
885 } // namespace Fortran
886