1 //===- AMDGPULibCalls.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// This file does AMD library function optimizations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "AMDGPU.h" 15 #include "AMDGPULibFunc.h" 16 #include "GCNSubtarget.h" 17 #include "llvm/Analysis/AliasAnalysis.h" 18 #include "llvm/Analysis/Loads.h" 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/IntrinsicInst.h" 21 #include "llvm/IR/IntrinsicsAMDGPU.h" 22 #include "llvm/InitializePasses.h" 23 #include <cmath> 24 25 #define DEBUG_TYPE "amdgpu-simplifylib" 26 27 using namespace llvm; 28 29 static cl::opt<bool> EnablePreLink("amdgpu-prelink", 30 cl::desc("Enable pre-link mode optimizations"), 31 cl::init(false), 32 cl::Hidden); 33 34 static cl::list<std::string> UseNative("amdgpu-use-native", 35 cl::desc("Comma separated list of functions to replace with native, or all"), 36 cl::CommaSeparated, cl::ValueOptional, 37 cl::Hidden); 38 39 #define MATH_PI numbers::pi 40 #define MATH_E numbers::e 41 #define MATH_SQRT2 numbers::sqrt2 42 #define MATH_SQRT1_2 numbers::inv_sqrt2 43 44 namespace llvm { 45 46 class AMDGPULibCalls { 47 private: 48 49 typedef llvm::AMDGPULibFunc FuncInfo; 50 51 bool UnsafeFPMath = false; 52 53 // -fuse-native. 54 bool AllNative = false; 55 56 bool useNativeFunc(const StringRef F) const; 57 58 // Return a pointer (pointer expr) to the function if function definition with 59 // "FuncName" exists. It may create a new function prototype in pre-link mode. 60 FunctionCallee getFunction(Module *M, const FuncInfo &fInfo); 61 62 bool parseFunctionName(const StringRef &FMangledName, FuncInfo &FInfo); 63 64 bool TDOFold(CallInst *CI, const FuncInfo &FInfo); 65 66 /* Specialized optimizations */ 67 68 // recip (half or native) 69 bool fold_recip(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo); 70 71 // divide (half or native) 72 bool fold_divide(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo); 73 74 // pow/powr/pown 75 bool fold_pow(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo); 76 77 // rootn 78 bool fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo); 79 80 // fma/mad 81 bool fold_fma_mad(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo); 82 83 // -fuse-native for sincos 84 bool sincosUseNative(CallInst *aCI, const FuncInfo &FInfo); 85 86 // evaluate calls if calls' arguments are constants. 87 bool evaluateScalarMathFunc(const FuncInfo &FInfo, double& Res0, 88 double& Res1, Constant *copr0, Constant *copr1, Constant *copr2); 89 bool evaluateCall(CallInst *aCI, const FuncInfo &FInfo); 90 91 // sqrt 92 bool fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo); 93 94 // sin/cos 95 bool fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo); 96 97 // __read_pipe/__write_pipe 98 bool fold_read_write_pipe(CallInst *CI, IRBuilder<> &B, 99 const FuncInfo &FInfo); 100 101 // Get insertion point at entry. 102 BasicBlock::iterator getEntryIns(CallInst * UI); 103 // Insert an Alloc instruction. 104 AllocaInst* insertAlloca(CallInst * UI, IRBuilder<> &B, const char *prefix); 105 106 // Get a scalar native builtin single argument FP function 107 FunctionCallee getNativeFunction(Module *M, const FuncInfo &FInfo); 108 109 protected: 110 bool isUnsafeMath(const FPMathOperator *FPOp) const; 111 112 bool canIncreasePrecisionOfConstantFold(const FPMathOperator *FPOp) const; 113 114 static void replaceCall(Instruction *I, Value *With) { 115 I->replaceAllUsesWith(With); 116 I->eraseFromParent(); 117 } 118 119 static void replaceCall(FPMathOperator *I, Value *With) { 120 replaceCall(cast<Instruction>(I), With); 121 } 122 123 public: 124 AMDGPULibCalls() {} 125 126 bool fold(CallInst *CI); 127 128 void initFunction(const Function &F); 129 void initNativeFuncs(); 130 131 // Replace a normal math function call with that native version 132 bool useNative(CallInst *CI); 133 }; 134 135 } // end llvm namespace 136 137 template <typename IRB> 138 static CallInst *CreateCallEx(IRB &B, FunctionCallee Callee, Value *Arg, 139 const Twine &Name = "") { 140 CallInst *R = B.CreateCall(Callee, Arg, Name); 141 if (Function *F = dyn_cast<Function>(Callee.getCallee())) 142 R->setCallingConv(F->getCallingConv()); 143 return R; 144 } 145 146 template <typename IRB> 147 static CallInst *CreateCallEx2(IRB &B, FunctionCallee Callee, Value *Arg1, 148 Value *Arg2, const Twine &Name = "") { 149 CallInst *R = B.CreateCall(Callee, {Arg1, Arg2}, Name); 150 if (Function *F = dyn_cast<Function>(Callee.getCallee())) 151 R->setCallingConv(F->getCallingConv()); 152 return R; 153 } 154 155 // Data structures for table-driven optimizations. 156 // FuncTbl works for both f32 and f64 functions with 1 input argument 157 158 struct TableEntry { 159 double result; 160 double input; 161 }; 162 163 /* a list of {result, input} */ 164 static const TableEntry tbl_acos[] = { 165 {MATH_PI / 2.0, 0.0}, 166 {MATH_PI / 2.0, -0.0}, 167 {0.0, 1.0}, 168 {MATH_PI, -1.0} 169 }; 170 static const TableEntry tbl_acosh[] = { 171 {0.0, 1.0} 172 }; 173 static const TableEntry tbl_acospi[] = { 174 {0.5, 0.0}, 175 {0.5, -0.0}, 176 {0.0, 1.0}, 177 {1.0, -1.0} 178 }; 179 static const TableEntry tbl_asin[] = { 180 {0.0, 0.0}, 181 {-0.0, -0.0}, 182 {MATH_PI / 2.0, 1.0}, 183 {-MATH_PI / 2.0, -1.0} 184 }; 185 static const TableEntry tbl_asinh[] = { 186 {0.0, 0.0}, 187 {-0.0, -0.0} 188 }; 189 static const TableEntry tbl_asinpi[] = { 190 {0.0, 0.0}, 191 {-0.0, -0.0}, 192 {0.5, 1.0}, 193 {-0.5, -1.0} 194 }; 195 static const TableEntry tbl_atan[] = { 196 {0.0, 0.0}, 197 {-0.0, -0.0}, 198 {MATH_PI / 4.0, 1.0}, 199 {-MATH_PI / 4.0, -1.0} 200 }; 201 static const TableEntry tbl_atanh[] = { 202 {0.0, 0.0}, 203 {-0.0, -0.0} 204 }; 205 static const TableEntry tbl_atanpi[] = { 206 {0.0, 0.0}, 207 {-0.0, -0.0}, 208 {0.25, 1.0}, 209 {-0.25, -1.0} 210 }; 211 static const TableEntry tbl_cbrt[] = { 212 {0.0, 0.0}, 213 {-0.0, -0.0}, 214 {1.0, 1.0}, 215 {-1.0, -1.0}, 216 }; 217 static const TableEntry tbl_cos[] = { 218 {1.0, 0.0}, 219 {1.0, -0.0} 220 }; 221 static const TableEntry tbl_cosh[] = { 222 {1.0, 0.0}, 223 {1.0, -0.0} 224 }; 225 static const TableEntry tbl_cospi[] = { 226 {1.0, 0.0}, 227 {1.0, -0.0} 228 }; 229 static const TableEntry tbl_erfc[] = { 230 {1.0, 0.0}, 231 {1.0, -0.0} 232 }; 233 static const TableEntry tbl_erf[] = { 234 {0.0, 0.0}, 235 {-0.0, -0.0} 236 }; 237 static const TableEntry tbl_exp[] = { 238 {1.0, 0.0}, 239 {1.0, -0.0}, 240 {MATH_E, 1.0} 241 }; 242 static const TableEntry tbl_exp2[] = { 243 {1.0, 0.0}, 244 {1.0, -0.0}, 245 {2.0, 1.0} 246 }; 247 static const TableEntry tbl_exp10[] = { 248 {1.0, 0.0}, 249 {1.0, -0.0}, 250 {10.0, 1.0} 251 }; 252 static const TableEntry tbl_expm1[] = { 253 {0.0, 0.0}, 254 {-0.0, -0.0} 255 }; 256 static const TableEntry tbl_log[] = { 257 {0.0, 1.0}, 258 {1.0, MATH_E} 259 }; 260 static const TableEntry tbl_log2[] = { 261 {0.0, 1.0}, 262 {1.0, 2.0} 263 }; 264 static const TableEntry tbl_log10[] = { 265 {0.0, 1.0}, 266 {1.0, 10.0} 267 }; 268 static const TableEntry tbl_rsqrt[] = { 269 {1.0, 1.0}, 270 {MATH_SQRT1_2, 2.0} 271 }; 272 static const TableEntry tbl_sin[] = { 273 {0.0, 0.0}, 274 {-0.0, -0.0} 275 }; 276 static const TableEntry tbl_sinh[] = { 277 {0.0, 0.0}, 278 {-0.0, -0.0} 279 }; 280 static const TableEntry tbl_sinpi[] = { 281 {0.0, 0.0}, 282 {-0.0, -0.0} 283 }; 284 static const TableEntry tbl_sqrt[] = { 285 {0.0, 0.0}, 286 {1.0, 1.0}, 287 {MATH_SQRT2, 2.0} 288 }; 289 static const TableEntry tbl_tan[] = { 290 {0.0, 0.0}, 291 {-0.0, -0.0} 292 }; 293 static const TableEntry tbl_tanh[] = { 294 {0.0, 0.0}, 295 {-0.0, -0.0} 296 }; 297 static const TableEntry tbl_tanpi[] = { 298 {0.0, 0.0}, 299 {-0.0, -0.0} 300 }; 301 static const TableEntry tbl_tgamma[] = { 302 {1.0, 1.0}, 303 {1.0, 2.0}, 304 {2.0, 3.0}, 305 {6.0, 4.0} 306 }; 307 308 static bool HasNative(AMDGPULibFunc::EFuncId id) { 309 switch(id) { 310 case AMDGPULibFunc::EI_DIVIDE: 311 case AMDGPULibFunc::EI_COS: 312 case AMDGPULibFunc::EI_EXP: 313 case AMDGPULibFunc::EI_EXP2: 314 case AMDGPULibFunc::EI_EXP10: 315 case AMDGPULibFunc::EI_LOG: 316 case AMDGPULibFunc::EI_LOG2: 317 case AMDGPULibFunc::EI_LOG10: 318 case AMDGPULibFunc::EI_POWR: 319 case AMDGPULibFunc::EI_RECIP: 320 case AMDGPULibFunc::EI_RSQRT: 321 case AMDGPULibFunc::EI_SIN: 322 case AMDGPULibFunc::EI_SINCOS: 323 case AMDGPULibFunc::EI_SQRT: 324 case AMDGPULibFunc::EI_TAN: 325 return true; 326 default:; 327 } 328 return false; 329 } 330 331 using TableRef = ArrayRef<TableEntry>; 332 333 static TableRef getOptTable(AMDGPULibFunc::EFuncId id) { 334 switch(id) { 335 case AMDGPULibFunc::EI_ACOS: return TableRef(tbl_acos); 336 case AMDGPULibFunc::EI_ACOSH: return TableRef(tbl_acosh); 337 case AMDGPULibFunc::EI_ACOSPI: return TableRef(tbl_acospi); 338 case AMDGPULibFunc::EI_ASIN: return TableRef(tbl_asin); 339 case AMDGPULibFunc::EI_ASINH: return TableRef(tbl_asinh); 340 case AMDGPULibFunc::EI_ASINPI: return TableRef(tbl_asinpi); 341 case AMDGPULibFunc::EI_ATAN: return TableRef(tbl_atan); 342 case AMDGPULibFunc::EI_ATANH: return TableRef(tbl_atanh); 343 case AMDGPULibFunc::EI_ATANPI: return TableRef(tbl_atanpi); 344 case AMDGPULibFunc::EI_CBRT: return TableRef(tbl_cbrt); 345 case AMDGPULibFunc::EI_NCOS: 346 case AMDGPULibFunc::EI_COS: return TableRef(tbl_cos); 347 case AMDGPULibFunc::EI_COSH: return TableRef(tbl_cosh); 348 case AMDGPULibFunc::EI_COSPI: return TableRef(tbl_cospi); 349 case AMDGPULibFunc::EI_ERFC: return TableRef(tbl_erfc); 350 case AMDGPULibFunc::EI_ERF: return TableRef(tbl_erf); 351 case AMDGPULibFunc::EI_EXP: return TableRef(tbl_exp); 352 case AMDGPULibFunc::EI_NEXP2: 353 case AMDGPULibFunc::EI_EXP2: return TableRef(tbl_exp2); 354 case AMDGPULibFunc::EI_EXP10: return TableRef(tbl_exp10); 355 case AMDGPULibFunc::EI_EXPM1: return TableRef(tbl_expm1); 356 case AMDGPULibFunc::EI_LOG: return TableRef(tbl_log); 357 case AMDGPULibFunc::EI_NLOG2: 358 case AMDGPULibFunc::EI_LOG2: return TableRef(tbl_log2); 359 case AMDGPULibFunc::EI_LOG10: return TableRef(tbl_log10); 360 case AMDGPULibFunc::EI_NRSQRT: 361 case AMDGPULibFunc::EI_RSQRT: return TableRef(tbl_rsqrt); 362 case AMDGPULibFunc::EI_NSIN: 363 case AMDGPULibFunc::EI_SIN: return TableRef(tbl_sin); 364 case AMDGPULibFunc::EI_SINH: return TableRef(tbl_sinh); 365 case AMDGPULibFunc::EI_SINPI: return TableRef(tbl_sinpi); 366 case AMDGPULibFunc::EI_NSQRT: 367 case AMDGPULibFunc::EI_SQRT: return TableRef(tbl_sqrt); 368 case AMDGPULibFunc::EI_TAN: return TableRef(tbl_tan); 369 case AMDGPULibFunc::EI_TANH: return TableRef(tbl_tanh); 370 case AMDGPULibFunc::EI_TANPI: return TableRef(tbl_tanpi); 371 case AMDGPULibFunc::EI_TGAMMA: return TableRef(tbl_tgamma); 372 default:; 373 } 374 return TableRef(); 375 } 376 377 static inline int getVecSize(const AMDGPULibFunc& FInfo) { 378 return FInfo.getLeads()[0].VectorSize; 379 } 380 381 static inline AMDGPULibFunc::EType getArgType(const AMDGPULibFunc& FInfo) { 382 return (AMDGPULibFunc::EType)FInfo.getLeads()[0].ArgType; 383 } 384 385 FunctionCallee AMDGPULibCalls::getFunction(Module *M, const FuncInfo &fInfo) { 386 // If we are doing PreLinkOpt, the function is external. So it is safe to 387 // use getOrInsertFunction() at this stage. 388 389 return EnablePreLink ? AMDGPULibFunc::getOrInsertFunction(M, fInfo) 390 : AMDGPULibFunc::getFunction(M, fInfo); 391 } 392 393 bool AMDGPULibCalls::parseFunctionName(const StringRef &FMangledName, 394 FuncInfo &FInfo) { 395 return AMDGPULibFunc::parse(FMangledName, FInfo); 396 } 397 398 bool AMDGPULibCalls::isUnsafeMath(const FPMathOperator *FPOp) const { 399 return UnsafeFPMath || FPOp->isFast(); 400 } 401 402 bool AMDGPULibCalls::canIncreasePrecisionOfConstantFold( 403 const FPMathOperator *FPOp) const { 404 // TODO: Refine to approxFunc or contract 405 return isUnsafeMath(FPOp); 406 } 407 408 void AMDGPULibCalls::initFunction(const Function &F) { 409 UnsafeFPMath = F.getFnAttribute("unsafe-fp-math").getValueAsBool(); 410 } 411 412 bool AMDGPULibCalls::useNativeFunc(const StringRef F) const { 413 return AllNative || llvm::is_contained(UseNative, F); 414 } 415 416 void AMDGPULibCalls::initNativeFuncs() { 417 AllNative = useNativeFunc("all") || 418 (UseNative.getNumOccurrences() && UseNative.size() == 1 && 419 UseNative.begin()->empty()); 420 } 421 422 bool AMDGPULibCalls::sincosUseNative(CallInst *aCI, const FuncInfo &FInfo) { 423 bool native_sin = useNativeFunc("sin"); 424 bool native_cos = useNativeFunc("cos"); 425 426 if (native_sin && native_cos) { 427 Module *M = aCI->getModule(); 428 Value *opr0 = aCI->getArgOperand(0); 429 430 AMDGPULibFunc nf; 431 nf.getLeads()[0].ArgType = FInfo.getLeads()[0].ArgType; 432 nf.getLeads()[0].VectorSize = FInfo.getLeads()[0].VectorSize; 433 434 nf.setPrefix(AMDGPULibFunc::NATIVE); 435 nf.setId(AMDGPULibFunc::EI_SIN); 436 FunctionCallee sinExpr = getFunction(M, nf); 437 438 nf.setPrefix(AMDGPULibFunc::NATIVE); 439 nf.setId(AMDGPULibFunc::EI_COS); 440 FunctionCallee cosExpr = getFunction(M, nf); 441 if (sinExpr && cosExpr) { 442 Value *sinval = CallInst::Create(sinExpr, opr0, "splitsin", aCI); 443 Value *cosval = CallInst::Create(cosExpr, opr0, "splitcos", aCI); 444 new StoreInst(cosval, aCI->getArgOperand(1), aCI); 445 446 DEBUG_WITH_TYPE("usenative", dbgs() << "<useNative> replace " << *aCI 447 << " with native version of sin/cos"); 448 449 replaceCall(aCI, sinval); 450 return true; 451 } 452 } 453 return false; 454 } 455 456 bool AMDGPULibCalls::useNative(CallInst *aCI) { 457 Function *Callee = aCI->getCalledFunction(); 458 if (!Callee || aCI->isNoBuiltin()) 459 return false; 460 461 FuncInfo FInfo; 462 if (!parseFunctionName(Callee->getName(), FInfo) || !FInfo.isMangled() || 463 FInfo.getPrefix() != AMDGPULibFunc::NOPFX || 464 getArgType(FInfo) == AMDGPULibFunc::F64 || !HasNative(FInfo.getId()) || 465 !(AllNative || useNativeFunc(FInfo.getName()))) { 466 return false; 467 } 468 469 if (FInfo.getId() == AMDGPULibFunc::EI_SINCOS) 470 return sincosUseNative(aCI, FInfo); 471 472 FInfo.setPrefix(AMDGPULibFunc::NATIVE); 473 FunctionCallee F = getFunction(aCI->getModule(), FInfo); 474 if (!F) 475 return false; 476 477 aCI->setCalledFunction(F); 478 DEBUG_WITH_TYPE("usenative", dbgs() << "<useNative> replace " << *aCI 479 << " with native version"); 480 return true; 481 } 482 483 // Clang emits call of __read_pipe_2 or __read_pipe_4 for OpenCL read_pipe 484 // builtin, with appended type size and alignment arguments, where 2 or 4 485 // indicates the original number of arguments. The library has optimized version 486 // of __read_pipe_2/__read_pipe_4 when the type size and alignment has the same 487 // power of 2 value. This function transforms __read_pipe_2 to __read_pipe_2_N 488 // for such cases where N is the size in bytes of the type (N = 1, 2, 4, 8, ..., 489 // 128). The same for __read_pipe_4, write_pipe_2, and write_pipe_4. 490 bool AMDGPULibCalls::fold_read_write_pipe(CallInst *CI, IRBuilder<> &B, 491 const FuncInfo &FInfo) { 492 auto *Callee = CI->getCalledFunction(); 493 if (!Callee->isDeclaration()) 494 return false; 495 496 assert(Callee->hasName() && "Invalid read_pipe/write_pipe function"); 497 auto *M = Callee->getParent(); 498 std::string Name = std::string(Callee->getName()); 499 auto NumArg = CI->arg_size(); 500 if (NumArg != 4 && NumArg != 6) 501 return false; 502 ConstantInt *PacketSize = 503 dyn_cast<ConstantInt>(CI->getArgOperand(NumArg - 2)); 504 ConstantInt *PacketAlign = 505 dyn_cast<ConstantInt>(CI->getArgOperand(NumArg - 1)); 506 if (!PacketSize || !PacketAlign) 507 return false; 508 509 unsigned Size = PacketSize->getZExtValue(); 510 Align Alignment = PacketAlign->getAlignValue(); 511 if (Alignment != Size) 512 return false; 513 514 unsigned PtrArgLoc = CI->arg_size() - 3; 515 Value *PtrArg = CI->getArgOperand(PtrArgLoc); 516 Type *PtrTy = PtrArg->getType(); 517 518 SmallVector<llvm::Type *, 6> ArgTys; 519 for (unsigned I = 0; I != PtrArgLoc; ++I) 520 ArgTys.push_back(CI->getArgOperand(I)->getType()); 521 ArgTys.push_back(PtrTy); 522 523 Name = Name + "_" + std::to_string(Size); 524 auto *FTy = FunctionType::get(Callee->getReturnType(), 525 ArrayRef<Type *>(ArgTys), false); 526 AMDGPULibFunc NewLibFunc(Name, FTy); 527 FunctionCallee F = AMDGPULibFunc::getOrInsertFunction(M, NewLibFunc); 528 if (!F) 529 return false; 530 531 auto *BCast = B.CreatePointerCast(PtrArg, PtrTy); 532 SmallVector<Value *, 6> Args; 533 for (unsigned I = 0; I != PtrArgLoc; ++I) 534 Args.push_back(CI->getArgOperand(I)); 535 Args.push_back(BCast); 536 537 auto *NCI = B.CreateCall(F, Args); 538 NCI->setAttributes(CI->getAttributes()); 539 CI->replaceAllUsesWith(NCI); 540 CI->dropAllReferences(); 541 CI->eraseFromParent(); 542 543 return true; 544 } 545 546 // This function returns false if no change; return true otherwise. 547 bool AMDGPULibCalls::fold(CallInst *CI) { 548 Function *Callee = CI->getCalledFunction(); 549 // Ignore indirect calls. 550 if (!Callee || Callee->isIntrinsic() || CI->isNoBuiltin()) 551 return false; 552 553 FuncInfo FInfo; 554 if (!parseFunctionName(Callee->getName(), FInfo)) 555 return false; 556 557 // Further check the number of arguments to see if they match. 558 // TODO: Check calling convention matches too 559 if (!FInfo.isCompatibleSignature(CI->getFunctionType())) 560 return false; 561 562 LLVM_DEBUG(dbgs() << "AMDIC: try folding " << *CI << '\n'); 563 564 if (TDOFold(CI, FInfo)) 565 return true; 566 567 IRBuilder<> B(CI); 568 569 if (FPMathOperator *FPOp = dyn_cast<FPMathOperator>(CI)) { 570 // Under unsafe-math, evaluate calls if possible. 571 // According to Brian Sumner, we can do this for all f32 function calls 572 // using host's double function calls. 573 if (canIncreasePrecisionOfConstantFold(FPOp) && evaluateCall(CI, FInfo)) 574 return true; 575 576 // Copy fast flags from the original call. 577 B.setFastMathFlags(FPOp->getFastMathFlags()); 578 579 // Specialized optimizations for each function call 580 switch (FInfo.getId()) { 581 case AMDGPULibFunc::EI_POW: 582 case AMDGPULibFunc::EI_POWR: 583 case AMDGPULibFunc::EI_POWN: 584 return fold_pow(FPOp, B, FInfo); 585 case AMDGPULibFunc::EI_ROOTN: 586 return fold_rootn(FPOp, B, FInfo); 587 case AMDGPULibFunc::EI_SQRT: 588 return fold_sqrt(FPOp, B, FInfo); 589 case AMDGPULibFunc::EI_COS: 590 case AMDGPULibFunc::EI_SIN: 591 return fold_sincos(FPOp, B, FInfo); 592 case AMDGPULibFunc::EI_RECIP: 593 // skip vector function 594 assert((FInfo.getPrefix() == AMDGPULibFunc::NATIVE || 595 FInfo.getPrefix() == AMDGPULibFunc::HALF) && 596 "recip must be an either native or half function"); 597 return (getVecSize(FInfo) != 1) ? false : fold_recip(CI, B, FInfo); 598 599 case AMDGPULibFunc::EI_DIVIDE: 600 // skip vector function 601 assert((FInfo.getPrefix() == AMDGPULibFunc::NATIVE || 602 FInfo.getPrefix() == AMDGPULibFunc::HALF) && 603 "divide must be an either native or half function"); 604 return (getVecSize(FInfo) != 1) ? false : fold_divide(CI, B, FInfo); 605 case AMDGPULibFunc::EI_FMA: 606 case AMDGPULibFunc::EI_MAD: 607 case AMDGPULibFunc::EI_NFMA: 608 // skip vector function 609 return (getVecSize(FInfo) != 1) ? false : fold_fma_mad(CI, B, FInfo); 610 default: 611 break; 612 } 613 } else { 614 // Specialized optimizations for each function call 615 switch (FInfo.getId()) { 616 case AMDGPULibFunc::EI_READ_PIPE_2: 617 case AMDGPULibFunc::EI_READ_PIPE_4: 618 case AMDGPULibFunc::EI_WRITE_PIPE_2: 619 case AMDGPULibFunc::EI_WRITE_PIPE_4: 620 return fold_read_write_pipe(CI, B, FInfo); 621 default: 622 break; 623 } 624 } 625 626 return false; 627 } 628 629 bool AMDGPULibCalls::TDOFold(CallInst *CI, const FuncInfo &FInfo) { 630 // Table-Driven optimization 631 const TableRef tr = getOptTable(FInfo.getId()); 632 if (tr.empty()) 633 return false; 634 635 int const sz = (int)tr.size(); 636 Value *opr0 = CI->getArgOperand(0); 637 638 if (getVecSize(FInfo) > 1) { 639 if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(opr0)) { 640 SmallVector<double, 0> DVal; 641 for (int eltNo = 0; eltNo < getVecSize(FInfo); ++eltNo) { 642 ConstantFP *eltval = dyn_cast<ConstantFP>( 643 CV->getElementAsConstant((unsigned)eltNo)); 644 assert(eltval && "Non-FP arguments in math function!"); 645 bool found = false; 646 for (int i=0; i < sz; ++i) { 647 if (eltval->isExactlyValue(tr[i].input)) { 648 DVal.push_back(tr[i].result); 649 found = true; 650 break; 651 } 652 } 653 if (!found) { 654 // This vector constants not handled yet. 655 return false; 656 } 657 } 658 LLVMContext &context = CI->getParent()->getParent()->getContext(); 659 Constant *nval; 660 if (getArgType(FInfo) == AMDGPULibFunc::F32) { 661 SmallVector<float, 0> FVal; 662 for (unsigned i = 0; i < DVal.size(); ++i) { 663 FVal.push_back((float)DVal[i]); 664 } 665 ArrayRef<float> tmp(FVal); 666 nval = ConstantDataVector::get(context, tmp); 667 } else { // F64 668 ArrayRef<double> tmp(DVal); 669 nval = ConstantDataVector::get(context, tmp); 670 } 671 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n"); 672 replaceCall(CI, nval); 673 return true; 674 } 675 } else { 676 // Scalar version 677 if (ConstantFP *CF = dyn_cast<ConstantFP>(opr0)) { 678 for (int i = 0; i < sz; ++i) { 679 if (CF->isExactlyValue(tr[i].input)) { 680 Value *nval = ConstantFP::get(CF->getType(), tr[i].result); 681 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n"); 682 replaceCall(CI, nval); 683 return true; 684 } 685 } 686 } 687 } 688 689 return false; 690 } 691 692 // [native_]half_recip(c) ==> 1.0/c 693 bool AMDGPULibCalls::fold_recip(CallInst *CI, IRBuilder<> &B, 694 const FuncInfo &FInfo) { 695 Value *opr0 = CI->getArgOperand(0); 696 if (ConstantFP *CF = dyn_cast<ConstantFP>(opr0)) { 697 // Just create a normal div. Later, InstCombine will be able 698 // to compute the divide into a constant (avoid check float infinity 699 // or subnormal at this point). 700 Value *nval = B.CreateFDiv(ConstantFP::get(CF->getType(), 1.0), 701 opr0, 702 "recip2div"); 703 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n"); 704 replaceCall(CI, nval); 705 return true; 706 } 707 return false; 708 } 709 710 // [native_]half_divide(x, c) ==> x/c 711 bool AMDGPULibCalls::fold_divide(CallInst *CI, IRBuilder<> &B, 712 const FuncInfo &FInfo) { 713 Value *opr0 = CI->getArgOperand(0); 714 Value *opr1 = CI->getArgOperand(1); 715 ConstantFP *CF0 = dyn_cast<ConstantFP>(opr0); 716 ConstantFP *CF1 = dyn_cast<ConstantFP>(opr1); 717 718 if ((CF0 && CF1) || // both are constants 719 (CF1 && (getArgType(FInfo) == AMDGPULibFunc::F32))) 720 // CF1 is constant && f32 divide 721 { 722 Value *nval1 = B.CreateFDiv(ConstantFP::get(opr1->getType(), 1.0), 723 opr1, "__div2recip"); 724 Value *nval = B.CreateFMul(opr0, nval1, "__div2mul"); 725 replaceCall(CI, nval); 726 return true; 727 } 728 return false; 729 } 730 731 namespace llvm { 732 static double log2(double V) { 733 #if _XOPEN_SOURCE >= 600 || defined(_ISOC99_SOURCE) || _POSIX_C_SOURCE >= 200112L 734 return ::log2(V); 735 #else 736 return log(V) / numbers::ln2; 737 #endif 738 } 739 } 740 741 bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B, 742 const FuncInfo &FInfo) { 743 assert((FInfo.getId() == AMDGPULibFunc::EI_POW || 744 FInfo.getId() == AMDGPULibFunc::EI_POWR || 745 FInfo.getId() == AMDGPULibFunc::EI_POWN) && 746 "fold_pow: encounter a wrong function call"); 747 748 Module *M = B.GetInsertBlock()->getModule(); 749 ConstantFP *CF; 750 ConstantInt *CINT; 751 Type *eltType; 752 Value *opr0 = FPOp->getOperand(0); 753 Value *opr1 = FPOp->getOperand(1); 754 ConstantAggregateZero *CZero = dyn_cast<ConstantAggregateZero>(opr1); 755 756 if (getVecSize(FInfo) == 1) { 757 eltType = opr0->getType(); 758 CF = dyn_cast<ConstantFP>(opr1); 759 CINT = dyn_cast<ConstantInt>(opr1); 760 } else { 761 VectorType *VTy = dyn_cast<VectorType>(opr0->getType()); 762 assert(VTy && "Oprand of vector function should be of vectortype"); 763 eltType = VTy->getElementType(); 764 ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr1); 765 766 // Now, only Handle vector const whose elements have the same value. 767 CF = CDV ? dyn_cast_or_null<ConstantFP>(CDV->getSplatValue()) : nullptr; 768 CINT = CDV ? dyn_cast_or_null<ConstantInt>(CDV->getSplatValue()) : nullptr; 769 } 770 771 // No unsafe math , no constant argument, do nothing 772 if (!isUnsafeMath(FPOp) && !CF && !CINT && !CZero) 773 return false; 774 775 // 0x1111111 means that we don't do anything for this call. 776 int ci_opr1 = (CINT ? (int)CINT->getSExtValue() : 0x1111111); 777 778 if ((CF && CF->isZero()) || (CINT && ci_opr1 == 0) || CZero) { 779 // pow/powr/pown(x, 0) == 1 780 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1\n"); 781 Constant *cnval = ConstantFP::get(eltType, 1.0); 782 if (getVecSize(FInfo) > 1) { 783 cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval); 784 } 785 replaceCall(FPOp, cnval); 786 return true; 787 } 788 if ((CF && CF->isExactlyValue(1.0)) || (CINT && ci_opr1 == 1)) { 789 // pow/powr/pown(x, 1.0) = x 790 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << "\n"); 791 replaceCall(FPOp, opr0); 792 return true; 793 } 794 if ((CF && CF->isExactlyValue(2.0)) || (CINT && ci_opr1 == 2)) { 795 // pow/powr/pown(x, 2.0) = x*x 796 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << " * " 797 << *opr0 << "\n"); 798 Value *nval = B.CreateFMul(opr0, opr0, "__pow2"); 799 replaceCall(FPOp, nval); 800 return true; 801 } 802 if ((CF && CF->isExactlyValue(-1.0)) || (CINT && ci_opr1 == -1)) { 803 // pow/powr/pown(x, -1.0) = 1.0/x 804 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1 / " << *opr0 << "\n"); 805 Constant *cnval = ConstantFP::get(eltType, 1.0); 806 if (getVecSize(FInfo) > 1) { 807 cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval); 808 } 809 Value *nval = B.CreateFDiv(cnval, opr0, "__powrecip"); 810 replaceCall(FPOp, nval); 811 return true; 812 } 813 814 if (CF && (CF->isExactlyValue(0.5) || CF->isExactlyValue(-0.5))) { 815 // pow[r](x, [-]0.5) = sqrt(x) 816 bool issqrt = CF->isExactlyValue(0.5); 817 if (FunctionCallee FPExpr = 818 getFunction(M, AMDGPULibFunc(issqrt ? AMDGPULibFunc::EI_SQRT 819 : AMDGPULibFunc::EI_RSQRT, 820 FInfo))) { 821 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << FInfo.getName() 822 << '(' << *opr0 << ")\n"); 823 Value *nval = CreateCallEx(B,FPExpr, opr0, issqrt ? "__pow2sqrt" 824 : "__pow2rsqrt"); 825 replaceCall(FPOp, nval); 826 return true; 827 } 828 } 829 830 if (!isUnsafeMath(FPOp)) 831 return false; 832 833 // Unsafe Math optimization 834 835 // Remember that ci_opr1 is set if opr1 is integral 836 if (CF) { 837 double dval = (getArgType(FInfo) == AMDGPULibFunc::F32) 838 ? (double)CF->getValueAPF().convertToFloat() 839 : CF->getValueAPF().convertToDouble(); 840 int ival = (int)dval; 841 if ((double)ival == dval) { 842 ci_opr1 = ival; 843 } else 844 ci_opr1 = 0x11111111; 845 } 846 847 // pow/powr/pown(x, c) = [1/](x*x*..x); where 848 // trunc(c) == c && the number of x == c && |c| <= 12 849 unsigned abs_opr1 = (ci_opr1 < 0) ? -ci_opr1 : ci_opr1; 850 if (abs_opr1 <= 12) { 851 Constant *cnval; 852 Value *nval; 853 if (abs_opr1 == 0) { 854 cnval = ConstantFP::get(eltType, 1.0); 855 if (getVecSize(FInfo) > 1) { 856 cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval); 857 } 858 nval = cnval; 859 } else { 860 Value *valx2 = nullptr; 861 nval = nullptr; 862 while (abs_opr1 > 0) { 863 valx2 = valx2 ? B.CreateFMul(valx2, valx2, "__powx2") : opr0; 864 if (abs_opr1 & 1) { 865 nval = nval ? B.CreateFMul(nval, valx2, "__powprod") : valx2; 866 } 867 abs_opr1 >>= 1; 868 } 869 } 870 871 if (ci_opr1 < 0) { 872 cnval = ConstantFP::get(eltType, 1.0); 873 if (getVecSize(FInfo) > 1) { 874 cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval); 875 } 876 nval = B.CreateFDiv(cnval, nval, "__1powprod"); 877 } 878 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " 879 << ((ci_opr1 < 0) ? "1/prod(" : "prod(") << *opr0 880 << ")\n"); 881 replaceCall(FPOp, nval); 882 return true; 883 } 884 885 // powr ---> exp2(y * log2(x)) 886 // pown/pow ---> powr(fabs(x), y) | (x & ((int)y << 31)) 887 FunctionCallee ExpExpr = 888 getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_EXP2, FInfo)); 889 if (!ExpExpr) 890 return false; 891 892 bool needlog = false; 893 bool needabs = false; 894 bool needcopysign = false; 895 Constant *cnval = nullptr; 896 if (getVecSize(FInfo) == 1) { 897 CF = dyn_cast<ConstantFP>(opr0); 898 899 if (CF) { 900 double V = (getArgType(FInfo) == AMDGPULibFunc::F32) 901 ? (double)CF->getValueAPF().convertToFloat() 902 : CF->getValueAPF().convertToDouble(); 903 904 V = log2(std::abs(V)); 905 cnval = ConstantFP::get(eltType, V); 906 needcopysign = (FInfo.getId() != AMDGPULibFunc::EI_POWR) && 907 CF->isNegative(); 908 } else { 909 needlog = true; 910 needcopysign = needabs = FInfo.getId() != AMDGPULibFunc::EI_POWR && 911 (!CF || CF->isNegative()); 912 } 913 } else { 914 ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr0); 915 916 if (!CDV) { 917 needlog = true; 918 needcopysign = needabs = FInfo.getId() != AMDGPULibFunc::EI_POWR; 919 } else { 920 assert ((int)CDV->getNumElements() == getVecSize(FInfo) && 921 "Wrong vector size detected"); 922 923 SmallVector<double, 0> DVal; 924 for (int i=0; i < getVecSize(FInfo); ++i) { 925 double V = (getArgType(FInfo) == AMDGPULibFunc::F32) 926 ? (double)CDV->getElementAsFloat(i) 927 : CDV->getElementAsDouble(i); 928 if (V < 0.0) needcopysign = true; 929 V = log2(std::abs(V)); 930 DVal.push_back(V); 931 } 932 if (getArgType(FInfo) == AMDGPULibFunc::F32) { 933 SmallVector<float, 0> FVal; 934 for (unsigned i=0; i < DVal.size(); ++i) { 935 FVal.push_back((float)DVal[i]); 936 } 937 ArrayRef<float> tmp(FVal); 938 cnval = ConstantDataVector::get(M->getContext(), tmp); 939 } else { 940 ArrayRef<double> tmp(DVal); 941 cnval = ConstantDataVector::get(M->getContext(), tmp); 942 } 943 } 944 } 945 946 if (needcopysign && (FInfo.getId() == AMDGPULibFunc::EI_POW)) { 947 // We cannot handle corner cases for a general pow() function, give up 948 // unless y is a constant integral value. Then proceed as if it were pown. 949 if (getVecSize(FInfo) == 1) { 950 if (const ConstantFP *CF = dyn_cast<ConstantFP>(opr1)) { 951 double y = (getArgType(FInfo) == AMDGPULibFunc::F32) 952 ? (double)CF->getValueAPF().convertToFloat() 953 : CF->getValueAPF().convertToDouble(); 954 if (y != (double)(int64_t)y) 955 return false; 956 } else 957 return false; 958 } else { 959 if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr1)) { 960 for (int i=0; i < getVecSize(FInfo); ++i) { 961 double y = (getArgType(FInfo) == AMDGPULibFunc::F32) 962 ? (double)CDV->getElementAsFloat(i) 963 : CDV->getElementAsDouble(i); 964 if (y != (double)(int64_t)y) 965 return false; 966 } 967 } else 968 return false; 969 } 970 } 971 972 Value *nval; 973 if (needabs) { 974 nval = B.CreateUnaryIntrinsic(Intrinsic::fabs, opr0, nullptr, "__fabs"); 975 } else { 976 nval = cnval ? cnval : opr0; 977 } 978 if (needlog) { 979 FunctionCallee LogExpr = 980 getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_LOG2, FInfo)); 981 if (!LogExpr) 982 return false; 983 nval = CreateCallEx(B,LogExpr, nval, "__log2"); 984 } 985 986 if (FInfo.getId() == AMDGPULibFunc::EI_POWN) { 987 // convert int(32) to fp(f32 or f64) 988 opr1 = B.CreateSIToFP(opr1, nval->getType(), "pownI2F"); 989 } 990 nval = B.CreateFMul(opr1, nval, "__ylogx"); 991 nval = CreateCallEx(B,ExpExpr, nval, "__exp2"); 992 993 if (needcopysign) { 994 Value *opr_n; 995 Type* rTy = opr0->getType(); 996 Type* nTyS = eltType->isDoubleTy() ? B.getInt64Ty() : B.getInt32Ty(); 997 Type *nTy = nTyS; 998 if (const auto *vTy = dyn_cast<FixedVectorType>(rTy)) 999 nTy = FixedVectorType::get(nTyS, vTy); 1000 unsigned size = nTy->getScalarSizeInBits(); 1001 opr_n = FPOp->getOperand(1); 1002 if (opr_n->getType()->isIntegerTy()) 1003 opr_n = B.CreateZExtOrBitCast(opr_n, nTy, "__ytou"); 1004 else 1005 opr_n = B.CreateFPToSI(opr1, nTy, "__ytou"); 1006 1007 Value *sign = B.CreateShl(opr_n, size-1, "__yeven"); 1008 sign = B.CreateAnd(B.CreateBitCast(opr0, nTy), sign, "__pow_sign"); 1009 nval = B.CreateOr(B.CreateBitCast(nval, nTy), sign); 1010 nval = B.CreateBitCast(nval, opr0->getType()); 1011 } 1012 1013 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " 1014 << "exp2(" << *opr1 << " * log2(" << *opr0 << "))\n"); 1015 replaceCall(FPOp, nval); 1016 1017 return true; 1018 } 1019 1020 bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B, 1021 const FuncInfo &FInfo) { 1022 // skip vector function 1023 if (getVecSize(FInfo) != 1) 1024 return false; 1025 1026 Value *opr0 = FPOp->getOperand(0); 1027 Value *opr1 = FPOp->getOperand(1); 1028 1029 ConstantInt *CINT = dyn_cast<ConstantInt>(opr1); 1030 if (!CINT) { 1031 return false; 1032 } 1033 int ci_opr1 = (int)CINT->getSExtValue(); 1034 if (ci_opr1 == 1) { // rootn(x, 1) = x 1035 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << "\n"); 1036 replaceCall(FPOp, opr0); 1037 return true; 1038 } 1039 1040 Module *M = B.GetInsertBlock()->getModule(); 1041 if (ci_opr1 == 2) { // rootn(x, 2) = sqrt(x) 1042 if (FunctionCallee FPExpr = 1043 getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) { 1044 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> sqrt(" << *opr0 1045 << ")\n"); 1046 Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2sqrt"); 1047 replaceCall(FPOp, nval); 1048 return true; 1049 } 1050 } else if (ci_opr1 == 3) { // rootn(x, 3) = cbrt(x) 1051 if (FunctionCallee FPExpr = 1052 getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_CBRT, FInfo))) { 1053 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> cbrt(" << *opr0 1054 << ")\n"); 1055 Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2cbrt"); 1056 replaceCall(FPOp, nval); 1057 return true; 1058 } 1059 } else if (ci_opr1 == -1) { // rootn(x, -1) = 1.0/x 1060 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1.0 / " << *opr0 << "\n"); 1061 Value *nval = B.CreateFDiv(ConstantFP::get(opr0->getType(), 1.0), 1062 opr0, 1063 "__rootn2div"); 1064 replaceCall(FPOp, nval); 1065 return true; 1066 } else if (ci_opr1 == -2) { // rootn(x, -2) = rsqrt(x) 1067 if (FunctionCallee FPExpr = 1068 getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_RSQRT, FInfo))) { 1069 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> rsqrt(" << *opr0 1070 << ")\n"); 1071 Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2rsqrt"); 1072 replaceCall(FPOp, nval); 1073 return true; 1074 } 1075 } 1076 return false; 1077 } 1078 1079 bool AMDGPULibCalls::fold_fma_mad(CallInst *CI, IRBuilder<> &B, 1080 const FuncInfo &FInfo) { 1081 Value *opr0 = CI->getArgOperand(0); 1082 Value *opr1 = CI->getArgOperand(1); 1083 Value *opr2 = CI->getArgOperand(2); 1084 1085 ConstantFP *CF0 = dyn_cast<ConstantFP>(opr0); 1086 ConstantFP *CF1 = dyn_cast<ConstantFP>(opr1); 1087 if ((CF0 && CF0->isZero()) || (CF1 && CF1->isZero())) { 1088 // fma/mad(a, b, c) = c if a=0 || b=0 1089 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr2 << "\n"); 1090 replaceCall(CI, opr2); 1091 return true; 1092 } 1093 if (CF0 && CF0->isExactlyValue(1.0f)) { 1094 // fma/mad(a, b, c) = b+c if a=1 1095 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr1 << " + " << *opr2 1096 << "\n"); 1097 Value *nval = B.CreateFAdd(opr1, opr2, "fmaadd"); 1098 replaceCall(CI, nval); 1099 return true; 1100 } 1101 if (CF1 && CF1->isExactlyValue(1.0f)) { 1102 // fma/mad(a, b, c) = a+c if b=1 1103 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << " + " << *opr2 1104 << "\n"); 1105 Value *nval = B.CreateFAdd(opr0, opr2, "fmaadd"); 1106 replaceCall(CI, nval); 1107 return true; 1108 } 1109 if (ConstantFP *CF = dyn_cast<ConstantFP>(opr2)) { 1110 if (CF->isZero()) { 1111 // fma/mad(a, b, c) = a*b if c=0 1112 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << " * " 1113 << *opr1 << "\n"); 1114 Value *nval = B.CreateFMul(opr0, opr1, "fmamul"); 1115 replaceCall(CI, nval); 1116 return true; 1117 } 1118 } 1119 1120 return false; 1121 } 1122 1123 // Get a scalar native builtin single argument FP function 1124 FunctionCallee AMDGPULibCalls::getNativeFunction(Module *M, 1125 const FuncInfo &FInfo) { 1126 if (getArgType(FInfo) == AMDGPULibFunc::F64 || !HasNative(FInfo.getId())) 1127 return nullptr; 1128 FuncInfo nf = FInfo; 1129 nf.setPrefix(AMDGPULibFunc::NATIVE); 1130 return getFunction(M, nf); 1131 } 1132 1133 // fold sqrt -> native_sqrt (x) 1134 bool AMDGPULibCalls::fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B, 1135 const FuncInfo &FInfo) { 1136 if (!isUnsafeMath(FPOp)) 1137 return false; 1138 1139 if (getArgType(FInfo) == AMDGPULibFunc::F32 && (getVecSize(FInfo) == 1) && 1140 (FInfo.getPrefix() != AMDGPULibFunc::NATIVE)) { 1141 Module *M = B.GetInsertBlock()->getModule(); 1142 1143 if (FunctionCallee FPExpr = getNativeFunction( 1144 M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) { 1145 Value *opr0 = FPOp->getOperand(0); 1146 LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " 1147 << "sqrt(" << *opr0 << ")\n"); 1148 Value *nval = CreateCallEx(B,FPExpr, opr0, "__sqrt"); 1149 replaceCall(FPOp, nval); 1150 return true; 1151 } 1152 } 1153 return false; 1154 } 1155 1156 // fold sin, cos -> sincos. 1157 bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B, 1158 const FuncInfo &fInfo) { 1159 assert(fInfo.getId() == AMDGPULibFunc::EI_SIN || 1160 fInfo.getId() == AMDGPULibFunc::EI_COS); 1161 1162 if ((getArgType(fInfo) != AMDGPULibFunc::F32 && 1163 getArgType(fInfo) != AMDGPULibFunc::F64) || 1164 fInfo.getPrefix() != AMDGPULibFunc::NOPFX) 1165 return false; 1166 1167 bool const isSin = fInfo.getId() == AMDGPULibFunc::EI_SIN; 1168 1169 Value *CArgVal = FPOp->getOperand(0); 1170 CallInst *CI = cast<CallInst>(FPOp); 1171 BasicBlock * const CBB = CI->getParent(); 1172 1173 int const MaxScan = 30; 1174 bool Changed = false; 1175 1176 Module *M = CI->getModule(); 1177 FuncInfo PartnerInfo(isSin ? AMDGPULibFunc::EI_COS : AMDGPULibFunc::EI_SIN, 1178 fInfo); 1179 const std::string PairName = PartnerInfo.mangle(); 1180 1181 CallInst *UI = nullptr; 1182 for (User* U : CArgVal->users()) { 1183 CallInst *XI = dyn_cast_or_null<CallInst>(U); 1184 if (!XI || XI == CI || XI->getParent() != CBB) 1185 continue; 1186 1187 Function *UCallee = XI->getCalledFunction(); 1188 if (!UCallee || !UCallee->getName().equals(PairName)) 1189 continue; 1190 1191 BasicBlock::iterator BBI = CI->getIterator(); 1192 if (BBI == CI->getParent()->begin()) 1193 break; 1194 --BBI; 1195 for (int I = MaxScan; I > 0 && BBI != CBB->begin(); --BBI, --I) { 1196 if (cast<Instruction>(BBI) == XI) { 1197 UI = XI; 1198 break; 1199 } 1200 } 1201 if (UI) break; 1202 } 1203 1204 if (!UI) 1205 return Changed; 1206 1207 // Merge the sin and cos. 1208 1209 // for OpenCL 2.0 we have only generic implementation of sincos 1210 // function. 1211 AMDGPULibFunc nf(AMDGPULibFunc::EI_SINCOS, fInfo); 1212 nf.getLeads()[0].PtrKind = AMDGPULibFunc::getEPtrKindFromAddrSpace(AMDGPUAS::FLAT_ADDRESS); 1213 FunctionCallee Fsincos = getFunction(M, nf); 1214 if (!Fsincos) 1215 return Changed; 1216 1217 BasicBlock::iterator ItOld = B.GetInsertPoint(); 1218 AllocaInst *Alloc = insertAlloca(UI, B, "__sincos_"); 1219 B.SetInsertPoint(UI); 1220 1221 Value *P = Alloc; 1222 Type *PTy = Fsincos.getFunctionType()->getParamType(1); 1223 // The allocaInst allocates the memory in private address space. This need 1224 // to be bitcasted to point to the address space of cos pointer type. 1225 // In OpenCL 2.0 this is generic, while in 1.2 that is private. 1226 if (PTy->getPointerAddressSpace() != AMDGPUAS::PRIVATE_ADDRESS) 1227 P = B.CreateAddrSpaceCast(Alloc, PTy); 1228 CallInst *Call = CreateCallEx2(B, Fsincos, UI->getArgOperand(0), P); 1229 1230 LLVM_DEBUG(errs() << "AMDIC: fold_sincos (" << *CI << ", " << *UI << ") with " 1231 << *Call << "\n"); 1232 1233 if (!isSin) { // CI->cos, UI->sin 1234 B.SetInsertPoint(&*ItOld); 1235 UI->replaceAllUsesWith(&*Call); 1236 Instruction *Reload = B.CreateLoad(Alloc->getAllocatedType(), Alloc); 1237 CI->replaceAllUsesWith(Reload); 1238 UI->eraseFromParent(); 1239 CI->eraseFromParent(); 1240 } else { // CI->sin, UI->cos 1241 Instruction *Reload = B.CreateLoad(Alloc->getAllocatedType(), Alloc); 1242 UI->replaceAllUsesWith(Reload); 1243 CI->replaceAllUsesWith(Call); 1244 UI->eraseFromParent(); 1245 CI->eraseFromParent(); 1246 } 1247 return true; 1248 } 1249 1250 // Get insertion point at entry. 1251 BasicBlock::iterator AMDGPULibCalls::getEntryIns(CallInst * UI) { 1252 Function * Func = UI->getParent()->getParent(); 1253 BasicBlock * BB = &Func->getEntryBlock(); 1254 assert(BB && "Entry block not found!"); 1255 BasicBlock::iterator ItNew = BB->begin(); 1256 return ItNew; 1257 } 1258 1259 // Insert a AllocsInst at the beginning of function entry block. 1260 AllocaInst* AMDGPULibCalls::insertAlloca(CallInst *UI, IRBuilder<> &B, 1261 const char *prefix) { 1262 BasicBlock::iterator ItNew = getEntryIns(UI); 1263 Function *UCallee = UI->getCalledFunction(); 1264 Type *RetType = UCallee->getReturnType(); 1265 B.SetInsertPoint(&*ItNew); 1266 AllocaInst *Alloc = 1267 B.CreateAlloca(RetType, nullptr, std::string(prefix) + UI->getName()); 1268 Alloc->setAlignment( 1269 Align(UCallee->getParent()->getDataLayout().getTypeAllocSize(RetType))); 1270 return Alloc; 1271 } 1272 1273 bool AMDGPULibCalls::evaluateScalarMathFunc(const FuncInfo &FInfo, 1274 double& Res0, double& Res1, 1275 Constant *copr0, Constant *copr1, 1276 Constant *copr2) { 1277 // By default, opr0/opr1/opr3 holds values of float/double type. 1278 // If they are not float/double, each function has to its 1279 // operand separately. 1280 double opr0=0.0, opr1=0.0, opr2=0.0; 1281 ConstantFP *fpopr0 = dyn_cast_or_null<ConstantFP>(copr0); 1282 ConstantFP *fpopr1 = dyn_cast_or_null<ConstantFP>(copr1); 1283 ConstantFP *fpopr2 = dyn_cast_or_null<ConstantFP>(copr2); 1284 if (fpopr0) { 1285 opr0 = (getArgType(FInfo) == AMDGPULibFunc::F64) 1286 ? fpopr0->getValueAPF().convertToDouble() 1287 : (double)fpopr0->getValueAPF().convertToFloat(); 1288 } 1289 1290 if (fpopr1) { 1291 opr1 = (getArgType(FInfo) == AMDGPULibFunc::F64) 1292 ? fpopr1->getValueAPF().convertToDouble() 1293 : (double)fpopr1->getValueAPF().convertToFloat(); 1294 } 1295 1296 if (fpopr2) { 1297 opr2 = (getArgType(FInfo) == AMDGPULibFunc::F64) 1298 ? fpopr2->getValueAPF().convertToDouble() 1299 : (double)fpopr2->getValueAPF().convertToFloat(); 1300 } 1301 1302 switch (FInfo.getId()) { 1303 default : return false; 1304 1305 case AMDGPULibFunc::EI_ACOS: 1306 Res0 = acos(opr0); 1307 return true; 1308 1309 case AMDGPULibFunc::EI_ACOSH: 1310 // acosh(x) == log(x + sqrt(x*x - 1)) 1311 Res0 = log(opr0 + sqrt(opr0*opr0 - 1.0)); 1312 return true; 1313 1314 case AMDGPULibFunc::EI_ACOSPI: 1315 Res0 = acos(opr0) / MATH_PI; 1316 return true; 1317 1318 case AMDGPULibFunc::EI_ASIN: 1319 Res0 = asin(opr0); 1320 return true; 1321 1322 case AMDGPULibFunc::EI_ASINH: 1323 // asinh(x) == log(x + sqrt(x*x + 1)) 1324 Res0 = log(opr0 + sqrt(opr0*opr0 + 1.0)); 1325 return true; 1326 1327 case AMDGPULibFunc::EI_ASINPI: 1328 Res0 = asin(opr0) / MATH_PI; 1329 return true; 1330 1331 case AMDGPULibFunc::EI_ATAN: 1332 Res0 = atan(opr0); 1333 return true; 1334 1335 case AMDGPULibFunc::EI_ATANH: 1336 // atanh(x) == (log(x+1) - log(x-1))/2; 1337 Res0 = (log(opr0 + 1.0) - log(opr0 - 1.0))/2.0; 1338 return true; 1339 1340 case AMDGPULibFunc::EI_ATANPI: 1341 Res0 = atan(opr0) / MATH_PI; 1342 return true; 1343 1344 case AMDGPULibFunc::EI_CBRT: 1345 Res0 = (opr0 < 0.0) ? -pow(-opr0, 1.0/3.0) : pow(opr0, 1.0/3.0); 1346 return true; 1347 1348 case AMDGPULibFunc::EI_COS: 1349 Res0 = cos(opr0); 1350 return true; 1351 1352 case AMDGPULibFunc::EI_COSH: 1353 Res0 = cosh(opr0); 1354 return true; 1355 1356 case AMDGPULibFunc::EI_COSPI: 1357 Res0 = cos(MATH_PI * opr0); 1358 return true; 1359 1360 case AMDGPULibFunc::EI_EXP: 1361 Res0 = exp(opr0); 1362 return true; 1363 1364 case AMDGPULibFunc::EI_EXP2: 1365 Res0 = pow(2.0, opr0); 1366 return true; 1367 1368 case AMDGPULibFunc::EI_EXP10: 1369 Res0 = pow(10.0, opr0); 1370 return true; 1371 1372 case AMDGPULibFunc::EI_EXPM1: 1373 Res0 = exp(opr0) - 1.0; 1374 return true; 1375 1376 case AMDGPULibFunc::EI_LOG: 1377 Res0 = log(opr0); 1378 return true; 1379 1380 case AMDGPULibFunc::EI_LOG2: 1381 Res0 = log(opr0) / log(2.0); 1382 return true; 1383 1384 case AMDGPULibFunc::EI_LOG10: 1385 Res0 = log(opr0) / log(10.0); 1386 return true; 1387 1388 case AMDGPULibFunc::EI_RSQRT: 1389 Res0 = 1.0 / sqrt(opr0); 1390 return true; 1391 1392 case AMDGPULibFunc::EI_SIN: 1393 Res0 = sin(opr0); 1394 return true; 1395 1396 case AMDGPULibFunc::EI_SINH: 1397 Res0 = sinh(opr0); 1398 return true; 1399 1400 case AMDGPULibFunc::EI_SINPI: 1401 Res0 = sin(MATH_PI * opr0); 1402 return true; 1403 1404 case AMDGPULibFunc::EI_SQRT: 1405 Res0 = sqrt(opr0); 1406 return true; 1407 1408 case AMDGPULibFunc::EI_TAN: 1409 Res0 = tan(opr0); 1410 return true; 1411 1412 case AMDGPULibFunc::EI_TANH: 1413 Res0 = tanh(opr0); 1414 return true; 1415 1416 case AMDGPULibFunc::EI_TANPI: 1417 Res0 = tan(MATH_PI * opr0); 1418 return true; 1419 1420 case AMDGPULibFunc::EI_RECIP: 1421 Res0 = 1.0 / opr0; 1422 return true; 1423 1424 // two-arg functions 1425 case AMDGPULibFunc::EI_DIVIDE: 1426 Res0 = opr0 / opr1; 1427 return true; 1428 1429 case AMDGPULibFunc::EI_POW: 1430 case AMDGPULibFunc::EI_POWR: 1431 Res0 = pow(opr0, opr1); 1432 return true; 1433 1434 case AMDGPULibFunc::EI_POWN: { 1435 if (ConstantInt *iopr1 = dyn_cast_or_null<ConstantInt>(copr1)) { 1436 double val = (double)iopr1->getSExtValue(); 1437 Res0 = pow(opr0, val); 1438 return true; 1439 } 1440 return false; 1441 } 1442 1443 case AMDGPULibFunc::EI_ROOTN: { 1444 if (ConstantInt *iopr1 = dyn_cast_or_null<ConstantInt>(copr1)) { 1445 double val = (double)iopr1->getSExtValue(); 1446 Res0 = pow(opr0, 1.0 / val); 1447 return true; 1448 } 1449 return false; 1450 } 1451 1452 // with ptr arg 1453 case AMDGPULibFunc::EI_SINCOS: 1454 Res0 = sin(opr0); 1455 Res1 = cos(opr0); 1456 return true; 1457 1458 // three-arg functions 1459 case AMDGPULibFunc::EI_FMA: 1460 case AMDGPULibFunc::EI_MAD: 1461 Res0 = opr0 * opr1 + opr2; 1462 return true; 1463 } 1464 1465 return false; 1466 } 1467 1468 bool AMDGPULibCalls::evaluateCall(CallInst *aCI, const FuncInfo &FInfo) { 1469 int numArgs = (int)aCI->arg_size(); 1470 if (numArgs > 3) 1471 return false; 1472 1473 Constant *copr0 = nullptr; 1474 Constant *copr1 = nullptr; 1475 Constant *copr2 = nullptr; 1476 if (numArgs > 0) { 1477 if ((copr0 = dyn_cast<Constant>(aCI->getArgOperand(0))) == nullptr) 1478 return false; 1479 } 1480 1481 if (numArgs > 1) { 1482 if ((copr1 = dyn_cast<Constant>(aCI->getArgOperand(1))) == nullptr) { 1483 if (FInfo.getId() != AMDGPULibFunc::EI_SINCOS) 1484 return false; 1485 } 1486 } 1487 1488 if (numArgs > 2) { 1489 if ((copr2 = dyn_cast<Constant>(aCI->getArgOperand(2))) == nullptr) 1490 return false; 1491 } 1492 1493 // At this point, all arguments to aCI are constants. 1494 1495 // max vector size is 16, and sincos will generate two results. 1496 double DVal0[16], DVal1[16]; 1497 int FuncVecSize = getVecSize(FInfo); 1498 bool hasTwoResults = (FInfo.getId() == AMDGPULibFunc::EI_SINCOS); 1499 if (FuncVecSize == 1) { 1500 if (!evaluateScalarMathFunc(FInfo, DVal0[0], 1501 DVal1[0], copr0, copr1, copr2)) { 1502 return false; 1503 } 1504 } else { 1505 ConstantDataVector *CDV0 = dyn_cast_or_null<ConstantDataVector>(copr0); 1506 ConstantDataVector *CDV1 = dyn_cast_or_null<ConstantDataVector>(copr1); 1507 ConstantDataVector *CDV2 = dyn_cast_or_null<ConstantDataVector>(copr2); 1508 for (int i = 0; i < FuncVecSize; ++i) { 1509 Constant *celt0 = CDV0 ? CDV0->getElementAsConstant(i) : nullptr; 1510 Constant *celt1 = CDV1 ? CDV1->getElementAsConstant(i) : nullptr; 1511 Constant *celt2 = CDV2 ? CDV2->getElementAsConstant(i) : nullptr; 1512 if (!evaluateScalarMathFunc(FInfo, DVal0[i], 1513 DVal1[i], celt0, celt1, celt2)) { 1514 return false; 1515 } 1516 } 1517 } 1518 1519 LLVMContext &context = aCI->getContext(); 1520 Constant *nval0, *nval1; 1521 if (FuncVecSize == 1) { 1522 nval0 = ConstantFP::get(aCI->getType(), DVal0[0]); 1523 if (hasTwoResults) 1524 nval1 = ConstantFP::get(aCI->getType(), DVal1[0]); 1525 } else { 1526 if (getArgType(FInfo) == AMDGPULibFunc::F32) { 1527 SmallVector <float, 0> FVal0, FVal1; 1528 for (int i = 0; i < FuncVecSize; ++i) 1529 FVal0.push_back((float)DVal0[i]); 1530 ArrayRef<float> tmp0(FVal0); 1531 nval0 = ConstantDataVector::get(context, tmp0); 1532 if (hasTwoResults) { 1533 for (int i = 0; i < FuncVecSize; ++i) 1534 FVal1.push_back((float)DVal1[i]); 1535 ArrayRef<float> tmp1(FVal1); 1536 nval1 = ConstantDataVector::get(context, tmp1); 1537 } 1538 } else { 1539 ArrayRef<double> tmp0(DVal0); 1540 nval0 = ConstantDataVector::get(context, tmp0); 1541 if (hasTwoResults) { 1542 ArrayRef<double> tmp1(DVal1); 1543 nval1 = ConstantDataVector::get(context, tmp1); 1544 } 1545 } 1546 } 1547 1548 if (hasTwoResults) { 1549 // sincos 1550 assert(FInfo.getId() == AMDGPULibFunc::EI_SINCOS && 1551 "math function with ptr arg not supported yet"); 1552 new StoreInst(nval1, aCI->getArgOperand(1), aCI); 1553 } 1554 1555 replaceCall(aCI, nval0); 1556 return true; 1557 } 1558 1559 PreservedAnalyses AMDGPUSimplifyLibCallsPass::run(Function &F, 1560 FunctionAnalysisManager &AM) { 1561 AMDGPULibCalls Simplifier; 1562 Simplifier.initNativeFuncs(); 1563 Simplifier.initFunction(F); 1564 1565 bool Changed = false; 1566 1567 LLVM_DEBUG(dbgs() << "AMDIC: process function "; 1568 F.printAsOperand(dbgs(), false, F.getParent()); dbgs() << '\n';); 1569 1570 for (auto &BB : F) { 1571 for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E;) { 1572 // Ignore non-calls. 1573 CallInst *CI = dyn_cast<CallInst>(I); 1574 ++I; 1575 1576 if (CI) { 1577 if (Simplifier.fold(CI)) 1578 Changed = true; 1579 } 1580 } 1581 } 1582 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); 1583 } 1584 1585 PreservedAnalyses AMDGPUUseNativeCallsPass::run(Function &F, 1586 FunctionAnalysisManager &AM) { 1587 if (UseNative.empty()) 1588 return PreservedAnalyses::all(); 1589 1590 AMDGPULibCalls Simplifier; 1591 Simplifier.initNativeFuncs(); 1592 Simplifier.initFunction(F); 1593 1594 bool Changed = false; 1595 for (auto &BB : F) { 1596 for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E;) { 1597 // Ignore non-calls. 1598 CallInst *CI = dyn_cast<CallInst>(I); 1599 ++I; 1600 if (CI && Simplifier.useNative(CI)) 1601 Changed = true; 1602 } 1603 } 1604 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); 1605 } 1606