xref: /llvm-project/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp (revision 79e788d02eefdacb08af365389b9055518f3fad6)
1 //===-- BoxedProcedure.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/CodeGen/CodeGen.h"
10 
11 #include "flang/Optimizer/Builder/FIRBuilder.h"
12 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
13 #include "flang/Optimizer/Dialect/FIRDialect.h"
14 #include "flang/Optimizer/Dialect/FIROps.h"
15 #include "flang/Optimizer/Dialect/FIRType.h"
16 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
17 #include "flang/Optimizer/Support/FatalError.h"
18 #include "flang/Optimizer/Support/InternalNames.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/DenseMap.h"
23 
24 namespace fir {
25 #define GEN_PASS_DEF_BOXEDPROCEDUREPASS
26 #include "flang/Optimizer/CodeGen/CGPasses.h.inc"
27 } // namespace fir
28 
29 #define DEBUG_TYPE "flang-procedure-pointer"
30 
31 using namespace fir;
32 
33 namespace {
34 /// Options to the procedure pointer pass.
35 struct BoxedProcedureOptions {
36   // Lower the boxproc abstraction to function pointers and thunks where
37   // required.
38   bool useThunks = true;
39 };
40 
41 /// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types.
42 class BoxprocTypeRewriter : public mlir::TypeConverter {
43 public:
44   using mlir::TypeConverter::convertType;
45 
46   /// Does the type \p ty need to be converted?
47   /// Any type that is a `!fir.boxproc` in whole or in part will need to be
48   /// converted to a function type to lower the IR to function pointer form in
49   /// the default implementation performed in this pass. Other implementations
50   /// are possible, so those may convert `!fir.boxproc` to some other type or
51   /// not at all depending on the implementation target's characteristics and
52   /// preference.
53   bool needsConversion(mlir::Type ty) {
54     if (mlir::isa<BoxProcType>(ty))
55       return true;
56     if (auto funcTy = mlir::dyn_cast<mlir::FunctionType>(ty)) {
57       for (auto t : funcTy.getInputs())
58         if (needsConversion(t))
59           return true;
60       for (auto t : funcTy.getResults())
61         if (needsConversion(t))
62           return true;
63       return false;
64     }
65     if (auto tupleTy = mlir::dyn_cast<mlir::TupleType>(ty)) {
66       for (auto t : tupleTy.getTypes())
67         if (needsConversion(t))
68           return true;
69       return false;
70     }
71     if (auto recTy = mlir::dyn_cast<RecordType>(ty)) {
72       auto visited = visitedTypes.find(ty);
73       if (visited != visitedTypes.end())
74         return visited->second;
75       [[maybe_unused]] auto newIt = visitedTypes.try_emplace(ty, false);
76       assert(newIt.second && "expected ty to not be in the map");
77       bool wasAlreadyVisitingRecordType = needConversionIsVisitingRecordType;
78       needConversionIsVisitingRecordType = true;
79       bool result = false;
80       for (auto t : recTy.getTypeList()) {
81         if (needsConversion(t.second)) {
82           result = true;
83           break;
84         }
85       }
86       // Only keep the result cached if the fir.type visited was a "top-level
87       // type". Nested types with a recursive reference to the "top-level type"
88       // may incorrectly have been resolved as not needed conversions because it
89       // had not been determined yet if the "top-level type" needed conversion.
90       // This is not an issue to determine the "top-level type" need of
91       // conversion, but the result should not be kept and later used in other
92       // contexts.
93       needConversionIsVisitingRecordType = wasAlreadyVisitingRecordType;
94       if (needConversionIsVisitingRecordType)
95         visitedTypes.erase(ty);
96       else
97         visitedTypes.find(ty)->second = result;
98       return result;
99     }
100     if (auto boxTy = mlir::dyn_cast<BaseBoxType>(ty))
101       return needsConversion(boxTy.getEleTy());
102     if (isa_ref_type(ty))
103       return needsConversion(unwrapRefType(ty));
104     if (auto t = mlir::dyn_cast<SequenceType>(ty))
105       return needsConversion(unwrapSequenceType(ty));
106     if (auto t = mlir::dyn_cast<TypeDescType>(ty))
107       return needsConversion(t.getOfTy());
108     return false;
109   }
110 
111   BoxprocTypeRewriter(mlir::Location location) : loc{location} {
112     addConversion([](mlir::Type ty) { return ty; });
113     addConversion(
114         [&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); });
115     addConversion([&](mlir::TupleType tupTy) {
116       llvm::SmallVector<mlir::Type> memTys;
117       for (auto ty : tupTy.getTypes())
118         memTys.push_back(convertType(ty));
119       return mlir::TupleType::get(tupTy.getContext(), memTys);
120     });
121     addConversion([&](mlir::FunctionType funcTy) {
122       llvm::SmallVector<mlir::Type> inTys;
123       llvm::SmallVector<mlir::Type> resTys;
124       for (auto ty : funcTy.getInputs())
125         inTys.push_back(convertType(ty));
126       for (auto ty : funcTy.getResults())
127         resTys.push_back(convertType(ty));
128       return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys);
129     });
130     addConversion([&](ReferenceType ty) {
131       return ReferenceType::get(convertType(ty.getEleTy()));
132     });
133     addConversion([&](PointerType ty) {
134       return PointerType::get(convertType(ty.getEleTy()));
135     });
136     addConversion(
137         [&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); });
138     addConversion([&](fir::LLVMPointerType ty) {
139       return fir::LLVMPointerType::get(convertType(ty.getEleTy()));
140     });
141     addConversion(
142         [&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); });
143     addConversion([&](ClassType ty) {
144       return ClassType::get(convertType(ty.getEleTy()));
145     });
146     addConversion([&](SequenceType ty) {
147       // TODO: add ty.getLayoutMap() as needed.
148       return SequenceType::get(ty.getShape(), convertType(ty.getEleTy()));
149     });
150     addConversion([&](RecordType ty) -> mlir::Type {
151       if (!needsConversion(ty))
152         return ty;
153       if (auto converted = convertedTypes.lookup(ty))
154         return converted;
155       auto rec = RecordType::get(ty.getContext(),
156                                  ty.getName().str() + boxprocSuffix.str());
157       if (rec.isFinalized())
158         return rec;
159       [[maybe_unused]] auto it = convertedTypes.try_emplace(ty, rec);
160       assert(it.second && "expected ty to not be in the map");
161       std::vector<RecordType::TypePair> ps = ty.getLenParamList();
162       std::vector<RecordType::TypePair> cs;
163       for (auto t : ty.getTypeList()) {
164         if (needsConversion(t.second))
165           cs.emplace_back(t.first, convertType(t.second));
166         else
167           cs.emplace_back(t.first, t.second);
168       }
169       rec.finalize(ps, cs);
170       rec.pack(ty.isPacked());
171       return rec;
172     });
173     addConversion([&](TypeDescType ty) {
174       return TypeDescType::get(convertType(ty.getOfTy()));
175     });
176     addSourceMaterialization(materializeProcedure);
177     addTargetMaterialization(materializeProcedure);
178   }
179 
180   static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
181                                           BoxProcType type,
182                                           mlir::ValueRange inputs,
183                                           mlir::Location loc) {
184     assert(inputs.size() == 1);
185     return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
186                                      inputs[0]);
187   }
188 
189   void setLocation(mlir::Location location) { loc = location; }
190 
191 private:
192   // Maps to deal with recursive derived types (avoid infinite loops).
193   // Caching is also beneficial for apps with big types (dozens of
194   // components and or parent types), so the lifetime of the cache
195   // is the whole pass.
196   llvm::DenseMap<mlir::Type, bool> visitedTypes;
197   bool needConversionIsVisitingRecordType = false;
198   llvm::DenseMap<mlir::Type, mlir::Type> convertedTypes;
199   mlir::Location loc;
200 };
201 
202 /// A `boxproc` is an abstraction for a Fortran procedure reference. Typically,
203 /// Fortran procedures can be referenced directly through a function pointer.
204 /// However, Fortran has one-level dynamic scoping between a host procedure and
205 /// its internal procedures. This allows internal procedures to directly access
206 /// and modify the state of the host procedure's variables.
207 ///
208 /// There are any number of possible implementations possible.
209 ///
210 /// The implementation used here is to convert `boxproc` values to function
211 /// pointers everywhere. If a `boxproc` value includes a frame pointer to the
212 /// host procedure's data, then a thunk will be created at runtime to capture
213 /// the frame pointer during execution. In LLVM IR, the frame pointer is
214 /// designated with the `nest` attribute. The thunk's address will then be used
215 /// as the call target instead of the original function's address directly.
216 class BoxedProcedurePass
217     : public fir::impl::BoxedProcedurePassBase<BoxedProcedurePass> {
218 public:
219   using BoxedProcedurePassBase<BoxedProcedurePass>::BoxedProcedurePassBase;
220 
221   inline mlir::ModuleOp getModule() { return getOperation(); }
222 
223   void runOnOperation() override final {
224     if (options.useThunks) {
225       auto *context = &getContext();
226       mlir::IRRewriter rewriter(context);
227       BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
228       getModule().walk([&](mlir::Operation *op) {
229         bool opIsValid = true;
230         typeConverter.setLocation(op->getLoc());
231         if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
232           mlir::Type ty = addr.getVal().getType();
233           mlir::Type resTy = addr.getResult().getType();
234           if (llvm::isa<mlir::FunctionType>(ty) ||
235               llvm::isa<fir::BoxProcType>(ty)) {
236             // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
237             // or function type to be `fir.convert` ops.
238             rewriter.setInsertionPoint(addr);
239             rewriter.replaceOpWithNewOp<ConvertOp>(
240                 addr, typeConverter.convertType(addr.getType()), addr.getVal());
241             opIsValid = false;
242           } else if (typeConverter.needsConversion(resTy)) {
243             rewriter.startOpModification(op);
244             op->getResult(0).setType(typeConverter.convertType(resTy));
245             rewriter.finalizeOpModification(op);
246           }
247         } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
248           mlir::FunctionType ty = func.getFunctionType();
249           if (typeConverter.needsConversion(ty)) {
250             rewriter.startOpModification(func);
251             auto toTy =
252                 mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty));
253             if (!func.empty())
254               for (auto e : llvm::enumerate(toTy.getInputs())) {
255                 unsigned i = e.index();
256                 auto &block = func.front();
257                 block.insertArgument(i, e.value(), func.getLoc());
258                 block.getArgument(i + 1).replaceAllUsesWith(
259                     block.getArgument(i));
260                 block.eraseArgument(i + 1);
261               }
262             func.setType(toTy);
263             rewriter.finalizeOpModification(func);
264           }
265         } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
266           // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
267           // as required.
268           mlir::Type toTy = typeConverter.convertType(
269               mlir::cast<BoxProcType>(embox.getType()).getEleTy());
270           rewriter.setInsertionPoint(embox);
271           if (embox.getHost()) {
272             // Create the thunk.
273             auto module = embox->getParentOfType<mlir::ModuleOp>();
274             FirOpBuilder builder(rewriter, module);
275             auto loc = embox.getLoc();
276             mlir::Type i8Ty = builder.getI8Type();
277             mlir::Type i8Ptr = builder.getRefType(i8Ty);
278             mlir::Type buffTy = SequenceType::get({32}, i8Ty);
279             auto buffer = builder.create<AllocaOp>(loc, buffTy);
280             mlir::Value closure =
281                 builder.createConvert(loc, i8Ptr, embox.getHost());
282             mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer);
283             mlir::Value func =
284                 builder.createConvert(loc, i8Ptr, embox.getFunc());
285             builder.create<fir::CallOp>(
286                 loc, factory::getLlvmInitTrampoline(builder),
287                 llvm::ArrayRef<mlir::Value>{tramp, func, closure});
288             auto adjustCall = builder.create<fir::CallOp>(
289                 loc, factory::getLlvmAdjustTrampoline(builder),
290                 llvm::ArrayRef<mlir::Value>{tramp});
291             rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
292                                                    adjustCall.getResult(0));
293             opIsValid = false;
294           } else {
295             // Just forward the function as a pointer.
296             rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
297                                                    embox.getFunc());
298             opIsValid = false;
299           }
300         } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
301           auto ty = global.getType();
302           if (typeConverter.needsConversion(ty)) {
303             rewriter.startOpModification(global);
304             auto toTy = typeConverter.convertType(ty);
305             global.setType(toTy);
306             rewriter.finalizeOpModification(global);
307           }
308         } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
309           auto ty = mem.getType();
310           if (typeConverter.needsConversion(ty)) {
311             rewriter.setInsertionPoint(mem);
312             auto toTy = typeConverter.convertType(unwrapRefType(ty));
313             bool isPinned = mem.getPinned();
314             llvm::StringRef uniqName =
315                 mem.getUniqName().value_or(llvm::StringRef());
316             llvm::StringRef bindcName =
317                 mem.getBindcName().value_or(llvm::StringRef());
318             rewriter.replaceOpWithNewOp<AllocaOp>(
319                 mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
320                 mem.getShape());
321             opIsValid = false;
322           }
323         } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
324           auto ty = mem.getType();
325           if (typeConverter.needsConversion(ty)) {
326             rewriter.setInsertionPoint(mem);
327             auto toTy = typeConverter.convertType(unwrapRefType(ty));
328             llvm::StringRef uniqName =
329                 mem.getUniqName().value_or(llvm::StringRef());
330             llvm::StringRef bindcName =
331                 mem.getBindcName().value_or(llvm::StringRef());
332             rewriter.replaceOpWithNewOp<AllocMemOp>(
333                 mem, toTy, uniqName, bindcName, mem.getTypeparams(),
334                 mem.getShape());
335             opIsValid = false;
336           }
337         } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
338           auto ty = coor.getType();
339           mlir::Type baseTy = coor.getBaseType();
340           if (typeConverter.needsConversion(ty) ||
341               typeConverter.needsConversion(baseTy)) {
342             rewriter.setInsertionPoint(coor);
343             auto toTy = typeConverter.convertType(ty);
344             auto toBaseTy = typeConverter.convertType(baseTy);
345             rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(),
346                                                       coor.getCoor(), toBaseTy);
347             opIsValid = false;
348           }
349         } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
350           auto ty = index.getType();
351           mlir::Type onTy = index.getOnType();
352           if (typeConverter.needsConversion(ty) ||
353               typeConverter.needsConversion(onTy)) {
354             rewriter.setInsertionPoint(index);
355             auto toTy = typeConverter.convertType(ty);
356             auto toOnTy = typeConverter.convertType(onTy);
357             rewriter.replaceOpWithNewOp<FieldIndexOp>(
358                 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
359             opIsValid = false;
360           }
361         } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
362           auto ty = index.getType();
363           mlir::Type onTy = index.getOnType();
364           if (typeConverter.needsConversion(ty) ||
365               typeConverter.needsConversion(onTy)) {
366             rewriter.setInsertionPoint(index);
367             auto toTy = typeConverter.convertType(ty);
368             auto toOnTy = typeConverter.convertType(onTy);
369             rewriter.replaceOpWithNewOp<LenParamIndexOp>(
370                 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
371             opIsValid = false;
372           }
373         } else {
374           rewriter.startOpModification(op);
375           // Convert the operands if needed
376           for (auto i : llvm::enumerate(op->getResultTypes()))
377             if (typeConverter.needsConversion(i.value())) {
378               auto toTy = typeConverter.convertType(i.value());
379               op->getResult(i.index()).setType(toTy);
380             }
381 
382           // Convert the type attributes if needed
383           for (const mlir::NamedAttribute &attr : op->getAttrDictionary())
384             if (auto tyAttr = llvm::dyn_cast<mlir::TypeAttr>(attr.getValue()))
385               if (typeConverter.needsConversion(tyAttr.getValue())) {
386                 auto toTy = typeConverter.convertType(tyAttr.getValue());
387                 op->setAttr(attr.getName(), mlir::TypeAttr::get(toTy));
388               }
389           rewriter.finalizeOpModification(op);
390         }
391         // Ensure block arguments are updated if needed.
392         if (opIsValid && op->getNumRegions() != 0) {
393           rewriter.startOpModification(op);
394           for (mlir::Region &region : op->getRegions())
395             for (mlir::Block &block : region.getBlocks())
396               for (mlir::BlockArgument blockArg : block.getArguments())
397                 if (typeConverter.needsConversion(blockArg.getType())) {
398                   mlir::Type toTy =
399                       typeConverter.convertType(blockArg.getType());
400                   blockArg.setType(toTy);
401                 }
402           rewriter.finalizeOpModification(op);
403         }
404       });
405     }
406   }
407 
408 private:
409   BoxedProcedureOptions options;
410 };
411 } // namespace
412