xref: /llvm-project/llvm/lib/CodeGen/ReplaceWithVeclib.cpp (revision 82b40fd4fd2f3f723d30b666f8766973da4166db)
1 //=== ReplaceWithVeclib.cpp - Replace vector intrinsics with veclib calls -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Replaces calls to LLVM Intrinsics with matching calls to functions from a
10 // vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/ReplaceWithVeclib.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/Analysis/DemandedBits.h"
19 #include "llvm/Analysis/GlobalsModRef.h"
20 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
21 #include "llvm/Analysis/TargetLibraryInfo.h"
22 #include "llvm/Analysis/VectorUtils.h"
23 #include "llvm/CodeGen/Passes.h"
24 #include "llvm/IR/DerivedTypes.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/InstIterator.h"
27 #include "llvm/IR/IntrinsicInst.h"
28 #include "llvm/IR/VFABIDemangler.h"
29 #include "llvm/Support/TypeSize.h"
30 #include "llvm/Transforms/Utils/ModuleUtils.h"
31 
32 using namespace llvm;
33 
34 #define DEBUG_TYPE "replace-with-veclib"
35 
36 STATISTIC(NumCallsReplaced,
37           "Number of calls to intrinsics that have been replaced.");
38 
39 STATISTIC(NumTLIFuncDeclAdded,
40           "Number of vector library function declarations added.");
41 
42 STATISTIC(NumFuncUsedAdded,
43           "Number of functions added to `llvm.compiler.used`");
44 
45 /// Returns a vector Function that it adds to the Module \p M. When an \p
46 /// ScalarFunc is not null, it copies its attributes to the newly created
47 /// Function.
48 Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
49                          const StringRef TLIName,
50                          Function *ScalarFunc = nullptr) {
51   Function *TLIFunc = M->getFunction(TLIName);
52   if (!TLIFunc) {
53     TLIFunc =
54         Function::Create(VectorFTy, Function::ExternalLinkage, TLIName, *M);
55     if (ScalarFunc)
56       TLIFunc->copyAttributesFrom(ScalarFunc);
57 
58     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
59                       << TLIName << "` of type `" << *(TLIFunc->getType())
60                       << "` to module.\n");
61 
62     ++NumTLIFuncDeclAdded;
63     // Add the freshly created function to llvm.compiler.used, similar to as it
64     // is done in InjectTLIMappings.
65     appendToCompilerUsed(*M, {TLIFunc});
66     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
67                       << "` to `@llvm.compiler.used`.\n");
68     ++NumFuncUsedAdded;
69   }
70   return TLIFunc;
71 }
72 
73 /// Replace the intrinsic call \p II to \p TLIVecFunc, which is the
74 /// corresponding function from the vector library.
75 static void replaceWithTLIFunction(IntrinsicInst *II, VFInfo &Info,
76                                    Function *TLIVecFunc) {
77   IRBuilder<> IRBuilder(II);
78   SmallVector<Value *> Args(II->args());
79   if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
80     auto *MaskTy =
81         VectorType::get(Type::getInt1Ty(II->getContext()), Info.Shape.VF);
82     Args.insert(Args.begin() + OptMaskpos.value(),
83                 Constant::getAllOnesValue(MaskTy));
84   }
85 
86   // Preserve the operand bundles.
87   SmallVector<OperandBundleDef, 1> OpBundles;
88   II->getOperandBundlesAsDefs(OpBundles);
89 
90   auto *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
91   II->replaceAllUsesWith(Replacement);
92   // Preserve fast math flags for FP math.
93   if (isa<FPMathOperator>(Replacement))
94     Replacement->copyFastMathFlags(II);
95 }
96 
97 /// Returns true when successfully replaced \p II, which is a call to a
98 /// vectorized intrinsic, with a suitable function taking vector arguments,
99 /// based on available mappings in the \p TLI.
100 static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
101                                     IntrinsicInst *II) {
102   assert(II != nullptr && "Intrinsic cannot be null");
103   Intrinsic::ID IID = II->getIntrinsicID();
104   Type *RetTy = II->getType();
105   Type *ScalarRetTy = RetTy->getScalarType();
106   // At the moment VFABI assumes the return type is always widened unless it is
107   // a void type.
108   auto *VTy = dyn_cast<VectorType>(RetTy);
109   ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
110 
111   // OloadTys collects types used in scalar intrinsic overload name.
112   SmallVector<Type *, 3> OloadTys;
113   if (!RetTy->isVoidTy() && isVectorIntrinsicWithOverloadTypeAtArg(IID, -1))
114     OloadTys.push_back(ScalarRetTy);
115 
116   // Compute the argument types of the corresponding scalar call and check that
117   // all vector operands match the previously found EC.
118   SmallVector<Type *, 8> ScalarArgTypes;
119   for (auto Arg : enumerate(II->args())) {
120     auto *ArgTy = Arg.value()->getType();
121     bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index());
122     if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
123       ScalarArgTypes.push_back(ArgTy);
124       if (IsOloadTy)
125         OloadTys.push_back(ArgTy);
126     } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
127       auto *ScalarArgTy = VectorArgTy->getElementType();
128       ScalarArgTypes.push_back(ScalarArgTy);
129       if (IsOloadTy)
130         OloadTys.push_back(ScalarArgTy);
131       // When return type is void, set EC to the first vector argument, and
132       // disallow vector arguments with different ECs.
133       if (EC.isZero())
134         EC = VectorArgTy->getElementCount();
135       else if (EC != VectorArgTy->getElementCount())
136         return false;
137     } else
138       // Exit when it is supposed to be a vector argument but it isn't.
139       return false;
140   }
141 
142   // Try to reconstruct the name for the scalar version of the instruction,
143   // using scalar argument types.
144   std::string ScalarName =
145       Intrinsic::isOverloaded(IID)
146           ? Intrinsic::getName(IID, OloadTys, II->getModule())
147           : Intrinsic::getName(IID).str();
148 
149   // Try to find the mapping for the scalar version of this intrinsic and the
150   // exact vector width of the call operands in the TargetLibraryInfo. First,
151   // check with a non-masked variant, and if that fails try with a masked one.
152   const VecDesc *VD =
153       TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ false);
154   if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ true)))
155     return false;
156 
157   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI mapping from: `" << ScalarName
158                     << "` and vector width " << EC << " to: `"
159                     << VD->getVectorFnName() << "`.\n");
160 
161   // Replace the call to the intrinsic with a call to the vector library
162   // function.
163   FunctionType *ScalarFTy =
164       FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false);
165   const std::string MangledName = VD->getVectorFunctionABIVariantString();
166   auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
167   if (!OptInfo)
168     return false;
169 
170   // There is no guarantee that the vectorized instructions followed the VFABI
171   // specification when being created, this is why we need to add extra check to
172   // make sure that the operands of the vector function obtained via VFABI match
173   // the operands of the original vector instruction.
174   for (auto &VFParam : OptInfo->Shape.Parameters) {
175     if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
176       continue;
177 
178     // tryDemangleForVFABI must return valid ParamPos, otherwise it could be
179     // a bug in the VFABI parser.
180     assert(VFParam.ParamPos < II->arg_size() && "ParamPos has invalid range");
181     Type *OrigTy = II->getArgOperand(VFParam.ParamPos)->getType();
182     if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
183       LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName
184                         << ". Wrong type at index " << VFParam.ParamPos << ": "
185                         << *OrigTy << "\n");
186       return false;
187     }
188   }
189 
190   FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
191   if (!VectorFTy)
192     return false;
193 
194   Function *TLIFunc =
195       getTLIFunction(II->getModule(), VectorFTy, VD->getVectorFnName(),
196                      II->getCalledFunction());
197   replaceWithTLIFunction(II, *OptInfo, TLIFunc);
198   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
199                     << "` with call to `" << TLIFunc->getName() << "`.\n");
200   ++NumCallsReplaced;
201   return true;
202 }
203 
204 static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
205   SmallVector<Instruction *> ReplacedCalls;
206   for (auto &I : instructions(F)) {
207     // Process only intrinsic calls that return void or a vector.
208     if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
209       if (II->getIntrinsicID() == Intrinsic::not_intrinsic)
210         continue;
211       if (!II->getType()->isVectorTy() && !II->getType()->isVoidTy())
212         continue;
213 
214       if (replaceWithCallToVeclib(TLI, II))
215         ReplacedCalls.push_back(&I);
216     }
217   }
218   // Erase any intrinsic calls that were replaced with vector library calls.
219   for (auto *I : ReplacedCalls)
220     I->eraseFromParent();
221   return !ReplacedCalls.empty();
222 }
223 
224 ////////////////////////////////////////////////////////////////////////////////
225 // New pass manager implementation.
226 ////////////////////////////////////////////////////////////////////////////////
227 PreservedAnalyses ReplaceWithVeclib::run(Function &F,
228                                          FunctionAnalysisManager &AM) {
229   const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
230   auto Changed = runImpl(TLI, F);
231   if (Changed) {
232     LLVM_DEBUG(dbgs() << "Intrinsic calls replaced with vector libraries: "
233                       << NumCallsReplaced << "\n");
234 
235     PreservedAnalyses PA;
236     PA.preserveSet<CFGAnalyses>();
237     PA.preserve<TargetLibraryAnalysis>();
238     PA.preserve<ScalarEvolutionAnalysis>();
239     PA.preserve<LoopAccessAnalysis>();
240     PA.preserve<DemandedBitsAnalysis>();
241     PA.preserve<OptimizationRemarkEmitterAnalysis>();
242     return PA;
243   }
244 
245   // The pass did not replace any calls, hence it preserves all analyses.
246   return PreservedAnalyses::all();
247 }
248 
249 ////////////////////////////////////////////////////////////////////////////////
250 // Legacy PM Implementation.
251 ////////////////////////////////////////////////////////////////////////////////
252 bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
253   const TargetLibraryInfo &TLI =
254       getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
255   return runImpl(TLI, F);
256 }
257 
258 void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
259   AU.setPreservesCFG();
260   AU.addRequired<TargetLibraryInfoWrapperPass>();
261   AU.addPreserved<TargetLibraryInfoWrapperPass>();
262   AU.addPreserved<ScalarEvolutionWrapperPass>();
263   AU.addPreserved<AAResultsWrapperPass>();
264   AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
265   AU.addPreserved<GlobalsAAWrapperPass>();
266 }
267 
268 ////////////////////////////////////////////////////////////////////////////////
269 // Legacy Pass manager initialization
270 ////////////////////////////////////////////////////////////////////////////////
271 char ReplaceWithVeclibLegacy::ID = 0;
272 
273 INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE,
274                       "Replace intrinsics with calls to vector library", false,
275                       false)
276 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
277 INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE,
278                     "Replace intrinsics with calls to vector library", false,
279                     false)
280 
281 FunctionPass *llvm::createReplaceWithVeclibLegacyPass() {
282   return new ReplaceWithVeclibLegacy();
283 }
284