xref: /llvm-project/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp (revision 6457aee5b7da6bb6d7f556d14f42a6763b42e060)
15ac624c8SFarzon Lotfi //===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===//
25ac624c8SFarzon Lotfi //
35ac624c8SFarzon Lotfi // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45ac624c8SFarzon Lotfi // See https://llvm.org/LICENSE.txt for license information.
55ac624c8SFarzon Lotfi // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65ac624c8SFarzon Lotfi //
75ac624c8SFarzon Lotfi //===---------------------------------------------------------------------===//
85ac624c8SFarzon Lotfi ///
95ac624c8SFarzon Lotfi /// \file This file contains a pass to flatten arrays for the DirectX Backend.
105ac624c8SFarzon Lotfi ///
115ac624c8SFarzon Lotfi //===----------------------------------------------------------------------===//
125ac624c8SFarzon Lotfi 
135ac624c8SFarzon Lotfi #include "DXILFlattenArrays.h"
145ac624c8SFarzon Lotfi #include "DirectX.h"
155ac624c8SFarzon Lotfi #include "llvm/ADT/PostOrderIterator.h"
165ac624c8SFarzon Lotfi #include "llvm/ADT/STLExtras.h"
175ac624c8SFarzon Lotfi #include "llvm/IR/BasicBlock.h"
185ac624c8SFarzon Lotfi #include "llvm/IR/DerivedTypes.h"
195ac624c8SFarzon Lotfi #include "llvm/IR/IRBuilder.h"
205ac624c8SFarzon Lotfi #include "llvm/IR/InstVisitor.h"
215ac624c8SFarzon Lotfi #include "llvm/IR/ReplaceConstant.h"
225ac624c8SFarzon Lotfi #include "llvm/Support/Casting.h"
235ac624c8SFarzon Lotfi #include "llvm/Transforms/Utils/Local.h"
245ac624c8SFarzon Lotfi #include <cassert>
255ac624c8SFarzon Lotfi #include <cstddef>
265ac624c8SFarzon Lotfi #include <cstdint>
275ac624c8SFarzon Lotfi #include <utility>
285ac624c8SFarzon Lotfi 
295ac624c8SFarzon Lotfi #define DEBUG_TYPE "dxil-flatten-arrays"
305ac624c8SFarzon Lotfi 
315ac624c8SFarzon Lotfi using namespace llvm;
325ac624c8SFarzon Lotfi namespace {
335ac624c8SFarzon Lotfi 
345ac624c8SFarzon Lotfi class DXILFlattenArraysLegacy : public ModulePass {
355ac624c8SFarzon Lotfi 
365ac624c8SFarzon Lotfi public:
375ac624c8SFarzon Lotfi   bool runOnModule(Module &M) override;
385ac624c8SFarzon Lotfi   DXILFlattenArraysLegacy() : ModulePass(ID) {}
395ac624c8SFarzon Lotfi 
405ac624c8SFarzon Lotfi   static char ID; // Pass identification.
415ac624c8SFarzon Lotfi };
425ac624c8SFarzon Lotfi 
435ac624c8SFarzon Lotfi struct GEPData {
445ac624c8SFarzon Lotfi   ArrayType *ParentArrayType;
455ac624c8SFarzon Lotfi   Value *ParendOperand;
465ac624c8SFarzon Lotfi   SmallVector<Value *> Indices;
475ac624c8SFarzon Lotfi   SmallVector<uint64_t> Dims;
485ac624c8SFarzon Lotfi   bool AllIndicesAreConstInt;
495ac624c8SFarzon Lotfi };
505ac624c8SFarzon Lotfi 
515ac624c8SFarzon Lotfi class DXILFlattenArraysVisitor
525ac624c8SFarzon Lotfi     : public InstVisitor<DXILFlattenArraysVisitor, bool> {
535ac624c8SFarzon Lotfi public:
545ac624c8SFarzon Lotfi   DXILFlattenArraysVisitor() {}
555ac624c8SFarzon Lotfi   bool visit(Function &F);
565ac624c8SFarzon Lotfi   // InstVisitor methods.  They return true if the instruction was scalarized,
575ac624c8SFarzon Lotfi   // false if nothing changed.
585ac624c8SFarzon Lotfi   bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
595ac624c8SFarzon Lotfi   bool visitAllocaInst(AllocaInst &AI);
605ac624c8SFarzon Lotfi   bool visitInstruction(Instruction &I) { return false; }
615ac624c8SFarzon Lotfi   bool visitSelectInst(SelectInst &SI) { return false; }
625ac624c8SFarzon Lotfi   bool visitICmpInst(ICmpInst &ICI) { return false; }
635ac624c8SFarzon Lotfi   bool visitFCmpInst(FCmpInst &FCI) { return false; }
645ac624c8SFarzon Lotfi   bool visitUnaryOperator(UnaryOperator &UO) { return false; }
655ac624c8SFarzon Lotfi   bool visitBinaryOperator(BinaryOperator &BO) { return false; }
665ac624c8SFarzon Lotfi   bool visitCastInst(CastInst &CI) { return false; }
675ac624c8SFarzon Lotfi   bool visitBitCastInst(BitCastInst &BCI) { return false; }
685ac624c8SFarzon Lotfi   bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
695ac624c8SFarzon Lotfi   bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
705ac624c8SFarzon Lotfi   bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
715ac624c8SFarzon Lotfi   bool visitPHINode(PHINode &PHI) { return false; }
725ac624c8SFarzon Lotfi   bool visitLoadInst(LoadInst &LI);
735ac624c8SFarzon Lotfi   bool visitStoreInst(StoreInst &SI);
745ac624c8SFarzon Lotfi   bool visitCallInst(CallInst &ICI) { return false; }
755ac624c8SFarzon Lotfi   bool visitFreezeInst(FreezeInst &FI) { return false; }
765ac624c8SFarzon Lotfi   static bool isMultiDimensionalArray(Type *T);
775ac624c8SFarzon Lotfi   static std::pair<unsigned, Type *> getElementCountAndType(Type *ArrayTy);
785ac624c8SFarzon Lotfi 
795ac624c8SFarzon Lotfi private:
805ac624c8SFarzon Lotfi   SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
815ac624c8SFarzon Lotfi   DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
825ac624c8SFarzon Lotfi   bool finish();
835ac624c8SFarzon Lotfi   ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices,
845ac624c8SFarzon Lotfi                                       ArrayRef<uint64_t> Dims,
855ac624c8SFarzon Lotfi                                       IRBuilder<> &Builder);
865ac624c8SFarzon Lotfi   Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices,
875ac624c8SFarzon Lotfi                                       ArrayRef<uint64_t> Dims,
885ac624c8SFarzon Lotfi                                       IRBuilder<> &Builder);
895ac624c8SFarzon Lotfi   void
905ac624c8SFarzon Lotfi   recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
915ac624c8SFarzon Lotfi                          ArrayType *FlattenedArrayType, Value *PtrOperand,
925ac624c8SFarzon Lotfi                          unsigned &GEPChainUseCount,
935ac624c8SFarzon Lotfi                          SmallVector<Value *> Indices = SmallVector<Value *>(),
945ac624c8SFarzon Lotfi                          SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
955ac624c8SFarzon Lotfi                          bool AllIndicesAreConstInt = true);
965ac624c8SFarzon Lotfi   bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
975ac624c8SFarzon Lotfi   bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
985ac624c8SFarzon Lotfi                                             GetElementPtrInst &GEP);
995ac624c8SFarzon Lotfi };
1005ac624c8SFarzon Lotfi } // namespace
1015ac624c8SFarzon Lotfi 
1025ac624c8SFarzon Lotfi bool DXILFlattenArraysVisitor::finish() {
1035ac624c8SFarzon Lotfi   RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
1045ac624c8SFarzon Lotfi   return true;
1055ac624c8SFarzon Lotfi }
1065ac624c8SFarzon Lotfi 
1075ac624c8SFarzon Lotfi bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
1085ac624c8SFarzon Lotfi   if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
1095ac624c8SFarzon Lotfi     return isa<ArrayType>(ArrType->getElementType());
1105ac624c8SFarzon Lotfi   return false;
1115ac624c8SFarzon Lotfi }
1125ac624c8SFarzon Lotfi 
1135ac624c8SFarzon Lotfi std::pair<unsigned, Type *>
1145ac624c8SFarzon Lotfi DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) {
1155ac624c8SFarzon Lotfi   unsigned TotalElements = 1;
1165ac624c8SFarzon Lotfi   Type *CurrArrayTy = ArrayTy;
1175ac624c8SFarzon Lotfi   while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
1185ac624c8SFarzon Lotfi     TotalElements *= InnerArrayTy->getNumElements();
1195ac624c8SFarzon Lotfi     CurrArrayTy = InnerArrayTy->getElementType();
1205ac624c8SFarzon Lotfi   }
1215ac624c8SFarzon Lotfi   return std::make_pair(TotalElements, CurrArrayTy);
1225ac624c8SFarzon Lotfi }
1235ac624c8SFarzon Lotfi 
1245ac624c8SFarzon Lotfi ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices(
1255ac624c8SFarzon Lotfi     ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
1265ac624c8SFarzon Lotfi   assert(Indices.size() == Dims.size() &&
1275ac624c8SFarzon Lotfi          "Indicies and dimmensions should be the same");
1285ac624c8SFarzon Lotfi   unsigned FlatIndex = 0;
1295ac624c8SFarzon Lotfi   unsigned Multiplier = 1;
1305ac624c8SFarzon Lotfi 
1315ac624c8SFarzon Lotfi   for (int I = Indices.size() - 1; I >= 0; --I) {
1325ac624c8SFarzon Lotfi     unsigned DimSize = Dims[I];
1335ac624c8SFarzon Lotfi     ConstantInt *CIndex = dyn_cast<ConstantInt>(Indices[I]);
1345ac624c8SFarzon Lotfi     assert(CIndex && "This function expects all indicies to be ConstantInt");
1355ac624c8SFarzon Lotfi     FlatIndex += CIndex->getZExtValue() * Multiplier;
1365ac624c8SFarzon Lotfi     Multiplier *= DimSize;
1375ac624c8SFarzon Lotfi   }
1385ac624c8SFarzon Lotfi   return Builder.getInt32(FlatIndex);
1395ac624c8SFarzon Lotfi }
1405ac624c8SFarzon Lotfi 
1415ac624c8SFarzon Lotfi Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices(
1425ac624c8SFarzon Lotfi     ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
1435ac624c8SFarzon Lotfi   if (Indices.size() == 1)
1445ac624c8SFarzon Lotfi     return Indices[0];
1455ac624c8SFarzon Lotfi 
1465ac624c8SFarzon Lotfi   Value *FlatIndex = Builder.getInt32(0);
1475ac624c8SFarzon Lotfi   unsigned Multiplier = 1;
1485ac624c8SFarzon Lotfi 
1495ac624c8SFarzon Lotfi   for (int I = Indices.size() - 1; I >= 0; --I) {
1505ac624c8SFarzon Lotfi     unsigned DimSize = Dims[I];
1515ac624c8SFarzon Lotfi     Value *VMultiplier = Builder.getInt32(Multiplier);
1525ac624c8SFarzon Lotfi     Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier);
1535ac624c8SFarzon Lotfi     FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex);
1545ac624c8SFarzon Lotfi     Multiplier *= DimSize;
1555ac624c8SFarzon Lotfi   }
1565ac624c8SFarzon Lotfi   return FlatIndex;
1575ac624c8SFarzon Lotfi }
1585ac624c8SFarzon Lotfi 
1595ac624c8SFarzon Lotfi bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
1605ac624c8SFarzon Lotfi   unsigned NumOperands = LI.getNumOperands();
1615ac624c8SFarzon Lotfi   for (unsigned I = 0; I < NumOperands; ++I) {
1625ac624c8SFarzon Lotfi     Value *CurrOpperand = LI.getOperand(I);
1635ac624c8SFarzon Lotfi     ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
1645ac624c8SFarzon Lotfi     if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
165*6457aee5SFarzon Lotfi       GetElementPtrInst *OldGEP =
166*6457aee5SFarzon Lotfi           cast<GetElementPtrInst>(CE->getAsInstruction());
167*6457aee5SFarzon Lotfi       OldGEP->insertBefore(&LI);
168*6457aee5SFarzon Lotfi 
169*6457aee5SFarzon Lotfi       IRBuilder<> Builder(&LI);
170*6457aee5SFarzon Lotfi       LoadInst *NewLoad =
171*6457aee5SFarzon Lotfi           Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
172*6457aee5SFarzon Lotfi       NewLoad->setAlignment(LI.getAlign());
173*6457aee5SFarzon Lotfi       LI.replaceAllUsesWith(NewLoad);
174*6457aee5SFarzon Lotfi       LI.eraseFromParent();
175*6457aee5SFarzon Lotfi       visitGetElementPtrInst(*OldGEP);
176*6457aee5SFarzon Lotfi       return true;
1775ac624c8SFarzon Lotfi     }
1785ac624c8SFarzon Lotfi   }
1795ac624c8SFarzon Lotfi   return false;
1805ac624c8SFarzon Lotfi }
1815ac624c8SFarzon Lotfi 
1825ac624c8SFarzon Lotfi bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
1835ac624c8SFarzon Lotfi   unsigned NumOperands = SI.getNumOperands();
1845ac624c8SFarzon Lotfi   for (unsigned I = 0; I < NumOperands; ++I) {
1855ac624c8SFarzon Lotfi     Value *CurrOpperand = SI.getOperand(I);
1865ac624c8SFarzon Lotfi     ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
1875ac624c8SFarzon Lotfi     if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
188*6457aee5SFarzon Lotfi       GetElementPtrInst *OldGEP =
189*6457aee5SFarzon Lotfi           cast<GetElementPtrInst>(CE->getAsInstruction());
190*6457aee5SFarzon Lotfi       OldGEP->insertBefore(&SI);
191*6457aee5SFarzon Lotfi 
192*6457aee5SFarzon Lotfi       IRBuilder<> Builder(&SI);
193*6457aee5SFarzon Lotfi       StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
194*6457aee5SFarzon Lotfi       NewStore->setAlignment(SI.getAlign());
195*6457aee5SFarzon Lotfi       SI.replaceAllUsesWith(NewStore);
196*6457aee5SFarzon Lotfi       SI.eraseFromParent();
197*6457aee5SFarzon Lotfi       visitGetElementPtrInst(*OldGEP);
198*6457aee5SFarzon Lotfi       return true;
1995ac624c8SFarzon Lotfi     }
2005ac624c8SFarzon Lotfi   }
2015ac624c8SFarzon Lotfi   return false;
2025ac624c8SFarzon Lotfi }
2035ac624c8SFarzon Lotfi 
2045ac624c8SFarzon Lotfi bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
2055ac624c8SFarzon Lotfi   if (!isMultiDimensionalArray(AI.getAllocatedType()))
2065ac624c8SFarzon Lotfi     return false;
2075ac624c8SFarzon Lotfi 
2085ac624c8SFarzon Lotfi   ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
2095ac624c8SFarzon Lotfi   IRBuilder<> Builder(&AI);
2105ac624c8SFarzon Lotfi   auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
2115ac624c8SFarzon Lotfi 
2125ac624c8SFarzon Lotfi   ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);
2135ac624c8SFarzon Lotfi   AllocaInst *FlatAlloca =
2145ac624c8SFarzon Lotfi       Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".flat");
2155ac624c8SFarzon Lotfi   FlatAlloca->setAlignment(AI.getAlign());
2165ac624c8SFarzon Lotfi   AI.replaceAllUsesWith(FlatAlloca);
2175ac624c8SFarzon Lotfi   AI.eraseFromParent();
2185ac624c8SFarzon Lotfi   return true;
2195ac624c8SFarzon Lotfi }
2205ac624c8SFarzon Lotfi 
2215ac624c8SFarzon Lotfi void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
2225ac624c8SFarzon Lotfi     GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
2235ac624c8SFarzon Lotfi     Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
2245ac624c8SFarzon Lotfi     SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) {
2255ac624c8SFarzon Lotfi   Value *LastIndex = CurrGEP.getOperand(CurrGEP.getNumOperands() - 1);
2265ac624c8SFarzon Lotfi   AllIndicesAreConstInt &= isa<ConstantInt>(LastIndex);
2275ac624c8SFarzon Lotfi   Indices.push_back(LastIndex);
2285ac624c8SFarzon Lotfi   assert(isa<ArrayType>(CurrGEP.getSourceElementType()));
2295ac624c8SFarzon Lotfi   Dims.push_back(
2305ac624c8SFarzon Lotfi       cast<ArrayType>(CurrGEP.getSourceElementType())->getNumElements());
2315ac624c8SFarzon Lotfi   bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());
2325ac624c8SFarzon Lotfi   if (!IsMultiDimArr) {
2335ac624c8SFarzon Lotfi     assert(GEPChainUseCount < FlattenedArrayType->getNumElements());
2345ac624c8SFarzon Lotfi     GEPChainMap.insert(
2355ac624c8SFarzon Lotfi         {&CurrGEP,
2365ac624c8SFarzon Lotfi          {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
2375ac624c8SFarzon Lotfi           std::move(Dims), AllIndicesAreConstInt}});
2385ac624c8SFarzon Lotfi     return;
2395ac624c8SFarzon Lotfi   }
2405ac624c8SFarzon Lotfi   bool GepUses = false;
2415ac624c8SFarzon Lotfi   for (auto *User : CurrGEP.users()) {
2425ac624c8SFarzon Lotfi     if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
2435ac624c8SFarzon Lotfi       recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
2445ac624c8SFarzon Lotfi                              ++GEPChainUseCount, Indices, Dims,
2455ac624c8SFarzon Lotfi                              AllIndicesAreConstInt);
2465ac624c8SFarzon Lotfi       GepUses = true;
2475ac624c8SFarzon Lotfi     }
2485ac624c8SFarzon Lotfi   }
2495ac624c8SFarzon Lotfi   // This case is just incase the gep chain doesn't end with a 1d array.
2505ac624c8SFarzon Lotfi   if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
2515ac624c8SFarzon Lotfi     GEPChainMap.insert(
2525ac624c8SFarzon Lotfi         {&CurrGEP,
2535ac624c8SFarzon Lotfi          {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
2545ac624c8SFarzon Lotfi           std::move(Dims), AllIndicesAreConstInt}});
2555ac624c8SFarzon Lotfi   }
2565ac624c8SFarzon Lotfi }
2575ac624c8SFarzon Lotfi 
2585ac624c8SFarzon Lotfi bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
2595ac624c8SFarzon Lotfi     GetElementPtrInst &GEP) {
2605ac624c8SFarzon Lotfi   GEPData GEPInfo = GEPChainMap.at(&GEP);
2615ac624c8SFarzon Lotfi   return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
2625ac624c8SFarzon Lotfi }
2635ac624c8SFarzon Lotfi bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
2645ac624c8SFarzon Lotfi     GEPData &GEPInfo, GetElementPtrInst &GEP) {
2655ac624c8SFarzon Lotfi   IRBuilder<> Builder(&GEP);
2665ac624c8SFarzon Lotfi   Value *FlatIndex;
2675ac624c8SFarzon Lotfi   if (GEPInfo.AllIndicesAreConstInt)
2685ac624c8SFarzon Lotfi     FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
2695ac624c8SFarzon Lotfi   else
2705ac624c8SFarzon Lotfi     FlatIndex =
2715ac624c8SFarzon Lotfi         genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
2725ac624c8SFarzon Lotfi 
2735ac624c8SFarzon Lotfi   ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
2745ac624c8SFarzon Lotfi   Value *FlatGEP =
2755ac624c8SFarzon Lotfi       Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParendOperand, FlatIndex,
2765ac624c8SFarzon Lotfi                         GEP.getName() + ".flat", GEP.isInBounds());
2775ac624c8SFarzon Lotfi 
2785ac624c8SFarzon Lotfi   GEP.replaceAllUsesWith(FlatGEP);
2795ac624c8SFarzon Lotfi   GEP.eraseFromParent();
2805ac624c8SFarzon Lotfi   return true;
2815ac624c8SFarzon Lotfi }
2825ac624c8SFarzon Lotfi 
2835ac624c8SFarzon Lotfi bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
2845ac624c8SFarzon Lotfi   auto It = GEPChainMap.find(&GEP);
2855ac624c8SFarzon Lotfi   if (It != GEPChainMap.end())
2865ac624c8SFarzon Lotfi     return visitGetElementPtrInstInGEPChain(GEP);
2875ac624c8SFarzon Lotfi   if (!isMultiDimensionalArray(GEP.getSourceElementType()))
2885ac624c8SFarzon Lotfi     return false;
2895ac624c8SFarzon Lotfi 
2905ac624c8SFarzon Lotfi   ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
2915ac624c8SFarzon Lotfi   IRBuilder<> Builder(&GEP);
2925ac624c8SFarzon Lotfi   auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
2935ac624c8SFarzon Lotfi   ArrayType *FlattenedArrayType = ArrayType::get(BaseType, TotalElements);
2945ac624c8SFarzon Lotfi 
2955ac624c8SFarzon Lotfi   Value *PtrOperand = GEP.getPointerOperand();
2965ac624c8SFarzon Lotfi 
2975ac624c8SFarzon Lotfi   unsigned GEPChainUseCount = 0;
2985ac624c8SFarzon Lotfi   recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
2995ac624c8SFarzon Lotfi 
3005ac624c8SFarzon Lotfi   // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
3015ac624c8SFarzon Lotfi   // Here recursion is used to get the length of the GEP chain.
3025ac624c8SFarzon Lotfi   // Handle zero uses here because there won't be an update via
3035ac624c8SFarzon Lotfi   // a child in the chain later.
3045ac624c8SFarzon Lotfi   if (GEPChainUseCount == 0) {
3055ac624c8SFarzon Lotfi     SmallVector<Value *> Indices({GEP.getOperand(GEP.getNumOperands() - 1)});
3065ac624c8SFarzon Lotfi     SmallVector<uint64_t> Dims({ArrType->getNumElements()});
3075ac624c8SFarzon Lotfi     bool AllIndicesAreConstInt = isa<ConstantInt>(Indices[0]);
3085ac624c8SFarzon Lotfi     GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
3095ac624c8SFarzon Lotfi                     std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
3105ac624c8SFarzon Lotfi     return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
3115ac624c8SFarzon Lotfi   }
3125ac624c8SFarzon Lotfi 
3135ac624c8SFarzon Lotfi   PotentiallyDeadInstrs.emplace_back(&GEP);
3145ac624c8SFarzon Lotfi   return false;
3155ac624c8SFarzon Lotfi }
3165ac624c8SFarzon Lotfi 
3175ac624c8SFarzon Lotfi bool DXILFlattenArraysVisitor::visit(Function &F) {
3185ac624c8SFarzon Lotfi   bool MadeChange = false;
3195ac624c8SFarzon Lotfi   ReversePostOrderTraversal<Function *> RPOT(&F);
3205ac624c8SFarzon Lotfi   for (BasicBlock *BB : make_early_inc_range(RPOT)) {
3215ac624c8SFarzon Lotfi     for (Instruction &I : make_early_inc_range(*BB))
3225ac624c8SFarzon Lotfi       MadeChange |= InstVisitor::visit(I);
3235ac624c8SFarzon Lotfi   }
3245ac624c8SFarzon Lotfi   finish();
3255ac624c8SFarzon Lotfi   return MadeChange;
3265ac624c8SFarzon Lotfi }
3275ac624c8SFarzon Lotfi 
3285ac624c8SFarzon Lotfi static void collectElements(Constant *Init,
3295ac624c8SFarzon Lotfi                             SmallVectorImpl<Constant *> &Elements) {
3305ac624c8SFarzon Lotfi   // Base case: If Init is not an array, add it directly to the vector.
331*6457aee5SFarzon Lotfi   auto *ArrayTy = dyn_cast<ArrayType>(Init->getType());
332*6457aee5SFarzon Lotfi   if (!ArrayTy) {
3335ac624c8SFarzon Lotfi     Elements.push_back(Init);
3345ac624c8SFarzon Lotfi     return;
3355ac624c8SFarzon Lotfi   }
336*6457aee5SFarzon Lotfi   unsigned ArrSize = ArrayTy->getNumElements();
337*6457aee5SFarzon Lotfi   if (isa<ConstantAggregateZero>(Init)) {
338*6457aee5SFarzon Lotfi     for (unsigned I = 0; I < ArrSize; ++I)
339*6457aee5SFarzon Lotfi       Elements.push_back(Constant::getNullValue(ArrayTy->getElementType()));
340*6457aee5SFarzon Lotfi     return;
341*6457aee5SFarzon Lotfi   }
3425ac624c8SFarzon Lotfi 
3435ac624c8SFarzon Lotfi   // Recursive case: Process each element in the array.
3445ac624c8SFarzon Lotfi   if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) {
3455ac624c8SFarzon Lotfi     for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) {
3465ac624c8SFarzon Lotfi       collectElements(ArrayConstant->getOperand(I), Elements);
3475ac624c8SFarzon Lotfi     }
3485ac624c8SFarzon Lotfi   } else if (auto *DataArrayConstant = dyn_cast<ConstantDataArray>(Init)) {
3495ac624c8SFarzon Lotfi     for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) {
3505ac624c8SFarzon Lotfi       collectElements(DataArrayConstant->getElementAsConstant(I), Elements);
3515ac624c8SFarzon Lotfi     }
3525ac624c8SFarzon Lotfi   } else {
3535ac624c8SFarzon Lotfi     llvm_unreachable(
3545ac624c8SFarzon Lotfi         "Expected a ConstantArray or ConstantDataArray for array initializer!");
3555ac624c8SFarzon Lotfi   }
3565ac624c8SFarzon Lotfi }
3575ac624c8SFarzon Lotfi 
3585ac624c8SFarzon Lotfi static Constant *transformInitializer(Constant *Init, Type *OrigType,
3595ac624c8SFarzon Lotfi                                       ArrayType *FlattenedType,
3605ac624c8SFarzon Lotfi                                       LLVMContext &Ctx) {
3615ac624c8SFarzon Lotfi   // Handle ConstantAggregateZero (zero-initialized constants)
3625ac624c8SFarzon Lotfi   if (isa<ConstantAggregateZero>(Init))
3635ac624c8SFarzon Lotfi     return ConstantAggregateZero::get(FlattenedType);
3645ac624c8SFarzon Lotfi 
3655ac624c8SFarzon Lotfi   // Handle UndefValue (undefined constants)
3665ac624c8SFarzon Lotfi   if (isa<UndefValue>(Init))
3675ac624c8SFarzon Lotfi     return UndefValue::get(FlattenedType);
3685ac624c8SFarzon Lotfi 
3695ac624c8SFarzon Lotfi   if (!isa<ArrayType>(OrigType))
3705ac624c8SFarzon Lotfi     return Init;
3715ac624c8SFarzon Lotfi 
3725ac624c8SFarzon Lotfi   SmallVector<Constant *> FlattenedElements;
3735ac624c8SFarzon Lotfi   collectElements(Init, FlattenedElements);
3745ac624c8SFarzon Lotfi   assert(FlattenedType->getNumElements() == FlattenedElements.size() &&
3755ac624c8SFarzon Lotfi          "The number of collected elements should match the FlattenedType");
3765ac624c8SFarzon Lotfi   return ConstantArray::get(FlattenedType, FlattenedElements);
3775ac624c8SFarzon Lotfi }
3785ac624c8SFarzon Lotfi 
3795ac624c8SFarzon Lotfi static void
3805ac624c8SFarzon Lotfi flattenGlobalArrays(Module &M,
3815ac624c8SFarzon Lotfi                     DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
3825ac624c8SFarzon Lotfi   LLVMContext &Ctx = M.getContext();
3835ac624c8SFarzon Lotfi   for (GlobalVariable &G : M.globals()) {
3845ac624c8SFarzon Lotfi     Type *OrigType = G.getValueType();
3855ac624c8SFarzon Lotfi     if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))
3865ac624c8SFarzon Lotfi       continue;
3875ac624c8SFarzon Lotfi 
3885ac624c8SFarzon Lotfi     ArrayType *ArrType = cast<ArrayType>(OrigType);
3895ac624c8SFarzon Lotfi     auto [TotalElements, BaseType] =
3905ac624c8SFarzon Lotfi         DXILFlattenArraysVisitor::getElementCountAndType(ArrType);
3915ac624c8SFarzon Lotfi     ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);
3925ac624c8SFarzon Lotfi 
3935ac624c8SFarzon Lotfi     // Create a new global variable with the updated type
3945ac624c8SFarzon Lotfi     // Note: Initializer is set via transformInitializer
3955ac624c8SFarzon Lotfi     GlobalVariable *NewGlobal =
3965ac624c8SFarzon Lotfi         new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(),
3975ac624c8SFarzon Lotfi                            /*Initializer=*/nullptr, G.getName() + ".1dim", &G,
3985ac624c8SFarzon Lotfi                            G.getThreadLocalMode(), G.getAddressSpace(),
3995ac624c8SFarzon Lotfi                            G.isExternallyInitialized());
4005ac624c8SFarzon Lotfi 
4015ac624c8SFarzon Lotfi     // Copy relevant attributes
4025ac624c8SFarzon Lotfi     NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
4035ac624c8SFarzon Lotfi     if (G.getAlignment() > 0) {
4045ac624c8SFarzon Lotfi       NewGlobal->setAlignment(G.getAlign());
4055ac624c8SFarzon Lotfi     }
4065ac624c8SFarzon Lotfi 
4075ac624c8SFarzon Lotfi     if (G.hasInitializer()) {
4085ac624c8SFarzon Lotfi       Constant *Init = G.getInitializer();
4095ac624c8SFarzon Lotfi       Constant *NewInit =
4105ac624c8SFarzon Lotfi           transformInitializer(Init, OrigType, FattenedArrayType, Ctx);
4115ac624c8SFarzon Lotfi       NewGlobal->setInitializer(NewInit);
4125ac624c8SFarzon Lotfi     }
4135ac624c8SFarzon Lotfi     GlobalMap[&G] = NewGlobal;
4145ac624c8SFarzon Lotfi   }
4155ac624c8SFarzon Lotfi }
4165ac624c8SFarzon Lotfi 
4175ac624c8SFarzon Lotfi static bool flattenArrays(Module &M) {
4185ac624c8SFarzon Lotfi   bool MadeChange = false;
4195ac624c8SFarzon Lotfi   DXILFlattenArraysVisitor Impl;
4205ac624c8SFarzon Lotfi   DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
4215ac624c8SFarzon Lotfi   flattenGlobalArrays(M, GlobalMap);
4225ac624c8SFarzon Lotfi   for (auto &F : make_early_inc_range(M.functions())) {
423e0b522ddSJustin Bogner     if (F.isDeclaration())
4245ac624c8SFarzon Lotfi       continue;
4255ac624c8SFarzon Lotfi     MadeChange |= Impl.visit(F);
4265ac624c8SFarzon Lotfi   }
4275ac624c8SFarzon Lotfi   for (auto &[Old, New] : GlobalMap) {
4285ac624c8SFarzon Lotfi     Old->replaceAllUsesWith(New);
4295ac624c8SFarzon Lotfi     Old->eraseFromParent();
4305ac624c8SFarzon Lotfi     MadeChange = true;
4315ac624c8SFarzon Lotfi   }
4325ac624c8SFarzon Lotfi   return MadeChange;
4335ac624c8SFarzon Lotfi }
4345ac624c8SFarzon Lotfi 
4355ac624c8SFarzon Lotfi PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) {
4365ac624c8SFarzon Lotfi   bool MadeChanges = flattenArrays(M);
4375ac624c8SFarzon Lotfi   if (!MadeChanges)
4385ac624c8SFarzon Lotfi     return PreservedAnalyses::all();
4395ac624c8SFarzon Lotfi   PreservedAnalyses PA;
4405ac624c8SFarzon Lotfi   return PA;
4415ac624c8SFarzon Lotfi }
4425ac624c8SFarzon Lotfi 
4435ac624c8SFarzon Lotfi bool DXILFlattenArraysLegacy::runOnModule(Module &M) {
4445ac624c8SFarzon Lotfi   return flattenArrays(M);
4455ac624c8SFarzon Lotfi }
4465ac624c8SFarzon Lotfi 
4475ac624c8SFarzon Lotfi char DXILFlattenArraysLegacy::ID = 0;
4485ac624c8SFarzon Lotfi 
4495ac624c8SFarzon Lotfi INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy, DEBUG_TYPE,
4505ac624c8SFarzon Lotfi                       "DXIL Array Flattener", false, false)
4515ac624c8SFarzon Lotfi INITIALIZE_PASS_END(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener",
4525ac624c8SFarzon Lotfi                     false, false)
4535ac624c8SFarzon Lotfi 
4545ac624c8SFarzon Lotfi ModulePass *llvm::createDXILFlattenArraysLegacyPass() {
4555ac624c8SFarzon Lotfi   return new DXILFlattenArraysLegacy();
4565ac624c8SFarzon Lotfi }
457