1 //===-- TargetRewrite.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest. 10 // LLVM expects different lowering idioms to be used for distinct target 11 // triples. These distinctions are handled by this pass. 12 // 13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "flang/Optimizer/CodeGen/CodeGen.h" 18 19 #include "flang/Optimizer/Builder/Character.h" 20 #include "flang/Optimizer/Builder/FIRBuilder.h" 21 #include "flang/Optimizer/Builder/Todo.h" 22 #include "flang/Optimizer/CodeGen/Target.h" 23 #include "flang/Optimizer/Dialect/FIRDialect.h" 24 #include "flang/Optimizer/Dialect/FIROps.h" 25 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 26 #include "flang/Optimizer/Dialect/FIRType.h" 27 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 28 #include "flang/Optimizer/Support/DataLayout.h" 29 #include "mlir/Dialect/DLTI/DLTI.h" 30 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 31 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 32 #include "mlir/Transforms/DialectConversion.h" 33 #include "llvm/ADT/STLExtras.h" 34 #include "llvm/ADT/TypeSwitch.h" 35 #include "llvm/Support/Debug.h" 36 #include <optional> 37 38 namespace fir { 39 #define GEN_PASS_DEF_TARGETREWRITEPASS 40 #include "flang/Optimizer/CodeGen/CGPasses.h.inc" 41 } // namespace fir 42 43 #define DEBUG_TYPE "flang-target-rewrite" 44 45 namespace { 46 47 /// Fixups for updating a FuncOp's arguments and return values. 48 struct FixupTy { 49 enum class Codes { 50 ArgumentAsLoad, 51 ArgumentType, 52 CharPair, 53 ReturnAsStore, 54 ReturnType, 55 Split, 56 Trailing, 57 TrailingCharProc 58 }; 59 60 FixupTy(Codes code, std::size_t index, std::size_t second = 0) 61 : code{code}, index{index}, second{second} {} 62 FixupTy(Codes code, std::size_t index, 63 std::function<void(mlir::func::FuncOp)> &&finalizer) 64 : code{code}, index{index}, finalizer{finalizer} {} 65 FixupTy(Codes code, std::size_t index, 66 std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer) 67 : code{code}, index{index}, gpuFinalizer{finalizer} {} 68 FixupTy(Codes code, std::size_t index, std::size_t second, 69 std::function<void(mlir::func::FuncOp)> &&finalizer) 70 : code{code}, index{index}, second{second}, finalizer{finalizer} {} 71 FixupTy(Codes code, std::size_t index, std::size_t second, 72 std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer) 73 : code{code}, index{index}, second{second}, gpuFinalizer{finalizer} {} 74 75 Codes code; 76 std::size_t index; 77 std::size_t second{}; 78 std::optional<std::function<void(mlir::func::FuncOp)>> finalizer{}; 79 std::optional<std::function<void(mlir::gpu::GPUFuncOp)>> gpuFinalizer{}; 80 }; // namespace 81 82 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code 83 /// generation that traverses the FIR and modifies types and operations to a 84 /// form that is appropriate for the specific target. LLVM IR has specific 85 /// idioms that are used for distinct target processor and ABI combinations. 86 class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> { 87 public: 88 using TargetRewritePassBase<TargetRewrite>::TargetRewritePassBase; 89 90 void runOnOperation() override final { 91 auto &context = getContext(); 92 mlir::OpBuilder rewriter(&context); 93 94 auto mod = getModule(); 95 if (!forcedTargetTriple.empty()) 96 fir::setTargetTriple(mod, forcedTargetTriple); 97 98 if (!forcedTargetCPU.empty()) 99 fir::setTargetCPU(mod, forcedTargetCPU); 100 101 if (!forcedTuneCPU.empty()) 102 fir::setTuneCPU(mod, forcedTuneCPU); 103 104 if (!forcedTargetFeatures.empty()) 105 fir::setTargetFeatures(mod, forcedTargetFeatures); 106 107 // TargetRewrite will require querying the type storage sizes, if it was 108 // not set already, create a DataLayoutSpec for the ModuleOp now. 109 std::optional<mlir::DataLayout> dl = 110 fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true); 111 if (!dl) { 112 mlir::emitError(mod.getLoc(), 113 "module operation must carry a data layout attribute " 114 "to perform target ABI rewrites on FIR"); 115 signalPassFailure(); 116 return; 117 } 118 119 auto specifics = fir::CodeGenSpecifics::get( 120 mod.getContext(), fir::getTargetTriple(mod), fir::getKindMapping(mod), 121 fir::getTargetCPU(mod), fir::getTargetFeatures(mod), *dl, 122 fir::getTuneCPU(mod)); 123 124 setMembers(specifics.get(), &rewriter, &*dl); 125 126 // Perform type conversion on signatures and call sites. 127 if (mlir::failed(convertTypes(mod))) { 128 mlir::emitError(mlir::UnknownLoc::get(&context), 129 "error in converting types to target abi"); 130 signalPassFailure(); 131 } 132 133 // Convert ops in target-specific patterns. 134 mod.walk([&](mlir::Operation *op) { 135 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) { 136 if (!hasPortableSignature(call.getFunctionType(), op)) 137 convertCallOp(call, call.getFunctionType()); 138 } else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) { 139 if (!hasPortableSignature(dispatch.getFunctionType(), op)) 140 convertCallOp(dispatch, dispatch.getFunctionType()); 141 } else if (auto gpuLaunchFunc = 142 mlir::dyn_cast<mlir::gpu::LaunchFuncOp>(op)) { 143 llvm::SmallVector<mlir::Type> operandsTypes; 144 for (auto arg : gpuLaunchFunc.getKernelOperands()) 145 operandsTypes.push_back(arg.getType()); 146 auto fctTy = mlir::FunctionType::get(&context, operandsTypes, {}); 147 if (!hasPortableSignature(fctTy, op)) 148 convertCallOp(gpuLaunchFunc, fctTy); 149 } else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) { 150 if (mlir::isa<mlir::FunctionType>(addr.getType()) && 151 !hasPortableSignature(addr.getType(), op)) 152 convertAddrOp(addr); 153 } 154 }); 155 156 clearMembers(); 157 } 158 159 mlir::ModuleOp getModule() { return getOperation(); } 160 161 template <typename Ty, typename Callback> 162 std::optional<std::function<mlir::Value(mlir::Operation *)>> 163 rewriteCallResultType(mlir::Location loc, mlir::Type originalResTy, 164 Ty &newResTys, 165 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, 166 Callback &newOpers, mlir::Value &savedStackPtr, 167 fir::CodeGenSpecifics::Marshalling &m) { 168 // Currently, targets mandate COMPLEX or STRUCT is a single aggregate or 169 // packed scalar, including the sret case. 170 assert(m.size() == 1 && "return type not supported on this target"); 171 auto resTy = std::get<mlir::Type>(m[0]); 172 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]); 173 if (attr.isSRet()) { 174 assert(fir::isa_ref_type(resTy) && "must be a memory reference type"); 175 // Save the stack pointer, if it has not been saved for this call yet. 176 // We will need to restore it after the call, because the alloca 177 // needs to be deallocated. 178 if (!savedStackPtr) 179 savedStackPtr = genStackSave(loc); 180 mlir::Value stack = 181 rewriter->create<fir::AllocaOp>(loc, fir::dyn_cast_ptrEleTy(resTy)); 182 newInTyAndAttrs.push_back(m[0]); 183 newOpers.push_back(stack); 184 return [=](mlir::Operation *) -> mlir::Value { 185 auto memTy = fir::ReferenceType::get(originalResTy); 186 auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, stack); 187 return rewriter->create<fir::LoadOp>(loc, cast); 188 }; 189 } 190 newResTys.push_back(resTy); 191 return [=, &savedStackPtr](mlir::Operation *call) -> mlir::Value { 192 // We are going to generate an alloca, so save the stack pointer. 193 if (!savedStackPtr) 194 savedStackPtr = genStackSave(loc); 195 return this->convertValueInMemory(loc, call->getResult(0), originalResTy, 196 /*inputMayBeBigger=*/true); 197 }; 198 } 199 200 template <typename Ty, typename Callback> 201 std::optional<std::function<mlir::Value(mlir::Operation *)>> 202 rewriteCallComplexResultType( 203 mlir::Location loc, mlir::ComplexType ty, Ty &newResTys, 204 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, Callback &newOpers, 205 mlir::Value &savedStackPtr) { 206 if (noComplexConversion) { 207 newResTys.push_back(ty); 208 return std::nullopt; 209 } 210 auto m = specifics->complexReturnType(loc, ty.getElementType()); 211 return rewriteCallResultType(loc, ty, newResTys, newInTyAndAttrs, newOpers, 212 savedStackPtr, m); 213 } 214 215 template <typename Ty, typename Callback> 216 std::optional<std::function<mlir::Value(mlir::Operation *)>> 217 rewriteCallStructResultType( 218 mlir::Location loc, fir::RecordType recTy, Ty &newResTys, 219 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, Callback &newOpers, 220 mlir::Value &savedStackPtr) { 221 if (noStructConversion) { 222 newResTys.push_back(recTy); 223 return std::nullopt; 224 } 225 auto m = specifics->structReturnType(loc, recTy); 226 return rewriteCallResultType(loc, recTy, newResTys, newInTyAndAttrs, 227 newOpers, savedStackPtr, m); 228 } 229 230 void passArgumentOnStackOrWithNewType( 231 mlir::Location loc, fir::CodeGenSpecifics::TypeAndAttr newTypeAndAttr, 232 mlir::Type oldType, mlir::Value oper, 233 llvm::SmallVectorImpl<mlir::Value> &newOpers, 234 mlir::Value &savedStackPtr) { 235 auto resTy = std::get<mlir::Type>(newTypeAndAttr); 236 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(newTypeAndAttr); 237 // We are going to generate an alloca, so save the stack pointer. 238 if (!savedStackPtr) 239 savedStackPtr = genStackSave(loc); 240 if (attr.isByVal()) { 241 mlir::Value mem = rewriter->create<fir::AllocaOp>(loc, oldType); 242 rewriter->create<fir::StoreOp>(loc, oper, mem); 243 if (mem.getType() != resTy) 244 mem = rewriter->create<fir::ConvertOp>(loc, resTy, mem); 245 newOpers.push_back(mem); 246 } else { 247 mlir::Value bitcast = 248 convertValueInMemory(loc, oper, resTy, /*inputMayBeBigger=*/false); 249 newOpers.push_back(bitcast); 250 } 251 } 252 253 // Do a bitcast (convert a value via its memory representation). 254 // The input and output types may have different storage sizes, 255 // "inputMayBeBigger" should be set to indicate which of the input or 256 // output type may be bigger in order for the load/store to be safe. 257 // The mismatch comes from the fact that the LLVM register used for passing 258 // may be bigger than the value being passed (e.g., passing 259 // a `!fir.type<t{fir.array<3xi8>}>` into an i32 LLVM register). 260 mlir::Value convertValueInMemory(mlir::Location loc, mlir::Value value, 261 mlir::Type newType, bool inputMayBeBigger) { 262 if (inputMayBeBigger) { 263 auto newRefTy = fir::ReferenceType::get(newType); 264 auto mem = rewriter->create<fir::AllocaOp>(loc, value.getType()); 265 rewriter->create<fir::StoreOp>(loc, value, mem); 266 auto cast = rewriter->create<fir::ConvertOp>(loc, newRefTy, mem); 267 return rewriter->create<fir::LoadOp>(loc, cast); 268 } else { 269 auto oldRefTy = fir::ReferenceType::get(value.getType()); 270 auto mem = rewriter->create<fir::AllocaOp>(loc, newType); 271 auto cast = rewriter->create<fir::ConvertOp>(loc, oldRefTy, mem); 272 rewriter->create<fir::StoreOp>(loc, value, cast); 273 return rewriter->create<fir::LoadOp>(loc, mem); 274 } 275 } 276 277 void passSplitArgument(mlir::Location loc, 278 fir::CodeGenSpecifics::Marshalling splitArgs, 279 mlir::Type oldType, mlir::Value oper, 280 llvm::SmallVectorImpl<mlir::Value> &newOpers, 281 mlir::Value &savedStackPtr) { 282 // COMPLEX or struct argument split into separate arguments 283 if (!fir::isa_complex(oldType)) { 284 // Cast original operand to a tuple of the new arguments 285 // via memory. 286 llvm::SmallVector<mlir::Type> partTypes; 287 for (auto argPart : splitArgs) 288 partTypes.push_back(std::get<mlir::Type>(argPart)); 289 mlir::Type tupleType = 290 mlir::TupleType::get(oldType.getContext(), partTypes); 291 if (!savedStackPtr) 292 savedStackPtr = genStackSave(loc); 293 oper = convertValueInMemory(loc, oper, tupleType, 294 /*inputMayBeBigger=*/false); 295 } 296 auto iTy = rewriter->getIntegerType(32); 297 for (auto e : llvm::enumerate(splitArgs)) { 298 auto &tup = e.value(); 299 auto ty = std::get<mlir::Type>(tup); 300 auto index = e.index(); 301 auto idx = rewriter->getIntegerAttr(iTy, index); 302 auto val = rewriter->create<fir::ExtractValueOp>( 303 loc, ty, oper, rewriter->getArrayAttr(idx)); 304 newOpers.push_back(val); 305 } 306 } 307 308 void rewriteCallOperands( 309 mlir::Location loc, fir::CodeGenSpecifics::Marshalling passArgAs, 310 mlir::Type originalArgTy, mlir::Value oper, 311 llvm::SmallVectorImpl<mlir::Value> &newOpers, mlir::Value &savedStackPtr, 312 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) { 313 if (passArgAs.size() == 1) { 314 // COMPLEX or derived type is passed as a single argument. 315 passArgumentOnStackOrWithNewType(loc, passArgAs[0], originalArgTy, oper, 316 newOpers, savedStackPtr); 317 } else { 318 // COMPLEX or derived type is split into separate arguments 319 passSplitArgument(loc, passArgAs, originalArgTy, oper, newOpers, 320 savedStackPtr); 321 } 322 newInTyAndAttrs.insert(newInTyAndAttrs.end(), passArgAs.begin(), 323 passArgAs.end()); 324 } 325 326 template <typename CPLX> 327 void rewriteCallComplexInputType( 328 mlir::Location loc, CPLX ty, mlir::Value oper, 329 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, 330 llvm::SmallVectorImpl<mlir::Value> &newOpers, 331 mlir::Value &savedStackPtr) { 332 if (noComplexConversion) { 333 newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(ty)); 334 newOpers.push_back(oper); 335 return; 336 } 337 auto m = specifics->complexArgumentType(loc, ty.getElementType()); 338 rewriteCallOperands(loc, m, ty, oper, newOpers, savedStackPtr, 339 newInTyAndAttrs); 340 } 341 342 void rewriteCallStructInputType( 343 mlir::Location loc, fir::RecordType recTy, mlir::Value oper, 344 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, 345 llvm::SmallVectorImpl<mlir::Value> &newOpers, 346 mlir::Value &savedStackPtr) { 347 if (noStructConversion) { 348 newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(recTy)); 349 newOpers.push_back(oper); 350 return; 351 } 352 auto structArgs = 353 specifics->structArgumentType(loc, recTy, newInTyAndAttrs); 354 rewriteCallOperands(loc, structArgs, recTy, oper, newOpers, savedStackPtr, 355 newInTyAndAttrs); 356 } 357 358 static bool hasByValOrSRetArgs( 359 const fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) { 360 return llvm::any_of(newInTyAndAttrs, [](auto arg) { 361 const auto &attr = std::get<fir::CodeGenSpecifics::Attributes>(arg); 362 return attr.isByVal() || attr.isSRet(); 363 }); 364 } 365 366 // Convert fir.call and fir.dispatch Ops. 367 template <typename A> 368 void convertCallOp(A callOp, mlir::FunctionType fnTy) { 369 auto loc = callOp.getLoc(); 370 rewriter->setInsertionPoint(callOp); 371 llvm::SmallVector<mlir::Type> newResTys; 372 fir::CodeGenSpecifics::Marshalling newInTyAndAttrs; 373 llvm::SmallVector<mlir::Value> newOpers; 374 mlir::Value savedStackPtr = nullptr; 375 376 // If the call is indirect, the first argument must still be the function 377 // to call. 378 int dropFront = 0; 379 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 380 if (!callOp.getCallee()) { 381 newInTyAndAttrs.push_back( 382 fir::CodeGenSpecifics::getTypeAndAttr(fnTy.getInput(0))); 383 newOpers.push_back(callOp.getOperand(0)); 384 dropFront = 1; 385 } 386 } else if constexpr (std::is_same_v<std::decay_t<A>, fir::DispatchOp>) { 387 dropFront = 1; // First operand is the polymorphic object. 388 } 389 390 // Determine the rewrite function, `wrap`, for the result value. 391 std::optional<std::function<mlir::Value(mlir::Operation *)>> wrap; 392 if (fnTy.getResults().size() == 1) { 393 mlir::Type ty = fnTy.getResult(0); 394 llvm::TypeSwitch<mlir::Type>(ty) 395 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 396 wrap = rewriteCallComplexResultType(loc, cmplx, newResTys, 397 newInTyAndAttrs, newOpers, 398 savedStackPtr); 399 }) 400 .template Case<fir::RecordType>([&](fir::RecordType recTy) { 401 wrap = rewriteCallStructResultType(loc, recTy, newResTys, 402 newInTyAndAttrs, newOpers, 403 savedStackPtr); 404 }) 405 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 406 } else if (fnTy.getResults().size() > 1) { 407 TODO(loc, "multiple results not supported yet"); 408 } 409 410 llvm::SmallVector<mlir::Type> trailingInTys; 411 llvm::SmallVector<mlir::Value> trailingOpers; 412 llvm::SmallVector<mlir::Value> operands; 413 unsigned passArgShift = 0; 414 if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) 415 operands = callOp.getKernelOperands(); 416 else 417 operands = callOp.getOperands().drop_front(dropFront); 418 for (auto e : llvm::enumerate( 419 llvm::zip(fnTy.getInputs().drop_front(dropFront), operands))) { 420 mlir::Type ty = std::get<0>(e.value()); 421 mlir::Value oper = std::get<1>(e.value()); 422 unsigned index = e.index(); 423 llvm::TypeSwitch<mlir::Type>(ty) 424 .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) { 425 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 426 if (noCharacterConversion) { 427 newInTyAndAttrs.push_back( 428 fir::CodeGenSpecifics::getTypeAndAttr(boxTy)); 429 newOpers.push_back(oper); 430 return; 431 } 432 } else { 433 // TODO: dispatch case; it used to be a to-do because of sret, 434 // but is not tested and maybe should be removed. This pass is 435 // anyway ran after lowering fir.dispatch in flang, so maybe that 436 // should just be a requirement of the pass. 437 TODO(loc, "ABI of fir.dispatch with character arguments"); 438 } 439 auto m = specifics->boxcharArgumentType(boxTy.getEleTy()); 440 auto unbox = rewriter->create<fir::UnboxCharOp>( 441 loc, std::get<mlir::Type>(m[0]), std::get<mlir::Type>(m[1]), 442 oper); 443 // unboxed CHARACTER arguments 444 for (auto e : llvm::enumerate(m)) { 445 unsigned idx = e.index(); 446 auto attr = 447 std::get<fir::CodeGenSpecifics::Attributes>(e.value()); 448 auto argTy = std::get<mlir::Type>(e.value()); 449 if (attr.isAppend()) { 450 trailingInTys.push_back(argTy); 451 trailingOpers.push_back(unbox.getResult(idx)); 452 } else { 453 newInTyAndAttrs.push_back(e.value()); 454 newOpers.push_back(unbox.getResult(idx)); 455 } 456 } 457 }) 458 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 459 rewriteCallComplexInputType(loc, cmplx, oper, newInTyAndAttrs, 460 newOpers, savedStackPtr); 461 }) 462 .template Case<fir::RecordType>([&](fir::RecordType recTy) { 463 rewriteCallStructInputType(loc, recTy, oper, newInTyAndAttrs, 464 newOpers, savedStackPtr); 465 }) 466 .template Case<mlir::TupleType>([&](mlir::TupleType tuple) { 467 if (fir::isCharacterProcedureTuple(tuple)) { 468 mlir::ModuleOp module = getModule(); 469 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 470 if (callOp.getCallee()) { 471 llvm::StringRef charProcAttr = 472 fir::getCharacterProcedureDummyAttrName(); 473 // The charProcAttr attribute is only used as a safety to 474 // confirm that this is a dummy procedure and should be split. 475 // It cannot be used to match because attributes are not 476 // available in case of indirect calls. 477 auto funcOp = module.lookupSymbol<mlir::func::FuncOp>( 478 *callOp.getCallee()); 479 if (funcOp && 480 !funcOp.template getArgAttrOfType<mlir::UnitAttr>( 481 index, charProcAttr)) 482 mlir::emitError(loc, "tuple argument will be split even " 483 "though it does not have the `" + 484 charProcAttr + "` attribute"); 485 } 486 } 487 mlir::Type funcPointerType = tuple.getType(0); 488 mlir::Type lenType = tuple.getType(1); 489 fir::FirOpBuilder builder(*rewriter, module); 490 auto [funcPointer, len] = 491 fir::factory::extractCharacterProcedureTuple(builder, loc, 492 oper); 493 newInTyAndAttrs.push_back( 494 fir::CodeGenSpecifics::getTypeAndAttr(funcPointerType)); 495 newOpers.push_back(funcPointer); 496 trailingInTys.push_back(lenType); 497 trailingOpers.push_back(len); 498 } else { 499 newInTyAndAttrs.push_back( 500 fir::CodeGenSpecifics::getTypeAndAttr(tuple)); 501 newOpers.push_back(oper); 502 } 503 }) 504 .Default([&](mlir::Type ty) { 505 if constexpr (std::is_same_v<std::decay_t<A>, fir::DispatchOp>) { 506 if (callOp.getPassArgPos() && *callOp.getPassArgPos() == index) 507 passArgShift = newOpers.size() - *callOp.getPassArgPos(); 508 } 509 newInTyAndAttrs.push_back( 510 fir::CodeGenSpecifics::getTypeAndAttr(ty)); 511 newOpers.push_back(oper); 512 }); 513 } 514 515 llvm::SmallVector<mlir::Type> newInTypes = toTypeList(newInTyAndAttrs); 516 newInTypes.insert(newInTypes.end(), trailingInTys.begin(), 517 trailingInTys.end()); 518 newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); 519 520 llvm::SmallVector<mlir::Value, 1> newCallResults; 521 if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) { 522 auto newCall = rewriter->create<A>( 523 loc, callOp.getKernel(), callOp.getGridSizeOperandValues(), 524 callOp.getBlockSizeOperandValues(), 525 callOp.getDynamicSharedMemorySize(), newOpers); 526 if (callOp.getClusterSizeX()) 527 newCall.getClusterSizeXMutable().assign(callOp.getClusterSizeX()); 528 if (callOp.getClusterSizeY()) 529 newCall.getClusterSizeYMutable().assign(callOp.getClusterSizeY()); 530 if (callOp.getClusterSizeZ()) 531 newCall.getClusterSizeZMutable().assign(callOp.getClusterSizeZ()); 532 newCallResults.append(newCall.result_begin(), newCall.result_end()); 533 } else if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 534 fir::CallOp newCall; 535 if (callOp.getCallee()) { 536 newCall = 537 rewriter->create<A>(loc, *callOp.getCallee(), newResTys, newOpers); 538 } else { 539 // TODO: llvm dialect must be updated to propagate argument on 540 // attributes for indirect calls. See: 541 // https://discourse.llvm.org/t/should-llvm-callop-be-able-to-carry-argument-attributes-for-indirect-calls/75431 542 if (hasByValOrSRetArgs(newInTyAndAttrs)) 543 TODO(loc, 544 "passing argument or result on the stack in indirect calls"); 545 newOpers[0].setType(mlir::FunctionType::get( 546 callOp.getContext(), 547 mlir::TypeRange{newInTypes}.drop_front(dropFront), newResTys)); 548 newCall = rewriter->create<A>(loc, newResTys, newOpers); 549 } 550 LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n'); 551 if (wrap) 552 newCallResults.push_back((*wrap)(newCall.getOperation())); 553 else 554 newCallResults.append(newCall.result_begin(), newCall.result_end()); 555 } else { 556 fir::DispatchOp dispatchOp = rewriter->create<A>( 557 loc, newResTys, rewriter->getStringAttr(callOp.getMethod()), 558 callOp.getOperands()[0], newOpers, 559 rewriter->getI32IntegerAttr(*callOp.getPassArgPos() + passArgShift), 560 callOp.getProcedureAttrsAttr()); 561 if (wrap) 562 newCallResults.push_back((*wrap)(dispatchOp.getOperation())); 563 else 564 newCallResults.append(dispatchOp.result_begin(), 565 dispatchOp.result_end()); 566 } 567 568 if (newCallResults.size() <= 1) { 569 if (savedStackPtr) { 570 if (newCallResults.size() == 1) { 571 // We assume that all the allocas are inserted before 572 // the operation that defines the new call result. 573 rewriter->setInsertionPointAfterValue(newCallResults[0]); 574 } else { 575 // If the call does not have results, then insert 576 // stack restore after the original call operation. 577 rewriter->setInsertionPointAfter(callOp); 578 } 579 genStackRestore(loc, savedStackPtr); 580 } 581 replaceOp(callOp, newCallResults); 582 } else { 583 // The TODO is duplicated here to make sure this part 584 // handles the stackrestore insertion properly, if 585 // we add support for multiple call results. 586 TODO(loc, "multiple results not supported yet"); 587 } 588 } 589 590 // Result type fixup for ComplexType. 591 template <typename Ty> 592 void lowerComplexSignatureRes( 593 mlir::Location loc, mlir::ComplexType cmplx, Ty &newResTys, 594 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) { 595 if (noComplexConversion) { 596 newResTys.push_back(cmplx); 597 return; 598 } 599 for (auto &tup : 600 specifics->complexReturnType(loc, cmplx.getElementType())) { 601 auto argTy = std::get<mlir::Type>(tup); 602 if (std::get<fir::CodeGenSpecifics::Attributes>(tup).isSRet()) 603 newInTyAndAttrs.push_back(tup); 604 else 605 newResTys.push_back(argTy); 606 } 607 } 608 609 // Argument type fixup for ComplexType. 610 void lowerComplexSignatureArg( 611 mlir::Location loc, mlir::ComplexType cmplx, 612 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) { 613 if (noComplexConversion) { 614 newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(cmplx)); 615 } else { 616 auto cplxArgs = 617 specifics->complexArgumentType(loc, cmplx.getElementType()); 618 newInTyAndAttrs.insert(newInTyAndAttrs.end(), cplxArgs.begin(), 619 cplxArgs.end()); 620 } 621 } 622 623 template <typename Ty> 624 void 625 lowerStructSignatureRes(mlir::Location loc, fir::RecordType recTy, 626 Ty &newResTys, 627 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) { 628 if (noComplexConversion) { 629 newResTys.push_back(recTy); 630 return; 631 } else { 632 for (auto &tup : specifics->structReturnType(loc, recTy)) { 633 if (std::get<fir::CodeGenSpecifics::Attributes>(tup).isSRet()) 634 newInTyAndAttrs.push_back(tup); 635 else 636 newResTys.push_back(std::get<mlir::Type>(tup)); 637 } 638 } 639 } 640 641 void 642 lowerStructSignatureArg(mlir::Location loc, fir::RecordType recTy, 643 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) { 644 if (noStructConversion) { 645 newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(recTy)); 646 return; 647 } 648 auto structArgs = 649 specifics->structArgumentType(loc, recTy, newInTyAndAttrs); 650 newInTyAndAttrs.insert(newInTyAndAttrs.end(), structArgs.begin(), 651 structArgs.end()); 652 } 653 654 llvm::SmallVector<mlir::Type> 655 toTypeList(const fir::CodeGenSpecifics::Marshalling &marshalled) { 656 llvm::SmallVector<mlir::Type> typeList; 657 for (auto &typeAndAttr : marshalled) 658 typeList.emplace_back(std::get<mlir::Type>(typeAndAttr)); 659 return typeList; 660 } 661 662 /// Taking the address of a function. Modify the signature as needed. 663 void convertAddrOp(fir::AddrOfOp addrOp) { 664 rewriter->setInsertionPoint(addrOp); 665 auto addrTy = mlir::cast<mlir::FunctionType>(addrOp.getType()); 666 fir::CodeGenSpecifics::Marshalling newInTyAndAttrs; 667 llvm::SmallVector<mlir::Type> newResTys; 668 auto loc = addrOp.getLoc(); 669 for (mlir::Type ty : addrTy.getResults()) { 670 llvm::TypeSwitch<mlir::Type>(ty) 671 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 672 lowerComplexSignatureRes(loc, ty, newResTys, newInTyAndAttrs); 673 }) 674 .Case<fir::RecordType>([&](fir::RecordType ty) { 675 lowerStructSignatureRes(loc, ty, newResTys, newInTyAndAttrs); 676 }) 677 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 678 } 679 llvm::SmallVector<mlir::Type> trailingInTys; 680 for (mlir::Type ty : addrTy.getInputs()) { 681 llvm::TypeSwitch<mlir::Type>(ty) 682 .Case<fir::BoxCharType>([&](auto box) { 683 if (noCharacterConversion) { 684 newInTyAndAttrs.push_back( 685 fir::CodeGenSpecifics::getTypeAndAttr(box)); 686 } else { 687 for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) { 688 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup); 689 auto argTy = std::get<mlir::Type>(tup); 690 if (attr.isAppend()) 691 trailingInTys.push_back(argTy); 692 else 693 newInTyAndAttrs.push_back(tup); 694 } 695 } 696 }) 697 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 698 lowerComplexSignatureArg(loc, ty, newInTyAndAttrs); 699 }) 700 .Case<mlir::TupleType>([&](mlir::TupleType tuple) { 701 if (fir::isCharacterProcedureTuple(tuple)) { 702 newInTyAndAttrs.push_back( 703 fir::CodeGenSpecifics::getTypeAndAttr(tuple.getType(0))); 704 trailingInTys.push_back(tuple.getType(1)); 705 } else { 706 newInTyAndAttrs.push_back( 707 fir::CodeGenSpecifics::getTypeAndAttr(ty)); 708 } 709 }) 710 .template Case<fir::RecordType>([&](fir::RecordType recTy) { 711 lowerStructSignatureArg(loc, recTy, newInTyAndAttrs); 712 }) 713 .Default([&](mlir::Type ty) { 714 newInTyAndAttrs.push_back( 715 fir::CodeGenSpecifics::getTypeAndAttr(ty)); 716 }); 717 } 718 llvm::SmallVector<mlir::Type> newInTypes = toTypeList(newInTyAndAttrs); 719 // append trailing input types 720 newInTypes.insert(newInTypes.end(), trailingInTys.begin(), 721 trailingInTys.end()); 722 // replace this op with a new one with the updated signature 723 auto newTy = rewriter->getFunctionType(newInTypes, newResTys); 724 auto newOp = rewriter->create<fir::AddrOfOp>(addrOp.getLoc(), newTy, 725 addrOp.getSymbol()); 726 replaceOp(addrOp, newOp.getResult()); 727 } 728 729 /// Convert the type signatures on all the functions present in the module. 730 /// As the type signature is being changed, this must also update the 731 /// function itself to use any new arguments, etc. 732 llvm::LogicalResult convertTypes(mlir::ModuleOp mod) { 733 mlir::MLIRContext *ctx = mod->getContext(); 734 auto targetCPU = specifics->getTargetCPU(); 735 mlir::StringAttr targetCPUAttr = 736 targetCPU.empty() ? nullptr : mlir::StringAttr::get(ctx, targetCPU); 737 auto tuneCPU = specifics->getTuneCPU(); 738 mlir::StringAttr tuneCPUAttr = 739 tuneCPU.empty() ? nullptr : mlir::StringAttr::get(ctx, tuneCPU); 740 auto targetFeaturesAttr = specifics->getTargetFeatures(); 741 742 for (auto fn : mod.getOps<mlir::func::FuncOp>()) { 743 if (targetCPUAttr) 744 fn->setAttr("target_cpu", targetCPUAttr); 745 746 if (tuneCPUAttr) 747 fn->setAttr("tune_cpu", tuneCPUAttr); 748 749 if (targetFeaturesAttr) 750 fn->setAttr("target_features", targetFeaturesAttr); 751 752 convertSignature<mlir::func::ReturnOp, mlir::func::FuncOp>(fn); 753 } 754 755 for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>()) { 756 for (auto fn : gpuMod.getOps<mlir::func::FuncOp>()) 757 convertSignature<mlir::func::ReturnOp, mlir::func::FuncOp>(fn); 758 for (auto fn : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) 759 convertSignature<mlir::gpu::ReturnOp, mlir::gpu::GPUFuncOp>(fn); 760 } 761 762 return mlir::success(); 763 } 764 765 // Returns true if the function should be interoperable with C. 766 static bool isFuncWithCCallingConvention(mlir::Operation *op) { 767 auto funcOp = mlir::dyn_cast<mlir::func::FuncOp>(op); 768 if (!funcOp) 769 return false; 770 return op->hasAttrOfType<mlir::UnitAttr>( 771 fir::FIROpsDialect::getFirRuntimeAttrName()) || 772 op->hasAttrOfType<mlir::StringAttr>(fir::getSymbolAttrName()); 773 } 774 775 /// If the signature does not need any special target-specific conversions, 776 /// then it is considered portable for any target, and this function will 777 /// return `true`. Otherwise, the signature is not portable and `false` is 778 /// returned. 779 bool hasPortableSignature(mlir::Type signature, mlir::Operation *op) { 780 assert(mlir::isa<mlir::FunctionType>(signature)); 781 auto func = mlir::dyn_cast<mlir::FunctionType>(signature); 782 bool hasCCallingConv = isFuncWithCCallingConvention(op); 783 for (auto ty : func.getResults()) 784 if ((mlir::isa<fir::BoxCharType>(ty) && !noCharacterConversion) || 785 (fir::isa_complex(ty) && !noComplexConversion) || 786 (mlir::isa<mlir::IntegerType>(ty) && hasCCallingConv) || 787 (mlir::isa<fir::RecordType>(ty) && !noStructConversion)) { 788 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 789 return false; 790 } 791 for (auto ty : func.getInputs()) 792 if (((mlir::isa<fir::BoxCharType>(ty) || 793 fir::isCharacterProcedureTuple(ty)) && 794 !noCharacterConversion) || 795 (fir::isa_complex(ty) && !noComplexConversion) || 796 (mlir::isa<mlir::IntegerType>(ty) && hasCCallingConv) || 797 (mlir::isa<fir::RecordType>(ty) && !noStructConversion)) { 798 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 799 return false; 800 } 801 return true; 802 } 803 804 /// Determine if the signature has host associations. The host association 805 /// argument may need special target specific rewriting. 806 template <typename OpTy> 807 static bool hasHostAssociations(OpTy func) { 808 std::size_t end = func.getFunctionType().getInputs().size(); 809 for (std::size_t i = 0; i < end; ++i) 810 if (func.template getArgAttrOfType<mlir::UnitAttr>( 811 i, fir::getHostAssocAttrName())) 812 return true; 813 return false; 814 } 815 816 /// Rewrite the signatures and body of the `FuncOp`s in the module for 817 /// the immediately subsequent target code gen. 818 template <typename ReturnOpTy, typename FuncOpTy> 819 void convertSignature(FuncOpTy func) { 820 auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType()); 821 if (hasPortableSignature(funcTy, func) && !hasHostAssociations(func)) 822 return; 823 llvm::SmallVector<mlir::Type> newResTys; 824 fir::CodeGenSpecifics::Marshalling newInTyAndAttrs; 825 llvm::SmallVector<std::pair<unsigned, mlir::NamedAttribute>> savedAttrs; 826 llvm::SmallVector<std::pair<unsigned, mlir::NamedAttribute>> extraAttrs; 827 llvm::SmallVector<FixupTy> fixups; 828 llvm::SmallVector<std::pair<unsigned, mlir::NamedAttrList>, 1> resultAttrs; 829 830 // Save argument attributes in case there is a shift so we can replace them 831 // correctly. 832 for (auto e : llvm::enumerate(funcTy.getInputs())) { 833 unsigned index = e.index(); 834 llvm::ArrayRef<mlir::NamedAttribute> attrs = 835 mlir::function_interface_impl::getArgAttrs(func, index); 836 for (mlir::NamedAttribute attr : attrs) { 837 savedAttrs.push_back({index, attr}); 838 } 839 } 840 841 // Convert return value(s) 842 for (auto ty : funcTy.getResults()) 843 llvm::TypeSwitch<mlir::Type>(ty) 844 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 845 if (noComplexConversion) 846 newResTys.push_back(cmplx); 847 else 848 doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups); 849 }) 850 .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) { 851 auto m = specifics->integerArgumentType(func.getLoc(), intTy); 852 assert(m.size() == 1); 853 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]); 854 auto retTy = std::get<mlir::Type>(m[0]); 855 std::size_t resId = newResTys.size(); 856 llvm::StringRef extensionAttrName = attr.getIntExtensionAttrName(); 857 if (!extensionAttrName.empty() && 858 isFuncWithCCallingConvention(func)) 859 resultAttrs.emplace_back( 860 resId, rewriter->getNamedAttr(extensionAttrName, 861 rewriter->getUnitAttr())); 862 newResTys.push_back(retTy); 863 }) 864 .template Case<fir::RecordType>([&](fir::RecordType recTy) { 865 doStructReturn(func, recTy, newResTys, newInTyAndAttrs, fixups); 866 }) 867 .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); 868 869 // Saved potential shift in argument. Handling of result can add arguments 870 // at the beginning of the function signature. 871 unsigned argumentShift = newInTyAndAttrs.size(); 872 873 // Convert arguments 874 llvm::SmallVector<mlir::Type> trailingTys; 875 for (auto e : llvm::enumerate(funcTy.getInputs())) { 876 auto ty = e.value(); 877 unsigned index = e.index(); 878 llvm::TypeSwitch<mlir::Type>(ty) 879 .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) { 880 if (noCharacterConversion) { 881 newInTyAndAttrs.push_back( 882 fir::CodeGenSpecifics::getTypeAndAttr(boxTy)); 883 } else { 884 // Convert a CHARACTER argument type. This can involve separating 885 // the pointer and the LEN into two arguments and moving the LEN 886 // argument to the end of the arg list. 887 for (auto &tup : 888 specifics->boxcharArgumentType(boxTy.getEleTy())) { 889 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup); 890 auto argTy = std::get<mlir::Type>(tup); 891 if (attr.isAppend()) { 892 trailingTys.push_back(argTy); 893 } else { 894 fixups.emplace_back(FixupTy::Codes::Trailing, 895 newInTyAndAttrs.size(), 896 trailingTys.size()); 897 newInTyAndAttrs.push_back(tup); 898 } 899 } 900 } 901 }) 902 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { 903 doComplexArg(func, cmplx, newInTyAndAttrs, fixups); 904 }) 905 .template Case<mlir::TupleType>([&](mlir::TupleType tuple) { 906 if (fir::isCharacterProcedureTuple(tuple)) { 907 fixups.emplace_back(FixupTy::Codes::TrailingCharProc, 908 newInTyAndAttrs.size(), trailingTys.size()); 909 newInTyAndAttrs.push_back( 910 fir::CodeGenSpecifics::getTypeAndAttr(tuple.getType(0))); 911 trailingTys.push_back(tuple.getType(1)); 912 } else { 913 newInTyAndAttrs.push_back( 914 fir::CodeGenSpecifics::getTypeAndAttr(ty)); 915 } 916 }) 917 .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) { 918 auto m = specifics->integerArgumentType(func.getLoc(), intTy); 919 assert(m.size() == 1); 920 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]); 921 auto argNo = newInTyAndAttrs.size(); 922 llvm::StringRef extensionAttrName = attr.getIntExtensionAttrName(); 923 if (!extensionAttrName.empty() && 924 isFuncWithCCallingConvention(func)) 925 fixups.emplace_back(FixupTy::Codes::ArgumentType, argNo, 926 [=](FuncOpTy func) { 927 func.setArgAttr( 928 argNo, extensionAttrName, 929 mlir::UnitAttr::get(func.getContext())); 930 }); 931 932 newInTyAndAttrs.push_back(m[0]); 933 }) 934 .template Case<fir::RecordType>([&](fir::RecordType recTy) { 935 doStructArg(func, recTy, newInTyAndAttrs, fixups); 936 }) 937 .Default([&](mlir::Type ty) { 938 newInTyAndAttrs.push_back( 939 fir::CodeGenSpecifics::getTypeAndAttr(ty)); 940 }); 941 942 if (func.template getArgAttrOfType<mlir::UnitAttr>( 943 index, fir::getHostAssocAttrName())) { 944 extraAttrs.push_back( 945 {newInTyAndAttrs.size() - 1, 946 rewriter->getNamedAttr("llvm.nest", rewriter->getUnitAttr())}); 947 } 948 } 949 950 if (!func.empty()) { 951 // If the function has a body, then apply the fixups to the arguments and 952 // return ops as required. These fixups are done in place. 953 auto loc = func.getLoc(); 954 const auto fixupSize = fixups.size(); 955 const auto oldArgTys = func.getFunctionType().getInputs(); 956 int offset = 0; 957 for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) { 958 const auto &fixup = fixups[i]; 959 mlir::Type fixupType = 960 fixup.index < newInTyAndAttrs.size() 961 ? std::get<mlir::Type>(newInTyAndAttrs[fixup.index]) 962 : mlir::Type{}; 963 switch (fixup.code) { 964 case FixupTy::Codes::ArgumentAsLoad: { 965 // Argument was pass-by-value, but is now pass-by-reference and 966 // possibly with a different element type. 967 auto newArg = 968 func.front().insertArgument(fixup.index, fixupType, loc); 969 rewriter->setInsertionPointToStart(&func.front()); 970 auto oldArgTy = 971 fir::ReferenceType::get(oldArgTys[fixup.index - offset]); 972 auto cast = rewriter->create<fir::ConvertOp>(loc, oldArgTy, newArg); 973 auto load = rewriter->create<fir::LoadOp>(loc, cast); 974 func.getArgument(fixup.index + 1).replaceAllUsesWith(load); 975 func.front().eraseArgument(fixup.index + 1); 976 } break; 977 case FixupTy::Codes::ArgumentType: { 978 // Argument is pass-by-value, but its type has likely been modified to 979 // suit the target ABI convention. 980 auto oldArgTy = oldArgTys[fixup.index - offset]; 981 // If type did not change, keep the original argument. 982 if (fixupType == oldArgTy) 983 break; 984 985 auto newArg = 986 func.front().insertArgument(fixup.index, fixupType, loc); 987 rewriter->setInsertionPointToStart(&func.front()); 988 mlir::Value bitcast = convertValueInMemory(loc, newArg, oldArgTy, 989 /*inputMayBeBigger=*/true); 990 func.getArgument(fixup.index + 1).replaceAllUsesWith(bitcast); 991 func.front().eraseArgument(fixup.index + 1); 992 LLVM_DEBUG(llvm::dbgs() 993 << "old argument: " << oldArgTy << ", repl: " << bitcast 994 << ", new argument: " 995 << func.getArgument(fixup.index).getType() << '\n'); 996 } break; 997 case FixupTy::Codes::CharPair: { 998 // The FIR boxchar argument has been split into a pair of distinct 999 // arguments that are in juxtaposition to each other. 1000 auto newArg = 1001 func.front().insertArgument(fixup.index, fixupType, loc); 1002 if (fixup.second == 1) { 1003 rewriter->setInsertionPointToStart(&func.front()); 1004 auto boxTy = oldArgTys[fixup.index - offset - fixup.second]; 1005 auto box = rewriter->create<fir::EmboxCharOp>( 1006 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg); 1007 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 1008 func.front().eraseArgument(fixup.index + 1); 1009 offset++; 1010 } 1011 } break; 1012 case FixupTy::Codes::ReturnAsStore: { 1013 // The value being returned is now being returned in memory (callee 1014 // stack space) through a hidden reference argument. 1015 auto newArg = 1016 func.front().insertArgument(fixup.index, fixupType, loc); 1017 offset++; 1018 func.walk([&](ReturnOpTy ret) { 1019 rewriter->setInsertionPoint(ret); 1020 auto oldOper = ret.getOperand(0); 1021 auto oldOperTy = fir::ReferenceType::get(oldOper.getType()); 1022 auto cast = 1023 rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg); 1024 rewriter->create<fir::StoreOp>(loc, oldOper, cast); 1025 rewriter->create<ReturnOpTy>(loc); 1026 ret.erase(); 1027 }); 1028 } break; 1029 case FixupTy::Codes::ReturnType: { 1030 // The function is still returning a value, but its type has likely 1031 // changed to suit the target ABI convention. 1032 func.walk([&](ReturnOpTy ret) { 1033 rewriter->setInsertionPoint(ret); 1034 auto oldOper = ret.getOperand(0); 1035 mlir::Value bitcast = 1036 convertValueInMemory(loc, oldOper, newResTys[fixup.index], 1037 /*inputMayBeBigger=*/false); 1038 rewriter->create<ReturnOpTy>(loc, bitcast); 1039 ret.erase(); 1040 }); 1041 } break; 1042 case FixupTy::Codes::Split: { 1043 // The FIR argument has been split into a pair of distinct arguments 1044 // that are in juxtaposition to each other. (For COMPLEX value or 1045 // derived type passed with VALUE in BIND(C) context). 1046 auto newArg = 1047 func.front().insertArgument(fixup.index, fixupType, loc); 1048 if (fixup.second == 1) { 1049 rewriter->setInsertionPointToStart(&func.front()); 1050 mlir::Value firstArg = func.front().getArgument(fixup.index - 1); 1051 mlir::Type originalTy = 1052 oldArgTys[fixup.index - offset - fixup.second]; 1053 mlir::Type pairTy = originalTy; 1054 if (!fir::isa_complex(originalTy)) { 1055 pairTy = mlir::TupleType::get( 1056 originalTy.getContext(), 1057 mlir::TypeRange{firstArg.getType(), newArg.getType()}); 1058 } 1059 auto undef = rewriter->create<fir::UndefOp>(loc, pairTy); 1060 auto iTy = rewriter->getIntegerType(32); 1061 auto zero = rewriter->getIntegerAttr(iTy, 0); 1062 auto one = rewriter->getIntegerAttr(iTy, 1); 1063 mlir::Value pair1 = rewriter->create<fir::InsertValueOp>( 1064 loc, pairTy, undef, firstArg, rewriter->getArrayAttr(zero)); 1065 mlir::Value pair = rewriter->create<fir::InsertValueOp>( 1066 loc, pairTy, pair1, newArg, rewriter->getArrayAttr(one)); 1067 // Cast local argument tuple to original type via memory if needed. 1068 if (pairTy != originalTy) 1069 pair = convertValueInMemory(loc, pair, originalTy, 1070 /*inputMayBeBigger=*/true); 1071 func.getArgument(fixup.index + 1).replaceAllUsesWith(pair); 1072 func.front().eraseArgument(fixup.index + 1); 1073 offset++; 1074 } 1075 } break; 1076 case FixupTy::Codes::Trailing: { 1077 // The FIR argument has been split into a pair of distinct arguments. 1078 // The first part of the pair appears in the original argument 1079 // position. The second part of the pair is appended after all the 1080 // original arguments. (Boxchar arguments.) 1081 auto newBufArg = 1082 func.front().insertArgument(fixup.index, fixupType, loc); 1083 auto newLenArg = 1084 func.front().addArgument(trailingTys[fixup.second], loc); 1085 auto boxTy = oldArgTys[fixup.index - offset]; 1086 rewriter->setInsertionPointToStart(&func.front()); 1087 auto box = rewriter->create<fir::EmboxCharOp>(loc, boxTy, newBufArg, 1088 newLenArg); 1089 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 1090 func.front().eraseArgument(fixup.index + 1); 1091 } break; 1092 case FixupTy::Codes::TrailingCharProc: { 1093 // The FIR character procedure argument tuple must be split into a 1094 // pair of distinct arguments. The first part of the pair appears in 1095 // the original argument position. The second part of the pair is 1096 // appended after all the original arguments. 1097 auto newProcPointerArg = 1098 func.front().insertArgument(fixup.index, fixupType, loc); 1099 auto newLenArg = 1100 func.front().addArgument(trailingTys[fixup.second], loc); 1101 auto tupleType = oldArgTys[fixup.index - offset]; 1102 rewriter->setInsertionPointToStart(&func.front()); 1103 fir::FirOpBuilder builder(*rewriter, getModule()); 1104 auto tuple = fir::factory::createCharacterProcedureTuple( 1105 builder, loc, tupleType, newProcPointerArg, newLenArg); 1106 func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple); 1107 func.front().eraseArgument(fixup.index + 1); 1108 } break; 1109 } 1110 } 1111 } 1112 1113 llvm::SmallVector<mlir::Type> newInTypes = toTypeList(newInTyAndAttrs); 1114 // Set the new type and finalize the arguments, etc. 1115 newInTypes.insert(newInTypes.end(), trailingTys.begin(), trailingTys.end()); 1116 auto newFuncTy = 1117 mlir::FunctionType::get(func.getContext(), newInTypes, newResTys); 1118 LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n'); 1119 func.setType(newFuncTy); 1120 1121 for (std::pair<unsigned, mlir::NamedAttribute> extraAttr : extraAttrs) 1122 func.setArgAttr(extraAttr.first, extraAttr.second.getName(), 1123 extraAttr.second.getValue()); 1124 1125 for (auto [resId, resAttrList] : resultAttrs) 1126 for (mlir::NamedAttribute resAttr : resAttrList) 1127 func.setResultAttr(resId, resAttr.getName(), resAttr.getValue()); 1128 1129 // Replace attributes to the correct argument if there was an argument shift 1130 // to the right. 1131 if (argumentShift > 0) { 1132 for (std::pair<unsigned, mlir::NamedAttribute> savedAttr : savedAttrs) { 1133 func.removeArgAttr(savedAttr.first, savedAttr.second.getName()); 1134 func.setArgAttr(savedAttr.first + argumentShift, 1135 savedAttr.second.getName(), 1136 savedAttr.second.getValue()); 1137 } 1138 } 1139 1140 for (auto &fixup : fixups) { 1141 if constexpr (std::is_same_v<FuncOpTy, mlir::func::FuncOp>) 1142 if (fixup.finalizer) 1143 (*fixup.finalizer)(func); 1144 if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>) 1145 if (fixup.gpuFinalizer) 1146 (*fixup.gpuFinalizer)(func); 1147 } 1148 } 1149 1150 template <typename OpTy, typename Ty, typename FIXUPS> 1151 void doReturn(OpTy func, Ty &newResTys, 1152 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, 1153 FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) { 1154 assert(m.size() == 1 && 1155 "expect result to be turned into single argument or result so far"); 1156 auto &tup = m[0]; 1157 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup); 1158 auto argTy = std::get<mlir::Type>(tup); 1159 if (attr.isSRet()) { 1160 unsigned argNo = newInTyAndAttrs.size(); 1161 if (auto align = attr.getAlignment()) 1162 fixups.emplace_back( 1163 FixupTy::Codes::ReturnAsStore, argNo, [=](OpTy func) { 1164 auto elemType = fir::dyn_cast_ptrOrBoxEleTy( 1165 func.getFunctionType().getInput(argNo)); 1166 func.setArgAttr(argNo, "llvm.sret", 1167 mlir::TypeAttr::get(elemType)); 1168 func.setArgAttr(argNo, "llvm.align", 1169 rewriter->getIntegerAttr( 1170 rewriter->getIntegerType(32), align)); 1171 }); 1172 else 1173 fixups.emplace_back(FixupTy::Codes::ReturnAsStore, argNo, 1174 [=](OpTy func) { 1175 auto elemType = fir::dyn_cast_ptrOrBoxEleTy( 1176 func.getFunctionType().getInput(argNo)); 1177 func.setArgAttr(argNo, "llvm.sret", 1178 mlir::TypeAttr::get(elemType)); 1179 }); 1180 newInTyAndAttrs.push_back(tup); 1181 return; 1182 } 1183 if (auto align = attr.getAlignment()) 1184 fixups.emplace_back( 1185 FixupTy::Codes::ReturnType, newResTys.size(), [=](OpTy func) { 1186 func.setArgAttr( 1187 newResTys.size(), "llvm.align", 1188 rewriter->getIntegerAttr(rewriter->getIntegerType(32), align)); 1189 }); 1190 else 1191 fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size()); 1192 newResTys.push_back(argTy); 1193 } 1194 1195 /// Convert a complex return value. This can involve converting the return 1196 /// value to a "hidden" first argument or packing the complex into a wide 1197 /// GPR. 1198 template <typename OpTy, typename Ty, typename FIXUPS> 1199 void doComplexReturn(OpTy func, mlir::ComplexType cmplx, Ty &newResTys, 1200 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, 1201 FIXUPS &fixups) { 1202 if (noComplexConversion) { 1203 newResTys.push_back(cmplx); 1204 return; 1205 } 1206 auto m = 1207 specifics->complexReturnType(func.getLoc(), cmplx.getElementType()); 1208 doReturn(func, newResTys, newInTyAndAttrs, fixups, m); 1209 } 1210 1211 template <typename OpTy, typename Ty, typename FIXUPS> 1212 void doStructReturn(OpTy func, fir::RecordType recTy, Ty &newResTys, 1213 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, 1214 FIXUPS &fixups) { 1215 if (noStructConversion) { 1216 newResTys.push_back(recTy); 1217 return; 1218 } 1219 auto m = specifics->structReturnType(func.getLoc(), recTy); 1220 doReturn(func, newResTys, newInTyAndAttrs, fixups, m); 1221 } 1222 1223 template <typename OpTy, typename FIXUPS> 1224 void createFuncOpArgFixups( 1225 OpTy func, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, 1226 fir::CodeGenSpecifics::Marshalling &argsInTys, FIXUPS &fixups) { 1227 const auto fixupCode = argsInTys.size() > 1 ? FixupTy::Codes::Split 1228 : FixupTy::Codes::ArgumentType; 1229 for (auto e : llvm::enumerate(argsInTys)) { 1230 auto &tup = e.value(); 1231 auto index = e.index(); 1232 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup); 1233 auto argNo = newInTyAndAttrs.size(); 1234 if (attr.isByVal()) { 1235 if (auto align = attr.getAlignment()) 1236 fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, argNo, 1237 [=](OpTy func) { 1238 auto elemType = fir::dyn_cast_ptrOrBoxEleTy( 1239 func.getFunctionType().getInput(argNo)); 1240 func.setArgAttr(argNo, "llvm.byval", 1241 mlir::TypeAttr::get(elemType)); 1242 func.setArgAttr( 1243 argNo, "llvm.align", 1244 rewriter->getIntegerAttr( 1245 rewriter->getIntegerType(32), align)); 1246 }); 1247 else 1248 fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, 1249 newInTyAndAttrs.size(), [=](OpTy func) { 1250 auto elemType = fir::dyn_cast_ptrOrBoxEleTy( 1251 func.getFunctionType().getInput(argNo)); 1252 func.setArgAttr(argNo, "llvm.byval", 1253 mlir::TypeAttr::get(elemType)); 1254 }); 1255 } else { 1256 if (auto align = attr.getAlignment()) 1257 fixups.emplace_back( 1258 fixupCode, argNo, index, [=](OpTy func) { 1259 func.setArgAttr(argNo, "llvm.align", 1260 rewriter->getIntegerAttr( 1261 rewriter->getIntegerType(32), align)); 1262 }); 1263 else 1264 fixups.emplace_back(fixupCode, argNo, index); 1265 } 1266 newInTyAndAttrs.push_back(tup); 1267 } 1268 } 1269 1270 /// Convert a complex argument value. This can involve storing the value to 1271 /// a temporary memory location or factoring the value into two distinct 1272 /// arguments. 1273 template <typename OpTy, typename FIXUPS> 1274 void doComplexArg(OpTy func, mlir::ComplexType cmplx, 1275 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, 1276 FIXUPS &fixups) { 1277 if (noComplexConversion) { 1278 newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(cmplx)); 1279 return; 1280 } 1281 auto cplxArgs = 1282 specifics->complexArgumentType(func.getLoc(), cmplx.getElementType()); 1283 createFuncOpArgFixups(func, newInTyAndAttrs, cplxArgs, fixups); 1284 } 1285 1286 template <typename OpTy, typename FIXUPS> 1287 void doStructArg(OpTy func, fir::RecordType recTy, 1288 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, 1289 FIXUPS &fixups) { 1290 if (noStructConversion) { 1291 newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(recTy)); 1292 return; 1293 } 1294 auto structArgs = 1295 specifics->structArgumentType(func.getLoc(), recTy, newInTyAndAttrs); 1296 createFuncOpArgFixups(func, newInTyAndAttrs, structArgs, fixups); 1297 } 1298 1299 private: 1300 // Replace `op` and remove it. 1301 void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { 1302 op->replaceAllUsesWith(newValues); 1303 op->dropAllReferences(); 1304 op->erase(); 1305 } 1306 1307 inline void setMembers(fir::CodeGenSpecifics *s, mlir::OpBuilder *r, 1308 mlir::DataLayout *dl) { 1309 specifics = s; 1310 rewriter = r; 1311 dataLayout = dl; 1312 } 1313 1314 inline void clearMembers() { setMembers(nullptr, nullptr, nullptr); } 1315 1316 // Inserts a call to llvm.stacksave at the current insertion 1317 // point and the given location. Returns the call's result Value. 1318 inline mlir::Value genStackSave(mlir::Location loc) { 1319 fir::FirOpBuilder builder(*rewriter, getModule()); 1320 return builder.genStackSave(loc); 1321 } 1322 1323 // Inserts a call to llvm.stackrestore at the current insertion 1324 // point and the given location and argument. 1325 inline void genStackRestore(mlir::Location loc, mlir::Value sp) { 1326 fir::FirOpBuilder builder(*rewriter, getModule()); 1327 return builder.genStackRestore(loc, sp); 1328 } 1329 1330 fir::CodeGenSpecifics *specifics = nullptr; 1331 mlir::OpBuilder *rewriter = nullptr; 1332 mlir::DataLayout *dataLayout = nullptr; 1333 }; 1334 } // namespace 1335