xref: /llvm-project/flang/lib/Optimizer/Transforms/AbstractResult.cpp (revision afb34cf3077a38007fcebe17dc384532207283fa)
1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/Todo.h"
11 #include "flang/Optimizer/Dialect/FIRDialect.h"
12 #include "flang/Optimizer/Dialect/FIROps.h"
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Support/FIRContext.h"
15 #include "flang/Optimizer/Transforms/Passes.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Diagnostics.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "mlir/Transforms/Passes.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace fir {
24 #define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT
25 #define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT
26 #include "flang/Optimizer/Transforms/Passes.h.inc"
27 } // namespace fir
28 
29 #define DEBUG_TYPE "flang-abstract-result-opt"
30 
31 using namespace mlir;
32 
33 namespace fir {
34 namespace {
35 
36 static mlir::Type getResultArgumentType(mlir::Type resultType,
37                                         bool shouldBoxResult) {
38   return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
39       .Case<fir::SequenceType, fir::RecordType>(
40           [&](mlir::Type type) -> mlir::Type {
41             if (shouldBoxResult)
42               return fir::BoxType::get(type);
43             return fir::ReferenceType::get(type);
44           })
45       .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type {
46         return fir::ReferenceType::get(type);
47       })
48       .Default([](mlir::Type) -> mlir::Type {
49         llvm_unreachable("bad abstract result type");
50       });
51 }
52 
53 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
54                                              bool shouldBoxResult) {
55   auto resultType = funcTy.getResult(0);
56   auto argTy = getResultArgumentType(resultType, shouldBoxResult);
57   llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
58   newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
59   return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
60                                  /*resultTypes=*/{});
61 }
62 
63 /// This is for function result types that are of type C_PTR from ISO_C_BINDING.
64 /// Follow the ABI for interoperability with C.
65 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
66   auto resultType = funcTy.getResult(0);
67   assert(fir::isa_builtin_cptr_type(resultType));
68   llvm::SmallVector<mlir::Type> outputTypes;
69   auto recTy = resultType.dyn_cast<fir::RecordType>();
70   outputTypes.emplace_back(recTy.getTypeList()[0].second);
71   return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
72                                  outputTypes);
73 }
74 
75 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
76   return resultType.isa<fir::SequenceType, fir::RecordType>() &&
77          shouldBoxResult;
78 }
79 
80 template <typename Op>
81 class CallConversion : public mlir::OpRewritePattern<Op> {
82 public:
83   using mlir::OpRewritePattern<Op>::OpRewritePattern;
84 
85   CallConversion(mlir::MLIRContext *context, bool shouldBoxResult)
86       : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {}
87 
88   mlir::LogicalResult
89   matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
90     auto loc = op.getLoc();
91     auto result = op->getResult(0);
92     if (!result.hasOneUse()) {
93       mlir::emitError(loc,
94                       "calls with abstract result must have exactly one user");
95       return mlir::failure();
96     }
97     auto saveResult =
98         mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
99     if (!saveResult) {
100       mlir::emitError(
101           loc, "calls with abstract result must be used in fir.save_result");
102       return mlir::failure();
103     }
104     auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
105     auto buffer = saveResult.getMemref();
106     mlir::Value arg = buffer;
107     if (mustEmboxResult(result.getType(), shouldBoxResult))
108       arg = rewriter.create<fir::EmboxOp>(
109           loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
110           saveResult.getTypeparams());
111 
112     llvm::SmallVector<mlir::Type> newResultTypes;
113     // TODO: This should be generalized for derived types, and it is
114     // architecture and OS dependent.
115     bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
116     Op newOp;
117     if (isResultBuiltinCPtr) {
118       auto recTy = result.getType().template dyn_cast<fir::RecordType>();
119       newResultTypes.emplace_back(recTy.getTypeList()[0].second);
120     }
121 
122     // fir::CallOp specific handling.
123     if constexpr (std::is_same_v<Op, fir::CallOp>) {
124       if (op.getCallee()) {
125         llvm::SmallVector<mlir::Value> newOperands;
126         if (!isResultBuiltinCPtr)
127           newOperands.emplace_back(arg);
128         newOperands.append(op.getOperands().begin(), op.getOperands().end());
129         newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(),
130                                              newResultTypes, newOperands);
131       } else {
132         // Indirect calls.
133         llvm::SmallVector<mlir::Type> newInputTypes;
134         if (!isResultBuiltinCPtr)
135           newInputTypes.emplace_back(argType);
136         for (auto operand : op.getOperands().drop_front())
137           newInputTypes.push_back(operand.getType());
138         auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes,
139                                                  newResultTypes);
140 
141         llvm::SmallVector<mlir::Value> newOperands;
142         newOperands.push_back(
143             rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0)));
144         if (!isResultBuiltinCPtr)
145           newOperands.push_back(arg);
146         newOperands.append(op.getOperands().begin() + 1,
147                            op.getOperands().end());
148         newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
149                                              newResultTypes, newOperands);
150       }
151     }
152 
153     // fir::DispatchOp specific handling.
154     if constexpr (std::is_same_v<Op, fir::DispatchOp>) {
155       llvm::SmallVector<mlir::Value> newOperands;
156       if (!isResultBuiltinCPtr)
157         newOperands.emplace_back(arg);
158       unsigned passArgShift = newOperands.size();
159       newOperands.append(op.getOperands().begin() + 1, op.getOperands().end());
160 
161       fir::DispatchOp newDispatchOp;
162       if (op.getPassArgPos())
163         newOp = rewriter.create<fir::DispatchOp>(
164             loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
165             op.getOperands()[0], newOperands,
166             rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift));
167       else
168         newOp = rewriter.create<fir::DispatchOp>(
169             loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
170             op.getOperands()[0], newOperands, nullptr);
171     }
172 
173     if (isResultBuiltinCPtr) {
174       mlir::Value save = saveResult.getMemref();
175       auto module = op->template getParentOfType<mlir::ModuleOp>();
176       fir::KindMapping kindMap = fir::getKindMapping(module);
177       FirOpBuilder builder(rewriter, kindMap);
178       mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
179           builder, loc, save, result.getType());
180       rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr);
181     }
182     op->dropAllReferences();
183     rewriter.eraseOp(op);
184     return mlir::success();
185   }
186 
187 private:
188   bool shouldBoxResult;
189 };
190 
191 class SaveResultOpConversion
192     : public mlir::OpRewritePattern<fir::SaveResultOp> {
193 public:
194   using OpRewritePattern::OpRewritePattern;
195   SaveResultOpConversion(mlir::MLIRContext *context)
196       : OpRewritePattern(context) {}
197   mlir::LogicalResult
198   matchAndRewrite(fir::SaveResultOp op,
199                   mlir::PatternRewriter &rewriter) const override {
200     rewriter.eraseOp(op);
201     return mlir::success();
202   }
203 };
204 
205 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
206 public:
207   using OpRewritePattern::OpRewritePattern;
208   ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
209       : OpRewritePattern(context), newArg{newArg} {}
210   mlir::LogicalResult
211   matchAndRewrite(mlir::func::ReturnOp ret,
212                   mlir::PatternRewriter &rewriter) const override {
213     auto loc = ret.getLoc();
214     rewriter.setInsertionPoint(ret);
215     auto returnedValue = ret.getOperand(0);
216     bool replacedStorage = false;
217     if (auto *op = returnedValue.getDefiningOp())
218       if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
219         auto resultStorage = load.getMemref();
220         // TODO: This should be generalized for derived types, and it is
221         // architecture and OS dependent.
222         if (fir::isa_builtin_cptr_type(returnedValue.getType())) {
223           rewriter.eraseOp(load);
224           auto module = ret->getParentOfType<mlir::ModuleOp>();
225           fir::KindMapping kindMap = fir::getKindMapping(module);
226           FirOpBuilder builder(rewriter, kindMap);
227           mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr(
228               builder, loc, resultStorage, returnedValue.getType());
229           mlir::Value retValue = rewriter.create<fir::LoadOp>(
230               loc, fir::unwrapRefType(retAddr.getType()), retAddr);
231           rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
232               ret, mlir::ValueRange{retValue});
233           return mlir::success();
234         }
235         load.getMemref().replaceAllUsesWith(newArg);
236         replacedStorage = true;
237         if (auto *alloc = resultStorage.getDefiningOp())
238           if (alloc->use_empty())
239             rewriter.eraseOp(alloc);
240       }
241     // The result storage may have been optimized out by a memory to
242     // register pass, this is possible for fir.box results, or fir.record
243     // with no length parameters. Simply store the result in the result storage.
244     // at the return point.
245     if (!replacedStorage)
246       rewriter.create<fir::StoreOp>(loc, returnedValue, newArg);
247     rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
248     return mlir::success();
249   }
250 
251 private:
252   mlir::Value newArg;
253 };
254 
255 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
256 public:
257   using OpRewritePattern::OpRewritePattern;
258   AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
259       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
260   mlir::LogicalResult
261   matchAndRewrite(fir::AddrOfOp addrOf,
262                   mlir::PatternRewriter &rewriter) const override {
263     auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>();
264     mlir::FunctionType newFuncTy;
265     // TODO: This should be generalized for derived types, and it is
266     // architecture and OS dependent.
267     if (oldFuncTy.getNumResults() != 0 &&
268         fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
269       newFuncTy = getCPtrFunctionType(oldFuncTy);
270     else
271       newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
272     auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
273                                                     addrOf.getSymbol());
274     // Rather than converting all op a function pointer might transit through
275     // (e.g calls, stores, loads, converts...), cast new type to the abstract
276     // type. A conversion will be added when calling indirect calls of abstract
277     // types.
278     rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
279     return mlir::success();
280   }
281 
282 private:
283   bool shouldBoxResult;
284 };
285 
286 /// @brief Base CRTP class for AbstractResult pass family.
287 /// Contains common logic for abstract result conversion in a reusable fashion.
288 /// @tparam Pass target class that implements operation-specific logic.
289 /// @tparam PassBase base class template for the pass generated by TableGen.
290 /// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
291 /// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
292 /// This function should implement operation-specific functionality.
293 template <typename Pass, template <typename> class PassBase>
294 class AbstractResultOptTemplate : public PassBase<Pass> {
295 public:
296   void runOnOperation() override {
297     auto *context = &this->getContext();
298     auto op = this->getOperation();
299 
300     mlir::RewritePatternSet patterns(context);
301     mlir::ConversionTarget target = *context;
302     const bool shouldBoxResult = this->passResultAsBox.getValue();
303 
304     auto &self = static_cast<Pass &>(*this);
305     self.runOnSpecificOperation(op, shouldBoxResult, patterns, target);
306 
307     // Convert the calls and, if needed,  the ReturnOp in the function body.
308     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
309                            mlir::func::FuncDialect>();
310     target.addIllegalOp<fir::SaveResultOp>();
311     target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
312       return !hasAbstractResult(call.getFunctionType());
313     });
314     target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
315       if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
316         return !hasAbstractResult(funTy);
317       return true;
318     });
319     target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
320       return !hasAbstractResult(dispatch.getFunctionType());
321     });
322 
323     patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
324     patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
325     patterns.insert<SaveResultOpConversion>(context);
326     patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
327     if (mlir::failed(
328             mlir::applyPartialConversion(op, target, std::move(patterns)))) {
329       mlir::emitError(op.getLoc(), "error in converting abstract results\n");
330       this->signalPassFailure();
331     }
332   }
333 };
334 
335 class AbstractResultOnFuncOpt
336     : public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
337                                        fir::impl::AbstractResultOnFuncOptBase> {
338 public:
339   void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
340                               mlir::RewritePatternSet &patterns,
341                               mlir::ConversionTarget &target) {
342     auto loc = func.getLoc();
343     auto *context = &getContext();
344     // Convert function type itself if it has an abstract result.
345     auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
346     if (hasAbstractResult(funcTy)) {
347       // TODO: This should be generalized for derived types, and it is
348       // architecture and OS dependent.
349       if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
350         func.setType(getCPtrFunctionType(funcTy));
351         patterns.insert<ReturnOpConversion>(context, mlir::Value{});
352         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
353             [](mlir::func::ReturnOp ret) {
354               mlir::Type retTy = ret.getOperand(0).getType();
355               return !fir::isa_builtin_cptr_type(retTy);
356             });
357         return;
358       }
359       if (!func.empty()) {
360         // Insert new argument.
361         mlir::OpBuilder rewriter(context);
362         auto resultType = funcTy.getResult(0);
363         auto argTy = getResultArgumentType(resultType, shouldBoxResult);
364         func.insertArgument(0u, argTy, {}, loc);
365         func.eraseResult(0u);
366         mlir::Value newArg = func.getArgument(0u);
367         if (mustEmboxResult(resultType, shouldBoxResult)) {
368           auto bufferType = fir::ReferenceType::get(resultType);
369           rewriter.setInsertionPointToStart(&func.front());
370           newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
371         }
372         patterns.insert<ReturnOpConversion>(context, newArg);
373         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
374             [](mlir::func::ReturnOp ret) { return ret.operands().empty(); });
375         assert(func.getFunctionType() ==
376                getNewFunctionType(funcTy, shouldBoxResult));
377       } else {
378         llvm::SmallVector<mlir::DictionaryAttr> allArgs;
379         func.getAllArgAttrs(allArgs);
380         allArgs.insert(allArgs.begin(),
381                        mlir::DictionaryAttr::get(func->getContext()));
382         func.setType(getNewFunctionType(funcTy, shouldBoxResult));
383         func.setAllArgAttrs(allArgs);
384       }
385     }
386   }
387 };
388 
389 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
390   return mlir::TypeSwitch<mlir::Type, bool>(type)
391       .Case([](fir::BoxProcType boxProc) {
392         return fir::hasAbstractResult(
393             boxProc.getEleTy().cast<mlir::FunctionType>());
394       })
395       .Case([](fir::PointerType pointer) {
396         return fir::hasAbstractResult(
397             pointer.getEleTy().cast<mlir::FunctionType>());
398       })
399       .Default([](auto &&) { return false; });
400 }
401 
402 class AbstractResultOnGlobalOpt
403     : public AbstractResultOptTemplate<
404           AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> {
405 public:
406   void runOnSpecificOperation(fir::GlobalOp global, bool,
407                               mlir::RewritePatternSet &,
408                               mlir::ConversionTarget &) {
409     if (containsFunctionTypeWithAbstractResult(global.getType())) {
410       TODO(global->getLoc(), "support for procedure pointers");
411     }
412   }
413 };
414 } // end anonymous namespace
415 } // namespace fir
416 
417 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() {
418   return std::make_unique<AbstractResultOnFuncOpt>();
419 }
420 
421 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() {
422   return std::make_unique<AbstractResultOnGlobalOpt>();
423 }
424