xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVRegularizer.cpp (revision 5f757f3ff9144b609b3c433dfd370cc6bdc191ad)
1bdd1243dSDimitry Andric //===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===//
2bdd1243dSDimitry Andric //
3bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bdd1243dSDimitry Andric //
7bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
8bdd1243dSDimitry Andric //
9bdd1243dSDimitry Andric // This pass implements regularization of LLVM IR for SPIR-V. The prototype of
10bdd1243dSDimitry Andric // the pass was taken from SPIRV-LLVM translator.
11bdd1243dSDimitry Andric //
12bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
13bdd1243dSDimitry Andric 
14bdd1243dSDimitry Andric #include "SPIRV.h"
15bdd1243dSDimitry Andric #include "SPIRVTargetMachine.h"
16bdd1243dSDimitry Andric #include "llvm/Demangle/Demangle.h"
17bdd1243dSDimitry Andric #include "llvm/IR/InstIterator.h"
18bdd1243dSDimitry Andric #include "llvm/IR/InstVisitor.h"
19bdd1243dSDimitry Andric #include "llvm/IR/PassManager.h"
20bdd1243dSDimitry Andric #include "llvm/Transforms/Utils/Cloning.h"
21bdd1243dSDimitry Andric 
22bdd1243dSDimitry Andric #include <list>
23bdd1243dSDimitry Andric 
24bdd1243dSDimitry Andric #define DEBUG_TYPE "spirv-regularizer"
25bdd1243dSDimitry Andric 
26bdd1243dSDimitry Andric using namespace llvm;
27bdd1243dSDimitry Andric 
28bdd1243dSDimitry Andric namespace llvm {
29bdd1243dSDimitry Andric void initializeSPIRVRegularizerPass(PassRegistry &);
30bdd1243dSDimitry Andric }
31bdd1243dSDimitry Andric 
32bdd1243dSDimitry Andric namespace {
33bdd1243dSDimitry Andric struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> {
34bdd1243dSDimitry Andric   DenseMap<Function *, Function *> Old2NewFuncs;
35bdd1243dSDimitry Andric 
36bdd1243dSDimitry Andric public:
37bdd1243dSDimitry Andric   static char ID;
SPIRVRegularizer__anon8e84787b0111::SPIRVRegularizer38bdd1243dSDimitry Andric   SPIRVRegularizer() : FunctionPass(ID) {
39bdd1243dSDimitry Andric     initializeSPIRVRegularizerPass(*PassRegistry::getPassRegistry());
40bdd1243dSDimitry Andric   }
41bdd1243dSDimitry Andric   bool runOnFunction(Function &F) override;
getPassName__anon8e84787b0111::SPIRVRegularizer42bdd1243dSDimitry Andric   StringRef getPassName() const override { return "SPIR-V Regularizer"; }
43bdd1243dSDimitry Andric 
getAnalysisUsage__anon8e84787b0111::SPIRVRegularizer44bdd1243dSDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
45bdd1243dSDimitry Andric     FunctionPass::getAnalysisUsage(AU);
46bdd1243dSDimitry Andric   }
47bdd1243dSDimitry Andric   void visitCallInst(CallInst &CI);
48bdd1243dSDimitry Andric 
49bdd1243dSDimitry Andric private:
50bdd1243dSDimitry Andric   void visitCallScalToVec(CallInst *CI, StringRef MangledName,
51bdd1243dSDimitry Andric                           StringRef DemangledName);
52bdd1243dSDimitry Andric   void runLowerConstExpr(Function &F);
53bdd1243dSDimitry Andric };
54bdd1243dSDimitry Andric } // namespace
55bdd1243dSDimitry Andric 
56bdd1243dSDimitry Andric char SPIRVRegularizer::ID = 0;
57bdd1243dSDimitry Andric 
58bdd1243dSDimitry Andric INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false,
59bdd1243dSDimitry Andric                 false)
60bdd1243dSDimitry Andric 
61bdd1243dSDimitry Andric // Since SPIR-V cannot represent constant expression, constant expressions
62bdd1243dSDimitry Andric // in LLVM IR need to be lowered to instructions. For each function,
63bdd1243dSDimitry Andric // the constant expressions used by instructions of the function are replaced
64bdd1243dSDimitry Andric // by instructions placed in the entry block since it dominates all other BBs.
65bdd1243dSDimitry Andric // Each constant expression only needs to be lowered once in each function
66bdd1243dSDimitry Andric // and all uses of it by instructions in that function are replaced by
67bdd1243dSDimitry Andric // one instruction.
68bdd1243dSDimitry Andric // TODO: remove redundant instructions for common subexpression.
runLowerConstExpr(Function & F)69bdd1243dSDimitry Andric void SPIRVRegularizer::runLowerConstExpr(Function &F) {
70bdd1243dSDimitry Andric   LLVMContext &Ctx = F.getContext();
71bdd1243dSDimitry Andric   std::list<Instruction *> WorkList;
72bdd1243dSDimitry Andric   for (auto &II : instructions(F))
73bdd1243dSDimitry Andric     WorkList.push_back(&II);
74bdd1243dSDimitry Andric 
75bdd1243dSDimitry Andric   auto FBegin = F.begin();
76bdd1243dSDimitry Andric   while (!WorkList.empty()) {
77bdd1243dSDimitry Andric     Instruction *II = WorkList.front();
78bdd1243dSDimitry Andric 
79bdd1243dSDimitry Andric     auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * {
80bdd1243dSDimitry Andric       if (isa<Function>(V))
81bdd1243dSDimitry Andric         return V;
82bdd1243dSDimitry Andric       auto *CE = cast<ConstantExpr>(V);
83bdd1243dSDimitry Andric       LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE);
84bdd1243dSDimitry Andric       auto ReplInst = CE->getAsInstruction();
85bdd1243dSDimitry Andric       auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back();
86bdd1243dSDimitry Andric       ReplInst->insertBefore(InsPoint);
87bdd1243dSDimitry Andric       LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n');
88bdd1243dSDimitry Andric       std::vector<Instruction *> Users;
89bdd1243dSDimitry Andric       // Do not replace use during iteration of use. Do it in another loop.
90bdd1243dSDimitry Andric       for (auto U : CE->users()) {
91bdd1243dSDimitry Andric         LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n');
92bdd1243dSDimitry Andric         auto InstUser = dyn_cast<Instruction>(U);
93bdd1243dSDimitry Andric         // Only replace users in scope of current function.
94bdd1243dSDimitry Andric         if (InstUser && InstUser->getParent()->getParent() == &F)
95bdd1243dSDimitry Andric           Users.push_back(InstUser);
96bdd1243dSDimitry Andric       }
97bdd1243dSDimitry Andric       for (auto &User : Users) {
98bdd1243dSDimitry Andric         if (ReplInst->getParent() == User->getParent() &&
99bdd1243dSDimitry Andric             User->comesBefore(ReplInst))
100bdd1243dSDimitry Andric           ReplInst->moveBefore(User);
101bdd1243dSDimitry Andric         User->replaceUsesOfWith(CE, ReplInst);
102bdd1243dSDimitry Andric       }
103bdd1243dSDimitry Andric       return ReplInst;
104bdd1243dSDimitry Andric     };
105bdd1243dSDimitry Andric 
106bdd1243dSDimitry Andric     WorkList.pop_front();
107bdd1243dSDimitry Andric     auto LowerConstantVec = [&II, &LowerOp, &WorkList,
108bdd1243dSDimitry Andric                              &Ctx](ConstantVector *Vec,
109bdd1243dSDimitry Andric                                    unsigned NumOfOp) -> Value * {
110bdd1243dSDimitry Andric       if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
111bdd1243dSDimitry Andric             return isa<ConstantExpr>(V) || isa<Function>(V);
112bdd1243dSDimitry Andric           })) {
113bdd1243dSDimitry Andric         // Expand a vector of constexprs and construct it back with
114bdd1243dSDimitry Andric         // series of insertelement instructions.
115bdd1243dSDimitry Andric         std::list<Value *> OpList;
116bdd1243dSDimitry Andric         std::transform(Vec->op_begin(), Vec->op_end(),
117bdd1243dSDimitry Andric                        std::back_inserter(OpList),
118bdd1243dSDimitry Andric                        [LowerOp](Value *V) { return LowerOp(V); });
119bdd1243dSDimitry Andric         Value *Repl = nullptr;
120bdd1243dSDimitry Andric         unsigned Idx = 0;
121bdd1243dSDimitry Andric         auto *PhiII = dyn_cast<PHINode>(II);
122bdd1243dSDimitry Andric         Instruction *InsPoint =
123bdd1243dSDimitry Andric             PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II;
124bdd1243dSDimitry Andric         std::list<Instruction *> ReplList;
125bdd1243dSDimitry Andric         for (auto V : OpList) {
126bdd1243dSDimitry Andric           if (auto *Inst = dyn_cast<Instruction>(V))
127bdd1243dSDimitry Andric             ReplList.push_back(Inst);
128bdd1243dSDimitry Andric           Repl = InsertElementInst::Create(
129bdd1243dSDimitry Andric               (Repl ? Repl : PoisonValue::get(Vec->getType())), V,
130bdd1243dSDimitry Andric               ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "", InsPoint);
131bdd1243dSDimitry Andric         }
132bdd1243dSDimitry Andric         WorkList.splice(WorkList.begin(), ReplList);
133bdd1243dSDimitry Andric         return Repl;
134bdd1243dSDimitry Andric       }
135bdd1243dSDimitry Andric       return nullptr;
136bdd1243dSDimitry Andric     };
137bdd1243dSDimitry Andric     for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
138bdd1243dSDimitry Andric       auto *Op = II->getOperand(OI);
139bdd1243dSDimitry Andric       if (auto *Vec = dyn_cast<ConstantVector>(Op)) {
140bdd1243dSDimitry Andric         Value *ReplInst = LowerConstantVec(Vec, OI);
141bdd1243dSDimitry Andric         if (ReplInst)
142bdd1243dSDimitry Andric           II->replaceUsesOfWith(Op, ReplInst);
143bdd1243dSDimitry Andric       } else if (auto CE = dyn_cast<ConstantExpr>(Op)) {
144bdd1243dSDimitry Andric         WorkList.push_front(cast<Instruction>(LowerOp(CE)));
145bdd1243dSDimitry Andric       } else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) {
146bdd1243dSDimitry Andric         auto ConstMD = dyn_cast<ConstantAsMetadata>(MDAsVal->getMetadata());
147bdd1243dSDimitry Andric         if (!ConstMD)
148bdd1243dSDimitry Andric           continue;
149bdd1243dSDimitry Andric         Constant *C = ConstMD->getValue();
150bdd1243dSDimitry Andric         Value *ReplInst = nullptr;
151bdd1243dSDimitry Andric         if (auto *Vec = dyn_cast<ConstantVector>(C))
152bdd1243dSDimitry Andric           ReplInst = LowerConstantVec(Vec, OI);
153bdd1243dSDimitry Andric         if (auto *CE = dyn_cast<ConstantExpr>(C))
154bdd1243dSDimitry Andric           ReplInst = LowerOp(CE);
155bdd1243dSDimitry Andric         if (!ReplInst)
156bdd1243dSDimitry Andric           continue;
157bdd1243dSDimitry Andric         Metadata *RepMD = ValueAsMetadata::get(ReplInst);
158bdd1243dSDimitry Andric         Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD);
159bdd1243dSDimitry Andric         II->setOperand(OI, RepMDVal);
160bdd1243dSDimitry Andric         WorkList.push_front(cast<Instruction>(ReplInst));
161bdd1243dSDimitry Andric       }
162bdd1243dSDimitry Andric     }
163bdd1243dSDimitry Andric   }
164bdd1243dSDimitry Andric }
165bdd1243dSDimitry Andric 
166bdd1243dSDimitry Andric // It fixes calls to OCL builtins that accept vector arguments and one of them
167bdd1243dSDimitry Andric // is actually a scalar splat.
visitCallInst(CallInst & CI)168bdd1243dSDimitry Andric void SPIRVRegularizer::visitCallInst(CallInst &CI) {
169bdd1243dSDimitry Andric   auto F = CI.getCalledFunction();
170bdd1243dSDimitry Andric   if (!F)
171bdd1243dSDimitry Andric     return;
172bdd1243dSDimitry Andric 
173bdd1243dSDimitry Andric   auto MangledName = F->getName();
17406c3fb27SDimitry Andric   char *NameStr = itaniumDemangle(F->getName().data());
17506c3fb27SDimitry Andric   if (!NameStr)
17606c3fb27SDimitry Andric     return;
177bdd1243dSDimitry Andric   StringRef DemangledName(NameStr);
178bdd1243dSDimitry Andric 
179bdd1243dSDimitry Andric   // TODO: add support for other builtins.
180*5f757f3fSDimitry Andric   if (DemangledName.starts_with("fmin") || DemangledName.starts_with("fmax") ||
181*5f757f3fSDimitry Andric       DemangledName.starts_with("min") || DemangledName.starts_with("max"))
182bdd1243dSDimitry Andric     visitCallScalToVec(&CI, MangledName, DemangledName);
183bdd1243dSDimitry Andric   free(NameStr);
184bdd1243dSDimitry Andric }
185bdd1243dSDimitry Andric 
visitCallScalToVec(CallInst * CI,StringRef MangledName,StringRef DemangledName)186bdd1243dSDimitry Andric void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName,
187bdd1243dSDimitry Andric                                           StringRef DemangledName) {
188bdd1243dSDimitry Andric   // Check if all arguments have the same type - it's simple case.
189bdd1243dSDimitry Andric   auto Uniform = true;
190bdd1243dSDimitry Andric   Type *Arg0Ty = CI->getOperand(0)->getType();
191bdd1243dSDimitry Andric   auto IsArg0Vector = isa<VectorType>(Arg0Ty);
192bdd1243dSDimitry Andric   for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I)
193bdd1243dSDimitry Andric     Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector;
194bdd1243dSDimitry Andric   if (Uniform)
195bdd1243dSDimitry Andric     return;
196bdd1243dSDimitry Andric 
197bdd1243dSDimitry Andric   auto *OldF = CI->getCalledFunction();
198bdd1243dSDimitry Andric   Function *NewF = nullptr;
199bdd1243dSDimitry Andric   if (!Old2NewFuncs.count(OldF)) {
200bdd1243dSDimitry Andric     AttributeList Attrs = CI->getCalledFunction()->getAttributes();
201bdd1243dSDimitry Andric     SmallVector<Type *, 2> ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty};
202bdd1243dSDimitry Andric     auto *NewFTy =
203bdd1243dSDimitry Andric         FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg());
204bdd1243dSDimitry Andric     NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(),
205bdd1243dSDimitry Andric                             *OldF->getParent());
206bdd1243dSDimitry Andric     ValueToValueMapTy VMap;
207bdd1243dSDimitry Andric     auto NewFArgIt = NewF->arg_begin();
208bdd1243dSDimitry Andric     for (auto &Arg : OldF->args()) {
209bdd1243dSDimitry Andric       auto ArgName = Arg.getName();
210bdd1243dSDimitry Andric       NewFArgIt->setName(ArgName);
211bdd1243dSDimitry Andric       VMap[&Arg] = &(*NewFArgIt++);
212bdd1243dSDimitry Andric     }
213bdd1243dSDimitry Andric     SmallVector<ReturnInst *, 8> Returns;
214bdd1243dSDimitry Andric     CloneFunctionInto(NewF, OldF, VMap,
215bdd1243dSDimitry Andric                       CloneFunctionChangeType::LocalChangesOnly, Returns);
216bdd1243dSDimitry Andric     NewF->setAttributes(Attrs);
217bdd1243dSDimitry Andric     Old2NewFuncs[OldF] = NewF;
218bdd1243dSDimitry Andric   } else {
219bdd1243dSDimitry Andric     NewF = Old2NewFuncs[OldF];
220bdd1243dSDimitry Andric   }
221bdd1243dSDimitry Andric   assert(NewF);
222bdd1243dSDimitry Andric 
223bdd1243dSDimitry Andric   // This produces an instruction sequence that implements a splat of
224bdd1243dSDimitry Andric   // CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst
225bdd1243dSDimitry Andric   // and ShuffleVectorInst to generate the same code as the SPIR-V translator.
226bdd1243dSDimitry Andric   // For instance (transcoding/OpMin.ll), this call
227bdd1243dSDimitry Andric   //   call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5)
228bdd1243dSDimitry Andric   // is translated to
229bdd1243dSDimitry Andric   //    %8 = OpUndef %v2uint
230bdd1243dSDimitry Andric   //   %14 = OpConstantComposite %v2uint %uint_1 %uint_10
231bdd1243dSDimitry Andric   //   ...
232bdd1243dSDimitry Andric   //   %10 = OpCompositeInsert %v2uint %uint_5 %8 0
233bdd1243dSDimitry Andric   //   %11 = OpVectorShuffle %v2uint %10 %8 0 0
234bdd1243dSDimitry Andric   // %call = OpExtInst %v2uint %1 s_min %14 %11
235bdd1243dSDimitry Andric   auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0);
236bdd1243dSDimitry Andric   PoisonValue *PVal = PoisonValue::get(Arg0Ty);
237bdd1243dSDimitry Andric   Instruction *Inst =
238bdd1243dSDimitry Andric       InsertElementInst::Create(PVal, CI->getOperand(1), ConstInt, "", CI);
239bdd1243dSDimitry Andric   ElementCount VecElemCount = cast<VectorType>(Arg0Ty)->getElementCount();
240bdd1243dSDimitry Andric   Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt);
241bdd1243dSDimitry Andric   Value *NewVec = new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI);
242bdd1243dSDimitry Andric   CI->setOperand(1, NewVec);
243bdd1243dSDimitry Andric   CI->replaceUsesOfWith(OldF, NewF);
244bdd1243dSDimitry Andric   CI->mutateFunctionType(NewF->getFunctionType());
245bdd1243dSDimitry Andric }
246bdd1243dSDimitry Andric 
runOnFunction(Function & F)247bdd1243dSDimitry Andric bool SPIRVRegularizer::runOnFunction(Function &F) {
248bdd1243dSDimitry Andric   runLowerConstExpr(F);
249bdd1243dSDimitry Andric   visit(F);
250bdd1243dSDimitry Andric   for (auto &OldNew : Old2NewFuncs) {
251bdd1243dSDimitry Andric     Function *OldF = OldNew.first;
252bdd1243dSDimitry Andric     Function *NewF = OldNew.second;
253bdd1243dSDimitry Andric     NewF->takeName(OldF);
254bdd1243dSDimitry Andric     OldF->eraseFromParent();
255bdd1243dSDimitry Andric   }
256bdd1243dSDimitry Andric   return true;
257bdd1243dSDimitry Andric }
258bdd1243dSDimitry Andric 
createSPIRVRegularizerPass()259bdd1243dSDimitry Andric FunctionPass *llvm::createSPIRVRegularizerPass() {
260bdd1243dSDimitry Andric   return new SPIRVRegularizer();
261bdd1243dSDimitry Andric }
262