xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVRegularizer.cpp (revision 304a99091c84f303ff5037dc6bf5455e4cfde7a1)
1 //===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements regularization of LLVM IR for SPIR-V. The prototype of
10 // the pass was taken from SPIRV-LLVM translator.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "SPIRV.h"
15 #include "SPIRVTargetMachine.h"
16 #include "llvm/Demangle/Demangle.h"
17 #include "llvm/IR/InstIterator.h"
18 #include "llvm/IR/InstVisitor.h"
19 #include "llvm/IR/PassManager.h"
20 #include "llvm/Transforms/Utils/Cloning.h"
21 
22 #include <list>
23 
24 #define DEBUG_TYPE "spirv-regularizer"
25 
26 using namespace llvm;
27 
28 namespace llvm {
29 void initializeSPIRVRegularizerPass(PassRegistry &);
30 }
31 
32 namespace {
33 struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> {
34   DenseMap<Function *, Function *> Old2NewFuncs;
35 
36 public:
37   static char ID;
38   SPIRVRegularizer() : FunctionPass(ID) {
39     initializeSPIRVRegularizerPass(*PassRegistry::getPassRegistry());
40   }
41   bool runOnFunction(Function &F) override;
42   StringRef getPassName() const override { return "SPIR-V Regularizer"; }
43 
44   void getAnalysisUsage(AnalysisUsage &AU) const override {
45     FunctionPass::getAnalysisUsage(AU);
46   }
47   void visitCallInst(CallInst &CI);
48 
49 private:
50   void visitCallScalToVec(CallInst *CI, StringRef MangledName,
51                           StringRef DemangledName);
52   void runLowerConstExpr(Function &F);
53 };
54 } // namespace
55 
56 char SPIRVRegularizer::ID = 0;
57 
58 INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false,
59                 false)
60 
61 // Since SPIR-V cannot represent constant expression, constant expressions
62 // in LLVM IR need to be lowered to instructions. For each function,
63 // the constant expressions used by instructions of the function are replaced
64 // by instructions placed in the entry block since it dominates all other BBs.
65 // Each constant expression only needs to be lowered once in each function
66 // and all uses of it by instructions in that function are replaced by
67 // one instruction.
68 // TODO: remove redundant instructions for common subexpression.
69 void SPIRVRegularizer::runLowerConstExpr(Function &F) {
70   LLVMContext &Ctx = F.getContext();
71   std::list<Instruction *> WorkList;
72   for (auto &II : instructions(F))
73     WorkList.push_back(&II);
74 
75   auto FBegin = F.begin();
76   while (!WorkList.empty()) {
77     Instruction *II = WorkList.front();
78 
79     auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * {
80       if (isa<Function>(V))
81         return V;
82       auto *CE = cast<ConstantExpr>(V);
83       LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE);
84       auto ReplInst = CE->getAsInstruction();
85       auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back();
86       ReplInst->insertBefore(InsPoint->getIterator());
87       LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n');
88       std::vector<Instruction *> Users;
89       // Do not replace use during iteration of use. Do it in another loop.
90       for (auto U : CE->users()) {
91         LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n');
92         auto InstUser = dyn_cast<Instruction>(U);
93         // Only replace users in scope of current function.
94         if (InstUser && InstUser->getParent()->getParent() == &F)
95           Users.push_back(InstUser);
96       }
97       for (auto &User : Users) {
98         if (ReplInst->getParent() == User->getParent() &&
99             User->comesBefore(ReplInst))
100           ReplInst->moveBefore(User->getIterator());
101         User->replaceUsesOfWith(CE, ReplInst);
102       }
103       return ReplInst;
104     };
105 
106     WorkList.pop_front();
107     auto LowerConstantVec = [&II, &LowerOp, &WorkList,
108                              &Ctx](ConstantVector *Vec,
109                                    unsigned NumOfOp) -> Value * {
110       if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
111             return isa<ConstantExpr>(V) || isa<Function>(V);
112           })) {
113         // Expand a vector of constexprs and construct it back with
114         // series of insertelement instructions.
115         std::list<Value *> OpList;
116         std::transform(Vec->op_begin(), Vec->op_end(),
117                        std::back_inserter(OpList),
118                        [LowerOp](Value *V) { return LowerOp(V); });
119         Value *Repl = nullptr;
120         unsigned Idx = 0;
121         auto *PhiII = dyn_cast<PHINode>(II);
122         Instruction *InsPoint =
123             PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II;
124         std::list<Instruction *> ReplList;
125         for (auto V : OpList) {
126           if (auto *Inst = dyn_cast<Instruction>(V))
127             ReplList.push_back(Inst);
128           Repl = InsertElementInst::Create(
129               (Repl ? Repl : PoisonValue::get(Vec->getType())), V,
130               ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "",
131               InsPoint->getIterator());
132         }
133         WorkList.splice(WorkList.begin(), ReplList);
134         return Repl;
135       }
136       return nullptr;
137     };
138     for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
139       auto *Op = II->getOperand(OI);
140       if (auto *Vec = dyn_cast<ConstantVector>(Op)) {
141         Value *ReplInst = LowerConstantVec(Vec, OI);
142         if (ReplInst)
143           II->replaceUsesOfWith(Op, ReplInst);
144       } else if (auto CE = dyn_cast<ConstantExpr>(Op)) {
145         WorkList.push_front(cast<Instruction>(LowerOp(CE)));
146       } else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) {
147         auto ConstMD = dyn_cast<ConstantAsMetadata>(MDAsVal->getMetadata());
148         if (!ConstMD)
149           continue;
150         Constant *C = ConstMD->getValue();
151         Value *ReplInst = nullptr;
152         if (auto *Vec = dyn_cast<ConstantVector>(C))
153           ReplInst = LowerConstantVec(Vec, OI);
154         if (auto *CE = dyn_cast<ConstantExpr>(C))
155           ReplInst = LowerOp(CE);
156         if (!ReplInst)
157           continue;
158         Metadata *RepMD = ValueAsMetadata::get(ReplInst);
159         Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD);
160         II->setOperand(OI, RepMDVal);
161         WorkList.push_front(cast<Instruction>(ReplInst));
162       }
163     }
164   }
165 }
166 
167 // It fixes calls to OCL builtins that accept vector arguments and one of them
168 // is actually a scalar splat.
169 void SPIRVRegularizer::visitCallInst(CallInst &CI) {
170   auto F = CI.getCalledFunction();
171   if (!F)
172     return;
173 
174   auto MangledName = F->getName();
175   char *NameStr = itaniumDemangle(F->getName().data());
176   if (!NameStr)
177     return;
178   StringRef DemangledName(NameStr);
179 
180   // TODO: add support for other builtins.
181   if (DemangledName.starts_with("fmin") || DemangledName.starts_with("fmax") ||
182       DemangledName.starts_with("min") || DemangledName.starts_with("max"))
183     visitCallScalToVec(&CI, MangledName, DemangledName);
184   free(NameStr);
185 }
186 
187 void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName,
188                                           StringRef DemangledName) {
189   // Check if all arguments have the same type - it's simple case.
190   auto Uniform = true;
191   Type *Arg0Ty = CI->getOperand(0)->getType();
192   auto IsArg0Vector = isa<VectorType>(Arg0Ty);
193   for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I)
194     Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector;
195   if (Uniform)
196     return;
197 
198   auto *OldF = CI->getCalledFunction();
199   Function *NewF = nullptr;
200   if (!Old2NewFuncs.count(OldF)) {
201     AttributeList Attrs = CI->getCalledFunction()->getAttributes();
202     SmallVector<Type *, 2> ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty};
203     auto *NewFTy =
204         FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg());
205     NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(),
206                             *OldF->getParent());
207     ValueToValueMapTy VMap;
208     auto NewFArgIt = NewF->arg_begin();
209     for (auto &Arg : OldF->args()) {
210       auto ArgName = Arg.getName();
211       NewFArgIt->setName(ArgName);
212       VMap[&Arg] = &(*NewFArgIt++);
213     }
214     SmallVector<ReturnInst *, 8> Returns;
215     CloneFunctionInto(NewF, OldF, VMap,
216                       CloneFunctionChangeType::LocalChangesOnly, Returns);
217     NewF->setAttributes(Attrs);
218     Old2NewFuncs[OldF] = NewF;
219   } else {
220     NewF = Old2NewFuncs[OldF];
221   }
222   assert(NewF);
223 
224   // This produces an instruction sequence that implements a splat of
225   // CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst
226   // and ShuffleVectorInst to generate the same code as the SPIR-V translator.
227   // For instance (transcoding/OpMin.ll), this call
228   //   call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5)
229   // is translated to
230   //    %8 = OpUndef %v2uint
231   //   %14 = OpConstantComposite %v2uint %uint_1 %uint_10
232   //   ...
233   //   %10 = OpCompositeInsert %v2uint %uint_5 %8 0
234   //   %11 = OpVectorShuffle %v2uint %10 %8 0 0
235   // %call = OpExtInst %v2uint %1 s_min %14 %11
236   auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0);
237   PoisonValue *PVal = PoisonValue::get(Arg0Ty);
238   Instruction *Inst = InsertElementInst::Create(
239       PVal, CI->getOperand(1), ConstInt, "", CI->getIterator());
240   ElementCount VecElemCount = cast<VectorType>(Arg0Ty)->getElementCount();
241   Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt);
242   Value *NewVec =
243       new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI->getIterator());
244   CI->setOperand(1, NewVec);
245   CI->replaceUsesOfWith(OldF, NewF);
246   CI->mutateFunctionType(NewF->getFunctionType());
247 }
248 
249 bool SPIRVRegularizer::runOnFunction(Function &F) {
250   runLowerConstExpr(F);
251   visit(F);
252   for (auto &OldNew : Old2NewFuncs) {
253     Function *OldF = OldNew.first;
254     Function *NewF = OldNew.second;
255     NewF->takeName(OldF);
256     OldF->eraseFromParent();
257   }
258   return true;
259 }
260 
261 FunctionPass *llvm::createSPIRVRegularizerPass() {
262   return new SPIRVRegularizer();
263 }
264