1 //===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 9 #include "DXILDataScalarization.h" 10 #include "DirectX.h" 11 #include "llvm/ADT/PostOrderIterator.h" 12 #include "llvm/ADT/STLExtras.h" 13 #include "llvm/IR/GlobalVariable.h" 14 #include "llvm/IR/IRBuilder.h" 15 #include "llvm/IR/InstVisitor.h" 16 #include "llvm/IR/Module.h" 17 #include "llvm/IR/Operator.h" 18 #include "llvm/IR/PassManager.h" 19 #include "llvm/IR/ReplaceConstant.h" 20 #include "llvm/IR/Type.h" 21 #include "llvm/Transforms/Utils/Cloning.h" 22 #include "llvm/Transforms/Utils/Local.h" 23 24 #define DEBUG_TYPE "dxil-data-scalarization" 25 static const int MaxVecSize = 4; 26 27 using namespace llvm; 28 29 class DXILDataScalarizationLegacy : public ModulePass { 30 31 public: 32 bool runOnModule(Module &M) override; 33 DXILDataScalarizationLegacy() : ModulePass(ID) {} 34 35 static char ID; // Pass identification. 36 }; 37 38 static bool findAndReplaceVectors(Module &M); 39 40 class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> { 41 public: 42 DataScalarizerVisitor() : GlobalMap() {} 43 bool visit(Instruction &I); 44 // InstVisitor methods. They return true if the instruction was scalarized, 45 // false if nothing changed. 46 bool visitInstruction(Instruction &I) { return false; } 47 bool visitSelectInst(SelectInst &SI) { return false; } 48 bool visitICmpInst(ICmpInst &ICI) { return false; } 49 bool visitFCmpInst(FCmpInst &FCI) { return false; } 50 bool visitUnaryOperator(UnaryOperator &UO) { return false; } 51 bool visitBinaryOperator(BinaryOperator &BO) { return false; } 52 bool visitGetElementPtrInst(GetElementPtrInst &GEPI); 53 bool visitCastInst(CastInst &CI) { return false; } 54 bool visitBitCastInst(BitCastInst &BCI) { return false; } 55 bool visitInsertElementInst(InsertElementInst &IEI) { return false; } 56 bool visitExtractElementInst(ExtractElementInst &EEI) { return false; } 57 bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; } 58 bool visitPHINode(PHINode &PHI) { return false; } 59 bool visitLoadInst(LoadInst &LI); 60 bool visitStoreInst(StoreInst &SI); 61 bool visitCallInst(CallInst &ICI) { return false; } 62 bool visitFreezeInst(FreezeInst &FI) { return false; } 63 friend bool findAndReplaceVectors(llvm::Module &M); 64 65 private: 66 GlobalVariable *lookupReplacementGlobal(Value *CurrOperand); 67 DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap; 68 }; 69 70 bool DataScalarizerVisitor::visit(Instruction &I) { 71 assert(!GlobalMap.empty()); 72 return InstVisitor::visit(I); 73 } 74 75 GlobalVariable * 76 DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) { 77 if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) { 78 auto It = GlobalMap.find(OldGlobal); 79 if (It != GlobalMap.end()) { 80 return It->second; // Found, return the new global 81 } 82 } 83 return nullptr; // Not found 84 } 85 86 bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) { 87 unsigned NumOperands = LI.getNumOperands(); 88 for (unsigned I = 0; I < NumOperands; ++I) { 89 Value *CurrOpperand = LI.getOperand(I); 90 ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand); 91 if (CE && CE->getOpcode() == Instruction::GetElementPtr) { 92 GetElementPtrInst *OldGEP = 93 cast<GetElementPtrInst>(CE->getAsInstruction()); 94 OldGEP->insertBefore(&LI); 95 IRBuilder<> Builder(&LI); 96 LoadInst *NewLoad = 97 Builder.CreateLoad(LI.getType(), OldGEP, LI.getName()); 98 NewLoad->setAlignment(LI.getAlign()); 99 LI.replaceAllUsesWith(NewLoad); 100 LI.eraseFromParent(); 101 visitGetElementPtrInst(*OldGEP); 102 return true; 103 } 104 if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) 105 LI.setOperand(I, NewGlobal); 106 } 107 return false; 108 } 109 110 bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) { 111 unsigned NumOperands = SI.getNumOperands(); 112 for (unsigned I = 0; I < NumOperands; ++I) { 113 Value *CurrOpperand = SI.getOperand(I); 114 ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand); 115 if (CE && CE->getOpcode() == Instruction::GetElementPtr) { 116 GetElementPtrInst *OldGEP = 117 cast<GetElementPtrInst>(CE->getAsInstruction()); 118 OldGEP->insertBefore(&SI); 119 IRBuilder<> Builder(&SI); 120 StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP); 121 NewStore->setAlignment(SI.getAlign()); 122 SI.replaceAllUsesWith(NewStore); 123 SI.eraseFromParent(); 124 visitGetElementPtrInst(*OldGEP); 125 return true; 126 } 127 if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) 128 SI.setOperand(I, NewGlobal); 129 } 130 return false; 131 } 132 133 bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { 134 135 unsigned NumOperands = GEPI.getNumOperands(); 136 GlobalVariable *NewGlobal = nullptr; 137 for (unsigned I = 0; I < NumOperands; ++I) { 138 Value *CurrOpperand = GEPI.getOperand(I); 139 NewGlobal = lookupReplacementGlobal(CurrOpperand); 140 if (NewGlobal) 141 break; 142 } 143 if (!NewGlobal) 144 return false; 145 146 IRBuilder<> Builder(&GEPI); 147 SmallVector<Value *, MaxVecSize> Indices; 148 for (auto &Index : GEPI.indices()) 149 Indices.push_back(Index); 150 151 Value *NewGEP = 152 Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices, 153 GEPI.getName(), GEPI.getNoWrapFlags()); 154 GEPI.replaceAllUsesWith(NewGEP); 155 GEPI.eraseFromParent(); 156 return true; 157 } 158 159 // Recursively Creates and Array like version of the given vector like type. 160 static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) { 161 if (auto *VecTy = dyn_cast<VectorType>(T)) 162 return ArrayType::get(VecTy->getElementType(), 163 dyn_cast<FixedVectorType>(VecTy)->getNumElements()); 164 if (auto *ArrayTy = dyn_cast<ArrayType>(T)) { 165 Type *NewElementType = 166 replaceVectorWithArray(ArrayTy->getElementType(), Ctx); 167 return ArrayType::get(NewElementType, ArrayTy->getNumElements()); 168 } 169 // If it's not a vector or array, return the original type. 170 return T; 171 } 172 173 Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType, 174 LLVMContext &Ctx) { 175 // Handle ConstantAggregateZero (zero-initialized constants) 176 if (isa<ConstantAggregateZero>(Init)) { 177 return ConstantAggregateZero::get(NewType); 178 } 179 180 // Handle UndefValue (undefined constants) 181 if (isa<UndefValue>(Init)) { 182 return UndefValue::get(NewType); 183 } 184 185 // Handle vector to array transformation 186 if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) { 187 // Convert vector initializer to array initializer 188 SmallVector<Constant *, MaxVecSize> ArrayElements; 189 if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) { 190 for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I) 191 ArrayElements.push_back(ConstVecInit->getOperand(I)); 192 } else if (ConstantDataVector *ConstDataVecInit = 193 llvm::dyn_cast<llvm::ConstantDataVector>(Init)) { 194 for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I) 195 ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I)); 196 } else { 197 assert(false && "Expected a ConstantVector or ConstantDataVector for " 198 "vector initializer!"); 199 } 200 201 return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements); 202 } 203 204 // Handle array of vectors transformation 205 if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) { 206 auto *ArrayInit = dyn_cast<ConstantArray>(Init); 207 assert(ArrayInit && "Expected a ConstantArray for array initializer!"); 208 209 SmallVector<Constant *, MaxVecSize> NewArrayElements; 210 for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) { 211 // Recursively transform array elements 212 Constant *NewElemInit = transformInitializer( 213 ArrayInit->getOperand(I), ArrayTy->getElementType(), 214 cast<ArrayType>(NewType)->getElementType(), Ctx); 215 NewArrayElements.push_back(NewElemInit); 216 } 217 218 return ConstantArray::get(cast<ArrayType>(NewType), NewArrayElements); 219 } 220 221 // If not a vector or array, return the original initializer 222 return Init; 223 } 224 225 static bool findAndReplaceVectors(Module &M) { 226 bool MadeChange = false; 227 LLVMContext &Ctx = M.getContext(); 228 IRBuilder<> Builder(Ctx); 229 DataScalarizerVisitor Impl; 230 for (GlobalVariable &G : M.globals()) { 231 Type *OrigType = G.getValueType(); 232 233 Type *NewType = replaceVectorWithArray(OrigType, Ctx); 234 if (OrigType != NewType) { 235 // Create a new global variable with the updated type 236 // Note: Initializer is set via transformInitializer 237 GlobalVariable *NewGlobal = new GlobalVariable( 238 M, NewType, G.isConstant(), G.getLinkage(), 239 /*Initializer=*/nullptr, G.getName() + ".scalarized", &G, 240 G.getThreadLocalMode(), G.getAddressSpace(), 241 G.isExternallyInitialized()); 242 243 // Copy relevant attributes 244 NewGlobal->setUnnamedAddr(G.getUnnamedAddr()); 245 if (G.getAlignment() > 0) { 246 NewGlobal->setAlignment(G.getAlign()); 247 } 248 249 if (G.hasInitializer()) { 250 Constant *Init = G.getInitializer(); 251 Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx); 252 NewGlobal->setInitializer(NewInit); 253 } 254 255 // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes 256 // type equality. Instead we will use the visitor pattern. 257 Impl.GlobalMap[&G] = NewGlobal; 258 for (User *U : make_early_inc_range(G.users())) { 259 if (isa<ConstantExpr>(U) && isa<Operator>(U)) { 260 ConstantExpr *CE = cast<ConstantExpr>(U); 261 for (User *UCE : make_early_inc_range(CE->users())) { 262 if (Instruction *Inst = dyn_cast<Instruction>(UCE)) 263 Impl.visit(*Inst); 264 } 265 } 266 if (Instruction *Inst = dyn_cast<Instruction>(U)) 267 Impl.visit(*Inst); 268 } 269 } 270 } 271 272 // Remove the old globals after the iteration 273 for (auto &[Old, New] : Impl.GlobalMap) { 274 Old->eraseFromParent(); 275 MadeChange = true; 276 } 277 return MadeChange; 278 } 279 280 PreservedAnalyses DXILDataScalarization::run(Module &M, 281 ModuleAnalysisManager &) { 282 bool MadeChanges = findAndReplaceVectors(M); 283 if (!MadeChanges) 284 return PreservedAnalyses::all(); 285 PreservedAnalyses PA; 286 return PA; 287 } 288 289 bool DXILDataScalarizationLegacy::runOnModule(Module &M) { 290 return findAndReplaceVectors(M); 291 } 292 293 char DXILDataScalarizationLegacy::ID = 0; 294 295 INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE, 296 "DXIL Data Scalarization", false, false) 297 INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE, 298 "DXIL Data Scalarization", false, false) 299 300 ModulePass *llvm::createDXILDataScalarizationLegacyPass() { 301 return new DXILDataScalarizationLegacy(); 302 } 303