xref: /llvm-project/flang/lib/Optimizer/Transforms/AbstractResult.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/Todo.h"
11 #include "flang/Optimizer/Dialect/FIRDialect.h"
12 #include "flang/Optimizer/Dialect/FIROps.h"
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
15 #include "flang/Optimizer/Transforms/Passes.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Diagnostics.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Pass/PassManager.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace fir {
24 #define GEN_PASS_DEF_ABSTRACTRESULTOPT
25 #include "flang/Optimizer/Transforms/Passes.h.inc"
26 } // namespace fir
27 
28 #define DEBUG_TYPE "flang-abstract-result-opt"
29 
30 using namespace mlir;
31 
32 namespace fir {
33 namespace {
34 
35 static mlir::Type getResultArgumentType(mlir::Type resultType,
36                                         bool shouldBoxResult) {
37   return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
38       .Case<fir::SequenceType, fir::RecordType>(
39           [&](mlir::Type type) -> mlir::Type {
40             if (shouldBoxResult)
41               return fir::BoxType::get(type);
42             return fir::ReferenceType::get(type);
43           })
44       .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type {
45         return fir::ReferenceType::get(type);
46       })
47       .Default([](mlir::Type) -> mlir::Type {
48         llvm_unreachable("bad abstract result type");
49       });
50 }
51 
52 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
53                                              bool shouldBoxResult) {
54   auto resultType = funcTy.getResult(0);
55   auto argTy = getResultArgumentType(resultType, shouldBoxResult);
56   llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
57   newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
58   return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
59                                  /*resultTypes=*/{});
60 }
61 
62 /// This is for function result types that are of type C_PTR from ISO_C_BINDING.
63 /// Follow the ABI for interoperability with C.
64 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
65   auto resultType = funcTy.getResult(0);
66   assert(fir::isa_builtin_cptr_type(resultType));
67   llvm::SmallVector<mlir::Type> outputTypes;
68   auto recTy = mlir::dyn_cast<fir::RecordType>(resultType);
69   outputTypes.emplace_back(recTy.getTypeList()[0].second);
70   return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
71                                  outputTypes);
72 }
73 
74 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
75   return mlir::isa<fir::SequenceType, fir::RecordType>(resultType) &&
76          shouldBoxResult;
77 }
78 
79 template <typename Op>
80 class CallConversion : public mlir::OpRewritePattern<Op> {
81 public:
82   using mlir::OpRewritePattern<Op>::OpRewritePattern;
83 
84   CallConversion(mlir::MLIRContext *context, bool shouldBoxResult)
85       : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {}
86 
87   llvm::LogicalResult
88   matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
89     auto loc = op.getLoc();
90     auto result = op->getResult(0);
91     if (!result.hasOneUse()) {
92       mlir::emitError(loc,
93                       "calls with abstract result must have exactly one user");
94       return mlir::failure();
95     }
96     auto saveResult =
97         mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
98     if (!saveResult) {
99       mlir::emitError(
100           loc, "calls with abstract result must be used in fir.save_result");
101       return mlir::failure();
102     }
103     auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
104     auto buffer = saveResult.getMemref();
105     mlir::Value arg = buffer;
106     if (mustEmboxResult(result.getType(), shouldBoxResult))
107       arg = rewriter.create<fir::EmboxOp>(
108           loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
109           saveResult.getTypeparams());
110 
111     llvm::SmallVector<mlir::Type> newResultTypes;
112     // TODO: This should be generalized for derived types, and it is
113     // architecture and OS dependent.
114     bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
115     Op newOp;
116     if (isResultBuiltinCPtr) {
117       auto recTy = mlir::dyn_cast<fir::RecordType>(result.getType());
118       newResultTypes.emplace_back(recTy.getTypeList()[0].second);
119     }
120 
121     // fir::CallOp specific handling.
122     if constexpr (std::is_same_v<Op, fir::CallOp>) {
123       if (op.getCallee()) {
124         llvm::SmallVector<mlir::Value> newOperands;
125         if (!isResultBuiltinCPtr)
126           newOperands.emplace_back(arg);
127         newOperands.append(op.getOperands().begin(), op.getOperands().end());
128         newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(),
129                                              newResultTypes, newOperands);
130       } else {
131         // Indirect calls.
132         llvm::SmallVector<mlir::Type> newInputTypes;
133         if (!isResultBuiltinCPtr)
134           newInputTypes.emplace_back(argType);
135         for (auto operand : op.getOperands().drop_front())
136           newInputTypes.push_back(operand.getType());
137         auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes,
138                                                  newResultTypes);
139 
140         llvm::SmallVector<mlir::Value> newOperands;
141         newOperands.push_back(
142             rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0)));
143         if (!isResultBuiltinCPtr)
144           newOperands.push_back(arg);
145         newOperands.append(op.getOperands().begin() + 1,
146                            op.getOperands().end());
147         newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
148                                              newResultTypes, newOperands);
149       }
150     }
151 
152     // fir::DispatchOp specific handling.
153     if constexpr (std::is_same_v<Op, fir::DispatchOp>) {
154       llvm::SmallVector<mlir::Value> newOperands;
155       if (!isResultBuiltinCPtr)
156         newOperands.emplace_back(arg);
157       unsigned passArgShift = newOperands.size();
158       newOperands.append(op.getOperands().begin() + 1, op.getOperands().end());
159 
160       fir::DispatchOp newDispatchOp;
161       if (op.getPassArgPos())
162         newOp = rewriter.create<fir::DispatchOp>(
163             loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
164             op.getOperands()[0], newOperands,
165             rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift));
166       else
167         newOp = rewriter.create<fir::DispatchOp>(
168             loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
169             op.getOperands()[0], newOperands, nullptr);
170     }
171 
172     if (isResultBuiltinCPtr) {
173       mlir::Value save = saveResult.getMemref();
174       auto module = op->template getParentOfType<mlir::ModuleOp>();
175       FirOpBuilder builder(rewriter, module);
176       mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
177           builder, loc, save, result.getType());
178       rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr);
179     }
180     op->dropAllReferences();
181     rewriter.eraseOp(op);
182     return mlir::success();
183   }
184 
185 private:
186   bool shouldBoxResult;
187 };
188 
189 class SaveResultOpConversion
190     : public mlir::OpRewritePattern<fir::SaveResultOp> {
191 public:
192   using OpRewritePattern::OpRewritePattern;
193   SaveResultOpConversion(mlir::MLIRContext *context)
194       : OpRewritePattern(context) {}
195   llvm::LogicalResult
196   matchAndRewrite(fir::SaveResultOp op,
197                   mlir::PatternRewriter &rewriter) const override {
198     rewriter.eraseOp(op);
199     return mlir::success();
200   }
201 };
202 
203 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
204 public:
205   using OpRewritePattern::OpRewritePattern;
206   ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
207       : OpRewritePattern(context), newArg{newArg} {}
208   llvm::LogicalResult
209   matchAndRewrite(mlir::func::ReturnOp ret,
210                   mlir::PatternRewriter &rewriter) const override {
211     auto loc = ret.getLoc();
212     rewriter.setInsertionPoint(ret);
213     auto returnedValue = ret.getOperand(0);
214     bool replacedStorage = false;
215     if (auto *op = returnedValue.getDefiningOp())
216       if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
217         auto resultStorage = load.getMemref();
218         // The result alloca may be behind a fir.declare, if any.
219         if (auto declare = mlir::dyn_cast_or_null<fir::DeclareOp>(
220                 resultStorage.getDefiningOp()))
221           resultStorage = declare.getMemref();
222         // TODO: This should be generalized for derived types, and it is
223         // architecture and OS dependent.
224         if (fir::isa_builtin_cptr_type(returnedValue.getType())) {
225           rewriter.eraseOp(load);
226           auto module = ret->getParentOfType<mlir::ModuleOp>();
227           FirOpBuilder builder(rewriter, module);
228           mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr(
229               builder, loc, resultStorage, returnedValue.getType());
230           mlir::Value retValue = rewriter.create<fir::LoadOp>(
231               loc, fir::unwrapRefType(retAddr.getType()), retAddr);
232           rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
233               ret, mlir::ValueRange{retValue});
234           return mlir::success();
235         }
236         resultStorage.replaceAllUsesWith(newArg);
237         replacedStorage = true;
238         if (auto *alloc = resultStorage.getDefiningOp())
239           if (alloc->use_empty())
240             rewriter.eraseOp(alloc);
241       }
242     // The result storage may have been optimized out by a memory to
243     // register pass, this is possible for fir.box results, or fir.record
244     // with no length parameters. Simply store the result in the result storage.
245     // at the return point.
246     if (!replacedStorage)
247       rewriter.create<fir::StoreOp>(loc, returnedValue, newArg);
248     rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
249     return mlir::success();
250   }
251 
252 private:
253   mlir::Value newArg;
254 };
255 
256 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
257 public:
258   using OpRewritePattern::OpRewritePattern;
259   AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
260       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
261   llvm::LogicalResult
262   matchAndRewrite(fir::AddrOfOp addrOf,
263                   mlir::PatternRewriter &rewriter) const override {
264     auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType());
265     mlir::FunctionType newFuncTy;
266     // TODO: This should be generalized for derived types, and it is
267     // architecture and OS dependent.
268     if (oldFuncTy.getNumResults() != 0 &&
269         fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
270       newFuncTy = getCPtrFunctionType(oldFuncTy);
271     else
272       newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
273     auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
274                                                     addrOf.getSymbol());
275     // Rather than converting all op a function pointer might transit through
276     // (e.g calls, stores, loads, converts...), cast new type to the abstract
277     // type. A conversion will be added when calling indirect calls of abstract
278     // types.
279     rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
280     return mlir::success();
281   }
282 
283 private:
284   bool shouldBoxResult;
285 };
286 
287 class AbstractResultOpt
288     : public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
289 public:
290   using fir::impl::AbstractResultOptBase<
291       AbstractResultOpt>::AbstractResultOptBase;
292 
293   void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
294                               mlir::RewritePatternSet &patterns,
295                               mlir::ConversionTarget &target) {
296     auto loc = func.getLoc();
297     auto *context = &getContext();
298     // Convert function type itself if it has an abstract result.
299     auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
300     if (hasAbstractResult(funcTy)) {
301       // TODO: This should be generalized for derived types, and it is
302       // architecture and OS dependent.
303       if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
304         func.setType(getCPtrFunctionType(funcTy));
305         patterns.insert<ReturnOpConversion>(context, mlir::Value{});
306         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
307             [](mlir::func::ReturnOp ret) {
308               mlir::Type retTy = ret.getOperand(0).getType();
309               return !fir::isa_builtin_cptr_type(retTy);
310             });
311         return;
312       }
313       if (!func.empty()) {
314         // Insert new argument.
315         mlir::OpBuilder rewriter(context);
316         auto resultType = funcTy.getResult(0);
317         auto argTy = getResultArgumentType(resultType, shouldBoxResult);
318         func.insertArgument(0u, argTy, {}, loc);
319         func.eraseResult(0u);
320         mlir::Value newArg = func.getArgument(0u);
321         if (mustEmboxResult(resultType, shouldBoxResult)) {
322           auto bufferType = fir::ReferenceType::get(resultType);
323           rewriter.setInsertionPointToStart(&func.front());
324           newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
325         }
326         patterns.insert<ReturnOpConversion>(context, newArg);
327         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
328             [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); });
329         assert(func.getFunctionType() ==
330                getNewFunctionType(funcTy, shouldBoxResult));
331       } else {
332         llvm::SmallVector<mlir::DictionaryAttr> allArgs;
333         func.getAllArgAttrs(allArgs);
334         allArgs.insert(allArgs.begin(),
335                        mlir::DictionaryAttr::get(func->getContext()));
336         func.setType(getNewFunctionType(funcTy, shouldBoxResult));
337         func.setAllArgAttrs(allArgs);
338       }
339     }
340   }
341 
342   inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
343     return mlir::TypeSwitch<mlir::Type, bool>(type)
344         .Case([](fir::BoxProcType boxProc) {
345           return fir::hasAbstractResult(
346               mlir::cast<mlir::FunctionType>(boxProc.getEleTy()));
347         })
348         .Case([](fir::PointerType pointer) {
349           return fir::hasAbstractResult(
350               mlir::cast<mlir::FunctionType>(pointer.getEleTy()));
351         })
352         .Default([](auto &&) { return false; });
353   }
354 
355   void runOnSpecificOperation(fir::GlobalOp global, bool,
356                               mlir::RewritePatternSet &,
357                               mlir::ConversionTarget &) {
358     if (containsFunctionTypeWithAbstractResult(global.getType())) {
359       TODO(global->getLoc(), "support for procedure pointers");
360     }
361   }
362 
363   /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work.
364   void runOnModule() {
365     mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation());
366 
367     auto pass = std::make_unique<AbstractResultOpt>();
368     pass->copyOptionValuesFrom(this);
369     mlir::OpPassManager pipeline;
370     pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()});
371 
372     // Run the pass on all operations directly nested inside of the ModuleOp
373     // we can't just call runOnSpecificOperation here because the pass
374     // implementation only works when scoped to a particular func.func or
375     // fir.global
376     for (mlir::Region &region : mod->getRegions()) {
377       for (mlir::Block &block : region.getBlocks()) {
378         for (mlir::Operation &op : block.getOperations()) {
379           if (mlir::failed(runPipeline(pipeline, &op))) {
380             mlir::emitError(op.getLoc(), "Failed to run abstract result pass");
381             signalPassFailure();
382             return;
383           }
384         }
385       }
386     }
387   }
388 
389   void runOnOperation() override {
390     auto *context = &this->getContext();
391     mlir::Operation *op = this->getOperation();
392     if (mlir::isa<mlir::ModuleOp>(op)) {
393       runOnModule();
394       return;
395     }
396 
397     mlir::RewritePatternSet patterns(context);
398     mlir::ConversionTarget target = *context;
399     const bool shouldBoxResult = this->passResultAsBox.getValue();
400 
401     mlir::TypeSwitch<mlir::Operation *, void>(op)
402         .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
403           runOnSpecificOperation(op, shouldBoxResult, patterns, target);
404         });
405 
406     // Convert the calls and, if needed,  the ReturnOp in the function body.
407     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
408                            mlir::func::FuncDialect>();
409     target.addIllegalOp<fir::SaveResultOp>();
410     target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
411       return !hasAbstractResult(call.getFunctionType());
412     });
413     target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
414       if (auto funTy = mlir::dyn_cast<mlir::FunctionType>(addrOf.getType()))
415         return !hasAbstractResult(funTy);
416       return true;
417     });
418     target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
419       return !hasAbstractResult(dispatch.getFunctionType());
420     });
421 
422     patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
423     patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
424     patterns.insert<SaveResultOpConversion>(context);
425     patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
426     if (mlir::failed(
427             mlir::applyPartialConversion(op, target, std::move(patterns)))) {
428       mlir::emitError(op->getLoc(), "error in converting abstract results\n");
429       this->signalPassFailure();
430     }
431   }
432 };
433 
434 } // end anonymous namespace
435 } // namespace fir
436