1324bdd66SFarzon Lotfi //===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===// 2324bdd66SFarzon Lotfi // 3324bdd66SFarzon Lotfi // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4324bdd66SFarzon Lotfi // See https://llvm.org/LICENSE.txt for license information. 5324bdd66SFarzon Lotfi // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6324bdd66SFarzon Lotfi // 7324bdd66SFarzon Lotfi //===---------------------------------------------------------------------===// 8324bdd66SFarzon Lotfi 9324bdd66SFarzon Lotfi #include "DXILDataScalarization.h" 10324bdd66SFarzon Lotfi #include "DirectX.h" 11324bdd66SFarzon Lotfi #include "llvm/ADT/PostOrderIterator.h" 12324bdd66SFarzon Lotfi #include "llvm/ADT/STLExtras.h" 13324bdd66SFarzon Lotfi #include "llvm/IR/GlobalVariable.h" 14324bdd66SFarzon Lotfi #include "llvm/IR/IRBuilder.h" 15324bdd66SFarzon Lotfi #include "llvm/IR/InstVisitor.h" 16324bdd66SFarzon Lotfi #include "llvm/IR/Module.h" 17324bdd66SFarzon Lotfi #include "llvm/IR/Operator.h" 18324bdd66SFarzon Lotfi #include "llvm/IR/PassManager.h" 19324bdd66SFarzon Lotfi #include "llvm/IR/ReplaceConstant.h" 20324bdd66SFarzon Lotfi #include "llvm/IR/Type.h" 21324bdd66SFarzon Lotfi #include "llvm/Transforms/Utils/Cloning.h" 22324bdd66SFarzon Lotfi #include "llvm/Transforms/Utils/Local.h" 23324bdd66SFarzon Lotfi 24324bdd66SFarzon Lotfi #define DEBUG_TYPE "dxil-data-scalarization" 25324bdd66SFarzon Lotfi static const int MaxVecSize = 4; 26324bdd66SFarzon Lotfi 27324bdd66SFarzon Lotfi using namespace llvm; 28324bdd66SFarzon Lotfi 29324bdd66SFarzon Lotfi class DXILDataScalarizationLegacy : public ModulePass { 30324bdd66SFarzon Lotfi 31324bdd66SFarzon Lotfi public: 32324bdd66SFarzon Lotfi bool runOnModule(Module &M) override; 33324bdd66SFarzon Lotfi DXILDataScalarizationLegacy() : ModulePass(ID) {} 34324bdd66SFarzon Lotfi 35324bdd66SFarzon Lotfi static char ID; // Pass identification. 36324bdd66SFarzon Lotfi }; 37324bdd66SFarzon Lotfi 38324bdd66SFarzon Lotfi static bool findAndReplaceVectors(Module &M); 39324bdd66SFarzon Lotfi 40324bdd66SFarzon Lotfi class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> { 41324bdd66SFarzon Lotfi public: 42324bdd66SFarzon Lotfi DataScalarizerVisitor() : GlobalMap() {} 43*6457aee5SFarzon Lotfi bool visit(Instruction &I); 44324bdd66SFarzon Lotfi // InstVisitor methods. They return true if the instruction was scalarized, 45324bdd66SFarzon Lotfi // false if nothing changed. 46324bdd66SFarzon Lotfi bool visitInstruction(Instruction &I) { return false; } 47324bdd66SFarzon Lotfi bool visitSelectInst(SelectInst &SI) { return false; } 48324bdd66SFarzon Lotfi bool visitICmpInst(ICmpInst &ICI) { return false; } 49324bdd66SFarzon Lotfi bool visitFCmpInst(FCmpInst &FCI) { return false; } 50324bdd66SFarzon Lotfi bool visitUnaryOperator(UnaryOperator &UO) { return false; } 51324bdd66SFarzon Lotfi bool visitBinaryOperator(BinaryOperator &BO) { return false; } 52324bdd66SFarzon Lotfi bool visitGetElementPtrInst(GetElementPtrInst &GEPI); 53324bdd66SFarzon Lotfi bool visitCastInst(CastInst &CI) { return false; } 54324bdd66SFarzon Lotfi bool visitBitCastInst(BitCastInst &BCI) { return false; } 55324bdd66SFarzon Lotfi bool visitInsertElementInst(InsertElementInst &IEI) { return false; } 56324bdd66SFarzon Lotfi bool visitExtractElementInst(ExtractElementInst &EEI) { return false; } 57324bdd66SFarzon Lotfi bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; } 58324bdd66SFarzon Lotfi bool visitPHINode(PHINode &PHI) { return false; } 59324bdd66SFarzon Lotfi bool visitLoadInst(LoadInst &LI); 60324bdd66SFarzon Lotfi bool visitStoreInst(StoreInst &SI); 61324bdd66SFarzon Lotfi bool visitCallInst(CallInst &ICI) { return false; } 62324bdd66SFarzon Lotfi bool visitFreezeInst(FreezeInst &FI) { return false; } 63324bdd66SFarzon Lotfi friend bool findAndReplaceVectors(llvm::Module &M); 64324bdd66SFarzon Lotfi 65324bdd66SFarzon Lotfi private: 66324bdd66SFarzon Lotfi GlobalVariable *lookupReplacementGlobal(Value *CurrOperand); 67324bdd66SFarzon Lotfi DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap; 68324bdd66SFarzon Lotfi }; 69324bdd66SFarzon Lotfi 70*6457aee5SFarzon Lotfi bool DataScalarizerVisitor::visit(Instruction &I) { 71324bdd66SFarzon Lotfi assert(!GlobalMap.empty()); 72*6457aee5SFarzon Lotfi return InstVisitor::visit(I); 73324bdd66SFarzon Lotfi } 74324bdd66SFarzon Lotfi 75324bdd66SFarzon Lotfi GlobalVariable * 76324bdd66SFarzon Lotfi DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) { 77324bdd66SFarzon Lotfi if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) { 78324bdd66SFarzon Lotfi auto It = GlobalMap.find(OldGlobal); 79324bdd66SFarzon Lotfi if (It != GlobalMap.end()) { 80324bdd66SFarzon Lotfi return It->second; // Found, return the new global 81324bdd66SFarzon Lotfi } 82324bdd66SFarzon Lotfi } 83324bdd66SFarzon Lotfi return nullptr; // Not found 84324bdd66SFarzon Lotfi } 85324bdd66SFarzon Lotfi 86324bdd66SFarzon Lotfi bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) { 87324bdd66SFarzon Lotfi unsigned NumOperands = LI.getNumOperands(); 88324bdd66SFarzon Lotfi for (unsigned I = 0; I < NumOperands; ++I) { 89324bdd66SFarzon Lotfi Value *CurrOpperand = LI.getOperand(I); 90*6457aee5SFarzon Lotfi ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand); 91*6457aee5SFarzon Lotfi if (CE && CE->getOpcode() == Instruction::GetElementPtr) { 92*6457aee5SFarzon Lotfi GetElementPtrInst *OldGEP = 93*6457aee5SFarzon Lotfi cast<GetElementPtrInst>(CE->getAsInstruction()); 94*6457aee5SFarzon Lotfi OldGEP->insertBefore(&LI); 95*6457aee5SFarzon Lotfi IRBuilder<> Builder(&LI); 96*6457aee5SFarzon Lotfi LoadInst *NewLoad = 97*6457aee5SFarzon Lotfi Builder.CreateLoad(LI.getType(), OldGEP, LI.getName()); 98*6457aee5SFarzon Lotfi NewLoad->setAlignment(LI.getAlign()); 99*6457aee5SFarzon Lotfi LI.replaceAllUsesWith(NewLoad); 100*6457aee5SFarzon Lotfi LI.eraseFromParent(); 101*6457aee5SFarzon Lotfi visitGetElementPtrInst(*OldGEP); 102*6457aee5SFarzon Lotfi return true; 103*6457aee5SFarzon Lotfi } 104324bdd66SFarzon Lotfi if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) 105324bdd66SFarzon Lotfi LI.setOperand(I, NewGlobal); 106324bdd66SFarzon Lotfi } 107324bdd66SFarzon Lotfi return false; 108324bdd66SFarzon Lotfi } 109324bdd66SFarzon Lotfi 110324bdd66SFarzon Lotfi bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) { 111324bdd66SFarzon Lotfi unsigned NumOperands = SI.getNumOperands(); 112324bdd66SFarzon Lotfi for (unsigned I = 0; I < NumOperands; ++I) { 113324bdd66SFarzon Lotfi Value *CurrOpperand = SI.getOperand(I); 114*6457aee5SFarzon Lotfi ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand); 115*6457aee5SFarzon Lotfi if (CE && CE->getOpcode() == Instruction::GetElementPtr) { 116*6457aee5SFarzon Lotfi GetElementPtrInst *OldGEP = 117*6457aee5SFarzon Lotfi cast<GetElementPtrInst>(CE->getAsInstruction()); 118*6457aee5SFarzon Lotfi OldGEP->insertBefore(&SI); 119*6457aee5SFarzon Lotfi IRBuilder<> Builder(&SI); 120*6457aee5SFarzon Lotfi StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP); 121*6457aee5SFarzon Lotfi NewStore->setAlignment(SI.getAlign()); 122*6457aee5SFarzon Lotfi SI.replaceAllUsesWith(NewStore); 123*6457aee5SFarzon Lotfi SI.eraseFromParent(); 124*6457aee5SFarzon Lotfi visitGetElementPtrInst(*OldGEP); 125*6457aee5SFarzon Lotfi return true; 126324bdd66SFarzon Lotfi } 127*6457aee5SFarzon Lotfi if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) 128*6457aee5SFarzon Lotfi SI.setOperand(I, NewGlobal); 129324bdd66SFarzon Lotfi } 130324bdd66SFarzon Lotfi return false; 131324bdd66SFarzon Lotfi } 132324bdd66SFarzon Lotfi 133324bdd66SFarzon Lotfi bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { 134*6457aee5SFarzon Lotfi 135324bdd66SFarzon Lotfi unsigned NumOperands = GEPI.getNumOperands(); 136*6457aee5SFarzon Lotfi GlobalVariable *NewGlobal = nullptr; 137324bdd66SFarzon Lotfi for (unsigned I = 0; I < NumOperands; ++I) { 138324bdd66SFarzon Lotfi Value *CurrOpperand = GEPI.getOperand(I); 139*6457aee5SFarzon Lotfi NewGlobal = lookupReplacementGlobal(CurrOpperand); 140*6457aee5SFarzon Lotfi if (NewGlobal) 141*6457aee5SFarzon Lotfi break; 142*6457aee5SFarzon Lotfi } 143324bdd66SFarzon Lotfi if (!NewGlobal) 144*6457aee5SFarzon Lotfi return false; 145324bdd66SFarzon Lotfi 146*6457aee5SFarzon Lotfi IRBuilder<> Builder(&GEPI); 147324bdd66SFarzon Lotfi SmallVector<Value *, MaxVecSize> Indices; 148324bdd66SFarzon Lotfi for (auto &Index : GEPI.indices()) 149324bdd66SFarzon Lotfi Indices.push_back(Index); 150324bdd66SFarzon Lotfi 151324bdd66SFarzon Lotfi Value *NewGEP = 152*6457aee5SFarzon Lotfi Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices, 153*6457aee5SFarzon Lotfi GEPI.getName(), GEPI.getNoWrapFlags()); 154324bdd66SFarzon Lotfi GEPI.replaceAllUsesWith(NewGEP); 155*6457aee5SFarzon Lotfi GEPI.eraseFromParent(); 156324bdd66SFarzon Lotfi return true; 157324bdd66SFarzon Lotfi } 158324bdd66SFarzon Lotfi 159324bdd66SFarzon Lotfi // Recursively Creates and Array like version of the given vector like type. 160324bdd66SFarzon Lotfi static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) { 161324bdd66SFarzon Lotfi if (auto *VecTy = dyn_cast<VectorType>(T)) 162324bdd66SFarzon Lotfi return ArrayType::get(VecTy->getElementType(), 163324bdd66SFarzon Lotfi dyn_cast<FixedVectorType>(VecTy)->getNumElements()); 164324bdd66SFarzon Lotfi if (auto *ArrayTy = dyn_cast<ArrayType>(T)) { 165324bdd66SFarzon Lotfi Type *NewElementType = 166324bdd66SFarzon Lotfi replaceVectorWithArray(ArrayTy->getElementType(), Ctx); 167324bdd66SFarzon Lotfi return ArrayType::get(NewElementType, ArrayTy->getNumElements()); 168324bdd66SFarzon Lotfi } 169324bdd66SFarzon Lotfi // If it's not a vector or array, return the original type. 170324bdd66SFarzon Lotfi return T; 171324bdd66SFarzon Lotfi } 172324bdd66SFarzon Lotfi 173324bdd66SFarzon Lotfi Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType, 174324bdd66SFarzon Lotfi LLVMContext &Ctx) { 175324bdd66SFarzon Lotfi // Handle ConstantAggregateZero (zero-initialized constants) 176324bdd66SFarzon Lotfi if (isa<ConstantAggregateZero>(Init)) { 177324bdd66SFarzon Lotfi return ConstantAggregateZero::get(NewType); 178324bdd66SFarzon Lotfi } 179324bdd66SFarzon Lotfi 180324bdd66SFarzon Lotfi // Handle UndefValue (undefined constants) 181324bdd66SFarzon Lotfi if (isa<UndefValue>(Init)) { 182324bdd66SFarzon Lotfi return UndefValue::get(NewType); 183324bdd66SFarzon Lotfi } 184324bdd66SFarzon Lotfi 185324bdd66SFarzon Lotfi // Handle vector to array transformation 186324bdd66SFarzon Lotfi if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) { 187324bdd66SFarzon Lotfi // Convert vector initializer to array initializer 188324bdd66SFarzon Lotfi SmallVector<Constant *, MaxVecSize> ArrayElements; 189324bdd66SFarzon Lotfi if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) { 190324bdd66SFarzon Lotfi for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I) 191324bdd66SFarzon Lotfi ArrayElements.push_back(ConstVecInit->getOperand(I)); 192324bdd66SFarzon Lotfi } else if (ConstantDataVector *ConstDataVecInit = 193324bdd66SFarzon Lotfi llvm::dyn_cast<llvm::ConstantDataVector>(Init)) { 194324bdd66SFarzon Lotfi for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I) 195324bdd66SFarzon Lotfi ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I)); 196324bdd66SFarzon Lotfi } else { 197324bdd66SFarzon Lotfi assert(false && "Expected a ConstantVector or ConstantDataVector for " 198324bdd66SFarzon Lotfi "vector initializer!"); 199324bdd66SFarzon Lotfi } 200324bdd66SFarzon Lotfi 201324bdd66SFarzon Lotfi return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements); 202324bdd66SFarzon Lotfi } 203324bdd66SFarzon Lotfi 204324bdd66SFarzon Lotfi // Handle array of vectors transformation 205324bdd66SFarzon Lotfi if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) { 206324bdd66SFarzon Lotfi auto *ArrayInit = dyn_cast<ConstantArray>(Init); 207324bdd66SFarzon Lotfi assert(ArrayInit && "Expected a ConstantArray for array initializer!"); 208324bdd66SFarzon Lotfi 209324bdd66SFarzon Lotfi SmallVector<Constant *, MaxVecSize> NewArrayElements; 210324bdd66SFarzon Lotfi for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) { 211324bdd66SFarzon Lotfi // Recursively transform array elements 212324bdd66SFarzon Lotfi Constant *NewElemInit = transformInitializer( 213324bdd66SFarzon Lotfi ArrayInit->getOperand(I), ArrayTy->getElementType(), 214324bdd66SFarzon Lotfi cast<ArrayType>(NewType)->getElementType(), Ctx); 215324bdd66SFarzon Lotfi NewArrayElements.push_back(NewElemInit); 216324bdd66SFarzon Lotfi } 217324bdd66SFarzon Lotfi 218324bdd66SFarzon Lotfi return ConstantArray::get(cast<ArrayType>(NewType), NewArrayElements); 219324bdd66SFarzon Lotfi } 220324bdd66SFarzon Lotfi 221324bdd66SFarzon Lotfi // If not a vector or array, return the original initializer 222324bdd66SFarzon Lotfi return Init; 223324bdd66SFarzon Lotfi } 224324bdd66SFarzon Lotfi 225324bdd66SFarzon Lotfi static bool findAndReplaceVectors(Module &M) { 226324bdd66SFarzon Lotfi bool MadeChange = false; 227324bdd66SFarzon Lotfi LLVMContext &Ctx = M.getContext(); 228324bdd66SFarzon Lotfi IRBuilder<> Builder(Ctx); 229324bdd66SFarzon Lotfi DataScalarizerVisitor Impl; 230324bdd66SFarzon Lotfi for (GlobalVariable &G : M.globals()) { 231324bdd66SFarzon Lotfi Type *OrigType = G.getValueType(); 232324bdd66SFarzon Lotfi 233324bdd66SFarzon Lotfi Type *NewType = replaceVectorWithArray(OrigType, Ctx); 234324bdd66SFarzon Lotfi if (OrigType != NewType) { 235324bdd66SFarzon Lotfi // Create a new global variable with the updated type 236324bdd66SFarzon Lotfi // Note: Initializer is set via transformInitializer 237324bdd66SFarzon Lotfi GlobalVariable *NewGlobal = new GlobalVariable( 238324bdd66SFarzon Lotfi M, NewType, G.isConstant(), G.getLinkage(), 239324bdd66SFarzon Lotfi /*Initializer=*/nullptr, G.getName() + ".scalarized", &G, 240324bdd66SFarzon Lotfi G.getThreadLocalMode(), G.getAddressSpace(), 241324bdd66SFarzon Lotfi G.isExternallyInitialized()); 242324bdd66SFarzon Lotfi 243324bdd66SFarzon Lotfi // Copy relevant attributes 244324bdd66SFarzon Lotfi NewGlobal->setUnnamedAddr(G.getUnnamedAddr()); 245324bdd66SFarzon Lotfi if (G.getAlignment() > 0) { 246324bdd66SFarzon Lotfi NewGlobal->setAlignment(G.getAlign()); 247324bdd66SFarzon Lotfi } 248324bdd66SFarzon Lotfi 249324bdd66SFarzon Lotfi if (G.hasInitializer()) { 250324bdd66SFarzon Lotfi Constant *Init = G.getInitializer(); 251324bdd66SFarzon Lotfi Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx); 252324bdd66SFarzon Lotfi NewGlobal->setInitializer(NewInit); 253324bdd66SFarzon Lotfi } 254324bdd66SFarzon Lotfi 255324bdd66SFarzon Lotfi // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes 256324bdd66SFarzon Lotfi // type equality. Instead we will use the visitor pattern. 257324bdd66SFarzon Lotfi Impl.GlobalMap[&G] = NewGlobal; 258324bdd66SFarzon Lotfi for (User *U : make_early_inc_range(G.users())) { 259324bdd66SFarzon Lotfi if (isa<ConstantExpr>(U) && isa<Operator>(U)) { 260324bdd66SFarzon Lotfi ConstantExpr *CE = cast<ConstantExpr>(U); 261*6457aee5SFarzon Lotfi for (User *UCE : make_early_inc_range(CE->users())) { 262*6457aee5SFarzon Lotfi if (Instruction *Inst = dyn_cast<Instruction>(UCE)) 263*6457aee5SFarzon Lotfi Impl.visit(*Inst); 264324bdd66SFarzon Lotfi } 265324bdd66SFarzon Lotfi } 266*6457aee5SFarzon Lotfi if (Instruction *Inst = dyn_cast<Instruction>(U)) 267*6457aee5SFarzon Lotfi Impl.visit(*Inst); 268324bdd66SFarzon Lotfi } 269324bdd66SFarzon Lotfi } 270324bdd66SFarzon Lotfi } 271324bdd66SFarzon Lotfi 272324bdd66SFarzon Lotfi // Remove the old globals after the iteration 273324bdd66SFarzon Lotfi for (auto &[Old, New] : Impl.GlobalMap) { 274324bdd66SFarzon Lotfi Old->eraseFromParent(); 275324bdd66SFarzon Lotfi MadeChange = true; 276324bdd66SFarzon Lotfi } 277324bdd66SFarzon Lotfi return MadeChange; 278324bdd66SFarzon Lotfi } 279324bdd66SFarzon Lotfi 280324bdd66SFarzon Lotfi PreservedAnalyses DXILDataScalarization::run(Module &M, 281324bdd66SFarzon Lotfi ModuleAnalysisManager &) { 282324bdd66SFarzon Lotfi bool MadeChanges = findAndReplaceVectors(M); 283324bdd66SFarzon Lotfi if (!MadeChanges) 284324bdd66SFarzon Lotfi return PreservedAnalyses::all(); 285324bdd66SFarzon Lotfi PreservedAnalyses PA; 286324bdd66SFarzon Lotfi return PA; 287324bdd66SFarzon Lotfi } 288324bdd66SFarzon Lotfi 289324bdd66SFarzon Lotfi bool DXILDataScalarizationLegacy::runOnModule(Module &M) { 290324bdd66SFarzon Lotfi return findAndReplaceVectors(M); 291324bdd66SFarzon Lotfi } 292324bdd66SFarzon Lotfi 293324bdd66SFarzon Lotfi char DXILDataScalarizationLegacy::ID = 0; 294324bdd66SFarzon Lotfi 295324bdd66SFarzon Lotfi INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE, 296324bdd66SFarzon Lotfi "DXIL Data Scalarization", false, false) 297324bdd66SFarzon Lotfi INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE, 298324bdd66SFarzon Lotfi "DXIL Data Scalarization", false, false) 299324bdd66SFarzon Lotfi 300324bdd66SFarzon Lotfi ModulePass *llvm::createDXILDataScalarizationLegacyPass() { 301324bdd66SFarzon Lotfi return new DXILDataScalarizationLegacy(); 302324bdd66SFarzon Lotfi } 303