1 //===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file
11 /// This pass removes performs the following type substitution on all
12 /// non-compute shaders:
13 ///
14 /// v16i8 => i128
15 /// - v16i8 is used for constant memory resource descriptors. This type is
16 /// legal for some compute APIs, and we don't want to declare it as legal
17 /// in the backend, because we want the legalizer to expand all v16i8
18 /// operations.
19 /// v1* => *
20 /// - Having v1* types complicates the legalizer and we can easily replace
21 /// - them with the element type.
22 //===----------------------------------------------------------------------===//
23
24 #include "AMDGPU.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/InstVisitor.h"
27
28 using namespace llvm;
29
30 namespace {
31
32 class SITypeRewriter : public FunctionPass,
33 public InstVisitor<SITypeRewriter> {
34
35 static char ID;
36 Module *Mod;
37 Type *v16i8;
38 Type *v4i32;
39
40 public:
SITypeRewriter()41 SITypeRewriter() : FunctionPass(ID) { }
42 bool doInitialization(Module &M) override;
43 bool runOnFunction(Function &F) override;
getPassName() const44 const char *getPassName() const override {
45 return "SI Type Rewriter";
46 }
47 void visitLoadInst(LoadInst &I);
48 void visitCallInst(CallInst &I);
49 void visitBitCast(BitCastInst &I);
50 };
51
52 } // End anonymous namespace
53
54 char SITypeRewriter::ID = 0;
55
doInitialization(Module & M)56 bool SITypeRewriter::doInitialization(Module &M) {
57 Mod = &M;
58 v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16);
59 v4i32 = VectorType::get(Type::getInt32Ty(M.getContext()), 4);
60 return false;
61 }
62
runOnFunction(Function & F)63 bool SITypeRewriter::runOnFunction(Function &F) {
64 AttributeSet Set = F.getAttributes();
65 Attribute A = Set.getAttribute(AttributeSet::FunctionIndex, "ShaderType");
66
67 unsigned ShaderType = ShaderType::COMPUTE;
68 if (A.isStringAttribute()) {
69 StringRef Str = A.getValueAsString();
70 Str.getAsInteger(0, ShaderType);
71 }
72 if (ShaderType == ShaderType::COMPUTE)
73 return false;
74
75 visit(F);
76 visit(F);
77
78 return false;
79 }
80
visitLoadInst(LoadInst & I)81 void SITypeRewriter::visitLoadInst(LoadInst &I) {
82 Value *Ptr = I.getPointerOperand();
83 Type *PtrTy = Ptr->getType();
84 Type *ElemTy = PtrTy->getPointerElementType();
85 IRBuilder<> Builder(&I);
86 if (ElemTy == v16i8) {
87 Value *BitCast = Builder.CreateBitCast(Ptr,
88 PointerType::get(v4i32,PtrTy->getPointerAddressSpace()));
89 LoadInst *Load = Builder.CreateLoad(BitCast);
90 SmallVector<std::pair<unsigned, MDNode *>, 8> MD;
91 I.getAllMetadataOtherThanDebugLoc(MD);
92 for (unsigned i = 0, e = MD.size(); i != e; ++i) {
93 Load->setMetadata(MD[i].first, MD[i].second);
94 }
95 Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType());
96 I.replaceAllUsesWith(BitCastLoad);
97 I.eraseFromParent();
98 }
99 }
100
visitCallInst(CallInst & I)101 void SITypeRewriter::visitCallInst(CallInst &I) {
102 IRBuilder<> Builder(&I);
103
104 SmallVector <Value*, 8> Args;
105 SmallVector <Type*, 8> Types;
106 bool NeedToReplace = false;
107 Function *F = I.getCalledFunction();
108 std::string Name = F->getName().str();
109 for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) {
110 Value *Arg = I.getArgOperand(i);
111 if (Arg->getType() == v16i8) {
112 Args.push_back(Builder.CreateBitCast(Arg, v4i32));
113 Types.push_back(v4i32);
114 NeedToReplace = true;
115 Name = Name + ".v4i32";
116 } else if (Arg->getType()->isVectorTy() &&
117 Arg->getType()->getVectorNumElements() == 1 &&
118 Arg->getType()->getVectorElementType() ==
119 Type::getInt32Ty(I.getContext())){
120 Type *ElementTy = Arg->getType()->getVectorElementType();
121 std::string TypeName = "i32";
122 InsertElementInst *Def = cast<InsertElementInst>(Arg);
123 Args.push_back(Def->getOperand(1));
124 Types.push_back(ElementTy);
125 std::string VecTypeName = "v1" + TypeName;
126 Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName);
127 NeedToReplace = true;
128 } else {
129 Args.push_back(Arg);
130 Types.push_back(Arg->getType());
131 }
132 }
133
134 if (!NeedToReplace) {
135 return;
136 }
137 Function *NewF = Mod->getFunction(Name);
138 if (!NewF) {
139 NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod);
140 NewF->setAttributes(F->getAttributes());
141 }
142 I.replaceAllUsesWith(Builder.CreateCall(NewF, Args));
143 I.eraseFromParent();
144 }
145
visitBitCast(BitCastInst & I)146 void SITypeRewriter::visitBitCast(BitCastInst &I) {
147 IRBuilder<> Builder(&I);
148 if (I.getDestTy() != v4i32) {
149 return;
150 }
151
152 if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) {
153 if (Op->getSrcTy() == v4i32) {
154 I.replaceAllUsesWith(Op->getOperand(0));
155 I.eraseFromParent();
156 }
157 }
158 }
159
createSITypeRewriter()160 FunctionPass *llvm::createSITypeRewriter() {
161 return new SITypeRewriter();
162 }
163