1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/FIRBuilder.h" 10 #include "flang/Optimizer/Builder/Todo.h" 11 #include "flang/Optimizer/Dialect/FIRDialect.h" 12 #include "flang/Optimizer/Dialect/FIROps.h" 13 #include "flang/Optimizer/Dialect/FIRType.h" 14 #include "flang/Optimizer/Support/FIRContext.h" 15 #include "flang/Optimizer/Transforms/Passes.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/IR/Diagnostics.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "mlir/Transforms/Passes.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 namespace fir { 24 #define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT 25 #define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT 26 #include "flang/Optimizer/Transforms/Passes.h.inc" 27 } // namespace fir 28 29 #define DEBUG_TYPE "flang-abstract-result-opt" 30 31 using namespace mlir; 32 33 namespace fir { 34 namespace { 35 36 static mlir::Type getResultArgumentType(mlir::Type resultType, 37 bool shouldBoxResult) { 38 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) 39 .Case<fir::SequenceType, fir::RecordType>( 40 [&](mlir::Type type) -> mlir::Type { 41 if (shouldBoxResult) 42 return fir::BoxType::get(type); 43 return fir::ReferenceType::get(type); 44 }) 45 .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type { 46 return fir::ReferenceType::get(type); 47 }) 48 .Default([](mlir::Type) -> mlir::Type { 49 llvm_unreachable("bad abstract result type"); 50 }); 51 } 52 53 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, 54 bool shouldBoxResult) { 55 auto resultType = funcTy.getResult(0); 56 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 57 llvm::SmallVector<mlir::Type> newInputTypes = {argTy}; 58 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); 59 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, 60 /*resultTypes=*/{}); 61 } 62 63 /// This is for function result types that are of type C_PTR from ISO_C_BINDING. 64 /// Follow the ABI for interoperability with C. 65 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) { 66 auto resultType = funcTy.getResult(0); 67 assert(fir::isa_builtin_cptr_type(resultType)); 68 llvm::SmallVector<mlir::Type> outputTypes; 69 auto recTy = resultType.dyn_cast<fir::RecordType>(); 70 outputTypes.emplace_back(recTy.getTypeList()[0].second); 71 return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(), 72 outputTypes); 73 } 74 75 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { 76 return resultType.isa<fir::SequenceType, fir::RecordType>() && 77 shouldBoxResult; 78 } 79 80 template <typename Op> 81 class CallConversion : public mlir::OpRewritePattern<Op> { 82 public: 83 using mlir::OpRewritePattern<Op>::OpRewritePattern; 84 85 CallConversion(mlir::MLIRContext *context, bool shouldBoxResult) 86 : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {} 87 88 mlir::LogicalResult 89 matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { 90 auto loc = op.getLoc(); 91 auto result = op->getResult(0); 92 if (!result.hasOneUse()) { 93 mlir::emitError(loc, 94 "calls with abstract result must have exactly one user"); 95 return mlir::failure(); 96 } 97 auto saveResult = 98 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser()); 99 if (!saveResult) { 100 mlir::emitError( 101 loc, "calls with abstract result must be used in fir.save_result"); 102 return mlir::failure(); 103 } 104 auto argType = getResultArgumentType(result.getType(), shouldBoxResult); 105 auto buffer = saveResult.getMemref(); 106 mlir::Value arg = buffer; 107 if (mustEmboxResult(result.getType(), shouldBoxResult)) 108 arg = rewriter.create<fir::EmboxOp>( 109 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, 110 saveResult.getTypeparams()); 111 112 llvm::SmallVector<mlir::Type> newResultTypes; 113 // TODO: This should be generalized for derived types, and it is 114 // architecture and OS dependent. 115 bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType()); 116 Op newOp; 117 if (isResultBuiltinCPtr) { 118 auto recTy = result.getType().template dyn_cast<fir::RecordType>(); 119 newResultTypes.emplace_back(recTy.getTypeList()[0].second); 120 } 121 122 // fir::CallOp specific handling. 123 if constexpr (std::is_same_v<Op, fir::CallOp>) { 124 if (op.getCallee()) { 125 llvm::SmallVector<mlir::Value> newOperands; 126 if (!isResultBuiltinCPtr) 127 newOperands.emplace_back(arg); 128 newOperands.append(op.getOperands().begin(), op.getOperands().end()); 129 newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(), 130 newResultTypes, newOperands); 131 } else { 132 // Indirect calls. 133 llvm::SmallVector<mlir::Type> newInputTypes; 134 if (!isResultBuiltinCPtr) 135 newInputTypes.emplace_back(argType); 136 for (auto operand : op.getOperands().drop_front()) 137 newInputTypes.push_back(operand.getType()); 138 auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes, 139 newResultTypes); 140 141 llvm::SmallVector<mlir::Value> newOperands; 142 newOperands.push_back( 143 rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0))); 144 if (!isResultBuiltinCPtr) 145 newOperands.push_back(arg); 146 newOperands.append(op.getOperands().begin() + 1, 147 op.getOperands().end()); 148 newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, 149 newResultTypes, newOperands); 150 } 151 } 152 153 // fir::DispatchOp specific handling. 154 if constexpr (std::is_same_v<Op, fir::DispatchOp>) { 155 llvm::SmallVector<mlir::Value> newOperands; 156 if (!isResultBuiltinCPtr) 157 newOperands.emplace_back(arg); 158 unsigned passArgShift = newOperands.size(); 159 newOperands.append(op.getOperands().begin() + 1, op.getOperands().end()); 160 161 fir::DispatchOp newDispatchOp; 162 if (op.getPassArgPos()) 163 newOp = rewriter.create<fir::DispatchOp>( 164 loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), 165 op.getOperands()[0], newOperands, 166 rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift)); 167 else 168 newOp = rewriter.create<fir::DispatchOp>( 169 loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), 170 op.getOperands()[0], newOperands, nullptr); 171 } 172 173 if (isResultBuiltinCPtr) { 174 mlir::Value save = saveResult.getMemref(); 175 auto module = op->template getParentOfType<mlir::ModuleOp>(); 176 fir::KindMapping kindMap = fir::getKindMapping(module); 177 FirOpBuilder builder(rewriter, kindMap); 178 mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( 179 builder, loc, save, result.getType()); 180 rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr); 181 } 182 op->dropAllReferences(); 183 rewriter.eraseOp(op); 184 return mlir::success(); 185 } 186 187 private: 188 bool shouldBoxResult; 189 }; 190 191 class SaveResultOpConversion 192 : public mlir::OpRewritePattern<fir::SaveResultOp> { 193 public: 194 using OpRewritePattern::OpRewritePattern; 195 SaveResultOpConversion(mlir::MLIRContext *context) 196 : OpRewritePattern(context) {} 197 mlir::LogicalResult 198 matchAndRewrite(fir::SaveResultOp op, 199 mlir::PatternRewriter &rewriter) const override { 200 rewriter.eraseOp(op); 201 return mlir::success(); 202 } 203 }; 204 205 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { 206 public: 207 using OpRewritePattern::OpRewritePattern; 208 ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) 209 : OpRewritePattern(context), newArg{newArg} {} 210 mlir::LogicalResult 211 matchAndRewrite(mlir::func::ReturnOp ret, 212 mlir::PatternRewriter &rewriter) const override { 213 auto loc = ret.getLoc(); 214 rewriter.setInsertionPoint(ret); 215 auto returnedValue = ret.getOperand(0); 216 bool replacedStorage = false; 217 if (auto *op = returnedValue.getDefiningOp()) 218 if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) { 219 auto resultStorage = load.getMemref(); 220 // TODO: This should be generalized for derived types, and it is 221 // architecture and OS dependent. 222 if (fir::isa_builtin_cptr_type(returnedValue.getType())) { 223 rewriter.eraseOp(load); 224 auto module = ret->getParentOfType<mlir::ModuleOp>(); 225 fir::KindMapping kindMap = fir::getKindMapping(module); 226 FirOpBuilder builder(rewriter, kindMap); 227 mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr( 228 builder, loc, resultStorage, returnedValue.getType()); 229 mlir::Value retValue = rewriter.create<fir::LoadOp>( 230 loc, fir::unwrapRefType(retAddr.getType()), retAddr); 231 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>( 232 ret, mlir::ValueRange{retValue}); 233 return mlir::success(); 234 } 235 load.getMemref().replaceAllUsesWith(newArg); 236 replacedStorage = true; 237 if (auto *alloc = resultStorage.getDefiningOp()) 238 if (alloc->use_empty()) 239 rewriter.eraseOp(alloc); 240 } 241 // The result storage may have been optimized out by a memory to 242 // register pass, this is possible for fir.box results, or fir.record 243 // with no length parameters. Simply store the result in the result storage. 244 // at the return point. 245 if (!replacedStorage) 246 rewriter.create<fir::StoreOp>(loc, returnedValue, newArg); 247 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); 248 return mlir::success(); 249 } 250 251 private: 252 mlir::Value newArg; 253 }; 254 255 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> { 256 public: 257 using OpRewritePattern::OpRewritePattern; 258 AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) 259 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} 260 mlir::LogicalResult 261 matchAndRewrite(fir::AddrOfOp addrOf, 262 mlir::PatternRewriter &rewriter) const override { 263 auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>(); 264 mlir::FunctionType newFuncTy; 265 // TODO: This should be generalized for derived types, and it is 266 // architecture and OS dependent. 267 if (oldFuncTy.getNumResults() != 0 && 268 fir::isa_builtin_cptr_type(oldFuncTy.getResult(0))) 269 newFuncTy = getCPtrFunctionType(oldFuncTy); 270 else 271 newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); 272 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy, 273 addrOf.getSymbol()); 274 // Rather than converting all op a function pointer might transit through 275 // (e.g calls, stores, loads, converts...), cast new type to the abstract 276 // type. A conversion will be added when calling indirect calls of abstract 277 // types. 278 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf); 279 return mlir::success(); 280 } 281 282 private: 283 bool shouldBoxResult; 284 }; 285 286 /// @brief Base CRTP class for AbstractResult pass family. 287 /// Contains common logic for abstract result conversion in a reusable fashion. 288 /// @tparam Pass target class that implements operation-specific logic. 289 /// @tparam PassBase base class template for the pass generated by TableGen. 290 /// The `Pass` class must define runOnSpecificOperation(OpTy, bool, 291 /// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function. 292 /// This function should implement operation-specific functionality. 293 template <typename Pass, template <typename> class PassBase> 294 class AbstractResultOptTemplate : public PassBase<Pass> { 295 public: 296 void runOnOperation() override { 297 auto *context = &this->getContext(); 298 auto op = this->getOperation(); 299 300 mlir::RewritePatternSet patterns(context); 301 mlir::ConversionTarget target = *context; 302 const bool shouldBoxResult = this->passResultAsBox.getValue(); 303 304 auto &self = static_cast<Pass &>(*this); 305 self.runOnSpecificOperation(op, shouldBoxResult, patterns, target); 306 307 // Convert the calls and, if needed, the ReturnOp in the function body. 308 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect, 309 mlir::func::FuncDialect>(); 310 target.addIllegalOp<fir::SaveResultOp>(); 311 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) { 312 return !hasAbstractResult(call.getFunctionType()); 313 }); 314 target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) { 315 if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>()) 316 return !hasAbstractResult(funTy); 317 return true; 318 }); 319 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { 320 return !hasAbstractResult(dispatch.getFunctionType()); 321 }); 322 323 patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult); 324 patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult); 325 patterns.insert<SaveResultOpConversion>(context); 326 patterns.insert<AddrOfOpConversion>(context, shouldBoxResult); 327 if (mlir::failed( 328 mlir::applyPartialConversion(op, target, std::move(patterns)))) { 329 mlir::emitError(op.getLoc(), "error in converting abstract results\n"); 330 this->signalPassFailure(); 331 } 332 } 333 }; 334 335 class AbstractResultOnFuncOpt 336 : public AbstractResultOptTemplate<AbstractResultOnFuncOpt, 337 fir::impl::AbstractResultOnFuncOptBase> { 338 public: 339 void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult, 340 mlir::RewritePatternSet &patterns, 341 mlir::ConversionTarget &target) { 342 auto loc = func.getLoc(); 343 auto *context = &getContext(); 344 // Convert function type itself if it has an abstract result. 345 auto funcTy = func.getFunctionType().cast<mlir::FunctionType>(); 346 if (hasAbstractResult(funcTy)) { 347 // TODO: This should be generalized for derived types, and it is 348 // architecture and OS dependent. 349 if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) { 350 func.setType(getCPtrFunctionType(funcTy)); 351 patterns.insert<ReturnOpConversion>(context, mlir::Value{}); 352 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 353 [](mlir::func::ReturnOp ret) { 354 mlir::Type retTy = ret.getOperand(0).getType(); 355 return !fir::isa_builtin_cptr_type(retTy); 356 }); 357 return; 358 } 359 if (!func.empty()) { 360 // Insert new argument. 361 mlir::OpBuilder rewriter(context); 362 auto resultType = funcTy.getResult(0); 363 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 364 func.insertArgument(0u, argTy, {}, loc); 365 func.eraseResult(0u); 366 mlir::Value newArg = func.getArgument(0u); 367 if (mustEmboxResult(resultType, shouldBoxResult)) { 368 auto bufferType = fir::ReferenceType::get(resultType); 369 rewriter.setInsertionPointToStart(&func.front()); 370 newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg); 371 } 372 patterns.insert<ReturnOpConversion>(context, newArg); 373 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 374 [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); 375 assert(func.getFunctionType() == 376 getNewFunctionType(funcTy, shouldBoxResult)); 377 } else { 378 llvm::SmallVector<mlir::DictionaryAttr> allArgs; 379 func.getAllArgAttrs(allArgs); 380 allArgs.insert(allArgs.begin(), 381 mlir::DictionaryAttr::get(func->getContext())); 382 func.setType(getNewFunctionType(funcTy, shouldBoxResult)); 383 func.setAllArgAttrs(allArgs); 384 } 385 } 386 } 387 }; 388 389 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) { 390 return mlir::TypeSwitch<mlir::Type, bool>(type) 391 .Case([](fir::BoxProcType boxProc) { 392 return fir::hasAbstractResult( 393 boxProc.getEleTy().cast<mlir::FunctionType>()); 394 }) 395 .Case([](fir::PointerType pointer) { 396 return fir::hasAbstractResult( 397 pointer.getEleTy().cast<mlir::FunctionType>()); 398 }) 399 .Default([](auto &&) { return false; }); 400 } 401 402 class AbstractResultOnGlobalOpt 403 : public AbstractResultOptTemplate< 404 AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> { 405 public: 406 void runOnSpecificOperation(fir::GlobalOp global, bool, 407 mlir::RewritePatternSet &, 408 mlir::ConversionTarget &) { 409 if (containsFunctionTypeWithAbstractResult(global.getType())) { 410 TODO(global->getLoc(), "support for procedure pointers"); 411 } 412 } 413 }; 414 } // end anonymous namespace 415 } // namespace fir 416 417 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() { 418 return std::make_unique<AbstractResultOnFuncOpt>(); 419 } 420 421 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() { 422 return std::make_unique<AbstractResultOnGlobalOpt>(); 423 } 424