xref: /llvm-project/flang/lib/Optimizer/Transforms/AbstractResult.cpp (revision a6f2f44f9ca1540ce81655382015ab4bb78e1c0c)
1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Builder/Todo.h"
11 #include "flang/Optimizer/Dialect/FIRDialect.h"
12 #include "flang/Optimizer/Dialect/FIROps.h"
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Transforms/Passes.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/IR/Diagnostics.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include "mlir/Transforms/Passes.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 #define DEBUG_TYPE "flang-abstract-result-opt"
23 
24 namespace fir {
25 namespace {
26 
27 static mlir::Type getResultArgumentType(mlir::Type resultType,
28                                         bool shouldBoxResult) {
29   return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
30       .Case<fir::SequenceType, fir::RecordType>(
31           [&](mlir::Type type) -> mlir::Type {
32             if (shouldBoxResult)
33               return fir::BoxType::get(type);
34             return fir::ReferenceType::get(type);
35           })
36       .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type {
37         return fir::ReferenceType::get(type);
38       })
39       .Default([](mlir::Type) -> mlir::Type {
40         llvm_unreachable("bad abstract result type");
41       });
42 }
43 
44 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
45                                              bool shouldBoxResult) {
46   auto resultType = funcTy.getResult(0);
47   auto argTy = getResultArgumentType(resultType, shouldBoxResult);
48   llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
49   newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
50   return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
51                                  /*resultTypes=*/{});
52 }
53 
54 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
55   return resultType.isa<fir::SequenceType, fir::RecordType>() &&
56          shouldBoxResult;
57 }
58 
59 class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
60 public:
61   using OpRewritePattern::OpRewritePattern;
62   CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
63       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
64   mlir::LogicalResult
65   matchAndRewrite(fir::CallOp callOp,
66                   mlir::PatternRewriter &rewriter) const override {
67     auto loc = callOp.getLoc();
68     auto result = callOp->getResult(0);
69     if (!result.hasOneUse()) {
70       mlir::emitError(loc,
71                       "calls with abstract result must have exactly one user");
72       return mlir::failure();
73     }
74     auto saveResult =
75         mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
76     if (!saveResult) {
77       mlir::emitError(
78           loc, "calls with abstract result must be used in fir.save_result");
79       return mlir::failure();
80     }
81     auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
82     auto buffer = saveResult.getMemref();
83     mlir::Value arg = buffer;
84     if (mustEmboxResult(result.getType(), shouldBoxResult))
85       arg = rewriter.create<fir::EmboxOp>(
86           loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
87           saveResult.getTypeparams());
88 
89     llvm::SmallVector<mlir::Type> newResultTypes;
90     if (callOp.getCallee()) {
91       llvm::SmallVector<mlir::Value> newOperands = {arg};
92       newOperands.append(callOp.getOperands().begin(),
93                          callOp.getOperands().end());
94       rewriter.create<fir::CallOp>(loc, *callOp.getCallee(), newResultTypes,
95                                    newOperands);
96     } else {
97       // Indirect calls.
98       llvm::SmallVector<mlir::Type> newInputTypes = {argType};
99       for (auto operand : callOp.getOperands().drop_front())
100         newInputTypes.push_back(operand.getType());
101       auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes,
102                                            newResultTypes);
103 
104       llvm::SmallVector<mlir::Value> newOperands;
105       newOperands.push_back(
106           rewriter.create<fir::ConvertOp>(loc, funTy, callOp.getOperand(0)));
107       newOperands.push_back(arg);
108       newOperands.append(callOp.getOperands().begin() + 1,
109                          callOp.getOperands().end());
110       rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, newResultTypes,
111                                    newOperands);
112     }
113     callOp->dropAllReferences();
114     rewriter.eraseOp(callOp);
115     return mlir::success();
116   }
117 
118 private:
119   bool shouldBoxResult;
120 };
121 
122 class SaveResultOpConversion
123     : public mlir::OpRewritePattern<fir::SaveResultOp> {
124 public:
125   using OpRewritePattern::OpRewritePattern;
126   SaveResultOpConversion(mlir::MLIRContext *context)
127       : OpRewritePattern(context) {}
128   mlir::LogicalResult
129   matchAndRewrite(fir::SaveResultOp op,
130                   mlir::PatternRewriter &rewriter) const override {
131     rewriter.eraseOp(op);
132     return mlir::success();
133   }
134 };
135 
136 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
137 public:
138   using OpRewritePattern::OpRewritePattern;
139   ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
140       : OpRewritePattern(context), newArg{newArg} {}
141   mlir::LogicalResult
142   matchAndRewrite(mlir::func::ReturnOp ret,
143                   mlir::PatternRewriter &rewriter) const override {
144     rewriter.setInsertionPoint(ret);
145     auto returnedValue = ret.getOperand(0);
146     bool replacedStorage = false;
147     if (auto *op = returnedValue.getDefiningOp())
148       if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
149         auto resultStorage = load.getMemref();
150         load.getMemref().replaceAllUsesWith(newArg);
151         replacedStorage = true;
152         if (auto *alloc = resultStorage.getDefiningOp())
153           if (alloc->use_empty())
154             rewriter.eraseOp(alloc);
155       }
156     // The result storage may have been optimized out by a memory to
157     // register pass, this is possible for fir.box results, or fir.record
158     // with no length parameters. Simply store the result in the result storage.
159     // at the return point.
160     if (!replacedStorage)
161       rewriter.create<fir::StoreOp>(ret.getLoc(), returnedValue, newArg);
162     rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
163     return mlir::success();
164   }
165 
166 private:
167   mlir::Value newArg;
168 };
169 
170 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
171 public:
172   using OpRewritePattern::OpRewritePattern;
173   AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
174       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
175   mlir::LogicalResult
176   matchAndRewrite(fir::AddrOfOp addrOf,
177                   mlir::PatternRewriter &rewriter) const override {
178     auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>();
179     auto newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
180     auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
181                                                     addrOf.getSymbol());
182     // Rather than converting all op a function pointer might transit through
183     // (e.g calls, stores, loads, converts...), cast new type to the abstract
184     // type. A conversion will be added when calling indirect calls of abstract
185     // types.
186     rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
187     return mlir::success();
188   }
189 
190 private:
191   bool shouldBoxResult;
192 };
193 
194 /// @brief Base CRTP class for AbstractResult pass family.
195 /// Contains common logic for abstract result conversion in a reusable fashion.
196 /// @tparam Pass target class that implements operation-specific logic.
197 /// @tparam PassBase base class template for the pass generated by TableGen.
198 /// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
199 /// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
200 /// This function should implement operation-specific functionality.
201 template <typename Pass, template <typename> class PassBase>
202 class AbstractResultOptTemplate : public PassBase<Pass> {
203 public:
204   void runOnOperation() override {
205     auto *context = &this->getContext();
206     auto op = this->getOperation();
207 
208     mlir::RewritePatternSet patterns(context);
209     mlir::ConversionTarget target = *context;
210     const bool shouldBoxResult = this->passResultAsBox.getValue();
211 
212     auto &self = static_cast<Pass &>(*this);
213     self.runOnSpecificOperation(op, shouldBoxResult, patterns, target);
214 
215     // Convert the calls and, if needed,  the ReturnOp in the function body.
216     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithmeticDialect,
217                            mlir::func::FuncDialect>();
218     target.addIllegalOp<fir::SaveResultOp>();
219     target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
220       return !hasAbstractResult(call.getFunctionType());
221     });
222     target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
223       if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
224         return !hasAbstractResult(funTy);
225       return true;
226     });
227     target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
228       if (dispatch->getNumResults() != 1)
229         return true;
230       auto resultType = dispatch->getResult(0).getType();
231       if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) {
232         TODO(dispatch.getLoc(), "dispatchOp with abstract results");
233         return false;
234       }
235       return true;
236     });
237 
238     patterns.insert<CallOpConversion>(context, shouldBoxResult);
239     patterns.insert<SaveResultOpConversion>(context);
240     patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
241     if (mlir::failed(
242             mlir::applyPartialConversion(op, target, std::move(patterns)))) {
243       mlir::emitError(op.getLoc(), "error in converting abstract results\n");
244       this->signalPassFailure();
245     }
246   }
247 };
248 
249 class AbstractResultOnFuncOpt
250     : public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
251                                        fir::AbstractResultOnFuncOptBase> {
252 public:
253   void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
254                               mlir::RewritePatternSet &patterns,
255                               mlir::ConversionTarget &target) {
256     auto loc = func.getLoc();
257     auto *context = &getContext();
258     // Convert function type itself if it has an abstract result.
259     auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
260     if (hasAbstractResult(funcTy)) {
261       func.setType(getNewFunctionType(funcTy, shouldBoxResult));
262       if (!func.empty()) {
263         // Insert new argument.
264         mlir::OpBuilder rewriter(context);
265         auto resultType = funcTy.getResult(0);
266         auto argTy = getResultArgumentType(resultType, shouldBoxResult);
267         mlir::Value newArg = func.front().insertArgument(0u, argTy, loc);
268         if (mustEmboxResult(resultType, shouldBoxResult)) {
269           auto bufferType = fir::ReferenceType::get(resultType);
270           rewriter.setInsertionPointToStart(&func.front());
271           newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
272         }
273         patterns.insert<ReturnOpConversion>(context, newArg);
274         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
275             [](mlir::func::ReturnOp ret) { return ret.operands().empty(); });
276       }
277     }
278   }
279 };
280 
281 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
282   return mlir::TypeSwitch<mlir::Type, bool>(type)
283       .Case([](fir::BoxProcType boxProc) {
284         return fir::hasAbstractResult(
285             boxProc.getEleTy().cast<mlir::FunctionType>());
286       })
287       .Case([](fir::PointerType pointer) {
288         return fir::hasAbstractResult(
289             pointer.getEleTy().cast<mlir::FunctionType>());
290       })
291       .Default([](auto &&) { return false; });
292 }
293 
294 class AbstractResultOnGlobalOpt
295     : public AbstractResultOptTemplate<AbstractResultOnGlobalOpt,
296                                        fir::AbstractResultOnGlobalOptBase> {
297 public:
298   void runOnSpecificOperation(fir::GlobalOp global, bool,
299                               mlir::RewritePatternSet &,
300                               mlir::ConversionTarget &) {
301     if (containsFunctionTypeWithAbstractResult(global.getType())) {
302       TODO(global->getLoc(), "support for procedure pointers");
303     }
304   }
305 };
306 } // end anonymous namespace
307 } // namespace fir
308 
309 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() {
310   return std::make_unique<AbstractResultOnFuncOpt>();
311 }
312 
313 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() {
314   return std::make_unique<AbstractResultOnGlobalOpt>();
315 }
316