xref: /llvm-project/llvm/lib/CodeGen/ReplaceWithVeclib.cpp (revision 8e514c572e44eda237417236b4c92176dfce9cd9)
1 //=== ReplaceWithVeclib.cpp - Replace vector intrinsics with veclib calls -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Replaces LLVM IR instructions with vector operands (i.e., the frem
10 // instruction or calls to LLVM intrinsics) with matching calls to functions
11 // from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/CodeGen/ReplaceWithVeclib.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/Analysis/DemandedBits.h"
20 #include "llvm/Analysis/GlobalsModRef.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/Analysis/TargetLibraryInfo.h"
23 #include "llvm/Analysis/VectorUtils.h"
24 #include "llvm/CodeGen/Passes.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/InstIterator.h"
28 #include "llvm/Support/TypeSize.h"
29 #include "llvm/Transforms/Utils/ModuleUtils.h"
30 
31 using namespace llvm;
32 
33 #define DEBUG_TYPE "replace-with-veclib"
34 
35 STATISTIC(NumCallsReplaced,
36           "Number of calls to intrinsics that have been replaced.");
37 
38 STATISTIC(NumTLIFuncDeclAdded,
39           "Number of vector library function declarations added.");
40 
41 STATISTIC(NumFuncUsedAdded,
42           "Number of functions added to `llvm.compiler.used`");
43 
44 /// Returns a vector Function that it adds to the Module \p M. When an \p
45 /// ScalarFunc is not null, it copies its attributes to the newly created
46 /// Function.
47 Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
48                          const StringRef TLIName,
49                          Function *ScalarFunc = nullptr) {
50   Function *TLIFunc = M->getFunction(TLIName);
51   if (!TLIFunc) {
52     TLIFunc =
53         Function::Create(VectorFTy, Function::ExternalLinkage, TLIName, *M);
54     if (ScalarFunc)
55       TLIFunc->copyAttributesFrom(ScalarFunc);
56 
57     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
58                       << TLIName << "` of type `" << *(TLIFunc->getType())
59                       << "` to module.\n");
60 
61     ++NumTLIFuncDeclAdded;
62     // Add the freshly created function to llvm.compiler.used, similar to as it
63     // is done in InjectTLIMappings.
64     appendToCompilerUsed(*M, {TLIFunc});
65     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
66                       << "` to `@llvm.compiler.used`.\n");
67     ++NumFuncUsedAdded;
68   }
69   return TLIFunc;
70 }
71 
72 /// Replace the instruction \p I with a call to the corresponding function from
73 /// the vector library (\p TLIVecFunc).
74 static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
75                                    Function *TLIVecFunc) {
76   IRBuilder<> IRBuilder(&I);
77   auto *CI = dyn_cast<CallInst>(&I);
78   SmallVector<Value *> Args(CI ? CI->args() : I.operands());
79   if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
80     auto *MaskTy =
81         VectorType::get(Type::getInt1Ty(I.getContext()), Info.Shape.VF);
82     Args.insert(Args.begin() + OptMaskpos.value(),
83                 Constant::getAllOnesValue(MaskTy));
84   }
85 
86   // If it is a call instruction, preserve the operand bundles.
87   SmallVector<OperandBundleDef, 1> OpBundles;
88   if (CI)
89     CI->getOperandBundlesAsDefs(OpBundles);
90 
91   auto *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
92   I.replaceAllUsesWith(Replacement);
93   // Preserve fast math flags for FP math.
94   if (isa<FPMathOperator>(Replacement))
95     Replacement->copyFastMathFlags(&I);
96 }
97 
98 /// Returns true when successfully replaced \p I with a suitable function taking
99 /// vector arguments, based on available mappings in the \p TLI. Currently only
100 /// works when \p I is a call to vectorized intrinsic or the frem instruction.
101 static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
102                                     Instruction &I) {
103   // At the moment VFABI assumes the return type is always widened unless it is
104   // a void type.
105   auto *VTy = dyn_cast<VectorType>(I.getType());
106   ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
107 
108   // Compute the argument types of the corresponding scalar call and the scalar
109   // function name. For calls, it additionally finds the function to replace
110   // and checks that all vector operands match the previously found EC.
111   SmallVector<Type *, 8> ScalarArgTypes;
112   std::string ScalarName;
113   Function *FuncToReplace = nullptr;
114   auto *CI = dyn_cast<CallInst>(&I);
115   if (CI) {
116     FuncToReplace = CI->getCalledFunction();
117     Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
118     assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
119     for (auto Arg : enumerate(CI->args())) {
120       auto *ArgTy = Arg.value()->getType();
121       if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
122         ScalarArgTypes.push_back(ArgTy);
123       } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
124         ScalarArgTypes.push_back(VectorArgTy->getElementType());
125         // When return type is void, set EC to the first vector argument, and
126         // disallow vector arguments with different ECs.
127         if (EC.isZero())
128           EC = VectorArgTy->getElementCount();
129         else if (EC != VectorArgTy->getElementCount())
130           return false;
131       } else
132         // Exit when it is supposed to be a vector argument but it isn't.
133         return false;
134     }
135     // Try to reconstruct the name for the scalar version of the instruction,
136     // using scalar argument types.
137     ScalarName = Intrinsic::isOverloaded(IID)
138                      ? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
139                      : Intrinsic::getName(IID).str();
140   } else {
141     assert(VTy && "Return type must be a vector");
142     auto *ScalarTy = VTy->getScalarType();
143     LibFunc Func;
144     if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
145       return false;
146     ScalarName = TLI.getName(Func);
147     ScalarArgTypes = {ScalarTy, ScalarTy};
148   }
149 
150   // Try to find the mapping for the scalar version of this intrinsic and the
151   // exact vector width of the call operands in the TargetLibraryInfo. First,
152   // check with a non-masked variant, and if that fails try with a masked one.
153   const VecDesc *VD =
154       TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ false);
155   if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ true)))
156     return false;
157 
158   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI mapping from: `" << ScalarName
159                     << "` and vector width " << EC << " to: `"
160                     << VD->getVectorFnName() << "`.\n");
161 
162   // Replace the call to the intrinsic with a call to the vector library
163   // function.
164   Type *ScalarRetTy = I.getType()->getScalarType();
165   FunctionType *ScalarFTy =
166       FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false);
167   const std::string MangledName = VD->getVectorFunctionABIVariantString();
168   auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
169   if (!OptInfo)
170     return false;
171 
172   // There is no guarantee that the vectorized instructions followed the VFABI
173   // specification when being created, this is why we need to add extra check to
174   // make sure that the operands of the vector function obtained via VFABI match
175   // the operands of the original vector instruction.
176   if (CI) {
177     for (auto VFParam : OptInfo->Shape.Parameters) {
178       if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
179         continue;
180 
181       // tryDemangleForVFABI must return valid ParamPos, otherwise it could be
182       // a bug in the VFABI parser.
183       assert(VFParam.ParamPos < CI->arg_size() &&
184              "ParamPos has invalid range.");
185       Type *OrigTy = CI->getArgOperand(VFParam.ParamPos)->getType();
186       if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
187         LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName
188                           << ". Wrong type at index " << VFParam.ParamPos
189                           << ": " << *OrigTy << "\n");
190         return false;
191       }
192     }
193   }
194 
195   FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
196   if (!VectorFTy)
197     return false;
198 
199   Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
200                                      VD->getVectorFnName(), FuncToReplace);
201 
202   replaceWithTLIFunction(I, *OptInfo, TLIFunc);
203   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
204                     << "` with call to `" << TLIFunc->getName() << "`.\n");
205   ++NumCallsReplaced;
206   return true;
207 }
208 
209 /// Supported instruction \p I must be a vectorized frem or a call to an
210 /// intrinsic that returns either void or a vector.
211 static bool isSupportedInstruction(Instruction *I) {
212   Type *Ty = I->getType();
213   if (auto *CI = dyn_cast<CallInst>(I))
214     return (Ty->isVectorTy() || Ty->isVoidTy()) && CI->getCalledFunction() &&
215            CI->getCalledFunction()->getIntrinsicID() !=
216                Intrinsic::not_intrinsic;
217   if (I->getOpcode() == Instruction::FRem && Ty->isVectorTy())
218     return true;
219   return false;
220 }
221 
222 static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
223   bool Changed = false;
224   SmallVector<Instruction *> ReplacedCalls;
225   for (auto &I : instructions(F)) {
226     if (!isSupportedInstruction(&I))
227       continue;
228     if (replaceWithCallToVeclib(TLI, I)) {
229       ReplacedCalls.push_back(&I);
230       Changed = true;
231     }
232   }
233   // Erase the calls to the intrinsics that have been replaced
234   // with calls to the vector library.
235   for (auto *CI : ReplacedCalls)
236     CI->eraseFromParent();
237   return Changed;
238 }
239 
240 ////////////////////////////////////////////////////////////////////////////////
241 // New pass manager implementation.
242 ////////////////////////////////////////////////////////////////////////////////
243 PreservedAnalyses ReplaceWithVeclib::run(Function &F,
244                                          FunctionAnalysisManager &AM) {
245   const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
246   auto Changed = runImpl(TLI, F);
247   if (Changed) {
248     LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: "
249                       << NumCallsReplaced << "\n");
250 
251     PreservedAnalyses PA;
252     PA.preserveSet<CFGAnalyses>();
253     PA.preserve<TargetLibraryAnalysis>();
254     PA.preserve<ScalarEvolutionAnalysis>();
255     PA.preserve<LoopAccessAnalysis>();
256     PA.preserve<DemandedBitsAnalysis>();
257     PA.preserve<OptimizationRemarkEmitterAnalysis>();
258     return PA;
259   }
260 
261   // The pass did not replace any calls, hence it preserves all analyses.
262   return PreservedAnalyses::all();
263 }
264 
265 ////////////////////////////////////////////////////////////////////////////////
266 // Legacy PM Implementation.
267 ////////////////////////////////////////////////////////////////////////////////
268 bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
269   const TargetLibraryInfo &TLI =
270       getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
271   return runImpl(TLI, F);
272 }
273 
274 void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
275   AU.setPreservesCFG();
276   AU.addRequired<TargetLibraryInfoWrapperPass>();
277   AU.addPreserved<TargetLibraryInfoWrapperPass>();
278   AU.addPreserved<ScalarEvolutionWrapperPass>();
279   AU.addPreserved<AAResultsWrapperPass>();
280   AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
281   AU.addPreserved<GlobalsAAWrapperPass>();
282 }
283 
284 ////////////////////////////////////////////////////////////////////////////////
285 // Legacy Pass manager initialization
286 ////////////////////////////////////////////////////////////////////////////////
287 char ReplaceWithVeclibLegacy::ID = 0;
288 
289 INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE,
290                       "Replace intrinsics with calls to vector library", false,
291                       false)
292 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
293 INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE,
294                     "Replace intrinsics with calls to vector library", false,
295                     false)
296 
297 FunctionPass *llvm::createReplaceWithVeclibLegacyPass() {
298   return new ReplaceWithVeclibLegacy();
299 }
300