xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVRegularizer.cpp (revision 304a99091c84f303ff5037dc6bf5455e4cfde7a1)
13544d200SIlia Diachkov //===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===//
23544d200SIlia Diachkov //
33544d200SIlia Diachkov // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43544d200SIlia Diachkov // See https://llvm.org/LICENSE.txt for license information.
53544d200SIlia Diachkov // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63544d200SIlia Diachkov //
73544d200SIlia Diachkov //===----------------------------------------------------------------------===//
83544d200SIlia Diachkov //
93544d200SIlia Diachkov // This pass implements regularization of LLVM IR for SPIR-V. The prototype of
103544d200SIlia Diachkov // the pass was taken from SPIRV-LLVM translator.
113544d200SIlia Diachkov //
123544d200SIlia Diachkov //===----------------------------------------------------------------------===//
133544d200SIlia Diachkov 
143544d200SIlia Diachkov #include "SPIRV.h"
153544d200SIlia Diachkov #include "SPIRVTargetMachine.h"
163544d200SIlia Diachkov #include "llvm/Demangle/Demangle.h"
173544d200SIlia Diachkov #include "llvm/IR/InstIterator.h"
183544d200SIlia Diachkov #include "llvm/IR/InstVisitor.h"
193544d200SIlia Diachkov #include "llvm/IR/PassManager.h"
203544d200SIlia Diachkov #include "llvm/Transforms/Utils/Cloning.h"
213544d200SIlia Diachkov 
223544d200SIlia Diachkov #include <list>
233544d200SIlia Diachkov 
243544d200SIlia Diachkov #define DEBUG_TYPE "spirv-regularizer"
253544d200SIlia Diachkov 
263544d200SIlia Diachkov using namespace llvm;
273544d200SIlia Diachkov 
283544d200SIlia Diachkov namespace llvm {
293544d200SIlia Diachkov void initializeSPIRVRegularizerPass(PassRegistry &);
303544d200SIlia Diachkov }
313544d200SIlia Diachkov 
323544d200SIlia Diachkov namespace {
333544d200SIlia Diachkov struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> {
343544d200SIlia Diachkov   DenseMap<Function *, Function *> Old2NewFuncs;
353544d200SIlia Diachkov 
363544d200SIlia Diachkov public:
373544d200SIlia Diachkov   static char ID;
383544d200SIlia Diachkov   SPIRVRegularizer() : FunctionPass(ID) {
393544d200SIlia Diachkov     initializeSPIRVRegularizerPass(*PassRegistry::getPassRegistry());
403544d200SIlia Diachkov   }
413544d200SIlia Diachkov   bool runOnFunction(Function &F) override;
423544d200SIlia Diachkov   StringRef getPassName() const override { return "SPIR-V Regularizer"; }
433544d200SIlia Diachkov 
443544d200SIlia Diachkov   void getAnalysisUsage(AnalysisUsage &AU) const override {
453544d200SIlia Diachkov     FunctionPass::getAnalysisUsage(AU);
463544d200SIlia Diachkov   }
473544d200SIlia Diachkov   void visitCallInst(CallInst &CI);
483544d200SIlia Diachkov 
493544d200SIlia Diachkov private:
503544d200SIlia Diachkov   void visitCallScalToVec(CallInst *CI, StringRef MangledName,
513544d200SIlia Diachkov                           StringRef DemangledName);
523544d200SIlia Diachkov   void runLowerConstExpr(Function &F);
533544d200SIlia Diachkov };
543544d200SIlia Diachkov } // namespace
553544d200SIlia Diachkov 
563544d200SIlia Diachkov char SPIRVRegularizer::ID = 0;
573544d200SIlia Diachkov 
583544d200SIlia Diachkov INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false,
593544d200SIlia Diachkov                 false)
603544d200SIlia Diachkov 
613544d200SIlia Diachkov // Since SPIR-V cannot represent constant expression, constant expressions
623544d200SIlia Diachkov // in LLVM IR need to be lowered to instructions. For each function,
633544d200SIlia Diachkov // the constant expressions used by instructions of the function are replaced
643544d200SIlia Diachkov // by instructions placed in the entry block since it dominates all other BBs.
653544d200SIlia Diachkov // Each constant expression only needs to be lowered once in each function
663544d200SIlia Diachkov // and all uses of it by instructions in that function are replaced by
673544d200SIlia Diachkov // one instruction.
683544d200SIlia Diachkov // TODO: remove redundant instructions for common subexpression.
693544d200SIlia Diachkov void SPIRVRegularizer::runLowerConstExpr(Function &F) {
703544d200SIlia Diachkov   LLVMContext &Ctx = F.getContext();
713544d200SIlia Diachkov   std::list<Instruction *> WorkList;
723544d200SIlia Diachkov   for (auto &II : instructions(F))
733544d200SIlia Diachkov     WorkList.push_back(&II);
743544d200SIlia Diachkov 
753544d200SIlia Diachkov   auto FBegin = F.begin();
763544d200SIlia Diachkov   while (!WorkList.empty()) {
773544d200SIlia Diachkov     Instruction *II = WorkList.front();
783544d200SIlia Diachkov 
793544d200SIlia Diachkov     auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * {
803544d200SIlia Diachkov       if (isa<Function>(V))
813544d200SIlia Diachkov         return V;
823544d200SIlia Diachkov       auto *CE = cast<ConstantExpr>(V);
833544d200SIlia Diachkov       LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE);
843544d200SIlia Diachkov       auto ReplInst = CE->getAsInstruction();
853544d200SIlia Diachkov       auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back();
86*304a9909SJeremy Morse       ReplInst->insertBefore(InsPoint->getIterator());
873544d200SIlia Diachkov       LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n');
883544d200SIlia Diachkov       std::vector<Instruction *> Users;
893544d200SIlia Diachkov       // Do not replace use during iteration of use. Do it in another loop.
903544d200SIlia Diachkov       for (auto U : CE->users()) {
913544d200SIlia Diachkov         LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n');
923544d200SIlia Diachkov         auto InstUser = dyn_cast<Instruction>(U);
933544d200SIlia Diachkov         // Only replace users in scope of current function.
943544d200SIlia Diachkov         if (InstUser && InstUser->getParent()->getParent() == &F)
953544d200SIlia Diachkov           Users.push_back(InstUser);
963544d200SIlia Diachkov       }
973544d200SIlia Diachkov       for (auto &User : Users) {
983544d200SIlia Diachkov         if (ReplInst->getParent() == User->getParent() &&
993544d200SIlia Diachkov             User->comesBefore(ReplInst))
100*304a9909SJeremy Morse           ReplInst->moveBefore(User->getIterator());
1013544d200SIlia Diachkov         User->replaceUsesOfWith(CE, ReplInst);
1023544d200SIlia Diachkov       }
1033544d200SIlia Diachkov       return ReplInst;
1043544d200SIlia Diachkov     };
1053544d200SIlia Diachkov 
1063544d200SIlia Diachkov     WorkList.pop_front();
1073544d200SIlia Diachkov     auto LowerConstantVec = [&II, &LowerOp, &WorkList,
1083544d200SIlia Diachkov                              &Ctx](ConstantVector *Vec,
1093544d200SIlia Diachkov                                    unsigned NumOfOp) -> Value * {
1103544d200SIlia Diachkov       if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
1113544d200SIlia Diachkov             return isa<ConstantExpr>(V) || isa<Function>(V);
1123544d200SIlia Diachkov           })) {
1133544d200SIlia Diachkov         // Expand a vector of constexprs and construct it back with
1143544d200SIlia Diachkov         // series of insertelement instructions.
1153544d200SIlia Diachkov         std::list<Value *> OpList;
1163544d200SIlia Diachkov         std::transform(Vec->op_begin(), Vec->op_end(),
1173544d200SIlia Diachkov                        std::back_inserter(OpList),
1183544d200SIlia Diachkov                        [LowerOp](Value *V) { return LowerOp(V); });
1193544d200SIlia Diachkov         Value *Repl = nullptr;
1203544d200SIlia Diachkov         unsigned Idx = 0;
1213544d200SIlia Diachkov         auto *PhiII = dyn_cast<PHINode>(II);
1223544d200SIlia Diachkov         Instruction *InsPoint =
1233544d200SIlia Diachkov             PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II;
1243544d200SIlia Diachkov         std::list<Instruction *> ReplList;
1253544d200SIlia Diachkov         for (auto V : OpList) {
1263544d200SIlia Diachkov           if (auto *Inst = dyn_cast<Instruction>(V))
1273544d200SIlia Diachkov             ReplList.push_back(Inst);
1283544d200SIlia Diachkov           Repl = InsertElementInst::Create(
1293544d200SIlia Diachkov               (Repl ? Repl : PoisonValue::get(Vec->getType())), V,
130caf0897cSJustin Bogner               ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "",
131caf0897cSJustin Bogner               InsPoint->getIterator());
1323544d200SIlia Diachkov         }
1333544d200SIlia Diachkov         WorkList.splice(WorkList.begin(), ReplList);
1343544d200SIlia Diachkov         return Repl;
1353544d200SIlia Diachkov       }
1363544d200SIlia Diachkov       return nullptr;
1373544d200SIlia Diachkov     };
1383544d200SIlia Diachkov     for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
1393544d200SIlia Diachkov       auto *Op = II->getOperand(OI);
1403544d200SIlia Diachkov       if (auto *Vec = dyn_cast<ConstantVector>(Op)) {
1413544d200SIlia Diachkov         Value *ReplInst = LowerConstantVec(Vec, OI);
1423544d200SIlia Diachkov         if (ReplInst)
1433544d200SIlia Diachkov           II->replaceUsesOfWith(Op, ReplInst);
1443544d200SIlia Diachkov       } else if (auto CE = dyn_cast<ConstantExpr>(Op)) {
1453544d200SIlia Diachkov         WorkList.push_front(cast<Instruction>(LowerOp(CE)));
1463544d200SIlia Diachkov       } else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) {
1473544d200SIlia Diachkov         auto ConstMD = dyn_cast<ConstantAsMetadata>(MDAsVal->getMetadata());
1483544d200SIlia Diachkov         if (!ConstMD)
1493544d200SIlia Diachkov           continue;
1503544d200SIlia Diachkov         Constant *C = ConstMD->getValue();
1513544d200SIlia Diachkov         Value *ReplInst = nullptr;
1523544d200SIlia Diachkov         if (auto *Vec = dyn_cast<ConstantVector>(C))
1533544d200SIlia Diachkov           ReplInst = LowerConstantVec(Vec, OI);
1543544d200SIlia Diachkov         if (auto *CE = dyn_cast<ConstantExpr>(C))
1553544d200SIlia Diachkov           ReplInst = LowerOp(CE);
1563544d200SIlia Diachkov         if (!ReplInst)
1573544d200SIlia Diachkov           continue;
1583544d200SIlia Diachkov         Metadata *RepMD = ValueAsMetadata::get(ReplInst);
1593544d200SIlia Diachkov         Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD);
1603544d200SIlia Diachkov         II->setOperand(OI, RepMDVal);
1613544d200SIlia Diachkov         WorkList.push_front(cast<Instruction>(ReplInst));
1623544d200SIlia Diachkov       }
1633544d200SIlia Diachkov     }
1643544d200SIlia Diachkov   }
1653544d200SIlia Diachkov }
1663544d200SIlia Diachkov 
1673544d200SIlia Diachkov // It fixes calls to OCL builtins that accept vector arguments and one of them
1683544d200SIlia Diachkov // is actually a scalar splat.
1693544d200SIlia Diachkov void SPIRVRegularizer::visitCallInst(CallInst &CI) {
1703544d200SIlia Diachkov   auto F = CI.getCalledFunction();
1713544d200SIlia Diachkov   if (!F)
1723544d200SIlia Diachkov     return;
1733544d200SIlia Diachkov 
1743544d200SIlia Diachkov   auto MangledName = F->getName();
1758e6b3cc4SFangrui Song   char *NameStr = itaniumDemangle(F->getName().data());
1768e6b3cc4SFangrui Song   if (!NameStr)
1778e6b3cc4SFangrui Song     return;
1783544d200SIlia Diachkov   StringRef DemangledName(NameStr);
1793544d200SIlia Diachkov 
1803544d200SIlia Diachkov   // TODO: add support for other builtins.
181395f9ce3SKazu Hirata   if (DemangledName.starts_with("fmin") || DemangledName.starts_with("fmax") ||
182395f9ce3SKazu Hirata       DemangledName.starts_with("min") || DemangledName.starts_with("max"))
1833544d200SIlia Diachkov     visitCallScalToVec(&CI, MangledName, DemangledName);
1843544d200SIlia Diachkov   free(NameStr);
1853544d200SIlia Diachkov }
1863544d200SIlia Diachkov 
1873544d200SIlia Diachkov void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName,
1883544d200SIlia Diachkov                                           StringRef DemangledName) {
1893544d200SIlia Diachkov   // Check if all arguments have the same type - it's simple case.
1903544d200SIlia Diachkov   auto Uniform = true;
1913544d200SIlia Diachkov   Type *Arg0Ty = CI->getOperand(0)->getType();
1923544d200SIlia Diachkov   auto IsArg0Vector = isa<VectorType>(Arg0Ty);
1933544d200SIlia Diachkov   for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I)
1943544d200SIlia Diachkov     Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector;
1953544d200SIlia Diachkov   if (Uniform)
1963544d200SIlia Diachkov     return;
1973544d200SIlia Diachkov 
1983544d200SIlia Diachkov   auto *OldF = CI->getCalledFunction();
1993544d200SIlia Diachkov   Function *NewF = nullptr;
2003544d200SIlia Diachkov   if (!Old2NewFuncs.count(OldF)) {
2013544d200SIlia Diachkov     AttributeList Attrs = CI->getCalledFunction()->getAttributes();
2023544d200SIlia Diachkov     SmallVector<Type *, 2> ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty};
2033544d200SIlia Diachkov     auto *NewFTy =
2043544d200SIlia Diachkov         FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg());
2053544d200SIlia Diachkov     NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(),
2063544d200SIlia Diachkov                             *OldF->getParent());
2073544d200SIlia Diachkov     ValueToValueMapTy VMap;
2083544d200SIlia Diachkov     auto NewFArgIt = NewF->arg_begin();
2093544d200SIlia Diachkov     for (auto &Arg : OldF->args()) {
2103544d200SIlia Diachkov       auto ArgName = Arg.getName();
2113544d200SIlia Diachkov       NewFArgIt->setName(ArgName);
2123544d200SIlia Diachkov       VMap[&Arg] = &(*NewFArgIt++);
2133544d200SIlia Diachkov     }
2143544d200SIlia Diachkov     SmallVector<ReturnInst *, 8> Returns;
2153544d200SIlia Diachkov     CloneFunctionInto(NewF, OldF, VMap,
2163544d200SIlia Diachkov                       CloneFunctionChangeType::LocalChangesOnly, Returns);
2173544d200SIlia Diachkov     NewF->setAttributes(Attrs);
2183544d200SIlia Diachkov     Old2NewFuncs[OldF] = NewF;
2193544d200SIlia Diachkov   } else {
2203544d200SIlia Diachkov     NewF = Old2NewFuncs[OldF];
2213544d200SIlia Diachkov   }
2223544d200SIlia Diachkov   assert(NewF);
2233544d200SIlia Diachkov 
2244421b24fSIlia Diachkov   // This produces an instruction sequence that implements a splat of
2254421b24fSIlia Diachkov   // CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst
2264421b24fSIlia Diachkov   // and ShuffleVectorInst to generate the same code as the SPIR-V translator.
2274421b24fSIlia Diachkov   // For instance (transcoding/OpMin.ll), this call
2284421b24fSIlia Diachkov   //   call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5)
2294421b24fSIlia Diachkov   // is translated to
2304421b24fSIlia Diachkov   //    %8 = OpUndef %v2uint
2314421b24fSIlia Diachkov   //   %14 = OpConstantComposite %v2uint %uint_1 %uint_10
2324421b24fSIlia Diachkov   //   ...
2334421b24fSIlia Diachkov   //   %10 = OpCompositeInsert %v2uint %uint_5 %8 0
2344421b24fSIlia Diachkov   //   %11 = OpVectorShuffle %v2uint %10 %8 0 0
2354421b24fSIlia Diachkov   // %call = OpExtInst %v2uint %1 s_min %14 %11
2363544d200SIlia Diachkov   auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0);
2374421b24fSIlia Diachkov   PoisonValue *PVal = PoisonValue::get(Arg0Ty);
238caf0897cSJustin Bogner   Instruction *Inst = InsertElementInst::Create(
239caf0897cSJustin Bogner       PVal, CI->getOperand(1), ConstInt, "", CI->getIterator());
2403544d200SIlia Diachkov   ElementCount VecElemCount = cast<VectorType>(Arg0Ty)->getElementCount();
2413544d200SIlia Diachkov   Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt);
242caf0897cSJustin Bogner   Value *NewVec =
243caf0897cSJustin Bogner       new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI->getIterator());
2443544d200SIlia Diachkov   CI->setOperand(1, NewVec);
2453544d200SIlia Diachkov   CI->replaceUsesOfWith(OldF, NewF);
2463544d200SIlia Diachkov   CI->mutateFunctionType(NewF->getFunctionType());
2473544d200SIlia Diachkov }
2483544d200SIlia Diachkov 
2493544d200SIlia Diachkov bool SPIRVRegularizer::runOnFunction(Function &F) {
2503544d200SIlia Diachkov   runLowerConstExpr(F);
2513544d200SIlia Diachkov   visit(F);
2523544d200SIlia Diachkov   for (auto &OldNew : Old2NewFuncs) {
2533544d200SIlia Diachkov     Function *OldF = OldNew.first;
2543544d200SIlia Diachkov     Function *NewF = OldNew.second;
2553544d200SIlia Diachkov     NewF->takeName(OldF);
2563544d200SIlia Diachkov     OldF->eraseFromParent();
2573544d200SIlia Diachkov   }
2583544d200SIlia Diachkov   return true;
2593544d200SIlia Diachkov }
2603544d200SIlia Diachkov 
2613544d200SIlia Diachkov FunctionPass *llvm::createSPIRVRegularizerPass() {
2623544d200SIlia Diachkov   return new SPIRVRegularizer();
2633544d200SIlia Diachkov }
264