xref: /llvm-project/flang/lib/Lower/OpenMP/ReductionProcessor.cpp (revision 0c455ee34823cb991a35e33ff020bb7cc4e44c8a)
1 //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ReductionProcessor.h"
14 
15 #include "flang/Lower/AbstractConverter.h"
16 #include "flang/Lower/ConvertType.h"
17 #include "flang/Lower/SymbolMap.h"
18 #include "flang/Optimizer/Builder/Complex.h"
19 #include "flang/Optimizer/Builder/HLFIRTools.h"
20 #include "flang/Optimizer/Builder/Todo.h"
21 #include "flang/Optimizer/Dialect/FIRType.h"
22 #include "flang/Optimizer/HLFIR/HLFIROps.h"
23 #include "flang/Optimizer/Support/FatalError.h"
24 #include "flang/Parser/tools.h"
25 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
26 #include "llvm/Support/CommandLine.h"
27 
28 static llvm::cl::opt<bool> forceByrefReduction(
29     "force-byref-reduction",
30     llvm::cl::desc("Pass all reduction arguments by reference"),
31     llvm::cl::Hidden);
32 
33 namespace Fortran {
34 namespace lower {
35 namespace omp {
36 
37 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
38     const omp::clause::ProcedureDesignator &pd) {
39   auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
40                      getRealName(pd.v.id()).ToString())
41                      .Case("max", ReductionIdentifier::MAX)
42                      .Case("min", ReductionIdentifier::MIN)
43                      .Case("iand", ReductionIdentifier::IAND)
44                      .Case("ior", ReductionIdentifier::IOR)
45                      .Case("ieor", ReductionIdentifier::IEOR)
46                      .Default(std::nullopt);
47   assert(redType && "Invalid Reduction");
48   return *redType;
49 }
50 
51 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
52     omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
53   switch (intrinsicOp) {
54   case omp::clause::DefinedOperator::IntrinsicOperator::Add:
55     return ReductionIdentifier::ADD;
56   case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
57     return ReductionIdentifier::SUBTRACT;
58   case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
59     return ReductionIdentifier::MULTIPLY;
60   case omp::clause::DefinedOperator::IntrinsicOperator::AND:
61     return ReductionIdentifier::AND;
62   case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
63     return ReductionIdentifier::EQV;
64   case omp::clause::DefinedOperator::IntrinsicOperator::OR:
65     return ReductionIdentifier::OR;
66   case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
67     return ReductionIdentifier::NEQV;
68   default:
69     llvm_unreachable("unexpected intrinsic operator in reduction");
70   }
71 }
72 
73 bool ReductionProcessor::supportedIntrinsicProcReduction(
74     const omp::clause::ProcedureDesignator &pd) {
75   Fortran::semantics::Symbol *sym = pd.v.id();
76   if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC))
77     return false;
78   auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
79                      .Case("max", true)
80                      .Case("min", true)
81                      .Case("iand", true)
82                      .Case("ior", true)
83                      .Case("ieor", true)
84                      .Default(false);
85   return redType;
86 }
87 
88 std::string
89 ReductionProcessor::getReductionName(llvm::StringRef name,
90                                      const fir::KindMapping &kindMap,
91                                      mlir::Type ty, bool isByRef) {
92   ty = fir::unwrapRefType(ty);
93 
94   // extra string to distinguish reduction functions for variables passed by
95   // reference
96   llvm::StringRef byrefAddition{""};
97   if (isByRef)
98     byrefAddition = "_byref";
99 
100   return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str());
101 }
102 
103 std::string ReductionProcessor::getReductionName(
104     omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
105     const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) {
106   std::string reductionName;
107 
108   switch (intrinsicOp) {
109   case omp::clause::DefinedOperator::IntrinsicOperator::Add:
110     reductionName = "add_reduction";
111     break;
112   case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
113     reductionName = "multiply_reduction";
114     break;
115   case omp::clause::DefinedOperator::IntrinsicOperator::AND:
116     return "and_reduction";
117   case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
118     return "eqv_reduction";
119   case omp::clause::DefinedOperator::IntrinsicOperator::OR:
120     return "or_reduction";
121   case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
122     return "neqv_reduction";
123   default:
124     reductionName = "other_reduction";
125     break;
126   }
127 
128   return getReductionName(reductionName, kindMap, ty, isByRef);
129 }
130 
131 mlir::Value
132 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
133                                           ReductionIdentifier redId,
134                                           fir::FirOpBuilder &builder) {
135   type = fir::unwrapRefType(type);
136   if (!fir::isa_integer(type) && !fir::isa_real(type) &&
137       !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
138     TODO(loc, "Reduction of some types is not supported");
139   switch (redId) {
140   case ReductionIdentifier::MAX: {
141     if (auto ty = type.dyn_cast<mlir::FloatType>()) {
142       const llvm::fltSemantics &sem = ty.getFloatSemantics();
143       return builder.createRealConstant(
144           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
145     }
146     unsigned bits = type.getIntOrFloatBitWidth();
147     int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
148     return builder.createIntegerConstant(loc, type, minInt);
149   }
150   case ReductionIdentifier::MIN: {
151     if (auto ty = type.dyn_cast<mlir::FloatType>()) {
152       const llvm::fltSemantics &sem = ty.getFloatSemantics();
153       return builder.createRealConstant(
154           loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
155     }
156     unsigned bits = type.getIntOrFloatBitWidth();
157     int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
158     return builder.createIntegerConstant(loc, type, maxInt);
159   }
160   case ReductionIdentifier::IOR: {
161     unsigned bits = type.getIntOrFloatBitWidth();
162     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
163     return builder.createIntegerConstant(loc, type, zeroInt);
164   }
165   case ReductionIdentifier::IEOR: {
166     unsigned bits = type.getIntOrFloatBitWidth();
167     int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
168     return builder.createIntegerConstant(loc, type, zeroInt);
169   }
170   case ReductionIdentifier::IAND: {
171     unsigned bits = type.getIntOrFloatBitWidth();
172     int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
173     return builder.createIntegerConstant(loc, type, allOnInt);
174   }
175   case ReductionIdentifier::ADD:
176   case ReductionIdentifier::MULTIPLY:
177   case ReductionIdentifier::AND:
178   case ReductionIdentifier::OR:
179   case ReductionIdentifier::EQV:
180   case ReductionIdentifier::NEQV:
181     if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) {
182       mlir::Type realTy =
183           Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind());
184       mlir::Value initRe = builder.createRealConstant(
185           loc, realTy, getOperationIdentity(redId, loc));
186       mlir::Value initIm = builder.createRealConstant(loc, realTy, 0);
187 
188       return fir::factory::Complex{builder, loc}.createComplex(type, initRe,
189                                                                initIm);
190     }
191     if (type.isa<mlir::FloatType>())
192       return builder.create<mlir::arith::ConstantOp>(
193           loc, type,
194           builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc)));
195 
196     if (type.isa<fir::LogicalType>()) {
197       mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
198           loc, builder.getI1Type(),
199           builder.getIntegerAttr(builder.getI1Type(),
200                                  getOperationIdentity(redId, loc)));
201       return builder.createConvert(loc, type, intConst);
202     }
203 
204     return builder.create<mlir::arith::ConstantOp>(
205         loc, type,
206         builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
207   case ReductionIdentifier::ID:
208   case ReductionIdentifier::USER_DEF_OP:
209   case ReductionIdentifier::SUBTRACT:
210     TODO(loc, "Reduction of some identifier types is not supported");
211   }
212   llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
213 }
214 
215 mlir::Value ReductionProcessor::createScalarCombiner(
216     fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
217     mlir::Type type, mlir::Value op1, mlir::Value op2) {
218   mlir::Value reductionOp;
219   type = fir::unwrapRefType(type);
220   switch (redId) {
221   case ReductionIdentifier::MAX:
222     reductionOp =
223         getReductionOperation<mlir::arith::MaxNumFOp, mlir::arith::MaxSIOp>(
224             builder, type, loc, op1, op2);
225     break;
226   case ReductionIdentifier::MIN:
227     reductionOp =
228         getReductionOperation<mlir::arith::MinNumFOp, mlir::arith::MinSIOp>(
229             builder, type, loc, op1, op2);
230     break;
231   case ReductionIdentifier::IOR:
232     assert((type.isIntOrIndex()) && "only integer is expected");
233     reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
234     break;
235   case ReductionIdentifier::IEOR:
236     assert((type.isIntOrIndex()) && "only integer is expected");
237     reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
238     break;
239   case ReductionIdentifier::IAND:
240     assert((type.isIntOrIndex()) && "only integer is expected");
241     reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
242     break;
243   case ReductionIdentifier::ADD:
244     reductionOp =
245         getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
246                               fir::AddcOp>(builder, type, loc, op1, op2);
247     break;
248   case ReductionIdentifier::MULTIPLY:
249     reductionOp =
250         getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
251                               fir::MulcOp>(builder, type, loc, op1, op2);
252     break;
253   case ReductionIdentifier::AND: {
254     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
255     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
256 
257     mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
258 
259     reductionOp = builder.createConvert(loc, type, andiOp);
260     break;
261   }
262   case ReductionIdentifier::OR: {
263     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
264     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
265 
266     mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
267 
268     reductionOp = builder.createConvert(loc, type, oriOp);
269     break;
270   }
271   case ReductionIdentifier::EQV: {
272     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
273     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
274 
275     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
276         loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
277 
278     reductionOp = builder.createConvert(loc, type, cmpiOp);
279     break;
280   }
281   case ReductionIdentifier::NEQV: {
282     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
283     mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
284 
285     mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
286         loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
287 
288     reductionOp = builder.createConvert(loc, type, cmpiOp);
289     break;
290   }
291   default:
292     TODO(loc, "Reduction of some intrinsic operators is not supported");
293   }
294 
295   return reductionOp;
296 }
297 
298 /// Create reduction combiner region for reduction variables which are boxed
299 /// arrays
300 static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
301                            ReductionProcessor::ReductionIdentifier redId,
302                            fir::BaseBoxType boxTy, mlir::Value lhs,
303                            mlir::Value rhs) {
304   fir::SequenceType seqTy =
305       mlir::dyn_cast_or_null<fir::SequenceType>(boxTy.getEleTy());
306   // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
307   if (!seqTy || seqTy.hasUnknownShape())
308     TODO(loc, "Unsupported boxed type in OpenMP reduction");
309 
310   // load fir.ref<fir.box<...>>
311   mlir::Value lhsAddr = lhs;
312   lhs = builder.create<fir::LoadOp>(loc, lhs);
313   rhs = builder.create<fir::LoadOp>(loc, rhs);
314 
315   const unsigned rank = seqTy.getDimension();
316   llvm::SmallVector<mlir::Value> extents;
317   extents.reserve(rank);
318   llvm::SmallVector<mlir::Value> lbAndExtents;
319   lbAndExtents.reserve(rank * 2);
320 
321   // Get box lowerbounds and extents:
322   mlir::Type idxTy = builder.getIndexType();
323   for (unsigned i = 0; i < rank; ++i) {
324     // TODO: ideally we want to hoist box reads out of the critical section.
325     // We could do this by having box dimensions in block arguments like
326     // OpenACC does
327     mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i);
328     auto dimInfo =
329         builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, lhs, dim);
330     extents.push_back(dimInfo.getExtent());
331     lbAndExtents.push_back(dimInfo.getLowerBound());
332     lbAndExtents.push_back(dimInfo.getExtent());
333   }
334 
335   auto shapeShiftTy = fir::ShapeShiftType::get(builder.getContext(), rank);
336   auto shapeShift =
337       builder.create<fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents);
338 
339   // Iterate over array elements, applying the equivalent scalar reduction:
340 
341   // A hlfir::elemental here gets inlined with a temporary so create the
342   // loop nest directly.
343   // This function already controls all of the code in this region so we
344   // know this won't miss any opportuinties for clever elemental inlining
345   hlfir::LoopNest nest =
346       hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
347   builder.setInsertionPointToStart(nest.innerLoop.getBody());
348   mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
349   auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
350       loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
351       nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
352   auto rhsEleAddr = builder.create<fir::ArrayCoorOp>(
353       loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{},
354       nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
355   auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr);
356   auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr);
357   mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner(
358       builder, loc, redId, refTy, lhsEle, rhsEle);
359   builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
360 
361   builder.setInsertionPointAfter(nest.outerLoop);
362   builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
363 }
364 
365 // generate combiner region for reduction operations
366 static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
367                         ReductionProcessor::ReductionIdentifier redId,
368                         mlir::Type ty, mlir::Value lhs, mlir::Value rhs,
369                         bool isByRef) {
370   ty = fir::unwrapRefType(ty);
371 
372   if (fir::isa_trivial(ty)) {
373     mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs);
374     mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs);
375 
376     mlir::Value result = ReductionProcessor::createScalarCombiner(
377         builder, loc, redId, ty, lhsLoaded, rhsLoaded);
378     if (isByRef) {
379       builder.create<fir::StoreOp>(loc, result, lhs);
380       builder.create<mlir::omp::YieldOp>(loc, lhs);
381     } else {
382       builder.create<mlir::omp::YieldOp>(loc, result);
383     }
384     return;
385   }
386   // all arrays should have been boxed
387   if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
388     genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs);
389     return;
390   }
391 
392   TODO(loc, "OpenMP genCombiner for unsupported reduction variable type");
393 }
394 
395 static void
396 createReductionCleanupRegion(fir::FirOpBuilder &builder, mlir::Location loc,
397                              mlir::omp::DeclareReductionOp &reductionDecl) {
398   mlir::Type redTy = reductionDecl.getType();
399 
400   mlir::Region &cleanupRegion = reductionDecl.getCleanupRegion();
401   assert(cleanupRegion.empty());
402   mlir::Block *block =
403       builder.createBlock(&cleanupRegion, cleanupRegion.end(), {redTy}, {loc});
404   builder.setInsertionPointToEnd(block);
405 
406   auto typeError = [loc]() {
407     fir::emitFatalError(loc,
408                         "Attempt to create an omp reduction cleanup region "
409                         "for a type that wasn't allocated",
410                         /*genCrashDiag=*/true);
411   };
412 
413   mlir::Type valTy = fir::unwrapRefType(redTy);
414   if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(valTy)) {
415     mlir::Type innerTy = fir::extractSequenceType(boxTy);
416     if (!mlir::isa<fir::SequenceType>(innerTy))
417       typeError();
418 
419     mlir::Value arg = block->getArgument(0);
420     arg = builder.loadIfRef(loc, arg);
421     assert(mlir::isa<fir::BaseBoxType>(arg.getType()));
422 
423     // Deallocate box
424     // The FIR type system doesn't nesecarrily know that this is a mutable box
425     // if we allocated the thread local array on the heap to avoid looped stack
426     // allocations.
427     mlir::Value addr =
428         hlfir::genVariableRawAddress(loc, builder, hlfir::Entity{arg});
429     mlir::Value isAllocated = builder.genIsNotNullAddr(loc, addr);
430     fir::IfOp ifOp =
431         builder.create<fir::IfOp>(loc, isAllocated, /*withElseRegion=*/false);
432     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
433 
434     mlir::Value cast = builder.createConvert(
435         loc, fir::HeapType::get(fir::dyn_cast_ptrEleTy(addr.getType())), addr);
436     builder.create<fir::FreeMemOp>(loc, cast);
437 
438     builder.setInsertionPointAfter(ifOp);
439     builder.create<mlir::omp::YieldOp>(loc);
440     return;
441   }
442 
443   typeError();
444 }
445 
446 static mlir::Value
447 createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc,
448                           mlir::omp::DeclareReductionOp &reductionDecl,
449                           const ReductionProcessor::ReductionIdentifier redId,
450                           mlir::Type type, bool isByRef) {
451   mlir::Type ty = fir::unwrapRefType(type);
452   mlir::Value initValue = ReductionProcessor::getReductionInitValue(
453       loc, fir::unwrapSeqOrBoxedSeqType(ty), redId, builder);
454 
455   if (fir::isa_trivial(ty)) {
456     if (isByRef) {
457       mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty);
458       builder.createStoreWithConvert(loc, initValue, alloca);
459       return alloca;
460     }
461     // by val
462     return initValue;
463   }
464 
465   // all arrays are boxed
466   if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
467     assert(isByRef && "passing arrays by value is unsupported");
468     // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
469     mlir::Type innerTy = fir::extractSequenceType(boxTy);
470     if (!mlir::isa<fir::SequenceType>(innerTy))
471       TODO(loc, "Unsupported boxed type for reduction");
472     // Create the private copy from the initial fir.box:
473     hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)};
474 
475     // Allocating on the heap in case the whole reduction is nested inside of a
476     // loop
477     // TODO: compare performance here to using allocas - this could be made to
478     // work by inserting stacksave/stackrestore around the reduction in
479     // openmpirbuilder
480     auto [temp, needsDealloc] = createTempFromMold(loc, builder, source);
481     // if needsDealloc isn't statically false, add cleanup region. TODO: always
482     // do this for allocatable boxes because they might have been re-allocated
483     // in the body of the loop/parallel region
484     std::optional<int64_t> cstNeedsDealloc =
485         fir::getIntIfConstant(needsDealloc);
486     assert(cstNeedsDealloc.has_value() &&
487            "createTempFromMold decides this statically");
488     if (cstNeedsDealloc.has_value() && *cstNeedsDealloc != false) {
489       mlir::OpBuilder::InsertionGuard guard(builder);
490       createReductionCleanupRegion(builder, loc, reductionDecl);
491     }
492 
493     // Put the temporary inside of a box:
494     hlfir::Entity box = hlfir::genVariableBox(loc, builder, temp);
495     builder.create<hlfir::AssignOp>(loc, initValue, box);
496     mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty);
497     builder.create<fir::StoreOp>(loc, box, boxAlloca);
498     return boxAlloca;
499   }
500 
501   TODO(loc, "createReductionInitRegion for unsupported type");
502 }
503 
504 mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
505     fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
506     const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
507     bool isByRef) {
508   mlir::OpBuilder::InsertionGuard guard(builder);
509   mlir::ModuleOp module = builder.getModule();
510 
511   assert(!reductionOpName.empty());
512 
513   auto decl =
514       module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName);
515   if (decl)
516     return decl;
517 
518   mlir::OpBuilder modBuilder(module.getBodyRegion());
519   mlir::Type valTy = fir::unwrapRefType(type);
520   if (!isByRef)
521     type = valTy;
522 
523   decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName,
524                                                           type);
525   builder.createBlock(&decl.getInitializerRegion(),
526                       decl.getInitializerRegion().end(), {type}, {loc});
527   builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
528 
529   mlir::Value init =
530       createReductionInitRegion(builder, loc, decl, redId, type, isByRef);
531   builder.create<mlir::omp::YieldOp>(loc, init);
532 
533   builder.createBlock(&decl.getReductionRegion(),
534                       decl.getReductionRegion().end(), {type, type},
535                       {loc, loc});
536 
537   builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
538   mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
539   mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
540   genCombiner(builder, loc, redId, type, op1, op2, isByRef);
541 
542   return decl;
543 }
544 
545 // TODO: By-ref vs by-val reductions are currently toggled for the whole
546 //       operation (possibly effecting multiple reduction variables).
547 //       This could cause a problem with openmp target reductions because
548 //       by-ref trivial types may not be supported.
549 bool ReductionProcessor::doReductionByRef(
550     const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
551   if (reductionVars.empty())
552     return false;
553   if (forceByrefReduction)
554     return true;
555 
556   for (mlir::Value reductionVar : reductionVars) {
557     if (auto declare =
558             mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
559       reductionVar = declare.getMemref();
560 
561     if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
562       return true;
563   }
564   return false;
565 }
566 
567 void ReductionProcessor::addDeclareReduction(
568     mlir::Location currentLocation,
569     Fortran::lower::AbstractConverter &converter,
570     const omp::clause::Reduction &reduction,
571     llvm::SmallVectorImpl<mlir::Value> &reductionVars,
572     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
573     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
574         *reductionSymbols) {
575   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
576 
577   if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
578           reduction.t))
579     TODO(currentLocation, "Reduction modifiers are not supported");
580 
581   mlir::omp::DeclareReductionOp decl;
582   const auto &redOperatorList{
583       std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)};
584   assert(redOperatorList.size() == 1 && "Expecting single operator");
585   const auto &redOperator = redOperatorList.front();
586   const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
587 
588   if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) {
589     if (const auto *reductionIntrinsic =
590             std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
591       if (!ReductionProcessor::supportedIntrinsicProcReduction(
592               *reductionIntrinsic)) {
593         return;
594       }
595     } else {
596       return;
597     }
598   }
599 
600   // initial pass to collect all reduction vars so we can figure out if this
601   // should happen byref
602   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
603   for (const Object &object : objectList) {
604     const Fortran::semantics::Symbol *symbol = object.id();
605     if (reductionSymbols)
606       reductionSymbols->push_back(symbol);
607     mlir::Value symVal = converter.getSymbolAddress(*symbol);
608     mlir::Type eleType;
609     auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType());
610     if (refType)
611       eleType = refType.getEleTy();
612     else
613       eleType = symVal.getType();
614 
615     // all arrays must be boxed so that we have convenient access to all the
616     // information needed to iterate over the array
617     if (mlir::isa<fir::SequenceType>(eleType)) {
618       // For Host associated symbols, use `SymbolBox` instead
619       Fortran::lower::SymbolBox symBox =
620           converter.lookupOneLevelUpSymbol(*symbol);
621       hlfir::Entity entity{symBox.getAddr()};
622       entity = genVariableBox(currentLocation, builder, entity);
623       mlir::Value box = entity.getBase();
624 
625       // Always pass the box by reference so that the OpenMP dialect
626       // verifiers don't need to know anything about fir.box
627       auto alloca =
628           builder.create<fir::AllocaOp>(currentLocation, box.getType());
629       builder.create<fir::StoreOp>(currentLocation, box, alloca);
630 
631       symVal = alloca;
632     } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) {
633       // boxed arrays are passed as values not by reference. Unfortunately,
634       // we can't pass a box by value to omp.redution_declare, so turn it
635       // into a reference
636 
637       auto alloca =
638           builder.create<fir::AllocaOp>(currentLocation, symVal.getType());
639       builder.create<fir::StoreOp>(currentLocation, symVal, alloca);
640       symVal = alloca;
641     } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
642       symVal = declOp.getBase();
643     }
644 
645     // this isn't the same as the by-val and by-ref passing later in the
646     // pipeline. Both styles assume that the variable is a reference at
647     // this point
648     assert(mlir::isa<fir::ReferenceType>(symVal.getType()) &&
649            "reduction input var is a reference");
650 
651     reductionVars.push_back(symVal);
652   }
653   const bool isByRef = doReductionByRef(reductionVars);
654 
655   if (const auto &redDefinedOp =
656           std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
657     const auto &intrinsicOp{
658         std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
659             redDefinedOp->u)};
660     ReductionIdentifier redId = getReductionType(intrinsicOp);
661     switch (redId) {
662     case ReductionIdentifier::ADD:
663     case ReductionIdentifier::MULTIPLY:
664     case ReductionIdentifier::AND:
665     case ReductionIdentifier::EQV:
666     case ReductionIdentifier::OR:
667     case ReductionIdentifier::NEQV:
668       break;
669     default:
670       TODO(currentLocation,
671            "Reduction of some intrinsic operators is not supported");
672       break;
673     }
674 
675     for (mlir::Value symVal : reductionVars) {
676       auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
677       const auto &kindMap = firOpBuilder.getKindMap();
678       if (redType.getEleTy().isa<fir::LogicalType>())
679         decl = createDeclareReduction(firOpBuilder,
680                                       getReductionName(intrinsicOp, kindMap,
681                                                        firOpBuilder.getI1Type(),
682                                                        isByRef),
683                                       redId, redType, currentLocation, isByRef);
684       else
685         decl = createDeclareReduction(
686             firOpBuilder,
687             getReductionName(intrinsicOp, kindMap, redType, isByRef), redId,
688             redType, currentLocation, isByRef);
689       reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
690           firOpBuilder.getContext(), decl.getSymName()));
691     }
692   } else if (const auto *reductionIntrinsic =
693                  std::get_if<omp::clause::ProcedureDesignator>(
694                      &redOperator.u)) {
695     if (ReductionProcessor::supportedIntrinsicProcReduction(
696             *reductionIntrinsic)) {
697       ReductionProcessor::ReductionIdentifier redId =
698           ReductionProcessor::getReductionType(*reductionIntrinsic);
699       for (const Object &object : objectList) {
700         const Fortran::semantics::Symbol *symbol = object.id();
701         mlir::Value symVal = converter.getSymbolAddress(*symbol);
702         if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
703           symVal = declOp.getBase();
704         auto redType = symVal.getType().cast<fir::ReferenceType>();
705         if (!redType.getEleTy().isIntOrIndexOrFloat())
706           TODO(currentLocation, "User Defined Reduction on non-trivial type");
707         decl = createDeclareReduction(
708             firOpBuilder,
709             getReductionName(getRealName(*reductionIntrinsic).ToString(),
710                              firOpBuilder.getKindMap(), redType, isByRef),
711             redId, redType, currentLocation, isByRef);
712         reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
713             firOpBuilder.getContext(), decl.getSymName()));
714       }
715     }
716   }
717 }
718 
719 const Fortran::semantics::SourceName
720 ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) {
721   return symbol->GetUltimate().name();
722 }
723 
724 const Fortran::semantics::SourceName
725 ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) {
726   return getRealName(pd.v.id());
727 }
728 
729 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
730                                              mlir::Location loc) {
731   switch (redId) {
732   case ReductionIdentifier::ADD:
733   case ReductionIdentifier::OR:
734   case ReductionIdentifier::NEQV:
735     return 0;
736   case ReductionIdentifier::MULTIPLY:
737   case ReductionIdentifier::AND:
738   case ReductionIdentifier::EQV:
739     return 1;
740   default:
741     TODO(loc, "Reduction of some intrinsic operators is not supported");
742   }
743 }
744 
745 } // namespace omp
746 } // namespace lower
747 } // namespace Fortran
748