1 //===- LLVMMemorySlot.cpp - MemorySlot interfaces ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements MemorySlot-related interfaces for LLVM dialect 10 // operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 16 #include "mlir/IR/Matchers.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/Interfaces/DataLayoutInterfaces.h" 19 #include "mlir/Interfaces/MemorySlotInterfaces.h" 20 #include "llvm/ADT/STLExtras.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 #define DEBUG_TYPE "sroa" 24 25 using namespace mlir; 26 27 //===----------------------------------------------------------------------===// 28 // Interfaces for AllocaOp 29 //===----------------------------------------------------------------------===// 30 31 llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() { 32 if (!getOperation()->getBlock()->isEntryBlock()) 33 return {}; 34 35 return {MemorySlot{getResult(), getElemType()}}; 36 } 37 38 Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot, 39 OpBuilder &builder) { 40 return builder.create<LLVM::UndefOp>(getLoc(), slot.elemType); 41 } 42 43 void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot, 44 BlockArgument argument, 45 OpBuilder &builder) { 46 for (Operation *user : getOperation()->getUsers()) 47 if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user)) 48 builder.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument, 49 declareOp.getVarInfo(), 50 declareOp.getLocationExpr()); 51 } 52 53 std::optional<PromotableAllocationOpInterface> 54 LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot, 55 Value defaultValue, 56 OpBuilder &builder) { 57 if (defaultValue && defaultValue.use_empty()) 58 defaultValue.getDefiningOp()->erase(); 59 this->erase(); 60 return std::nullopt; 61 } 62 63 SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() { 64 if (!mlir::matchPattern(getArraySize(), m_One())) 65 return {}; 66 67 auto destructurable = dyn_cast<DestructurableTypeInterface>(getElemType()); 68 if (!destructurable) 69 return {}; 70 71 std::optional<DenseMap<Attribute, Type>> destructuredType = 72 destructurable.getSubelementIndexMap(); 73 if (!destructuredType) 74 return {}; 75 76 return {DestructurableMemorySlot{{getResult(), getElemType()}, 77 *destructuredType}}; 78 } 79 80 DenseMap<Attribute, MemorySlot> LLVM::AllocaOp::destructure( 81 const DestructurableMemorySlot &slot, 82 const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder, 83 SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) { 84 assert(slot.ptr == getResult()); 85 builder.setInsertionPointAfter(*this); 86 87 auto destructurableType = cast<DestructurableTypeInterface>(getElemType()); 88 DenseMap<Attribute, MemorySlot> slotMap; 89 for (Attribute index : usedIndices) { 90 Type elemType = destructurableType.getTypeAtIndex(index); 91 assert(elemType && "used index must exist"); 92 auto subAlloca = builder.create<LLVM::AllocaOp>( 93 getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType, 94 getArraySize()); 95 newAllocators.push_back(subAlloca); 96 slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType}); 97 } 98 99 return slotMap; 100 } 101 102 std::optional<DestructurableAllocationOpInterface> 103 LLVM::AllocaOp::handleDestructuringComplete( 104 const DestructurableMemorySlot &slot, OpBuilder &builder) { 105 assert(slot.ptr == getResult()); 106 this->erase(); 107 return std::nullopt; 108 } 109 110 //===----------------------------------------------------------------------===// 111 // Interfaces for LoadOp/StoreOp 112 //===----------------------------------------------------------------------===// 113 114 bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) { 115 return getAddr() == slot.ptr; 116 } 117 118 bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; } 119 120 Value LLVM::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder, 121 Value reachingDef, const DataLayout &dataLayout) { 122 llvm_unreachable("getStored should not be called on LoadOp"); 123 } 124 125 bool LLVM::StoreOp::loadsFrom(const MemorySlot &slot) { return false; } 126 127 bool LLVM::StoreOp::storesTo(const MemorySlot &slot) { 128 return getAddr() == slot.ptr; 129 } 130 131 /// Checks if `type` can be used in any kind of conversion sequences. 132 static bool isSupportedTypeForConversion(Type type) { 133 // Aggregate types are not bitcastable. 134 if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type)) 135 return false; 136 137 // LLVM vector types are only used for either pointers or target specific 138 // types. These types cannot be casted in the general case, thus the memory 139 // optimizations do not support them. 140 if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type)) 141 return false; 142 143 // Scalable types are not supported. 144 if (auto vectorType = dyn_cast<VectorType>(type)) 145 return !vectorType.isScalable(); 146 return true; 147 } 148 149 /// Checks that `rhs` can be converted to `lhs` by a sequence of casts and 150 /// truncations. Checks for narrowing or widening conversion compatibility 151 /// depending on `narrowingConversion`. 152 static bool areConversionCompatible(const DataLayout &layout, Type targetType, 153 Type srcType, bool narrowingConversion) { 154 if (targetType == srcType) 155 return true; 156 157 if (!isSupportedTypeForConversion(targetType) || 158 !isSupportedTypeForConversion(srcType)) 159 return false; 160 161 uint64_t targetSize = layout.getTypeSize(targetType); 162 uint64_t srcSize = layout.getTypeSize(srcType); 163 164 // Pointer casts will only be sane when the bitsize of both pointer types is 165 // the same. 166 if (isa<LLVM::LLVMPointerType>(targetType) && 167 isa<LLVM::LLVMPointerType>(srcType)) 168 return targetSize == srcSize; 169 170 if (narrowingConversion) 171 return targetSize <= srcSize; 172 return targetSize >= srcSize; 173 } 174 175 /// Checks if `dataLayout` describes a little endian layout. 176 static bool isBigEndian(const DataLayout &dataLayout) { 177 auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness()); 178 return endiannessStr && endiannessStr == "big"; 179 } 180 181 /// Converts a value to an integer type of the same size. 182 /// Assumes that the type can be converted. 183 static Value castToSameSizedInt(OpBuilder &builder, Location loc, Value val, 184 const DataLayout &dataLayout) { 185 Type type = val.getType(); 186 assert(isSupportedTypeForConversion(type) && 187 "expected value to have a convertible type"); 188 189 if (isa<IntegerType>(type)) 190 return val; 191 192 uint64_t typeBitSize = dataLayout.getTypeSizeInBits(type); 193 IntegerType valueSizeInteger = builder.getIntegerType(typeBitSize); 194 195 if (isa<LLVM::LLVMPointerType>(type)) 196 return builder.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val); 197 return builder.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val); 198 } 199 200 /// Converts a value with an integer type to `targetType`. 201 static Value castIntValueToSameSizedType(OpBuilder &builder, Location loc, 202 Value val, Type targetType) { 203 assert(isa<IntegerType>(val.getType()) && 204 "expected value to have an integer type"); 205 assert(isSupportedTypeForConversion(targetType) && 206 "expected the target type to be supported for conversions"); 207 if (val.getType() == targetType) 208 return val; 209 if (isa<LLVM::LLVMPointerType>(targetType)) 210 return builder.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val); 211 return builder.createOrFold<LLVM::BitcastOp>(loc, targetType, val); 212 } 213 214 /// Constructs operations that convert `srcValue` into a new value of type 215 /// `targetType`. Assumes the types have the same bitsize. 216 static Value castSameSizedTypes(OpBuilder &builder, Location loc, 217 Value srcValue, Type targetType, 218 const DataLayout &dataLayout) { 219 Type srcType = srcValue.getType(); 220 assert(areConversionCompatible(dataLayout, targetType, srcType, 221 /*narrowingConversion=*/true) && 222 "expected that the compatibility was checked before"); 223 224 // Nothing has to be done if the types are already the same. 225 if (srcType == targetType) 226 return srcValue; 227 228 // In the special case of casting one pointer to another, we want to generate 229 // an address space cast. Bitcasts of pointers are not allowed and using 230 // pointer to integer conversions are not equivalent due to the loss of 231 // provenance. 232 if (isa<LLVM::LLVMPointerType>(targetType) && 233 isa<LLVM::LLVMPointerType>(srcType)) 234 return builder.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType, 235 srcValue); 236 237 // For all other castable types, casting through integers is necessary. 238 Value replacement = castToSameSizedInt(builder, loc, srcValue, dataLayout); 239 return castIntValueToSameSizedType(builder, loc, replacement, targetType); 240 } 241 242 /// Constructs operations that convert `srcValue` into a new value of type 243 /// `targetType`. Performs bit-level extraction if the source type is larger 244 /// than the target type. Assumes that this conversion is possible. 245 static Value createExtractAndCast(OpBuilder &builder, Location loc, 246 Value srcValue, Type targetType, 247 const DataLayout &dataLayout) { 248 // Get the types of the source and target values. 249 Type srcType = srcValue.getType(); 250 assert(areConversionCompatible(dataLayout, targetType, srcType, 251 /*narrowingConversion=*/true) && 252 "expected that the compatibility was checked before"); 253 254 uint64_t srcTypeSize = dataLayout.getTypeSizeInBits(srcType); 255 uint64_t targetTypeSize = dataLayout.getTypeSizeInBits(targetType); 256 if (srcTypeSize == targetTypeSize) 257 return castSameSizedTypes(builder, loc, srcValue, targetType, dataLayout); 258 259 // First, cast the value to a same-sized integer type. 260 Value replacement = castToSameSizedInt(builder, loc, srcValue, dataLayout); 261 262 // Truncate the integer if the size of the target is less than the value. 263 if (isBigEndian(dataLayout)) { 264 uint64_t shiftAmount = srcTypeSize - targetTypeSize; 265 auto shiftConstant = builder.create<LLVM::ConstantOp>( 266 loc, builder.getIntegerAttr(srcType, shiftAmount)); 267 replacement = 268 builder.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant); 269 } 270 271 replacement = builder.create<LLVM::TruncOp>( 272 loc, builder.getIntegerType(targetTypeSize), replacement); 273 274 // Now cast the integer to the actual target type if required. 275 return castIntValueToSameSizedType(builder, loc, replacement, targetType); 276 } 277 278 /// Constructs operations that insert the bits of `srcValue` into the 279 /// "beginning" of `reachingDef` (beginning is endianness dependent). 280 /// Assumes that this conversion is possible. 281 static Value createInsertAndCast(OpBuilder &builder, Location loc, 282 Value srcValue, Value reachingDef, 283 const DataLayout &dataLayout) { 284 285 assert(areConversionCompatible(dataLayout, reachingDef.getType(), 286 srcValue.getType(), 287 /*narrowingConversion=*/false) && 288 "expected that the compatibility was checked before"); 289 uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(srcValue.getType()); 290 uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(reachingDef.getType()); 291 if (slotTypeSize == valueTypeSize) 292 return castSameSizedTypes(builder, loc, srcValue, reachingDef.getType(), 293 dataLayout); 294 295 // In the case where the store only overwrites parts of the memory, 296 // bit fiddling is required to construct the new value. 297 298 // First convert both values to integers of the same size. 299 Value defAsInt = castToSameSizedInt(builder, loc, reachingDef, dataLayout); 300 Value valueAsInt = castToSameSizedInt(builder, loc, srcValue, dataLayout); 301 // Extend the value to the size of the reaching definition. 302 valueAsInt = 303 builder.createOrFold<LLVM::ZExtOp>(loc, defAsInt.getType(), valueAsInt); 304 uint64_t sizeDifference = slotTypeSize - valueTypeSize; 305 if (isBigEndian(dataLayout)) { 306 // On big endian systems, a store to the base pointer overwrites the most 307 // significant bits. To accomodate for this, the stored value needs to be 308 // shifted into the according position. 309 Value bigEndianShift = builder.create<LLVM::ConstantOp>( 310 loc, builder.getIntegerAttr(defAsInt.getType(), sizeDifference)); 311 valueAsInt = 312 builder.createOrFold<LLVM::ShlOp>(loc, valueAsInt, bigEndianShift); 313 } 314 315 // Construct the mask that is used to erase the bits that are overwritten by 316 // the store. 317 APInt maskValue; 318 if (isBigEndian(dataLayout)) { 319 // Build a mask that has the most significant bits set to zero. 320 // Note: This is the same as 2^sizeDifference - 1 321 maskValue = APInt::getAllOnes(sizeDifference).zext(slotTypeSize); 322 } else { 323 // Build a mask that has the least significant bits set to zero. 324 // Note: This is the same as -(2^valueTypeSize) 325 maskValue = APInt::getAllOnes(valueTypeSize).zext(slotTypeSize); 326 maskValue.flipAllBits(); 327 } 328 329 // Mask out the affected bits ... 330 Value mask = builder.create<LLVM::ConstantOp>( 331 loc, builder.getIntegerAttr(defAsInt.getType(), maskValue)); 332 Value masked = builder.createOrFold<LLVM::AndOp>(loc, defAsInt, mask); 333 334 // ... and combine the result with the new value. 335 Value combined = builder.createOrFold<LLVM::OrOp>(loc, masked, valueAsInt); 336 337 return castIntValueToSameSizedType(builder, loc, combined, 338 reachingDef.getType()); 339 } 340 341 Value LLVM::StoreOp::getStored(const MemorySlot &slot, OpBuilder &builder, 342 Value reachingDef, 343 const DataLayout &dataLayout) { 344 assert(reachingDef && reachingDef.getType() == slot.elemType && 345 "expected the reaching definition's type to match the slot's type"); 346 return createInsertAndCast(builder, getLoc(), getValue(), reachingDef, 347 dataLayout); 348 } 349 350 bool LLVM::LoadOp::canUsesBeRemoved( 351 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, 352 SmallVectorImpl<OpOperand *> &newBlockingUses, 353 const DataLayout &dataLayout) { 354 if (blockingUses.size() != 1) 355 return false; 356 Value blockingUse = (*blockingUses.begin())->get(); 357 // If the blocking use is the slot ptr itself, there will be enough 358 // context to reconstruct the result of the load at removal time, so it can 359 // be removed (provided it is not volatile). 360 return blockingUse == slot.ptr && getAddr() == slot.ptr && 361 areConversionCompatible(dataLayout, getResult().getType(), 362 slot.elemType, /*narrowingConversion=*/true) && 363 !getVolatile_(); 364 } 365 366 DeletionKind LLVM::LoadOp::removeBlockingUses( 367 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, 368 OpBuilder &builder, Value reachingDefinition, 369 const DataLayout &dataLayout) { 370 // `canUsesBeRemoved` checked this blocking use must be the loaded slot 371 // pointer. 372 Value newResult = createExtractAndCast(builder, getLoc(), reachingDefinition, 373 getResult().getType(), dataLayout); 374 getResult().replaceAllUsesWith(newResult); 375 return DeletionKind::Delete; 376 } 377 378 bool LLVM::StoreOp::canUsesBeRemoved( 379 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, 380 SmallVectorImpl<OpOperand *> &newBlockingUses, 381 const DataLayout &dataLayout) { 382 if (blockingUses.size() != 1) 383 return false; 384 Value blockingUse = (*blockingUses.begin())->get(); 385 // If the blocking use is the slot ptr itself, dropping the store is 386 // fine, provided we are currently promoting its target value. Don't allow a 387 // store OF the slot pointer, only INTO the slot pointer. 388 return blockingUse == slot.ptr && getAddr() == slot.ptr && 389 getValue() != slot.ptr && 390 areConversionCompatible(dataLayout, slot.elemType, 391 getValue().getType(), 392 /*narrowingConversion=*/false) && 393 !getVolatile_(); 394 } 395 396 DeletionKind LLVM::StoreOp::removeBlockingUses( 397 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, 398 OpBuilder &builder, Value reachingDefinition, 399 const DataLayout &dataLayout) { 400 return DeletionKind::Delete; 401 } 402 403 /// Checks if `slot` can be accessed through the provided access type. 404 static bool isValidAccessType(const MemorySlot &slot, Type accessType, 405 const DataLayout &dataLayout) { 406 return dataLayout.getTypeSize(accessType) <= 407 dataLayout.getTypeSize(slot.elemType); 408 } 409 410 LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses( 411 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 412 const DataLayout &dataLayout) { 413 return success(getAddr() != slot.ptr || 414 isValidAccessType(slot, getType(), dataLayout)); 415 } 416 417 LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses( 418 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 419 const DataLayout &dataLayout) { 420 return success(getAddr() != slot.ptr || 421 isValidAccessType(slot, getValue().getType(), dataLayout)); 422 } 423 424 /// Returns the subslot's type at the requested index. 425 static Type getTypeAtIndex(const DestructurableMemorySlot &slot, 426 Attribute index) { 427 auto subelementIndexMap = 428 cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap(); 429 if (!subelementIndexMap) 430 return {}; 431 assert(!subelementIndexMap->empty()); 432 433 // Note: Returns a null-type when no entry was found. 434 return subelementIndexMap->lookup(index); 435 } 436 437 bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot, 438 SmallPtrSetImpl<Attribute> &usedIndices, 439 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 440 const DataLayout &dataLayout) { 441 if (getVolatile_()) 442 return false; 443 444 // A load always accesses the first element of the destructured slot. 445 auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); 446 Type subslotType = getTypeAtIndex(slot, index); 447 if (!subslotType) 448 return false; 449 450 // The access can only be replaced when the subslot is read within its bounds. 451 if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType)) 452 return false; 453 454 usedIndices.insert(index); 455 return true; 456 } 457 458 DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot, 459 DenseMap<Attribute, MemorySlot> &subslots, 460 OpBuilder &builder, 461 const DataLayout &dataLayout) { 462 auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); 463 auto it = subslots.find(index); 464 assert(it != subslots.end()); 465 466 getAddrMutable().set(it->getSecond().ptr); 467 return DeletionKind::Keep; 468 } 469 470 bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot, 471 SmallPtrSetImpl<Attribute> &usedIndices, 472 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 473 const DataLayout &dataLayout) { 474 if (getVolatile_()) 475 return false; 476 477 // Storing the pointer to memory cannot be dealt with. 478 if (getValue() == slot.ptr) 479 return false; 480 481 // A store always accesses the first element of the destructured slot. 482 auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); 483 Type subslotType = getTypeAtIndex(slot, index); 484 if (!subslotType) 485 return false; 486 487 // The access can only be replaced when the subslot is read within its bounds. 488 if (dataLayout.getTypeSize(getValue().getType()) > 489 dataLayout.getTypeSize(subslotType)) 490 return false; 491 492 usedIndices.insert(index); 493 return true; 494 } 495 496 DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot, 497 DenseMap<Attribute, MemorySlot> &subslots, 498 OpBuilder &builder, 499 const DataLayout &dataLayout) { 500 auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); 501 auto it = subslots.find(index); 502 assert(it != subslots.end()); 503 504 getAddrMutable().set(it->getSecond().ptr); 505 return DeletionKind::Keep; 506 } 507 508 //===----------------------------------------------------------------------===// 509 // Interfaces for discardable OPs 510 //===----------------------------------------------------------------------===// 511 512 /// Conditions the deletion of the operation to the removal of all its uses. 513 static bool forwardToUsers(Operation *op, 514 SmallVectorImpl<OpOperand *> &newBlockingUses) { 515 for (Value result : op->getResults()) 516 for (OpOperand &use : result.getUses()) 517 newBlockingUses.push_back(&use); 518 return true; 519 } 520 521 bool LLVM::BitcastOp::canUsesBeRemoved( 522 const SmallPtrSetImpl<OpOperand *> &blockingUses, 523 SmallVectorImpl<OpOperand *> &newBlockingUses, 524 const DataLayout &dataLayout) { 525 return forwardToUsers(*this, newBlockingUses); 526 } 527 528 DeletionKind LLVM::BitcastOp::removeBlockingUses( 529 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) { 530 return DeletionKind::Delete; 531 } 532 533 bool LLVM::AddrSpaceCastOp::canUsesBeRemoved( 534 const SmallPtrSetImpl<OpOperand *> &blockingUses, 535 SmallVectorImpl<OpOperand *> &newBlockingUses, 536 const DataLayout &dataLayout) { 537 return forwardToUsers(*this, newBlockingUses); 538 } 539 540 DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses( 541 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) { 542 return DeletionKind::Delete; 543 } 544 545 bool LLVM::LifetimeStartOp::canUsesBeRemoved( 546 const SmallPtrSetImpl<OpOperand *> &blockingUses, 547 SmallVectorImpl<OpOperand *> &newBlockingUses, 548 const DataLayout &dataLayout) { 549 return true; 550 } 551 552 DeletionKind LLVM::LifetimeStartOp::removeBlockingUses( 553 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) { 554 return DeletionKind::Delete; 555 } 556 557 bool LLVM::LifetimeEndOp::canUsesBeRemoved( 558 const SmallPtrSetImpl<OpOperand *> &blockingUses, 559 SmallVectorImpl<OpOperand *> &newBlockingUses, 560 const DataLayout &dataLayout) { 561 return true; 562 } 563 564 DeletionKind LLVM::LifetimeEndOp::removeBlockingUses( 565 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) { 566 return DeletionKind::Delete; 567 } 568 569 bool LLVM::InvariantStartOp::canUsesBeRemoved( 570 const SmallPtrSetImpl<OpOperand *> &blockingUses, 571 SmallVectorImpl<OpOperand *> &newBlockingUses, 572 const DataLayout &dataLayout) { 573 return true; 574 } 575 576 DeletionKind LLVM::InvariantStartOp::removeBlockingUses( 577 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) { 578 return DeletionKind::Delete; 579 } 580 581 bool LLVM::InvariantEndOp::canUsesBeRemoved( 582 const SmallPtrSetImpl<OpOperand *> &blockingUses, 583 SmallVectorImpl<OpOperand *> &newBlockingUses, 584 const DataLayout &dataLayout) { 585 return true; 586 } 587 588 DeletionKind LLVM::InvariantEndOp::removeBlockingUses( 589 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) { 590 return DeletionKind::Delete; 591 } 592 593 bool LLVM::LaunderInvariantGroupOp::canUsesBeRemoved( 594 const SmallPtrSetImpl<OpOperand *> &blockingUses, 595 SmallVectorImpl<OpOperand *> &newBlockingUses, 596 const DataLayout &dataLayout) { 597 return forwardToUsers(*this, newBlockingUses); 598 } 599 600 DeletionKind LLVM::LaunderInvariantGroupOp::removeBlockingUses( 601 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) { 602 return DeletionKind::Delete; 603 } 604 605 bool LLVM::StripInvariantGroupOp::canUsesBeRemoved( 606 const SmallPtrSetImpl<OpOperand *> &blockingUses, 607 SmallVectorImpl<OpOperand *> &newBlockingUses, 608 const DataLayout &dataLayout) { 609 return forwardToUsers(*this, newBlockingUses); 610 } 611 612 DeletionKind LLVM::StripInvariantGroupOp::removeBlockingUses( 613 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) { 614 return DeletionKind::Delete; 615 } 616 617 bool LLVM::DbgDeclareOp::canUsesBeRemoved( 618 const SmallPtrSetImpl<OpOperand *> &blockingUses, 619 SmallVectorImpl<OpOperand *> &newBlockingUses, 620 const DataLayout &dataLayout) { 621 return true; 622 } 623 624 DeletionKind LLVM::DbgDeclareOp::removeBlockingUses( 625 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) { 626 return DeletionKind::Delete; 627 } 628 629 bool LLVM::DbgValueOp::canUsesBeRemoved( 630 const SmallPtrSetImpl<OpOperand *> &blockingUses, 631 SmallVectorImpl<OpOperand *> &newBlockingUses, 632 const DataLayout &dataLayout) { 633 // There is only one operand that we can remove the use of. 634 if (blockingUses.size() != 1) 635 return false; 636 637 return (*blockingUses.begin())->get() == getValue(); 638 } 639 640 DeletionKind LLVM::DbgValueOp::removeBlockingUses( 641 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) { 642 // builder by default is after '*this', but we need it before '*this'. 643 builder.setInsertionPoint(*this); 644 645 // Rather than dropping the debug value, replace it with undef to preserve the 646 // debug local variable info. This allows the debugger to inform the user that 647 // the variable has been optimized out. 648 auto undef = 649 builder.create<UndefOp>(getValue().getLoc(), getValue().getType()); 650 getValueMutable().assign(undef); 651 return DeletionKind::Keep; 652 } 653 654 bool LLVM::DbgDeclareOp::requiresReplacedValues() { return true; } 655 656 void LLVM::DbgDeclareOp::visitReplacedValues( 657 ArrayRef<std::pair<Operation *, Value>> definitions, OpBuilder &builder) { 658 for (auto [op, value] : definitions) { 659 builder.setInsertionPointAfter(op); 660 builder.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(), 661 getLocationExpr()); 662 } 663 } 664 665 //===----------------------------------------------------------------------===// 666 // Interfaces for GEPOp 667 //===----------------------------------------------------------------------===// 668 669 static bool hasAllZeroIndices(LLVM::GEPOp gepOp) { 670 return llvm::all_of(gepOp.getIndices(), [](auto index) { 671 auto indexAttr = llvm::dyn_cast_if_present<IntegerAttr>(index); 672 return indexAttr && indexAttr.getValue() == 0; 673 }); 674 } 675 676 bool LLVM::GEPOp::canUsesBeRemoved( 677 const SmallPtrSetImpl<OpOperand *> &blockingUses, 678 SmallVectorImpl<OpOperand *> &newBlockingUses, 679 const DataLayout &dataLayout) { 680 // GEP can be removed as long as it is a no-op and its users can be removed. 681 if (!hasAllZeroIndices(*this)) 682 return false; 683 return forwardToUsers(*this, newBlockingUses); 684 } 685 686 DeletionKind LLVM::GEPOp::removeBlockingUses( 687 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) { 688 return DeletionKind::Delete; 689 } 690 691 /// Returns the amount of bytes the provided GEP elements will offset the 692 /// pointer by. Returns nullopt if no constant offset could be computed. 693 static std::optional<uint64_t> gepToByteOffset(const DataLayout &dataLayout, 694 LLVM::GEPOp gep) { 695 // Collects all indices. 696 SmallVector<uint64_t> indices; 697 for (auto index : gep.getIndices()) { 698 auto constIndex = dyn_cast<IntegerAttr>(index); 699 if (!constIndex) 700 return {}; 701 int64_t gepIndex = constIndex.getInt(); 702 // Negative indices are not supported. 703 if (gepIndex < 0) 704 return {}; 705 indices.push_back(gepIndex); 706 } 707 708 Type currentType = gep.getElemType(); 709 uint64_t offset = indices[0] * dataLayout.getTypeSize(currentType); 710 711 for (uint64_t index : llvm::drop_begin(indices)) { 712 bool shouldCancel = 713 TypeSwitch<Type, bool>(currentType) 714 .Case([&](LLVM::LLVMArrayType arrayType) { 715 offset += 716 index * dataLayout.getTypeSize(arrayType.getElementType()); 717 currentType = arrayType.getElementType(); 718 return false; 719 }) 720 .Case([&](LLVM::LLVMStructType structType) { 721 ArrayRef<Type> body = structType.getBody(); 722 assert(index < body.size() && "expected valid struct indexing"); 723 for (uint32_t i : llvm::seq(index)) { 724 if (!structType.isPacked()) 725 offset = llvm::alignTo( 726 offset, dataLayout.getTypeABIAlignment(body[i])); 727 offset += dataLayout.getTypeSize(body[i]); 728 } 729 730 // Align for the current type as well. 731 if (!structType.isPacked()) 732 offset = llvm::alignTo( 733 offset, dataLayout.getTypeABIAlignment(body[index])); 734 currentType = body[index]; 735 return false; 736 }) 737 .Default([&](Type type) { 738 LLVM_DEBUG(llvm::dbgs() 739 << "[sroa] Unsupported type for offset computations" 740 << type << "\n"); 741 return true; 742 }); 743 744 if (shouldCancel) 745 return std::nullopt; 746 } 747 748 return offset; 749 } 750 751 namespace { 752 /// A struct that stores both the index into the aggregate type of the slot as 753 /// well as the corresponding byte offset in memory. 754 struct SubslotAccessInfo { 755 /// The parent slot's index that the access falls into. 756 uint32_t index; 757 /// The offset into the subslot of the access. 758 uint64_t subslotOffset; 759 }; 760 } // namespace 761 762 /// Computes subslot access information for an access into `slot` with the given 763 /// offset. 764 /// Returns nullopt when the offset is out-of-bounds or when the access is into 765 /// the padding of `slot`. 766 static std::optional<SubslotAccessInfo> 767 getSubslotAccessInfo(const DestructurableMemorySlot &slot, 768 const DataLayout &dataLayout, LLVM::GEPOp gep) { 769 std::optional<uint64_t> offset = gepToByteOffset(dataLayout, gep); 770 if (!offset) 771 return {}; 772 773 // Helper to check that a constant index is in the bounds of the GEP index 774 // representation. LLVM dialects's GEP arguments have a limited bitwidth, thus 775 // this additional check is necessary. 776 auto isOutOfBoundsGEPIndex = [](uint64_t index) { 777 return index >= (1 << LLVM::kGEPConstantBitWidth); 778 }; 779 780 Type type = slot.elemType; 781 if (*offset >= dataLayout.getTypeSize(type)) 782 return {}; 783 return TypeSwitch<Type, std::optional<SubslotAccessInfo>>(type) 784 .Case([&](LLVM::LLVMArrayType arrayType) 785 -> std::optional<SubslotAccessInfo> { 786 // Find which element of the array contains the offset. 787 uint64_t elemSize = dataLayout.getTypeSize(arrayType.getElementType()); 788 uint64_t index = *offset / elemSize; 789 if (isOutOfBoundsGEPIndex(index)) 790 return {}; 791 return SubslotAccessInfo{static_cast<uint32_t>(index), 792 *offset - (index * elemSize)}; 793 }) 794 .Case([&](LLVM::LLVMStructType structType) 795 -> std::optional<SubslotAccessInfo> { 796 uint64_t distanceToStart = 0; 797 // Walk over the elements of the struct to find in which of 798 // them the offset is. 799 for (auto [index, elem] : llvm::enumerate(structType.getBody())) { 800 uint64_t elemSize = dataLayout.getTypeSize(elem); 801 if (!structType.isPacked()) { 802 distanceToStart = llvm::alignTo( 803 distanceToStart, dataLayout.getTypeABIAlignment(elem)); 804 // If the offset is in padding, cancel the rewrite. 805 if (offset < distanceToStart) 806 return {}; 807 } 808 809 if (offset < distanceToStart + elemSize) { 810 if (isOutOfBoundsGEPIndex(index)) 811 return {}; 812 // The offset is within this element, stop iterating the 813 // struct and return the index. 814 return SubslotAccessInfo{static_cast<uint32_t>(index), 815 *offset - distanceToStart}; 816 } 817 818 // The offset is not within this element, continue walking 819 // over the struct. 820 distanceToStart += elemSize; 821 } 822 823 return {}; 824 }); 825 } 826 827 /// Constructs a byte array type of the given size. 828 static LLVM::LLVMArrayType getByteArrayType(MLIRContext *context, 829 unsigned size) { 830 auto byteType = IntegerType::get(context, 8); 831 return LLVM::LLVMArrayType::get(context, byteType, size); 832 } 833 834 LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses( 835 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 836 const DataLayout &dataLayout) { 837 if (getBase() != slot.ptr) 838 return success(); 839 std::optional<uint64_t> gepOffset = gepToByteOffset(dataLayout, *this); 840 if (!gepOffset) 841 return failure(); 842 uint64_t slotSize = dataLayout.getTypeSize(slot.elemType); 843 // Check that the access is strictly inside the slot. 844 if (*gepOffset >= slotSize) 845 return failure(); 846 // Every access that remains in bounds of the remaining slot is considered 847 // legal. 848 mustBeSafelyUsed.emplace_back<MemorySlot>( 849 {getRes(), getByteArrayType(getContext(), slotSize - *gepOffset)}); 850 return success(); 851 } 852 853 bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot, 854 SmallPtrSetImpl<Attribute> &usedIndices, 855 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 856 const DataLayout &dataLayout) { 857 if (!isa<LLVM::LLVMPointerType>(getBase().getType())) 858 return false; 859 860 if (getBase() != slot.ptr) 861 return false; 862 std::optional<SubslotAccessInfo> accessInfo = 863 getSubslotAccessInfo(slot, dataLayout, *this); 864 if (!accessInfo) 865 return false; 866 auto indexAttr = 867 IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index); 868 assert(slot.subelementTypes.contains(indexAttr)); 869 usedIndices.insert(indexAttr); 870 871 // The remainder of the subslot should be accesses in-bounds. Thus, we create 872 // a dummy slot with the size of the remainder. 873 Type subslotType = slot.subelementTypes.lookup(indexAttr); 874 uint64_t slotSize = dataLayout.getTypeSize(subslotType); 875 LLVM::LLVMArrayType remainingSlotType = 876 getByteArrayType(getContext(), slotSize - accessInfo->subslotOffset); 877 mustBeSafelyUsed.emplace_back<MemorySlot>({getRes(), remainingSlotType}); 878 879 return true; 880 } 881 882 DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot, 883 DenseMap<Attribute, MemorySlot> &subslots, 884 OpBuilder &builder, 885 const DataLayout &dataLayout) { 886 std::optional<SubslotAccessInfo> accessInfo = 887 getSubslotAccessInfo(slot, dataLayout, *this); 888 assert(accessInfo && "expected access info to be checked before"); 889 auto indexAttr = 890 IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index); 891 const MemorySlot &newSlot = subslots.at(indexAttr); 892 893 auto byteType = IntegerType::get(builder.getContext(), 8); 894 auto newPtr = builder.createOrFold<LLVM::GEPOp>( 895 getLoc(), getResult().getType(), byteType, newSlot.ptr, 896 ArrayRef<GEPArg>(accessInfo->subslotOffset), getInbounds()); 897 getResult().replaceAllUsesWith(newPtr); 898 return DeletionKind::Delete; 899 } 900 901 //===----------------------------------------------------------------------===// 902 // Utilities for memory intrinsics 903 //===----------------------------------------------------------------------===// 904 905 namespace { 906 907 /// Returns the length of the given memory intrinsic in bytes if it can be known 908 /// at compile-time on a best-effort basis, nothing otherwise. 909 template <class MemIntr> 910 std::optional<uint64_t> getStaticMemIntrLen(MemIntr op) { 911 APInt memIntrLen; 912 if (!matchPattern(op.getLen(), m_ConstantInt(&memIntrLen))) 913 return {}; 914 if (memIntrLen.getBitWidth() > 64) 915 return {}; 916 return memIntrLen.getZExtValue(); 917 } 918 919 /// Returns the length of the given memory intrinsic in bytes if it can be known 920 /// at compile-time on a best-effort basis, nothing otherwise. 921 /// Because MemcpyInlineOp has its length encoded as an attribute, this requires 922 /// specialized handling. 923 template <> 924 std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemcpyInlineOp op) { 925 APInt memIntrLen = op.getLen(); 926 if (memIntrLen.getBitWidth() > 64) 927 return {}; 928 return memIntrLen.getZExtValue(); 929 } 930 931 /// Returns the length of the given memory intrinsic in bytes if it can be known 932 /// at compile-time on a best-effort basis, nothing otherwise. 933 /// Because MemsetInlineOp has its length encoded as an attribute, this requires 934 /// specialized handling. 935 template <> 936 std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemsetInlineOp op) { 937 APInt memIntrLen = op.getLen(); 938 if (memIntrLen.getBitWidth() > 64) 939 return {}; 940 return memIntrLen.getZExtValue(); 941 } 942 943 /// Returns an integer attribute representing the length of a memset intrinsic 944 template <class MemsetIntr> 945 IntegerAttr createMemsetLenAttr(MemsetIntr op) { 946 IntegerAttr memsetLenAttr; 947 bool successfulMatch = 948 matchPattern(op.getLen(), m_Constant<IntegerAttr>(&memsetLenAttr)); 949 (void)successfulMatch; 950 assert(successfulMatch); 951 return memsetLenAttr; 952 } 953 954 /// Returns an integer attribute representing the length of a memset intrinsic 955 /// Because MemsetInlineOp has its length encoded as an attribute, this requires 956 /// specialized handling. 957 template <> 958 IntegerAttr createMemsetLenAttr(LLVM::MemsetInlineOp op) { 959 return op.getLenAttr(); 960 } 961 962 /// Creates a memset intrinsic of that matches the `toReplace` intrinsic 963 /// using the provided parameters. There are template specializations for 964 /// MemsetOp and MemsetInlineOp. 965 template <class MemsetIntr> 966 void createMemsetIntr(OpBuilder &builder, MemsetIntr toReplace, 967 IntegerAttr memsetLenAttr, uint64_t newMemsetSize, 968 DenseMap<Attribute, MemorySlot> &subslots, 969 Attribute index); 970 971 template <> 972 void createMemsetIntr(OpBuilder &builder, LLVM::MemsetOp toReplace, 973 IntegerAttr memsetLenAttr, uint64_t newMemsetSize, 974 DenseMap<Attribute, MemorySlot> &subslots, 975 Attribute index) { 976 Value newMemsetSizeValue = 977 builder 978 .create<LLVM::ConstantOp>( 979 toReplace.getLen().getLoc(), 980 IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize)) 981 .getResult(); 982 983 builder.create<LLVM::MemsetOp>(toReplace.getLoc(), subslots.at(index).ptr, 984 toReplace.getVal(), newMemsetSizeValue, 985 toReplace.getIsVolatile()); 986 } 987 988 template <> 989 void createMemsetIntr(OpBuilder &builder, LLVM::MemsetInlineOp toReplace, 990 IntegerAttr memsetLenAttr, uint64_t newMemsetSize, 991 DenseMap<Attribute, MemorySlot> &subslots, 992 Attribute index) { 993 auto newMemsetSizeValue = 994 IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize); 995 996 builder.create<LLVM::MemsetInlineOp>( 997 toReplace.getLoc(), subslots.at(index).ptr, toReplace.getVal(), 998 newMemsetSizeValue, toReplace.getIsVolatile()); 999 } 1000 1001 } // namespace 1002 1003 /// Returns whether one can be sure the memory intrinsic does not write outside 1004 /// of the bounds of the given slot, on a best-effort basis. 1005 template <class MemIntr> 1006 static bool definitelyWritesOnlyWithinSlot(MemIntr op, const MemorySlot &slot, 1007 const DataLayout &dataLayout) { 1008 if (!isa<LLVM::LLVMPointerType>(slot.ptr.getType()) || 1009 op.getDst() != slot.ptr) 1010 return false; 1011 1012 std::optional<uint64_t> memIntrLen = getStaticMemIntrLen(op); 1013 return memIntrLen && *memIntrLen <= dataLayout.getTypeSize(slot.elemType); 1014 } 1015 1016 /// Checks whether all indices are i32. This is used to check GEPs can index 1017 /// into them. 1018 static bool areAllIndicesI32(const DestructurableMemorySlot &slot) { 1019 Type i32 = IntegerType::get(slot.ptr.getContext(), 32); 1020 return llvm::all_of(llvm::make_first_range(slot.subelementTypes), 1021 [&](Attribute index) { 1022 auto intIndex = dyn_cast<IntegerAttr>(index); 1023 return intIndex && intIndex.getType() == i32; 1024 }); 1025 } 1026 1027 //===----------------------------------------------------------------------===// 1028 // Interfaces for memset and memset.inline 1029 //===----------------------------------------------------------------------===// 1030 1031 template <class MemsetIntr> 1032 static bool memsetCanRewire(MemsetIntr op, const DestructurableMemorySlot &slot, 1033 SmallPtrSetImpl<Attribute> &usedIndices, 1034 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 1035 const DataLayout &dataLayout) { 1036 if (&slot.elemType.getDialect() != op.getOperation()->getDialect()) 1037 return false; 1038 1039 if (op.getIsVolatile()) 1040 return false; 1041 1042 if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap()) 1043 return false; 1044 1045 if (!areAllIndicesI32(slot)) 1046 return false; 1047 1048 return definitelyWritesOnlyWithinSlot(op, slot, dataLayout); 1049 } 1050 1051 template <class MemsetIntr> 1052 static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot, 1053 OpBuilder &builder) { 1054 // TODO: Support non-integer types. 1055 return TypeSwitch<Type, Value>(slot.elemType) 1056 .Case([&](IntegerType intType) -> Value { 1057 if (intType.getWidth() == 8) 1058 return op.getVal(); 1059 1060 assert(intType.getWidth() % 8 == 0); 1061 1062 // Build the memset integer by repeatedly shifting the value and 1063 // or-ing it with the previous value. 1064 uint64_t coveredBits = 8; 1065 Value currentValue = 1066 builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal()); 1067 while (coveredBits < intType.getWidth()) { 1068 Value shiftBy = builder.create<LLVM::ConstantOp>(op.getLoc(), intType, 1069 coveredBits); 1070 Value shifted = 1071 builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy); 1072 currentValue = 1073 builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted); 1074 coveredBits *= 2; 1075 } 1076 1077 return currentValue; 1078 }) 1079 .Default([](Type) -> Value { 1080 llvm_unreachable( 1081 "getStored should not be called on memset to unsupported type"); 1082 }); 1083 } 1084 1085 template <class MemsetIntr> 1086 static bool 1087 memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot, 1088 const SmallPtrSetImpl<OpOperand *> &blockingUses, 1089 SmallVectorImpl<OpOperand *> &newBlockingUses, 1090 const DataLayout &dataLayout) { 1091 // TODO: Support non-integer types. 1092 bool canConvertType = 1093 TypeSwitch<Type, bool>(slot.elemType) 1094 .Case([](IntegerType intType) { 1095 return intType.getWidth() % 8 == 0 && intType.getWidth() > 0; 1096 }) 1097 .Default([](Type) { return false; }); 1098 if (!canConvertType) 1099 return false; 1100 1101 if (op.getIsVolatile()) 1102 return false; 1103 1104 return getStaticMemIntrLen(op) == dataLayout.getTypeSize(slot.elemType); 1105 } 1106 1107 template <class MemsetIntr> 1108 static DeletionKind 1109 memsetRewire(MemsetIntr op, const DestructurableMemorySlot &slot, 1110 DenseMap<Attribute, MemorySlot> &subslots, OpBuilder &builder, 1111 const DataLayout &dataLayout) { 1112 1113 std::optional<DenseMap<Attribute, Type>> types = 1114 cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap(); 1115 1116 IntegerAttr memsetLenAttr = createMemsetLenAttr(op); 1117 1118 bool packed = false; 1119 if (auto structType = dyn_cast<LLVM::LLVMStructType>(slot.elemType)) 1120 packed = structType.isPacked(); 1121 1122 Type i32 = IntegerType::get(op.getContext(), 32); 1123 uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue(); 1124 uint64_t covered = 0; 1125 for (size_t i = 0; i < types->size(); i++) { 1126 // Create indices on the fly to get elements in the right order. 1127 Attribute index = IntegerAttr::get(i32, i); 1128 Type elemType = types->at(index); 1129 uint64_t typeSize = dataLayout.getTypeSize(elemType); 1130 1131 if (!packed) 1132 covered = 1133 llvm::alignTo(covered, dataLayout.getTypeABIAlignment(elemType)); 1134 1135 if (covered >= memsetLen) 1136 break; 1137 1138 // If this subslot is used, apply a new memset to it. 1139 // Otherwise, only compute its offset within the original memset. 1140 if (subslots.contains(index)) { 1141 uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize); 1142 createMemsetIntr(builder, op, memsetLenAttr, newMemsetSize, subslots, 1143 index); 1144 } 1145 1146 covered += typeSize; 1147 } 1148 1149 return DeletionKind::Delete; 1150 } 1151 1152 bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; } 1153 1154 bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) { 1155 return getDst() == slot.ptr; 1156 } 1157 1158 Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder, 1159 Value reachingDef, 1160 const DataLayout &dataLayout) { 1161 return memsetGetStored(*this, slot, builder); 1162 } 1163 1164 bool LLVM::MemsetOp::canUsesBeRemoved( 1165 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, 1166 SmallVectorImpl<OpOperand *> &newBlockingUses, 1167 const DataLayout &dataLayout) { 1168 return memsetCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses, 1169 dataLayout); 1170 } 1171 1172 DeletionKind LLVM::MemsetOp::removeBlockingUses( 1173 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, 1174 OpBuilder &builder, Value reachingDefinition, 1175 const DataLayout &dataLayout) { 1176 return DeletionKind::Delete; 1177 } 1178 1179 LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses( 1180 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 1181 const DataLayout &dataLayout) { 1182 return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout)); 1183 } 1184 1185 bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot, 1186 SmallPtrSetImpl<Attribute> &usedIndices, 1187 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 1188 const DataLayout &dataLayout) { 1189 return memsetCanRewire(*this, slot, usedIndices, mustBeSafelyUsed, 1190 dataLayout); 1191 } 1192 1193 DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot, 1194 DenseMap<Attribute, MemorySlot> &subslots, 1195 OpBuilder &builder, 1196 const DataLayout &dataLayout) { 1197 return memsetRewire(*this, slot, subslots, builder, dataLayout); 1198 } 1199 1200 bool LLVM::MemsetInlineOp::loadsFrom(const MemorySlot &slot) { return false; } 1201 1202 bool LLVM::MemsetInlineOp::storesTo(const MemorySlot &slot) { 1203 return getDst() == slot.ptr; 1204 } 1205 1206 Value LLVM::MemsetInlineOp::getStored(const MemorySlot &slot, 1207 OpBuilder &builder, Value reachingDef, 1208 const DataLayout &dataLayout) { 1209 return memsetGetStored(*this, slot, builder); 1210 } 1211 1212 bool LLVM::MemsetInlineOp::canUsesBeRemoved( 1213 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, 1214 SmallVectorImpl<OpOperand *> &newBlockingUses, 1215 const DataLayout &dataLayout) { 1216 return memsetCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses, 1217 dataLayout); 1218 } 1219 1220 DeletionKind LLVM::MemsetInlineOp::removeBlockingUses( 1221 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, 1222 OpBuilder &builder, Value reachingDefinition, 1223 const DataLayout &dataLayout) { 1224 return DeletionKind::Delete; 1225 } 1226 1227 LogicalResult LLVM::MemsetInlineOp::ensureOnlySafeAccesses( 1228 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 1229 const DataLayout &dataLayout) { 1230 return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout)); 1231 } 1232 1233 bool LLVM::MemsetInlineOp::canRewire( 1234 const DestructurableMemorySlot &slot, 1235 SmallPtrSetImpl<Attribute> &usedIndices, 1236 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 1237 const DataLayout &dataLayout) { 1238 return memsetCanRewire(*this, slot, usedIndices, mustBeSafelyUsed, 1239 dataLayout); 1240 } 1241 1242 DeletionKind 1243 LLVM::MemsetInlineOp::rewire(const DestructurableMemorySlot &slot, 1244 DenseMap<Attribute, MemorySlot> &subslots, 1245 OpBuilder &builder, const DataLayout &dataLayout) { 1246 return memsetRewire(*this, slot, subslots, builder, dataLayout); 1247 } 1248 1249 //===----------------------------------------------------------------------===// 1250 // Interfaces for memcpy/memmove 1251 //===----------------------------------------------------------------------===// 1252 1253 template <class MemcpyLike> 1254 static bool memcpyLoadsFrom(MemcpyLike op, const MemorySlot &slot) { 1255 return op.getSrc() == slot.ptr; 1256 } 1257 1258 template <class MemcpyLike> 1259 static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot) { 1260 return op.getDst() == slot.ptr; 1261 } 1262 1263 template <class MemcpyLike> 1264 static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot, 1265 OpBuilder &builder) { 1266 return builder.create<LLVM::LoadOp>(op.getLoc(), slot.elemType, op.getSrc()); 1267 } 1268 1269 template <class MemcpyLike> 1270 static bool 1271 memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot, 1272 const SmallPtrSetImpl<OpOperand *> &blockingUses, 1273 SmallVectorImpl<OpOperand *> &newBlockingUses, 1274 const DataLayout &dataLayout) { 1275 // If source and destination are the same, memcpy behavior is undefined and 1276 // memmove is a no-op. Because there is no memory change happening here, 1277 // simplifying such operations is left to canonicalization. 1278 if (op.getDst() == op.getSrc()) 1279 return false; 1280 1281 if (op.getIsVolatile()) 1282 return false; 1283 1284 return getStaticMemIntrLen(op) == dataLayout.getTypeSize(slot.elemType); 1285 } 1286 1287 template <class MemcpyLike> 1288 static DeletionKind 1289 memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot, 1290 const SmallPtrSetImpl<OpOperand *> &blockingUses, 1291 OpBuilder &builder, Value reachingDefinition) { 1292 if (op.loadsFrom(slot)) 1293 builder.create<LLVM::StoreOp>(op.getLoc(), reachingDefinition, op.getDst()); 1294 return DeletionKind::Delete; 1295 } 1296 1297 template <class MemcpyLike> 1298 static LogicalResult 1299 memcpyEnsureOnlySafeAccesses(MemcpyLike op, const MemorySlot &slot, 1300 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) { 1301 DataLayout dataLayout = DataLayout::closest(op); 1302 // While rewiring memcpy-like intrinsics only supports full copies, partial 1303 // copies are still safe accesses so it is enough to only check for writes 1304 // within bounds. 1305 return success(definitelyWritesOnlyWithinSlot(op, slot, dataLayout)); 1306 } 1307 1308 template <class MemcpyLike> 1309 static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot, 1310 SmallPtrSetImpl<Attribute> &usedIndices, 1311 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 1312 const DataLayout &dataLayout) { 1313 if (op.getIsVolatile()) 1314 return false; 1315 1316 if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap()) 1317 return false; 1318 1319 if (!areAllIndicesI32(slot)) 1320 return false; 1321 1322 // Only full copies are supported. 1323 if (getStaticMemIntrLen(op) != dataLayout.getTypeSize(slot.elemType)) 1324 return false; 1325 1326 if (op.getSrc() == slot.ptr) 1327 for (Attribute index : llvm::make_first_range(slot.subelementTypes)) 1328 usedIndices.insert(index); 1329 1330 return true; 1331 } 1332 1333 namespace { 1334 1335 template <class MemcpyLike> 1336 void createMemcpyLikeToReplace(OpBuilder &builder, const DataLayout &layout, 1337 MemcpyLike toReplace, Value dst, Value src, 1338 Type toCpy, bool isVolatile) { 1339 Value memcpySize = builder.create<LLVM::ConstantOp>( 1340 toReplace.getLoc(), IntegerAttr::get(toReplace.getLen().getType(), 1341 layout.getTypeSize(toCpy))); 1342 builder.create<MemcpyLike>(toReplace.getLoc(), dst, src, memcpySize, 1343 isVolatile); 1344 } 1345 1346 template <> 1347 void createMemcpyLikeToReplace(OpBuilder &builder, const DataLayout &layout, 1348 LLVM::MemcpyInlineOp toReplace, Value dst, 1349 Value src, Type toCpy, bool isVolatile) { 1350 Type lenType = IntegerType::get(toReplace->getContext(), 1351 toReplace.getLen().getBitWidth()); 1352 builder.create<LLVM::MemcpyInlineOp>( 1353 toReplace.getLoc(), dst, src, 1354 IntegerAttr::get(lenType, layout.getTypeSize(toCpy)), isVolatile); 1355 } 1356 1357 } // namespace 1358 1359 /// Rewires a memcpy-like operation. Only copies to or from the full slot are 1360 /// supported. 1361 template <class MemcpyLike> 1362 static DeletionKind 1363 memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot, 1364 DenseMap<Attribute, MemorySlot> &subslots, OpBuilder &builder, 1365 const DataLayout &dataLayout) { 1366 if (subslots.empty()) 1367 return DeletionKind::Delete; 1368 1369 assert((slot.ptr == op.getDst()) != (slot.ptr == op.getSrc())); 1370 bool isDst = slot.ptr == op.getDst(); 1371 1372 #ifndef NDEBUG 1373 size_t slotsTreated = 0; 1374 #endif 1375 1376 // It was previously checked that index types are consistent, so this type can 1377 // be fetched now. 1378 Type indexType = cast<IntegerAttr>(subslots.begin()->first).getType(); 1379 for (size_t i = 0, e = slot.subelementTypes.size(); i != e; i++) { 1380 Attribute index = IntegerAttr::get(indexType, i); 1381 if (!subslots.contains(index)) 1382 continue; 1383 const MemorySlot &subslot = subslots.at(index); 1384 1385 #ifndef NDEBUG 1386 slotsTreated++; 1387 #endif 1388 1389 // First get a pointer to the equivalent of this subslot from the source 1390 // pointer. 1391 SmallVector<LLVM::GEPArg> gepIndices{ 1392 0, static_cast<int32_t>( 1393 cast<IntegerAttr>(index).getValue().getZExtValue())}; 1394 Value subslotPtrInOther = builder.create<LLVM::GEPOp>( 1395 op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), slot.elemType, 1396 isDst ? op.getSrc() : op.getDst(), gepIndices); 1397 1398 // Then create a new memcpy out of this source pointer. 1399 createMemcpyLikeToReplace(builder, dataLayout, op, 1400 isDst ? subslot.ptr : subslotPtrInOther, 1401 isDst ? subslotPtrInOther : subslot.ptr, 1402 subslot.elemType, op.getIsVolatile()); 1403 } 1404 1405 assert(subslots.size() == slotsTreated); 1406 1407 return DeletionKind::Delete; 1408 } 1409 1410 bool LLVM::MemcpyOp::loadsFrom(const MemorySlot &slot) { 1411 return memcpyLoadsFrom(*this, slot); 1412 } 1413 1414 bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) { 1415 return memcpyStoresTo(*this, slot); 1416 } 1417 1418 Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, OpBuilder &builder, 1419 Value reachingDef, 1420 const DataLayout &dataLayout) { 1421 return memcpyGetStored(*this, slot, builder); 1422 } 1423 1424 bool LLVM::MemcpyOp::canUsesBeRemoved( 1425 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, 1426 SmallVectorImpl<OpOperand *> &newBlockingUses, 1427 const DataLayout &dataLayout) { 1428 return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses, 1429 dataLayout); 1430 } 1431 1432 DeletionKind LLVM::MemcpyOp::removeBlockingUses( 1433 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, 1434 OpBuilder &builder, Value reachingDefinition, 1435 const DataLayout &dataLayout) { 1436 return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder, 1437 reachingDefinition); 1438 } 1439 1440 LogicalResult LLVM::MemcpyOp::ensureOnlySafeAccesses( 1441 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 1442 const DataLayout &dataLayout) { 1443 return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed); 1444 } 1445 1446 bool LLVM::MemcpyOp::canRewire(const DestructurableMemorySlot &slot, 1447 SmallPtrSetImpl<Attribute> &usedIndices, 1448 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 1449 const DataLayout &dataLayout) { 1450 return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed, 1451 dataLayout); 1452 } 1453 1454 DeletionKind LLVM::MemcpyOp::rewire(const DestructurableMemorySlot &slot, 1455 DenseMap<Attribute, MemorySlot> &subslots, 1456 OpBuilder &builder, 1457 const DataLayout &dataLayout) { 1458 return memcpyRewire(*this, slot, subslots, builder, dataLayout); 1459 } 1460 1461 bool LLVM::MemcpyInlineOp::loadsFrom(const MemorySlot &slot) { 1462 return memcpyLoadsFrom(*this, slot); 1463 } 1464 1465 bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) { 1466 return memcpyStoresTo(*this, slot); 1467 } 1468 1469 Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot, 1470 OpBuilder &builder, Value reachingDef, 1471 const DataLayout &dataLayout) { 1472 return memcpyGetStored(*this, slot, builder); 1473 } 1474 1475 bool LLVM::MemcpyInlineOp::canUsesBeRemoved( 1476 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, 1477 SmallVectorImpl<OpOperand *> &newBlockingUses, 1478 const DataLayout &dataLayout) { 1479 return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses, 1480 dataLayout); 1481 } 1482 1483 DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses( 1484 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, 1485 OpBuilder &builder, Value reachingDefinition, 1486 const DataLayout &dataLayout) { 1487 return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder, 1488 reachingDefinition); 1489 } 1490 1491 LogicalResult LLVM::MemcpyInlineOp::ensureOnlySafeAccesses( 1492 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 1493 const DataLayout &dataLayout) { 1494 return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed); 1495 } 1496 1497 bool LLVM::MemcpyInlineOp::canRewire( 1498 const DestructurableMemorySlot &slot, 1499 SmallPtrSetImpl<Attribute> &usedIndices, 1500 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 1501 const DataLayout &dataLayout) { 1502 return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed, 1503 dataLayout); 1504 } 1505 1506 DeletionKind 1507 LLVM::MemcpyInlineOp::rewire(const DestructurableMemorySlot &slot, 1508 DenseMap<Attribute, MemorySlot> &subslots, 1509 OpBuilder &builder, const DataLayout &dataLayout) { 1510 return memcpyRewire(*this, slot, subslots, builder, dataLayout); 1511 } 1512 1513 bool LLVM::MemmoveOp::loadsFrom(const MemorySlot &slot) { 1514 return memcpyLoadsFrom(*this, slot); 1515 } 1516 1517 bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) { 1518 return memcpyStoresTo(*this, slot); 1519 } 1520 1521 Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, OpBuilder &builder, 1522 Value reachingDef, 1523 const DataLayout &dataLayout) { 1524 return memcpyGetStored(*this, slot, builder); 1525 } 1526 1527 bool LLVM::MemmoveOp::canUsesBeRemoved( 1528 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, 1529 SmallVectorImpl<OpOperand *> &newBlockingUses, 1530 const DataLayout &dataLayout) { 1531 return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses, 1532 dataLayout); 1533 } 1534 1535 DeletionKind LLVM::MemmoveOp::removeBlockingUses( 1536 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, 1537 OpBuilder &builder, Value reachingDefinition, 1538 const DataLayout &dataLayout) { 1539 return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder, 1540 reachingDefinition); 1541 } 1542 1543 LogicalResult LLVM::MemmoveOp::ensureOnlySafeAccesses( 1544 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 1545 const DataLayout &dataLayout) { 1546 return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed); 1547 } 1548 1549 bool LLVM::MemmoveOp::canRewire(const DestructurableMemorySlot &slot, 1550 SmallPtrSetImpl<Attribute> &usedIndices, 1551 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, 1552 const DataLayout &dataLayout) { 1553 return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed, 1554 dataLayout); 1555 } 1556 1557 DeletionKind LLVM::MemmoveOp::rewire(const DestructurableMemorySlot &slot, 1558 DenseMap<Attribute, MemorySlot> &subslots, 1559 OpBuilder &builder, 1560 const DataLayout &dataLayout) { 1561 return memcpyRewire(*this, slot, subslots, builder, dataLayout); 1562 } 1563 1564 //===----------------------------------------------------------------------===// 1565 // Interfaces for destructurable types 1566 //===----------------------------------------------------------------------===// 1567 1568 std::optional<DenseMap<Attribute, Type>> 1569 LLVM::LLVMStructType::getSubelementIndexMap() const { 1570 Type i32 = IntegerType::get(getContext(), 32); 1571 DenseMap<Attribute, Type> destructured; 1572 for (const auto &[index, elemType] : llvm::enumerate(getBody())) 1573 destructured.insert({IntegerAttr::get(i32, index), elemType}); 1574 return destructured; 1575 } 1576 1577 Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) const { 1578 auto indexAttr = llvm::dyn_cast<IntegerAttr>(index); 1579 if (!indexAttr || !indexAttr.getType().isInteger(32)) 1580 return {}; 1581 int32_t indexInt = indexAttr.getInt(); 1582 ArrayRef<Type> body = getBody(); 1583 if (indexInt < 0 || body.size() <= static_cast<uint32_t>(indexInt)) 1584 return {}; 1585 return body[indexInt]; 1586 } 1587 1588 std::optional<DenseMap<Attribute, Type>> 1589 LLVM::LLVMArrayType::getSubelementIndexMap() const { 1590 constexpr size_t maxArraySizeForDestructuring = 16; 1591 if (getNumElements() > maxArraySizeForDestructuring) 1592 return {}; 1593 int32_t numElements = getNumElements(); 1594 1595 Type i32 = IntegerType::get(getContext(), 32); 1596 DenseMap<Attribute, Type> destructured; 1597 for (int32_t index = 0; index < numElements; ++index) 1598 destructured.insert({IntegerAttr::get(i32, index), getElementType()}); 1599 return destructured; 1600 } 1601 1602 Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const { 1603 auto indexAttr = llvm::dyn_cast<IntegerAttr>(index); 1604 if (!indexAttr || !indexAttr.getType().isInteger(32)) 1605 return {}; 1606 int32_t indexInt = indexAttr.getInt(); 1607 if (indexInt < 0 || getNumElements() <= static_cast<uint32_t>(indexInt)) 1608 return {}; 1609 return getElementType(); 1610 } 1611