xref: /llvm-project/flang/lib/Optimizer/Transforms/AbstractResult.cpp (revision c203850ad55de6e1396d5735e4d9b56b66db9220)
1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/Todo.h"
11 #include "flang/Optimizer/Dialect/FIRDialect.h"
12 #include "flang/Optimizer/Dialect/FIROps.h"
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
15 #include "flang/Optimizer/Transforms/Passes.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Diagnostics.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "mlir/Transforms/Passes.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace fir {
24 #define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT
25 #define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT
26 #include "flang/Optimizer/Transforms/Passes.h.inc"
27 } // namespace fir
28 
29 #define DEBUG_TYPE "flang-abstract-result-opt"
30 
31 using namespace mlir;
32 
33 namespace fir {
34 namespace {
35 
36 static mlir::Type getResultArgumentType(mlir::Type resultType,
37                                         bool shouldBoxResult) {
38   return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
39       .Case<fir::SequenceType, fir::RecordType>(
40           [&](mlir::Type type) -> mlir::Type {
41             if (shouldBoxResult)
42               return fir::BoxType::get(type);
43             return fir::ReferenceType::get(type);
44           })
45       .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type {
46         return fir::ReferenceType::get(type);
47       })
48       .Default([](mlir::Type) -> mlir::Type {
49         llvm_unreachable("bad abstract result type");
50       });
51 }
52 
53 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
54                                              bool shouldBoxResult) {
55   auto resultType = funcTy.getResult(0);
56   auto argTy = getResultArgumentType(resultType, shouldBoxResult);
57   llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
58   newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
59   return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
60                                  /*resultTypes=*/{});
61 }
62 
63 /// This is for function result types that are of type C_PTR from ISO_C_BINDING.
64 /// Follow the ABI for interoperability with C.
65 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
66   auto resultType = funcTy.getResult(0);
67   assert(fir::isa_builtin_cptr_type(resultType));
68   llvm::SmallVector<mlir::Type> outputTypes;
69   auto recTy = resultType.dyn_cast<fir::RecordType>();
70   outputTypes.emplace_back(recTy.getTypeList()[0].second);
71   return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
72                                  outputTypes);
73 }
74 
75 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
76   return resultType.isa<fir::SequenceType, fir::RecordType>() &&
77          shouldBoxResult;
78 }
79 
80 template <typename Op>
81 class CallConversion : public mlir::OpRewritePattern<Op> {
82 public:
83   using mlir::OpRewritePattern<Op>::OpRewritePattern;
84 
85   CallConversion(mlir::MLIRContext *context, bool shouldBoxResult)
86       : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {}
87 
88   mlir::LogicalResult
89   matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
90     auto loc = op.getLoc();
91     auto result = op->getResult(0);
92     if (!result.hasOneUse()) {
93       mlir::emitError(loc,
94                       "calls with abstract result must have exactly one user");
95       return mlir::failure();
96     }
97     auto saveResult =
98         mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
99     if (!saveResult) {
100       mlir::emitError(
101           loc, "calls with abstract result must be used in fir.save_result");
102       return mlir::failure();
103     }
104     auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
105     auto buffer = saveResult.getMemref();
106     mlir::Value arg = buffer;
107     if (mustEmboxResult(result.getType(), shouldBoxResult))
108       arg = rewriter.create<fir::EmboxOp>(
109           loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
110           saveResult.getTypeparams());
111 
112     llvm::SmallVector<mlir::Type> newResultTypes;
113     // TODO: This should be generalized for derived types, and it is
114     // architecture and OS dependent.
115     bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
116     Op newOp;
117     if (isResultBuiltinCPtr) {
118       auto recTy = result.getType().template dyn_cast<fir::RecordType>();
119       newResultTypes.emplace_back(recTy.getTypeList()[0].second);
120     }
121 
122     // fir::CallOp specific handling.
123     if constexpr (std::is_same_v<Op, fir::CallOp>) {
124       if (op.getCallee()) {
125         llvm::SmallVector<mlir::Value> newOperands;
126         if (!isResultBuiltinCPtr)
127           newOperands.emplace_back(arg);
128         newOperands.append(op.getOperands().begin(), op.getOperands().end());
129         newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(),
130                                              newResultTypes, newOperands);
131       } else {
132         // Indirect calls.
133         llvm::SmallVector<mlir::Type> newInputTypes;
134         if (!isResultBuiltinCPtr)
135           newInputTypes.emplace_back(argType);
136         for (auto operand : op.getOperands().drop_front())
137           newInputTypes.push_back(operand.getType());
138         auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes,
139                                                  newResultTypes);
140 
141         llvm::SmallVector<mlir::Value> newOperands;
142         newOperands.push_back(
143             rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0)));
144         if (!isResultBuiltinCPtr)
145           newOperands.push_back(arg);
146         newOperands.append(op.getOperands().begin() + 1,
147                            op.getOperands().end());
148         newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
149                                              newResultTypes, newOperands);
150       }
151     }
152 
153     // fir::DispatchOp specific handling.
154     if constexpr (std::is_same_v<Op, fir::DispatchOp>) {
155       llvm::SmallVector<mlir::Value> newOperands;
156       if (!isResultBuiltinCPtr)
157         newOperands.emplace_back(arg);
158       unsigned passArgShift = newOperands.size();
159       newOperands.append(op.getOperands().begin() + 1, op.getOperands().end());
160 
161       fir::DispatchOp newDispatchOp;
162       if (op.getPassArgPos())
163         newOp = rewriter.create<fir::DispatchOp>(
164             loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
165             op.getOperands()[0], newOperands,
166             rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift));
167       else
168         newOp = rewriter.create<fir::DispatchOp>(
169             loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
170             op.getOperands()[0], newOperands, nullptr);
171     }
172 
173     if (isResultBuiltinCPtr) {
174       mlir::Value save = saveResult.getMemref();
175       auto module = op->template getParentOfType<mlir::ModuleOp>();
176       fir::KindMapping kindMap = fir::getKindMapping(module);
177       FirOpBuilder builder(rewriter, kindMap);
178       mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
179           builder, loc, save, result.getType());
180       rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr);
181     }
182     op->dropAllReferences();
183     rewriter.eraseOp(op);
184     return mlir::success();
185   }
186 
187 private:
188   bool shouldBoxResult;
189 };
190 
191 class SaveResultOpConversion
192     : public mlir::OpRewritePattern<fir::SaveResultOp> {
193 public:
194   using OpRewritePattern::OpRewritePattern;
195   SaveResultOpConversion(mlir::MLIRContext *context)
196       : OpRewritePattern(context) {}
197   mlir::LogicalResult
198   matchAndRewrite(fir::SaveResultOp op,
199                   mlir::PatternRewriter &rewriter) const override {
200     rewriter.eraseOp(op);
201     return mlir::success();
202   }
203 };
204 
205 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
206 public:
207   using OpRewritePattern::OpRewritePattern;
208   ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
209       : OpRewritePattern(context), newArg{newArg} {}
210   mlir::LogicalResult
211   matchAndRewrite(mlir::func::ReturnOp ret,
212                   mlir::PatternRewriter &rewriter) const override {
213     auto loc = ret.getLoc();
214     rewriter.setInsertionPoint(ret);
215     auto returnedValue = ret.getOperand(0);
216     bool replacedStorage = false;
217     if (auto *op = returnedValue.getDefiningOp())
218       if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
219         auto resultStorage = load.getMemref();
220         // The result alloca may be behind a fir.declare, if any.
221         if (auto declare = mlir::dyn_cast_or_null<fir::DeclareOp>(
222                 resultStorage.getDefiningOp()))
223           resultStorage = declare.getMemref();
224         // TODO: This should be generalized for derived types, and it is
225         // architecture and OS dependent.
226         if (fir::isa_builtin_cptr_type(returnedValue.getType())) {
227           rewriter.eraseOp(load);
228           auto module = ret->getParentOfType<mlir::ModuleOp>();
229           fir::KindMapping kindMap = fir::getKindMapping(module);
230           FirOpBuilder builder(rewriter, kindMap);
231           mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr(
232               builder, loc, resultStorage, returnedValue.getType());
233           mlir::Value retValue = rewriter.create<fir::LoadOp>(
234               loc, fir::unwrapRefType(retAddr.getType()), retAddr);
235           rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
236               ret, mlir::ValueRange{retValue});
237           return mlir::success();
238         }
239         resultStorage.replaceAllUsesWith(newArg);
240         replacedStorage = true;
241         if (auto *alloc = resultStorage.getDefiningOp())
242           if (alloc->use_empty())
243             rewriter.eraseOp(alloc);
244       }
245     // The result storage may have been optimized out by a memory to
246     // register pass, this is possible for fir.box results, or fir.record
247     // with no length parameters. Simply store the result in the result storage.
248     // at the return point.
249     if (!replacedStorage)
250       rewriter.create<fir::StoreOp>(loc, returnedValue, newArg);
251     rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
252     return mlir::success();
253   }
254 
255 private:
256   mlir::Value newArg;
257 };
258 
259 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
260 public:
261   using OpRewritePattern::OpRewritePattern;
262   AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
263       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
264   mlir::LogicalResult
265   matchAndRewrite(fir::AddrOfOp addrOf,
266                   mlir::PatternRewriter &rewriter) const override {
267     auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>();
268     mlir::FunctionType newFuncTy;
269     // TODO: This should be generalized for derived types, and it is
270     // architecture and OS dependent.
271     if (oldFuncTy.getNumResults() != 0 &&
272         fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
273       newFuncTy = getCPtrFunctionType(oldFuncTy);
274     else
275       newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
276     auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
277                                                     addrOf.getSymbol());
278     // Rather than converting all op a function pointer might transit through
279     // (e.g calls, stores, loads, converts...), cast new type to the abstract
280     // type. A conversion will be added when calling indirect calls of abstract
281     // types.
282     rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
283     return mlir::success();
284   }
285 
286 private:
287   bool shouldBoxResult;
288 };
289 
290 /// @brief Base CRTP class for AbstractResult pass family.
291 /// Contains common logic for abstract result conversion in a reusable fashion.
292 /// @tparam Pass target class that implements operation-specific logic.
293 /// @tparam PassBase base class template for the pass generated by TableGen.
294 /// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
295 /// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
296 /// This function should implement operation-specific functionality.
297 template <typename Pass, template <typename> class PassBase>
298 class AbstractResultOptTemplate : public PassBase<Pass> {
299 public:
300   void runOnOperation() override {
301     auto *context = &this->getContext();
302     auto op = this->getOperation();
303 
304     mlir::RewritePatternSet patterns(context);
305     mlir::ConversionTarget target = *context;
306     const bool shouldBoxResult = this->passResultAsBox.getValue();
307 
308     auto &self = static_cast<Pass &>(*this);
309     self.runOnSpecificOperation(op, shouldBoxResult, patterns, target);
310 
311     // Convert the calls and, if needed,  the ReturnOp in the function body.
312     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
313                            mlir::func::FuncDialect>();
314     target.addIllegalOp<fir::SaveResultOp>();
315     target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
316       return !hasAbstractResult(call.getFunctionType());
317     });
318     target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
319       if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
320         return !hasAbstractResult(funTy);
321       return true;
322     });
323     target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
324       return !hasAbstractResult(dispatch.getFunctionType());
325     });
326 
327     patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
328     patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
329     patterns.insert<SaveResultOpConversion>(context);
330     patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
331     if (mlir::failed(
332             mlir::applyPartialConversion(op, target, std::move(patterns)))) {
333       mlir::emitError(op.getLoc(), "error in converting abstract results\n");
334       this->signalPassFailure();
335     }
336   }
337 };
338 
339 class AbstractResultOnFuncOpt
340     : public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
341                                        fir::impl::AbstractResultOnFuncOptBase> {
342 public:
343   void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
344                               mlir::RewritePatternSet &patterns,
345                               mlir::ConversionTarget &target) {
346     auto loc = func.getLoc();
347     auto *context = &getContext();
348     // Convert function type itself if it has an abstract result.
349     auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
350     if (hasAbstractResult(funcTy)) {
351       // TODO: This should be generalized for derived types, and it is
352       // architecture and OS dependent.
353       if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
354         func.setType(getCPtrFunctionType(funcTy));
355         patterns.insert<ReturnOpConversion>(context, mlir::Value{});
356         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
357             [](mlir::func::ReturnOp ret) {
358               mlir::Type retTy = ret.getOperand(0).getType();
359               return !fir::isa_builtin_cptr_type(retTy);
360             });
361         return;
362       }
363       if (!func.empty()) {
364         // Insert new argument.
365         mlir::OpBuilder rewriter(context);
366         auto resultType = funcTy.getResult(0);
367         auto argTy = getResultArgumentType(resultType, shouldBoxResult);
368         func.insertArgument(0u, argTy, {}, loc);
369         func.eraseResult(0u);
370         mlir::Value newArg = func.getArgument(0u);
371         if (mustEmboxResult(resultType, shouldBoxResult)) {
372           auto bufferType = fir::ReferenceType::get(resultType);
373           rewriter.setInsertionPointToStart(&func.front());
374           newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
375         }
376         patterns.insert<ReturnOpConversion>(context, newArg);
377         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
378             [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); });
379         assert(func.getFunctionType() ==
380                getNewFunctionType(funcTy, shouldBoxResult));
381       } else {
382         llvm::SmallVector<mlir::DictionaryAttr> allArgs;
383         func.getAllArgAttrs(allArgs);
384         allArgs.insert(allArgs.begin(),
385                        mlir::DictionaryAttr::get(func->getContext()));
386         func.setType(getNewFunctionType(funcTy, shouldBoxResult));
387         func.setAllArgAttrs(allArgs);
388       }
389     }
390   }
391 };
392 
393 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
394   return mlir::TypeSwitch<mlir::Type, bool>(type)
395       .Case([](fir::BoxProcType boxProc) {
396         return fir::hasAbstractResult(
397             boxProc.getEleTy().cast<mlir::FunctionType>());
398       })
399       .Case([](fir::PointerType pointer) {
400         return fir::hasAbstractResult(
401             pointer.getEleTy().cast<mlir::FunctionType>());
402       })
403       .Default([](auto &&) { return false; });
404 }
405 
406 class AbstractResultOnGlobalOpt
407     : public AbstractResultOptTemplate<
408           AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> {
409 public:
410   void runOnSpecificOperation(fir::GlobalOp global, bool,
411                               mlir::RewritePatternSet &,
412                               mlir::ConversionTarget &) {
413     if (containsFunctionTypeWithAbstractResult(global.getType())) {
414       TODO(global->getLoc(), "support for procedure pointers");
415     }
416   }
417 };
418 } // end anonymous namespace
419 } // namespace fir
420 
421 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() {
422   return std::make_unique<AbstractResultOnFuncOpt>();
423 }
424 
425 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() {
426   return std::make_unique<AbstractResultOnGlobalOpt>();
427 }
428