xref: /llvm-project/flang/lib/Lower/OpenMP/PrivateReductionUtils.cpp (revision 8557a57c4b1a228ce63f2409dd5cc4c70a25e6fc)
1*8557a57cSTom Eccles //===-- PrivateReductionUtils.cpp -------------------------------*- C++ -*-===//
2*8557a57cSTom Eccles //
3*8557a57cSTom Eccles // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*8557a57cSTom Eccles // See https://llvm.org/LICENSE.txt for license information.
5*8557a57cSTom Eccles // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*8557a57cSTom Eccles //
7*8557a57cSTom Eccles //===----------------------------------------------------------------------===//
8*8557a57cSTom Eccles //
9*8557a57cSTom Eccles // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10*8557a57cSTom Eccles //
11*8557a57cSTom Eccles //===----------------------------------------------------------------------===//
12*8557a57cSTom Eccles 
13*8557a57cSTom Eccles #include "PrivateReductionUtils.h"
14*8557a57cSTom Eccles 
15*8557a57cSTom Eccles #include "flang/Optimizer/Builder/FIRBuilder.h"
16*8557a57cSTom Eccles #include "flang/Optimizer/Builder/HLFIRTools.h"
17*8557a57cSTom Eccles #include "flang/Optimizer/Builder/Todo.h"
18*8557a57cSTom Eccles #include "flang/Optimizer/HLFIR/HLFIROps.h"
19*8557a57cSTom Eccles #include "flang/Optimizer/Support/FatalError.h"
20*8557a57cSTom Eccles #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
21*8557a57cSTom Eccles #include "mlir/IR/Location.h"
22*8557a57cSTom Eccles 
23*8557a57cSTom Eccles static void createCleanupRegion(fir::FirOpBuilder &builder, mlir::Location loc,
24*8557a57cSTom Eccles                                 mlir::Type argType,
25*8557a57cSTom Eccles                                 mlir::Region &cleanupRegion) {
26*8557a57cSTom Eccles   assert(cleanupRegion.empty());
27*8557a57cSTom Eccles   mlir::Block *block = builder.createBlock(&cleanupRegion, cleanupRegion.end(),
28*8557a57cSTom Eccles                                            {argType}, {loc});
29*8557a57cSTom Eccles   builder.setInsertionPointToEnd(block);
30*8557a57cSTom Eccles 
31*8557a57cSTom Eccles   auto typeError = [loc]() {
32*8557a57cSTom Eccles     fir::emitFatalError(loc,
33*8557a57cSTom Eccles                         "Attempt to create an omp cleanup region "
34*8557a57cSTom Eccles                         "for a type that wasn't allocated",
35*8557a57cSTom Eccles                         /*genCrashDiag=*/true);
36*8557a57cSTom Eccles   };
37*8557a57cSTom Eccles 
38*8557a57cSTom Eccles   mlir::Type valTy = fir::unwrapRefType(argType);
39*8557a57cSTom Eccles   if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(valTy)) {
40*8557a57cSTom Eccles     if (!mlir::isa<fir::HeapType, fir::PointerType>(boxTy.getEleTy())) {
41*8557a57cSTom Eccles       mlir::Type innerTy = fir::extractSequenceType(boxTy);
42*8557a57cSTom Eccles       if (!mlir::isa<fir::SequenceType>(innerTy))
43*8557a57cSTom Eccles         typeError();
44*8557a57cSTom Eccles     }
45*8557a57cSTom Eccles 
46*8557a57cSTom Eccles     mlir::Value arg = builder.loadIfRef(loc, block->getArgument(0));
47*8557a57cSTom Eccles     assert(mlir::isa<fir::BaseBoxType>(arg.getType()));
48*8557a57cSTom Eccles 
49*8557a57cSTom Eccles     // Deallocate box
50*8557a57cSTom Eccles     // The FIR type system doesn't nesecarrily know that this is a mutable box
51*8557a57cSTom Eccles     // if we allocated the thread local array on the heap to avoid looped stack
52*8557a57cSTom Eccles     // allocations.
53*8557a57cSTom Eccles     mlir::Value addr =
54*8557a57cSTom Eccles         hlfir::genVariableRawAddress(loc, builder, hlfir::Entity{arg});
55*8557a57cSTom Eccles     mlir::Value isAllocated = builder.genIsNotNullAddr(loc, addr);
56*8557a57cSTom Eccles     fir::IfOp ifOp =
57*8557a57cSTom Eccles         builder.create<fir::IfOp>(loc, isAllocated, /*withElseRegion=*/false);
58*8557a57cSTom Eccles     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
59*8557a57cSTom Eccles 
60*8557a57cSTom Eccles     mlir::Value cast = builder.createConvert(
61*8557a57cSTom Eccles         loc, fir::HeapType::get(fir::dyn_cast_ptrEleTy(addr.getType())), addr);
62*8557a57cSTom Eccles     builder.create<fir::FreeMemOp>(loc, cast);
63*8557a57cSTom Eccles 
64*8557a57cSTom Eccles     builder.setInsertionPointAfter(ifOp);
65*8557a57cSTom Eccles     builder.create<mlir::omp::YieldOp>(loc);
66*8557a57cSTom Eccles     return;
67*8557a57cSTom Eccles   }
68*8557a57cSTom Eccles 
69*8557a57cSTom Eccles   typeError();
70*8557a57cSTom Eccles }
71*8557a57cSTom Eccles 
72*8557a57cSTom Eccles fir::ShapeShiftOp Fortran::lower::omp::getShapeShift(fir::FirOpBuilder &builder,
73*8557a57cSTom Eccles                                                      mlir::Location loc,
74*8557a57cSTom Eccles                                                      mlir::Value box) {
75*8557a57cSTom Eccles   fir::SequenceType sequenceType = mlir::cast<fir::SequenceType>(
76*8557a57cSTom Eccles       hlfir::getFortranElementOrSequenceType(box.getType()));
77*8557a57cSTom Eccles   const unsigned rank = sequenceType.getDimension();
78*8557a57cSTom Eccles   llvm::SmallVector<mlir::Value> lbAndExtents;
79*8557a57cSTom Eccles   lbAndExtents.reserve(rank * 2);
80*8557a57cSTom Eccles 
81*8557a57cSTom Eccles   mlir::Type idxTy = builder.getIndexType();
82*8557a57cSTom Eccles   for (unsigned i = 0; i < rank; ++i) {
83*8557a57cSTom Eccles     // TODO: ideally we want to hoist box reads out of the critical section.
84*8557a57cSTom Eccles     // We could do this by having box dimensions in block arguments like
85*8557a57cSTom Eccles     // OpenACC does
86*8557a57cSTom Eccles     mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i);
87*8557a57cSTom Eccles     auto dimInfo =
88*8557a57cSTom Eccles         builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, box, dim);
89*8557a57cSTom Eccles     lbAndExtents.push_back(dimInfo.getLowerBound());
90*8557a57cSTom Eccles     lbAndExtents.push_back(dimInfo.getExtent());
91*8557a57cSTom Eccles   }
92*8557a57cSTom Eccles 
93*8557a57cSTom Eccles   auto shapeShiftTy = fir::ShapeShiftType::get(builder.getContext(), rank);
94*8557a57cSTom Eccles   auto shapeShift =
95*8557a57cSTom Eccles       builder.create<fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents);
96*8557a57cSTom Eccles   return shapeShift;
97*8557a57cSTom Eccles }
98*8557a57cSTom Eccles 
99*8557a57cSTom Eccles void Fortran::lower::omp::populateByRefInitAndCleanupRegions(
100*8557a57cSTom Eccles     fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type argType,
101*8557a57cSTom Eccles     mlir::Value scalarInitValue, mlir::Block *initBlock,
102*8557a57cSTom Eccles     mlir::Value allocatedPrivVarArg, mlir::Value moldArg,
103*8557a57cSTom Eccles     mlir::Region &cleanupRegion) {
104*8557a57cSTom Eccles   mlir::Type ty = fir::unwrapRefType(argType);
105*8557a57cSTom Eccles   builder.setInsertionPointToEnd(initBlock);
106*8557a57cSTom Eccles   auto yield = [&](mlir::Value ret) {
107*8557a57cSTom Eccles     builder.create<mlir::omp::YieldOp>(loc, ret);
108*8557a57cSTom Eccles   };
109*8557a57cSTom Eccles 
110*8557a57cSTom Eccles   if (fir::isa_trivial(ty)) {
111*8557a57cSTom Eccles     builder.setInsertionPointToEnd(initBlock);
112*8557a57cSTom Eccles 
113*8557a57cSTom Eccles     if (scalarInitValue)
114*8557a57cSTom Eccles       builder.createStoreWithConvert(loc, scalarInitValue, allocatedPrivVarArg);
115*8557a57cSTom Eccles     yield(allocatedPrivVarArg);
116*8557a57cSTom Eccles     return;
117*8557a57cSTom Eccles   }
118*8557a57cSTom Eccles 
119*8557a57cSTom Eccles   // check if an allocatable box is unallocated. If so, initialize the boxAlloca
120*8557a57cSTom Eccles   // to be unallocated e.g.
121*8557a57cSTom Eccles   // %box_alloca = fir.alloca !fir.box<!fir.heap<...>>
122*8557a57cSTom Eccles   // %addr = fir.box_addr %box
123*8557a57cSTom Eccles   // if (%addr == 0) {
124*8557a57cSTom Eccles   //   %nullbox = fir.embox %addr
125*8557a57cSTom Eccles   //   fir.store %nullbox to %box_alloca
126*8557a57cSTom Eccles   // } else {
127*8557a57cSTom Eccles   //   // ...
128*8557a57cSTom Eccles   //   fir.store %something to %box_alloca
129*8557a57cSTom Eccles   // }
130*8557a57cSTom Eccles   // omp.yield %box_alloca
131*8557a57cSTom Eccles   moldArg = builder.loadIfRef(loc, moldArg);
132*8557a57cSTom Eccles   auto handleNullAllocatable = [&](mlir::Value boxAlloca) -> fir::IfOp {
133*8557a57cSTom Eccles     mlir::Value addr = builder.create<fir::BoxAddrOp>(loc, moldArg);
134*8557a57cSTom Eccles     mlir::Value isNotAllocated = builder.genIsNullAddr(loc, addr);
135*8557a57cSTom Eccles     fir::IfOp ifOp = builder.create<fir::IfOp>(loc, isNotAllocated,
136*8557a57cSTom Eccles                                                /*withElseRegion=*/true);
137*8557a57cSTom Eccles     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
138*8557a57cSTom Eccles     // just embox the null address and return
139*8557a57cSTom Eccles     mlir::Value nullBox = builder.create<fir::EmboxOp>(loc, ty, addr);
140*8557a57cSTom Eccles     builder.create<fir::StoreOp>(loc, nullBox, boxAlloca);
141*8557a57cSTom Eccles     return ifOp;
142*8557a57cSTom Eccles   };
143*8557a57cSTom Eccles 
144*8557a57cSTom Eccles   // all arrays are boxed
145*8557a57cSTom Eccles   if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
146*8557a57cSTom Eccles     bool isAllocatableOrPointer =
147*8557a57cSTom Eccles         mlir::isa<fir::HeapType, fir::PointerType>(boxTy.getEleTy());
148*8557a57cSTom Eccles 
149*8557a57cSTom Eccles     builder.setInsertionPointToEnd(initBlock);
150*8557a57cSTom Eccles     mlir::Value boxAlloca = allocatedPrivVarArg;
151*8557a57cSTom Eccles     mlir::Type innerTy = fir::unwrapRefType(boxTy.getEleTy());
152*8557a57cSTom Eccles     if (fir::isa_trivial(innerTy)) {
153*8557a57cSTom Eccles       // boxed non-sequence value e.g. !fir.box<!fir.heap<i32>>
154*8557a57cSTom Eccles       if (!isAllocatableOrPointer)
155*8557a57cSTom Eccles         TODO(loc,
156*8557a57cSTom Eccles              "Reduction/Privatization of non-allocatable trivial typed box");
157*8557a57cSTom Eccles 
158*8557a57cSTom Eccles       fir::IfOp ifUnallocated = handleNullAllocatable(boxAlloca);
159*8557a57cSTom Eccles 
160*8557a57cSTom Eccles       builder.setInsertionPointToStart(&ifUnallocated.getElseRegion().front());
161*8557a57cSTom Eccles       mlir::Value valAlloc = builder.create<fir::AllocMemOp>(loc, innerTy);
162*8557a57cSTom Eccles       if (scalarInitValue)
163*8557a57cSTom Eccles         builder.createStoreWithConvert(loc, scalarInitValue, valAlloc);
164*8557a57cSTom Eccles       mlir::Value box = builder.create<fir::EmboxOp>(loc, ty, valAlloc);
165*8557a57cSTom Eccles       builder.create<fir::StoreOp>(loc, box, boxAlloca);
166*8557a57cSTom Eccles 
167*8557a57cSTom Eccles       createCleanupRegion(builder, loc, argType, cleanupRegion);
168*8557a57cSTom Eccles       builder.setInsertionPointAfter(ifUnallocated);
169*8557a57cSTom Eccles       yield(boxAlloca);
170*8557a57cSTom Eccles       return;
171*8557a57cSTom Eccles     }
172*8557a57cSTom Eccles     innerTy = fir::extractSequenceType(boxTy);
173*8557a57cSTom Eccles     if (!mlir::isa<fir::SequenceType>(innerTy))
174*8557a57cSTom Eccles       TODO(loc, "Unsupported boxed type for reduction/privatization");
175*8557a57cSTom Eccles 
176*8557a57cSTom Eccles     fir::IfOp ifUnallocated{nullptr};
177*8557a57cSTom Eccles     if (isAllocatableOrPointer) {
178*8557a57cSTom Eccles       ifUnallocated = handleNullAllocatable(boxAlloca);
179*8557a57cSTom Eccles       builder.setInsertionPointToStart(&ifUnallocated.getElseRegion().front());
180*8557a57cSTom Eccles     }
181*8557a57cSTom Eccles 
182*8557a57cSTom Eccles     // Create the private copy from the initial fir.box:
183*8557a57cSTom Eccles     mlir::Value loadedBox = builder.loadIfRef(loc, moldArg);
184*8557a57cSTom Eccles     hlfir::Entity source = hlfir::Entity{loadedBox};
185*8557a57cSTom Eccles 
186*8557a57cSTom Eccles     // Allocating on the heap in case the whole reduction is nested inside of a
187*8557a57cSTom Eccles     // loop
188*8557a57cSTom Eccles     // TODO: compare performance here to using allocas - this could be made to
189*8557a57cSTom Eccles     // work by inserting stacksave/stackrestore around the reduction in
190*8557a57cSTom Eccles     // openmpirbuilder
191*8557a57cSTom Eccles     auto [temp, needsDealloc] = createTempFromMold(loc, builder, source);
192*8557a57cSTom Eccles     // if needsDealloc isn't statically false, add cleanup region. Always
193*8557a57cSTom Eccles     // do this for allocatable boxes because they might have been re-allocated
194*8557a57cSTom Eccles     // in the body of the loop/parallel region
195*8557a57cSTom Eccles 
196*8557a57cSTom Eccles     std::optional<int64_t> cstNeedsDealloc =
197*8557a57cSTom Eccles         fir::getIntIfConstant(needsDealloc);
198*8557a57cSTom Eccles     assert(cstNeedsDealloc.has_value() &&
199*8557a57cSTom Eccles            "createTempFromMold decides this statically");
200*8557a57cSTom Eccles     if (cstNeedsDealloc.has_value() && *cstNeedsDealloc != false) {
201*8557a57cSTom Eccles       mlir::OpBuilder::InsertionGuard guard(builder);
202*8557a57cSTom Eccles       createCleanupRegion(builder, loc, argType, cleanupRegion);
203*8557a57cSTom Eccles     } else {
204*8557a57cSTom Eccles       assert(!isAllocatableOrPointer &&
205*8557a57cSTom Eccles              "Pointer-like arrays must be heap allocated");
206*8557a57cSTom Eccles     }
207*8557a57cSTom Eccles 
208*8557a57cSTom Eccles     // Put the temporary inside of a box:
209*8557a57cSTom Eccles     // hlfir::genVariableBox doesn't handle non-default lower bounds
210*8557a57cSTom Eccles     mlir::Value box;
211*8557a57cSTom Eccles     fir::ShapeShiftOp shapeShift = getShapeShift(builder, loc, loadedBox);
212*8557a57cSTom Eccles     mlir::Type boxType = loadedBox.getType();
213*8557a57cSTom Eccles     if (mlir::isa<fir::BaseBoxType>(temp.getType()))
214*8557a57cSTom Eccles       // the box created by the declare form createTempFromMold is missing lower
215*8557a57cSTom Eccles       // bounds info
216*8557a57cSTom Eccles       box = builder.create<fir::ReboxOp>(loc, boxType, temp, shapeShift,
217*8557a57cSTom Eccles                                          /*shift=*/mlir::Value{});
218*8557a57cSTom Eccles     else
219*8557a57cSTom Eccles       box = builder.create<fir::EmboxOp>(
220*8557a57cSTom Eccles           loc, boxType, temp, shapeShift,
221*8557a57cSTom Eccles           /*slice=*/mlir::Value{},
222*8557a57cSTom Eccles           /*typeParams=*/llvm::ArrayRef<mlir::Value>{});
223*8557a57cSTom Eccles 
224*8557a57cSTom Eccles     if (scalarInitValue)
225*8557a57cSTom Eccles       builder.create<hlfir::AssignOp>(loc, scalarInitValue, box);
226*8557a57cSTom Eccles     builder.create<fir::StoreOp>(loc, box, boxAlloca);
227*8557a57cSTom Eccles     if (ifUnallocated)
228*8557a57cSTom Eccles       builder.setInsertionPointAfter(ifUnallocated);
229*8557a57cSTom Eccles     yield(boxAlloca);
230*8557a57cSTom Eccles     return;
231*8557a57cSTom Eccles   }
232*8557a57cSTom Eccles 
233*8557a57cSTom Eccles   TODO(loc,
234*8557a57cSTom Eccles        "creating reduction/privatization init region for unsupported type");
235*8557a57cSTom Eccles   return;
236*8557a57cSTom Eccles }
237