xref: /llvm-project/flang/lib/Optimizer/Transforms/AbstractResult.cpp (revision 6c8d8d10455a3d38a92102ae3a94e1c06d3eb363)
1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/Todo.h"
10 #include "flang/Optimizer/Dialect/FIRDialect.h"
11 #include "flang/Optimizer/Dialect/FIROps.h"
12 #include "flang/Optimizer/Dialect/FIRType.h"
13 #include "flang/Optimizer/Transforms/Passes.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "mlir/Transforms/Passes.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 namespace fir {
22 #define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT
23 #define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT
24 #include "flang/Optimizer/Transforms/Passes.h.inc"
25 } // namespace fir
26 
27 #define DEBUG_TYPE "flang-abstract-result-opt"
28 
29 namespace fir {
30 namespace {
31 
32 static mlir::Type getResultArgumentType(mlir::Type resultType,
33                                         bool shouldBoxResult) {
34   return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
35       .Case<fir::SequenceType, fir::RecordType>(
36           [&](mlir::Type type) -> mlir::Type {
37             if (shouldBoxResult)
38               return fir::BoxType::get(type);
39             return fir::ReferenceType::get(type);
40           })
41       .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type {
42         return fir::ReferenceType::get(type);
43       })
44       .Default([](mlir::Type) -> mlir::Type {
45         llvm_unreachable("bad abstract result type");
46       });
47 }
48 
49 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
50                                              bool shouldBoxResult) {
51   auto resultType = funcTy.getResult(0);
52   auto argTy = getResultArgumentType(resultType, shouldBoxResult);
53   llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
54   newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
55   return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
56                                  /*resultTypes=*/{});
57 }
58 
59 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
60   return resultType.isa<fir::SequenceType, fir::RecordType>() &&
61          shouldBoxResult;
62 }
63 
64 class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
65 public:
66   using OpRewritePattern::OpRewritePattern;
67   CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
68       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
69   mlir::LogicalResult
70   matchAndRewrite(fir::CallOp callOp,
71                   mlir::PatternRewriter &rewriter) const override {
72     auto loc = callOp.getLoc();
73     auto result = callOp->getResult(0);
74     if (!result.hasOneUse()) {
75       mlir::emitError(loc,
76                       "calls with abstract result must have exactly one user");
77       return mlir::failure();
78     }
79     auto saveResult =
80         mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
81     if (!saveResult) {
82       mlir::emitError(
83           loc, "calls with abstract result must be used in fir.save_result");
84       return mlir::failure();
85     }
86     auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
87     auto buffer = saveResult.getMemref();
88     mlir::Value arg = buffer;
89     if (mustEmboxResult(result.getType(), shouldBoxResult))
90       arg = rewriter.create<fir::EmboxOp>(
91           loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
92           saveResult.getTypeparams());
93 
94     llvm::SmallVector<mlir::Type> newResultTypes;
95     if (callOp.getCallee()) {
96       llvm::SmallVector<mlir::Value> newOperands = {arg};
97       newOperands.append(callOp.getOperands().begin(),
98                          callOp.getOperands().end());
99       rewriter.create<fir::CallOp>(loc, *callOp.getCallee(), newResultTypes,
100                                    newOperands);
101     } else {
102       // Indirect calls.
103       llvm::SmallVector<mlir::Type> newInputTypes = {argType};
104       for (auto operand : callOp.getOperands().drop_front())
105         newInputTypes.push_back(operand.getType());
106       auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes,
107                                            newResultTypes);
108 
109       llvm::SmallVector<mlir::Value> newOperands;
110       newOperands.push_back(
111           rewriter.create<fir::ConvertOp>(loc, funTy, callOp.getOperand(0)));
112       newOperands.push_back(arg);
113       newOperands.append(callOp.getOperands().begin() + 1,
114                          callOp.getOperands().end());
115       rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, newResultTypes,
116                                    newOperands);
117     }
118     callOp->dropAllReferences();
119     rewriter.eraseOp(callOp);
120     return mlir::success();
121   }
122 
123 private:
124   bool shouldBoxResult;
125 };
126 
127 class SaveResultOpConversion
128     : public mlir::OpRewritePattern<fir::SaveResultOp> {
129 public:
130   using OpRewritePattern::OpRewritePattern;
131   SaveResultOpConversion(mlir::MLIRContext *context)
132       : OpRewritePattern(context) {}
133   mlir::LogicalResult
134   matchAndRewrite(fir::SaveResultOp op,
135                   mlir::PatternRewriter &rewriter) const override {
136     rewriter.eraseOp(op);
137     return mlir::success();
138   }
139 };
140 
141 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
142 public:
143   using OpRewritePattern::OpRewritePattern;
144   ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
145       : OpRewritePattern(context), newArg{newArg} {}
146   mlir::LogicalResult
147   matchAndRewrite(mlir::func::ReturnOp ret,
148                   mlir::PatternRewriter &rewriter) const override {
149     rewriter.setInsertionPoint(ret);
150     auto returnedValue = ret.getOperand(0);
151     bool replacedStorage = false;
152     if (auto *op = returnedValue.getDefiningOp())
153       if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
154         auto resultStorage = load.getMemref();
155         load.getMemref().replaceAllUsesWith(newArg);
156         replacedStorage = true;
157         if (auto *alloc = resultStorage.getDefiningOp())
158           if (alloc->use_empty())
159             rewriter.eraseOp(alloc);
160       }
161     // The result storage may have been optimized out by a memory to
162     // register pass, this is possible for fir.box results, or fir.record
163     // with no length parameters. Simply store the result in the result storage.
164     // at the return point.
165     if (!replacedStorage)
166       rewriter.create<fir::StoreOp>(ret.getLoc(), returnedValue, newArg);
167     rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
168     return mlir::success();
169   }
170 
171 private:
172   mlir::Value newArg;
173 };
174 
175 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
176 public:
177   using OpRewritePattern::OpRewritePattern;
178   AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
179       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
180   mlir::LogicalResult
181   matchAndRewrite(fir::AddrOfOp addrOf,
182                   mlir::PatternRewriter &rewriter) const override {
183     auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>();
184     auto newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
185     auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
186                                                     addrOf.getSymbol());
187     // Rather than converting all op a function pointer might transit through
188     // (e.g calls, stores, loads, converts...), cast new type to the abstract
189     // type. A conversion will be added when calling indirect calls of abstract
190     // types.
191     rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
192     return mlir::success();
193   }
194 
195 private:
196   bool shouldBoxResult;
197 };
198 
199 /// @brief Base CRTP class for AbstractResult pass family.
200 /// Contains common logic for abstract result conversion in a reusable fashion.
201 /// @tparam Pass target class that implements operation-specific logic.
202 /// @tparam PassBase base class template for the pass generated by TableGen.
203 /// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
204 /// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
205 /// This function should implement operation-specific functionality.
206 template <typename Pass, template <typename> class PassBase>
207 class AbstractResultOptTemplate : public PassBase<Pass> {
208 public:
209   void runOnOperation() override {
210     auto *context = &this->getContext();
211     auto op = this->getOperation();
212 
213     mlir::RewritePatternSet patterns(context);
214     mlir::ConversionTarget target = *context;
215     const bool shouldBoxResult = this->passResultAsBox.getValue();
216 
217     auto &self = static_cast<Pass &>(*this);
218     self.runOnSpecificOperation(op, shouldBoxResult, patterns, target);
219 
220     // Convert the calls and, if needed,  the ReturnOp in the function body.
221     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
222                            mlir::func::FuncDialect>();
223     target.addIllegalOp<fir::SaveResultOp>();
224     target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
225       return !hasAbstractResult(call.getFunctionType());
226     });
227     target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
228       if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
229         return !hasAbstractResult(funTy);
230       return true;
231     });
232     target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
233       if (dispatch->getNumResults() != 1)
234         return true;
235       auto resultType = dispatch->getResult(0).getType();
236       if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) {
237         TODO(dispatch.getLoc(), "dispatchOp with abstract results");
238         return false;
239       }
240       return true;
241     });
242 
243     patterns.insert<CallOpConversion>(context, shouldBoxResult);
244     patterns.insert<SaveResultOpConversion>(context);
245     patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
246     if (mlir::failed(
247             mlir::applyPartialConversion(op, target, std::move(patterns)))) {
248       mlir::emitError(op.getLoc(), "error in converting abstract results\n");
249       this->signalPassFailure();
250     }
251   }
252 };
253 
254 class AbstractResultOnFuncOpt
255     : public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
256                                        fir::impl::AbstractResultOnFuncOptBase> {
257 public:
258   void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
259                               mlir::RewritePatternSet &patterns,
260                               mlir::ConversionTarget &target) {
261     auto loc = func.getLoc();
262     auto *context = &getContext();
263     // Convert function type itself if it has an abstract result.
264     auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
265     if (hasAbstractResult(funcTy)) {
266       func.setType(getNewFunctionType(funcTy, shouldBoxResult));
267       if (!func.empty()) {
268         // Insert new argument.
269         mlir::OpBuilder rewriter(context);
270         auto resultType = funcTy.getResult(0);
271         auto argTy = getResultArgumentType(resultType, shouldBoxResult);
272         mlir::Value newArg = func.front().insertArgument(0u, argTy, loc);
273         if (mustEmboxResult(resultType, shouldBoxResult)) {
274           auto bufferType = fir::ReferenceType::get(resultType);
275           rewriter.setInsertionPointToStart(&func.front());
276           newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
277         }
278         patterns.insert<ReturnOpConversion>(context, newArg);
279         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
280             [](mlir::func::ReturnOp ret) { return ret.operands().empty(); });
281       }
282     }
283   }
284 };
285 
286 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
287   return mlir::TypeSwitch<mlir::Type, bool>(type)
288       .Case([](fir::BoxProcType boxProc) {
289         return fir::hasAbstractResult(
290             boxProc.getEleTy().cast<mlir::FunctionType>());
291       })
292       .Case([](fir::PointerType pointer) {
293         return fir::hasAbstractResult(
294             pointer.getEleTy().cast<mlir::FunctionType>());
295       })
296       .Default([](auto &&) { return false; });
297 }
298 
299 class AbstractResultOnGlobalOpt
300     : public AbstractResultOptTemplate<
301           AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> {
302 public:
303   void runOnSpecificOperation(fir::GlobalOp global, bool,
304                               mlir::RewritePatternSet &,
305                               mlir::ConversionTarget &) {
306     if (containsFunctionTypeWithAbstractResult(global.getType())) {
307       TODO(global->getLoc(), "support for procedure pointers");
308     }
309   }
310 };
311 } // end anonymous namespace
312 } // namespace fir
313 
314 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() {
315   return std::make_unique<AbstractResultOnFuncOpt>();
316 }
317 
318 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() {
319   return std::make_unique<AbstractResultOnGlobalOpt>();
320 }
321