xref: /llvm-project/flang/lib/Optimizer/Transforms/AbstractResult.cpp (revision 4ddc756bccb34f3d07e30c9ca96bba32cb0cf4f9)
1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/Todo.h"
11 #include "flang/Optimizer/Dialect/FIRDialect.h"
12 #include "flang/Optimizer/Dialect/FIROps.h"
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
15 #include "flang/Optimizer/Transforms/Passes.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Diagnostics.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Pass/PassManager.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace fir {
24 #define GEN_PASS_DEF_ABSTRACTRESULTOPT
25 #include "flang/Optimizer/Transforms/Passes.h.inc"
26 } // namespace fir
27 
28 #define DEBUG_TYPE "flang-abstract-result-opt"
29 
30 using namespace mlir;
31 
32 namespace fir {
33 namespace {
34 
35 static mlir::Type getResultArgumentType(mlir::Type resultType,
36                                         bool shouldBoxResult) {
37   return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
38       .Case<fir::SequenceType, fir::RecordType>(
39           [&](mlir::Type type) -> mlir::Type {
40             if (shouldBoxResult)
41               return fir::BoxType::get(type);
42             return fir::ReferenceType::get(type);
43           })
44       .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type {
45         return fir::ReferenceType::get(type);
46       })
47       .Default([](mlir::Type) -> mlir::Type {
48         llvm_unreachable("bad abstract result type");
49       });
50 }
51 
52 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
53                                              bool shouldBoxResult) {
54   auto resultType = funcTy.getResult(0);
55   auto argTy = getResultArgumentType(resultType, shouldBoxResult);
56   llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
57   newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
58   return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
59                                  /*resultTypes=*/{});
60 }
61 
62 static mlir::Type getVoidPtrType(mlir::MLIRContext *context) {
63   return fir::ReferenceType::get(mlir::NoneType::get(context));
64 }
65 
66 /// This is for function result types that are of type C_PTR from ISO_C_BINDING.
67 /// Follow the ABI for interoperability with C.
68 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
69   assert(fir::isa_builtin_cptr_type(funcTy.getResult(0)));
70   llvm::SmallVector<mlir::Type> outputTypes{
71       getVoidPtrType(funcTy.getContext())};
72   return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
73                                  outputTypes);
74 }
75 
76 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
77   return mlir::isa<fir::SequenceType, fir::RecordType>(resultType) &&
78          shouldBoxResult;
79 }
80 
81 template <typename Op>
82 class CallConversion : public mlir::OpRewritePattern<Op> {
83 public:
84   using mlir::OpRewritePattern<Op>::OpRewritePattern;
85 
86   CallConversion(mlir::MLIRContext *context, bool shouldBoxResult)
87       : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {}
88 
89   llvm::LogicalResult
90   matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
91     auto loc = op.getLoc();
92     auto result = op->getResult(0);
93     if (!result.hasOneUse()) {
94       mlir::emitError(loc,
95                       "calls with abstract result must have exactly one user");
96       return mlir::failure();
97     }
98     auto saveResult =
99         mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
100     if (!saveResult) {
101       mlir::emitError(
102           loc, "calls with abstract result must be used in fir.save_result");
103       return mlir::failure();
104     }
105     auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
106     auto buffer = saveResult.getMemref();
107     mlir::Value arg = buffer;
108     if (mustEmboxResult(result.getType(), shouldBoxResult))
109       arg = rewriter.create<fir::EmboxOp>(
110           loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
111           saveResult.getTypeparams());
112 
113     llvm::SmallVector<mlir::Type> newResultTypes;
114     bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
115     if (isResultBuiltinCPtr)
116       newResultTypes.emplace_back(getVoidPtrType(result.getContext()));
117 
118     Op newOp;
119     // fir::CallOp specific handling.
120     if constexpr (std::is_same_v<Op, fir::CallOp>) {
121       if (op.getCallee()) {
122         llvm::SmallVector<mlir::Value> newOperands;
123         if (!isResultBuiltinCPtr)
124           newOperands.emplace_back(arg);
125         newOperands.append(op.getOperands().begin(), op.getOperands().end());
126         newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(),
127                                              newResultTypes, newOperands);
128       } else {
129         // Indirect calls.
130         llvm::SmallVector<mlir::Type> newInputTypes;
131         if (!isResultBuiltinCPtr)
132           newInputTypes.emplace_back(argType);
133         for (auto operand : op.getOperands().drop_front())
134           newInputTypes.push_back(operand.getType());
135         auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes,
136                                                  newResultTypes);
137 
138         llvm::SmallVector<mlir::Value> newOperands;
139         newOperands.push_back(
140             rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0)));
141         if (!isResultBuiltinCPtr)
142           newOperands.push_back(arg);
143         newOperands.append(op.getOperands().begin() + 1,
144                            op.getOperands().end());
145         newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
146                                              newResultTypes, newOperands);
147       }
148     }
149 
150     // fir::DispatchOp specific handling.
151     if constexpr (std::is_same_v<Op, fir::DispatchOp>) {
152       llvm::SmallVector<mlir::Value> newOperands;
153       if (!isResultBuiltinCPtr)
154         newOperands.emplace_back(arg);
155       unsigned passArgShift = newOperands.size();
156       newOperands.append(op.getOperands().begin() + 1, op.getOperands().end());
157       mlir::IntegerAttr passArgPos;
158       if (op.getPassArgPos())
159         passArgPos =
160             rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift);
161       newOp = rewriter.create<fir::DispatchOp>(
162           loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
163           op.getOperands()[0], newOperands, passArgPos,
164           op.getProcedureAttrsAttr());
165     }
166 
167     if (isResultBuiltinCPtr) {
168       mlir::Value save = saveResult.getMemref();
169       auto module = op->template getParentOfType<mlir::ModuleOp>();
170       FirOpBuilder builder(rewriter, module);
171       mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
172           builder, loc, save, result.getType());
173       builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr);
174     }
175     op->dropAllReferences();
176     rewriter.eraseOp(op);
177     return mlir::success();
178   }
179 
180 private:
181   bool shouldBoxResult;
182 };
183 
184 class SaveResultOpConversion
185     : public mlir::OpRewritePattern<fir::SaveResultOp> {
186 public:
187   using OpRewritePattern::OpRewritePattern;
188   SaveResultOpConversion(mlir::MLIRContext *context)
189       : OpRewritePattern(context) {}
190   llvm::LogicalResult
191   matchAndRewrite(fir::SaveResultOp op,
192                   mlir::PatternRewriter &rewriter) const override {
193     rewriter.eraseOp(op);
194     return mlir::success();
195   }
196 };
197 
198 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
199 public:
200   using OpRewritePattern::OpRewritePattern;
201   ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
202       : OpRewritePattern(context), newArg{newArg} {}
203   llvm::LogicalResult
204   matchAndRewrite(mlir::func::ReturnOp ret,
205                   mlir::PatternRewriter &rewriter) const override {
206     auto loc = ret.getLoc();
207     rewriter.setInsertionPoint(ret);
208     mlir::Value resultValue = ret.getOperand(0);
209     fir::LoadOp resultLoad;
210     mlir::Value resultStorage;
211     // Identify result local storage.
212     if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) {
213       resultLoad = load;
214       resultStorage = load.getMemref();
215       // The result alloca may be behind a fir.declare, if any.
216       if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>())
217         resultStorage = declare.getMemref();
218     }
219     // Replace old local storage with new storage argument, unless
220     // the derived type is C_PTR/C_FUN_PTR, in which case the return
221     // type is updated to return void* (no new argument is passed).
222     if (fir::isa_builtin_cptr_type(resultValue.getType())) {
223       auto module = ret->getParentOfType<mlir::ModuleOp>();
224       FirOpBuilder builder(rewriter, module);
225       mlir::Value cptr = resultValue;
226       if (resultLoad) {
227         // Replace whole derived type load by component load.
228         cptr = resultLoad.getMemref();
229         rewriter.setInsertionPoint(resultLoad);
230       }
231       mlir::Value newResultValue =
232           fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr);
233       newResultValue = builder.createConvert(
234           loc, getVoidPtrType(ret.getContext()), newResultValue);
235       rewriter.setInsertionPoint(ret);
236       rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
237           ret, mlir::ValueRange{newResultValue});
238     } else if (resultStorage) {
239       resultStorage.replaceAllUsesWith(newArg);
240       rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
241     } else {
242       // The result storage may have been optimized out by a memory to
243       // register pass, this is possible for fir.box results, or fir.record
244       // with no length parameters. Simply store the result in the result
245       // storage. at the return point.
246       rewriter.create<fir::StoreOp>(loc, resultValue, newArg);
247       rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
248     }
249     // Delete result old local storage if unused.
250     if (resultStorage)
251       if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>())
252         if (alloc->use_empty())
253           rewriter.eraseOp(alloc);
254     return mlir::success();
255   }
256 
257 private:
258   mlir::Value newArg;
259 };
260 
261 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
262 public:
263   using OpRewritePattern::OpRewritePattern;
264   AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
265       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
266   llvm::LogicalResult
267   matchAndRewrite(fir::AddrOfOp addrOf,
268                   mlir::PatternRewriter &rewriter) const override {
269     auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType());
270     mlir::FunctionType newFuncTy;
271     if (oldFuncTy.getNumResults() != 0 &&
272         fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
273       newFuncTy = getCPtrFunctionType(oldFuncTy);
274     else
275       newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
276     auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
277                                                     addrOf.getSymbol());
278     // Rather than converting all op a function pointer might transit through
279     // (e.g calls, stores, loads, converts...), cast new type to the abstract
280     // type. A conversion will be added when calling indirect calls of abstract
281     // types.
282     rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
283     return mlir::success();
284   }
285 
286 private:
287   bool shouldBoxResult;
288 };
289 
290 class AbstractResultOpt
291     : public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
292 public:
293   using fir::impl::AbstractResultOptBase<
294       AbstractResultOpt>::AbstractResultOptBase;
295 
296   void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
297                               mlir::RewritePatternSet &patterns,
298                               mlir::ConversionTarget &target) {
299     auto loc = func.getLoc();
300     auto *context = &getContext();
301     // Convert function type itself if it has an abstract result.
302     auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
303     if (hasAbstractResult(funcTy)) {
304       if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
305         func.setType(getCPtrFunctionType(funcTy));
306         patterns.insert<ReturnOpConversion>(context, mlir::Value{});
307         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
308             [](mlir::func::ReturnOp ret) {
309               mlir::Type retTy = ret.getOperand(0).getType();
310               return !fir::isa_builtin_cptr_type(retTy);
311             });
312         return;
313       }
314       if (!func.empty()) {
315         // Insert new argument.
316         mlir::OpBuilder rewriter(context);
317         auto resultType = funcTy.getResult(0);
318         auto argTy = getResultArgumentType(resultType, shouldBoxResult);
319         func.insertArgument(0u, argTy, {}, loc);
320         func.eraseResult(0u);
321         mlir::Value newArg = func.getArgument(0u);
322         if (mustEmboxResult(resultType, shouldBoxResult)) {
323           auto bufferType = fir::ReferenceType::get(resultType);
324           rewriter.setInsertionPointToStart(&func.front());
325           newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
326         }
327         patterns.insert<ReturnOpConversion>(context, newArg);
328         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
329             [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); });
330         assert(func.getFunctionType() ==
331                getNewFunctionType(funcTy, shouldBoxResult));
332       } else {
333         llvm::SmallVector<mlir::DictionaryAttr> allArgs;
334         func.getAllArgAttrs(allArgs);
335         allArgs.insert(allArgs.begin(),
336                        mlir::DictionaryAttr::get(func->getContext()));
337         func.setType(getNewFunctionType(funcTy, shouldBoxResult));
338         func.setAllArgAttrs(allArgs);
339       }
340     }
341   }
342 
343   inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
344     return mlir::TypeSwitch<mlir::Type, bool>(type)
345         .Case([](fir::BoxProcType boxProc) {
346           return fir::hasAbstractResult(
347               mlir::cast<mlir::FunctionType>(boxProc.getEleTy()));
348         })
349         .Case([](fir::PointerType pointer) {
350           return fir::hasAbstractResult(
351               mlir::cast<mlir::FunctionType>(pointer.getEleTy()));
352         })
353         .Default([](auto &&) { return false; });
354   }
355 
356   void runOnSpecificOperation(fir::GlobalOp global, bool,
357                               mlir::RewritePatternSet &,
358                               mlir::ConversionTarget &) {
359     if (containsFunctionTypeWithAbstractResult(global.getType())) {
360       TODO(global->getLoc(), "support for procedure pointers");
361     }
362   }
363 
364   /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work.
365   void runOnModule() {
366     mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation());
367 
368     auto pass = std::make_unique<AbstractResultOpt>();
369     pass->copyOptionValuesFrom(this);
370     mlir::OpPassManager pipeline;
371     pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()});
372 
373     // Run the pass on all operations directly nested inside of the ModuleOp
374     // we can't just call runOnSpecificOperation here because the pass
375     // implementation only works when scoped to a particular func.func or
376     // fir.global
377     for (mlir::Region &region : mod->getRegions()) {
378       for (mlir::Block &block : region.getBlocks()) {
379         for (mlir::Operation &op : block.getOperations()) {
380           if (mlir::failed(runPipeline(pipeline, &op))) {
381             mlir::emitError(op.getLoc(), "Failed to run abstract result pass");
382             signalPassFailure();
383             return;
384           }
385         }
386       }
387     }
388   }
389 
390   void runOnOperation() override {
391     auto *context = &this->getContext();
392     mlir::Operation *op = this->getOperation();
393     if (mlir::isa<mlir::ModuleOp>(op)) {
394       runOnModule();
395       return;
396     }
397 
398     mlir::RewritePatternSet patterns(context);
399     mlir::ConversionTarget target = *context;
400     const bool shouldBoxResult = this->passResultAsBox.getValue();
401 
402     mlir::TypeSwitch<mlir::Operation *, void>(op)
403         .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
404           runOnSpecificOperation(op, shouldBoxResult, patterns, target);
405         });
406 
407     // Convert the calls and, if needed,  the ReturnOp in the function body.
408     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
409                            mlir::func::FuncDialect>();
410     target.addIllegalOp<fir::SaveResultOp>();
411     target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
412       return !hasAbstractResult(call.getFunctionType());
413     });
414     target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
415       if (auto funTy = mlir::dyn_cast<mlir::FunctionType>(addrOf.getType()))
416         return !hasAbstractResult(funTy);
417       return true;
418     });
419     target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
420       return !hasAbstractResult(dispatch.getFunctionType());
421     });
422 
423     patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
424     patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
425     patterns.insert<SaveResultOpConversion>(context);
426     patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
427     if (mlir::failed(
428             mlir::applyPartialConversion(op, target, std::move(patterns)))) {
429       mlir::emitError(op->getLoc(), "error in converting abstract results\n");
430       this->signalPassFailure();
431     }
432   }
433 };
434 
435 } // end anonymous namespace
436 } // namespace fir
437