xref: /llvm-project/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp (revision 67d0d7ac0acb0665d6a09f61278fbcf51f0114c2)
1 //===-- BoxedProcedure.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/CodeGen/CodeGen.h"
10 
11 #include "flang/Optimizer/Builder/FIRBuilder.h"
12 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
13 #include "flang/Optimizer/Dialect/FIRDialect.h"
14 #include "flang/Optimizer/Dialect/FIROps.h"
15 #include "flang/Optimizer/Dialect/FIRType.h"
16 #include "flang/Optimizer/Support/FIRContext.h"
17 #include "flang/Optimizer/Support/FatalError.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 namespace fir {
23 #define GEN_PASS_DEF_BOXEDPROCEDUREPASS
24 #include "flang/Optimizer/CodeGen/CGPasses.h.inc"
25 } // namespace fir
26 
27 #define DEBUG_TYPE "flang-procedure-pointer"
28 
29 using namespace fir;
30 
31 namespace {
32 /// Options to the procedure pointer pass.
33 struct BoxedProcedureOptions {
34   // Lower the boxproc abstraction to function pointers and thunks where
35   // required.
36   bool useThunks = true;
37 };
38 
39 /// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types.
40 class BoxprocTypeRewriter : public mlir::TypeConverter {
41 public:
42   using mlir::TypeConverter::convertType;
43 
44   /// Does the type \p ty need to be converted?
45   /// Any type that is a `!fir.boxproc` in whole or in part will need to be
46   /// converted to a function type to lower the IR to function pointer form in
47   /// the default implementation performed in this pass. Other implementations
48   /// are possible, so those may convert `!fir.boxproc` to some other type or
49   /// not at all depending on the implementation target's characteristics and
50   /// preference.
51   bool needsConversion(mlir::Type ty) {
52     if (ty.isa<BoxProcType>())
53       return true;
54     if (auto funcTy = ty.dyn_cast<mlir::FunctionType>()) {
55       for (auto t : funcTy.getInputs())
56         if (needsConversion(t))
57           return true;
58       for (auto t : funcTy.getResults())
59         if (needsConversion(t))
60           return true;
61       return false;
62     }
63     if (auto tupleTy = ty.dyn_cast<mlir::TupleType>()) {
64       for (auto t : tupleTy.getTypes())
65         if (needsConversion(t))
66           return true;
67       return false;
68     }
69     if (auto recTy = ty.dyn_cast<RecordType>()) {
70       if (llvm::is_contained(visitedTypes, recTy))
71         return false;
72       bool result = false;
73       visitedTypes.push_back(recTy);
74       for (auto t : recTy.getTypeList()) {
75         if (needsConversion(t.second)) {
76           result = true;
77           break;
78         }
79       }
80       visitedTypes.pop_back();
81       return result;
82     }
83     if (auto boxTy = ty.dyn_cast<BoxType>())
84       return needsConversion(boxTy.getEleTy());
85     if (isa_ref_type(ty))
86       return needsConversion(unwrapRefType(ty));
87     if (auto t = ty.dyn_cast<SequenceType>())
88       return needsConversion(unwrapSequenceType(ty));
89     return false;
90   }
91 
92   BoxprocTypeRewriter(mlir::Location location) : loc{location} {
93     addConversion([](mlir::Type ty) { return ty; });
94     addConversion(
95         [&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); });
96     addConversion([&](mlir::TupleType tupTy) {
97       llvm::SmallVector<mlir::Type> memTys;
98       for (auto ty : tupTy.getTypes())
99         memTys.push_back(convertType(ty));
100       return mlir::TupleType::get(tupTy.getContext(), memTys);
101     });
102     addConversion([&](mlir::FunctionType funcTy) {
103       llvm::SmallVector<mlir::Type> inTys;
104       llvm::SmallVector<mlir::Type> resTys;
105       for (auto ty : funcTy.getInputs())
106         inTys.push_back(convertType(ty));
107       for (auto ty : funcTy.getResults())
108         resTys.push_back(convertType(ty));
109       return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys);
110     });
111     addConversion([&](ReferenceType ty) {
112       return ReferenceType::get(convertType(ty.getEleTy()));
113     });
114     addConversion([&](PointerType ty) {
115       return PointerType::get(convertType(ty.getEleTy()));
116     });
117     addConversion(
118         [&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); });
119     addConversion(
120         [&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); });
121     addConversion([&](SequenceType ty) {
122       // TODO: add ty.getLayoutMap() as needed.
123       return SequenceType::get(ty.getShape(), convertType(ty.getEleTy()));
124     });
125     addConversion([&](RecordType ty) -> mlir::Type {
126       if (!needsConversion(ty))
127         return ty;
128       // FIR record types can have recursive references, so conversion is a bit
129       // more complex than the other types. This conversion is not needed
130       // presently, so just emit a TODO message. Need to consider the uniqued
131       // name of the record, etc. Also, fir::RecordType::get returns the
132       // existing type being translated. So finalize() will not change it, and
133       // the translation would not do anything. So the type needs to be mutated,
134       // and this might require special care to comply with MLIR infrastructure.
135 
136       // TODO: this will be needed to support derived type containing procedure
137       // pointer components.
138       fir::emitFatalError(
139           loc, "not yet implemented: record type with a boxproc type");
140       return RecordType::get(ty.getContext(), "*fixme*");
141     });
142     addArgumentMaterialization(materializeProcedure);
143     addSourceMaterialization(materializeProcedure);
144     addTargetMaterialization(materializeProcedure);
145   }
146 
147   static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
148                                           BoxProcType type,
149                                           mlir::ValueRange inputs,
150                                           mlir::Location loc) {
151     assert(inputs.size() == 1);
152     return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
153                                      inputs[0]);
154   }
155 
156   void setLocation(mlir::Location location) { loc = location; }
157 
158 private:
159   llvm::SmallVector<mlir::Type> visitedTypes;
160   mlir::Location loc;
161 };
162 
163 /// A `boxproc` is an abstraction for a Fortran procedure reference. Typically,
164 /// Fortran procedures can be referenced directly through a function pointer.
165 /// However, Fortran has one-level dynamic scoping between a host procedure and
166 /// its internal procedures. This allows internal procedures to directly access
167 /// and modify the state of the host procedure's variables.
168 ///
169 /// There are any number of possible implementations possible.
170 ///
171 /// The implementation used here is to convert `boxproc` values to function
172 /// pointers everywhere. If a `boxproc` value includes a frame pointer to the
173 /// host procedure's data, then a thunk will be created at runtime to capture
174 /// the frame pointer during execution. In LLVM IR, the frame pointer is
175 /// designated with the `nest` attribute. The thunk's address will then be used
176 /// as the call target instead of the original function's address directly.
177 class BoxedProcedurePass
178     : public fir::impl::BoxedProcedurePassBase<BoxedProcedurePass> {
179 public:
180   BoxedProcedurePass() { options = {true}; }
181   BoxedProcedurePass(bool useThunks) { options = {useThunks}; }
182 
183   inline mlir::ModuleOp getModule() { return getOperation(); }
184 
185   void runOnOperation() override final {
186     if (options.useThunks) {
187       auto *context = &getContext();
188       mlir::IRRewriter rewriter(context);
189       BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
190       mlir::Dialect *firDialect = context->getLoadedDialect("fir");
191       getModule().walk([&](mlir::Operation *op) {
192         typeConverter.setLocation(op->getLoc());
193         if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
194           auto ty = addr.getVal().getType();
195           if (typeConverter.needsConversion(ty) ||
196               ty.isa<mlir::FunctionType>()) {
197             // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
198             // or function type to be `fir.convert` ops.
199             rewriter.setInsertionPoint(addr);
200             rewriter.replaceOpWithNewOp<ConvertOp>(
201                 addr, typeConverter.convertType(addr.getType()), addr.getVal());
202           }
203         } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
204           mlir::FunctionType ty = func.getFunctionType();
205           if (typeConverter.needsConversion(ty)) {
206             rewriter.startRootUpdate(func);
207             auto toTy =
208                 typeConverter.convertType(ty).cast<mlir::FunctionType>();
209             if (!func.empty())
210               for (auto e : llvm::enumerate(toTy.getInputs())) {
211                 unsigned i = e.index();
212                 auto &block = func.front();
213                 block.insertArgument(i, e.value(), func.getLoc());
214                 block.getArgument(i + 1).replaceAllUsesWith(
215                     block.getArgument(i));
216                 block.eraseArgument(i + 1);
217               }
218             func.setType(toTy);
219             rewriter.finalizeRootUpdate(func);
220           }
221         } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
222           // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
223           // as required.
224           mlir::Type toTy = embox.getType().cast<BoxProcType>().getEleTy();
225           rewriter.setInsertionPoint(embox);
226           if (embox.getHost()) {
227             // Create the thunk.
228             auto module = embox->getParentOfType<mlir::ModuleOp>();
229             fir::KindMapping kindMap = getKindMapping(module);
230             FirOpBuilder builder(rewriter, kindMap);
231             auto loc = embox.getLoc();
232             mlir::Type i8Ty = builder.getI8Type();
233             mlir::Type i8Ptr = builder.getRefType(i8Ty);
234             mlir::Type buffTy = SequenceType::get({32}, i8Ty);
235             auto buffer = builder.create<AllocaOp>(loc, buffTy);
236             mlir::Value closure =
237                 builder.createConvert(loc, i8Ptr, embox.getHost());
238             mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer);
239             mlir::Value func =
240                 builder.createConvert(loc, i8Ptr, embox.getFunc());
241             builder.create<fir::CallOp>(
242                 loc, factory::getLlvmInitTrampoline(builder),
243                 llvm::ArrayRef<mlir::Value>{tramp, func, closure});
244             auto adjustCall = builder.create<fir::CallOp>(
245                 loc, factory::getLlvmAdjustTrampoline(builder),
246                 llvm::ArrayRef<mlir::Value>{tramp});
247             rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
248                                                    adjustCall.getResult(0));
249           } else {
250             // Just forward the function as a pointer.
251             rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
252                                                    embox.getFunc());
253           }
254         } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
255           auto ty = mem.getType();
256           if (typeConverter.needsConversion(ty)) {
257             rewriter.setInsertionPoint(mem);
258             auto toTy = typeConverter.convertType(unwrapRefType(ty));
259             bool isPinned = mem.getPinned();
260             llvm::StringRef uniqName =
261                 mem.getUniqName().value_or(llvm::StringRef());
262             llvm::StringRef bindcName =
263                 mem.getBindcName().value_or(llvm::StringRef());
264             rewriter.replaceOpWithNewOp<AllocaOp>(
265                 mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
266                 mem.getShape());
267           }
268         } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
269           auto ty = mem.getType();
270           if (typeConverter.needsConversion(ty)) {
271             rewriter.setInsertionPoint(mem);
272             auto toTy = typeConverter.convertType(unwrapRefType(ty));
273             llvm::StringRef uniqName =
274                 mem.getUniqName().value_or(llvm::StringRef());
275             llvm::StringRef bindcName =
276                 mem.getBindcName().value_or(llvm::StringRef());
277             rewriter.replaceOpWithNewOp<AllocMemOp>(
278                 mem, toTy, uniqName, bindcName, mem.getTypeparams(),
279                 mem.getShape());
280           }
281         } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
282           auto ty = coor.getType();
283           mlir::Type baseTy = coor.getBaseType();
284           if (typeConverter.needsConversion(ty) ||
285               typeConverter.needsConversion(baseTy)) {
286             rewriter.setInsertionPoint(coor);
287             auto toTy = typeConverter.convertType(ty);
288             auto toBaseTy = typeConverter.convertType(baseTy);
289             rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(),
290                                                       coor.getCoor(), toBaseTy);
291           }
292         } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
293           auto ty = index.getType();
294           mlir::Type onTy = index.getOnType();
295           if (typeConverter.needsConversion(ty) ||
296               typeConverter.needsConversion(onTy)) {
297             rewriter.setInsertionPoint(index);
298             auto toTy = typeConverter.convertType(ty);
299             auto toOnTy = typeConverter.convertType(onTy);
300             rewriter.replaceOpWithNewOp<FieldIndexOp>(
301                 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
302           }
303         } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
304           auto ty = index.getType();
305           mlir::Type onTy = index.getOnType();
306           if (typeConverter.needsConversion(ty) ||
307               typeConverter.needsConversion(onTy)) {
308             rewriter.setInsertionPoint(index);
309             auto toTy = typeConverter.convertType(ty);
310             auto toOnTy = typeConverter.convertType(onTy);
311             rewriter.replaceOpWithNewOp<LenParamIndexOp>(
312                 mem, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
313           }
314         } else if (op->getDialect() == firDialect) {
315           rewriter.startRootUpdate(op);
316           for (auto i : llvm::enumerate(op->getResultTypes()))
317             if (typeConverter.needsConversion(i.value())) {
318               auto toTy = typeConverter.convertType(i.value());
319               op->getResult(i.index()).setType(toTy);
320             }
321           rewriter.finalizeRootUpdate(op);
322         }
323       });
324     }
325     // TODO: any alternative implementation. Note: currently, the default code
326     // gen will not be able to handle boxproc and will give an error.
327   }
328 
329 private:
330   BoxedProcedureOptions options;
331 };
332 } // namespace
333 
334 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass() {
335   return std::make_unique<BoxedProcedurePass>();
336 }
337 
338 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass(bool useThunks) {
339   return std::make_unique<BoxedProcedurePass>(useThunks);
340 }
341