xref: /llvm-project/llvm/lib/Target/DirectX/DXILDataScalarization.cpp (revision 6457aee5b7da6bb6d7f556d14f42a6763b42e060)
1 //===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 
9 #include "DXILDataScalarization.h"
10 #include "DirectX.h"
11 #include "llvm/ADT/PostOrderIterator.h"
12 #include "llvm/ADT/STLExtras.h"
13 #include "llvm/IR/GlobalVariable.h"
14 #include "llvm/IR/IRBuilder.h"
15 #include "llvm/IR/InstVisitor.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/IR/Operator.h"
18 #include "llvm/IR/PassManager.h"
19 #include "llvm/IR/ReplaceConstant.h"
20 #include "llvm/IR/Type.h"
21 #include "llvm/Transforms/Utils/Cloning.h"
22 #include "llvm/Transforms/Utils/Local.h"
23 
24 #define DEBUG_TYPE "dxil-data-scalarization"
25 static const int MaxVecSize = 4;
26 
27 using namespace llvm;
28 
29 class DXILDataScalarizationLegacy : public ModulePass {
30 
31 public:
32   bool runOnModule(Module &M) override;
33   DXILDataScalarizationLegacy() : ModulePass(ID) {}
34 
35   static char ID; // Pass identification.
36 };
37 
38 static bool findAndReplaceVectors(Module &M);
39 
40 class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
41 public:
42   DataScalarizerVisitor() : GlobalMap() {}
43   bool visit(Instruction &I);
44   // InstVisitor methods.  They return true if the instruction was scalarized,
45   // false if nothing changed.
46   bool visitInstruction(Instruction &I) { return false; }
47   bool visitSelectInst(SelectInst &SI) { return false; }
48   bool visitICmpInst(ICmpInst &ICI) { return false; }
49   bool visitFCmpInst(FCmpInst &FCI) { return false; }
50   bool visitUnaryOperator(UnaryOperator &UO) { return false; }
51   bool visitBinaryOperator(BinaryOperator &BO) { return false; }
52   bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
53   bool visitCastInst(CastInst &CI) { return false; }
54   bool visitBitCastInst(BitCastInst &BCI) { return false; }
55   bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
56   bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
57   bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
58   bool visitPHINode(PHINode &PHI) { return false; }
59   bool visitLoadInst(LoadInst &LI);
60   bool visitStoreInst(StoreInst &SI);
61   bool visitCallInst(CallInst &ICI) { return false; }
62   bool visitFreezeInst(FreezeInst &FI) { return false; }
63   friend bool findAndReplaceVectors(llvm::Module &M);
64 
65 private:
66   GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
67   DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
68 };
69 
70 bool DataScalarizerVisitor::visit(Instruction &I) {
71   assert(!GlobalMap.empty());
72   return InstVisitor::visit(I);
73 }
74 
75 GlobalVariable *
76 DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
77   if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) {
78     auto It = GlobalMap.find(OldGlobal);
79     if (It != GlobalMap.end()) {
80       return It->second; // Found, return the new global
81     }
82   }
83   return nullptr; // Not found
84 }
85 
86 bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
87   unsigned NumOperands = LI.getNumOperands();
88   for (unsigned I = 0; I < NumOperands; ++I) {
89     Value *CurrOpperand = LI.getOperand(I);
90     ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
91     if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
92       GetElementPtrInst *OldGEP =
93           cast<GetElementPtrInst>(CE->getAsInstruction());
94       OldGEP->insertBefore(&LI);
95       IRBuilder<> Builder(&LI);
96       LoadInst *NewLoad =
97           Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
98       NewLoad->setAlignment(LI.getAlign());
99       LI.replaceAllUsesWith(NewLoad);
100       LI.eraseFromParent();
101       visitGetElementPtrInst(*OldGEP);
102       return true;
103     }
104     if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
105       LI.setOperand(I, NewGlobal);
106   }
107   return false;
108 }
109 
110 bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
111   unsigned NumOperands = SI.getNumOperands();
112   for (unsigned I = 0; I < NumOperands; ++I) {
113     Value *CurrOpperand = SI.getOperand(I);
114     ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
115     if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
116       GetElementPtrInst *OldGEP =
117           cast<GetElementPtrInst>(CE->getAsInstruction());
118       OldGEP->insertBefore(&SI);
119       IRBuilder<> Builder(&SI);
120       StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
121       NewStore->setAlignment(SI.getAlign());
122       SI.replaceAllUsesWith(NewStore);
123       SI.eraseFromParent();
124       visitGetElementPtrInst(*OldGEP);
125       return true;
126     }
127     if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
128       SI.setOperand(I, NewGlobal);
129   }
130   return false;
131 }
132 
133 bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
134 
135   unsigned NumOperands = GEPI.getNumOperands();
136   GlobalVariable *NewGlobal = nullptr;
137   for (unsigned I = 0; I < NumOperands; ++I) {
138     Value *CurrOpperand = GEPI.getOperand(I);
139     NewGlobal = lookupReplacementGlobal(CurrOpperand);
140     if (NewGlobal)
141       break;
142   }
143   if (!NewGlobal)
144     return false;
145 
146   IRBuilder<> Builder(&GEPI);
147   SmallVector<Value *, MaxVecSize> Indices;
148   for (auto &Index : GEPI.indices())
149     Indices.push_back(Index);
150 
151   Value *NewGEP =
152       Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices,
153                         GEPI.getName(), GEPI.getNoWrapFlags());
154   GEPI.replaceAllUsesWith(NewGEP);
155   GEPI.eraseFromParent();
156   return true;
157 }
158 
159 // Recursively Creates and Array like version of the given vector like type.
160 static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
161   if (auto *VecTy = dyn_cast<VectorType>(T))
162     return ArrayType::get(VecTy->getElementType(),
163                           dyn_cast<FixedVectorType>(VecTy)->getNumElements());
164   if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
165     Type *NewElementType =
166         replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
167     return ArrayType::get(NewElementType, ArrayTy->getNumElements());
168   }
169   // If it's not a vector or array, return the original type.
170   return T;
171 }
172 
173 Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
174                                LLVMContext &Ctx) {
175   // Handle ConstantAggregateZero (zero-initialized constants)
176   if (isa<ConstantAggregateZero>(Init)) {
177     return ConstantAggregateZero::get(NewType);
178   }
179 
180   // Handle UndefValue (undefined constants)
181   if (isa<UndefValue>(Init)) {
182     return UndefValue::get(NewType);
183   }
184 
185   // Handle vector to array transformation
186   if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
187     // Convert vector initializer to array initializer
188     SmallVector<Constant *, MaxVecSize> ArrayElements;
189     if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
190       for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)
191         ArrayElements.push_back(ConstVecInit->getOperand(I));
192     } else if (ConstantDataVector *ConstDataVecInit =
193                    llvm::dyn_cast<llvm::ConstantDataVector>(Init)) {
194       for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)
195         ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
196     } else {
197       assert(false && "Expected a ConstantVector or ConstantDataVector for "
198                       "vector initializer!");
199     }
200 
201     return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
202   }
203 
204   // Handle array of vectors transformation
205   if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
206     auto *ArrayInit = dyn_cast<ConstantArray>(Init);
207     assert(ArrayInit && "Expected a ConstantArray for array initializer!");
208 
209     SmallVector<Constant *, MaxVecSize> NewArrayElements;
210     for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) {
211       // Recursively transform array elements
212       Constant *NewElemInit = transformInitializer(
213           ArrayInit->getOperand(I), ArrayTy->getElementType(),
214           cast<ArrayType>(NewType)->getElementType(), Ctx);
215       NewArrayElements.push_back(NewElemInit);
216     }
217 
218     return ConstantArray::get(cast<ArrayType>(NewType), NewArrayElements);
219   }
220 
221   // If not a vector or array, return the original initializer
222   return Init;
223 }
224 
225 static bool findAndReplaceVectors(Module &M) {
226   bool MadeChange = false;
227   LLVMContext &Ctx = M.getContext();
228   IRBuilder<> Builder(Ctx);
229   DataScalarizerVisitor Impl;
230   for (GlobalVariable &G : M.globals()) {
231     Type *OrigType = G.getValueType();
232 
233     Type *NewType = replaceVectorWithArray(OrigType, Ctx);
234     if (OrigType != NewType) {
235       // Create a new global variable with the updated type
236       // Note: Initializer is set via transformInitializer
237       GlobalVariable *NewGlobal = new GlobalVariable(
238           M, NewType, G.isConstant(), G.getLinkage(),
239           /*Initializer=*/nullptr, G.getName() + ".scalarized", &G,
240           G.getThreadLocalMode(), G.getAddressSpace(),
241           G.isExternallyInitialized());
242 
243       // Copy relevant attributes
244       NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
245       if (G.getAlignment() > 0) {
246         NewGlobal->setAlignment(G.getAlign());
247       }
248 
249       if (G.hasInitializer()) {
250         Constant *Init = G.getInitializer();
251         Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx);
252         NewGlobal->setInitializer(NewInit);
253       }
254 
255       // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
256       // type equality. Instead we will use the visitor pattern.
257       Impl.GlobalMap[&G] = NewGlobal;
258       for (User *U : make_early_inc_range(G.users())) {
259         if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
260           ConstantExpr *CE = cast<ConstantExpr>(U);
261           for (User *UCE : make_early_inc_range(CE->users())) {
262             if (Instruction *Inst = dyn_cast<Instruction>(UCE))
263               Impl.visit(*Inst);
264           }
265         }
266         if (Instruction *Inst = dyn_cast<Instruction>(U))
267           Impl.visit(*Inst);
268       }
269     }
270   }
271 
272   // Remove the old globals after the iteration
273   for (auto &[Old, New] : Impl.GlobalMap) {
274     Old->eraseFromParent();
275     MadeChange = true;
276   }
277   return MadeChange;
278 }
279 
280 PreservedAnalyses DXILDataScalarization::run(Module &M,
281                                              ModuleAnalysisManager &) {
282   bool MadeChanges = findAndReplaceVectors(M);
283   if (!MadeChanges)
284     return PreservedAnalyses::all();
285   PreservedAnalyses PA;
286   return PA;
287 }
288 
289 bool DXILDataScalarizationLegacy::runOnModule(Module &M) {
290   return findAndReplaceVectors(M);
291 }
292 
293 char DXILDataScalarizationLegacy::ID = 0;
294 
295 INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE,
296                       "DXIL Data Scalarization", false, false)
297 INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE,
298                     "DXIL Data Scalarization", false, false)
299 
300 ModulePass *llvm::createDXILDataScalarizationLegacyPass() {
301   return new DXILDataScalarizationLegacy();
302 }
303