xref: /llvm-project/flang/lib/Optimizer/Transforms/AbstractResult.cpp (revision 1ead51a86c6c746a1b9948ca1ee142df223ffebd)
1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/Todo.h"
11 #include "flang/Optimizer/Dialect/FIRDialect.h"
12 #include "flang/Optimizer/Dialect/FIROps.h"
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
15 #include "flang/Optimizer/Transforms/Passes.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Diagnostics.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Pass/PassManager.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace fir {
24 #define GEN_PASS_DEF_ABSTRACTRESULTOPT
25 #include "flang/Optimizer/Transforms/Passes.h.inc"
26 } // namespace fir
27 
28 #define DEBUG_TYPE "flang-abstract-result-opt"
29 
30 using namespace mlir;
31 
32 namespace fir {
33 namespace {
34 
35 static mlir::Type getResultArgumentType(mlir::Type resultType,
36                                         bool shouldBoxResult) {
37   return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
38       .Case<fir::SequenceType, fir::RecordType>(
39           [&](mlir::Type type) -> mlir::Type {
40             if (shouldBoxResult)
41               return fir::BoxType::get(type);
42             return fir::ReferenceType::get(type);
43           })
44       .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type {
45         return fir::ReferenceType::get(type);
46       })
47       .Default([](mlir::Type) -> mlir::Type {
48         llvm_unreachable("bad abstract result type");
49       });
50 }
51 
52 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
53                                              bool shouldBoxResult) {
54   auto resultType = funcTy.getResult(0);
55   auto argTy = getResultArgumentType(resultType, shouldBoxResult);
56   llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
57   newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
58   return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
59                                  /*resultTypes=*/{});
60 }
61 
62 static mlir::Type getVoidPtrType(mlir::MLIRContext *context) {
63   return fir::ReferenceType::get(mlir::NoneType::get(context));
64 }
65 
66 /// This is for function result types that are of type C_PTR from ISO_C_BINDING.
67 /// Follow the ABI for interoperability with C.
68 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
69   assert(fir::isa_builtin_cptr_type(funcTy.getResult(0)));
70   llvm::SmallVector<mlir::Type> outputTypes{
71       getVoidPtrType(funcTy.getContext())};
72   return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
73                                  outputTypes);
74 }
75 
76 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
77   return mlir::isa<fir::SequenceType, fir::RecordType>(resultType) &&
78          shouldBoxResult;
79 }
80 
81 template <typename Op>
82 class CallConversion : public mlir::OpRewritePattern<Op> {
83 public:
84   using mlir::OpRewritePattern<Op>::OpRewritePattern;
85 
86   CallConversion(mlir::MLIRContext *context, bool shouldBoxResult)
87       : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {}
88 
89   llvm::LogicalResult
90   matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
91     auto loc = op.getLoc();
92     auto result = op->getResult(0);
93     if (!result.hasOneUse()) {
94       mlir::emitError(loc,
95                       "calls with abstract result must have exactly one user");
96       return mlir::failure();
97     }
98     auto saveResult =
99         mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
100     if (!saveResult) {
101       mlir::emitError(
102           loc, "calls with abstract result must be used in fir.save_result");
103       return mlir::failure();
104     }
105     auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
106     auto buffer = saveResult.getMemref();
107     mlir::Value arg = buffer;
108     if (mustEmboxResult(result.getType(), shouldBoxResult))
109       arg = rewriter.create<fir::EmboxOp>(
110           loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
111           saveResult.getTypeparams());
112 
113     llvm::SmallVector<mlir::Type> newResultTypes;
114     bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
115     if (isResultBuiltinCPtr)
116       newResultTypes.emplace_back(getVoidPtrType(result.getContext()));
117 
118     Op newOp;
119     // fir::CallOp specific handling.
120     if constexpr (std::is_same_v<Op, fir::CallOp>) {
121       if (op.getCallee()) {
122         llvm::SmallVector<mlir::Value> newOperands;
123         if (!isResultBuiltinCPtr)
124           newOperands.emplace_back(arg);
125         newOperands.append(op.getOperands().begin(), op.getOperands().end());
126         newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(),
127                                              newResultTypes, newOperands);
128       } else {
129         // Indirect calls.
130         llvm::SmallVector<mlir::Type> newInputTypes;
131         if (!isResultBuiltinCPtr)
132           newInputTypes.emplace_back(argType);
133         for (auto operand : op.getOperands().drop_front())
134           newInputTypes.push_back(operand.getType());
135         auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes,
136                                                  newResultTypes);
137 
138         llvm::SmallVector<mlir::Value> newOperands;
139         newOperands.push_back(
140             rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0)));
141         if (!isResultBuiltinCPtr)
142           newOperands.push_back(arg);
143         newOperands.append(op.getOperands().begin() + 1,
144                            op.getOperands().end());
145         newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
146                                              newResultTypes, newOperands);
147       }
148     }
149 
150     // fir::DispatchOp specific handling.
151     if constexpr (std::is_same_v<Op, fir::DispatchOp>) {
152       llvm::SmallVector<mlir::Value> newOperands;
153       if (!isResultBuiltinCPtr)
154         newOperands.emplace_back(arg);
155       unsigned passArgShift = newOperands.size();
156       newOperands.append(op.getOperands().begin() + 1, op.getOperands().end());
157 
158       fir::DispatchOp newDispatchOp;
159       if (op.getPassArgPos())
160         newOp = rewriter.create<fir::DispatchOp>(
161             loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
162             op.getOperands()[0], newOperands,
163             rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift));
164       else
165         newOp = rewriter.create<fir::DispatchOp>(
166             loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
167             op.getOperands()[0], newOperands, nullptr);
168     }
169 
170     if (isResultBuiltinCPtr) {
171       mlir::Value save = saveResult.getMemref();
172       auto module = op->template getParentOfType<mlir::ModuleOp>();
173       FirOpBuilder builder(rewriter, module);
174       mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
175           builder, loc, save, result.getType());
176       builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr);
177     }
178     op->dropAllReferences();
179     rewriter.eraseOp(op);
180     return mlir::success();
181   }
182 
183 private:
184   bool shouldBoxResult;
185 };
186 
187 class SaveResultOpConversion
188     : public mlir::OpRewritePattern<fir::SaveResultOp> {
189 public:
190   using OpRewritePattern::OpRewritePattern;
191   SaveResultOpConversion(mlir::MLIRContext *context)
192       : OpRewritePattern(context) {}
193   llvm::LogicalResult
194   matchAndRewrite(fir::SaveResultOp op,
195                   mlir::PatternRewriter &rewriter) const override {
196     rewriter.eraseOp(op);
197     return mlir::success();
198   }
199 };
200 
201 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
202 public:
203   using OpRewritePattern::OpRewritePattern;
204   ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
205       : OpRewritePattern(context), newArg{newArg} {}
206   llvm::LogicalResult
207   matchAndRewrite(mlir::func::ReturnOp ret,
208                   mlir::PatternRewriter &rewriter) const override {
209     auto loc = ret.getLoc();
210     rewriter.setInsertionPoint(ret);
211     mlir::Value resultValue = ret.getOperand(0);
212     fir::LoadOp resultLoad;
213     mlir::Value resultStorage;
214     // Identify result local storage.
215     if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) {
216       resultLoad = load;
217       resultStorage = load.getMemref();
218       // The result alloca may be behind a fir.declare, if any.
219       if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>())
220         resultStorage = declare.getMemref();
221     }
222     // Replace old local storage with new storage argument, unless
223     // the derived type is C_PTR/C_FUN_PTR, in which case the return
224     // type is updated to return void* (no new argument is passed).
225     if (fir::isa_builtin_cptr_type(resultValue.getType())) {
226       auto module = ret->getParentOfType<mlir::ModuleOp>();
227       FirOpBuilder builder(rewriter, module);
228       mlir::Value cptr = resultValue;
229       if (resultLoad) {
230         // Replace whole derived type load by component load.
231         cptr = resultLoad.getMemref();
232         rewriter.setInsertionPoint(resultLoad);
233       }
234       mlir::Value newResultValue =
235           fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr);
236       newResultValue = builder.createConvert(
237           loc, getVoidPtrType(ret.getContext()), newResultValue);
238       rewriter.setInsertionPoint(ret);
239       rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
240           ret, mlir::ValueRange{newResultValue});
241     } else if (resultStorage) {
242       resultStorage.replaceAllUsesWith(newArg);
243       rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
244     } else {
245       // The result storage may have been optimized out by a memory to
246       // register pass, this is possible for fir.box results, or fir.record
247       // with no length parameters. Simply store the result in the result
248       // storage. at the return point.
249       rewriter.create<fir::StoreOp>(loc, resultValue, newArg);
250       rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
251     }
252     // Delete result old local storage if unused.
253     if (resultStorage)
254       if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>())
255         if (alloc->use_empty())
256           rewriter.eraseOp(alloc);
257     return mlir::success();
258   }
259 
260 private:
261   mlir::Value newArg;
262 };
263 
264 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
265 public:
266   using OpRewritePattern::OpRewritePattern;
267   AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
268       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
269   llvm::LogicalResult
270   matchAndRewrite(fir::AddrOfOp addrOf,
271                   mlir::PatternRewriter &rewriter) const override {
272     auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType());
273     mlir::FunctionType newFuncTy;
274     if (oldFuncTy.getNumResults() != 0 &&
275         fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
276       newFuncTy = getCPtrFunctionType(oldFuncTy);
277     else
278       newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
279     auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
280                                                     addrOf.getSymbol());
281     // Rather than converting all op a function pointer might transit through
282     // (e.g calls, stores, loads, converts...), cast new type to the abstract
283     // type. A conversion will be added when calling indirect calls of abstract
284     // types.
285     rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
286     return mlir::success();
287   }
288 
289 private:
290   bool shouldBoxResult;
291 };
292 
293 class AbstractResultOpt
294     : public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
295 public:
296   using fir::impl::AbstractResultOptBase<
297       AbstractResultOpt>::AbstractResultOptBase;
298 
299   void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
300                               mlir::RewritePatternSet &patterns,
301                               mlir::ConversionTarget &target) {
302     auto loc = func.getLoc();
303     auto *context = &getContext();
304     // Convert function type itself if it has an abstract result.
305     auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
306     if (hasAbstractResult(funcTy)) {
307       if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
308         func.setType(getCPtrFunctionType(funcTy));
309         patterns.insert<ReturnOpConversion>(context, mlir::Value{});
310         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
311             [](mlir::func::ReturnOp ret) {
312               mlir::Type retTy = ret.getOperand(0).getType();
313               return !fir::isa_builtin_cptr_type(retTy);
314             });
315         return;
316       }
317       if (!func.empty()) {
318         // Insert new argument.
319         mlir::OpBuilder rewriter(context);
320         auto resultType = funcTy.getResult(0);
321         auto argTy = getResultArgumentType(resultType, shouldBoxResult);
322         func.insertArgument(0u, argTy, {}, loc);
323         func.eraseResult(0u);
324         mlir::Value newArg = func.getArgument(0u);
325         if (mustEmboxResult(resultType, shouldBoxResult)) {
326           auto bufferType = fir::ReferenceType::get(resultType);
327           rewriter.setInsertionPointToStart(&func.front());
328           newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
329         }
330         patterns.insert<ReturnOpConversion>(context, newArg);
331         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
332             [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); });
333         assert(func.getFunctionType() ==
334                getNewFunctionType(funcTy, shouldBoxResult));
335       } else {
336         llvm::SmallVector<mlir::DictionaryAttr> allArgs;
337         func.getAllArgAttrs(allArgs);
338         allArgs.insert(allArgs.begin(),
339                        mlir::DictionaryAttr::get(func->getContext()));
340         func.setType(getNewFunctionType(funcTy, shouldBoxResult));
341         func.setAllArgAttrs(allArgs);
342       }
343     }
344   }
345 
346   inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
347     return mlir::TypeSwitch<mlir::Type, bool>(type)
348         .Case([](fir::BoxProcType boxProc) {
349           return fir::hasAbstractResult(
350               mlir::cast<mlir::FunctionType>(boxProc.getEleTy()));
351         })
352         .Case([](fir::PointerType pointer) {
353           return fir::hasAbstractResult(
354               mlir::cast<mlir::FunctionType>(pointer.getEleTy()));
355         })
356         .Default([](auto &&) { return false; });
357   }
358 
359   void runOnSpecificOperation(fir::GlobalOp global, bool,
360                               mlir::RewritePatternSet &,
361                               mlir::ConversionTarget &) {
362     if (containsFunctionTypeWithAbstractResult(global.getType())) {
363       TODO(global->getLoc(), "support for procedure pointers");
364     }
365   }
366 
367   /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work.
368   void runOnModule() {
369     mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation());
370 
371     auto pass = std::make_unique<AbstractResultOpt>();
372     pass->copyOptionValuesFrom(this);
373     mlir::OpPassManager pipeline;
374     pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()});
375 
376     // Run the pass on all operations directly nested inside of the ModuleOp
377     // we can't just call runOnSpecificOperation here because the pass
378     // implementation only works when scoped to a particular func.func or
379     // fir.global
380     for (mlir::Region &region : mod->getRegions()) {
381       for (mlir::Block &block : region.getBlocks()) {
382         for (mlir::Operation &op : block.getOperations()) {
383           if (mlir::failed(runPipeline(pipeline, &op))) {
384             mlir::emitError(op.getLoc(), "Failed to run abstract result pass");
385             signalPassFailure();
386             return;
387           }
388         }
389       }
390     }
391   }
392 
393   void runOnOperation() override {
394     auto *context = &this->getContext();
395     mlir::Operation *op = this->getOperation();
396     if (mlir::isa<mlir::ModuleOp>(op)) {
397       runOnModule();
398       return;
399     }
400 
401     mlir::RewritePatternSet patterns(context);
402     mlir::ConversionTarget target = *context;
403     const bool shouldBoxResult = this->passResultAsBox.getValue();
404 
405     mlir::TypeSwitch<mlir::Operation *, void>(op)
406         .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
407           runOnSpecificOperation(op, shouldBoxResult, patterns, target);
408         });
409 
410     // Convert the calls and, if needed,  the ReturnOp in the function body.
411     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
412                            mlir::func::FuncDialect>();
413     target.addIllegalOp<fir::SaveResultOp>();
414     target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
415       return !hasAbstractResult(call.getFunctionType());
416     });
417     target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
418       if (auto funTy = mlir::dyn_cast<mlir::FunctionType>(addrOf.getType()))
419         return !hasAbstractResult(funTy);
420       return true;
421     });
422     target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
423       return !hasAbstractResult(dispatch.getFunctionType());
424     });
425 
426     patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
427     patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
428     patterns.insert<SaveResultOpConversion>(context);
429     patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
430     if (mlir::failed(
431             mlir::applyPartialConversion(op, target, std::move(patterns)))) {
432       mlir::emitError(op->getLoc(), "error in converting abstract results\n");
433       this->signalPassFailure();
434     }
435   }
436 };
437 
438 } // end anonymous namespace
439 } // namespace fir
440