1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/FIRBuilder.h" 10 #include "flang/Optimizer/Builder/Todo.h" 11 #include "flang/Optimizer/Dialect/FIRDialect.h" 12 #include "flang/Optimizer/Dialect/FIROps.h" 13 #include "flang/Optimizer/Dialect/FIRType.h" 14 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 15 #include "flang/Optimizer/Transforms/Passes.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/IR/Diagnostics.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Pass/PassManager.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 namespace fir { 24 #define GEN_PASS_DEF_ABSTRACTRESULTOPT 25 #include "flang/Optimizer/Transforms/Passes.h.inc" 26 } // namespace fir 27 28 #define DEBUG_TYPE "flang-abstract-result-opt" 29 30 using namespace mlir; 31 32 namespace fir { 33 namespace { 34 35 static mlir::Type getResultArgumentType(mlir::Type resultType, 36 bool shouldBoxResult) { 37 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) 38 .Case<fir::SequenceType, fir::RecordType>( 39 [&](mlir::Type type) -> mlir::Type { 40 if (shouldBoxResult) 41 return fir::BoxType::get(type); 42 return fir::ReferenceType::get(type); 43 }) 44 .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type { 45 return fir::ReferenceType::get(type); 46 }) 47 .Default([](mlir::Type) -> mlir::Type { 48 llvm_unreachable("bad abstract result type"); 49 }); 50 } 51 52 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, 53 bool shouldBoxResult) { 54 auto resultType = funcTy.getResult(0); 55 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 56 llvm::SmallVector<mlir::Type> newInputTypes = {argTy}; 57 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); 58 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, 59 /*resultTypes=*/{}); 60 } 61 62 static mlir::Type getVoidPtrType(mlir::MLIRContext *context) { 63 return fir::ReferenceType::get(mlir::NoneType::get(context)); 64 } 65 66 /// This is for function result types that are of type C_PTR from ISO_C_BINDING. 67 /// Follow the ABI for interoperability with C. 68 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) { 69 assert(fir::isa_builtin_cptr_type(funcTy.getResult(0))); 70 llvm::SmallVector<mlir::Type> outputTypes{ 71 getVoidPtrType(funcTy.getContext())}; 72 return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(), 73 outputTypes); 74 } 75 76 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { 77 return mlir::isa<fir::SequenceType, fir::RecordType>(resultType) && 78 shouldBoxResult; 79 } 80 81 template <typename Op> 82 class CallConversion : public mlir::OpRewritePattern<Op> { 83 public: 84 using mlir::OpRewritePattern<Op>::OpRewritePattern; 85 86 CallConversion(mlir::MLIRContext *context, bool shouldBoxResult) 87 : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {} 88 89 llvm::LogicalResult 90 matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { 91 auto loc = op.getLoc(); 92 auto result = op->getResult(0); 93 if (!result.hasOneUse()) { 94 mlir::emitError(loc, 95 "calls with abstract result must have exactly one user"); 96 return mlir::failure(); 97 } 98 auto saveResult = 99 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser()); 100 if (!saveResult) { 101 mlir::emitError( 102 loc, "calls with abstract result must be used in fir.save_result"); 103 return mlir::failure(); 104 } 105 auto argType = getResultArgumentType(result.getType(), shouldBoxResult); 106 auto buffer = saveResult.getMemref(); 107 mlir::Value arg = buffer; 108 if (mustEmboxResult(result.getType(), shouldBoxResult)) 109 arg = rewriter.create<fir::EmboxOp>( 110 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, 111 saveResult.getTypeparams()); 112 113 llvm::SmallVector<mlir::Type> newResultTypes; 114 bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType()); 115 if (isResultBuiltinCPtr) 116 newResultTypes.emplace_back(getVoidPtrType(result.getContext())); 117 118 Op newOp; 119 // fir::CallOp specific handling. 120 if constexpr (std::is_same_v<Op, fir::CallOp>) { 121 if (op.getCallee()) { 122 llvm::SmallVector<mlir::Value> newOperands; 123 if (!isResultBuiltinCPtr) 124 newOperands.emplace_back(arg); 125 newOperands.append(op.getOperands().begin(), op.getOperands().end()); 126 newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(), 127 newResultTypes, newOperands); 128 } else { 129 // Indirect calls. 130 llvm::SmallVector<mlir::Type> newInputTypes; 131 if (!isResultBuiltinCPtr) 132 newInputTypes.emplace_back(argType); 133 for (auto operand : op.getOperands().drop_front()) 134 newInputTypes.push_back(operand.getType()); 135 auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes, 136 newResultTypes); 137 138 llvm::SmallVector<mlir::Value> newOperands; 139 newOperands.push_back( 140 rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0))); 141 if (!isResultBuiltinCPtr) 142 newOperands.push_back(arg); 143 newOperands.append(op.getOperands().begin() + 1, 144 op.getOperands().end()); 145 newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, 146 newResultTypes, newOperands); 147 } 148 } 149 150 // fir::DispatchOp specific handling. 151 if constexpr (std::is_same_v<Op, fir::DispatchOp>) { 152 llvm::SmallVector<mlir::Value> newOperands; 153 if (!isResultBuiltinCPtr) 154 newOperands.emplace_back(arg); 155 unsigned passArgShift = newOperands.size(); 156 newOperands.append(op.getOperands().begin() + 1, op.getOperands().end()); 157 mlir::IntegerAttr passArgPos; 158 if (op.getPassArgPos()) 159 passArgPos = 160 rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift); 161 newOp = rewriter.create<fir::DispatchOp>( 162 loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), 163 op.getOperands()[0], newOperands, passArgPos, 164 op.getProcedureAttrsAttr()); 165 } 166 167 if (isResultBuiltinCPtr) { 168 mlir::Value save = saveResult.getMemref(); 169 auto module = op->template getParentOfType<mlir::ModuleOp>(); 170 FirOpBuilder builder(rewriter, module); 171 mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( 172 builder, loc, save, result.getType()); 173 builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr); 174 } 175 op->dropAllReferences(); 176 rewriter.eraseOp(op); 177 return mlir::success(); 178 } 179 180 private: 181 bool shouldBoxResult; 182 }; 183 184 class SaveResultOpConversion 185 : public mlir::OpRewritePattern<fir::SaveResultOp> { 186 public: 187 using OpRewritePattern::OpRewritePattern; 188 SaveResultOpConversion(mlir::MLIRContext *context) 189 : OpRewritePattern(context) {} 190 llvm::LogicalResult 191 matchAndRewrite(fir::SaveResultOp op, 192 mlir::PatternRewriter &rewriter) const override { 193 rewriter.eraseOp(op); 194 return mlir::success(); 195 } 196 }; 197 198 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { 199 public: 200 using OpRewritePattern::OpRewritePattern; 201 ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) 202 : OpRewritePattern(context), newArg{newArg} {} 203 llvm::LogicalResult 204 matchAndRewrite(mlir::func::ReturnOp ret, 205 mlir::PatternRewriter &rewriter) const override { 206 auto loc = ret.getLoc(); 207 rewriter.setInsertionPoint(ret); 208 mlir::Value resultValue = ret.getOperand(0); 209 fir::LoadOp resultLoad; 210 mlir::Value resultStorage; 211 // Identify result local storage. 212 if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) { 213 resultLoad = load; 214 resultStorage = load.getMemref(); 215 // The result alloca may be behind a fir.declare, if any. 216 if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>()) 217 resultStorage = declare.getMemref(); 218 } 219 // Replace old local storage with new storage argument, unless 220 // the derived type is C_PTR/C_FUN_PTR, in which case the return 221 // type is updated to return void* (no new argument is passed). 222 if (fir::isa_builtin_cptr_type(resultValue.getType())) { 223 auto module = ret->getParentOfType<mlir::ModuleOp>(); 224 FirOpBuilder builder(rewriter, module); 225 mlir::Value cptr = resultValue; 226 if (resultLoad) { 227 // Replace whole derived type load by component load. 228 cptr = resultLoad.getMemref(); 229 rewriter.setInsertionPoint(resultLoad); 230 } 231 mlir::Value newResultValue = 232 fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr); 233 newResultValue = builder.createConvert( 234 loc, getVoidPtrType(ret.getContext()), newResultValue); 235 rewriter.setInsertionPoint(ret); 236 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>( 237 ret, mlir::ValueRange{newResultValue}); 238 } else if (resultStorage) { 239 resultStorage.replaceAllUsesWith(newArg); 240 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); 241 } else { 242 // The result storage may have been optimized out by a memory to 243 // register pass, this is possible for fir.box results, or fir.record 244 // with no length parameters. Simply store the result in the result 245 // storage. at the return point. 246 rewriter.create<fir::StoreOp>(loc, resultValue, newArg); 247 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); 248 } 249 // Delete result old local storage if unused. 250 if (resultStorage) 251 if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>()) 252 if (alloc->use_empty()) 253 rewriter.eraseOp(alloc); 254 return mlir::success(); 255 } 256 257 private: 258 mlir::Value newArg; 259 }; 260 261 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> { 262 public: 263 using OpRewritePattern::OpRewritePattern; 264 AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) 265 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} 266 llvm::LogicalResult 267 matchAndRewrite(fir::AddrOfOp addrOf, 268 mlir::PatternRewriter &rewriter) const override { 269 auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType()); 270 mlir::FunctionType newFuncTy; 271 if (oldFuncTy.getNumResults() != 0 && 272 fir::isa_builtin_cptr_type(oldFuncTy.getResult(0))) 273 newFuncTy = getCPtrFunctionType(oldFuncTy); 274 else 275 newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); 276 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy, 277 addrOf.getSymbol()); 278 // Rather than converting all op a function pointer might transit through 279 // (e.g calls, stores, loads, converts...), cast new type to the abstract 280 // type. A conversion will be added when calling indirect calls of abstract 281 // types. 282 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf); 283 return mlir::success(); 284 } 285 286 private: 287 bool shouldBoxResult; 288 }; 289 290 class AbstractResultOpt 291 : public fir::impl::AbstractResultOptBase<AbstractResultOpt> { 292 public: 293 using fir::impl::AbstractResultOptBase< 294 AbstractResultOpt>::AbstractResultOptBase; 295 296 void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult, 297 mlir::RewritePatternSet &patterns, 298 mlir::ConversionTarget &target) { 299 auto loc = func.getLoc(); 300 auto *context = &getContext(); 301 // Convert function type itself if it has an abstract result. 302 auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType()); 303 if (hasAbstractResult(funcTy)) { 304 if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) { 305 func.setType(getCPtrFunctionType(funcTy)); 306 patterns.insert<ReturnOpConversion>(context, mlir::Value{}); 307 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 308 [](mlir::func::ReturnOp ret) { 309 mlir::Type retTy = ret.getOperand(0).getType(); 310 return !fir::isa_builtin_cptr_type(retTy); 311 }); 312 return; 313 } 314 if (!func.empty()) { 315 // Insert new argument. 316 mlir::OpBuilder rewriter(context); 317 auto resultType = funcTy.getResult(0); 318 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 319 func.insertArgument(0u, argTy, {}, loc); 320 func.eraseResult(0u); 321 mlir::Value newArg = func.getArgument(0u); 322 if (mustEmboxResult(resultType, shouldBoxResult)) { 323 auto bufferType = fir::ReferenceType::get(resultType); 324 rewriter.setInsertionPointToStart(&func.front()); 325 newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg); 326 } 327 patterns.insert<ReturnOpConversion>(context, newArg); 328 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 329 [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); }); 330 assert(func.getFunctionType() == 331 getNewFunctionType(funcTy, shouldBoxResult)); 332 } else { 333 llvm::SmallVector<mlir::DictionaryAttr> allArgs; 334 func.getAllArgAttrs(allArgs); 335 allArgs.insert(allArgs.begin(), 336 mlir::DictionaryAttr::get(func->getContext())); 337 func.setType(getNewFunctionType(funcTy, shouldBoxResult)); 338 func.setAllArgAttrs(allArgs); 339 } 340 } 341 } 342 343 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) { 344 return mlir::TypeSwitch<mlir::Type, bool>(type) 345 .Case([](fir::BoxProcType boxProc) { 346 return fir::hasAbstractResult( 347 mlir::cast<mlir::FunctionType>(boxProc.getEleTy())); 348 }) 349 .Case([](fir::PointerType pointer) { 350 return fir::hasAbstractResult( 351 mlir::cast<mlir::FunctionType>(pointer.getEleTy())); 352 }) 353 .Default([](auto &&) { return false; }); 354 } 355 356 void runOnSpecificOperation(fir::GlobalOp global, bool, 357 mlir::RewritePatternSet &, 358 mlir::ConversionTarget &) { 359 if (containsFunctionTypeWithAbstractResult(global.getType())) { 360 TODO(global->getLoc(), "support for procedure pointers"); 361 } 362 } 363 364 /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work. 365 void runOnModule() { 366 mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation()); 367 368 auto pass = std::make_unique<AbstractResultOpt>(); 369 pass->copyOptionValuesFrom(this); 370 mlir::OpPassManager pipeline; 371 pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()}); 372 373 // Run the pass on all operations directly nested inside of the ModuleOp 374 // we can't just call runOnSpecificOperation here because the pass 375 // implementation only works when scoped to a particular func.func or 376 // fir.global 377 for (mlir::Region ®ion : mod->getRegions()) { 378 for (mlir::Block &block : region.getBlocks()) { 379 for (mlir::Operation &op : block.getOperations()) { 380 if (mlir::failed(runPipeline(pipeline, &op))) { 381 mlir::emitError(op.getLoc(), "Failed to run abstract result pass"); 382 signalPassFailure(); 383 return; 384 } 385 } 386 } 387 } 388 } 389 390 void runOnOperation() override { 391 auto *context = &this->getContext(); 392 mlir::Operation *op = this->getOperation(); 393 if (mlir::isa<mlir::ModuleOp>(op)) { 394 runOnModule(); 395 return; 396 } 397 398 mlir::RewritePatternSet patterns(context); 399 mlir::ConversionTarget target = *context; 400 const bool shouldBoxResult = this->passResultAsBox.getValue(); 401 402 mlir::TypeSwitch<mlir::Operation *, void>(op) 403 .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) { 404 runOnSpecificOperation(op, shouldBoxResult, patterns, target); 405 }); 406 407 // Convert the calls and, if needed, the ReturnOp in the function body. 408 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect, 409 mlir::func::FuncDialect>(); 410 target.addIllegalOp<fir::SaveResultOp>(); 411 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) { 412 return !hasAbstractResult(call.getFunctionType()); 413 }); 414 target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) { 415 if (auto funTy = mlir::dyn_cast<mlir::FunctionType>(addrOf.getType())) 416 return !hasAbstractResult(funTy); 417 return true; 418 }); 419 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { 420 return !hasAbstractResult(dispatch.getFunctionType()); 421 }); 422 423 patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult); 424 patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult); 425 patterns.insert<SaveResultOpConversion>(context); 426 patterns.insert<AddrOfOpConversion>(context, shouldBoxResult); 427 if (mlir::failed( 428 mlir::applyPartialConversion(op, target, std::move(patterns)))) { 429 mlir::emitError(op->getLoc(), "error in converting abstract results\n"); 430 this->signalPassFailure(); 431 } 432 } 433 }; 434 435 } // end anonymous namespace 436 } // namespace fir 437