xref: /llvm-project/flang/lib/Optimizer/Transforms/AbstractResult.cpp (revision 75623bfe1b89fa84cf2b9e4fb4c9f7560e01d4a6)
1b0eef1eeSValentin Clement //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2b0eef1eeSValentin Clement //
3b0eef1eeSValentin Clement // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b0eef1eeSValentin Clement // See https://llvm.org/LICENSE.txt for license information.
5b0eef1eeSValentin Clement // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b0eef1eeSValentin Clement //
7b0eef1eeSValentin Clement //===----------------------------------------------------------------------===//
8b0eef1eeSValentin Clement 
9c336e72cSPeixin-Qiao #include "flang/Optimizer/Builder/FIRBuilder.h"
109de831aaSJean Perier #include "flang/Optimizer/Builder/Todo.h"
11b0eef1eeSValentin Clement #include "flang/Optimizer/Dialect/FIRDialect.h"
12b0eef1eeSValentin Clement #include "flang/Optimizer/Dialect/FIROps.h"
13b0eef1eeSValentin Clement #include "flang/Optimizer/Dialect/FIRType.h"
14b07ef9e7SRenaud-K #include "flang/Optimizer/Dialect/Support/FIRContext.h"
15b0eef1eeSValentin Clement #include "flang/Optimizer/Transforms/Passes.h"
1623aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
175522d246SValentin Clement (バレンタイン クレメン) #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18b0eef1eeSValentin Clement #include "mlir/IR/Diagnostics.h"
19b0eef1eeSValentin Clement #include "mlir/Pass/Pass.h"
20bfd19445STom Eccles #include "mlir/Pass/PassManager.h"
21b0eef1eeSValentin Clement #include "mlir/Transforms/DialectConversion.h"
22b0eef1eeSValentin Clement #include "llvm/ADT/TypeSwitch.h"
23b0eef1eeSValentin Clement 
2467d0d7acSMichele Scuttari namespace fir {
25bfd19445STom Eccles #define GEN_PASS_DEF_ABSTRACTRESULTOPT
2667d0d7acSMichele Scuttari #include "flang/Optimizer/Transforms/Passes.h.inc"
2767d0d7acSMichele Scuttari } // namespace fir
2867d0d7acSMichele Scuttari 
29b0eef1eeSValentin Clement #define DEBUG_TYPE "flang-abstract-result-opt"
30b0eef1eeSValentin Clement 
31afb34cf3SValentin Clement using namespace mlir;
32afb34cf3SValentin Clement 
33b0eef1eeSValentin Clement namespace fir {
34b0eef1eeSValentin Clement namespace {
35b0eef1eeSValentin Clement 
36367c3c96SjeanPerier // Helper to only build the symbol table if needed because its build time is
37367c3c96SjeanPerier // linear on the number of symbols in the module.
38367c3c96SjeanPerier struct LazySymbolTable {
39367c3c96SjeanPerier   LazySymbolTable(mlir::Operation *op)
40367c3c96SjeanPerier       : module{op->getParentOfType<mlir::ModuleOp>()} {}
41367c3c96SjeanPerier   void build() {
42367c3c96SjeanPerier     if (table)
43367c3c96SjeanPerier       return;
44367c3c96SjeanPerier     table = std::make_unique<mlir::SymbolTable>(module);
45367c3c96SjeanPerier   }
46367c3c96SjeanPerier 
47367c3c96SjeanPerier   template <typename T>
48367c3c96SjeanPerier   T lookup(llvm::StringRef name) {
49367c3c96SjeanPerier     build();
50367c3c96SjeanPerier     return table->lookup<T>(name);
51367c3c96SjeanPerier   }
52367c3c96SjeanPerier 
53367c3c96SjeanPerier private:
54367c3c96SjeanPerier   std::unique_ptr<mlir::SymbolTable> table;
55367c3c96SjeanPerier   mlir::ModuleOp module;
56367c3c96SjeanPerier };
57367c3c96SjeanPerier 
58367c3c96SjeanPerier bool hasScalarDerivedResult(mlir::FunctionType funTy) {
59367c3c96SjeanPerier   // C_PTR/C_FUNPTR are results to void* in this pass, do not consider
60367c3c96SjeanPerier   // them as normal derived types.
61367c3c96SjeanPerier   return funTy.getNumResults() == 1 &&
62367c3c96SjeanPerier          mlir::isa<fir::RecordType>(funTy.getResult(0)) &&
63367c3c96SjeanPerier          !fir::isa_builtin_cptr_type(funTy.getResult(0));
64367c3c96SjeanPerier }
65367c3c96SjeanPerier 
66b0eef1eeSValentin Clement static mlir::Type getResultArgumentType(mlir::Type resultType,
67ea1cdb58SDaniil Dudkin                                         bool shouldBoxResult) {
68b0eef1eeSValentin Clement   return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
69b0eef1eeSValentin Clement       .Case<fir::SequenceType, fir::RecordType>(
70b0eef1eeSValentin Clement           [&](mlir::Type type) -> mlir::Type {
71ea1cdb58SDaniil Dudkin             if (shouldBoxResult)
72b0eef1eeSValentin Clement               return fir::BoxType::get(type);
73b0eef1eeSValentin Clement             return fir::ReferenceType::get(type);
74b0eef1eeSValentin Clement           })
75afb34cf3SValentin Clement       .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type {
76b0eef1eeSValentin Clement         return fir::ReferenceType::get(type);
77b0eef1eeSValentin Clement       })
78b0eef1eeSValentin Clement       .Default([](mlir::Type) -> mlir::Type {
79b0eef1eeSValentin Clement         llvm_unreachable("bad abstract result type");
80b0eef1eeSValentin Clement       });
81b0eef1eeSValentin Clement }
82b0eef1eeSValentin Clement 
83ea1cdb58SDaniil Dudkin static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
84ea1cdb58SDaniil Dudkin                                              bool shouldBoxResult) {
85b0eef1eeSValentin Clement   auto resultType = funcTy.getResult(0);
86ea1cdb58SDaniil Dudkin   auto argTy = getResultArgumentType(resultType, shouldBoxResult);
87b0eef1eeSValentin Clement   llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
88b0eef1eeSValentin Clement   newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
89b0eef1eeSValentin Clement   return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
90b0eef1eeSValentin Clement                                  /*resultTypes=*/{});
91b0eef1eeSValentin Clement }
92b0eef1eeSValentin Clement 
931ead51a8SjeanPerier static mlir::Type getVoidPtrType(mlir::MLIRContext *context) {
941ead51a8SjeanPerier   return fir::ReferenceType::get(mlir::NoneType::get(context));
951ead51a8SjeanPerier }
961ead51a8SjeanPerier 
97c336e72cSPeixin-Qiao /// This is for function result types that are of type C_PTR from ISO_C_BINDING.
98c336e72cSPeixin-Qiao /// Follow the ABI for interoperability with C.
99c336e72cSPeixin-Qiao static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
1001ead51a8SjeanPerier   assert(fir::isa_builtin_cptr_type(funcTy.getResult(0)));
1011ead51a8SjeanPerier   llvm::SmallVector<mlir::Type> outputTypes{
1021ead51a8SjeanPerier       getVoidPtrType(funcTy.getContext())};
103c336e72cSPeixin-Qiao   return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
104c336e72cSPeixin-Qiao                                  outputTypes);
105c336e72cSPeixin-Qiao }
106c336e72cSPeixin-Qiao 
107ea1cdb58SDaniil Dudkin static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
108fac349a1SChristian Sigg   return mlir::isa<fir::SequenceType, fir::RecordType>(resultType) &&
109ea1cdb58SDaniil Dudkin          shouldBoxResult;
110b0eef1eeSValentin Clement }
111b0eef1eeSValentin Clement 
112afb34cf3SValentin Clement template <typename Op>
113afb34cf3SValentin Clement class CallConversion : public mlir::OpRewritePattern<Op> {
114b0eef1eeSValentin Clement public:
115afb34cf3SValentin Clement   using mlir::OpRewritePattern<Op>::OpRewritePattern;
116afb34cf3SValentin Clement 
117afb34cf3SValentin Clement   CallConversion(mlir::MLIRContext *context, bool shouldBoxResult)
118afb34cf3SValentin Clement       : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {}
119afb34cf3SValentin Clement 
120db791b27SRamkumar Ramachandra   llvm::LogicalResult
121afb34cf3SValentin Clement   matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
122afb34cf3SValentin Clement     auto loc = op.getLoc();
123afb34cf3SValentin Clement     auto result = op->getResult(0);
124b0eef1eeSValentin Clement     if (!result.hasOneUse()) {
125b0eef1eeSValentin Clement       mlir::emitError(loc,
126b0eef1eeSValentin Clement                       "calls with abstract result must have exactly one user");
127b0eef1eeSValentin Clement       return mlir::failure();
128b0eef1eeSValentin Clement     }
129b0eef1eeSValentin Clement     auto saveResult =
130b0eef1eeSValentin Clement         mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
131b0eef1eeSValentin Clement     if (!saveResult) {
132b0eef1eeSValentin Clement       mlir::emitError(
133b0eef1eeSValentin Clement           loc, "calls with abstract result must be used in fir.save_result");
134b0eef1eeSValentin Clement       return mlir::failure();
135b0eef1eeSValentin Clement     }
136ea1cdb58SDaniil Dudkin     auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
137149ad3d5SShraiysh Vaishay     auto buffer = saveResult.getMemref();
138b0eef1eeSValentin Clement     mlir::Value arg = buffer;
139ea1cdb58SDaniil Dudkin     if (mustEmboxResult(result.getType(), shouldBoxResult))
140b0eef1eeSValentin Clement       arg = rewriter.create<fir::EmboxOp>(
141149ad3d5SShraiysh Vaishay           loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
142149ad3d5SShraiysh Vaishay           saveResult.getTypeparams());
143b0eef1eeSValentin Clement 
144b0eef1eeSValentin Clement     llvm::SmallVector<mlir::Type> newResultTypes;
145c336e72cSPeixin-Qiao     bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
1461ead51a8SjeanPerier     if (isResultBuiltinCPtr)
1471ead51a8SjeanPerier       newResultTypes.emplace_back(getVoidPtrType(result.getContext()));
148afb34cf3SValentin Clement 
1491ead51a8SjeanPerier     Op newOp;
150afb34cf3SValentin Clement     // fir::CallOp specific handling.
151afb34cf3SValentin Clement     if constexpr (std::is_same_v<Op, fir::CallOp>) {
152afb34cf3SValentin Clement       if (op.getCallee()) {
153c336e72cSPeixin-Qiao         llvm::SmallVector<mlir::Value> newOperands;
154c336e72cSPeixin-Qiao         if (!isResultBuiltinCPtr)
155c336e72cSPeixin-Qiao           newOperands.emplace_back(arg);
156afb34cf3SValentin Clement         newOperands.append(op.getOperands().begin(), op.getOperands().end());
157afb34cf3SValentin Clement         newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(),
158c336e72cSPeixin-Qiao                                              newResultTypes, newOperands);
159b0eef1eeSValentin Clement       } else {
160b0eef1eeSValentin Clement         // Indirect calls.
161c336e72cSPeixin-Qiao         llvm::SmallVector<mlir::Type> newInputTypes;
162c336e72cSPeixin-Qiao         if (!isResultBuiltinCPtr)
163c336e72cSPeixin-Qiao           newInputTypes.emplace_back(argType);
164afb34cf3SValentin Clement         for (auto operand : op.getOperands().drop_front())
165b0eef1eeSValentin Clement           newInputTypes.push_back(operand.getType());
166afb34cf3SValentin Clement         auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes,
167afb34cf3SValentin Clement                                                  newResultTypes);
168b0eef1eeSValentin Clement 
169b0eef1eeSValentin Clement         llvm::SmallVector<mlir::Value> newOperands;
170afb34cf3SValentin Clement         newOperands.push_back(
171afb34cf3SValentin Clement             rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0)));
172c336e72cSPeixin-Qiao         if (!isResultBuiltinCPtr)
173b0eef1eeSValentin Clement           newOperands.push_back(arg);
174afb34cf3SValentin Clement         newOperands.append(op.getOperands().begin() + 1,
175afb34cf3SValentin Clement                            op.getOperands().end());
176afb34cf3SValentin Clement         newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
177c336e72cSPeixin-Qiao                                              newResultTypes, newOperands);
178c336e72cSPeixin-Qiao       }
179afb34cf3SValentin Clement     }
180afb34cf3SValentin Clement 
181afb34cf3SValentin Clement     // fir::DispatchOp specific handling.
182afb34cf3SValentin Clement     if constexpr (std::is_same_v<Op, fir::DispatchOp>) {
183afb34cf3SValentin Clement       llvm::SmallVector<mlir::Value> newOperands;
184afb34cf3SValentin Clement       if (!isResultBuiltinCPtr)
185afb34cf3SValentin Clement         newOperands.emplace_back(arg);
186afb34cf3SValentin Clement       unsigned passArgShift = newOperands.size();
187afb34cf3SValentin Clement       newOperands.append(op.getOperands().begin() + 1, op.getOperands().end());
188a78359c2SjeanPerier       mlir::IntegerAttr passArgPos;
189afb34cf3SValentin Clement       if (op.getPassArgPos())
190a78359c2SjeanPerier         passArgPos =
191a78359c2SjeanPerier             rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift);
192afb34cf3SValentin Clement       newOp = rewriter.create<fir::DispatchOp>(
193afb34cf3SValentin Clement           loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
194a78359c2SjeanPerier           op.getOperands()[0], newOperands, passArgPos,
195a78359c2SjeanPerier           op.getProcedureAttrsAttr());
196afb34cf3SValentin Clement     }
197afb34cf3SValentin Clement 
198c336e72cSPeixin-Qiao     if (isResultBuiltinCPtr) {
199c336e72cSPeixin-Qiao       mlir::Value save = saveResult.getMemref();
200afb34cf3SValentin Clement       auto module = op->template getParentOfType<mlir::ModuleOp>();
20153cc33b0STom Eccles       FirOpBuilder builder(rewriter, module);
202c336e72cSPeixin-Qiao       mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
203c336e72cSPeixin-Qiao           builder, loc, save, result.getType());
2041ead51a8SjeanPerier       builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr);
205b0eef1eeSValentin Clement     }
206afb34cf3SValentin Clement     op->dropAllReferences();
207afb34cf3SValentin Clement     rewriter.eraseOp(op);
208b0eef1eeSValentin Clement     return mlir::success();
209b0eef1eeSValentin Clement   }
210b0eef1eeSValentin Clement 
211b0eef1eeSValentin Clement private:
212ea1cdb58SDaniil Dudkin   bool shouldBoxResult;
213b0eef1eeSValentin Clement };
214b0eef1eeSValentin Clement 
215b0eef1eeSValentin Clement class SaveResultOpConversion
216b0eef1eeSValentin Clement     : public mlir::OpRewritePattern<fir::SaveResultOp> {
217b0eef1eeSValentin Clement public:
218b0eef1eeSValentin Clement   using OpRewritePattern::OpRewritePattern;
219b0eef1eeSValentin Clement   SaveResultOpConversion(mlir::MLIRContext *context)
220b0eef1eeSValentin Clement       : OpRewritePattern(context) {}
221db791b27SRamkumar Ramachandra   llvm::LogicalResult
222b0eef1eeSValentin Clement   matchAndRewrite(fir::SaveResultOp op,
223b0eef1eeSValentin Clement                   mlir::PatternRewriter &rewriter) const override {
224367c3c96SjeanPerier     mlir::Operation *call = op.getValue().getDefiningOp();
225367c3c96SjeanPerier     mlir::Type type = op.getValue().getType();
226367c3c96SjeanPerier     if (mlir::isa<fir::RecordType>(type) && call && fir::hasBindcAttr(call) &&
227367c3c96SjeanPerier         !fir::isa_builtin_cptr_type(type)) {
228367c3c96SjeanPerier       rewriter.replaceOpWithNewOp<fir::StoreOp>(op, op.getValue(),
229367c3c96SjeanPerier                                                 op.getMemref());
230367c3c96SjeanPerier     } else {
231b0eef1eeSValentin Clement       rewriter.eraseOp(op);
232367c3c96SjeanPerier     }
233b0eef1eeSValentin Clement     return mlir::success();
234b0eef1eeSValentin Clement   }
235b0eef1eeSValentin Clement };
236b0eef1eeSValentin Clement 
237*75623bfeSValentin Clement (バレンタイン クレメン) template <typename OpTy>
238*75623bfeSValentin Clement (バレンタイン クレメン) static mlir::LogicalResult
239*75623bfeSValentin Clement (バレンタイン クレメン) processReturnLikeOp(OpTy ret, mlir::Value newArg,
240*75623bfeSValentin Clement (バレンタイン クレメン)                     mlir::PatternRewriter &rewriter) {
241c336e72cSPeixin-Qiao   auto loc = ret.getLoc();
242b0eef1eeSValentin Clement   rewriter.setInsertionPoint(ret);
2431ead51a8SjeanPerier   mlir::Value resultValue = ret.getOperand(0);
2441ead51a8SjeanPerier   fir::LoadOp resultLoad;
2451ead51a8SjeanPerier   mlir::Value resultStorage;
2461ead51a8SjeanPerier   // Identify result local storage.
2471ead51a8SjeanPerier   if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) {
2481ead51a8SjeanPerier     resultLoad = load;
2491ead51a8SjeanPerier     resultStorage = load.getMemref();
250c203850aSJean Perier     // The result alloca may be behind a fir.declare, if any.
2511ead51a8SjeanPerier     if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>())
252c203850aSJean Perier       resultStorage = declare.getMemref();
2531ead51a8SjeanPerier   }
2541ead51a8SjeanPerier   // Replace old local storage with new storage argument, unless
2551ead51a8SjeanPerier   // the derived type is C_PTR/C_FUN_PTR, in which case the return
2561ead51a8SjeanPerier   // type is updated to return void* (no new argument is passed).
2571ead51a8SjeanPerier   if (fir::isa_builtin_cptr_type(resultValue.getType())) {
258*75623bfeSValentin Clement (バレンタイン クレメン)     auto module = ret->template getParentOfType<mlir::ModuleOp>();
25953cc33b0STom Eccles     FirOpBuilder builder(rewriter, module);
2601ead51a8SjeanPerier     mlir::Value cptr = resultValue;
2611ead51a8SjeanPerier     if (resultLoad) {
2621ead51a8SjeanPerier       // Replace whole derived type load by component load.
2631ead51a8SjeanPerier       cptr = resultLoad.getMemref();
2641ead51a8SjeanPerier       rewriter.setInsertionPoint(resultLoad);
2651ead51a8SjeanPerier     }
2661ead51a8SjeanPerier     mlir::Value newResultValue =
2671ead51a8SjeanPerier         fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr);
2681ead51a8SjeanPerier     newResultValue = builder.createConvert(
2691ead51a8SjeanPerier         loc, getVoidPtrType(ret.getContext()), newResultValue);
2701ead51a8SjeanPerier     rewriter.setInsertionPoint(ret);
271*75623bfeSValentin Clement (バレンタイン クレメン)     rewriter.replaceOpWithNewOp<OpTy>(ret, mlir::ValueRange{newResultValue});
2721ead51a8SjeanPerier   } else if (resultStorage) {
273c203850aSJean Perier     resultStorage.replaceAllUsesWith(newArg);
274*75623bfeSValentin Clement (バレンタイン クレメン)     rewriter.replaceOpWithNewOp<OpTy>(ret);
2751ead51a8SjeanPerier   } else {
276b0eef1eeSValentin Clement     // The result storage may have been optimized out by a memory to
277b0eef1eeSValentin Clement     // register pass, this is possible for fir.box results, or fir.record
2781ead51a8SjeanPerier     // with no length parameters. Simply store the result in the result
2791ead51a8SjeanPerier     // storage. at the return point.
2801ead51a8SjeanPerier     rewriter.create<fir::StoreOp>(loc, resultValue, newArg);
281*75623bfeSValentin Clement (バレンタイン クレメン)     rewriter.replaceOpWithNewOp<OpTy>(ret);
2821ead51a8SjeanPerier   }
2831ead51a8SjeanPerier   // Delete result old local storage if unused.
2841ead51a8SjeanPerier   if (resultStorage)
2851ead51a8SjeanPerier     if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>())
2861ead51a8SjeanPerier       if (alloc->use_empty())
2871ead51a8SjeanPerier         rewriter.eraseOp(alloc);
288b0eef1eeSValentin Clement   return mlir::success();
289b0eef1eeSValentin Clement }
290b0eef1eeSValentin Clement 
291*75623bfeSValentin Clement (バレンタイン クレメン) class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
292*75623bfeSValentin Clement (バレンタイン クレメン) public:
293*75623bfeSValentin Clement (バレンタイン クレメン)   using OpRewritePattern::OpRewritePattern;
294*75623bfeSValentin Clement (バレンタイン クレメン)   ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
295*75623bfeSValentin Clement (バレンタイン クレメン)       : OpRewritePattern(context), newArg{newArg} {}
296*75623bfeSValentin Clement (バレンタイン クレメン)   llvm::LogicalResult
297*75623bfeSValentin Clement (バレンタイン クレメン)   matchAndRewrite(mlir::func::ReturnOp ret,
298*75623bfeSValentin Clement (バレンタイン クレメン)                   mlir::PatternRewriter &rewriter) const override {
299*75623bfeSValentin Clement (バレンタイン クレメン)     return processReturnLikeOp(ret, newArg, rewriter);
300*75623bfeSValentin Clement (バレンタイン クレメン)   }
301*75623bfeSValentin Clement (バレンタイン クレメン) 
302*75623bfeSValentin Clement (バレンタイン クレメン) private:
303*75623bfeSValentin Clement (バレンタイン クレメン)   mlir::Value newArg;
304*75623bfeSValentin Clement (バレンタイン クレメン) };
305*75623bfeSValentin Clement (バレンタイン クレメン) 
306*75623bfeSValentin Clement (バレンタイン クレメン) class GPUReturnOpConversion
307*75623bfeSValentin Clement (バレンタイン クレメン)     : public mlir::OpRewritePattern<mlir::gpu::ReturnOp> {
308*75623bfeSValentin Clement (バレンタイン クレメン) public:
309*75623bfeSValentin Clement (バレンタイン クレメン)   using OpRewritePattern::OpRewritePattern;
310*75623bfeSValentin Clement (バレンタイン クレメン)   GPUReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
311*75623bfeSValentin Clement (バレンタイン クレメン)       : OpRewritePattern(context), newArg{newArg} {}
312*75623bfeSValentin Clement (バレンタイン クレメン)   llvm::LogicalResult
313*75623bfeSValentin Clement (バレンタイン クレメン)   matchAndRewrite(mlir::gpu::ReturnOp ret,
314*75623bfeSValentin Clement (バレンタイン クレメン)                   mlir::PatternRewriter &rewriter) const override {
315*75623bfeSValentin Clement (バレンタイン クレメン)     return processReturnLikeOp(ret, newArg, rewriter);
316*75623bfeSValentin Clement (バレンタイン クレメン)   }
317*75623bfeSValentin Clement (バレンタイン クレメン) 
318b0eef1eeSValentin Clement private:
319ea1cdb58SDaniil Dudkin   mlir::Value newArg;
320b0eef1eeSValentin Clement };
321b0eef1eeSValentin Clement 
322b0eef1eeSValentin Clement class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
323b0eef1eeSValentin Clement public:
324b0eef1eeSValentin Clement   using OpRewritePattern::OpRewritePattern;
325ea1cdb58SDaniil Dudkin   AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
326ea1cdb58SDaniil Dudkin       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
327db791b27SRamkumar Ramachandra   llvm::LogicalResult
328b0eef1eeSValentin Clement   matchAndRewrite(fir::AddrOfOp addrOf,
329b0eef1eeSValentin Clement                   mlir::PatternRewriter &rewriter) const override {
330fac349a1SChristian Sigg     auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType());
331c336e72cSPeixin-Qiao     mlir::FunctionType newFuncTy;
332c336e72cSPeixin-Qiao     if (oldFuncTy.getNumResults() != 0 &&
333c336e72cSPeixin-Qiao         fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
334c336e72cSPeixin-Qiao       newFuncTy = getCPtrFunctionType(oldFuncTy);
335c336e72cSPeixin-Qiao     else
336c336e72cSPeixin-Qiao       newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
337b0eef1eeSValentin Clement     auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
338149ad3d5SShraiysh Vaishay                                                     addrOf.getSymbol());
339b0eef1eeSValentin Clement     // Rather than converting all op a function pointer might transit through
340b0eef1eeSValentin Clement     // (e.g calls, stores, loads, converts...), cast new type to the abstract
341b0eef1eeSValentin Clement     // type. A conversion will be added when calling indirect calls of abstract
342b0eef1eeSValentin Clement     // types.
343b0eef1eeSValentin Clement     rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
344b0eef1eeSValentin Clement     return mlir::success();
345b0eef1eeSValentin Clement   }
346b0eef1eeSValentin Clement 
347b0eef1eeSValentin Clement private:
348ea1cdb58SDaniil Dudkin   bool shouldBoxResult;
349b0eef1eeSValentin Clement };
350b0eef1eeSValentin Clement 
351bfd19445STom Eccles class AbstractResultOpt
352bfd19445STom Eccles     : public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
353b0eef1eeSValentin Clement public:
354bfd19445STom Eccles   using fir::impl::AbstractResultOptBase<
355bfd19445STom Eccles       AbstractResultOpt>::AbstractResultOptBase;
356cb33e4abSDaniil Dudkin 
3575522d246SValentin Clement (バレンタイン クレメン)   template <typename OpTy>
3585522d246SValentin Clement (バレンタイン クレメン)   void runOnFunctionLikeOperation(OpTy func, bool shouldBoxResult,
359cb33e4abSDaniil Dudkin                                   mlir::RewritePatternSet &patterns,
360cb33e4abSDaniil Dudkin                                   mlir::ConversionTarget &target) {
361cb33e4abSDaniil Dudkin     auto loc = func.getLoc();
362cb33e4abSDaniil Dudkin     auto *context = &getContext();
363cb33e4abSDaniil Dudkin     // Convert function type itself if it has an abstract result.
364fac349a1SChristian Sigg     auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
365367c3c96SjeanPerier     // Scalar derived result of BIND(C) function must be returned according
366367c3c96SjeanPerier     // to the C struct return ABI which is target dependent and implemented in
367367c3c96SjeanPerier     // the target-rewrite pass.
368367c3c96SjeanPerier     if (hasScalarDerivedResult(funcTy) &&
369367c3c96SjeanPerier         fir::hasBindcAttr(func.getOperation()))
370367c3c96SjeanPerier       return;
371cb33e4abSDaniil Dudkin     if (hasAbstractResult(funcTy)) {
372c336e72cSPeixin-Qiao       if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
373c336e72cSPeixin-Qiao         func.setType(getCPtrFunctionType(funcTy));
374c336e72cSPeixin-Qiao         patterns.insert<ReturnOpConversion>(context, mlir::Value{});
375c336e72cSPeixin-Qiao         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
376c336e72cSPeixin-Qiao             [](mlir::func::ReturnOp ret) {
377c336e72cSPeixin-Qiao               mlir::Type retTy = ret.getOperand(0).getType();
378c336e72cSPeixin-Qiao               return !fir::isa_builtin_cptr_type(retTy);
379c336e72cSPeixin-Qiao             });
380c336e72cSPeixin-Qiao         return;
381c336e72cSPeixin-Qiao       }
382cb33e4abSDaniil Dudkin       if (!func.empty()) {
383cb33e4abSDaniil Dudkin         // Insert new argument.
384cb33e4abSDaniil Dudkin         mlir::OpBuilder rewriter(context);
385cb33e4abSDaniil Dudkin         auto resultType = funcTy.getResult(0);
386cb33e4abSDaniil Dudkin         auto argTy = getResultArgumentType(resultType, shouldBoxResult);
3874d5a9c1dSRenaud-K         func.insertArgument(0u, argTy, {}, loc);
3884d5a9c1dSRenaud-K         func.eraseResult(0u);
3894d5a9c1dSRenaud-K         mlir::Value newArg = func.getArgument(0u);
390cb33e4abSDaniil Dudkin         if (mustEmboxResult(resultType, shouldBoxResult)) {
391cb33e4abSDaniil Dudkin           auto bufferType = fir::ReferenceType::get(resultType);
392cb33e4abSDaniil Dudkin           rewriter.setInsertionPointToStart(&func.front());
393cb33e4abSDaniil Dudkin           newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
394cb33e4abSDaniil Dudkin         }
395cb33e4abSDaniil Dudkin         patterns.insert<ReturnOpConversion>(context, newArg);
396cb33e4abSDaniil Dudkin         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
397b74192b7SRiver Riddle             [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); });
398*75623bfeSValentin Clement (バレンタイン クレメン)         patterns.insert<GPUReturnOpConversion>(context, newArg);
399*75623bfeSValentin Clement (バレンタイン クレメン)         target.addDynamicallyLegalOp<mlir::gpu::ReturnOp>(
400*75623bfeSValentin Clement (バレンタイン クレメン)             [](mlir::gpu::ReturnOp ret) { return ret.getOperands().empty(); });
4014d5a9c1dSRenaud-K         assert(func.getFunctionType() ==
4024d5a9c1dSRenaud-K                getNewFunctionType(funcTy, shouldBoxResult));
4034d5a9c1dSRenaud-K       } else {
4045f9e0491SRenaud-K         llvm::SmallVector<mlir::DictionaryAttr> allArgs;
4055f9e0491SRenaud-K         func.getAllArgAttrs(allArgs);
4065f9e0491SRenaud-K         allArgs.insert(allArgs.begin(),
4075f9e0491SRenaud-K                        mlir::DictionaryAttr::get(func->getContext()));
4084d5a9c1dSRenaud-K         func.setType(getNewFunctionType(funcTy, shouldBoxResult));
4095f9e0491SRenaud-K         func.setAllArgAttrs(allArgs);
410cb33e4abSDaniil Dudkin       }
411b0eef1eeSValentin Clement     }
412b0eef1eeSValentin Clement   }
413a6f2f44fSDaniil Dudkin 
4145522d246SValentin Clement (バレンタイン クレメン)   void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
4155522d246SValentin Clement (バレンタイン クレメン)                               mlir::RewritePatternSet &patterns,
4165522d246SValentin Clement (バレンタイン クレメン)                               mlir::ConversionTarget &target) {
4175522d246SValentin Clement (バレンタイン クレメン)     runOnFunctionLikeOperation(func, shouldBoxResult, patterns, target);
4185522d246SValentin Clement (バレンタイン クレメン)   }
4195522d246SValentin Clement (バレンタイン クレメン) 
4205522d246SValentin Clement (バレンタイン クレメン)   void runOnSpecificOperation(mlir::gpu::GPUFuncOp func, bool shouldBoxResult,
4215522d246SValentin Clement (バレンタイン クレメン)                               mlir::RewritePatternSet &patterns,
4225522d246SValentin Clement (バレンタイン クレメン)                               mlir::ConversionTarget &target) {
4235522d246SValentin Clement (バレンタイン クレメン)     runOnFunctionLikeOperation(func, shouldBoxResult, patterns, target);
4245522d246SValentin Clement (バレンタイン クレメン)   }
4255522d246SValentin Clement (バレンタイン クレメン) 
426a6f2f44fSDaniil Dudkin   inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
427a6f2f44fSDaniil Dudkin     return mlir::TypeSwitch<mlir::Type, bool>(type)
428a6f2f44fSDaniil Dudkin         .Case([](fir::BoxProcType boxProc) {
429a6f2f44fSDaniil Dudkin           return fir::hasAbstractResult(
430fac349a1SChristian Sigg               mlir::cast<mlir::FunctionType>(boxProc.getEleTy()));
431a6f2f44fSDaniil Dudkin         })
432a6f2f44fSDaniil Dudkin         .Case([](fir::PointerType pointer) {
433a6f2f44fSDaniil Dudkin           return fir::hasAbstractResult(
434fac349a1SChristian Sigg               mlir::cast<mlir::FunctionType>(pointer.getEleTy()));
435a6f2f44fSDaniil Dudkin         })
436a6f2f44fSDaniil Dudkin         .Default([](auto &&) { return false; });
437a6f2f44fSDaniil Dudkin   }
438a6f2f44fSDaniil Dudkin 
439a6f2f44fSDaniil Dudkin   void runOnSpecificOperation(fir::GlobalOp global, bool,
440a6f2f44fSDaniil Dudkin                               mlir::RewritePatternSet &,
441a6f2f44fSDaniil Dudkin                               mlir::ConversionTarget &) {
442a6f2f44fSDaniil Dudkin     if (containsFunctionTypeWithAbstractResult(global.getType())) {
443a6f2f44fSDaniil Dudkin       TODO(global->getLoc(), "support for procedure pointers");
444a6f2f44fSDaniil Dudkin     }
445a6f2f44fSDaniil Dudkin   }
446bfd19445STom Eccles 
447bfd19445STom Eccles   /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work.
448bfd19445STom Eccles   void runOnModule() {
449bfd19445STom Eccles     mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation());
450bfd19445STom Eccles 
451bfd19445STom Eccles     auto pass = std::make_unique<AbstractResultOpt>();
452bfd19445STom Eccles     pass->copyOptionValuesFrom(this);
453bfd19445STom Eccles     mlir::OpPassManager pipeline;
454bfd19445STom Eccles     pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()});
455bfd19445STom Eccles 
456bfd19445STom Eccles     // Run the pass on all operations directly nested inside of the ModuleOp
457bfd19445STom Eccles     // we can't just call runOnSpecificOperation here because the pass
458bfd19445STom Eccles     // implementation only works when scoped to a particular func.func or
459bfd19445STom Eccles     // fir.global
460bfd19445STom Eccles     for (mlir::Region &region : mod->getRegions()) {
461bfd19445STom Eccles       for (mlir::Block &block : region.getBlocks()) {
462bfd19445STom Eccles         for (mlir::Operation &op : block.getOperations()) {
463bfd19445STom Eccles           if (mlir::failed(runPipeline(pipeline, &op))) {
464bfd19445STom Eccles             mlir::emitError(op.getLoc(), "Failed to run abstract result pass");
465bfd19445STom Eccles             signalPassFailure();
466bfd19445STom Eccles             return;
467bfd19445STom Eccles           }
468bfd19445STom Eccles         }
469bfd19445STom Eccles       }
470bfd19445STom Eccles     }
471bfd19445STom Eccles   }
472bfd19445STom Eccles 
473bfd19445STom Eccles   void runOnOperation() override {
474bfd19445STom Eccles     auto *context = &this->getContext();
475bfd19445STom Eccles     mlir::Operation *op = this->getOperation();
476bfd19445STom Eccles     if (mlir::isa<mlir::ModuleOp>(op)) {
477bfd19445STom Eccles       runOnModule();
478bfd19445STom Eccles       return;
479bfd19445STom Eccles     }
480bfd19445STom Eccles 
481367c3c96SjeanPerier     LazySymbolTable symbolTable(op);
482367c3c96SjeanPerier 
483bfd19445STom Eccles     mlir::RewritePatternSet patterns(context);
484bfd19445STom Eccles     mlir::ConversionTarget target = *context;
485bfd19445STom Eccles     const bool shouldBoxResult = this->passResultAsBox.getValue();
486bfd19445STom Eccles 
487bfd19445STom Eccles     mlir::TypeSwitch<mlir::Operation *, void>(op)
4881d4b5c16SValentin Clement (バレンタイン クレメン)         .Case<mlir::func::FuncOp, fir::GlobalOp, mlir::gpu::GPUFuncOp>(
4891d4b5c16SValentin Clement (バレンタイン クレメン)             [&](auto op) {
490bfd19445STom Eccles               runOnSpecificOperation(op, shouldBoxResult, patterns, target);
491bfd19445STom Eccles             });
492bfd19445STom Eccles 
493bfd19445STom Eccles     // Convert the calls and, if needed,  the ReturnOp in the function body.
494bfd19445STom Eccles     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
495bfd19445STom Eccles                            mlir::func::FuncDialect>();
496bfd19445STom Eccles     target.addIllegalOp<fir::SaveResultOp>();
497bfd19445STom Eccles     target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
498367c3c96SjeanPerier       mlir::FunctionType funTy = call.getFunctionType();
499367c3c96SjeanPerier       if (hasScalarDerivedResult(funTy) &&
500367c3c96SjeanPerier           fir::hasBindcAttr(call.getOperation()))
501367c3c96SjeanPerier         return true;
502480e7f06SjeanPerier       return !hasAbstractResult(funTy);
503367c3c96SjeanPerier     });
504367c3c96SjeanPerier     target.addDynamicallyLegalOp<fir::AddrOfOp>([&symbolTable](
505367c3c96SjeanPerier                                                     fir::AddrOfOp addrOf) {
506367c3c96SjeanPerier       if (auto funTy = mlir::dyn_cast<mlir::FunctionType>(addrOf.getType())) {
507367c3c96SjeanPerier         if (hasScalarDerivedResult(funTy)) {
508367c3c96SjeanPerier           auto func = symbolTable.lookup<mlir::func::FuncOp>(
509367c3c96SjeanPerier               addrOf.getSymbol().getRootReference().getValue());
510367c3c96SjeanPerier           return func && fir::hasBindcAttr(func.getOperation());
511367c3c96SjeanPerier         }
512367c3c96SjeanPerier         return !hasAbstractResult(funTy);
513367c3c96SjeanPerier       }
514bfd19445STom Eccles       return true;
515bfd19445STom Eccles     });
516bfd19445STom Eccles     target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
517367c3c96SjeanPerier       mlir::FunctionType funTy = dispatch.getFunctionType();
518367c3c96SjeanPerier       if (hasScalarDerivedResult(funTy) &&
519367c3c96SjeanPerier           fir::hasBindcAttr(dispatch.getOperation()))
520367c3c96SjeanPerier         return true;
521bfd19445STom Eccles       return !hasAbstractResult(dispatch.getFunctionType());
522bfd19445STom Eccles     });
523bfd19445STom Eccles 
524bfd19445STom Eccles     patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
525bfd19445STom Eccles     patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
526bfd19445STom Eccles     patterns.insert<SaveResultOpConversion>(context);
527bfd19445STom Eccles     patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
528bfd19445STom Eccles     if (mlir::failed(
529bfd19445STom Eccles             mlir::applyPartialConversion(op, target, std::move(patterns)))) {
530bfd19445STom Eccles       mlir::emitError(op->getLoc(), "error in converting abstract results\n");
531bfd19445STom Eccles       this->signalPassFailure();
532bfd19445STom Eccles     }
533bfd19445STom Eccles   }
534a6f2f44fSDaniil Dudkin };
535bfd19445STom Eccles 
536b0eef1eeSValentin Clement } // end anonymous namespace
537b0eef1eeSValentin Clement } // namespace fir
538