13544d200SIlia Diachkov //===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===// 23544d200SIlia Diachkov // 33544d200SIlia Diachkov // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 43544d200SIlia Diachkov // See https://llvm.org/LICENSE.txt for license information. 53544d200SIlia Diachkov // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 63544d200SIlia Diachkov // 73544d200SIlia Diachkov //===----------------------------------------------------------------------===// 83544d200SIlia Diachkov // 93544d200SIlia Diachkov // This pass implements regularization of LLVM IR for SPIR-V. The prototype of 103544d200SIlia Diachkov // the pass was taken from SPIRV-LLVM translator. 113544d200SIlia Diachkov // 123544d200SIlia Diachkov //===----------------------------------------------------------------------===// 133544d200SIlia Diachkov 143544d200SIlia Diachkov #include "SPIRV.h" 153544d200SIlia Diachkov #include "SPIRVTargetMachine.h" 163544d200SIlia Diachkov #include "llvm/Demangle/Demangle.h" 173544d200SIlia Diachkov #include "llvm/IR/InstIterator.h" 183544d200SIlia Diachkov #include "llvm/IR/InstVisitor.h" 193544d200SIlia Diachkov #include "llvm/IR/PassManager.h" 203544d200SIlia Diachkov #include "llvm/Transforms/Utils/Cloning.h" 213544d200SIlia Diachkov 223544d200SIlia Diachkov #include <list> 233544d200SIlia Diachkov 243544d200SIlia Diachkov #define DEBUG_TYPE "spirv-regularizer" 253544d200SIlia Diachkov 263544d200SIlia Diachkov using namespace llvm; 273544d200SIlia Diachkov 283544d200SIlia Diachkov namespace llvm { 293544d200SIlia Diachkov void initializeSPIRVRegularizerPass(PassRegistry &); 303544d200SIlia Diachkov } 313544d200SIlia Diachkov 323544d200SIlia Diachkov namespace { 333544d200SIlia Diachkov struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> { 343544d200SIlia Diachkov DenseMap<Function *, Function *> Old2NewFuncs; 353544d200SIlia Diachkov 363544d200SIlia Diachkov public: 373544d200SIlia Diachkov static char ID; 383544d200SIlia Diachkov SPIRVRegularizer() : FunctionPass(ID) { 393544d200SIlia Diachkov initializeSPIRVRegularizerPass(*PassRegistry::getPassRegistry()); 403544d200SIlia Diachkov } 413544d200SIlia Diachkov bool runOnFunction(Function &F) override; 423544d200SIlia Diachkov StringRef getPassName() const override { return "SPIR-V Regularizer"; } 433544d200SIlia Diachkov 443544d200SIlia Diachkov void getAnalysisUsage(AnalysisUsage &AU) const override { 453544d200SIlia Diachkov FunctionPass::getAnalysisUsage(AU); 463544d200SIlia Diachkov } 473544d200SIlia Diachkov void visitCallInst(CallInst &CI); 483544d200SIlia Diachkov 493544d200SIlia Diachkov private: 503544d200SIlia Diachkov void visitCallScalToVec(CallInst *CI, StringRef MangledName, 513544d200SIlia Diachkov StringRef DemangledName); 523544d200SIlia Diachkov void runLowerConstExpr(Function &F); 533544d200SIlia Diachkov }; 543544d200SIlia Diachkov } // namespace 553544d200SIlia Diachkov 563544d200SIlia Diachkov char SPIRVRegularizer::ID = 0; 573544d200SIlia Diachkov 583544d200SIlia Diachkov INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false, 593544d200SIlia Diachkov false) 603544d200SIlia Diachkov 613544d200SIlia Diachkov // Since SPIR-V cannot represent constant expression, constant expressions 623544d200SIlia Diachkov // in LLVM IR need to be lowered to instructions. For each function, 633544d200SIlia Diachkov // the constant expressions used by instructions of the function are replaced 643544d200SIlia Diachkov // by instructions placed in the entry block since it dominates all other BBs. 653544d200SIlia Diachkov // Each constant expression only needs to be lowered once in each function 663544d200SIlia Diachkov // and all uses of it by instructions in that function are replaced by 673544d200SIlia Diachkov // one instruction. 683544d200SIlia Diachkov // TODO: remove redundant instructions for common subexpression. 693544d200SIlia Diachkov void SPIRVRegularizer::runLowerConstExpr(Function &F) { 703544d200SIlia Diachkov LLVMContext &Ctx = F.getContext(); 713544d200SIlia Diachkov std::list<Instruction *> WorkList; 723544d200SIlia Diachkov for (auto &II : instructions(F)) 733544d200SIlia Diachkov WorkList.push_back(&II); 743544d200SIlia Diachkov 753544d200SIlia Diachkov auto FBegin = F.begin(); 763544d200SIlia Diachkov while (!WorkList.empty()) { 773544d200SIlia Diachkov Instruction *II = WorkList.front(); 783544d200SIlia Diachkov 793544d200SIlia Diachkov auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * { 803544d200SIlia Diachkov if (isa<Function>(V)) 813544d200SIlia Diachkov return V; 823544d200SIlia Diachkov auto *CE = cast<ConstantExpr>(V); 833544d200SIlia Diachkov LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE); 843544d200SIlia Diachkov auto ReplInst = CE->getAsInstruction(); 853544d200SIlia Diachkov auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back(); 86*304a9909SJeremy Morse ReplInst->insertBefore(InsPoint->getIterator()); 873544d200SIlia Diachkov LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n'); 883544d200SIlia Diachkov std::vector<Instruction *> Users; 893544d200SIlia Diachkov // Do not replace use during iteration of use. Do it in another loop. 903544d200SIlia Diachkov for (auto U : CE->users()) { 913544d200SIlia Diachkov LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n'); 923544d200SIlia Diachkov auto InstUser = dyn_cast<Instruction>(U); 933544d200SIlia Diachkov // Only replace users in scope of current function. 943544d200SIlia Diachkov if (InstUser && InstUser->getParent()->getParent() == &F) 953544d200SIlia Diachkov Users.push_back(InstUser); 963544d200SIlia Diachkov } 973544d200SIlia Diachkov for (auto &User : Users) { 983544d200SIlia Diachkov if (ReplInst->getParent() == User->getParent() && 993544d200SIlia Diachkov User->comesBefore(ReplInst)) 100*304a9909SJeremy Morse ReplInst->moveBefore(User->getIterator()); 1013544d200SIlia Diachkov User->replaceUsesOfWith(CE, ReplInst); 1023544d200SIlia Diachkov } 1033544d200SIlia Diachkov return ReplInst; 1043544d200SIlia Diachkov }; 1053544d200SIlia Diachkov 1063544d200SIlia Diachkov WorkList.pop_front(); 1073544d200SIlia Diachkov auto LowerConstantVec = [&II, &LowerOp, &WorkList, 1083544d200SIlia Diachkov &Ctx](ConstantVector *Vec, 1093544d200SIlia Diachkov unsigned NumOfOp) -> Value * { 1103544d200SIlia Diachkov if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) { 1113544d200SIlia Diachkov return isa<ConstantExpr>(V) || isa<Function>(V); 1123544d200SIlia Diachkov })) { 1133544d200SIlia Diachkov // Expand a vector of constexprs and construct it back with 1143544d200SIlia Diachkov // series of insertelement instructions. 1153544d200SIlia Diachkov std::list<Value *> OpList; 1163544d200SIlia Diachkov std::transform(Vec->op_begin(), Vec->op_end(), 1173544d200SIlia Diachkov std::back_inserter(OpList), 1183544d200SIlia Diachkov [LowerOp](Value *V) { return LowerOp(V); }); 1193544d200SIlia Diachkov Value *Repl = nullptr; 1203544d200SIlia Diachkov unsigned Idx = 0; 1213544d200SIlia Diachkov auto *PhiII = dyn_cast<PHINode>(II); 1223544d200SIlia Diachkov Instruction *InsPoint = 1233544d200SIlia Diachkov PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II; 1243544d200SIlia Diachkov std::list<Instruction *> ReplList; 1253544d200SIlia Diachkov for (auto V : OpList) { 1263544d200SIlia Diachkov if (auto *Inst = dyn_cast<Instruction>(V)) 1273544d200SIlia Diachkov ReplList.push_back(Inst); 1283544d200SIlia Diachkov Repl = InsertElementInst::Create( 1293544d200SIlia Diachkov (Repl ? Repl : PoisonValue::get(Vec->getType())), V, 130caf0897cSJustin Bogner ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "", 131caf0897cSJustin Bogner InsPoint->getIterator()); 1323544d200SIlia Diachkov } 1333544d200SIlia Diachkov WorkList.splice(WorkList.begin(), ReplList); 1343544d200SIlia Diachkov return Repl; 1353544d200SIlia Diachkov } 1363544d200SIlia Diachkov return nullptr; 1373544d200SIlia Diachkov }; 1383544d200SIlia Diachkov for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) { 1393544d200SIlia Diachkov auto *Op = II->getOperand(OI); 1403544d200SIlia Diachkov if (auto *Vec = dyn_cast<ConstantVector>(Op)) { 1413544d200SIlia Diachkov Value *ReplInst = LowerConstantVec(Vec, OI); 1423544d200SIlia Diachkov if (ReplInst) 1433544d200SIlia Diachkov II->replaceUsesOfWith(Op, ReplInst); 1443544d200SIlia Diachkov } else if (auto CE = dyn_cast<ConstantExpr>(Op)) { 1453544d200SIlia Diachkov WorkList.push_front(cast<Instruction>(LowerOp(CE))); 1463544d200SIlia Diachkov } else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) { 1473544d200SIlia Diachkov auto ConstMD = dyn_cast<ConstantAsMetadata>(MDAsVal->getMetadata()); 1483544d200SIlia Diachkov if (!ConstMD) 1493544d200SIlia Diachkov continue; 1503544d200SIlia Diachkov Constant *C = ConstMD->getValue(); 1513544d200SIlia Diachkov Value *ReplInst = nullptr; 1523544d200SIlia Diachkov if (auto *Vec = dyn_cast<ConstantVector>(C)) 1533544d200SIlia Diachkov ReplInst = LowerConstantVec(Vec, OI); 1543544d200SIlia Diachkov if (auto *CE = dyn_cast<ConstantExpr>(C)) 1553544d200SIlia Diachkov ReplInst = LowerOp(CE); 1563544d200SIlia Diachkov if (!ReplInst) 1573544d200SIlia Diachkov continue; 1583544d200SIlia Diachkov Metadata *RepMD = ValueAsMetadata::get(ReplInst); 1593544d200SIlia Diachkov Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD); 1603544d200SIlia Diachkov II->setOperand(OI, RepMDVal); 1613544d200SIlia Diachkov WorkList.push_front(cast<Instruction>(ReplInst)); 1623544d200SIlia Diachkov } 1633544d200SIlia Diachkov } 1643544d200SIlia Diachkov } 1653544d200SIlia Diachkov } 1663544d200SIlia Diachkov 1673544d200SIlia Diachkov // It fixes calls to OCL builtins that accept vector arguments and one of them 1683544d200SIlia Diachkov // is actually a scalar splat. 1693544d200SIlia Diachkov void SPIRVRegularizer::visitCallInst(CallInst &CI) { 1703544d200SIlia Diachkov auto F = CI.getCalledFunction(); 1713544d200SIlia Diachkov if (!F) 1723544d200SIlia Diachkov return; 1733544d200SIlia Diachkov 1743544d200SIlia Diachkov auto MangledName = F->getName(); 1758e6b3cc4SFangrui Song char *NameStr = itaniumDemangle(F->getName().data()); 1768e6b3cc4SFangrui Song if (!NameStr) 1778e6b3cc4SFangrui Song return; 1783544d200SIlia Diachkov StringRef DemangledName(NameStr); 1793544d200SIlia Diachkov 1803544d200SIlia Diachkov // TODO: add support for other builtins. 181395f9ce3SKazu Hirata if (DemangledName.starts_with("fmin") || DemangledName.starts_with("fmax") || 182395f9ce3SKazu Hirata DemangledName.starts_with("min") || DemangledName.starts_with("max")) 1833544d200SIlia Diachkov visitCallScalToVec(&CI, MangledName, DemangledName); 1843544d200SIlia Diachkov free(NameStr); 1853544d200SIlia Diachkov } 1863544d200SIlia Diachkov 1873544d200SIlia Diachkov void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName, 1883544d200SIlia Diachkov StringRef DemangledName) { 1893544d200SIlia Diachkov // Check if all arguments have the same type - it's simple case. 1903544d200SIlia Diachkov auto Uniform = true; 1913544d200SIlia Diachkov Type *Arg0Ty = CI->getOperand(0)->getType(); 1923544d200SIlia Diachkov auto IsArg0Vector = isa<VectorType>(Arg0Ty); 1933544d200SIlia Diachkov for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I) 1943544d200SIlia Diachkov Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector; 1953544d200SIlia Diachkov if (Uniform) 1963544d200SIlia Diachkov return; 1973544d200SIlia Diachkov 1983544d200SIlia Diachkov auto *OldF = CI->getCalledFunction(); 1993544d200SIlia Diachkov Function *NewF = nullptr; 2003544d200SIlia Diachkov if (!Old2NewFuncs.count(OldF)) { 2013544d200SIlia Diachkov AttributeList Attrs = CI->getCalledFunction()->getAttributes(); 2023544d200SIlia Diachkov SmallVector<Type *, 2> ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty}; 2033544d200SIlia Diachkov auto *NewFTy = 2043544d200SIlia Diachkov FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg()); 2053544d200SIlia Diachkov NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(), 2063544d200SIlia Diachkov *OldF->getParent()); 2073544d200SIlia Diachkov ValueToValueMapTy VMap; 2083544d200SIlia Diachkov auto NewFArgIt = NewF->arg_begin(); 2093544d200SIlia Diachkov for (auto &Arg : OldF->args()) { 2103544d200SIlia Diachkov auto ArgName = Arg.getName(); 2113544d200SIlia Diachkov NewFArgIt->setName(ArgName); 2123544d200SIlia Diachkov VMap[&Arg] = &(*NewFArgIt++); 2133544d200SIlia Diachkov } 2143544d200SIlia Diachkov SmallVector<ReturnInst *, 8> Returns; 2153544d200SIlia Diachkov CloneFunctionInto(NewF, OldF, VMap, 2163544d200SIlia Diachkov CloneFunctionChangeType::LocalChangesOnly, Returns); 2173544d200SIlia Diachkov NewF->setAttributes(Attrs); 2183544d200SIlia Diachkov Old2NewFuncs[OldF] = NewF; 2193544d200SIlia Diachkov } else { 2203544d200SIlia Diachkov NewF = Old2NewFuncs[OldF]; 2213544d200SIlia Diachkov } 2223544d200SIlia Diachkov assert(NewF); 2233544d200SIlia Diachkov 2244421b24fSIlia Diachkov // This produces an instruction sequence that implements a splat of 2254421b24fSIlia Diachkov // CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst 2264421b24fSIlia Diachkov // and ShuffleVectorInst to generate the same code as the SPIR-V translator. 2274421b24fSIlia Diachkov // For instance (transcoding/OpMin.ll), this call 2284421b24fSIlia Diachkov // call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5) 2294421b24fSIlia Diachkov // is translated to 2304421b24fSIlia Diachkov // %8 = OpUndef %v2uint 2314421b24fSIlia Diachkov // %14 = OpConstantComposite %v2uint %uint_1 %uint_10 2324421b24fSIlia Diachkov // ... 2334421b24fSIlia Diachkov // %10 = OpCompositeInsert %v2uint %uint_5 %8 0 2344421b24fSIlia Diachkov // %11 = OpVectorShuffle %v2uint %10 %8 0 0 2354421b24fSIlia Diachkov // %call = OpExtInst %v2uint %1 s_min %14 %11 2363544d200SIlia Diachkov auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0); 2374421b24fSIlia Diachkov PoisonValue *PVal = PoisonValue::get(Arg0Ty); 238caf0897cSJustin Bogner Instruction *Inst = InsertElementInst::Create( 239caf0897cSJustin Bogner PVal, CI->getOperand(1), ConstInt, "", CI->getIterator()); 2403544d200SIlia Diachkov ElementCount VecElemCount = cast<VectorType>(Arg0Ty)->getElementCount(); 2413544d200SIlia Diachkov Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt); 242caf0897cSJustin Bogner Value *NewVec = 243caf0897cSJustin Bogner new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI->getIterator()); 2443544d200SIlia Diachkov CI->setOperand(1, NewVec); 2453544d200SIlia Diachkov CI->replaceUsesOfWith(OldF, NewF); 2463544d200SIlia Diachkov CI->mutateFunctionType(NewF->getFunctionType()); 2473544d200SIlia Diachkov } 2483544d200SIlia Diachkov 2493544d200SIlia Diachkov bool SPIRVRegularizer::runOnFunction(Function &F) { 2503544d200SIlia Diachkov runLowerConstExpr(F); 2513544d200SIlia Diachkov visit(F); 2523544d200SIlia Diachkov for (auto &OldNew : Old2NewFuncs) { 2533544d200SIlia Diachkov Function *OldF = OldNew.first; 2543544d200SIlia Diachkov Function *NewF = OldNew.second; 2553544d200SIlia Diachkov NewF->takeName(OldF); 2563544d200SIlia Diachkov OldF->eraseFromParent(); 2573544d200SIlia Diachkov } 2583544d200SIlia Diachkov return true; 2593544d200SIlia Diachkov } 2603544d200SIlia Diachkov 2613544d200SIlia Diachkov FunctionPass *llvm::createSPIRVRegularizerPass() { 2623544d200SIlia Diachkov return new SPIRVRegularizer(); 2633544d200SIlia Diachkov } 264