1 //===- AMDGPULibCalls.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// This file does AMD library function optimizations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "AMDGPU.h" 15 #include "AMDGPULibFunc.h" 16 #include "GCNSubtarget.h" 17 #include "llvm/Analysis/AliasAnalysis.h" 18 #include "llvm/Analysis/Loads.h" 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/IntrinsicInst.h" 21 #include "llvm/IR/IntrinsicsAMDGPU.h" 22 #include "llvm/InitializePasses.h" 23 #include <cmath> 24 25 #define DEBUG_TYPE "amdgpu-simplifylib" 26 27 using namespace llvm; 28 29 static cl::opt<bool> EnablePreLink("amdgpu-prelink", 30 cl::desc("Enable pre-link mode optimizations"), 31 cl::init(false), 32 cl::Hidden); 33 34 static cl::list<std::string> UseNative("amdgpu-use-native", 35 cl::desc("Comma separated list of functions to replace with native, or all"), 36 cl::CommaSeparated, cl::ValueOptional, 37 cl::Hidden); 38 39 #define MATH_PI numbers::pi 40 #define MATH_E numbers::e 41 #define MATH_SQRT2 numbers::sqrt2 42 #define MATH_SQRT1_2 numbers::inv_sqrt2 43 44 namespace llvm { 45 46 class AMDGPULibCalls { 47 private: 48 49 typedef llvm::AMDGPULibFunc FuncInfo; 50 51 bool UnsafeFPMath = false; 52 53 // -fuse-native. 54 bool AllNative = false; 55 56 bool useNativeFunc(const StringRef F) const; 57 58 // Return a pointer (pointer expr) to the function if function definition with 59 // "FuncName" exists. It may create a new function prototype in pre-link mode. 60 FunctionCallee getFunction(Module *M, const FuncInfo &fInfo); 61 62 bool parseFunctionName(const StringRef &FMangledName, FuncInfo &FInfo); 63 64 bool TDOFold(CallInst *CI, const FuncInfo &FInfo); 65 66 /* Specialized optimizations */ 67 68 // pow/powr/pown 69 bool fold_pow(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo); 70 71 // rootn 72 bool fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo); 73 74 // -fuse-native for sincos 75 bool sincosUseNative(CallInst *aCI, const FuncInfo &FInfo); 76 77 // evaluate calls if calls' arguments are constants. 78 bool evaluateScalarMathFunc(const FuncInfo &FInfo, double& Res0, 79 double& Res1, Constant *copr0, Constant *copr1, Constant *copr2); 80 bool evaluateCall(CallInst *aCI, const FuncInfo &FInfo); 81 82 // sqrt 83 bool fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo); 84 85 /// Insert a value to sincos function \p Fsincos. Returns (value of sin, value 86 /// of cos, sincos call). 87 std::tuple<Value *, Value *, Value *> insertSinCos(Value *Arg, 88 FastMathFlags FMF, 89 IRBuilder<> &B, 90 FunctionCallee Fsincos); 91 92 // sin/cos 93 bool fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo); 94 95 // __read_pipe/__write_pipe 96 bool fold_read_write_pipe(CallInst *CI, IRBuilder<> &B, 97 const FuncInfo &FInfo); 98 99 // Get a scalar native builtin single argument FP function 100 FunctionCallee getNativeFunction(Module *M, const FuncInfo &FInfo); 101 102 protected: 103 bool isUnsafeMath(const FPMathOperator *FPOp) const; 104 105 bool canIncreasePrecisionOfConstantFold(const FPMathOperator *FPOp) const; 106 107 static void replaceCall(Instruction *I, Value *With) { 108 I->replaceAllUsesWith(With); 109 I->eraseFromParent(); 110 } 111 112 static void replaceCall(FPMathOperator *I, Value *With) { 113 replaceCall(cast<Instruction>(I), With); 114 } 115 116 public: 117 AMDGPULibCalls() {} 118 119 bool fold(CallInst *CI); 120 121 void initFunction(const Function &F); 122 void initNativeFuncs(); 123 124 // Replace a normal math function call with that native version 125 bool useNative(CallInst *CI); 126 }; 127 128 } // end llvm namespace 129 130 template <typename IRB> 131 static CallInst *CreateCallEx(IRB &B, FunctionCallee Callee, Value *Arg, 132 const Twine &Name = "") { 133 CallInst *R = B.CreateCall(Callee, Arg, Name); 134 if (Function *F = dyn_cast<Function>(Callee.getCallee())) 135 R->setCallingConv(F->getCallingConv()); 136 return R; 137 } 138 139 template <typename IRB> 140 static CallInst *CreateCallEx2(IRB &B, FunctionCallee Callee, Value *Arg1, 141 Value *Arg2, const Twine &Name = "") { 142 CallInst *R = B.CreateCall(Callee, {Arg1, Arg2}, Name); 143 if (Function *F = dyn_cast<Function>(Callee.getCallee())) 144 R->setCallingConv(F->getCallingConv()); 145 return R; 146 } 147 148 // Data structures for table-driven optimizations. 149 // FuncTbl works for both f32 and f64 functions with 1 input argument 150 151 struct TableEntry { 152 double result; 153 double input; 154 }; 155 156 /* a list of {result, input} */ 157 static const TableEntry tbl_acos[] = { 158 {MATH_PI / 2.0, 0.0}, 159 {MATH_PI / 2.0, -0.0}, 160 {0.0, 1.0}, 161 {MATH_PI, -1.0} 162 }; 163 static const TableEntry tbl_acosh[] = { 164 {0.0, 1.0} 165 }; 166 static const TableEntry tbl_acospi[] = { 167 {0.5, 0.0}, 168 {0.5, -0.0}, 169 {0.0, 1.0}, 170 {1.0, -1.0} 171 }; 172 static const TableEntry tbl_asin[] = { 173 {0.0, 0.0}, 174 {-0.0, -0.0}, 175 {MATH_PI / 2.0, 1.0}, 176 {-MATH_PI / 2.0, -1.0} 177 }; 178 static const TableEntry tbl_asinh[] = { 179 {0.0, 0.0}, 180 {-0.0, -0.0} 181 }; 182 static const TableEntry tbl_asinpi[] = { 183 {0.0, 0.0}, 184 {-0.0, -0.0}, 185 {0.5, 1.0}, 186 {-0.5, -1.0} 187 }; 188 static const TableEntry tbl_atan[] = { 189 {0.0, 0.0}, 190 {-0.0, -0.0}, 191 {MATH_PI / 4.0, 1.0}, 192 {-MATH_PI / 4.0, -1.0} 193 }; 194 static const TableEntry tbl_atanh[] = { 195 {0.0, 0.0}, 196 {-0.0, -0.0} 197 }; 198 static const TableEntry tbl_atanpi[] = { 199 {0.0, 0.0}, 200 {-0.0, -0.0}, 201 {0.25, 1.0}, 202 {-0.25, -1.0} 203 }; 204 static const TableEntry tbl_cbrt[] = { 205 {0.0, 0.0}, 206 {-0.0, -0.0}, 207 {1.0, 1.0}, 208 {-1.0, -1.0}, 209 }; 210 static const TableEntry tbl_cos[] = { 211 {1.0, 0.0}, 212 {1.0, -0.0} 213 }; 214 static const TableEntry tbl_cosh[] = { 215 {1.0, 0.0}, 216 {1.0, -0.0} 217 }; 218 static const TableEntry tbl_cospi[] = { 219 {1.0, 0.0}, 220 {1.0, -0.0} 221 }; 222 static const TableEntry tbl_erfc[] = { 223 {1.0, 0.0}, 224 {1.0, -0.0} 225 }; 226 static const TableEntry tbl_erf[] = { 227 {0.0, 0.0}, 228 {-0.0, -0.0} 229 }; 230 static const TableEntry tbl_exp[] = { 231 {1.0, 0.0}, 232 {1.0, -0.0}, 233 {MATH_E, 1.0} 234 }; 235 static const TableEntry tbl_exp2[] = { 236 {1.0, 0.0}, 237 {1.0, -0.0}, 238 {2.0, 1.0} 239 }; 240 static const TableEntry tbl_exp10[] = { 241 {1.0, 0.0}, 242 {1.0, -0.0}, 243 {10.0, 1.0} 244 }; 245 static const TableEntry tbl_expm1[] = { 246 {0.0, 0.0}, 247 {-0.0, -0.0} 248 }; 249 static const TableEntry tbl_log[] = { 250 {0.0, 1.0}, 251 {1.0, MATH_E} 252 }; 253 static const TableEntry tbl_log2[] = { 254 {0.0, 1.0}, 255 {1.0, 2.0} 256 }; 257 static const TableEntry tbl_log10[] = { 258 {0.0, 1.0}, 259 {1.0, 10.0} 260 }; 261 static const TableEntry tbl_rsqrt[] = { 262 {1.0, 1.0}, 263 {MATH_SQRT1_2, 2.0} 264 }; 265 static const TableEntry tbl_sin[] = { 266 {0.0, 0.0}, 267 {-0.0, -0.0} 268 }; 269 static const TableEntry tbl_sinh[] = { 270 {0.0, 0.0}, 271 {-0.0, -0.0} 272 }; 273 static const TableEntry tbl_sinpi[] = { 274 {0.0, 0.0}, 275 {-0.0, -0.0} 276 }; 277 static const TableEntry tbl_sqrt[] = { 278 {0.0, 0.0}, 279 {1.0, 1.0}, 280 {MATH_SQRT2, 2.0} 281 }; 282 static const TableEntry tbl_tan[] = { 283 {0.0, 0.0}, 284 {-0.0, -0.0} 285 }; 286 static const TableEntry tbl_tanh[] = { 287 {0.0, 0.0}, 288 {-0.0, -0.0} 289 }; 290 static const TableEntry tbl_tanpi[] = { 291 {0.0, 0.0}, 292 {-0.0, -0.0} 293 }; 294 static const TableEntry tbl_tgamma[] = { 295 {1.0, 1.0}, 296 {1.0, 2.0}, 297 {2.0, 3.0}, 298 {6.0, 4.0} 299 }; 300 301 static bool HasNative(AMDGPULibFunc::EFuncId id) { 302 switch(id) { 303 case AMDGPULibFunc::EI_DIVIDE: 304 case AMDGPULibFunc::EI_COS: 305 case AMDGPULibFunc::EI_EXP: 306 case AMDGPULibFunc::EI_EXP2: 307 case AMDGPULibFunc::EI_EXP10: 308 case AMDGPULibFunc::EI_LOG: 309 case AMDGPULibFunc::EI_LOG2: 310 case AMDGPULibFunc::EI_LOG10: 311 case AMDGPULibFunc::EI_POWR: 312 case AMDGPULibFunc::EI_RECIP: 313 case AMDGPULibFunc::EI_RSQRT: 314 case AMDGPULibFunc::EI_SIN: 315 case AMDGPULibFunc::EI_SINCOS: 316 case AMDGPULibFunc::EI_SQRT: 317 case AMDGPULibFunc::EI_TAN: 318 return true; 319 default:; 320 } 321 return false; 322 } 323 324 using TableRef = ArrayRef<TableEntry>; 325 326 static TableRef getOptTable(AMDGPULibFunc::EFuncId id) { 327 switch(id) { 328 case AMDGPULibFunc::EI_ACOS: return TableRef(tbl_acos); 329 case AMDGPULibFunc::EI_ACOSH: return TableRef(tbl_acosh); 330 case AMDGPULibFunc::EI_ACOSPI: return TableRef(tbl_acospi); 331 case AMDGPULibFunc::EI_ASIN: return TableRef(tbl_asin); 332 case AMDGPULibFunc::EI_ASINH: return TableRef(tbl_asinh); 333 case AMDGPULibFunc::EI_ASINPI: return TableRef(tbl_asinpi); 334 case AMDGPULibFunc::EI_ATAN: return TableRef(tbl_atan); 335 case AMDGPULibFunc::EI_ATANH: return TableRef(tbl_atanh); 336 case AMDGPULibFunc::EI_ATANPI: return TableRef(tbl_atanpi); 337 case AMDGPULibFunc::EI_CBRT: return TableRef(tbl_cbrt); 338 case AMDGPULibFunc::EI_NCOS: 339 case AMDGPULibFunc::EI_COS: return TableRef(tbl_cos); 340 case AMDGPULibFunc::EI_COSH: return TableRef(tbl_cosh); 341 case AMDGPULibFunc::EI_COSPI: return TableRef(tbl_cospi); 342 case AMDGPULibFunc::EI_ERFC: return TableRef(tbl_erfc); 343 case AMDGPULibFunc::EI_ERF: return TableRef(tbl_erf); 344 case AMDGPULibFunc::EI_EXP: return TableRef(tbl_exp); 345 case AMDGPULibFunc::EI_NEXP2: 346 case AMDGPULibFunc::EI_EXP2: return TableRef(tbl_exp2); 347 case AMDGPULibFunc::EI_EXP10: return TableRef(tbl_exp10); 348 case AMDGPULibFunc::EI_EXPM1: return TableRef(tbl_expm1); 349 case AMDGPULibFunc::EI_LOG: return TableRef(tbl_log); 350 case AMDGPULibFunc::EI_NLOG2: 351 case AMDGPULibFunc::EI_LOG2: return TableRef(tbl_log2); 352 case AMDGPULibFunc::EI_LOG10: return TableRef(tbl_log10); 353 case AMDGPULibFunc::EI_NRSQRT: 354 case AMDGPULibFunc::EI_RSQRT: return TableRef(tbl_rsqrt); 355 case AMDGPULibFunc::EI_NSIN: 356 case AMDGPULibFunc::EI_SIN: return TableRef(tbl_sin); 357 case AMDGPULibFunc::EI_SINH: return TableRef(tbl_sinh); 358 case AMDGPULibFunc::EI_SINPI: return TableRef(tbl_sinpi); 359 case AMDGPULibFunc::EI_NSQRT: 360 case AMDGPULibFunc::EI_SQRT: return TableRef(tbl_sqrt); 361 case AMDGPULibFunc::EI_TAN: return TableRef(tbl_tan); 362 case AMDGPULibFunc::EI_TANH: return TableRef(tbl_tanh); 363 case AMDGPULibFunc::EI_TANPI: return TableRef(tbl_tanpi); 364 case AMDGPULibFunc::EI_TGAMMA: return TableRef(tbl_tgamma); 365 default:; 366 } 367 return TableRef(); 368 } 369 370 static inline int getVecSize(const AMDGPULibFunc& FInfo) { 371 return FInfo.getLeads()[0].VectorSize; 372 } 373 374 static inline AMDGPULibFunc::EType getArgType(const AMDGPULibFunc& FInfo) { 375 return (AMDGPULibFunc::EType)FInfo.getLeads()[0].ArgType; 376 } 377 378 FunctionCallee AMDGPULibCalls::getFunction(Module *M, const FuncInfo &fInfo) { 379 // If we are doing PreLinkOpt, the function is external. So it is safe to 380 // use getOrInsertFunction() at this stage. 381 382 return EnablePreLink ? AMDGPULibFunc::getOrInsertFunction(M, fInfo) 383 : AMDGPULibFunc::getFunction(M, fInfo); 384 } 385 386 bool AMDGPULibCalls::parseFunctionName(const StringRef &FMangledName, 387 FuncInfo &FInfo) { 388 return AMDGPULibFunc::parse(FMangledName, FInfo); 389 } 390 391 bool AMDGPULibCalls::isUnsafeMath(const FPMathOperator *FPOp) const { 392 return UnsafeFPMath || FPOp->isFast(); 393 } 394 395 bool AMDGPULibCalls::canIncreasePrecisionOfConstantFold( 396 const FPMathOperator *FPOp) const { 397 // TODO: Refine to approxFunc or contract 398 return isUnsafeMath(FPOp); 399 } 400 401 void AMDGPULibCalls::initFunction(const Function &F) { 402 UnsafeFPMath = F.getFnAttribute("unsafe-fp-math").getValueAsBool(); 403 } 404 405 bool AMDGPULibCalls::useNativeFunc(const StringRef F) const { 406 return AllNative || llvm::is_contained(UseNative, F); 407 } 408 409 void AMDGPULibCalls::initNativeFuncs() { 410 AllNative = useNativeFunc("all") || 411 (UseNative.getNumOccurrences() && UseNative.size() == 1 && 412 UseNative.begin()->empty()); 413 } 414 415 bool AMDGPULibCalls::sincosUseNative(CallInst *aCI, const FuncInfo &FInfo) { 416 bool native_sin = useNativeFunc("sin"); 417 bool native_cos = useNativeFunc("cos"); 418 419 if (native_sin && native_cos) { 420 Module *M = aCI->getModule(); 421 Value *opr0 = aCI->getArgOperand(0); 422 423 AMDGPULibFunc nf; 424 nf.getLeads()[0].ArgType = FInfo.getLeads()[0].ArgType; 425 nf.getLeads()[0].VectorSize = FInfo.getLeads()[0].VectorSize; 426 427 nf.setPrefix(AMDGPULibFunc::NATIVE); 428 nf.setId(AMDGPULibFunc::EI_SIN); 429 FunctionCallee sinExpr = getFunction(M, nf); 430 431 nf.setPrefix(AMDGPULibFunc::NATIVE); 432 nf.setId(AMDGPULibFunc::EI_COS); 433 FunctionCallee cosExpr = getFunction(M, nf); 434 if (sinExpr && cosExpr) { 435 Value *sinval = CallInst::Create(sinExpr, opr0, "splitsin", aCI); 436 Value *cosval = CallInst::Create(cosExpr, opr0, "splitcos", aCI); 437 new StoreInst(cosval, aCI->getArgOperand(1), aCI); 438 439 DEBUG_WITH_TYPE("usenative", dbgs() << "<useNative> replace " << *aCI 440 << " with native version of sin/cos"); 441 442 replaceCall(aCI, sinval); 443 return true; 444 } 445 } 446 return false; 447 } 448 449 bool AMDGPULibCalls::useNative(CallInst *aCI) { 450 Function *Callee = aCI->getCalledFunction(); 451 if (!Callee || aCI->isNoBuiltin()) 452 return false; 453 454 FuncInfo FInfo; 455 if (!parseFunctionName(Callee->getName(), FInfo) || !FInfo.isMangled() || 456 FInfo.getPrefix() != AMDGPULibFunc::NOPFX || 457 getArgType(FInfo) == AMDGPULibFunc::F64 || !HasNative(FInfo.getId()) || 458 !(AllNative || useNativeFunc(FInfo.getName()))) { 459 return false; 460 } 461 462 if (FInfo.getId() == AMDGPULibFunc::EI_SINCOS) 463 return sincosUseNative(aCI, FInfo); 464 465 FInfo.setPrefix(AMDGPULibFunc::NATIVE); 466 FunctionCallee F = getFunction(aCI->getModule(), FInfo); 467 if (!F) 468 return false; 469 470 aCI->setCalledFunction(F); 471 DEBUG_WITH_TYPE("usenative", dbgs() << "<useNative> replace " << *aCI 472 << " with native version"); 473 return true; 474 } 475 476 // Clang emits call of __read_pipe_2 or __read_pipe_4 for OpenCL read_pipe 477 // builtin, with appended type size and alignment arguments, where 2 or 4 478 // indicates the original number of arguments. The library has optimized version 479 // of __read_pipe_2/__read_pipe_4 when the type size and alignment has the same 480 // power of 2 value. This function transforms __read_pipe_2 to __read_pipe_2_N 481 // for such cases where N is the size in bytes of the type (N = 1, 2, 4, 8, ..., 482 // 128). The same for __read_pipe_4, write_pipe_2, and write_pipe_4. 483 bool AMDGPULibCalls::fold_read_write_pipe(CallInst *CI, IRBuilder<> &B, 484 const FuncInfo &FInfo) { 485 auto *Callee = CI->getCalledFunction(); 486 if (!Callee->isDeclaration()) 487 return false; 488 489 assert(Callee->hasName() && "Invalid read_pipe/write_pipe function"); 490 auto *M = Callee->getParent(); 491 std::string Name = std::string(Callee->getName()); 492 auto NumArg = CI->arg_size(); 493 if (NumArg != 4 && NumArg != 6) 494 return false; 495 ConstantInt *PacketSize = 496 dyn_cast<ConstantInt>(CI->getArgOperand(NumArg - 2)); 497 ConstantInt *PacketAlign = 498 dyn_cast<ConstantInt>(CI->getArgOperand(NumArg - 1)); 499 if (!PacketSize || !PacketAlign) 500 return false; 501 502 unsigned Size = PacketSize->getZExtValue(); 503 Align Alignment = PacketAlign->getAlignValue(); 504 if (Alignment != Size) 505 return false; 506 507 unsigned PtrArgLoc = CI->arg_size() - 3; 508 Value *PtrArg = CI->getArgOperand(PtrArgLoc); 509 Type *PtrTy = PtrArg->getType(); 510 511 SmallVector<llvm::Type *, 6> ArgTys; 512 for (unsigned I = 0; I != PtrArgLoc; ++I) 513 ArgTys.push_back(CI->getArgOperand(I)->getType()); 514 ArgTys.push_back(PtrTy); 515 516 Name = Name + "_" + std::to_string(Size); 517 auto *FTy = FunctionType::get(Callee->getReturnType(), 518 ArrayRef<Type *>(ArgTys), false); 519 AMDGPULibFunc NewLibFunc(Name, FTy); 520 FunctionCallee F = AMDGPULibFunc::getOrInsertFunction(M, NewLibFunc); 521 if (!F) 522 return false; 523 524 auto *BCast = B.CreatePointerCast(PtrArg, PtrTy); 525 SmallVector<Value *, 6> Args; 526 for (unsigned I = 0; I != PtrArgLoc; ++I) 527 Args.push_back(CI->getArgOperand(I)); 528 Args.push_back(BCast); 529 530 auto *NCI = B.CreateCall(F, Args); 531 NCI->setAttributes(CI->getAttributes()); 532 CI->replaceAllUsesWith(NCI); 533 CI->dropAllReferences(); 534 CI->eraseFromParent(); 535 536 return true; 537 } 538 539 // This function returns false if no change; return true otherwise. 540 bool AMDGPULibCalls::fold(CallInst *CI) { 541 Function *Callee = CI->getCalledFunction(); 542 // Ignore indirect calls. 543 if (!Callee || Callee->isIntrinsic() || CI->isNoBuiltin()) 544 return false; 545 546 FuncInfo FInfo; 547 if (!parseFunctionName(Callee->getName(), FInfo)) 548 return false; 549 550 // Further check the number of arguments to see if they match. 551 // TODO: Check calling convention matches too 552 if (!FInfo.isCompatibleSignature(CI->getFunctionType())) 553 return false; 554 555 LLVM_DEBUG(dbgs() << "AMDIC: try folding " << *CI << '\n'); 556 557 if (TDOFold(CI, FInfo)) 558 return true; 559 560 IRBuilder<> B(CI); 561 562 if (FPMathOperator *FPOp = dyn_cast<FPMathOperator>(CI)) { 563 // Under unsafe-math, evaluate calls if possible. 564 // According to Brian Sumner, we can do this for all f32 function calls 565 // using host's double function calls. 566 if (canIncreasePrecisionOfConstantFold(FPOp) && evaluateCall(CI, FInfo)) 567 return true; 568 569 // Copy fast flags from the original call. 570 B.setFastMathFlags(FPOp->getFastMathFlags()); 571 572 // Specialized optimizations for each function call 573 switch (FInfo.getId()) { 574 case AMDGPULibFunc::EI_POW: 575 case AMDGPULibFunc::EI_POWR: 576 case AMDGPULibFunc::EI_POWN: 577 return fold_pow(FPOp, B, FInfo); 578 case AMDGPULibFunc::EI_ROOTN: 579 return fold_rootn(FPOp, B, FInfo); 580 case AMDGPULibFunc::EI_SQRT: 581 return fold_sqrt(FPOp, B, FInfo); 582 case AMDGPULibFunc::EI_COS: 583 case AMDGPULibFunc::EI_SIN: 584 return fold_sincos(FPOp, B, FInfo); 585 default: 586 break; 587 } 588 } else { 589 // Specialized optimizations for each function call 590 switch (FInfo.getId()) { 591 case AMDGPULibFunc::EI_READ_PIPE_2: 592 case AMDGPULibFunc::EI_READ_PIPE_4: 593 case AMDGPULibFunc::EI_WRITE_PIPE_2: 594 case AMDGPULibFunc::EI_WRITE_PIPE_4: 595 return fold_read_write_pipe(CI, B, FInfo); 596 default: 597 break; 598 } 599 } 600 601 return false; 602 } 603 604 bool AMDGPULibCalls::TDOFold(CallInst *CI, const FuncInfo &FInfo) { 605 // Table-Driven optimization 606 const TableRef tr = getOptTable(FInfo.getId()); 607 if (tr.empty()) 608 return false; 609 610 int const sz = (int)tr.size(); 611 Value *opr0 = CI->getArgOperand(0); 612 613 if (getVecSize(FInfo) > 1) { 614 if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(opr0)) { 615 SmallVector<double, 0> DVal; 616 for (int eltNo = 0; eltNo < getVecSize(FInfo); ++eltNo) { 617 ConstantFP *eltval = dyn_cast<ConstantFP>( 618 CV->getElementAsConstant((unsigned)eltNo)); 619 assert(eltval && "Non-FP arguments in math function!"); 620 bool found = false; 621 for (int i=0; i < sz; ++i) { 622 if (eltval->isExactlyValue(tr[i].input)) { 623 DVal.push_back(tr[i].result); 624 found = true; 625 break; 626 } 627 } 628 if (!found) { 629 // This vector constants not handled yet. 630 return false; 631 } 632 } 633 LLVMContext &context = CI->getParent()->getParent()->getContext(); 634 Constant *nval; 635 if (getArgType(FInfo) == AMDGPULibFunc::F32) { 636 SmallVector<float, 0> FVal; 637 for (unsigned i = 0; i < DVal.size(); ++i) { 638 FVal.push_back((float)DVal[i]); 639 } 640 ArrayRef<float> tmp(FVal); 641 nval = ConstantDataVector::get(context, tmp); 642 } else { // F64 643 ArrayRef<double> tmp(DVal); 644 nval = ConstantDataVector::get(context, tmp); 645 } 646 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n"); 647 replaceCall(CI, nval); 648 return true; 649 } 650 } else { 651 // Scalar version 652 if (ConstantFP *CF = dyn_cast<ConstantFP>(opr0)) { 653 for (int i = 0; i < sz; ++i) { 654 if (CF->isExactlyValue(tr[i].input)) { 655 Value *nval = ConstantFP::get(CF->getType(), tr[i].result); 656 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n"); 657 replaceCall(CI, nval); 658 return true; 659 } 660 } 661 } 662 } 663 664 return false; 665 } 666 667 namespace llvm { 668 static double log2(double V) { 669 #if _XOPEN_SOURCE >= 600 || defined(_ISOC99_SOURCE) || _POSIX_C_SOURCE >= 200112L 670 return ::log2(V); 671 #else 672 return log(V) / numbers::ln2; 673 #endif 674 } 675 } 676 677 bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B, 678 const FuncInfo &FInfo) { 679 assert((FInfo.getId() == AMDGPULibFunc::EI_POW || 680 FInfo.getId() == AMDGPULibFunc::EI_POWR || 681 FInfo.getId() == AMDGPULibFunc::EI_POWN) && 682 "fold_pow: encounter a wrong function call"); 683 684 Module *M = B.GetInsertBlock()->getModule(); 685 ConstantFP *CF; 686 ConstantInt *CINT; 687 Type *eltType; 688 Value *opr0 = FPOp->getOperand(0); 689 Value *opr1 = FPOp->getOperand(1); 690 ConstantAggregateZero *CZero = dyn_cast<ConstantAggregateZero>(opr1); 691 692 if (getVecSize(FInfo) == 1) { 693 eltType = opr0->getType(); 694 CF = dyn_cast<ConstantFP>(opr1); 695 CINT = dyn_cast<ConstantInt>(opr1); 696 } else { 697 VectorType *VTy = dyn_cast<VectorType>(opr0->getType()); 698 assert(VTy && "Oprand of vector function should be of vectortype"); 699 eltType = VTy->getElementType(); 700 ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr1); 701 702 // Now, only Handle vector const whose elements have the same value. 703 CF = CDV ? dyn_cast_or_null<ConstantFP>(CDV->getSplatValue()) : nullptr; 704 CINT = CDV ? dyn_cast_or_null<ConstantInt>(CDV->getSplatValue()) : nullptr; 705 } 706 707 // No unsafe math , no constant argument, do nothing 708 if (!isUnsafeMath(FPOp) && !CF && !CINT && !CZero) 709 return false; 710 711 // 0x1111111 means that we don't do anything for this call. 712 int ci_opr1 = (CINT ? (int)CINT->getSExtValue() : 0x1111111); 713 714 if ((CF && CF->isZero()) || (CINT && ci_opr1 == 0) || CZero) { 715 // pow/powr/pown(x, 0) == 1 716 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1\n"); 717 Constant *cnval = ConstantFP::get(eltType, 1.0); 718 if (getVecSize(FInfo) > 1) { 719 cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval); 720 } 721 replaceCall(FPOp, cnval); 722 return true; 723 } 724 if ((CF && CF->isExactlyValue(1.0)) || (CINT && ci_opr1 == 1)) { 725 // pow/powr/pown(x, 1.0) = x 726 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << "\n"); 727 replaceCall(FPOp, opr0); 728 return true; 729 } 730 if ((CF && CF->isExactlyValue(2.0)) || (CINT && ci_opr1 == 2)) { 731 // pow/powr/pown(x, 2.0) = x*x 732 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << " * " 733 << *opr0 << "\n"); 734 Value *nval = B.CreateFMul(opr0, opr0, "__pow2"); 735 replaceCall(FPOp, nval); 736 return true; 737 } 738 if ((CF && CF->isExactlyValue(-1.0)) || (CINT && ci_opr1 == -1)) { 739 // pow/powr/pown(x, -1.0) = 1.0/x 740 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1 / " << *opr0 << "\n"); 741 Constant *cnval = ConstantFP::get(eltType, 1.0); 742 if (getVecSize(FInfo) > 1) { 743 cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval); 744 } 745 Value *nval = B.CreateFDiv(cnval, opr0, "__powrecip"); 746 replaceCall(FPOp, nval); 747 return true; 748 } 749 750 if (CF && (CF->isExactlyValue(0.5) || CF->isExactlyValue(-0.5))) { 751 // pow[r](x, [-]0.5) = sqrt(x) 752 bool issqrt = CF->isExactlyValue(0.5); 753 if (FunctionCallee FPExpr = 754 getFunction(M, AMDGPULibFunc(issqrt ? AMDGPULibFunc::EI_SQRT 755 : AMDGPULibFunc::EI_RSQRT, 756 FInfo))) { 757 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << FInfo.getName() 758 << '(' << *opr0 << ")\n"); 759 Value *nval = CreateCallEx(B,FPExpr, opr0, issqrt ? "__pow2sqrt" 760 : "__pow2rsqrt"); 761 replaceCall(FPOp, nval); 762 return true; 763 } 764 } 765 766 if (!isUnsafeMath(FPOp)) 767 return false; 768 769 // Unsafe Math optimization 770 771 // Remember that ci_opr1 is set if opr1 is integral 772 if (CF) { 773 double dval = (getArgType(FInfo) == AMDGPULibFunc::F32) 774 ? (double)CF->getValueAPF().convertToFloat() 775 : CF->getValueAPF().convertToDouble(); 776 int ival = (int)dval; 777 if ((double)ival == dval) { 778 ci_opr1 = ival; 779 } else 780 ci_opr1 = 0x11111111; 781 } 782 783 // pow/powr/pown(x, c) = [1/](x*x*..x); where 784 // trunc(c) == c && the number of x == c && |c| <= 12 785 unsigned abs_opr1 = (ci_opr1 < 0) ? -ci_opr1 : ci_opr1; 786 if (abs_opr1 <= 12) { 787 Constant *cnval; 788 Value *nval; 789 if (abs_opr1 == 0) { 790 cnval = ConstantFP::get(eltType, 1.0); 791 if (getVecSize(FInfo) > 1) { 792 cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval); 793 } 794 nval = cnval; 795 } else { 796 Value *valx2 = nullptr; 797 nval = nullptr; 798 while (abs_opr1 > 0) { 799 valx2 = valx2 ? B.CreateFMul(valx2, valx2, "__powx2") : opr0; 800 if (abs_opr1 & 1) { 801 nval = nval ? B.CreateFMul(nval, valx2, "__powprod") : valx2; 802 } 803 abs_opr1 >>= 1; 804 } 805 } 806 807 if (ci_opr1 < 0) { 808 cnval = ConstantFP::get(eltType, 1.0); 809 if (getVecSize(FInfo) > 1) { 810 cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval); 811 } 812 nval = B.CreateFDiv(cnval, nval, "__1powprod"); 813 } 814 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " 815 << ((ci_opr1 < 0) ? "1/prod(" : "prod(") << *opr0 816 << ")\n"); 817 replaceCall(FPOp, nval); 818 return true; 819 } 820 821 // powr ---> exp2(y * log2(x)) 822 // pown/pow ---> powr(fabs(x), y) | (x & ((int)y << 31)) 823 FunctionCallee ExpExpr = 824 getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_EXP2, FInfo)); 825 if (!ExpExpr) 826 return false; 827 828 bool needlog = false; 829 bool needabs = false; 830 bool needcopysign = false; 831 Constant *cnval = nullptr; 832 if (getVecSize(FInfo) == 1) { 833 CF = dyn_cast<ConstantFP>(opr0); 834 835 if (CF) { 836 double V = (getArgType(FInfo) == AMDGPULibFunc::F32) 837 ? (double)CF->getValueAPF().convertToFloat() 838 : CF->getValueAPF().convertToDouble(); 839 840 V = log2(std::abs(V)); 841 cnval = ConstantFP::get(eltType, V); 842 needcopysign = (FInfo.getId() != AMDGPULibFunc::EI_POWR) && 843 CF->isNegative(); 844 } else { 845 needlog = true; 846 needcopysign = needabs = FInfo.getId() != AMDGPULibFunc::EI_POWR && 847 (!CF || CF->isNegative()); 848 } 849 } else { 850 ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr0); 851 852 if (!CDV) { 853 needlog = true; 854 needcopysign = needabs = FInfo.getId() != AMDGPULibFunc::EI_POWR; 855 } else { 856 assert ((int)CDV->getNumElements() == getVecSize(FInfo) && 857 "Wrong vector size detected"); 858 859 SmallVector<double, 0> DVal; 860 for (int i=0; i < getVecSize(FInfo); ++i) { 861 double V = (getArgType(FInfo) == AMDGPULibFunc::F32) 862 ? (double)CDV->getElementAsFloat(i) 863 : CDV->getElementAsDouble(i); 864 if (V < 0.0) needcopysign = true; 865 V = log2(std::abs(V)); 866 DVal.push_back(V); 867 } 868 if (getArgType(FInfo) == AMDGPULibFunc::F32) { 869 SmallVector<float, 0> FVal; 870 for (unsigned i=0; i < DVal.size(); ++i) { 871 FVal.push_back((float)DVal[i]); 872 } 873 ArrayRef<float> tmp(FVal); 874 cnval = ConstantDataVector::get(M->getContext(), tmp); 875 } else { 876 ArrayRef<double> tmp(DVal); 877 cnval = ConstantDataVector::get(M->getContext(), tmp); 878 } 879 } 880 } 881 882 if (needcopysign && (FInfo.getId() == AMDGPULibFunc::EI_POW)) { 883 // We cannot handle corner cases for a general pow() function, give up 884 // unless y is a constant integral value. Then proceed as if it were pown. 885 if (getVecSize(FInfo) == 1) { 886 if (const ConstantFP *CF = dyn_cast<ConstantFP>(opr1)) { 887 double y = (getArgType(FInfo) == AMDGPULibFunc::F32) 888 ? (double)CF->getValueAPF().convertToFloat() 889 : CF->getValueAPF().convertToDouble(); 890 if (y != (double)(int64_t)y) 891 return false; 892 } else 893 return false; 894 } else { 895 if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr1)) { 896 for (int i=0; i < getVecSize(FInfo); ++i) { 897 double y = (getArgType(FInfo) == AMDGPULibFunc::F32) 898 ? (double)CDV->getElementAsFloat(i) 899 : CDV->getElementAsDouble(i); 900 if (y != (double)(int64_t)y) 901 return false; 902 } 903 } else 904 return false; 905 } 906 } 907 908 Value *nval; 909 if (needabs) { 910 nval = B.CreateUnaryIntrinsic(Intrinsic::fabs, opr0, nullptr, "__fabs"); 911 } else { 912 nval = cnval ? cnval : opr0; 913 } 914 if (needlog) { 915 FunctionCallee LogExpr = 916 getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_LOG2, FInfo)); 917 if (!LogExpr) 918 return false; 919 nval = CreateCallEx(B,LogExpr, nval, "__log2"); 920 } 921 922 if (FInfo.getId() == AMDGPULibFunc::EI_POWN) { 923 // convert int(32) to fp(f32 or f64) 924 opr1 = B.CreateSIToFP(opr1, nval->getType(), "pownI2F"); 925 } 926 nval = B.CreateFMul(opr1, nval, "__ylogx"); 927 nval = CreateCallEx(B,ExpExpr, nval, "__exp2"); 928 929 if (needcopysign) { 930 Value *opr_n; 931 Type* rTy = opr0->getType(); 932 Type* nTyS = eltType->isDoubleTy() ? B.getInt64Ty() : B.getInt32Ty(); 933 Type *nTy = nTyS; 934 if (const auto *vTy = dyn_cast<FixedVectorType>(rTy)) 935 nTy = FixedVectorType::get(nTyS, vTy); 936 unsigned size = nTy->getScalarSizeInBits(); 937 opr_n = FPOp->getOperand(1); 938 if (opr_n->getType()->isIntegerTy()) 939 opr_n = B.CreateZExtOrBitCast(opr_n, nTy, "__ytou"); 940 else 941 opr_n = B.CreateFPToSI(opr1, nTy, "__ytou"); 942 943 Value *sign = B.CreateShl(opr_n, size-1, "__yeven"); 944 sign = B.CreateAnd(B.CreateBitCast(opr0, nTy), sign, "__pow_sign"); 945 nval = B.CreateOr(B.CreateBitCast(nval, nTy), sign); 946 nval = B.CreateBitCast(nval, opr0->getType()); 947 } 948 949 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " 950 << "exp2(" << *opr1 << " * log2(" << *opr0 << "))\n"); 951 replaceCall(FPOp, nval); 952 953 return true; 954 } 955 956 bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B, 957 const FuncInfo &FInfo) { 958 // skip vector function 959 if (getVecSize(FInfo) != 1) 960 return false; 961 962 Value *opr0 = FPOp->getOperand(0); 963 Value *opr1 = FPOp->getOperand(1); 964 965 ConstantInt *CINT = dyn_cast<ConstantInt>(opr1); 966 if (!CINT) { 967 return false; 968 } 969 int ci_opr1 = (int)CINT->getSExtValue(); 970 if (ci_opr1 == 1) { // rootn(x, 1) = x 971 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << "\n"); 972 replaceCall(FPOp, opr0); 973 return true; 974 } 975 976 Module *M = B.GetInsertBlock()->getModule(); 977 if (ci_opr1 == 2) { // rootn(x, 2) = sqrt(x) 978 if (FunctionCallee FPExpr = 979 getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) { 980 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> sqrt(" << *opr0 981 << ")\n"); 982 Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2sqrt"); 983 replaceCall(FPOp, nval); 984 return true; 985 } 986 } else if (ci_opr1 == 3) { // rootn(x, 3) = cbrt(x) 987 if (FunctionCallee FPExpr = 988 getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_CBRT, FInfo))) { 989 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> cbrt(" << *opr0 990 << ")\n"); 991 Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2cbrt"); 992 replaceCall(FPOp, nval); 993 return true; 994 } 995 } else if (ci_opr1 == -1) { // rootn(x, -1) = 1.0/x 996 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1.0 / " << *opr0 << "\n"); 997 Value *nval = B.CreateFDiv(ConstantFP::get(opr0->getType(), 1.0), 998 opr0, 999 "__rootn2div"); 1000 replaceCall(FPOp, nval); 1001 return true; 1002 } else if (ci_opr1 == -2) { // rootn(x, -2) = rsqrt(x) 1003 if (FunctionCallee FPExpr = 1004 getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_RSQRT, FInfo))) { 1005 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> rsqrt(" << *opr0 1006 << ")\n"); 1007 Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2rsqrt"); 1008 replaceCall(FPOp, nval); 1009 return true; 1010 } 1011 } 1012 return false; 1013 } 1014 1015 // Get a scalar native builtin single argument FP function 1016 FunctionCallee AMDGPULibCalls::getNativeFunction(Module *M, 1017 const FuncInfo &FInfo) { 1018 if (getArgType(FInfo) == AMDGPULibFunc::F64 || !HasNative(FInfo.getId())) 1019 return nullptr; 1020 FuncInfo nf = FInfo; 1021 nf.setPrefix(AMDGPULibFunc::NATIVE); 1022 return getFunction(M, nf); 1023 } 1024 1025 // fold sqrt -> native_sqrt (x) 1026 bool AMDGPULibCalls::fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B, 1027 const FuncInfo &FInfo) { 1028 if (!isUnsafeMath(FPOp)) 1029 return false; 1030 1031 if (getArgType(FInfo) == AMDGPULibFunc::F32 && (getVecSize(FInfo) == 1) && 1032 (FInfo.getPrefix() != AMDGPULibFunc::NATIVE)) { 1033 Module *M = B.GetInsertBlock()->getModule(); 1034 1035 if (FunctionCallee FPExpr = getNativeFunction( 1036 M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) { 1037 Value *opr0 = FPOp->getOperand(0); 1038 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " 1039 << "sqrt(" << *opr0 << ")\n"); 1040 Value *nval = CreateCallEx(B,FPExpr, opr0, "__sqrt"); 1041 replaceCall(FPOp, nval); 1042 return true; 1043 } 1044 } 1045 return false; 1046 } 1047 1048 std::tuple<Value *, Value *, Value *> 1049 AMDGPULibCalls::insertSinCos(Value *Arg, FastMathFlags FMF, IRBuilder<> &B, 1050 FunctionCallee Fsincos) { 1051 DebugLoc DL = B.getCurrentDebugLocation(); 1052 Function *F = B.GetInsertBlock()->getParent(); 1053 B.SetInsertPointPastAllocas(F); 1054 1055 AllocaInst *Alloc = B.CreateAlloca(Arg->getType(), nullptr, "__sincos_"); 1056 1057 if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) { 1058 // If the argument is an instruction, it must dominate all uses so put our 1059 // sincos call there. Otherwise, right after the allocas works well enough 1060 // if it's an argument or constant. 1061 1062 B.SetInsertPoint(ArgInst->getParent(), ++ArgInst->getIterator()); 1063 1064 // SetInsertPoint unwelcomely always tries to set the debug loc. 1065 B.SetCurrentDebugLocation(DL); 1066 } 1067 1068 Value *P = Alloc; 1069 Type *PTy = Fsincos.getFunctionType()->getParamType(1); 1070 // The allocaInst allocates the memory in private address space. This need 1071 // to be bitcasted to point to the address space of cos pointer type. 1072 // In OpenCL 2.0 this is generic, while in 1.2 that is private. 1073 if (PTy->getPointerAddressSpace() != AMDGPUAS::PRIVATE_ADDRESS) 1074 P = B.CreateAddrSpaceCast(Alloc, PTy); 1075 1076 CallInst *SinCos = CreateCallEx2(B, Fsincos, Arg, P); 1077 1078 // TODO: Is it worth trying to preserve the location for the cos calls for the 1079 // load? 1080 LoadInst *LoadCos = B.CreateLoad(Alloc->getAllocatedType(), Alloc); 1081 return {SinCos, LoadCos, SinCos}; 1082 } 1083 1084 // fold sin, cos -> sincos. 1085 bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B, 1086 const FuncInfo &fInfo) { 1087 assert(fInfo.getId() == AMDGPULibFunc::EI_SIN || 1088 fInfo.getId() == AMDGPULibFunc::EI_COS); 1089 1090 if ((getArgType(fInfo) != AMDGPULibFunc::F32 && 1091 getArgType(fInfo) != AMDGPULibFunc::F64) || 1092 fInfo.getPrefix() != AMDGPULibFunc::NOPFX) 1093 return false; 1094 1095 bool const isSin = fInfo.getId() == AMDGPULibFunc::EI_SIN; 1096 1097 Value *CArgVal = FPOp->getOperand(0); 1098 CallInst *CI = cast<CallInst>(FPOp); 1099 1100 Function *F = B.GetInsertBlock()->getParent(); 1101 Module *M = F->getParent(); 1102 1103 // Merge the sin and cos. 1104 1105 // for OpenCL 2.0 we have only generic implementation of sincos 1106 // function. 1107 // FIXME: This is not true anymore 1108 AMDGPULibFunc SinCosLibFunc(AMDGPULibFunc::EI_SINCOS, fInfo); 1109 SinCosLibFunc.getLeads()[0].PtrKind = 1110 AMDGPULibFunc::getEPtrKindFromAddrSpace(AMDGPUAS::FLAT_ADDRESS); 1111 FunctionCallee FSinCos = getFunction(M, SinCosLibFunc); 1112 if (!FSinCos) 1113 return false; 1114 1115 SmallVector<CallInst *> SinCalls; 1116 SmallVector<CallInst *> CosCalls; 1117 SmallVector<CallInst *> SinCosCalls; 1118 FuncInfo PartnerInfo(isSin ? AMDGPULibFunc::EI_COS : AMDGPULibFunc::EI_SIN, 1119 fInfo); 1120 const std::string PairName = PartnerInfo.mangle(); 1121 1122 StringRef SinName = isSin ? CI->getCalledFunction()->getName() : PairName; 1123 StringRef CosName = isSin ? PairName : CI->getCalledFunction()->getName(); 1124 const std::string SinCosName = SinCosLibFunc.mangle(); 1125 1126 // Intersect the two sets of flags. 1127 FastMathFlags FMF = FPOp->getFastMathFlags(); 1128 MDNode *FPMath = CI->getMetadata(LLVMContext::MD_fpmath); 1129 1130 SmallVector<DILocation *> MergeDbgLocs = {CI->getDebugLoc()}; 1131 1132 for (User* U : CArgVal->users()) { 1133 CallInst *XI = dyn_cast<CallInst>(U); 1134 if (!XI || XI->getFunction() != F || XI->isNoBuiltin()) 1135 continue; 1136 1137 Function *UCallee = XI->getCalledFunction(); 1138 if (!UCallee) 1139 continue; 1140 1141 bool Handled = true; 1142 1143 if (UCallee->getName() == SinName) 1144 SinCalls.push_back(XI); 1145 else if (UCallee->getName() == CosName) 1146 CosCalls.push_back(XI); 1147 else if (UCallee->getName() == SinCosName) 1148 SinCosCalls.push_back(XI); 1149 else 1150 Handled = false; 1151 1152 if (Handled) { 1153 MergeDbgLocs.push_back(XI->getDebugLoc()); 1154 auto *OtherOp = cast<FPMathOperator>(XI); 1155 FMF &= OtherOp->getFastMathFlags(); 1156 FPMath = MDNode::getMostGenericFPMath( 1157 FPMath, XI->getMetadata(LLVMContext::MD_fpmath)); 1158 } 1159 } 1160 1161 if (SinCalls.empty() || CosCalls.empty()) 1162 return false; 1163 1164 B.setFastMathFlags(FMF); 1165 B.setDefaultFPMathTag(FPMath); 1166 DILocation *DbgLoc = DILocation::getMergedLocations(MergeDbgLocs); 1167 B.SetCurrentDebugLocation(DbgLoc); 1168 1169 auto [Sin, Cos, SinCos] = insertSinCos(CArgVal, FMF, B, FSinCos); 1170 1171 auto replaceTrigInsts = [](ArrayRef<CallInst *> Calls, Value *Res) { 1172 for (CallInst *C : Calls) 1173 C->replaceAllUsesWith(Res); 1174 1175 // Leave the other dead instructions to avoid clobbering iterators. 1176 }; 1177 1178 replaceTrigInsts(SinCalls, Sin); 1179 replaceTrigInsts(CosCalls, Cos); 1180 replaceTrigInsts(SinCosCalls, SinCos); 1181 1182 // It's safe to delete the original now. 1183 CI->eraseFromParent(); 1184 return true; 1185 } 1186 1187 bool AMDGPULibCalls::evaluateScalarMathFunc(const FuncInfo &FInfo, 1188 double& Res0, double& Res1, 1189 Constant *copr0, Constant *copr1, 1190 Constant *copr2) { 1191 // By default, opr0/opr1/opr3 holds values of float/double type. 1192 // If they are not float/double, each function has to its 1193 // operand separately. 1194 double opr0=0.0, opr1=0.0, opr2=0.0; 1195 ConstantFP *fpopr0 = dyn_cast_or_null<ConstantFP>(copr0); 1196 ConstantFP *fpopr1 = dyn_cast_or_null<ConstantFP>(copr1); 1197 ConstantFP *fpopr2 = dyn_cast_or_null<ConstantFP>(copr2); 1198 if (fpopr0) { 1199 opr0 = (getArgType(FInfo) == AMDGPULibFunc::F64) 1200 ? fpopr0->getValueAPF().convertToDouble() 1201 : (double)fpopr0->getValueAPF().convertToFloat(); 1202 } 1203 1204 if (fpopr1) { 1205 opr1 = (getArgType(FInfo) == AMDGPULibFunc::F64) 1206 ? fpopr1->getValueAPF().convertToDouble() 1207 : (double)fpopr1->getValueAPF().convertToFloat(); 1208 } 1209 1210 if (fpopr2) { 1211 opr2 = (getArgType(FInfo) == AMDGPULibFunc::F64) 1212 ? fpopr2->getValueAPF().convertToDouble() 1213 : (double)fpopr2->getValueAPF().convertToFloat(); 1214 } 1215 1216 switch (FInfo.getId()) { 1217 default : return false; 1218 1219 case AMDGPULibFunc::EI_ACOS: 1220 Res0 = acos(opr0); 1221 return true; 1222 1223 case AMDGPULibFunc::EI_ACOSH: 1224 // acosh(x) == log(x + sqrt(x*x - 1)) 1225 Res0 = log(opr0 + sqrt(opr0*opr0 - 1.0)); 1226 return true; 1227 1228 case AMDGPULibFunc::EI_ACOSPI: 1229 Res0 = acos(opr0) / MATH_PI; 1230 return true; 1231 1232 case AMDGPULibFunc::EI_ASIN: 1233 Res0 = asin(opr0); 1234 return true; 1235 1236 case AMDGPULibFunc::EI_ASINH: 1237 // asinh(x) == log(x + sqrt(x*x + 1)) 1238 Res0 = log(opr0 + sqrt(opr0*opr0 + 1.0)); 1239 return true; 1240 1241 case AMDGPULibFunc::EI_ASINPI: 1242 Res0 = asin(opr0) / MATH_PI; 1243 return true; 1244 1245 case AMDGPULibFunc::EI_ATAN: 1246 Res0 = atan(opr0); 1247 return true; 1248 1249 case AMDGPULibFunc::EI_ATANH: 1250 // atanh(x) == (log(x+1) - log(x-1))/2; 1251 Res0 = (log(opr0 + 1.0) - log(opr0 - 1.0))/2.0; 1252 return true; 1253 1254 case AMDGPULibFunc::EI_ATANPI: 1255 Res0 = atan(opr0) / MATH_PI; 1256 return true; 1257 1258 case AMDGPULibFunc::EI_CBRT: 1259 Res0 = (opr0 < 0.0) ? -pow(-opr0, 1.0/3.0) : pow(opr0, 1.0/3.0); 1260 return true; 1261 1262 case AMDGPULibFunc::EI_COS: 1263 Res0 = cos(opr0); 1264 return true; 1265 1266 case AMDGPULibFunc::EI_COSH: 1267 Res0 = cosh(opr0); 1268 return true; 1269 1270 case AMDGPULibFunc::EI_COSPI: 1271 Res0 = cos(MATH_PI * opr0); 1272 return true; 1273 1274 case AMDGPULibFunc::EI_EXP: 1275 Res0 = exp(opr0); 1276 return true; 1277 1278 case AMDGPULibFunc::EI_EXP2: 1279 Res0 = pow(2.0, opr0); 1280 return true; 1281 1282 case AMDGPULibFunc::EI_EXP10: 1283 Res0 = pow(10.0, opr0); 1284 return true; 1285 1286 case AMDGPULibFunc::EI_LOG: 1287 Res0 = log(opr0); 1288 return true; 1289 1290 case AMDGPULibFunc::EI_LOG2: 1291 Res0 = log(opr0) / log(2.0); 1292 return true; 1293 1294 case AMDGPULibFunc::EI_LOG10: 1295 Res0 = log(opr0) / log(10.0); 1296 return true; 1297 1298 case AMDGPULibFunc::EI_RSQRT: 1299 Res0 = 1.0 / sqrt(opr0); 1300 return true; 1301 1302 case AMDGPULibFunc::EI_SIN: 1303 Res0 = sin(opr0); 1304 return true; 1305 1306 case AMDGPULibFunc::EI_SINH: 1307 Res0 = sinh(opr0); 1308 return true; 1309 1310 case AMDGPULibFunc::EI_SINPI: 1311 Res0 = sin(MATH_PI * opr0); 1312 return true; 1313 1314 case AMDGPULibFunc::EI_SQRT: 1315 Res0 = sqrt(opr0); 1316 return true; 1317 1318 case AMDGPULibFunc::EI_TAN: 1319 Res0 = tan(opr0); 1320 return true; 1321 1322 case AMDGPULibFunc::EI_TANH: 1323 Res0 = tanh(opr0); 1324 return true; 1325 1326 case AMDGPULibFunc::EI_TANPI: 1327 Res0 = tan(MATH_PI * opr0); 1328 return true; 1329 1330 case AMDGPULibFunc::EI_RECIP: 1331 Res0 = 1.0 / opr0; 1332 return true; 1333 1334 // two-arg functions 1335 case AMDGPULibFunc::EI_DIVIDE: 1336 Res0 = opr0 / opr1; 1337 return true; 1338 1339 case AMDGPULibFunc::EI_POW: 1340 case AMDGPULibFunc::EI_POWR: 1341 Res0 = pow(opr0, opr1); 1342 return true; 1343 1344 case AMDGPULibFunc::EI_POWN: { 1345 if (ConstantInt *iopr1 = dyn_cast_or_null<ConstantInt>(copr1)) { 1346 double val = (double)iopr1->getSExtValue(); 1347 Res0 = pow(opr0, val); 1348 return true; 1349 } 1350 return false; 1351 } 1352 1353 case AMDGPULibFunc::EI_ROOTN: { 1354 if (ConstantInt *iopr1 = dyn_cast_or_null<ConstantInt>(copr1)) { 1355 double val = (double)iopr1->getSExtValue(); 1356 Res0 = pow(opr0, 1.0 / val); 1357 return true; 1358 } 1359 return false; 1360 } 1361 1362 // with ptr arg 1363 case AMDGPULibFunc::EI_SINCOS: 1364 Res0 = sin(opr0); 1365 Res1 = cos(opr0); 1366 return true; 1367 1368 // three-arg functions 1369 case AMDGPULibFunc::EI_FMA: 1370 case AMDGPULibFunc::EI_MAD: 1371 Res0 = opr0 * opr1 + opr2; 1372 return true; 1373 } 1374 1375 return false; 1376 } 1377 1378 bool AMDGPULibCalls::evaluateCall(CallInst *aCI, const FuncInfo &FInfo) { 1379 int numArgs = (int)aCI->arg_size(); 1380 if (numArgs > 3) 1381 return false; 1382 1383 Constant *copr0 = nullptr; 1384 Constant *copr1 = nullptr; 1385 Constant *copr2 = nullptr; 1386 if (numArgs > 0) { 1387 if ((copr0 = dyn_cast<Constant>(aCI->getArgOperand(0))) == nullptr) 1388 return false; 1389 } 1390 1391 if (numArgs > 1) { 1392 if ((copr1 = dyn_cast<Constant>(aCI->getArgOperand(1))) == nullptr) { 1393 if (FInfo.getId() != AMDGPULibFunc::EI_SINCOS) 1394 return false; 1395 } 1396 } 1397 1398 if (numArgs > 2) { 1399 if ((copr2 = dyn_cast<Constant>(aCI->getArgOperand(2))) == nullptr) 1400 return false; 1401 } 1402 1403 // At this point, all arguments to aCI are constants. 1404 1405 // max vector size is 16, and sincos will generate two results. 1406 double DVal0[16], DVal1[16]; 1407 int FuncVecSize = getVecSize(FInfo); 1408 bool hasTwoResults = (FInfo.getId() == AMDGPULibFunc::EI_SINCOS); 1409 if (FuncVecSize == 1) { 1410 if (!evaluateScalarMathFunc(FInfo, DVal0[0], 1411 DVal1[0], copr0, copr1, copr2)) { 1412 return false; 1413 } 1414 } else { 1415 ConstantDataVector *CDV0 = dyn_cast_or_null<ConstantDataVector>(copr0); 1416 ConstantDataVector *CDV1 = dyn_cast_or_null<ConstantDataVector>(copr1); 1417 ConstantDataVector *CDV2 = dyn_cast_or_null<ConstantDataVector>(copr2); 1418 for (int i = 0; i < FuncVecSize; ++i) { 1419 Constant *celt0 = CDV0 ? CDV0->getElementAsConstant(i) : nullptr; 1420 Constant *celt1 = CDV1 ? CDV1->getElementAsConstant(i) : nullptr; 1421 Constant *celt2 = CDV2 ? CDV2->getElementAsConstant(i) : nullptr; 1422 if (!evaluateScalarMathFunc(FInfo, DVal0[i], 1423 DVal1[i], celt0, celt1, celt2)) { 1424 return false; 1425 } 1426 } 1427 } 1428 1429 LLVMContext &context = aCI->getContext(); 1430 Constant *nval0, *nval1; 1431 if (FuncVecSize == 1) { 1432 nval0 = ConstantFP::get(aCI->getType(), DVal0[0]); 1433 if (hasTwoResults) 1434 nval1 = ConstantFP::get(aCI->getType(), DVal1[0]); 1435 } else { 1436 if (getArgType(FInfo) == AMDGPULibFunc::F32) { 1437 SmallVector <float, 0> FVal0, FVal1; 1438 for (int i = 0; i < FuncVecSize; ++i) 1439 FVal0.push_back((float)DVal0[i]); 1440 ArrayRef<float> tmp0(FVal0); 1441 nval0 = ConstantDataVector::get(context, tmp0); 1442 if (hasTwoResults) { 1443 for (int i = 0; i < FuncVecSize; ++i) 1444 FVal1.push_back((float)DVal1[i]); 1445 ArrayRef<float> tmp1(FVal1); 1446 nval1 = ConstantDataVector::get(context, tmp1); 1447 } 1448 } else { 1449 ArrayRef<double> tmp0(DVal0); 1450 nval0 = ConstantDataVector::get(context, tmp0); 1451 if (hasTwoResults) { 1452 ArrayRef<double> tmp1(DVal1); 1453 nval1 = ConstantDataVector::get(context, tmp1); 1454 } 1455 } 1456 } 1457 1458 if (hasTwoResults) { 1459 // sincos 1460 assert(FInfo.getId() == AMDGPULibFunc::EI_SINCOS && 1461 "math function with ptr arg not supported yet"); 1462 new StoreInst(nval1, aCI->getArgOperand(1), aCI); 1463 } 1464 1465 replaceCall(aCI, nval0); 1466 return true; 1467 } 1468 1469 PreservedAnalyses AMDGPUSimplifyLibCallsPass::run(Function &F, 1470 FunctionAnalysisManager &AM) { 1471 AMDGPULibCalls Simplifier; 1472 Simplifier.initNativeFuncs(); 1473 Simplifier.initFunction(F); 1474 1475 bool Changed = false; 1476 1477 LLVM_DEBUG(dbgs() << "AMDIC: process function "; 1478 F.printAsOperand(dbgs(), false, F.getParent()); dbgs() << '\n';); 1479 1480 for (auto &BB : F) { 1481 for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E;) { 1482 // Ignore non-calls. 1483 CallInst *CI = dyn_cast<CallInst>(I); 1484 ++I; 1485 1486 if (CI) { 1487 if (Simplifier.fold(CI)) 1488 Changed = true; 1489 } 1490 } 1491 } 1492 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); 1493 } 1494 1495 PreservedAnalyses AMDGPUUseNativeCallsPass::run(Function &F, 1496 FunctionAnalysisManager &AM) { 1497 if (UseNative.empty()) 1498 return PreservedAnalyses::all(); 1499 1500 AMDGPULibCalls Simplifier; 1501 Simplifier.initNativeFuncs(); 1502 Simplifier.initFunction(F); 1503 1504 bool Changed = false; 1505 for (auto &BB : F) { 1506 for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E;) { 1507 // Ignore non-calls. 1508 CallInst *CI = dyn_cast<CallInst>(I); 1509 ++I; 1510 if (CI && Simplifier.useNative(CI)) 1511 Changed = true; 1512 } 1513 } 1514 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); 1515 } 1516