xref: /llvm-project/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp (revision d0c9e70bcc40948821e83eb0ec32e6e15fb0dd4b)
14c2f90f3SChristian Ulmann //===- InlinerInterfaceImpl.cpp - Inlining for LLVM the dialect -----------===//
24c2f90f3SChristian Ulmann //
34c2f90f3SChristian Ulmann // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44c2f90f3SChristian Ulmann // See https://llvm.org/LICENSE.txt for license information.
54c2f90f3SChristian Ulmann // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64c2f90f3SChristian Ulmann //
74c2f90f3SChristian Ulmann //===----------------------------------------------------------------------===//
84c2f90f3SChristian Ulmann //
94c2f90f3SChristian Ulmann // Logic for inlining LLVM functions and the definition of the
104c2f90f3SChristian Ulmann // LLVMInliningInterface.
114c2f90f3SChristian Ulmann //
124c2f90f3SChristian Ulmann //===----------------------------------------------------------------------===//
134c2f90f3SChristian Ulmann 
144c2f90f3SChristian Ulmann #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
1514153654SChristian Ulmann #include "mlir/Analysis/SliceWalk.h"
164c2f90f3SChristian Ulmann #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
178d306ccdSWilliam Moses #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
184c2f90f3SChristian Ulmann #include "mlir/IR/Matchers.h"
194c2f90f3SChristian Ulmann #include "mlir/Interfaces/DataLayoutInterfaces.h"
208e2ccdc4STobias Gysi #include "mlir/Interfaces/ViewLikeInterface.h"
214c2f90f3SChristian Ulmann #include "mlir/Transforms/InliningUtils.h"
224c2f90f3SChristian Ulmann #include "llvm/ADT/ScopeExit.h"
234c2f90f3SChristian Ulmann #include "llvm/Support/Debug.h"
244c2f90f3SChristian Ulmann 
254c2f90f3SChristian Ulmann #define DEBUG_TYPE "llvm-inliner"
264c2f90f3SChristian Ulmann 
274c2f90f3SChristian Ulmann using namespace mlir;
284c2f90f3SChristian Ulmann 
294c2f90f3SChristian Ulmann /// Check whether the given alloca is an input to a lifetime intrinsic,
304c2f90f3SChristian Ulmann /// optionally passing through one or more casts on the way. This is not
314c2f90f3SChristian Ulmann /// transitive through block arguments.
324c2f90f3SChristian Ulmann static bool hasLifetimeMarkers(LLVM::AllocaOp allocaOp) {
334c2f90f3SChristian Ulmann   SmallVector<Operation *> stack(allocaOp->getUsers().begin(),
344c2f90f3SChristian Ulmann                                  allocaOp->getUsers().end());
354c2f90f3SChristian Ulmann   while (!stack.empty()) {
364c2f90f3SChristian Ulmann     Operation *op = stack.pop_back_val();
374c2f90f3SChristian Ulmann     if (isa<LLVM::LifetimeStartOp, LLVM::LifetimeEndOp>(op))
384c2f90f3SChristian Ulmann       return true;
394c2f90f3SChristian Ulmann     if (isa<LLVM::BitcastOp>(op))
404c2f90f3SChristian Ulmann       stack.append(op->getUsers().begin(), op->getUsers().end());
414c2f90f3SChristian Ulmann   }
424c2f90f3SChristian Ulmann   return false;
434c2f90f3SChristian Ulmann }
444c2f90f3SChristian Ulmann 
454c2f90f3SChristian Ulmann /// Handles alloca operations in the inlined blocks:
464c2f90f3SChristian Ulmann /// - Moves all alloca operations with a constant size in the former entry block
474c2f90f3SChristian Ulmann ///   of the callee into the entry block of the caller, so they become part of
484c2f90f3SChristian Ulmann ///   the function prologue/epilogue during code generation.
494c2f90f3SChristian Ulmann /// - Inserts lifetime intrinsics that limit the scope of inlined static allocas
504c2f90f3SChristian Ulmann ///   to the inlined blocks.
514c2f90f3SChristian Ulmann /// - Inserts StackSave and StackRestore operations if dynamic allocas were
524c2f90f3SChristian Ulmann ///   inlined.
534c2f90f3SChristian Ulmann static void
544c2f90f3SChristian Ulmann handleInlinedAllocas(Operation *call,
554c2f90f3SChristian Ulmann                      iterator_range<Region::iterator> inlinedBlocks) {
564c2f90f3SChristian Ulmann   // Locate the entry block of the closest callsite ancestor that has either the
574c2f90f3SChristian Ulmann   // IsolatedFromAbove or AutomaticAllocationScope trait. In pure LLVM dialect
584c2f90f3SChristian Ulmann   // programs, this is the LLVMFuncOp containing the call site. However, in
594c2f90f3SChristian Ulmann   // mixed-dialect programs, the callsite might be nested in another operation
604c2f90f3SChristian Ulmann   // that carries one of these traits. In such scenarios, this traversal stops
614c2f90f3SChristian Ulmann   // at the closest ancestor with either trait, ensuring visibility post
624c2f90f3SChristian Ulmann   // relocation and respecting allocation scopes.
634c2f90f3SChristian Ulmann   Block *callerEntryBlock = nullptr;
644c2f90f3SChristian Ulmann   Operation *currentOp = call;
654c2f90f3SChristian Ulmann   while (Operation *parentOp = currentOp->getParentOp()) {
664c2f90f3SChristian Ulmann     if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
674c2f90f3SChristian Ulmann         parentOp->mightHaveTrait<OpTrait::AutomaticAllocationScope>()) {
684c2f90f3SChristian Ulmann       callerEntryBlock = &currentOp->getParentRegion()->front();
694c2f90f3SChristian Ulmann       break;
704c2f90f3SChristian Ulmann     }
714c2f90f3SChristian Ulmann     currentOp = parentOp;
724c2f90f3SChristian Ulmann   }
734c2f90f3SChristian Ulmann 
744c2f90f3SChristian Ulmann   // Avoid relocating the alloca operations if the call has been inlined into
754c2f90f3SChristian Ulmann   // the entry block already, which is typically the encompassing
764c2f90f3SChristian Ulmann   // LLVM function, or if the relevant entry block cannot be identified.
774c2f90f3SChristian Ulmann   Block *calleeEntryBlock = &(*inlinedBlocks.begin());
784c2f90f3SChristian Ulmann   if (!callerEntryBlock || callerEntryBlock == calleeEntryBlock)
794c2f90f3SChristian Ulmann     return;
804c2f90f3SChristian Ulmann 
814c2f90f3SChristian Ulmann   SmallVector<std::tuple<LLVM::AllocaOp, IntegerAttr, bool>> allocasToMove;
824c2f90f3SChristian Ulmann   bool shouldInsertLifetimes = false;
834c2f90f3SChristian Ulmann   bool hasDynamicAlloca = false;
844c2f90f3SChristian Ulmann   // Conservatively only move static alloca operations that are part of the
854c2f90f3SChristian Ulmann   // entry block and do not inspect nested regions, since they may execute
864c2f90f3SChristian Ulmann   // conditionally or have other unknown semantics.
874c2f90f3SChristian Ulmann   for (auto allocaOp : calleeEntryBlock->getOps<LLVM::AllocaOp>()) {
884c2f90f3SChristian Ulmann     IntegerAttr arraySize;
894c2f90f3SChristian Ulmann     if (!matchPattern(allocaOp.getArraySize(), m_Constant(&arraySize))) {
904c2f90f3SChristian Ulmann       hasDynamicAlloca = true;
914c2f90f3SChristian Ulmann       continue;
924c2f90f3SChristian Ulmann     }
934c2f90f3SChristian Ulmann     bool shouldInsertLifetime =
944c2f90f3SChristian Ulmann         arraySize.getValue() != 0 && !hasLifetimeMarkers(allocaOp);
954c2f90f3SChristian Ulmann     shouldInsertLifetimes |= shouldInsertLifetime;
964c2f90f3SChristian Ulmann     allocasToMove.emplace_back(allocaOp, arraySize, shouldInsertLifetime);
974c2f90f3SChristian Ulmann   }
984c2f90f3SChristian Ulmann   // Check the remaining inlined blocks for dynamic allocas as well.
994c2f90f3SChristian Ulmann   for (Block &block : llvm::drop_begin(inlinedBlocks)) {
1004c2f90f3SChristian Ulmann     if (hasDynamicAlloca)
1014c2f90f3SChristian Ulmann       break;
1024c2f90f3SChristian Ulmann     hasDynamicAlloca =
1034c2f90f3SChristian Ulmann         llvm::any_of(block.getOps<LLVM::AllocaOp>(), [](auto allocaOp) {
1044c2f90f3SChristian Ulmann           return !matchPattern(allocaOp.getArraySize(), m_Constant());
1054c2f90f3SChristian Ulmann         });
1064c2f90f3SChristian Ulmann   }
1074c2f90f3SChristian Ulmann   if (allocasToMove.empty() && !hasDynamicAlloca)
1084c2f90f3SChristian Ulmann     return;
1094c2f90f3SChristian Ulmann   OpBuilder builder(calleeEntryBlock, calleeEntryBlock->begin());
1104c2f90f3SChristian Ulmann   Value stackPtr;
1114c2f90f3SChristian Ulmann   if (hasDynamicAlloca) {
1124c2f90f3SChristian Ulmann     // This may result in multiple stacksave/stackrestore intrinsics in the same
1134c2f90f3SChristian Ulmann     // scope if some are already present in the body of the caller. This is not
1144c2f90f3SChristian Ulmann     // invalid IR, but LLVM cleans these up in InstCombineCalls.cpp, along with
1154c2f90f3SChristian Ulmann     // other cases where the stacksave/stackrestore is redundant.
1164c2f90f3SChristian Ulmann     stackPtr = builder.create<LLVM::StackSaveOp>(
1174c2f90f3SChristian Ulmann         call->getLoc(), LLVM::LLVMPointerType::get(call->getContext()));
1184c2f90f3SChristian Ulmann   }
119b613a540SMatthias Springer   builder.setInsertionPointToStart(callerEntryBlock);
1204c2f90f3SChristian Ulmann   for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
1214c2f90f3SChristian Ulmann     auto newConstant = builder.create<LLVM::ConstantOp>(
1224c2f90f3SChristian Ulmann         allocaOp->getLoc(), allocaOp.getArraySize().getType(), arraySize);
1234c2f90f3SChristian Ulmann     // Insert a lifetime start intrinsic where the alloca was before moving it.
1244c2f90f3SChristian Ulmann     if (shouldInsertLifetime) {
1254c2f90f3SChristian Ulmann       OpBuilder::InsertionGuard insertionGuard(builder);
1264c2f90f3SChristian Ulmann       builder.setInsertionPoint(allocaOp);
1274c2f90f3SChristian Ulmann       builder.create<LLVM::LifetimeStartOp>(
1284c2f90f3SChristian Ulmann           allocaOp.getLoc(), arraySize.getValue().getLimitedValue(),
1294c2f90f3SChristian Ulmann           allocaOp.getResult());
1304c2f90f3SChristian Ulmann     }
1314c2f90f3SChristian Ulmann     allocaOp->moveAfter(newConstant);
1324c2f90f3SChristian Ulmann     allocaOp.getArraySizeMutable().assign(newConstant.getResult());
1334c2f90f3SChristian Ulmann   }
1344c2f90f3SChristian Ulmann   if (!shouldInsertLifetimes && !hasDynamicAlloca)
1354c2f90f3SChristian Ulmann     return;
1364c2f90f3SChristian Ulmann   // Insert a lifetime end intrinsic before each return in the callee function.
1374c2f90f3SChristian Ulmann   for (Block &block : inlinedBlocks) {
1384c2f90f3SChristian Ulmann     if (!block.getTerminator()->hasTrait<OpTrait::ReturnLike>())
1394c2f90f3SChristian Ulmann       continue;
1404c2f90f3SChristian Ulmann     builder.setInsertionPoint(block.getTerminator());
1414c2f90f3SChristian Ulmann     if (hasDynamicAlloca)
1424c2f90f3SChristian Ulmann       builder.create<LLVM::StackRestoreOp>(call->getLoc(), stackPtr);
1434c2f90f3SChristian Ulmann     for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
1444c2f90f3SChristian Ulmann       if (shouldInsertLifetime)
1454c2f90f3SChristian Ulmann         builder.create<LLVM::LifetimeEndOp>(
1464c2f90f3SChristian Ulmann             allocaOp.getLoc(), arraySize.getValue().getLimitedValue(),
1474c2f90f3SChristian Ulmann             allocaOp.getResult());
1484c2f90f3SChristian Ulmann     }
1494c2f90f3SChristian Ulmann   }
1504c2f90f3SChristian Ulmann }
1514c2f90f3SChristian Ulmann 
1524c2f90f3SChristian Ulmann /// Maps all alias scopes in the inlined operations to deep clones of the scopes
1534c2f90f3SChristian Ulmann /// and domain. This is required for code such as `foo(a, b); foo(a2, b2);` to
1544c2f90f3SChristian Ulmann /// not incorrectly return `noalias` for e.g. operations on `a` and `a2`.
1554c2f90f3SChristian Ulmann static void
1564c2f90f3SChristian Ulmann deepCloneAliasScopes(iterator_range<Region::iterator> inlinedBlocks) {
1574c2f90f3SChristian Ulmann   DenseMap<Attribute, Attribute> mapping;
1584c2f90f3SChristian Ulmann 
1594c2f90f3SChristian Ulmann   // Register handles in the walker to create the deep clones.
1604c2f90f3SChristian Ulmann   // The walker ensures that an attribute is only ever walked once and does a
1614c2f90f3SChristian Ulmann   // post-order walk, ensuring the domain is visited prior to the scope.
1624c2f90f3SChristian Ulmann   AttrTypeWalker walker;
1634c2f90f3SChristian Ulmann 
1644c2f90f3SChristian Ulmann   // Perform the deep clones while visiting. Builders create a distinct
1654c2f90f3SChristian Ulmann   // attribute to make sure that new instances are always created by the
1664c2f90f3SChristian Ulmann   // uniquer.
1674c2f90f3SChristian Ulmann   walker.addWalk([&](LLVM::AliasScopeDomainAttr domainAttr) {
1684c2f90f3SChristian Ulmann     mapping[domainAttr] = LLVM::AliasScopeDomainAttr::get(
1694c2f90f3SChristian Ulmann         domainAttr.getContext(), domainAttr.getDescription());
1704c2f90f3SChristian Ulmann   });
1714c2f90f3SChristian Ulmann 
1724c2f90f3SChristian Ulmann   walker.addWalk([&](LLVM::AliasScopeAttr scopeAttr) {
1734c2f90f3SChristian Ulmann     mapping[scopeAttr] = LLVM::AliasScopeAttr::get(
1744c2f90f3SChristian Ulmann         cast<LLVM::AliasScopeDomainAttr>(mapping.lookup(scopeAttr.getDomain())),
1754c2f90f3SChristian Ulmann         scopeAttr.getDescription());
1764c2f90f3SChristian Ulmann   });
1774c2f90f3SChristian Ulmann 
1784c2f90f3SChristian Ulmann   // Map an array of scopes to an array of deep clones.
1794c2f90f3SChristian Ulmann   auto convertScopeList = [&](ArrayAttr arrayAttr) -> ArrayAttr {
1804c2f90f3SChristian Ulmann     if (!arrayAttr)
1814c2f90f3SChristian Ulmann       return nullptr;
1824c2f90f3SChristian Ulmann 
1834c2f90f3SChristian Ulmann     // Create the deep clones if necessary.
1844c2f90f3SChristian Ulmann     walker.walk(arrayAttr);
1854c2f90f3SChristian Ulmann 
1864c2f90f3SChristian Ulmann     return ArrayAttr::get(arrayAttr.getContext(),
1874c2f90f3SChristian Ulmann                           llvm::map_to_vector(arrayAttr, [&](Attribute attr) {
1884c2f90f3SChristian Ulmann                             return mapping.lookup(attr);
1894c2f90f3SChristian Ulmann                           }));
1904c2f90f3SChristian Ulmann   };
1914c2f90f3SChristian Ulmann 
1924c2f90f3SChristian Ulmann   for (Block &block : inlinedBlocks) {
1934c2f90f3SChristian Ulmann     block.walk([&](Operation *op) {
1944c2f90f3SChristian Ulmann       if (auto aliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(op)) {
1954c2f90f3SChristian Ulmann         aliasInterface.setAliasScopes(
1964c2f90f3SChristian Ulmann             convertScopeList(aliasInterface.getAliasScopesOrNull()));
1974c2f90f3SChristian Ulmann         aliasInterface.setNoAliasScopes(
1984c2f90f3SChristian Ulmann             convertScopeList(aliasInterface.getNoAliasScopesOrNull()));
1994c2f90f3SChristian Ulmann       }
2004c2f90f3SChristian Ulmann 
2014c2f90f3SChristian Ulmann       if (auto noAliasScope = dyn_cast<LLVM::NoAliasScopeDeclOp>(op)) {
2024c2f90f3SChristian Ulmann         // Create the deep clones if necessary.
2034c2f90f3SChristian Ulmann         walker.walk(noAliasScope.getScopeAttr());
2044c2f90f3SChristian Ulmann 
2054c2f90f3SChristian Ulmann         noAliasScope.setScopeAttr(cast<LLVM::AliasScopeAttr>(
2064c2f90f3SChristian Ulmann             mapping.lookup(noAliasScope.getScopeAttr())));
2074c2f90f3SChristian Ulmann       }
2084c2f90f3SChristian Ulmann     });
2094c2f90f3SChristian Ulmann   }
2104c2f90f3SChristian Ulmann }
2114c2f90f3SChristian Ulmann 
2124c2f90f3SChristian Ulmann /// Creates a new ArrayAttr by concatenating `lhs` with `rhs`.
2134c2f90f3SChristian Ulmann /// Returns null if both parameters are null. If only one attribute is null,
2144c2f90f3SChristian Ulmann /// return the other.
2154c2f90f3SChristian Ulmann static ArrayAttr concatArrayAttr(ArrayAttr lhs, ArrayAttr rhs) {
2164c2f90f3SChristian Ulmann   if (!lhs)
2174c2f90f3SChristian Ulmann     return rhs;
2184c2f90f3SChristian Ulmann   if (!rhs)
2194c2f90f3SChristian Ulmann     return lhs;
2204c2f90f3SChristian Ulmann 
2214c2f90f3SChristian Ulmann   SmallVector<Attribute> result;
2224c2f90f3SChristian Ulmann   llvm::append_range(result, lhs);
2234c2f90f3SChristian Ulmann   llvm::append_range(result, rhs);
2244c2f90f3SChristian Ulmann   return ArrayAttr::get(lhs.getContext(), result);
2254c2f90f3SChristian Ulmann }
2264c2f90f3SChristian Ulmann 
2274c2f90f3SChristian Ulmann /// Attempts to return the set of all underlying pointer values that
2284c2f90f3SChristian Ulmann /// `pointerValue` is based on. This function traverses through select
22914153654SChristian Ulmann /// operations and block arguments.
23014153654SChristian Ulmann static FailureOr<SmallVector<Value>>
23114153654SChristian Ulmann getUnderlyingObjectSet(Value pointerValue) {
2324c2f90f3SChristian Ulmann   SmallVector<Value> result;
23314153654SChristian Ulmann   WalkContinuation walkResult = walkSlice(pointerValue, [&](Value val) {
2348e2ccdc4STobias Gysi     // Attempt to advance to the source of the underlying view-like operation.
2358e2ccdc4STobias Gysi     // Examples of view-like operations include GEPOp and AddrSpaceCastOp.
2368e2ccdc4STobias Gysi     if (auto viewOp = val.getDefiningOp<ViewLikeOpInterface>())
2378e2ccdc4STobias Gysi       return WalkContinuation::advanceTo(viewOp.getViewSource());
2384c2f90f3SChristian Ulmann 
23914153654SChristian Ulmann     // Attempt to advance to control flow predecessors.
24014153654SChristian Ulmann     std::optional<SmallVector<Value>> controlFlowPredecessors =
24114153654SChristian Ulmann         getControlFlowPredecessors(val);
24214153654SChristian Ulmann     if (controlFlowPredecessors)
24314153654SChristian Ulmann       return WalkContinuation::advanceTo(*controlFlowPredecessors);
24414153654SChristian Ulmann 
24514153654SChristian Ulmann     // For all non-control flow results, consider `val` an underlying object.
24614153654SChristian Ulmann     if (isa<OpResult>(val)) {
24714153654SChristian Ulmann       result.push_back(val);
24814153654SChristian Ulmann       return WalkContinuation::skip();
2494c2f90f3SChristian Ulmann     }
2504c2f90f3SChristian Ulmann 
25114153654SChristian Ulmann     // If this place is reached, `val` is a block argument that is not
25214153654SChristian Ulmann     // understood. Therefore, we conservatively interrupt.
25314153654SChristian Ulmann     // Note: Dealing with function arguments is not necessary, as the slice
25414153654SChristian Ulmann     // would have to go through an SSACopyOp first.
25514153654SChristian Ulmann     return WalkContinuation::interrupt();
25614153654SChristian Ulmann   });
2574c2f90f3SChristian Ulmann 
25814153654SChristian Ulmann   if (walkResult.wasInterrupted())
25914153654SChristian Ulmann     return failure();
2604c2f90f3SChristian Ulmann 
2614c2f90f3SChristian Ulmann   return result;
2624c2f90f3SChristian Ulmann }
2634c2f90f3SChristian Ulmann 
2644c2f90f3SChristian Ulmann /// Creates a new AliasScopeAttr for every noalias parameter and attaches it to
2654c2f90f3SChristian Ulmann /// the appropriate inlined memory operations in an attempt to preserve the
2664c2f90f3SChristian Ulmann /// original semantics of the parameter attribute.
2674c2f90f3SChristian Ulmann static void createNewAliasScopesFromNoAliasParameter(
2684c2f90f3SChristian Ulmann     Operation *call, iterator_range<Region::iterator> inlinedBlocks) {
2694c2f90f3SChristian Ulmann 
270065d2d9cSChristian Ulmann   // First, collect all ssa copy operations, which correspond to function
271065d2d9cSChristian Ulmann   // parameters, and additionally store the noalias parameters. All parameters
272065d2d9cSChristian Ulmann   // have been marked by the `handleArgument` implementation by using the
273065d2d9cSChristian Ulmann   // `ssa.copy` intrinsic. Additionally, noalias parameters have an attached
274065d2d9cSChristian Ulmann   // `noalias` attribute to the intrinsics. These intrinsics are only meant to
275065d2d9cSChristian Ulmann   // be temporary and should therefore be deleted after we're done using them
276065d2d9cSChristian Ulmann   // here.
277065d2d9cSChristian Ulmann   SetVector<LLVM::SSACopyOp> ssaCopies;
2784c2f90f3SChristian Ulmann   SetVector<LLVM::SSACopyOp> noAliasParams;
2794c2f90f3SChristian Ulmann   for (Value argument : cast<LLVM::CallOp>(call).getArgOperands()) {
2804c2f90f3SChristian Ulmann     for (Operation *user : argument.getUsers()) {
2814c2f90f3SChristian Ulmann       auto ssaCopy = llvm::dyn_cast<LLVM::SSACopyOp>(user);
2824c2f90f3SChristian Ulmann       if (!ssaCopy)
2834c2f90f3SChristian Ulmann         continue;
284065d2d9cSChristian Ulmann       ssaCopies.insert(ssaCopy);
285065d2d9cSChristian Ulmann 
2864c2f90f3SChristian Ulmann       if (!ssaCopy->hasAttr(LLVM::LLVMDialect::getNoAliasAttrName()))
2874c2f90f3SChristian Ulmann         continue;
2884c2f90f3SChristian Ulmann       noAliasParams.insert(ssaCopy);
2894c2f90f3SChristian Ulmann     }
2904c2f90f3SChristian Ulmann   }
2914c2f90f3SChristian Ulmann 
2924c2f90f3SChristian Ulmann   // Scope exit block to make it impossible to forget to get rid of the
2934c2f90f3SChristian Ulmann   // intrinsics.
2944c2f90f3SChristian Ulmann   auto exit = llvm::make_scope_exit([&] {
295065d2d9cSChristian Ulmann     for (LLVM::SSACopyOp ssaCopyOp : ssaCopies) {
2964c2f90f3SChristian Ulmann       ssaCopyOp.replaceAllUsesWith(ssaCopyOp.getOperand());
2974c2f90f3SChristian Ulmann       ssaCopyOp->erase();
2984c2f90f3SChristian Ulmann     }
2994c2f90f3SChristian Ulmann   });
3004c2f90f3SChristian Ulmann 
301065d2d9cSChristian Ulmann   // If there were no noalias parameters, we have nothing to do here.
302065d2d9cSChristian Ulmann   if (noAliasParams.empty())
303065d2d9cSChristian Ulmann     return;
304065d2d9cSChristian Ulmann 
3054c2f90f3SChristian Ulmann   // Create a new domain for this specific inlining and a new scope for every
3064c2f90f3SChristian Ulmann   // noalias parameter.
3074c2f90f3SChristian Ulmann   auto functionDomain = LLVM::AliasScopeDomainAttr::get(
3084c2f90f3SChristian Ulmann       call->getContext(), cast<LLVM::CallOp>(call).getCalleeAttr().getAttr());
3094c2f90f3SChristian Ulmann   DenseMap<Value, LLVM::AliasScopeAttr> pointerScopes;
3104c2f90f3SChristian Ulmann   for (LLVM::SSACopyOp copyOp : noAliasParams) {
3114c2f90f3SChristian Ulmann     auto scope = LLVM::AliasScopeAttr::get(functionDomain);
3124c2f90f3SChristian Ulmann     pointerScopes[copyOp] = scope;
3134c2f90f3SChristian Ulmann 
3144c2f90f3SChristian Ulmann     OpBuilder(call).create<LLVM::NoAliasScopeDeclOp>(call->getLoc(), scope);
3154c2f90f3SChristian Ulmann   }
3164c2f90f3SChristian Ulmann 
3174c2f90f3SChristian Ulmann   // Go through every instruction and attempt to find which noalias parameters
3184c2f90f3SChristian Ulmann   // it is definitely based on and definitely not based on.
3194c2f90f3SChristian Ulmann   for (Block &inlinedBlock : inlinedBlocks) {
3204c2f90f3SChristian Ulmann     inlinedBlock.walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) {
3214c2f90f3SChristian Ulmann       // Collect the pointer arguments affected by the alias scopes.
3224c2f90f3SChristian Ulmann       SmallVector<Value> pointerArgs = aliasInterface.getAccessedOperands();
3234c2f90f3SChristian Ulmann 
3244c2f90f3SChristian Ulmann       // Find the set of underlying pointers that this pointer is based on.
3254c2f90f3SChristian Ulmann       SmallPtrSet<Value, 4> basedOnPointers;
32614153654SChristian Ulmann       for (Value pointer : pointerArgs) {
32714153654SChristian Ulmann         FailureOr<SmallVector<Value>> underlyingObjectSet =
32814153654SChristian Ulmann             getUnderlyingObjectSet(pointer);
32914153654SChristian Ulmann         if (failed(underlyingObjectSet))
33014153654SChristian Ulmann           return;
33114153654SChristian Ulmann         llvm::copy(*underlyingObjectSet,
3324c2f90f3SChristian Ulmann                    std::inserter(basedOnPointers, basedOnPointers.begin()));
33314153654SChristian Ulmann       }
3344c2f90f3SChristian Ulmann 
3354c2f90f3SChristian Ulmann       bool aliasesOtherKnownObject = false;
3364c2f90f3SChristian Ulmann       // Go through the based on pointers and check that they are either:
3374c2f90f3SChristian Ulmann       // * Constants that can be ignored (undef, poison, null pointer).
338065d2d9cSChristian Ulmann       // * Based on a pointer parameter.
3394c2f90f3SChristian Ulmann       // * Other pointers that we know can't alias with our noalias parameter.
3404c2f90f3SChristian Ulmann       //
3414c2f90f3SChristian Ulmann       // Any other value might be a pointer based on any noalias parameter that
3424c2f90f3SChristian Ulmann       // hasn't been identified. In that case conservatively don't add any
3434c2f90f3SChristian Ulmann       // scopes to this operation indicating either aliasing or not aliasing
3444c2f90f3SChristian Ulmann       // with any parameter.
3454c2f90f3SChristian Ulmann       if (llvm::any_of(basedOnPointers, [&](Value object) {
3464c2f90f3SChristian Ulmann             if (matchPattern(object, m_Constant()))
3474c2f90f3SChristian Ulmann               return false;
3484c2f90f3SChristian Ulmann 
349065d2d9cSChristian Ulmann             if (auto ssaCopy = object.getDefiningOp<LLVM::SSACopyOp>()) {
350065d2d9cSChristian Ulmann               // If that value is based on a noalias parameter, it is guaranteed
351065d2d9cSChristian Ulmann               // to not alias with any other object.
352065d2d9cSChristian Ulmann               aliasesOtherKnownObject |= !noAliasParams.contains(ssaCopy);
3534c2f90f3SChristian Ulmann               return false;
354065d2d9cSChristian Ulmann             }
3554c2f90f3SChristian Ulmann 
3564c2f90f3SChristian Ulmann             if (isa_and_nonnull<LLVM::AllocaOp, LLVM::AddressOfOp>(
3574c2f90f3SChristian Ulmann                     object.getDefiningOp())) {
3584c2f90f3SChristian Ulmann               aliasesOtherKnownObject = true;
3594c2f90f3SChristian Ulmann               return false;
3604c2f90f3SChristian Ulmann             }
3614c2f90f3SChristian Ulmann             return true;
3624c2f90f3SChristian Ulmann           }))
3634c2f90f3SChristian Ulmann         return;
3644c2f90f3SChristian Ulmann 
3654c2f90f3SChristian Ulmann       // Add all noalias parameter scopes to the noalias scope list that we are
3664c2f90f3SChristian Ulmann       // not based on.
3674c2f90f3SChristian Ulmann       SmallVector<Attribute> noAliasScopes;
3684c2f90f3SChristian Ulmann       for (LLVM::SSACopyOp noAlias : noAliasParams) {
3694c2f90f3SChristian Ulmann         if (basedOnPointers.contains(noAlias))
3704c2f90f3SChristian Ulmann           continue;
3714c2f90f3SChristian Ulmann 
3724c2f90f3SChristian Ulmann         noAliasScopes.push_back(pointerScopes[noAlias]);
3734c2f90f3SChristian Ulmann       }
3744c2f90f3SChristian Ulmann 
3754c2f90f3SChristian Ulmann       if (!noAliasScopes.empty())
3764c2f90f3SChristian Ulmann         aliasInterface.setNoAliasScopes(
3774c2f90f3SChristian Ulmann             concatArrayAttr(aliasInterface.getNoAliasScopesOrNull(),
3784c2f90f3SChristian Ulmann                             ArrayAttr::get(call->getContext(), noAliasScopes)));
3794c2f90f3SChristian Ulmann 
3804c2f90f3SChristian Ulmann       // Don't add alias scopes to call operations or operations that might
3814c2f90f3SChristian Ulmann       // operate on pointers not based on any noalias parameter.
3824c2f90f3SChristian Ulmann       // Since we add all scopes to an operation's noalias list that it
3834c2f90f3SChristian Ulmann       // definitely doesn't alias, we mustn't do the same for the alias.scope
3844c2f90f3SChristian Ulmann       // list if other objects are involved.
3854c2f90f3SChristian Ulmann       //
3864c2f90f3SChristian Ulmann       // Consider the following case:
3874c2f90f3SChristian Ulmann       // %0 = llvm.alloca
3884c2f90f3SChristian Ulmann       // %1 = select %magic, %0, %noalias_param
3894c2f90f3SChristian Ulmann       // store 5, %1  (1) noalias=[scope(...)]
3904c2f90f3SChristian Ulmann       // ...
3914c2f90f3SChristian Ulmann       // store 3, %0  (2) noalias=[scope(noalias_param), scope(...)]
3924c2f90f3SChristian Ulmann       //
3934c2f90f3SChristian Ulmann       // We can add the scopes of any noalias parameters that aren't
3944c2f90f3SChristian Ulmann       // noalias_param's scope to (1) and add all of them to (2). We mustn't add
3954c2f90f3SChristian Ulmann       // the scope of noalias_param to the alias.scope list of (1) since
3964c2f90f3SChristian Ulmann       // that would mean (2) cannot alias with (1) which is wrong since both may
3974c2f90f3SChristian Ulmann       // store to %0.
3984c2f90f3SChristian Ulmann       //
3994c2f90f3SChristian Ulmann       // In conclusion, only add scopes to the alias.scope list if all pointers
4004c2f90f3SChristian Ulmann       // have a corresponding scope.
4014c2f90f3SChristian Ulmann       // Call operations are included in this list since we do not know whether
4024c2f90f3SChristian Ulmann       // the callee accesses any memory besides the ones passed as its
4034c2f90f3SChristian Ulmann       // arguments.
4044c2f90f3SChristian Ulmann       if (aliasesOtherKnownObject ||
4054c2f90f3SChristian Ulmann           isa<LLVM::CallOp>(aliasInterface.getOperation()))
4064c2f90f3SChristian Ulmann         return;
4074c2f90f3SChristian Ulmann 
4084c2f90f3SChristian Ulmann       SmallVector<Attribute> aliasScopes;
4094c2f90f3SChristian Ulmann       for (LLVM::SSACopyOp noAlias : noAliasParams)
4104c2f90f3SChristian Ulmann         if (basedOnPointers.contains(noAlias))
4114c2f90f3SChristian Ulmann           aliasScopes.push_back(pointerScopes[noAlias]);
4124c2f90f3SChristian Ulmann 
4134c2f90f3SChristian Ulmann       if (!aliasScopes.empty())
4144c2f90f3SChristian Ulmann         aliasInterface.setAliasScopes(
4154c2f90f3SChristian Ulmann             concatArrayAttr(aliasInterface.getAliasScopesOrNull(),
4164c2f90f3SChristian Ulmann                             ArrayAttr::get(call->getContext(), aliasScopes)));
4174c2f90f3SChristian Ulmann     });
4184c2f90f3SChristian Ulmann   }
4194c2f90f3SChristian Ulmann }
4204c2f90f3SChristian Ulmann 
4214c2f90f3SChristian Ulmann /// Appends any alias scopes of the call operation to any inlined memory
4224c2f90f3SChristian Ulmann /// operation.
4234c2f90f3SChristian Ulmann static void
4244c2f90f3SChristian Ulmann appendCallOpAliasScopes(Operation *call,
4254c2f90f3SChristian Ulmann                         iterator_range<Region::iterator> inlinedBlocks) {
4264c2f90f3SChristian Ulmann   auto callAliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(call);
4274c2f90f3SChristian Ulmann   if (!callAliasInterface)
4284c2f90f3SChristian Ulmann     return;
4294c2f90f3SChristian Ulmann 
4304c2f90f3SChristian Ulmann   ArrayAttr aliasScopes = callAliasInterface.getAliasScopesOrNull();
4314c2f90f3SChristian Ulmann   ArrayAttr noAliasScopes = callAliasInterface.getNoAliasScopesOrNull();
4324c2f90f3SChristian Ulmann   // If the call has neither alias scopes or noalias scopes we have nothing to
4334c2f90f3SChristian Ulmann   // do here.
4344c2f90f3SChristian Ulmann   if (!aliasScopes && !noAliasScopes)
4354c2f90f3SChristian Ulmann     return;
4364c2f90f3SChristian Ulmann 
4374c2f90f3SChristian Ulmann   // Simply append the call op's alias and noalias scopes to any operation
4384c2f90f3SChristian Ulmann   // implementing AliasAnalysisOpInterface.
4394c2f90f3SChristian Ulmann   for (Block &block : inlinedBlocks) {
4404c2f90f3SChristian Ulmann     block.walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) {
4414c2f90f3SChristian Ulmann       if (aliasScopes)
4424c2f90f3SChristian Ulmann         aliasInterface.setAliasScopes(concatArrayAttr(
4434c2f90f3SChristian Ulmann             aliasInterface.getAliasScopesOrNull(), aliasScopes));
4444c2f90f3SChristian Ulmann 
4454c2f90f3SChristian Ulmann       if (noAliasScopes)
4464c2f90f3SChristian Ulmann         aliasInterface.setNoAliasScopes(concatArrayAttr(
4474c2f90f3SChristian Ulmann             aliasInterface.getNoAliasScopesOrNull(), noAliasScopes));
4484c2f90f3SChristian Ulmann     });
4494c2f90f3SChristian Ulmann   }
4504c2f90f3SChristian Ulmann }
4514c2f90f3SChristian Ulmann 
4524c2f90f3SChristian Ulmann /// Handles all interactions with alias scopes during inlining.
4534c2f90f3SChristian Ulmann static void handleAliasScopes(Operation *call,
4544c2f90f3SChristian Ulmann                               iterator_range<Region::iterator> inlinedBlocks) {
4554c2f90f3SChristian Ulmann   deepCloneAliasScopes(inlinedBlocks);
4564c2f90f3SChristian Ulmann   createNewAliasScopesFromNoAliasParameter(call, inlinedBlocks);
4574c2f90f3SChristian Ulmann   appendCallOpAliasScopes(call, inlinedBlocks);
4584c2f90f3SChristian Ulmann }
4594c2f90f3SChristian Ulmann 
4604c2f90f3SChristian Ulmann /// Appends any access groups of the call operation to any inlined memory
4614c2f90f3SChristian Ulmann /// operation.
4624c2f90f3SChristian Ulmann static void handleAccessGroups(Operation *call,
4634c2f90f3SChristian Ulmann                                iterator_range<Region::iterator> inlinedBlocks) {
4644c2f90f3SChristian Ulmann   auto callAccessGroupInterface = dyn_cast<LLVM::AccessGroupOpInterface>(call);
4654c2f90f3SChristian Ulmann   if (!callAccessGroupInterface)
4664c2f90f3SChristian Ulmann     return;
4674c2f90f3SChristian Ulmann 
4684c2f90f3SChristian Ulmann   auto accessGroups = callAccessGroupInterface.getAccessGroupsOrNull();
4694c2f90f3SChristian Ulmann   if (!accessGroups)
4704c2f90f3SChristian Ulmann     return;
4714c2f90f3SChristian Ulmann 
4724c2f90f3SChristian Ulmann   // Simply append the call op's access groups to any operation implementing
4734c2f90f3SChristian Ulmann   // AccessGroupOpInterface.
4744c2f90f3SChristian Ulmann   for (Block &block : inlinedBlocks)
4754c2f90f3SChristian Ulmann     for (auto accessGroupOpInterface :
4764c2f90f3SChristian Ulmann          block.getOps<LLVM::AccessGroupOpInterface>())
4774c2f90f3SChristian Ulmann       accessGroupOpInterface.setAccessGroups(concatArrayAttr(
4784c2f90f3SChristian Ulmann           accessGroupOpInterface.getAccessGroupsOrNull(), accessGroups));
4794c2f90f3SChristian Ulmann }
4804c2f90f3SChristian Ulmann 
4814c2f90f3SChristian Ulmann /// Updates locations inside loop annotations to reflect that they were inlined.
4824c2f90f3SChristian Ulmann static void
4834c2f90f3SChristian Ulmann handleLoopAnnotations(Operation *call,
4844c2f90f3SChristian Ulmann                       iterator_range<Region::iterator> inlinedBlocks) {
4854c2f90f3SChristian Ulmann   // Attempt to extract a DISubprogram from the callee.
4864c2f90f3SChristian Ulmann   auto func = call->getParentOfType<FunctionOpInterface>();
4874c2f90f3SChristian Ulmann   if (!func)
4884c2f90f3SChristian Ulmann     return;
4894c2f90f3SChristian Ulmann   LocationAttr funcLoc = func->getLoc();
4904c2f90f3SChristian Ulmann   auto fusedLoc = dyn_cast_if_present<FusedLoc>(funcLoc);
4914c2f90f3SChristian Ulmann   if (!fusedLoc)
4924c2f90f3SChristian Ulmann     return;
4934c2f90f3SChristian Ulmann   auto scope =
4944c2f90f3SChristian Ulmann       dyn_cast_if_present<LLVM::DISubprogramAttr>(fusedLoc.getMetadata());
4954c2f90f3SChristian Ulmann   if (!scope)
4964c2f90f3SChristian Ulmann     return;
4974c2f90f3SChristian Ulmann 
4984c2f90f3SChristian Ulmann   // Helper to build a new fused location that reflects the inlining of the loop
4994c2f90f3SChristian Ulmann   // annotation.
5004c2f90f3SChristian Ulmann   auto updateLoc = [&](FusedLoc loc) -> FusedLoc {
5014c2f90f3SChristian Ulmann     if (!loc)
5024c2f90f3SChristian Ulmann       return {};
5034c2f90f3SChristian Ulmann     Location callSiteLoc = CallSiteLoc::get(loc, call->getLoc());
5044c2f90f3SChristian Ulmann     return FusedLoc::get(loc.getContext(), callSiteLoc, scope);
5054c2f90f3SChristian Ulmann   };
5064c2f90f3SChristian Ulmann 
5074c2f90f3SChristian Ulmann   AttrTypeReplacer replacer;
5084c2f90f3SChristian Ulmann   replacer.addReplacement([&](LLVM::LoopAnnotationAttr loopAnnotation)
5094c2f90f3SChristian Ulmann                               -> std::pair<Attribute, WalkResult> {
5104c2f90f3SChristian Ulmann     FusedLoc newStartLoc = updateLoc(loopAnnotation.getStartLoc());
5114c2f90f3SChristian Ulmann     FusedLoc newEndLoc = updateLoc(loopAnnotation.getEndLoc());
5124c2f90f3SChristian Ulmann     if (!newStartLoc && !newEndLoc)
5134c2f90f3SChristian Ulmann       return {loopAnnotation, WalkResult::advance()};
5144c2f90f3SChristian Ulmann     auto newLoopAnnotation = LLVM::LoopAnnotationAttr::get(
5154c2f90f3SChristian Ulmann         loopAnnotation.getContext(), loopAnnotation.getDisableNonforced(),
5164c2f90f3SChristian Ulmann         loopAnnotation.getVectorize(), loopAnnotation.getInterleave(),
5174c2f90f3SChristian Ulmann         loopAnnotation.getUnroll(), loopAnnotation.getUnrollAndJam(),
5184c2f90f3SChristian Ulmann         loopAnnotation.getLicm(), loopAnnotation.getDistribute(),
5194c2f90f3SChristian Ulmann         loopAnnotation.getPipeline(), loopAnnotation.getPeeled(),
5204c2f90f3SChristian Ulmann         loopAnnotation.getUnswitch(), loopAnnotation.getMustProgress(),
5214c2f90f3SChristian Ulmann         loopAnnotation.getIsVectorized(), newStartLoc, newEndLoc,
5224c2f90f3SChristian Ulmann         loopAnnotation.getParallelAccesses());
5234c2f90f3SChristian Ulmann     // Needs to advance, as loop annotations can be nested.
5244c2f90f3SChristian Ulmann     return {newLoopAnnotation, WalkResult::advance()};
5254c2f90f3SChristian Ulmann   });
5264c2f90f3SChristian Ulmann 
5274c2f90f3SChristian Ulmann   for (Block &block : inlinedBlocks)
5284c2f90f3SChristian Ulmann     for (Operation &op : block)
5294c2f90f3SChristian Ulmann       replacer.recursivelyReplaceElementsIn(&op);
5304c2f90f3SChristian Ulmann }
5314c2f90f3SChristian Ulmann 
5324c2f90f3SChristian Ulmann /// If `requestedAlignment` is higher than the alignment specified on `alloca`,
5334c2f90f3SChristian Ulmann /// realigns `alloca` if this does not exceed the natural stack alignment.
5344c2f90f3SChristian Ulmann /// Returns the post-alignment of `alloca`, whether it was realigned or not.
5354c2f90f3SChristian Ulmann static uint64_t tryToEnforceAllocaAlignment(LLVM::AllocaOp alloca,
5364c2f90f3SChristian Ulmann                                             uint64_t requestedAlignment,
5374c2f90f3SChristian Ulmann                                             DataLayout const &dataLayout) {
5384c2f90f3SChristian Ulmann   uint64_t allocaAlignment = alloca.getAlignment().value_or(1);
5394c2f90f3SChristian Ulmann   if (requestedAlignment <= allocaAlignment)
5404c2f90f3SChristian Ulmann     // No realignment necessary.
5414c2f90f3SChristian Ulmann     return allocaAlignment;
5424c2f90f3SChristian Ulmann   uint64_t naturalStackAlignmentBits = dataLayout.getStackAlignment();
5434c2f90f3SChristian Ulmann   // If the natural stack alignment is not specified, the data layout returns
5444c2f90f3SChristian Ulmann   // zero. Optimistically allow realignment in this case.
5454c2f90f3SChristian Ulmann   if (naturalStackAlignmentBits == 0 ||
5464c2f90f3SChristian Ulmann       // If the requested alignment exceeds the natural stack alignment, this
5474c2f90f3SChristian Ulmann       // will trigger a dynamic stack realignment, so we prefer to copy...
5484c2f90f3SChristian Ulmann       8 * requestedAlignment <= naturalStackAlignmentBits ||
5494c2f90f3SChristian Ulmann       // ...unless the alloca already triggers dynamic stack realignment. Then
5504c2f90f3SChristian Ulmann       // we might as well further increase the alignment to avoid a copy.
5514c2f90f3SChristian Ulmann       8 * allocaAlignment > naturalStackAlignmentBits) {
5524c2f90f3SChristian Ulmann     alloca.setAlignment(requestedAlignment);
5534c2f90f3SChristian Ulmann     allocaAlignment = requestedAlignment;
5544c2f90f3SChristian Ulmann   }
5554c2f90f3SChristian Ulmann   return allocaAlignment;
5564c2f90f3SChristian Ulmann }
5574c2f90f3SChristian Ulmann 
5584c2f90f3SChristian Ulmann /// Tries to find and return the alignment of the pointer `value` by looking for
5594c2f90f3SChristian Ulmann /// an alignment attribute on the defining allocation op or function argument.
5604c2f90f3SChristian Ulmann /// If the found alignment is lower than `requestedAlignment`, tries to realign
5614c2f90f3SChristian Ulmann /// the pointer, then returns the resulting post-alignment, regardless of
5624c2f90f3SChristian Ulmann /// whether it was realigned or not. If no existing alignment attribute is
5634c2f90f3SChristian Ulmann /// found, returns 1 (i.e., assume that no alignment is guaranteed).
5644c2f90f3SChristian Ulmann static uint64_t tryToEnforceAlignment(Value value, uint64_t requestedAlignment,
5654c2f90f3SChristian Ulmann                                       DataLayout const &dataLayout) {
5664c2f90f3SChristian Ulmann   if (Operation *definingOp = value.getDefiningOp()) {
5674c2f90f3SChristian Ulmann     if (auto alloca = dyn_cast<LLVM::AllocaOp>(definingOp))
5684c2f90f3SChristian Ulmann       return tryToEnforceAllocaAlignment(alloca, requestedAlignment,
5694c2f90f3SChristian Ulmann                                          dataLayout);
5704c2f90f3SChristian Ulmann     if (auto addressOf = dyn_cast<LLVM::AddressOfOp>(definingOp))
5714c2f90f3SChristian Ulmann       if (auto global = SymbolTable::lookupNearestSymbolFrom<LLVM::GlobalOp>(
5724c2f90f3SChristian Ulmann               definingOp, addressOf.getGlobalNameAttr()))
5734c2f90f3SChristian Ulmann         return global.getAlignment().value_or(1);
5744c2f90f3SChristian Ulmann     // We don't currently handle this operation; assume no alignment.
5754c2f90f3SChristian Ulmann     return 1;
5764c2f90f3SChristian Ulmann   }
5774c2f90f3SChristian Ulmann   // Since there is no defining op, this is a block argument. Probably this
5784c2f90f3SChristian Ulmann   // comes directly from a function argument, so check that this is the case.
5794c2f90f3SChristian Ulmann   Operation *parentOp = value.getParentBlock()->getParentOp();
5804c2f90f3SChristian Ulmann   if (auto func = dyn_cast<LLVM::LLVMFuncOp>(parentOp)) {
5814c2f90f3SChristian Ulmann     // Use the alignment attribute set for this argument in the parent function
5824c2f90f3SChristian Ulmann     // if it has been set.
5834c2f90f3SChristian Ulmann     auto blockArg = llvm::cast<BlockArgument>(value);
5844c2f90f3SChristian Ulmann     if (Attribute alignAttr = func.getArgAttr(
5854c2f90f3SChristian Ulmann             blockArg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName()))
5864c2f90f3SChristian Ulmann       return cast<IntegerAttr>(alignAttr).getValue().getLimitedValue();
5874c2f90f3SChristian Ulmann   }
5884c2f90f3SChristian Ulmann   // We didn't find anything useful; assume no alignment.
5894c2f90f3SChristian Ulmann   return 1;
5904c2f90f3SChristian Ulmann }
5914c2f90f3SChristian Ulmann 
5924c2f90f3SChristian Ulmann /// Introduces a new alloca and copies the memory pointed to by `argument` to
5934c2f90f3SChristian Ulmann /// the address of the new alloca, then returns the value of the new alloca.
5944c2f90f3SChristian Ulmann static Value handleByValArgumentInit(OpBuilder &builder, Location loc,
5954c2f90f3SChristian Ulmann                                      Value argument, Type elementType,
5964c2f90f3SChristian Ulmann                                      uint64_t elementTypeSize,
5974c2f90f3SChristian Ulmann                                      uint64_t targetAlignment) {
5984c2f90f3SChristian Ulmann   // Allocate the new value on the stack.
5994c2f90f3SChristian Ulmann   Value allocaOp;
6004c2f90f3SChristian Ulmann   {
6014c2f90f3SChristian Ulmann     // Since this is a static alloca, we can put it directly in the entry block,
6024c2f90f3SChristian Ulmann     // so they can be absorbed into the prologue/epilogue at code generation.
6034c2f90f3SChristian Ulmann     OpBuilder::InsertionGuard insertionGuard(builder);
6044c2f90f3SChristian Ulmann     Block *entryBlock = &(*argument.getParentRegion()->begin());
6054c2f90f3SChristian Ulmann     builder.setInsertionPointToStart(entryBlock);
6064c2f90f3SChristian Ulmann     Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
6074c2f90f3SChristian Ulmann                                                  builder.getI64IntegerAttr(1));
6084c2f90f3SChristian Ulmann     allocaOp = builder.create<LLVM::AllocaOp>(
6094c2f90f3SChristian Ulmann         loc, argument.getType(), elementType, one, targetAlignment);
6104c2f90f3SChristian Ulmann   }
6114c2f90f3SChristian Ulmann   // Copy the pointee to the newly allocated value.
6124c2f90f3SChristian Ulmann   Value copySize = builder.create<LLVM::ConstantOp>(
6134c2f90f3SChristian Ulmann       loc, builder.getI64Type(), builder.getI64IntegerAttr(elementTypeSize));
6144c2f90f3SChristian Ulmann   builder.create<LLVM::MemcpyOp>(loc, allocaOp, argument, copySize,
6154c2f90f3SChristian Ulmann                                  /*isVolatile=*/false);
6164c2f90f3SChristian Ulmann   return allocaOp;
6174c2f90f3SChristian Ulmann }
6184c2f90f3SChristian Ulmann 
6194c2f90f3SChristian Ulmann /// Handles a function argument marked with the byval attribute by introducing a
6204c2f90f3SChristian Ulmann /// memcpy or realigning the defining operation, if required either due to the
6214c2f90f3SChristian Ulmann /// pointee being writeable in the callee, and/or due to an alignment mismatch.
6224c2f90f3SChristian Ulmann /// `requestedAlignment` specifies the alignment set in the "align" argument
6234c2f90f3SChristian Ulmann /// attribute (or 1 if no align attribute was set).
6244c2f90f3SChristian Ulmann static Value handleByValArgument(OpBuilder &builder, Operation *callable,
6254c2f90f3SChristian Ulmann                                  Value argument, Type elementType,
6264c2f90f3SChristian Ulmann                                  uint64_t requestedAlignment) {
6274c2f90f3SChristian Ulmann   auto func = cast<LLVM::LLVMFuncOp>(callable);
6284c2f90f3SChristian Ulmann   LLVM::MemoryEffectsAttr memoryEffects = func.getMemoryEffectsAttr();
6294c2f90f3SChristian Ulmann   // If there is no memory effects attribute, assume that the function is
6304c2f90f3SChristian Ulmann   // not read-only.
6314c2f90f3SChristian Ulmann   bool isReadOnly = memoryEffects &&
6324c2f90f3SChristian Ulmann                     memoryEffects.getArgMem() != LLVM::ModRefInfo::ModRef &&
6334c2f90f3SChristian Ulmann                     memoryEffects.getArgMem() != LLVM::ModRefInfo::Mod;
6344c2f90f3SChristian Ulmann   // Check if there's an alignment mismatch requiring us to copy.
6354c2f90f3SChristian Ulmann   DataLayout dataLayout = DataLayout::closest(callable);
6364c2f90f3SChristian Ulmann   uint64_t minimumAlignment = dataLayout.getTypeABIAlignment(elementType);
6374c2f90f3SChristian Ulmann   if (isReadOnly) {
6384c2f90f3SChristian Ulmann     if (requestedAlignment <= minimumAlignment)
6394c2f90f3SChristian Ulmann       return argument;
6404c2f90f3SChristian Ulmann     uint64_t currentAlignment =
6414c2f90f3SChristian Ulmann         tryToEnforceAlignment(argument, requestedAlignment, dataLayout);
6424c2f90f3SChristian Ulmann     if (currentAlignment >= requestedAlignment)
6434c2f90f3SChristian Ulmann       return argument;
6444c2f90f3SChristian Ulmann   }
6454c2f90f3SChristian Ulmann   uint64_t targetAlignment = std::max(requestedAlignment, minimumAlignment);
646*d0c9e70bSTobias Gysi   return handleByValArgumentInit(
647*d0c9e70bSTobias Gysi       builder, argument.getLoc(), argument, elementType,
648*d0c9e70bSTobias Gysi       dataLayout.getTypeSize(elementType), targetAlignment);
6494c2f90f3SChristian Ulmann }
6504c2f90f3SChristian Ulmann 
6514c2f90f3SChristian Ulmann namespace {
6524c2f90f3SChristian Ulmann struct LLVMInlinerInterface : public DialectInlinerInterface {
6534c2f90f3SChristian Ulmann   using DialectInlinerInterface::DialectInlinerInterface;
6544c2f90f3SChristian Ulmann 
6554c2f90f3SChristian Ulmann   LLVMInlinerInterface(Dialect *dialect)
6564c2f90f3SChristian Ulmann       : DialectInlinerInterface(dialect),
6574c2f90f3SChristian Ulmann         // Cache set of StringAttrs for fast lookup in `isLegalToInline`.
6584c2f90f3SChristian Ulmann         disallowedFunctionAttrs({
6594c2f90f3SChristian Ulmann             StringAttr::get(dialect->getContext(), "noduplicate"),
6604c2f90f3SChristian Ulmann             StringAttr::get(dialect->getContext(), "presplitcoroutine"),
6614c2f90f3SChristian Ulmann             StringAttr::get(dialect->getContext(), "returns_twice"),
6624c2f90f3SChristian Ulmann             StringAttr::get(dialect->getContext(), "strictfp"),
6634c2f90f3SChristian Ulmann         }) {}
6644c2f90f3SChristian Ulmann 
6654c2f90f3SChristian Ulmann   bool isLegalToInline(Operation *call, Operation *callable,
6664c2f90f3SChristian Ulmann                        bool wouldBeCloned) const final {
6674c2f90f3SChristian Ulmann     if (!isa<LLVM::CallOp>(call)) {
6684c2f90f3SChristian Ulmann       LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is not an '"
6694c2f90f3SChristian Ulmann                               << LLVM::CallOp::getOperationName() << "' op\n");
6704c2f90f3SChristian Ulmann       return false;
6714c2f90f3SChristian Ulmann     }
6724c2f90f3SChristian Ulmann     auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable);
6734c2f90f3SChristian Ulmann     if (!funcOp) {
6744c2f90f3SChristian Ulmann       LLVM_DEBUG(llvm::dbgs()
6754c2f90f3SChristian Ulmann                  << "Cannot inline: callable is not an '"
6764c2f90f3SChristian Ulmann                  << LLVM::LLVMFuncOp::getOperationName() << "' op\n");
6774c2f90f3SChristian Ulmann       return false;
6784c2f90f3SChristian Ulmann     }
6794c2f90f3SChristian Ulmann     if (funcOp.isNoInline()) {
6804c2f90f3SChristian Ulmann       LLVM_DEBUG(llvm::dbgs()
6814c2f90f3SChristian Ulmann                  << "Cannot inline: function is marked no_inline\n");
6824c2f90f3SChristian Ulmann       return false;
6834c2f90f3SChristian Ulmann     }
6844c2f90f3SChristian Ulmann     if (funcOp.isVarArg()) {
6854c2f90f3SChristian Ulmann       LLVM_DEBUG(llvm::dbgs() << "Cannot inline: callable is variadic\n");
6864c2f90f3SChristian Ulmann       return false;
6874c2f90f3SChristian Ulmann     }
6884c2f90f3SChristian Ulmann     // TODO: Generate aliasing metadata from noalias result attributes.
6894c2f90f3SChristian Ulmann     if (auto attrs = funcOp.getArgAttrs()) {
6904c2f90f3SChristian Ulmann       for (DictionaryAttr attrDict : attrs->getAsRange<DictionaryAttr>()) {
6914c2f90f3SChristian Ulmann         if (attrDict.contains(LLVM::LLVMDialect::getInAllocaAttrName())) {
6924c2f90f3SChristian Ulmann           LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName()
6934c2f90f3SChristian Ulmann                                   << ": inalloca arguments not supported\n");
6944c2f90f3SChristian Ulmann           return false;
6954c2f90f3SChristian Ulmann         }
6964c2f90f3SChristian Ulmann       }
6974c2f90f3SChristian Ulmann     }
6984c2f90f3SChristian Ulmann     // TODO: Handle exceptions.
6994c2f90f3SChristian Ulmann     if (funcOp.getPersonality()) {
7004c2f90f3SChristian Ulmann       LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName()
7014c2f90f3SChristian Ulmann                               << ": unhandled function personality\n");
7024c2f90f3SChristian Ulmann       return false;
7034c2f90f3SChristian Ulmann     }
7044c2f90f3SChristian Ulmann     if (funcOp.getPassthrough()) {
7054c2f90f3SChristian Ulmann       // TODO: Used attributes should not be passthrough.
7064c2f90f3SChristian Ulmann       if (llvm::any_of(*funcOp.getPassthrough(), [&](Attribute attr) {
7074c2f90f3SChristian Ulmann             auto stringAttr = dyn_cast<StringAttr>(attr);
7084c2f90f3SChristian Ulmann             if (!stringAttr)
7094c2f90f3SChristian Ulmann               return false;
7104c2f90f3SChristian Ulmann             if (disallowedFunctionAttrs.contains(stringAttr)) {
7114c2f90f3SChristian Ulmann               LLVM_DEBUG(llvm::dbgs()
7124c2f90f3SChristian Ulmann                          << "Cannot inline " << funcOp.getSymName()
7134c2f90f3SChristian Ulmann                          << ": found disallowed function attribute "
7144c2f90f3SChristian Ulmann                          << stringAttr << "\n");
7154c2f90f3SChristian Ulmann               return true;
7164c2f90f3SChristian Ulmann             }
7174c2f90f3SChristian Ulmann             return false;
7184c2f90f3SChristian Ulmann           }))
7194c2f90f3SChristian Ulmann         return false;
7204c2f90f3SChristian Ulmann     }
7214c2f90f3SChristian Ulmann     return true;
7224c2f90f3SChristian Ulmann   }
7234c2f90f3SChristian Ulmann 
7244c2f90f3SChristian Ulmann   bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
7254c2f90f3SChristian Ulmann     return true;
7264c2f90f3SChristian Ulmann   }
7274c2f90f3SChristian Ulmann 
7284c2f90f3SChristian Ulmann   bool isLegalToInline(Operation *op, Region *, bool, IRMapping &) const final {
7294c2f90f3SChristian Ulmann     // The inliner cannot handle variadic function arguments.
7304c2f90f3SChristian Ulmann     return !isa<LLVM::VaStartOp>(op);
7314c2f90f3SChristian Ulmann   }
7324c2f90f3SChristian Ulmann 
7334c2f90f3SChristian Ulmann   /// Handle the given inlined return by replacing it with a branch. This
7344c2f90f3SChristian Ulmann   /// overload is called when the inlined region has more than one block.
7354c2f90f3SChristian Ulmann   void handleTerminator(Operation *op, Block *newDest) const final {
7364c2f90f3SChristian Ulmann     // Only return needs to be handled here.
7374c2f90f3SChristian Ulmann     auto returnOp = dyn_cast<LLVM::ReturnOp>(op);
7384c2f90f3SChristian Ulmann     if (!returnOp)
7394c2f90f3SChristian Ulmann       return;
7404c2f90f3SChristian Ulmann 
7414c2f90f3SChristian Ulmann     // Replace the return with a branch to the dest.
7424c2f90f3SChristian Ulmann     OpBuilder builder(op);
7434c2f90f3SChristian Ulmann     builder.create<LLVM::BrOp>(op->getLoc(), returnOp.getOperands(), newDest);
7444c2f90f3SChristian Ulmann     op->erase();
7454c2f90f3SChristian Ulmann   }
7464c2f90f3SChristian Ulmann 
747b39c5cb6SWilliam Moses   bool allowSingleBlockOptimization(
748b39c5cb6SWilliam Moses       iterator_range<Region::iterator> inlinedBlocks) const final {
749b39c5cb6SWilliam Moses     if (!inlinedBlocks.empty() &&
750b39c5cb6SWilliam Moses         isa<LLVM::UnreachableOp>(inlinedBlocks.begin()->getTerminator()))
751b39c5cb6SWilliam Moses       return false;
752b39c5cb6SWilliam Moses     return true;
753b39c5cb6SWilliam Moses   }
754b39c5cb6SWilliam Moses 
7554c2f90f3SChristian Ulmann   /// Handle the given inlined return by replacing the uses of the call with the
7564c2f90f3SChristian Ulmann   /// operands of the return. This overload is called when the inlined region
7574c2f90f3SChristian Ulmann   /// only contains one block.
7584c2f90f3SChristian Ulmann   void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
7594c2f90f3SChristian Ulmann     // Return will be the only terminator present.
7604c2f90f3SChristian Ulmann     auto returnOp = cast<LLVM::ReturnOp>(op);
7614c2f90f3SChristian Ulmann 
7624c2f90f3SChristian Ulmann     // Replace the values directly with the return operands.
7634c2f90f3SChristian Ulmann     assert(returnOp.getNumOperands() == valuesToRepl.size());
7644c2f90f3SChristian Ulmann     for (auto [dst, src] : llvm::zip(valuesToRepl, returnOp.getOperands()))
7654c2f90f3SChristian Ulmann       dst.replaceAllUsesWith(src);
7664c2f90f3SChristian Ulmann   }
7674c2f90f3SChristian Ulmann 
7684c2f90f3SChristian Ulmann   Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
7694c2f90f3SChristian Ulmann                        Value argument,
7704c2f90f3SChristian Ulmann                        DictionaryAttr argumentAttrs) const final {
7714c2f90f3SChristian Ulmann     if (std::optional<NamedAttribute> attr =
7724c2f90f3SChristian Ulmann             argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName())) {
7734c2f90f3SChristian Ulmann       Type elementType = cast<TypeAttr>(attr->getValue()).getValue();
7744c2f90f3SChristian Ulmann       uint64_t requestedAlignment = 1;
7754c2f90f3SChristian Ulmann       if (std::optional<NamedAttribute> alignAttr =
7764c2f90f3SChristian Ulmann               argumentAttrs.getNamed(LLVM::LLVMDialect::getAlignAttrName())) {
7774c2f90f3SChristian Ulmann         requestedAlignment = cast<IntegerAttr>(alignAttr->getValue())
7784c2f90f3SChristian Ulmann                                  .getValue()
7794c2f90f3SChristian Ulmann                                  .getLimitedValue();
7804c2f90f3SChristian Ulmann       }
7814c2f90f3SChristian Ulmann       return handleByValArgument(builder, callable, argument, elementType,
7824c2f90f3SChristian Ulmann                                  requestedAlignment);
7834c2f90f3SChristian Ulmann     }
7844c2f90f3SChristian Ulmann 
785065d2d9cSChristian Ulmann     // This code is essentially a workaround for deficiencies in the inliner
786065d2d9cSChristian Ulmann     // interface: We need to transform operations *after* inlined based on the
787065d2d9cSChristian Ulmann     // argument attributes of the parameters *before* inlining. This method runs
788065d2d9cSChristian Ulmann     // prior to actual inlining and thus cannot transform the post-inlining
789065d2d9cSChristian Ulmann     // code, while `processInlinedCallBlocks` does not have access to
790065d2d9cSChristian Ulmann     // pre-inlining function arguments. Additionally, it is required to
791065d2d9cSChristian Ulmann     // distinguish which parameter an SSA value originally came from. As a
792065d2d9cSChristian Ulmann     // workaround until this is changed: Create an ssa.copy intrinsic with the
793065d2d9cSChristian Ulmann     // noalias attribute (when it was present before) that can easily be found,
794065d2d9cSChristian Ulmann     // and is extremely unlikely to exist in the code prior to inlining, using
795065d2d9cSChristian Ulmann     // this to communicate between this method and `processInlinedCallBlocks`.
7964c2f90f3SChristian Ulmann     // TODO: Fix this by refactoring the inliner interface.
7974c2f90f3SChristian Ulmann     auto copyOp = builder.create<LLVM::SSACopyOp>(call->getLoc(), argument);
798065d2d9cSChristian Ulmann     if (argumentAttrs.contains(LLVM::LLVMDialect::getNoAliasAttrName()))
7994c2f90f3SChristian Ulmann       copyOp->setDiscardableAttr(
8004c2f90f3SChristian Ulmann           builder.getStringAttr(LLVM::LLVMDialect::getNoAliasAttrName()),
8014c2f90f3SChristian Ulmann           builder.getUnitAttr());
8024c2f90f3SChristian Ulmann     return copyOp;
8034c2f90f3SChristian Ulmann   }
8044c2f90f3SChristian Ulmann 
8054c2f90f3SChristian Ulmann   void processInlinedCallBlocks(
8064c2f90f3SChristian Ulmann       Operation *call,
8074c2f90f3SChristian Ulmann       iterator_range<Region::iterator> inlinedBlocks) const override {
8084c2f90f3SChristian Ulmann     handleInlinedAllocas(call, inlinedBlocks);
8094c2f90f3SChristian Ulmann     handleAliasScopes(call, inlinedBlocks);
8104c2f90f3SChristian Ulmann     handleAccessGroups(call, inlinedBlocks);
8114c2f90f3SChristian Ulmann     handleLoopAnnotations(call, inlinedBlocks);
8124c2f90f3SChristian Ulmann   }
8134c2f90f3SChristian Ulmann 
8144c2f90f3SChristian Ulmann   // Keeping this (immutable) state on the interface allows us to look up
8154c2f90f3SChristian Ulmann   // StringAttrs instead of looking up strings, since StringAttrs are bound to
8164c2f90f3SChristian Ulmann   // the current context and thus cannot be initialized as static fields.
8174c2f90f3SChristian Ulmann   const DenseSet<StringAttr> disallowedFunctionAttrs;
8184c2f90f3SChristian Ulmann };
8194c2f90f3SChristian Ulmann 
8204c2f90f3SChristian Ulmann } // end anonymous namespace
8214c2f90f3SChristian Ulmann 
8224c2f90f3SChristian Ulmann void mlir::LLVM::registerInlinerInterface(DialectRegistry &registry) {
8234c2f90f3SChristian Ulmann   registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
8244c2f90f3SChristian Ulmann     dialect->addInterfaces<LLVMInlinerInterface>();
8254c2f90f3SChristian Ulmann   });
8264c2f90f3SChristian Ulmann }
8278d306ccdSWilliam Moses 
8288d306ccdSWilliam Moses void mlir::NVVM::registerInlinerInterface(DialectRegistry &registry) {
8298d306ccdSWilliam Moses   registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) {
8308d306ccdSWilliam Moses     dialect->addInterfaces<LLVMInlinerInterface>();
8318d306ccdSWilliam Moses   });
8328d306ccdSWilliam Moses }
833