xref: /llvm-project/flang/lib/Optimizer/Transforms/AbstractResult.cpp (revision 149ad3d554c6e901a57c7e34e29fba334dcd283c)
1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Dialect/FIRDialect.h"
11 #include "flang/Optimizer/Dialect/FIROps.h"
12 #include "flang/Optimizer/Dialect/FIRType.h"
13 #include "flang/Optimizer/Transforms/Passes.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "mlir/Transforms/Passes.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 #define DEBUG_TYPE "flang-abstract-result-opt"
22 
23 namespace fir {
24 namespace {
25 
26 struct AbstractResultOptions {
27   // Always pass result as a fir.box argument.
28   bool boxResult = false;
29   // New function block argument for the result if the current FuncOp had
30   // an abstract result.
31   mlir::Value newArg;
32 };
33 
34 static bool mustConvertCallOrFunc(mlir::FunctionType type) {
35   if (type.getNumResults() == 0)
36     return false;
37   auto resultType = type.getResult(0);
38   return resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>();
39 }
40 
41 static mlir::Type getResultArgumentType(mlir::Type resultType,
42                                         const AbstractResultOptions &options) {
43   return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
44       .Case<fir::SequenceType, fir::RecordType>(
45           [&](mlir::Type type) -> mlir::Type {
46             if (options.boxResult)
47               return fir::BoxType::get(type);
48             return fir::ReferenceType::get(type);
49           })
50       .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type {
51         return fir::ReferenceType::get(type);
52       })
53       .Default([](mlir::Type) -> mlir::Type {
54         llvm_unreachable("bad abstract result type");
55       });
56 }
57 
58 static mlir::FunctionType
59 getNewFunctionType(mlir::FunctionType funcTy,
60                    const AbstractResultOptions &options) {
61   auto resultType = funcTy.getResult(0);
62   auto argTy = getResultArgumentType(resultType, options);
63   llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
64   newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
65   return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
66                                  /*resultTypes=*/{});
67 }
68 
69 static bool mustEmboxResult(mlir::Type resultType,
70                             const AbstractResultOptions &options) {
71   return resultType.isa<fir::SequenceType, fir::RecordType>() &&
72          options.boxResult;
73 }
74 
75 class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
76 public:
77   using OpRewritePattern::OpRewritePattern;
78   CallOpConversion(mlir::MLIRContext *context, const AbstractResultOptions &opt)
79       : OpRewritePattern(context), options{opt} {}
80   mlir::LogicalResult
81   matchAndRewrite(fir::CallOp callOp,
82                   mlir::PatternRewriter &rewriter) const override {
83     auto loc = callOp.getLoc();
84     auto result = callOp->getResult(0);
85     if (!result.hasOneUse()) {
86       mlir::emitError(loc,
87                       "calls with abstract result must have exactly one user");
88       return mlir::failure();
89     }
90     auto saveResult =
91         mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
92     if (!saveResult) {
93       mlir::emitError(
94           loc, "calls with abstract result must be used in fir.save_result");
95       return mlir::failure();
96     }
97     auto argType = getResultArgumentType(result.getType(), options);
98     auto buffer = saveResult.getMemref();
99     mlir::Value arg = buffer;
100     if (mustEmboxResult(result.getType(), options))
101       arg = rewriter.create<fir::EmboxOp>(
102           loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
103           saveResult.getTypeparams());
104 
105     llvm::SmallVector<mlir::Type> newResultTypes;
106     if (callOp.getCallee()) {
107       llvm::SmallVector<mlir::Value> newOperands = {arg};
108       newOperands.append(callOp.getOperands().begin(),
109                          callOp.getOperands().end());
110       rewriter.create<fir::CallOp>(loc, callOp.getCallee().getValue(),
111                                    newResultTypes, newOperands);
112     } else {
113       // Indirect calls.
114       llvm::SmallVector<mlir::Type> newInputTypes = {argType};
115       for (auto operand : callOp.getOperands().drop_front())
116         newInputTypes.push_back(operand.getType());
117       auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes,
118                                            newResultTypes);
119 
120       llvm::SmallVector<mlir::Value> newOperands;
121       newOperands.push_back(
122           rewriter.create<fir::ConvertOp>(loc, funTy, callOp.getOperand(0)));
123       newOperands.push_back(arg);
124       newOperands.append(callOp.getOperands().begin() + 1,
125                          callOp.getOperands().end());
126       rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, newResultTypes,
127                                    newOperands);
128     }
129     callOp->dropAllReferences();
130     rewriter.eraseOp(callOp);
131     return mlir::success();
132   }
133 
134 private:
135   const AbstractResultOptions &options;
136 };
137 
138 class SaveResultOpConversion
139     : public mlir::OpRewritePattern<fir::SaveResultOp> {
140 public:
141   using OpRewritePattern::OpRewritePattern;
142   SaveResultOpConversion(mlir::MLIRContext *context)
143       : OpRewritePattern(context) {}
144   mlir::LogicalResult
145   matchAndRewrite(fir::SaveResultOp op,
146                   mlir::PatternRewriter &rewriter) const override {
147     rewriter.eraseOp(op);
148     return mlir::success();
149   }
150 };
151 
152 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::ReturnOp> {
153 public:
154   using OpRewritePattern::OpRewritePattern;
155   ReturnOpConversion(mlir::MLIRContext *context,
156                      const AbstractResultOptions &opt)
157       : OpRewritePattern(context), options{opt} {}
158   mlir::LogicalResult
159   matchAndRewrite(mlir::ReturnOp ret,
160                   mlir::PatternRewriter &rewriter) const override {
161     rewriter.setInsertionPoint(ret);
162     auto returnedValue = ret.getOperand(0);
163     bool replacedStorage = false;
164     if (auto *op = returnedValue.getDefiningOp())
165       if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
166         auto resultStorage = load.getMemref();
167         load.getMemref().replaceAllUsesWith(options.newArg);
168         replacedStorage = true;
169         if (auto *alloc = resultStorage.getDefiningOp())
170           if (alloc->use_empty())
171             rewriter.eraseOp(alloc);
172       }
173     // The result storage may have been optimized out by a memory to
174     // register pass, this is possible for fir.box results, or fir.record
175     // with no length parameters. Simply store the result in the result storage.
176     // at the return point.
177     if (!replacedStorage)
178       rewriter.create<fir::StoreOp>(ret.getLoc(), returnedValue,
179                                     options.newArg);
180     rewriter.replaceOpWithNewOp<mlir::ReturnOp>(ret);
181     return mlir::success();
182   }
183 
184 private:
185   const AbstractResultOptions &options;
186 };
187 
188 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
189 public:
190   using OpRewritePattern::OpRewritePattern;
191   AddrOfOpConversion(mlir::MLIRContext *context,
192                      const AbstractResultOptions &opt)
193       : OpRewritePattern(context), options{opt} {}
194   mlir::LogicalResult
195   matchAndRewrite(fir::AddrOfOp addrOf,
196                   mlir::PatternRewriter &rewriter) const override {
197     auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>();
198     auto newFuncTy = getNewFunctionType(oldFuncTy, options);
199     auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
200                                                     addrOf.getSymbol());
201     // Rather than converting all op a function pointer might transit through
202     // (e.g calls, stores, loads, converts...), cast new type to the abstract
203     // type. A conversion will be added when calling indirect calls of abstract
204     // types.
205     rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
206     return mlir::success();
207   }
208 
209 private:
210   const AbstractResultOptions &options;
211 };
212 
213 class AbstractResultOpt : public fir::AbstractResultOptBase<AbstractResultOpt> {
214 public:
215   void runOnOperation() override {
216     auto *context = &getContext();
217     auto func = getOperation();
218     auto loc = func.getLoc();
219     mlir::RewritePatternSet patterns(context);
220     mlir::ConversionTarget target = *context;
221     AbstractResultOptions options{passResultAsBox.getValue(),
222                                   /*newArg=*/{}};
223 
224     // Convert function type itself if it has an abstract result
225     auto funcTy = func.getType().cast<mlir::FunctionType>();
226     if (mustConvertCallOrFunc(funcTy)) {
227       func.setType(getNewFunctionType(funcTy, options));
228       unsigned zero = 0;
229       if (!func.empty()) {
230         // Insert new argument
231         mlir::OpBuilder rewriter(context);
232         auto resultType = funcTy.getResult(0);
233         auto argTy = getResultArgumentType(resultType, options);
234         options.newArg = func.front().insertArgument(zero, argTy, loc);
235         if (mustEmboxResult(resultType, options)) {
236           auto bufferType = fir::ReferenceType::get(resultType);
237           rewriter.setInsertionPointToStart(&func.front());
238           options.newArg =
239               rewriter.create<fir::BoxAddrOp>(loc, bufferType, options.newArg);
240         }
241         patterns.insert<ReturnOpConversion>(context, options);
242         target.addDynamicallyLegalOp<mlir::ReturnOp>(
243             [](mlir::ReturnOp ret) { return ret.operands().empty(); });
244       }
245     }
246 
247     if (func.empty())
248       return;
249 
250     // Convert the calls and, if needed,  the ReturnOp in the function body.
251     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithmeticDialect,
252                            mlir::StandardOpsDialect>();
253     target.addIllegalOp<fir::SaveResultOp>();
254     target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
255       return !mustConvertCallOrFunc(call.getFunctionType());
256     });
257     target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
258       if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
259         return !mustConvertCallOrFunc(funTy);
260       return true;
261     });
262     target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
263       if (dispatch->getNumResults() != 1)
264         return true;
265       auto resultType = dispatch->getResult(0).getType();
266       if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) {
267         mlir::emitError(dispatch.getLoc(),
268                         "TODO: dispatchOp with abstract results");
269         return false;
270       }
271       return true;
272     });
273 
274     patterns.insert<CallOpConversion>(context, options);
275     patterns.insert<SaveResultOpConversion>(context);
276     patterns.insert<AddrOfOpConversion>(context, options);
277     if (mlir::failed(
278             mlir::applyPartialConversion(func, target, std::move(patterns)))) {
279       mlir::emitError(func.getLoc(), "error in converting abstract results\n");
280       signalPassFailure();
281     }
282   }
283 };
284 } // end anonymous namespace
285 } // namespace fir
286 
287 std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass() {
288   return std::make_unique<AbstractResultOpt>();
289 }
290