1 //=== ReplaceWithVeclib.cpp - Replace vector intrinsics with veclib calls -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Replaces calls to LLVM Intrinsics with matching calls to functions from a 10 // vector library (e.g libmvec, SVML) using TargetLibraryInfo interface. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/CodeGen/ReplaceWithVeclib.h" 15 #include "llvm/ADT/STLExtras.h" 16 #include "llvm/ADT/Statistic.h" 17 #include "llvm/ADT/StringRef.h" 18 #include "llvm/Analysis/DemandedBits.h" 19 #include "llvm/Analysis/GlobalsModRef.h" 20 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 21 #include "llvm/Analysis/TargetLibraryInfo.h" 22 #include "llvm/Analysis/VectorUtils.h" 23 #include "llvm/CodeGen/Passes.h" 24 #include "llvm/IR/DerivedTypes.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/InstIterator.h" 27 #include "llvm/IR/IntrinsicInst.h" 28 #include "llvm/IR/VFABIDemangler.h" 29 #include "llvm/Support/TypeSize.h" 30 #include "llvm/Transforms/Utils/ModuleUtils.h" 31 32 using namespace llvm; 33 34 #define DEBUG_TYPE "replace-with-veclib" 35 36 STATISTIC(NumCallsReplaced, 37 "Number of calls to intrinsics that have been replaced."); 38 39 STATISTIC(NumTLIFuncDeclAdded, 40 "Number of vector library function declarations added."); 41 42 STATISTIC(NumFuncUsedAdded, 43 "Number of functions added to `llvm.compiler.used`"); 44 45 /// Returns a vector Function that it adds to the Module \p M. When an \p 46 /// ScalarFunc is not null, it copies its attributes to the newly created 47 /// Function. 48 Function *getTLIFunction(Module *M, FunctionType *VectorFTy, 49 const StringRef TLIName, 50 Function *ScalarFunc = nullptr) { 51 Function *TLIFunc = M->getFunction(TLIName); 52 if (!TLIFunc) { 53 TLIFunc = 54 Function::Create(VectorFTy, Function::ExternalLinkage, TLIName, *M); 55 if (ScalarFunc) 56 TLIFunc->copyAttributesFrom(ScalarFunc); 57 58 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `" 59 << TLIName << "` of type `" << *(TLIFunc->getType()) 60 << "` to module.\n"); 61 62 ++NumTLIFuncDeclAdded; 63 // Add the freshly created function to llvm.compiler.used, similar to as it 64 // is done in InjectTLIMappings. 65 appendToCompilerUsed(*M, {TLIFunc}); 66 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName 67 << "` to `@llvm.compiler.used`.\n"); 68 ++NumFuncUsedAdded; 69 } 70 return TLIFunc; 71 } 72 73 /// Replace the intrinsic call \p II to \p TLIVecFunc, which is the 74 /// corresponding function from the vector library. 75 static void replaceWithTLIFunction(IntrinsicInst *II, VFInfo &Info, 76 Function *TLIVecFunc) { 77 IRBuilder<> IRBuilder(II); 78 SmallVector<Value *> Args(II->args()); 79 if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) { 80 auto *MaskTy = 81 VectorType::get(Type::getInt1Ty(II->getContext()), Info.Shape.VF); 82 Args.insert(Args.begin() + OptMaskpos.value(), 83 Constant::getAllOnesValue(MaskTy)); 84 } 85 86 // Preserve the operand bundles. 87 SmallVector<OperandBundleDef, 1> OpBundles; 88 II->getOperandBundlesAsDefs(OpBundles); 89 90 auto *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles); 91 II->replaceAllUsesWith(Replacement); 92 // Preserve fast math flags for FP math. 93 if (isa<FPMathOperator>(Replacement)) 94 Replacement->copyFastMathFlags(II); 95 } 96 97 /// Returns true when successfully replaced \p II, which is a call to a 98 /// vectorized intrinsic, with a suitable function taking vector arguments, 99 /// based on available mappings in the \p TLI. 100 static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, 101 IntrinsicInst *II) { 102 assert(II != nullptr && "Intrinsic cannot be null"); 103 Intrinsic::ID IID = II->getIntrinsicID(); 104 Type *RetTy = II->getType(); 105 Type *ScalarRetTy = RetTy->getScalarType(); 106 // At the moment VFABI assumes the return type is always widened unless it is 107 // a void type. 108 auto *VTy = dyn_cast<VectorType>(RetTy); 109 ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0)); 110 111 // OloadTys collects types used in scalar intrinsic overload name. 112 SmallVector<Type *, 3> OloadTys; 113 if (!RetTy->isVoidTy() && 114 isVectorIntrinsicWithOverloadTypeAtArg(IID, -1, /*TTI=*/nullptr)) 115 OloadTys.push_back(ScalarRetTy); 116 117 // Compute the argument types of the corresponding scalar call and check that 118 // all vector operands match the previously found EC. 119 SmallVector<Type *, 8> ScalarArgTypes; 120 for (auto Arg : enumerate(II->args())) { 121 auto *ArgTy = Arg.value()->getType(); 122 bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index(), 123 /*TTI=*/nullptr); 124 if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index(), /*TTI=*/nullptr)) { 125 ScalarArgTypes.push_back(ArgTy); 126 if (IsOloadTy) 127 OloadTys.push_back(ArgTy); 128 } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) { 129 auto *ScalarArgTy = VectorArgTy->getElementType(); 130 ScalarArgTypes.push_back(ScalarArgTy); 131 if (IsOloadTy) 132 OloadTys.push_back(ScalarArgTy); 133 // When return type is void, set EC to the first vector argument, and 134 // disallow vector arguments with different ECs. 135 if (EC.isZero()) 136 EC = VectorArgTy->getElementCount(); 137 else if (EC != VectorArgTy->getElementCount()) 138 return false; 139 } else 140 // Exit when it is supposed to be a vector argument but it isn't. 141 return false; 142 } 143 144 // Try to reconstruct the name for the scalar version of the instruction, 145 // using scalar argument types. 146 std::string ScalarName = 147 Intrinsic::isOverloaded(IID) 148 ? Intrinsic::getName(IID, OloadTys, II->getModule()) 149 : Intrinsic::getName(IID).str(); 150 151 // Try to find the mapping for the scalar version of this intrinsic and the 152 // exact vector width of the call operands in the TargetLibraryInfo. First, 153 // check with a non-masked variant, and if that fails try with a masked one. 154 const VecDesc *VD = 155 TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ false); 156 if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ true))) 157 return false; 158 159 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI mapping from: `" << ScalarName 160 << "` and vector width " << EC << " to: `" 161 << VD->getVectorFnName() << "`.\n"); 162 163 // Replace the call to the intrinsic with a call to the vector library 164 // function. 165 FunctionType *ScalarFTy = 166 FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false); 167 const std::string MangledName = VD->getVectorFunctionABIVariantString(); 168 auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy); 169 if (!OptInfo) 170 return false; 171 172 // There is no guarantee that the vectorized instructions followed the VFABI 173 // specification when being created, this is why we need to add extra check to 174 // make sure that the operands of the vector function obtained via VFABI match 175 // the operands of the original vector instruction. 176 for (auto &VFParam : OptInfo->Shape.Parameters) { 177 if (VFParam.ParamKind == VFParamKind::GlobalPredicate) 178 continue; 179 180 // tryDemangleForVFABI must return valid ParamPos, otherwise it could be 181 // a bug in the VFABI parser. 182 assert(VFParam.ParamPos < II->arg_size() && "ParamPos has invalid range"); 183 Type *OrigTy = II->getArgOperand(VFParam.ParamPos)->getType(); 184 if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) { 185 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName 186 << ". Wrong type at index " << VFParam.ParamPos << ": " 187 << *OrigTy << "\n"); 188 return false; 189 } 190 } 191 192 FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy); 193 if (!VectorFTy) 194 return false; 195 196 Function *TLIFunc = 197 getTLIFunction(II->getModule(), VectorFTy, VD->getVectorFnName(), 198 II->getCalledFunction()); 199 replaceWithTLIFunction(II, *OptInfo, TLIFunc); 200 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName 201 << "` with call to `" << TLIFunc->getName() << "`.\n"); 202 ++NumCallsReplaced; 203 return true; 204 } 205 206 static bool runImpl(const TargetLibraryInfo &TLI, Function &F) { 207 SmallVector<Instruction *> ReplacedCalls; 208 for (auto &I : instructions(F)) { 209 // Process only intrinsic calls that return void or a vector. 210 if (auto *II = dyn_cast<IntrinsicInst>(&I)) { 211 if (II->getIntrinsicID() == Intrinsic::not_intrinsic) 212 continue; 213 if (!II->getType()->isVectorTy() && !II->getType()->isVoidTy()) 214 continue; 215 216 if (replaceWithCallToVeclib(TLI, II)) 217 ReplacedCalls.push_back(&I); 218 } 219 } 220 // Erase any intrinsic calls that were replaced with vector library calls. 221 for (auto *I : ReplacedCalls) 222 I->eraseFromParent(); 223 return !ReplacedCalls.empty(); 224 } 225 226 //////////////////////////////////////////////////////////////////////////////// 227 // New pass manager implementation. 228 //////////////////////////////////////////////////////////////////////////////// 229 PreservedAnalyses ReplaceWithVeclib::run(Function &F, 230 FunctionAnalysisManager &AM) { 231 const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F); 232 auto Changed = runImpl(TLI, F); 233 if (Changed) { 234 LLVM_DEBUG(dbgs() << "Intrinsic calls replaced with vector libraries: " 235 << NumCallsReplaced << "\n"); 236 237 PreservedAnalyses PA; 238 PA.preserveSet<CFGAnalyses>(); 239 PA.preserve<TargetLibraryAnalysis>(); 240 PA.preserve<ScalarEvolutionAnalysis>(); 241 PA.preserve<LoopAccessAnalysis>(); 242 PA.preserve<DemandedBitsAnalysis>(); 243 PA.preserve<OptimizationRemarkEmitterAnalysis>(); 244 return PA; 245 } 246 247 // The pass did not replace any calls, hence it preserves all analyses. 248 return PreservedAnalyses::all(); 249 } 250 251 //////////////////////////////////////////////////////////////////////////////// 252 // Legacy PM Implementation. 253 //////////////////////////////////////////////////////////////////////////////// 254 bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) { 255 const TargetLibraryInfo &TLI = 256 getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 257 return runImpl(TLI, F); 258 } 259 260 void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const { 261 AU.setPreservesCFG(); 262 AU.addRequired<TargetLibraryInfoWrapperPass>(); 263 AU.addPreserved<TargetLibraryInfoWrapperPass>(); 264 AU.addPreserved<ScalarEvolutionWrapperPass>(); 265 AU.addPreserved<AAResultsWrapperPass>(); 266 AU.addPreserved<OptimizationRemarkEmitterWrapperPass>(); 267 AU.addPreserved<GlobalsAAWrapperPass>(); 268 } 269 270 //////////////////////////////////////////////////////////////////////////////// 271 // Legacy Pass manager initialization 272 //////////////////////////////////////////////////////////////////////////////// 273 char ReplaceWithVeclibLegacy::ID = 0; 274 275 INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE, 276 "Replace intrinsics with calls to vector library", false, 277 false) 278 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 279 INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE, 280 "Replace intrinsics with calls to vector library", false, 281 false) 282 283 FunctionPass *llvm::createReplaceWithVeclibLegacyPass() { 284 return new ReplaceWithVeclibLegacy(); 285 } 286