1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/FIRBuilder.h" 10 #include "flang/Optimizer/Builder/Todo.h" 11 #include "flang/Optimizer/Dialect/FIRDialect.h" 12 #include "flang/Optimizer/Dialect/FIROps.h" 13 #include "flang/Optimizer/Dialect/FIRType.h" 14 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 15 #include "flang/Optimizer/Transforms/Passes.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/IR/Diagnostics.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Pass/PassManager.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 namespace fir { 24 #define GEN_PASS_DEF_ABSTRACTRESULTOPT 25 #include "flang/Optimizer/Transforms/Passes.h.inc" 26 } // namespace fir 27 28 #define DEBUG_TYPE "flang-abstract-result-opt" 29 30 using namespace mlir; 31 32 namespace fir { 33 namespace { 34 35 // Helper to only build the symbol table if needed because its build time is 36 // linear on the number of symbols in the module. 37 struct LazySymbolTable { 38 LazySymbolTable(mlir::Operation *op) 39 : module{op->getParentOfType<mlir::ModuleOp>()} {} 40 void build() { 41 if (table) 42 return; 43 table = std::make_unique<mlir::SymbolTable>(module); 44 } 45 46 template <typename T> 47 T lookup(llvm::StringRef name) { 48 build(); 49 return table->lookup<T>(name); 50 } 51 52 private: 53 std::unique_ptr<mlir::SymbolTable> table; 54 mlir::ModuleOp module; 55 }; 56 57 bool hasScalarDerivedResult(mlir::FunctionType funTy) { 58 // C_PTR/C_FUNPTR are results to void* in this pass, do not consider 59 // them as normal derived types. 60 return funTy.getNumResults() == 1 && 61 mlir::isa<fir::RecordType>(funTy.getResult(0)) && 62 !fir::isa_builtin_cptr_type(funTy.getResult(0)); 63 } 64 65 static mlir::Type getResultArgumentType(mlir::Type resultType, 66 bool shouldBoxResult) { 67 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) 68 .Case<fir::SequenceType, fir::RecordType>( 69 [&](mlir::Type type) -> mlir::Type { 70 if (shouldBoxResult) 71 return fir::BoxType::get(type); 72 return fir::ReferenceType::get(type); 73 }) 74 .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type { 75 return fir::ReferenceType::get(type); 76 }) 77 .Default([](mlir::Type) -> mlir::Type { 78 llvm_unreachable("bad abstract result type"); 79 }); 80 } 81 82 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, 83 bool shouldBoxResult) { 84 auto resultType = funcTy.getResult(0); 85 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 86 llvm::SmallVector<mlir::Type> newInputTypes = {argTy}; 87 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); 88 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, 89 /*resultTypes=*/{}); 90 } 91 92 static mlir::Type getVoidPtrType(mlir::MLIRContext *context) { 93 return fir::ReferenceType::get(mlir::NoneType::get(context)); 94 } 95 96 /// This is for function result types that are of type C_PTR from ISO_C_BINDING. 97 /// Follow the ABI for interoperability with C. 98 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) { 99 assert(fir::isa_builtin_cptr_type(funcTy.getResult(0))); 100 llvm::SmallVector<mlir::Type> outputTypes{ 101 getVoidPtrType(funcTy.getContext())}; 102 return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(), 103 outputTypes); 104 } 105 106 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { 107 return mlir::isa<fir::SequenceType, fir::RecordType>(resultType) && 108 shouldBoxResult; 109 } 110 111 template <typename Op> 112 class CallConversion : public mlir::OpRewritePattern<Op> { 113 public: 114 using mlir::OpRewritePattern<Op>::OpRewritePattern; 115 116 CallConversion(mlir::MLIRContext *context, bool shouldBoxResult) 117 : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {} 118 119 llvm::LogicalResult 120 matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { 121 auto loc = op.getLoc(); 122 auto result = op->getResult(0); 123 if (!result.hasOneUse()) { 124 mlir::emitError(loc, 125 "calls with abstract result must have exactly one user"); 126 return mlir::failure(); 127 } 128 auto saveResult = 129 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser()); 130 if (!saveResult) { 131 mlir::emitError( 132 loc, "calls with abstract result must be used in fir.save_result"); 133 return mlir::failure(); 134 } 135 auto argType = getResultArgumentType(result.getType(), shouldBoxResult); 136 auto buffer = saveResult.getMemref(); 137 mlir::Value arg = buffer; 138 if (mustEmboxResult(result.getType(), shouldBoxResult)) 139 arg = rewriter.create<fir::EmboxOp>( 140 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, 141 saveResult.getTypeparams()); 142 143 llvm::SmallVector<mlir::Type> newResultTypes; 144 bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType()); 145 if (isResultBuiltinCPtr) 146 newResultTypes.emplace_back(getVoidPtrType(result.getContext())); 147 148 Op newOp; 149 // fir::CallOp specific handling. 150 if constexpr (std::is_same_v<Op, fir::CallOp>) { 151 if (op.getCallee()) { 152 llvm::SmallVector<mlir::Value> newOperands; 153 if (!isResultBuiltinCPtr) 154 newOperands.emplace_back(arg); 155 newOperands.append(op.getOperands().begin(), op.getOperands().end()); 156 newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(), 157 newResultTypes, newOperands); 158 } else { 159 // Indirect calls. 160 llvm::SmallVector<mlir::Type> newInputTypes; 161 if (!isResultBuiltinCPtr) 162 newInputTypes.emplace_back(argType); 163 for (auto operand : op.getOperands().drop_front()) 164 newInputTypes.push_back(operand.getType()); 165 auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes, 166 newResultTypes); 167 168 llvm::SmallVector<mlir::Value> newOperands; 169 newOperands.push_back( 170 rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0))); 171 if (!isResultBuiltinCPtr) 172 newOperands.push_back(arg); 173 newOperands.append(op.getOperands().begin() + 1, 174 op.getOperands().end()); 175 newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, 176 newResultTypes, newOperands); 177 } 178 } 179 180 // fir::DispatchOp specific handling. 181 if constexpr (std::is_same_v<Op, fir::DispatchOp>) { 182 llvm::SmallVector<mlir::Value> newOperands; 183 if (!isResultBuiltinCPtr) 184 newOperands.emplace_back(arg); 185 unsigned passArgShift = newOperands.size(); 186 newOperands.append(op.getOperands().begin() + 1, op.getOperands().end()); 187 mlir::IntegerAttr passArgPos; 188 if (op.getPassArgPos()) 189 passArgPos = 190 rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift); 191 newOp = rewriter.create<fir::DispatchOp>( 192 loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), 193 op.getOperands()[0], newOperands, passArgPos, 194 op.getProcedureAttrsAttr()); 195 } 196 197 if (isResultBuiltinCPtr) { 198 mlir::Value save = saveResult.getMemref(); 199 auto module = op->template getParentOfType<mlir::ModuleOp>(); 200 FirOpBuilder builder(rewriter, module); 201 mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( 202 builder, loc, save, result.getType()); 203 builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr); 204 } 205 op->dropAllReferences(); 206 rewriter.eraseOp(op); 207 return mlir::success(); 208 } 209 210 private: 211 bool shouldBoxResult; 212 }; 213 214 class SaveResultOpConversion 215 : public mlir::OpRewritePattern<fir::SaveResultOp> { 216 public: 217 using OpRewritePattern::OpRewritePattern; 218 SaveResultOpConversion(mlir::MLIRContext *context) 219 : OpRewritePattern(context) {} 220 llvm::LogicalResult 221 matchAndRewrite(fir::SaveResultOp op, 222 mlir::PatternRewriter &rewriter) const override { 223 mlir::Operation *call = op.getValue().getDefiningOp(); 224 mlir::Type type = op.getValue().getType(); 225 if (mlir::isa<fir::RecordType>(type) && call && fir::hasBindcAttr(call) && 226 !fir::isa_builtin_cptr_type(type)) { 227 rewriter.replaceOpWithNewOp<fir::StoreOp>(op, op.getValue(), 228 op.getMemref()); 229 } else { 230 rewriter.eraseOp(op); 231 } 232 return mlir::success(); 233 } 234 }; 235 236 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { 237 public: 238 using OpRewritePattern::OpRewritePattern; 239 ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) 240 : OpRewritePattern(context), newArg{newArg} {} 241 llvm::LogicalResult 242 matchAndRewrite(mlir::func::ReturnOp ret, 243 mlir::PatternRewriter &rewriter) const override { 244 auto loc = ret.getLoc(); 245 rewriter.setInsertionPoint(ret); 246 mlir::Value resultValue = ret.getOperand(0); 247 fir::LoadOp resultLoad; 248 mlir::Value resultStorage; 249 // Identify result local storage. 250 if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) { 251 resultLoad = load; 252 resultStorage = load.getMemref(); 253 // The result alloca may be behind a fir.declare, if any. 254 if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>()) 255 resultStorage = declare.getMemref(); 256 } 257 // Replace old local storage with new storage argument, unless 258 // the derived type is C_PTR/C_FUN_PTR, in which case the return 259 // type is updated to return void* (no new argument is passed). 260 if (fir::isa_builtin_cptr_type(resultValue.getType())) { 261 auto module = ret->getParentOfType<mlir::ModuleOp>(); 262 FirOpBuilder builder(rewriter, module); 263 mlir::Value cptr = resultValue; 264 if (resultLoad) { 265 // Replace whole derived type load by component load. 266 cptr = resultLoad.getMemref(); 267 rewriter.setInsertionPoint(resultLoad); 268 } 269 mlir::Value newResultValue = 270 fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr); 271 newResultValue = builder.createConvert( 272 loc, getVoidPtrType(ret.getContext()), newResultValue); 273 rewriter.setInsertionPoint(ret); 274 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>( 275 ret, mlir::ValueRange{newResultValue}); 276 } else if (resultStorage) { 277 resultStorage.replaceAllUsesWith(newArg); 278 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); 279 } else { 280 // The result storage may have been optimized out by a memory to 281 // register pass, this is possible for fir.box results, or fir.record 282 // with no length parameters. Simply store the result in the result 283 // storage. at the return point. 284 rewriter.create<fir::StoreOp>(loc, resultValue, newArg); 285 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); 286 } 287 // Delete result old local storage if unused. 288 if (resultStorage) 289 if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>()) 290 if (alloc->use_empty()) 291 rewriter.eraseOp(alloc); 292 return mlir::success(); 293 } 294 295 private: 296 mlir::Value newArg; 297 }; 298 299 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> { 300 public: 301 using OpRewritePattern::OpRewritePattern; 302 AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) 303 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} 304 llvm::LogicalResult 305 matchAndRewrite(fir::AddrOfOp addrOf, 306 mlir::PatternRewriter &rewriter) const override { 307 auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType()); 308 mlir::FunctionType newFuncTy; 309 if (oldFuncTy.getNumResults() != 0 && 310 fir::isa_builtin_cptr_type(oldFuncTy.getResult(0))) 311 newFuncTy = getCPtrFunctionType(oldFuncTy); 312 else 313 newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); 314 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy, 315 addrOf.getSymbol()); 316 // Rather than converting all op a function pointer might transit through 317 // (e.g calls, stores, loads, converts...), cast new type to the abstract 318 // type. A conversion will be added when calling indirect calls of abstract 319 // types. 320 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf); 321 return mlir::success(); 322 } 323 324 private: 325 bool shouldBoxResult; 326 }; 327 328 class AbstractResultOpt 329 : public fir::impl::AbstractResultOptBase<AbstractResultOpt> { 330 public: 331 using fir::impl::AbstractResultOptBase< 332 AbstractResultOpt>::AbstractResultOptBase; 333 334 void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult, 335 mlir::RewritePatternSet &patterns, 336 mlir::ConversionTarget &target) { 337 auto loc = func.getLoc(); 338 auto *context = &getContext(); 339 // Convert function type itself if it has an abstract result. 340 auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType()); 341 // Scalar derived result of BIND(C) function must be returned according 342 // to the C struct return ABI which is target dependent and implemented in 343 // the target-rewrite pass. 344 if (hasScalarDerivedResult(funcTy) && 345 fir::hasBindcAttr(func.getOperation())) 346 return; 347 if (hasAbstractResult(funcTy)) { 348 if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) { 349 func.setType(getCPtrFunctionType(funcTy)); 350 patterns.insert<ReturnOpConversion>(context, mlir::Value{}); 351 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 352 [](mlir::func::ReturnOp ret) { 353 mlir::Type retTy = ret.getOperand(0).getType(); 354 return !fir::isa_builtin_cptr_type(retTy); 355 }); 356 return; 357 } 358 if (!func.empty()) { 359 // Insert new argument. 360 mlir::OpBuilder rewriter(context); 361 auto resultType = funcTy.getResult(0); 362 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 363 func.insertArgument(0u, argTy, {}, loc); 364 func.eraseResult(0u); 365 mlir::Value newArg = func.getArgument(0u); 366 if (mustEmboxResult(resultType, shouldBoxResult)) { 367 auto bufferType = fir::ReferenceType::get(resultType); 368 rewriter.setInsertionPointToStart(&func.front()); 369 newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg); 370 } 371 patterns.insert<ReturnOpConversion>(context, newArg); 372 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 373 [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); }); 374 assert(func.getFunctionType() == 375 getNewFunctionType(funcTy, shouldBoxResult)); 376 } else { 377 llvm::SmallVector<mlir::DictionaryAttr> allArgs; 378 func.getAllArgAttrs(allArgs); 379 allArgs.insert(allArgs.begin(), 380 mlir::DictionaryAttr::get(func->getContext())); 381 func.setType(getNewFunctionType(funcTy, shouldBoxResult)); 382 func.setAllArgAttrs(allArgs); 383 } 384 } 385 } 386 387 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) { 388 return mlir::TypeSwitch<mlir::Type, bool>(type) 389 .Case([](fir::BoxProcType boxProc) { 390 return fir::hasAbstractResult( 391 mlir::cast<mlir::FunctionType>(boxProc.getEleTy())); 392 }) 393 .Case([](fir::PointerType pointer) { 394 return fir::hasAbstractResult( 395 mlir::cast<mlir::FunctionType>(pointer.getEleTy())); 396 }) 397 .Default([](auto &&) { return false; }); 398 } 399 400 void runOnSpecificOperation(fir::GlobalOp global, bool, 401 mlir::RewritePatternSet &, 402 mlir::ConversionTarget &) { 403 if (containsFunctionTypeWithAbstractResult(global.getType())) { 404 TODO(global->getLoc(), "support for procedure pointers"); 405 } 406 } 407 408 /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work. 409 void runOnModule() { 410 mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation()); 411 412 auto pass = std::make_unique<AbstractResultOpt>(); 413 pass->copyOptionValuesFrom(this); 414 mlir::OpPassManager pipeline; 415 pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()}); 416 417 // Run the pass on all operations directly nested inside of the ModuleOp 418 // we can't just call runOnSpecificOperation here because the pass 419 // implementation only works when scoped to a particular func.func or 420 // fir.global 421 for (mlir::Region ®ion : mod->getRegions()) { 422 for (mlir::Block &block : region.getBlocks()) { 423 for (mlir::Operation &op : block.getOperations()) { 424 if (mlir::failed(runPipeline(pipeline, &op))) { 425 mlir::emitError(op.getLoc(), "Failed to run abstract result pass"); 426 signalPassFailure(); 427 return; 428 } 429 } 430 } 431 } 432 } 433 434 void runOnOperation() override { 435 auto *context = &this->getContext(); 436 mlir::Operation *op = this->getOperation(); 437 if (mlir::isa<mlir::ModuleOp>(op)) { 438 runOnModule(); 439 return; 440 } 441 442 LazySymbolTable symbolTable(op); 443 444 mlir::RewritePatternSet patterns(context); 445 mlir::ConversionTarget target = *context; 446 const bool shouldBoxResult = this->passResultAsBox.getValue(); 447 448 mlir::TypeSwitch<mlir::Operation *, void>(op) 449 .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) { 450 runOnSpecificOperation(op, shouldBoxResult, patterns, target); 451 }); 452 453 // Convert the calls and, if needed, the ReturnOp in the function body. 454 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect, 455 mlir::func::FuncDialect>(); 456 target.addIllegalOp<fir::SaveResultOp>(); 457 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) { 458 mlir::FunctionType funTy = call.getFunctionType(); 459 if (hasScalarDerivedResult(funTy) && 460 fir::hasBindcAttr(call.getOperation())) 461 return true; 462 return !hasAbstractResult(funTy); 463 }); 464 target.addDynamicallyLegalOp<fir::AddrOfOp>([&symbolTable]( 465 fir::AddrOfOp addrOf) { 466 if (auto funTy = mlir::dyn_cast<mlir::FunctionType>(addrOf.getType())) { 467 if (hasScalarDerivedResult(funTy)) { 468 auto func = symbolTable.lookup<mlir::func::FuncOp>( 469 addrOf.getSymbol().getRootReference().getValue()); 470 return func && fir::hasBindcAttr(func.getOperation()); 471 } 472 return !hasAbstractResult(funTy); 473 } 474 return true; 475 }); 476 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { 477 mlir::FunctionType funTy = dispatch.getFunctionType(); 478 if (hasScalarDerivedResult(funTy) && 479 fir::hasBindcAttr(dispatch.getOperation())) 480 return true; 481 return !hasAbstractResult(dispatch.getFunctionType()); 482 }); 483 484 patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult); 485 patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult); 486 patterns.insert<SaveResultOpConversion>(context); 487 patterns.insert<AddrOfOpConversion>(context, shouldBoxResult); 488 if (mlir::failed( 489 mlir::applyPartialConversion(op, target, std::move(patterns)))) { 490 mlir::emitError(op->getLoc(), "error in converting abstract results\n"); 491 this->signalPassFailure(); 492 } 493 } 494 }; 495 496 } // end anonymous namespace 497 } // namespace fir 498