xref: /llvm-project/flang/lib/Optimizer/Builder/TemporaryStorage.cpp (revision 0a8d5f4e599fca394610a690e026c0460fc43270)
1 //===-- Optimizer/Builder/TemporaryStorage.cpp ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // Implementation of utility data structures to create and manipulate temporary
9 // storages to stack Fortran values or pointers in HLFIR.
10 //===----------------------------------------------------------------------===//
11 
12 #include "flang/Optimizer/Builder/TemporaryStorage.h"
13 #include "flang/Optimizer/Builder/FIRBuilder.h"
14 #include "flang/Optimizer/Builder/HLFIRTools.h"
15 #include "flang/Optimizer/Builder/Runtime/TemporaryStack.h"
16 #include "flang/Optimizer/Builder/Todo.h"
17 #include "flang/Optimizer/HLFIR/HLFIROps.h"
18 
19 //===----------------------------------------------------------------------===//
20 // fir::factory::Counter implementation.
21 //===----------------------------------------------------------------------===//
22 
Counter(mlir::Location loc,fir::FirOpBuilder & builder,mlir::Value initialValue,bool canCountThroughLoops)23 fir::factory::Counter::Counter(mlir::Location loc, fir::FirOpBuilder &builder,
24                                mlir::Value initialValue,
25                                bool canCountThroughLoops)
26     : canCountThroughLoops{canCountThroughLoops}, initialValue{initialValue} {
27   mlir::Type type = initialValue.getType();
28   one = builder.createIntegerConstant(loc, type, 1);
29   if (canCountThroughLoops) {
30     index = builder.createTemporary(loc, type);
31     builder.create<fir::StoreOp>(loc, initialValue, index);
32   } else {
33     index = initialValue;
34   }
35 }
36 
37 mlir::Value
getAndIncrementIndex(mlir::Location loc,fir::FirOpBuilder & builder)38 fir::factory::Counter::getAndIncrementIndex(mlir::Location loc,
39                                             fir::FirOpBuilder &builder) {
40   if (canCountThroughLoops) {
41     mlir::Value indexValue = builder.create<fir::LoadOp>(loc, index);
42     mlir::Value newValue =
43         builder.create<mlir::arith::AddIOp>(loc, indexValue, one);
44     builder.create<fir::StoreOp>(loc, newValue, index);
45     return indexValue;
46   }
47   mlir::Value indexValue = index;
48   index = builder.create<mlir::arith::AddIOp>(loc, indexValue, one);
49   return indexValue;
50 }
51 
reset(mlir::Location loc,fir::FirOpBuilder & builder)52 void fir::factory::Counter::reset(mlir::Location loc,
53                                   fir::FirOpBuilder &builder) {
54   if (canCountThroughLoops)
55     builder.create<fir::StoreOp>(loc, initialValue, index);
56   else
57     index = initialValue;
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // fir::factory::HomogeneousScalarStack implementation.
62 //===----------------------------------------------------------------------===//
63 
HomogeneousScalarStack(mlir::Location loc,fir::FirOpBuilder & builder,fir::SequenceType declaredType,mlir::Value extent,llvm::ArrayRef<mlir::Value> lengths,bool allocateOnHeap,bool stackThroughLoops,llvm::StringRef tempName)64 fir::factory::HomogeneousScalarStack::HomogeneousScalarStack(
65     mlir::Location loc, fir::FirOpBuilder &builder,
66     fir::SequenceType declaredType, mlir::Value extent,
67     llvm::ArrayRef<mlir::Value> lengths, bool allocateOnHeap,
68     bool stackThroughLoops, llvm::StringRef tempName)
69     : allocateOnHeap{allocateOnHeap},
70       counter{loc, builder,
71               builder.createIntegerConstant(loc, builder.getIndexType(), 1),
72               stackThroughLoops} {
73   // Allocate the temporary storage.
74   llvm::SmallVector<mlir::Value, 1> extents{extent};
75   mlir::Value tempStorage;
76   if (allocateOnHeap)
77     tempStorage = builder.createHeapTemporary(loc, declaredType, tempName,
78                                               extents, lengths);
79   else
80     tempStorage =
81         builder.createTemporary(loc, declaredType, tempName, extents, lengths);
82 
83   mlir::Value shape = builder.genShape(loc, extents);
84   temp = builder
85              .create<hlfir::DeclareOp>(loc, tempStorage, tempName, shape,
86                                        lengths, /*dummy_scope=*/nullptr,
87                                        fir::FortranVariableFlagsAttr{})
88              .getBase();
89 }
90 
pushValue(mlir::Location loc,fir::FirOpBuilder & builder,mlir::Value value)91 void fir::factory::HomogeneousScalarStack::pushValue(mlir::Location loc,
92                                                      fir::FirOpBuilder &builder,
93                                                      mlir::Value value) {
94   hlfir::Entity entity{value};
95   assert(entity.isScalar() && "cannot use inlined temp with array");
96   mlir::Value indexValue = counter.getAndIncrementIndex(loc, builder);
97   hlfir::Entity tempElement = hlfir::getElementAt(
98       loc, builder, hlfir::Entity{temp}, mlir::ValueRange{indexValue});
99   // TODO: "copy" would probably be better than assign to ensure there are no
100   // side effects (user assignments, temp, lhs finalization)?
101   // This only makes a difference for derived types, and for now derived types
102   // will use the runtime strategy to avoid any bad behaviors. So the todo
103   // below should not get hit but is added as a remainder/safety.
104   if (!entity.hasIntrinsicType())
105     TODO(loc, "creating inlined temporary stack for derived types");
106   builder.create<hlfir::AssignOp>(loc, value, tempElement);
107 }
108 
resetFetchPosition(mlir::Location loc,fir::FirOpBuilder & builder)109 void fir::factory::HomogeneousScalarStack::resetFetchPosition(
110     mlir::Location loc, fir::FirOpBuilder &builder) {
111   counter.reset(loc, builder);
112 }
113 
114 mlir::Value
fetch(mlir::Location loc,fir::FirOpBuilder & builder)115 fir::factory::HomogeneousScalarStack::fetch(mlir::Location loc,
116                                             fir::FirOpBuilder &builder) {
117   mlir::Value indexValue = counter.getAndIncrementIndex(loc, builder);
118   hlfir::Entity tempElement = hlfir::getElementAt(
119       loc, builder, hlfir::Entity{temp}, mlir::ValueRange{indexValue});
120   return hlfir::loadTrivialScalar(loc, builder, tempElement);
121 }
122 
destroy(mlir::Location loc,fir::FirOpBuilder & builder)123 void fir::factory::HomogeneousScalarStack::destroy(mlir::Location loc,
124                                                    fir::FirOpBuilder &builder) {
125   if (allocateOnHeap) {
126     auto declare = temp.getDefiningOp<hlfir::DeclareOp>();
127     assert(declare && "temp must have been declared");
128     builder.create<fir::FreeMemOp>(loc, declare.getMemref());
129   }
130 }
131 
moveStackAsArrayExpr(mlir::Location loc,fir::FirOpBuilder & builder)132 hlfir::Entity fir::factory::HomogeneousScalarStack::moveStackAsArrayExpr(
133     mlir::Location loc, fir::FirOpBuilder &builder) {
134   mlir::Value mustFree = builder.createBool(loc, allocateOnHeap);
135   auto hlfirExpr = builder.create<hlfir::AsExprOp>(loc, temp, mustFree);
136   return hlfir::Entity{hlfirExpr};
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // fir::factory::SimpleCopy implementation.
141 //===----------------------------------------------------------------------===//
142 
SimpleCopy(mlir::Location loc,fir::FirOpBuilder & builder,hlfir::Entity source,llvm::StringRef tempName)143 fir::factory::SimpleCopy::SimpleCopy(mlir::Location loc,
144                                      fir::FirOpBuilder &builder,
145                                      hlfir::Entity source,
146                                      llvm::StringRef tempName) {
147   // Use hlfir.as_expr and hlfir.associate to create a copy and leave
148   // bufferization deals with how best to make the copy.
149   if (source.isVariable())
150     source = hlfir::Entity{builder.create<hlfir::AsExprOp>(loc, source)};
151   copy = hlfir::genAssociateExpr(loc, builder, source,
152                                  source.getFortranElementType(), tempName);
153 }
154 
destroy(mlir::Location loc,fir::FirOpBuilder & builder)155 void fir::factory::SimpleCopy::destroy(mlir::Location loc,
156                                        fir::FirOpBuilder &builder) {
157   builder.create<hlfir::EndAssociateOp>(loc, copy);
158 }
159 
160 //===----------------------------------------------------------------------===//
161 // fir::factory::AnyValueStack implementation.
162 //===----------------------------------------------------------------------===//
163 
AnyValueStack(mlir::Location loc,fir::FirOpBuilder & builder,mlir::Type valueStaticType)164 fir::factory::AnyValueStack::AnyValueStack(mlir::Location loc,
165                                            fir::FirOpBuilder &builder,
166                                            mlir::Type valueStaticType)
167     : valueStaticType{valueStaticType},
168       counter{loc, builder,
169               builder.createIntegerConstant(loc, builder.getI64Type(), 0),
170               /*stackThroughLoops=*/true} {
171   opaquePtr = fir::runtime::genCreateValueStack(loc, builder);
172   // Compute the storage type. I1 are stored as fir.logical<1>. This is required
173   // to use descriptor.
174   mlir::Type storageType =
175       hlfir::getFortranElementOrSequenceType(valueStaticType);
176   mlir::Type i1Type = builder.getI1Type();
177   if (storageType == i1Type)
178     storageType = fir::LogicalType::get(builder.getContext(), 1);
179   assert(hlfir::getFortranElementType(storageType) != i1Type &&
180          "array of i1 should not be used");
181   mlir::Type heapType = fir::HeapType::get(storageType);
182   mlir::Type boxType;
183   if (hlfir::isPolymorphicType(valueStaticType))
184     boxType = fir::ClassType::get(heapType);
185   else
186     boxType = fir::BoxType::get(heapType);
187   retValueBox = builder.createTemporary(loc, boxType);
188 }
189 
pushValue(mlir::Location loc,fir::FirOpBuilder & builder,mlir::Value value)190 void fir::factory::AnyValueStack::pushValue(mlir::Location loc,
191                                             fir::FirOpBuilder &builder,
192                                             mlir::Value value) {
193   hlfir::Entity entity{value};
194   mlir::Type storageElementType =
195       hlfir::getFortranElementType(retValueBox.getType());
196   auto [box, maybeCleanUp] =
197       hlfir::convertToBox(loc, builder, entity, storageElementType);
198   fir::runtime::genPushValue(loc, builder, opaquePtr, fir::getBase(box));
199   if (maybeCleanUp)
200     (*maybeCleanUp)();
201 }
202 
resetFetchPosition(mlir::Location loc,fir::FirOpBuilder & builder)203 void fir::factory::AnyValueStack::resetFetchPosition(
204     mlir::Location loc, fir::FirOpBuilder &builder) {
205   counter.reset(loc, builder);
206 }
207 
fetch(mlir::Location loc,fir::FirOpBuilder & builder)208 mlir::Value fir::factory::AnyValueStack::fetch(mlir::Location loc,
209                                                fir::FirOpBuilder &builder) {
210   mlir::Value indexValue = counter.getAndIncrementIndex(loc, builder);
211   fir::runtime::genValueAt(loc, builder, opaquePtr, indexValue, retValueBox);
212   // Dereference the allocatable "retValueBox", and load if trivial scalar
213   // value.
214   mlir::Value result =
215       hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{retValueBox});
216   if (valueStaticType != result.getType()) {
217     // Cast back saved simple scalars stored with another type to their original
218     // type (like i1).
219     if (fir::isa_trivial(valueStaticType))
220       return builder.createConvert(loc, valueStaticType, result);
221     // Memory type mismatches (e.g. fir.ref vs fir.heap) or hlfir.expr vs
222     // variable type mismatches are OK, but the base Fortran type must be the
223     // same.
224     assert(hlfir::getFortranElementOrSequenceType(valueStaticType) ==
225                hlfir::getFortranElementOrSequenceType(result.getType()) &&
226            "non trivial values must be saved with their original type");
227   }
228   return result;
229 }
230 
destroy(mlir::Location loc,fir::FirOpBuilder & builder)231 void fir::factory::AnyValueStack::destroy(mlir::Location loc,
232                                           fir::FirOpBuilder &builder) {
233   fir::runtime::genDestroyValueStack(loc, builder, opaquePtr);
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // fir::factory::AnyVariableStack implementation.
238 //===----------------------------------------------------------------------===//
239 
AnyVariableStack(mlir::Location loc,fir::FirOpBuilder & builder,mlir::Type variableStaticType)240 fir::factory::AnyVariableStack::AnyVariableStack(mlir::Location loc,
241                                                  fir::FirOpBuilder &builder,
242                                                  mlir::Type variableStaticType)
243     : variableStaticType{variableStaticType},
244       counter{loc, builder,
245               builder.createIntegerConstant(loc, builder.getI64Type(), 0),
246               /*stackThroughLoops=*/true} {
247   opaquePtr = fir::runtime::genCreateDescriptorStack(loc, builder);
248   mlir::Type storageType =
249       hlfir::getFortranElementOrSequenceType(variableStaticType);
250   mlir::Type ptrType = fir::PointerType::get(storageType);
251   mlir::Type boxType;
252   if (hlfir::isPolymorphicType(variableStaticType))
253     boxType = fir::ClassType::get(ptrType);
254   else
255     boxType = fir::BoxType::get(ptrType);
256   retValueBox = builder.createTemporary(loc, boxType);
257 }
258 
pushValue(mlir::Location loc,fir::FirOpBuilder & builder,mlir::Value variable)259 void fir::factory::AnyVariableStack::pushValue(mlir::Location loc,
260                                                fir::FirOpBuilder &builder,
261                                                mlir::Value variable) {
262   hlfir::Entity entity{variable};
263   mlir::Type storageElementType =
264       hlfir::getFortranElementType(retValueBox.getType());
265   auto [box, maybeCleanUp] =
266       hlfir::convertToBox(loc, builder, entity, storageElementType);
267   fir::runtime::genPushDescriptor(loc, builder, opaquePtr, fir::getBase(box));
268   if (maybeCleanUp)
269     (*maybeCleanUp)();
270 }
271 
resetFetchPosition(mlir::Location loc,fir::FirOpBuilder & builder)272 void fir::factory::AnyVariableStack::resetFetchPosition(
273     mlir::Location loc, fir::FirOpBuilder &builder) {
274   counter.reset(loc, builder);
275 }
276 
fetch(mlir::Location loc,fir::FirOpBuilder & builder)277 mlir::Value fir::factory::AnyVariableStack::fetch(mlir::Location loc,
278                                                   fir::FirOpBuilder &builder) {
279   mlir::Value indexValue = counter.getAndIncrementIndex(loc, builder);
280   fir::runtime::genDescriptorAt(loc, builder, opaquePtr, indexValue,
281                                 retValueBox);
282   hlfir::Entity retBox{builder.create<fir::LoadOp>(loc, retValueBox)};
283   // The runtime always tracks variable as address, but the form of the variable
284   // that was saved may be different (raw address, fir.boxchar), ensure
285   // the returned variable has the same form of the one that was saved.
286   if (mlir::isa<fir::BaseBoxType>(variableStaticType))
287     return builder.createConvert(loc, variableStaticType, retBox);
288   if (mlir::isa<fir::BoxCharType>(variableStaticType))
289     return hlfir::genVariableBoxChar(loc, builder, retBox);
290   mlir::Value rawAddr = genVariableRawAddress(loc, builder, retBox);
291   return builder.createConvert(loc, variableStaticType, rawAddr);
292 }
293 
destroy(mlir::Location loc,fir::FirOpBuilder & builder)294 void fir::factory::AnyVariableStack::destroy(mlir::Location loc,
295                                              fir::FirOpBuilder &builder) {
296   fir::runtime::genDestroyDescriptorStack(loc, builder, opaquePtr);
297 }
298 
299 //===----------------------------------------------------------------------===//
300 // fir::factory::AnyVectorSubscriptStack implementation.
301 //===----------------------------------------------------------------------===//
302 
AnyVectorSubscriptStack(mlir::Location loc,fir::FirOpBuilder & builder,mlir::Type variableStaticType,bool shapeCanBeSavedAsRegister,int rank)303 fir::factory::AnyVectorSubscriptStack::AnyVectorSubscriptStack(
304     mlir::Location loc, fir::FirOpBuilder &builder,
305     mlir::Type variableStaticType, bool shapeCanBeSavedAsRegister, int rank)
306     : AnyVariableStack{loc, builder, variableStaticType} {
307   if (shapeCanBeSavedAsRegister) {
308     shapeTemp = std::make_unique<TemporaryStorage>(SSARegister{});
309     return;
310   }
311   // The shape will be tracked as the dimension inside a descriptor because
312   // that is the easiest from a lowering point of view, and this is an
313   // edge case situation that will probably not very well be exercised.
314   mlir::Type type =
315       fir::BoxType::get(builder.getVarLenSeqTy(builder.getI32Type(), rank));
316   boxType = type;
317   shapeTemp =
318       std::make_unique<TemporaryStorage>(AnyVariableStack{loc, builder, type});
319 }
320 
pushShape(mlir::Location loc,fir::FirOpBuilder & builder,mlir::Value shape)321 void fir::factory::AnyVectorSubscriptStack::pushShape(
322     mlir::Location loc, fir::FirOpBuilder &builder, mlir::Value shape) {
323   if (boxType) {
324     // The shape is saved as a dimensions inside a descriptors.
325     mlir::Type refType = fir::ReferenceType::get(
326         hlfir::getFortranElementOrSequenceType(*boxType));
327     mlir::Value null = builder.createNullConstant(loc, refType);
328     mlir::Value descriptor =
329         builder.create<fir::EmboxOp>(loc, *boxType, null, shape);
330     shapeTemp->pushValue(loc, builder, descriptor);
331     return;
332   }
333   // Otherwise, simply keep track of the fir.shape itself, it is invariant.
334   shapeTemp->cast<SSARegister>().pushValue(loc, builder, shape);
335 }
336 
resetFetchPosition(mlir::Location loc,fir::FirOpBuilder & builder)337 void fir::factory::AnyVectorSubscriptStack::resetFetchPosition(
338     mlir::Location loc, fir::FirOpBuilder &builder) {
339   static_cast<AnyVariableStack *>(this)->resetFetchPosition(loc, builder);
340   shapeTemp->resetFetchPosition(loc, builder);
341 }
342 
343 mlir::Value
fetchShape(mlir::Location loc,fir::FirOpBuilder & builder)344 fir::factory::AnyVectorSubscriptStack::fetchShape(mlir::Location loc,
345                                                   fir::FirOpBuilder &builder) {
346   if (boxType) {
347     hlfir::Entity descriptor{shapeTemp->fetch(loc, builder)};
348     return hlfir::genShape(loc, builder, descriptor);
349   }
350   return shapeTemp->cast<SSARegister>().fetch(loc, builder);
351 }
352 
destroy(mlir::Location loc,fir::FirOpBuilder & builder)353 void fir::factory::AnyVectorSubscriptStack::destroy(
354     mlir::Location loc, fir::FirOpBuilder &builder) {
355   static_cast<AnyVariableStack *>(this)->destroy(loc, builder);
356   shapeTemp->destroy(loc, builder);
357 }
358