1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/FIRBuilder.h" 10 #include "flang/Optimizer/Builder/Todo.h" 11 #include "flang/Optimizer/Dialect/FIRDialect.h" 12 #include "flang/Optimizer/Dialect/FIROps.h" 13 #include "flang/Optimizer/Dialect/FIRType.h" 14 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 15 #include "flang/Optimizer/Transforms/Passes.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 18 #include "mlir/IR/Diagnostics.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Pass/PassManager.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 24 namespace fir { 25 #define GEN_PASS_DEF_ABSTRACTRESULTOPT 26 #include "flang/Optimizer/Transforms/Passes.h.inc" 27 } // namespace fir 28 29 #define DEBUG_TYPE "flang-abstract-result-opt" 30 31 using namespace mlir; 32 33 namespace fir { 34 namespace { 35 36 // Helper to only build the symbol table if needed because its build time is 37 // linear on the number of symbols in the module. 38 struct LazySymbolTable { 39 LazySymbolTable(mlir::Operation *op) 40 : module{op->getParentOfType<mlir::ModuleOp>()} {} 41 void build() { 42 if (table) 43 return; 44 table = std::make_unique<mlir::SymbolTable>(module); 45 } 46 47 template <typename T> 48 T lookup(llvm::StringRef name) { 49 build(); 50 return table->lookup<T>(name); 51 } 52 53 private: 54 std::unique_ptr<mlir::SymbolTable> table; 55 mlir::ModuleOp module; 56 }; 57 58 bool hasScalarDerivedResult(mlir::FunctionType funTy) { 59 // C_PTR/C_FUNPTR are results to void* in this pass, do not consider 60 // them as normal derived types. 61 return funTy.getNumResults() == 1 && 62 mlir::isa<fir::RecordType>(funTy.getResult(0)) && 63 !fir::isa_builtin_cptr_type(funTy.getResult(0)); 64 } 65 66 static mlir::Type getResultArgumentType(mlir::Type resultType, 67 bool shouldBoxResult) { 68 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) 69 .Case<fir::SequenceType, fir::RecordType>( 70 [&](mlir::Type type) -> mlir::Type { 71 if (shouldBoxResult) 72 return fir::BoxType::get(type); 73 return fir::ReferenceType::get(type); 74 }) 75 .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type { 76 return fir::ReferenceType::get(type); 77 }) 78 .Default([](mlir::Type) -> mlir::Type { 79 llvm_unreachable("bad abstract result type"); 80 }); 81 } 82 83 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, 84 bool shouldBoxResult) { 85 auto resultType = funcTy.getResult(0); 86 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 87 llvm::SmallVector<mlir::Type> newInputTypes = {argTy}; 88 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); 89 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, 90 /*resultTypes=*/{}); 91 } 92 93 static mlir::Type getVoidPtrType(mlir::MLIRContext *context) { 94 return fir::ReferenceType::get(mlir::NoneType::get(context)); 95 } 96 97 /// This is for function result types that are of type C_PTR from ISO_C_BINDING. 98 /// Follow the ABI for interoperability with C. 99 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) { 100 assert(fir::isa_builtin_cptr_type(funcTy.getResult(0))); 101 llvm::SmallVector<mlir::Type> outputTypes{ 102 getVoidPtrType(funcTy.getContext())}; 103 return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(), 104 outputTypes); 105 } 106 107 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { 108 return mlir::isa<fir::SequenceType, fir::RecordType>(resultType) && 109 shouldBoxResult; 110 } 111 112 template <typename Op> 113 class CallConversion : public mlir::OpRewritePattern<Op> { 114 public: 115 using mlir::OpRewritePattern<Op>::OpRewritePattern; 116 117 CallConversion(mlir::MLIRContext *context, bool shouldBoxResult) 118 : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {} 119 120 llvm::LogicalResult 121 matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { 122 auto loc = op.getLoc(); 123 auto result = op->getResult(0); 124 if (!result.hasOneUse()) { 125 mlir::emitError(loc, 126 "calls with abstract result must have exactly one user"); 127 return mlir::failure(); 128 } 129 auto saveResult = 130 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser()); 131 if (!saveResult) { 132 mlir::emitError( 133 loc, "calls with abstract result must be used in fir.save_result"); 134 return mlir::failure(); 135 } 136 auto argType = getResultArgumentType(result.getType(), shouldBoxResult); 137 auto buffer = saveResult.getMemref(); 138 mlir::Value arg = buffer; 139 if (mustEmboxResult(result.getType(), shouldBoxResult)) 140 arg = rewriter.create<fir::EmboxOp>( 141 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, 142 saveResult.getTypeparams()); 143 144 llvm::SmallVector<mlir::Type> newResultTypes; 145 bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType()); 146 if (isResultBuiltinCPtr) 147 newResultTypes.emplace_back(getVoidPtrType(result.getContext())); 148 149 Op newOp; 150 // fir::CallOp specific handling. 151 if constexpr (std::is_same_v<Op, fir::CallOp>) { 152 if (op.getCallee()) { 153 llvm::SmallVector<mlir::Value> newOperands; 154 if (!isResultBuiltinCPtr) 155 newOperands.emplace_back(arg); 156 newOperands.append(op.getOperands().begin(), op.getOperands().end()); 157 newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(), 158 newResultTypes, newOperands); 159 } else { 160 // Indirect calls. 161 llvm::SmallVector<mlir::Type> newInputTypes; 162 if (!isResultBuiltinCPtr) 163 newInputTypes.emplace_back(argType); 164 for (auto operand : op.getOperands().drop_front()) 165 newInputTypes.push_back(operand.getType()); 166 auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes, 167 newResultTypes); 168 169 llvm::SmallVector<mlir::Value> newOperands; 170 newOperands.push_back( 171 rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0))); 172 if (!isResultBuiltinCPtr) 173 newOperands.push_back(arg); 174 newOperands.append(op.getOperands().begin() + 1, 175 op.getOperands().end()); 176 newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, 177 newResultTypes, newOperands); 178 } 179 } 180 181 // fir::DispatchOp specific handling. 182 if constexpr (std::is_same_v<Op, fir::DispatchOp>) { 183 llvm::SmallVector<mlir::Value> newOperands; 184 if (!isResultBuiltinCPtr) 185 newOperands.emplace_back(arg); 186 unsigned passArgShift = newOperands.size(); 187 newOperands.append(op.getOperands().begin() + 1, op.getOperands().end()); 188 mlir::IntegerAttr passArgPos; 189 if (op.getPassArgPos()) 190 passArgPos = 191 rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift); 192 newOp = rewriter.create<fir::DispatchOp>( 193 loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), 194 op.getOperands()[0], newOperands, passArgPos, 195 op.getProcedureAttrsAttr()); 196 } 197 198 if (isResultBuiltinCPtr) { 199 mlir::Value save = saveResult.getMemref(); 200 auto module = op->template getParentOfType<mlir::ModuleOp>(); 201 FirOpBuilder builder(rewriter, module); 202 mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( 203 builder, loc, save, result.getType()); 204 builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr); 205 } 206 op->dropAllReferences(); 207 rewriter.eraseOp(op); 208 return mlir::success(); 209 } 210 211 private: 212 bool shouldBoxResult; 213 }; 214 215 class SaveResultOpConversion 216 : public mlir::OpRewritePattern<fir::SaveResultOp> { 217 public: 218 using OpRewritePattern::OpRewritePattern; 219 SaveResultOpConversion(mlir::MLIRContext *context) 220 : OpRewritePattern(context) {} 221 llvm::LogicalResult 222 matchAndRewrite(fir::SaveResultOp op, 223 mlir::PatternRewriter &rewriter) const override { 224 mlir::Operation *call = op.getValue().getDefiningOp(); 225 mlir::Type type = op.getValue().getType(); 226 if (mlir::isa<fir::RecordType>(type) && call && fir::hasBindcAttr(call) && 227 !fir::isa_builtin_cptr_type(type)) { 228 rewriter.replaceOpWithNewOp<fir::StoreOp>(op, op.getValue(), 229 op.getMemref()); 230 } else { 231 rewriter.eraseOp(op); 232 } 233 return mlir::success(); 234 } 235 }; 236 237 template <typename OpTy> 238 static mlir::LogicalResult 239 processReturnLikeOp(OpTy ret, mlir::Value newArg, 240 mlir::PatternRewriter &rewriter) { 241 auto loc = ret.getLoc(); 242 rewriter.setInsertionPoint(ret); 243 mlir::Value resultValue = ret.getOperand(0); 244 fir::LoadOp resultLoad; 245 mlir::Value resultStorage; 246 // Identify result local storage. 247 if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) { 248 resultLoad = load; 249 resultStorage = load.getMemref(); 250 // The result alloca may be behind a fir.declare, if any. 251 if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>()) 252 resultStorage = declare.getMemref(); 253 } 254 // Replace old local storage with new storage argument, unless 255 // the derived type is C_PTR/C_FUN_PTR, in which case the return 256 // type is updated to return void* (no new argument is passed). 257 if (fir::isa_builtin_cptr_type(resultValue.getType())) { 258 auto module = ret->template getParentOfType<mlir::ModuleOp>(); 259 FirOpBuilder builder(rewriter, module); 260 mlir::Value cptr = resultValue; 261 if (resultLoad) { 262 // Replace whole derived type load by component load. 263 cptr = resultLoad.getMemref(); 264 rewriter.setInsertionPoint(resultLoad); 265 } 266 mlir::Value newResultValue = 267 fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr); 268 newResultValue = builder.createConvert( 269 loc, getVoidPtrType(ret.getContext()), newResultValue); 270 rewriter.setInsertionPoint(ret); 271 rewriter.replaceOpWithNewOp<OpTy>(ret, mlir::ValueRange{newResultValue}); 272 } else if (resultStorage) { 273 resultStorage.replaceAllUsesWith(newArg); 274 rewriter.replaceOpWithNewOp<OpTy>(ret); 275 } else { 276 // The result storage may have been optimized out by a memory to 277 // register pass, this is possible for fir.box results, or fir.record 278 // with no length parameters. Simply store the result in the result 279 // storage. at the return point. 280 rewriter.create<fir::StoreOp>(loc, resultValue, newArg); 281 rewriter.replaceOpWithNewOp<OpTy>(ret); 282 } 283 // Delete result old local storage if unused. 284 if (resultStorage) 285 if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>()) 286 if (alloc->use_empty()) 287 rewriter.eraseOp(alloc); 288 return mlir::success(); 289 } 290 291 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { 292 public: 293 using OpRewritePattern::OpRewritePattern; 294 ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) 295 : OpRewritePattern(context), newArg{newArg} {} 296 llvm::LogicalResult 297 matchAndRewrite(mlir::func::ReturnOp ret, 298 mlir::PatternRewriter &rewriter) const override { 299 return processReturnLikeOp(ret, newArg, rewriter); 300 } 301 302 private: 303 mlir::Value newArg; 304 }; 305 306 class GPUReturnOpConversion 307 : public mlir::OpRewritePattern<mlir::gpu::ReturnOp> { 308 public: 309 using OpRewritePattern::OpRewritePattern; 310 GPUReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) 311 : OpRewritePattern(context), newArg{newArg} {} 312 llvm::LogicalResult 313 matchAndRewrite(mlir::gpu::ReturnOp ret, 314 mlir::PatternRewriter &rewriter) const override { 315 return processReturnLikeOp(ret, newArg, rewriter); 316 } 317 318 private: 319 mlir::Value newArg; 320 }; 321 322 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> { 323 public: 324 using OpRewritePattern::OpRewritePattern; 325 AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) 326 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} 327 llvm::LogicalResult 328 matchAndRewrite(fir::AddrOfOp addrOf, 329 mlir::PatternRewriter &rewriter) const override { 330 auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType()); 331 mlir::FunctionType newFuncTy; 332 if (oldFuncTy.getNumResults() != 0 && 333 fir::isa_builtin_cptr_type(oldFuncTy.getResult(0))) 334 newFuncTy = getCPtrFunctionType(oldFuncTy); 335 else 336 newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); 337 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy, 338 addrOf.getSymbol()); 339 // Rather than converting all op a function pointer might transit through 340 // (e.g calls, stores, loads, converts...), cast new type to the abstract 341 // type. A conversion will be added when calling indirect calls of abstract 342 // types. 343 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf); 344 return mlir::success(); 345 } 346 347 private: 348 bool shouldBoxResult; 349 }; 350 351 class AbstractResultOpt 352 : public fir::impl::AbstractResultOptBase<AbstractResultOpt> { 353 public: 354 using fir::impl::AbstractResultOptBase< 355 AbstractResultOpt>::AbstractResultOptBase; 356 357 template <typename OpTy> 358 void runOnFunctionLikeOperation(OpTy func, bool shouldBoxResult, 359 mlir::RewritePatternSet &patterns, 360 mlir::ConversionTarget &target) { 361 auto loc = func.getLoc(); 362 auto *context = &getContext(); 363 // Convert function type itself if it has an abstract result. 364 auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType()); 365 // Scalar derived result of BIND(C) function must be returned according 366 // to the C struct return ABI which is target dependent and implemented in 367 // the target-rewrite pass. 368 if (hasScalarDerivedResult(funcTy) && 369 fir::hasBindcAttr(func.getOperation())) 370 return; 371 if (hasAbstractResult(funcTy)) { 372 if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) { 373 func.setType(getCPtrFunctionType(funcTy)); 374 patterns.insert<ReturnOpConversion>(context, mlir::Value{}); 375 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 376 [](mlir::func::ReturnOp ret) { 377 mlir::Type retTy = ret.getOperand(0).getType(); 378 return !fir::isa_builtin_cptr_type(retTy); 379 }); 380 return; 381 } 382 if (!func.empty()) { 383 // Insert new argument. 384 mlir::OpBuilder rewriter(context); 385 auto resultType = funcTy.getResult(0); 386 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 387 func.insertArgument(0u, argTy, {}, loc); 388 func.eraseResult(0u); 389 mlir::Value newArg = func.getArgument(0u); 390 if (mustEmboxResult(resultType, shouldBoxResult)) { 391 auto bufferType = fir::ReferenceType::get(resultType); 392 rewriter.setInsertionPointToStart(&func.front()); 393 newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg); 394 } 395 patterns.insert<ReturnOpConversion>(context, newArg); 396 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 397 [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); }); 398 patterns.insert<GPUReturnOpConversion>(context, newArg); 399 target.addDynamicallyLegalOp<mlir::gpu::ReturnOp>( 400 [](mlir::gpu::ReturnOp ret) { return ret.getOperands().empty(); }); 401 assert(func.getFunctionType() == 402 getNewFunctionType(funcTy, shouldBoxResult)); 403 } else { 404 llvm::SmallVector<mlir::DictionaryAttr> allArgs; 405 func.getAllArgAttrs(allArgs); 406 allArgs.insert(allArgs.begin(), 407 mlir::DictionaryAttr::get(func->getContext())); 408 func.setType(getNewFunctionType(funcTy, shouldBoxResult)); 409 func.setAllArgAttrs(allArgs); 410 } 411 } 412 } 413 414 void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult, 415 mlir::RewritePatternSet &patterns, 416 mlir::ConversionTarget &target) { 417 runOnFunctionLikeOperation(func, shouldBoxResult, patterns, target); 418 } 419 420 void runOnSpecificOperation(mlir::gpu::GPUFuncOp func, bool shouldBoxResult, 421 mlir::RewritePatternSet &patterns, 422 mlir::ConversionTarget &target) { 423 runOnFunctionLikeOperation(func, shouldBoxResult, patterns, target); 424 } 425 426 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) { 427 return mlir::TypeSwitch<mlir::Type, bool>(type) 428 .Case([](fir::BoxProcType boxProc) { 429 return fir::hasAbstractResult( 430 mlir::cast<mlir::FunctionType>(boxProc.getEleTy())); 431 }) 432 .Case([](fir::PointerType pointer) { 433 return fir::hasAbstractResult( 434 mlir::cast<mlir::FunctionType>(pointer.getEleTy())); 435 }) 436 .Default([](auto &&) { return false; }); 437 } 438 439 void runOnSpecificOperation(fir::GlobalOp global, bool, 440 mlir::RewritePatternSet &, 441 mlir::ConversionTarget &) { 442 if (containsFunctionTypeWithAbstractResult(global.getType())) { 443 TODO(global->getLoc(), "support for procedure pointers"); 444 } 445 } 446 447 /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work. 448 void runOnModule() { 449 mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation()); 450 451 auto pass = std::make_unique<AbstractResultOpt>(); 452 pass->copyOptionValuesFrom(this); 453 mlir::OpPassManager pipeline; 454 pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()}); 455 456 // Run the pass on all operations directly nested inside of the ModuleOp 457 // we can't just call runOnSpecificOperation here because the pass 458 // implementation only works when scoped to a particular func.func or 459 // fir.global 460 for (mlir::Region ®ion : mod->getRegions()) { 461 for (mlir::Block &block : region.getBlocks()) { 462 for (mlir::Operation &op : block.getOperations()) { 463 if (mlir::failed(runPipeline(pipeline, &op))) { 464 mlir::emitError(op.getLoc(), "Failed to run abstract result pass"); 465 signalPassFailure(); 466 return; 467 } 468 } 469 } 470 } 471 } 472 473 void runOnOperation() override { 474 auto *context = &this->getContext(); 475 mlir::Operation *op = this->getOperation(); 476 if (mlir::isa<mlir::ModuleOp>(op)) { 477 runOnModule(); 478 return; 479 } 480 481 LazySymbolTable symbolTable(op); 482 483 mlir::RewritePatternSet patterns(context); 484 mlir::ConversionTarget target = *context; 485 const bool shouldBoxResult = this->passResultAsBox.getValue(); 486 487 mlir::TypeSwitch<mlir::Operation *, void>(op) 488 .Case<mlir::func::FuncOp, fir::GlobalOp, mlir::gpu::GPUFuncOp>( 489 [&](auto op) { 490 runOnSpecificOperation(op, shouldBoxResult, patterns, target); 491 }); 492 493 // Convert the calls and, if needed, the ReturnOp in the function body. 494 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect, 495 mlir::func::FuncDialect>(); 496 target.addIllegalOp<fir::SaveResultOp>(); 497 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) { 498 mlir::FunctionType funTy = call.getFunctionType(); 499 if (hasScalarDerivedResult(funTy) && 500 fir::hasBindcAttr(call.getOperation())) 501 return true; 502 return !hasAbstractResult(funTy); 503 }); 504 target.addDynamicallyLegalOp<fir::AddrOfOp>([&symbolTable]( 505 fir::AddrOfOp addrOf) { 506 if (auto funTy = mlir::dyn_cast<mlir::FunctionType>(addrOf.getType())) { 507 if (hasScalarDerivedResult(funTy)) { 508 auto func = symbolTable.lookup<mlir::func::FuncOp>( 509 addrOf.getSymbol().getRootReference().getValue()); 510 return func && fir::hasBindcAttr(func.getOperation()); 511 } 512 return !hasAbstractResult(funTy); 513 } 514 return true; 515 }); 516 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { 517 mlir::FunctionType funTy = dispatch.getFunctionType(); 518 if (hasScalarDerivedResult(funTy) && 519 fir::hasBindcAttr(dispatch.getOperation())) 520 return true; 521 return !hasAbstractResult(dispatch.getFunctionType()); 522 }); 523 524 patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult); 525 patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult); 526 patterns.insert<SaveResultOpConversion>(context); 527 patterns.insert<AddrOfOpConversion>(context, shouldBoxResult); 528 if (mlir::failed( 529 mlir::applyPartialConversion(op, target, std::move(patterns)))) { 530 mlir::emitError(op->getLoc(), "error in converting abstract results\n"); 531 this->signalPassFailure(); 532 } 533 } 534 }; 535 536 } // end anonymous namespace 537 } // namespace fir 538