1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/FIRBuilder.h" 10 #include "flang/Optimizer/Builder/Todo.h" 11 #include "flang/Optimizer/Dialect/FIRDialect.h" 12 #include "flang/Optimizer/Dialect/FIROps.h" 13 #include "flang/Optimizer/Dialect/FIRType.h" 14 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 15 #include "flang/Optimizer/Transforms/Passes.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/IR/Diagnostics.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Pass/PassManager.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 namespace fir { 24 #define GEN_PASS_DEF_ABSTRACTRESULTOPT 25 #include "flang/Optimizer/Transforms/Passes.h.inc" 26 } // namespace fir 27 28 #define DEBUG_TYPE "flang-abstract-result-opt" 29 30 using namespace mlir; 31 32 namespace fir { 33 namespace { 34 35 static mlir::Type getResultArgumentType(mlir::Type resultType, 36 bool shouldBoxResult) { 37 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) 38 .Case<fir::SequenceType, fir::RecordType>( 39 [&](mlir::Type type) -> mlir::Type { 40 if (shouldBoxResult) 41 return fir::BoxType::get(type); 42 return fir::ReferenceType::get(type); 43 }) 44 .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type { 45 return fir::ReferenceType::get(type); 46 }) 47 .Default([](mlir::Type) -> mlir::Type { 48 llvm_unreachable("bad abstract result type"); 49 }); 50 } 51 52 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, 53 bool shouldBoxResult) { 54 auto resultType = funcTy.getResult(0); 55 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 56 llvm::SmallVector<mlir::Type> newInputTypes = {argTy}; 57 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); 58 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, 59 /*resultTypes=*/{}); 60 } 61 62 static mlir::Type getVoidPtrType(mlir::MLIRContext *context) { 63 return fir::ReferenceType::get(mlir::NoneType::get(context)); 64 } 65 66 /// This is for function result types that are of type C_PTR from ISO_C_BINDING. 67 /// Follow the ABI for interoperability with C. 68 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) { 69 assert(fir::isa_builtin_cptr_type(funcTy.getResult(0))); 70 llvm::SmallVector<mlir::Type> outputTypes{ 71 getVoidPtrType(funcTy.getContext())}; 72 return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(), 73 outputTypes); 74 } 75 76 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { 77 return mlir::isa<fir::SequenceType, fir::RecordType>(resultType) && 78 shouldBoxResult; 79 } 80 81 template <typename Op> 82 class CallConversion : public mlir::OpRewritePattern<Op> { 83 public: 84 using mlir::OpRewritePattern<Op>::OpRewritePattern; 85 86 CallConversion(mlir::MLIRContext *context, bool shouldBoxResult) 87 : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {} 88 89 llvm::LogicalResult 90 matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { 91 auto loc = op.getLoc(); 92 auto result = op->getResult(0); 93 if (!result.hasOneUse()) { 94 mlir::emitError(loc, 95 "calls with abstract result must have exactly one user"); 96 return mlir::failure(); 97 } 98 auto saveResult = 99 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser()); 100 if (!saveResult) { 101 mlir::emitError( 102 loc, "calls with abstract result must be used in fir.save_result"); 103 return mlir::failure(); 104 } 105 auto argType = getResultArgumentType(result.getType(), shouldBoxResult); 106 auto buffer = saveResult.getMemref(); 107 mlir::Value arg = buffer; 108 if (mustEmboxResult(result.getType(), shouldBoxResult)) 109 arg = rewriter.create<fir::EmboxOp>( 110 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, 111 saveResult.getTypeparams()); 112 113 llvm::SmallVector<mlir::Type> newResultTypes; 114 bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType()); 115 if (isResultBuiltinCPtr) 116 newResultTypes.emplace_back(getVoidPtrType(result.getContext())); 117 118 Op newOp; 119 // fir::CallOp specific handling. 120 if constexpr (std::is_same_v<Op, fir::CallOp>) { 121 if (op.getCallee()) { 122 llvm::SmallVector<mlir::Value> newOperands; 123 if (!isResultBuiltinCPtr) 124 newOperands.emplace_back(arg); 125 newOperands.append(op.getOperands().begin(), op.getOperands().end()); 126 newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(), 127 newResultTypes, newOperands); 128 } else { 129 // Indirect calls. 130 llvm::SmallVector<mlir::Type> newInputTypes; 131 if (!isResultBuiltinCPtr) 132 newInputTypes.emplace_back(argType); 133 for (auto operand : op.getOperands().drop_front()) 134 newInputTypes.push_back(operand.getType()); 135 auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes, 136 newResultTypes); 137 138 llvm::SmallVector<mlir::Value> newOperands; 139 newOperands.push_back( 140 rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0))); 141 if (!isResultBuiltinCPtr) 142 newOperands.push_back(arg); 143 newOperands.append(op.getOperands().begin() + 1, 144 op.getOperands().end()); 145 newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, 146 newResultTypes, newOperands); 147 } 148 } 149 150 // fir::DispatchOp specific handling. 151 if constexpr (std::is_same_v<Op, fir::DispatchOp>) { 152 llvm::SmallVector<mlir::Value> newOperands; 153 if (!isResultBuiltinCPtr) 154 newOperands.emplace_back(arg); 155 unsigned passArgShift = newOperands.size(); 156 newOperands.append(op.getOperands().begin() + 1, op.getOperands().end()); 157 158 fir::DispatchOp newDispatchOp; 159 if (op.getPassArgPos()) 160 newOp = rewriter.create<fir::DispatchOp>( 161 loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), 162 op.getOperands()[0], newOperands, 163 rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift)); 164 else 165 newOp = rewriter.create<fir::DispatchOp>( 166 loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), 167 op.getOperands()[0], newOperands, nullptr); 168 } 169 170 if (isResultBuiltinCPtr) { 171 mlir::Value save = saveResult.getMemref(); 172 auto module = op->template getParentOfType<mlir::ModuleOp>(); 173 FirOpBuilder builder(rewriter, module); 174 mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( 175 builder, loc, save, result.getType()); 176 builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr); 177 } 178 op->dropAllReferences(); 179 rewriter.eraseOp(op); 180 return mlir::success(); 181 } 182 183 private: 184 bool shouldBoxResult; 185 }; 186 187 class SaveResultOpConversion 188 : public mlir::OpRewritePattern<fir::SaveResultOp> { 189 public: 190 using OpRewritePattern::OpRewritePattern; 191 SaveResultOpConversion(mlir::MLIRContext *context) 192 : OpRewritePattern(context) {} 193 llvm::LogicalResult 194 matchAndRewrite(fir::SaveResultOp op, 195 mlir::PatternRewriter &rewriter) const override { 196 rewriter.eraseOp(op); 197 return mlir::success(); 198 } 199 }; 200 201 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { 202 public: 203 using OpRewritePattern::OpRewritePattern; 204 ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) 205 : OpRewritePattern(context), newArg{newArg} {} 206 llvm::LogicalResult 207 matchAndRewrite(mlir::func::ReturnOp ret, 208 mlir::PatternRewriter &rewriter) const override { 209 auto loc = ret.getLoc(); 210 rewriter.setInsertionPoint(ret); 211 mlir::Value resultValue = ret.getOperand(0); 212 fir::LoadOp resultLoad; 213 mlir::Value resultStorage; 214 // Identify result local storage. 215 if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) { 216 resultLoad = load; 217 resultStorage = load.getMemref(); 218 // The result alloca may be behind a fir.declare, if any. 219 if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>()) 220 resultStorage = declare.getMemref(); 221 } 222 // Replace old local storage with new storage argument, unless 223 // the derived type is C_PTR/C_FUN_PTR, in which case the return 224 // type is updated to return void* (no new argument is passed). 225 if (fir::isa_builtin_cptr_type(resultValue.getType())) { 226 auto module = ret->getParentOfType<mlir::ModuleOp>(); 227 FirOpBuilder builder(rewriter, module); 228 mlir::Value cptr = resultValue; 229 if (resultLoad) { 230 // Replace whole derived type load by component load. 231 cptr = resultLoad.getMemref(); 232 rewriter.setInsertionPoint(resultLoad); 233 } 234 mlir::Value newResultValue = 235 fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr); 236 newResultValue = builder.createConvert( 237 loc, getVoidPtrType(ret.getContext()), newResultValue); 238 rewriter.setInsertionPoint(ret); 239 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>( 240 ret, mlir::ValueRange{newResultValue}); 241 } else if (resultStorage) { 242 resultStorage.replaceAllUsesWith(newArg); 243 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); 244 } else { 245 // The result storage may have been optimized out by a memory to 246 // register pass, this is possible for fir.box results, or fir.record 247 // with no length parameters. Simply store the result in the result 248 // storage. at the return point. 249 rewriter.create<fir::StoreOp>(loc, resultValue, newArg); 250 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); 251 } 252 // Delete result old local storage if unused. 253 if (resultStorage) 254 if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>()) 255 if (alloc->use_empty()) 256 rewriter.eraseOp(alloc); 257 return mlir::success(); 258 } 259 260 private: 261 mlir::Value newArg; 262 }; 263 264 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> { 265 public: 266 using OpRewritePattern::OpRewritePattern; 267 AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) 268 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} 269 llvm::LogicalResult 270 matchAndRewrite(fir::AddrOfOp addrOf, 271 mlir::PatternRewriter &rewriter) const override { 272 auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType()); 273 mlir::FunctionType newFuncTy; 274 if (oldFuncTy.getNumResults() != 0 && 275 fir::isa_builtin_cptr_type(oldFuncTy.getResult(0))) 276 newFuncTy = getCPtrFunctionType(oldFuncTy); 277 else 278 newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); 279 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy, 280 addrOf.getSymbol()); 281 // Rather than converting all op a function pointer might transit through 282 // (e.g calls, stores, loads, converts...), cast new type to the abstract 283 // type. A conversion will be added when calling indirect calls of abstract 284 // types. 285 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf); 286 return mlir::success(); 287 } 288 289 private: 290 bool shouldBoxResult; 291 }; 292 293 class AbstractResultOpt 294 : public fir::impl::AbstractResultOptBase<AbstractResultOpt> { 295 public: 296 using fir::impl::AbstractResultOptBase< 297 AbstractResultOpt>::AbstractResultOptBase; 298 299 void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult, 300 mlir::RewritePatternSet &patterns, 301 mlir::ConversionTarget &target) { 302 auto loc = func.getLoc(); 303 auto *context = &getContext(); 304 // Convert function type itself if it has an abstract result. 305 auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType()); 306 if (hasAbstractResult(funcTy)) { 307 if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) { 308 func.setType(getCPtrFunctionType(funcTy)); 309 patterns.insert<ReturnOpConversion>(context, mlir::Value{}); 310 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 311 [](mlir::func::ReturnOp ret) { 312 mlir::Type retTy = ret.getOperand(0).getType(); 313 return !fir::isa_builtin_cptr_type(retTy); 314 }); 315 return; 316 } 317 if (!func.empty()) { 318 // Insert new argument. 319 mlir::OpBuilder rewriter(context); 320 auto resultType = funcTy.getResult(0); 321 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 322 func.insertArgument(0u, argTy, {}, loc); 323 func.eraseResult(0u); 324 mlir::Value newArg = func.getArgument(0u); 325 if (mustEmboxResult(resultType, shouldBoxResult)) { 326 auto bufferType = fir::ReferenceType::get(resultType); 327 rewriter.setInsertionPointToStart(&func.front()); 328 newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg); 329 } 330 patterns.insert<ReturnOpConversion>(context, newArg); 331 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 332 [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); }); 333 assert(func.getFunctionType() == 334 getNewFunctionType(funcTy, shouldBoxResult)); 335 } else { 336 llvm::SmallVector<mlir::DictionaryAttr> allArgs; 337 func.getAllArgAttrs(allArgs); 338 allArgs.insert(allArgs.begin(), 339 mlir::DictionaryAttr::get(func->getContext())); 340 func.setType(getNewFunctionType(funcTy, shouldBoxResult)); 341 func.setAllArgAttrs(allArgs); 342 } 343 } 344 } 345 346 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) { 347 return mlir::TypeSwitch<mlir::Type, bool>(type) 348 .Case([](fir::BoxProcType boxProc) { 349 return fir::hasAbstractResult( 350 mlir::cast<mlir::FunctionType>(boxProc.getEleTy())); 351 }) 352 .Case([](fir::PointerType pointer) { 353 return fir::hasAbstractResult( 354 mlir::cast<mlir::FunctionType>(pointer.getEleTy())); 355 }) 356 .Default([](auto &&) { return false; }); 357 } 358 359 void runOnSpecificOperation(fir::GlobalOp global, bool, 360 mlir::RewritePatternSet &, 361 mlir::ConversionTarget &) { 362 if (containsFunctionTypeWithAbstractResult(global.getType())) { 363 TODO(global->getLoc(), "support for procedure pointers"); 364 } 365 } 366 367 /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work. 368 void runOnModule() { 369 mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation()); 370 371 auto pass = std::make_unique<AbstractResultOpt>(); 372 pass->copyOptionValuesFrom(this); 373 mlir::OpPassManager pipeline; 374 pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()}); 375 376 // Run the pass on all operations directly nested inside of the ModuleOp 377 // we can't just call runOnSpecificOperation here because the pass 378 // implementation only works when scoped to a particular func.func or 379 // fir.global 380 for (mlir::Region ®ion : mod->getRegions()) { 381 for (mlir::Block &block : region.getBlocks()) { 382 for (mlir::Operation &op : block.getOperations()) { 383 if (mlir::failed(runPipeline(pipeline, &op))) { 384 mlir::emitError(op.getLoc(), "Failed to run abstract result pass"); 385 signalPassFailure(); 386 return; 387 } 388 } 389 } 390 } 391 } 392 393 void runOnOperation() override { 394 auto *context = &this->getContext(); 395 mlir::Operation *op = this->getOperation(); 396 if (mlir::isa<mlir::ModuleOp>(op)) { 397 runOnModule(); 398 return; 399 } 400 401 mlir::RewritePatternSet patterns(context); 402 mlir::ConversionTarget target = *context; 403 const bool shouldBoxResult = this->passResultAsBox.getValue(); 404 405 mlir::TypeSwitch<mlir::Operation *, void>(op) 406 .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) { 407 runOnSpecificOperation(op, shouldBoxResult, patterns, target); 408 }); 409 410 // Convert the calls and, if needed, the ReturnOp in the function body. 411 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect, 412 mlir::func::FuncDialect>(); 413 target.addIllegalOp<fir::SaveResultOp>(); 414 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) { 415 return !hasAbstractResult(call.getFunctionType()); 416 }); 417 target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) { 418 if (auto funTy = mlir::dyn_cast<mlir::FunctionType>(addrOf.getType())) 419 return !hasAbstractResult(funTy); 420 return true; 421 }); 422 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { 423 return !hasAbstractResult(dispatch.getFunctionType()); 424 }); 425 426 patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult); 427 patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult); 428 patterns.insert<SaveResultOpConversion>(context); 429 patterns.insert<AddrOfOpConversion>(context, shouldBoxResult); 430 if (mlir::failed( 431 mlir::applyPartialConversion(op, target, std::move(patterns)))) { 432 mlir::emitError(op->getLoc(), "error in converting abstract results\n"); 433 this->signalPassFailure(); 434 } 435 } 436 }; 437 438 } // end anonymous namespace 439 } // namespace fir 440