1 //===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass implements regularization of LLVM IR for SPIR-V. The prototype of 10 // the pass was taken from SPIRV-LLVM translator. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "SPIRV.h" 15 #include "SPIRVTargetMachine.h" 16 #include "llvm/Demangle/Demangle.h" 17 #include "llvm/IR/InstIterator.h" 18 #include "llvm/IR/InstVisitor.h" 19 #include "llvm/IR/PassManager.h" 20 #include "llvm/Transforms/Utils/Cloning.h" 21 22 #include <list> 23 24 #define DEBUG_TYPE "spirv-regularizer" 25 26 using namespace llvm; 27 28 namespace llvm { 29 void initializeSPIRVRegularizerPass(PassRegistry &); 30 } 31 32 namespace { 33 struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> { 34 DenseMap<Function *, Function *> Old2NewFuncs; 35 36 public: 37 static char ID; 38 SPIRVRegularizer() : FunctionPass(ID) { 39 initializeSPIRVRegularizerPass(*PassRegistry::getPassRegistry()); 40 } 41 bool runOnFunction(Function &F) override; 42 StringRef getPassName() const override { return "SPIR-V Regularizer"; } 43 44 void getAnalysisUsage(AnalysisUsage &AU) const override { 45 FunctionPass::getAnalysisUsage(AU); 46 } 47 void visitCallInst(CallInst &CI); 48 49 private: 50 void visitCallScalToVec(CallInst *CI, StringRef MangledName, 51 StringRef DemangledName); 52 void runLowerConstExpr(Function &F); 53 }; 54 } // namespace 55 56 char SPIRVRegularizer::ID = 0; 57 58 INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false, 59 false) 60 61 // Since SPIR-V cannot represent constant expression, constant expressions 62 // in LLVM IR need to be lowered to instructions. For each function, 63 // the constant expressions used by instructions of the function are replaced 64 // by instructions placed in the entry block since it dominates all other BBs. 65 // Each constant expression only needs to be lowered once in each function 66 // and all uses of it by instructions in that function are replaced by 67 // one instruction. 68 // TODO: remove redundant instructions for common subexpression. 69 void SPIRVRegularizer::runLowerConstExpr(Function &F) { 70 LLVMContext &Ctx = F.getContext(); 71 std::list<Instruction *> WorkList; 72 for (auto &II : instructions(F)) 73 WorkList.push_back(&II); 74 75 auto FBegin = F.begin(); 76 while (!WorkList.empty()) { 77 Instruction *II = WorkList.front(); 78 79 auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * { 80 if (isa<Function>(V)) 81 return V; 82 auto *CE = cast<ConstantExpr>(V); 83 LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE); 84 auto ReplInst = CE->getAsInstruction(); 85 auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back(); 86 ReplInst->insertBefore(InsPoint->getIterator()); 87 LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n'); 88 std::vector<Instruction *> Users; 89 // Do not replace use during iteration of use. Do it in another loop. 90 for (auto U : CE->users()) { 91 LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n'); 92 auto InstUser = dyn_cast<Instruction>(U); 93 // Only replace users in scope of current function. 94 if (InstUser && InstUser->getParent()->getParent() == &F) 95 Users.push_back(InstUser); 96 } 97 for (auto &User : Users) { 98 if (ReplInst->getParent() == User->getParent() && 99 User->comesBefore(ReplInst)) 100 ReplInst->moveBefore(User->getIterator()); 101 User->replaceUsesOfWith(CE, ReplInst); 102 } 103 return ReplInst; 104 }; 105 106 WorkList.pop_front(); 107 auto LowerConstantVec = [&II, &LowerOp, &WorkList, 108 &Ctx](ConstantVector *Vec, 109 unsigned NumOfOp) -> Value * { 110 if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) { 111 return isa<ConstantExpr>(V) || isa<Function>(V); 112 })) { 113 // Expand a vector of constexprs and construct it back with 114 // series of insertelement instructions. 115 std::list<Value *> OpList; 116 std::transform(Vec->op_begin(), Vec->op_end(), 117 std::back_inserter(OpList), 118 [LowerOp](Value *V) { return LowerOp(V); }); 119 Value *Repl = nullptr; 120 unsigned Idx = 0; 121 auto *PhiII = dyn_cast<PHINode>(II); 122 Instruction *InsPoint = 123 PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II; 124 std::list<Instruction *> ReplList; 125 for (auto V : OpList) { 126 if (auto *Inst = dyn_cast<Instruction>(V)) 127 ReplList.push_back(Inst); 128 Repl = InsertElementInst::Create( 129 (Repl ? Repl : PoisonValue::get(Vec->getType())), V, 130 ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "", 131 InsPoint->getIterator()); 132 } 133 WorkList.splice(WorkList.begin(), ReplList); 134 return Repl; 135 } 136 return nullptr; 137 }; 138 for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) { 139 auto *Op = II->getOperand(OI); 140 if (auto *Vec = dyn_cast<ConstantVector>(Op)) { 141 Value *ReplInst = LowerConstantVec(Vec, OI); 142 if (ReplInst) 143 II->replaceUsesOfWith(Op, ReplInst); 144 } else if (auto CE = dyn_cast<ConstantExpr>(Op)) { 145 WorkList.push_front(cast<Instruction>(LowerOp(CE))); 146 } else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) { 147 auto ConstMD = dyn_cast<ConstantAsMetadata>(MDAsVal->getMetadata()); 148 if (!ConstMD) 149 continue; 150 Constant *C = ConstMD->getValue(); 151 Value *ReplInst = nullptr; 152 if (auto *Vec = dyn_cast<ConstantVector>(C)) 153 ReplInst = LowerConstantVec(Vec, OI); 154 if (auto *CE = dyn_cast<ConstantExpr>(C)) 155 ReplInst = LowerOp(CE); 156 if (!ReplInst) 157 continue; 158 Metadata *RepMD = ValueAsMetadata::get(ReplInst); 159 Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD); 160 II->setOperand(OI, RepMDVal); 161 WorkList.push_front(cast<Instruction>(ReplInst)); 162 } 163 } 164 } 165 } 166 167 // It fixes calls to OCL builtins that accept vector arguments and one of them 168 // is actually a scalar splat. 169 void SPIRVRegularizer::visitCallInst(CallInst &CI) { 170 auto F = CI.getCalledFunction(); 171 if (!F) 172 return; 173 174 auto MangledName = F->getName(); 175 char *NameStr = itaniumDemangle(F->getName().data()); 176 if (!NameStr) 177 return; 178 StringRef DemangledName(NameStr); 179 180 // TODO: add support for other builtins. 181 if (DemangledName.starts_with("fmin") || DemangledName.starts_with("fmax") || 182 DemangledName.starts_with("min") || DemangledName.starts_with("max")) 183 visitCallScalToVec(&CI, MangledName, DemangledName); 184 free(NameStr); 185 } 186 187 void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName, 188 StringRef DemangledName) { 189 // Check if all arguments have the same type - it's simple case. 190 auto Uniform = true; 191 Type *Arg0Ty = CI->getOperand(0)->getType(); 192 auto IsArg0Vector = isa<VectorType>(Arg0Ty); 193 for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I) 194 Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector; 195 if (Uniform) 196 return; 197 198 auto *OldF = CI->getCalledFunction(); 199 Function *NewF = nullptr; 200 if (!Old2NewFuncs.count(OldF)) { 201 AttributeList Attrs = CI->getCalledFunction()->getAttributes(); 202 SmallVector<Type *, 2> ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty}; 203 auto *NewFTy = 204 FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg()); 205 NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(), 206 *OldF->getParent()); 207 ValueToValueMapTy VMap; 208 auto NewFArgIt = NewF->arg_begin(); 209 for (auto &Arg : OldF->args()) { 210 auto ArgName = Arg.getName(); 211 NewFArgIt->setName(ArgName); 212 VMap[&Arg] = &(*NewFArgIt++); 213 } 214 SmallVector<ReturnInst *, 8> Returns; 215 CloneFunctionInto(NewF, OldF, VMap, 216 CloneFunctionChangeType::LocalChangesOnly, Returns); 217 NewF->setAttributes(Attrs); 218 Old2NewFuncs[OldF] = NewF; 219 } else { 220 NewF = Old2NewFuncs[OldF]; 221 } 222 assert(NewF); 223 224 // This produces an instruction sequence that implements a splat of 225 // CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst 226 // and ShuffleVectorInst to generate the same code as the SPIR-V translator. 227 // For instance (transcoding/OpMin.ll), this call 228 // call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5) 229 // is translated to 230 // %8 = OpUndef %v2uint 231 // %14 = OpConstantComposite %v2uint %uint_1 %uint_10 232 // ... 233 // %10 = OpCompositeInsert %v2uint %uint_5 %8 0 234 // %11 = OpVectorShuffle %v2uint %10 %8 0 0 235 // %call = OpExtInst %v2uint %1 s_min %14 %11 236 auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0); 237 PoisonValue *PVal = PoisonValue::get(Arg0Ty); 238 Instruction *Inst = InsertElementInst::Create( 239 PVal, CI->getOperand(1), ConstInt, "", CI->getIterator()); 240 ElementCount VecElemCount = cast<VectorType>(Arg0Ty)->getElementCount(); 241 Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt); 242 Value *NewVec = 243 new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI->getIterator()); 244 CI->setOperand(1, NewVec); 245 CI->replaceUsesOfWith(OldF, NewF); 246 CI->mutateFunctionType(NewF->getFunctionType()); 247 } 248 249 bool SPIRVRegularizer::runOnFunction(Function &F) { 250 runLowerConstExpr(F); 251 visit(F); 252 for (auto &OldNew : Old2NewFuncs) { 253 Function *OldF = OldNew.first; 254 Function *NewF = OldNew.second; 255 NewF->takeName(OldF); 256 OldF->eraseFromParent(); 257 } 258 return true; 259 } 260 261 FunctionPass *llvm::createSPIRVRegularizerPass() { 262 return new SPIRVRegularizer(); 263 } 264