xref: /llvm-project/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp (revision feb9d33a2aa5f0bec98f4da5c4330e36ea59ce96)
1 //===-- BoxedProcedure.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/CodeGen/CodeGen.h"
10 
11 #include "flang/Optimizer/Builder/FIRBuilder.h"
12 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
13 #include "flang/Optimizer/Dialect/FIRDialect.h"
14 #include "flang/Optimizer/Dialect/FIROps.h"
15 #include "flang/Optimizer/Dialect/FIRType.h"
16 #include "flang/Optimizer/Support/FIRContext.h"
17 #include "flang/Optimizer/Support/FatalError.h"
18 #include "flang/Optimizer/Support/InternalNames.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 namespace fir {
24 #define GEN_PASS_DEF_BOXEDPROCEDUREPASS
25 #include "flang/Optimizer/CodeGen/CGPasses.h.inc"
26 } // namespace fir
27 
28 #define DEBUG_TYPE "flang-procedure-pointer"
29 
30 using namespace fir;
31 
32 namespace {
33 /// Options to the procedure pointer pass.
34 struct BoxedProcedureOptions {
35   // Lower the boxproc abstraction to function pointers and thunks where
36   // required.
37   bool useThunks = true;
38 };
39 
40 /// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types.
41 class BoxprocTypeRewriter : public mlir::TypeConverter {
42 public:
43   using mlir::TypeConverter::convertType;
44 
45   /// Does the type \p ty need to be converted?
46   /// Any type that is a `!fir.boxproc` in whole or in part will need to be
47   /// converted to a function type to lower the IR to function pointer form in
48   /// the default implementation performed in this pass. Other implementations
49   /// are possible, so those may convert `!fir.boxproc` to some other type or
50   /// not at all depending on the implementation target's characteristics and
51   /// preference.
52   bool needsConversion(mlir::Type ty) {
53     if (ty.isa<BoxProcType>())
54       return true;
55     if (auto funcTy = ty.dyn_cast<mlir::FunctionType>()) {
56       for (auto t : funcTy.getInputs())
57         if (needsConversion(t))
58           return true;
59       for (auto t : funcTy.getResults())
60         if (needsConversion(t))
61           return true;
62       return false;
63     }
64     if (auto tupleTy = ty.dyn_cast<mlir::TupleType>()) {
65       for (auto t : tupleTy.getTypes())
66         if (needsConversion(t))
67           return true;
68       return false;
69     }
70     if (auto recTy = ty.dyn_cast<RecordType>()) {
71       if (llvm::is_contained(visitedTypes, recTy))
72         return false;
73       bool result = false;
74       visitedTypes.push_back(recTy);
75       for (auto t : recTy.getTypeList()) {
76         if (needsConversion(t.second)) {
77           result = true;
78           break;
79         }
80       }
81       visitedTypes.pop_back();
82       return result;
83     }
84     if (auto boxTy = ty.dyn_cast<BoxType>())
85       return needsConversion(boxTy.getEleTy());
86     if (isa_ref_type(ty))
87       return needsConversion(unwrapRefType(ty));
88     if (auto t = ty.dyn_cast<SequenceType>())
89       return needsConversion(unwrapSequenceType(ty));
90     return false;
91   }
92 
93   BoxprocTypeRewriter(mlir::Location location) : loc{location} {
94     addConversion([](mlir::Type ty) { return ty; });
95     addConversion(
96         [&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); });
97     addConversion([&](mlir::TupleType tupTy) {
98       llvm::SmallVector<mlir::Type> memTys;
99       for (auto ty : tupTy.getTypes())
100         memTys.push_back(convertType(ty));
101       return mlir::TupleType::get(tupTy.getContext(), memTys);
102     });
103     addConversion([&](mlir::FunctionType funcTy) {
104       llvm::SmallVector<mlir::Type> inTys;
105       llvm::SmallVector<mlir::Type> resTys;
106       for (auto ty : funcTy.getInputs())
107         inTys.push_back(convertType(ty));
108       for (auto ty : funcTy.getResults())
109         resTys.push_back(convertType(ty));
110       return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys);
111     });
112     addConversion([&](ReferenceType ty) {
113       return ReferenceType::get(convertType(ty.getEleTy()));
114     });
115     addConversion([&](PointerType ty) {
116       return PointerType::get(convertType(ty.getEleTy()));
117     });
118     addConversion(
119         [&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); });
120     addConversion(
121         [&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); });
122     addConversion([&](SequenceType ty) {
123       // TODO: add ty.getLayoutMap() as needed.
124       return SequenceType::get(ty.getShape(), convertType(ty.getEleTy()));
125     });
126     addConversion([&](RecordType ty) -> mlir::Type {
127       if (!needsConversion(ty))
128         return ty;
129       auto rec = RecordType::get(ty.getContext(),
130                                  ty.getName().str() + boxprocSuffix.str());
131       if (rec.isFinalized())
132         return rec;
133       std::vector<RecordType::TypePair> ps = ty.getLenParamList();
134       std::vector<RecordType::TypePair> cs;
135       for (auto t : ty.getTypeList()) {
136         if (needsConversion(t.second))
137           cs.emplace_back(t.first, convertType(t.second));
138         else
139           cs.emplace_back(t.first, t.second);
140       }
141       rec.finalize(ps, cs);
142       return rec;
143     });
144     addArgumentMaterialization(materializeProcedure);
145     addSourceMaterialization(materializeProcedure);
146     addTargetMaterialization(materializeProcedure);
147   }
148 
149   static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
150                                           BoxProcType type,
151                                           mlir::ValueRange inputs,
152                                           mlir::Location loc) {
153     assert(inputs.size() == 1);
154     return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
155                                      inputs[0]);
156   }
157 
158   void setLocation(mlir::Location location) { loc = location; }
159 
160 private:
161   llvm::SmallVector<mlir::Type> visitedTypes;
162   mlir::Location loc;
163 };
164 
165 /// A `boxproc` is an abstraction for a Fortran procedure reference. Typically,
166 /// Fortran procedures can be referenced directly through a function pointer.
167 /// However, Fortran has one-level dynamic scoping between a host procedure and
168 /// its internal procedures. This allows internal procedures to directly access
169 /// and modify the state of the host procedure's variables.
170 ///
171 /// There are any number of possible implementations possible.
172 ///
173 /// The implementation used here is to convert `boxproc` values to function
174 /// pointers everywhere. If a `boxproc` value includes a frame pointer to the
175 /// host procedure's data, then a thunk will be created at runtime to capture
176 /// the frame pointer during execution. In LLVM IR, the frame pointer is
177 /// designated with the `nest` attribute. The thunk's address will then be used
178 /// as the call target instead of the original function's address directly.
179 class BoxedProcedurePass
180     : public fir::impl::BoxedProcedurePassBase<BoxedProcedurePass> {
181 public:
182   BoxedProcedurePass() { options = {true}; }
183   BoxedProcedurePass(bool useThunks) { options = {useThunks}; }
184 
185   inline mlir::ModuleOp getModule() { return getOperation(); }
186 
187   void runOnOperation() override final {
188     if (options.useThunks) {
189       auto *context = &getContext();
190       mlir::IRRewriter rewriter(context);
191       BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
192       mlir::Dialect *firDialect = context->getLoadedDialect("fir");
193       getModule().walk([&](mlir::Operation *op) {
194         typeConverter.setLocation(op->getLoc());
195         if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
196           auto ty = addr.getVal().getType();
197           if (typeConverter.needsConversion(ty) ||
198               ty.isa<mlir::FunctionType>()) {
199             // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
200             // or function type to be `fir.convert` ops.
201             rewriter.setInsertionPoint(addr);
202             rewriter.replaceOpWithNewOp<ConvertOp>(
203                 addr, typeConverter.convertType(addr.getType()), addr.getVal());
204           }
205         } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
206           mlir::FunctionType ty = func.getFunctionType();
207           if (typeConverter.needsConversion(ty)) {
208             rewriter.startRootUpdate(func);
209             auto toTy =
210                 typeConverter.convertType(ty).cast<mlir::FunctionType>();
211             if (!func.empty())
212               for (auto e : llvm::enumerate(toTy.getInputs())) {
213                 unsigned i = e.index();
214                 auto &block = func.front();
215                 block.insertArgument(i, e.value(), func.getLoc());
216                 block.getArgument(i + 1).replaceAllUsesWith(
217                     block.getArgument(i));
218                 block.eraseArgument(i + 1);
219               }
220             func.setType(toTy);
221             rewriter.finalizeRootUpdate(func);
222           }
223         } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
224           // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
225           // as required.
226           mlir::Type toTy = embox.getType().cast<BoxProcType>().getEleTy();
227           rewriter.setInsertionPoint(embox);
228           if (embox.getHost()) {
229             // Create the thunk.
230             auto module = embox->getParentOfType<mlir::ModuleOp>();
231             fir::KindMapping kindMap = getKindMapping(module);
232             FirOpBuilder builder(rewriter, kindMap);
233             auto loc = embox.getLoc();
234             mlir::Type i8Ty = builder.getI8Type();
235             mlir::Type i8Ptr = builder.getRefType(i8Ty);
236             mlir::Type buffTy = SequenceType::get({32}, i8Ty);
237             auto buffer = builder.create<AllocaOp>(loc, buffTy);
238             mlir::Value closure =
239                 builder.createConvert(loc, i8Ptr, embox.getHost());
240             mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer);
241             mlir::Value func =
242                 builder.createConvert(loc, i8Ptr, embox.getFunc());
243             builder.create<fir::CallOp>(
244                 loc, factory::getLlvmInitTrampoline(builder),
245                 llvm::ArrayRef<mlir::Value>{tramp, func, closure});
246             auto adjustCall = builder.create<fir::CallOp>(
247                 loc, factory::getLlvmAdjustTrampoline(builder),
248                 llvm::ArrayRef<mlir::Value>{tramp});
249             rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
250                                                    adjustCall.getResult(0));
251           } else {
252             // Just forward the function as a pointer.
253             rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
254                                                    embox.getFunc());
255           }
256         } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
257           auto ty = global.getType();
258           if (typeConverter.needsConversion(ty)) {
259             rewriter.startRootUpdate(global);
260             auto toTy = typeConverter.convertType(ty);
261             global.setType(toTy);
262             rewriter.finalizeRootUpdate(global);
263           }
264         } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
265           auto ty = mem.getType();
266           if (typeConverter.needsConversion(ty)) {
267             rewriter.setInsertionPoint(mem);
268             auto toTy = typeConverter.convertType(unwrapRefType(ty));
269             bool isPinned = mem.getPinned();
270             llvm::StringRef uniqName =
271                 mem.getUniqName().value_or(llvm::StringRef());
272             llvm::StringRef bindcName =
273                 mem.getBindcName().value_or(llvm::StringRef());
274             rewriter.replaceOpWithNewOp<AllocaOp>(
275                 mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
276                 mem.getShape());
277           }
278         } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
279           auto ty = mem.getType();
280           if (typeConverter.needsConversion(ty)) {
281             rewriter.setInsertionPoint(mem);
282             auto toTy = typeConverter.convertType(unwrapRefType(ty));
283             llvm::StringRef uniqName =
284                 mem.getUniqName().value_or(llvm::StringRef());
285             llvm::StringRef bindcName =
286                 mem.getBindcName().value_or(llvm::StringRef());
287             rewriter.replaceOpWithNewOp<AllocMemOp>(
288                 mem, toTy, uniqName, bindcName, mem.getTypeparams(),
289                 mem.getShape());
290           }
291         } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
292           auto ty = coor.getType();
293           mlir::Type baseTy = coor.getBaseType();
294           if (typeConverter.needsConversion(ty) ||
295               typeConverter.needsConversion(baseTy)) {
296             rewriter.setInsertionPoint(coor);
297             auto toTy = typeConverter.convertType(ty);
298             auto toBaseTy = typeConverter.convertType(baseTy);
299             rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(),
300                                                       coor.getCoor(), toBaseTy);
301           }
302         } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
303           auto ty = index.getType();
304           mlir::Type onTy = index.getOnType();
305           if (typeConverter.needsConversion(ty) ||
306               typeConverter.needsConversion(onTy)) {
307             rewriter.setInsertionPoint(index);
308             auto toTy = typeConverter.convertType(ty);
309             auto toOnTy = typeConverter.convertType(onTy);
310             rewriter.replaceOpWithNewOp<FieldIndexOp>(
311                 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
312           }
313         } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
314           auto ty = index.getType();
315           mlir::Type onTy = index.getOnType();
316           if (typeConverter.needsConversion(ty) ||
317               typeConverter.needsConversion(onTy)) {
318             rewriter.setInsertionPoint(index);
319             auto toTy = typeConverter.convertType(ty);
320             auto toOnTy = typeConverter.convertType(onTy);
321             rewriter.replaceOpWithNewOp<LenParamIndexOp>(
322                 mem, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
323           }
324         } else if (op->getDialect() == firDialect) {
325           rewriter.startRootUpdate(op);
326           for (auto i : llvm::enumerate(op->getResultTypes()))
327             if (typeConverter.needsConversion(i.value())) {
328               auto toTy = typeConverter.convertType(i.value());
329               op->getResult(i.index()).setType(toTy);
330             }
331           rewriter.finalizeRootUpdate(op);
332         }
333       });
334     }
335   }
336 
337 private:
338   BoxedProcedureOptions options;
339 };
340 } // namespace
341 
342 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass() {
343   return std::make_unique<BoxedProcedurePass>();
344 }
345 
346 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass(bool useThunks) {
347   return std::make_unique<BoxedProcedurePass>(useThunks);
348 }
349