xref: /llvm-project/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp (revision d0c9e70bcc40948821e83eb0ec32e6e15fb0dd4b)
1 //===- InlinerInterfaceImpl.cpp - Inlining for LLVM the dialect -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Logic for inlining LLVM functions and the definition of the
10 // LLVMInliningInterface.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
15 #include "mlir/Analysis/SliceWalk.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/Interfaces/DataLayoutInterfaces.h"
20 #include "mlir/Interfaces/ViewLikeInterface.h"
21 #include "mlir/Transforms/InliningUtils.h"
22 #include "llvm/ADT/ScopeExit.h"
23 #include "llvm/Support/Debug.h"
24 
25 #define DEBUG_TYPE "llvm-inliner"
26 
27 using namespace mlir;
28 
29 /// Check whether the given alloca is an input to a lifetime intrinsic,
30 /// optionally passing through one or more casts on the way. This is not
31 /// transitive through block arguments.
32 static bool hasLifetimeMarkers(LLVM::AllocaOp allocaOp) {
33   SmallVector<Operation *> stack(allocaOp->getUsers().begin(),
34                                  allocaOp->getUsers().end());
35   while (!stack.empty()) {
36     Operation *op = stack.pop_back_val();
37     if (isa<LLVM::LifetimeStartOp, LLVM::LifetimeEndOp>(op))
38       return true;
39     if (isa<LLVM::BitcastOp>(op))
40       stack.append(op->getUsers().begin(), op->getUsers().end());
41   }
42   return false;
43 }
44 
45 /// Handles alloca operations in the inlined blocks:
46 /// - Moves all alloca operations with a constant size in the former entry block
47 ///   of the callee into the entry block of the caller, so they become part of
48 ///   the function prologue/epilogue during code generation.
49 /// - Inserts lifetime intrinsics that limit the scope of inlined static allocas
50 ///   to the inlined blocks.
51 /// - Inserts StackSave and StackRestore operations if dynamic allocas were
52 ///   inlined.
53 static void
54 handleInlinedAllocas(Operation *call,
55                      iterator_range<Region::iterator> inlinedBlocks) {
56   // Locate the entry block of the closest callsite ancestor that has either the
57   // IsolatedFromAbove or AutomaticAllocationScope trait. In pure LLVM dialect
58   // programs, this is the LLVMFuncOp containing the call site. However, in
59   // mixed-dialect programs, the callsite might be nested in another operation
60   // that carries one of these traits. In such scenarios, this traversal stops
61   // at the closest ancestor with either trait, ensuring visibility post
62   // relocation and respecting allocation scopes.
63   Block *callerEntryBlock = nullptr;
64   Operation *currentOp = call;
65   while (Operation *parentOp = currentOp->getParentOp()) {
66     if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
67         parentOp->mightHaveTrait<OpTrait::AutomaticAllocationScope>()) {
68       callerEntryBlock = &currentOp->getParentRegion()->front();
69       break;
70     }
71     currentOp = parentOp;
72   }
73 
74   // Avoid relocating the alloca operations if the call has been inlined into
75   // the entry block already, which is typically the encompassing
76   // LLVM function, or if the relevant entry block cannot be identified.
77   Block *calleeEntryBlock = &(*inlinedBlocks.begin());
78   if (!callerEntryBlock || callerEntryBlock == calleeEntryBlock)
79     return;
80 
81   SmallVector<std::tuple<LLVM::AllocaOp, IntegerAttr, bool>> allocasToMove;
82   bool shouldInsertLifetimes = false;
83   bool hasDynamicAlloca = false;
84   // Conservatively only move static alloca operations that are part of the
85   // entry block and do not inspect nested regions, since they may execute
86   // conditionally or have other unknown semantics.
87   for (auto allocaOp : calleeEntryBlock->getOps<LLVM::AllocaOp>()) {
88     IntegerAttr arraySize;
89     if (!matchPattern(allocaOp.getArraySize(), m_Constant(&arraySize))) {
90       hasDynamicAlloca = true;
91       continue;
92     }
93     bool shouldInsertLifetime =
94         arraySize.getValue() != 0 && !hasLifetimeMarkers(allocaOp);
95     shouldInsertLifetimes |= shouldInsertLifetime;
96     allocasToMove.emplace_back(allocaOp, arraySize, shouldInsertLifetime);
97   }
98   // Check the remaining inlined blocks for dynamic allocas as well.
99   for (Block &block : llvm::drop_begin(inlinedBlocks)) {
100     if (hasDynamicAlloca)
101       break;
102     hasDynamicAlloca =
103         llvm::any_of(block.getOps<LLVM::AllocaOp>(), [](auto allocaOp) {
104           return !matchPattern(allocaOp.getArraySize(), m_Constant());
105         });
106   }
107   if (allocasToMove.empty() && !hasDynamicAlloca)
108     return;
109   OpBuilder builder(calleeEntryBlock, calleeEntryBlock->begin());
110   Value stackPtr;
111   if (hasDynamicAlloca) {
112     // This may result in multiple stacksave/stackrestore intrinsics in the same
113     // scope if some are already present in the body of the caller. This is not
114     // invalid IR, but LLVM cleans these up in InstCombineCalls.cpp, along with
115     // other cases where the stacksave/stackrestore is redundant.
116     stackPtr = builder.create<LLVM::StackSaveOp>(
117         call->getLoc(), LLVM::LLVMPointerType::get(call->getContext()));
118   }
119   builder.setInsertionPointToStart(callerEntryBlock);
120   for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
121     auto newConstant = builder.create<LLVM::ConstantOp>(
122         allocaOp->getLoc(), allocaOp.getArraySize().getType(), arraySize);
123     // Insert a lifetime start intrinsic where the alloca was before moving it.
124     if (shouldInsertLifetime) {
125       OpBuilder::InsertionGuard insertionGuard(builder);
126       builder.setInsertionPoint(allocaOp);
127       builder.create<LLVM::LifetimeStartOp>(
128           allocaOp.getLoc(), arraySize.getValue().getLimitedValue(),
129           allocaOp.getResult());
130     }
131     allocaOp->moveAfter(newConstant);
132     allocaOp.getArraySizeMutable().assign(newConstant.getResult());
133   }
134   if (!shouldInsertLifetimes && !hasDynamicAlloca)
135     return;
136   // Insert a lifetime end intrinsic before each return in the callee function.
137   for (Block &block : inlinedBlocks) {
138     if (!block.getTerminator()->hasTrait<OpTrait::ReturnLike>())
139       continue;
140     builder.setInsertionPoint(block.getTerminator());
141     if (hasDynamicAlloca)
142       builder.create<LLVM::StackRestoreOp>(call->getLoc(), stackPtr);
143     for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
144       if (shouldInsertLifetime)
145         builder.create<LLVM::LifetimeEndOp>(
146             allocaOp.getLoc(), arraySize.getValue().getLimitedValue(),
147             allocaOp.getResult());
148     }
149   }
150 }
151 
152 /// Maps all alias scopes in the inlined operations to deep clones of the scopes
153 /// and domain. This is required for code such as `foo(a, b); foo(a2, b2);` to
154 /// not incorrectly return `noalias` for e.g. operations on `a` and `a2`.
155 static void
156 deepCloneAliasScopes(iterator_range<Region::iterator> inlinedBlocks) {
157   DenseMap<Attribute, Attribute> mapping;
158 
159   // Register handles in the walker to create the deep clones.
160   // The walker ensures that an attribute is only ever walked once and does a
161   // post-order walk, ensuring the domain is visited prior to the scope.
162   AttrTypeWalker walker;
163 
164   // Perform the deep clones while visiting. Builders create a distinct
165   // attribute to make sure that new instances are always created by the
166   // uniquer.
167   walker.addWalk([&](LLVM::AliasScopeDomainAttr domainAttr) {
168     mapping[domainAttr] = LLVM::AliasScopeDomainAttr::get(
169         domainAttr.getContext(), domainAttr.getDescription());
170   });
171 
172   walker.addWalk([&](LLVM::AliasScopeAttr scopeAttr) {
173     mapping[scopeAttr] = LLVM::AliasScopeAttr::get(
174         cast<LLVM::AliasScopeDomainAttr>(mapping.lookup(scopeAttr.getDomain())),
175         scopeAttr.getDescription());
176   });
177 
178   // Map an array of scopes to an array of deep clones.
179   auto convertScopeList = [&](ArrayAttr arrayAttr) -> ArrayAttr {
180     if (!arrayAttr)
181       return nullptr;
182 
183     // Create the deep clones if necessary.
184     walker.walk(arrayAttr);
185 
186     return ArrayAttr::get(arrayAttr.getContext(),
187                           llvm::map_to_vector(arrayAttr, [&](Attribute attr) {
188                             return mapping.lookup(attr);
189                           }));
190   };
191 
192   for (Block &block : inlinedBlocks) {
193     block.walk([&](Operation *op) {
194       if (auto aliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(op)) {
195         aliasInterface.setAliasScopes(
196             convertScopeList(aliasInterface.getAliasScopesOrNull()));
197         aliasInterface.setNoAliasScopes(
198             convertScopeList(aliasInterface.getNoAliasScopesOrNull()));
199       }
200 
201       if (auto noAliasScope = dyn_cast<LLVM::NoAliasScopeDeclOp>(op)) {
202         // Create the deep clones if necessary.
203         walker.walk(noAliasScope.getScopeAttr());
204 
205         noAliasScope.setScopeAttr(cast<LLVM::AliasScopeAttr>(
206             mapping.lookup(noAliasScope.getScopeAttr())));
207       }
208     });
209   }
210 }
211 
212 /// Creates a new ArrayAttr by concatenating `lhs` with `rhs`.
213 /// Returns null if both parameters are null. If only one attribute is null,
214 /// return the other.
215 static ArrayAttr concatArrayAttr(ArrayAttr lhs, ArrayAttr rhs) {
216   if (!lhs)
217     return rhs;
218   if (!rhs)
219     return lhs;
220 
221   SmallVector<Attribute> result;
222   llvm::append_range(result, lhs);
223   llvm::append_range(result, rhs);
224   return ArrayAttr::get(lhs.getContext(), result);
225 }
226 
227 /// Attempts to return the set of all underlying pointer values that
228 /// `pointerValue` is based on. This function traverses through select
229 /// operations and block arguments.
230 static FailureOr<SmallVector<Value>>
231 getUnderlyingObjectSet(Value pointerValue) {
232   SmallVector<Value> result;
233   WalkContinuation walkResult = walkSlice(pointerValue, [&](Value val) {
234     // Attempt to advance to the source of the underlying view-like operation.
235     // Examples of view-like operations include GEPOp and AddrSpaceCastOp.
236     if (auto viewOp = val.getDefiningOp<ViewLikeOpInterface>())
237       return WalkContinuation::advanceTo(viewOp.getViewSource());
238 
239     // Attempt to advance to control flow predecessors.
240     std::optional<SmallVector<Value>> controlFlowPredecessors =
241         getControlFlowPredecessors(val);
242     if (controlFlowPredecessors)
243       return WalkContinuation::advanceTo(*controlFlowPredecessors);
244 
245     // For all non-control flow results, consider `val` an underlying object.
246     if (isa<OpResult>(val)) {
247       result.push_back(val);
248       return WalkContinuation::skip();
249     }
250 
251     // If this place is reached, `val` is a block argument that is not
252     // understood. Therefore, we conservatively interrupt.
253     // Note: Dealing with function arguments is not necessary, as the slice
254     // would have to go through an SSACopyOp first.
255     return WalkContinuation::interrupt();
256   });
257 
258   if (walkResult.wasInterrupted())
259     return failure();
260 
261   return result;
262 }
263 
264 /// Creates a new AliasScopeAttr for every noalias parameter and attaches it to
265 /// the appropriate inlined memory operations in an attempt to preserve the
266 /// original semantics of the parameter attribute.
267 static void createNewAliasScopesFromNoAliasParameter(
268     Operation *call, iterator_range<Region::iterator> inlinedBlocks) {
269 
270   // First, collect all ssa copy operations, which correspond to function
271   // parameters, and additionally store the noalias parameters. All parameters
272   // have been marked by the `handleArgument` implementation by using the
273   // `ssa.copy` intrinsic. Additionally, noalias parameters have an attached
274   // `noalias` attribute to the intrinsics. These intrinsics are only meant to
275   // be temporary and should therefore be deleted after we're done using them
276   // here.
277   SetVector<LLVM::SSACopyOp> ssaCopies;
278   SetVector<LLVM::SSACopyOp> noAliasParams;
279   for (Value argument : cast<LLVM::CallOp>(call).getArgOperands()) {
280     for (Operation *user : argument.getUsers()) {
281       auto ssaCopy = llvm::dyn_cast<LLVM::SSACopyOp>(user);
282       if (!ssaCopy)
283         continue;
284       ssaCopies.insert(ssaCopy);
285 
286       if (!ssaCopy->hasAttr(LLVM::LLVMDialect::getNoAliasAttrName()))
287         continue;
288       noAliasParams.insert(ssaCopy);
289     }
290   }
291 
292   // Scope exit block to make it impossible to forget to get rid of the
293   // intrinsics.
294   auto exit = llvm::make_scope_exit([&] {
295     for (LLVM::SSACopyOp ssaCopyOp : ssaCopies) {
296       ssaCopyOp.replaceAllUsesWith(ssaCopyOp.getOperand());
297       ssaCopyOp->erase();
298     }
299   });
300 
301   // If there were no noalias parameters, we have nothing to do here.
302   if (noAliasParams.empty())
303     return;
304 
305   // Create a new domain for this specific inlining and a new scope for every
306   // noalias parameter.
307   auto functionDomain = LLVM::AliasScopeDomainAttr::get(
308       call->getContext(), cast<LLVM::CallOp>(call).getCalleeAttr().getAttr());
309   DenseMap<Value, LLVM::AliasScopeAttr> pointerScopes;
310   for (LLVM::SSACopyOp copyOp : noAliasParams) {
311     auto scope = LLVM::AliasScopeAttr::get(functionDomain);
312     pointerScopes[copyOp] = scope;
313 
314     OpBuilder(call).create<LLVM::NoAliasScopeDeclOp>(call->getLoc(), scope);
315   }
316 
317   // Go through every instruction and attempt to find which noalias parameters
318   // it is definitely based on and definitely not based on.
319   for (Block &inlinedBlock : inlinedBlocks) {
320     inlinedBlock.walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) {
321       // Collect the pointer arguments affected by the alias scopes.
322       SmallVector<Value> pointerArgs = aliasInterface.getAccessedOperands();
323 
324       // Find the set of underlying pointers that this pointer is based on.
325       SmallPtrSet<Value, 4> basedOnPointers;
326       for (Value pointer : pointerArgs) {
327         FailureOr<SmallVector<Value>> underlyingObjectSet =
328             getUnderlyingObjectSet(pointer);
329         if (failed(underlyingObjectSet))
330           return;
331         llvm::copy(*underlyingObjectSet,
332                    std::inserter(basedOnPointers, basedOnPointers.begin()));
333       }
334 
335       bool aliasesOtherKnownObject = false;
336       // Go through the based on pointers and check that they are either:
337       // * Constants that can be ignored (undef, poison, null pointer).
338       // * Based on a pointer parameter.
339       // * Other pointers that we know can't alias with our noalias parameter.
340       //
341       // Any other value might be a pointer based on any noalias parameter that
342       // hasn't been identified. In that case conservatively don't add any
343       // scopes to this operation indicating either aliasing or not aliasing
344       // with any parameter.
345       if (llvm::any_of(basedOnPointers, [&](Value object) {
346             if (matchPattern(object, m_Constant()))
347               return false;
348 
349             if (auto ssaCopy = object.getDefiningOp<LLVM::SSACopyOp>()) {
350               // If that value is based on a noalias parameter, it is guaranteed
351               // to not alias with any other object.
352               aliasesOtherKnownObject |= !noAliasParams.contains(ssaCopy);
353               return false;
354             }
355 
356             if (isa_and_nonnull<LLVM::AllocaOp, LLVM::AddressOfOp>(
357                     object.getDefiningOp())) {
358               aliasesOtherKnownObject = true;
359               return false;
360             }
361             return true;
362           }))
363         return;
364 
365       // Add all noalias parameter scopes to the noalias scope list that we are
366       // not based on.
367       SmallVector<Attribute> noAliasScopes;
368       for (LLVM::SSACopyOp noAlias : noAliasParams) {
369         if (basedOnPointers.contains(noAlias))
370           continue;
371 
372         noAliasScopes.push_back(pointerScopes[noAlias]);
373       }
374 
375       if (!noAliasScopes.empty())
376         aliasInterface.setNoAliasScopes(
377             concatArrayAttr(aliasInterface.getNoAliasScopesOrNull(),
378                             ArrayAttr::get(call->getContext(), noAliasScopes)));
379 
380       // Don't add alias scopes to call operations or operations that might
381       // operate on pointers not based on any noalias parameter.
382       // Since we add all scopes to an operation's noalias list that it
383       // definitely doesn't alias, we mustn't do the same for the alias.scope
384       // list if other objects are involved.
385       //
386       // Consider the following case:
387       // %0 = llvm.alloca
388       // %1 = select %magic, %0, %noalias_param
389       // store 5, %1  (1) noalias=[scope(...)]
390       // ...
391       // store 3, %0  (2) noalias=[scope(noalias_param), scope(...)]
392       //
393       // We can add the scopes of any noalias parameters that aren't
394       // noalias_param's scope to (1) and add all of them to (2). We mustn't add
395       // the scope of noalias_param to the alias.scope list of (1) since
396       // that would mean (2) cannot alias with (1) which is wrong since both may
397       // store to %0.
398       //
399       // In conclusion, only add scopes to the alias.scope list if all pointers
400       // have a corresponding scope.
401       // Call operations are included in this list since we do not know whether
402       // the callee accesses any memory besides the ones passed as its
403       // arguments.
404       if (aliasesOtherKnownObject ||
405           isa<LLVM::CallOp>(aliasInterface.getOperation()))
406         return;
407 
408       SmallVector<Attribute> aliasScopes;
409       for (LLVM::SSACopyOp noAlias : noAliasParams)
410         if (basedOnPointers.contains(noAlias))
411           aliasScopes.push_back(pointerScopes[noAlias]);
412 
413       if (!aliasScopes.empty())
414         aliasInterface.setAliasScopes(
415             concatArrayAttr(aliasInterface.getAliasScopesOrNull(),
416                             ArrayAttr::get(call->getContext(), aliasScopes)));
417     });
418   }
419 }
420 
421 /// Appends any alias scopes of the call operation to any inlined memory
422 /// operation.
423 static void
424 appendCallOpAliasScopes(Operation *call,
425                         iterator_range<Region::iterator> inlinedBlocks) {
426   auto callAliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(call);
427   if (!callAliasInterface)
428     return;
429 
430   ArrayAttr aliasScopes = callAliasInterface.getAliasScopesOrNull();
431   ArrayAttr noAliasScopes = callAliasInterface.getNoAliasScopesOrNull();
432   // If the call has neither alias scopes or noalias scopes we have nothing to
433   // do here.
434   if (!aliasScopes && !noAliasScopes)
435     return;
436 
437   // Simply append the call op's alias and noalias scopes to any operation
438   // implementing AliasAnalysisOpInterface.
439   for (Block &block : inlinedBlocks) {
440     block.walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) {
441       if (aliasScopes)
442         aliasInterface.setAliasScopes(concatArrayAttr(
443             aliasInterface.getAliasScopesOrNull(), aliasScopes));
444 
445       if (noAliasScopes)
446         aliasInterface.setNoAliasScopes(concatArrayAttr(
447             aliasInterface.getNoAliasScopesOrNull(), noAliasScopes));
448     });
449   }
450 }
451 
452 /// Handles all interactions with alias scopes during inlining.
453 static void handleAliasScopes(Operation *call,
454                               iterator_range<Region::iterator> inlinedBlocks) {
455   deepCloneAliasScopes(inlinedBlocks);
456   createNewAliasScopesFromNoAliasParameter(call, inlinedBlocks);
457   appendCallOpAliasScopes(call, inlinedBlocks);
458 }
459 
460 /// Appends any access groups of the call operation to any inlined memory
461 /// operation.
462 static void handleAccessGroups(Operation *call,
463                                iterator_range<Region::iterator> inlinedBlocks) {
464   auto callAccessGroupInterface = dyn_cast<LLVM::AccessGroupOpInterface>(call);
465   if (!callAccessGroupInterface)
466     return;
467 
468   auto accessGroups = callAccessGroupInterface.getAccessGroupsOrNull();
469   if (!accessGroups)
470     return;
471 
472   // Simply append the call op's access groups to any operation implementing
473   // AccessGroupOpInterface.
474   for (Block &block : inlinedBlocks)
475     for (auto accessGroupOpInterface :
476          block.getOps<LLVM::AccessGroupOpInterface>())
477       accessGroupOpInterface.setAccessGroups(concatArrayAttr(
478           accessGroupOpInterface.getAccessGroupsOrNull(), accessGroups));
479 }
480 
481 /// Updates locations inside loop annotations to reflect that they were inlined.
482 static void
483 handleLoopAnnotations(Operation *call,
484                       iterator_range<Region::iterator> inlinedBlocks) {
485   // Attempt to extract a DISubprogram from the callee.
486   auto func = call->getParentOfType<FunctionOpInterface>();
487   if (!func)
488     return;
489   LocationAttr funcLoc = func->getLoc();
490   auto fusedLoc = dyn_cast_if_present<FusedLoc>(funcLoc);
491   if (!fusedLoc)
492     return;
493   auto scope =
494       dyn_cast_if_present<LLVM::DISubprogramAttr>(fusedLoc.getMetadata());
495   if (!scope)
496     return;
497 
498   // Helper to build a new fused location that reflects the inlining of the loop
499   // annotation.
500   auto updateLoc = [&](FusedLoc loc) -> FusedLoc {
501     if (!loc)
502       return {};
503     Location callSiteLoc = CallSiteLoc::get(loc, call->getLoc());
504     return FusedLoc::get(loc.getContext(), callSiteLoc, scope);
505   };
506 
507   AttrTypeReplacer replacer;
508   replacer.addReplacement([&](LLVM::LoopAnnotationAttr loopAnnotation)
509                               -> std::pair<Attribute, WalkResult> {
510     FusedLoc newStartLoc = updateLoc(loopAnnotation.getStartLoc());
511     FusedLoc newEndLoc = updateLoc(loopAnnotation.getEndLoc());
512     if (!newStartLoc && !newEndLoc)
513       return {loopAnnotation, WalkResult::advance()};
514     auto newLoopAnnotation = LLVM::LoopAnnotationAttr::get(
515         loopAnnotation.getContext(), loopAnnotation.getDisableNonforced(),
516         loopAnnotation.getVectorize(), loopAnnotation.getInterleave(),
517         loopAnnotation.getUnroll(), loopAnnotation.getUnrollAndJam(),
518         loopAnnotation.getLicm(), loopAnnotation.getDistribute(),
519         loopAnnotation.getPipeline(), loopAnnotation.getPeeled(),
520         loopAnnotation.getUnswitch(), loopAnnotation.getMustProgress(),
521         loopAnnotation.getIsVectorized(), newStartLoc, newEndLoc,
522         loopAnnotation.getParallelAccesses());
523     // Needs to advance, as loop annotations can be nested.
524     return {newLoopAnnotation, WalkResult::advance()};
525   });
526 
527   for (Block &block : inlinedBlocks)
528     for (Operation &op : block)
529       replacer.recursivelyReplaceElementsIn(&op);
530 }
531 
532 /// If `requestedAlignment` is higher than the alignment specified on `alloca`,
533 /// realigns `alloca` if this does not exceed the natural stack alignment.
534 /// Returns the post-alignment of `alloca`, whether it was realigned or not.
535 static uint64_t tryToEnforceAllocaAlignment(LLVM::AllocaOp alloca,
536                                             uint64_t requestedAlignment,
537                                             DataLayout const &dataLayout) {
538   uint64_t allocaAlignment = alloca.getAlignment().value_or(1);
539   if (requestedAlignment <= allocaAlignment)
540     // No realignment necessary.
541     return allocaAlignment;
542   uint64_t naturalStackAlignmentBits = dataLayout.getStackAlignment();
543   // If the natural stack alignment is not specified, the data layout returns
544   // zero. Optimistically allow realignment in this case.
545   if (naturalStackAlignmentBits == 0 ||
546       // If the requested alignment exceeds the natural stack alignment, this
547       // will trigger a dynamic stack realignment, so we prefer to copy...
548       8 * requestedAlignment <= naturalStackAlignmentBits ||
549       // ...unless the alloca already triggers dynamic stack realignment. Then
550       // we might as well further increase the alignment to avoid a copy.
551       8 * allocaAlignment > naturalStackAlignmentBits) {
552     alloca.setAlignment(requestedAlignment);
553     allocaAlignment = requestedAlignment;
554   }
555   return allocaAlignment;
556 }
557 
558 /// Tries to find and return the alignment of the pointer `value` by looking for
559 /// an alignment attribute on the defining allocation op or function argument.
560 /// If the found alignment is lower than `requestedAlignment`, tries to realign
561 /// the pointer, then returns the resulting post-alignment, regardless of
562 /// whether it was realigned or not. If no existing alignment attribute is
563 /// found, returns 1 (i.e., assume that no alignment is guaranteed).
564 static uint64_t tryToEnforceAlignment(Value value, uint64_t requestedAlignment,
565                                       DataLayout const &dataLayout) {
566   if (Operation *definingOp = value.getDefiningOp()) {
567     if (auto alloca = dyn_cast<LLVM::AllocaOp>(definingOp))
568       return tryToEnforceAllocaAlignment(alloca, requestedAlignment,
569                                          dataLayout);
570     if (auto addressOf = dyn_cast<LLVM::AddressOfOp>(definingOp))
571       if (auto global = SymbolTable::lookupNearestSymbolFrom<LLVM::GlobalOp>(
572               definingOp, addressOf.getGlobalNameAttr()))
573         return global.getAlignment().value_or(1);
574     // We don't currently handle this operation; assume no alignment.
575     return 1;
576   }
577   // Since there is no defining op, this is a block argument. Probably this
578   // comes directly from a function argument, so check that this is the case.
579   Operation *parentOp = value.getParentBlock()->getParentOp();
580   if (auto func = dyn_cast<LLVM::LLVMFuncOp>(parentOp)) {
581     // Use the alignment attribute set for this argument in the parent function
582     // if it has been set.
583     auto blockArg = llvm::cast<BlockArgument>(value);
584     if (Attribute alignAttr = func.getArgAttr(
585             blockArg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName()))
586       return cast<IntegerAttr>(alignAttr).getValue().getLimitedValue();
587   }
588   // We didn't find anything useful; assume no alignment.
589   return 1;
590 }
591 
592 /// Introduces a new alloca and copies the memory pointed to by `argument` to
593 /// the address of the new alloca, then returns the value of the new alloca.
594 static Value handleByValArgumentInit(OpBuilder &builder, Location loc,
595                                      Value argument, Type elementType,
596                                      uint64_t elementTypeSize,
597                                      uint64_t targetAlignment) {
598   // Allocate the new value on the stack.
599   Value allocaOp;
600   {
601     // Since this is a static alloca, we can put it directly in the entry block,
602     // so they can be absorbed into the prologue/epilogue at code generation.
603     OpBuilder::InsertionGuard insertionGuard(builder);
604     Block *entryBlock = &(*argument.getParentRegion()->begin());
605     builder.setInsertionPointToStart(entryBlock);
606     Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
607                                                  builder.getI64IntegerAttr(1));
608     allocaOp = builder.create<LLVM::AllocaOp>(
609         loc, argument.getType(), elementType, one, targetAlignment);
610   }
611   // Copy the pointee to the newly allocated value.
612   Value copySize = builder.create<LLVM::ConstantOp>(
613       loc, builder.getI64Type(), builder.getI64IntegerAttr(elementTypeSize));
614   builder.create<LLVM::MemcpyOp>(loc, allocaOp, argument, copySize,
615                                  /*isVolatile=*/false);
616   return allocaOp;
617 }
618 
619 /// Handles a function argument marked with the byval attribute by introducing a
620 /// memcpy or realigning the defining operation, if required either due to the
621 /// pointee being writeable in the callee, and/or due to an alignment mismatch.
622 /// `requestedAlignment` specifies the alignment set in the "align" argument
623 /// attribute (or 1 if no align attribute was set).
624 static Value handleByValArgument(OpBuilder &builder, Operation *callable,
625                                  Value argument, Type elementType,
626                                  uint64_t requestedAlignment) {
627   auto func = cast<LLVM::LLVMFuncOp>(callable);
628   LLVM::MemoryEffectsAttr memoryEffects = func.getMemoryEffectsAttr();
629   // If there is no memory effects attribute, assume that the function is
630   // not read-only.
631   bool isReadOnly = memoryEffects &&
632                     memoryEffects.getArgMem() != LLVM::ModRefInfo::ModRef &&
633                     memoryEffects.getArgMem() != LLVM::ModRefInfo::Mod;
634   // Check if there's an alignment mismatch requiring us to copy.
635   DataLayout dataLayout = DataLayout::closest(callable);
636   uint64_t minimumAlignment = dataLayout.getTypeABIAlignment(elementType);
637   if (isReadOnly) {
638     if (requestedAlignment <= minimumAlignment)
639       return argument;
640     uint64_t currentAlignment =
641         tryToEnforceAlignment(argument, requestedAlignment, dataLayout);
642     if (currentAlignment >= requestedAlignment)
643       return argument;
644   }
645   uint64_t targetAlignment = std::max(requestedAlignment, minimumAlignment);
646   return handleByValArgumentInit(
647       builder, argument.getLoc(), argument, elementType,
648       dataLayout.getTypeSize(elementType), targetAlignment);
649 }
650 
651 namespace {
652 struct LLVMInlinerInterface : public DialectInlinerInterface {
653   using DialectInlinerInterface::DialectInlinerInterface;
654 
655   LLVMInlinerInterface(Dialect *dialect)
656       : DialectInlinerInterface(dialect),
657         // Cache set of StringAttrs for fast lookup in `isLegalToInline`.
658         disallowedFunctionAttrs({
659             StringAttr::get(dialect->getContext(), "noduplicate"),
660             StringAttr::get(dialect->getContext(), "presplitcoroutine"),
661             StringAttr::get(dialect->getContext(), "returns_twice"),
662             StringAttr::get(dialect->getContext(), "strictfp"),
663         }) {}
664 
665   bool isLegalToInline(Operation *call, Operation *callable,
666                        bool wouldBeCloned) const final {
667     if (!isa<LLVM::CallOp>(call)) {
668       LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is not an '"
669                               << LLVM::CallOp::getOperationName() << "' op\n");
670       return false;
671     }
672     auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable);
673     if (!funcOp) {
674       LLVM_DEBUG(llvm::dbgs()
675                  << "Cannot inline: callable is not an '"
676                  << LLVM::LLVMFuncOp::getOperationName() << "' op\n");
677       return false;
678     }
679     if (funcOp.isNoInline()) {
680       LLVM_DEBUG(llvm::dbgs()
681                  << "Cannot inline: function is marked no_inline\n");
682       return false;
683     }
684     if (funcOp.isVarArg()) {
685       LLVM_DEBUG(llvm::dbgs() << "Cannot inline: callable is variadic\n");
686       return false;
687     }
688     // TODO: Generate aliasing metadata from noalias result attributes.
689     if (auto attrs = funcOp.getArgAttrs()) {
690       for (DictionaryAttr attrDict : attrs->getAsRange<DictionaryAttr>()) {
691         if (attrDict.contains(LLVM::LLVMDialect::getInAllocaAttrName())) {
692           LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName()
693                                   << ": inalloca arguments not supported\n");
694           return false;
695         }
696       }
697     }
698     // TODO: Handle exceptions.
699     if (funcOp.getPersonality()) {
700       LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName()
701                               << ": unhandled function personality\n");
702       return false;
703     }
704     if (funcOp.getPassthrough()) {
705       // TODO: Used attributes should not be passthrough.
706       if (llvm::any_of(*funcOp.getPassthrough(), [&](Attribute attr) {
707             auto stringAttr = dyn_cast<StringAttr>(attr);
708             if (!stringAttr)
709               return false;
710             if (disallowedFunctionAttrs.contains(stringAttr)) {
711               LLVM_DEBUG(llvm::dbgs()
712                          << "Cannot inline " << funcOp.getSymName()
713                          << ": found disallowed function attribute "
714                          << stringAttr << "\n");
715               return true;
716             }
717             return false;
718           }))
719         return false;
720     }
721     return true;
722   }
723 
724   bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
725     return true;
726   }
727 
728   bool isLegalToInline(Operation *op, Region *, bool, IRMapping &) const final {
729     // The inliner cannot handle variadic function arguments.
730     return !isa<LLVM::VaStartOp>(op);
731   }
732 
733   /// Handle the given inlined return by replacing it with a branch. This
734   /// overload is called when the inlined region has more than one block.
735   void handleTerminator(Operation *op, Block *newDest) const final {
736     // Only return needs to be handled here.
737     auto returnOp = dyn_cast<LLVM::ReturnOp>(op);
738     if (!returnOp)
739       return;
740 
741     // Replace the return with a branch to the dest.
742     OpBuilder builder(op);
743     builder.create<LLVM::BrOp>(op->getLoc(), returnOp.getOperands(), newDest);
744     op->erase();
745   }
746 
747   bool allowSingleBlockOptimization(
748       iterator_range<Region::iterator> inlinedBlocks) const final {
749     if (!inlinedBlocks.empty() &&
750         isa<LLVM::UnreachableOp>(inlinedBlocks.begin()->getTerminator()))
751       return false;
752     return true;
753   }
754 
755   /// Handle the given inlined return by replacing the uses of the call with the
756   /// operands of the return. This overload is called when the inlined region
757   /// only contains one block.
758   void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
759     // Return will be the only terminator present.
760     auto returnOp = cast<LLVM::ReturnOp>(op);
761 
762     // Replace the values directly with the return operands.
763     assert(returnOp.getNumOperands() == valuesToRepl.size());
764     for (auto [dst, src] : llvm::zip(valuesToRepl, returnOp.getOperands()))
765       dst.replaceAllUsesWith(src);
766   }
767 
768   Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
769                        Value argument,
770                        DictionaryAttr argumentAttrs) const final {
771     if (std::optional<NamedAttribute> attr =
772             argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName())) {
773       Type elementType = cast<TypeAttr>(attr->getValue()).getValue();
774       uint64_t requestedAlignment = 1;
775       if (std::optional<NamedAttribute> alignAttr =
776               argumentAttrs.getNamed(LLVM::LLVMDialect::getAlignAttrName())) {
777         requestedAlignment = cast<IntegerAttr>(alignAttr->getValue())
778                                  .getValue()
779                                  .getLimitedValue();
780       }
781       return handleByValArgument(builder, callable, argument, elementType,
782                                  requestedAlignment);
783     }
784 
785     // This code is essentially a workaround for deficiencies in the inliner
786     // interface: We need to transform operations *after* inlined based on the
787     // argument attributes of the parameters *before* inlining. This method runs
788     // prior to actual inlining and thus cannot transform the post-inlining
789     // code, while `processInlinedCallBlocks` does not have access to
790     // pre-inlining function arguments. Additionally, it is required to
791     // distinguish which parameter an SSA value originally came from. As a
792     // workaround until this is changed: Create an ssa.copy intrinsic with the
793     // noalias attribute (when it was present before) that can easily be found,
794     // and is extremely unlikely to exist in the code prior to inlining, using
795     // this to communicate between this method and `processInlinedCallBlocks`.
796     // TODO: Fix this by refactoring the inliner interface.
797     auto copyOp = builder.create<LLVM::SSACopyOp>(call->getLoc(), argument);
798     if (argumentAttrs.contains(LLVM::LLVMDialect::getNoAliasAttrName()))
799       copyOp->setDiscardableAttr(
800           builder.getStringAttr(LLVM::LLVMDialect::getNoAliasAttrName()),
801           builder.getUnitAttr());
802     return copyOp;
803   }
804 
805   void processInlinedCallBlocks(
806       Operation *call,
807       iterator_range<Region::iterator> inlinedBlocks) const override {
808     handleInlinedAllocas(call, inlinedBlocks);
809     handleAliasScopes(call, inlinedBlocks);
810     handleAccessGroups(call, inlinedBlocks);
811     handleLoopAnnotations(call, inlinedBlocks);
812   }
813 
814   // Keeping this (immutable) state on the interface allows us to look up
815   // StringAttrs instead of looking up strings, since StringAttrs are bound to
816   // the current context and thus cannot be initialized as static fields.
817   const DenseSet<StringAttr> disallowedFunctionAttrs;
818 };
819 
820 } // end anonymous namespace
821 
822 void mlir::LLVM::registerInlinerInterface(DialectRegistry &registry) {
823   registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
824     dialect->addInterfaces<LLVMInlinerInterface>();
825   });
826 }
827 
828 void mlir::NVVM::registerInlinerInterface(DialectRegistry &registry) {
829   registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) {
830     dialect->addInterfaces<LLVMInlinerInterface>();
831   });
832 }
833