xref: /llvm-project/llvm/lib/Target/DirectX/DXILDataScalarization.cpp (revision 6457aee5b7da6bb6d7f556d14f42a6763b42e060)
1324bdd66SFarzon Lotfi //===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===//
2324bdd66SFarzon Lotfi //
3324bdd66SFarzon Lotfi // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4324bdd66SFarzon Lotfi // See https://llvm.org/LICENSE.txt for license information.
5324bdd66SFarzon Lotfi // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6324bdd66SFarzon Lotfi //
7324bdd66SFarzon Lotfi //===---------------------------------------------------------------------===//
8324bdd66SFarzon Lotfi 
9324bdd66SFarzon Lotfi #include "DXILDataScalarization.h"
10324bdd66SFarzon Lotfi #include "DirectX.h"
11324bdd66SFarzon Lotfi #include "llvm/ADT/PostOrderIterator.h"
12324bdd66SFarzon Lotfi #include "llvm/ADT/STLExtras.h"
13324bdd66SFarzon Lotfi #include "llvm/IR/GlobalVariable.h"
14324bdd66SFarzon Lotfi #include "llvm/IR/IRBuilder.h"
15324bdd66SFarzon Lotfi #include "llvm/IR/InstVisitor.h"
16324bdd66SFarzon Lotfi #include "llvm/IR/Module.h"
17324bdd66SFarzon Lotfi #include "llvm/IR/Operator.h"
18324bdd66SFarzon Lotfi #include "llvm/IR/PassManager.h"
19324bdd66SFarzon Lotfi #include "llvm/IR/ReplaceConstant.h"
20324bdd66SFarzon Lotfi #include "llvm/IR/Type.h"
21324bdd66SFarzon Lotfi #include "llvm/Transforms/Utils/Cloning.h"
22324bdd66SFarzon Lotfi #include "llvm/Transforms/Utils/Local.h"
23324bdd66SFarzon Lotfi 
24324bdd66SFarzon Lotfi #define DEBUG_TYPE "dxil-data-scalarization"
25324bdd66SFarzon Lotfi static const int MaxVecSize = 4;
26324bdd66SFarzon Lotfi 
27324bdd66SFarzon Lotfi using namespace llvm;
28324bdd66SFarzon Lotfi 
29324bdd66SFarzon Lotfi class DXILDataScalarizationLegacy : public ModulePass {
30324bdd66SFarzon Lotfi 
31324bdd66SFarzon Lotfi public:
32324bdd66SFarzon Lotfi   bool runOnModule(Module &M) override;
33324bdd66SFarzon Lotfi   DXILDataScalarizationLegacy() : ModulePass(ID) {}
34324bdd66SFarzon Lotfi 
35324bdd66SFarzon Lotfi   static char ID; // Pass identification.
36324bdd66SFarzon Lotfi };
37324bdd66SFarzon Lotfi 
38324bdd66SFarzon Lotfi static bool findAndReplaceVectors(Module &M);
39324bdd66SFarzon Lotfi 
40324bdd66SFarzon Lotfi class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
41324bdd66SFarzon Lotfi public:
42324bdd66SFarzon Lotfi   DataScalarizerVisitor() : GlobalMap() {}
43*6457aee5SFarzon Lotfi   bool visit(Instruction &I);
44324bdd66SFarzon Lotfi   // InstVisitor methods.  They return true if the instruction was scalarized,
45324bdd66SFarzon Lotfi   // false if nothing changed.
46324bdd66SFarzon Lotfi   bool visitInstruction(Instruction &I) { return false; }
47324bdd66SFarzon Lotfi   bool visitSelectInst(SelectInst &SI) { return false; }
48324bdd66SFarzon Lotfi   bool visitICmpInst(ICmpInst &ICI) { return false; }
49324bdd66SFarzon Lotfi   bool visitFCmpInst(FCmpInst &FCI) { return false; }
50324bdd66SFarzon Lotfi   bool visitUnaryOperator(UnaryOperator &UO) { return false; }
51324bdd66SFarzon Lotfi   bool visitBinaryOperator(BinaryOperator &BO) { return false; }
52324bdd66SFarzon Lotfi   bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
53324bdd66SFarzon Lotfi   bool visitCastInst(CastInst &CI) { return false; }
54324bdd66SFarzon Lotfi   bool visitBitCastInst(BitCastInst &BCI) { return false; }
55324bdd66SFarzon Lotfi   bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
56324bdd66SFarzon Lotfi   bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
57324bdd66SFarzon Lotfi   bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
58324bdd66SFarzon Lotfi   bool visitPHINode(PHINode &PHI) { return false; }
59324bdd66SFarzon Lotfi   bool visitLoadInst(LoadInst &LI);
60324bdd66SFarzon Lotfi   bool visitStoreInst(StoreInst &SI);
61324bdd66SFarzon Lotfi   bool visitCallInst(CallInst &ICI) { return false; }
62324bdd66SFarzon Lotfi   bool visitFreezeInst(FreezeInst &FI) { return false; }
63324bdd66SFarzon Lotfi   friend bool findAndReplaceVectors(llvm::Module &M);
64324bdd66SFarzon Lotfi 
65324bdd66SFarzon Lotfi private:
66324bdd66SFarzon Lotfi   GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
67324bdd66SFarzon Lotfi   DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
68324bdd66SFarzon Lotfi };
69324bdd66SFarzon Lotfi 
70*6457aee5SFarzon Lotfi bool DataScalarizerVisitor::visit(Instruction &I) {
71324bdd66SFarzon Lotfi   assert(!GlobalMap.empty());
72*6457aee5SFarzon Lotfi   return InstVisitor::visit(I);
73324bdd66SFarzon Lotfi }
74324bdd66SFarzon Lotfi 
75324bdd66SFarzon Lotfi GlobalVariable *
76324bdd66SFarzon Lotfi DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
77324bdd66SFarzon Lotfi   if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) {
78324bdd66SFarzon Lotfi     auto It = GlobalMap.find(OldGlobal);
79324bdd66SFarzon Lotfi     if (It != GlobalMap.end()) {
80324bdd66SFarzon Lotfi       return It->second; // Found, return the new global
81324bdd66SFarzon Lotfi     }
82324bdd66SFarzon Lotfi   }
83324bdd66SFarzon Lotfi   return nullptr; // Not found
84324bdd66SFarzon Lotfi }
85324bdd66SFarzon Lotfi 
86324bdd66SFarzon Lotfi bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
87324bdd66SFarzon Lotfi   unsigned NumOperands = LI.getNumOperands();
88324bdd66SFarzon Lotfi   for (unsigned I = 0; I < NumOperands; ++I) {
89324bdd66SFarzon Lotfi     Value *CurrOpperand = LI.getOperand(I);
90*6457aee5SFarzon Lotfi     ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
91*6457aee5SFarzon Lotfi     if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
92*6457aee5SFarzon Lotfi       GetElementPtrInst *OldGEP =
93*6457aee5SFarzon Lotfi           cast<GetElementPtrInst>(CE->getAsInstruction());
94*6457aee5SFarzon Lotfi       OldGEP->insertBefore(&LI);
95*6457aee5SFarzon Lotfi       IRBuilder<> Builder(&LI);
96*6457aee5SFarzon Lotfi       LoadInst *NewLoad =
97*6457aee5SFarzon Lotfi           Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
98*6457aee5SFarzon Lotfi       NewLoad->setAlignment(LI.getAlign());
99*6457aee5SFarzon Lotfi       LI.replaceAllUsesWith(NewLoad);
100*6457aee5SFarzon Lotfi       LI.eraseFromParent();
101*6457aee5SFarzon Lotfi       visitGetElementPtrInst(*OldGEP);
102*6457aee5SFarzon Lotfi       return true;
103*6457aee5SFarzon Lotfi     }
104324bdd66SFarzon Lotfi     if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
105324bdd66SFarzon Lotfi       LI.setOperand(I, NewGlobal);
106324bdd66SFarzon Lotfi   }
107324bdd66SFarzon Lotfi   return false;
108324bdd66SFarzon Lotfi }
109324bdd66SFarzon Lotfi 
110324bdd66SFarzon Lotfi bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
111324bdd66SFarzon Lotfi   unsigned NumOperands = SI.getNumOperands();
112324bdd66SFarzon Lotfi   for (unsigned I = 0; I < NumOperands; ++I) {
113324bdd66SFarzon Lotfi     Value *CurrOpperand = SI.getOperand(I);
114*6457aee5SFarzon Lotfi     ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
115*6457aee5SFarzon Lotfi     if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
116*6457aee5SFarzon Lotfi       GetElementPtrInst *OldGEP =
117*6457aee5SFarzon Lotfi           cast<GetElementPtrInst>(CE->getAsInstruction());
118*6457aee5SFarzon Lotfi       OldGEP->insertBefore(&SI);
119*6457aee5SFarzon Lotfi       IRBuilder<> Builder(&SI);
120*6457aee5SFarzon Lotfi       StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
121*6457aee5SFarzon Lotfi       NewStore->setAlignment(SI.getAlign());
122*6457aee5SFarzon Lotfi       SI.replaceAllUsesWith(NewStore);
123*6457aee5SFarzon Lotfi       SI.eraseFromParent();
124*6457aee5SFarzon Lotfi       visitGetElementPtrInst(*OldGEP);
125*6457aee5SFarzon Lotfi       return true;
126324bdd66SFarzon Lotfi     }
127*6457aee5SFarzon Lotfi     if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
128*6457aee5SFarzon Lotfi       SI.setOperand(I, NewGlobal);
129324bdd66SFarzon Lotfi   }
130324bdd66SFarzon Lotfi   return false;
131324bdd66SFarzon Lotfi }
132324bdd66SFarzon Lotfi 
133324bdd66SFarzon Lotfi bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
134*6457aee5SFarzon Lotfi 
135324bdd66SFarzon Lotfi   unsigned NumOperands = GEPI.getNumOperands();
136*6457aee5SFarzon Lotfi   GlobalVariable *NewGlobal = nullptr;
137324bdd66SFarzon Lotfi   for (unsigned I = 0; I < NumOperands; ++I) {
138324bdd66SFarzon Lotfi     Value *CurrOpperand = GEPI.getOperand(I);
139*6457aee5SFarzon Lotfi     NewGlobal = lookupReplacementGlobal(CurrOpperand);
140*6457aee5SFarzon Lotfi     if (NewGlobal)
141*6457aee5SFarzon Lotfi       break;
142*6457aee5SFarzon Lotfi   }
143324bdd66SFarzon Lotfi   if (!NewGlobal)
144*6457aee5SFarzon Lotfi     return false;
145324bdd66SFarzon Lotfi 
146*6457aee5SFarzon Lotfi   IRBuilder<> Builder(&GEPI);
147324bdd66SFarzon Lotfi   SmallVector<Value *, MaxVecSize> Indices;
148324bdd66SFarzon Lotfi   for (auto &Index : GEPI.indices())
149324bdd66SFarzon Lotfi     Indices.push_back(Index);
150324bdd66SFarzon Lotfi 
151324bdd66SFarzon Lotfi   Value *NewGEP =
152*6457aee5SFarzon Lotfi       Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices,
153*6457aee5SFarzon Lotfi                         GEPI.getName(), GEPI.getNoWrapFlags());
154324bdd66SFarzon Lotfi   GEPI.replaceAllUsesWith(NewGEP);
155*6457aee5SFarzon Lotfi   GEPI.eraseFromParent();
156324bdd66SFarzon Lotfi   return true;
157324bdd66SFarzon Lotfi }
158324bdd66SFarzon Lotfi 
159324bdd66SFarzon Lotfi // Recursively Creates and Array like version of the given vector like type.
160324bdd66SFarzon Lotfi static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
161324bdd66SFarzon Lotfi   if (auto *VecTy = dyn_cast<VectorType>(T))
162324bdd66SFarzon Lotfi     return ArrayType::get(VecTy->getElementType(),
163324bdd66SFarzon Lotfi                           dyn_cast<FixedVectorType>(VecTy)->getNumElements());
164324bdd66SFarzon Lotfi   if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
165324bdd66SFarzon Lotfi     Type *NewElementType =
166324bdd66SFarzon Lotfi         replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
167324bdd66SFarzon Lotfi     return ArrayType::get(NewElementType, ArrayTy->getNumElements());
168324bdd66SFarzon Lotfi   }
169324bdd66SFarzon Lotfi   // If it's not a vector or array, return the original type.
170324bdd66SFarzon Lotfi   return T;
171324bdd66SFarzon Lotfi }
172324bdd66SFarzon Lotfi 
173324bdd66SFarzon Lotfi Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
174324bdd66SFarzon Lotfi                                LLVMContext &Ctx) {
175324bdd66SFarzon Lotfi   // Handle ConstantAggregateZero (zero-initialized constants)
176324bdd66SFarzon Lotfi   if (isa<ConstantAggregateZero>(Init)) {
177324bdd66SFarzon Lotfi     return ConstantAggregateZero::get(NewType);
178324bdd66SFarzon Lotfi   }
179324bdd66SFarzon Lotfi 
180324bdd66SFarzon Lotfi   // Handle UndefValue (undefined constants)
181324bdd66SFarzon Lotfi   if (isa<UndefValue>(Init)) {
182324bdd66SFarzon Lotfi     return UndefValue::get(NewType);
183324bdd66SFarzon Lotfi   }
184324bdd66SFarzon Lotfi 
185324bdd66SFarzon Lotfi   // Handle vector to array transformation
186324bdd66SFarzon Lotfi   if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
187324bdd66SFarzon Lotfi     // Convert vector initializer to array initializer
188324bdd66SFarzon Lotfi     SmallVector<Constant *, MaxVecSize> ArrayElements;
189324bdd66SFarzon Lotfi     if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
190324bdd66SFarzon Lotfi       for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)
191324bdd66SFarzon Lotfi         ArrayElements.push_back(ConstVecInit->getOperand(I));
192324bdd66SFarzon Lotfi     } else if (ConstantDataVector *ConstDataVecInit =
193324bdd66SFarzon Lotfi                    llvm::dyn_cast<llvm::ConstantDataVector>(Init)) {
194324bdd66SFarzon Lotfi       for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)
195324bdd66SFarzon Lotfi         ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
196324bdd66SFarzon Lotfi     } else {
197324bdd66SFarzon Lotfi       assert(false && "Expected a ConstantVector or ConstantDataVector for "
198324bdd66SFarzon Lotfi                       "vector initializer!");
199324bdd66SFarzon Lotfi     }
200324bdd66SFarzon Lotfi 
201324bdd66SFarzon Lotfi     return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
202324bdd66SFarzon Lotfi   }
203324bdd66SFarzon Lotfi 
204324bdd66SFarzon Lotfi   // Handle array of vectors transformation
205324bdd66SFarzon Lotfi   if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
206324bdd66SFarzon Lotfi     auto *ArrayInit = dyn_cast<ConstantArray>(Init);
207324bdd66SFarzon Lotfi     assert(ArrayInit && "Expected a ConstantArray for array initializer!");
208324bdd66SFarzon Lotfi 
209324bdd66SFarzon Lotfi     SmallVector<Constant *, MaxVecSize> NewArrayElements;
210324bdd66SFarzon Lotfi     for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) {
211324bdd66SFarzon Lotfi       // Recursively transform array elements
212324bdd66SFarzon Lotfi       Constant *NewElemInit = transformInitializer(
213324bdd66SFarzon Lotfi           ArrayInit->getOperand(I), ArrayTy->getElementType(),
214324bdd66SFarzon Lotfi           cast<ArrayType>(NewType)->getElementType(), Ctx);
215324bdd66SFarzon Lotfi       NewArrayElements.push_back(NewElemInit);
216324bdd66SFarzon Lotfi     }
217324bdd66SFarzon Lotfi 
218324bdd66SFarzon Lotfi     return ConstantArray::get(cast<ArrayType>(NewType), NewArrayElements);
219324bdd66SFarzon Lotfi   }
220324bdd66SFarzon Lotfi 
221324bdd66SFarzon Lotfi   // If not a vector or array, return the original initializer
222324bdd66SFarzon Lotfi   return Init;
223324bdd66SFarzon Lotfi }
224324bdd66SFarzon Lotfi 
225324bdd66SFarzon Lotfi static bool findAndReplaceVectors(Module &M) {
226324bdd66SFarzon Lotfi   bool MadeChange = false;
227324bdd66SFarzon Lotfi   LLVMContext &Ctx = M.getContext();
228324bdd66SFarzon Lotfi   IRBuilder<> Builder(Ctx);
229324bdd66SFarzon Lotfi   DataScalarizerVisitor Impl;
230324bdd66SFarzon Lotfi   for (GlobalVariable &G : M.globals()) {
231324bdd66SFarzon Lotfi     Type *OrigType = G.getValueType();
232324bdd66SFarzon Lotfi 
233324bdd66SFarzon Lotfi     Type *NewType = replaceVectorWithArray(OrigType, Ctx);
234324bdd66SFarzon Lotfi     if (OrigType != NewType) {
235324bdd66SFarzon Lotfi       // Create a new global variable with the updated type
236324bdd66SFarzon Lotfi       // Note: Initializer is set via transformInitializer
237324bdd66SFarzon Lotfi       GlobalVariable *NewGlobal = new GlobalVariable(
238324bdd66SFarzon Lotfi           M, NewType, G.isConstant(), G.getLinkage(),
239324bdd66SFarzon Lotfi           /*Initializer=*/nullptr, G.getName() + ".scalarized", &G,
240324bdd66SFarzon Lotfi           G.getThreadLocalMode(), G.getAddressSpace(),
241324bdd66SFarzon Lotfi           G.isExternallyInitialized());
242324bdd66SFarzon Lotfi 
243324bdd66SFarzon Lotfi       // Copy relevant attributes
244324bdd66SFarzon Lotfi       NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
245324bdd66SFarzon Lotfi       if (G.getAlignment() > 0) {
246324bdd66SFarzon Lotfi         NewGlobal->setAlignment(G.getAlign());
247324bdd66SFarzon Lotfi       }
248324bdd66SFarzon Lotfi 
249324bdd66SFarzon Lotfi       if (G.hasInitializer()) {
250324bdd66SFarzon Lotfi         Constant *Init = G.getInitializer();
251324bdd66SFarzon Lotfi         Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx);
252324bdd66SFarzon Lotfi         NewGlobal->setInitializer(NewInit);
253324bdd66SFarzon Lotfi       }
254324bdd66SFarzon Lotfi 
255324bdd66SFarzon Lotfi       // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
256324bdd66SFarzon Lotfi       // type equality. Instead we will use the visitor pattern.
257324bdd66SFarzon Lotfi       Impl.GlobalMap[&G] = NewGlobal;
258324bdd66SFarzon Lotfi       for (User *U : make_early_inc_range(G.users())) {
259324bdd66SFarzon Lotfi         if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
260324bdd66SFarzon Lotfi           ConstantExpr *CE = cast<ConstantExpr>(U);
261*6457aee5SFarzon Lotfi           for (User *UCE : make_early_inc_range(CE->users())) {
262*6457aee5SFarzon Lotfi             if (Instruction *Inst = dyn_cast<Instruction>(UCE))
263*6457aee5SFarzon Lotfi               Impl.visit(*Inst);
264324bdd66SFarzon Lotfi           }
265324bdd66SFarzon Lotfi         }
266*6457aee5SFarzon Lotfi         if (Instruction *Inst = dyn_cast<Instruction>(U))
267*6457aee5SFarzon Lotfi           Impl.visit(*Inst);
268324bdd66SFarzon Lotfi       }
269324bdd66SFarzon Lotfi     }
270324bdd66SFarzon Lotfi   }
271324bdd66SFarzon Lotfi 
272324bdd66SFarzon Lotfi   // Remove the old globals after the iteration
273324bdd66SFarzon Lotfi   for (auto &[Old, New] : Impl.GlobalMap) {
274324bdd66SFarzon Lotfi     Old->eraseFromParent();
275324bdd66SFarzon Lotfi     MadeChange = true;
276324bdd66SFarzon Lotfi   }
277324bdd66SFarzon Lotfi   return MadeChange;
278324bdd66SFarzon Lotfi }
279324bdd66SFarzon Lotfi 
280324bdd66SFarzon Lotfi PreservedAnalyses DXILDataScalarization::run(Module &M,
281324bdd66SFarzon Lotfi                                              ModuleAnalysisManager &) {
282324bdd66SFarzon Lotfi   bool MadeChanges = findAndReplaceVectors(M);
283324bdd66SFarzon Lotfi   if (!MadeChanges)
284324bdd66SFarzon Lotfi     return PreservedAnalyses::all();
285324bdd66SFarzon Lotfi   PreservedAnalyses PA;
286324bdd66SFarzon Lotfi   return PA;
287324bdd66SFarzon Lotfi }
288324bdd66SFarzon Lotfi 
289324bdd66SFarzon Lotfi bool DXILDataScalarizationLegacy::runOnModule(Module &M) {
290324bdd66SFarzon Lotfi   return findAndReplaceVectors(M);
291324bdd66SFarzon Lotfi }
292324bdd66SFarzon Lotfi 
293324bdd66SFarzon Lotfi char DXILDataScalarizationLegacy::ID = 0;
294324bdd66SFarzon Lotfi 
295324bdd66SFarzon Lotfi INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE,
296324bdd66SFarzon Lotfi                       "DXIL Data Scalarization", false, false)
297324bdd66SFarzon Lotfi INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE,
298324bdd66SFarzon Lotfi                     "DXIL Data Scalarization", false, false)
299324bdd66SFarzon Lotfi 
300324bdd66SFarzon Lotfi ModulePass *llvm::createDXILDataScalarizationLegacyPass() {
301324bdd66SFarzon Lotfi   return new DXILDataScalarizationLegacy();
302324bdd66SFarzon Lotfi }
303