xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp (revision f337a77c99aa31b37c60dd2ecbd96f8317426fad)
1 //===- AMDGPULibCalls.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This file does AMD library function optimizations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "AMDGPU.h"
15 #include "AMDGPULibFunc.h"
16 #include "GCNSubtarget.h"
17 #include "llvm/Analysis/AliasAnalysis.h"
18 #include "llvm/Analysis/Loads.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/IntrinsicInst.h"
21 #include "llvm/IR/IntrinsicsAMDGPU.h"
22 #include "llvm/InitializePasses.h"
23 #include <cmath>
24 
25 #define DEBUG_TYPE "amdgpu-simplifylib"
26 
27 using namespace llvm;
28 
29 static cl::opt<bool> EnablePreLink("amdgpu-prelink",
30   cl::desc("Enable pre-link mode optimizations"),
31   cl::init(false),
32   cl::Hidden);
33 
34 static cl::list<std::string> UseNative("amdgpu-use-native",
35   cl::desc("Comma separated list of functions to replace with native, or all"),
36   cl::CommaSeparated, cl::ValueOptional,
37   cl::Hidden);
38 
39 #define MATH_PI      numbers::pi
40 #define MATH_E       numbers::e
41 #define MATH_SQRT2   numbers::sqrt2
42 #define MATH_SQRT1_2 numbers::inv_sqrt2
43 
44 namespace llvm {
45 
46 class AMDGPULibCalls {
47 private:
48 
49   typedef llvm::AMDGPULibFunc FuncInfo;
50 
51   bool UnsafeFPMath = false;
52 
53   // -fuse-native.
54   bool AllNative = false;
55 
56   bool useNativeFunc(const StringRef F) const;
57 
58   // Return a pointer (pointer expr) to the function if function definition with
59   // "FuncName" exists. It may create a new function prototype in pre-link mode.
60   FunctionCallee getFunction(Module *M, const FuncInfo &fInfo);
61 
62   bool parseFunctionName(const StringRef &FMangledName, FuncInfo &FInfo);
63 
64   bool TDOFold(CallInst *CI, const FuncInfo &FInfo);
65 
66   /* Specialized optimizations */
67 
68   // pow/powr/pown
69   bool fold_pow(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
70 
71   // rootn
72   bool fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
73 
74   // -fuse-native for sincos
75   bool sincosUseNative(CallInst *aCI, const FuncInfo &FInfo);
76 
77   // evaluate calls if calls' arguments are constants.
78   bool evaluateScalarMathFunc(const FuncInfo &FInfo, double& Res0,
79     double& Res1, Constant *copr0, Constant *copr1, Constant *copr2);
80   bool evaluateCall(CallInst *aCI, const FuncInfo &FInfo);
81 
82   // sqrt
83   bool fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
84 
85   /// Insert a value to sincos function \p Fsincos. Returns (value of sin, value
86   /// of cos, sincos call).
87   std::tuple<Value *, Value *, Value *> insertSinCos(Value *Arg,
88                                                      FastMathFlags FMF,
89                                                      IRBuilder<> &B,
90                                                      FunctionCallee Fsincos);
91 
92   // sin/cos
93   bool fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
94 
95   // __read_pipe/__write_pipe
96   bool fold_read_write_pipe(CallInst *CI, IRBuilder<> &B,
97                             const FuncInfo &FInfo);
98 
99   // Get a scalar native builtin single argument FP function
100   FunctionCallee getNativeFunction(Module *M, const FuncInfo &FInfo);
101 
102   /// Substitute a call to a known libcall with an intrinsic call. If \p
103   /// AllowMinSize is true, allow the replacement in a minsize function.
104   bool shouldReplaceLibcallWithIntrinsic(const CallInst *CI,
105                                          bool AllowMinSizeF32 = false,
106                                          bool AllowF64 = false,
107                                          bool AllowStrictFP = false);
108   void replaceLibCallWithSimpleIntrinsic(CallInst *CI, Intrinsic::ID IntrID);
109 
110   bool tryReplaceLibcallWithSimpleIntrinsic(CallInst *CI, Intrinsic::ID IntrID,
111                                             bool AllowMinSizeF32 = false,
112                                             bool AllowF64 = false,
113                                             bool AllowStrictFP = false);
114 
115 protected:
116   bool isUnsafeMath(const FPMathOperator *FPOp) const;
117 
118   bool canIncreasePrecisionOfConstantFold(const FPMathOperator *FPOp) const;
119 
120   static void replaceCall(Instruction *I, Value *With) {
121     I->replaceAllUsesWith(With);
122     I->eraseFromParent();
123   }
124 
125   static void replaceCall(FPMathOperator *I, Value *With) {
126     replaceCall(cast<Instruction>(I), With);
127   }
128 
129 public:
130   AMDGPULibCalls() {}
131 
132   bool fold(CallInst *CI);
133 
134   void initFunction(const Function &F);
135   void initNativeFuncs();
136 
137   // Replace a normal math function call with that native version
138   bool useNative(CallInst *CI);
139 };
140 
141 } // end llvm namespace
142 
143 template <typename IRB>
144 static CallInst *CreateCallEx(IRB &B, FunctionCallee Callee, Value *Arg,
145                               const Twine &Name = "") {
146   CallInst *R = B.CreateCall(Callee, Arg, Name);
147   if (Function *F = dyn_cast<Function>(Callee.getCallee()))
148     R->setCallingConv(F->getCallingConv());
149   return R;
150 }
151 
152 template <typename IRB>
153 static CallInst *CreateCallEx2(IRB &B, FunctionCallee Callee, Value *Arg1,
154                                Value *Arg2, const Twine &Name = "") {
155   CallInst *R = B.CreateCall(Callee, {Arg1, Arg2}, Name);
156   if (Function *F = dyn_cast<Function>(Callee.getCallee()))
157     R->setCallingConv(F->getCallingConv());
158   return R;
159 }
160 
161 //  Data structures for table-driven optimizations.
162 //  FuncTbl works for both f32 and f64 functions with 1 input argument
163 
164 struct TableEntry {
165   double   result;
166   double   input;
167 };
168 
169 /* a list of {result, input} */
170 static const TableEntry tbl_acos[] = {
171   {MATH_PI / 2.0, 0.0},
172   {MATH_PI / 2.0, -0.0},
173   {0.0, 1.0},
174   {MATH_PI, -1.0}
175 };
176 static const TableEntry tbl_acosh[] = {
177   {0.0, 1.0}
178 };
179 static const TableEntry tbl_acospi[] = {
180   {0.5, 0.0},
181   {0.5, -0.0},
182   {0.0, 1.0},
183   {1.0, -1.0}
184 };
185 static const TableEntry tbl_asin[] = {
186   {0.0, 0.0},
187   {-0.0, -0.0},
188   {MATH_PI / 2.0, 1.0},
189   {-MATH_PI / 2.0, -1.0}
190 };
191 static const TableEntry tbl_asinh[] = {
192   {0.0, 0.0},
193   {-0.0, -0.0}
194 };
195 static const TableEntry tbl_asinpi[] = {
196   {0.0, 0.0},
197   {-0.0, -0.0},
198   {0.5, 1.0},
199   {-0.5, -1.0}
200 };
201 static const TableEntry tbl_atan[] = {
202   {0.0, 0.0},
203   {-0.0, -0.0},
204   {MATH_PI / 4.0, 1.0},
205   {-MATH_PI / 4.0, -1.0}
206 };
207 static const TableEntry tbl_atanh[] = {
208   {0.0, 0.0},
209   {-0.0, -0.0}
210 };
211 static const TableEntry tbl_atanpi[] = {
212   {0.0, 0.0},
213   {-0.0, -0.0},
214   {0.25, 1.0},
215   {-0.25, -1.0}
216 };
217 static const TableEntry tbl_cbrt[] = {
218   {0.0, 0.0},
219   {-0.0, -0.0},
220   {1.0, 1.0},
221   {-1.0, -1.0},
222 };
223 static const TableEntry tbl_cos[] = {
224   {1.0, 0.0},
225   {1.0, -0.0}
226 };
227 static const TableEntry tbl_cosh[] = {
228   {1.0, 0.0},
229   {1.0, -0.0}
230 };
231 static const TableEntry tbl_cospi[] = {
232   {1.0, 0.0},
233   {1.0, -0.0}
234 };
235 static const TableEntry tbl_erfc[] = {
236   {1.0, 0.0},
237   {1.0, -0.0}
238 };
239 static const TableEntry tbl_erf[] = {
240   {0.0, 0.0},
241   {-0.0, -0.0}
242 };
243 static const TableEntry tbl_exp[] = {
244   {1.0, 0.0},
245   {1.0, -0.0},
246   {MATH_E, 1.0}
247 };
248 static const TableEntry tbl_exp2[] = {
249   {1.0, 0.0},
250   {1.0, -0.0},
251   {2.0, 1.0}
252 };
253 static const TableEntry tbl_exp10[] = {
254   {1.0, 0.0},
255   {1.0, -0.0},
256   {10.0, 1.0}
257 };
258 static const TableEntry tbl_expm1[] = {
259   {0.0, 0.0},
260   {-0.0, -0.0}
261 };
262 static const TableEntry tbl_log[] = {
263   {0.0, 1.0},
264   {1.0, MATH_E}
265 };
266 static const TableEntry tbl_log2[] = {
267   {0.0, 1.0},
268   {1.0, 2.0}
269 };
270 static const TableEntry tbl_log10[] = {
271   {0.0, 1.0},
272   {1.0, 10.0}
273 };
274 static const TableEntry tbl_rsqrt[] = {
275   {1.0, 1.0},
276   {MATH_SQRT1_2, 2.0}
277 };
278 static const TableEntry tbl_sin[] = {
279   {0.0, 0.0},
280   {-0.0, -0.0}
281 };
282 static const TableEntry tbl_sinh[] = {
283   {0.0, 0.0},
284   {-0.0, -0.0}
285 };
286 static const TableEntry tbl_sinpi[] = {
287   {0.0, 0.0},
288   {-0.0, -0.0}
289 };
290 static const TableEntry tbl_sqrt[] = {
291   {0.0, 0.0},
292   {1.0, 1.0},
293   {MATH_SQRT2, 2.0}
294 };
295 static const TableEntry tbl_tan[] = {
296   {0.0, 0.0},
297   {-0.0, -0.0}
298 };
299 static const TableEntry tbl_tanh[] = {
300   {0.0, 0.0},
301   {-0.0, -0.0}
302 };
303 static const TableEntry tbl_tanpi[] = {
304   {0.0, 0.0},
305   {-0.0, -0.0}
306 };
307 static const TableEntry tbl_tgamma[] = {
308   {1.0, 1.0},
309   {1.0, 2.0},
310   {2.0, 3.0},
311   {6.0, 4.0}
312 };
313 
314 static bool HasNative(AMDGPULibFunc::EFuncId id) {
315   switch(id) {
316   case AMDGPULibFunc::EI_DIVIDE:
317   case AMDGPULibFunc::EI_COS:
318   case AMDGPULibFunc::EI_EXP:
319   case AMDGPULibFunc::EI_EXP2:
320   case AMDGPULibFunc::EI_EXP10:
321   case AMDGPULibFunc::EI_LOG:
322   case AMDGPULibFunc::EI_LOG2:
323   case AMDGPULibFunc::EI_LOG10:
324   case AMDGPULibFunc::EI_POWR:
325   case AMDGPULibFunc::EI_RECIP:
326   case AMDGPULibFunc::EI_RSQRT:
327   case AMDGPULibFunc::EI_SIN:
328   case AMDGPULibFunc::EI_SINCOS:
329   case AMDGPULibFunc::EI_SQRT:
330   case AMDGPULibFunc::EI_TAN:
331     return true;
332   default:;
333   }
334   return false;
335 }
336 
337 using TableRef = ArrayRef<TableEntry>;
338 
339 static TableRef getOptTable(AMDGPULibFunc::EFuncId id) {
340   switch(id) {
341   case AMDGPULibFunc::EI_ACOS:    return TableRef(tbl_acos);
342   case AMDGPULibFunc::EI_ACOSH:   return TableRef(tbl_acosh);
343   case AMDGPULibFunc::EI_ACOSPI:  return TableRef(tbl_acospi);
344   case AMDGPULibFunc::EI_ASIN:    return TableRef(tbl_asin);
345   case AMDGPULibFunc::EI_ASINH:   return TableRef(tbl_asinh);
346   case AMDGPULibFunc::EI_ASINPI:  return TableRef(tbl_asinpi);
347   case AMDGPULibFunc::EI_ATAN:    return TableRef(tbl_atan);
348   case AMDGPULibFunc::EI_ATANH:   return TableRef(tbl_atanh);
349   case AMDGPULibFunc::EI_ATANPI:  return TableRef(tbl_atanpi);
350   case AMDGPULibFunc::EI_CBRT:    return TableRef(tbl_cbrt);
351   case AMDGPULibFunc::EI_NCOS:
352   case AMDGPULibFunc::EI_COS:     return TableRef(tbl_cos);
353   case AMDGPULibFunc::EI_COSH:    return TableRef(tbl_cosh);
354   case AMDGPULibFunc::EI_COSPI:   return TableRef(tbl_cospi);
355   case AMDGPULibFunc::EI_ERFC:    return TableRef(tbl_erfc);
356   case AMDGPULibFunc::EI_ERF:     return TableRef(tbl_erf);
357   case AMDGPULibFunc::EI_EXP:     return TableRef(tbl_exp);
358   case AMDGPULibFunc::EI_NEXP2:
359   case AMDGPULibFunc::EI_EXP2:    return TableRef(tbl_exp2);
360   case AMDGPULibFunc::EI_EXP10:   return TableRef(tbl_exp10);
361   case AMDGPULibFunc::EI_EXPM1:   return TableRef(tbl_expm1);
362   case AMDGPULibFunc::EI_LOG:     return TableRef(tbl_log);
363   case AMDGPULibFunc::EI_NLOG2:
364   case AMDGPULibFunc::EI_LOG2:    return TableRef(tbl_log2);
365   case AMDGPULibFunc::EI_LOG10:   return TableRef(tbl_log10);
366   case AMDGPULibFunc::EI_NRSQRT:
367   case AMDGPULibFunc::EI_RSQRT:   return TableRef(tbl_rsqrt);
368   case AMDGPULibFunc::EI_NSIN:
369   case AMDGPULibFunc::EI_SIN:     return TableRef(tbl_sin);
370   case AMDGPULibFunc::EI_SINH:    return TableRef(tbl_sinh);
371   case AMDGPULibFunc::EI_SINPI:   return TableRef(tbl_sinpi);
372   case AMDGPULibFunc::EI_NSQRT:
373   case AMDGPULibFunc::EI_SQRT:    return TableRef(tbl_sqrt);
374   case AMDGPULibFunc::EI_TAN:     return TableRef(tbl_tan);
375   case AMDGPULibFunc::EI_TANH:    return TableRef(tbl_tanh);
376   case AMDGPULibFunc::EI_TANPI:   return TableRef(tbl_tanpi);
377   case AMDGPULibFunc::EI_TGAMMA:  return TableRef(tbl_tgamma);
378   default:;
379   }
380   return TableRef();
381 }
382 
383 static inline int getVecSize(const AMDGPULibFunc& FInfo) {
384   return FInfo.getLeads()[0].VectorSize;
385 }
386 
387 static inline AMDGPULibFunc::EType getArgType(const AMDGPULibFunc& FInfo) {
388   return (AMDGPULibFunc::EType)FInfo.getLeads()[0].ArgType;
389 }
390 
391 FunctionCallee AMDGPULibCalls::getFunction(Module *M, const FuncInfo &fInfo) {
392   // If we are doing PreLinkOpt, the function is external. So it is safe to
393   // use getOrInsertFunction() at this stage.
394 
395   return EnablePreLink ? AMDGPULibFunc::getOrInsertFunction(M, fInfo)
396                        : AMDGPULibFunc::getFunction(M, fInfo);
397 }
398 
399 bool AMDGPULibCalls::parseFunctionName(const StringRef &FMangledName,
400                                        FuncInfo &FInfo) {
401   return AMDGPULibFunc::parse(FMangledName, FInfo);
402 }
403 
404 bool AMDGPULibCalls::isUnsafeMath(const FPMathOperator *FPOp) const {
405   return UnsafeFPMath || FPOp->isFast();
406 }
407 
408 bool AMDGPULibCalls::canIncreasePrecisionOfConstantFold(
409     const FPMathOperator *FPOp) const {
410   // TODO: Refine to approxFunc or contract
411   return isUnsafeMath(FPOp);
412 }
413 
414 void AMDGPULibCalls::initFunction(const Function &F) {
415   UnsafeFPMath = F.getFnAttribute("unsafe-fp-math").getValueAsBool();
416 }
417 
418 bool AMDGPULibCalls::useNativeFunc(const StringRef F) const {
419   return AllNative || llvm::is_contained(UseNative, F);
420 }
421 
422 void AMDGPULibCalls::initNativeFuncs() {
423   AllNative = useNativeFunc("all") ||
424               (UseNative.getNumOccurrences() && UseNative.size() == 1 &&
425                UseNative.begin()->empty());
426 }
427 
428 bool AMDGPULibCalls::sincosUseNative(CallInst *aCI, const FuncInfo &FInfo) {
429   bool native_sin = useNativeFunc("sin");
430   bool native_cos = useNativeFunc("cos");
431 
432   if (native_sin && native_cos) {
433     Module *M = aCI->getModule();
434     Value *opr0 = aCI->getArgOperand(0);
435 
436     AMDGPULibFunc nf;
437     nf.getLeads()[0].ArgType = FInfo.getLeads()[0].ArgType;
438     nf.getLeads()[0].VectorSize = FInfo.getLeads()[0].VectorSize;
439 
440     nf.setPrefix(AMDGPULibFunc::NATIVE);
441     nf.setId(AMDGPULibFunc::EI_SIN);
442     FunctionCallee sinExpr = getFunction(M, nf);
443 
444     nf.setPrefix(AMDGPULibFunc::NATIVE);
445     nf.setId(AMDGPULibFunc::EI_COS);
446     FunctionCallee cosExpr = getFunction(M, nf);
447     if (sinExpr && cosExpr) {
448       Value *sinval = CallInst::Create(sinExpr, opr0, "splitsin", aCI);
449       Value *cosval = CallInst::Create(cosExpr, opr0, "splitcos", aCI);
450       new StoreInst(cosval, aCI->getArgOperand(1), aCI);
451 
452       DEBUG_WITH_TYPE("usenative", dbgs() << "<useNative> replace " << *aCI
453                                           << " with native version of sin/cos");
454 
455       replaceCall(aCI, sinval);
456       return true;
457     }
458   }
459   return false;
460 }
461 
462 bool AMDGPULibCalls::useNative(CallInst *aCI) {
463   Function *Callee = aCI->getCalledFunction();
464   if (!Callee || aCI->isNoBuiltin())
465     return false;
466 
467   FuncInfo FInfo;
468   if (!parseFunctionName(Callee->getName(), FInfo) || !FInfo.isMangled() ||
469       FInfo.getPrefix() != AMDGPULibFunc::NOPFX ||
470       getArgType(FInfo) == AMDGPULibFunc::F64 || !HasNative(FInfo.getId()) ||
471       !(AllNative || useNativeFunc(FInfo.getName()))) {
472     return false;
473   }
474 
475   if (FInfo.getId() == AMDGPULibFunc::EI_SINCOS)
476     return sincosUseNative(aCI, FInfo);
477 
478   FInfo.setPrefix(AMDGPULibFunc::NATIVE);
479   FunctionCallee F = getFunction(aCI->getModule(), FInfo);
480   if (!F)
481     return false;
482 
483   aCI->setCalledFunction(F);
484   DEBUG_WITH_TYPE("usenative", dbgs() << "<useNative> replace " << *aCI
485                                       << " with native version");
486   return true;
487 }
488 
489 // Clang emits call of __read_pipe_2 or __read_pipe_4 for OpenCL read_pipe
490 // builtin, with appended type size and alignment arguments, where 2 or 4
491 // indicates the original number of arguments. The library has optimized version
492 // of __read_pipe_2/__read_pipe_4 when the type size and alignment has the same
493 // power of 2 value. This function transforms __read_pipe_2 to __read_pipe_2_N
494 // for such cases where N is the size in bytes of the type (N = 1, 2, 4, 8, ...,
495 // 128). The same for __read_pipe_4, write_pipe_2, and write_pipe_4.
496 bool AMDGPULibCalls::fold_read_write_pipe(CallInst *CI, IRBuilder<> &B,
497                                           const FuncInfo &FInfo) {
498   auto *Callee = CI->getCalledFunction();
499   if (!Callee->isDeclaration())
500     return false;
501 
502   assert(Callee->hasName() && "Invalid read_pipe/write_pipe function");
503   auto *M = Callee->getParent();
504   std::string Name = std::string(Callee->getName());
505   auto NumArg = CI->arg_size();
506   if (NumArg != 4 && NumArg != 6)
507     return false;
508   ConstantInt *PacketSize =
509       dyn_cast<ConstantInt>(CI->getArgOperand(NumArg - 2));
510   ConstantInt *PacketAlign =
511       dyn_cast<ConstantInt>(CI->getArgOperand(NumArg - 1));
512   if (!PacketSize || !PacketAlign)
513     return false;
514 
515   unsigned Size = PacketSize->getZExtValue();
516   Align Alignment = PacketAlign->getAlignValue();
517   if (Alignment != Size)
518     return false;
519 
520   unsigned PtrArgLoc = CI->arg_size() - 3;
521   Value *PtrArg = CI->getArgOperand(PtrArgLoc);
522   Type *PtrTy = PtrArg->getType();
523 
524   SmallVector<llvm::Type *, 6> ArgTys;
525   for (unsigned I = 0; I != PtrArgLoc; ++I)
526     ArgTys.push_back(CI->getArgOperand(I)->getType());
527   ArgTys.push_back(PtrTy);
528 
529   Name = Name + "_" + std::to_string(Size);
530   auto *FTy = FunctionType::get(Callee->getReturnType(),
531                                 ArrayRef<Type *>(ArgTys), false);
532   AMDGPULibFunc NewLibFunc(Name, FTy);
533   FunctionCallee F = AMDGPULibFunc::getOrInsertFunction(M, NewLibFunc);
534   if (!F)
535     return false;
536 
537   auto *BCast = B.CreatePointerCast(PtrArg, PtrTy);
538   SmallVector<Value *, 6> Args;
539   for (unsigned I = 0; I != PtrArgLoc; ++I)
540     Args.push_back(CI->getArgOperand(I));
541   Args.push_back(BCast);
542 
543   auto *NCI = B.CreateCall(F, Args);
544   NCI->setAttributes(CI->getAttributes());
545   CI->replaceAllUsesWith(NCI);
546   CI->dropAllReferences();
547   CI->eraseFromParent();
548 
549   return true;
550 }
551 
552 // This function returns false if no change; return true otherwise.
553 bool AMDGPULibCalls::fold(CallInst *CI) {
554   Function *Callee = CI->getCalledFunction();
555   // Ignore indirect calls.
556   if (!Callee || Callee->isIntrinsic() || CI->isNoBuiltin())
557     return false;
558 
559   FuncInfo FInfo;
560   if (!parseFunctionName(Callee->getName(), FInfo))
561     return false;
562 
563   // Further check the number of arguments to see if they match.
564   // TODO: Check calling convention matches too
565   if (!FInfo.isCompatibleSignature(CI->getFunctionType()))
566     return false;
567 
568   LLVM_DEBUG(dbgs() << "AMDIC: try folding " << *CI << '\n');
569 
570   if (TDOFold(CI, FInfo))
571     return true;
572 
573   IRBuilder<> B(CI);
574 
575   if (FPMathOperator *FPOp = dyn_cast<FPMathOperator>(CI)) {
576     // Under unsafe-math, evaluate calls if possible.
577     // According to Brian Sumner, we can do this for all f32 function calls
578     // using host's double function calls.
579     if (canIncreasePrecisionOfConstantFold(FPOp) && evaluateCall(CI, FInfo))
580       return true;
581 
582     // Copy fast flags from the original call.
583     FastMathFlags FMF = FPOp->getFastMathFlags();
584     B.setFastMathFlags(FMF);
585 
586     // Specialized optimizations for each function call.
587     //
588     // TODO: Handle other simple intrinsic wrappers. Sqrt, ldexp log.
589     //
590     // TODO: Handle native functions
591     switch (FInfo.getId()) {
592     case AMDGPULibFunc::EI_EXP:
593       if (FMF.none())
594         return false;
595       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::exp,
596                                                   FMF.approxFunc());
597     case AMDGPULibFunc::EI_EXP2:
598       if (FMF.none())
599         return false;
600       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::exp2,
601                                                   FMF.approxFunc());
602     case AMDGPULibFunc::EI_FMIN:
603       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::minnum, true,
604                                                   true);
605     case AMDGPULibFunc::EI_FMAX:
606       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::maxnum, true,
607                                                   true);
608     case AMDGPULibFunc::EI_FMA:
609       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::fma, true,
610                                                   true);
611     case AMDGPULibFunc::EI_MAD:
612       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::fmuladd, true,
613                                                   true);
614     case AMDGPULibFunc::EI_FABS:
615       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::fabs, true,
616                                                   true, true);
617     case AMDGPULibFunc::EI_COPYSIGN:
618       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::copysign, true,
619                                                   true, true);
620     case AMDGPULibFunc::EI_FLOOR:
621       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::floor, true,
622                                                   true);
623     case AMDGPULibFunc::EI_CEIL:
624       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::ceil, true,
625                                                   true);
626     case AMDGPULibFunc::EI_TRUNC:
627       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::trunc, true,
628                                                   true);
629     case AMDGPULibFunc::EI_RINT:
630       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::rint, true,
631                                                   true);
632     case AMDGPULibFunc::EI_ROUND:
633       return tryReplaceLibcallWithSimpleIntrinsic(CI, Intrinsic::round, true,
634                                                   true);
635     case AMDGPULibFunc::EI_POW:
636     case AMDGPULibFunc::EI_POWR:
637     case AMDGPULibFunc::EI_POWN:
638       return fold_pow(FPOp, B, FInfo);
639     case AMDGPULibFunc::EI_ROOTN:
640       return fold_rootn(FPOp, B, FInfo);
641     case AMDGPULibFunc::EI_SQRT:
642       return fold_sqrt(FPOp, B, FInfo);
643     case AMDGPULibFunc::EI_COS:
644     case AMDGPULibFunc::EI_SIN:
645       return fold_sincos(FPOp, B, FInfo);
646     default:
647       break;
648     }
649   } else {
650     // Specialized optimizations for each function call
651     switch (FInfo.getId()) {
652     case AMDGPULibFunc::EI_READ_PIPE_2:
653     case AMDGPULibFunc::EI_READ_PIPE_4:
654     case AMDGPULibFunc::EI_WRITE_PIPE_2:
655     case AMDGPULibFunc::EI_WRITE_PIPE_4:
656       return fold_read_write_pipe(CI, B, FInfo);
657     default:
658       break;
659     }
660   }
661 
662   return false;
663 }
664 
665 bool AMDGPULibCalls::TDOFold(CallInst *CI, const FuncInfo &FInfo) {
666   // Table-Driven optimization
667   const TableRef tr = getOptTable(FInfo.getId());
668   if (tr.empty())
669     return false;
670 
671   int const sz = (int)tr.size();
672   Value *opr0 = CI->getArgOperand(0);
673 
674   if (getVecSize(FInfo) > 1) {
675     if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(opr0)) {
676       SmallVector<double, 0> DVal;
677       for (int eltNo = 0; eltNo < getVecSize(FInfo); ++eltNo) {
678         ConstantFP *eltval = dyn_cast<ConstantFP>(
679                                CV->getElementAsConstant((unsigned)eltNo));
680         assert(eltval && "Non-FP arguments in math function!");
681         bool found = false;
682         for (int i=0; i < sz; ++i) {
683           if (eltval->isExactlyValue(tr[i].input)) {
684             DVal.push_back(tr[i].result);
685             found = true;
686             break;
687           }
688         }
689         if (!found) {
690           // This vector constants not handled yet.
691           return false;
692         }
693       }
694       LLVMContext &context = CI->getParent()->getParent()->getContext();
695       Constant *nval;
696       if (getArgType(FInfo) == AMDGPULibFunc::F32) {
697         SmallVector<float, 0> FVal;
698         for (unsigned i = 0; i < DVal.size(); ++i) {
699           FVal.push_back((float)DVal[i]);
700         }
701         ArrayRef<float> tmp(FVal);
702         nval = ConstantDataVector::get(context, tmp);
703       } else { // F64
704         ArrayRef<double> tmp(DVal);
705         nval = ConstantDataVector::get(context, tmp);
706       }
707       LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n");
708       replaceCall(CI, nval);
709       return true;
710     }
711   } else {
712     // Scalar version
713     if (ConstantFP *CF = dyn_cast<ConstantFP>(opr0)) {
714       for (int i = 0; i < sz; ++i) {
715         if (CF->isExactlyValue(tr[i].input)) {
716           Value *nval = ConstantFP::get(CF->getType(), tr[i].result);
717           LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n");
718           replaceCall(CI, nval);
719           return true;
720         }
721       }
722     }
723   }
724 
725   return false;
726 }
727 
728 namespace llvm {
729 static double log2(double V) {
730 #if _XOPEN_SOURCE >= 600 || defined(_ISOC99_SOURCE) || _POSIX_C_SOURCE >= 200112L
731   return ::log2(V);
732 #else
733   return log(V) / numbers::ln2;
734 #endif
735 }
736 }
737 
738 bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
739                               const FuncInfo &FInfo) {
740   assert((FInfo.getId() == AMDGPULibFunc::EI_POW ||
741           FInfo.getId() == AMDGPULibFunc::EI_POWR ||
742           FInfo.getId() == AMDGPULibFunc::EI_POWN) &&
743          "fold_pow: encounter a wrong function call");
744 
745   Module *M = B.GetInsertBlock()->getModule();
746   ConstantFP *CF;
747   ConstantInt *CINT;
748   Type *eltType;
749   Value *opr0 = FPOp->getOperand(0);
750   Value *opr1 = FPOp->getOperand(1);
751   ConstantAggregateZero *CZero = dyn_cast<ConstantAggregateZero>(opr1);
752 
753   if (getVecSize(FInfo) == 1) {
754     eltType = opr0->getType();
755     CF = dyn_cast<ConstantFP>(opr1);
756     CINT = dyn_cast<ConstantInt>(opr1);
757   } else {
758     VectorType *VTy = dyn_cast<VectorType>(opr0->getType());
759     assert(VTy && "Oprand of vector function should be of vectortype");
760     eltType = VTy->getElementType();
761     ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr1);
762 
763     // Now, only Handle vector const whose elements have the same value.
764     CF = CDV ? dyn_cast_or_null<ConstantFP>(CDV->getSplatValue()) : nullptr;
765     CINT = CDV ? dyn_cast_or_null<ConstantInt>(CDV->getSplatValue()) : nullptr;
766   }
767 
768   // No unsafe math , no constant argument, do nothing
769   if (!isUnsafeMath(FPOp) && !CF && !CINT && !CZero)
770     return false;
771 
772   // 0x1111111 means that we don't do anything for this call.
773   int ci_opr1 = (CINT ? (int)CINT->getSExtValue() : 0x1111111);
774 
775   if ((CF && CF->isZero()) || (CINT && ci_opr1 == 0) || CZero) {
776     //  pow/powr/pown(x, 0) == 1
777     LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1\n");
778     Constant *cnval = ConstantFP::get(eltType, 1.0);
779     if (getVecSize(FInfo) > 1) {
780       cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
781     }
782     replaceCall(FPOp, cnval);
783     return true;
784   }
785   if ((CF && CF->isExactlyValue(1.0)) || (CINT && ci_opr1 == 1)) {
786     // pow/powr/pown(x, 1.0) = x
787     LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << "\n");
788     replaceCall(FPOp, opr0);
789     return true;
790   }
791   if ((CF && CF->isExactlyValue(2.0)) || (CINT && ci_opr1 == 2)) {
792     // pow/powr/pown(x, 2.0) = x*x
793     LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << " * "
794                       << *opr0 << "\n");
795     Value *nval = B.CreateFMul(opr0, opr0, "__pow2");
796     replaceCall(FPOp, nval);
797     return true;
798   }
799   if ((CF && CF->isExactlyValue(-1.0)) || (CINT && ci_opr1 == -1)) {
800     // pow/powr/pown(x, -1.0) = 1.0/x
801     LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1 / " << *opr0 << "\n");
802     Constant *cnval = ConstantFP::get(eltType, 1.0);
803     if (getVecSize(FInfo) > 1) {
804       cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
805     }
806     Value *nval = B.CreateFDiv(cnval, opr0, "__powrecip");
807     replaceCall(FPOp, nval);
808     return true;
809   }
810 
811   if (CF && (CF->isExactlyValue(0.5) || CF->isExactlyValue(-0.5))) {
812     // pow[r](x, [-]0.5) = sqrt(x)
813     bool issqrt = CF->isExactlyValue(0.5);
814     if (FunctionCallee FPExpr =
815             getFunction(M, AMDGPULibFunc(issqrt ? AMDGPULibFunc::EI_SQRT
816                                                 : AMDGPULibFunc::EI_RSQRT,
817                                          FInfo))) {
818       LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << FInfo.getName()
819                         << '(' << *opr0 << ")\n");
820       Value *nval = CreateCallEx(B,FPExpr, opr0, issqrt ? "__pow2sqrt"
821                                                         : "__pow2rsqrt");
822       replaceCall(FPOp, nval);
823       return true;
824     }
825   }
826 
827   if (!isUnsafeMath(FPOp))
828     return false;
829 
830   // Unsafe Math optimization
831 
832   // Remember that ci_opr1 is set if opr1 is integral
833   if (CF) {
834     double dval = (getArgType(FInfo) == AMDGPULibFunc::F32)
835                     ? (double)CF->getValueAPF().convertToFloat()
836                     : CF->getValueAPF().convertToDouble();
837     int ival = (int)dval;
838     if ((double)ival == dval) {
839       ci_opr1 = ival;
840     } else
841       ci_opr1 = 0x11111111;
842   }
843 
844   // pow/powr/pown(x, c) = [1/](x*x*..x); where
845   //   trunc(c) == c && the number of x == c && |c| <= 12
846   unsigned abs_opr1 = (ci_opr1 < 0) ? -ci_opr1 : ci_opr1;
847   if (abs_opr1 <= 12) {
848     Constant *cnval;
849     Value *nval;
850     if (abs_opr1 == 0) {
851       cnval = ConstantFP::get(eltType, 1.0);
852       if (getVecSize(FInfo) > 1) {
853         cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
854       }
855       nval = cnval;
856     } else {
857       Value *valx2 = nullptr;
858       nval = nullptr;
859       while (abs_opr1 > 0) {
860         valx2 = valx2 ? B.CreateFMul(valx2, valx2, "__powx2") : opr0;
861         if (abs_opr1 & 1) {
862           nval = nval ? B.CreateFMul(nval, valx2, "__powprod") : valx2;
863         }
864         abs_opr1 >>= 1;
865       }
866     }
867 
868     if (ci_opr1 < 0) {
869       cnval = ConstantFP::get(eltType, 1.0);
870       if (getVecSize(FInfo) > 1) {
871         cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
872       }
873       nval = B.CreateFDiv(cnval, nval, "__1powprod");
874     }
875     LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> "
876                       << ((ci_opr1 < 0) ? "1/prod(" : "prod(") << *opr0
877                       << ")\n");
878     replaceCall(FPOp, nval);
879     return true;
880   }
881 
882   // powr ---> exp2(y * log2(x))
883   // pown/pow ---> powr(fabs(x), y) | (x & ((int)y << 31))
884   FunctionCallee ExpExpr =
885       getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_EXP2, FInfo));
886   if (!ExpExpr)
887     return false;
888 
889   bool needlog = false;
890   bool needabs = false;
891   bool needcopysign = false;
892   Constant *cnval = nullptr;
893   if (getVecSize(FInfo) == 1) {
894     CF = dyn_cast<ConstantFP>(opr0);
895 
896     if (CF) {
897       double V = (getArgType(FInfo) == AMDGPULibFunc::F32)
898                    ? (double)CF->getValueAPF().convertToFloat()
899                    : CF->getValueAPF().convertToDouble();
900 
901       V = log2(std::abs(V));
902       cnval = ConstantFP::get(eltType, V);
903       needcopysign = (FInfo.getId() != AMDGPULibFunc::EI_POWR) &&
904                      CF->isNegative();
905     } else {
906       needlog = true;
907       needcopysign = needabs = FInfo.getId() != AMDGPULibFunc::EI_POWR &&
908                                (!CF || CF->isNegative());
909     }
910   } else {
911     ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr0);
912 
913     if (!CDV) {
914       needlog = true;
915       needcopysign = needabs = FInfo.getId() != AMDGPULibFunc::EI_POWR;
916     } else {
917       assert ((int)CDV->getNumElements() == getVecSize(FInfo) &&
918               "Wrong vector size detected");
919 
920       SmallVector<double, 0> DVal;
921       for (int i=0; i < getVecSize(FInfo); ++i) {
922         double V = (getArgType(FInfo) == AMDGPULibFunc::F32)
923                      ? (double)CDV->getElementAsFloat(i)
924                      : CDV->getElementAsDouble(i);
925         if (V < 0.0) needcopysign = true;
926         V = log2(std::abs(V));
927         DVal.push_back(V);
928       }
929       if (getArgType(FInfo) == AMDGPULibFunc::F32) {
930         SmallVector<float, 0> FVal;
931         for (unsigned i=0; i < DVal.size(); ++i) {
932           FVal.push_back((float)DVal[i]);
933         }
934         ArrayRef<float> tmp(FVal);
935         cnval = ConstantDataVector::get(M->getContext(), tmp);
936       } else {
937         ArrayRef<double> tmp(DVal);
938         cnval = ConstantDataVector::get(M->getContext(), tmp);
939       }
940     }
941   }
942 
943   if (needcopysign && (FInfo.getId() == AMDGPULibFunc::EI_POW)) {
944     // We cannot handle corner cases for a general pow() function, give up
945     // unless y is a constant integral value. Then proceed as if it were pown.
946     if (getVecSize(FInfo) == 1) {
947       if (const ConstantFP *CF = dyn_cast<ConstantFP>(opr1)) {
948         double y = (getArgType(FInfo) == AMDGPULibFunc::F32)
949                    ? (double)CF->getValueAPF().convertToFloat()
950                    : CF->getValueAPF().convertToDouble();
951         if (y != (double)(int64_t)y)
952           return false;
953       } else
954         return false;
955     } else {
956       if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr1)) {
957         for (int i=0; i < getVecSize(FInfo); ++i) {
958           double y = (getArgType(FInfo) == AMDGPULibFunc::F32)
959                      ? (double)CDV->getElementAsFloat(i)
960                      : CDV->getElementAsDouble(i);
961           if (y != (double)(int64_t)y)
962             return false;
963         }
964       } else
965         return false;
966     }
967   }
968 
969   Value *nval;
970   if (needabs) {
971     nval = B.CreateUnaryIntrinsic(Intrinsic::fabs, opr0, nullptr, "__fabs");
972   } else {
973     nval = cnval ? cnval : opr0;
974   }
975   if (needlog) {
976     FunctionCallee LogExpr =
977         getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_LOG2, FInfo));
978     if (!LogExpr)
979       return false;
980     nval = CreateCallEx(B,LogExpr, nval, "__log2");
981   }
982 
983   if (FInfo.getId() == AMDGPULibFunc::EI_POWN) {
984     // convert int(32) to fp(f32 or f64)
985     opr1 = B.CreateSIToFP(opr1, nval->getType(), "pownI2F");
986   }
987   nval = B.CreateFMul(opr1, nval, "__ylogx");
988   nval = CreateCallEx(B,ExpExpr, nval, "__exp2");
989 
990   if (needcopysign) {
991     Value *opr_n;
992     Type* rTy = opr0->getType();
993     Type* nTyS = eltType->isDoubleTy() ? B.getInt64Ty() : B.getInt32Ty();
994     Type *nTy = nTyS;
995     if (const auto *vTy = dyn_cast<FixedVectorType>(rTy))
996       nTy = FixedVectorType::get(nTyS, vTy);
997     unsigned size = nTy->getScalarSizeInBits();
998     opr_n = FPOp->getOperand(1);
999     if (opr_n->getType()->isIntegerTy())
1000       opr_n = B.CreateZExtOrBitCast(opr_n, nTy, "__ytou");
1001     else
1002       opr_n = B.CreateFPToSI(opr1, nTy, "__ytou");
1003 
1004     Value *sign = B.CreateShl(opr_n, size-1, "__yeven");
1005     sign = B.CreateAnd(B.CreateBitCast(opr0, nTy), sign, "__pow_sign");
1006     nval = B.CreateOr(B.CreateBitCast(nval, nTy), sign);
1007     nval = B.CreateBitCast(nval, opr0->getType());
1008   }
1009 
1010   LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> "
1011                     << "exp2(" << *opr1 << " * log2(" << *opr0 << "))\n");
1012   replaceCall(FPOp, nval);
1013 
1014   return true;
1015 }
1016 
1017 bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B,
1018                                 const FuncInfo &FInfo) {
1019   // skip vector function
1020   if (getVecSize(FInfo) != 1)
1021     return false;
1022 
1023   Value *opr0 = FPOp->getOperand(0);
1024   Value *opr1 = FPOp->getOperand(1);
1025 
1026   ConstantInt *CINT = dyn_cast<ConstantInt>(opr1);
1027   if (!CINT) {
1028     return false;
1029   }
1030   int ci_opr1 = (int)CINT->getSExtValue();
1031   if (ci_opr1 == 1) {  // rootn(x, 1) = x
1032     LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << "\n");
1033     replaceCall(FPOp, opr0);
1034     return true;
1035   }
1036 
1037   Module *M = B.GetInsertBlock()->getModule();
1038   if (ci_opr1 == 2) { // rootn(x, 2) = sqrt(x)
1039     if (FunctionCallee FPExpr =
1040             getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
1041       LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> sqrt(" << *opr0
1042                         << ")\n");
1043       Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2sqrt");
1044       replaceCall(FPOp, nval);
1045       return true;
1046     }
1047   } else if (ci_opr1 == 3) { // rootn(x, 3) = cbrt(x)
1048     if (FunctionCallee FPExpr =
1049             getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_CBRT, FInfo))) {
1050       LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> cbrt(" << *opr0
1051                         << ")\n");
1052       Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2cbrt");
1053       replaceCall(FPOp, nval);
1054       return true;
1055     }
1056   } else if (ci_opr1 == -1) { // rootn(x, -1) = 1.0/x
1057     LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1.0 / " << *opr0 << "\n");
1058     Value *nval = B.CreateFDiv(ConstantFP::get(opr0->getType(), 1.0),
1059                                opr0,
1060                                "__rootn2div");
1061     replaceCall(FPOp, nval);
1062     return true;
1063   } else if (ci_opr1 == -2) { // rootn(x, -2) = rsqrt(x)
1064     if (FunctionCallee FPExpr =
1065             getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_RSQRT, FInfo))) {
1066       LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> rsqrt(" << *opr0
1067                         << ")\n");
1068       Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2rsqrt");
1069       replaceCall(FPOp, nval);
1070       return true;
1071     }
1072   }
1073   return false;
1074 }
1075 
1076 // Get a scalar native builtin single argument FP function
1077 FunctionCallee AMDGPULibCalls::getNativeFunction(Module *M,
1078                                                  const FuncInfo &FInfo) {
1079   if (getArgType(FInfo) == AMDGPULibFunc::F64 || !HasNative(FInfo.getId()))
1080     return nullptr;
1081   FuncInfo nf = FInfo;
1082   nf.setPrefix(AMDGPULibFunc::NATIVE);
1083   return getFunction(M, nf);
1084 }
1085 
1086 // Some library calls are just wrappers around llvm intrinsics, but compiled
1087 // conservatively. Preserve the flags from the original call site by
1088 // substituting them with direct calls with all the flags.
1089 bool AMDGPULibCalls::shouldReplaceLibcallWithIntrinsic(const CallInst *CI,
1090                                                        bool AllowMinSizeF32,
1091                                                        bool AllowF64,
1092                                                        bool AllowStrictFP) {
1093   Type *FltTy = CI->getType()->getScalarType();
1094   const bool IsF32 = FltTy->isFloatTy();
1095 
1096   // f64 intrinsics aren't implemented for most operations.
1097   if (!IsF32 && !FltTy->isHalfTy() && (!AllowF64 || !FltTy->isDoubleTy()))
1098     return false;
1099 
1100   // We're implicitly inlining by replacing the libcall with the intrinsic, so
1101   // don't do it for noinline call sites.
1102   if (CI->isNoInline())
1103     return false;
1104 
1105   const Function *ParentF = CI->getFunction();
1106   // TODO: Handle strictfp
1107   if (!AllowStrictFP && ParentF->hasFnAttribute(Attribute::StrictFP))
1108     return false;
1109 
1110   if (IsF32 && !AllowMinSizeF32 && ParentF->hasMinSize())
1111     return false;
1112   return true;
1113 }
1114 
1115 void AMDGPULibCalls::replaceLibCallWithSimpleIntrinsic(CallInst *CI,
1116                                                        Intrinsic::ID IntrID) {
1117   CI->setCalledFunction(
1118       Intrinsic::getDeclaration(CI->getModule(), IntrID, {CI->getType()}));
1119 }
1120 
1121 bool AMDGPULibCalls::tryReplaceLibcallWithSimpleIntrinsic(CallInst *CI,
1122                                                           Intrinsic::ID IntrID,
1123                                                           bool AllowMinSizeF32,
1124                                                           bool AllowF64,
1125                                                           bool AllowStrictFP) {
1126   if (!shouldReplaceLibcallWithIntrinsic(CI, AllowMinSizeF32, AllowF64,
1127                                          AllowStrictFP))
1128     return false;
1129   replaceLibCallWithSimpleIntrinsic(CI, IntrID);
1130   return true;
1131 }
1132 
1133 // fold sqrt -> native_sqrt (x)
1134 bool AMDGPULibCalls::fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B,
1135                                const FuncInfo &FInfo) {
1136   if (!isUnsafeMath(FPOp))
1137     return false;
1138 
1139   if (getArgType(FInfo) == AMDGPULibFunc::F32 && (getVecSize(FInfo) == 1) &&
1140       (FInfo.getPrefix() != AMDGPULibFunc::NATIVE)) {
1141     Module *M = B.GetInsertBlock()->getModule();
1142 
1143     if (FunctionCallee FPExpr = getNativeFunction(
1144             M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
1145       Value *opr0 = FPOp->getOperand(0);
1146       LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> "
1147                         << "sqrt(" << *opr0 << ")\n");
1148       Value *nval = CreateCallEx(B,FPExpr, opr0, "__sqrt");
1149       replaceCall(FPOp, nval);
1150       return true;
1151     }
1152   }
1153   return false;
1154 }
1155 
1156 std::tuple<Value *, Value *, Value *>
1157 AMDGPULibCalls::insertSinCos(Value *Arg, FastMathFlags FMF, IRBuilder<> &B,
1158                              FunctionCallee Fsincos) {
1159   DebugLoc DL = B.getCurrentDebugLocation();
1160   Function *F = B.GetInsertBlock()->getParent();
1161   B.SetInsertPointPastAllocas(F);
1162 
1163   AllocaInst *Alloc = B.CreateAlloca(Arg->getType(), nullptr, "__sincos_");
1164 
1165   if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) {
1166     // If the argument is an instruction, it must dominate all uses so put our
1167     // sincos call there. Otherwise, right after the allocas works well enough
1168     // if it's an argument or constant.
1169 
1170     B.SetInsertPoint(ArgInst->getParent(), ++ArgInst->getIterator());
1171 
1172     // SetInsertPoint unwelcomely always tries to set the debug loc.
1173     B.SetCurrentDebugLocation(DL);
1174   }
1175 
1176   Type *CosPtrTy = Fsincos.getFunctionType()->getParamType(1);
1177 
1178   // The allocaInst allocates the memory in private address space. This need
1179   // to be addrspacecasted to point to the address space of cos pointer type.
1180   // In OpenCL 2.0 this is generic, while in 1.2 that is private.
1181   Value *CastAlloc = B.CreateAddrSpaceCast(Alloc, CosPtrTy);
1182 
1183   CallInst *SinCos = CreateCallEx2(B, Fsincos, Arg, CastAlloc);
1184 
1185   // TODO: Is it worth trying to preserve the location for the cos calls for the
1186   // load?
1187 
1188   LoadInst *LoadCos = B.CreateLoad(Alloc->getAllocatedType(), Alloc);
1189   return {SinCos, LoadCos, SinCos};
1190 }
1191 
1192 // fold sin, cos -> sincos.
1193 bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B,
1194                                  const FuncInfo &fInfo) {
1195   assert(fInfo.getId() == AMDGPULibFunc::EI_SIN ||
1196          fInfo.getId() == AMDGPULibFunc::EI_COS);
1197 
1198   if ((getArgType(fInfo) != AMDGPULibFunc::F32 &&
1199        getArgType(fInfo) != AMDGPULibFunc::F64) ||
1200       fInfo.getPrefix() != AMDGPULibFunc::NOPFX)
1201     return false;
1202 
1203   bool const isSin = fInfo.getId() == AMDGPULibFunc::EI_SIN;
1204 
1205   Value *CArgVal = FPOp->getOperand(0);
1206   CallInst *CI = cast<CallInst>(FPOp);
1207 
1208   Function *F = B.GetInsertBlock()->getParent();
1209   Module *M = F->getParent();
1210 
1211   // Merge the sin and cos. For OpenCL 2.0, there may only be a generic pointer
1212   // implementation. Prefer the private form if available.
1213   AMDGPULibFunc SinCosLibFuncPrivate(AMDGPULibFunc::EI_SINCOS, fInfo);
1214   SinCosLibFuncPrivate.getLeads()[0].PtrKind =
1215       AMDGPULibFunc::getEPtrKindFromAddrSpace(AMDGPUAS::PRIVATE_ADDRESS);
1216 
1217   AMDGPULibFunc SinCosLibFuncGeneric(AMDGPULibFunc::EI_SINCOS, fInfo);
1218   SinCosLibFuncGeneric.getLeads()[0].PtrKind =
1219       AMDGPULibFunc::getEPtrKindFromAddrSpace(AMDGPUAS::FLAT_ADDRESS);
1220 
1221   FunctionCallee FSinCosPrivate = getFunction(M, SinCosLibFuncPrivate);
1222   FunctionCallee FSinCosGeneric = getFunction(M, SinCosLibFuncGeneric);
1223   FunctionCallee FSinCos = FSinCosPrivate ? FSinCosPrivate : FSinCosGeneric;
1224   if (!FSinCos)
1225     return false;
1226 
1227   SmallVector<CallInst *> SinCalls;
1228   SmallVector<CallInst *> CosCalls;
1229   SmallVector<CallInst *> SinCosCalls;
1230   FuncInfo PartnerInfo(isSin ? AMDGPULibFunc::EI_COS : AMDGPULibFunc::EI_SIN,
1231                        fInfo);
1232   const std::string PairName = PartnerInfo.mangle();
1233 
1234   StringRef SinName = isSin ? CI->getCalledFunction()->getName() : PairName;
1235   StringRef CosName = isSin ? PairName : CI->getCalledFunction()->getName();
1236   const std::string SinCosPrivateName = SinCosLibFuncPrivate.mangle();
1237   const std::string SinCosGenericName = SinCosLibFuncGeneric.mangle();
1238 
1239   // Intersect the two sets of flags.
1240   FastMathFlags FMF = FPOp->getFastMathFlags();
1241   MDNode *FPMath = CI->getMetadata(LLVMContext::MD_fpmath);
1242 
1243   SmallVector<DILocation *> MergeDbgLocs = {CI->getDebugLoc()};
1244 
1245   for (User* U : CArgVal->users()) {
1246     CallInst *XI = dyn_cast<CallInst>(U);
1247     if (!XI || XI->getFunction() != F || XI->isNoBuiltin())
1248       continue;
1249 
1250     Function *UCallee = XI->getCalledFunction();
1251     if (!UCallee)
1252       continue;
1253 
1254     bool Handled = true;
1255 
1256     if (UCallee->getName() == SinName)
1257       SinCalls.push_back(XI);
1258     else if (UCallee->getName() == CosName)
1259       CosCalls.push_back(XI);
1260     else if (UCallee->getName() == SinCosPrivateName ||
1261              UCallee->getName() == SinCosGenericName)
1262       SinCosCalls.push_back(XI);
1263     else
1264       Handled = false;
1265 
1266     if (Handled) {
1267       MergeDbgLocs.push_back(XI->getDebugLoc());
1268       auto *OtherOp = cast<FPMathOperator>(XI);
1269       FMF &= OtherOp->getFastMathFlags();
1270       FPMath = MDNode::getMostGenericFPMath(
1271           FPMath, XI->getMetadata(LLVMContext::MD_fpmath));
1272     }
1273   }
1274 
1275   if (SinCalls.empty() || CosCalls.empty())
1276     return false;
1277 
1278   B.setFastMathFlags(FMF);
1279   B.setDefaultFPMathTag(FPMath);
1280   DILocation *DbgLoc = DILocation::getMergedLocations(MergeDbgLocs);
1281   B.SetCurrentDebugLocation(DbgLoc);
1282 
1283   auto [Sin, Cos, SinCos] = insertSinCos(CArgVal, FMF, B, FSinCos);
1284 
1285   auto replaceTrigInsts = [](ArrayRef<CallInst *> Calls, Value *Res) {
1286     for (CallInst *C : Calls)
1287       C->replaceAllUsesWith(Res);
1288 
1289     // Leave the other dead instructions to avoid clobbering iterators.
1290   };
1291 
1292   replaceTrigInsts(SinCalls, Sin);
1293   replaceTrigInsts(CosCalls, Cos);
1294   replaceTrigInsts(SinCosCalls, SinCos);
1295 
1296   // It's safe to delete the original now.
1297   CI->eraseFromParent();
1298   return true;
1299 }
1300 
1301 bool AMDGPULibCalls::evaluateScalarMathFunc(const FuncInfo &FInfo,
1302                                             double& Res0, double& Res1,
1303                                             Constant *copr0, Constant *copr1,
1304                                             Constant *copr2) {
1305   // By default, opr0/opr1/opr3 holds values of float/double type.
1306   // If they are not float/double, each function has to its
1307   // operand separately.
1308   double opr0=0.0, opr1=0.0, opr2=0.0;
1309   ConstantFP *fpopr0 = dyn_cast_or_null<ConstantFP>(copr0);
1310   ConstantFP *fpopr1 = dyn_cast_or_null<ConstantFP>(copr1);
1311   ConstantFP *fpopr2 = dyn_cast_or_null<ConstantFP>(copr2);
1312   if (fpopr0) {
1313     opr0 = (getArgType(FInfo) == AMDGPULibFunc::F64)
1314              ? fpopr0->getValueAPF().convertToDouble()
1315              : (double)fpopr0->getValueAPF().convertToFloat();
1316   }
1317 
1318   if (fpopr1) {
1319     opr1 = (getArgType(FInfo) == AMDGPULibFunc::F64)
1320              ? fpopr1->getValueAPF().convertToDouble()
1321              : (double)fpopr1->getValueAPF().convertToFloat();
1322   }
1323 
1324   if (fpopr2) {
1325     opr2 = (getArgType(FInfo) == AMDGPULibFunc::F64)
1326              ? fpopr2->getValueAPF().convertToDouble()
1327              : (double)fpopr2->getValueAPF().convertToFloat();
1328   }
1329 
1330   switch (FInfo.getId()) {
1331   default : return false;
1332 
1333   case AMDGPULibFunc::EI_ACOS:
1334     Res0 = acos(opr0);
1335     return true;
1336 
1337   case AMDGPULibFunc::EI_ACOSH:
1338     // acosh(x) == log(x + sqrt(x*x - 1))
1339     Res0 = log(opr0 + sqrt(opr0*opr0 - 1.0));
1340     return true;
1341 
1342   case AMDGPULibFunc::EI_ACOSPI:
1343     Res0 = acos(opr0) / MATH_PI;
1344     return true;
1345 
1346   case AMDGPULibFunc::EI_ASIN:
1347     Res0 = asin(opr0);
1348     return true;
1349 
1350   case AMDGPULibFunc::EI_ASINH:
1351     // asinh(x) == log(x + sqrt(x*x + 1))
1352     Res0 = log(opr0 + sqrt(opr0*opr0 + 1.0));
1353     return true;
1354 
1355   case AMDGPULibFunc::EI_ASINPI:
1356     Res0 = asin(opr0) / MATH_PI;
1357     return true;
1358 
1359   case AMDGPULibFunc::EI_ATAN:
1360     Res0 = atan(opr0);
1361     return true;
1362 
1363   case AMDGPULibFunc::EI_ATANH:
1364     // atanh(x) == (log(x+1) - log(x-1))/2;
1365     Res0 = (log(opr0 + 1.0) - log(opr0 - 1.0))/2.0;
1366     return true;
1367 
1368   case AMDGPULibFunc::EI_ATANPI:
1369     Res0 = atan(opr0) / MATH_PI;
1370     return true;
1371 
1372   case AMDGPULibFunc::EI_CBRT:
1373     Res0 = (opr0 < 0.0) ? -pow(-opr0, 1.0/3.0) : pow(opr0, 1.0/3.0);
1374     return true;
1375 
1376   case AMDGPULibFunc::EI_COS:
1377     Res0 = cos(opr0);
1378     return true;
1379 
1380   case AMDGPULibFunc::EI_COSH:
1381     Res0 = cosh(opr0);
1382     return true;
1383 
1384   case AMDGPULibFunc::EI_COSPI:
1385     Res0 = cos(MATH_PI * opr0);
1386     return true;
1387 
1388   case AMDGPULibFunc::EI_EXP:
1389     Res0 = exp(opr0);
1390     return true;
1391 
1392   case AMDGPULibFunc::EI_EXP2:
1393     Res0 = pow(2.0, opr0);
1394     return true;
1395 
1396   case AMDGPULibFunc::EI_EXP10:
1397     Res0 = pow(10.0, opr0);
1398     return true;
1399 
1400   case AMDGPULibFunc::EI_LOG:
1401     Res0 = log(opr0);
1402     return true;
1403 
1404   case AMDGPULibFunc::EI_LOG2:
1405     Res0 = log(opr0) / log(2.0);
1406     return true;
1407 
1408   case AMDGPULibFunc::EI_LOG10:
1409     Res0 = log(opr0) / log(10.0);
1410     return true;
1411 
1412   case AMDGPULibFunc::EI_RSQRT:
1413     Res0 = 1.0 / sqrt(opr0);
1414     return true;
1415 
1416   case AMDGPULibFunc::EI_SIN:
1417     Res0 = sin(opr0);
1418     return true;
1419 
1420   case AMDGPULibFunc::EI_SINH:
1421     Res0 = sinh(opr0);
1422     return true;
1423 
1424   case AMDGPULibFunc::EI_SINPI:
1425     Res0 = sin(MATH_PI * opr0);
1426     return true;
1427 
1428   case AMDGPULibFunc::EI_SQRT:
1429     Res0 = sqrt(opr0);
1430     return true;
1431 
1432   case AMDGPULibFunc::EI_TAN:
1433     Res0 = tan(opr0);
1434     return true;
1435 
1436   case AMDGPULibFunc::EI_TANH:
1437     Res0 = tanh(opr0);
1438     return true;
1439 
1440   case AMDGPULibFunc::EI_TANPI:
1441     Res0 = tan(MATH_PI * opr0);
1442     return true;
1443 
1444   case AMDGPULibFunc::EI_RECIP:
1445     Res0 = 1.0 / opr0;
1446     return true;
1447 
1448   // two-arg functions
1449   case AMDGPULibFunc::EI_DIVIDE:
1450     Res0 = opr0 / opr1;
1451     return true;
1452 
1453   case AMDGPULibFunc::EI_POW:
1454   case AMDGPULibFunc::EI_POWR:
1455     Res0 = pow(opr0, opr1);
1456     return true;
1457 
1458   case AMDGPULibFunc::EI_POWN: {
1459     if (ConstantInt *iopr1 = dyn_cast_or_null<ConstantInt>(copr1)) {
1460       double val = (double)iopr1->getSExtValue();
1461       Res0 = pow(opr0, val);
1462       return true;
1463     }
1464     return false;
1465   }
1466 
1467   case AMDGPULibFunc::EI_ROOTN: {
1468     if (ConstantInt *iopr1 = dyn_cast_or_null<ConstantInt>(copr1)) {
1469       double val = (double)iopr1->getSExtValue();
1470       Res0 = pow(opr0, 1.0 / val);
1471       return true;
1472     }
1473     return false;
1474   }
1475 
1476   // with ptr arg
1477   case AMDGPULibFunc::EI_SINCOS:
1478     Res0 = sin(opr0);
1479     Res1 = cos(opr0);
1480     return true;
1481 
1482   // three-arg functions
1483   case AMDGPULibFunc::EI_FMA:
1484   case AMDGPULibFunc::EI_MAD:
1485     Res0 = opr0 * opr1 + opr2;
1486     return true;
1487   }
1488 
1489   return false;
1490 }
1491 
1492 bool AMDGPULibCalls::evaluateCall(CallInst *aCI, const FuncInfo &FInfo) {
1493   int numArgs = (int)aCI->arg_size();
1494   if (numArgs > 3)
1495     return false;
1496 
1497   Constant *copr0 = nullptr;
1498   Constant *copr1 = nullptr;
1499   Constant *copr2 = nullptr;
1500   if (numArgs > 0) {
1501     if ((copr0 = dyn_cast<Constant>(aCI->getArgOperand(0))) == nullptr)
1502       return false;
1503   }
1504 
1505   if (numArgs > 1) {
1506     if ((copr1 = dyn_cast<Constant>(aCI->getArgOperand(1))) == nullptr) {
1507       if (FInfo.getId() != AMDGPULibFunc::EI_SINCOS)
1508         return false;
1509     }
1510   }
1511 
1512   if (numArgs > 2) {
1513     if ((copr2 = dyn_cast<Constant>(aCI->getArgOperand(2))) == nullptr)
1514       return false;
1515   }
1516 
1517   // At this point, all arguments to aCI are constants.
1518 
1519   // max vector size is 16, and sincos will generate two results.
1520   double DVal0[16], DVal1[16];
1521   int FuncVecSize = getVecSize(FInfo);
1522   bool hasTwoResults = (FInfo.getId() == AMDGPULibFunc::EI_SINCOS);
1523   if (FuncVecSize == 1) {
1524     if (!evaluateScalarMathFunc(FInfo, DVal0[0],
1525                                 DVal1[0], copr0, copr1, copr2)) {
1526       return false;
1527     }
1528   } else {
1529     ConstantDataVector *CDV0 = dyn_cast_or_null<ConstantDataVector>(copr0);
1530     ConstantDataVector *CDV1 = dyn_cast_or_null<ConstantDataVector>(copr1);
1531     ConstantDataVector *CDV2 = dyn_cast_or_null<ConstantDataVector>(copr2);
1532     for (int i = 0; i < FuncVecSize; ++i) {
1533       Constant *celt0 = CDV0 ? CDV0->getElementAsConstant(i) : nullptr;
1534       Constant *celt1 = CDV1 ? CDV1->getElementAsConstant(i) : nullptr;
1535       Constant *celt2 = CDV2 ? CDV2->getElementAsConstant(i) : nullptr;
1536       if (!evaluateScalarMathFunc(FInfo, DVal0[i],
1537                                   DVal1[i], celt0, celt1, celt2)) {
1538         return false;
1539       }
1540     }
1541   }
1542 
1543   LLVMContext &context = aCI->getContext();
1544   Constant *nval0, *nval1;
1545   if (FuncVecSize == 1) {
1546     nval0 = ConstantFP::get(aCI->getType(), DVal0[0]);
1547     if (hasTwoResults)
1548       nval1 = ConstantFP::get(aCI->getType(), DVal1[0]);
1549   } else {
1550     if (getArgType(FInfo) == AMDGPULibFunc::F32) {
1551       SmallVector <float, 0> FVal0, FVal1;
1552       for (int i = 0; i < FuncVecSize; ++i)
1553         FVal0.push_back((float)DVal0[i]);
1554       ArrayRef<float> tmp0(FVal0);
1555       nval0 = ConstantDataVector::get(context, tmp0);
1556       if (hasTwoResults) {
1557         for (int i = 0; i < FuncVecSize; ++i)
1558           FVal1.push_back((float)DVal1[i]);
1559         ArrayRef<float> tmp1(FVal1);
1560         nval1 = ConstantDataVector::get(context, tmp1);
1561       }
1562     } else {
1563       ArrayRef<double> tmp0(DVal0);
1564       nval0 = ConstantDataVector::get(context, tmp0);
1565       if (hasTwoResults) {
1566         ArrayRef<double> tmp1(DVal1);
1567         nval1 = ConstantDataVector::get(context, tmp1);
1568       }
1569     }
1570   }
1571 
1572   if (hasTwoResults) {
1573     // sincos
1574     assert(FInfo.getId() == AMDGPULibFunc::EI_SINCOS &&
1575            "math function with ptr arg not supported yet");
1576     new StoreInst(nval1, aCI->getArgOperand(1), aCI);
1577   }
1578 
1579   replaceCall(aCI, nval0);
1580   return true;
1581 }
1582 
1583 PreservedAnalyses AMDGPUSimplifyLibCallsPass::run(Function &F,
1584                                                   FunctionAnalysisManager &AM) {
1585   AMDGPULibCalls Simplifier;
1586   Simplifier.initNativeFuncs();
1587   Simplifier.initFunction(F);
1588 
1589   bool Changed = false;
1590 
1591   LLVM_DEBUG(dbgs() << "AMDIC: process function ";
1592              F.printAsOperand(dbgs(), false, F.getParent()); dbgs() << '\n';);
1593 
1594   for (auto &BB : F) {
1595     for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E;) {
1596       // Ignore non-calls.
1597       CallInst *CI = dyn_cast<CallInst>(I);
1598       ++I;
1599 
1600       if (CI) {
1601         if (Simplifier.fold(CI))
1602           Changed = true;
1603       }
1604     }
1605   }
1606   return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
1607 }
1608 
1609 PreservedAnalyses AMDGPUUseNativeCallsPass::run(Function &F,
1610                                                 FunctionAnalysisManager &AM) {
1611   if (UseNative.empty())
1612     return PreservedAnalyses::all();
1613 
1614   AMDGPULibCalls Simplifier;
1615   Simplifier.initNativeFuncs();
1616   Simplifier.initFunction(F);
1617 
1618   bool Changed = false;
1619   for (auto &BB : F) {
1620     for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E;) {
1621       // Ignore non-calls.
1622       CallInst *CI = dyn_cast<CallInst>(I);
1623       ++I;
1624       if (CI && Simplifier.useNative(CI))
1625         Changed = true;
1626     }
1627   }
1628   return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
1629 }
1630