xref: /llvm-project/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp (revision 6457aee5b7da6bb6d7f556d14f42a6763b42e060)
1 //===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 ///
9 /// \file This file contains a pass to flatten arrays for the DirectX Backend.
10 ///
11 //===----------------------------------------------------------------------===//
12 
13 #include "DXILFlattenArrays.h"
14 #include "DirectX.h"
15 #include "llvm/ADT/PostOrderIterator.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/IR/BasicBlock.h"
18 #include "llvm/IR/DerivedTypes.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InstVisitor.h"
21 #include "llvm/IR/ReplaceConstant.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Transforms/Utils/Local.h"
24 #include <cassert>
25 #include <cstddef>
26 #include <cstdint>
27 #include <utility>
28 
29 #define DEBUG_TYPE "dxil-flatten-arrays"
30 
31 using namespace llvm;
32 namespace {
33 
34 class DXILFlattenArraysLegacy : public ModulePass {
35 
36 public:
37   bool runOnModule(Module &M) override;
38   DXILFlattenArraysLegacy() : ModulePass(ID) {}
39 
40   static char ID; // Pass identification.
41 };
42 
43 struct GEPData {
44   ArrayType *ParentArrayType;
45   Value *ParendOperand;
46   SmallVector<Value *> Indices;
47   SmallVector<uint64_t> Dims;
48   bool AllIndicesAreConstInt;
49 };
50 
51 class DXILFlattenArraysVisitor
52     : public InstVisitor<DXILFlattenArraysVisitor, bool> {
53 public:
54   DXILFlattenArraysVisitor() {}
55   bool visit(Function &F);
56   // InstVisitor methods.  They return true if the instruction was scalarized,
57   // false if nothing changed.
58   bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
59   bool visitAllocaInst(AllocaInst &AI);
60   bool visitInstruction(Instruction &I) { return false; }
61   bool visitSelectInst(SelectInst &SI) { return false; }
62   bool visitICmpInst(ICmpInst &ICI) { return false; }
63   bool visitFCmpInst(FCmpInst &FCI) { return false; }
64   bool visitUnaryOperator(UnaryOperator &UO) { return false; }
65   bool visitBinaryOperator(BinaryOperator &BO) { return false; }
66   bool visitCastInst(CastInst &CI) { return false; }
67   bool visitBitCastInst(BitCastInst &BCI) { return false; }
68   bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
69   bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
70   bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
71   bool visitPHINode(PHINode &PHI) { return false; }
72   bool visitLoadInst(LoadInst &LI);
73   bool visitStoreInst(StoreInst &SI);
74   bool visitCallInst(CallInst &ICI) { return false; }
75   bool visitFreezeInst(FreezeInst &FI) { return false; }
76   static bool isMultiDimensionalArray(Type *T);
77   static std::pair<unsigned, Type *> getElementCountAndType(Type *ArrayTy);
78 
79 private:
80   SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
81   DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
82   bool finish();
83   ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices,
84                                       ArrayRef<uint64_t> Dims,
85                                       IRBuilder<> &Builder);
86   Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices,
87                                       ArrayRef<uint64_t> Dims,
88                                       IRBuilder<> &Builder);
89   void
90   recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
91                          ArrayType *FlattenedArrayType, Value *PtrOperand,
92                          unsigned &GEPChainUseCount,
93                          SmallVector<Value *> Indices = SmallVector<Value *>(),
94                          SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
95                          bool AllIndicesAreConstInt = true);
96   bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
97   bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
98                                             GetElementPtrInst &GEP);
99 };
100 } // namespace
101 
102 bool DXILFlattenArraysVisitor::finish() {
103   RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
104   return true;
105 }
106 
107 bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
108   if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
109     return isa<ArrayType>(ArrType->getElementType());
110   return false;
111 }
112 
113 std::pair<unsigned, Type *>
114 DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) {
115   unsigned TotalElements = 1;
116   Type *CurrArrayTy = ArrayTy;
117   while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
118     TotalElements *= InnerArrayTy->getNumElements();
119     CurrArrayTy = InnerArrayTy->getElementType();
120   }
121   return std::make_pair(TotalElements, CurrArrayTy);
122 }
123 
124 ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices(
125     ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
126   assert(Indices.size() == Dims.size() &&
127          "Indicies and dimmensions should be the same");
128   unsigned FlatIndex = 0;
129   unsigned Multiplier = 1;
130 
131   for (int I = Indices.size() - 1; I >= 0; --I) {
132     unsigned DimSize = Dims[I];
133     ConstantInt *CIndex = dyn_cast<ConstantInt>(Indices[I]);
134     assert(CIndex && "This function expects all indicies to be ConstantInt");
135     FlatIndex += CIndex->getZExtValue() * Multiplier;
136     Multiplier *= DimSize;
137   }
138   return Builder.getInt32(FlatIndex);
139 }
140 
141 Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices(
142     ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
143   if (Indices.size() == 1)
144     return Indices[0];
145 
146   Value *FlatIndex = Builder.getInt32(0);
147   unsigned Multiplier = 1;
148 
149   for (int I = Indices.size() - 1; I >= 0; --I) {
150     unsigned DimSize = Dims[I];
151     Value *VMultiplier = Builder.getInt32(Multiplier);
152     Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier);
153     FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex);
154     Multiplier *= DimSize;
155   }
156   return FlatIndex;
157 }
158 
159 bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
160   unsigned NumOperands = LI.getNumOperands();
161   for (unsigned I = 0; I < NumOperands; ++I) {
162     Value *CurrOpperand = LI.getOperand(I);
163     ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
164     if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
165       GetElementPtrInst *OldGEP =
166           cast<GetElementPtrInst>(CE->getAsInstruction());
167       OldGEP->insertBefore(&LI);
168 
169       IRBuilder<> Builder(&LI);
170       LoadInst *NewLoad =
171           Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
172       NewLoad->setAlignment(LI.getAlign());
173       LI.replaceAllUsesWith(NewLoad);
174       LI.eraseFromParent();
175       visitGetElementPtrInst(*OldGEP);
176       return true;
177     }
178   }
179   return false;
180 }
181 
182 bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
183   unsigned NumOperands = SI.getNumOperands();
184   for (unsigned I = 0; I < NumOperands; ++I) {
185     Value *CurrOpperand = SI.getOperand(I);
186     ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
187     if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
188       GetElementPtrInst *OldGEP =
189           cast<GetElementPtrInst>(CE->getAsInstruction());
190       OldGEP->insertBefore(&SI);
191 
192       IRBuilder<> Builder(&SI);
193       StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
194       NewStore->setAlignment(SI.getAlign());
195       SI.replaceAllUsesWith(NewStore);
196       SI.eraseFromParent();
197       visitGetElementPtrInst(*OldGEP);
198       return true;
199     }
200   }
201   return false;
202 }
203 
204 bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
205   if (!isMultiDimensionalArray(AI.getAllocatedType()))
206     return false;
207 
208   ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
209   IRBuilder<> Builder(&AI);
210   auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
211 
212   ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);
213   AllocaInst *FlatAlloca =
214       Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".flat");
215   FlatAlloca->setAlignment(AI.getAlign());
216   AI.replaceAllUsesWith(FlatAlloca);
217   AI.eraseFromParent();
218   return true;
219 }
220 
221 void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
222     GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
223     Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
224     SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) {
225   Value *LastIndex = CurrGEP.getOperand(CurrGEP.getNumOperands() - 1);
226   AllIndicesAreConstInt &= isa<ConstantInt>(LastIndex);
227   Indices.push_back(LastIndex);
228   assert(isa<ArrayType>(CurrGEP.getSourceElementType()));
229   Dims.push_back(
230       cast<ArrayType>(CurrGEP.getSourceElementType())->getNumElements());
231   bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());
232   if (!IsMultiDimArr) {
233     assert(GEPChainUseCount < FlattenedArrayType->getNumElements());
234     GEPChainMap.insert(
235         {&CurrGEP,
236          {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
237           std::move(Dims), AllIndicesAreConstInt}});
238     return;
239   }
240   bool GepUses = false;
241   for (auto *User : CurrGEP.users()) {
242     if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
243       recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
244                              ++GEPChainUseCount, Indices, Dims,
245                              AllIndicesAreConstInt);
246       GepUses = true;
247     }
248   }
249   // This case is just incase the gep chain doesn't end with a 1d array.
250   if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
251     GEPChainMap.insert(
252         {&CurrGEP,
253          {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
254           std::move(Dims), AllIndicesAreConstInt}});
255   }
256 }
257 
258 bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
259     GetElementPtrInst &GEP) {
260   GEPData GEPInfo = GEPChainMap.at(&GEP);
261   return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
262 }
263 bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
264     GEPData &GEPInfo, GetElementPtrInst &GEP) {
265   IRBuilder<> Builder(&GEP);
266   Value *FlatIndex;
267   if (GEPInfo.AllIndicesAreConstInt)
268     FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
269   else
270     FlatIndex =
271         genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
272 
273   ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
274   Value *FlatGEP =
275       Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParendOperand, FlatIndex,
276                         GEP.getName() + ".flat", GEP.isInBounds());
277 
278   GEP.replaceAllUsesWith(FlatGEP);
279   GEP.eraseFromParent();
280   return true;
281 }
282 
283 bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
284   auto It = GEPChainMap.find(&GEP);
285   if (It != GEPChainMap.end())
286     return visitGetElementPtrInstInGEPChain(GEP);
287   if (!isMultiDimensionalArray(GEP.getSourceElementType()))
288     return false;
289 
290   ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
291   IRBuilder<> Builder(&GEP);
292   auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
293   ArrayType *FlattenedArrayType = ArrayType::get(BaseType, TotalElements);
294 
295   Value *PtrOperand = GEP.getPointerOperand();
296 
297   unsigned GEPChainUseCount = 0;
298   recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
299 
300   // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
301   // Here recursion is used to get the length of the GEP chain.
302   // Handle zero uses here because there won't be an update via
303   // a child in the chain later.
304   if (GEPChainUseCount == 0) {
305     SmallVector<Value *> Indices({GEP.getOperand(GEP.getNumOperands() - 1)});
306     SmallVector<uint64_t> Dims({ArrType->getNumElements()});
307     bool AllIndicesAreConstInt = isa<ConstantInt>(Indices[0]);
308     GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
309                     std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
310     return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
311   }
312 
313   PotentiallyDeadInstrs.emplace_back(&GEP);
314   return false;
315 }
316 
317 bool DXILFlattenArraysVisitor::visit(Function &F) {
318   bool MadeChange = false;
319   ReversePostOrderTraversal<Function *> RPOT(&F);
320   for (BasicBlock *BB : make_early_inc_range(RPOT)) {
321     for (Instruction &I : make_early_inc_range(*BB))
322       MadeChange |= InstVisitor::visit(I);
323   }
324   finish();
325   return MadeChange;
326 }
327 
328 static void collectElements(Constant *Init,
329                             SmallVectorImpl<Constant *> &Elements) {
330   // Base case: If Init is not an array, add it directly to the vector.
331   auto *ArrayTy = dyn_cast<ArrayType>(Init->getType());
332   if (!ArrayTy) {
333     Elements.push_back(Init);
334     return;
335   }
336   unsigned ArrSize = ArrayTy->getNumElements();
337   if (isa<ConstantAggregateZero>(Init)) {
338     for (unsigned I = 0; I < ArrSize; ++I)
339       Elements.push_back(Constant::getNullValue(ArrayTy->getElementType()));
340     return;
341   }
342 
343   // Recursive case: Process each element in the array.
344   if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) {
345     for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) {
346       collectElements(ArrayConstant->getOperand(I), Elements);
347     }
348   } else if (auto *DataArrayConstant = dyn_cast<ConstantDataArray>(Init)) {
349     for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) {
350       collectElements(DataArrayConstant->getElementAsConstant(I), Elements);
351     }
352   } else {
353     llvm_unreachable(
354         "Expected a ConstantArray or ConstantDataArray for array initializer!");
355   }
356 }
357 
358 static Constant *transformInitializer(Constant *Init, Type *OrigType,
359                                       ArrayType *FlattenedType,
360                                       LLVMContext &Ctx) {
361   // Handle ConstantAggregateZero (zero-initialized constants)
362   if (isa<ConstantAggregateZero>(Init))
363     return ConstantAggregateZero::get(FlattenedType);
364 
365   // Handle UndefValue (undefined constants)
366   if (isa<UndefValue>(Init))
367     return UndefValue::get(FlattenedType);
368 
369   if (!isa<ArrayType>(OrigType))
370     return Init;
371 
372   SmallVector<Constant *> FlattenedElements;
373   collectElements(Init, FlattenedElements);
374   assert(FlattenedType->getNumElements() == FlattenedElements.size() &&
375          "The number of collected elements should match the FlattenedType");
376   return ConstantArray::get(FlattenedType, FlattenedElements);
377 }
378 
379 static void
380 flattenGlobalArrays(Module &M,
381                     DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
382   LLVMContext &Ctx = M.getContext();
383   for (GlobalVariable &G : M.globals()) {
384     Type *OrigType = G.getValueType();
385     if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))
386       continue;
387 
388     ArrayType *ArrType = cast<ArrayType>(OrigType);
389     auto [TotalElements, BaseType] =
390         DXILFlattenArraysVisitor::getElementCountAndType(ArrType);
391     ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);
392 
393     // Create a new global variable with the updated type
394     // Note: Initializer is set via transformInitializer
395     GlobalVariable *NewGlobal =
396         new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(),
397                            /*Initializer=*/nullptr, G.getName() + ".1dim", &G,
398                            G.getThreadLocalMode(), G.getAddressSpace(),
399                            G.isExternallyInitialized());
400 
401     // Copy relevant attributes
402     NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
403     if (G.getAlignment() > 0) {
404       NewGlobal->setAlignment(G.getAlign());
405     }
406 
407     if (G.hasInitializer()) {
408       Constant *Init = G.getInitializer();
409       Constant *NewInit =
410           transformInitializer(Init, OrigType, FattenedArrayType, Ctx);
411       NewGlobal->setInitializer(NewInit);
412     }
413     GlobalMap[&G] = NewGlobal;
414   }
415 }
416 
417 static bool flattenArrays(Module &M) {
418   bool MadeChange = false;
419   DXILFlattenArraysVisitor Impl;
420   DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
421   flattenGlobalArrays(M, GlobalMap);
422   for (auto &F : make_early_inc_range(M.functions())) {
423     if (F.isDeclaration())
424       continue;
425     MadeChange |= Impl.visit(F);
426   }
427   for (auto &[Old, New] : GlobalMap) {
428     Old->replaceAllUsesWith(New);
429     Old->eraseFromParent();
430     MadeChange = true;
431   }
432   return MadeChange;
433 }
434 
435 PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) {
436   bool MadeChanges = flattenArrays(M);
437   if (!MadeChanges)
438     return PreservedAnalyses::all();
439   PreservedAnalyses PA;
440   return PA;
441 }
442 
443 bool DXILFlattenArraysLegacy::runOnModule(Module &M) {
444   return flattenArrays(M);
445 }
446 
447 char DXILFlattenArraysLegacy::ID = 0;
448 
449 INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy, DEBUG_TYPE,
450                       "DXIL Array Flattener", false, false)
451 INITIALIZE_PASS_END(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener",
452                     false, false)
453 
454 ModulePass *llvm::createDXILFlattenArraysLegacyPass() {
455   return new DXILFlattenArraysLegacy();
456 }
457