xref: /llvm-project/flang/lib/Optimizer/Transforms/AbstractResult.cpp (revision 367c3c968eb8f29b55fb8019b2464c7ff6307ca8)
1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/Todo.h"
11 #include "flang/Optimizer/Dialect/FIRDialect.h"
12 #include "flang/Optimizer/Dialect/FIROps.h"
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
15 #include "flang/Optimizer/Transforms/Passes.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Diagnostics.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Pass/PassManager.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace fir {
24 #define GEN_PASS_DEF_ABSTRACTRESULTOPT
25 #include "flang/Optimizer/Transforms/Passes.h.inc"
26 } // namespace fir
27 
28 #define DEBUG_TYPE "flang-abstract-result-opt"
29 
30 using namespace mlir;
31 
32 namespace fir {
33 namespace {
34 
35 // Helper to only build the symbol table if needed because its build time is
36 // linear on the number of symbols in the module.
37 struct LazySymbolTable {
38   LazySymbolTable(mlir::Operation *op)
39       : module{op->getParentOfType<mlir::ModuleOp>()} {}
40   void build() {
41     if (table)
42       return;
43     table = std::make_unique<mlir::SymbolTable>(module);
44   }
45 
46   template <typename T>
47   T lookup(llvm::StringRef name) {
48     build();
49     return table->lookup<T>(name);
50   }
51 
52 private:
53   std::unique_ptr<mlir::SymbolTable> table;
54   mlir::ModuleOp module;
55 };
56 
57 bool hasScalarDerivedResult(mlir::FunctionType funTy) {
58   // C_PTR/C_FUNPTR are results to void* in this pass, do not consider
59   // them as normal derived types.
60   return funTy.getNumResults() == 1 &&
61          mlir::isa<fir::RecordType>(funTy.getResult(0)) &&
62          !fir::isa_builtin_cptr_type(funTy.getResult(0));
63 }
64 
65 static mlir::Type getResultArgumentType(mlir::Type resultType,
66                                         bool shouldBoxResult) {
67   return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
68       .Case<fir::SequenceType, fir::RecordType>(
69           [&](mlir::Type type) -> mlir::Type {
70             if (shouldBoxResult)
71               return fir::BoxType::get(type);
72             return fir::ReferenceType::get(type);
73           })
74       .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type {
75         return fir::ReferenceType::get(type);
76       })
77       .Default([](mlir::Type) -> mlir::Type {
78         llvm_unreachable("bad abstract result type");
79       });
80 }
81 
82 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
83                                              bool shouldBoxResult) {
84   auto resultType = funcTy.getResult(0);
85   auto argTy = getResultArgumentType(resultType, shouldBoxResult);
86   llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
87   newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
88   return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
89                                  /*resultTypes=*/{});
90 }
91 
92 static mlir::Type getVoidPtrType(mlir::MLIRContext *context) {
93   return fir::ReferenceType::get(mlir::NoneType::get(context));
94 }
95 
96 /// This is for function result types that are of type C_PTR from ISO_C_BINDING.
97 /// Follow the ABI for interoperability with C.
98 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
99   assert(fir::isa_builtin_cptr_type(funcTy.getResult(0)));
100   llvm::SmallVector<mlir::Type> outputTypes{
101       getVoidPtrType(funcTy.getContext())};
102   return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
103                                  outputTypes);
104 }
105 
106 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
107   return mlir::isa<fir::SequenceType, fir::RecordType>(resultType) &&
108          shouldBoxResult;
109 }
110 
111 template <typename Op>
112 class CallConversion : public mlir::OpRewritePattern<Op> {
113 public:
114   using mlir::OpRewritePattern<Op>::OpRewritePattern;
115 
116   CallConversion(mlir::MLIRContext *context, bool shouldBoxResult)
117       : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {}
118 
119   llvm::LogicalResult
120   matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
121     auto loc = op.getLoc();
122     auto result = op->getResult(0);
123     if (!result.hasOneUse()) {
124       mlir::emitError(loc,
125                       "calls with abstract result must have exactly one user");
126       return mlir::failure();
127     }
128     auto saveResult =
129         mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
130     if (!saveResult) {
131       mlir::emitError(
132           loc, "calls with abstract result must be used in fir.save_result");
133       return mlir::failure();
134     }
135     auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
136     auto buffer = saveResult.getMemref();
137     mlir::Value arg = buffer;
138     if (mustEmboxResult(result.getType(), shouldBoxResult))
139       arg = rewriter.create<fir::EmboxOp>(
140           loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
141           saveResult.getTypeparams());
142 
143     llvm::SmallVector<mlir::Type> newResultTypes;
144     bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
145     if (isResultBuiltinCPtr)
146       newResultTypes.emplace_back(getVoidPtrType(result.getContext()));
147 
148     Op newOp;
149     // fir::CallOp specific handling.
150     if constexpr (std::is_same_v<Op, fir::CallOp>) {
151       if (op.getCallee()) {
152         llvm::SmallVector<mlir::Value> newOperands;
153         if (!isResultBuiltinCPtr)
154           newOperands.emplace_back(arg);
155         newOperands.append(op.getOperands().begin(), op.getOperands().end());
156         newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(),
157                                              newResultTypes, newOperands);
158       } else {
159         // Indirect calls.
160         llvm::SmallVector<mlir::Type> newInputTypes;
161         if (!isResultBuiltinCPtr)
162           newInputTypes.emplace_back(argType);
163         for (auto operand : op.getOperands().drop_front())
164           newInputTypes.push_back(operand.getType());
165         auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes,
166                                                  newResultTypes);
167 
168         llvm::SmallVector<mlir::Value> newOperands;
169         newOperands.push_back(
170             rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0)));
171         if (!isResultBuiltinCPtr)
172           newOperands.push_back(arg);
173         newOperands.append(op.getOperands().begin() + 1,
174                            op.getOperands().end());
175         newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
176                                              newResultTypes, newOperands);
177       }
178     }
179 
180     // fir::DispatchOp specific handling.
181     if constexpr (std::is_same_v<Op, fir::DispatchOp>) {
182       llvm::SmallVector<mlir::Value> newOperands;
183       if (!isResultBuiltinCPtr)
184         newOperands.emplace_back(arg);
185       unsigned passArgShift = newOperands.size();
186       newOperands.append(op.getOperands().begin() + 1, op.getOperands().end());
187       mlir::IntegerAttr passArgPos;
188       if (op.getPassArgPos())
189         passArgPos =
190             rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift);
191       newOp = rewriter.create<fir::DispatchOp>(
192           loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
193           op.getOperands()[0], newOperands, passArgPos,
194           op.getProcedureAttrsAttr());
195     }
196 
197     if (isResultBuiltinCPtr) {
198       mlir::Value save = saveResult.getMemref();
199       auto module = op->template getParentOfType<mlir::ModuleOp>();
200       FirOpBuilder builder(rewriter, module);
201       mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
202           builder, loc, save, result.getType());
203       builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr);
204     }
205     op->dropAllReferences();
206     rewriter.eraseOp(op);
207     return mlir::success();
208   }
209 
210 private:
211   bool shouldBoxResult;
212 };
213 
214 class SaveResultOpConversion
215     : public mlir::OpRewritePattern<fir::SaveResultOp> {
216 public:
217   using OpRewritePattern::OpRewritePattern;
218   SaveResultOpConversion(mlir::MLIRContext *context)
219       : OpRewritePattern(context) {}
220   llvm::LogicalResult
221   matchAndRewrite(fir::SaveResultOp op,
222                   mlir::PatternRewriter &rewriter) const override {
223     mlir::Operation *call = op.getValue().getDefiningOp();
224     mlir::Type type = op.getValue().getType();
225     if (mlir::isa<fir::RecordType>(type) && call && fir::hasBindcAttr(call) &&
226         !fir::isa_builtin_cptr_type(type)) {
227       rewriter.replaceOpWithNewOp<fir::StoreOp>(op, op.getValue(),
228                                                 op.getMemref());
229     } else {
230       rewriter.eraseOp(op);
231     }
232     return mlir::success();
233   }
234 };
235 
236 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
237 public:
238   using OpRewritePattern::OpRewritePattern;
239   ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
240       : OpRewritePattern(context), newArg{newArg} {}
241   llvm::LogicalResult
242   matchAndRewrite(mlir::func::ReturnOp ret,
243                   mlir::PatternRewriter &rewriter) const override {
244     auto loc = ret.getLoc();
245     rewriter.setInsertionPoint(ret);
246     mlir::Value resultValue = ret.getOperand(0);
247     fir::LoadOp resultLoad;
248     mlir::Value resultStorage;
249     // Identify result local storage.
250     if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) {
251       resultLoad = load;
252       resultStorage = load.getMemref();
253       // The result alloca may be behind a fir.declare, if any.
254       if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>())
255         resultStorage = declare.getMemref();
256     }
257     // Replace old local storage with new storage argument, unless
258     // the derived type is C_PTR/C_FUN_PTR, in which case the return
259     // type is updated to return void* (no new argument is passed).
260     if (fir::isa_builtin_cptr_type(resultValue.getType())) {
261       auto module = ret->getParentOfType<mlir::ModuleOp>();
262       FirOpBuilder builder(rewriter, module);
263       mlir::Value cptr = resultValue;
264       if (resultLoad) {
265         // Replace whole derived type load by component load.
266         cptr = resultLoad.getMemref();
267         rewriter.setInsertionPoint(resultLoad);
268       }
269       mlir::Value newResultValue =
270           fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr);
271       newResultValue = builder.createConvert(
272           loc, getVoidPtrType(ret.getContext()), newResultValue);
273       rewriter.setInsertionPoint(ret);
274       rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
275           ret, mlir::ValueRange{newResultValue});
276     } else if (resultStorage) {
277       resultStorage.replaceAllUsesWith(newArg);
278       rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
279     } else {
280       // The result storage may have been optimized out by a memory to
281       // register pass, this is possible for fir.box results, or fir.record
282       // with no length parameters. Simply store the result in the result
283       // storage. at the return point.
284       rewriter.create<fir::StoreOp>(loc, resultValue, newArg);
285       rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
286     }
287     // Delete result old local storage if unused.
288     if (resultStorage)
289       if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>())
290         if (alloc->use_empty())
291           rewriter.eraseOp(alloc);
292     return mlir::success();
293   }
294 
295 private:
296   mlir::Value newArg;
297 };
298 
299 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
300 public:
301   using OpRewritePattern::OpRewritePattern;
302   AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
303       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
304   llvm::LogicalResult
305   matchAndRewrite(fir::AddrOfOp addrOf,
306                   mlir::PatternRewriter &rewriter) const override {
307     auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType());
308     mlir::FunctionType newFuncTy;
309     if (oldFuncTy.getNumResults() != 0 &&
310         fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
311       newFuncTy = getCPtrFunctionType(oldFuncTy);
312     else
313       newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
314     auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
315                                                     addrOf.getSymbol());
316     // Rather than converting all op a function pointer might transit through
317     // (e.g calls, stores, loads, converts...), cast new type to the abstract
318     // type. A conversion will be added when calling indirect calls of abstract
319     // types.
320     rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
321     return mlir::success();
322   }
323 
324 private:
325   bool shouldBoxResult;
326 };
327 
328 class AbstractResultOpt
329     : public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
330 public:
331   using fir::impl::AbstractResultOptBase<
332       AbstractResultOpt>::AbstractResultOptBase;
333 
334   void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
335                               mlir::RewritePatternSet &patterns,
336                               mlir::ConversionTarget &target) {
337     auto loc = func.getLoc();
338     auto *context = &getContext();
339     // Convert function type itself if it has an abstract result.
340     auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
341     // Scalar derived result of BIND(C) function must be returned according
342     // to the C struct return ABI which is target dependent and implemented in
343     // the target-rewrite pass.
344     if (hasScalarDerivedResult(funcTy) &&
345         fir::hasBindcAttr(func.getOperation()))
346       return;
347     if (hasAbstractResult(funcTy)) {
348       if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
349         func.setType(getCPtrFunctionType(funcTy));
350         patterns.insert<ReturnOpConversion>(context, mlir::Value{});
351         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
352             [](mlir::func::ReturnOp ret) {
353               mlir::Type retTy = ret.getOperand(0).getType();
354               return !fir::isa_builtin_cptr_type(retTy);
355             });
356         return;
357       }
358       if (!func.empty()) {
359         // Insert new argument.
360         mlir::OpBuilder rewriter(context);
361         auto resultType = funcTy.getResult(0);
362         auto argTy = getResultArgumentType(resultType, shouldBoxResult);
363         func.insertArgument(0u, argTy, {}, loc);
364         func.eraseResult(0u);
365         mlir::Value newArg = func.getArgument(0u);
366         if (mustEmboxResult(resultType, shouldBoxResult)) {
367           auto bufferType = fir::ReferenceType::get(resultType);
368           rewriter.setInsertionPointToStart(&func.front());
369           newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
370         }
371         patterns.insert<ReturnOpConversion>(context, newArg);
372         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
373             [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); });
374         assert(func.getFunctionType() ==
375                getNewFunctionType(funcTy, shouldBoxResult));
376       } else {
377         llvm::SmallVector<mlir::DictionaryAttr> allArgs;
378         func.getAllArgAttrs(allArgs);
379         allArgs.insert(allArgs.begin(),
380                        mlir::DictionaryAttr::get(func->getContext()));
381         func.setType(getNewFunctionType(funcTy, shouldBoxResult));
382         func.setAllArgAttrs(allArgs);
383       }
384     }
385   }
386 
387   inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
388     return mlir::TypeSwitch<mlir::Type, bool>(type)
389         .Case([](fir::BoxProcType boxProc) {
390           return fir::hasAbstractResult(
391               mlir::cast<mlir::FunctionType>(boxProc.getEleTy()));
392         })
393         .Case([](fir::PointerType pointer) {
394           return fir::hasAbstractResult(
395               mlir::cast<mlir::FunctionType>(pointer.getEleTy()));
396         })
397         .Default([](auto &&) { return false; });
398   }
399 
400   void runOnSpecificOperation(fir::GlobalOp global, bool,
401                               mlir::RewritePatternSet &,
402                               mlir::ConversionTarget &) {
403     if (containsFunctionTypeWithAbstractResult(global.getType())) {
404       TODO(global->getLoc(), "support for procedure pointers");
405     }
406   }
407 
408   /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work.
409   void runOnModule() {
410     mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation());
411 
412     auto pass = std::make_unique<AbstractResultOpt>();
413     pass->copyOptionValuesFrom(this);
414     mlir::OpPassManager pipeline;
415     pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()});
416 
417     // Run the pass on all operations directly nested inside of the ModuleOp
418     // we can't just call runOnSpecificOperation here because the pass
419     // implementation only works when scoped to a particular func.func or
420     // fir.global
421     for (mlir::Region &region : mod->getRegions()) {
422       for (mlir::Block &block : region.getBlocks()) {
423         for (mlir::Operation &op : block.getOperations()) {
424           if (mlir::failed(runPipeline(pipeline, &op))) {
425             mlir::emitError(op.getLoc(), "Failed to run abstract result pass");
426             signalPassFailure();
427             return;
428           }
429         }
430       }
431     }
432   }
433 
434   void runOnOperation() override {
435     auto *context = &this->getContext();
436     mlir::Operation *op = this->getOperation();
437     if (mlir::isa<mlir::ModuleOp>(op)) {
438       runOnModule();
439       return;
440     }
441 
442     LazySymbolTable symbolTable(op);
443 
444     mlir::RewritePatternSet patterns(context);
445     mlir::ConversionTarget target = *context;
446     const bool shouldBoxResult = this->passResultAsBox.getValue();
447 
448     mlir::TypeSwitch<mlir::Operation *, void>(op)
449         .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
450           runOnSpecificOperation(op, shouldBoxResult, patterns, target);
451         });
452 
453     // Convert the calls and, if needed,  the ReturnOp in the function body.
454     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
455                            mlir::func::FuncDialect>();
456     target.addIllegalOp<fir::SaveResultOp>();
457     target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
458       mlir::FunctionType funTy = call.getFunctionType();
459       if (hasScalarDerivedResult(funTy) &&
460           fir::hasBindcAttr(call.getOperation()))
461         return true;
462       return !hasAbstractResult(funTy);
463     });
464     target.addDynamicallyLegalOp<fir::AddrOfOp>([&symbolTable](
465                                                     fir::AddrOfOp addrOf) {
466       if (auto funTy = mlir::dyn_cast<mlir::FunctionType>(addrOf.getType())) {
467         if (hasScalarDerivedResult(funTy)) {
468           auto func = symbolTable.lookup<mlir::func::FuncOp>(
469               addrOf.getSymbol().getRootReference().getValue());
470           return func && fir::hasBindcAttr(func.getOperation());
471         }
472         return !hasAbstractResult(funTy);
473       }
474       return true;
475     });
476     target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
477       mlir::FunctionType funTy = dispatch.getFunctionType();
478       if (hasScalarDerivedResult(funTy) &&
479           fir::hasBindcAttr(dispatch.getOperation()))
480         return true;
481       return !hasAbstractResult(dispatch.getFunctionType());
482     });
483 
484     patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
485     patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
486     patterns.insert<SaveResultOpConversion>(context);
487     patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
488     if (mlir::failed(
489             mlir::applyPartialConversion(op, target, std::move(patterns)))) {
490       mlir::emitError(op->getLoc(), "error in converting abstract results\n");
491       this->signalPassFailure();
492     }
493   }
494 };
495 
496 } // end anonymous namespace
497 } // namespace fir
498