xref: /llvm-project/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp (revision c67a87444f5585aafa659c52544ac912f283e7e3)
1 //===-- BoxedProcedure.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Builder/FIRBuilder.h"
11 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
12 #include "flang/Optimizer/CodeGen/CodeGen.h"
13 #include "flang/Optimizer/Dialect/FIRDialect.h"
14 #include "flang/Optimizer/Dialect/FIROps.h"
15 #include "flang/Optimizer/Dialect/FIRType.h"
16 #include "flang/Optimizer/Support/FIRContext.h"
17 #include "flang/Optimizer/Support/FatalError.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 #define DEBUG_TYPE "flang-procedure-pointer"
23 
24 using namespace fir;
25 
26 namespace {
27 /// Options to the procedure pointer pass.
28 struct BoxedProcedureOptions {
29   // Lower the boxproc abstraction to function pointers and thunks where
30   // required.
31   bool useThunks = true;
32 };
33 
34 /// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types.
35 class BoxprocTypeRewriter : public mlir::TypeConverter {
36 public:
37   using mlir::TypeConverter::convertType;
38 
39   /// Does the type \p ty need to be converted?
40   /// Any type that is a `!fir.boxproc` in whole or in part will need to be
41   /// converted to a function type to lower the IR to function pointer form in
42   /// the default implementation performed in this pass. Other implementations
43   /// are possible, so those may convert `!fir.boxproc` to some other type or
44   /// not at all depending on the implementation target's characteristics and
45   /// preference.
46   bool needsConversion(mlir::Type ty) {
47     if (ty.isa<BoxProcType>())
48       return true;
49     if (auto funcTy = ty.dyn_cast<mlir::FunctionType>()) {
50       for (auto t : funcTy.getInputs())
51         if (needsConversion(t))
52           return true;
53       for (auto t : funcTy.getResults())
54         if (needsConversion(t))
55           return true;
56       return false;
57     }
58     if (auto tupleTy = ty.dyn_cast<mlir::TupleType>()) {
59       for (auto t : tupleTy.getTypes())
60         if (needsConversion(t))
61           return true;
62       return false;
63     }
64     if (auto recTy = ty.dyn_cast<RecordType>()) {
65       if (llvm::any_of(visitedTypes,
66                        [&](mlir::Type rt) { return rt == recTy; }))
67         return false;
68       bool result = false;
69       visitedTypes.push_back(recTy);
70       for (auto t : recTy.getTypeList()) {
71         if (needsConversion(t.second)) {
72           result = true;
73           break;
74         }
75       }
76       visitedTypes.pop_back();
77       return result;
78     }
79     if (auto boxTy = ty.dyn_cast<BoxType>())
80       return needsConversion(boxTy.getEleTy());
81     if (isa_ref_type(ty))
82       return needsConversion(unwrapRefType(ty));
83     if (auto t = ty.dyn_cast<SequenceType>())
84       return needsConversion(unwrapSequenceType(ty));
85     return false;
86   }
87 
88   BoxprocTypeRewriter(mlir::Location location) : loc{location} {
89     addConversion([](mlir::Type ty) { return ty; });
90     addConversion(
91         [&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); });
92     addConversion([&](mlir::TupleType tupTy) {
93       llvm::SmallVector<mlir::Type> memTys;
94       for (auto ty : tupTy.getTypes())
95         memTys.push_back(convertType(ty));
96       return mlir::TupleType::get(tupTy.getContext(), memTys);
97     });
98     addConversion([&](mlir::FunctionType funcTy) {
99       llvm::SmallVector<mlir::Type> inTys;
100       llvm::SmallVector<mlir::Type> resTys;
101       for (auto ty : funcTy.getInputs())
102         inTys.push_back(convertType(ty));
103       for (auto ty : funcTy.getResults())
104         resTys.push_back(convertType(ty));
105       return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys);
106     });
107     addConversion([&](ReferenceType ty) {
108       return ReferenceType::get(convertType(ty.getEleTy()));
109     });
110     addConversion([&](PointerType ty) {
111       return PointerType::get(convertType(ty.getEleTy()));
112     });
113     addConversion(
114         [&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); });
115     addConversion(
116         [&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); });
117     addConversion([&](SequenceType ty) {
118       // TODO: add ty.getLayoutMap() as needed.
119       return SequenceType::get(ty.getShape(), convertType(ty.getEleTy()));
120     });
121     addConversion([&](RecordType ty) -> mlir::Type {
122       if (!needsConversion(ty))
123         return ty;
124       // FIR record types can have recursive references, so conversion is a bit
125       // more complex than the other types. This conversion is not needed
126       // presently, so just emit a TODO message. Need to consider the uniqued
127       // name of the record, etc. Also, fir::RecordType::get returns the
128       // existing type being translated. So finalize() will not change it, and
129       // the translation would not do anything. So the type needs to be mutated,
130       // and this might require special care to comply with MLIR infrastructure.
131 
132       // TODO: this will be needed to support derived type containing procedure
133       // pointer components.
134       fir::emitFatalError(
135           loc, "not yet implemented: record type with a boxproc type");
136       return RecordType::get(ty.getContext(), "*fixme*");
137     });
138     addArgumentMaterialization(materializeProcedure);
139     addSourceMaterialization(materializeProcedure);
140     addTargetMaterialization(materializeProcedure);
141   }
142 
143   static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
144                                           BoxProcType type,
145                                           mlir::ValueRange inputs,
146                                           mlir::Location loc) {
147     assert(inputs.size() == 1);
148     return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
149                                      inputs[0]);
150   }
151 
152   void setLocation(mlir::Location location) { loc = location; }
153 
154 private:
155   llvm::SmallVector<mlir::Type> visitedTypes;
156   mlir::Location loc;
157 };
158 
159 /// A `boxproc` is an abstraction for a Fortran procedure reference. Typically,
160 /// Fortran procedures can be referenced directly through a function pointer.
161 /// However, Fortran has one-level dynamic scoping between a host procedure and
162 /// its internal procedures. This allows internal procedures to directly access
163 /// and modify the state of the host procedure's variables.
164 ///
165 /// There are any number of possible implementations possible.
166 ///
167 /// The implementation used here is to convert `boxproc` values to function
168 /// pointers everywhere. If a `boxproc` value includes a frame pointer to the
169 /// host procedure's data, then a thunk will be created at runtime to capture
170 /// the frame pointer during execution. In LLVM IR, the frame pointer is
171 /// designated with the `nest` attribute. The thunk's address will then be used
172 /// as the call target instead of the original function's address directly.
173 class BoxedProcedurePass : public BoxedProcedurePassBase<BoxedProcedurePass> {
174 public:
175   BoxedProcedurePass() { options = {true}; }
176   BoxedProcedurePass(bool useThunks) { options = {useThunks}; }
177 
178   inline mlir::ModuleOp getModule() { return getOperation(); }
179 
180   void runOnOperation() override final {
181     if (options.useThunks) {
182       auto *context = &getContext();
183       mlir::IRRewriter rewriter(context);
184       BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
185       mlir::Dialect *firDialect = context->getLoadedDialect("fir");
186       getModule().walk([&](mlir::Operation *op) {
187         typeConverter.setLocation(op->getLoc());
188         if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
189           auto ty = addr.getVal().getType();
190           if (typeConverter.needsConversion(ty) ||
191               ty.isa<mlir::FunctionType>()) {
192             // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
193             // or function type to be `fir.convert` ops.
194             rewriter.setInsertionPoint(addr);
195             rewriter.replaceOpWithNewOp<ConvertOp>(
196                 addr, typeConverter.convertType(addr.getType()), addr.getVal());
197           }
198         } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
199           mlir::FunctionType ty = func.getFunctionType();
200           if (typeConverter.needsConversion(ty)) {
201             rewriter.startRootUpdate(func);
202             auto toTy =
203                 typeConverter.convertType(ty).cast<mlir::FunctionType>();
204             if (!func.empty())
205               for (auto e : llvm::enumerate(toTy.getInputs())) {
206                 unsigned i = e.index();
207                 auto &block = func.front();
208                 block.insertArgument(i, e.value(), func.getLoc());
209                 block.getArgument(i + 1).replaceAllUsesWith(
210                     block.getArgument(i));
211                 block.eraseArgument(i + 1);
212               }
213             func.setType(toTy);
214             rewriter.finalizeRootUpdate(func);
215           }
216         } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
217           // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
218           // as required.
219           mlir::Type toTy = embox.getType().cast<BoxProcType>().getEleTy();
220           rewriter.setInsertionPoint(embox);
221           if (embox.getHost()) {
222             // Create the thunk.
223             auto module = embox->getParentOfType<mlir::ModuleOp>();
224             FirOpBuilder builder(rewriter, getKindMapping(module));
225             auto loc = embox.getLoc();
226             mlir::Type i8Ty = builder.getI8Type();
227             mlir::Type i8Ptr = builder.getRefType(i8Ty);
228             mlir::Type buffTy = SequenceType::get({32}, i8Ty);
229             auto buffer = builder.create<AllocaOp>(loc, buffTy);
230             mlir::Value closure =
231                 builder.createConvert(loc, i8Ptr, embox.getHost());
232             mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer);
233             mlir::Value func =
234                 builder.createConvert(loc, i8Ptr, embox.getFunc());
235             builder.create<fir::CallOp>(
236                 loc, factory::getLlvmInitTrampoline(builder),
237                 llvm::ArrayRef<mlir::Value>{tramp, func, closure});
238             auto adjustCall = builder.create<fir::CallOp>(
239                 loc, factory::getLlvmAdjustTrampoline(builder),
240                 llvm::ArrayRef<mlir::Value>{tramp});
241             rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
242                                                    adjustCall.getResult(0));
243           } else {
244             // Just forward the function as a pointer.
245             rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
246                                                    embox.getFunc());
247           }
248         } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
249           auto ty = mem.getType();
250           if (typeConverter.needsConversion(ty)) {
251             rewriter.setInsertionPoint(mem);
252             auto toTy = typeConverter.convertType(unwrapRefType(ty));
253             bool isPinned = mem.getPinned();
254             llvm::StringRef uniqName;
255             if (mem.getUniqName().hasValue())
256               uniqName = mem.getUniqName().getValue();
257             llvm::StringRef bindcName;
258             if (mem.getBindcName().hasValue())
259               bindcName = mem.getBindcName().getValue();
260             rewriter.replaceOpWithNewOp<AllocaOp>(
261                 mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
262                 mem.getShape());
263           }
264         } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
265           auto ty = mem.getType();
266           if (typeConverter.needsConversion(ty)) {
267             rewriter.setInsertionPoint(mem);
268             auto toTy = typeConverter.convertType(unwrapRefType(ty));
269             llvm::StringRef uniqName;
270             if (mem.getUniqName().hasValue())
271               uniqName = mem.getUniqName().getValue();
272             llvm::StringRef bindcName;
273             if (mem.getBindcName().hasValue())
274               bindcName = mem.getBindcName().getValue();
275             rewriter.replaceOpWithNewOp<AllocMemOp>(
276                 mem, toTy, uniqName, bindcName, mem.getTypeparams(),
277                 mem.getShape());
278           }
279         } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
280           auto ty = coor.getType();
281           mlir::Type baseTy = coor.getBaseType();
282           if (typeConverter.needsConversion(ty) ||
283               typeConverter.needsConversion(baseTy)) {
284             rewriter.setInsertionPoint(coor);
285             auto toTy = typeConverter.convertType(ty);
286             auto toBaseTy = typeConverter.convertType(baseTy);
287             rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(),
288                                                       coor.getCoor(), toBaseTy);
289           }
290         } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
291           auto ty = index.getType();
292           mlir::Type onTy = index.getOnType();
293           if (typeConverter.needsConversion(ty) ||
294               typeConverter.needsConversion(onTy)) {
295             rewriter.setInsertionPoint(index);
296             auto toTy = typeConverter.convertType(ty);
297             auto toOnTy = typeConverter.convertType(onTy);
298             rewriter.replaceOpWithNewOp<FieldIndexOp>(
299                 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
300           }
301         } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
302           auto ty = index.getType();
303           mlir::Type onTy = index.getOnType();
304           if (typeConverter.needsConversion(ty) ||
305               typeConverter.needsConversion(onTy)) {
306             rewriter.setInsertionPoint(index);
307             auto toTy = typeConverter.convertType(ty);
308             auto toOnTy = typeConverter.convertType(onTy);
309             rewriter.replaceOpWithNewOp<LenParamIndexOp>(
310                 mem, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
311           }
312         } else if (op->getDialect() == firDialect) {
313           rewriter.startRootUpdate(op);
314           for (auto i : llvm::enumerate(op->getResultTypes()))
315             if (typeConverter.needsConversion(i.value())) {
316               auto toTy = typeConverter.convertType(i.value());
317               op->getResult(i.index()).setType(toTy);
318             }
319           rewriter.finalizeRootUpdate(op);
320         }
321       });
322     }
323     // TODO: any alternative implementation. Note: currently, the default code
324     // gen will not be able to handle boxproc and will give an error.
325   }
326 
327 private:
328   BoxedProcedureOptions options;
329 };
330 } // namespace
331 
332 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass() {
333   return std::make_unique<BoxedProcedurePass>();
334 }
335 
336 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass(bool useThunks) {
337   return std::make_unique<BoxedProcedurePass>(useThunks);
338 }
339