xref: /llvm-project/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp (revision 546f32df26f58fdfe02d99e6d91d681dd9ed6839)
1 //===-- BoxedProcedure.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/CodeGen/CodeGen.h"
10 
11 #include "flang/Optimizer/Builder/FIRBuilder.h"
12 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
13 #include "flang/Optimizer/Dialect/FIRDialect.h"
14 #include "flang/Optimizer/Dialect/FIROps.h"
15 #include "flang/Optimizer/Dialect/FIRType.h"
16 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
17 #include "flang/Optimizer/Support/FatalError.h"
18 #include "flang/Optimizer/Support/InternalNames.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/DenseMap.h"
23 
24 namespace fir {
25 #define GEN_PASS_DEF_BOXEDPROCEDUREPASS
26 #include "flang/Optimizer/CodeGen/CGPasses.h.inc"
27 } // namespace fir
28 
29 #define DEBUG_TYPE "flang-procedure-pointer"
30 
31 using namespace fir;
32 
33 namespace {
34 /// Options to the procedure pointer pass.
35 struct BoxedProcedureOptions {
36   // Lower the boxproc abstraction to function pointers and thunks where
37   // required.
38   bool useThunks = true;
39 };
40 
41 /// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types.
42 class BoxprocTypeRewriter : public mlir::TypeConverter {
43 public:
44   using mlir::TypeConverter::convertType;
45 
46   /// Does the type \p ty need to be converted?
47   /// Any type that is a `!fir.boxproc` in whole or in part will need to be
48   /// converted to a function type to lower the IR to function pointer form in
49   /// the default implementation performed in this pass. Other implementations
50   /// are possible, so those may convert `!fir.boxproc` to some other type or
51   /// not at all depending on the implementation target's characteristics and
52   /// preference.
53   bool needsConversion(mlir::Type ty) {
54     if (ty.isa<BoxProcType>())
55       return true;
56     if (auto funcTy = ty.dyn_cast<mlir::FunctionType>()) {
57       for (auto t : funcTy.getInputs())
58         if (needsConversion(t))
59           return true;
60       for (auto t : funcTy.getResults())
61         if (needsConversion(t))
62           return true;
63       return false;
64     }
65     if (auto tupleTy = ty.dyn_cast<mlir::TupleType>()) {
66       for (auto t : tupleTy.getTypes())
67         if (needsConversion(t))
68           return true;
69       return false;
70     }
71     if (auto recTy = ty.dyn_cast<RecordType>()) {
72       if (llvm::is_contained(visitedTypes, recTy))
73         return false;
74       bool result = false;
75       visitedTypes.push_back(recTy);
76       for (auto t : recTy.getTypeList()) {
77         if (needsConversion(t.second)) {
78           result = true;
79           break;
80         }
81       }
82       visitedTypes.pop_back();
83       return result;
84     }
85     if (auto boxTy = ty.dyn_cast<BaseBoxType>())
86       return needsConversion(boxTy.getEleTy());
87     if (isa_ref_type(ty))
88       return needsConversion(unwrapRefType(ty));
89     if (auto t = ty.dyn_cast<SequenceType>())
90       return needsConversion(unwrapSequenceType(ty));
91     return false;
92   }
93 
94   BoxprocTypeRewriter(mlir::Location location) : loc{location} {
95     addConversion([](mlir::Type ty) { return ty; });
96     addConversion(
97         [&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); });
98     addConversion([&](mlir::TupleType tupTy) {
99       llvm::SmallVector<mlir::Type> memTys;
100       for (auto ty : tupTy.getTypes())
101         memTys.push_back(convertType(ty));
102       return mlir::TupleType::get(tupTy.getContext(), memTys);
103     });
104     addConversion([&](mlir::FunctionType funcTy) {
105       llvm::SmallVector<mlir::Type> inTys;
106       llvm::SmallVector<mlir::Type> resTys;
107       for (auto ty : funcTy.getInputs())
108         inTys.push_back(convertType(ty));
109       for (auto ty : funcTy.getResults())
110         resTys.push_back(convertType(ty));
111       return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys);
112     });
113     addConversion([&](ReferenceType ty) {
114       return ReferenceType::get(convertType(ty.getEleTy()));
115     });
116     addConversion([&](PointerType ty) {
117       return PointerType::get(convertType(ty.getEleTy()));
118     });
119     addConversion(
120         [&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); });
121     addConversion([&](fir::LLVMPointerType ty) {
122       return fir::LLVMPointerType::get(convertType(ty.getEleTy()));
123     });
124     addConversion(
125         [&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); });
126     addConversion([&](ClassType ty) {
127       return ClassType::get(convertType(ty.getEleTy()));
128     });
129     addConversion([&](SequenceType ty) {
130       // TODO: add ty.getLayoutMap() as needed.
131       return SequenceType::get(ty.getShape(), convertType(ty.getEleTy()));
132     });
133     addConversion([&](RecordType ty) -> mlir::Type {
134       if (!needsConversion(ty))
135         return ty;
136       if (auto converted = convertedTypes.lookup(ty))
137         return converted;
138       auto rec = RecordType::get(ty.getContext(),
139                                  ty.getName().str() + boxprocSuffix.str());
140       if (rec.isFinalized())
141         return rec;
142       auto it = convertedTypes.try_emplace(ty, rec);
143       if (!it.second) {
144         llvm::errs() << "failed\n" << ty << "\n";
145       }
146       std::vector<RecordType::TypePair> ps = ty.getLenParamList();
147       std::vector<RecordType::TypePair> cs;
148       for (auto t : ty.getTypeList()) {
149         if (needsConversion(t.second))
150           cs.emplace_back(t.first, convertType(t.second));
151         else
152           cs.emplace_back(t.first, t.second);
153       }
154       rec.finalize(ps, cs);
155       return rec;
156     });
157     addArgumentMaterialization(materializeProcedure);
158     addSourceMaterialization(materializeProcedure);
159     addTargetMaterialization(materializeProcedure);
160   }
161 
162   static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
163                                           BoxProcType type,
164                                           mlir::ValueRange inputs,
165                                           mlir::Location loc) {
166     assert(inputs.size() == 1);
167     return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
168                                      inputs[0]);
169   }
170 
171   void setLocation(mlir::Location location) { loc = location; }
172 
173 private:
174   llvm::SmallVector<mlir::Type> visitedTypes;
175   // Map to deal with recursive derived types (avoid infinite loops).
176   // Caching is also beneficial for apps with big types (dozens of
177   // components and or parent types), so the lifetime of the cache
178   // is the whole pass.
179   llvm::DenseMap<mlir::Type, mlir::Type> convertedTypes;
180   mlir::Location loc;
181 };
182 
183 /// A `boxproc` is an abstraction for a Fortran procedure reference. Typically,
184 /// Fortran procedures can be referenced directly through a function pointer.
185 /// However, Fortran has one-level dynamic scoping between a host procedure and
186 /// its internal procedures. This allows internal procedures to directly access
187 /// and modify the state of the host procedure's variables.
188 ///
189 /// There are any number of possible implementations possible.
190 ///
191 /// The implementation used here is to convert `boxproc` values to function
192 /// pointers everywhere. If a `boxproc` value includes a frame pointer to the
193 /// host procedure's data, then a thunk will be created at runtime to capture
194 /// the frame pointer during execution. In LLVM IR, the frame pointer is
195 /// designated with the `nest` attribute. The thunk's address will then be used
196 /// as the call target instead of the original function's address directly.
197 class BoxedProcedurePass
198     : public fir::impl::BoxedProcedurePassBase<BoxedProcedurePass> {
199 public:
200   BoxedProcedurePass() { options = {true}; }
201   BoxedProcedurePass(bool useThunks) { options = {useThunks}; }
202 
203   inline mlir::ModuleOp getModule() { return getOperation(); }
204 
205   void runOnOperation() override final {
206     if (options.useThunks) {
207       auto *context = &getContext();
208       mlir::IRRewriter rewriter(context);
209       BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
210       mlir::Dialect *firDialect = context->getLoadedDialect("fir");
211       getModule().walk([&](mlir::Operation *op) {
212         bool opIsValid = true;
213         typeConverter.setLocation(op->getLoc());
214         if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
215           mlir::Type ty = addr.getVal().getType();
216           mlir::Type resTy = addr.getResult().getType();
217           if (llvm::isa<mlir::FunctionType>(ty) ||
218               llvm::isa<fir::BoxProcType>(ty)) {
219             // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
220             // or function type to be `fir.convert` ops.
221             rewriter.setInsertionPoint(addr);
222             rewriter.replaceOpWithNewOp<ConvertOp>(
223                 addr, typeConverter.convertType(addr.getType()), addr.getVal());
224             opIsValid = false;
225           } else if (typeConverter.needsConversion(resTy)) {
226             rewriter.startOpModification(op);
227             op->getResult(0).setType(typeConverter.convertType(resTy));
228             rewriter.finalizeOpModification(op);
229           }
230         } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
231           mlir::FunctionType ty = func.getFunctionType();
232           if (typeConverter.needsConversion(ty)) {
233             rewriter.startOpModification(func);
234             auto toTy =
235                 typeConverter.convertType(ty).cast<mlir::FunctionType>();
236             if (!func.empty())
237               for (auto e : llvm::enumerate(toTy.getInputs())) {
238                 unsigned i = e.index();
239                 auto &block = func.front();
240                 block.insertArgument(i, e.value(), func.getLoc());
241                 block.getArgument(i + 1).replaceAllUsesWith(
242                     block.getArgument(i));
243                 block.eraseArgument(i + 1);
244               }
245             func.setType(toTy);
246             rewriter.finalizeOpModification(func);
247           }
248         } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
249           // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
250           // as required.
251           mlir::Type toTy = typeConverter.convertType(
252               embox.getType().cast<BoxProcType>().getEleTy());
253           rewriter.setInsertionPoint(embox);
254           if (embox.getHost()) {
255             // Create the thunk.
256             auto module = embox->getParentOfType<mlir::ModuleOp>();
257             FirOpBuilder builder(rewriter, module);
258             auto loc = embox.getLoc();
259             mlir::Type i8Ty = builder.getI8Type();
260             mlir::Type i8Ptr = builder.getRefType(i8Ty);
261             mlir::Type buffTy = SequenceType::get({32}, i8Ty);
262             auto buffer = builder.create<AllocaOp>(loc, buffTy);
263             mlir::Value closure =
264                 builder.createConvert(loc, i8Ptr, embox.getHost());
265             mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer);
266             mlir::Value func =
267                 builder.createConvert(loc, i8Ptr, embox.getFunc());
268             builder.create<fir::CallOp>(
269                 loc, factory::getLlvmInitTrampoline(builder),
270                 llvm::ArrayRef<mlir::Value>{tramp, func, closure});
271             auto adjustCall = builder.create<fir::CallOp>(
272                 loc, factory::getLlvmAdjustTrampoline(builder),
273                 llvm::ArrayRef<mlir::Value>{tramp});
274             rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
275                                                    adjustCall.getResult(0));
276             opIsValid = false;
277           } else {
278             // Just forward the function as a pointer.
279             rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
280                                                    embox.getFunc());
281             opIsValid = false;
282           }
283         } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
284           auto ty = global.getType();
285           if (typeConverter.needsConversion(ty)) {
286             rewriter.startOpModification(global);
287             auto toTy = typeConverter.convertType(ty);
288             global.setType(toTy);
289             rewriter.finalizeOpModification(global);
290           }
291         } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
292           auto ty = mem.getType();
293           if (typeConverter.needsConversion(ty)) {
294             rewriter.setInsertionPoint(mem);
295             auto toTy = typeConverter.convertType(unwrapRefType(ty));
296             bool isPinned = mem.getPinned();
297             llvm::StringRef uniqName =
298                 mem.getUniqName().value_or(llvm::StringRef());
299             llvm::StringRef bindcName =
300                 mem.getBindcName().value_or(llvm::StringRef());
301             rewriter.replaceOpWithNewOp<AllocaOp>(
302                 mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
303                 mem.getShape());
304             opIsValid = false;
305           }
306         } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
307           auto ty = mem.getType();
308           if (typeConverter.needsConversion(ty)) {
309             rewriter.setInsertionPoint(mem);
310             auto toTy = typeConverter.convertType(unwrapRefType(ty));
311             llvm::StringRef uniqName =
312                 mem.getUniqName().value_or(llvm::StringRef());
313             llvm::StringRef bindcName =
314                 mem.getBindcName().value_or(llvm::StringRef());
315             rewriter.replaceOpWithNewOp<AllocMemOp>(
316                 mem, toTy, uniqName, bindcName, mem.getTypeparams(),
317                 mem.getShape());
318             opIsValid = false;
319           }
320         } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
321           auto ty = coor.getType();
322           mlir::Type baseTy = coor.getBaseType();
323           if (typeConverter.needsConversion(ty) ||
324               typeConverter.needsConversion(baseTy)) {
325             rewriter.setInsertionPoint(coor);
326             auto toTy = typeConverter.convertType(ty);
327             auto toBaseTy = typeConverter.convertType(baseTy);
328             rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(),
329                                                       coor.getCoor(), toBaseTy);
330             opIsValid = false;
331           }
332         } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
333           auto ty = index.getType();
334           mlir::Type onTy = index.getOnType();
335           if (typeConverter.needsConversion(ty) ||
336               typeConverter.needsConversion(onTy)) {
337             rewriter.setInsertionPoint(index);
338             auto toTy = typeConverter.convertType(ty);
339             auto toOnTy = typeConverter.convertType(onTy);
340             rewriter.replaceOpWithNewOp<FieldIndexOp>(
341                 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
342             opIsValid = false;
343           }
344         } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
345           auto ty = index.getType();
346           mlir::Type onTy = index.getOnType();
347           if (typeConverter.needsConversion(ty) ||
348               typeConverter.needsConversion(onTy)) {
349             rewriter.setInsertionPoint(index);
350             auto toTy = typeConverter.convertType(ty);
351             auto toOnTy = typeConverter.convertType(onTy);
352             rewriter.replaceOpWithNewOp<LenParamIndexOp>(
353                 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
354             opIsValid = false;
355           }
356         } else if (op->getDialect() == firDialect) {
357           rewriter.startOpModification(op);
358           for (auto i : llvm::enumerate(op->getResultTypes()))
359             if (typeConverter.needsConversion(i.value())) {
360               auto toTy = typeConverter.convertType(i.value());
361               op->getResult(i.index()).setType(toTy);
362             }
363           rewriter.finalizeOpModification(op);
364         }
365         // Ensure block arguments are updated if needed.
366         if (opIsValid && op->getNumRegions() != 0) {
367           rewriter.startOpModification(op);
368           for (mlir::Region &region : op->getRegions())
369             for (mlir::Block &block : region.getBlocks())
370               for (mlir::BlockArgument blockArg : block.getArguments())
371                 if (typeConverter.needsConversion(blockArg.getType())) {
372                   mlir::Type toTy =
373                       typeConverter.convertType(blockArg.getType());
374                   blockArg.setType(toTy);
375                 }
376           rewriter.finalizeOpModification(op);
377         }
378       });
379     }
380   }
381 
382 private:
383   BoxedProcedureOptions options;
384 };
385 } // namespace
386 
387 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass() {
388   return std::make_unique<BoxedProcedurePass>();
389 }
390 
391 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass(bool useThunks) {
392   return std::make_unique<BoxedProcedurePass>(useThunks);
393 }
394