xref: /llvm-project/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp (revision 65e00315c92f53895c1d88912de8838d7790c3f0)
1 //===-- TargetRewrite.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest.
10 // LLVM expects different lowering idioms to be used for distinct target
11 // triples. These distinctions are handled by this pass.
12 //
13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "flang/Optimizer/CodeGen/CodeGen.h"
18 
19 #include "flang/Optimizer/Builder/Character.h"
20 #include "flang/Optimizer/Builder/FIRBuilder.h"
21 #include "flang/Optimizer/Builder/Todo.h"
22 #include "flang/Optimizer/CodeGen/Target.h"
23 #include "flang/Optimizer/Dialect/FIRDialect.h"
24 #include "flang/Optimizer/Dialect/FIROps.h"
25 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
26 #include "flang/Optimizer/Dialect/FIRType.h"
27 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
28 #include "flang/Optimizer/Support/DataLayout.h"
29 #include "mlir/Dialect/DLTI/DLTI.h"
30 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
31 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
32 #include "mlir/Transforms/DialectConversion.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include <optional>
37 
38 namespace fir {
39 #define GEN_PASS_DEF_TARGETREWRITEPASS
40 #include "flang/Optimizer/CodeGen/CGPasses.h.inc"
41 } // namespace fir
42 
43 #define DEBUG_TYPE "flang-target-rewrite"
44 
45 namespace {
46 
47 /// Fixups for updating a FuncOp's arguments and return values.
48 struct FixupTy {
49   enum class Codes {
50     ArgumentAsLoad,
51     ArgumentType,
52     CharPair,
53     ReturnAsStore,
54     ReturnType,
55     Split,
56     Trailing,
57     TrailingCharProc
58   };
59 
60   FixupTy(Codes code, std::size_t index, std::size_t second = 0)
61       : code{code}, index{index}, second{second} {}
62   FixupTy(Codes code, std::size_t index,
63           std::function<void(mlir::func::FuncOp)> &&finalizer)
64       : code{code}, index{index}, finalizer{finalizer} {}
65   FixupTy(Codes code, std::size_t index,
66           std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
67       : code{code}, index{index}, gpuFinalizer{finalizer} {}
68   FixupTy(Codes code, std::size_t index, std::size_t second,
69           std::function<void(mlir::func::FuncOp)> &&finalizer)
70       : code{code}, index{index}, second{second}, finalizer{finalizer} {}
71   FixupTy(Codes code, std::size_t index, std::size_t second,
72           std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
73       : code{code}, index{index}, second{second}, gpuFinalizer{finalizer} {}
74 
75   Codes code;
76   std::size_t index;
77   std::size_t second{};
78   std::optional<std::function<void(mlir::func::FuncOp)>> finalizer{};
79   std::optional<std::function<void(mlir::gpu::GPUFuncOp)>> gpuFinalizer{};
80 }; // namespace
81 
82 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code
83 /// generation that traverses the FIR and modifies types and operations to a
84 /// form that is appropriate for the specific target. LLVM IR has specific
85 /// idioms that are used for distinct target processor and ABI combinations.
86 class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
87 public:
88   using TargetRewritePassBase<TargetRewrite>::TargetRewritePassBase;
89 
90   void runOnOperation() override final {
91     auto &context = getContext();
92     mlir::OpBuilder rewriter(&context);
93 
94     auto mod = getModule();
95     if (!forcedTargetTriple.empty())
96       fir::setTargetTriple(mod, forcedTargetTriple);
97 
98     if (!forcedTargetCPU.empty())
99       fir::setTargetCPU(mod, forcedTargetCPU);
100 
101     if (!forcedTuneCPU.empty())
102       fir::setTuneCPU(mod, forcedTuneCPU);
103 
104     if (!forcedTargetFeatures.empty())
105       fir::setTargetFeatures(mod, forcedTargetFeatures);
106 
107     // TargetRewrite will require querying the type storage sizes, if it was
108     // not set already, create a DataLayoutSpec for the ModuleOp now.
109     std::optional<mlir::DataLayout> dl =
110         fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true);
111     if (!dl) {
112       mlir::emitError(mod.getLoc(),
113                       "module operation must carry a data layout attribute "
114                       "to perform target ABI rewrites on FIR");
115       signalPassFailure();
116       return;
117     }
118 
119     auto specifics = fir::CodeGenSpecifics::get(
120         mod.getContext(), fir::getTargetTriple(mod), fir::getKindMapping(mod),
121         fir::getTargetCPU(mod), fir::getTargetFeatures(mod), *dl,
122         fir::getTuneCPU(mod));
123 
124     setMembers(specifics.get(), &rewriter, &*dl);
125 
126     // Perform type conversion on signatures and call sites.
127     if (mlir::failed(convertTypes(mod))) {
128       mlir::emitError(mlir::UnknownLoc::get(&context),
129                       "error in converting types to target abi");
130       signalPassFailure();
131     }
132 
133     // Convert ops in target-specific patterns.
134     mod.walk([&](mlir::Operation *op) {
135       if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
136         if (!hasPortableSignature(call.getFunctionType(), op))
137           convertCallOp(call, call.getFunctionType());
138       } else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) {
139         if (!hasPortableSignature(dispatch.getFunctionType(), op))
140           convertCallOp(dispatch, dispatch.getFunctionType());
141       } else if (auto gpuLaunchFunc =
142                      mlir::dyn_cast<mlir::gpu::LaunchFuncOp>(op)) {
143         llvm::SmallVector<mlir::Type> operandsTypes;
144         for (auto arg : gpuLaunchFunc.getKernelOperands())
145           operandsTypes.push_back(arg.getType());
146         auto fctTy = mlir::FunctionType::get(&context, operandsTypes, {});
147         if (!hasPortableSignature(fctTy, op))
148           convertCallOp(gpuLaunchFunc, fctTy);
149       } else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) {
150         if (mlir::isa<mlir::FunctionType>(addr.getType()) &&
151             !hasPortableSignature(addr.getType(), op))
152           convertAddrOp(addr);
153       }
154     });
155 
156     clearMembers();
157   }
158 
159   mlir::ModuleOp getModule() { return getOperation(); }
160 
161   template <typename Ty, typename Callback>
162   std::optional<std::function<mlir::Value(mlir::Operation *)>>
163   rewriteCallResultType(mlir::Location loc, mlir::Type originalResTy,
164                         Ty &newResTys,
165                         fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
166                         Callback &newOpers, mlir::Value &savedStackPtr,
167                         fir::CodeGenSpecifics::Marshalling &m) {
168     // Currently, targets mandate COMPLEX or STRUCT is a single aggregate or
169     // packed scalar, including the sret case.
170     assert(m.size() == 1 && "return type not supported on this target");
171     auto resTy = std::get<mlir::Type>(m[0]);
172     auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
173     if (attr.isSRet()) {
174       assert(fir::isa_ref_type(resTy) && "must be a memory reference type");
175       // Save the stack pointer, if it has not been saved for this call yet.
176       // We will need to restore it after the call, because the alloca
177       // needs to be deallocated.
178       if (!savedStackPtr)
179         savedStackPtr = genStackSave(loc);
180       mlir::Value stack =
181           rewriter->create<fir::AllocaOp>(loc, fir::dyn_cast_ptrEleTy(resTy));
182       newInTyAndAttrs.push_back(m[0]);
183       newOpers.push_back(stack);
184       return [=](mlir::Operation *) -> mlir::Value {
185         auto memTy = fir::ReferenceType::get(originalResTy);
186         auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, stack);
187         return rewriter->create<fir::LoadOp>(loc, cast);
188       };
189     }
190     newResTys.push_back(resTy);
191     return [=, &savedStackPtr](mlir::Operation *call) -> mlir::Value {
192       // We are going to generate an alloca, so save the stack pointer.
193       if (!savedStackPtr)
194         savedStackPtr = genStackSave(loc);
195       return this->convertValueInMemory(loc, call->getResult(0), originalResTy,
196                                         /*inputMayBeBigger=*/true);
197     };
198   }
199 
200   template <typename Ty, typename Callback>
201   std::optional<std::function<mlir::Value(mlir::Operation *)>>
202   rewriteCallComplexResultType(
203       mlir::Location loc, mlir::ComplexType ty, Ty &newResTys,
204       fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, Callback &newOpers,
205       mlir::Value &savedStackPtr) {
206     if (noComplexConversion) {
207       newResTys.push_back(ty);
208       return std::nullopt;
209     }
210     auto m = specifics->complexReturnType(loc, ty.getElementType());
211     return rewriteCallResultType(loc, ty, newResTys, newInTyAndAttrs, newOpers,
212                                  savedStackPtr, m);
213   }
214 
215   template <typename Ty, typename Callback>
216   std::optional<std::function<mlir::Value(mlir::Operation *)>>
217   rewriteCallStructResultType(
218       mlir::Location loc, fir::RecordType recTy, Ty &newResTys,
219       fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, Callback &newOpers,
220       mlir::Value &savedStackPtr) {
221     if (noStructConversion) {
222       newResTys.push_back(recTy);
223       return std::nullopt;
224     }
225     auto m = specifics->structReturnType(loc, recTy);
226     return rewriteCallResultType(loc, recTy, newResTys, newInTyAndAttrs,
227                                  newOpers, savedStackPtr, m);
228   }
229 
230   void passArgumentOnStackOrWithNewType(
231       mlir::Location loc, fir::CodeGenSpecifics::TypeAndAttr newTypeAndAttr,
232       mlir::Type oldType, mlir::Value oper,
233       llvm::SmallVectorImpl<mlir::Value> &newOpers,
234       mlir::Value &savedStackPtr) {
235     auto resTy = std::get<mlir::Type>(newTypeAndAttr);
236     auto attr = std::get<fir::CodeGenSpecifics::Attributes>(newTypeAndAttr);
237     // We are going to generate an alloca, so save the stack pointer.
238     if (!savedStackPtr)
239       savedStackPtr = genStackSave(loc);
240     if (attr.isByVal()) {
241       mlir::Value mem = rewriter->create<fir::AllocaOp>(loc, oldType);
242       rewriter->create<fir::StoreOp>(loc, oper, mem);
243       if (mem.getType() != resTy)
244         mem = rewriter->create<fir::ConvertOp>(loc, resTy, mem);
245       newOpers.push_back(mem);
246     } else {
247       mlir::Value bitcast =
248           convertValueInMemory(loc, oper, resTy, /*inputMayBeBigger=*/false);
249       newOpers.push_back(bitcast);
250     }
251   }
252 
253   // Do a bitcast (convert a value via its memory representation).
254   // The input and output types may have different storage sizes,
255   // "inputMayBeBigger" should be set to indicate which of the input or
256   // output type may be bigger in order for the load/store to be safe.
257   // The mismatch comes from the fact that the LLVM register used for passing
258   // may be bigger than the value being passed (e.g., passing
259   // a `!fir.type<t{fir.array<3xi8>}>` into an i32 LLVM register).
260   mlir::Value convertValueInMemory(mlir::Location loc, mlir::Value value,
261                                    mlir::Type newType, bool inputMayBeBigger) {
262     if (inputMayBeBigger) {
263       auto newRefTy = fir::ReferenceType::get(newType);
264       auto mem = rewriter->create<fir::AllocaOp>(loc, value.getType());
265       rewriter->create<fir::StoreOp>(loc, value, mem);
266       auto cast = rewriter->create<fir::ConvertOp>(loc, newRefTy, mem);
267       return rewriter->create<fir::LoadOp>(loc, cast);
268     } else {
269       auto oldRefTy = fir::ReferenceType::get(value.getType());
270       auto mem = rewriter->create<fir::AllocaOp>(loc, newType);
271       auto cast = rewriter->create<fir::ConvertOp>(loc, oldRefTy, mem);
272       rewriter->create<fir::StoreOp>(loc, value, cast);
273       return rewriter->create<fir::LoadOp>(loc, mem);
274     }
275   }
276 
277   void passSplitArgument(mlir::Location loc,
278                          fir::CodeGenSpecifics::Marshalling splitArgs,
279                          mlir::Type oldType, mlir::Value oper,
280                          llvm::SmallVectorImpl<mlir::Value> &newOpers,
281                          mlir::Value &savedStackPtr) {
282     // COMPLEX or struct argument split into separate arguments
283     if (!fir::isa_complex(oldType)) {
284       // Cast original operand to a tuple of the new arguments
285       // via memory.
286       llvm::SmallVector<mlir::Type> partTypes;
287       for (auto argPart : splitArgs)
288         partTypes.push_back(std::get<mlir::Type>(argPart));
289       mlir::Type tupleType =
290           mlir::TupleType::get(oldType.getContext(), partTypes);
291       if (!savedStackPtr)
292         savedStackPtr = genStackSave(loc);
293       oper = convertValueInMemory(loc, oper, tupleType,
294                                   /*inputMayBeBigger=*/false);
295     }
296     auto iTy = rewriter->getIntegerType(32);
297     for (auto e : llvm::enumerate(splitArgs)) {
298       auto &tup = e.value();
299       auto ty = std::get<mlir::Type>(tup);
300       auto index = e.index();
301       auto idx = rewriter->getIntegerAttr(iTy, index);
302       auto val = rewriter->create<fir::ExtractValueOp>(
303           loc, ty, oper, rewriter->getArrayAttr(idx));
304       newOpers.push_back(val);
305     }
306   }
307 
308   void rewriteCallOperands(
309       mlir::Location loc, fir::CodeGenSpecifics::Marshalling passArgAs,
310       mlir::Type originalArgTy, mlir::Value oper,
311       llvm::SmallVectorImpl<mlir::Value> &newOpers, mlir::Value &savedStackPtr,
312       fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
313     if (passArgAs.size() == 1) {
314       // COMPLEX or derived type is passed as a single argument.
315       passArgumentOnStackOrWithNewType(loc, passArgAs[0], originalArgTy, oper,
316                                        newOpers, savedStackPtr);
317     } else {
318       // COMPLEX or derived type is split into separate arguments
319       passSplitArgument(loc, passArgAs, originalArgTy, oper, newOpers,
320                         savedStackPtr);
321     }
322     newInTyAndAttrs.insert(newInTyAndAttrs.end(), passArgAs.begin(),
323                            passArgAs.end());
324   }
325 
326   template <typename CPLX>
327   void rewriteCallComplexInputType(
328       mlir::Location loc, CPLX ty, mlir::Value oper,
329       fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
330       llvm::SmallVectorImpl<mlir::Value> &newOpers,
331       mlir::Value &savedStackPtr) {
332     if (noComplexConversion) {
333       newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(ty));
334       newOpers.push_back(oper);
335       return;
336     }
337     auto m = specifics->complexArgumentType(loc, ty.getElementType());
338     rewriteCallOperands(loc, m, ty, oper, newOpers, savedStackPtr,
339                         newInTyAndAttrs);
340   }
341 
342   void rewriteCallStructInputType(
343       mlir::Location loc, fir::RecordType recTy, mlir::Value oper,
344       fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
345       llvm::SmallVectorImpl<mlir::Value> &newOpers,
346       mlir::Value &savedStackPtr) {
347     if (noStructConversion) {
348       newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(recTy));
349       newOpers.push_back(oper);
350       return;
351     }
352     auto structArgs =
353         specifics->structArgumentType(loc, recTy, newInTyAndAttrs);
354     rewriteCallOperands(loc, structArgs, recTy, oper, newOpers, savedStackPtr,
355                         newInTyAndAttrs);
356   }
357 
358   static bool hasByValOrSRetArgs(
359       const fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
360     return llvm::any_of(newInTyAndAttrs, [](auto arg) {
361       const auto &attr = std::get<fir::CodeGenSpecifics::Attributes>(arg);
362       return attr.isByVal() || attr.isSRet();
363     });
364   }
365 
366   // Convert fir.call and fir.dispatch Ops.
367   template <typename A>
368   void convertCallOp(A callOp, mlir::FunctionType fnTy) {
369     auto loc = callOp.getLoc();
370     rewriter->setInsertionPoint(callOp);
371     llvm::SmallVector<mlir::Type> newResTys;
372     fir::CodeGenSpecifics::Marshalling newInTyAndAttrs;
373     llvm::SmallVector<mlir::Value> newOpers;
374     mlir::Value savedStackPtr = nullptr;
375 
376     // If the call is indirect, the first argument must still be the function
377     // to call.
378     int dropFront = 0;
379     if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
380       if (!callOp.getCallee()) {
381         newInTyAndAttrs.push_back(
382             fir::CodeGenSpecifics::getTypeAndAttr(fnTy.getInput(0)));
383         newOpers.push_back(callOp.getOperand(0));
384         dropFront = 1;
385       }
386     } else if constexpr (std::is_same_v<std::decay_t<A>, fir::DispatchOp>) {
387       dropFront = 1; // First operand is the polymorphic object.
388     }
389 
390     // Determine the rewrite function, `wrap`, for the result value.
391     std::optional<std::function<mlir::Value(mlir::Operation *)>> wrap;
392     if (fnTy.getResults().size() == 1) {
393       mlir::Type ty = fnTy.getResult(0);
394       llvm::TypeSwitch<mlir::Type>(ty)
395           .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
396             wrap = rewriteCallComplexResultType(loc, cmplx, newResTys,
397                                                 newInTyAndAttrs, newOpers,
398                                                 savedStackPtr);
399           })
400           .template Case<fir::RecordType>([&](fir::RecordType recTy) {
401             wrap = rewriteCallStructResultType(loc, recTy, newResTys,
402                                                newInTyAndAttrs, newOpers,
403                                                savedStackPtr);
404           })
405           .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
406     } else if (fnTy.getResults().size() > 1) {
407       TODO(loc, "multiple results not supported yet");
408     }
409 
410     llvm::SmallVector<mlir::Type> trailingInTys;
411     llvm::SmallVector<mlir::Value> trailingOpers;
412     llvm::SmallVector<mlir::Value> operands;
413     unsigned passArgShift = 0;
414     if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>)
415       operands = callOp.getKernelOperands();
416     else
417       operands = callOp.getOperands().drop_front(dropFront);
418     for (auto e : llvm::enumerate(
419              llvm::zip(fnTy.getInputs().drop_front(dropFront), operands))) {
420       mlir::Type ty = std::get<0>(e.value());
421       mlir::Value oper = std::get<1>(e.value());
422       unsigned index = e.index();
423       llvm::TypeSwitch<mlir::Type>(ty)
424           .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
425             if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
426               if (noCharacterConversion) {
427                 newInTyAndAttrs.push_back(
428                     fir::CodeGenSpecifics::getTypeAndAttr(boxTy));
429                 newOpers.push_back(oper);
430                 return;
431               }
432             } else {
433               // TODO: dispatch case; it used to be a to-do because of sret,
434               // but is not tested and maybe should be removed. This pass is
435               // anyway ran after lowering fir.dispatch in flang, so maybe that
436               // should just be a requirement of the pass.
437               TODO(loc, "ABI of fir.dispatch with character arguments");
438             }
439             auto m = specifics->boxcharArgumentType(boxTy.getEleTy());
440             auto unbox = rewriter->create<fir::UnboxCharOp>(
441                 loc, std::get<mlir::Type>(m[0]), std::get<mlir::Type>(m[1]),
442                 oper);
443             // unboxed CHARACTER arguments
444             for (auto e : llvm::enumerate(m)) {
445               unsigned idx = e.index();
446               auto attr =
447                   std::get<fir::CodeGenSpecifics::Attributes>(e.value());
448               auto argTy = std::get<mlir::Type>(e.value());
449               if (attr.isAppend()) {
450                 trailingInTys.push_back(argTy);
451                 trailingOpers.push_back(unbox.getResult(idx));
452               } else {
453                 newInTyAndAttrs.push_back(e.value());
454                 newOpers.push_back(unbox.getResult(idx));
455               }
456             }
457           })
458           .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
459             rewriteCallComplexInputType(loc, cmplx, oper, newInTyAndAttrs,
460                                         newOpers, savedStackPtr);
461           })
462           .template Case<fir::RecordType>([&](fir::RecordType recTy) {
463             rewriteCallStructInputType(loc, recTy, oper, newInTyAndAttrs,
464                                        newOpers, savedStackPtr);
465           })
466           .template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
467             if (fir::isCharacterProcedureTuple(tuple)) {
468               mlir::ModuleOp module = getModule();
469               if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
470                 if (callOp.getCallee()) {
471                   llvm::StringRef charProcAttr =
472                       fir::getCharacterProcedureDummyAttrName();
473                   // The charProcAttr attribute is only used as a safety to
474                   // confirm that this is a dummy procedure and should be split.
475                   // It cannot be used to match because attributes are not
476                   // available in case of indirect calls.
477                   auto funcOp = module.lookupSymbol<mlir::func::FuncOp>(
478                       *callOp.getCallee());
479                   if (funcOp &&
480                       !funcOp.template getArgAttrOfType<mlir::UnitAttr>(
481                           index, charProcAttr))
482                     mlir::emitError(loc, "tuple argument will be split even "
483                                          "though it does not have the `" +
484                                              charProcAttr + "` attribute");
485                 }
486               }
487               mlir::Type funcPointerType = tuple.getType(0);
488               mlir::Type lenType = tuple.getType(1);
489               fir::FirOpBuilder builder(*rewriter, module);
490               auto [funcPointer, len] =
491                   fir::factory::extractCharacterProcedureTuple(builder, loc,
492                                                                oper);
493               newInTyAndAttrs.push_back(
494                   fir::CodeGenSpecifics::getTypeAndAttr(funcPointerType));
495               newOpers.push_back(funcPointer);
496               trailingInTys.push_back(lenType);
497               trailingOpers.push_back(len);
498             } else {
499               newInTyAndAttrs.push_back(
500                   fir::CodeGenSpecifics::getTypeAndAttr(tuple));
501               newOpers.push_back(oper);
502             }
503           })
504           .Default([&](mlir::Type ty) {
505             if constexpr (std::is_same_v<std::decay_t<A>, fir::DispatchOp>) {
506               if (callOp.getPassArgPos() && *callOp.getPassArgPos() == index)
507                 passArgShift = newOpers.size() - *callOp.getPassArgPos();
508             }
509             newInTyAndAttrs.push_back(
510                 fir::CodeGenSpecifics::getTypeAndAttr(ty));
511             newOpers.push_back(oper);
512           });
513     }
514 
515     llvm::SmallVector<mlir::Type> newInTypes = toTypeList(newInTyAndAttrs);
516     newInTypes.insert(newInTypes.end(), trailingInTys.begin(),
517                       trailingInTys.end());
518     newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
519 
520     llvm::SmallVector<mlir::Value, 1> newCallResults;
521     if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) {
522       auto newCall = rewriter->create<A>(
523           loc, callOp.getKernel(), callOp.getGridSizeOperandValues(),
524           callOp.getBlockSizeOperandValues(),
525           callOp.getDynamicSharedMemorySize(), newOpers);
526       if (callOp.getClusterSizeX())
527         newCall.getClusterSizeXMutable().assign(callOp.getClusterSizeX());
528       if (callOp.getClusterSizeY())
529         newCall.getClusterSizeYMutable().assign(callOp.getClusterSizeY());
530       if (callOp.getClusterSizeZ())
531         newCall.getClusterSizeZMutable().assign(callOp.getClusterSizeZ());
532       newCallResults.append(newCall.result_begin(), newCall.result_end());
533     } else if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
534       fir::CallOp newCall;
535       if (callOp.getCallee()) {
536         newCall =
537             rewriter->create<A>(loc, *callOp.getCallee(), newResTys, newOpers);
538       } else {
539         // TODO: llvm dialect must be updated to propagate argument on
540         // attributes for indirect calls. See:
541         // https://discourse.llvm.org/t/should-llvm-callop-be-able-to-carry-argument-attributes-for-indirect-calls/75431
542         if (hasByValOrSRetArgs(newInTyAndAttrs))
543           TODO(loc,
544                "passing argument or result on the stack in indirect calls");
545         newOpers[0].setType(mlir::FunctionType::get(
546             callOp.getContext(),
547             mlir::TypeRange{newInTypes}.drop_front(dropFront), newResTys));
548         newCall = rewriter->create<A>(loc, newResTys, newOpers);
549       }
550       LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n');
551       if (wrap)
552         newCallResults.push_back((*wrap)(newCall.getOperation()));
553       else
554         newCallResults.append(newCall.result_begin(), newCall.result_end());
555     } else {
556       fir::DispatchOp dispatchOp = rewriter->create<A>(
557           loc, newResTys, rewriter->getStringAttr(callOp.getMethod()),
558           callOp.getOperands()[0], newOpers,
559           rewriter->getI32IntegerAttr(*callOp.getPassArgPos() + passArgShift),
560           callOp.getProcedureAttrsAttr());
561       if (wrap)
562         newCallResults.push_back((*wrap)(dispatchOp.getOperation()));
563       else
564         newCallResults.append(dispatchOp.result_begin(),
565                               dispatchOp.result_end());
566     }
567 
568     if (newCallResults.size() <= 1) {
569       if (savedStackPtr) {
570         if (newCallResults.size() == 1) {
571           // We assume that all the allocas are inserted before
572           // the operation that defines the new call result.
573           rewriter->setInsertionPointAfterValue(newCallResults[0]);
574         } else {
575           // If the call does not have results, then insert
576           // stack restore after the original call operation.
577           rewriter->setInsertionPointAfter(callOp);
578         }
579         genStackRestore(loc, savedStackPtr);
580       }
581       replaceOp(callOp, newCallResults);
582     } else {
583       // The TODO is duplicated here to make sure this part
584       // handles the stackrestore insertion properly, if
585       // we add support for multiple call results.
586       TODO(loc, "multiple results not supported yet");
587     }
588   }
589 
590   // Result type fixup for ComplexType.
591   template <typename Ty>
592   void lowerComplexSignatureRes(
593       mlir::Location loc, mlir::ComplexType cmplx, Ty &newResTys,
594       fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
595     if (noComplexConversion) {
596       newResTys.push_back(cmplx);
597       return;
598     }
599     for (auto &tup :
600          specifics->complexReturnType(loc, cmplx.getElementType())) {
601       auto argTy = std::get<mlir::Type>(tup);
602       if (std::get<fir::CodeGenSpecifics::Attributes>(tup).isSRet())
603         newInTyAndAttrs.push_back(tup);
604       else
605         newResTys.push_back(argTy);
606     }
607   }
608 
609   // Argument type fixup for ComplexType.
610   void lowerComplexSignatureArg(
611       mlir::Location loc, mlir::ComplexType cmplx,
612       fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
613     if (noComplexConversion) {
614       newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(cmplx));
615     } else {
616       auto cplxArgs =
617           specifics->complexArgumentType(loc, cmplx.getElementType());
618       newInTyAndAttrs.insert(newInTyAndAttrs.end(), cplxArgs.begin(),
619                              cplxArgs.end());
620     }
621   }
622 
623   template <typename Ty>
624   void
625   lowerStructSignatureRes(mlir::Location loc, fir::RecordType recTy,
626                           Ty &newResTys,
627                           fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
628     if (noComplexConversion) {
629       newResTys.push_back(recTy);
630       return;
631     } else {
632       for (auto &tup : specifics->structReturnType(loc, recTy)) {
633         if (std::get<fir::CodeGenSpecifics::Attributes>(tup).isSRet())
634           newInTyAndAttrs.push_back(tup);
635         else
636           newResTys.push_back(std::get<mlir::Type>(tup));
637       }
638     }
639   }
640 
641   void
642   lowerStructSignatureArg(mlir::Location loc, fir::RecordType recTy,
643                           fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
644     if (noStructConversion) {
645       newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(recTy));
646       return;
647     }
648     auto structArgs =
649         specifics->structArgumentType(loc, recTy, newInTyAndAttrs);
650     newInTyAndAttrs.insert(newInTyAndAttrs.end(), structArgs.begin(),
651                            structArgs.end());
652   }
653 
654   llvm::SmallVector<mlir::Type>
655   toTypeList(const fir::CodeGenSpecifics::Marshalling &marshalled) {
656     llvm::SmallVector<mlir::Type> typeList;
657     for (auto &typeAndAttr : marshalled)
658       typeList.emplace_back(std::get<mlir::Type>(typeAndAttr));
659     return typeList;
660   }
661 
662   /// Taking the address of a function. Modify the signature as needed.
663   void convertAddrOp(fir::AddrOfOp addrOp) {
664     rewriter->setInsertionPoint(addrOp);
665     auto addrTy = mlir::cast<mlir::FunctionType>(addrOp.getType());
666     fir::CodeGenSpecifics::Marshalling newInTyAndAttrs;
667     llvm::SmallVector<mlir::Type> newResTys;
668     auto loc = addrOp.getLoc();
669     for (mlir::Type ty : addrTy.getResults()) {
670       llvm::TypeSwitch<mlir::Type>(ty)
671           .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
672             lowerComplexSignatureRes(loc, ty, newResTys, newInTyAndAttrs);
673           })
674           .Case<fir::RecordType>([&](fir::RecordType ty) {
675             lowerStructSignatureRes(loc, ty, newResTys, newInTyAndAttrs);
676           })
677           .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
678     }
679     llvm::SmallVector<mlir::Type> trailingInTys;
680     for (mlir::Type ty : addrTy.getInputs()) {
681       llvm::TypeSwitch<mlir::Type>(ty)
682           .Case<fir::BoxCharType>([&](auto box) {
683             if (noCharacterConversion) {
684               newInTyAndAttrs.push_back(
685                   fir::CodeGenSpecifics::getTypeAndAttr(box));
686             } else {
687               for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) {
688                 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
689                 auto argTy = std::get<mlir::Type>(tup);
690                 if (attr.isAppend())
691                   trailingInTys.push_back(argTy);
692                 else
693                   newInTyAndAttrs.push_back(tup);
694               }
695             }
696           })
697           .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
698             lowerComplexSignatureArg(loc, ty, newInTyAndAttrs);
699           })
700           .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
701             if (fir::isCharacterProcedureTuple(tuple)) {
702               newInTyAndAttrs.push_back(
703                   fir::CodeGenSpecifics::getTypeAndAttr(tuple.getType(0)));
704               trailingInTys.push_back(tuple.getType(1));
705             } else {
706               newInTyAndAttrs.push_back(
707                   fir::CodeGenSpecifics::getTypeAndAttr(ty));
708             }
709           })
710           .template Case<fir::RecordType>([&](fir::RecordType recTy) {
711             lowerStructSignatureArg(loc, recTy, newInTyAndAttrs);
712           })
713           .Default([&](mlir::Type ty) {
714             newInTyAndAttrs.push_back(
715                 fir::CodeGenSpecifics::getTypeAndAttr(ty));
716           });
717     }
718     llvm::SmallVector<mlir::Type> newInTypes = toTypeList(newInTyAndAttrs);
719     // append trailing input types
720     newInTypes.insert(newInTypes.end(), trailingInTys.begin(),
721                       trailingInTys.end());
722     // replace this op with a new one with the updated signature
723     auto newTy = rewriter->getFunctionType(newInTypes, newResTys);
724     auto newOp = rewriter->create<fir::AddrOfOp>(addrOp.getLoc(), newTy,
725                                                  addrOp.getSymbol());
726     replaceOp(addrOp, newOp.getResult());
727   }
728 
729   /// Convert the type signatures on all the functions present in the module.
730   /// As the type signature is being changed, this must also update the
731   /// function itself to use any new arguments, etc.
732   llvm::LogicalResult convertTypes(mlir::ModuleOp mod) {
733     mlir::MLIRContext *ctx = mod->getContext();
734     auto targetCPU = specifics->getTargetCPU();
735     mlir::StringAttr targetCPUAttr =
736         targetCPU.empty() ? nullptr : mlir::StringAttr::get(ctx, targetCPU);
737     auto tuneCPU = specifics->getTuneCPU();
738     mlir::StringAttr tuneCPUAttr =
739         tuneCPU.empty() ? nullptr : mlir::StringAttr::get(ctx, tuneCPU);
740     auto targetFeaturesAttr = specifics->getTargetFeatures();
741 
742     for (auto fn : mod.getOps<mlir::func::FuncOp>()) {
743       if (targetCPUAttr)
744         fn->setAttr("target_cpu", targetCPUAttr);
745 
746       if (tuneCPUAttr)
747         fn->setAttr("tune_cpu", tuneCPUAttr);
748 
749       if (targetFeaturesAttr)
750         fn->setAttr("target_features", targetFeaturesAttr);
751 
752       convertSignature<mlir::func::ReturnOp, mlir::func::FuncOp>(fn);
753     }
754 
755     for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>()) {
756       for (auto fn : gpuMod.getOps<mlir::func::FuncOp>())
757         convertSignature<mlir::func::ReturnOp, mlir::func::FuncOp>(fn);
758       for (auto fn : gpuMod.getOps<mlir::gpu::GPUFuncOp>())
759         convertSignature<mlir::gpu::ReturnOp, mlir::gpu::GPUFuncOp>(fn);
760     }
761 
762     return mlir::success();
763   }
764 
765   // Returns true if the function should be interoperable with C.
766   static bool isFuncWithCCallingConvention(mlir::Operation *op) {
767     auto funcOp = mlir::dyn_cast<mlir::func::FuncOp>(op);
768     if (!funcOp)
769       return false;
770     return op->hasAttrOfType<mlir::UnitAttr>(
771                fir::FIROpsDialect::getFirRuntimeAttrName()) ||
772            op->hasAttrOfType<mlir::StringAttr>(fir::getSymbolAttrName());
773   }
774 
775   /// If the signature does not need any special target-specific conversions,
776   /// then it is considered portable for any target, and this function will
777   /// return `true`. Otherwise, the signature is not portable and `false` is
778   /// returned.
779   bool hasPortableSignature(mlir::Type signature, mlir::Operation *op) {
780     assert(mlir::isa<mlir::FunctionType>(signature));
781     auto func = mlir::dyn_cast<mlir::FunctionType>(signature);
782     bool hasCCallingConv = isFuncWithCCallingConvention(op);
783     for (auto ty : func.getResults())
784       if ((mlir::isa<fir::BoxCharType>(ty) && !noCharacterConversion) ||
785           (fir::isa_complex(ty) && !noComplexConversion) ||
786           (mlir::isa<mlir::IntegerType>(ty) && hasCCallingConv) ||
787           (mlir::isa<fir::RecordType>(ty) && !noStructConversion)) {
788         LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
789         return false;
790       }
791     for (auto ty : func.getInputs())
792       if (((mlir::isa<fir::BoxCharType>(ty) ||
793             fir::isCharacterProcedureTuple(ty)) &&
794            !noCharacterConversion) ||
795           (fir::isa_complex(ty) && !noComplexConversion) ||
796           (mlir::isa<mlir::IntegerType>(ty) && hasCCallingConv) ||
797           (mlir::isa<fir::RecordType>(ty) && !noStructConversion)) {
798         LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
799         return false;
800       }
801     return true;
802   }
803 
804   /// Determine if the signature has host associations. The host association
805   /// argument may need special target specific rewriting.
806   template <typename OpTy>
807   static bool hasHostAssociations(OpTy func) {
808     std::size_t end = func.getFunctionType().getInputs().size();
809     for (std::size_t i = 0; i < end; ++i)
810       if (func.template getArgAttrOfType<mlir::UnitAttr>(
811               i, fir::getHostAssocAttrName()))
812         return true;
813     return false;
814   }
815 
816   /// Rewrite the signatures and body of the `FuncOp`s in the module for
817   /// the immediately subsequent target code gen.
818   template <typename ReturnOpTy, typename FuncOpTy>
819   void convertSignature(FuncOpTy func) {
820     auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
821     if (hasPortableSignature(funcTy, func) && !hasHostAssociations(func))
822       return;
823     llvm::SmallVector<mlir::Type> newResTys;
824     fir::CodeGenSpecifics::Marshalling newInTyAndAttrs;
825     llvm::SmallVector<std::pair<unsigned, mlir::NamedAttribute>> savedAttrs;
826     llvm::SmallVector<std::pair<unsigned, mlir::NamedAttribute>> extraAttrs;
827     llvm::SmallVector<FixupTy> fixups;
828     llvm::SmallVector<std::pair<unsigned, mlir::NamedAttrList>, 1> resultAttrs;
829 
830     // Save argument attributes in case there is a shift so we can replace them
831     // correctly.
832     for (auto e : llvm::enumerate(funcTy.getInputs())) {
833       unsigned index = e.index();
834       llvm::ArrayRef<mlir::NamedAttribute> attrs =
835           mlir::function_interface_impl::getArgAttrs(func, index);
836       for (mlir::NamedAttribute attr : attrs) {
837         savedAttrs.push_back({index, attr});
838       }
839     }
840 
841     // Convert return value(s)
842     for (auto ty : funcTy.getResults())
843       llvm::TypeSwitch<mlir::Type>(ty)
844           .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
845             if (noComplexConversion)
846               newResTys.push_back(cmplx);
847             else
848               doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups);
849           })
850           .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
851             auto m = specifics->integerArgumentType(func.getLoc(), intTy);
852             assert(m.size() == 1);
853             auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
854             auto retTy = std::get<mlir::Type>(m[0]);
855             std::size_t resId = newResTys.size();
856             llvm::StringRef extensionAttrName = attr.getIntExtensionAttrName();
857             if (!extensionAttrName.empty() &&
858                 isFuncWithCCallingConvention(func))
859               resultAttrs.emplace_back(
860                   resId, rewriter->getNamedAttr(extensionAttrName,
861                                                 rewriter->getUnitAttr()));
862             newResTys.push_back(retTy);
863           })
864           .template Case<fir::RecordType>([&](fir::RecordType recTy) {
865             doStructReturn(func, recTy, newResTys, newInTyAndAttrs, fixups);
866           })
867           .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
868 
869     // Saved potential shift in argument. Handling of result can add arguments
870     // at the beginning of the function signature.
871     unsigned argumentShift = newInTyAndAttrs.size();
872 
873     // Convert arguments
874     llvm::SmallVector<mlir::Type> trailingTys;
875     for (auto e : llvm::enumerate(funcTy.getInputs())) {
876       auto ty = e.value();
877       unsigned index = e.index();
878       llvm::TypeSwitch<mlir::Type>(ty)
879           .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
880             if (noCharacterConversion) {
881               newInTyAndAttrs.push_back(
882                   fir::CodeGenSpecifics::getTypeAndAttr(boxTy));
883             } else {
884               // Convert a CHARACTER argument type. This can involve separating
885               // the pointer and the LEN into two arguments and moving the LEN
886               // argument to the end of the arg list.
887               for (auto &tup :
888                    specifics->boxcharArgumentType(boxTy.getEleTy())) {
889                 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
890                 auto argTy = std::get<mlir::Type>(tup);
891                 if (attr.isAppend()) {
892                   trailingTys.push_back(argTy);
893                 } else {
894                   fixups.emplace_back(FixupTy::Codes::Trailing,
895                                       newInTyAndAttrs.size(),
896                                       trailingTys.size());
897                   newInTyAndAttrs.push_back(tup);
898                 }
899               }
900             }
901           })
902           .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
903             doComplexArg(func, cmplx, newInTyAndAttrs, fixups);
904           })
905           .template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
906             if (fir::isCharacterProcedureTuple(tuple)) {
907               fixups.emplace_back(FixupTy::Codes::TrailingCharProc,
908                                   newInTyAndAttrs.size(), trailingTys.size());
909               newInTyAndAttrs.push_back(
910                   fir::CodeGenSpecifics::getTypeAndAttr(tuple.getType(0)));
911               trailingTys.push_back(tuple.getType(1));
912             } else {
913               newInTyAndAttrs.push_back(
914                   fir::CodeGenSpecifics::getTypeAndAttr(ty));
915             }
916           })
917           .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
918             auto m = specifics->integerArgumentType(func.getLoc(), intTy);
919             assert(m.size() == 1);
920             auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
921             auto argNo = newInTyAndAttrs.size();
922             llvm::StringRef extensionAttrName = attr.getIntExtensionAttrName();
923             if (!extensionAttrName.empty() &&
924                 isFuncWithCCallingConvention(func))
925               fixups.emplace_back(FixupTy::Codes::ArgumentType, argNo,
926                                   [=](FuncOpTy func) {
927                                     func.setArgAttr(
928                                         argNo, extensionAttrName,
929                                         mlir::UnitAttr::get(func.getContext()));
930                                   });
931 
932             newInTyAndAttrs.push_back(m[0]);
933           })
934           .template Case<fir::RecordType>([&](fir::RecordType recTy) {
935             doStructArg(func, recTy, newInTyAndAttrs, fixups);
936           })
937           .Default([&](mlir::Type ty) {
938             newInTyAndAttrs.push_back(
939                 fir::CodeGenSpecifics::getTypeAndAttr(ty));
940           });
941 
942       if (func.template getArgAttrOfType<mlir::UnitAttr>(
943               index, fir::getHostAssocAttrName())) {
944         extraAttrs.push_back(
945             {newInTyAndAttrs.size() - 1,
946              rewriter->getNamedAttr("llvm.nest", rewriter->getUnitAttr())});
947       }
948     }
949 
950     if (!func.empty()) {
951       // If the function has a body, then apply the fixups to the arguments and
952       // return ops as required. These fixups are done in place.
953       auto loc = func.getLoc();
954       const auto fixupSize = fixups.size();
955       const auto oldArgTys = func.getFunctionType().getInputs();
956       int offset = 0;
957       for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) {
958         const auto &fixup = fixups[i];
959         mlir::Type fixupType =
960             fixup.index < newInTyAndAttrs.size()
961                 ? std::get<mlir::Type>(newInTyAndAttrs[fixup.index])
962                 : mlir::Type{};
963         switch (fixup.code) {
964         case FixupTy::Codes::ArgumentAsLoad: {
965           // Argument was pass-by-value, but is now pass-by-reference and
966           // possibly with a different element type.
967           auto newArg =
968               func.front().insertArgument(fixup.index, fixupType, loc);
969           rewriter->setInsertionPointToStart(&func.front());
970           auto oldArgTy =
971               fir::ReferenceType::get(oldArgTys[fixup.index - offset]);
972           auto cast = rewriter->create<fir::ConvertOp>(loc, oldArgTy, newArg);
973           auto load = rewriter->create<fir::LoadOp>(loc, cast);
974           func.getArgument(fixup.index + 1).replaceAllUsesWith(load);
975           func.front().eraseArgument(fixup.index + 1);
976         } break;
977         case FixupTy::Codes::ArgumentType: {
978           // Argument is pass-by-value, but its type has likely been modified to
979           // suit the target ABI convention.
980           auto oldArgTy = oldArgTys[fixup.index - offset];
981           // If type did not change, keep the original argument.
982           if (fixupType == oldArgTy)
983             break;
984 
985           auto newArg =
986               func.front().insertArgument(fixup.index, fixupType, loc);
987           rewriter->setInsertionPointToStart(&func.front());
988           mlir::Value bitcast = convertValueInMemory(loc, newArg, oldArgTy,
989                                                      /*inputMayBeBigger=*/true);
990           func.getArgument(fixup.index + 1).replaceAllUsesWith(bitcast);
991           func.front().eraseArgument(fixup.index + 1);
992           LLVM_DEBUG(llvm::dbgs()
993                      << "old argument: " << oldArgTy << ", repl: " << bitcast
994                      << ", new argument: "
995                      << func.getArgument(fixup.index).getType() << '\n');
996         } break;
997         case FixupTy::Codes::CharPair: {
998           // The FIR boxchar argument has been split into a pair of distinct
999           // arguments that are in juxtaposition to each other.
1000           auto newArg =
1001               func.front().insertArgument(fixup.index, fixupType, loc);
1002           if (fixup.second == 1) {
1003             rewriter->setInsertionPointToStart(&func.front());
1004             auto boxTy = oldArgTys[fixup.index - offset - fixup.second];
1005             auto box = rewriter->create<fir::EmboxCharOp>(
1006                 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg);
1007             func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
1008             func.front().eraseArgument(fixup.index + 1);
1009             offset++;
1010           }
1011         } break;
1012         case FixupTy::Codes::ReturnAsStore: {
1013           // The value being returned is now being returned in memory (callee
1014           // stack space) through a hidden reference argument.
1015           auto newArg =
1016               func.front().insertArgument(fixup.index, fixupType, loc);
1017           offset++;
1018           func.walk([&](ReturnOpTy ret) {
1019             rewriter->setInsertionPoint(ret);
1020             auto oldOper = ret.getOperand(0);
1021             auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
1022             auto cast =
1023                 rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
1024             rewriter->create<fir::StoreOp>(loc, oldOper, cast);
1025             rewriter->create<ReturnOpTy>(loc);
1026             ret.erase();
1027           });
1028         } break;
1029         case FixupTy::Codes::ReturnType: {
1030           // The function is still returning a value, but its type has likely
1031           // changed to suit the target ABI convention.
1032           func.walk([&](ReturnOpTy ret) {
1033             rewriter->setInsertionPoint(ret);
1034             auto oldOper = ret.getOperand(0);
1035             mlir::Value bitcast =
1036                 convertValueInMemory(loc, oldOper, newResTys[fixup.index],
1037                                      /*inputMayBeBigger=*/false);
1038             rewriter->create<ReturnOpTy>(loc, bitcast);
1039             ret.erase();
1040           });
1041         } break;
1042         case FixupTy::Codes::Split: {
1043           // The FIR argument has been split into a pair of distinct arguments
1044           // that are in juxtaposition to each other. (For COMPLEX value or
1045           // derived type passed with VALUE in BIND(C) context).
1046           auto newArg =
1047               func.front().insertArgument(fixup.index, fixupType, loc);
1048           if (fixup.second == 1) {
1049             rewriter->setInsertionPointToStart(&func.front());
1050             mlir::Value firstArg = func.front().getArgument(fixup.index - 1);
1051             mlir::Type originalTy =
1052                 oldArgTys[fixup.index - offset - fixup.second];
1053             mlir::Type pairTy = originalTy;
1054             if (!fir::isa_complex(originalTy)) {
1055               pairTy = mlir::TupleType::get(
1056                   originalTy.getContext(),
1057                   mlir::TypeRange{firstArg.getType(), newArg.getType()});
1058             }
1059             auto undef = rewriter->create<fir::UndefOp>(loc, pairTy);
1060             auto iTy = rewriter->getIntegerType(32);
1061             auto zero = rewriter->getIntegerAttr(iTy, 0);
1062             auto one = rewriter->getIntegerAttr(iTy, 1);
1063             mlir::Value pair1 = rewriter->create<fir::InsertValueOp>(
1064                 loc, pairTy, undef, firstArg, rewriter->getArrayAttr(zero));
1065             mlir::Value pair = rewriter->create<fir::InsertValueOp>(
1066                 loc, pairTy, pair1, newArg, rewriter->getArrayAttr(one));
1067             // Cast local argument tuple to original type via memory if needed.
1068             if (pairTy != originalTy)
1069               pair = convertValueInMemory(loc, pair, originalTy,
1070                                           /*inputMayBeBigger=*/true);
1071             func.getArgument(fixup.index + 1).replaceAllUsesWith(pair);
1072             func.front().eraseArgument(fixup.index + 1);
1073             offset++;
1074           }
1075         } break;
1076         case FixupTy::Codes::Trailing: {
1077           // The FIR argument has been split into a pair of distinct arguments.
1078           // The first part of the pair appears in the original argument
1079           // position. The second part of the pair is appended after all the
1080           // original arguments. (Boxchar arguments.)
1081           auto newBufArg =
1082               func.front().insertArgument(fixup.index, fixupType, loc);
1083           auto newLenArg =
1084               func.front().addArgument(trailingTys[fixup.second], loc);
1085           auto boxTy = oldArgTys[fixup.index - offset];
1086           rewriter->setInsertionPointToStart(&func.front());
1087           auto box = rewriter->create<fir::EmboxCharOp>(loc, boxTy, newBufArg,
1088                                                         newLenArg);
1089           func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
1090           func.front().eraseArgument(fixup.index + 1);
1091         } break;
1092         case FixupTy::Codes::TrailingCharProc: {
1093           // The FIR character procedure argument tuple must be split into a
1094           // pair of distinct arguments. The first part of the pair appears in
1095           // the original argument position. The second part of the pair is
1096           // appended after all the original arguments.
1097           auto newProcPointerArg =
1098               func.front().insertArgument(fixup.index, fixupType, loc);
1099           auto newLenArg =
1100               func.front().addArgument(trailingTys[fixup.second], loc);
1101           auto tupleType = oldArgTys[fixup.index - offset];
1102           rewriter->setInsertionPointToStart(&func.front());
1103           fir::FirOpBuilder builder(*rewriter, getModule());
1104           auto tuple = fir::factory::createCharacterProcedureTuple(
1105               builder, loc, tupleType, newProcPointerArg, newLenArg);
1106           func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple);
1107           func.front().eraseArgument(fixup.index + 1);
1108         } break;
1109         }
1110       }
1111     }
1112 
1113     llvm::SmallVector<mlir::Type> newInTypes = toTypeList(newInTyAndAttrs);
1114     // Set the new type and finalize the arguments, etc.
1115     newInTypes.insert(newInTypes.end(), trailingTys.begin(), trailingTys.end());
1116     auto newFuncTy =
1117         mlir::FunctionType::get(func.getContext(), newInTypes, newResTys);
1118     LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n');
1119     func.setType(newFuncTy);
1120 
1121     for (std::pair<unsigned, mlir::NamedAttribute> extraAttr : extraAttrs)
1122       func.setArgAttr(extraAttr.first, extraAttr.second.getName(),
1123                       extraAttr.second.getValue());
1124 
1125     for (auto [resId, resAttrList] : resultAttrs)
1126       for (mlir::NamedAttribute resAttr : resAttrList)
1127         func.setResultAttr(resId, resAttr.getName(), resAttr.getValue());
1128 
1129     // Replace attributes to the correct argument if there was an argument shift
1130     // to the right.
1131     if (argumentShift > 0) {
1132       for (std::pair<unsigned, mlir::NamedAttribute> savedAttr : savedAttrs) {
1133         func.removeArgAttr(savedAttr.first, savedAttr.second.getName());
1134         func.setArgAttr(savedAttr.first + argumentShift,
1135                         savedAttr.second.getName(),
1136                         savedAttr.second.getValue());
1137       }
1138     }
1139 
1140     for (auto &fixup : fixups) {
1141       if constexpr (std::is_same_v<FuncOpTy, mlir::func::FuncOp>)
1142         if (fixup.finalizer)
1143           (*fixup.finalizer)(func);
1144       if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>)
1145         if (fixup.gpuFinalizer)
1146           (*fixup.gpuFinalizer)(func);
1147     }
1148   }
1149 
1150   template <typename OpTy, typename Ty, typename FIXUPS>
1151   void doReturn(OpTy func, Ty &newResTys,
1152                 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1153                 FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) {
1154     assert(m.size() == 1 &&
1155            "expect result to be turned into single argument or result so far");
1156     auto &tup = m[0];
1157     auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
1158     auto argTy = std::get<mlir::Type>(tup);
1159     if (attr.isSRet()) {
1160       unsigned argNo = newInTyAndAttrs.size();
1161       if (auto align = attr.getAlignment())
1162         fixups.emplace_back(
1163             FixupTy::Codes::ReturnAsStore, argNo, [=](OpTy func) {
1164               auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
1165                   func.getFunctionType().getInput(argNo));
1166               func.setArgAttr(argNo, "llvm.sret",
1167                               mlir::TypeAttr::get(elemType));
1168               func.setArgAttr(argNo, "llvm.align",
1169                               rewriter->getIntegerAttr(
1170                                   rewriter->getIntegerType(32), align));
1171             });
1172       else
1173         fixups.emplace_back(FixupTy::Codes::ReturnAsStore, argNo,
1174                             [=](OpTy func) {
1175                               auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
1176                                   func.getFunctionType().getInput(argNo));
1177                               func.setArgAttr(argNo, "llvm.sret",
1178                                               mlir::TypeAttr::get(elemType));
1179                             });
1180       newInTyAndAttrs.push_back(tup);
1181       return;
1182     }
1183     if (auto align = attr.getAlignment())
1184       fixups.emplace_back(
1185           FixupTy::Codes::ReturnType, newResTys.size(), [=](OpTy func) {
1186             func.setArgAttr(
1187                 newResTys.size(), "llvm.align",
1188                 rewriter->getIntegerAttr(rewriter->getIntegerType(32), align));
1189           });
1190     else
1191       fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size());
1192     newResTys.push_back(argTy);
1193   }
1194 
1195   /// Convert a complex return value. This can involve converting the return
1196   /// value to a "hidden" first argument or packing the complex into a wide
1197   /// GPR.
1198   template <typename OpTy, typename Ty, typename FIXUPS>
1199   void doComplexReturn(OpTy func, mlir::ComplexType cmplx, Ty &newResTys,
1200                        fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1201                        FIXUPS &fixups) {
1202     if (noComplexConversion) {
1203       newResTys.push_back(cmplx);
1204       return;
1205     }
1206     auto m =
1207         specifics->complexReturnType(func.getLoc(), cmplx.getElementType());
1208     doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
1209   }
1210 
1211   template <typename OpTy, typename Ty, typename FIXUPS>
1212   void doStructReturn(OpTy func, fir::RecordType recTy, Ty &newResTys,
1213                       fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1214                       FIXUPS &fixups) {
1215     if (noStructConversion) {
1216       newResTys.push_back(recTy);
1217       return;
1218     }
1219     auto m = specifics->structReturnType(func.getLoc(), recTy);
1220     doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
1221   }
1222 
1223   template <typename OpTy, typename FIXUPS>
1224   void createFuncOpArgFixups(
1225       OpTy func, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1226       fir::CodeGenSpecifics::Marshalling &argsInTys, FIXUPS &fixups) {
1227     const auto fixupCode = argsInTys.size() > 1 ? FixupTy::Codes::Split
1228                                                 : FixupTy::Codes::ArgumentType;
1229     for (auto e : llvm::enumerate(argsInTys)) {
1230       auto &tup = e.value();
1231       auto index = e.index();
1232       auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
1233       auto argNo = newInTyAndAttrs.size();
1234       if (attr.isByVal()) {
1235         if (auto align = attr.getAlignment())
1236           fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, argNo,
1237                               [=](OpTy func) {
1238                                 auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
1239                                     func.getFunctionType().getInput(argNo));
1240                                 func.setArgAttr(argNo, "llvm.byval",
1241                                                 mlir::TypeAttr::get(elemType));
1242                                 func.setArgAttr(
1243                                     argNo, "llvm.align",
1244                                     rewriter->getIntegerAttr(
1245                                         rewriter->getIntegerType(32), align));
1246                               });
1247         else
1248           fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad,
1249                               newInTyAndAttrs.size(), [=](OpTy func) {
1250                                 auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
1251                                     func.getFunctionType().getInput(argNo));
1252                                 func.setArgAttr(argNo, "llvm.byval",
1253                                                 mlir::TypeAttr::get(elemType));
1254                               });
1255       } else {
1256         if (auto align = attr.getAlignment())
1257           fixups.emplace_back(
1258               fixupCode, argNo, index, [=](OpTy func) {
1259                 func.setArgAttr(argNo, "llvm.align",
1260                                 rewriter->getIntegerAttr(
1261                                     rewriter->getIntegerType(32), align));
1262               });
1263         else
1264           fixups.emplace_back(fixupCode, argNo, index);
1265       }
1266       newInTyAndAttrs.push_back(tup);
1267     }
1268   }
1269 
1270   /// Convert a complex argument value. This can involve storing the value to
1271   /// a temporary memory location or factoring the value into two distinct
1272   /// arguments.
1273   template <typename OpTy, typename FIXUPS>
1274   void doComplexArg(OpTy func, mlir::ComplexType cmplx,
1275                     fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1276                     FIXUPS &fixups) {
1277     if (noComplexConversion) {
1278       newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(cmplx));
1279       return;
1280     }
1281     auto cplxArgs =
1282         specifics->complexArgumentType(func.getLoc(), cmplx.getElementType());
1283     createFuncOpArgFixups(func, newInTyAndAttrs, cplxArgs, fixups);
1284   }
1285 
1286   template <typename OpTy, typename FIXUPS>
1287   void doStructArg(OpTy func, fir::RecordType recTy,
1288                    fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1289                    FIXUPS &fixups) {
1290     if (noStructConversion) {
1291       newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(recTy));
1292       return;
1293     }
1294     auto structArgs =
1295         specifics->structArgumentType(func.getLoc(), recTy, newInTyAndAttrs);
1296     createFuncOpArgFixups(func, newInTyAndAttrs, structArgs, fixups);
1297   }
1298 
1299 private:
1300   // Replace `op` and remove it.
1301   void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) {
1302     op->replaceAllUsesWith(newValues);
1303     op->dropAllReferences();
1304     op->erase();
1305   }
1306 
1307   inline void setMembers(fir::CodeGenSpecifics *s, mlir::OpBuilder *r,
1308                          mlir::DataLayout *dl) {
1309     specifics = s;
1310     rewriter = r;
1311     dataLayout = dl;
1312   }
1313 
1314   inline void clearMembers() { setMembers(nullptr, nullptr, nullptr); }
1315 
1316   // Inserts a call to llvm.stacksave at the current insertion
1317   // point and the given location. Returns the call's result Value.
1318   inline mlir::Value genStackSave(mlir::Location loc) {
1319     fir::FirOpBuilder builder(*rewriter, getModule());
1320     return builder.genStackSave(loc);
1321   }
1322 
1323   // Inserts a call to llvm.stackrestore at the current insertion
1324   // point and the given location and argument.
1325   inline void genStackRestore(mlir::Location loc, mlir::Value sp) {
1326     fir::FirOpBuilder builder(*rewriter, getModule());
1327     return builder.genStackRestore(loc, sp);
1328   }
1329 
1330   fir::CodeGenSpecifics *specifics = nullptr;
1331   mlir::OpBuilder *rewriter = nullptr;
1332   mlir::DataLayout *dataLayout = nullptr;
1333 };
1334 } // namespace
1335