1 //===- InlinerInterfaceImpl.cpp - Inlining for LLVM the dialect -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Logic for inlining LLVM functions and the definition of the 10 // LLVMInliningInterface. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" 15 #include "mlir/Analysis/SliceWalk.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/Interfaces/DataLayoutInterfaces.h" 20 #include "mlir/Interfaces/ViewLikeInterface.h" 21 #include "mlir/Transforms/InliningUtils.h" 22 #include "llvm/ADT/ScopeExit.h" 23 #include "llvm/Support/Debug.h" 24 25 #define DEBUG_TYPE "llvm-inliner" 26 27 using namespace mlir; 28 29 /// Check whether the given alloca is an input to a lifetime intrinsic, 30 /// optionally passing through one or more casts on the way. This is not 31 /// transitive through block arguments. 32 static bool hasLifetimeMarkers(LLVM::AllocaOp allocaOp) { 33 SmallVector<Operation *> stack(allocaOp->getUsers().begin(), 34 allocaOp->getUsers().end()); 35 while (!stack.empty()) { 36 Operation *op = stack.pop_back_val(); 37 if (isa<LLVM::LifetimeStartOp, LLVM::LifetimeEndOp>(op)) 38 return true; 39 if (isa<LLVM::BitcastOp>(op)) 40 stack.append(op->getUsers().begin(), op->getUsers().end()); 41 } 42 return false; 43 } 44 45 /// Handles alloca operations in the inlined blocks: 46 /// - Moves all alloca operations with a constant size in the former entry block 47 /// of the callee into the entry block of the caller, so they become part of 48 /// the function prologue/epilogue during code generation. 49 /// - Inserts lifetime intrinsics that limit the scope of inlined static allocas 50 /// to the inlined blocks. 51 /// - Inserts StackSave and StackRestore operations if dynamic allocas were 52 /// inlined. 53 static void 54 handleInlinedAllocas(Operation *call, 55 iterator_range<Region::iterator> inlinedBlocks) { 56 // Locate the entry block of the closest callsite ancestor that has either the 57 // IsolatedFromAbove or AutomaticAllocationScope trait. In pure LLVM dialect 58 // programs, this is the LLVMFuncOp containing the call site. However, in 59 // mixed-dialect programs, the callsite might be nested in another operation 60 // that carries one of these traits. In such scenarios, this traversal stops 61 // at the closest ancestor with either trait, ensuring visibility post 62 // relocation and respecting allocation scopes. 63 Block *callerEntryBlock = nullptr; 64 Operation *currentOp = call; 65 while (Operation *parentOp = currentOp->getParentOp()) { 66 if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() || 67 parentOp->mightHaveTrait<OpTrait::AutomaticAllocationScope>()) { 68 callerEntryBlock = ¤tOp->getParentRegion()->front(); 69 break; 70 } 71 currentOp = parentOp; 72 } 73 74 // Avoid relocating the alloca operations if the call has been inlined into 75 // the entry block already, which is typically the encompassing 76 // LLVM function, or if the relevant entry block cannot be identified. 77 Block *calleeEntryBlock = &(*inlinedBlocks.begin()); 78 if (!callerEntryBlock || callerEntryBlock == calleeEntryBlock) 79 return; 80 81 SmallVector<std::tuple<LLVM::AllocaOp, IntegerAttr, bool>> allocasToMove; 82 bool shouldInsertLifetimes = false; 83 bool hasDynamicAlloca = false; 84 // Conservatively only move static alloca operations that are part of the 85 // entry block and do not inspect nested regions, since they may execute 86 // conditionally or have other unknown semantics. 87 for (auto allocaOp : calleeEntryBlock->getOps<LLVM::AllocaOp>()) { 88 IntegerAttr arraySize; 89 if (!matchPattern(allocaOp.getArraySize(), m_Constant(&arraySize))) { 90 hasDynamicAlloca = true; 91 continue; 92 } 93 bool shouldInsertLifetime = 94 arraySize.getValue() != 0 && !hasLifetimeMarkers(allocaOp); 95 shouldInsertLifetimes |= shouldInsertLifetime; 96 allocasToMove.emplace_back(allocaOp, arraySize, shouldInsertLifetime); 97 } 98 // Check the remaining inlined blocks for dynamic allocas as well. 99 for (Block &block : llvm::drop_begin(inlinedBlocks)) { 100 if (hasDynamicAlloca) 101 break; 102 hasDynamicAlloca = 103 llvm::any_of(block.getOps<LLVM::AllocaOp>(), [](auto allocaOp) { 104 return !matchPattern(allocaOp.getArraySize(), m_Constant()); 105 }); 106 } 107 if (allocasToMove.empty() && !hasDynamicAlloca) 108 return; 109 OpBuilder builder(calleeEntryBlock, calleeEntryBlock->begin()); 110 Value stackPtr; 111 if (hasDynamicAlloca) { 112 // This may result in multiple stacksave/stackrestore intrinsics in the same 113 // scope if some are already present in the body of the caller. This is not 114 // invalid IR, but LLVM cleans these up in InstCombineCalls.cpp, along with 115 // other cases where the stacksave/stackrestore is redundant. 116 stackPtr = builder.create<LLVM::StackSaveOp>( 117 call->getLoc(), LLVM::LLVMPointerType::get(call->getContext())); 118 } 119 builder.setInsertionPointToStart(callerEntryBlock); 120 for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) { 121 auto newConstant = builder.create<LLVM::ConstantOp>( 122 allocaOp->getLoc(), allocaOp.getArraySize().getType(), arraySize); 123 // Insert a lifetime start intrinsic where the alloca was before moving it. 124 if (shouldInsertLifetime) { 125 OpBuilder::InsertionGuard insertionGuard(builder); 126 builder.setInsertionPoint(allocaOp); 127 builder.create<LLVM::LifetimeStartOp>( 128 allocaOp.getLoc(), arraySize.getValue().getLimitedValue(), 129 allocaOp.getResult()); 130 } 131 allocaOp->moveAfter(newConstant); 132 allocaOp.getArraySizeMutable().assign(newConstant.getResult()); 133 } 134 if (!shouldInsertLifetimes && !hasDynamicAlloca) 135 return; 136 // Insert a lifetime end intrinsic before each return in the callee function. 137 for (Block &block : inlinedBlocks) { 138 if (!block.getTerminator()->hasTrait<OpTrait::ReturnLike>()) 139 continue; 140 builder.setInsertionPoint(block.getTerminator()); 141 if (hasDynamicAlloca) 142 builder.create<LLVM::StackRestoreOp>(call->getLoc(), stackPtr); 143 for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) { 144 if (shouldInsertLifetime) 145 builder.create<LLVM::LifetimeEndOp>( 146 allocaOp.getLoc(), arraySize.getValue().getLimitedValue(), 147 allocaOp.getResult()); 148 } 149 } 150 } 151 152 /// Maps all alias scopes in the inlined operations to deep clones of the scopes 153 /// and domain. This is required for code such as `foo(a, b); foo(a2, b2);` to 154 /// not incorrectly return `noalias` for e.g. operations on `a` and `a2`. 155 static void 156 deepCloneAliasScopes(iterator_range<Region::iterator> inlinedBlocks) { 157 DenseMap<Attribute, Attribute> mapping; 158 159 // Register handles in the walker to create the deep clones. 160 // The walker ensures that an attribute is only ever walked once and does a 161 // post-order walk, ensuring the domain is visited prior to the scope. 162 AttrTypeWalker walker; 163 164 // Perform the deep clones while visiting. Builders create a distinct 165 // attribute to make sure that new instances are always created by the 166 // uniquer. 167 walker.addWalk([&](LLVM::AliasScopeDomainAttr domainAttr) { 168 mapping[domainAttr] = LLVM::AliasScopeDomainAttr::get( 169 domainAttr.getContext(), domainAttr.getDescription()); 170 }); 171 172 walker.addWalk([&](LLVM::AliasScopeAttr scopeAttr) { 173 mapping[scopeAttr] = LLVM::AliasScopeAttr::get( 174 cast<LLVM::AliasScopeDomainAttr>(mapping.lookup(scopeAttr.getDomain())), 175 scopeAttr.getDescription()); 176 }); 177 178 // Map an array of scopes to an array of deep clones. 179 auto convertScopeList = [&](ArrayAttr arrayAttr) -> ArrayAttr { 180 if (!arrayAttr) 181 return nullptr; 182 183 // Create the deep clones if necessary. 184 walker.walk(arrayAttr); 185 186 return ArrayAttr::get(arrayAttr.getContext(), 187 llvm::map_to_vector(arrayAttr, [&](Attribute attr) { 188 return mapping.lookup(attr); 189 })); 190 }; 191 192 for (Block &block : inlinedBlocks) { 193 block.walk([&](Operation *op) { 194 if (auto aliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(op)) { 195 aliasInterface.setAliasScopes( 196 convertScopeList(aliasInterface.getAliasScopesOrNull())); 197 aliasInterface.setNoAliasScopes( 198 convertScopeList(aliasInterface.getNoAliasScopesOrNull())); 199 } 200 201 if (auto noAliasScope = dyn_cast<LLVM::NoAliasScopeDeclOp>(op)) { 202 // Create the deep clones if necessary. 203 walker.walk(noAliasScope.getScopeAttr()); 204 205 noAliasScope.setScopeAttr(cast<LLVM::AliasScopeAttr>( 206 mapping.lookup(noAliasScope.getScopeAttr()))); 207 } 208 }); 209 } 210 } 211 212 /// Creates a new ArrayAttr by concatenating `lhs` with `rhs`. 213 /// Returns null if both parameters are null. If only one attribute is null, 214 /// return the other. 215 static ArrayAttr concatArrayAttr(ArrayAttr lhs, ArrayAttr rhs) { 216 if (!lhs) 217 return rhs; 218 if (!rhs) 219 return lhs; 220 221 SmallVector<Attribute> result; 222 llvm::append_range(result, lhs); 223 llvm::append_range(result, rhs); 224 return ArrayAttr::get(lhs.getContext(), result); 225 } 226 227 /// Attempts to return the set of all underlying pointer values that 228 /// `pointerValue` is based on. This function traverses through select 229 /// operations and block arguments. 230 static FailureOr<SmallVector<Value>> 231 getUnderlyingObjectSet(Value pointerValue) { 232 SmallVector<Value> result; 233 WalkContinuation walkResult = walkSlice(pointerValue, [&](Value val) { 234 // Attempt to advance to the source of the underlying view-like operation. 235 // Examples of view-like operations include GEPOp and AddrSpaceCastOp. 236 if (auto viewOp = val.getDefiningOp<ViewLikeOpInterface>()) 237 return WalkContinuation::advanceTo(viewOp.getViewSource()); 238 239 // Attempt to advance to control flow predecessors. 240 std::optional<SmallVector<Value>> controlFlowPredecessors = 241 getControlFlowPredecessors(val); 242 if (controlFlowPredecessors) 243 return WalkContinuation::advanceTo(*controlFlowPredecessors); 244 245 // For all non-control flow results, consider `val` an underlying object. 246 if (isa<OpResult>(val)) { 247 result.push_back(val); 248 return WalkContinuation::skip(); 249 } 250 251 // If this place is reached, `val` is a block argument that is not 252 // understood. Therefore, we conservatively interrupt. 253 // Note: Dealing with function arguments is not necessary, as the slice 254 // would have to go through an SSACopyOp first. 255 return WalkContinuation::interrupt(); 256 }); 257 258 if (walkResult.wasInterrupted()) 259 return failure(); 260 261 return result; 262 } 263 264 /// Creates a new AliasScopeAttr for every noalias parameter and attaches it to 265 /// the appropriate inlined memory operations in an attempt to preserve the 266 /// original semantics of the parameter attribute. 267 static void createNewAliasScopesFromNoAliasParameter( 268 Operation *call, iterator_range<Region::iterator> inlinedBlocks) { 269 270 // First, collect all ssa copy operations, which correspond to function 271 // parameters, and additionally store the noalias parameters. All parameters 272 // have been marked by the `handleArgument` implementation by using the 273 // `ssa.copy` intrinsic. Additionally, noalias parameters have an attached 274 // `noalias` attribute to the intrinsics. These intrinsics are only meant to 275 // be temporary and should therefore be deleted after we're done using them 276 // here. 277 SetVector<LLVM::SSACopyOp> ssaCopies; 278 SetVector<LLVM::SSACopyOp> noAliasParams; 279 for (Value argument : cast<LLVM::CallOp>(call).getArgOperands()) { 280 for (Operation *user : argument.getUsers()) { 281 auto ssaCopy = llvm::dyn_cast<LLVM::SSACopyOp>(user); 282 if (!ssaCopy) 283 continue; 284 ssaCopies.insert(ssaCopy); 285 286 if (!ssaCopy->hasAttr(LLVM::LLVMDialect::getNoAliasAttrName())) 287 continue; 288 noAliasParams.insert(ssaCopy); 289 } 290 } 291 292 // Scope exit block to make it impossible to forget to get rid of the 293 // intrinsics. 294 auto exit = llvm::make_scope_exit([&] { 295 for (LLVM::SSACopyOp ssaCopyOp : ssaCopies) { 296 ssaCopyOp.replaceAllUsesWith(ssaCopyOp.getOperand()); 297 ssaCopyOp->erase(); 298 } 299 }); 300 301 // If there were no noalias parameters, we have nothing to do here. 302 if (noAliasParams.empty()) 303 return; 304 305 // Create a new domain for this specific inlining and a new scope for every 306 // noalias parameter. 307 auto functionDomain = LLVM::AliasScopeDomainAttr::get( 308 call->getContext(), cast<LLVM::CallOp>(call).getCalleeAttr().getAttr()); 309 DenseMap<Value, LLVM::AliasScopeAttr> pointerScopes; 310 for (LLVM::SSACopyOp copyOp : noAliasParams) { 311 auto scope = LLVM::AliasScopeAttr::get(functionDomain); 312 pointerScopes[copyOp] = scope; 313 314 OpBuilder(call).create<LLVM::NoAliasScopeDeclOp>(call->getLoc(), scope); 315 } 316 317 // Go through every instruction and attempt to find which noalias parameters 318 // it is definitely based on and definitely not based on. 319 for (Block &inlinedBlock : inlinedBlocks) { 320 inlinedBlock.walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) { 321 // Collect the pointer arguments affected by the alias scopes. 322 SmallVector<Value> pointerArgs = aliasInterface.getAccessedOperands(); 323 324 // Find the set of underlying pointers that this pointer is based on. 325 SmallPtrSet<Value, 4> basedOnPointers; 326 for (Value pointer : pointerArgs) { 327 FailureOr<SmallVector<Value>> underlyingObjectSet = 328 getUnderlyingObjectSet(pointer); 329 if (failed(underlyingObjectSet)) 330 return; 331 llvm::copy(*underlyingObjectSet, 332 std::inserter(basedOnPointers, basedOnPointers.begin())); 333 } 334 335 bool aliasesOtherKnownObject = false; 336 // Go through the based on pointers and check that they are either: 337 // * Constants that can be ignored (undef, poison, null pointer). 338 // * Based on a pointer parameter. 339 // * Other pointers that we know can't alias with our noalias parameter. 340 // 341 // Any other value might be a pointer based on any noalias parameter that 342 // hasn't been identified. In that case conservatively don't add any 343 // scopes to this operation indicating either aliasing or not aliasing 344 // with any parameter. 345 if (llvm::any_of(basedOnPointers, [&](Value object) { 346 if (matchPattern(object, m_Constant())) 347 return false; 348 349 if (auto ssaCopy = object.getDefiningOp<LLVM::SSACopyOp>()) { 350 // If that value is based on a noalias parameter, it is guaranteed 351 // to not alias with any other object. 352 aliasesOtherKnownObject |= !noAliasParams.contains(ssaCopy); 353 return false; 354 } 355 356 if (isa_and_nonnull<LLVM::AllocaOp, LLVM::AddressOfOp>( 357 object.getDefiningOp())) { 358 aliasesOtherKnownObject = true; 359 return false; 360 } 361 return true; 362 })) 363 return; 364 365 // Add all noalias parameter scopes to the noalias scope list that we are 366 // not based on. 367 SmallVector<Attribute> noAliasScopes; 368 for (LLVM::SSACopyOp noAlias : noAliasParams) { 369 if (basedOnPointers.contains(noAlias)) 370 continue; 371 372 noAliasScopes.push_back(pointerScopes[noAlias]); 373 } 374 375 if (!noAliasScopes.empty()) 376 aliasInterface.setNoAliasScopes( 377 concatArrayAttr(aliasInterface.getNoAliasScopesOrNull(), 378 ArrayAttr::get(call->getContext(), noAliasScopes))); 379 380 // Don't add alias scopes to call operations or operations that might 381 // operate on pointers not based on any noalias parameter. 382 // Since we add all scopes to an operation's noalias list that it 383 // definitely doesn't alias, we mustn't do the same for the alias.scope 384 // list if other objects are involved. 385 // 386 // Consider the following case: 387 // %0 = llvm.alloca 388 // %1 = select %magic, %0, %noalias_param 389 // store 5, %1 (1) noalias=[scope(...)] 390 // ... 391 // store 3, %0 (2) noalias=[scope(noalias_param), scope(...)] 392 // 393 // We can add the scopes of any noalias parameters that aren't 394 // noalias_param's scope to (1) and add all of them to (2). We mustn't add 395 // the scope of noalias_param to the alias.scope list of (1) since 396 // that would mean (2) cannot alias with (1) which is wrong since both may 397 // store to %0. 398 // 399 // In conclusion, only add scopes to the alias.scope list if all pointers 400 // have a corresponding scope. 401 // Call operations are included in this list since we do not know whether 402 // the callee accesses any memory besides the ones passed as its 403 // arguments. 404 if (aliasesOtherKnownObject || 405 isa<LLVM::CallOp>(aliasInterface.getOperation())) 406 return; 407 408 SmallVector<Attribute> aliasScopes; 409 for (LLVM::SSACopyOp noAlias : noAliasParams) 410 if (basedOnPointers.contains(noAlias)) 411 aliasScopes.push_back(pointerScopes[noAlias]); 412 413 if (!aliasScopes.empty()) 414 aliasInterface.setAliasScopes( 415 concatArrayAttr(aliasInterface.getAliasScopesOrNull(), 416 ArrayAttr::get(call->getContext(), aliasScopes))); 417 }); 418 } 419 } 420 421 /// Appends any alias scopes of the call operation to any inlined memory 422 /// operation. 423 static void 424 appendCallOpAliasScopes(Operation *call, 425 iterator_range<Region::iterator> inlinedBlocks) { 426 auto callAliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(call); 427 if (!callAliasInterface) 428 return; 429 430 ArrayAttr aliasScopes = callAliasInterface.getAliasScopesOrNull(); 431 ArrayAttr noAliasScopes = callAliasInterface.getNoAliasScopesOrNull(); 432 // If the call has neither alias scopes or noalias scopes we have nothing to 433 // do here. 434 if (!aliasScopes && !noAliasScopes) 435 return; 436 437 // Simply append the call op's alias and noalias scopes to any operation 438 // implementing AliasAnalysisOpInterface. 439 for (Block &block : inlinedBlocks) { 440 block.walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) { 441 if (aliasScopes) 442 aliasInterface.setAliasScopes(concatArrayAttr( 443 aliasInterface.getAliasScopesOrNull(), aliasScopes)); 444 445 if (noAliasScopes) 446 aliasInterface.setNoAliasScopes(concatArrayAttr( 447 aliasInterface.getNoAliasScopesOrNull(), noAliasScopes)); 448 }); 449 } 450 } 451 452 /// Handles all interactions with alias scopes during inlining. 453 static void handleAliasScopes(Operation *call, 454 iterator_range<Region::iterator> inlinedBlocks) { 455 deepCloneAliasScopes(inlinedBlocks); 456 createNewAliasScopesFromNoAliasParameter(call, inlinedBlocks); 457 appendCallOpAliasScopes(call, inlinedBlocks); 458 } 459 460 /// Appends any access groups of the call operation to any inlined memory 461 /// operation. 462 static void handleAccessGroups(Operation *call, 463 iterator_range<Region::iterator> inlinedBlocks) { 464 auto callAccessGroupInterface = dyn_cast<LLVM::AccessGroupOpInterface>(call); 465 if (!callAccessGroupInterface) 466 return; 467 468 auto accessGroups = callAccessGroupInterface.getAccessGroupsOrNull(); 469 if (!accessGroups) 470 return; 471 472 // Simply append the call op's access groups to any operation implementing 473 // AccessGroupOpInterface. 474 for (Block &block : inlinedBlocks) 475 for (auto accessGroupOpInterface : 476 block.getOps<LLVM::AccessGroupOpInterface>()) 477 accessGroupOpInterface.setAccessGroups(concatArrayAttr( 478 accessGroupOpInterface.getAccessGroupsOrNull(), accessGroups)); 479 } 480 481 /// Updates locations inside loop annotations to reflect that they were inlined. 482 static void 483 handleLoopAnnotations(Operation *call, 484 iterator_range<Region::iterator> inlinedBlocks) { 485 // Attempt to extract a DISubprogram from the callee. 486 auto func = call->getParentOfType<FunctionOpInterface>(); 487 if (!func) 488 return; 489 LocationAttr funcLoc = func->getLoc(); 490 auto fusedLoc = dyn_cast_if_present<FusedLoc>(funcLoc); 491 if (!fusedLoc) 492 return; 493 auto scope = 494 dyn_cast_if_present<LLVM::DISubprogramAttr>(fusedLoc.getMetadata()); 495 if (!scope) 496 return; 497 498 // Helper to build a new fused location that reflects the inlining of the loop 499 // annotation. 500 auto updateLoc = [&](FusedLoc loc) -> FusedLoc { 501 if (!loc) 502 return {}; 503 Location callSiteLoc = CallSiteLoc::get(loc, call->getLoc()); 504 return FusedLoc::get(loc.getContext(), callSiteLoc, scope); 505 }; 506 507 AttrTypeReplacer replacer; 508 replacer.addReplacement([&](LLVM::LoopAnnotationAttr loopAnnotation) 509 -> std::pair<Attribute, WalkResult> { 510 FusedLoc newStartLoc = updateLoc(loopAnnotation.getStartLoc()); 511 FusedLoc newEndLoc = updateLoc(loopAnnotation.getEndLoc()); 512 if (!newStartLoc && !newEndLoc) 513 return {loopAnnotation, WalkResult::advance()}; 514 auto newLoopAnnotation = LLVM::LoopAnnotationAttr::get( 515 loopAnnotation.getContext(), loopAnnotation.getDisableNonforced(), 516 loopAnnotation.getVectorize(), loopAnnotation.getInterleave(), 517 loopAnnotation.getUnroll(), loopAnnotation.getUnrollAndJam(), 518 loopAnnotation.getLicm(), loopAnnotation.getDistribute(), 519 loopAnnotation.getPipeline(), loopAnnotation.getPeeled(), 520 loopAnnotation.getUnswitch(), loopAnnotation.getMustProgress(), 521 loopAnnotation.getIsVectorized(), newStartLoc, newEndLoc, 522 loopAnnotation.getParallelAccesses()); 523 // Needs to advance, as loop annotations can be nested. 524 return {newLoopAnnotation, WalkResult::advance()}; 525 }); 526 527 for (Block &block : inlinedBlocks) 528 for (Operation &op : block) 529 replacer.recursivelyReplaceElementsIn(&op); 530 } 531 532 /// If `requestedAlignment` is higher than the alignment specified on `alloca`, 533 /// realigns `alloca` if this does not exceed the natural stack alignment. 534 /// Returns the post-alignment of `alloca`, whether it was realigned or not. 535 static uint64_t tryToEnforceAllocaAlignment(LLVM::AllocaOp alloca, 536 uint64_t requestedAlignment, 537 DataLayout const &dataLayout) { 538 uint64_t allocaAlignment = alloca.getAlignment().value_or(1); 539 if (requestedAlignment <= allocaAlignment) 540 // No realignment necessary. 541 return allocaAlignment; 542 uint64_t naturalStackAlignmentBits = dataLayout.getStackAlignment(); 543 // If the natural stack alignment is not specified, the data layout returns 544 // zero. Optimistically allow realignment in this case. 545 if (naturalStackAlignmentBits == 0 || 546 // If the requested alignment exceeds the natural stack alignment, this 547 // will trigger a dynamic stack realignment, so we prefer to copy... 548 8 * requestedAlignment <= naturalStackAlignmentBits || 549 // ...unless the alloca already triggers dynamic stack realignment. Then 550 // we might as well further increase the alignment to avoid a copy. 551 8 * allocaAlignment > naturalStackAlignmentBits) { 552 alloca.setAlignment(requestedAlignment); 553 allocaAlignment = requestedAlignment; 554 } 555 return allocaAlignment; 556 } 557 558 /// Tries to find and return the alignment of the pointer `value` by looking for 559 /// an alignment attribute on the defining allocation op or function argument. 560 /// If the found alignment is lower than `requestedAlignment`, tries to realign 561 /// the pointer, then returns the resulting post-alignment, regardless of 562 /// whether it was realigned or not. If no existing alignment attribute is 563 /// found, returns 1 (i.e., assume that no alignment is guaranteed). 564 static uint64_t tryToEnforceAlignment(Value value, uint64_t requestedAlignment, 565 DataLayout const &dataLayout) { 566 if (Operation *definingOp = value.getDefiningOp()) { 567 if (auto alloca = dyn_cast<LLVM::AllocaOp>(definingOp)) 568 return tryToEnforceAllocaAlignment(alloca, requestedAlignment, 569 dataLayout); 570 if (auto addressOf = dyn_cast<LLVM::AddressOfOp>(definingOp)) 571 if (auto global = SymbolTable::lookupNearestSymbolFrom<LLVM::GlobalOp>( 572 definingOp, addressOf.getGlobalNameAttr())) 573 return global.getAlignment().value_or(1); 574 // We don't currently handle this operation; assume no alignment. 575 return 1; 576 } 577 // Since there is no defining op, this is a block argument. Probably this 578 // comes directly from a function argument, so check that this is the case. 579 Operation *parentOp = value.getParentBlock()->getParentOp(); 580 if (auto func = dyn_cast<LLVM::LLVMFuncOp>(parentOp)) { 581 // Use the alignment attribute set for this argument in the parent function 582 // if it has been set. 583 auto blockArg = llvm::cast<BlockArgument>(value); 584 if (Attribute alignAttr = func.getArgAttr( 585 blockArg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName())) 586 return cast<IntegerAttr>(alignAttr).getValue().getLimitedValue(); 587 } 588 // We didn't find anything useful; assume no alignment. 589 return 1; 590 } 591 592 /// Introduces a new alloca and copies the memory pointed to by `argument` to 593 /// the address of the new alloca, then returns the value of the new alloca. 594 static Value handleByValArgumentInit(OpBuilder &builder, Location loc, 595 Value argument, Type elementType, 596 uint64_t elementTypeSize, 597 uint64_t targetAlignment) { 598 // Allocate the new value on the stack. 599 Value allocaOp; 600 { 601 // Since this is a static alloca, we can put it directly in the entry block, 602 // so they can be absorbed into the prologue/epilogue at code generation. 603 OpBuilder::InsertionGuard insertionGuard(builder); 604 Block *entryBlock = &(*argument.getParentRegion()->begin()); 605 builder.setInsertionPointToStart(entryBlock); 606 Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(), 607 builder.getI64IntegerAttr(1)); 608 allocaOp = builder.create<LLVM::AllocaOp>( 609 loc, argument.getType(), elementType, one, targetAlignment); 610 } 611 // Copy the pointee to the newly allocated value. 612 Value copySize = builder.create<LLVM::ConstantOp>( 613 loc, builder.getI64Type(), builder.getI64IntegerAttr(elementTypeSize)); 614 builder.create<LLVM::MemcpyOp>(loc, allocaOp, argument, copySize, 615 /*isVolatile=*/false); 616 return allocaOp; 617 } 618 619 /// Handles a function argument marked with the byval attribute by introducing a 620 /// memcpy or realigning the defining operation, if required either due to the 621 /// pointee being writeable in the callee, and/or due to an alignment mismatch. 622 /// `requestedAlignment` specifies the alignment set in the "align" argument 623 /// attribute (or 1 if no align attribute was set). 624 static Value handleByValArgument(OpBuilder &builder, Operation *callable, 625 Value argument, Type elementType, 626 uint64_t requestedAlignment) { 627 auto func = cast<LLVM::LLVMFuncOp>(callable); 628 LLVM::MemoryEffectsAttr memoryEffects = func.getMemoryEffectsAttr(); 629 // If there is no memory effects attribute, assume that the function is 630 // not read-only. 631 bool isReadOnly = memoryEffects && 632 memoryEffects.getArgMem() != LLVM::ModRefInfo::ModRef && 633 memoryEffects.getArgMem() != LLVM::ModRefInfo::Mod; 634 // Check if there's an alignment mismatch requiring us to copy. 635 DataLayout dataLayout = DataLayout::closest(callable); 636 uint64_t minimumAlignment = dataLayout.getTypeABIAlignment(elementType); 637 if (isReadOnly) { 638 if (requestedAlignment <= minimumAlignment) 639 return argument; 640 uint64_t currentAlignment = 641 tryToEnforceAlignment(argument, requestedAlignment, dataLayout); 642 if (currentAlignment >= requestedAlignment) 643 return argument; 644 } 645 uint64_t targetAlignment = std::max(requestedAlignment, minimumAlignment); 646 return handleByValArgumentInit( 647 builder, argument.getLoc(), argument, elementType, 648 dataLayout.getTypeSize(elementType), targetAlignment); 649 } 650 651 namespace { 652 struct LLVMInlinerInterface : public DialectInlinerInterface { 653 using DialectInlinerInterface::DialectInlinerInterface; 654 655 LLVMInlinerInterface(Dialect *dialect) 656 : DialectInlinerInterface(dialect), 657 // Cache set of StringAttrs for fast lookup in `isLegalToInline`. 658 disallowedFunctionAttrs({ 659 StringAttr::get(dialect->getContext(), "noduplicate"), 660 StringAttr::get(dialect->getContext(), "presplitcoroutine"), 661 StringAttr::get(dialect->getContext(), "returns_twice"), 662 StringAttr::get(dialect->getContext(), "strictfp"), 663 }) {} 664 665 bool isLegalToInline(Operation *call, Operation *callable, 666 bool wouldBeCloned) const final { 667 if (!isa<LLVM::CallOp>(call)) { 668 LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is not an '" 669 << LLVM::CallOp::getOperationName() << "' op\n"); 670 return false; 671 } 672 auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable); 673 if (!funcOp) { 674 LLVM_DEBUG(llvm::dbgs() 675 << "Cannot inline: callable is not an '" 676 << LLVM::LLVMFuncOp::getOperationName() << "' op\n"); 677 return false; 678 } 679 if (funcOp.isNoInline()) { 680 LLVM_DEBUG(llvm::dbgs() 681 << "Cannot inline: function is marked no_inline\n"); 682 return false; 683 } 684 if (funcOp.isVarArg()) { 685 LLVM_DEBUG(llvm::dbgs() << "Cannot inline: callable is variadic\n"); 686 return false; 687 } 688 // TODO: Generate aliasing metadata from noalias result attributes. 689 if (auto attrs = funcOp.getArgAttrs()) { 690 for (DictionaryAttr attrDict : attrs->getAsRange<DictionaryAttr>()) { 691 if (attrDict.contains(LLVM::LLVMDialect::getInAllocaAttrName())) { 692 LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName() 693 << ": inalloca arguments not supported\n"); 694 return false; 695 } 696 } 697 } 698 // TODO: Handle exceptions. 699 if (funcOp.getPersonality()) { 700 LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName() 701 << ": unhandled function personality\n"); 702 return false; 703 } 704 if (funcOp.getPassthrough()) { 705 // TODO: Used attributes should not be passthrough. 706 if (llvm::any_of(*funcOp.getPassthrough(), [&](Attribute attr) { 707 auto stringAttr = dyn_cast<StringAttr>(attr); 708 if (!stringAttr) 709 return false; 710 if (disallowedFunctionAttrs.contains(stringAttr)) { 711 LLVM_DEBUG(llvm::dbgs() 712 << "Cannot inline " << funcOp.getSymName() 713 << ": found disallowed function attribute " 714 << stringAttr << "\n"); 715 return true; 716 } 717 return false; 718 })) 719 return false; 720 } 721 return true; 722 } 723 724 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { 725 return true; 726 } 727 728 bool isLegalToInline(Operation *op, Region *, bool, IRMapping &) const final { 729 // The inliner cannot handle variadic function arguments. 730 return !isa<LLVM::VaStartOp>(op); 731 } 732 733 /// Handle the given inlined return by replacing it with a branch. This 734 /// overload is called when the inlined region has more than one block. 735 void handleTerminator(Operation *op, Block *newDest) const final { 736 // Only return needs to be handled here. 737 auto returnOp = dyn_cast<LLVM::ReturnOp>(op); 738 if (!returnOp) 739 return; 740 741 // Replace the return with a branch to the dest. 742 OpBuilder builder(op); 743 builder.create<LLVM::BrOp>(op->getLoc(), returnOp.getOperands(), newDest); 744 op->erase(); 745 } 746 747 bool allowSingleBlockOptimization( 748 iterator_range<Region::iterator> inlinedBlocks) const final { 749 if (!inlinedBlocks.empty() && 750 isa<LLVM::UnreachableOp>(inlinedBlocks.begin()->getTerminator())) 751 return false; 752 return true; 753 } 754 755 /// Handle the given inlined return by replacing the uses of the call with the 756 /// operands of the return. This overload is called when the inlined region 757 /// only contains one block. 758 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { 759 // Return will be the only terminator present. 760 auto returnOp = cast<LLVM::ReturnOp>(op); 761 762 // Replace the values directly with the return operands. 763 assert(returnOp.getNumOperands() == valuesToRepl.size()); 764 for (auto [dst, src] : llvm::zip(valuesToRepl, returnOp.getOperands())) 765 dst.replaceAllUsesWith(src); 766 } 767 768 Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, 769 Value argument, 770 DictionaryAttr argumentAttrs) const final { 771 if (std::optional<NamedAttribute> attr = 772 argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName())) { 773 Type elementType = cast<TypeAttr>(attr->getValue()).getValue(); 774 uint64_t requestedAlignment = 1; 775 if (std::optional<NamedAttribute> alignAttr = 776 argumentAttrs.getNamed(LLVM::LLVMDialect::getAlignAttrName())) { 777 requestedAlignment = cast<IntegerAttr>(alignAttr->getValue()) 778 .getValue() 779 .getLimitedValue(); 780 } 781 return handleByValArgument(builder, callable, argument, elementType, 782 requestedAlignment); 783 } 784 785 // This code is essentially a workaround for deficiencies in the inliner 786 // interface: We need to transform operations *after* inlined based on the 787 // argument attributes of the parameters *before* inlining. This method runs 788 // prior to actual inlining and thus cannot transform the post-inlining 789 // code, while `processInlinedCallBlocks` does not have access to 790 // pre-inlining function arguments. Additionally, it is required to 791 // distinguish which parameter an SSA value originally came from. As a 792 // workaround until this is changed: Create an ssa.copy intrinsic with the 793 // noalias attribute (when it was present before) that can easily be found, 794 // and is extremely unlikely to exist in the code prior to inlining, using 795 // this to communicate between this method and `processInlinedCallBlocks`. 796 // TODO: Fix this by refactoring the inliner interface. 797 auto copyOp = builder.create<LLVM::SSACopyOp>(call->getLoc(), argument); 798 if (argumentAttrs.contains(LLVM::LLVMDialect::getNoAliasAttrName())) 799 copyOp->setDiscardableAttr( 800 builder.getStringAttr(LLVM::LLVMDialect::getNoAliasAttrName()), 801 builder.getUnitAttr()); 802 return copyOp; 803 } 804 805 void processInlinedCallBlocks( 806 Operation *call, 807 iterator_range<Region::iterator> inlinedBlocks) const override { 808 handleInlinedAllocas(call, inlinedBlocks); 809 handleAliasScopes(call, inlinedBlocks); 810 handleAccessGroups(call, inlinedBlocks); 811 handleLoopAnnotations(call, inlinedBlocks); 812 } 813 814 // Keeping this (immutable) state on the interface allows us to look up 815 // StringAttrs instead of looking up strings, since StringAttrs are bound to 816 // the current context and thus cannot be initialized as static fields. 817 const DenseSet<StringAttr> disallowedFunctionAttrs; 818 }; 819 820 } // end anonymous namespace 821 822 void mlir::LLVM::registerInlinerInterface(DialectRegistry ®istry) { 823 registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { 824 dialect->addInterfaces<LLVMInlinerInterface>(); 825 }); 826 } 827 828 void mlir::NVVM::registerInlinerInterface(DialectRegistry ®istry) { 829 registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) { 830 dialect->addInterfaces<LLVMInlinerInterface>(); 831 }); 832 } 833