xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp (revision d45022b094a0a00b52057b464902693bc4e2db76)
1 //===- AMDGPULibCalls.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This file does AMD library function optimizations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "AMDGPU.h"
15 #include "AMDGPULibFunc.h"
16 #include "GCNSubtarget.h"
17 #include "llvm/Analysis/AliasAnalysis.h"
18 #include "llvm/Analysis/Loads.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/IntrinsicInst.h"
21 #include "llvm/IR/IntrinsicsAMDGPU.h"
22 #include "llvm/InitializePasses.h"
23 #include <cmath>
24 
25 #define DEBUG_TYPE "amdgpu-simplifylib"
26 
27 using namespace llvm;
28 
29 static cl::opt<bool> EnablePreLink("amdgpu-prelink",
30   cl::desc("Enable pre-link mode optimizations"),
31   cl::init(false),
32   cl::Hidden);
33 
34 static cl::list<std::string> UseNative("amdgpu-use-native",
35   cl::desc("Comma separated list of functions to replace with native, or all"),
36   cl::CommaSeparated, cl::ValueOptional,
37   cl::Hidden);
38 
39 #define MATH_PI      numbers::pi
40 #define MATH_E       numbers::e
41 #define MATH_SQRT2   numbers::sqrt2
42 #define MATH_SQRT1_2 numbers::inv_sqrt2
43 
44 namespace llvm {
45 
46 class AMDGPULibCalls {
47 private:
48 
49   typedef llvm::AMDGPULibFunc FuncInfo;
50 
51   bool UnsafeFPMath = false;
52 
53   // -fuse-native.
54   bool AllNative = false;
55 
56   bool useNativeFunc(const StringRef F) const;
57 
58   // Return a pointer (pointer expr) to the function if function definition with
59   // "FuncName" exists. It may create a new function prototype in pre-link mode.
60   FunctionCallee getFunction(Module *M, const FuncInfo &fInfo);
61 
62   bool parseFunctionName(const StringRef &FMangledName, FuncInfo &FInfo);
63 
64   bool TDOFold(CallInst *CI, const FuncInfo &FInfo);
65 
66   /* Specialized optimizations */
67 
68   // pow/powr/pown
69   bool fold_pow(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
70 
71   // rootn
72   bool fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
73 
74   // -fuse-native for sincos
75   bool sincosUseNative(CallInst *aCI, const FuncInfo &FInfo);
76 
77   // evaluate calls if calls' arguments are constants.
78   bool evaluateScalarMathFunc(const FuncInfo &FInfo, double &Res0, double &Res1,
79                               Constant *copr0, Constant *copr1);
80   bool evaluateCall(CallInst *aCI, const FuncInfo &FInfo);
81 
82   // sqrt
83   bool fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
84 
85   /// Insert a value to sincos function \p Fsincos. Returns (value of sin, value
86   /// of cos, sincos call).
87   std::tuple<Value *, Value *, Value *> insertSinCos(Value *Arg,
88                                                      FastMathFlags FMF,
89                                                      IRBuilder<> &B,
90                                                      FunctionCallee Fsincos);
91 
92   // sin/cos
93   bool fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
94 
95   // __read_pipe/__write_pipe
96   bool fold_read_write_pipe(CallInst *CI, IRBuilder<> &B,
97                             const FuncInfo &FInfo);
98 
99   // Get a scalar native builtin single argument FP function
100   FunctionCallee getNativeFunction(Module *M, const FuncInfo &FInfo);
101 
102   /// Substitute a call to a known libcall with an intrinsic call. If \p
103   /// AllowMinSize is true, allow the replacement in a minsize function.
104   bool shouldReplaceLibcallWithIntrinsic(const CallInst *CI,
105                                          bool AllowMinSizeF32 = false,
106                                          bool AllowF64 = false,
107                                          bool AllowStrictFP = false);
108   void replaceLibCallWithSimpleIntrinsic(CallInst *CI, Intrinsic::ID IntrID);
109 
110   bool tryReplaceLibcallWithSimpleIntrinsic(CallInst *CI, Intrinsic::ID IntrID,
111                                             bool AllowMinSizeF32 = false,
112                                             bool AllowF64 = false,
113                                             bool AllowStrictFP = false);
114 
115 protected:
116   bool isUnsafeMath(const FPMathOperator *FPOp) const;
117 
118   bool canIncreasePrecisionOfConstantFold(const FPMathOperator *FPOp) const;
119 
120   static void replaceCall(Instruction *I, Value *With) {
121     I->replaceAllUsesWith(With);
122     I->eraseFromParent();
123   }
124 
125   static void replaceCall(FPMathOperator *I, Value *With) {
126     replaceCall(cast<Instruction>(I), With);
127   }
128 
129 public:
130   AMDGPULibCalls() {}
131 
132   bool fold(CallInst *CI);
133 
134   void initFunction(const Function &F);
135   void initNativeFuncs();
136 
137   // Replace a normal math function call with that native version
138   bool useNative(CallInst *CI);
139 };
140 
141 } // end llvm namespace
142 
143 template <typename IRB>
144 static CallInst *CreateCallEx(IRB &B, FunctionCallee Callee, Value *Arg,
145                               const Twine &Name = "") {
146   CallInst *R = B.CreateCall(Callee, Arg, Name);
147   if (Function *F = dyn_cast<Function>(Callee.getCallee()))
148     R->setCallingConv(F->getCallingConv());
149   return R;
150 }
151 
152 template <typename IRB>
153 static CallInst *CreateCallEx2(IRB &B, FunctionCallee Callee, Value *Arg1,
154                                Value *Arg2, const Twine &Name = "") {
155   CallInst *R = B.CreateCall(Callee, {Arg1, Arg2}, Name);
156   if (Function *F = dyn_cast<Function>(Callee.getCallee()))
157     R->setCallingConv(F->getCallingConv());
158   return R;
159 }
160 
161 //  Data structures for table-driven optimizations.
162 //  FuncTbl works for both f32 and f64 functions with 1 input argument
163 
164 struct TableEntry {
165   double   result;
166   double   input;
167 };
168 
169 /* a list of {result, input} */
170 static const TableEntry tbl_acos[] = {
171   {MATH_PI / 2.0, 0.0},
172   {MATH_PI / 2.0, -0.0},
173   {0.0, 1.0},
174   {MATH_PI, -1.0}
175 };
176 static const TableEntry tbl_acosh[] = {
177   {0.0, 1.0}
178 };
179 static const TableEntry tbl_acospi[] = {
180   {0.5, 0.0},
181   {0.5, -0.0},
182   {0.0, 1.0},
183   {1.0, -1.0}
184 };
185 static const TableEntry tbl_asin[] = {
186   {0.0, 0.0},
187   {-0.0, -0.0},
188   {MATH_PI / 2.0, 1.0},
189   {-MATH_PI / 2.0, -1.0}
190 };
191 static const TableEntry tbl_asinh[] = {
192   {0.0, 0.0},
193   {-0.0, -0.0}
194 };
195 static const TableEntry tbl_asinpi[] = {
196   {0.0, 0.0},
197   {-0.0, -0.0},
198   {0.5, 1.0},
199   {-0.5, -1.0}
200 };
201 static const TableEntry tbl_atan[] = {
202   {0.0, 0.0},
203   {-0.0, -0.0},
204   {MATH_PI / 4.0, 1.0},
205   {-MATH_PI / 4.0, -1.0}
206 };
207 static const TableEntry tbl_atanh[] = {
208   {0.0, 0.0},
209   {-0.0, -0.0}
210 };
211 static const TableEntry tbl_atanpi[] = {
212   {0.0, 0.0},
213   {-0.0, -0.0},
214   {0.25, 1.0},
215   {-0.25, -1.0}
216 };
217 static const TableEntry tbl_cbrt[] = {
218   {0.0, 0.0},
219   {-0.0, -0.0},
220   {1.0, 1.0},
221   {-1.0, -1.0},
222 };
223 static const TableEntry tbl_cos[] = {
224   {1.0, 0.0},
225   {1.0, -0.0}
226 };
227 static const TableEntry tbl_cosh[] = {
228   {1.0, 0.0},
229   {1.0, -0.0}
230 };
231 static const TableEntry tbl_cospi[] = {
232   {1.0, 0.0},
233   {1.0, -0.0}
234 };
235 static const TableEntry tbl_erfc[] = {
236   {1.0, 0.0},
237   {1.0, -0.0}
238 };
239 static const TableEntry tbl_erf[] = {
240   {0.0, 0.0},
241   {-0.0, -0.0}
242 };
243 static const TableEntry tbl_exp[] = {
244   {1.0, 0.0},
245   {1.0, -0.0},
246   {MATH_E, 1.0}
247 };
248 static const TableEntry tbl_exp2[] = {
249   {1.0, 0.0},
250   {1.0, -0.0},
251   {2.0, 1.0}
252 };
253 static const TableEntry tbl_exp10[] = {
254   {1.0, 0.0},
255   {1.0, -0.0},
256   {10.0, 1.0}
257 };
258 static const TableEntry tbl_expm1[] = {
259   {0.0, 0.0},
260   {-0.0, -0.0}
261 };
262 static const TableEntry tbl_log[] = {
263   {0.0, 1.0},
264   {1.0, MATH_E}
265 };
266 static const TableEntry tbl_log2[] = {
267   {0.0, 1.0},
268   {1.0, 2.0}
269 };
270 static const TableEntry tbl_log10[] = {
271   {0.0, 1.0},
272   {1.0, 10.0}
273 };
274 static const TableEntry tbl_rsqrt[] = {
275   {1.0, 1.0},
276   {MATH_SQRT1_2, 2.0}
277 };
278 static const TableEntry tbl_sin[] = {
279   {0.0, 0.0},
280   {-0.0, -0.0}
281 };
282 static const TableEntry tbl_sinh[] = {
283   {0.0, 0.0},
284   {-0.0, -0.0}
285 };
286 static const TableEntry tbl_sinpi[] = {
287   {0.0, 0.0},
288   {-0.0, -0.0}
289 };
290 static const TableEntry tbl_sqrt[] = {
291   {0.0, 0.0},
292   {1.0, 1.0},
293   {MATH_SQRT2, 2.0}
294 };
295 static const TableEntry tbl_tan[] = {
296   {0.0, 0.0},
297   {-0.0, -0.0}
298 };
299 static const TableEntry tbl_tanh[] = {
300   {0.0, 0.0},
301   {-0.0, -0.0}
302 };
303 static const TableEntry tbl_tanpi[] = {
304   {0.0, 0.0},
305   {-0.0, -0.0}
306 };
307 static const TableEntry tbl_tgamma[] = {
308   {1.0, 1.0},
309   {1.0, 2.0},
310   {2.0, 3.0},
311   {6.0, 4.0}
312 };
313 
314 static bool HasNative(AMDGPULibFunc::EFuncId id) {
315   switch(id) {
316   case AMDGPULibFunc::EI_DIVIDE:
317   case AMDGPULibFunc::EI_COS:
318   case AMDGPULibFunc::EI_EXP:
319   case AMDGPULibFunc::EI_EXP2:
320   case AMDGPULibFunc::EI_EXP10:
321   case AMDGPULibFunc::EI_LOG:
322   case AMDGPULibFunc::EI_LOG2:
323   case AMDGPULibFunc::EI_LOG10:
324   case AMDGPULibFunc::EI_POWR:
325   case AMDGPULibFunc::EI_RECIP:
326   case AMDGPULibFunc::EI_RSQRT:
327   case AMDGPULibFunc::EI_SIN:
328   case AMDGPULibFunc::EI_SINCOS:
329   case AMDGPULibFunc::EI_SQRT:
330   case AMDGPULibFunc::EI_TAN:
331     return true;
332   default:;
333   }
334   return false;
335 }
336 
337 using TableRef = ArrayRef<TableEntry>;
338 
339 static TableRef getOptTable(AMDGPULibFunc::EFuncId id) {
340   switch(id) {
341   case AMDGPULibFunc::EI_ACOS:    return TableRef(tbl_acos);
342   case AMDGPULibFunc::EI_ACOSH:   return TableRef(tbl_acosh);
343   case AMDGPULibFunc::EI_ACOSPI:  return TableRef(tbl_acospi);
344   case AMDGPULibFunc::EI_ASIN:    return TableRef(tbl_asin);
345   case AMDGPULibFunc::EI_ASINH:   return TableRef(tbl_asinh);
346   case AMDGPULibFunc::EI_ASINPI:  return TableRef(tbl_asinpi);
347   case AMDGPULibFunc::EI_ATAN:    return TableRef(tbl_atan);
348   case AMDGPULibFunc::EI_ATANH:   return TableRef(tbl_atanh);
349   case AMDGPULibFunc::EI_ATANPI:  return TableRef(tbl_atanpi);
350   case AMDGPULibFunc::EI_CBRT:    return TableRef(tbl_cbrt);
351   case AMDGPULibFunc::EI_NCOS:
352   case AMDGPULibFunc::EI_COS:     return TableRef(tbl_cos);
353   case AMDGPULibFunc::EI_COSH:    return TableRef(tbl_cosh);
354   case AMDGPULibFunc::EI_COSPI:   return TableRef(tbl_cospi);
355   case AMDGPULibFunc::EI_ERFC:    return TableRef(tbl_erfc);
356   case AMDGPULibFunc::EI_ERF:     return TableRef(tbl_erf);
357   case AMDGPULibFunc::EI_EXP:     return TableRef(tbl_exp);
358   case AMDGPULibFunc::EI_NEXP2:
359   case AMDGPULibFunc::EI_EXP2:    return TableRef(tbl_exp2);
360   case AMDGPULibFunc::EI_EXP10:   return TableRef(tbl_exp10);
361   case AMDGPULibFunc::EI_EXPM1:   return TableRef(tbl_expm1);
362   case AMDGPULibFunc::EI_LOG:     return TableRef(tbl_log);
363   case AMDGPULibFunc::EI_NLOG2:
364   case AMDGPULibFunc::EI_LOG2:    return TableRef(tbl_log2);
365   case AMDGPULibFunc::EI_LOG10:   return TableRef(tbl_log10);
366   case AMDGPULibFunc::EI_NRSQRT:
367   case AMDGPULibFunc::EI_RSQRT:   return TableRef(tbl_rsqrt);
368   case AMDGPULibFunc::EI_NSIN:
369   case AMDGPULibFunc::EI_SIN:     return TableRef(tbl_sin);
370   case AMDGPULibFunc::EI_SINH:    return TableRef(tbl_sinh);
371   case AMDGPULibFunc::EI_SINPI:   return TableRef(tbl_sinpi);
372   case AMDGPULibFunc::EI_NSQRT:
373   case AMDGPULibFunc::EI_SQRT:    return TableRef(tbl_sqrt);
374   case AMDGPULibFunc::EI_TAN:     return TableRef(tbl_tan);
375   case AMDGPULibFunc::EI_TANH:    return TableRef(tbl_tanh);
376   case AMDGPULibFunc::EI_TANPI:   return TableRef(tbl_tanpi);
377   case AMDGPULibFunc::EI_TGAMMA:  return TableRef(tbl_tgamma);
378   default:;
379   }
380   return TableRef();
381 }
382 
383 static inline int getVecSize(const AMDGPULibFunc& FInfo) {
384   return FInfo.getLeads()[0].VectorSize;
385 }
386 
387 static inline AMDGPULibFunc::EType getArgType(const AMDGPULibFunc& FInfo) {
388   return (AMDGPULibFunc::EType)FInfo.getLeads()[0].ArgType;
389 }
390 
391 FunctionCallee AMDGPULibCalls::getFunction(Module *M, const FuncInfo &fInfo) {
392   // If we are doing PreLinkOpt, the function is external. So it is safe to
393   // use getOrInsertFunction() at this stage.
394 
395   return EnablePreLink ? AMDGPULibFunc::getOrInsertFunction(M, fInfo)
396                        : AMDGPULibFunc::getFunction(M, fInfo);
397 }
398 
399 bool AMDGPULibCalls::parseFunctionName(const StringRef &FMangledName,
400                                        FuncInfo &FInfo) {
401   return AMDGPULibFunc::parse(FMangledName, FInfo);
402 }
403 
404 bool AMDGPULibCalls::isUnsafeMath(const FPMathOperator *FPOp) const {
405   return UnsafeFPMath || FPOp->isFast();
406 }
407 
408 bool AMDGPULibCalls::canIncreasePrecisionOfConstantFold(
409     const FPMathOperator *FPOp) const {
410   // TODO: Refine to approxFunc or contract
411   return isUnsafeMath(FPOp);
412 }
413 
414 void AMDGPULibCalls::initFunction(const Function &F) {
415   UnsafeFPMath = F.getFnAttribute("unsafe-fp-math").getValueAsBool();
416 }
417 
418 bool AMDGPULibCalls::useNativeFunc(const StringRef F) const {
419   return AllNative || llvm::is_contained(UseNative, F);
420 }
421 
422 void AMDGPULibCalls::initNativeFuncs() {
423   AllNative = useNativeFunc("all") ||
424               (UseNative.getNumOccurrences() && UseNative.size() == 1 &&
425                UseNative.begin()->empty());
426 }
427 
428 bool AMDGPULibCalls::sincosUseNative(CallInst *aCI, const FuncInfo &FInfo) {
429   bool native_sin = useNativeFunc("sin");
430   bool native_cos = useNativeFunc("cos");
431 
432   if (native_sin && native_cos) {
433     Module *M = aCI->getModule();
434     Value *opr0 = aCI->getArgOperand(0);
435 
436     AMDGPULibFunc nf;
437     nf.getLeads()[0].ArgType = FInfo.getLeads()[0].ArgType;
438     nf.getLeads()[0].VectorSize = FInfo.getLeads()[0].VectorSize;
439 
440     nf.setPrefix(AMDGPULibFunc::NATIVE);
441     nf.setId(AMDGPULibFunc::EI_SIN);
442     FunctionCallee sinExpr = getFunction(M, nf);
443 
444     nf.setPrefix(AMDGPULibFunc::NATIVE);
445     nf.setId(AMDGPULibFunc::EI_COS);
446     FunctionCallee cosExpr = getFunction(M, nf);
447     if (sinExpr && cosExpr) {
448       Value *sinval = CallInst::Create(sinExpr, opr0, "splitsin", aCI);
449       Value *cosval = CallInst::Create(cosExpr, opr0, "splitcos", aCI);
450       new StoreInst(cosval, aCI->getArgOperand(1), aCI);
451 
452       DEBUG_WITH_TYPE("usenative", dbgs() << "<useNative> replace " << *aCI
453                                           << " with native version of sin/cos");
454 
455       replaceCall(aCI, sinval);
456       return true;
457     }
458   }
459   return false;
460 }
461 
462 bool AMDGPULibCalls::useNative(CallInst *aCI) {
463   Function *Callee = aCI->getCalledFunction();
464   if (!Callee || aCI->isNoBuiltin())
465     return false;
466 
467   FuncInfo FInfo;
468   if (!parseFunctionName(Callee->getName(), FInfo) || !FInfo.isMangled() ||
469       FInfo.getPrefix() != AMDGPULibFunc::NOPFX ||
470       getArgType(FInfo) == AMDGPULibFunc::F64 || !HasNative(FInfo.getId()) ||
471       !(AllNative || useNativeFunc(FInfo.getName()))) {
472     return false;
473   }
474 
475   if (FInfo.getId() == AMDGPULibFunc::EI_SINCOS)
476     return sincosUseNative(aCI, FInfo);
477 
478   FInfo.setPrefix(AMDGPULibFunc::NATIVE);
479   FunctionCallee F = getFunction(aCI->getModule(), FInfo);
480   if (!F)
481     return false;
482 
483   aCI->setCalledFunction(F);
484   DEBUG_WITH_TYPE("usenative", dbgs() << "<useNative> replace " << *aCI
485                                       << " with native version");
486   return true;
487 }
488 
489 // Clang emits call of __read_pipe_2 or __read_pipe_4 for OpenCL read_pipe
490 // builtin, with appended type size and alignment arguments, where 2 or 4
491 // indicates the original number of arguments. The library has optimized version
492 // of __read_pipe_2/__read_pipe_4 when the type size and alignment has the same
493 // power of 2 value. This function transforms __read_pipe_2 to __read_pipe_2_N
494 // for such cases where N is the size in bytes of the type (N = 1, 2, 4, 8, ...,
495 // 128). The same for __read_pipe_4, write_pipe_2, and write_pipe_4.
496 bool AMDGPULibCalls::fold_read_write_pipe(CallInst *CI, IRBuilder<> &B,
497                                           const FuncInfo &FInfo) {
498   auto *Callee = CI->getCalledFunction();
499   if (!Callee->isDeclaration())
500     return false;
501 
502   assert(Callee->hasName() && "Invalid read_pipe/write_pipe function");
503   auto *M = Callee->getParent();
504   std::string Name = std::string(Callee->getName());
505   auto NumArg = CI->arg_size();
506   if (NumArg != 4 && NumArg != 6)
507     return false;
508   ConstantInt *PacketSize =
509       dyn_cast<ConstantInt>(CI->getArgOperand(NumArg - 2));
510   ConstantInt *PacketAlign =
511       dyn_cast<ConstantInt>(CI->getArgOperand(NumArg - 1));
512   if (!PacketSize || !PacketAlign)
513     return false;
514 
515   unsigned Size = PacketSize->getZExtValue();
516   Align Alignment = PacketAlign->getAlignValue();
517   if (Alignment != Size)
518     return false;
519 
520   unsigned PtrArgLoc = CI->arg_size() - 3;
521   Value *PtrArg = CI->getArgOperand(PtrArgLoc);
522   Type *PtrTy = PtrArg->getType();
523 
524   SmallVector<llvm::Type *, 6> ArgTys;
525   for (unsigned I = 0; I != PtrArgLoc; ++I)
526     ArgTys.push_back(CI->getArgOperand(I)->getType());
527   ArgTys.push_back(PtrTy);
528 
529   Name = Name + "_" + std::to_string(Size);
530   auto *FTy = FunctionType::get(Callee->getReturnType(),
531                                 ArrayRef<Type *>(ArgTys), false);
532   AMDGPULibFunc NewLibFunc(Name, FTy);
533   FunctionCallee F = AMDGPULibFunc::getOrInsertFunction(M, NewLibFunc);
534   if (!F)
535     return false;
536 
537   auto *BCast = B.CreatePointerCast(PtrArg, PtrTy);
538   SmallVector<Value *, 6> Args;
539   for (unsigned I = 0; I != PtrArgLoc; ++I)
540     Args.push_back(CI->getArgOperand(I));
541   Args.push_back(BCast);
542 
543   auto *NCI = B.CreateCall(F, Args);
544   NCI->setAttributes(CI->getAttributes());
545   CI->replaceAllUsesWith(NCI);
546   CI->dropAllReferences();
547   CI->eraseFromParent();
548 
549   return true;
550 }
551 
552 // This function returns false if no change; return true otherwise.
553 bool AMDGPULibCalls::fold(CallInst *CI) {
554   Function *Callee = CI->getCalledFunction();
555   // Ignore indirect calls.
556   if (!Callee || Callee->isIntrinsic() || CI->isNoBuiltin())
557     return false;
558 
559   FuncInfo FInfo;
560   if (!parseFunctionName(Callee->getName(), FInfo))
561     return false;
562 
563   // Further check the number of arguments to see if they match.
564   // TODO: Check calling convention matches too
565   if (!FInfo.isCompatibleSignature(CI->getFunctionType()))
566     return false;
567 
568   LLVM_DEBUG(dbgs() << "AMDIC: try folding " << *CI << '\n');
569 
570   if (TDOFold(CI, FInfo))
571     return true;
572 
573   IRBuilder<> B(CI);
574 
575   if (FPMathOperator *FPOp = dyn_cast<FPMathOperator>(CI)) {
576     // Under unsafe-math, evaluate calls if possible.
577     // According to Brian Sumner, we can do this for all f32 function calls
578     // using host's double function calls.
579     if (canIncreasePrecisionOfConstantFold(FPOp) && evaluateCall(CI, FInfo))
580       return true;
581 
582     // Copy fast flags from the original call.
583     FastMathFlags FMF = FPOp->getFastMathFlags();
584     B.setFastMathFlags(FMF);
585 
586     // Specialized optimizations for each function call.
587     //
588     // TODO: Handle other simple intrinsic wrappers. Sqrt, log.
589     //
590     // TODO: Handle native functions
591     switch (FInfo.getId()) {
592     case AMDGPULibFunc::EI_EXP:
593       if (FMF.none())
594         return false;
595       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::exp,
596                                                   FMF.approxFunc());
597     case AMDGPULibFunc::EI_EXP2:
598       if (FMF.none())
599         return false;
600       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::exp2,
601                                                   FMF.approxFunc());
602     case AMDGPULibFunc::EI_FMIN:
603       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::minnum, true,
604                                                   true);
605     case AMDGPULibFunc::EI_FMAX:
606       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::maxnum, true,
607                                                   true);
608     case AMDGPULibFunc::EI_FMA:
609       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::fma, true,
610                                                   true);
611     case AMDGPULibFunc::EI_MAD:
612       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::fmuladd, true,
613                                                   true);
614     case AMDGPULibFunc::EI_FABS:
615       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::fabs, true,
616                                                   true, true);
617     case AMDGPULibFunc::EI_COPYSIGN:
618       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::copysign, true,
619                                                   true, true);
620     case AMDGPULibFunc::EI_FLOOR:
621       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::floor, true,
622                                                   true);
623     case AMDGPULibFunc::EI_CEIL:
624       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::ceil, true,
625                                                   true);
626     case AMDGPULibFunc::EI_TRUNC:
627       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::trunc, true,
628                                                   true);
629     case AMDGPULibFunc::EI_RINT:
630       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::rint, true,
631                                                   true);
632     case AMDGPULibFunc::EI_ROUND:
633       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::round, true,
634                                                   true);
635     case AMDGPULibFunc::EI_LDEXP: {
636       if (!shouldReplaceLibcallWithIntrinsic(CI, true, true))
637         return false;
638       CI->setCalledFunction(Intrinsic::getDeclaration(
639           CI->getModule(), Intrinsic::ldexp,
640           {CI->getType(), CI->getArgOperand(1)->getType()}));
641       return true;
642     }
643     case AMDGPULibFunc::EI_POW:
644     case AMDGPULibFunc::EI_POWR:
645     case AMDGPULibFunc::EI_POWN:
646       return fold_pow(FPOp, B, FInfo);
647     case AMDGPULibFunc::EI_ROOTN:
648       return fold_rootn(FPOp, B, FInfo);
649     case AMDGPULibFunc::EI_SQRT:
650       return fold_sqrt(FPOp, B, FInfo);
651     case AMDGPULibFunc::EI_COS:
652     case AMDGPULibFunc::EI_SIN:
653       return fold_sincos(FPOp, B, FInfo);
654     default:
655       break;
656     }
657   } else {
658     // Specialized optimizations for each function call
659     switch (FInfo.getId()) {
660     case AMDGPULibFunc::EI_READ_PIPE_2:
661     case AMDGPULibFunc::EI_READ_PIPE_4:
662     case AMDGPULibFunc::EI_WRITE_PIPE_2:
663     case AMDGPULibFunc::EI_WRITE_PIPE_4:
664       return fold_read_write_pipe(CI, B, FInfo);
665     default:
666       break;
667     }
668   }
669 
670   return false;
671 }
672 
673 bool AMDGPULibCalls::TDOFold(CallInst *CI, const FuncInfo &FInfo) {
674   // Table-Driven optimization
675   const TableRef tr = getOptTable(FInfo.getId());
676   if (tr.empty())
677     return false;
678 
679   int const sz = (int)tr.size();
680   Value *opr0 = CI->getArgOperand(0);
681 
682   if (getVecSize(FInfo) > 1) {
683     if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(opr0)) {
684       SmallVector<double, 0> DVal;
685       for (int eltNo = 0; eltNo < getVecSize(FInfo); ++eltNo) {
686         ConstantFP *eltval = dyn_cast<ConstantFP>(
687                                CV->getElementAsConstant((unsigned)eltNo));
688         assert(eltval && "Non-FP arguments in math function!");
689         bool found = false;
690         for (int i=0; i < sz; ++i) {
691           if (eltval->isExactlyValue(tr[i].input)) {
692             DVal.push_back(tr[i].result);
693             found = true;
694             break;
695           }
696         }
697         if (!found) {
698           // This vector constants not handled yet.
699           return false;
700         }
701       }
702       LLVMContext &context = CI->getParent()->getParent()->getContext();
703       Constant *nval;
704       if (getArgType(FInfo) == AMDGPULibFunc::F32) {
705         SmallVector<float, 0> FVal;
706         for (unsigned i = 0; i < DVal.size(); ++i) {
707           FVal.push_back((float)DVal[i]);
708         }
709         ArrayRef<float> tmp(FVal);
710         nval = ConstantDataVector::get(context, tmp);
711       } else { // F64
712         ArrayRef<double> tmp(DVal);
713         nval = ConstantDataVector::get(context, tmp);
714       }
715       LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n");
716       replaceCall(CI, nval);
717       return true;
718     }
719   } else {
720     // Scalar version
721     if (ConstantFP *CF = dyn_cast<ConstantFP>(opr0)) {
722       for (int i = 0; i < sz; ++i) {
723         if (CF->isExactlyValue(tr[i].input)) {
724           Value *nval = ConstantFP::get(CF->getType(), tr[i].result);
725           LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n");
726           replaceCall(CI, nval);
727           return true;
728         }
729       }
730     }
731   }
732 
733   return false;
734 }
735 
736 namespace llvm {
737 static double log2(double V) {
738 #if _XOPEN_SOURCE >= 600 || defined(_ISOC99_SOURCE) || _POSIX_C_SOURCE >= 200112L
739   return ::log2(V);
740 #else
741   return log(V) / numbers::ln2;
742 #endif
743 }
744 }
745 
746 bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
747                               const FuncInfo &FInfo) {
748   assert((FInfo.getId() == AMDGPULibFunc::EI_POW ||
749           FInfo.getId() == AMDGPULibFunc::EI_POWR ||
750           FInfo.getId() == AMDGPULibFunc::EI_POWN) &&
751          "fold_pow: encounter a wrong function call");
752 
753   Module *M = B.GetInsertBlock()->getModule();
754   ConstantFP *CF;
755   ConstantInt *CINT;
756   Type *eltType;
757   Value *opr0 = FPOp->getOperand(0);
758   Value *opr1 = FPOp->getOperand(1);
759   ConstantAggregateZero *CZero = dyn_cast<ConstantAggregateZero>(opr1);
760 
761   if (getVecSize(FInfo) == 1) {
762     eltType = opr0->getType();
763     CF = dyn_cast<ConstantFP>(opr1);
764     CINT = dyn_cast<ConstantInt>(opr1);
765   } else {
766     VectorType *VTy = dyn_cast<VectorType>(opr0->getType());
767     assert(VTy && "Oprand of vector function should be of vectortype");
768     eltType = VTy->getElementType();
769     ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr1);
770 
771     // Now, only Handle vector const whose elements have the same value.
772     CF = CDV ? dyn_cast_or_null<ConstantFP>(CDV->getSplatValue()) : nullptr;
773     CINT = CDV ? dyn_cast_or_null<ConstantInt>(CDV->getSplatValue()) : nullptr;
774   }
775 
776   // No unsafe math , no constant argument, do nothing
777   if (!isUnsafeMath(FPOp) && !CF && !CINT && !CZero)
778     return false;
779 
780   // 0x1111111 means that we don't do anything for this call.
781   int ci_opr1 = (CINT ? (int)CINT->getSExtValue() : 0x1111111);
782 
783   if ((CF && CF->isZero()) || (CINT && ci_opr1 == 0) || CZero) {
784     //  pow/powr/pown(x, 0) == 1
785     LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1\n");
786     Constant *cnval = ConstantFP::get(eltType, 1.0);
787     if (getVecSize(FInfo) > 1) {
788       cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
789     }
790     replaceCall(FPOp, cnval);
791     return true;
792   }
793   if ((CF && CF->isExactlyValue(1.0)) || (CINT && ci_opr1 == 1)) {
794     // pow/powr/pown(x, 1.0) = x
795     LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << "\n");
796     replaceCall(FPOp, opr0);
797     return true;
798   }
799   if ((CF && CF->isExactlyValue(2.0)) || (CINT && ci_opr1 == 2)) {
800     // pow/powr/pown(x, 2.0) = x*x
801     LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << " * "
802                       << *opr0 << "\n");
803     Value *nval = B.CreateFMul(opr0, opr0, "__pow2");
804     replaceCall(FPOp, nval);
805     return true;
806   }
807   if ((CF && CF->isExactlyValue(-1.0)) || (CINT && ci_opr1 == -1)) {
808     // pow/powr/pown(x, -1.0) = 1.0/x
809     LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1 / " << *opr0 << "\n");
810     Constant *cnval = ConstantFP::get(eltType, 1.0);
811     if (getVecSize(FInfo) > 1) {
812       cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
813     }
814     Value *nval = B.CreateFDiv(cnval, opr0, "__powrecip");
815     replaceCall(FPOp, nval);
816     return true;
817   }
818 
819   if (CF && (CF->isExactlyValue(0.5) || CF->isExactlyValue(-0.5))) {
820     // pow[r](x, [-]0.5) = sqrt(x)
821     bool issqrt = CF->isExactlyValue(0.5);
822     if (FunctionCallee FPExpr =
823             getFunction(M, AMDGPULibFunc(issqrt ? AMDGPULibFunc::EI_SQRT
824                                                 : AMDGPULibFunc::EI_RSQRT,
825                                          FInfo))) {
826       LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << FInfo.getName()
827                         << '(' << *opr0 << ")\n");
828       Value *nval = CreateCallEx(B,FPExpr, opr0, issqrt ? "__pow2sqrt"
829                                                         : "__pow2rsqrt");
830       replaceCall(FPOp, nval);
831       return true;
832     }
833   }
834 
835   if (!isUnsafeMath(FPOp))
836     return false;
837 
838   // Unsafe Math optimization
839 
840   // Remember that ci_opr1 is set if opr1 is integral
841   if (CF) {
842     double dval = (getArgType(FInfo) == AMDGPULibFunc::F32)
843                     ? (double)CF->getValueAPF().convertToFloat()
844                     : CF->getValueAPF().convertToDouble();
845     int ival = (int)dval;
846     if ((double)ival == dval) {
847       ci_opr1 = ival;
848     } else
849       ci_opr1 = 0x11111111;
850   }
851 
852   // pow/powr/pown(x, c) = [1/](x*x*..x); where
853   //   trunc(c) == c && the number of x == c && |c| <= 12
854   unsigned abs_opr1 = (ci_opr1 < 0) ? -ci_opr1 : ci_opr1;
855   if (abs_opr1 <= 12) {
856     Constant *cnval;
857     Value *nval;
858     if (abs_opr1 == 0) {
859       cnval = ConstantFP::get(eltType, 1.0);
860       if (getVecSize(FInfo) > 1) {
861         cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
862       }
863       nval = cnval;
864     } else {
865       Value *valx2 = nullptr;
866       nval = nullptr;
867       while (abs_opr1 > 0) {
868         valx2 = valx2 ? B.CreateFMul(valx2, valx2, "__powx2") : opr0;
869         if (abs_opr1 & 1) {
870           nval = nval ? B.CreateFMul(nval, valx2, "__powprod") : valx2;
871         }
872         abs_opr1 >>= 1;
873       }
874     }
875 
876     if (ci_opr1 < 0) {
877       cnval = ConstantFP::get(eltType, 1.0);
878       if (getVecSize(FInfo) > 1) {
879         cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
880       }
881       nval = B.CreateFDiv(cnval, nval, "__1powprod");
882     }
883     LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> "
884                       << ((ci_opr1 < 0) ? "1/prod(" : "prod(") << *opr0
885                       << ")\n");
886     replaceCall(FPOp, nval);
887     return true;
888   }
889 
890   // powr ---> exp2(y * log2(x))
891   // pown/pow ---> powr(fabs(x), y) | (x & ((int)y << 31))
892   FunctionCallee ExpExpr =
893       getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_EXP2, FInfo));
894   if (!ExpExpr)
895     return false;
896 
897   bool needlog = false;
898   bool needabs = false;
899   bool needcopysign = false;
900   Constant *cnval = nullptr;
901   if (getVecSize(FInfo) == 1) {
902     CF = dyn_cast<ConstantFP>(opr0);
903 
904     if (CF) {
905       double V = (getArgType(FInfo) == AMDGPULibFunc::F32)
906                    ? (double)CF->getValueAPF().convertToFloat()
907                    : CF->getValueAPF().convertToDouble();
908 
909       V = log2(std::abs(V));
910       cnval = ConstantFP::get(eltType, V);
911       needcopysign = (FInfo.getId() != AMDGPULibFunc::EI_POWR) &&
912                      CF->isNegative();
913     } else {
914       needlog = true;
915       needcopysign = needabs = FInfo.getId() != AMDGPULibFunc::EI_POWR &&
916                                (!CF || CF->isNegative());
917     }
918   } else {
919     ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr0);
920 
921     if (!CDV) {
922       needlog = true;
923       needcopysign = needabs = FInfo.getId() != AMDGPULibFunc::EI_POWR;
924     } else {
925       assert ((int)CDV->getNumElements() == getVecSize(FInfo) &&
926               "Wrong vector size detected");
927 
928       SmallVector<double, 0> DVal;
929       for (int i=0; i < getVecSize(FInfo); ++i) {
930         double V = (getArgType(FInfo) == AMDGPULibFunc::F32)
931                      ? (double)CDV->getElementAsFloat(i)
932                      : CDV->getElementAsDouble(i);
933         if (V < 0.0) needcopysign = true;
934         V = log2(std::abs(V));
935         DVal.push_back(V);
936       }
937       if (getArgType(FInfo) == AMDGPULibFunc::F32) {
938         SmallVector<float, 0> FVal;
939         for (unsigned i=0; i < DVal.size(); ++i) {
940           FVal.push_back((float)DVal[i]);
941         }
942         ArrayRef<float> tmp(FVal);
943         cnval = ConstantDataVector::get(M->getContext(), tmp);
944       } else {
945         ArrayRef<double> tmp(DVal);
946         cnval = ConstantDataVector::get(M->getContext(), tmp);
947       }
948     }
949   }
950 
951   if (needcopysign && (FInfo.getId() == AMDGPULibFunc::EI_POW)) {
952     // We cannot handle corner cases for a general pow() function, give up
953     // unless y is a constant integral value. Then proceed as if it were pown.
954     if (getVecSize(FInfo) == 1) {
955       if (const ConstantFP *CF = dyn_cast<ConstantFP>(opr1)) {
956         double y = (getArgType(FInfo) == AMDGPULibFunc::F32)
957                    ? (double)CF->getValueAPF().convertToFloat()
958                    : CF->getValueAPF().convertToDouble();
959         if (y != (double)(int64_t)y)
960           return false;
961       } else
962         return false;
963     } else {
964       if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr1)) {
965         for (int i=0; i < getVecSize(FInfo); ++i) {
966           double y = (getArgType(FInfo) == AMDGPULibFunc::F32)
967                      ? (double)CDV->getElementAsFloat(i)
968                      : CDV->getElementAsDouble(i);
969           if (y != (double)(int64_t)y)
970             return false;
971         }
972       } else
973         return false;
974     }
975   }
976 
977   Value *nval;
978   if (needabs) {
979     nval = B.CreateUnaryIntrinsic(Intrinsic::fabs, opr0, nullptr, "__fabs");
980   } else {
981     nval = cnval ? cnval : opr0;
982   }
983   if (needlog) {
984     FunctionCallee LogExpr =
985         getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_LOG2, FInfo));
986     if (!LogExpr)
987       return false;
988     nval = CreateCallEx(B,LogExpr, nval, "__log2");
989   }
990 
991   if (FInfo.getId() == AMDGPULibFunc::EI_POWN) {
992     // convert int(32) to fp(f32 or f64)
993     opr1 = B.CreateSIToFP(opr1, nval->getType(), "pownI2F");
994   }
995   nval = B.CreateFMul(opr1, nval, "__ylogx");
996   nval = CreateCallEx(B,ExpExpr, nval, "__exp2");
997 
998   if (needcopysign) {
999     Value *opr_n;
1000     Type* rTy = opr0->getType();
1001     Type* nTyS = eltType->isDoubleTy() ? B.getInt64Ty() : B.getInt32Ty();
1002     Type *nTy = nTyS;
1003     if (const auto *vTy = dyn_cast<FixedVectorType>(rTy))
1004       nTy = FixedVectorType::get(nTyS, vTy);
1005     unsigned size = nTy->getScalarSizeInBits();
1006     opr_n = FPOp->getOperand(1);
1007     if (opr_n->getType()->isIntegerTy())
1008       opr_n = B.CreateZExtOrBitCast(opr_n, nTy, "__ytou");
1009     else
1010       opr_n = B.CreateFPToSI(opr1, nTy, "__ytou");
1011 
1012     Value *sign = B.CreateShl(opr_n, size-1, "__yeven");
1013     sign = B.CreateAnd(B.CreateBitCast(opr0, nTy), sign, "__pow_sign");
1014     nval = B.CreateOr(B.CreateBitCast(nval, nTy), sign);
1015     nval = B.CreateBitCast(nval, opr0->getType());
1016   }
1017 
1018   LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> "
1019                     << "exp2(" << *opr1 << " * log2(" << *opr0 << "))\n");
1020   replaceCall(FPOp, nval);
1021 
1022   return true;
1023 }
1024 
1025 bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B,
1026                                 const FuncInfo &FInfo) {
1027   // skip vector function
1028   if (getVecSize(FInfo) != 1)
1029     return false;
1030 
1031   Value *opr0 = FPOp->getOperand(0);
1032   Value *opr1 = FPOp->getOperand(1);
1033 
1034   ConstantInt *CINT = dyn_cast<ConstantInt>(opr1);
1035   if (!CINT) {
1036     return false;
1037   }
1038   int ci_opr1 = (int)CINT->getSExtValue();
1039   if (ci_opr1 == 1) {  // rootn(x, 1) = x
1040     LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << "\n");
1041     replaceCall(FPOp, opr0);
1042     return true;
1043   }
1044 
1045   Module *M = B.GetInsertBlock()->getModule();
1046   if (ci_opr1 == 2) { // rootn(x, 2) = sqrt(x)
1047     if (FunctionCallee FPExpr =
1048             getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
1049       LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> sqrt(" << *opr0
1050                         << ")\n");
1051       Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2sqrt");
1052       replaceCall(FPOp, nval);
1053       return true;
1054     }
1055   } else if (ci_opr1 == 3) { // rootn(x, 3) = cbrt(x)
1056     if (FunctionCallee FPExpr =
1057             getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_CBRT, FInfo))) {
1058       LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> cbrt(" << *opr0
1059                         << ")\n");
1060       Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2cbrt");
1061       replaceCall(FPOp, nval);
1062       return true;
1063     }
1064   } else if (ci_opr1 == -1) { // rootn(x, -1) = 1.0/x
1065     LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1.0 / " << *opr0 << "\n");
1066     Value *nval = B.CreateFDiv(ConstantFP::get(opr0->getType(), 1.0),
1067                                opr0,
1068                                "__rootn2div");
1069     replaceCall(FPOp, nval);
1070     return true;
1071   } else if (ci_opr1 == -2) { // rootn(x, -2) = rsqrt(x)
1072     if (FunctionCallee FPExpr =
1073             getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_RSQRT, FInfo))) {
1074       LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> rsqrt(" << *opr0
1075                         << ")\n");
1076       Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2rsqrt");
1077       replaceCall(FPOp, nval);
1078       return true;
1079     }
1080   }
1081   return false;
1082 }
1083 
1084 // Get a scalar native builtin single argument FP function
1085 FunctionCallee AMDGPULibCalls::getNativeFunction(Module *M,
1086                                                  const FuncInfo &FInfo) {
1087   if (getArgType(FInfo) == AMDGPULibFunc::F64 || !HasNative(FInfo.getId()))
1088     return nullptr;
1089   FuncInfo nf = FInfo;
1090   nf.setPrefix(AMDGPULibFunc::NATIVE);
1091   return getFunction(M, nf);
1092 }
1093 
1094 // Some library calls are just wrappers around llvm intrinsics, but compiled
1095 // conservatively. Preserve the flags from the original call site by
1096 // substituting them with direct calls with all the flags.
1097 bool AMDGPULibCalls::shouldReplaceLibcallWithIntrinsic(const CallInst *CI,
1098                                                        bool AllowMinSizeF32,
1099                                                        bool AllowF64,
1100                                                        bool AllowStrictFP) {
1101   Type *FltTy = CI->getType()->getScalarType();
1102   const bool IsF32 = FltTy->isFloatTy();
1103 
1104   // f64 intrinsics aren't implemented for most operations.
1105   if (!IsF32 && !FltTy->isHalfTy() && (!AllowF64 || !FltTy->isDoubleTy()))
1106     return false;
1107 
1108   // We're implicitly inlining by replacing the libcall with the intrinsic, so
1109   // don't do it for noinline call sites.
1110   if (CI->isNoInline())
1111     return false;
1112 
1113   const Function *ParentF = CI->getFunction();
1114   // TODO: Handle strictfp
1115   if (!AllowStrictFP && ParentF->hasFnAttribute(Attribute::StrictFP))
1116     return false;
1117 
1118   if (IsF32 && !AllowMinSizeF32 && ParentF->hasMinSize())
1119     return false;
1120   return true;
1121 }
1122 
1123 void AMDGPULibCalls::replaceLibCallWithSimpleIntrinsic(CallInst *CI,
1124                                                        Intrinsic::ID IntrID) {
1125   CI->setCalledFunction(
1126       Intrinsic::getDeclaration(CI->getModule(), IntrID, {CI->getType()}));
1127 }
1128 
1129 bool AMDGPULibCalls::tryReplaceLibcallWithSimpleIntrinsic(CallInst *CI,
1130                                                           Intrinsic::ID IntrID,
1131                                                           bool AllowMinSizeF32,
1132                                                           bool AllowF64,
1133                                                           bool AllowStrictFP) {
1134   if (!shouldReplaceLibcallWithIntrinsic(CI, AllowMinSizeF32, AllowF64,
1135                                          AllowStrictFP))
1136     return false;
1137   replaceLibCallWithSimpleIntrinsic(CI, IntrID);
1138   return true;
1139 }
1140 
1141 // fold sqrt -> native_sqrt (x)
1142 bool AMDGPULibCalls::fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B,
1143                                const FuncInfo &FInfo) {
1144   if (!isUnsafeMath(FPOp))
1145     return false;
1146 
1147   if (getArgType(FInfo) == AMDGPULibFunc::F32 && (getVecSize(FInfo) == 1) &&
1148       (FInfo.getPrefix() != AMDGPULibFunc::NATIVE)) {
1149     Module *M = B.GetInsertBlock()->getModule();
1150 
1151     if (FunctionCallee FPExpr = getNativeFunction(
1152             M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
1153       Value *opr0 = FPOp->getOperand(0);
1154       LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> "
1155                         << "sqrt(" << *opr0 << ")\n");
1156       Value *nval = CreateCallEx(B,FPExpr, opr0, "__sqrt");
1157       replaceCall(FPOp, nval);
1158       return true;
1159     }
1160   }
1161   return false;
1162 }
1163 
1164 std::tuple<Value *, Value *, Value *>
1165 AMDGPULibCalls::insertSinCos(Value *Arg, FastMathFlags FMF, IRBuilder<> &B,
1166                              FunctionCallee Fsincos) {
1167   DebugLoc DL = B.getCurrentDebugLocation();
1168   Function *F = B.GetInsertBlock()->getParent();
1169   B.SetInsertPointPastAllocas(F);
1170 
1171   AllocaInst *Alloc = B.CreateAlloca(Arg->getType(), nullptr, "__sincos_");
1172 
1173   if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) {
1174     // If the argument is an instruction, it must dominate all uses so put our
1175     // sincos call there. Otherwise, right after the allocas works well enough
1176     // if it's an argument or constant.
1177 
1178     B.SetInsertPoint(ArgInst->getParent(), ++ArgInst->getIterator());
1179 
1180     // SetInsertPoint unwelcomely always tries to set the debug loc.
1181     B.SetCurrentDebugLocation(DL);
1182   }
1183 
1184   Type *CosPtrTy = Fsincos.getFunctionType()->getParamType(1);
1185 
1186   // The allocaInst allocates the memory in private address space. This need
1187   // to be addrspacecasted to point to the address space of cos pointer type.
1188   // In OpenCL 2.0 this is generic, while in 1.2 that is private.
1189   Value *CastAlloc = B.CreateAddrSpaceCast(Alloc, CosPtrTy);
1190 
1191   CallInst *SinCos = CreateCallEx2(B, Fsincos, Arg, CastAlloc);
1192 
1193   // TODO: Is it worth trying to preserve the location for the cos calls for the
1194   // load?
1195 
1196   LoadInst *LoadCos = B.CreateLoad(Alloc->getAllocatedType(), Alloc);
1197   return {SinCos, LoadCos, SinCos};
1198 }
1199 
1200 // fold sin, cos -> sincos.
1201 bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B,
1202                                  const FuncInfo &fInfo) {
1203   assert(fInfo.getId() == AMDGPULibFunc::EI_SIN ||
1204          fInfo.getId() == AMDGPULibFunc::EI_COS);
1205 
1206   if ((getArgType(fInfo) != AMDGPULibFunc::F32 &&
1207        getArgType(fInfo) != AMDGPULibFunc::F64) ||
1208       fInfo.getPrefix() != AMDGPULibFunc::NOPFX)
1209     return false;
1210 
1211   bool const isSin = fInfo.getId() == AMDGPULibFunc::EI_SIN;
1212 
1213   Value *CArgVal = FPOp->getOperand(0);
1214   CallInst *CI = cast<CallInst>(FPOp);
1215 
1216   Function *F = B.GetInsertBlock()->getParent();
1217   Module *M = F->getParent();
1218 
1219   // Merge the sin and cos. For OpenCL 2.0, there may only be a generic pointer
1220   // implementation. Prefer the private form if available.
1221   AMDGPULibFunc SinCosLibFuncPrivate(AMDGPULibFunc::EI_SINCOS, fInfo);
1222   SinCosLibFuncPrivate.getLeads()[0].PtrKind =
1223       AMDGPULibFunc::getEPtrKindFromAddrSpace(AMDGPUAS::PRIVATE_ADDRESS);
1224 
1225   AMDGPULibFunc SinCosLibFuncGeneric(AMDGPULibFunc::EI_SINCOS, fInfo);
1226   SinCosLibFuncGeneric.getLeads()[0].PtrKind =
1227       AMDGPULibFunc::getEPtrKindFromAddrSpace(AMDGPUAS::FLAT_ADDRESS);
1228 
1229   FunctionCallee FSinCosPrivate = getFunction(M, SinCosLibFuncPrivate);
1230   FunctionCallee FSinCosGeneric = getFunction(M, SinCosLibFuncGeneric);
1231   FunctionCallee FSinCos = FSinCosPrivate ? FSinCosPrivate : FSinCosGeneric;
1232   if (!FSinCos)
1233     return false;
1234 
1235   SmallVector<CallInst *> SinCalls;
1236   SmallVector<CallInst *> CosCalls;
1237   SmallVector<CallInst *> SinCosCalls;
1238   FuncInfo PartnerInfo(isSin ? AMDGPULibFunc::EI_COS : AMDGPULibFunc::EI_SIN,
1239                        fInfo);
1240   const std::string PairName = PartnerInfo.mangle();
1241 
1242   StringRef SinName = isSin ? CI->getCalledFunction()->getName() : PairName;
1243   StringRef CosName = isSin ? PairName : CI->getCalledFunction()->getName();
1244   const std::string SinCosPrivateName = SinCosLibFuncPrivate.mangle();
1245   const std::string SinCosGenericName = SinCosLibFuncGeneric.mangle();
1246 
1247   // Intersect the two sets of flags.
1248   FastMathFlags FMF = FPOp->getFastMathFlags();
1249   MDNode *FPMath = CI->getMetadata(LLVMContext::MD_fpmath);
1250 
1251   SmallVector<DILocation *> MergeDbgLocs = {CI->getDebugLoc()};
1252 
1253   for (User* U : CArgVal->users()) {
1254     CallInst *XI = dyn_cast<CallInst>(U);
1255     if (!XI || XI->getFunction() != F || XI->isNoBuiltin())
1256       continue;
1257 
1258     Function *UCallee = XI->getCalledFunction();
1259     if (!UCallee)
1260       continue;
1261 
1262     bool Handled = true;
1263 
1264     if (UCallee->getName() == SinName)
1265       SinCalls.push_back(XI);
1266     else if (UCallee->getName() == CosName)
1267       CosCalls.push_back(XI);
1268     else if (UCallee->getName() == SinCosPrivateName ||
1269              UCallee->getName() == SinCosGenericName)
1270       SinCosCalls.push_back(XI);
1271     else
1272       Handled = false;
1273 
1274     if (Handled) {
1275       MergeDbgLocs.push_back(XI->getDebugLoc());
1276       auto *OtherOp = cast<FPMathOperator>(XI);
1277       FMF &= OtherOp->getFastMathFlags();
1278       FPMath = MDNode::getMostGenericFPMath(
1279           FPMath, XI->getMetadata(LLVMContext::MD_fpmath));
1280     }
1281   }
1282 
1283   if (SinCalls.empty() || CosCalls.empty())
1284     return false;
1285 
1286   B.setFastMathFlags(FMF);
1287   B.setDefaultFPMathTag(FPMath);
1288   DILocation *DbgLoc = DILocation::getMergedLocations(MergeDbgLocs);
1289   B.SetCurrentDebugLocation(DbgLoc);
1290 
1291   auto [Sin, Cos, SinCos] = insertSinCos(CArgVal, FMF, B, FSinCos);
1292 
1293   auto replaceTrigInsts = [](ArrayRef<CallInst *> Calls, Value *Res) {
1294     for (CallInst *C : Calls)
1295       C->replaceAllUsesWith(Res);
1296 
1297     // Leave the other dead instructions to avoid clobbering iterators.
1298   };
1299 
1300   replaceTrigInsts(SinCalls, Sin);
1301   replaceTrigInsts(CosCalls, Cos);
1302   replaceTrigInsts(SinCosCalls, SinCos);
1303 
1304   // It's safe to delete the original now.
1305   CI->eraseFromParent();
1306   return true;
1307 }
1308 
1309 bool AMDGPULibCalls::evaluateScalarMathFunc(const FuncInfo &FInfo, double &Res0,
1310                                             double &Res1, Constant *copr0,
1311                                             Constant *copr1) {
1312   // By default, opr0/opr1/opr3 holds values of float/double type.
1313   // If they are not float/double, each function has to its
1314   // operand separately.
1315   double opr0 = 0.0, opr1 = 0.0;
1316   ConstantFP *fpopr0 = dyn_cast_or_null<ConstantFP>(copr0);
1317   ConstantFP *fpopr1 = dyn_cast_or_null<ConstantFP>(copr1);
1318   if (fpopr0) {
1319     opr0 = (getArgType(FInfo) == AMDGPULibFunc::F64)
1320              ? fpopr0->getValueAPF().convertToDouble()
1321              : (double)fpopr0->getValueAPF().convertToFloat();
1322   }
1323 
1324   if (fpopr1) {
1325     opr1 = (getArgType(FInfo) == AMDGPULibFunc::F64)
1326              ? fpopr1->getValueAPF().convertToDouble()
1327              : (double)fpopr1->getValueAPF().convertToFloat();
1328   }
1329 
1330   switch (FInfo.getId()) {
1331   default : return false;
1332 
1333   case AMDGPULibFunc::EI_ACOS:
1334     Res0 = acos(opr0);
1335     return true;
1336 
1337   case AMDGPULibFunc::EI_ACOSH:
1338     // acosh(x) == log(x + sqrt(x*x - 1))
1339     Res0 = log(opr0 + sqrt(opr0*opr0 - 1.0));
1340     return true;
1341 
1342   case AMDGPULibFunc::EI_ACOSPI:
1343     Res0 = acos(opr0) / MATH_PI;
1344     return true;
1345 
1346   case AMDGPULibFunc::EI_ASIN:
1347     Res0 = asin(opr0);
1348     return true;
1349 
1350   case AMDGPULibFunc::EI_ASINH:
1351     // asinh(x) == log(x + sqrt(x*x + 1))
1352     Res0 = log(opr0 + sqrt(opr0*opr0 + 1.0));
1353     return true;
1354 
1355   case AMDGPULibFunc::EI_ASINPI:
1356     Res0 = asin(opr0) / MATH_PI;
1357     return true;
1358 
1359   case AMDGPULibFunc::EI_ATAN:
1360     Res0 = atan(opr0);
1361     return true;
1362 
1363   case AMDGPULibFunc::EI_ATANH:
1364     // atanh(x) == (log(x+1) - log(x-1))/2;
1365     Res0 = (log(opr0 + 1.0) - log(opr0 - 1.0))/2.0;
1366     return true;
1367 
1368   case AMDGPULibFunc::EI_ATANPI:
1369     Res0 = atan(opr0) / MATH_PI;
1370     return true;
1371 
1372   case AMDGPULibFunc::EI_CBRT:
1373     Res0 = (opr0 < 0.0) ? -pow(-opr0, 1.0/3.0) : pow(opr0, 1.0/3.0);
1374     return true;
1375 
1376   case AMDGPULibFunc::EI_COS:
1377     Res0 = cos(opr0);
1378     return true;
1379 
1380   case AMDGPULibFunc::EI_COSH:
1381     Res0 = cosh(opr0);
1382     return true;
1383 
1384   case AMDGPULibFunc::EI_COSPI:
1385     Res0 = cos(MATH_PI * opr0);
1386     return true;
1387 
1388   case AMDGPULibFunc::EI_EXP:
1389     Res0 = exp(opr0);
1390     return true;
1391 
1392   case AMDGPULibFunc::EI_EXP2:
1393     Res0 = pow(2.0, opr0);
1394     return true;
1395 
1396   case AMDGPULibFunc::EI_EXP10:
1397     Res0 = pow(10.0, opr0);
1398     return true;
1399 
1400   case AMDGPULibFunc::EI_LOG:
1401     Res0 = log(opr0);
1402     return true;
1403 
1404   case AMDGPULibFunc::EI_LOG2:
1405     Res0 = log(opr0) / log(2.0);
1406     return true;
1407 
1408   case AMDGPULibFunc::EI_LOG10:
1409     Res0 = log(opr0) / log(10.0);
1410     return true;
1411 
1412   case AMDGPULibFunc::EI_RSQRT:
1413     Res0 = 1.0 / sqrt(opr0);
1414     return true;
1415 
1416   case AMDGPULibFunc::EI_SIN:
1417     Res0 = sin(opr0);
1418     return true;
1419 
1420   case AMDGPULibFunc::EI_SINH:
1421     Res0 = sinh(opr0);
1422     return true;
1423 
1424   case AMDGPULibFunc::EI_SINPI:
1425     Res0 = sin(MATH_PI * opr0);
1426     return true;
1427 
1428   case AMDGPULibFunc::EI_TAN:
1429     Res0 = tan(opr0);
1430     return true;
1431 
1432   case AMDGPULibFunc::EI_TANH:
1433     Res0 = tanh(opr0);
1434     return true;
1435 
1436   case AMDGPULibFunc::EI_TANPI:
1437     Res0 = tan(MATH_PI * opr0);
1438     return true;
1439 
1440   // two-arg functions
1441   case AMDGPULibFunc::EI_POW:
1442   case AMDGPULibFunc::EI_POWR:
1443     Res0 = pow(opr0, opr1);
1444     return true;
1445 
1446   case AMDGPULibFunc::EI_POWN: {
1447     if (ConstantInt *iopr1 = dyn_cast_or_null<ConstantInt>(copr1)) {
1448       double val = (double)iopr1->getSExtValue();
1449       Res0 = pow(opr0, val);
1450       return true;
1451     }
1452     return false;
1453   }
1454 
1455   case AMDGPULibFunc::EI_ROOTN: {
1456     if (ConstantInt *iopr1 = dyn_cast_or_null<ConstantInt>(copr1)) {
1457       double val = (double)iopr1->getSExtValue();
1458       Res0 = pow(opr0, 1.0 / val);
1459       return true;
1460     }
1461     return false;
1462   }
1463 
1464   // with ptr arg
1465   case AMDGPULibFunc::EI_SINCOS:
1466     Res0 = sin(opr0);
1467     Res1 = cos(opr0);
1468     return true;
1469   }
1470 
1471   return false;
1472 }
1473 
1474 bool AMDGPULibCalls::evaluateCall(CallInst *aCI, const FuncInfo &FInfo) {
1475   int numArgs = (int)aCI->arg_size();
1476   if (numArgs > 3)
1477     return false;
1478 
1479   Constant *copr0 = nullptr;
1480   Constant *copr1 = nullptr;
1481   if (numArgs > 0) {
1482     if ((copr0 = dyn_cast<Constant>(aCI->getArgOperand(0))) == nullptr)
1483       return false;
1484   }
1485 
1486   if (numArgs > 1) {
1487     if ((copr1 = dyn_cast<Constant>(aCI->getArgOperand(1))) == nullptr) {
1488       if (FInfo.getId() != AMDGPULibFunc::EI_SINCOS)
1489         return false;
1490     }
1491   }
1492 
1493   // At this point, all arguments to aCI are constants.
1494 
1495   // max vector size is 16, and sincos will generate two results.
1496   double DVal0[16], DVal1[16];
1497   int FuncVecSize = getVecSize(FInfo);
1498   bool hasTwoResults = (FInfo.getId() == AMDGPULibFunc::EI_SINCOS);
1499   if (FuncVecSize == 1) {
1500     if (!evaluateScalarMathFunc(FInfo, DVal0[0], DVal1[0], copr0, copr1)) {
1501       return false;
1502     }
1503   } else {
1504     ConstantDataVector *CDV0 = dyn_cast_or_null<ConstantDataVector>(copr0);
1505     ConstantDataVector *CDV1 = dyn_cast_or_null<ConstantDataVector>(copr1);
1506     for (int i = 0; i < FuncVecSize; ++i) {
1507       Constant *celt0 = CDV0 ? CDV0->getElementAsConstant(i) : nullptr;
1508       Constant *celt1 = CDV1 ? CDV1->getElementAsConstant(i) : nullptr;
1509       if (!evaluateScalarMathFunc(FInfo, DVal0[i], DVal1[i], celt0, celt1)) {
1510         return false;
1511       }
1512     }
1513   }
1514 
1515   LLVMContext &context = aCI->getContext();
1516   Constant *nval0, *nval1;
1517   if (FuncVecSize == 1) {
1518     nval0 = ConstantFP::get(aCI->getType(), DVal0[0]);
1519     if (hasTwoResults)
1520       nval1 = ConstantFP::get(aCI->getType(), DVal1[0]);
1521   } else {
1522     if (getArgType(FInfo) == AMDGPULibFunc::F32) {
1523       SmallVector <float, 0> FVal0, FVal1;
1524       for (int i = 0; i < FuncVecSize; ++i)
1525         FVal0.push_back((float)DVal0[i]);
1526       ArrayRef<float> tmp0(FVal0);
1527       nval0 = ConstantDataVector::get(context, tmp0);
1528       if (hasTwoResults) {
1529         for (int i = 0; i < FuncVecSize; ++i)
1530           FVal1.push_back((float)DVal1[i]);
1531         ArrayRef<float> tmp1(FVal1);
1532         nval1 = ConstantDataVector::get(context, tmp1);
1533       }
1534     } else {
1535       ArrayRef<double> tmp0(DVal0);
1536       nval0 = ConstantDataVector::get(context, tmp0);
1537       if (hasTwoResults) {
1538         ArrayRef<double> tmp1(DVal1);
1539         nval1 = ConstantDataVector::get(context, tmp1);
1540       }
1541     }
1542   }
1543 
1544   if (hasTwoResults) {
1545     // sincos
1546     assert(FInfo.getId() == AMDGPULibFunc::EI_SINCOS &&
1547            "math function with ptr arg not supported yet");
1548     new StoreInst(nval1, aCI->getArgOperand(1), aCI);
1549   }
1550 
1551   replaceCall(aCI, nval0);
1552   return true;
1553 }
1554 
1555 PreservedAnalyses AMDGPUSimplifyLibCallsPass::run(Function &F,
1556                                                   FunctionAnalysisManager &AM) {
1557   AMDGPULibCalls Simplifier;
1558   Simplifier.initNativeFuncs();
1559   Simplifier.initFunction(F);
1560 
1561   bool Changed = false;
1562 
1563   LLVM_DEBUG(dbgs() << "AMDIC: process function ";
1564              F.printAsOperand(dbgs(), false, F.getParent()); dbgs() << '\n';);
1565 
1566   for (auto &BB : F) {
1567     for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E;) {
1568       // Ignore non-calls.
1569       CallInst *CI = dyn_cast<CallInst>(I);
1570       ++I;
1571 
1572       if (CI) {
1573         if (Simplifier.fold(CI))
1574           Changed = true;
1575       }
1576     }
1577   }
1578   return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
1579 }
1580 
1581 PreservedAnalyses AMDGPUUseNativeCallsPass::run(Function &F,
1582                                                 FunctionAnalysisManager &AM) {
1583   if (UseNative.empty())
1584     return PreservedAnalyses::all();
1585 
1586   AMDGPULibCalls Simplifier;
1587   Simplifier.initNativeFuncs();
1588   Simplifier.initFunction(F);
1589 
1590   bool Changed = false;
1591   for (auto &BB : F) {
1592     for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E;) {
1593       // Ignore non-calls.
1594       CallInst *CI = dyn_cast<CallInst>(I);
1595       ++I;
1596       if (CI && Simplifier.useNative(CI))
1597         Changed = true;
1598     }
1599   }
1600   return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
1601 }
1602