//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===---------------------------------------------------------------------===// #include "DXILDataScalarization.h" #include "DirectX.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/ReplaceConstant.h" #include "llvm/IR/Type.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" #define DEBUG_TYPE "dxil-data-scalarization" static const int MaxVecSize = 4; using namespace llvm; class DXILDataScalarizationLegacy : public ModulePass { public: bool runOnModule(Module &M) override; DXILDataScalarizationLegacy() : ModulePass(ID) {} static char ID; // Pass identification. }; static bool findAndReplaceVectors(Module &M); class DataScalarizerVisitor : public InstVisitor { public: DataScalarizerVisitor() : GlobalMap() {} bool visit(Instruction &I); // InstVisitor methods. They return true if the instruction was scalarized, // false if nothing changed. bool visitInstruction(Instruction &I) { return false; } bool visitSelectInst(SelectInst &SI) { return false; } bool visitICmpInst(ICmpInst &ICI) { return false; } bool visitFCmpInst(FCmpInst &FCI) { return false; } bool visitUnaryOperator(UnaryOperator &UO) { return false; } bool visitBinaryOperator(BinaryOperator &BO) { return false; } bool visitGetElementPtrInst(GetElementPtrInst &GEPI); bool visitCastInst(CastInst &CI) { return false; } bool visitBitCastInst(BitCastInst &BCI) { return false; } bool visitInsertElementInst(InsertElementInst &IEI) { return false; } bool visitExtractElementInst(ExtractElementInst &EEI) { return false; } bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; } bool visitPHINode(PHINode &PHI) { return false; } bool visitLoadInst(LoadInst &LI); bool visitStoreInst(StoreInst &SI); bool visitCallInst(CallInst &ICI) { return false; } bool visitFreezeInst(FreezeInst &FI) { return false; } friend bool findAndReplaceVectors(llvm::Module &M); private: GlobalVariable *lookupReplacementGlobal(Value *CurrOperand); DenseMap GlobalMap; }; bool DataScalarizerVisitor::visit(Instruction &I) { assert(!GlobalMap.empty()); return InstVisitor::visit(I); } GlobalVariable * DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) { if (GlobalVariable *OldGlobal = dyn_cast(CurrOperand)) { auto It = GlobalMap.find(OldGlobal); if (It != GlobalMap.end()) { return It->second; // Found, return the new global } } return nullptr; // Not found } bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) { unsigned NumOperands = LI.getNumOperands(); for (unsigned I = 0; I < NumOperands; ++I) { Value *CurrOpperand = LI.getOperand(I); ConstantExpr *CE = dyn_cast(CurrOpperand); if (CE && CE->getOpcode() == Instruction::GetElementPtr) { GetElementPtrInst *OldGEP = cast(CE->getAsInstruction()); OldGEP->insertBefore(&LI); IRBuilder<> Builder(&LI); LoadInst *NewLoad = Builder.CreateLoad(LI.getType(), OldGEP, LI.getName()); NewLoad->setAlignment(LI.getAlign()); LI.replaceAllUsesWith(NewLoad); LI.eraseFromParent(); visitGetElementPtrInst(*OldGEP); return true; } if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) LI.setOperand(I, NewGlobal); } return false; } bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) { unsigned NumOperands = SI.getNumOperands(); for (unsigned I = 0; I < NumOperands; ++I) { Value *CurrOpperand = SI.getOperand(I); ConstantExpr *CE = dyn_cast(CurrOpperand); if (CE && CE->getOpcode() == Instruction::GetElementPtr) { GetElementPtrInst *OldGEP = cast(CE->getAsInstruction()); OldGEP->insertBefore(&SI); IRBuilder<> Builder(&SI); StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP); NewStore->setAlignment(SI.getAlign()); SI.replaceAllUsesWith(NewStore); SI.eraseFromParent(); visitGetElementPtrInst(*OldGEP); return true; } if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) SI.setOperand(I, NewGlobal); } return false; } bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { unsigned NumOperands = GEPI.getNumOperands(); GlobalVariable *NewGlobal = nullptr; for (unsigned I = 0; I < NumOperands; ++I) { Value *CurrOpperand = GEPI.getOperand(I); NewGlobal = lookupReplacementGlobal(CurrOpperand); if (NewGlobal) break; } if (!NewGlobal) return false; IRBuilder<> Builder(&GEPI); SmallVector Indices; for (auto &Index : GEPI.indices()) Indices.push_back(Index); Value *NewGEP = Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices, GEPI.getName(), GEPI.getNoWrapFlags()); GEPI.replaceAllUsesWith(NewGEP); GEPI.eraseFromParent(); return true; } // Recursively Creates and Array like version of the given vector like type. static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) { if (auto *VecTy = dyn_cast(T)) return ArrayType::get(VecTy->getElementType(), dyn_cast(VecTy)->getNumElements()); if (auto *ArrayTy = dyn_cast(T)) { Type *NewElementType = replaceVectorWithArray(ArrayTy->getElementType(), Ctx); return ArrayType::get(NewElementType, ArrayTy->getNumElements()); } // If it's not a vector or array, return the original type. return T; } Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType, LLVMContext &Ctx) { // Handle ConstantAggregateZero (zero-initialized constants) if (isa(Init)) { return ConstantAggregateZero::get(NewType); } // Handle UndefValue (undefined constants) if (isa(Init)) { return UndefValue::get(NewType); } // Handle vector to array transformation if (isa(OrigType) && isa(NewType)) { // Convert vector initializer to array initializer SmallVector ArrayElements; if (ConstantVector *ConstVecInit = dyn_cast(Init)) { for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I) ArrayElements.push_back(ConstVecInit->getOperand(I)); } else if (ConstantDataVector *ConstDataVecInit = llvm::dyn_cast(Init)) { for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I) ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I)); } else { assert(false && "Expected a ConstantVector or ConstantDataVector for " "vector initializer!"); } return ConstantArray::get(cast(NewType), ArrayElements); } // Handle array of vectors transformation if (auto *ArrayTy = dyn_cast(OrigType)) { auto *ArrayInit = dyn_cast(Init); assert(ArrayInit && "Expected a ConstantArray for array initializer!"); SmallVector NewArrayElements; for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) { // Recursively transform array elements Constant *NewElemInit = transformInitializer( ArrayInit->getOperand(I), ArrayTy->getElementType(), cast(NewType)->getElementType(), Ctx); NewArrayElements.push_back(NewElemInit); } return ConstantArray::get(cast(NewType), NewArrayElements); } // If not a vector or array, return the original initializer return Init; } static bool findAndReplaceVectors(Module &M) { bool MadeChange = false; LLVMContext &Ctx = M.getContext(); IRBuilder<> Builder(Ctx); DataScalarizerVisitor Impl; for (GlobalVariable &G : M.globals()) { Type *OrigType = G.getValueType(); Type *NewType = replaceVectorWithArray(OrigType, Ctx); if (OrigType != NewType) { // Create a new global variable with the updated type // Note: Initializer is set via transformInitializer GlobalVariable *NewGlobal = new GlobalVariable( M, NewType, G.isConstant(), G.getLinkage(), /*Initializer=*/nullptr, G.getName() + ".scalarized", &G, G.getThreadLocalMode(), G.getAddressSpace(), G.isExternallyInitialized()); // Copy relevant attributes NewGlobal->setUnnamedAddr(G.getUnnamedAddr()); if (G.getAlignment() > 0) { NewGlobal->setAlignment(G.getAlign()); } if (G.hasInitializer()) { Constant *Init = G.getInitializer(); Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx); NewGlobal->setInitializer(NewInit); } // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes // type equality. Instead we will use the visitor pattern. Impl.GlobalMap[&G] = NewGlobal; for (User *U : make_early_inc_range(G.users())) { if (isa(U) && isa(U)) { ConstantExpr *CE = cast(U); for (User *UCE : make_early_inc_range(CE->users())) { if (Instruction *Inst = dyn_cast(UCE)) Impl.visit(*Inst); } } if (Instruction *Inst = dyn_cast(U)) Impl.visit(*Inst); } } } // Remove the old globals after the iteration for (auto &[Old, New] : Impl.GlobalMap) { Old->eraseFromParent(); MadeChange = true; } return MadeChange; } PreservedAnalyses DXILDataScalarization::run(Module &M, ModuleAnalysisManager &) { bool MadeChanges = findAndReplaceVectors(M); if (!MadeChanges) return PreservedAnalyses::all(); PreservedAnalyses PA; return PA; } bool DXILDataScalarizationLegacy::runOnModule(Module &M) { return findAndReplaceVectors(M); } char DXILDataScalarizationLegacy::ID = 0; INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE, "DXIL Data Scalarization", false, false) INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE, "DXIL Data Scalarization", false, false) ModulePass *llvm::createDXILDataScalarizationLegacyPass() { return new DXILDataScalarizationLegacy(); }