1bdd1243dSDimitry Andric //===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===// 2bdd1243dSDimitry Andric // 3bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6bdd1243dSDimitry Andric // 7bdd1243dSDimitry Andric //===----------------------------------------------------------------------===// 8bdd1243dSDimitry Andric // 9bdd1243dSDimitry Andric // This pass implements regularization of LLVM IR for SPIR-V. The prototype of 10bdd1243dSDimitry Andric // the pass was taken from SPIRV-LLVM translator. 11bdd1243dSDimitry Andric // 12bdd1243dSDimitry Andric //===----------------------------------------------------------------------===// 13bdd1243dSDimitry Andric 14bdd1243dSDimitry Andric #include "SPIRV.h" 15bdd1243dSDimitry Andric #include "SPIRVTargetMachine.h" 16bdd1243dSDimitry Andric #include "llvm/Demangle/Demangle.h" 17bdd1243dSDimitry Andric #include "llvm/IR/InstIterator.h" 18bdd1243dSDimitry Andric #include "llvm/IR/InstVisitor.h" 19bdd1243dSDimitry Andric #include "llvm/IR/PassManager.h" 20bdd1243dSDimitry Andric #include "llvm/Transforms/Utils/Cloning.h" 21bdd1243dSDimitry Andric 22bdd1243dSDimitry Andric #include <list> 23bdd1243dSDimitry Andric 24bdd1243dSDimitry Andric #define DEBUG_TYPE "spirv-regularizer" 25bdd1243dSDimitry Andric 26bdd1243dSDimitry Andric using namespace llvm; 27bdd1243dSDimitry Andric 28bdd1243dSDimitry Andric namespace llvm { 29bdd1243dSDimitry Andric void initializeSPIRVRegularizerPass(PassRegistry &); 30bdd1243dSDimitry Andric } 31bdd1243dSDimitry Andric 32bdd1243dSDimitry Andric namespace { 33bdd1243dSDimitry Andric struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> { 34bdd1243dSDimitry Andric DenseMap<Function *, Function *> Old2NewFuncs; 35bdd1243dSDimitry Andric 36bdd1243dSDimitry Andric public: 37bdd1243dSDimitry Andric static char ID; 38bdd1243dSDimitry Andric SPIRVRegularizer() : FunctionPass(ID) { 39bdd1243dSDimitry Andric initializeSPIRVRegularizerPass(*PassRegistry::getPassRegistry()); 40bdd1243dSDimitry Andric } 41bdd1243dSDimitry Andric bool runOnFunction(Function &F) override; 42bdd1243dSDimitry Andric StringRef getPassName() const override { return "SPIR-V Regularizer"; } 43bdd1243dSDimitry Andric 44bdd1243dSDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 45bdd1243dSDimitry Andric FunctionPass::getAnalysisUsage(AU); 46bdd1243dSDimitry Andric } 47bdd1243dSDimitry Andric void visitCallInst(CallInst &CI); 48bdd1243dSDimitry Andric 49bdd1243dSDimitry Andric private: 50bdd1243dSDimitry Andric void visitCallScalToVec(CallInst *CI, StringRef MangledName, 51bdd1243dSDimitry Andric StringRef DemangledName); 52bdd1243dSDimitry Andric void runLowerConstExpr(Function &F); 53bdd1243dSDimitry Andric }; 54bdd1243dSDimitry Andric } // namespace 55bdd1243dSDimitry Andric 56bdd1243dSDimitry Andric char SPIRVRegularizer::ID = 0; 57bdd1243dSDimitry Andric 58bdd1243dSDimitry Andric INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false, 59bdd1243dSDimitry Andric false) 60bdd1243dSDimitry Andric 61bdd1243dSDimitry Andric // Since SPIR-V cannot represent constant expression, constant expressions 62bdd1243dSDimitry Andric // in LLVM IR need to be lowered to instructions. For each function, 63bdd1243dSDimitry Andric // the constant expressions used by instructions of the function are replaced 64bdd1243dSDimitry Andric // by instructions placed in the entry block since it dominates all other BBs. 65bdd1243dSDimitry Andric // Each constant expression only needs to be lowered once in each function 66bdd1243dSDimitry Andric // and all uses of it by instructions in that function are replaced by 67bdd1243dSDimitry Andric // one instruction. 68bdd1243dSDimitry Andric // TODO: remove redundant instructions for common subexpression. 69bdd1243dSDimitry Andric void SPIRVRegularizer::runLowerConstExpr(Function &F) { 70bdd1243dSDimitry Andric LLVMContext &Ctx = F.getContext(); 71bdd1243dSDimitry Andric std::list<Instruction *> WorkList; 72bdd1243dSDimitry Andric for (auto &II : instructions(F)) 73bdd1243dSDimitry Andric WorkList.push_back(&II); 74bdd1243dSDimitry Andric 75bdd1243dSDimitry Andric auto FBegin = F.begin(); 76bdd1243dSDimitry Andric while (!WorkList.empty()) { 77bdd1243dSDimitry Andric Instruction *II = WorkList.front(); 78bdd1243dSDimitry Andric 79bdd1243dSDimitry Andric auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * { 80bdd1243dSDimitry Andric if (isa<Function>(V)) 81bdd1243dSDimitry Andric return V; 82bdd1243dSDimitry Andric auto *CE = cast<ConstantExpr>(V); 83bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE); 84bdd1243dSDimitry Andric auto ReplInst = CE->getAsInstruction(); 85bdd1243dSDimitry Andric auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back(); 86bdd1243dSDimitry Andric ReplInst->insertBefore(InsPoint); 87bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n'); 88bdd1243dSDimitry Andric std::vector<Instruction *> Users; 89bdd1243dSDimitry Andric // Do not replace use during iteration of use. Do it in another loop. 90bdd1243dSDimitry Andric for (auto U : CE->users()) { 91bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n'); 92bdd1243dSDimitry Andric auto InstUser = dyn_cast<Instruction>(U); 93bdd1243dSDimitry Andric // Only replace users in scope of current function. 94bdd1243dSDimitry Andric if (InstUser && InstUser->getParent()->getParent() == &F) 95bdd1243dSDimitry Andric Users.push_back(InstUser); 96bdd1243dSDimitry Andric } 97bdd1243dSDimitry Andric for (auto &User : Users) { 98bdd1243dSDimitry Andric if (ReplInst->getParent() == User->getParent() && 99bdd1243dSDimitry Andric User->comesBefore(ReplInst)) 100bdd1243dSDimitry Andric ReplInst->moveBefore(User); 101bdd1243dSDimitry Andric User->replaceUsesOfWith(CE, ReplInst); 102bdd1243dSDimitry Andric } 103bdd1243dSDimitry Andric return ReplInst; 104bdd1243dSDimitry Andric }; 105bdd1243dSDimitry Andric 106bdd1243dSDimitry Andric WorkList.pop_front(); 107bdd1243dSDimitry Andric auto LowerConstantVec = [&II, &LowerOp, &WorkList, 108bdd1243dSDimitry Andric &Ctx](ConstantVector *Vec, 109bdd1243dSDimitry Andric unsigned NumOfOp) -> Value * { 110bdd1243dSDimitry Andric if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) { 111bdd1243dSDimitry Andric return isa<ConstantExpr>(V) || isa<Function>(V); 112bdd1243dSDimitry Andric })) { 113bdd1243dSDimitry Andric // Expand a vector of constexprs and construct it back with 114bdd1243dSDimitry Andric // series of insertelement instructions. 115bdd1243dSDimitry Andric std::list<Value *> OpList; 116bdd1243dSDimitry Andric std::transform(Vec->op_begin(), Vec->op_end(), 117bdd1243dSDimitry Andric std::back_inserter(OpList), 118bdd1243dSDimitry Andric [LowerOp](Value *V) { return LowerOp(V); }); 119bdd1243dSDimitry Andric Value *Repl = nullptr; 120bdd1243dSDimitry Andric unsigned Idx = 0; 121bdd1243dSDimitry Andric auto *PhiII = dyn_cast<PHINode>(II); 122bdd1243dSDimitry Andric Instruction *InsPoint = 123bdd1243dSDimitry Andric PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II; 124bdd1243dSDimitry Andric std::list<Instruction *> ReplList; 125bdd1243dSDimitry Andric for (auto V : OpList) { 126bdd1243dSDimitry Andric if (auto *Inst = dyn_cast<Instruction>(V)) 127bdd1243dSDimitry Andric ReplList.push_back(Inst); 128bdd1243dSDimitry Andric Repl = InsertElementInst::Create( 129bdd1243dSDimitry Andric (Repl ? Repl : PoisonValue::get(Vec->getType())), V, 130bdd1243dSDimitry Andric ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "", InsPoint); 131bdd1243dSDimitry Andric } 132bdd1243dSDimitry Andric WorkList.splice(WorkList.begin(), ReplList); 133bdd1243dSDimitry Andric return Repl; 134bdd1243dSDimitry Andric } 135bdd1243dSDimitry Andric return nullptr; 136bdd1243dSDimitry Andric }; 137bdd1243dSDimitry Andric for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) { 138bdd1243dSDimitry Andric auto *Op = II->getOperand(OI); 139bdd1243dSDimitry Andric if (auto *Vec = dyn_cast<ConstantVector>(Op)) { 140bdd1243dSDimitry Andric Value *ReplInst = LowerConstantVec(Vec, OI); 141bdd1243dSDimitry Andric if (ReplInst) 142bdd1243dSDimitry Andric II->replaceUsesOfWith(Op, ReplInst); 143bdd1243dSDimitry Andric } else if (auto CE = dyn_cast<ConstantExpr>(Op)) { 144bdd1243dSDimitry Andric WorkList.push_front(cast<Instruction>(LowerOp(CE))); 145bdd1243dSDimitry Andric } else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) { 146bdd1243dSDimitry Andric auto ConstMD = dyn_cast<ConstantAsMetadata>(MDAsVal->getMetadata()); 147bdd1243dSDimitry Andric if (!ConstMD) 148bdd1243dSDimitry Andric continue; 149bdd1243dSDimitry Andric Constant *C = ConstMD->getValue(); 150bdd1243dSDimitry Andric Value *ReplInst = nullptr; 151bdd1243dSDimitry Andric if (auto *Vec = dyn_cast<ConstantVector>(C)) 152bdd1243dSDimitry Andric ReplInst = LowerConstantVec(Vec, OI); 153bdd1243dSDimitry Andric if (auto *CE = dyn_cast<ConstantExpr>(C)) 154bdd1243dSDimitry Andric ReplInst = LowerOp(CE); 155bdd1243dSDimitry Andric if (!ReplInst) 156bdd1243dSDimitry Andric continue; 157bdd1243dSDimitry Andric Metadata *RepMD = ValueAsMetadata::get(ReplInst); 158bdd1243dSDimitry Andric Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD); 159bdd1243dSDimitry Andric II->setOperand(OI, RepMDVal); 160bdd1243dSDimitry Andric WorkList.push_front(cast<Instruction>(ReplInst)); 161bdd1243dSDimitry Andric } 162bdd1243dSDimitry Andric } 163bdd1243dSDimitry Andric } 164bdd1243dSDimitry Andric } 165bdd1243dSDimitry Andric 166bdd1243dSDimitry Andric // It fixes calls to OCL builtins that accept vector arguments and one of them 167bdd1243dSDimitry Andric // is actually a scalar splat. 168bdd1243dSDimitry Andric void SPIRVRegularizer::visitCallInst(CallInst &CI) { 169bdd1243dSDimitry Andric auto F = CI.getCalledFunction(); 170bdd1243dSDimitry Andric if (!F) 171bdd1243dSDimitry Andric return; 172bdd1243dSDimitry Andric 173bdd1243dSDimitry Andric auto MangledName = F->getName(); 174*06c3fb27SDimitry Andric char *NameStr = itaniumDemangle(F->getName().data()); 175*06c3fb27SDimitry Andric if (!NameStr) 176*06c3fb27SDimitry Andric return; 177bdd1243dSDimitry Andric StringRef DemangledName(NameStr); 178bdd1243dSDimitry Andric 179bdd1243dSDimitry Andric // TODO: add support for other builtins. 180bdd1243dSDimitry Andric if (DemangledName.startswith("fmin") || DemangledName.startswith("fmax") || 181bdd1243dSDimitry Andric DemangledName.startswith("min") || DemangledName.startswith("max")) 182bdd1243dSDimitry Andric visitCallScalToVec(&CI, MangledName, DemangledName); 183bdd1243dSDimitry Andric free(NameStr); 184bdd1243dSDimitry Andric } 185bdd1243dSDimitry Andric 186bdd1243dSDimitry Andric void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName, 187bdd1243dSDimitry Andric StringRef DemangledName) { 188bdd1243dSDimitry Andric // Check if all arguments have the same type - it's simple case. 189bdd1243dSDimitry Andric auto Uniform = true; 190bdd1243dSDimitry Andric Type *Arg0Ty = CI->getOperand(0)->getType(); 191bdd1243dSDimitry Andric auto IsArg0Vector = isa<VectorType>(Arg0Ty); 192bdd1243dSDimitry Andric for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I) 193bdd1243dSDimitry Andric Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector; 194bdd1243dSDimitry Andric if (Uniform) 195bdd1243dSDimitry Andric return; 196bdd1243dSDimitry Andric 197bdd1243dSDimitry Andric auto *OldF = CI->getCalledFunction(); 198bdd1243dSDimitry Andric Function *NewF = nullptr; 199bdd1243dSDimitry Andric if (!Old2NewFuncs.count(OldF)) { 200bdd1243dSDimitry Andric AttributeList Attrs = CI->getCalledFunction()->getAttributes(); 201bdd1243dSDimitry Andric SmallVector<Type *, 2> ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty}; 202bdd1243dSDimitry Andric auto *NewFTy = 203bdd1243dSDimitry Andric FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg()); 204bdd1243dSDimitry Andric NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(), 205bdd1243dSDimitry Andric *OldF->getParent()); 206bdd1243dSDimitry Andric ValueToValueMapTy VMap; 207bdd1243dSDimitry Andric auto NewFArgIt = NewF->arg_begin(); 208bdd1243dSDimitry Andric for (auto &Arg : OldF->args()) { 209bdd1243dSDimitry Andric auto ArgName = Arg.getName(); 210bdd1243dSDimitry Andric NewFArgIt->setName(ArgName); 211bdd1243dSDimitry Andric VMap[&Arg] = &(*NewFArgIt++); 212bdd1243dSDimitry Andric } 213bdd1243dSDimitry Andric SmallVector<ReturnInst *, 8> Returns; 214bdd1243dSDimitry Andric CloneFunctionInto(NewF, OldF, VMap, 215bdd1243dSDimitry Andric CloneFunctionChangeType::LocalChangesOnly, Returns); 216bdd1243dSDimitry Andric NewF->setAttributes(Attrs); 217bdd1243dSDimitry Andric Old2NewFuncs[OldF] = NewF; 218bdd1243dSDimitry Andric } else { 219bdd1243dSDimitry Andric NewF = Old2NewFuncs[OldF]; 220bdd1243dSDimitry Andric } 221bdd1243dSDimitry Andric assert(NewF); 222bdd1243dSDimitry Andric 223bdd1243dSDimitry Andric // This produces an instruction sequence that implements a splat of 224bdd1243dSDimitry Andric // CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst 225bdd1243dSDimitry Andric // and ShuffleVectorInst to generate the same code as the SPIR-V translator. 226bdd1243dSDimitry Andric // For instance (transcoding/OpMin.ll), this call 227bdd1243dSDimitry Andric // call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5) 228bdd1243dSDimitry Andric // is translated to 229bdd1243dSDimitry Andric // %8 = OpUndef %v2uint 230bdd1243dSDimitry Andric // %14 = OpConstantComposite %v2uint %uint_1 %uint_10 231bdd1243dSDimitry Andric // ... 232bdd1243dSDimitry Andric // %10 = OpCompositeInsert %v2uint %uint_5 %8 0 233bdd1243dSDimitry Andric // %11 = OpVectorShuffle %v2uint %10 %8 0 0 234bdd1243dSDimitry Andric // %call = OpExtInst %v2uint %1 s_min %14 %11 235bdd1243dSDimitry Andric auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0); 236bdd1243dSDimitry Andric PoisonValue *PVal = PoisonValue::get(Arg0Ty); 237bdd1243dSDimitry Andric Instruction *Inst = 238bdd1243dSDimitry Andric InsertElementInst::Create(PVal, CI->getOperand(1), ConstInt, "", CI); 239bdd1243dSDimitry Andric ElementCount VecElemCount = cast<VectorType>(Arg0Ty)->getElementCount(); 240bdd1243dSDimitry Andric Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt); 241bdd1243dSDimitry Andric Value *NewVec = new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI); 242bdd1243dSDimitry Andric CI->setOperand(1, NewVec); 243bdd1243dSDimitry Andric CI->replaceUsesOfWith(OldF, NewF); 244bdd1243dSDimitry Andric CI->mutateFunctionType(NewF->getFunctionType()); 245bdd1243dSDimitry Andric } 246bdd1243dSDimitry Andric 247bdd1243dSDimitry Andric bool SPIRVRegularizer::runOnFunction(Function &F) { 248bdd1243dSDimitry Andric runLowerConstExpr(F); 249bdd1243dSDimitry Andric visit(F); 250bdd1243dSDimitry Andric for (auto &OldNew : Old2NewFuncs) { 251bdd1243dSDimitry Andric Function *OldF = OldNew.first; 252bdd1243dSDimitry Andric Function *NewF = OldNew.second; 253bdd1243dSDimitry Andric NewF->takeName(OldF); 254bdd1243dSDimitry Andric OldF->eraseFromParent(); 255bdd1243dSDimitry Andric } 256bdd1243dSDimitry Andric return true; 257bdd1243dSDimitry Andric } 258bdd1243dSDimitry Andric 259bdd1243dSDimitry Andric FunctionPass *llvm::createSPIRVRegularizerPass() { 260bdd1243dSDimitry Andric return new SPIRVRegularizer(); 261bdd1243dSDimitry Andric } 262