xref: /llvm-project/flang/lib/Optimizer/Transforms/AbstractResult.cpp (revision 75623bfe1b89fa84cf2b9e4fb4c9f7560e01d4a6)
1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/Todo.h"
11 #include "flang/Optimizer/Dialect/FIRDialect.h"
12 #include "flang/Optimizer/Dialect/FIROps.h"
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
15 #include "flang/Optimizer/Transforms/Passes.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Pass/PassManager.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 
24 namespace fir {
25 #define GEN_PASS_DEF_ABSTRACTRESULTOPT
26 #include "flang/Optimizer/Transforms/Passes.h.inc"
27 } // namespace fir
28 
29 #define DEBUG_TYPE "flang-abstract-result-opt"
30 
31 using namespace mlir;
32 
33 namespace fir {
34 namespace {
35 
36 // Helper to only build the symbol table if needed because its build time is
37 // linear on the number of symbols in the module.
38 struct LazySymbolTable {
39   LazySymbolTable(mlir::Operation *op)
40       : module{op->getParentOfType<mlir::ModuleOp>()} {}
41   void build() {
42     if (table)
43       return;
44     table = std::make_unique<mlir::SymbolTable>(module);
45   }
46 
47   template <typename T>
48   T lookup(llvm::StringRef name) {
49     build();
50     return table->lookup<T>(name);
51   }
52 
53 private:
54   std::unique_ptr<mlir::SymbolTable> table;
55   mlir::ModuleOp module;
56 };
57 
58 bool hasScalarDerivedResult(mlir::FunctionType funTy) {
59   // C_PTR/C_FUNPTR are results to void* in this pass, do not consider
60   // them as normal derived types.
61   return funTy.getNumResults() == 1 &&
62          mlir::isa<fir::RecordType>(funTy.getResult(0)) &&
63          !fir::isa_builtin_cptr_type(funTy.getResult(0));
64 }
65 
66 static mlir::Type getResultArgumentType(mlir::Type resultType,
67                                         bool shouldBoxResult) {
68   return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
69       .Case<fir::SequenceType, fir::RecordType>(
70           [&](mlir::Type type) -> mlir::Type {
71             if (shouldBoxResult)
72               return fir::BoxType::get(type);
73             return fir::ReferenceType::get(type);
74           })
75       .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type {
76         return fir::ReferenceType::get(type);
77       })
78       .Default([](mlir::Type) -> mlir::Type {
79         llvm_unreachable("bad abstract result type");
80       });
81 }
82 
83 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
84                                              bool shouldBoxResult) {
85   auto resultType = funcTy.getResult(0);
86   auto argTy = getResultArgumentType(resultType, shouldBoxResult);
87   llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
88   newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
89   return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
90                                  /*resultTypes=*/{});
91 }
92 
93 static mlir::Type getVoidPtrType(mlir::MLIRContext *context) {
94   return fir::ReferenceType::get(mlir::NoneType::get(context));
95 }
96 
97 /// This is for function result types that are of type C_PTR from ISO_C_BINDING.
98 /// Follow the ABI for interoperability with C.
99 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
100   assert(fir::isa_builtin_cptr_type(funcTy.getResult(0)));
101   llvm::SmallVector<mlir::Type> outputTypes{
102       getVoidPtrType(funcTy.getContext())};
103   return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
104                                  outputTypes);
105 }
106 
107 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
108   return mlir::isa<fir::SequenceType, fir::RecordType>(resultType) &&
109          shouldBoxResult;
110 }
111 
112 template <typename Op>
113 class CallConversion : public mlir::OpRewritePattern<Op> {
114 public:
115   using mlir::OpRewritePattern<Op>::OpRewritePattern;
116 
117   CallConversion(mlir::MLIRContext *context, bool shouldBoxResult)
118       : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {}
119 
120   llvm::LogicalResult
121   matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
122     auto loc = op.getLoc();
123     auto result = op->getResult(0);
124     if (!result.hasOneUse()) {
125       mlir::emitError(loc,
126                       "calls with abstract result must have exactly one user");
127       return mlir::failure();
128     }
129     auto saveResult =
130         mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
131     if (!saveResult) {
132       mlir::emitError(
133           loc, "calls with abstract result must be used in fir.save_result");
134       return mlir::failure();
135     }
136     auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
137     auto buffer = saveResult.getMemref();
138     mlir::Value arg = buffer;
139     if (mustEmboxResult(result.getType(), shouldBoxResult))
140       arg = rewriter.create<fir::EmboxOp>(
141           loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
142           saveResult.getTypeparams());
143 
144     llvm::SmallVector<mlir::Type> newResultTypes;
145     bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
146     if (isResultBuiltinCPtr)
147       newResultTypes.emplace_back(getVoidPtrType(result.getContext()));
148 
149     Op newOp;
150     // fir::CallOp specific handling.
151     if constexpr (std::is_same_v<Op, fir::CallOp>) {
152       if (op.getCallee()) {
153         llvm::SmallVector<mlir::Value> newOperands;
154         if (!isResultBuiltinCPtr)
155           newOperands.emplace_back(arg);
156         newOperands.append(op.getOperands().begin(), op.getOperands().end());
157         newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(),
158                                              newResultTypes, newOperands);
159       } else {
160         // Indirect calls.
161         llvm::SmallVector<mlir::Type> newInputTypes;
162         if (!isResultBuiltinCPtr)
163           newInputTypes.emplace_back(argType);
164         for (auto operand : op.getOperands().drop_front())
165           newInputTypes.push_back(operand.getType());
166         auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes,
167                                                  newResultTypes);
168 
169         llvm::SmallVector<mlir::Value> newOperands;
170         newOperands.push_back(
171             rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0)));
172         if (!isResultBuiltinCPtr)
173           newOperands.push_back(arg);
174         newOperands.append(op.getOperands().begin() + 1,
175                            op.getOperands().end());
176         newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
177                                              newResultTypes, newOperands);
178       }
179     }
180 
181     // fir::DispatchOp specific handling.
182     if constexpr (std::is_same_v<Op, fir::DispatchOp>) {
183       llvm::SmallVector<mlir::Value> newOperands;
184       if (!isResultBuiltinCPtr)
185         newOperands.emplace_back(arg);
186       unsigned passArgShift = newOperands.size();
187       newOperands.append(op.getOperands().begin() + 1, op.getOperands().end());
188       mlir::IntegerAttr passArgPos;
189       if (op.getPassArgPos())
190         passArgPos =
191             rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift);
192       newOp = rewriter.create<fir::DispatchOp>(
193           loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
194           op.getOperands()[0], newOperands, passArgPos,
195           op.getProcedureAttrsAttr());
196     }
197 
198     if (isResultBuiltinCPtr) {
199       mlir::Value save = saveResult.getMemref();
200       auto module = op->template getParentOfType<mlir::ModuleOp>();
201       FirOpBuilder builder(rewriter, module);
202       mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
203           builder, loc, save, result.getType());
204       builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr);
205     }
206     op->dropAllReferences();
207     rewriter.eraseOp(op);
208     return mlir::success();
209   }
210 
211 private:
212   bool shouldBoxResult;
213 };
214 
215 class SaveResultOpConversion
216     : public mlir::OpRewritePattern<fir::SaveResultOp> {
217 public:
218   using OpRewritePattern::OpRewritePattern;
219   SaveResultOpConversion(mlir::MLIRContext *context)
220       : OpRewritePattern(context) {}
221   llvm::LogicalResult
222   matchAndRewrite(fir::SaveResultOp op,
223                   mlir::PatternRewriter &rewriter) const override {
224     mlir::Operation *call = op.getValue().getDefiningOp();
225     mlir::Type type = op.getValue().getType();
226     if (mlir::isa<fir::RecordType>(type) && call && fir::hasBindcAttr(call) &&
227         !fir::isa_builtin_cptr_type(type)) {
228       rewriter.replaceOpWithNewOp<fir::StoreOp>(op, op.getValue(),
229                                                 op.getMemref());
230     } else {
231       rewriter.eraseOp(op);
232     }
233     return mlir::success();
234   }
235 };
236 
237 template <typename OpTy>
238 static mlir::LogicalResult
239 processReturnLikeOp(OpTy ret, mlir::Value newArg,
240                     mlir::PatternRewriter &rewriter) {
241   auto loc = ret.getLoc();
242   rewriter.setInsertionPoint(ret);
243   mlir::Value resultValue = ret.getOperand(0);
244   fir::LoadOp resultLoad;
245   mlir::Value resultStorage;
246   // Identify result local storage.
247   if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) {
248     resultLoad = load;
249     resultStorage = load.getMemref();
250     // The result alloca may be behind a fir.declare, if any.
251     if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>())
252       resultStorage = declare.getMemref();
253   }
254   // Replace old local storage with new storage argument, unless
255   // the derived type is C_PTR/C_FUN_PTR, in which case the return
256   // type is updated to return void* (no new argument is passed).
257   if (fir::isa_builtin_cptr_type(resultValue.getType())) {
258     auto module = ret->template getParentOfType<mlir::ModuleOp>();
259     FirOpBuilder builder(rewriter, module);
260     mlir::Value cptr = resultValue;
261     if (resultLoad) {
262       // Replace whole derived type load by component load.
263       cptr = resultLoad.getMemref();
264       rewriter.setInsertionPoint(resultLoad);
265     }
266     mlir::Value newResultValue =
267         fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr);
268     newResultValue = builder.createConvert(
269         loc, getVoidPtrType(ret.getContext()), newResultValue);
270     rewriter.setInsertionPoint(ret);
271     rewriter.replaceOpWithNewOp<OpTy>(ret, mlir::ValueRange{newResultValue});
272   } else if (resultStorage) {
273     resultStorage.replaceAllUsesWith(newArg);
274     rewriter.replaceOpWithNewOp<OpTy>(ret);
275   } else {
276     // The result storage may have been optimized out by a memory to
277     // register pass, this is possible for fir.box results, or fir.record
278     // with no length parameters. Simply store the result in the result
279     // storage. at the return point.
280     rewriter.create<fir::StoreOp>(loc, resultValue, newArg);
281     rewriter.replaceOpWithNewOp<OpTy>(ret);
282   }
283   // Delete result old local storage if unused.
284   if (resultStorage)
285     if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>())
286       if (alloc->use_empty())
287         rewriter.eraseOp(alloc);
288   return mlir::success();
289 }
290 
291 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
292 public:
293   using OpRewritePattern::OpRewritePattern;
294   ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
295       : OpRewritePattern(context), newArg{newArg} {}
296   llvm::LogicalResult
297   matchAndRewrite(mlir::func::ReturnOp ret,
298                   mlir::PatternRewriter &rewriter) const override {
299     return processReturnLikeOp(ret, newArg, rewriter);
300   }
301 
302 private:
303   mlir::Value newArg;
304 };
305 
306 class GPUReturnOpConversion
307     : public mlir::OpRewritePattern<mlir::gpu::ReturnOp> {
308 public:
309   using OpRewritePattern::OpRewritePattern;
310   GPUReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
311       : OpRewritePattern(context), newArg{newArg} {}
312   llvm::LogicalResult
313   matchAndRewrite(mlir::gpu::ReturnOp ret,
314                   mlir::PatternRewriter &rewriter) const override {
315     return processReturnLikeOp(ret, newArg, rewriter);
316   }
317 
318 private:
319   mlir::Value newArg;
320 };
321 
322 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
323 public:
324   using OpRewritePattern::OpRewritePattern;
325   AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
326       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
327   llvm::LogicalResult
328   matchAndRewrite(fir::AddrOfOp addrOf,
329                   mlir::PatternRewriter &rewriter) const override {
330     auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType());
331     mlir::FunctionType newFuncTy;
332     if (oldFuncTy.getNumResults() != 0 &&
333         fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
334       newFuncTy = getCPtrFunctionType(oldFuncTy);
335     else
336       newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
337     auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
338                                                     addrOf.getSymbol());
339     // Rather than converting all op a function pointer might transit through
340     // (e.g calls, stores, loads, converts...), cast new type to the abstract
341     // type. A conversion will be added when calling indirect calls of abstract
342     // types.
343     rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
344     return mlir::success();
345   }
346 
347 private:
348   bool shouldBoxResult;
349 };
350 
351 class AbstractResultOpt
352     : public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
353 public:
354   using fir::impl::AbstractResultOptBase<
355       AbstractResultOpt>::AbstractResultOptBase;
356 
357   template <typename OpTy>
358   void runOnFunctionLikeOperation(OpTy func, bool shouldBoxResult,
359                                   mlir::RewritePatternSet &patterns,
360                                   mlir::ConversionTarget &target) {
361     auto loc = func.getLoc();
362     auto *context = &getContext();
363     // Convert function type itself if it has an abstract result.
364     auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
365     // Scalar derived result of BIND(C) function must be returned according
366     // to the C struct return ABI which is target dependent and implemented in
367     // the target-rewrite pass.
368     if (hasScalarDerivedResult(funcTy) &&
369         fir::hasBindcAttr(func.getOperation()))
370       return;
371     if (hasAbstractResult(funcTy)) {
372       if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
373         func.setType(getCPtrFunctionType(funcTy));
374         patterns.insert<ReturnOpConversion>(context, mlir::Value{});
375         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
376             [](mlir::func::ReturnOp ret) {
377               mlir::Type retTy = ret.getOperand(0).getType();
378               return !fir::isa_builtin_cptr_type(retTy);
379             });
380         return;
381       }
382       if (!func.empty()) {
383         // Insert new argument.
384         mlir::OpBuilder rewriter(context);
385         auto resultType = funcTy.getResult(0);
386         auto argTy = getResultArgumentType(resultType, shouldBoxResult);
387         func.insertArgument(0u, argTy, {}, loc);
388         func.eraseResult(0u);
389         mlir::Value newArg = func.getArgument(0u);
390         if (mustEmboxResult(resultType, shouldBoxResult)) {
391           auto bufferType = fir::ReferenceType::get(resultType);
392           rewriter.setInsertionPointToStart(&func.front());
393           newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
394         }
395         patterns.insert<ReturnOpConversion>(context, newArg);
396         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
397             [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); });
398         patterns.insert<GPUReturnOpConversion>(context, newArg);
399         target.addDynamicallyLegalOp<mlir::gpu::ReturnOp>(
400             [](mlir::gpu::ReturnOp ret) { return ret.getOperands().empty(); });
401         assert(func.getFunctionType() ==
402                getNewFunctionType(funcTy, shouldBoxResult));
403       } else {
404         llvm::SmallVector<mlir::DictionaryAttr> allArgs;
405         func.getAllArgAttrs(allArgs);
406         allArgs.insert(allArgs.begin(),
407                        mlir::DictionaryAttr::get(func->getContext()));
408         func.setType(getNewFunctionType(funcTy, shouldBoxResult));
409         func.setAllArgAttrs(allArgs);
410       }
411     }
412   }
413 
414   void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
415                               mlir::RewritePatternSet &patterns,
416                               mlir::ConversionTarget &target) {
417     runOnFunctionLikeOperation(func, shouldBoxResult, patterns, target);
418   }
419 
420   void runOnSpecificOperation(mlir::gpu::GPUFuncOp func, bool shouldBoxResult,
421                               mlir::RewritePatternSet &patterns,
422                               mlir::ConversionTarget &target) {
423     runOnFunctionLikeOperation(func, shouldBoxResult, patterns, target);
424   }
425 
426   inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
427     return mlir::TypeSwitch<mlir::Type, bool>(type)
428         .Case([](fir::BoxProcType boxProc) {
429           return fir::hasAbstractResult(
430               mlir::cast<mlir::FunctionType>(boxProc.getEleTy()));
431         })
432         .Case([](fir::PointerType pointer) {
433           return fir::hasAbstractResult(
434               mlir::cast<mlir::FunctionType>(pointer.getEleTy()));
435         })
436         .Default([](auto &&) { return false; });
437   }
438 
439   void runOnSpecificOperation(fir::GlobalOp global, bool,
440                               mlir::RewritePatternSet &,
441                               mlir::ConversionTarget &) {
442     if (containsFunctionTypeWithAbstractResult(global.getType())) {
443       TODO(global->getLoc(), "support for procedure pointers");
444     }
445   }
446 
447   /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work.
448   void runOnModule() {
449     mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation());
450 
451     auto pass = std::make_unique<AbstractResultOpt>();
452     pass->copyOptionValuesFrom(this);
453     mlir::OpPassManager pipeline;
454     pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()});
455 
456     // Run the pass on all operations directly nested inside of the ModuleOp
457     // we can't just call runOnSpecificOperation here because the pass
458     // implementation only works when scoped to a particular func.func or
459     // fir.global
460     for (mlir::Region &region : mod->getRegions()) {
461       for (mlir::Block &block : region.getBlocks()) {
462         for (mlir::Operation &op : block.getOperations()) {
463           if (mlir::failed(runPipeline(pipeline, &op))) {
464             mlir::emitError(op.getLoc(), "Failed to run abstract result pass");
465             signalPassFailure();
466             return;
467           }
468         }
469       }
470     }
471   }
472 
473   void runOnOperation() override {
474     auto *context = &this->getContext();
475     mlir::Operation *op = this->getOperation();
476     if (mlir::isa<mlir::ModuleOp>(op)) {
477       runOnModule();
478       return;
479     }
480 
481     LazySymbolTable symbolTable(op);
482 
483     mlir::RewritePatternSet patterns(context);
484     mlir::ConversionTarget target = *context;
485     const bool shouldBoxResult = this->passResultAsBox.getValue();
486 
487     mlir::TypeSwitch<mlir::Operation *, void>(op)
488         .Case<mlir::func::FuncOp, fir::GlobalOp, mlir::gpu::GPUFuncOp>(
489             [&](auto op) {
490               runOnSpecificOperation(op, shouldBoxResult, patterns, target);
491             });
492 
493     // Convert the calls and, if needed,  the ReturnOp in the function body.
494     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
495                            mlir::func::FuncDialect>();
496     target.addIllegalOp<fir::SaveResultOp>();
497     target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
498       mlir::FunctionType funTy = call.getFunctionType();
499       if (hasScalarDerivedResult(funTy) &&
500           fir::hasBindcAttr(call.getOperation()))
501         return true;
502       return !hasAbstractResult(funTy);
503     });
504     target.addDynamicallyLegalOp<fir::AddrOfOp>([&symbolTable](
505                                                     fir::AddrOfOp addrOf) {
506       if (auto funTy = mlir::dyn_cast<mlir::FunctionType>(addrOf.getType())) {
507         if (hasScalarDerivedResult(funTy)) {
508           auto func = symbolTable.lookup<mlir::func::FuncOp>(
509               addrOf.getSymbol().getRootReference().getValue());
510           return func && fir::hasBindcAttr(func.getOperation());
511         }
512         return !hasAbstractResult(funTy);
513       }
514       return true;
515     });
516     target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
517       mlir::FunctionType funTy = dispatch.getFunctionType();
518       if (hasScalarDerivedResult(funTy) &&
519           fir::hasBindcAttr(dispatch.getOperation()))
520         return true;
521       return !hasAbstractResult(dispatch.getFunctionType());
522     });
523 
524     patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
525     patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
526     patterns.insert<SaveResultOpConversion>(context);
527     patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
528     if (mlir::failed(
529             mlir::applyPartialConversion(op, target, std::move(patterns)))) {
530       mlir::emitError(op->getLoc(), "error in converting abstract results\n");
531       this->signalPassFailure();
532     }
533   }
534 };
535 
536 } // end anonymous namespace
537 } // namespace fir
538