1b0eef1eeSValentin Clement //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// 2b0eef1eeSValentin Clement // 3b0eef1eeSValentin Clement // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4b0eef1eeSValentin Clement // See https://llvm.org/LICENSE.txt for license information. 5b0eef1eeSValentin Clement // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6b0eef1eeSValentin Clement // 7b0eef1eeSValentin Clement //===----------------------------------------------------------------------===// 8b0eef1eeSValentin Clement 9c336e72cSPeixin-Qiao #include "flang/Optimizer/Builder/FIRBuilder.h" 109de831aaSJean Perier #include "flang/Optimizer/Builder/Todo.h" 11b0eef1eeSValentin Clement #include "flang/Optimizer/Dialect/FIRDialect.h" 12b0eef1eeSValentin Clement #include "flang/Optimizer/Dialect/FIROps.h" 13b0eef1eeSValentin Clement #include "flang/Optimizer/Dialect/FIRType.h" 14b07ef9e7SRenaud-K #include "flang/Optimizer/Dialect/Support/FIRContext.h" 15b0eef1eeSValentin Clement #include "flang/Optimizer/Transforms/Passes.h" 1623aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 175522d246SValentin Clement (バレンタイン クレメン) #include "mlir/Dialect/GPU/IR/GPUDialect.h" 18b0eef1eeSValentin Clement #include "mlir/IR/Diagnostics.h" 19b0eef1eeSValentin Clement #include "mlir/Pass/Pass.h" 20bfd19445STom Eccles #include "mlir/Pass/PassManager.h" 21b0eef1eeSValentin Clement #include "mlir/Transforms/DialectConversion.h" 22b0eef1eeSValentin Clement #include "llvm/ADT/TypeSwitch.h" 23b0eef1eeSValentin Clement 2467d0d7acSMichele Scuttari namespace fir { 25bfd19445STom Eccles #define GEN_PASS_DEF_ABSTRACTRESULTOPT 2667d0d7acSMichele Scuttari #include "flang/Optimizer/Transforms/Passes.h.inc" 2767d0d7acSMichele Scuttari } // namespace fir 2867d0d7acSMichele Scuttari 29b0eef1eeSValentin Clement #define DEBUG_TYPE "flang-abstract-result-opt" 30b0eef1eeSValentin Clement 31afb34cf3SValentin Clement using namespace mlir; 32afb34cf3SValentin Clement 33b0eef1eeSValentin Clement namespace fir { 34b0eef1eeSValentin Clement namespace { 35b0eef1eeSValentin Clement 36367c3c96SjeanPerier // Helper to only build the symbol table if needed because its build time is 37367c3c96SjeanPerier // linear on the number of symbols in the module. 38367c3c96SjeanPerier struct LazySymbolTable { 39367c3c96SjeanPerier LazySymbolTable(mlir::Operation *op) 40367c3c96SjeanPerier : module{op->getParentOfType<mlir::ModuleOp>()} {} 41367c3c96SjeanPerier void build() { 42367c3c96SjeanPerier if (table) 43367c3c96SjeanPerier return; 44367c3c96SjeanPerier table = std::make_unique<mlir::SymbolTable>(module); 45367c3c96SjeanPerier } 46367c3c96SjeanPerier 47367c3c96SjeanPerier template <typename T> 48367c3c96SjeanPerier T lookup(llvm::StringRef name) { 49367c3c96SjeanPerier build(); 50367c3c96SjeanPerier return table->lookup<T>(name); 51367c3c96SjeanPerier } 52367c3c96SjeanPerier 53367c3c96SjeanPerier private: 54367c3c96SjeanPerier std::unique_ptr<mlir::SymbolTable> table; 55367c3c96SjeanPerier mlir::ModuleOp module; 56367c3c96SjeanPerier }; 57367c3c96SjeanPerier 58367c3c96SjeanPerier bool hasScalarDerivedResult(mlir::FunctionType funTy) { 59367c3c96SjeanPerier // C_PTR/C_FUNPTR are results to void* in this pass, do not consider 60367c3c96SjeanPerier // them as normal derived types. 61367c3c96SjeanPerier return funTy.getNumResults() == 1 && 62367c3c96SjeanPerier mlir::isa<fir::RecordType>(funTy.getResult(0)) && 63367c3c96SjeanPerier !fir::isa_builtin_cptr_type(funTy.getResult(0)); 64367c3c96SjeanPerier } 65367c3c96SjeanPerier 66b0eef1eeSValentin Clement static mlir::Type getResultArgumentType(mlir::Type resultType, 67ea1cdb58SDaniil Dudkin bool shouldBoxResult) { 68b0eef1eeSValentin Clement return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) 69b0eef1eeSValentin Clement .Case<fir::SequenceType, fir::RecordType>( 70b0eef1eeSValentin Clement [&](mlir::Type type) -> mlir::Type { 71ea1cdb58SDaniil Dudkin if (shouldBoxResult) 72b0eef1eeSValentin Clement return fir::BoxType::get(type); 73b0eef1eeSValentin Clement return fir::ReferenceType::get(type); 74b0eef1eeSValentin Clement }) 75afb34cf3SValentin Clement .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type { 76b0eef1eeSValentin Clement return fir::ReferenceType::get(type); 77b0eef1eeSValentin Clement }) 78b0eef1eeSValentin Clement .Default([](mlir::Type) -> mlir::Type { 79b0eef1eeSValentin Clement llvm_unreachable("bad abstract result type"); 80b0eef1eeSValentin Clement }); 81b0eef1eeSValentin Clement } 82b0eef1eeSValentin Clement 83ea1cdb58SDaniil Dudkin static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, 84ea1cdb58SDaniil Dudkin bool shouldBoxResult) { 85b0eef1eeSValentin Clement auto resultType = funcTy.getResult(0); 86ea1cdb58SDaniil Dudkin auto argTy = getResultArgumentType(resultType, shouldBoxResult); 87b0eef1eeSValentin Clement llvm::SmallVector<mlir::Type> newInputTypes = {argTy}; 88b0eef1eeSValentin Clement newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); 89b0eef1eeSValentin Clement return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, 90b0eef1eeSValentin Clement /*resultTypes=*/{}); 91b0eef1eeSValentin Clement } 92b0eef1eeSValentin Clement 931ead51a8SjeanPerier static mlir::Type getVoidPtrType(mlir::MLIRContext *context) { 941ead51a8SjeanPerier return fir::ReferenceType::get(mlir::NoneType::get(context)); 951ead51a8SjeanPerier } 961ead51a8SjeanPerier 97c336e72cSPeixin-Qiao /// This is for function result types that are of type C_PTR from ISO_C_BINDING. 98c336e72cSPeixin-Qiao /// Follow the ABI for interoperability with C. 99c336e72cSPeixin-Qiao static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) { 1001ead51a8SjeanPerier assert(fir::isa_builtin_cptr_type(funcTy.getResult(0))); 1011ead51a8SjeanPerier llvm::SmallVector<mlir::Type> outputTypes{ 1021ead51a8SjeanPerier getVoidPtrType(funcTy.getContext())}; 103c336e72cSPeixin-Qiao return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(), 104c336e72cSPeixin-Qiao outputTypes); 105c336e72cSPeixin-Qiao } 106c336e72cSPeixin-Qiao 107ea1cdb58SDaniil Dudkin static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { 108fac349a1SChristian Sigg return mlir::isa<fir::SequenceType, fir::RecordType>(resultType) && 109ea1cdb58SDaniil Dudkin shouldBoxResult; 110b0eef1eeSValentin Clement } 111b0eef1eeSValentin Clement 112afb34cf3SValentin Clement template <typename Op> 113afb34cf3SValentin Clement class CallConversion : public mlir::OpRewritePattern<Op> { 114b0eef1eeSValentin Clement public: 115afb34cf3SValentin Clement using mlir::OpRewritePattern<Op>::OpRewritePattern; 116afb34cf3SValentin Clement 117afb34cf3SValentin Clement CallConversion(mlir::MLIRContext *context, bool shouldBoxResult) 118afb34cf3SValentin Clement : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {} 119afb34cf3SValentin Clement 120db791b27SRamkumar Ramachandra llvm::LogicalResult 121afb34cf3SValentin Clement matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { 122afb34cf3SValentin Clement auto loc = op.getLoc(); 123afb34cf3SValentin Clement auto result = op->getResult(0); 124b0eef1eeSValentin Clement if (!result.hasOneUse()) { 125b0eef1eeSValentin Clement mlir::emitError(loc, 126b0eef1eeSValentin Clement "calls with abstract result must have exactly one user"); 127b0eef1eeSValentin Clement return mlir::failure(); 128b0eef1eeSValentin Clement } 129b0eef1eeSValentin Clement auto saveResult = 130b0eef1eeSValentin Clement mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser()); 131b0eef1eeSValentin Clement if (!saveResult) { 132b0eef1eeSValentin Clement mlir::emitError( 133b0eef1eeSValentin Clement loc, "calls with abstract result must be used in fir.save_result"); 134b0eef1eeSValentin Clement return mlir::failure(); 135b0eef1eeSValentin Clement } 136ea1cdb58SDaniil Dudkin auto argType = getResultArgumentType(result.getType(), shouldBoxResult); 137149ad3d5SShraiysh Vaishay auto buffer = saveResult.getMemref(); 138b0eef1eeSValentin Clement mlir::Value arg = buffer; 139ea1cdb58SDaniil Dudkin if (mustEmboxResult(result.getType(), shouldBoxResult)) 140b0eef1eeSValentin Clement arg = rewriter.create<fir::EmboxOp>( 141149ad3d5SShraiysh Vaishay loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, 142149ad3d5SShraiysh Vaishay saveResult.getTypeparams()); 143b0eef1eeSValentin Clement 144b0eef1eeSValentin Clement llvm::SmallVector<mlir::Type> newResultTypes; 145c336e72cSPeixin-Qiao bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType()); 1461ead51a8SjeanPerier if (isResultBuiltinCPtr) 1471ead51a8SjeanPerier newResultTypes.emplace_back(getVoidPtrType(result.getContext())); 148afb34cf3SValentin Clement 1491ead51a8SjeanPerier Op newOp; 150afb34cf3SValentin Clement // fir::CallOp specific handling. 151afb34cf3SValentin Clement if constexpr (std::is_same_v<Op, fir::CallOp>) { 152afb34cf3SValentin Clement if (op.getCallee()) { 153c336e72cSPeixin-Qiao llvm::SmallVector<mlir::Value> newOperands; 154c336e72cSPeixin-Qiao if (!isResultBuiltinCPtr) 155c336e72cSPeixin-Qiao newOperands.emplace_back(arg); 156afb34cf3SValentin Clement newOperands.append(op.getOperands().begin(), op.getOperands().end()); 157afb34cf3SValentin Clement newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(), 158c336e72cSPeixin-Qiao newResultTypes, newOperands); 159b0eef1eeSValentin Clement } else { 160b0eef1eeSValentin Clement // Indirect calls. 161c336e72cSPeixin-Qiao llvm::SmallVector<mlir::Type> newInputTypes; 162c336e72cSPeixin-Qiao if (!isResultBuiltinCPtr) 163c336e72cSPeixin-Qiao newInputTypes.emplace_back(argType); 164afb34cf3SValentin Clement for (auto operand : op.getOperands().drop_front()) 165b0eef1eeSValentin Clement newInputTypes.push_back(operand.getType()); 166afb34cf3SValentin Clement auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes, 167afb34cf3SValentin Clement newResultTypes); 168b0eef1eeSValentin Clement 169b0eef1eeSValentin Clement llvm::SmallVector<mlir::Value> newOperands; 170afb34cf3SValentin Clement newOperands.push_back( 171afb34cf3SValentin Clement rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0))); 172c336e72cSPeixin-Qiao if (!isResultBuiltinCPtr) 173b0eef1eeSValentin Clement newOperands.push_back(arg); 174afb34cf3SValentin Clement newOperands.append(op.getOperands().begin() + 1, 175afb34cf3SValentin Clement op.getOperands().end()); 176afb34cf3SValentin Clement newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, 177c336e72cSPeixin-Qiao newResultTypes, newOperands); 178c336e72cSPeixin-Qiao } 179afb34cf3SValentin Clement } 180afb34cf3SValentin Clement 181afb34cf3SValentin Clement // fir::DispatchOp specific handling. 182afb34cf3SValentin Clement if constexpr (std::is_same_v<Op, fir::DispatchOp>) { 183afb34cf3SValentin Clement llvm::SmallVector<mlir::Value> newOperands; 184afb34cf3SValentin Clement if (!isResultBuiltinCPtr) 185afb34cf3SValentin Clement newOperands.emplace_back(arg); 186afb34cf3SValentin Clement unsigned passArgShift = newOperands.size(); 187afb34cf3SValentin Clement newOperands.append(op.getOperands().begin() + 1, op.getOperands().end()); 188a78359c2SjeanPerier mlir::IntegerAttr passArgPos; 189afb34cf3SValentin Clement if (op.getPassArgPos()) 190a78359c2SjeanPerier passArgPos = 191a78359c2SjeanPerier rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift); 192afb34cf3SValentin Clement newOp = rewriter.create<fir::DispatchOp>( 193afb34cf3SValentin Clement loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), 194a78359c2SjeanPerier op.getOperands()[0], newOperands, passArgPos, 195a78359c2SjeanPerier op.getProcedureAttrsAttr()); 196afb34cf3SValentin Clement } 197afb34cf3SValentin Clement 198c336e72cSPeixin-Qiao if (isResultBuiltinCPtr) { 199c336e72cSPeixin-Qiao mlir::Value save = saveResult.getMemref(); 200afb34cf3SValentin Clement auto module = op->template getParentOfType<mlir::ModuleOp>(); 20153cc33b0STom Eccles FirOpBuilder builder(rewriter, module); 202c336e72cSPeixin-Qiao mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( 203c336e72cSPeixin-Qiao builder, loc, save, result.getType()); 2041ead51a8SjeanPerier builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr); 205b0eef1eeSValentin Clement } 206afb34cf3SValentin Clement op->dropAllReferences(); 207afb34cf3SValentin Clement rewriter.eraseOp(op); 208b0eef1eeSValentin Clement return mlir::success(); 209b0eef1eeSValentin Clement } 210b0eef1eeSValentin Clement 211b0eef1eeSValentin Clement private: 212ea1cdb58SDaniil Dudkin bool shouldBoxResult; 213b0eef1eeSValentin Clement }; 214b0eef1eeSValentin Clement 215b0eef1eeSValentin Clement class SaveResultOpConversion 216b0eef1eeSValentin Clement : public mlir::OpRewritePattern<fir::SaveResultOp> { 217b0eef1eeSValentin Clement public: 218b0eef1eeSValentin Clement using OpRewritePattern::OpRewritePattern; 219b0eef1eeSValentin Clement SaveResultOpConversion(mlir::MLIRContext *context) 220b0eef1eeSValentin Clement : OpRewritePattern(context) {} 221db791b27SRamkumar Ramachandra llvm::LogicalResult 222b0eef1eeSValentin Clement matchAndRewrite(fir::SaveResultOp op, 223b0eef1eeSValentin Clement mlir::PatternRewriter &rewriter) const override { 224367c3c96SjeanPerier mlir::Operation *call = op.getValue().getDefiningOp(); 225367c3c96SjeanPerier mlir::Type type = op.getValue().getType(); 226367c3c96SjeanPerier if (mlir::isa<fir::RecordType>(type) && call && fir::hasBindcAttr(call) && 227367c3c96SjeanPerier !fir::isa_builtin_cptr_type(type)) { 228367c3c96SjeanPerier rewriter.replaceOpWithNewOp<fir::StoreOp>(op, op.getValue(), 229367c3c96SjeanPerier op.getMemref()); 230367c3c96SjeanPerier } else { 231b0eef1eeSValentin Clement rewriter.eraseOp(op); 232367c3c96SjeanPerier } 233b0eef1eeSValentin Clement return mlir::success(); 234b0eef1eeSValentin Clement } 235b0eef1eeSValentin Clement }; 236b0eef1eeSValentin Clement 237*75623bfeSValentin Clement (バレンタイン クレメン) template <typename OpTy> 238*75623bfeSValentin Clement (バレンタイン クレメン) static mlir::LogicalResult 239*75623bfeSValentin Clement (バレンタイン クレメン) processReturnLikeOp(OpTy ret, mlir::Value newArg, 240*75623bfeSValentin Clement (バレンタイン クレメン) mlir::PatternRewriter &rewriter) { 241c336e72cSPeixin-Qiao auto loc = ret.getLoc(); 242b0eef1eeSValentin Clement rewriter.setInsertionPoint(ret); 2431ead51a8SjeanPerier mlir::Value resultValue = ret.getOperand(0); 2441ead51a8SjeanPerier fir::LoadOp resultLoad; 2451ead51a8SjeanPerier mlir::Value resultStorage; 2461ead51a8SjeanPerier // Identify result local storage. 2471ead51a8SjeanPerier if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) { 2481ead51a8SjeanPerier resultLoad = load; 2491ead51a8SjeanPerier resultStorage = load.getMemref(); 250c203850aSJean Perier // The result alloca may be behind a fir.declare, if any. 2511ead51a8SjeanPerier if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>()) 252c203850aSJean Perier resultStorage = declare.getMemref(); 2531ead51a8SjeanPerier } 2541ead51a8SjeanPerier // Replace old local storage with new storage argument, unless 2551ead51a8SjeanPerier // the derived type is C_PTR/C_FUN_PTR, in which case the return 2561ead51a8SjeanPerier // type is updated to return void* (no new argument is passed). 2571ead51a8SjeanPerier if (fir::isa_builtin_cptr_type(resultValue.getType())) { 258*75623bfeSValentin Clement (バレンタイン クレメン) auto module = ret->template getParentOfType<mlir::ModuleOp>(); 25953cc33b0STom Eccles FirOpBuilder builder(rewriter, module); 2601ead51a8SjeanPerier mlir::Value cptr = resultValue; 2611ead51a8SjeanPerier if (resultLoad) { 2621ead51a8SjeanPerier // Replace whole derived type load by component load. 2631ead51a8SjeanPerier cptr = resultLoad.getMemref(); 2641ead51a8SjeanPerier rewriter.setInsertionPoint(resultLoad); 2651ead51a8SjeanPerier } 2661ead51a8SjeanPerier mlir::Value newResultValue = 2671ead51a8SjeanPerier fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr); 2681ead51a8SjeanPerier newResultValue = builder.createConvert( 2691ead51a8SjeanPerier loc, getVoidPtrType(ret.getContext()), newResultValue); 2701ead51a8SjeanPerier rewriter.setInsertionPoint(ret); 271*75623bfeSValentin Clement (バレンタイン クレメン) rewriter.replaceOpWithNewOp<OpTy>(ret, mlir::ValueRange{newResultValue}); 2721ead51a8SjeanPerier } else if (resultStorage) { 273c203850aSJean Perier resultStorage.replaceAllUsesWith(newArg); 274*75623bfeSValentin Clement (バレンタイン クレメン) rewriter.replaceOpWithNewOp<OpTy>(ret); 2751ead51a8SjeanPerier } else { 276b0eef1eeSValentin Clement // The result storage may have been optimized out by a memory to 277b0eef1eeSValentin Clement // register pass, this is possible for fir.box results, or fir.record 2781ead51a8SjeanPerier // with no length parameters. Simply store the result in the result 2791ead51a8SjeanPerier // storage. at the return point. 2801ead51a8SjeanPerier rewriter.create<fir::StoreOp>(loc, resultValue, newArg); 281*75623bfeSValentin Clement (バレンタイン クレメン) rewriter.replaceOpWithNewOp<OpTy>(ret); 2821ead51a8SjeanPerier } 2831ead51a8SjeanPerier // Delete result old local storage if unused. 2841ead51a8SjeanPerier if (resultStorage) 2851ead51a8SjeanPerier if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>()) 2861ead51a8SjeanPerier if (alloc->use_empty()) 2871ead51a8SjeanPerier rewriter.eraseOp(alloc); 288b0eef1eeSValentin Clement return mlir::success(); 289b0eef1eeSValentin Clement } 290b0eef1eeSValentin Clement 291*75623bfeSValentin Clement (バレンタイン クレメン) class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { 292*75623bfeSValentin Clement (バレンタイン クレメン) public: 293*75623bfeSValentin Clement (バレンタイン クレメン) using OpRewritePattern::OpRewritePattern; 294*75623bfeSValentin Clement (バレンタイン クレメン) ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) 295*75623bfeSValentin Clement (バレンタイン クレメン) : OpRewritePattern(context), newArg{newArg} {} 296*75623bfeSValentin Clement (バレンタイン クレメン) llvm::LogicalResult 297*75623bfeSValentin Clement (バレンタイン クレメン) matchAndRewrite(mlir::func::ReturnOp ret, 298*75623bfeSValentin Clement (バレンタイン クレメン) mlir::PatternRewriter &rewriter) const override { 299*75623bfeSValentin Clement (バレンタイン クレメン) return processReturnLikeOp(ret, newArg, rewriter); 300*75623bfeSValentin Clement (バレンタイン クレメン) } 301*75623bfeSValentin Clement (バレンタイン クレメン) 302*75623bfeSValentin Clement (バレンタイン クレメン) private: 303*75623bfeSValentin Clement (バレンタイン クレメン) mlir::Value newArg; 304*75623bfeSValentin Clement (バレンタイン クレメン) }; 305*75623bfeSValentin Clement (バレンタイン クレメン) 306*75623bfeSValentin Clement (バレンタイン クレメン) class GPUReturnOpConversion 307*75623bfeSValentin Clement (バレンタイン クレメン) : public mlir::OpRewritePattern<mlir::gpu::ReturnOp> { 308*75623bfeSValentin Clement (バレンタイン クレメン) public: 309*75623bfeSValentin Clement (バレンタイン クレメン) using OpRewritePattern::OpRewritePattern; 310*75623bfeSValentin Clement (バレンタイン クレメン) GPUReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) 311*75623bfeSValentin Clement (バレンタイン クレメン) : OpRewritePattern(context), newArg{newArg} {} 312*75623bfeSValentin Clement (バレンタイン クレメン) llvm::LogicalResult 313*75623bfeSValentin Clement (バレンタイン クレメン) matchAndRewrite(mlir::gpu::ReturnOp ret, 314*75623bfeSValentin Clement (バレンタイン クレメン) mlir::PatternRewriter &rewriter) const override { 315*75623bfeSValentin Clement (バレンタイン クレメン) return processReturnLikeOp(ret, newArg, rewriter); 316*75623bfeSValentin Clement (バレンタイン クレメン) } 317*75623bfeSValentin Clement (バレンタイン クレメン) 318b0eef1eeSValentin Clement private: 319ea1cdb58SDaniil Dudkin mlir::Value newArg; 320b0eef1eeSValentin Clement }; 321b0eef1eeSValentin Clement 322b0eef1eeSValentin Clement class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> { 323b0eef1eeSValentin Clement public: 324b0eef1eeSValentin Clement using OpRewritePattern::OpRewritePattern; 325ea1cdb58SDaniil Dudkin AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) 326ea1cdb58SDaniil Dudkin : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} 327db791b27SRamkumar Ramachandra llvm::LogicalResult 328b0eef1eeSValentin Clement matchAndRewrite(fir::AddrOfOp addrOf, 329b0eef1eeSValentin Clement mlir::PatternRewriter &rewriter) const override { 330fac349a1SChristian Sigg auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType()); 331c336e72cSPeixin-Qiao mlir::FunctionType newFuncTy; 332c336e72cSPeixin-Qiao if (oldFuncTy.getNumResults() != 0 && 333c336e72cSPeixin-Qiao fir::isa_builtin_cptr_type(oldFuncTy.getResult(0))) 334c336e72cSPeixin-Qiao newFuncTy = getCPtrFunctionType(oldFuncTy); 335c336e72cSPeixin-Qiao else 336c336e72cSPeixin-Qiao newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); 337b0eef1eeSValentin Clement auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy, 338149ad3d5SShraiysh Vaishay addrOf.getSymbol()); 339b0eef1eeSValentin Clement // Rather than converting all op a function pointer might transit through 340b0eef1eeSValentin Clement // (e.g calls, stores, loads, converts...), cast new type to the abstract 341b0eef1eeSValentin Clement // type. A conversion will be added when calling indirect calls of abstract 342b0eef1eeSValentin Clement // types. 343b0eef1eeSValentin Clement rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf); 344b0eef1eeSValentin Clement return mlir::success(); 345b0eef1eeSValentin Clement } 346b0eef1eeSValentin Clement 347b0eef1eeSValentin Clement private: 348ea1cdb58SDaniil Dudkin bool shouldBoxResult; 349b0eef1eeSValentin Clement }; 350b0eef1eeSValentin Clement 351bfd19445STom Eccles class AbstractResultOpt 352bfd19445STom Eccles : public fir::impl::AbstractResultOptBase<AbstractResultOpt> { 353b0eef1eeSValentin Clement public: 354bfd19445STom Eccles using fir::impl::AbstractResultOptBase< 355bfd19445STom Eccles AbstractResultOpt>::AbstractResultOptBase; 356cb33e4abSDaniil Dudkin 3575522d246SValentin Clement (バレンタイン クレメン) template <typename OpTy> 3585522d246SValentin Clement (バレンタイン クレメン) void runOnFunctionLikeOperation(OpTy func, bool shouldBoxResult, 359cb33e4abSDaniil Dudkin mlir::RewritePatternSet &patterns, 360cb33e4abSDaniil Dudkin mlir::ConversionTarget &target) { 361cb33e4abSDaniil Dudkin auto loc = func.getLoc(); 362cb33e4abSDaniil Dudkin auto *context = &getContext(); 363cb33e4abSDaniil Dudkin // Convert function type itself if it has an abstract result. 364fac349a1SChristian Sigg auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType()); 365367c3c96SjeanPerier // Scalar derived result of BIND(C) function must be returned according 366367c3c96SjeanPerier // to the C struct return ABI which is target dependent and implemented in 367367c3c96SjeanPerier // the target-rewrite pass. 368367c3c96SjeanPerier if (hasScalarDerivedResult(funcTy) && 369367c3c96SjeanPerier fir::hasBindcAttr(func.getOperation())) 370367c3c96SjeanPerier return; 371cb33e4abSDaniil Dudkin if (hasAbstractResult(funcTy)) { 372c336e72cSPeixin-Qiao if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) { 373c336e72cSPeixin-Qiao func.setType(getCPtrFunctionType(funcTy)); 374c336e72cSPeixin-Qiao patterns.insert<ReturnOpConversion>(context, mlir::Value{}); 375c336e72cSPeixin-Qiao target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 376c336e72cSPeixin-Qiao [](mlir::func::ReturnOp ret) { 377c336e72cSPeixin-Qiao mlir::Type retTy = ret.getOperand(0).getType(); 378c336e72cSPeixin-Qiao return !fir::isa_builtin_cptr_type(retTy); 379c336e72cSPeixin-Qiao }); 380c336e72cSPeixin-Qiao return; 381c336e72cSPeixin-Qiao } 382cb33e4abSDaniil Dudkin if (!func.empty()) { 383cb33e4abSDaniil Dudkin // Insert new argument. 384cb33e4abSDaniil Dudkin mlir::OpBuilder rewriter(context); 385cb33e4abSDaniil Dudkin auto resultType = funcTy.getResult(0); 386cb33e4abSDaniil Dudkin auto argTy = getResultArgumentType(resultType, shouldBoxResult); 3874d5a9c1dSRenaud-K func.insertArgument(0u, argTy, {}, loc); 3884d5a9c1dSRenaud-K func.eraseResult(0u); 3894d5a9c1dSRenaud-K mlir::Value newArg = func.getArgument(0u); 390cb33e4abSDaniil Dudkin if (mustEmboxResult(resultType, shouldBoxResult)) { 391cb33e4abSDaniil Dudkin auto bufferType = fir::ReferenceType::get(resultType); 392cb33e4abSDaniil Dudkin rewriter.setInsertionPointToStart(&func.front()); 393cb33e4abSDaniil Dudkin newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg); 394cb33e4abSDaniil Dudkin } 395cb33e4abSDaniil Dudkin patterns.insert<ReturnOpConversion>(context, newArg); 396cb33e4abSDaniil Dudkin target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 397b74192b7SRiver Riddle [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); }); 398*75623bfeSValentin Clement (バレンタイン クレメン) patterns.insert<GPUReturnOpConversion>(context, newArg); 399*75623bfeSValentin Clement (バレンタイン クレメン) target.addDynamicallyLegalOp<mlir::gpu::ReturnOp>( 400*75623bfeSValentin Clement (バレンタイン クレメン) [](mlir::gpu::ReturnOp ret) { return ret.getOperands().empty(); }); 4014d5a9c1dSRenaud-K assert(func.getFunctionType() == 4024d5a9c1dSRenaud-K getNewFunctionType(funcTy, shouldBoxResult)); 4034d5a9c1dSRenaud-K } else { 4045f9e0491SRenaud-K llvm::SmallVector<mlir::DictionaryAttr> allArgs; 4055f9e0491SRenaud-K func.getAllArgAttrs(allArgs); 4065f9e0491SRenaud-K allArgs.insert(allArgs.begin(), 4075f9e0491SRenaud-K mlir::DictionaryAttr::get(func->getContext())); 4084d5a9c1dSRenaud-K func.setType(getNewFunctionType(funcTy, shouldBoxResult)); 4095f9e0491SRenaud-K func.setAllArgAttrs(allArgs); 410cb33e4abSDaniil Dudkin } 411b0eef1eeSValentin Clement } 412b0eef1eeSValentin Clement } 413a6f2f44fSDaniil Dudkin 4145522d246SValentin Clement (バレンタイン クレメン) void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult, 4155522d246SValentin Clement (バレンタイン クレメン) mlir::RewritePatternSet &patterns, 4165522d246SValentin Clement (バレンタイン クレメン) mlir::ConversionTarget &target) { 4175522d246SValentin Clement (バレンタイン クレメン) runOnFunctionLikeOperation(func, shouldBoxResult, patterns, target); 4185522d246SValentin Clement (バレンタイン クレメン) } 4195522d246SValentin Clement (バレンタイン クレメン) 4205522d246SValentin Clement (バレンタイン クレメン) void runOnSpecificOperation(mlir::gpu::GPUFuncOp func, bool shouldBoxResult, 4215522d246SValentin Clement (バレンタイン クレメン) mlir::RewritePatternSet &patterns, 4225522d246SValentin Clement (バレンタイン クレメン) mlir::ConversionTarget &target) { 4235522d246SValentin Clement (バレンタイン クレメン) runOnFunctionLikeOperation(func, shouldBoxResult, patterns, target); 4245522d246SValentin Clement (バレンタイン クレメン) } 4255522d246SValentin Clement (バレンタイン クレメン) 426a6f2f44fSDaniil Dudkin inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) { 427a6f2f44fSDaniil Dudkin return mlir::TypeSwitch<mlir::Type, bool>(type) 428a6f2f44fSDaniil Dudkin .Case([](fir::BoxProcType boxProc) { 429a6f2f44fSDaniil Dudkin return fir::hasAbstractResult( 430fac349a1SChristian Sigg mlir::cast<mlir::FunctionType>(boxProc.getEleTy())); 431a6f2f44fSDaniil Dudkin }) 432a6f2f44fSDaniil Dudkin .Case([](fir::PointerType pointer) { 433a6f2f44fSDaniil Dudkin return fir::hasAbstractResult( 434fac349a1SChristian Sigg mlir::cast<mlir::FunctionType>(pointer.getEleTy())); 435a6f2f44fSDaniil Dudkin }) 436a6f2f44fSDaniil Dudkin .Default([](auto &&) { return false; }); 437a6f2f44fSDaniil Dudkin } 438a6f2f44fSDaniil Dudkin 439a6f2f44fSDaniil Dudkin void runOnSpecificOperation(fir::GlobalOp global, bool, 440a6f2f44fSDaniil Dudkin mlir::RewritePatternSet &, 441a6f2f44fSDaniil Dudkin mlir::ConversionTarget &) { 442a6f2f44fSDaniil Dudkin if (containsFunctionTypeWithAbstractResult(global.getType())) { 443a6f2f44fSDaniil Dudkin TODO(global->getLoc(), "support for procedure pointers"); 444a6f2f44fSDaniil Dudkin } 445a6f2f44fSDaniil Dudkin } 446bfd19445STom Eccles 447bfd19445STom Eccles /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work. 448bfd19445STom Eccles void runOnModule() { 449bfd19445STom Eccles mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation()); 450bfd19445STom Eccles 451bfd19445STom Eccles auto pass = std::make_unique<AbstractResultOpt>(); 452bfd19445STom Eccles pass->copyOptionValuesFrom(this); 453bfd19445STom Eccles mlir::OpPassManager pipeline; 454bfd19445STom Eccles pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()}); 455bfd19445STom Eccles 456bfd19445STom Eccles // Run the pass on all operations directly nested inside of the ModuleOp 457bfd19445STom Eccles // we can't just call runOnSpecificOperation here because the pass 458bfd19445STom Eccles // implementation only works when scoped to a particular func.func or 459bfd19445STom Eccles // fir.global 460bfd19445STom Eccles for (mlir::Region ®ion : mod->getRegions()) { 461bfd19445STom Eccles for (mlir::Block &block : region.getBlocks()) { 462bfd19445STom Eccles for (mlir::Operation &op : block.getOperations()) { 463bfd19445STom Eccles if (mlir::failed(runPipeline(pipeline, &op))) { 464bfd19445STom Eccles mlir::emitError(op.getLoc(), "Failed to run abstract result pass"); 465bfd19445STom Eccles signalPassFailure(); 466bfd19445STom Eccles return; 467bfd19445STom Eccles } 468bfd19445STom Eccles } 469bfd19445STom Eccles } 470bfd19445STom Eccles } 471bfd19445STom Eccles } 472bfd19445STom Eccles 473bfd19445STom Eccles void runOnOperation() override { 474bfd19445STom Eccles auto *context = &this->getContext(); 475bfd19445STom Eccles mlir::Operation *op = this->getOperation(); 476bfd19445STom Eccles if (mlir::isa<mlir::ModuleOp>(op)) { 477bfd19445STom Eccles runOnModule(); 478bfd19445STom Eccles return; 479bfd19445STom Eccles } 480bfd19445STom Eccles 481367c3c96SjeanPerier LazySymbolTable symbolTable(op); 482367c3c96SjeanPerier 483bfd19445STom Eccles mlir::RewritePatternSet patterns(context); 484bfd19445STom Eccles mlir::ConversionTarget target = *context; 485bfd19445STom Eccles const bool shouldBoxResult = this->passResultAsBox.getValue(); 486bfd19445STom Eccles 487bfd19445STom Eccles mlir::TypeSwitch<mlir::Operation *, void>(op) 4881d4b5c16SValentin Clement (バレンタイン クレメン) .Case<mlir::func::FuncOp, fir::GlobalOp, mlir::gpu::GPUFuncOp>( 4891d4b5c16SValentin Clement (バレンタイン クレメン) [&](auto op) { 490bfd19445STom Eccles runOnSpecificOperation(op, shouldBoxResult, patterns, target); 491bfd19445STom Eccles }); 492bfd19445STom Eccles 493bfd19445STom Eccles // Convert the calls and, if needed, the ReturnOp in the function body. 494bfd19445STom Eccles target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect, 495bfd19445STom Eccles mlir::func::FuncDialect>(); 496bfd19445STom Eccles target.addIllegalOp<fir::SaveResultOp>(); 497bfd19445STom Eccles target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) { 498367c3c96SjeanPerier mlir::FunctionType funTy = call.getFunctionType(); 499367c3c96SjeanPerier if (hasScalarDerivedResult(funTy) && 500367c3c96SjeanPerier fir::hasBindcAttr(call.getOperation())) 501367c3c96SjeanPerier return true; 502480e7f06SjeanPerier return !hasAbstractResult(funTy); 503367c3c96SjeanPerier }); 504367c3c96SjeanPerier target.addDynamicallyLegalOp<fir::AddrOfOp>([&symbolTable]( 505367c3c96SjeanPerier fir::AddrOfOp addrOf) { 506367c3c96SjeanPerier if (auto funTy = mlir::dyn_cast<mlir::FunctionType>(addrOf.getType())) { 507367c3c96SjeanPerier if (hasScalarDerivedResult(funTy)) { 508367c3c96SjeanPerier auto func = symbolTable.lookup<mlir::func::FuncOp>( 509367c3c96SjeanPerier addrOf.getSymbol().getRootReference().getValue()); 510367c3c96SjeanPerier return func && fir::hasBindcAttr(func.getOperation()); 511367c3c96SjeanPerier } 512367c3c96SjeanPerier return !hasAbstractResult(funTy); 513367c3c96SjeanPerier } 514bfd19445STom Eccles return true; 515bfd19445STom Eccles }); 516bfd19445STom Eccles target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { 517367c3c96SjeanPerier mlir::FunctionType funTy = dispatch.getFunctionType(); 518367c3c96SjeanPerier if (hasScalarDerivedResult(funTy) && 519367c3c96SjeanPerier fir::hasBindcAttr(dispatch.getOperation())) 520367c3c96SjeanPerier return true; 521bfd19445STom Eccles return !hasAbstractResult(dispatch.getFunctionType()); 522bfd19445STom Eccles }); 523bfd19445STom Eccles 524bfd19445STom Eccles patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult); 525bfd19445STom Eccles patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult); 526bfd19445STom Eccles patterns.insert<SaveResultOpConversion>(context); 527bfd19445STom Eccles patterns.insert<AddrOfOpConversion>(context, shouldBoxResult); 528bfd19445STom Eccles if (mlir::failed( 529bfd19445STom Eccles mlir::applyPartialConversion(op, target, std::move(patterns)))) { 530bfd19445STom Eccles mlir::emitError(op->getLoc(), "error in converting abstract results\n"); 531bfd19445STom Eccles this->signalPassFailure(); 532bfd19445STom Eccles } 533bfd19445STom Eccles } 534a6f2f44fSDaniil Dudkin }; 535bfd19445STom Eccles 536b0eef1eeSValentin Clement } // end anonymous namespace 537b0eef1eeSValentin Clement } // namespace fir 538