xref: /llvm-project/flang/lib/Optimizer/Transforms/AbstractResult.cpp (revision aa8feeefd3ac6c78ee8f67bf033976fc7d68bc6d)
1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Builder/Todo.h"
11 #include "flang/Optimizer/Dialect/FIRDialect.h"
12 #include "flang/Optimizer/Dialect/FIROps.h"
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Transforms/Passes.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/IR/Diagnostics.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include "mlir/Transforms/Passes.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 #define DEBUG_TYPE "flang-abstract-result-opt"
23 
24 namespace fir {
25 namespace {
26 
27 struct AbstractResultOptions {
28   // Always pass result as a fir.box argument.
29   bool boxResult = false;
30   // New function block argument for the result if the current FuncOp had
31   // an abstract result.
32   mlir::Value newArg;
33 };
34 
35 static mlir::Type getResultArgumentType(mlir::Type resultType,
36                                         const AbstractResultOptions &options) {
37   return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
38       .Case<fir::SequenceType, fir::RecordType>(
39           [&](mlir::Type type) -> mlir::Type {
40             if (options.boxResult)
41               return fir::BoxType::get(type);
42             return fir::ReferenceType::get(type);
43           })
44       .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type {
45         return fir::ReferenceType::get(type);
46       })
47       .Default([](mlir::Type) -> mlir::Type {
48         llvm_unreachable("bad abstract result type");
49       });
50 }
51 
52 static mlir::FunctionType
53 getNewFunctionType(mlir::FunctionType funcTy,
54                    const AbstractResultOptions &options) {
55   auto resultType = funcTy.getResult(0);
56   auto argTy = getResultArgumentType(resultType, options);
57   llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
58   newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
59   return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
60                                  /*resultTypes=*/{});
61 }
62 
63 static bool mustEmboxResult(mlir::Type resultType,
64                             const AbstractResultOptions &options) {
65   return resultType.isa<fir::SequenceType, fir::RecordType>() &&
66          options.boxResult;
67 }
68 
69 class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
70 public:
71   using OpRewritePattern::OpRewritePattern;
72   CallOpConversion(mlir::MLIRContext *context, const AbstractResultOptions &opt)
73       : OpRewritePattern(context), options{opt} {}
74   mlir::LogicalResult
75   matchAndRewrite(fir::CallOp callOp,
76                   mlir::PatternRewriter &rewriter) const override {
77     auto loc = callOp.getLoc();
78     auto result = callOp->getResult(0);
79     if (!result.hasOneUse()) {
80       mlir::emitError(loc,
81                       "calls with abstract result must have exactly one user");
82       return mlir::failure();
83     }
84     auto saveResult =
85         mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
86     if (!saveResult) {
87       mlir::emitError(
88           loc, "calls with abstract result must be used in fir.save_result");
89       return mlir::failure();
90     }
91     auto argType = getResultArgumentType(result.getType(), options);
92     auto buffer = saveResult.getMemref();
93     mlir::Value arg = buffer;
94     if (mustEmboxResult(result.getType(), options))
95       arg = rewriter.create<fir::EmboxOp>(
96           loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
97           saveResult.getTypeparams());
98 
99     llvm::SmallVector<mlir::Type> newResultTypes;
100     if (callOp.getCallee()) {
101       llvm::SmallVector<mlir::Value> newOperands = {arg};
102       newOperands.append(callOp.getOperands().begin(),
103                          callOp.getOperands().end());
104       rewriter.create<fir::CallOp>(loc, callOp.getCallee().value(),
105                                    newResultTypes, newOperands);
106     } else {
107       // Indirect calls.
108       llvm::SmallVector<mlir::Type> newInputTypes = {argType};
109       for (auto operand : callOp.getOperands().drop_front())
110         newInputTypes.push_back(operand.getType());
111       auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes,
112                                            newResultTypes);
113 
114       llvm::SmallVector<mlir::Value> newOperands;
115       newOperands.push_back(
116           rewriter.create<fir::ConvertOp>(loc, funTy, callOp.getOperand(0)));
117       newOperands.push_back(arg);
118       newOperands.append(callOp.getOperands().begin() + 1,
119                          callOp.getOperands().end());
120       rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, newResultTypes,
121                                    newOperands);
122     }
123     callOp->dropAllReferences();
124     rewriter.eraseOp(callOp);
125     return mlir::success();
126   }
127 
128 private:
129   const AbstractResultOptions &options;
130 };
131 
132 class SaveResultOpConversion
133     : public mlir::OpRewritePattern<fir::SaveResultOp> {
134 public:
135   using OpRewritePattern::OpRewritePattern;
136   SaveResultOpConversion(mlir::MLIRContext *context)
137       : OpRewritePattern(context) {}
138   mlir::LogicalResult
139   matchAndRewrite(fir::SaveResultOp op,
140                   mlir::PatternRewriter &rewriter) const override {
141     rewriter.eraseOp(op);
142     return mlir::success();
143   }
144 };
145 
146 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
147 public:
148   using OpRewritePattern::OpRewritePattern;
149   ReturnOpConversion(mlir::MLIRContext *context,
150                      const AbstractResultOptions &opt)
151       : OpRewritePattern(context), options{opt} {}
152   mlir::LogicalResult
153   matchAndRewrite(mlir::func::ReturnOp ret,
154                   mlir::PatternRewriter &rewriter) const override {
155     rewriter.setInsertionPoint(ret);
156     auto returnedValue = ret.getOperand(0);
157     bool replacedStorage = false;
158     if (auto *op = returnedValue.getDefiningOp())
159       if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
160         auto resultStorage = load.getMemref();
161         load.getMemref().replaceAllUsesWith(options.newArg);
162         replacedStorage = true;
163         if (auto *alloc = resultStorage.getDefiningOp())
164           if (alloc->use_empty())
165             rewriter.eraseOp(alloc);
166       }
167     // The result storage may have been optimized out by a memory to
168     // register pass, this is possible for fir.box results, or fir.record
169     // with no length parameters. Simply store the result in the result storage.
170     // at the return point.
171     if (!replacedStorage)
172       rewriter.create<fir::StoreOp>(ret.getLoc(), returnedValue,
173                                     options.newArg);
174     rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
175     return mlir::success();
176   }
177 
178 private:
179   const AbstractResultOptions &options;
180 };
181 
182 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
183 public:
184   using OpRewritePattern::OpRewritePattern;
185   AddrOfOpConversion(mlir::MLIRContext *context,
186                      const AbstractResultOptions &opt)
187       : OpRewritePattern(context), options{opt} {}
188   mlir::LogicalResult
189   matchAndRewrite(fir::AddrOfOp addrOf,
190                   mlir::PatternRewriter &rewriter) const override {
191     auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>();
192     auto newFuncTy = getNewFunctionType(oldFuncTy, options);
193     auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
194                                                     addrOf.getSymbol());
195     // Rather than converting all op a function pointer might transit through
196     // (e.g calls, stores, loads, converts...), cast new type to the abstract
197     // type. A conversion will be added when calling indirect calls of abstract
198     // types.
199     rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
200     return mlir::success();
201   }
202 
203 private:
204   const AbstractResultOptions &options;
205 };
206 
207 class AbstractResultOpt : public fir::AbstractResultOptBase<AbstractResultOpt> {
208 public:
209   void runOnOperation() override {
210     auto *context = &getContext();
211     auto func = getOperation();
212     auto loc = func.getLoc();
213     mlir::RewritePatternSet patterns(context);
214     mlir::ConversionTarget target = *context;
215     AbstractResultOptions options{passResultAsBox.getValue(),
216                                   /*newArg=*/{}};
217 
218     // Convert function type itself if it has an abstract result
219     auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
220     if (hasAbstractResult(funcTy)) {
221       func.setType(getNewFunctionType(funcTy, options));
222       unsigned zero = 0;
223       if (!func.empty()) {
224         // Insert new argument
225         mlir::OpBuilder rewriter(context);
226         auto resultType = funcTy.getResult(0);
227         auto argTy = getResultArgumentType(resultType, options);
228         options.newArg = func.front().insertArgument(zero, argTy, loc);
229         if (mustEmboxResult(resultType, options)) {
230           auto bufferType = fir::ReferenceType::get(resultType);
231           rewriter.setInsertionPointToStart(&func.front());
232           options.newArg =
233               rewriter.create<fir::BoxAddrOp>(loc, bufferType, options.newArg);
234         }
235         patterns.insert<ReturnOpConversion>(context, options);
236         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
237             [](mlir::func::ReturnOp ret) { return ret.operands().empty(); });
238       }
239     }
240 
241     if (func.empty())
242       return;
243 
244     // Convert the calls and, if needed,  the ReturnOp in the function body.
245     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithmeticDialect,
246                            mlir::func::FuncDialect>();
247     target.addIllegalOp<fir::SaveResultOp>();
248     target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
249       return !hasAbstractResult(call.getFunctionType());
250     });
251     target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
252       if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
253         return !hasAbstractResult(funTy);
254       return true;
255     });
256     target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
257       if (dispatch->getNumResults() != 1)
258         return true;
259       auto resultType = dispatch->getResult(0).getType();
260       if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) {
261         TODO(dispatch.getLoc(), "dispatchOp with abstract results");
262         return false;
263       }
264       return true;
265     });
266 
267     patterns.insert<CallOpConversion>(context, options);
268     patterns.insert<SaveResultOpConversion>(context);
269     patterns.insert<AddrOfOpConversion>(context, options);
270     if (mlir::failed(
271             mlir::applyPartialConversion(func, target, std::move(patterns)))) {
272       mlir::emitError(func.getLoc(), "error in converting abstract results\n");
273       signalPassFailure();
274     }
275   }
276 };
277 } // end anonymous namespace
278 } // namespace fir
279 
280 std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass() {
281   return std::make_unique<AbstractResultOpt>();
282 }
283