xref: /llvm-project/flang/lib/Optimizer/Transforms/AbstractResult.cpp (revision 4d5a9c1d171b1b86deb2030e37b31ab2ec9d1dd0)
1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/Todo.h"
11 #include "flang/Optimizer/Dialect/FIRDialect.h"
12 #include "flang/Optimizer/Dialect/FIROps.h"
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Support/FIRContext.h"
15 #include "flang/Optimizer/Transforms/Passes.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Diagnostics.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "mlir/Transforms/Passes.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace fir {
24 #define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT
25 #define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT
26 #include "flang/Optimizer/Transforms/Passes.h.inc"
27 } // namespace fir
28 
29 #define DEBUG_TYPE "flang-abstract-result-opt"
30 
31 namespace fir {
32 namespace {
33 
34 static mlir::Type getResultArgumentType(mlir::Type resultType,
35                                         bool shouldBoxResult) {
36   return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
37       .Case<fir::SequenceType, fir::RecordType>(
38           [&](mlir::Type type) -> mlir::Type {
39             if (shouldBoxResult)
40               return fir::BoxType::get(type);
41             return fir::ReferenceType::get(type);
42           })
43       .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type {
44         return fir::ReferenceType::get(type);
45       })
46       .Default([](mlir::Type) -> mlir::Type {
47         llvm_unreachable("bad abstract result type");
48       });
49 }
50 
51 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
52                                              bool shouldBoxResult) {
53   auto resultType = funcTy.getResult(0);
54   auto argTy = getResultArgumentType(resultType, shouldBoxResult);
55   llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
56   newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
57   return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
58                                  /*resultTypes=*/{});
59 }
60 
61 /// This is for function result types that are of type C_PTR from ISO_C_BINDING.
62 /// Follow the ABI for interoperability with C.
63 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
64   auto resultType = funcTy.getResult(0);
65   assert(fir::isa_builtin_cptr_type(resultType));
66   llvm::SmallVector<mlir::Type> outputTypes;
67   auto recTy = resultType.dyn_cast<fir::RecordType>();
68   outputTypes.emplace_back(recTy.getTypeList()[0].second);
69   return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
70                                  outputTypes);
71 }
72 
73 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
74   return resultType.isa<fir::SequenceType, fir::RecordType>() &&
75          shouldBoxResult;
76 }
77 
78 class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
79 public:
80   using OpRewritePattern::OpRewritePattern;
81   CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
82       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
83   mlir::LogicalResult
84   matchAndRewrite(fir::CallOp callOp,
85                   mlir::PatternRewriter &rewriter) const override {
86     auto loc = callOp.getLoc();
87     auto result = callOp->getResult(0);
88     if (!result.hasOneUse()) {
89       mlir::emitError(loc,
90                       "calls with abstract result must have exactly one user");
91       return mlir::failure();
92     }
93     auto saveResult =
94         mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
95     if (!saveResult) {
96       mlir::emitError(
97           loc, "calls with abstract result must be used in fir.save_result");
98       return mlir::failure();
99     }
100     auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
101     auto buffer = saveResult.getMemref();
102     mlir::Value arg = buffer;
103     if (mustEmboxResult(result.getType(), shouldBoxResult))
104       arg = rewriter.create<fir::EmboxOp>(
105           loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
106           saveResult.getTypeparams());
107 
108     llvm::SmallVector<mlir::Type> newResultTypes;
109     // TODO: This should be generalized for derived types, and it is
110     // architecture and OS dependent.
111     bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
112     fir::CallOp newCallOp;
113     if (isResultBuiltinCPtr) {
114       auto recTy = result.getType().dyn_cast<fir::RecordType>();
115       newResultTypes.emplace_back(recTy.getTypeList()[0].second);
116     }
117     if (callOp.getCallee()) {
118       llvm::SmallVector<mlir::Value> newOperands;
119       if (!isResultBuiltinCPtr)
120         newOperands.emplace_back(arg);
121       newOperands.append(callOp.getOperands().begin(),
122                          callOp.getOperands().end());
123       newCallOp = rewriter.create<fir::CallOp>(loc, *callOp.getCallee(),
124                                                newResultTypes, newOperands);
125     } else {
126       // Indirect calls.
127       llvm::SmallVector<mlir::Type> newInputTypes;
128       if (!isResultBuiltinCPtr)
129         newInputTypes.emplace_back(argType);
130       for (auto operand : callOp.getOperands().drop_front())
131         newInputTypes.push_back(operand.getType());
132       auto newFuncTy = mlir::FunctionType::get(callOp.getContext(),
133                                                newInputTypes, newResultTypes);
134 
135       llvm::SmallVector<mlir::Value> newOperands;
136       newOperands.push_back(rewriter.create<fir::ConvertOp>(
137           loc, newFuncTy, callOp.getOperand(0)));
138       if (!isResultBuiltinCPtr)
139         newOperands.push_back(arg);
140       newOperands.append(callOp.getOperands().begin() + 1,
141                          callOp.getOperands().end());
142       newCallOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
143                                                newResultTypes, newOperands);
144     }
145     if (isResultBuiltinCPtr) {
146       mlir::Value save = saveResult.getMemref();
147       auto module = callOp->getParentOfType<mlir::ModuleOp>();
148       fir::KindMapping kindMap = fir::getKindMapping(module);
149       FirOpBuilder builder(rewriter, kindMap);
150       mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
151           builder, loc, save, result.getType());
152       rewriter.create<fir::StoreOp>(loc, newCallOp->getResult(0), saveAddr);
153     }
154     callOp->dropAllReferences();
155     rewriter.eraseOp(callOp);
156     return mlir::success();
157   }
158 
159 private:
160   bool shouldBoxResult;
161 };
162 
163 class SaveResultOpConversion
164     : public mlir::OpRewritePattern<fir::SaveResultOp> {
165 public:
166   using OpRewritePattern::OpRewritePattern;
167   SaveResultOpConversion(mlir::MLIRContext *context)
168       : OpRewritePattern(context) {}
169   mlir::LogicalResult
170   matchAndRewrite(fir::SaveResultOp op,
171                   mlir::PatternRewriter &rewriter) const override {
172     rewriter.eraseOp(op);
173     return mlir::success();
174   }
175 };
176 
177 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
178 public:
179   using OpRewritePattern::OpRewritePattern;
180   ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
181       : OpRewritePattern(context), newArg{newArg} {}
182   mlir::LogicalResult
183   matchAndRewrite(mlir::func::ReturnOp ret,
184                   mlir::PatternRewriter &rewriter) const override {
185     auto loc = ret.getLoc();
186     rewriter.setInsertionPoint(ret);
187     auto returnedValue = ret.getOperand(0);
188     bool replacedStorage = false;
189     if (auto *op = returnedValue.getDefiningOp())
190       if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
191         auto resultStorage = load.getMemref();
192         // TODO: This should be generalized for derived types, and it is
193         // architecture and OS dependent.
194         if (fir::isa_builtin_cptr_type(returnedValue.getType())) {
195           rewriter.eraseOp(load);
196           auto module = ret->getParentOfType<mlir::ModuleOp>();
197           fir::KindMapping kindMap = fir::getKindMapping(module);
198           FirOpBuilder builder(rewriter, kindMap);
199           mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr(
200               builder, loc, resultStorage, returnedValue.getType());
201           mlir::Value retValue = rewriter.create<fir::LoadOp>(
202               loc, fir::unwrapRefType(retAddr.getType()), retAddr);
203           rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
204               ret, mlir::ValueRange{retValue});
205           return mlir::success();
206         }
207         load.getMemref().replaceAllUsesWith(newArg);
208         replacedStorage = true;
209         if (auto *alloc = resultStorage.getDefiningOp())
210           if (alloc->use_empty())
211             rewriter.eraseOp(alloc);
212       }
213     // The result storage may have been optimized out by a memory to
214     // register pass, this is possible for fir.box results, or fir.record
215     // with no length parameters. Simply store the result in the result storage.
216     // at the return point.
217     if (!replacedStorage)
218       rewriter.create<fir::StoreOp>(loc, returnedValue, newArg);
219     rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
220     return mlir::success();
221   }
222 
223 private:
224   mlir::Value newArg;
225 };
226 
227 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
228 public:
229   using OpRewritePattern::OpRewritePattern;
230   AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
231       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
232   mlir::LogicalResult
233   matchAndRewrite(fir::AddrOfOp addrOf,
234                   mlir::PatternRewriter &rewriter) const override {
235     auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>();
236     mlir::FunctionType newFuncTy;
237     // TODO: This should be generalized for derived types, and it is
238     // architecture and OS dependent.
239     if (oldFuncTy.getNumResults() != 0 &&
240         fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
241       newFuncTy = getCPtrFunctionType(oldFuncTy);
242     else
243       newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
244     auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
245                                                     addrOf.getSymbol());
246     // Rather than converting all op a function pointer might transit through
247     // (e.g calls, stores, loads, converts...), cast new type to the abstract
248     // type. A conversion will be added when calling indirect calls of abstract
249     // types.
250     rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
251     return mlir::success();
252   }
253 
254 private:
255   bool shouldBoxResult;
256 };
257 
258 /// @brief Base CRTP class for AbstractResult pass family.
259 /// Contains common logic for abstract result conversion in a reusable fashion.
260 /// @tparam Pass target class that implements operation-specific logic.
261 /// @tparam PassBase base class template for the pass generated by TableGen.
262 /// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
263 /// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
264 /// This function should implement operation-specific functionality.
265 template <typename Pass, template <typename> class PassBase>
266 class AbstractResultOptTemplate : public PassBase<Pass> {
267 public:
268   void runOnOperation() override {
269     auto *context = &this->getContext();
270     auto op = this->getOperation();
271 
272     mlir::RewritePatternSet patterns(context);
273     mlir::ConversionTarget target = *context;
274     const bool shouldBoxResult = this->passResultAsBox.getValue();
275 
276     auto &self = static_cast<Pass &>(*this);
277     self.runOnSpecificOperation(op, shouldBoxResult, patterns, target);
278 
279     // Convert the calls and, if needed,  the ReturnOp in the function body.
280     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
281                            mlir::func::FuncDialect>();
282     target.addIllegalOp<fir::SaveResultOp>();
283     target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
284       return !hasAbstractResult(call.getFunctionType());
285     });
286     target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
287       if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
288         return !hasAbstractResult(funTy);
289       return true;
290     });
291     target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
292       if (dispatch->getNumResults() != 1)
293         return true;
294       auto resultType = dispatch->getResult(0).getType();
295       if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) {
296         TODO(dispatch.getLoc(), "dispatchOp with abstract results");
297         return false;
298       }
299       return true;
300     });
301 
302     patterns.insert<CallOpConversion>(context, shouldBoxResult);
303     patterns.insert<SaveResultOpConversion>(context);
304     patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
305     if (mlir::failed(
306             mlir::applyPartialConversion(op, target, std::move(patterns)))) {
307       mlir::emitError(op.getLoc(), "error in converting abstract results\n");
308       this->signalPassFailure();
309     }
310   }
311 };
312 
313 class AbstractResultOnFuncOpt
314     : public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
315                                        fir::impl::AbstractResultOnFuncOptBase> {
316 public:
317   void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
318                               mlir::RewritePatternSet &patterns,
319                               mlir::ConversionTarget &target) {
320     auto loc = func.getLoc();
321     auto *context = &getContext();
322     // Convert function type itself if it has an abstract result.
323     auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
324     if (hasAbstractResult(funcTy)) {
325       // TODO: This should be generalized for derived types, and it is
326       // architecture and OS dependent.
327       if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
328         func.setType(getCPtrFunctionType(funcTy));
329         patterns.insert<ReturnOpConversion>(context, mlir::Value{});
330         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
331             [](mlir::func::ReturnOp ret) {
332               mlir::Type retTy = ret.getOperand(0).getType();
333               return !fir::isa_builtin_cptr_type(retTy);
334             });
335         return;
336       }
337       if (!func.empty()) {
338         // Insert new argument.
339         mlir::OpBuilder rewriter(context);
340         auto resultType = funcTy.getResult(0);
341         auto argTy = getResultArgumentType(resultType, shouldBoxResult);
342         func.insertArgument(0u, argTy, {}, loc);
343         func.eraseResult(0u);
344         mlir::Value newArg = func.getArgument(0u);
345         if (mustEmboxResult(resultType, shouldBoxResult)) {
346           auto bufferType = fir::ReferenceType::get(resultType);
347           rewriter.setInsertionPointToStart(&func.front());
348           newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
349         }
350         patterns.insert<ReturnOpConversion>(context, newArg);
351         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
352             [](mlir::func::ReturnOp ret) { return ret.operands().empty(); });
353         assert(func.getFunctionType() ==
354                getNewFunctionType(funcTy, shouldBoxResult));
355       } else {
356         func.setType(getNewFunctionType(funcTy, shouldBoxResult));
357       }
358     }
359   }
360 };
361 
362 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
363   return mlir::TypeSwitch<mlir::Type, bool>(type)
364       .Case([](fir::BoxProcType boxProc) {
365         return fir::hasAbstractResult(
366             boxProc.getEleTy().cast<mlir::FunctionType>());
367       })
368       .Case([](fir::PointerType pointer) {
369         return fir::hasAbstractResult(
370             pointer.getEleTy().cast<mlir::FunctionType>());
371       })
372       .Default([](auto &&) { return false; });
373 }
374 
375 class AbstractResultOnGlobalOpt
376     : public AbstractResultOptTemplate<
377           AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> {
378 public:
379   void runOnSpecificOperation(fir::GlobalOp global, bool,
380                               mlir::RewritePatternSet &,
381                               mlir::ConversionTarget &) {
382     if (containsFunctionTypeWithAbstractResult(global.getType())) {
383       TODO(global->getLoc(), "support for procedure pointers");
384     }
385   }
386 };
387 } // end anonymous namespace
388 } // namespace fir
389 
390 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() {
391   return std::make_unique<AbstractResultOnFuncOpt>();
392 }
393 
394 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() {
395   return std::make_unique<AbstractResultOnGlobalOpt>();
396 }
397