1 //===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 /// 9 /// \file This file contains a pass to flatten arrays for the DirectX Backend. 10 /// 11 //===----------------------------------------------------------------------===// 12 13 #include "DXILFlattenArrays.h" 14 #include "DirectX.h" 15 #include "llvm/ADT/PostOrderIterator.h" 16 #include "llvm/ADT/STLExtras.h" 17 #include "llvm/IR/BasicBlock.h" 18 #include "llvm/IR/DerivedTypes.h" 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/InstVisitor.h" 21 #include "llvm/IR/ReplaceConstant.h" 22 #include "llvm/Support/Casting.h" 23 #include "llvm/Transforms/Utils/Local.h" 24 #include <cassert> 25 #include <cstddef> 26 #include <cstdint> 27 #include <utility> 28 29 #define DEBUG_TYPE "dxil-flatten-arrays" 30 31 using namespace llvm; 32 namespace { 33 34 class DXILFlattenArraysLegacy : public ModulePass { 35 36 public: 37 bool runOnModule(Module &M) override; 38 DXILFlattenArraysLegacy() : ModulePass(ID) {} 39 40 static char ID; // Pass identification. 41 }; 42 43 struct GEPData { 44 ArrayType *ParentArrayType; 45 Value *ParendOperand; 46 SmallVector<Value *> Indices; 47 SmallVector<uint64_t> Dims; 48 bool AllIndicesAreConstInt; 49 }; 50 51 class DXILFlattenArraysVisitor 52 : public InstVisitor<DXILFlattenArraysVisitor, bool> { 53 public: 54 DXILFlattenArraysVisitor() {} 55 bool visit(Function &F); 56 // InstVisitor methods. They return true if the instruction was scalarized, 57 // false if nothing changed. 58 bool visitGetElementPtrInst(GetElementPtrInst &GEPI); 59 bool visitAllocaInst(AllocaInst &AI); 60 bool visitInstruction(Instruction &I) { return false; } 61 bool visitSelectInst(SelectInst &SI) { return false; } 62 bool visitICmpInst(ICmpInst &ICI) { return false; } 63 bool visitFCmpInst(FCmpInst &FCI) { return false; } 64 bool visitUnaryOperator(UnaryOperator &UO) { return false; } 65 bool visitBinaryOperator(BinaryOperator &BO) { return false; } 66 bool visitCastInst(CastInst &CI) { return false; } 67 bool visitBitCastInst(BitCastInst &BCI) { return false; } 68 bool visitInsertElementInst(InsertElementInst &IEI) { return false; } 69 bool visitExtractElementInst(ExtractElementInst &EEI) { return false; } 70 bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; } 71 bool visitPHINode(PHINode &PHI) { return false; } 72 bool visitLoadInst(LoadInst &LI); 73 bool visitStoreInst(StoreInst &SI); 74 bool visitCallInst(CallInst &ICI) { return false; } 75 bool visitFreezeInst(FreezeInst &FI) { return false; } 76 static bool isMultiDimensionalArray(Type *T); 77 static std::pair<unsigned, Type *> getElementCountAndType(Type *ArrayTy); 78 79 private: 80 SmallVector<WeakTrackingVH> PotentiallyDeadInstrs; 81 DenseMap<GetElementPtrInst *, GEPData> GEPChainMap; 82 bool finish(); 83 ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices, 84 ArrayRef<uint64_t> Dims, 85 IRBuilder<> &Builder); 86 Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices, 87 ArrayRef<uint64_t> Dims, 88 IRBuilder<> &Builder); 89 void 90 recursivelyCollectGEPs(GetElementPtrInst &CurrGEP, 91 ArrayType *FlattenedArrayType, Value *PtrOperand, 92 unsigned &GEPChainUseCount, 93 SmallVector<Value *> Indices = SmallVector<Value *>(), 94 SmallVector<uint64_t> Dims = SmallVector<uint64_t>(), 95 bool AllIndicesAreConstInt = true); 96 bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP); 97 bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo, 98 GetElementPtrInst &GEP); 99 }; 100 } // namespace 101 102 bool DXILFlattenArraysVisitor::finish() { 103 RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs); 104 return true; 105 } 106 107 bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) { 108 if (ArrayType *ArrType = dyn_cast<ArrayType>(T)) 109 return isa<ArrayType>(ArrType->getElementType()); 110 return false; 111 } 112 113 std::pair<unsigned, Type *> 114 DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) { 115 unsigned TotalElements = 1; 116 Type *CurrArrayTy = ArrayTy; 117 while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) { 118 TotalElements *= InnerArrayTy->getNumElements(); 119 CurrArrayTy = InnerArrayTy->getElementType(); 120 } 121 return std::make_pair(TotalElements, CurrArrayTy); 122 } 123 124 ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices( 125 ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) { 126 assert(Indices.size() == Dims.size() && 127 "Indicies and dimmensions should be the same"); 128 unsigned FlatIndex = 0; 129 unsigned Multiplier = 1; 130 131 for (int I = Indices.size() - 1; I >= 0; --I) { 132 unsigned DimSize = Dims[I]; 133 ConstantInt *CIndex = dyn_cast<ConstantInt>(Indices[I]); 134 assert(CIndex && "This function expects all indicies to be ConstantInt"); 135 FlatIndex += CIndex->getZExtValue() * Multiplier; 136 Multiplier *= DimSize; 137 } 138 return Builder.getInt32(FlatIndex); 139 } 140 141 Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices( 142 ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) { 143 if (Indices.size() == 1) 144 return Indices[0]; 145 146 Value *FlatIndex = Builder.getInt32(0); 147 unsigned Multiplier = 1; 148 149 for (int I = Indices.size() - 1; I >= 0; --I) { 150 unsigned DimSize = Dims[I]; 151 Value *VMultiplier = Builder.getInt32(Multiplier); 152 Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier); 153 FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex); 154 Multiplier *= DimSize; 155 } 156 return FlatIndex; 157 } 158 159 bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) { 160 unsigned NumOperands = LI.getNumOperands(); 161 for (unsigned I = 0; I < NumOperands; ++I) { 162 Value *CurrOpperand = LI.getOperand(I); 163 ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand); 164 if (CE && CE->getOpcode() == Instruction::GetElementPtr) { 165 GetElementPtrInst *OldGEP = 166 cast<GetElementPtrInst>(CE->getAsInstruction()); 167 OldGEP->insertBefore(&LI); 168 169 IRBuilder<> Builder(&LI); 170 LoadInst *NewLoad = 171 Builder.CreateLoad(LI.getType(), OldGEP, LI.getName()); 172 NewLoad->setAlignment(LI.getAlign()); 173 LI.replaceAllUsesWith(NewLoad); 174 LI.eraseFromParent(); 175 visitGetElementPtrInst(*OldGEP); 176 return true; 177 } 178 } 179 return false; 180 } 181 182 bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) { 183 unsigned NumOperands = SI.getNumOperands(); 184 for (unsigned I = 0; I < NumOperands; ++I) { 185 Value *CurrOpperand = SI.getOperand(I); 186 ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand); 187 if (CE && CE->getOpcode() == Instruction::GetElementPtr) { 188 GetElementPtrInst *OldGEP = 189 cast<GetElementPtrInst>(CE->getAsInstruction()); 190 OldGEP->insertBefore(&SI); 191 192 IRBuilder<> Builder(&SI); 193 StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP); 194 NewStore->setAlignment(SI.getAlign()); 195 SI.replaceAllUsesWith(NewStore); 196 SI.eraseFromParent(); 197 visitGetElementPtrInst(*OldGEP); 198 return true; 199 } 200 } 201 return false; 202 } 203 204 bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) { 205 if (!isMultiDimensionalArray(AI.getAllocatedType())) 206 return false; 207 208 ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType()); 209 IRBuilder<> Builder(&AI); 210 auto [TotalElements, BaseType] = getElementCountAndType(ArrType); 211 212 ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements); 213 AllocaInst *FlatAlloca = 214 Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".flat"); 215 FlatAlloca->setAlignment(AI.getAlign()); 216 AI.replaceAllUsesWith(FlatAlloca); 217 AI.eraseFromParent(); 218 return true; 219 } 220 221 void DXILFlattenArraysVisitor::recursivelyCollectGEPs( 222 GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType, 223 Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices, 224 SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) { 225 Value *LastIndex = CurrGEP.getOperand(CurrGEP.getNumOperands() - 1); 226 AllIndicesAreConstInt &= isa<ConstantInt>(LastIndex); 227 Indices.push_back(LastIndex); 228 assert(isa<ArrayType>(CurrGEP.getSourceElementType())); 229 Dims.push_back( 230 cast<ArrayType>(CurrGEP.getSourceElementType())->getNumElements()); 231 bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType()); 232 if (!IsMultiDimArr) { 233 assert(GEPChainUseCount < FlattenedArrayType->getNumElements()); 234 GEPChainMap.insert( 235 {&CurrGEP, 236 {std::move(FlattenedArrayType), PtrOperand, std::move(Indices), 237 std::move(Dims), AllIndicesAreConstInt}}); 238 return; 239 } 240 bool GepUses = false; 241 for (auto *User : CurrGEP.users()) { 242 if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) { 243 recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand, 244 ++GEPChainUseCount, Indices, Dims, 245 AllIndicesAreConstInt); 246 GepUses = true; 247 } 248 } 249 // This case is just incase the gep chain doesn't end with a 1d array. 250 if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) { 251 GEPChainMap.insert( 252 {&CurrGEP, 253 {std::move(FlattenedArrayType), PtrOperand, std::move(Indices), 254 std::move(Dims), AllIndicesAreConstInt}}); 255 } 256 } 257 258 bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain( 259 GetElementPtrInst &GEP) { 260 GEPData GEPInfo = GEPChainMap.at(&GEP); 261 return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP); 262 } 263 bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase( 264 GEPData &GEPInfo, GetElementPtrInst &GEP) { 265 IRBuilder<> Builder(&GEP); 266 Value *FlatIndex; 267 if (GEPInfo.AllIndicesAreConstInt) 268 FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder); 269 else 270 FlatIndex = 271 genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder); 272 273 ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType; 274 Value *FlatGEP = 275 Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParendOperand, FlatIndex, 276 GEP.getName() + ".flat", GEP.isInBounds()); 277 278 GEP.replaceAllUsesWith(FlatGEP); 279 GEP.eraseFromParent(); 280 return true; 281 } 282 283 bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { 284 auto It = GEPChainMap.find(&GEP); 285 if (It != GEPChainMap.end()) 286 return visitGetElementPtrInstInGEPChain(GEP); 287 if (!isMultiDimensionalArray(GEP.getSourceElementType())) 288 return false; 289 290 ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType()); 291 IRBuilder<> Builder(&GEP); 292 auto [TotalElements, BaseType] = getElementCountAndType(ArrType); 293 ArrayType *FlattenedArrayType = ArrayType::get(BaseType, TotalElements); 294 295 Value *PtrOperand = GEP.getPointerOperand(); 296 297 unsigned GEPChainUseCount = 0; 298 recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount); 299 300 // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0. 301 // Here recursion is used to get the length of the GEP chain. 302 // Handle zero uses here because there won't be an update via 303 // a child in the chain later. 304 if (GEPChainUseCount == 0) { 305 SmallVector<Value *> Indices({GEP.getOperand(GEP.getNumOperands() - 1)}); 306 SmallVector<uint64_t> Dims({ArrType->getNumElements()}); 307 bool AllIndicesAreConstInt = isa<ConstantInt>(Indices[0]); 308 GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand, 309 std::move(Indices), std::move(Dims), AllIndicesAreConstInt}; 310 return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP); 311 } 312 313 PotentiallyDeadInstrs.emplace_back(&GEP); 314 return false; 315 } 316 317 bool DXILFlattenArraysVisitor::visit(Function &F) { 318 bool MadeChange = false; 319 ReversePostOrderTraversal<Function *> RPOT(&F); 320 for (BasicBlock *BB : make_early_inc_range(RPOT)) { 321 for (Instruction &I : make_early_inc_range(*BB)) 322 MadeChange |= InstVisitor::visit(I); 323 } 324 finish(); 325 return MadeChange; 326 } 327 328 static void collectElements(Constant *Init, 329 SmallVectorImpl<Constant *> &Elements) { 330 // Base case: If Init is not an array, add it directly to the vector. 331 auto *ArrayTy = dyn_cast<ArrayType>(Init->getType()); 332 if (!ArrayTy) { 333 Elements.push_back(Init); 334 return; 335 } 336 unsigned ArrSize = ArrayTy->getNumElements(); 337 if (isa<ConstantAggregateZero>(Init)) { 338 for (unsigned I = 0; I < ArrSize; ++I) 339 Elements.push_back(Constant::getNullValue(ArrayTy->getElementType())); 340 return; 341 } 342 343 // Recursive case: Process each element in the array. 344 if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) { 345 for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) { 346 collectElements(ArrayConstant->getOperand(I), Elements); 347 } 348 } else if (auto *DataArrayConstant = dyn_cast<ConstantDataArray>(Init)) { 349 for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) { 350 collectElements(DataArrayConstant->getElementAsConstant(I), Elements); 351 } 352 } else { 353 llvm_unreachable( 354 "Expected a ConstantArray or ConstantDataArray for array initializer!"); 355 } 356 } 357 358 static Constant *transformInitializer(Constant *Init, Type *OrigType, 359 ArrayType *FlattenedType, 360 LLVMContext &Ctx) { 361 // Handle ConstantAggregateZero (zero-initialized constants) 362 if (isa<ConstantAggregateZero>(Init)) 363 return ConstantAggregateZero::get(FlattenedType); 364 365 // Handle UndefValue (undefined constants) 366 if (isa<UndefValue>(Init)) 367 return UndefValue::get(FlattenedType); 368 369 if (!isa<ArrayType>(OrigType)) 370 return Init; 371 372 SmallVector<Constant *> FlattenedElements; 373 collectElements(Init, FlattenedElements); 374 assert(FlattenedType->getNumElements() == FlattenedElements.size() && 375 "The number of collected elements should match the FlattenedType"); 376 return ConstantArray::get(FlattenedType, FlattenedElements); 377 } 378 379 static void 380 flattenGlobalArrays(Module &M, 381 DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) { 382 LLVMContext &Ctx = M.getContext(); 383 for (GlobalVariable &G : M.globals()) { 384 Type *OrigType = G.getValueType(); 385 if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType)) 386 continue; 387 388 ArrayType *ArrType = cast<ArrayType>(OrigType); 389 auto [TotalElements, BaseType] = 390 DXILFlattenArraysVisitor::getElementCountAndType(ArrType); 391 ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements); 392 393 // Create a new global variable with the updated type 394 // Note: Initializer is set via transformInitializer 395 GlobalVariable *NewGlobal = 396 new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(), 397 /*Initializer=*/nullptr, G.getName() + ".1dim", &G, 398 G.getThreadLocalMode(), G.getAddressSpace(), 399 G.isExternallyInitialized()); 400 401 // Copy relevant attributes 402 NewGlobal->setUnnamedAddr(G.getUnnamedAddr()); 403 if (G.getAlignment() > 0) { 404 NewGlobal->setAlignment(G.getAlign()); 405 } 406 407 if (G.hasInitializer()) { 408 Constant *Init = G.getInitializer(); 409 Constant *NewInit = 410 transformInitializer(Init, OrigType, FattenedArrayType, Ctx); 411 NewGlobal->setInitializer(NewInit); 412 } 413 GlobalMap[&G] = NewGlobal; 414 } 415 } 416 417 static bool flattenArrays(Module &M) { 418 bool MadeChange = false; 419 DXILFlattenArraysVisitor Impl; 420 DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap; 421 flattenGlobalArrays(M, GlobalMap); 422 for (auto &F : make_early_inc_range(M.functions())) { 423 if (F.isDeclaration()) 424 continue; 425 MadeChange |= Impl.visit(F); 426 } 427 for (auto &[Old, New] : GlobalMap) { 428 Old->replaceAllUsesWith(New); 429 Old->eraseFromParent(); 430 MadeChange = true; 431 } 432 return MadeChange; 433 } 434 435 PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) { 436 bool MadeChanges = flattenArrays(M); 437 if (!MadeChanges) 438 return PreservedAnalyses::all(); 439 PreservedAnalyses PA; 440 return PA; 441 } 442 443 bool DXILFlattenArraysLegacy::runOnModule(Module &M) { 444 return flattenArrays(M); 445 } 446 447 char DXILFlattenArraysLegacy::ID = 0; 448 449 INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy, DEBUG_TYPE, 450 "DXIL Array Flattener", false, false) 451 INITIALIZE_PASS_END(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener", 452 false, false) 453 454 ModulePass *llvm::createDXILFlattenArraysLegacyPass() { 455 return new DXILFlattenArraysLegacy(); 456 } 457