//===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===---------------------------------------------------------------------===// /// /// \file This file contains a pass to flatten arrays for the DirectX Backend. /// //===----------------------------------------------------------------------===// #include "DXILFlattenArrays.h" #include "DirectX.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/ReplaceConstant.h" #include "llvm/Support/Casting.h" #include "llvm/Transforms/Utils/Local.h" #include #include #include #include #define DEBUG_TYPE "dxil-flatten-arrays" using namespace llvm; namespace { class DXILFlattenArraysLegacy : public ModulePass { public: bool runOnModule(Module &M) override; DXILFlattenArraysLegacy() : ModulePass(ID) {} static char ID; // Pass identification. }; struct GEPData { ArrayType *ParentArrayType; Value *ParendOperand; SmallVector Indices; SmallVector Dims; bool AllIndicesAreConstInt; }; class DXILFlattenArraysVisitor : public InstVisitor { public: DXILFlattenArraysVisitor() {} bool visit(Function &F); // InstVisitor methods. They return true if the instruction was scalarized, // false if nothing changed. bool visitGetElementPtrInst(GetElementPtrInst &GEPI); bool visitAllocaInst(AllocaInst &AI); bool visitInstruction(Instruction &I) { return false; } bool visitSelectInst(SelectInst &SI) { return false; } bool visitICmpInst(ICmpInst &ICI) { return false; } bool visitFCmpInst(FCmpInst &FCI) { return false; } bool visitUnaryOperator(UnaryOperator &UO) { return false; } bool visitBinaryOperator(BinaryOperator &BO) { return false; } bool visitCastInst(CastInst &CI) { return false; } bool visitBitCastInst(BitCastInst &BCI) { return false; } bool visitInsertElementInst(InsertElementInst &IEI) { return false; } bool visitExtractElementInst(ExtractElementInst &EEI) { return false; } bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; } bool visitPHINode(PHINode &PHI) { return false; } bool visitLoadInst(LoadInst &LI); bool visitStoreInst(StoreInst &SI); bool visitCallInst(CallInst &ICI) { return false; } bool visitFreezeInst(FreezeInst &FI) { return false; } static bool isMultiDimensionalArray(Type *T); static std::pair getElementCountAndType(Type *ArrayTy); private: SmallVector PotentiallyDeadInstrs; DenseMap GEPChainMap; bool finish(); ConstantInt *genConstFlattenIndices(ArrayRef Indices, ArrayRef Dims, IRBuilder<> &Builder); Value *genInstructionFlattenIndices(ArrayRef Indices, ArrayRef Dims, IRBuilder<> &Builder); void recursivelyCollectGEPs(GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType, Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector Indices = SmallVector(), SmallVector Dims = SmallVector(), bool AllIndicesAreConstInt = true); bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP); bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo, GetElementPtrInst &GEP); }; } // namespace bool DXILFlattenArraysVisitor::finish() { RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs); return true; } bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) { if (ArrayType *ArrType = dyn_cast(T)) return isa(ArrType->getElementType()); return false; } std::pair DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) { unsigned TotalElements = 1; Type *CurrArrayTy = ArrayTy; while (auto *InnerArrayTy = dyn_cast(CurrArrayTy)) { TotalElements *= InnerArrayTy->getNumElements(); CurrArrayTy = InnerArrayTy->getElementType(); } return std::make_pair(TotalElements, CurrArrayTy); } ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices( ArrayRef Indices, ArrayRef Dims, IRBuilder<> &Builder) { assert(Indices.size() == Dims.size() && "Indicies and dimmensions should be the same"); unsigned FlatIndex = 0; unsigned Multiplier = 1; for (int I = Indices.size() - 1; I >= 0; --I) { unsigned DimSize = Dims[I]; ConstantInt *CIndex = dyn_cast(Indices[I]); assert(CIndex && "This function expects all indicies to be ConstantInt"); FlatIndex += CIndex->getZExtValue() * Multiplier; Multiplier *= DimSize; } return Builder.getInt32(FlatIndex); } Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices( ArrayRef Indices, ArrayRef Dims, IRBuilder<> &Builder) { if (Indices.size() == 1) return Indices[0]; Value *FlatIndex = Builder.getInt32(0); unsigned Multiplier = 1; for (int I = Indices.size() - 1; I >= 0; --I) { unsigned DimSize = Dims[I]; Value *VMultiplier = Builder.getInt32(Multiplier); Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier); FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex); Multiplier *= DimSize; } return FlatIndex; } bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) { unsigned NumOperands = LI.getNumOperands(); for (unsigned I = 0; I < NumOperands; ++I) { Value *CurrOpperand = LI.getOperand(I); ConstantExpr *CE = dyn_cast(CurrOpperand); if (CE && CE->getOpcode() == Instruction::GetElementPtr) { GetElementPtrInst *OldGEP = cast(CE->getAsInstruction()); OldGEP->insertBefore(&LI); IRBuilder<> Builder(&LI); LoadInst *NewLoad = Builder.CreateLoad(LI.getType(), OldGEP, LI.getName()); NewLoad->setAlignment(LI.getAlign()); LI.replaceAllUsesWith(NewLoad); LI.eraseFromParent(); visitGetElementPtrInst(*OldGEP); return true; } } return false; } bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) { unsigned NumOperands = SI.getNumOperands(); for (unsigned I = 0; I < NumOperands; ++I) { Value *CurrOpperand = SI.getOperand(I); ConstantExpr *CE = dyn_cast(CurrOpperand); if (CE && CE->getOpcode() == Instruction::GetElementPtr) { GetElementPtrInst *OldGEP = cast(CE->getAsInstruction()); OldGEP->insertBefore(&SI); IRBuilder<> Builder(&SI); StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP); NewStore->setAlignment(SI.getAlign()); SI.replaceAllUsesWith(NewStore); SI.eraseFromParent(); visitGetElementPtrInst(*OldGEP); return true; } } return false; } bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) { if (!isMultiDimensionalArray(AI.getAllocatedType())) return false; ArrayType *ArrType = cast(AI.getAllocatedType()); IRBuilder<> Builder(&AI); auto [TotalElements, BaseType] = getElementCountAndType(ArrType); ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements); AllocaInst *FlatAlloca = Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".flat"); FlatAlloca->setAlignment(AI.getAlign()); AI.replaceAllUsesWith(FlatAlloca); AI.eraseFromParent(); return true; } void DXILFlattenArraysVisitor::recursivelyCollectGEPs( GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType, Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector Indices, SmallVector Dims, bool AllIndicesAreConstInt) { Value *LastIndex = CurrGEP.getOperand(CurrGEP.getNumOperands() - 1); AllIndicesAreConstInt &= isa(LastIndex); Indices.push_back(LastIndex); assert(isa(CurrGEP.getSourceElementType())); Dims.push_back( cast(CurrGEP.getSourceElementType())->getNumElements()); bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType()); if (!IsMultiDimArr) { assert(GEPChainUseCount < FlattenedArrayType->getNumElements()); GEPChainMap.insert( {&CurrGEP, {std::move(FlattenedArrayType), PtrOperand, std::move(Indices), std::move(Dims), AllIndicesAreConstInt}}); return; } bool GepUses = false; for (auto *User : CurrGEP.users()) { if (GetElementPtrInst *NestedGEP = dyn_cast(User)) { recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand, ++GEPChainUseCount, Indices, Dims, AllIndicesAreConstInt); GepUses = true; } } // This case is just incase the gep chain doesn't end with a 1d array. if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) { GEPChainMap.insert( {&CurrGEP, {std::move(FlattenedArrayType), PtrOperand, std::move(Indices), std::move(Dims), AllIndicesAreConstInt}}); } } bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain( GetElementPtrInst &GEP) { GEPData GEPInfo = GEPChainMap.at(&GEP); return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP); } bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase( GEPData &GEPInfo, GetElementPtrInst &GEP) { IRBuilder<> Builder(&GEP); Value *FlatIndex; if (GEPInfo.AllIndicesAreConstInt) FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder); else FlatIndex = genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder); ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType; Value *FlatGEP = Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParendOperand, FlatIndex, GEP.getName() + ".flat", GEP.isInBounds()); GEP.replaceAllUsesWith(FlatGEP); GEP.eraseFromParent(); return true; } bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { auto It = GEPChainMap.find(&GEP); if (It != GEPChainMap.end()) return visitGetElementPtrInstInGEPChain(GEP); if (!isMultiDimensionalArray(GEP.getSourceElementType())) return false; ArrayType *ArrType = cast(GEP.getSourceElementType()); IRBuilder<> Builder(&GEP); auto [TotalElements, BaseType] = getElementCountAndType(ArrType); ArrayType *FlattenedArrayType = ArrayType::get(BaseType, TotalElements); Value *PtrOperand = GEP.getPointerOperand(); unsigned GEPChainUseCount = 0; recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount); // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0. // Here recursion is used to get the length of the GEP chain. // Handle zero uses here because there won't be an update via // a child in the chain later. if (GEPChainUseCount == 0) { SmallVector Indices({GEP.getOperand(GEP.getNumOperands() - 1)}); SmallVector Dims({ArrType->getNumElements()}); bool AllIndicesAreConstInt = isa(Indices[0]); GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand, std::move(Indices), std::move(Dims), AllIndicesAreConstInt}; return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP); } PotentiallyDeadInstrs.emplace_back(&GEP); return false; } bool DXILFlattenArraysVisitor::visit(Function &F) { bool MadeChange = false; ReversePostOrderTraversal RPOT(&F); for (BasicBlock *BB : make_early_inc_range(RPOT)) { for (Instruction &I : make_early_inc_range(*BB)) MadeChange |= InstVisitor::visit(I); } finish(); return MadeChange; } static void collectElements(Constant *Init, SmallVectorImpl &Elements) { // Base case: If Init is not an array, add it directly to the vector. auto *ArrayTy = dyn_cast(Init->getType()); if (!ArrayTy) { Elements.push_back(Init); return; } unsigned ArrSize = ArrayTy->getNumElements(); if (isa(Init)) { for (unsigned I = 0; I < ArrSize; ++I) Elements.push_back(Constant::getNullValue(ArrayTy->getElementType())); return; } // Recursive case: Process each element in the array. if (auto *ArrayConstant = dyn_cast(Init)) { for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) { collectElements(ArrayConstant->getOperand(I), Elements); } } else if (auto *DataArrayConstant = dyn_cast(Init)) { for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) { collectElements(DataArrayConstant->getElementAsConstant(I), Elements); } } else { llvm_unreachable( "Expected a ConstantArray or ConstantDataArray for array initializer!"); } } static Constant *transformInitializer(Constant *Init, Type *OrigType, ArrayType *FlattenedType, LLVMContext &Ctx) { // Handle ConstantAggregateZero (zero-initialized constants) if (isa(Init)) return ConstantAggregateZero::get(FlattenedType); // Handle UndefValue (undefined constants) if (isa(Init)) return UndefValue::get(FlattenedType); if (!isa(OrigType)) return Init; SmallVector FlattenedElements; collectElements(Init, FlattenedElements); assert(FlattenedType->getNumElements() == FlattenedElements.size() && "The number of collected elements should match the FlattenedType"); return ConstantArray::get(FlattenedType, FlattenedElements); } static void flattenGlobalArrays(Module &M, DenseMap &GlobalMap) { LLVMContext &Ctx = M.getContext(); for (GlobalVariable &G : M.globals()) { Type *OrigType = G.getValueType(); if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType)) continue; ArrayType *ArrType = cast(OrigType); auto [TotalElements, BaseType] = DXILFlattenArraysVisitor::getElementCountAndType(ArrType); ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements); // Create a new global variable with the updated type // Note: Initializer is set via transformInitializer GlobalVariable *NewGlobal = new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(), /*Initializer=*/nullptr, G.getName() + ".1dim", &G, G.getThreadLocalMode(), G.getAddressSpace(), G.isExternallyInitialized()); // Copy relevant attributes NewGlobal->setUnnamedAddr(G.getUnnamedAddr()); if (G.getAlignment() > 0) { NewGlobal->setAlignment(G.getAlign()); } if (G.hasInitializer()) { Constant *Init = G.getInitializer(); Constant *NewInit = transformInitializer(Init, OrigType, FattenedArrayType, Ctx); NewGlobal->setInitializer(NewInit); } GlobalMap[&G] = NewGlobal; } } static bool flattenArrays(Module &M) { bool MadeChange = false; DXILFlattenArraysVisitor Impl; DenseMap GlobalMap; flattenGlobalArrays(M, GlobalMap); for (auto &F : make_early_inc_range(M.functions())) { if (F.isDeclaration()) continue; MadeChange |= Impl.visit(F); } for (auto &[Old, New] : GlobalMap) { Old->replaceAllUsesWith(New); Old->eraseFromParent(); MadeChange = true; } return MadeChange; } PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) { bool MadeChanges = flattenArrays(M); if (!MadeChanges) return PreservedAnalyses::all(); PreservedAnalyses PA; return PA; } bool DXILFlattenArraysLegacy::runOnModule(Module &M) { return flattenArrays(M); } char DXILFlattenArraysLegacy::ID = 0; INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener", false, false) INITIALIZE_PASS_END(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener", false, false) ModulePass *llvm::createDXILFlattenArraysLegacyPass() { return new DXILFlattenArraysLegacy(); }