1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/FIRBuilder.h" 10 #include "flang/Optimizer/Builder/Todo.h" 11 #include "flang/Optimizer/Dialect/FIRDialect.h" 12 #include "flang/Optimizer/Dialect/FIROps.h" 13 #include "flang/Optimizer/Dialect/FIRType.h" 14 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 15 #include "flang/Optimizer/Transforms/Passes.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/IR/Diagnostics.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Pass/PassManager.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 namespace fir { 24 #define GEN_PASS_DEF_ABSTRACTRESULTOPT 25 #include "flang/Optimizer/Transforms/Passes.h.inc" 26 } // namespace fir 27 28 #define DEBUG_TYPE "flang-abstract-result-opt" 29 30 using namespace mlir; 31 32 namespace fir { 33 namespace { 34 35 static mlir::Type getResultArgumentType(mlir::Type resultType, 36 bool shouldBoxResult) { 37 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) 38 .Case<fir::SequenceType, fir::RecordType>( 39 [&](mlir::Type type) -> mlir::Type { 40 if (shouldBoxResult) 41 return fir::BoxType::get(type); 42 return fir::ReferenceType::get(type); 43 }) 44 .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type { 45 return fir::ReferenceType::get(type); 46 }) 47 .Default([](mlir::Type) -> mlir::Type { 48 llvm_unreachable("bad abstract result type"); 49 }); 50 } 51 52 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, 53 bool shouldBoxResult) { 54 auto resultType = funcTy.getResult(0); 55 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 56 llvm::SmallVector<mlir::Type> newInputTypes = {argTy}; 57 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); 58 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, 59 /*resultTypes=*/{}); 60 } 61 62 /// This is for function result types that are of type C_PTR from ISO_C_BINDING. 63 /// Follow the ABI for interoperability with C. 64 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) { 65 auto resultType = funcTy.getResult(0); 66 assert(fir::isa_builtin_cptr_type(resultType)); 67 llvm::SmallVector<mlir::Type> outputTypes; 68 auto recTy = mlir::dyn_cast<fir::RecordType>(resultType); 69 outputTypes.emplace_back(recTy.getTypeList()[0].second); 70 return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(), 71 outputTypes); 72 } 73 74 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { 75 return mlir::isa<fir::SequenceType, fir::RecordType>(resultType) && 76 shouldBoxResult; 77 } 78 79 template <typename Op> 80 class CallConversion : public mlir::OpRewritePattern<Op> { 81 public: 82 using mlir::OpRewritePattern<Op>::OpRewritePattern; 83 84 CallConversion(mlir::MLIRContext *context, bool shouldBoxResult) 85 : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {} 86 87 llvm::LogicalResult 88 matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { 89 auto loc = op.getLoc(); 90 auto result = op->getResult(0); 91 if (!result.hasOneUse()) { 92 mlir::emitError(loc, 93 "calls with abstract result must have exactly one user"); 94 return mlir::failure(); 95 } 96 auto saveResult = 97 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser()); 98 if (!saveResult) { 99 mlir::emitError( 100 loc, "calls with abstract result must be used in fir.save_result"); 101 return mlir::failure(); 102 } 103 auto argType = getResultArgumentType(result.getType(), shouldBoxResult); 104 auto buffer = saveResult.getMemref(); 105 mlir::Value arg = buffer; 106 if (mustEmboxResult(result.getType(), shouldBoxResult)) 107 arg = rewriter.create<fir::EmboxOp>( 108 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, 109 saveResult.getTypeparams()); 110 111 llvm::SmallVector<mlir::Type> newResultTypes; 112 // TODO: This should be generalized for derived types, and it is 113 // architecture and OS dependent. 114 bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType()); 115 Op newOp; 116 if (isResultBuiltinCPtr) { 117 auto recTy = mlir::dyn_cast<fir::RecordType>(result.getType()); 118 newResultTypes.emplace_back(recTy.getTypeList()[0].second); 119 } 120 121 // fir::CallOp specific handling. 122 if constexpr (std::is_same_v<Op, fir::CallOp>) { 123 if (op.getCallee()) { 124 llvm::SmallVector<mlir::Value> newOperands; 125 if (!isResultBuiltinCPtr) 126 newOperands.emplace_back(arg); 127 newOperands.append(op.getOperands().begin(), op.getOperands().end()); 128 newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(), 129 newResultTypes, newOperands); 130 } else { 131 // Indirect calls. 132 llvm::SmallVector<mlir::Type> newInputTypes; 133 if (!isResultBuiltinCPtr) 134 newInputTypes.emplace_back(argType); 135 for (auto operand : op.getOperands().drop_front()) 136 newInputTypes.push_back(operand.getType()); 137 auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes, 138 newResultTypes); 139 140 llvm::SmallVector<mlir::Value> newOperands; 141 newOperands.push_back( 142 rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0))); 143 if (!isResultBuiltinCPtr) 144 newOperands.push_back(arg); 145 newOperands.append(op.getOperands().begin() + 1, 146 op.getOperands().end()); 147 newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, 148 newResultTypes, newOperands); 149 } 150 } 151 152 // fir::DispatchOp specific handling. 153 if constexpr (std::is_same_v<Op, fir::DispatchOp>) { 154 llvm::SmallVector<mlir::Value> newOperands; 155 if (!isResultBuiltinCPtr) 156 newOperands.emplace_back(arg); 157 unsigned passArgShift = newOperands.size(); 158 newOperands.append(op.getOperands().begin() + 1, op.getOperands().end()); 159 160 fir::DispatchOp newDispatchOp; 161 if (op.getPassArgPos()) 162 newOp = rewriter.create<fir::DispatchOp>( 163 loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), 164 op.getOperands()[0], newOperands, 165 rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift)); 166 else 167 newOp = rewriter.create<fir::DispatchOp>( 168 loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), 169 op.getOperands()[0], newOperands, nullptr); 170 } 171 172 if (isResultBuiltinCPtr) { 173 mlir::Value save = saveResult.getMemref(); 174 auto module = op->template getParentOfType<mlir::ModuleOp>(); 175 FirOpBuilder builder(rewriter, module); 176 mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( 177 builder, loc, save, result.getType()); 178 rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr); 179 } 180 op->dropAllReferences(); 181 rewriter.eraseOp(op); 182 return mlir::success(); 183 } 184 185 private: 186 bool shouldBoxResult; 187 }; 188 189 class SaveResultOpConversion 190 : public mlir::OpRewritePattern<fir::SaveResultOp> { 191 public: 192 using OpRewritePattern::OpRewritePattern; 193 SaveResultOpConversion(mlir::MLIRContext *context) 194 : OpRewritePattern(context) {} 195 llvm::LogicalResult 196 matchAndRewrite(fir::SaveResultOp op, 197 mlir::PatternRewriter &rewriter) const override { 198 rewriter.eraseOp(op); 199 return mlir::success(); 200 } 201 }; 202 203 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { 204 public: 205 using OpRewritePattern::OpRewritePattern; 206 ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) 207 : OpRewritePattern(context), newArg{newArg} {} 208 llvm::LogicalResult 209 matchAndRewrite(mlir::func::ReturnOp ret, 210 mlir::PatternRewriter &rewriter) const override { 211 auto loc = ret.getLoc(); 212 rewriter.setInsertionPoint(ret); 213 auto returnedValue = ret.getOperand(0); 214 bool replacedStorage = false; 215 if (auto *op = returnedValue.getDefiningOp()) 216 if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) { 217 auto resultStorage = load.getMemref(); 218 // The result alloca may be behind a fir.declare, if any. 219 if (auto declare = mlir::dyn_cast_or_null<fir::DeclareOp>( 220 resultStorage.getDefiningOp())) 221 resultStorage = declare.getMemref(); 222 // TODO: This should be generalized for derived types, and it is 223 // architecture and OS dependent. 224 if (fir::isa_builtin_cptr_type(returnedValue.getType())) { 225 rewriter.eraseOp(load); 226 auto module = ret->getParentOfType<mlir::ModuleOp>(); 227 FirOpBuilder builder(rewriter, module); 228 mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr( 229 builder, loc, resultStorage, returnedValue.getType()); 230 mlir::Value retValue = rewriter.create<fir::LoadOp>( 231 loc, fir::unwrapRefType(retAddr.getType()), retAddr); 232 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>( 233 ret, mlir::ValueRange{retValue}); 234 return mlir::success(); 235 } 236 resultStorage.replaceAllUsesWith(newArg); 237 replacedStorage = true; 238 if (auto *alloc = resultStorage.getDefiningOp()) 239 if (alloc->use_empty()) 240 rewriter.eraseOp(alloc); 241 } 242 // The result storage may have been optimized out by a memory to 243 // register pass, this is possible for fir.box results, or fir.record 244 // with no length parameters. Simply store the result in the result storage. 245 // at the return point. 246 if (!replacedStorage) 247 rewriter.create<fir::StoreOp>(loc, returnedValue, newArg); 248 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); 249 return mlir::success(); 250 } 251 252 private: 253 mlir::Value newArg; 254 }; 255 256 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> { 257 public: 258 using OpRewritePattern::OpRewritePattern; 259 AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) 260 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} 261 llvm::LogicalResult 262 matchAndRewrite(fir::AddrOfOp addrOf, 263 mlir::PatternRewriter &rewriter) const override { 264 auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType()); 265 mlir::FunctionType newFuncTy; 266 // TODO: This should be generalized for derived types, and it is 267 // architecture and OS dependent. 268 if (oldFuncTy.getNumResults() != 0 && 269 fir::isa_builtin_cptr_type(oldFuncTy.getResult(0))) 270 newFuncTy = getCPtrFunctionType(oldFuncTy); 271 else 272 newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); 273 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy, 274 addrOf.getSymbol()); 275 // Rather than converting all op a function pointer might transit through 276 // (e.g calls, stores, loads, converts...), cast new type to the abstract 277 // type. A conversion will be added when calling indirect calls of abstract 278 // types. 279 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf); 280 return mlir::success(); 281 } 282 283 private: 284 bool shouldBoxResult; 285 }; 286 287 class AbstractResultOpt 288 : public fir::impl::AbstractResultOptBase<AbstractResultOpt> { 289 public: 290 using fir::impl::AbstractResultOptBase< 291 AbstractResultOpt>::AbstractResultOptBase; 292 293 void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult, 294 mlir::RewritePatternSet &patterns, 295 mlir::ConversionTarget &target) { 296 auto loc = func.getLoc(); 297 auto *context = &getContext(); 298 // Convert function type itself if it has an abstract result. 299 auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType()); 300 if (hasAbstractResult(funcTy)) { 301 // TODO: This should be generalized for derived types, and it is 302 // architecture and OS dependent. 303 if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) { 304 func.setType(getCPtrFunctionType(funcTy)); 305 patterns.insert<ReturnOpConversion>(context, mlir::Value{}); 306 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 307 [](mlir::func::ReturnOp ret) { 308 mlir::Type retTy = ret.getOperand(0).getType(); 309 return !fir::isa_builtin_cptr_type(retTy); 310 }); 311 return; 312 } 313 if (!func.empty()) { 314 // Insert new argument. 315 mlir::OpBuilder rewriter(context); 316 auto resultType = funcTy.getResult(0); 317 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 318 func.insertArgument(0u, argTy, {}, loc); 319 func.eraseResult(0u); 320 mlir::Value newArg = func.getArgument(0u); 321 if (mustEmboxResult(resultType, shouldBoxResult)) { 322 auto bufferType = fir::ReferenceType::get(resultType); 323 rewriter.setInsertionPointToStart(&func.front()); 324 newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg); 325 } 326 patterns.insert<ReturnOpConversion>(context, newArg); 327 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 328 [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); }); 329 assert(func.getFunctionType() == 330 getNewFunctionType(funcTy, shouldBoxResult)); 331 } else { 332 llvm::SmallVector<mlir::DictionaryAttr> allArgs; 333 func.getAllArgAttrs(allArgs); 334 allArgs.insert(allArgs.begin(), 335 mlir::DictionaryAttr::get(func->getContext())); 336 func.setType(getNewFunctionType(funcTy, shouldBoxResult)); 337 func.setAllArgAttrs(allArgs); 338 } 339 } 340 } 341 342 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) { 343 return mlir::TypeSwitch<mlir::Type, bool>(type) 344 .Case([](fir::BoxProcType boxProc) { 345 return fir::hasAbstractResult( 346 mlir::cast<mlir::FunctionType>(boxProc.getEleTy())); 347 }) 348 .Case([](fir::PointerType pointer) { 349 return fir::hasAbstractResult( 350 mlir::cast<mlir::FunctionType>(pointer.getEleTy())); 351 }) 352 .Default([](auto &&) { return false; }); 353 } 354 355 void runOnSpecificOperation(fir::GlobalOp global, bool, 356 mlir::RewritePatternSet &, 357 mlir::ConversionTarget &) { 358 if (containsFunctionTypeWithAbstractResult(global.getType())) { 359 TODO(global->getLoc(), "support for procedure pointers"); 360 } 361 } 362 363 /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work. 364 void runOnModule() { 365 mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation()); 366 367 auto pass = std::make_unique<AbstractResultOpt>(); 368 pass->copyOptionValuesFrom(this); 369 mlir::OpPassManager pipeline; 370 pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()}); 371 372 // Run the pass on all operations directly nested inside of the ModuleOp 373 // we can't just call runOnSpecificOperation here because the pass 374 // implementation only works when scoped to a particular func.func or 375 // fir.global 376 for (mlir::Region ®ion : mod->getRegions()) { 377 for (mlir::Block &block : region.getBlocks()) { 378 for (mlir::Operation &op : block.getOperations()) { 379 if (mlir::failed(runPipeline(pipeline, &op))) { 380 mlir::emitError(op.getLoc(), "Failed to run abstract result pass"); 381 signalPassFailure(); 382 return; 383 } 384 } 385 } 386 } 387 } 388 389 void runOnOperation() override { 390 auto *context = &this->getContext(); 391 mlir::Operation *op = this->getOperation(); 392 if (mlir::isa<mlir::ModuleOp>(op)) { 393 runOnModule(); 394 return; 395 } 396 397 mlir::RewritePatternSet patterns(context); 398 mlir::ConversionTarget target = *context; 399 const bool shouldBoxResult = this->passResultAsBox.getValue(); 400 401 mlir::TypeSwitch<mlir::Operation *, void>(op) 402 .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) { 403 runOnSpecificOperation(op, shouldBoxResult, patterns, target); 404 }); 405 406 // Convert the calls and, if needed, the ReturnOp in the function body. 407 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect, 408 mlir::func::FuncDialect>(); 409 target.addIllegalOp<fir::SaveResultOp>(); 410 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) { 411 return !hasAbstractResult(call.getFunctionType()); 412 }); 413 target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) { 414 if (auto funTy = mlir::dyn_cast<mlir::FunctionType>(addrOf.getType())) 415 return !hasAbstractResult(funTy); 416 return true; 417 }); 418 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { 419 return !hasAbstractResult(dispatch.getFunctionType()); 420 }); 421 422 patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult); 423 patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult); 424 patterns.insert<SaveResultOpConversion>(context); 425 patterns.insert<AddrOfOpConversion>(context, shouldBoxResult); 426 if (mlir::failed( 427 mlir::applyPartialConversion(op, target, std::move(patterns)))) { 428 mlir::emitError(op->getLoc(), "error in converting abstract results\n"); 429 this->signalPassFailure(); 430 } 431 } 432 }; 433 434 } // end anonymous namespace 435 } // namespace fir 436