1 //===-- NVPTXISelLowering.cpp - NVPTX DAG Lowering Implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a 10 // selection DAG. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "NVPTXISelLowering.h" 15 #include "MCTargetDesc/NVPTXBaseInfo.h" 16 #include "NVPTX.h" 17 #include "NVPTXSubtarget.h" 18 #include "NVPTXTargetMachine.h" 19 #include "NVPTXTargetObjectFile.h" 20 #include "NVPTXUtilities.h" 21 #include "llvm/ADT/APInt.h" 22 #include "llvm/ADT/STLExtras.h" 23 #include "llvm/ADT/SmallVector.h" 24 #include "llvm/ADT/StringRef.h" 25 #include "llvm/CodeGen/Analysis.h" 26 #include "llvm/CodeGen/ISDOpcodes.h" 27 #include "llvm/CodeGen/MachineFunction.h" 28 #include "llvm/CodeGen/MachineJumpTableInfo.h" 29 #include "llvm/CodeGen/MachineMemOperand.h" 30 #include "llvm/CodeGen/SelectionDAG.h" 31 #include "llvm/CodeGen/SelectionDAGNodes.h" 32 #include "llvm/CodeGen/TargetCallingConv.h" 33 #include "llvm/CodeGen/TargetLowering.h" 34 #include "llvm/CodeGen/ValueTypes.h" 35 #include "llvm/CodeGenTypes/MachineValueType.h" 36 #include "llvm/IR/Argument.h" 37 #include "llvm/IR/Attributes.h" 38 #include "llvm/IR/Constants.h" 39 #include "llvm/IR/DataLayout.h" 40 #include "llvm/IR/DerivedTypes.h" 41 #include "llvm/IR/DiagnosticInfo.h" 42 #include "llvm/IR/FPEnv.h" 43 #include "llvm/IR/Function.h" 44 #include "llvm/IR/GlobalValue.h" 45 #include "llvm/IR/Instruction.h" 46 #include "llvm/IR/Instructions.h" 47 #include "llvm/IR/IntrinsicsNVPTX.h" 48 #include "llvm/IR/Module.h" 49 #include "llvm/IR/Type.h" 50 #include "llvm/IR/Value.h" 51 #include "llvm/Support/Alignment.h" 52 #include "llvm/Support/Casting.h" 53 #include "llvm/Support/CodeGen.h" 54 #include "llvm/Support/CommandLine.h" 55 #include "llvm/Support/ErrorHandling.h" 56 #include "llvm/Support/NVPTXAddrSpace.h" 57 #include "llvm/Support/raw_ostream.h" 58 #include "llvm/Target/TargetMachine.h" 59 #include "llvm/Target/TargetOptions.h" 60 #include <algorithm> 61 #include <cassert> 62 #include <cmath> 63 #include <cstdint> 64 #include <iterator> 65 #include <optional> 66 #include <string> 67 #include <utility> 68 #include <vector> 69 70 #define DEBUG_TYPE "nvptx-lower" 71 72 using namespace llvm; 73 74 static std::atomic<unsigned> GlobalUniqueCallSite; 75 76 static cl::opt<bool> sched4reg( 77 "nvptx-sched4reg", 78 cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false)); 79 80 static cl::opt<unsigned> FMAContractLevelOpt( 81 "nvptx-fma-level", cl::Hidden, 82 cl::desc("NVPTX Specific: FMA contraction (0: don't do it" 83 " 1: do it 2: do it aggressively"), 84 cl::init(2)); 85 86 static cl::opt<int> UsePrecDivF32( 87 "nvptx-prec-divf32", cl::Hidden, 88 cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use" 89 " IEEE Compliant F32 div.rnd if available."), 90 cl::init(2)); 91 92 static cl::opt<bool> UsePrecSqrtF32( 93 "nvptx-prec-sqrtf32", cl::Hidden, 94 cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."), 95 cl::init(true)); 96 97 /// Whereas CUDA's implementation (see libdevice) uses ex2.approx for exp2(), it 98 /// does NOT use lg2.approx for log2, so this is disabled by default. 99 static cl::opt<bool> UseApproxLog2F32( 100 "nvptx-approx-log2f32", 101 cl::desc("NVPTX Specific: whether to use lg2.approx for log2"), 102 cl::init(false)); 103 104 static cl::opt<bool> ForceMinByValParamAlign( 105 "nvptx-force-min-byval-param-align", cl::Hidden, 106 cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval" 107 " params of device functions."), 108 cl::init(false)); 109 110 int NVPTXTargetLowering::getDivF32Level() const { 111 if (UsePrecDivF32.getNumOccurrences() > 0) { 112 // If nvptx-prec-div32=N is used on the command-line, always honor it 113 return UsePrecDivF32; 114 } else { 115 // Otherwise, use div.approx if fast math is enabled 116 if (getTargetMachine().Options.UnsafeFPMath) 117 return 0; 118 else 119 return 2; 120 } 121 } 122 123 bool NVPTXTargetLowering::usePrecSqrtF32() const { 124 if (UsePrecSqrtF32.getNumOccurrences() > 0) { 125 // If nvptx-prec-sqrtf32 is used on the command-line, always honor it 126 return UsePrecSqrtF32; 127 } else { 128 // Otherwise, use sqrt.approx if fast math is enabled 129 return !getTargetMachine().Options.UnsafeFPMath; 130 } 131 } 132 133 bool NVPTXTargetLowering::useF32FTZ(const MachineFunction &MF) const { 134 return MF.getDenormalMode(APFloat::IEEEsingle()).Output == 135 DenormalMode::PreserveSign; 136 } 137 138 static bool IsPTXVectorType(MVT VT) { 139 switch (VT.SimpleTy) { 140 default: 141 return false; 142 case MVT::v2i1: 143 case MVT::v4i1: 144 case MVT::v2i8: 145 case MVT::v4i8: 146 case MVT::v8i8: // <2 x i8x4> 147 case MVT::v16i8: // <4 x i8x4> 148 case MVT::v2i16: 149 case MVT::v4i16: 150 case MVT::v8i16: // <4 x i16x2> 151 case MVT::v2i32: 152 case MVT::v4i32: 153 case MVT::v2i64: 154 case MVT::v2f16: 155 case MVT::v4f16: 156 case MVT::v8f16: // <4 x f16x2> 157 case MVT::v2bf16: 158 case MVT::v4bf16: 159 case MVT::v8bf16: // <4 x bf16x2> 160 case MVT::v2f32: 161 case MVT::v4f32: 162 case MVT::v2f64: 163 return true; 164 } 165 } 166 167 static bool Is16bitsType(MVT VT) { 168 return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16 || 169 VT.SimpleTy == MVT::i16); 170 } 171 172 // When legalizing vector loads/stores, this function is called, which does two 173 // things: 174 // 1. Determines Whether the vector is something we want to custom lower, 175 // std::nullopt is returned if we do not want to custom lower it. 176 // 2. If we do want to handle it, returns two parameters: 177 // - unsigned int NumElts - The number of elements in the final vector 178 // - EVT EltVT - The type of the elements in the final vector 179 static std::optional<std::pair<unsigned int, EVT>> 180 getVectorLoweringShape(EVT VectorVT) { 181 if (!VectorVT.isVector() || !VectorVT.isSimple()) 182 return std::nullopt; 183 184 EVT EltVT = VectorVT.getVectorElementType(); 185 unsigned NumElts = VectorVT.getVectorNumElements(); 186 187 // We only handle "native" vector sizes for now, e.g. <4 x double> is not 188 // legal. We can (and should) split that into 2 stores of <2 x double> here 189 // but I'm leaving that as a TODO for now. 190 switch (VectorVT.getSimpleVT().SimpleTy) { 191 default: 192 return std::nullopt; 193 case MVT::v2i8: 194 case MVT::v2i16: 195 case MVT::v2i32: 196 case MVT::v2i64: 197 case MVT::v2f16: 198 case MVT::v2bf16: 199 case MVT::v2f32: 200 case MVT::v2f64: 201 case MVT::v4i8: 202 case MVT::v4i16: 203 case MVT::v4i32: 204 case MVT::v4f16: 205 case MVT::v4bf16: 206 case MVT::v4f32: 207 // This is a "native" vector type 208 return std::pair(NumElts, EltVT); 209 case MVT::v8i8: // <2 x i8x4> 210 case MVT::v8f16: // <4 x f16x2> 211 case MVT::v8bf16: // <4 x bf16x2> 212 case MVT::v8i16: // <4 x i16x2> 213 case MVT::v16i8: // <4 x i8x4> 214 // This can be upsized into a "native" vector type. 215 // Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for 216 // total load/store size, PTX syntax only supports v2/v4. Thus, we can't use 217 // vectorized loads/stores with the actual element type for i8/i16 as that 218 // would require v8/v16 variants that do not exist. 219 // In order to load/store such vectors efficiently, here in Type 220 // Legalization, we split the vector into word-sized chunks (v2x16/v4i8). 221 // Later, we will lower to PTX as vectors of b32. 222 223 // Number of elements to pack in one word. 224 unsigned NPerWord = 32 / EltVT.getSizeInBits(); 225 226 return std::pair(NumElts / NPerWord, 227 MVT::getVectorVT(EltVT.getSimpleVT(), NPerWord)); 228 } 229 230 llvm_unreachable("All cases in switch should return."); 231 } 232 233 /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive 234 /// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors 235 /// into their primitive components. 236 /// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the 237 /// same number of types as the Ins/Outs arrays in LowerFormalArguments, 238 /// LowerCall, and LowerReturn. 239 static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL, 240 Type *Ty, SmallVectorImpl<EVT> &ValueVTs, 241 SmallVectorImpl<uint64_t> *Offsets = nullptr, 242 uint64_t StartingOffset = 0) { 243 SmallVector<EVT, 16> TempVTs; 244 SmallVector<uint64_t, 16> TempOffsets; 245 246 // Special case for i128 - decompose to (i64, i64) 247 if (Ty->isIntegerTy(128)) { 248 ValueVTs.push_back(EVT(MVT::i64)); 249 ValueVTs.push_back(EVT(MVT::i64)); 250 251 if (Offsets) { 252 Offsets->push_back(StartingOffset + 0); 253 Offsets->push_back(StartingOffset + 8); 254 } 255 256 return; 257 } 258 259 // Given a struct type, recursively traverse the elements with custom ComputePTXValueVTs. 260 if (StructType *STy = dyn_cast<StructType>(Ty)) { 261 auto const *SL = DL.getStructLayout(STy); 262 auto ElementNum = 0; 263 for(auto *EI : STy->elements()) { 264 ComputePTXValueVTs(TLI, DL, EI, ValueVTs, Offsets, 265 StartingOffset + SL->getElementOffset(ElementNum)); 266 ++ElementNum; 267 } 268 return; 269 } 270 271 // Given an array type, recursively traverse the elements with custom ComputePTXValueVTs. 272 if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) { 273 Type *EltTy = ATy->getElementType(); 274 uint64_t EltSize = DL.getTypeAllocSize(EltTy); 275 for (int I : llvm::seq<int>(ATy->getNumElements())) 276 ComputePTXValueVTs(TLI, DL, EltTy, ValueVTs, Offsets, StartingOffset + I * EltSize); 277 return; 278 } 279 280 ComputeValueVTs(TLI, DL, Ty, TempVTs, &TempOffsets, StartingOffset); 281 for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) { 282 EVT VT = TempVTs[i]; 283 uint64_t Off = TempOffsets[i]; 284 // Split vectors into individual elements, except for v2f16, which 285 // we will pass as a single scalar. 286 if (VT.isVector()) { 287 unsigned NumElts = VT.getVectorNumElements(); 288 EVT EltVT = VT.getVectorElementType(); 289 // We require power-of-2 sized vectors becuase 290 // TargetLoweringBase::getVectorTypeBreakdown() which is invoked in 291 // ComputePTXValueVTs() cannot currently break down non-power-of-2 sized 292 // vectors. 293 if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 && 294 isPowerOf2_32(NumElts)) { 295 // Vectors with an even number of f16 elements will be passed to 296 // us as an array of v2f16/v2bf16 elements. We must match this so we 297 // stay in sync with Ins/Outs. 298 switch (EltVT.getSimpleVT().SimpleTy) { 299 case MVT::f16: 300 EltVT = MVT::v2f16; 301 break; 302 case MVT::bf16: 303 EltVT = MVT::v2bf16; 304 break; 305 case MVT::i16: 306 EltVT = MVT::v2i16; 307 break; 308 default: 309 llvm_unreachable("Unexpected type"); 310 } 311 NumElts /= 2; 312 } else if (EltVT.getSimpleVT() == MVT::i8 && 313 ((NumElts % 4 == 0 && isPowerOf2_32(NumElts)) || 314 NumElts == 3)) { 315 // v*i8 are formally lowered as v4i8 316 EltVT = MVT::v4i8; 317 NumElts = (NumElts + 3) / 4; 318 } else if (EltVT.getSimpleVT() == MVT::i8 && NumElts == 2) { 319 // v2i8 is promoted to v2i16 320 NumElts = 1; 321 EltVT = MVT::v2i16; 322 } 323 for (unsigned j = 0; j != NumElts; ++j) { 324 ValueVTs.push_back(EltVT); 325 if (Offsets) 326 Offsets->push_back(Off + j * EltVT.getStoreSize()); 327 } 328 } else { 329 ValueVTs.push_back(VT); 330 if (Offsets) 331 Offsets->push_back(Off); 332 } 333 } 334 } 335 336 /// PromoteScalarIntegerPTX 337 /// Used to make sure the arguments/returns are suitable for passing 338 /// and promote them to a larger size if they're not. 339 /// 340 /// The promoted type is placed in \p PromoteVT if the function returns true. 341 static bool PromoteScalarIntegerPTX(const EVT &VT, MVT *PromotedVT) { 342 if (VT.isScalarInteger()) { 343 switch (PowerOf2Ceil(VT.getFixedSizeInBits())) { 344 default: 345 llvm_unreachable( 346 "Promotion is not suitable for scalars of size larger than 64-bits"); 347 case 1: 348 *PromotedVT = MVT::i1; 349 break; 350 case 2: 351 case 4: 352 case 8: 353 *PromotedVT = MVT::i8; 354 break; 355 case 16: 356 *PromotedVT = MVT::i16; 357 break; 358 case 32: 359 *PromotedVT = MVT::i32; 360 break; 361 case 64: 362 *PromotedVT = MVT::i64; 363 break; 364 } 365 return EVT(*PromotedVT) != VT; 366 } 367 return false; 368 } 369 370 // Check whether we can merge loads/stores of some of the pieces of a 371 // flattened function parameter or return value into a single vector 372 // load/store. 373 // 374 // The flattened parameter is represented as a list of EVTs and 375 // offsets, and the whole structure is aligned to ParamAlignment. This 376 // function determines whether we can load/store pieces of the 377 // parameter starting at index Idx using a single vectorized op of 378 // size AccessSize. If so, it returns the number of param pieces 379 // covered by the vector op. Otherwise, it returns 1. 380 static unsigned CanMergeParamLoadStoresStartingAt( 381 unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs, 382 const SmallVectorImpl<uint64_t> &Offsets, Align ParamAlignment) { 383 384 // Can't vectorize if param alignment is not sufficient. 385 if (ParamAlignment < AccessSize) 386 return 1; 387 // Can't vectorize if offset is not aligned. 388 if (Offsets[Idx] & (AccessSize - 1)) 389 return 1; 390 391 EVT EltVT = ValueVTs[Idx]; 392 unsigned EltSize = EltVT.getStoreSize(); 393 394 // Element is too large to vectorize. 395 if (EltSize >= AccessSize) 396 return 1; 397 398 unsigned NumElts = AccessSize / EltSize; 399 // Can't vectorize if AccessBytes if not a multiple of EltSize. 400 if (AccessSize != EltSize * NumElts) 401 return 1; 402 403 // We don't have enough elements to vectorize. 404 if (Idx + NumElts > ValueVTs.size()) 405 return 1; 406 407 // PTX ISA can only deal with 2- and 4-element vector ops. 408 if (NumElts != 4 && NumElts != 2) 409 return 1; 410 411 for (unsigned j = Idx + 1; j < Idx + NumElts; ++j) { 412 // Types do not match. 413 if (ValueVTs[j] != EltVT) 414 return 1; 415 416 // Elements are not contiguous. 417 if (Offsets[j] - Offsets[j - 1] != EltSize) 418 return 1; 419 } 420 // OK. We can vectorize ValueVTs[i..i+NumElts) 421 return NumElts; 422 } 423 424 // Flags for tracking per-element vectorization state of loads/stores 425 // of a flattened function parameter or return value. 426 enum ParamVectorizationFlags { 427 PVF_INNER = 0x0, // Middle elements of a vector. 428 PVF_FIRST = 0x1, // First element of the vector. 429 PVF_LAST = 0x2, // Last element of the vector. 430 // Scalar is effectively a 1-element vector. 431 PVF_SCALAR = PVF_FIRST | PVF_LAST 432 }; 433 434 // Computes whether and how we can vectorize the loads/stores of a 435 // flattened function parameter or return value. 436 // 437 // The flattened parameter is represented as the list of ValueVTs and 438 // Offsets, and is aligned to ParamAlignment bytes. We return a vector 439 // of the same size as ValueVTs indicating how each piece should be 440 // loaded/stored (i.e. as a scalar, or as part of a vector 441 // load/store). 442 static SmallVector<ParamVectorizationFlags, 16> 443 VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs, 444 const SmallVectorImpl<uint64_t> &Offsets, 445 Align ParamAlignment, bool IsVAArg = false) { 446 // Set vector size to match ValueVTs and mark all elements as 447 // scalars by default. 448 SmallVector<ParamVectorizationFlags, 16> VectorInfo; 449 VectorInfo.assign(ValueVTs.size(), PVF_SCALAR); 450 451 if (IsVAArg) 452 return VectorInfo; 453 454 // Check what we can vectorize using 128/64/32-bit accesses. 455 for (int I = 0, E = ValueVTs.size(); I != E; ++I) { 456 // Skip elements we've already processed. 457 assert(VectorInfo[I] == PVF_SCALAR && "Unexpected vector info state."); 458 for (unsigned AccessSize : {16, 8, 4, 2}) { 459 unsigned NumElts = CanMergeParamLoadStoresStartingAt( 460 I, AccessSize, ValueVTs, Offsets, ParamAlignment); 461 // Mark vectorized elements. 462 switch (NumElts) { 463 default: 464 llvm_unreachable("Unexpected return value"); 465 case 1: 466 // Can't vectorize using this size, try next smaller size. 467 continue; 468 case 2: 469 assert(I + 1 < E && "Not enough elements."); 470 VectorInfo[I] = PVF_FIRST; 471 VectorInfo[I + 1] = PVF_LAST; 472 I += 1; 473 break; 474 case 4: 475 assert(I + 3 < E && "Not enough elements."); 476 VectorInfo[I] = PVF_FIRST; 477 VectorInfo[I + 1] = PVF_INNER; 478 VectorInfo[I + 2] = PVF_INNER; 479 VectorInfo[I + 3] = PVF_LAST; 480 I += 3; 481 break; 482 } 483 // Break out of the inner loop because we've already succeeded 484 // using largest possible AccessSize. 485 break; 486 } 487 } 488 return VectorInfo; 489 } 490 491 static SDValue MaybeBitcast(SelectionDAG &DAG, SDLoc DL, EVT VT, 492 SDValue Value) { 493 if (Value->getValueType(0) == VT) 494 return Value; 495 return DAG.getNode(ISD::BITCAST, DL, VT, Value); 496 } 497 498 // NVPTXTargetLowering Constructor. 499 NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, 500 const NVPTXSubtarget &STI) 501 : TargetLowering(TM), nvTM(&TM), STI(STI) { 502 // always lower memset, memcpy, and memmove intrinsics to load/store 503 // instructions, rather 504 // then generating calls to memset, mempcy or memmove. 505 MaxStoresPerMemset = MaxStoresPerMemsetOptSize = (unsigned)0xFFFFFFFF; 506 MaxStoresPerMemcpy = MaxStoresPerMemcpyOptSize = (unsigned) 0xFFFFFFFF; 507 MaxStoresPerMemmove = MaxStoresPerMemmoveOptSize = (unsigned) 0xFFFFFFFF; 508 509 setBooleanContents(ZeroOrNegativeOneBooleanContent); 510 setBooleanVectorContents(ZeroOrNegativeOneBooleanContent); 511 512 // Jump is Expensive. Don't create extra control flow for 'and', 'or' 513 // condition branches. 514 setJumpIsExpensive(true); 515 516 // Wide divides are _very_ slow. Try to reduce the width of the divide if 517 // possible. 518 addBypassSlowDiv(64, 32); 519 520 // By default, use the Source scheduling 521 if (sched4reg) 522 setSchedulingPreference(Sched::RegPressure); 523 else 524 setSchedulingPreference(Sched::Source); 525 526 auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action, 527 LegalizeAction NoF16Action) { 528 bool IsOpSupported = STI.allowFP16Math(); 529 switch (Op) { 530 // Several FP16 instructions are available on sm_80 only. 531 case ISD::FMINNUM: 532 case ISD::FMAXNUM: 533 case ISD::FMAXNUM_IEEE: 534 case ISD::FMINNUM_IEEE: 535 case ISD::FMAXIMUM: 536 case ISD::FMINIMUM: 537 IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70; 538 break; 539 case ISD::FEXP2: 540 IsOpSupported &= STI.getSmVersion() >= 75 && STI.getPTXVersion() >= 70; 541 break; 542 } 543 setOperationAction(Op, VT, IsOpSupported ? Action : NoF16Action); 544 }; 545 546 auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action, 547 LegalizeAction NoBF16Action) { 548 bool IsOpSupported = STI.hasNativeBF16Support(Op); 549 setOperationAction( 550 Op, VT, IsOpSupported ? Action : NoBF16Action); 551 }; 552 553 auto setI16x2OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action, 554 LegalizeAction NoI16x2Action) { 555 bool IsOpSupported = false; 556 // instructions are available on sm_90 only 557 switch (Op) { 558 case ISD::ADD: 559 case ISD::SMAX: 560 case ISD::SMIN: 561 case ISD::UMIN: 562 case ISD::UMAX: 563 IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 80; 564 break; 565 } 566 setOperationAction(Op, VT, IsOpSupported ? Action : NoI16x2Action); 567 }; 568 569 addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass); 570 addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass); 571 addRegisterClass(MVT::v2i16, &NVPTX::Int32RegsRegClass); 572 addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass); 573 addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass); 574 addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass); 575 addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass); 576 addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass); 577 addRegisterClass(MVT::f16, &NVPTX::Int16RegsRegClass); 578 addRegisterClass(MVT::v2f16, &NVPTX::Int32RegsRegClass); 579 addRegisterClass(MVT::bf16, &NVPTX::Int16RegsRegClass); 580 addRegisterClass(MVT::v2bf16, &NVPTX::Int32RegsRegClass); 581 582 // Conversion to/from FP16/FP16x2 is always legal. 583 setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom); 584 setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f16, Custom); 585 setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2f16, Expand); 586 setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2f16, Expand); 587 588 setOperationAction(ISD::READCYCLECOUNTER, MVT::i64, Legal); 589 if (STI.getSmVersion() >= 30 && STI.getPTXVersion() > 31) 590 setOperationAction(ISD::READSTEADYCOUNTER, MVT::i64, Legal); 591 592 setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote); 593 setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand); 594 595 // Conversion to/from BFP16/BFP16x2 is always legal. 596 setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Custom); 597 setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2bf16, Custom); 598 setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2bf16, Expand); 599 setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2bf16, Expand); 600 601 setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand); 602 setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote); 603 if (getOperationAction(ISD::SETCC, MVT::bf16) == Promote) 604 AddPromotedToType(ISD::SETCC, MVT::bf16, MVT::f32); 605 606 // Conversion to/from i16/i16x2 is always legal. 607 setOperationAction(ISD::BUILD_VECTOR, MVT::v2i16, Custom); 608 setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2i16, Custom); 609 setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2i16, Expand); 610 setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2i16, Expand); 611 612 setOperationAction(ISD::BUILD_VECTOR, MVT::v4i8, Custom); 613 setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom); 614 setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom); 615 setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom); 616 617 // Custom conversions to/from v2i8. 618 setOperationAction(ISD::BITCAST, MVT::v2i8, Custom); 619 620 // Only logical ops can be done on v4i8 directly, others must be done 621 // elementwise. 622 setOperationAction( 623 {ISD::ABS, ISD::ADD, ISD::ADDC, ISD::ADDE, 624 ISD::BITREVERSE, ISD::CTLZ, ISD::CTPOP, ISD::CTTZ, 625 ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FSHL, ISD::FSHR, 626 ISD::MUL, ISD::MULHS, ISD::MULHU, ISD::PARITY, 627 ISD::ROTL, ISD::ROTR, ISD::SADDO, ISD::SADDO_CARRY, 628 ISD::SADDSAT, ISD::SDIV, ISD::SDIVREM, ISD::SELECT_CC, 629 ISD::SETCC, ISD::SHL, ISD::SINT_TO_FP, ISD::SMAX, 630 ISD::SMIN, ISD::SMULO, ISD::SMUL_LOHI, ISD::SRA, 631 ISD::SREM, ISD::SRL, ISD::SSHLSAT, ISD::SSUBO, 632 ISD::SSUBO_CARRY, ISD::SSUBSAT, ISD::SUB, ISD::SUBC, 633 ISD::SUBE, ISD::UADDO, ISD::UADDO_CARRY, ISD::UADDSAT, 634 ISD::UDIV, ISD::UDIVREM, ISD::UINT_TO_FP, ISD::UMAX, 635 ISD::UMIN, ISD::UMULO, ISD::UMUL_LOHI, ISD::UREM, 636 ISD::USHLSAT, ISD::USUBO, ISD::USUBO_CARRY, ISD::VSELECT, 637 ISD::USUBSAT}, 638 MVT::v4i8, Expand); 639 640 // Operations not directly supported by NVPTX. 641 for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32, 642 MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::v4i8, 643 MVT::i32, MVT::i64}) { 644 setOperationAction(ISD::SELECT_CC, VT, Expand); 645 setOperationAction(ISD::BR_CC, VT, Expand); 646 } 647 648 // Some SIGN_EXTEND_INREG can be done using cvt instruction. 649 // For others we will expand to a SHL/SRA pair. 650 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Legal); 651 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Legal); 652 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Legal); 653 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8 , Legal); 654 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand); 655 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::v2i16, Expand); 656 657 setOperationAction(ISD::SHL_PARTS, MVT::i32 , Custom); 658 setOperationAction(ISD::SRA_PARTS, MVT::i32 , Custom); 659 setOperationAction(ISD::SRL_PARTS, MVT::i32 , Custom); 660 setOperationAction(ISD::SHL_PARTS, MVT::i64 , Custom); 661 setOperationAction(ISD::SRA_PARTS, MVT::i64 , Custom); 662 setOperationAction(ISD::SRL_PARTS, MVT::i64 , Custom); 663 664 setOperationAction(ISD::BITREVERSE, MVT::i32, Legal); 665 setOperationAction(ISD::BITREVERSE, MVT::i64, Legal); 666 667 setOperationAction({ISD::ROTL, ISD::ROTR}, 668 {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64}, 669 Expand); 670 671 if (STI.hasHWROT32()) 672 setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Legal); 673 674 setOperationAction(ISD::BSWAP, MVT::i16, Expand); 675 676 setOperationAction(ISD::BR_JT, MVT::Other, Custom); 677 setOperationAction(ISD::BRIND, MVT::Other, Expand); 678 679 setOperationAction(ISD::GlobalAddress, MVT::i32, Custom); 680 setOperationAction(ISD::GlobalAddress, MVT::i64, Custom); 681 682 // We want to legalize constant related memmove and memcopy 683 // intrinsics. 684 setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom); 685 686 // Turn FP extload into load/fpextend 687 setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand); 688 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand); 689 setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand); 690 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand); 691 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand); 692 setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand); 693 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand); 694 setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand); 695 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand); 696 setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand); 697 setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand); 698 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand); 699 setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand); 700 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand); 701 setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand); 702 setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand); 703 setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand); 704 setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand); 705 setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand); 706 // Turn FP truncstore into trunc + store. 707 // FIXME: vector types should also be expanded 708 setTruncStoreAction(MVT::f32, MVT::f16, Expand); 709 setTruncStoreAction(MVT::f64, MVT::f16, Expand); 710 setTruncStoreAction(MVT::f32, MVT::bf16, Expand); 711 setTruncStoreAction(MVT::f64, MVT::bf16, Expand); 712 setTruncStoreAction(MVT::f64, MVT::f32, Expand); 713 714 // PTX does not support load / store predicate registers 715 setOperationAction(ISD::LOAD, MVT::i1, Custom); 716 setOperationAction(ISD::STORE, MVT::i1, Custom); 717 718 for (MVT VT : MVT::integer_valuetypes()) { 719 setLoadExtAction(ISD::SEXTLOAD, VT, MVT::i1, Promote); 720 setLoadExtAction(ISD::ZEXTLOAD, VT, MVT::i1, Promote); 721 setLoadExtAction(ISD::EXTLOAD, VT, MVT::i1, Promote); 722 setTruncStoreAction(VT, MVT::i1, Expand); 723 } 724 725 setCondCodeAction({ISD::SETNE, ISD::SETEQ, ISD::SETUGE, ISD::SETULE, 726 ISD::SETUGT, ISD::SETULT, ISD::SETGT, ISD::SETLT, 727 ISD::SETGE, ISD::SETLE}, 728 MVT::i1, Expand); 729 730 // expand extload of vector of integers. 731 setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v2i16, 732 MVT::v2i8, Expand); 733 setTruncStoreAction(MVT::v2i16, MVT::v2i8, Expand); 734 735 // This is legal in NVPTX 736 setOperationAction(ISD::ConstantFP, MVT::f64, Legal); 737 setOperationAction(ISD::ConstantFP, MVT::f32, Legal); 738 setOperationAction(ISD::ConstantFP, MVT::f16, Legal); 739 setOperationAction(ISD::ConstantFP, MVT::bf16, Legal); 740 741 setOperationAction(ISD::DYNAMIC_STACKALLOC, {MVT::i32, MVT::i64}, Custom); 742 setOperationAction({ISD::STACKRESTORE, ISD::STACKSAVE}, MVT::Other, Custom); 743 744 // TRAP can be lowered to PTX trap 745 setOperationAction(ISD::TRAP, MVT::Other, Legal); 746 // DEBUGTRAP can be lowered to PTX brkpt 747 setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal); 748 749 // Register custom handling for vector loads/stores 750 for (MVT VT : MVT::fixedlen_vector_valuetypes()) { 751 if (IsPTXVectorType(VT)) { 752 setOperationAction(ISD::LOAD, VT, Custom); 753 setOperationAction(ISD::STORE, VT, Custom); 754 setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom); 755 } 756 } 757 758 // Support varargs. 759 setOperationAction(ISD::VASTART, MVT::Other, Custom); 760 setOperationAction(ISD::VAARG, MVT::Other, Custom); 761 setOperationAction(ISD::VACOPY, MVT::Other, Expand); 762 setOperationAction(ISD::VAEND, MVT::Other, Expand); 763 764 // Custom handling for i8 intrinsics 765 setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom); 766 767 for (const auto& Ty : {MVT::i16, MVT::i32, MVT::i64}) { 768 setOperationAction(ISD::ABS, Ty, Legal); 769 setOperationAction(ISD::SMIN, Ty, Legal); 770 setOperationAction(ISD::SMAX, Ty, Legal); 771 setOperationAction(ISD::UMIN, Ty, Legal); 772 setOperationAction(ISD::UMAX, Ty, Legal); 773 774 setOperationAction(ISD::CTPOP, Ty, Legal); 775 setOperationAction(ISD::CTLZ, Ty, Legal); 776 } 777 778 setI16x2OperationAction(ISD::ABS, MVT::v2i16, Legal, Custom); 779 setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Custom); 780 setI16x2OperationAction(ISD::SMAX, MVT::v2i16, Legal, Custom); 781 setI16x2OperationAction(ISD::UMIN, MVT::v2i16, Legal, Custom); 782 setI16x2OperationAction(ISD::UMAX, MVT::v2i16, Legal, Custom); 783 setI16x2OperationAction(ISD::CTPOP, MVT::v2i16, Legal, Expand); 784 setI16x2OperationAction(ISD::CTLZ, MVT::v2i16, Legal, Expand); 785 786 setI16x2OperationAction(ISD::ADD, MVT::v2i16, Legal, Custom); 787 setI16x2OperationAction(ISD::SUB, MVT::v2i16, Legal, Custom); 788 setI16x2OperationAction(ISD::MUL, MVT::v2i16, Legal, Custom); 789 setI16x2OperationAction(ISD::SHL, MVT::v2i16, Legal, Custom); 790 setI16x2OperationAction(ISD::SREM, MVT::v2i16, Legal, Custom); 791 setI16x2OperationAction(ISD::UREM, MVT::v2i16, Legal, Custom); 792 793 // Other arithmetic and logic ops are unsupported. 794 setOperationAction({ISD::SDIV, ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS, 795 ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT, 796 ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::SETCC}, 797 MVT::v2i16, Expand); 798 799 setOperationAction(ISD::ADDC, MVT::i32, Legal); 800 setOperationAction(ISD::ADDE, MVT::i32, Legal); 801 setOperationAction(ISD::SUBC, MVT::i32, Legal); 802 setOperationAction(ISD::SUBE, MVT::i32, Legal); 803 if (STI.getPTXVersion() >= 43) { 804 setOperationAction(ISD::ADDC, MVT::i64, Legal); 805 setOperationAction(ISD::ADDE, MVT::i64, Legal); 806 setOperationAction(ISD::SUBC, MVT::i64, Legal); 807 setOperationAction(ISD::SUBE, MVT::i64, Legal); 808 } 809 810 setOperationAction(ISD::CTTZ, MVT::i16, Expand); 811 setOperationAction(ISD::CTTZ, MVT::v2i16, Expand); 812 setOperationAction(ISD::CTTZ, MVT::i32, Expand); 813 setOperationAction(ISD::CTTZ, MVT::i64, Expand); 814 815 // PTX does not directly support SELP of i1, so promote to i32 first 816 setOperationAction(ISD::SELECT, MVT::i1, Custom); 817 818 // PTX cannot multiply two i64s in a single instruction. 819 setOperationAction(ISD::SMUL_LOHI, MVT::i64, Expand); 820 setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand); 821 822 // We have some custom DAG combine patterns for these nodes 823 setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD, 824 ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT, 825 ISD::BUILD_VECTOR}); 826 827 // setcc for f16x2 and bf16x2 needs special handling to prevent 828 // legalizer's attempt to scalarize it due to v2i1 not being legal. 829 if (STI.allowFP16Math() || STI.hasBF16Math()) 830 setTargetDAGCombine(ISD::SETCC); 831 832 // Promote fp16 arithmetic if fp16 hardware isn't available or the 833 // user passed --nvptx-no-fp16-math. The flag is useful because, 834 // although sm_53+ GPUs have some sort of FP16 support in 835 // hardware, only sm_53 and sm_60 have full implementation. Others 836 // only have token amount of hardware and are likely to run faster 837 // by using fp32 units instead. 838 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) { 839 setFP16OperationAction(Op, MVT::f16, Legal, Promote); 840 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand); 841 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand); 842 // bf16 must be promoted to f32. 843 setBF16OperationAction(Op, MVT::bf16, Legal, Promote); 844 if (getOperationAction(Op, MVT::bf16) == Promote) 845 AddPromotedToType(Op, MVT::bf16, MVT::f32); 846 } 847 848 // On SM80, we select add/mul/sub as fma to avoid promotion to float 849 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) { 850 for (const auto &VT : {MVT::bf16, MVT::v2bf16}) { 851 if (!STI.hasNativeBF16Support(Op) && STI.hasNativeBF16Support(ISD::FMA)) { 852 setOperationAction(Op, VT, Custom); 853 } 854 } 855 } 856 857 // f16/f16x2 neg was introduced in PTX 60, SM_53. 858 const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 && 859 STI.getPTXVersion() >= 60 && 860 STI.allowFP16Math(); 861 for (const auto &VT : {MVT::f16, MVT::v2f16}) 862 setOperationAction(ISD::FNEG, VT, 863 IsFP16FP16x2NegAvailable ? Legal : Expand); 864 865 setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand); 866 setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand); 867 // (would be) Library functions. 868 869 // These map to conversion instructions for scalar FP types. 870 for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT, 871 ISD::FROUNDEVEN, ISD::FTRUNC}) { 872 setOperationAction(Op, MVT::f16, Legal); 873 setOperationAction(Op, MVT::f32, Legal); 874 setOperationAction(Op, MVT::f64, Legal); 875 setOperationAction(Op, MVT::v2f16, Expand); 876 setOperationAction(Op, MVT::v2bf16, Expand); 877 setBF16OperationAction(Op, MVT::bf16, Legal, Promote); 878 if (getOperationAction(Op, MVT::bf16) == Promote) 879 AddPromotedToType(Op, MVT::bf16, MVT::f32); 880 } 881 882 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71) { 883 setOperationAction(ISD::BF16_TO_FP, MVT::f32, Expand); 884 } 885 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) { 886 for (MVT VT : {MVT::bf16, MVT::f32, MVT::f64}) { 887 setOperationAction(ISD::FP_EXTEND, VT, Custom); 888 setOperationAction(ISD::FP_ROUND, VT, Custom); 889 } 890 } 891 892 // sm_80 only has conversions between f32 and bf16. Custom lower all other 893 // bf16 conversions. 894 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) { 895 for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) { 896 setOperationAction( 897 {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT}, 898 VT, Custom); 899 } 900 setOperationAction( 901 {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT}, 902 MVT::bf16, Custom); 903 } 904 905 setOperationAction(ISD::FROUND, MVT::f16, Promote); 906 setOperationAction(ISD::FROUND, MVT::v2f16, Expand); 907 setOperationAction(ISD::FROUND, MVT::v2bf16, Expand); 908 setOperationAction(ISD::FROUND, MVT::f32, Custom); 909 setOperationAction(ISD::FROUND, MVT::f64, Custom); 910 setOperationAction(ISD::FROUND, MVT::bf16, Promote); 911 AddPromotedToType(ISD::FROUND, MVT::bf16, MVT::f32); 912 913 // 'Expand' implements FCOPYSIGN without calling an external library. 914 setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand); 915 setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand); 916 setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand); 917 setOperationAction(ISD::FCOPYSIGN, MVT::v2bf16, Expand); 918 setOperationAction(ISD::FCOPYSIGN, MVT::f32, Custom); 919 setOperationAction(ISD::FCOPYSIGN, MVT::f64, Custom); 920 921 // These map to corresponding instructions for f32/f64. f16 must be 922 // promoted to f32. v2f16 is expanded to f16, which is then promoted 923 // to f32. 924 for (const auto &Op : 925 {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS}) { 926 setOperationAction(Op, MVT::f16, Promote); 927 setOperationAction(Op, MVT::f32, Legal); 928 setOperationAction(Op, MVT::f64, Legal); 929 setOperationAction(Op, MVT::v2f16, Expand); 930 setOperationAction(Op, MVT::v2bf16, Expand); 931 setOperationAction(Op, MVT::bf16, Promote); 932 AddPromotedToType(Op, MVT::bf16, MVT::f32); 933 } 934 935 setOperationAction(ISD::FABS, {MVT::f32, MVT::f64}, Legal); 936 if (STI.getPTXVersion() >= 65) { 937 setFP16OperationAction(ISD::FABS, MVT::f16, Legal, Promote); 938 setFP16OperationAction(ISD::FABS, MVT::v2f16, Legal, Expand); 939 } else { 940 setOperationAction(ISD::FABS, MVT::f16, Promote); 941 setOperationAction(ISD::FABS, MVT::v2f16, Expand); 942 } 943 setBF16OperationAction(ISD::FABS, MVT::v2bf16, Legal, Expand); 944 setBF16OperationAction(ISD::FABS, MVT::bf16, Legal, Promote); 945 if (getOperationAction(ISD::FABS, MVT::bf16) == Promote) 946 AddPromotedToType(ISD::FABS, MVT::bf16, MVT::f32); 947 948 for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) { 949 setOperationAction(Op, MVT::f32, Legal); 950 setOperationAction(Op, MVT::f64, Legal); 951 setFP16OperationAction(Op, MVT::f16, Legal, Promote); 952 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand); 953 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand); 954 setBF16OperationAction(Op, MVT::bf16, Legal, Promote); 955 if (getOperationAction(Op, MVT::bf16) == Promote) 956 AddPromotedToType(Op, MVT::bf16, MVT::f32); 957 } 958 bool SupportsF32MinMaxNaN = 959 STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70; 960 for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) { 961 setOperationAction(Op, MVT::f32, SupportsF32MinMaxNaN ? Legal : Expand); 962 setFP16OperationAction(Op, MVT::f16, Legal, Expand); 963 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand); 964 setBF16OperationAction(Op, MVT::bf16, Legal, Expand); 965 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand); 966 } 967 968 // Custom lowering for inline asm with 128-bit operands 969 setOperationAction(ISD::CopyToReg, MVT::i128, Custom); 970 setOperationAction(ISD::CopyFromReg, MVT::i128, Custom); 971 972 // FEXP2 support: 973 // - f32 974 // - f16/f16x2 (sm_70+, PTX 7.0+) 975 // - bf16/bf16x2 (sm_90+, PTX 7.8+) 976 // When f16/bf16 types aren't supported, they are promoted/expanded to f32. 977 setOperationAction(ISD::FEXP2, MVT::f32, Legal); 978 setFP16OperationAction(ISD::FEXP2, MVT::f16, Legal, Promote); 979 setFP16OperationAction(ISD::FEXP2, MVT::v2f16, Legal, Expand); 980 setBF16OperationAction(ISD::FEXP2, MVT::bf16, Legal, Promote); 981 setBF16OperationAction(ISD::FEXP2, MVT::v2bf16, Legal, Expand); 982 983 // FLOG2 supports f32 only 984 // f16/bf16 types aren't supported, but they are promoted/expanded to f32. 985 if (UseApproxLog2F32) { 986 setOperationAction(ISD::FLOG2, MVT::f32, Legal); 987 setOperationPromotedToType(ISD::FLOG2, MVT::f16, MVT::f32); 988 setOperationPromotedToType(ISD::FLOG2, MVT::bf16, MVT::f32); 989 setOperationAction(ISD::FLOG2, {MVT::v2f16, MVT::v2bf16}, Expand); 990 } 991 992 // No FPOW or FREM in PTX. 993 994 // Now deduce the information based on the above mentioned 995 // actions 996 computeRegisterProperties(STI.getRegisterInfo()); 997 998 setMinCmpXchgSizeInBits(STI.getMinCmpXchgSizeInBits()); 999 setMaxAtomicSizeInBitsSupported(64); 1000 setMaxDivRemBitWidthSupported(64); 1001 } 1002 1003 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { 1004 1005 #define MAKE_CASE(V) \ 1006 case V: \ 1007 return #V; 1008 1009 switch ((NVPTXISD::NodeType)Opcode) { 1010 case NVPTXISD::FIRST_NUMBER: 1011 break; 1012 1013 MAKE_CASE(NVPTXISD::CALL) 1014 MAKE_CASE(NVPTXISD::RET_GLUE) 1015 MAKE_CASE(NVPTXISD::LOAD_PARAM) 1016 MAKE_CASE(NVPTXISD::Wrapper) 1017 MAKE_CASE(NVPTXISD::DeclareParam) 1018 MAKE_CASE(NVPTXISD::DeclareScalarParam) 1019 MAKE_CASE(NVPTXISD::DeclareRet) 1020 MAKE_CASE(NVPTXISD::DeclareScalarRet) 1021 MAKE_CASE(NVPTXISD::DeclareRetParam) 1022 MAKE_CASE(NVPTXISD::PrintCall) 1023 MAKE_CASE(NVPTXISD::PrintConvergentCall) 1024 MAKE_CASE(NVPTXISD::PrintCallUni) 1025 MAKE_CASE(NVPTXISD::PrintConvergentCallUni) 1026 MAKE_CASE(NVPTXISD::LoadParam) 1027 MAKE_CASE(NVPTXISD::LoadParamV2) 1028 MAKE_CASE(NVPTXISD::LoadParamV4) 1029 MAKE_CASE(NVPTXISD::StoreParam) 1030 MAKE_CASE(NVPTXISD::StoreParamV2) 1031 MAKE_CASE(NVPTXISD::StoreParamV4) 1032 MAKE_CASE(NVPTXISD::StoreParamS32) 1033 MAKE_CASE(NVPTXISD::StoreParamU32) 1034 MAKE_CASE(NVPTXISD::CallArgBegin) 1035 MAKE_CASE(NVPTXISD::CallArg) 1036 MAKE_CASE(NVPTXISD::LastCallArg) 1037 MAKE_CASE(NVPTXISD::CallArgEnd) 1038 MAKE_CASE(NVPTXISD::CallVoid) 1039 MAKE_CASE(NVPTXISD::CallVal) 1040 MAKE_CASE(NVPTXISD::CallSymbol) 1041 MAKE_CASE(NVPTXISD::Prototype) 1042 MAKE_CASE(NVPTXISD::MoveParam) 1043 MAKE_CASE(NVPTXISD::StoreRetval) 1044 MAKE_CASE(NVPTXISD::StoreRetvalV2) 1045 MAKE_CASE(NVPTXISD::StoreRetvalV4) 1046 MAKE_CASE(NVPTXISD::PseudoUseParam) 1047 MAKE_CASE(NVPTXISD::RETURN) 1048 MAKE_CASE(NVPTXISD::CallSeqBegin) 1049 MAKE_CASE(NVPTXISD::CallSeqEnd) 1050 MAKE_CASE(NVPTXISD::CallPrototype) 1051 MAKE_CASE(NVPTXISD::ProxyReg) 1052 MAKE_CASE(NVPTXISD::LoadV2) 1053 MAKE_CASE(NVPTXISD::LoadV4) 1054 MAKE_CASE(NVPTXISD::LDUV2) 1055 MAKE_CASE(NVPTXISD::LDUV4) 1056 MAKE_CASE(NVPTXISD::StoreV2) 1057 MAKE_CASE(NVPTXISD::StoreV4) 1058 MAKE_CASE(NVPTXISD::FSHL_CLAMP) 1059 MAKE_CASE(NVPTXISD::FSHR_CLAMP) 1060 MAKE_CASE(NVPTXISD::BFE) 1061 MAKE_CASE(NVPTXISD::BFI) 1062 MAKE_CASE(NVPTXISD::PRMT) 1063 MAKE_CASE(NVPTXISD::FCOPYSIGN) 1064 MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC) 1065 MAKE_CASE(NVPTXISD::STACKRESTORE) 1066 MAKE_CASE(NVPTXISD::STACKSAVE) 1067 MAKE_CASE(NVPTXISD::SETP_F16X2) 1068 MAKE_CASE(NVPTXISD::SETP_BF16X2) 1069 MAKE_CASE(NVPTXISD::Dummy) 1070 MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED) 1071 MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED) 1072 MAKE_CASE(NVPTXISD::BrxEnd) 1073 MAKE_CASE(NVPTXISD::BrxItem) 1074 MAKE_CASE(NVPTXISD::BrxStart) 1075 } 1076 return nullptr; 1077 1078 #undef MAKE_CASE 1079 } 1080 1081 TargetLoweringBase::LegalizeTypeAction 1082 NVPTXTargetLowering::getPreferredVectorAction(MVT VT) const { 1083 if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 && 1084 VT.getScalarType() == MVT::i1) 1085 return TypeSplitVector; 1086 return TargetLoweringBase::getPreferredVectorAction(VT); 1087 } 1088 1089 SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, 1090 int Enabled, int &ExtraSteps, 1091 bool &UseOneConst, 1092 bool Reciprocal) const { 1093 if (!(Enabled == ReciprocalEstimate::Enabled || 1094 (Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32()))) 1095 return SDValue(); 1096 1097 if (ExtraSteps == ReciprocalEstimate::Unspecified) 1098 ExtraSteps = 0; 1099 1100 SDLoc DL(Operand); 1101 EVT VT = Operand.getValueType(); 1102 bool Ftz = useF32FTZ(DAG.getMachineFunction()); 1103 1104 auto MakeIntrinsicCall = [&](Intrinsic::ID IID) { 1105 return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, 1106 DAG.getConstant(IID, DL, MVT::i32), Operand); 1107 }; 1108 1109 // The sqrt and rsqrt refinement processes assume we always start out with an 1110 // approximation of the rsqrt. Therefore, if we're going to do any refinement 1111 // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing 1112 // any refinement, we must return a regular sqrt. 1113 if (Reciprocal || ExtraSteps > 0) { 1114 if (VT == MVT::f32) 1115 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f 1116 : Intrinsic::nvvm_rsqrt_approx_f); 1117 else if (VT == MVT::f64) 1118 return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d); 1119 else 1120 return SDValue(); 1121 } else { 1122 if (VT == MVT::f32) 1123 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f 1124 : Intrinsic::nvvm_sqrt_approx_f); 1125 else { 1126 // There's no sqrt.approx.f64 instruction, so we emit 1127 // reciprocal(rsqrt(x)). This is faster than 1128 // select(x == 0, 0, x * rsqrt(x)). (In fact, it's faster than plain 1129 // x * rsqrt(x).) 1130 return DAG.getNode( 1131 ISD::INTRINSIC_WO_CHAIN, DL, VT, 1132 DAG.getConstant(Intrinsic::nvvm_rcp_approx_ftz_d, DL, MVT::i32), 1133 MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d)); 1134 } 1135 } 1136 } 1137 1138 SDValue 1139 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const { 1140 SDLoc dl(Op); 1141 const GlobalAddressSDNode *GAN = cast<GlobalAddressSDNode>(Op); 1142 auto PtrVT = getPointerTy(DAG.getDataLayout(), GAN->getAddressSpace()); 1143 Op = DAG.getTargetGlobalAddress(GAN->getGlobal(), dl, PtrVT); 1144 return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op); 1145 } 1146 1147 static bool IsTypePassedAsArray(const Type *Ty) { 1148 return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) || 1149 Ty->isHalfTy() || Ty->isBFloatTy(); 1150 } 1151 1152 std::string NVPTXTargetLowering::getPrototype( 1153 const DataLayout &DL, Type *retTy, const ArgListTy &Args, 1154 const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment, 1155 std::optional<std::pair<unsigned, const APInt &>> VAInfo, 1156 const CallBase &CB, unsigned UniqueCallSite) const { 1157 auto PtrVT = getPointerTy(DL); 1158 1159 bool isABI = (STI.getSmVersion() >= 20); 1160 assert(isABI && "Non-ABI compilation is not supported"); 1161 if (!isABI) 1162 return ""; 1163 1164 std::string Prototype; 1165 raw_string_ostream O(Prototype); 1166 O << "prototype_" << UniqueCallSite << " : .callprototype "; 1167 1168 if (retTy->getTypeID() == Type::VoidTyID) { 1169 O << "()"; 1170 } else { 1171 O << "("; 1172 if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) && 1173 !IsTypePassedAsArray(retTy)) { 1174 unsigned size = 0; 1175 if (auto *ITy = dyn_cast<IntegerType>(retTy)) { 1176 size = ITy->getBitWidth(); 1177 } else { 1178 assert(retTy->isFloatingPointTy() && 1179 "Floating point type expected here"); 1180 size = retTy->getPrimitiveSizeInBits(); 1181 } 1182 // PTX ABI requires all scalar return values to be at least 32 1183 // bits in size. fp16 normally uses .b16 as its storage type in 1184 // PTX, so its size must be adjusted here, too. 1185 size = promoteScalarArgumentSize(size); 1186 1187 O << ".param .b" << size << " _"; 1188 } else if (isa<PointerType>(retTy)) { 1189 O << ".param .b" << PtrVT.getSizeInBits() << " _"; 1190 } else if (IsTypePassedAsArray(retTy)) { 1191 O << ".param .align " << (retAlignment ? retAlignment->value() : 0) 1192 << " .b8 _[" << DL.getTypeAllocSize(retTy) << "]"; 1193 } else { 1194 llvm_unreachable("Unknown return type"); 1195 } 1196 O << ") "; 1197 } 1198 O << "_ ("; 1199 1200 bool first = true; 1201 1202 unsigned NumArgs = VAInfo ? VAInfo->first : Args.size(); 1203 for (unsigned i = 0, OIdx = 0; i != NumArgs; ++i, ++OIdx) { 1204 Type *Ty = Args[i].Ty; 1205 if (!first) { 1206 O << ", "; 1207 } 1208 first = false; 1209 1210 if (!Outs[OIdx].Flags.isByVal()) { 1211 if (IsTypePassedAsArray(Ty)) { 1212 Align ParamAlign = 1213 getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL); 1214 O << ".param .align " << ParamAlign.value() << " .b8 "; 1215 O << "_"; 1216 O << "[" << DL.getTypeAllocSize(Ty) << "]"; 1217 // update the index for Outs 1218 SmallVector<EVT, 16> vtparts; 1219 ComputeValueVTs(*this, DL, Ty, vtparts); 1220 if (unsigned len = vtparts.size()) 1221 OIdx += len - 1; 1222 continue; 1223 } 1224 // i8 types in IR will be i16 types in SDAG 1225 assert((getValueType(DL, Ty) == Outs[OIdx].VT || 1226 (getValueType(DL, Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) && 1227 "type mismatch between callee prototype and arguments"); 1228 // scalar type 1229 unsigned sz = 0; 1230 if (isa<IntegerType>(Ty)) { 1231 sz = cast<IntegerType>(Ty)->getBitWidth(); 1232 sz = promoteScalarArgumentSize(sz); 1233 } else if (isa<PointerType>(Ty)) { 1234 sz = PtrVT.getSizeInBits(); 1235 } else { 1236 sz = Ty->getPrimitiveSizeInBits(); 1237 } 1238 O << ".param .b" << sz << " "; 1239 O << "_"; 1240 continue; 1241 } 1242 1243 // Indirect calls need strict ABI alignment so we disable optimizations by 1244 // not providing a function to optimize. 1245 Type *ETy = Args[i].IndirectType; 1246 Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign(); 1247 Align ParamByValAlign = 1248 getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL); 1249 1250 O << ".param .align " << ParamByValAlign.value() << " .b8 "; 1251 O << "_"; 1252 O << "[" << Outs[OIdx].Flags.getByValSize() << "]"; 1253 } 1254 1255 if (VAInfo) 1256 O << (first ? "" : ",") << " .param .align " << VAInfo->second 1257 << " .b8 _[]\n"; 1258 O << ")"; 1259 if (shouldEmitPTXNoReturn(&CB, *nvTM)) 1260 O << " .noreturn"; 1261 O << ";"; 1262 1263 return Prototype; 1264 } 1265 1266 Align NVPTXTargetLowering::getFunctionArgumentAlignment( 1267 const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const { 1268 return getAlign(*F, Idx).value_or(getFunctionParamOptimizedAlign(F, Ty, DL)); 1269 } 1270 1271 Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty, 1272 unsigned Idx, 1273 const DataLayout &DL) const { 1274 if (!CB) { 1275 // CallSite is zero, fallback to ABI type alignment 1276 return DL.getABITypeAlign(Ty); 1277 } 1278 1279 const Function *DirectCallee = CB->getCalledFunction(); 1280 1281 if (!DirectCallee) { 1282 // We don't have a direct function symbol, but that may be because of 1283 // constant cast instructions in the call. 1284 1285 // With bitcast'd call targets, the instruction will be the call 1286 if (const auto *CI = dyn_cast<CallInst>(CB)) { 1287 // Check if we have call alignment metadata 1288 if (MaybeAlign StackAlign = getAlign(*CI, Idx)) 1289 return StackAlign.value(); 1290 } 1291 DirectCallee = getMaybeBitcastedCallee(CB); 1292 } 1293 1294 // Check for function alignment information if we found that the 1295 // ultimate target is a Function 1296 if (DirectCallee) 1297 return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL); 1298 1299 // Call is indirect, fall back to the ABI type alignment 1300 return DL.getABITypeAlign(Ty); 1301 } 1302 1303 static bool adjustElementType(EVT &ElementType) { 1304 switch (ElementType.getSimpleVT().SimpleTy) { 1305 default: 1306 return false; 1307 case MVT::f16: 1308 case MVT::bf16: 1309 ElementType = MVT::i16; 1310 return true; 1311 case MVT::f32: 1312 case MVT::v2f16: 1313 case MVT::v2bf16: 1314 ElementType = MVT::i32; 1315 return true; 1316 case MVT::f64: 1317 ElementType = MVT::i64; 1318 return true; 1319 } 1320 } 1321 1322 // Use byte-store when the param address of the argument value is unaligned. 1323 // This may happen when the return value is a field of a packed structure. 1324 // 1325 // This is called in LowerCall() when passing the param values. 1326 static SDValue LowerUnalignedStoreParam(SelectionDAG &DAG, SDValue Chain, 1327 uint64_t Offset, EVT ElementType, 1328 SDValue StVal, SDValue &InGlue, 1329 unsigned ArgID, const SDLoc &dl) { 1330 // Bit logic only works on integer types 1331 if (adjustElementType(ElementType)) 1332 StVal = DAG.getNode(ISD::BITCAST, dl, ElementType, StVal); 1333 1334 // Store each byte 1335 SDVTList StoreVTs = DAG.getVTList(MVT::Other, MVT::Glue); 1336 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) { 1337 // Shift the byte to the last byte position 1338 SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, StVal, 1339 DAG.getConstant(i * 8, dl, MVT::i32)); 1340 SDValue StoreOperands[] = {Chain, DAG.getConstant(ArgID, dl, MVT::i32), 1341 DAG.getConstant(Offset + i, dl, MVT::i32), 1342 ShiftVal, InGlue}; 1343 // Trunc store only the last byte by using 1344 // st.param.b8 1345 // The register type can be larger than b8. 1346 Chain = DAG.getMemIntrinsicNode( 1347 NVPTXISD::StoreParam, dl, StoreVTs, StoreOperands, MVT::i8, 1348 MachinePointerInfo(), Align(1), MachineMemOperand::MOStore); 1349 InGlue = Chain.getValue(1); 1350 } 1351 return Chain; 1352 } 1353 1354 // Use byte-load when the param adress of the returned value is unaligned. 1355 // This may happen when the returned value is a field of a packed structure. 1356 static SDValue 1357 LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset, 1358 EVT ElementType, SDValue &InGlue, 1359 SmallVectorImpl<SDValue> &TempProxyRegOps, 1360 const SDLoc &dl) { 1361 // Bit logic only works on integer types 1362 EVT MergedType = ElementType; 1363 adjustElementType(MergedType); 1364 1365 // Load each byte and construct the whole value. Initial value to 0 1366 SDValue RetVal = DAG.getConstant(0, dl, MergedType); 1367 // LoadParamMemI8 loads into i16 register only 1368 SDVTList LoadVTs = DAG.getVTList(MVT::i16, MVT::Other, MVT::Glue); 1369 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) { 1370 SDValue LoadOperands[] = {Chain, DAG.getConstant(1, dl, MVT::i32), 1371 DAG.getConstant(Offset + i, dl, MVT::i32), 1372 InGlue}; 1373 // This will be selected to LoadParamMemI8 1374 SDValue LdVal = 1375 DAG.getMemIntrinsicNode(NVPTXISD::LoadParam, dl, LoadVTs, LoadOperands, 1376 MVT::i8, MachinePointerInfo(), Align(1)); 1377 SDValue TmpLdVal = LdVal.getValue(0); 1378 Chain = LdVal.getValue(1); 1379 InGlue = LdVal.getValue(2); 1380 1381 TmpLdVal = DAG.getNode(NVPTXISD::ProxyReg, dl, 1382 TmpLdVal.getSimpleValueType(), TmpLdVal); 1383 TempProxyRegOps.push_back(TmpLdVal); 1384 1385 SDValue CMask = DAG.getConstant(255, dl, MergedType); 1386 SDValue CShift = DAG.getConstant(i * 8, dl, MVT::i32); 1387 // Need to extend the i16 register to the whole width. 1388 TmpLdVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MergedType, TmpLdVal); 1389 // Mask off the high bits. Leave only the lower 8bits. 1390 // Do this because we are using loadparam.b8. 1391 TmpLdVal = DAG.getNode(ISD::AND, dl, MergedType, TmpLdVal, CMask); 1392 // Shift and merge 1393 TmpLdVal = DAG.getNode(ISD::SHL, dl, MergedType, TmpLdVal, CShift); 1394 RetVal = DAG.getNode(ISD::OR, dl, MergedType, RetVal, TmpLdVal); 1395 } 1396 if (ElementType != MergedType) 1397 RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal); 1398 1399 return RetVal; 1400 } 1401 1402 static bool shouldConvertToIndirectCall(const CallBase *CB, 1403 const GlobalAddressSDNode *Func) { 1404 if (!Func) 1405 return false; 1406 if (auto *CalleeFunc = dyn_cast<Function>(Func->getGlobal())) 1407 return CB->getFunctionType() != CalleeFunc->getFunctionType(); 1408 return false; 1409 } 1410 1411 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, 1412 SmallVectorImpl<SDValue> &InVals) const { 1413 1414 if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30)) 1415 report_fatal_error( 1416 "Support for variadic functions (unsized array parameter) introduced " 1417 "in PTX ISA version 6.0 and requires target sm_30."); 1418 1419 SelectionDAG &DAG = CLI.DAG; 1420 SDLoc dl = CLI.DL; 1421 SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs; 1422 SmallVectorImpl<SDValue> &OutVals = CLI.OutVals; 1423 SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins; 1424 SDValue Chain = CLI.Chain; 1425 SDValue Callee = CLI.Callee; 1426 bool &isTailCall = CLI.IsTailCall; 1427 ArgListTy &Args = CLI.getArgs(); 1428 Type *RetTy = CLI.RetTy; 1429 const CallBase *CB = CLI.CB; 1430 const DataLayout &DL = DAG.getDataLayout(); 1431 1432 bool isABI = (STI.getSmVersion() >= 20); 1433 assert(isABI && "Non-ABI compilation is not supported"); 1434 if (!isABI) 1435 return Chain; 1436 1437 // Variadic arguments. 1438 // 1439 // Normally, for each argument, we declare a param scalar or a param 1440 // byte array in the .param space, and store the argument value to that 1441 // param scalar or array starting at offset 0. 1442 // 1443 // In the case of the first variadic argument, we declare a vararg byte array 1444 // with size 0. The exact size of this array isn't known at this point, so 1445 // it'll be patched later. All the variadic arguments will be stored to this 1446 // array at a certain offset (which gets tracked by 'VAOffset'). The offset is 1447 // initially set to 0, so it can be used for non-variadic arguments (which use 1448 // 0 offset) to simplify the code. 1449 // 1450 // After all vararg is processed, 'VAOffset' holds the size of the 1451 // vararg byte array. 1452 1453 SDValue VADeclareParam; // vararg byte array 1454 unsigned FirstVAArg = CLI.NumFixedArgs; // position of the first variadic 1455 unsigned VAOffset = 0; // current offset in the param array 1456 1457 unsigned UniqueCallSite = GlobalUniqueCallSite.fetch_add(1); 1458 SDValue TempChain = Chain; 1459 Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl); 1460 SDValue InGlue = Chain.getValue(1); 1461 1462 unsigned ParamCount = 0; 1463 // Args.size() and Outs.size() need not match. 1464 // Outs.size() will be larger 1465 // * if there is an aggregate argument with multiple fields (each field 1466 // showing up separately in Outs) 1467 // * if there is a vector argument with more than typical vector-length 1468 // elements (generally if more than 4) where each vector element is 1469 // individually present in Outs. 1470 // So a different index should be used for indexing into Outs/OutVals. 1471 // See similar issue in LowerFormalArguments. 1472 unsigned OIdx = 0; 1473 // Declare the .params or .reg need to pass values 1474 // to the function 1475 for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) { 1476 EVT VT = Outs[OIdx].VT; 1477 Type *Ty = Args[i].Ty; 1478 bool IsVAArg = (i >= CLI.NumFixedArgs); 1479 bool IsByVal = Outs[OIdx].Flags.isByVal(); 1480 1481 SmallVector<EVT, 16> VTs; 1482 SmallVector<uint64_t, 16> Offsets; 1483 1484 assert((!IsByVal || Args[i].IndirectType) && 1485 "byval arg must have indirect type"); 1486 Type *ETy = (IsByVal ? Args[i].IndirectType : Ty); 1487 ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset); 1488 1489 Align ArgAlign; 1490 if (IsByVal) { 1491 // The ByValAlign in the Outs[OIdx].Flags is always set at this point, 1492 // so we don't need to worry whether it's naturally aligned or not. 1493 // See TargetLowering::LowerCallTo(). 1494 Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign(); 1495 ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy, 1496 InitialAlign, DL); 1497 if (IsVAArg) 1498 VAOffset = alignTo(VAOffset, ArgAlign); 1499 } else { 1500 ArgAlign = getArgumentAlignment(CB, Ty, ParamCount + 1, DL); 1501 } 1502 1503 unsigned TypeSize = 1504 (IsByVal ? Outs[OIdx].Flags.getByValSize() : DL.getTypeAllocSize(Ty)); 1505 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); 1506 1507 bool NeedAlign; // Does argument declaration specify alignment? 1508 bool PassAsArray = IsByVal || IsTypePassedAsArray(Ty); 1509 if (IsVAArg) { 1510 if (ParamCount == FirstVAArg) { 1511 SDValue DeclareParamOps[] = { 1512 Chain, DAG.getConstant(STI.getMaxRequiredAlignment(), dl, MVT::i32), 1513 DAG.getConstant(ParamCount, dl, MVT::i32), 1514 DAG.getConstant(1, dl, MVT::i32), InGlue}; 1515 VADeclareParam = Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, 1516 DeclareParamVTs, DeclareParamOps); 1517 } 1518 NeedAlign = PassAsArray; 1519 } else if (PassAsArray) { 1520 // declare .param .align <align> .b8 .param<n>[<size>]; 1521 SDValue DeclareParamOps[] = { 1522 Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32), 1523 DAG.getConstant(ParamCount, dl, MVT::i32), 1524 DAG.getConstant(TypeSize, dl, MVT::i32), InGlue}; 1525 Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, 1526 DeclareParamOps); 1527 NeedAlign = true; 1528 } else { 1529 // declare .param .b<size> .param<n>; 1530 if (VT.isInteger() || VT.isFloatingPoint()) { 1531 // PTX ABI requires integral types to be at least 32 bits in 1532 // size. FP16 is loaded/stored using i16, so it's handled 1533 // here as well. 1534 TypeSize = promoteScalarArgumentSize(TypeSize * 8) / 8; 1535 } 1536 SDValue DeclareScalarParamOps[] = { 1537 Chain, DAG.getConstant(ParamCount, dl, MVT::i32), 1538 DAG.getConstant(TypeSize * 8, dl, MVT::i32), 1539 DAG.getConstant(0, dl, MVT::i32), InGlue}; 1540 Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs, 1541 DeclareScalarParamOps); 1542 NeedAlign = false; 1543 } 1544 InGlue = Chain.getValue(1); 1545 1546 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter 1547 // than 32-bits are sign extended or zero extended, depending on 1548 // whether they are signed or unsigned types. This case applies 1549 // only to scalar parameters and not to aggregate values. 1550 bool ExtendIntegerParam = 1551 Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32; 1552 1553 auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg); 1554 SmallVector<SDValue, 6> StoreOperands; 1555 for (unsigned j = 0, je = VTs.size(); j != je; ++j) { 1556 EVT EltVT = VTs[j]; 1557 int CurOffset = Offsets[j]; 1558 MaybeAlign PartAlign; 1559 if (NeedAlign) 1560 PartAlign = commonAlignment(ArgAlign, CurOffset); 1561 1562 SDValue StVal = OutVals[OIdx]; 1563 1564 MVT PromotedVT; 1565 if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) { 1566 EltVT = EVT(PromotedVT); 1567 } 1568 if (PromoteScalarIntegerPTX(StVal.getValueType(), &PromotedVT)) { 1569 llvm::ISD::NodeType Ext = 1570 Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; 1571 StVal = DAG.getNode(Ext, dl, PromotedVT, StVal); 1572 } 1573 1574 if (IsByVal) { 1575 auto PtrVT = getPointerTy(DL); 1576 SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal, 1577 DAG.getConstant(CurOffset, dl, PtrVT)); 1578 StVal = DAG.getLoad(EltVT, dl, TempChain, srcAddr, MachinePointerInfo(), 1579 PartAlign); 1580 } else if (ExtendIntegerParam) { 1581 assert(VTs.size() == 1 && "Scalar can't have multiple parts."); 1582 // zext/sext to i32 1583 StVal = DAG.getNode(Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND 1584 : ISD::ZERO_EXTEND, 1585 dl, MVT::i32, StVal); 1586 } 1587 1588 if (!ExtendIntegerParam && EltVT.getSizeInBits() < 16) { 1589 // Use 16-bit registers for small stores as it's the 1590 // smallest general purpose register size supported by NVPTX. 1591 StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal); 1592 } 1593 1594 // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a 1595 // scalar store. In such cases, fall back to byte stores. 1596 if (VectorInfo[j] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() && 1597 PartAlign.value() < 1598 DL.getABITypeAlign(EltVT.getTypeForEVT(*DAG.getContext()))) { 1599 assert(StoreOperands.empty() && "Unfinished preceeding store."); 1600 Chain = LowerUnalignedStoreParam( 1601 DAG, Chain, IsByVal ? CurOffset + VAOffset : CurOffset, EltVT, 1602 StVal, InGlue, ParamCount, dl); 1603 1604 // LowerUnalignedStoreParam took care of inserting the necessary nodes 1605 // into the SDAG, so just move on to the next element. 1606 if (!IsByVal) 1607 ++OIdx; 1608 continue; 1609 } 1610 1611 // New store. 1612 if (VectorInfo[j] & PVF_FIRST) { 1613 assert(StoreOperands.empty() && "Unfinished preceding store."); 1614 StoreOperands.push_back(Chain); 1615 StoreOperands.push_back( 1616 DAG.getConstant(IsVAArg ? FirstVAArg : ParamCount, dl, MVT::i32)); 1617 1618 StoreOperands.push_back(DAG.getConstant( 1619 IsByVal ? CurOffset + VAOffset : (IsVAArg ? VAOffset : CurOffset), 1620 dl, MVT::i32)); 1621 } 1622 1623 // Record the value to store. 1624 StoreOperands.push_back(StVal); 1625 1626 if (VectorInfo[j] & PVF_LAST) { 1627 unsigned NumElts = StoreOperands.size() - 3; 1628 NVPTXISD::NodeType Op; 1629 switch (NumElts) { 1630 case 1: 1631 Op = NVPTXISD::StoreParam; 1632 break; 1633 case 2: 1634 Op = NVPTXISD::StoreParamV2; 1635 break; 1636 case 4: 1637 Op = NVPTXISD::StoreParamV4; 1638 break; 1639 default: 1640 llvm_unreachable("Invalid vector info."); 1641 } 1642 1643 StoreOperands.push_back(InGlue); 1644 1645 // Adjust type of the store op if we've extended the scalar 1646 // return value. 1647 EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT; 1648 1649 Chain = DAG.getMemIntrinsicNode( 1650 Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands, 1651 TheStoreType, MachinePointerInfo(), PartAlign, 1652 MachineMemOperand::MOStore); 1653 InGlue = Chain.getValue(1); 1654 1655 // Cleanup. 1656 StoreOperands.clear(); 1657 1658 // TODO: We may need to support vector types that can be passed 1659 // as scalars in variadic arguments. 1660 if (!IsByVal && IsVAArg) { 1661 assert(NumElts == 1 && 1662 "Vectorization is expected to be disabled for variadics."); 1663 VAOffset += DL.getTypeAllocSize( 1664 TheStoreType.getTypeForEVT(*DAG.getContext())); 1665 } 1666 } 1667 if (!IsByVal) 1668 ++OIdx; 1669 } 1670 assert(StoreOperands.empty() && "Unfinished parameter store."); 1671 if (!IsByVal && VTs.size() > 0) 1672 --OIdx; 1673 ++ParamCount; 1674 if (IsByVal && IsVAArg) 1675 VAOffset += TypeSize; 1676 } 1677 1678 GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode()); 1679 MaybeAlign retAlignment = std::nullopt; 1680 1681 // Handle Result 1682 if (Ins.size() > 0) { 1683 SmallVector<EVT, 16> resvtparts; 1684 ComputeValueVTs(*this, DL, RetTy, resvtparts); 1685 1686 // Declare 1687 // .param .align N .b8 retval0[<size-in-bytes>], or 1688 // .param .b<size-in-bits> retval0 1689 unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy); 1690 if (!IsTypePassedAsArray(RetTy)) { 1691 resultsz = promoteScalarArgumentSize(resultsz); 1692 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue); 1693 SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32), 1694 DAG.getConstant(resultsz, dl, MVT::i32), 1695 DAG.getConstant(0, dl, MVT::i32), InGlue }; 1696 Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs, 1697 DeclareRetOps); 1698 InGlue = Chain.getValue(1); 1699 } else { 1700 retAlignment = getArgumentAlignment(CB, RetTy, 0, DL); 1701 assert(retAlignment && "retAlignment is guaranteed to be set"); 1702 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue); 1703 SDValue DeclareRetOps[] = { 1704 Chain, DAG.getConstant(retAlignment->value(), dl, MVT::i32), 1705 DAG.getConstant(resultsz / 8, dl, MVT::i32), 1706 DAG.getConstant(0, dl, MVT::i32), InGlue}; 1707 Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs, 1708 DeclareRetOps); 1709 InGlue = Chain.getValue(1); 1710 } 1711 } 1712 1713 bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs); 1714 // Set the size of the vararg param byte array if the callee is a variadic 1715 // function and the variadic part is not empty. 1716 if (HasVAArgs) { 1717 SDValue DeclareParamOps[] = { 1718 VADeclareParam.getOperand(0), VADeclareParam.getOperand(1), 1719 VADeclareParam.getOperand(2), DAG.getConstant(VAOffset, dl, MVT::i32), 1720 VADeclareParam.getOperand(4)}; 1721 DAG.MorphNodeTo(VADeclareParam.getNode(), VADeclareParam.getOpcode(), 1722 VADeclareParam->getVTList(), DeclareParamOps); 1723 } 1724 1725 // If the type of the callsite does not match that of the function, convert 1726 // the callsite to an indirect call. 1727 bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func); 1728 1729 // Both indirect calls and libcalls have nullptr Func. In order to distinguish 1730 // between them we must rely on the call site value which is valid for 1731 // indirect calls but is always null for libcalls. 1732 bool isIndirectCall = (!Func && CB) || ConvertToIndirectCall; 1733 1734 if (isa<ExternalSymbolSDNode>(Callee)) { 1735 Function* CalleeFunc = nullptr; 1736 1737 // Try to find the callee in the current module. 1738 Callee = DAG.getSymbolFunctionGlobalAddress(Callee, &CalleeFunc); 1739 assert(CalleeFunc != nullptr && "Libcall callee must be set."); 1740 1741 // Set the "libcall callee" attribute to indicate that the function 1742 // must always have a declaration. 1743 CalleeFunc->addFnAttr("nvptx-libcall-callee", "true"); 1744 } 1745 1746 if (isIndirectCall) { 1747 // This is indirect function call case : PTX requires a prototype of the 1748 // form 1749 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _); 1750 // to be emitted, and the label has to used as the last arg of call 1751 // instruction. 1752 // The prototype is embedded in a string and put as the operand for a 1753 // CallPrototype SDNode which will print out to the value of the string. 1754 SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue); 1755 std::string Proto = getPrototype( 1756 DL, RetTy, Args, Outs, retAlignment, 1757 HasVAArgs 1758 ? std::optional<std::pair<unsigned, const APInt &>>(std::make_pair( 1759 CLI.NumFixedArgs, VADeclareParam->getConstantOperandAPInt(1))) 1760 : std::nullopt, 1761 *CB, UniqueCallSite); 1762 const char *ProtoStr = nvTM->getStrPool().save(Proto).data(); 1763 SDValue ProtoOps[] = { 1764 Chain, 1765 DAG.getTargetExternalSymbol(ProtoStr, MVT::i32), 1766 InGlue, 1767 }; 1768 Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, ProtoVTs, ProtoOps); 1769 InGlue = Chain.getValue(1); 1770 } 1771 // Op to just print "call" 1772 SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue); 1773 SDValue PrintCallOps[] = { 1774 Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, dl, MVT::i32), InGlue 1775 }; 1776 // We model convergent calls as separate opcodes. 1777 unsigned Opcode = isIndirectCall ? NVPTXISD::PrintCall : NVPTXISD::PrintCallUni; 1778 if (CLI.IsConvergent) 1779 Opcode = Opcode == NVPTXISD::PrintCallUni ? NVPTXISD::PrintConvergentCallUni 1780 : NVPTXISD::PrintConvergentCall; 1781 Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps); 1782 InGlue = Chain.getValue(1); 1783 1784 if (ConvertToIndirectCall) { 1785 // Copy the function ptr to a ptx register and use the register to call the 1786 // function. 1787 EVT DestVT = Callee.getValueType(); 1788 MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo(); 1789 const TargetLowering &TLI = DAG.getTargetLoweringInfo(); 1790 unsigned DestReg = 1791 RegInfo.createVirtualRegister(TLI.getRegClassFor(DestVT.getSimpleVT())); 1792 auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee); 1793 Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT); 1794 } 1795 1796 // Ops to print out the function name 1797 SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue); 1798 SDValue CallVoidOps[] = { Chain, Callee, InGlue }; 1799 Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps); 1800 InGlue = Chain.getValue(1); 1801 1802 // Ops to print out the param list 1803 SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue); 1804 SDValue CallArgBeginOps[] = { Chain, InGlue }; 1805 Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs, 1806 CallArgBeginOps); 1807 InGlue = Chain.getValue(1); 1808 1809 for (unsigned i = 0, e = std::min(CLI.NumFixedArgs + 1, ParamCount); i != e; 1810 ++i) { 1811 unsigned opcode; 1812 if (i == (e - 1)) 1813 opcode = NVPTXISD::LastCallArg; 1814 else 1815 opcode = NVPTXISD::CallArg; 1816 SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue); 1817 SDValue CallArgOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32), 1818 DAG.getConstant(i, dl, MVT::i32), InGlue }; 1819 Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps); 1820 InGlue = Chain.getValue(1); 1821 } 1822 SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue); 1823 SDValue CallArgEndOps[] = { Chain, 1824 DAG.getConstant(isIndirectCall ? 0 : 1, dl, MVT::i32), 1825 InGlue }; 1826 Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps); 1827 InGlue = Chain.getValue(1); 1828 1829 if (isIndirectCall) { 1830 SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue); 1831 SDValue PrototypeOps[] = { 1832 Chain, DAG.getConstant(UniqueCallSite, dl, MVT::i32), InGlue}; 1833 Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps); 1834 InGlue = Chain.getValue(1); 1835 } 1836 1837 SmallVector<SDValue, 16> ProxyRegOps; 1838 SmallVector<std::optional<MVT>, 16> ProxyRegTruncates; 1839 // An item of the vector is filled if the element does not need a ProxyReg 1840 // operation on it and should be added to InVals as is. ProxyRegOps and 1841 // ProxyRegTruncates contain empty/none items at the same index. 1842 SmallVector<SDValue, 16> RetElts; 1843 // A temporary ProxyReg operations inserted in `LowerUnalignedLoadRetParam()` 1844 // to use the values of `LoadParam`s and to be replaced later then 1845 // `CALLSEQ_END` is added. 1846 SmallVector<SDValue, 16> TempProxyRegOps; 1847 1848 // Generate loads from param memory/moves from registers for result 1849 if (Ins.size() > 0) { 1850 SmallVector<EVT, 16> VTs; 1851 SmallVector<uint64_t, 16> Offsets; 1852 ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0); 1853 assert(VTs.size() == Ins.size() && "Bad value decomposition"); 1854 1855 Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL); 1856 auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign); 1857 1858 SmallVector<EVT, 6> LoadVTs; 1859 int VecIdx = -1; // Index of the first element of the vector. 1860 1861 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than 1862 // 32-bits are sign extended or zero extended, depending on whether 1863 // they are signed or unsigned types. 1864 bool ExtendIntegerRetVal = 1865 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32; 1866 1867 for (unsigned i = 0, e = VTs.size(); i != e; ++i) { 1868 bool needTruncate = false; 1869 EVT TheLoadType = VTs[i]; 1870 EVT EltType = Ins[i].VT; 1871 Align EltAlign = commonAlignment(RetAlign, Offsets[i]); 1872 MVT PromotedVT; 1873 1874 if (PromoteScalarIntegerPTX(TheLoadType, &PromotedVT)) { 1875 TheLoadType = EVT(PromotedVT); 1876 EltType = EVT(PromotedVT); 1877 needTruncate = true; 1878 } 1879 1880 if (ExtendIntegerRetVal) { 1881 TheLoadType = MVT::i32; 1882 EltType = MVT::i32; 1883 needTruncate = true; 1884 } else if (TheLoadType.getSizeInBits() < 16) { 1885 if (VTs[i].isInteger()) 1886 needTruncate = true; 1887 EltType = MVT::i16; 1888 } 1889 1890 // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a 1891 // scalar load. In such cases, fall back to byte loads. 1892 if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType() && 1893 EltAlign < DL.getABITypeAlign( 1894 TheLoadType.getTypeForEVT(*DAG.getContext()))) { 1895 assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list."); 1896 SDValue Ret = LowerUnalignedLoadRetParam( 1897 DAG, Chain, Offsets[i], TheLoadType, InGlue, TempProxyRegOps, dl); 1898 ProxyRegOps.push_back(SDValue()); 1899 ProxyRegTruncates.push_back(std::optional<MVT>()); 1900 RetElts.resize(i); 1901 RetElts.push_back(Ret); 1902 1903 continue; 1904 } 1905 1906 // Record index of the very first element of the vector. 1907 if (VectorInfo[i] & PVF_FIRST) { 1908 assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list."); 1909 VecIdx = i; 1910 } 1911 1912 LoadVTs.push_back(EltType); 1913 1914 if (VectorInfo[i] & PVF_LAST) { 1915 unsigned NumElts = LoadVTs.size(); 1916 LoadVTs.push_back(MVT::Other); 1917 LoadVTs.push_back(MVT::Glue); 1918 NVPTXISD::NodeType Op; 1919 switch (NumElts) { 1920 case 1: 1921 Op = NVPTXISD::LoadParam; 1922 break; 1923 case 2: 1924 Op = NVPTXISD::LoadParamV2; 1925 break; 1926 case 4: 1927 Op = NVPTXISD::LoadParamV4; 1928 break; 1929 default: 1930 llvm_unreachable("Invalid vector info."); 1931 } 1932 1933 SDValue LoadOperands[] = { 1934 Chain, DAG.getConstant(1, dl, MVT::i32), 1935 DAG.getConstant(Offsets[VecIdx], dl, MVT::i32), InGlue}; 1936 SDValue RetVal = DAG.getMemIntrinsicNode( 1937 Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType, 1938 MachinePointerInfo(), EltAlign, 1939 MachineMemOperand::MOLoad); 1940 1941 for (unsigned j = 0; j < NumElts; ++j) { 1942 ProxyRegOps.push_back(RetVal.getValue(j)); 1943 1944 if (needTruncate) 1945 ProxyRegTruncates.push_back(std::optional<MVT>(Ins[VecIdx + j].VT)); 1946 else 1947 ProxyRegTruncates.push_back(std::optional<MVT>()); 1948 } 1949 1950 Chain = RetVal.getValue(NumElts); 1951 InGlue = RetVal.getValue(NumElts + 1); 1952 1953 // Cleanup 1954 VecIdx = -1; 1955 LoadVTs.clear(); 1956 } 1957 } 1958 } 1959 1960 Chain = 1961 DAG.getCALLSEQ_END(Chain, UniqueCallSite, UniqueCallSite + 1, InGlue, dl); 1962 InGlue = Chain.getValue(1); 1963 1964 // Append ProxyReg instructions to the chain to make sure that `callseq_end` 1965 // will not get lost. Otherwise, during libcalls expansion, the nodes can become 1966 // dangling. 1967 for (unsigned i = 0; i < ProxyRegOps.size(); ++i) { 1968 if (i < RetElts.size() && RetElts[i]) { 1969 InVals.push_back(RetElts[i]); 1970 continue; 1971 } 1972 1973 SDValue Ret = DAG.getNode( 1974 NVPTXISD::ProxyReg, dl, 1975 DAG.getVTList(ProxyRegOps[i].getSimpleValueType(), MVT::Other, MVT::Glue), 1976 { Chain, ProxyRegOps[i], InGlue } 1977 ); 1978 1979 Chain = Ret.getValue(1); 1980 InGlue = Ret.getValue(2); 1981 1982 if (ProxyRegTruncates[i]) { 1983 Ret = DAG.getNode(ISD::TRUNCATE, dl, *ProxyRegTruncates[i], Ret); 1984 } 1985 1986 InVals.push_back(Ret); 1987 } 1988 1989 for (SDValue &T : TempProxyRegOps) { 1990 SDValue Repl = DAG.getNode( 1991 NVPTXISD::ProxyReg, dl, 1992 DAG.getVTList(T.getSimpleValueType(), MVT::Other, MVT::Glue), 1993 {Chain, T.getOperand(0), InGlue}); 1994 DAG.ReplaceAllUsesWith(T, Repl); 1995 DAG.RemoveDeadNode(T.getNode()); 1996 1997 Chain = Repl.getValue(1); 1998 InGlue = Repl.getValue(2); 1999 } 2000 2001 // set isTailCall to false for now, until we figure out how to express 2002 // tail call optimization in PTX 2003 isTailCall = false; 2004 return Chain; 2005 } 2006 2007 SDValue NVPTXTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op, 2008 SelectionDAG &DAG) const { 2009 2010 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) { 2011 const Function &Fn = DAG.getMachineFunction().getFunction(); 2012 2013 DiagnosticInfoUnsupported NoDynamicAlloca( 2014 Fn, 2015 "Support for dynamic alloca introduced in PTX ISA version 7.3 and " 2016 "requires target sm_52.", 2017 SDLoc(Op).getDebugLoc()); 2018 DAG.getContext()->diagnose(NoDynamicAlloca); 2019 auto Ops = {DAG.getConstant(0, SDLoc(), Op.getValueType()), 2020 Op.getOperand(0)}; 2021 return DAG.getMergeValues(Ops, SDLoc()); 2022 } 2023 2024 SDValue Chain = Op.getOperand(0); 2025 SDValue Size = Op.getOperand(1); 2026 uint64_t Align = cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue(); 2027 SDLoc DL(Op.getNode()); 2028 2029 // The size for ptx alloca instruction is 64-bit for m64 and 32-bit for m32. 2030 MVT ValueSizeTy = nvTM->is64Bit() ? MVT::i64 : MVT::i32; 2031 2032 SDValue AllocOps[] = {Chain, DAG.getZExtOrTrunc(Size, DL, ValueSizeTy), 2033 DAG.getTargetConstant(Align, DL, MVT::i32)}; 2034 EVT RetTypes[] = {ValueSizeTy, MVT::Other}; 2035 return DAG.getNode(NVPTXISD::DYNAMIC_STACKALLOC, DL, RetTypes, AllocOps); 2036 } 2037 2038 SDValue NVPTXTargetLowering::LowerSTACKRESTORE(SDValue Op, 2039 SelectionDAG &DAG) const { 2040 SDLoc DL(Op.getNode()); 2041 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) { 2042 const Function &Fn = DAG.getMachineFunction().getFunction(); 2043 2044 DiagnosticInfoUnsupported NoStackRestore( 2045 Fn, 2046 "Support for stackrestore requires PTX ISA version >= 7.3 and target " 2047 ">= sm_52.", 2048 DL.getDebugLoc()); 2049 DAG.getContext()->diagnose(NoStackRestore); 2050 return Op.getOperand(0); 2051 } 2052 2053 const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL); 2054 SDValue Chain = Op.getOperand(0); 2055 SDValue Ptr = Op.getOperand(1); 2056 SDValue ASC = DAG.getAddrSpaceCast(DL, LocalVT, Ptr, ADDRESS_SPACE_GENERIC, 2057 ADDRESS_SPACE_LOCAL); 2058 return DAG.getNode(NVPTXISD::STACKRESTORE, DL, MVT::Other, {Chain, ASC}); 2059 } 2060 2061 SDValue NVPTXTargetLowering::LowerSTACKSAVE(SDValue Op, 2062 SelectionDAG &DAG) const { 2063 SDLoc DL(Op.getNode()); 2064 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) { 2065 const Function &Fn = DAG.getMachineFunction().getFunction(); 2066 2067 DiagnosticInfoUnsupported NoStackSave( 2068 Fn, 2069 "Support for stacksave requires PTX ISA version >= 7.3 and target >= " 2070 "sm_52.", 2071 DL.getDebugLoc()); 2072 DAG.getContext()->diagnose(NoStackSave); 2073 auto Ops = {DAG.getConstant(0, DL, Op.getValueType()), Op.getOperand(0)}; 2074 return DAG.getMergeValues(Ops, DL); 2075 } 2076 2077 const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL); 2078 SDValue Chain = Op.getOperand(0); 2079 SDValue SS = 2080 DAG.getNode(NVPTXISD::STACKSAVE, DL, {LocalVT, MVT::Other}, Chain); 2081 SDValue ASC = DAG.getAddrSpaceCast( 2082 DL, Op.getValueType(), SS, ADDRESS_SPACE_LOCAL, ADDRESS_SPACE_GENERIC); 2083 return DAG.getMergeValues({ASC, SDValue(SS.getNode(), 1)}, DL); 2084 } 2085 2086 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack() 2087 // (see LegalizeDAG.cpp). This is slow and uses local memory. 2088 // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5 2089 SDValue 2090 NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const { 2091 SDNode *Node = Op.getNode(); 2092 SDLoc dl(Node); 2093 SmallVector<SDValue, 8> Ops; 2094 unsigned NumOperands = Node->getNumOperands(); 2095 for (unsigned i = 0; i < NumOperands; ++i) { 2096 SDValue SubOp = Node->getOperand(i); 2097 EVT VVT = SubOp.getNode()->getValueType(0); 2098 EVT EltVT = VVT.getVectorElementType(); 2099 unsigned NumSubElem = VVT.getVectorNumElements(); 2100 for (unsigned j = 0; j < NumSubElem; ++j) { 2101 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp, 2102 DAG.getIntPtrConstant(j, dl))); 2103 } 2104 } 2105 return DAG.getBuildVector(Node->getValueType(0), dl, Ops); 2106 } 2107 2108 SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const { 2109 // Handle bitcasting from v2i8 without hitting the default promotion 2110 // strategy which goes through stack memory. 2111 EVT FromVT = Op->getOperand(0)->getValueType(0); 2112 if (FromVT != MVT::v2i8) { 2113 return Op; 2114 } 2115 2116 // Pack vector elements into i16 and bitcast to final type 2117 SDLoc DL(Op); 2118 SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, 2119 Op->getOperand(0), DAG.getIntPtrConstant(0, DL)); 2120 SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, 2121 Op->getOperand(0), DAG.getIntPtrConstant(1, DL)); 2122 SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0); 2123 SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1); 2124 SDValue Const8 = DAG.getConstant(8, DL, MVT::i16); 2125 SDValue AsInt = DAG.getNode( 2126 ISD::OR, DL, MVT::i16, 2127 {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})}); 2128 EVT ToVT = Op->getValueType(0); 2129 return MaybeBitcast(DAG, DL, ToVT, AsInt); 2130 } 2131 2132 // We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it 2133 // would get lowered as two constant loads and vector-packing move. 2134 // Instead we want just a constant move: 2135 // mov.b32 %r2, 0x40003C00 2136 SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op, 2137 SelectionDAG &DAG) const { 2138 EVT VT = Op->getValueType(0); 2139 if (!(Isv2x16VT(VT) || VT == MVT::v4i8)) 2140 return Op; 2141 SDLoc DL(Op); 2142 2143 if (!llvm::all_of(Op->ops(), [](SDValue Operand) { 2144 return Operand->isUndef() || isa<ConstantSDNode>(Operand) || 2145 isa<ConstantFPSDNode>(Operand); 2146 })) { 2147 if (VT != MVT::v4i8) 2148 return Op; 2149 // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us 2150 // to optimize calculation of constant parts. 2151 auto GetPRMT = [&](const SDValue Left, const SDValue Right, bool Cast, 2152 uint64_t SelectionValue) -> SDValue { 2153 SDValue L = Left; 2154 SDValue R = Right; 2155 if (Cast) { 2156 L = DAG.getAnyExtOrTrunc(L, DL, MVT::i32); 2157 R = DAG.getAnyExtOrTrunc(R, DL, MVT::i32); 2158 } 2159 return DAG.getNode( 2160 NVPTXISD::PRMT, DL, MVT::v4i8, 2161 {L, R, DAG.getConstant(SelectionValue, DL, MVT::i32), 2162 DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)}); 2163 }; 2164 auto PRMT__10 = GetPRMT(Op->getOperand(0), Op->getOperand(1), true, 0x3340); 2165 auto PRMT__32 = GetPRMT(Op->getOperand(2), Op->getOperand(3), true, 0x3340); 2166 auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410); 2167 return DAG.getNode(ISD::BITCAST, DL, VT, PRMT3210); 2168 } 2169 2170 // Get value or the Nth operand as an APInt(32). Undef values treated as 0. 2171 auto GetOperand = [](SDValue Op, int N) -> APInt { 2172 const SDValue &Operand = Op->getOperand(N); 2173 EVT VT = Op->getValueType(0); 2174 if (Operand->isUndef()) 2175 return APInt(32, 0); 2176 APInt Value; 2177 if (VT == MVT::v2f16 || VT == MVT::v2bf16) 2178 Value = cast<ConstantFPSDNode>(Operand)->getValueAPF().bitcastToAPInt(); 2179 else if (VT == MVT::v2i16 || VT == MVT::v4i8) 2180 Value = Operand->getAsAPIntVal(); 2181 else 2182 llvm_unreachable("Unsupported type"); 2183 // i8 values are carried around as i16, so we need to zero out upper bits, 2184 // so they do not get in the way of combining individual byte values 2185 if (VT == MVT::v4i8) 2186 Value = Value.trunc(8); 2187 return Value.zext(32); 2188 }; 2189 APInt Value; 2190 if (Isv2x16VT(VT)) { 2191 Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(16); 2192 } else if (VT == MVT::v4i8) { 2193 Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(8) | 2194 GetOperand(Op, 2).shl(16) | GetOperand(Op, 3).shl(24); 2195 } else { 2196 llvm_unreachable("Unsupported type"); 2197 } 2198 SDValue Const = DAG.getConstant(Value, DL, MVT::i32); 2199 return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), Const); 2200 } 2201 2202 SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, 2203 SelectionDAG &DAG) const { 2204 SDValue Index = Op->getOperand(1); 2205 SDValue Vector = Op->getOperand(0); 2206 SDLoc DL(Op); 2207 EVT VectorVT = Vector.getValueType(); 2208 2209 if (VectorVT == MVT::v4i8) { 2210 SDValue BFE = 2211 DAG.getNode(NVPTXISD::BFE, DL, MVT::i32, 2212 {Vector, 2213 DAG.getNode(ISD::MUL, DL, MVT::i32, 2214 DAG.getZExtOrTrunc(Index, DL, MVT::i32), 2215 DAG.getConstant(8, DL, MVT::i32)), 2216 DAG.getConstant(8, DL, MVT::i32)}); 2217 return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0)); 2218 } 2219 2220 // Constant index will be matched by tablegen. 2221 if (isa<ConstantSDNode>(Index.getNode())) 2222 return Op; 2223 2224 // Extract individual elements and select one of them. 2225 assert(Isv2x16VT(VectorVT) && "Unexpected vector type."); 2226 EVT EltVT = VectorVT.getVectorElementType(); 2227 2228 SDLoc dl(Op.getNode()); 2229 SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector, 2230 DAG.getIntPtrConstant(0, dl)); 2231 SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector, 2232 DAG.getIntPtrConstant(1, dl)); 2233 return DAG.getSelectCC(dl, Index, DAG.getIntPtrConstant(0, dl), E0, E1, 2234 ISD::CondCode::SETEQ); 2235 } 2236 2237 SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op, 2238 SelectionDAG &DAG) const { 2239 SDValue Vector = Op->getOperand(0); 2240 EVT VectorVT = Vector.getValueType(); 2241 2242 if (VectorVT != MVT::v4i8) 2243 return Op; 2244 SDLoc DL(Op); 2245 SDValue Value = Op->getOperand(1); 2246 if (Value->isUndef()) 2247 return Vector; 2248 2249 SDValue Index = Op->getOperand(2); 2250 2251 SDValue BFI = 2252 DAG.getNode(NVPTXISD::BFI, DL, MVT::i32, 2253 {DAG.getZExtOrTrunc(Value, DL, MVT::i32), Vector, 2254 DAG.getNode(ISD::MUL, DL, MVT::i32, 2255 DAG.getZExtOrTrunc(Index, DL, MVT::i32), 2256 DAG.getConstant(8, DL, MVT::i32)), 2257 DAG.getConstant(8, DL, MVT::i32)}); 2258 return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), BFI); 2259 } 2260 2261 SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, 2262 SelectionDAG &DAG) const { 2263 SDValue V1 = Op.getOperand(0); 2264 EVT VectorVT = V1.getValueType(); 2265 if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8) 2266 return Op; 2267 2268 // Lower shuffle to PRMT instruction. 2269 const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode()); 2270 SDValue V2 = Op.getOperand(1); 2271 uint32_t Selector = 0; 2272 for (auto I : llvm::enumerate(SVN->getMask())) { 2273 if (I.value() != -1) // -1 is a placeholder for undef. 2274 Selector |= (I.value() << (I.index() * 4)); 2275 } 2276 2277 SDLoc DL(Op); 2278 return DAG.getNode(NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2, 2279 DAG.getConstant(Selector, DL, MVT::i32), 2280 DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)); 2281 } 2282 /// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which 2283 /// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift 2284 /// amount, or 2285 /// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift 2286 /// amount. 2287 SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op, 2288 SelectionDAG &DAG) const { 2289 assert(Op.getNumOperands() == 3 && "Not a double-shift!"); 2290 assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS); 2291 2292 EVT VT = Op.getValueType(); 2293 unsigned VTBits = VT.getSizeInBits(); 2294 SDLoc dl(Op); 2295 SDValue ShOpLo = Op.getOperand(0); 2296 SDValue ShOpHi = Op.getOperand(1); 2297 SDValue ShAmt = Op.getOperand(2); 2298 unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL; 2299 2300 if (VTBits == 32 && STI.getSmVersion() >= 35) { 2301 // For 32bit and sm35, we can use the funnel shift 'shf' instruction. 2302 // {dHi, dLo} = {aHi, aLo} >> Amt 2303 // dHi = aHi >> Amt 2304 // dLo = shf.r.clamp aLo, aHi, Amt 2305 2306 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt); 2307 SDValue Lo = 2308 DAG.getNode(NVPTXISD::FSHR_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt); 2309 2310 SDValue Ops[2] = { Lo, Hi }; 2311 return DAG.getMergeValues(Ops, dl); 2312 } 2313 else { 2314 // {dHi, dLo} = {aHi, aLo} >> Amt 2315 // - if (Amt>=size) then 2316 // dLo = aHi >> (Amt-size) 2317 // dHi = aHi >> Amt (this is either all 0 or all 1) 2318 // else 2319 // dLo = (aLo >>logic Amt) | (aHi << (size-Amt)) 2320 // dHi = aHi >> Amt 2321 2322 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, 2323 DAG.getConstant(VTBits, dl, MVT::i32), 2324 ShAmt); 2325 SDValue Tmp1 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt); 2326 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt, 2327 DAG.getConstant(VTBits, dl, MVT::i32)); 2328 SDValue Tmp2 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt); 2329 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2); 2330 SDValue TrueVal = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt); 2331 2332 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt, 2333 DAG.getConstant(VTBits, dl, MVT::i32), 2334 ISD::SETGE); 2335 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt); 2336 SDValue Lo = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal); 2337 2338 SDValue Ops[2] = { Lo, Hi }; 2339 return DAG.getMergeValues(Ops, dl); 2340 } 2341 } 2342 2343 /// LowerShiftLeftParts - Lower SHL_PARTS, which 2344 /// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift 2345 /// amount, or 2346 /// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift 2347 /// amount. 2348 SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op, 2349 SelectionDAG &DAG) const { 2350 assert(Op.getNumOperands() == 3 && "Not a double-shift!"); 2351 assert(Op.getOpcode() == ISD::SHL_PARTS); 2352 2353 EVT VT = Op.getValueType(); 2354 unsigned VTBits = VT.getSizeInBits(); 2355 SDLoc dl(Op); 2356 SDValue ShOpLo = Op.getOperand(0); 2357 SDValue ShOpHi = Op.getOperand(1); 2358 SDValue ShAmt = Op.getOperand(2); 2359 2360 if (VTBits == 32 && STI.getSmVersion() >= 35) { 2361 // For 32bit and sm35, we can use the funnel shift 'shf' instruction. 2362 // {dHi, dLo} = {aHi, aLo} << Amt 2363 // dHi = shf.l.clamp aLo, aHi, Amt 2364 // dLo = aLo << Amt 2365 2366 SDValue Hi = 2367 DAG.getNode(NVPTXISD::FSHL_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt); 2368 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt); 2369 2370 SDValue Ops[2] = { Lo, Hi }; 2371 return DAG.getMergeValues(Ops, dl); 2372 } 2373 else { 2374 // {dHi, dLo} = {aHi, aLo} << Amt 2375 // - if (Amt>=size) then 2376 // dLo = aLo << Amt (all 0) 2377 // dLo = aLo << (Amt-size) 2378 // else 2379 // dLo = aLo << Amt 2380 // dHi = (aHi << Amt) | (aLo >> (size-Amt)) 2381 2382 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, 2383 DAG.getConstant(VTBits, dl, MVT::i32), 2384 ShAmt); 2385 SDValue Tmp1 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt); 2386 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt, 2387 DAG.getConstant(VTBits, dl, MVT::i32)); 2388 SDValue Tmp2 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt); 2389 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2); 2390 SDValue TrueVal = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt); 2391 2392 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt, 2393 DAG.getConstant(VTBits, dl, MVT::i32), 2394 ISD::SETGE); 2395 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt); 2396 SDValue Hi = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal); 2397 2398 SDValue Ops[2] = { Lo, Hi }; 2399 return DAG.getMergeValues(Ops, dl); 2400 } 2401 } 2402 2403 /// If the types match, convert the generic copysign to the NVPTXISD version, 2404 /// otherwise bail ensuring that mismatched cases are properly expaned. 2405 SDValue NVPTXTargetLowering::LowerFCOPYSIGN(SDValue Op, 2406 SelectionDAG &DAG) const { 2407 EVT VT = Op.getValueType(); 2408 SDLoc DL(Op); 2409 2410 SDValue In1 = Op.getOperand(0); 2411 SDValue In2 = Op.getOperand(1); 2412 EVT SrcVT = In2.getValueType(); 2413 2414 if (!SrcVT.bitsEq(VT)) 2415 return SDValue(); 2416 2417 return DAG.getNode(NVPTXISD::FCOPYSIGN, DL, VT, In1, In2); 2418 } 2419 2420 SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const { 2421 EVT VT = Op.getValueType(); 2422 2423 if (VT == MVT::f32) 2424 return LowerFROUND32(Op, DAG); 2425 2426 if (VT == MVT::f64) 2427 return LowerFROUND64(Op, DAG); 2428 2429 llvm_unreachable("unhandled type"); 2430 } 2431 2432 // This is the the rounding method used in CUDA libdevice in C like code: 2433 // float roundf(float A) 2434 // { 2435 // float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f)); 2436 // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA; 2437 // return abs(A) < 0.5 ? (float)(int)A : RoundedA; 2438 // } 2439 SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op, 2440 SelectionDAG &DAG) const { 2441 SDLoc SL(Op); 2442 SDValue A = Op.getOperand(0); 2443 EVT VT = Op.getValueType(); 2444 2445 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A); 2446 2447 // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f)) 2448 SDValue Bitcast = DAG.getNode(ISD::BITCAST, SL, MVT::i32, A); 2449 const unsigned SignBitMask = 0x80000000; 2450 SDValue Sign = DAG.getNode(ISD::AND, SL, MVT::i32, Bitcast, 2451 DAG.getConstant(SignBitMask, SL, MVT::i32)); 2452 const unsigned PointFiveInBits = 0x3F000000; 2453 SDValue PointFiveWithSignRaw = 2454 DAG.getNode(ISD::OR, SL, MVT::i32, Sign, 2455 DAG.getConstant(PointFiveInBits, SL, MVT::i32)); 2456 SDValue PointFiveWithSign = 2457 DAG.getNode(ISD::BITCAST, SL, VT, PointFiveWithSignRaw); 2458 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, A, PointFiveWithSign); 2459 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA); 2460 2461 // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA; 2462 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); 2463 SDValue IsLarge = 2464 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 23.0), SL, VT), 2465 ISD::SETOGT); 2466 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA); 2467 2468 // return abs(A) < 0.5 ? (float)(int)A : RoundedA; 2469 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA, 2470 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT); 2471 SDValue RoundedAForSmallA = DAG.getNode(ISD::FTRUNC, SL, VT, A); 2472 return DAG.getNode(ISD::SELECT, SL, VT, IsSmall, RoundedAForSmallA, RoundedA); 2473 } 2474 2475 // The implementation of round(double) is similar to that of round(float) in 2476 // that they both separate the value range into three regions and use a method 2477 // specific to the region to round the values. However, round(double) first 2478 // calculates the round of the absolute value and then adds the sign back while 2479 // round(float) directly rounds the value with sign. 2480 SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op, 2481 SelectionDAG &DAG) const { 2482 SDLoc SL(Op); 2483 SDValue A = Op.getOperand(0); 2484 EVT VT = Op.getValueType(); 2485 2486 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A); 2487 2488 // double RoundedA = (double) (int) (abs(A) + 0.5f); 2489 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, AbsA, 2490 DAG.getConstantFP(0.5, SL, VT)); 2491 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA); 2492 2493 // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA; 2494 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); 2495 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA, 2496 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT); 2497 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsSmall, 2498 DAG.getConstantFP(0, SL, VT), 2499 RoundedA); 2500 2501 // Add sign to rounded_A 2502 RoundedA = DAG.getNode(ISD::FCOPYSIGN, SL, VT, RoundedA, A); 2503 DAG.getNode(ISD::FTRUNC, SL, VT, A); 2504 2505 // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA; 2506 SDValue IsLarge = 2507 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 52.0), SL, VT), 2508 ISD::SETOGT); 2509 return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA); 2510 } 2511 2512 static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) { 2513 EVT VT = N->getValueType(0); 2514 EVT NVT = MVT::f32; 2515 if (VT.isVector()) { 2516 NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount()); 2517 } 2518 SDLoc DL(N); 2519 SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT); 2520 SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT); 2521 SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags()); 2522 return DAG.getFPExtendOrRound(Res, DL, VT); 2523 } 2524 2525 SDValue NVPTXTargetLowering::PromoteBinOpIfF32FTZ(SDValue Op, 2526 SelectionDAG &DAG) const { 2527 if (useF32FTZ(DAG.getMachineFunction())) { 2528 return PromoteBinOpToF32(Op.getNode(), DAG); 2529 } 2530 return Op; 2531 } 2532 2533 SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op, 2534 SelectionDAG &DAG) const { 2535 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78); 2536 2537 if (Op.getValueType() == MVT::bf16) { 2538 SDLoc Loc(Op); 2539 return DAG.getNode( 2540 ISD::FP_ROUND, Loc, MVT::bf16, 2541 DAG.getNode(Op.getOpcode(), Loc, MVT::f32, Op.getOperand(0)), 2542 DAG.getIntPtrConstant(0, Loc, /*isTarget=*/true)); 2543 } 2544 2545 // Everything else is considered legal. 2546 return Op; 2547 } 2548 2549 SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op, 2550 SelectionDAG &DAG) const { 2551 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78); 2552 2553 if (Op.getOperand(0).getValueType() == MVT::bf16) { 2554 SDLoc Loc(Op); 2555 return DAG.getNode( 2556 Op.getOpcode(), Loc, Op.getValueType(), 2557 DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, Op.getOperand(0))); 2558 } 2559 2560 // Everything else is considered legal. 2561 return Op; 2562 } 2563 2564 SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op, 2565 SelectionDAG &DAG) const { 2566 EVT NarrowVT = Op.getValueType(); 2567 SDValue Wide = Op.getOperand(0); 2568 EVT WideVT = Wide.getValueType(); 2569 if (NarrowVT.getScalarType() == MVT::bf16) { 2570 const TargetLowering *TLI = STI.getTargetLowering(); 2571 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70) { 2572 return TLI->expandFP_ROUND(Op.getNode(), DAG); 2573 } 2574 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) { 2575 // This combination was the first to support f32 -> bf16. 2576 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70) { 2577 if (WideVT.getScalarType() == MVT::f32) { 2578 return Op; 2579 } 2580 if (WideVT.getScalarType() == MVT::f64) { 2581 SDLoc Loc(Op); 2582 // Round-inexact-to-odd f64 to f32, then do the final rounding using 2583 // the hardware f32 -> bf16 instruction. 2584 SDValue rod = TLI->expandRoundInexactToOdd( 2585 WideVT.isVector() ? WideVT.changeVectorElementType(MVT::f32) 2586 : MVT::f32, 2587 Wide, Loc, DAG); 2588 return DAG.getFPExtendOrRound(rod, Loc, NarrowVT); 2589 } 2590 } 2591 return TLI->expandFP_ROUND(Op.getNode(), DAG); 2592 } 2593 } 2594 2595 // Everything else is considered legal. 2596 return Op; 2597 } 2598 2599 SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op, 2600 SelectionDAG &DAG) const { 2601 SDValue Narrow = Op.getOperand(0); 2602 EVT NarrowVT = Narrow.getValueType(); 2603 EVT WideVT = Op.getValueType(); 2604 if (NarrowVT.getScalarType() == MVT::bf16) { 2605 if (WideVT.getScalarType() == MVT::f32 && 2606 (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) { 2607 SDLoc Loc(Op); 2608 return DAG.getNode(ISD::BF16_TO_FP, Loc, WideVT, Narrow); 2609 } 2610 if (WideVT.getScalarType() == MVT::f64 && 2611 (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) { 2612 EVT F32 = NarrowVT.isVector() ? NarrowVT.changeVectorElementType(MVT::f32) 2613 : MVT::f32; 2614 SDLoc Loc(Op); 2615 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) { 2616 Op = DAG.getNode(ISD::FP_EXTEND, Loc, F32, Narrow); 2617 } else { 2618 Op = DAG.getNode(ISD::BF16_TO_FP, Loc, F32, Narrow); 2619 } 2620 return DAG.getNode(ISD::FP_EXTEND, Loc, WideVT, Op); 2621 } 2622 } 2623 2624 // Everything else is considered legal. 2625 return Op; 2626 } 2627 2628 static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) { 2629 SDLoc DL(Op); 2630 if (Op.getValueType() != MVT::v2i16) 2631 return Op; 2632 EVT EltVT = Op.getValueType().getVectorElementType(); 2633 SmallVector<SDValue> VecElements; 2634 for (int I = 0, E = Op.getValueType().getVectorNumElements(); I < E; I++) { 2635 SmallVector<SDValue> ScalarArgs; 2636 llvm::transform(Op->ops(), std::back_inserter(ScalarArgs), 2637 [&](const SDUse &O) { 2638 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, 2639 O.get(), DAG.getIntPtrConstant(I, DL)); 2640 }); 2641 VecElements.push_back(DAG.getNode(Op.getOpcode(), DL, EltVT, ScalarArgs)); 2642 } 2643 SDValue V = 2644 DAG.getNode(ISD::BUILD_VECTOR, DL, Op.getValueType(), VecElements); 2645 return V; 2646 } 2647 2648 SDValue 2649 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { 2650 switch (Op.getOpcode()) { 2651 case ISD::RETURNADDR: 2652 return SDValue(); 2653 case ISD::FRAMEADDR: 2654 return SDValue(); 2655 case ISD::GlobalAddress: 2656 return LowerGlobalAddress(Op, DAG); 2657 case ISD::INTRINSIC_W_CHAIN: 2658 return Op; 2659 case ISD::BUILD_VECTOR: 2660 return LowerBUILD_VECTOR(Op, DAG); 2661 case ISD::BITCAST: 2662 return LowerBITCAST(Op, DAG); 2663 case ISD::EXTRACT_SUBVECTOR: 2664 return Op; 2665 case ISD::EXTRACT_VECTOR_ELT: 2666 return LowerEXTRACT_VECTOR_ELT(Op, DAG); 2667 case ISD::INSERT_VECTOR_ELT: 2668 return LowerINSERT_VECTOR_ELT(Op, DAG); 2669 case ISD::VECTOR_SHUFFLE: 2670 return LowerVECTOR_SHUFFLE(Op, DAG); 2671 case ISD::CONCAT_VECTORS: 2672 return LowerCONCAT_VECTORS(Op, DAG); 2673 case ISD::STORE: 2674 return LowerSTORE(Op, DAG); 2675 case ISD::LOAD: 2676 return LowerLOAD(Op, DAG); 2677 case ISD::SHL_PARTS: 2678 return LowerShiftLeftParts(Op, DAG); 2679 case ISD::SRA_PARTS: 2680 case ISD::SRL_PARTS: 2681 return LowerShiftRightParts(Op, DAG); 2682 case ISD::SELECT: 2683 return LowerSelect(Op, DAG); 2684 case ISD::FROUND: 2685 return LowerFROUND(Op, DAG); 2686 case ISD::FCOPYSIGN: 2687 return LowerFCOPYSIGN(Op, DAG); 2688 case ISD::SINT_TO_FP: 2689 case ISD::UINT_TO_FP: 2690 return LowerINT_TO_FP(Op, DAG); 2691 case ISD::FP_TO_SINT: 2692 case ISD::FP_TO_UINT: 2693 return LowerFP_TO_INT(Op, DAG); 2694 case ISD::FP_ROUND: 2695 return LowerFP_ROUND(Op, DAG); 2696 case ISD::FP_EXTEND: 2697 return LowerFP_EXTEND(Op, DAG); 2698 case ISD::BR_JT: 2699 return LowerBR_JT(Op, DAG); 2700 case ISD::VAARG: 2701 return LowerVAARG(Op, DAG); 2702 case ISD::VASTART: 2703 return LowerVASTART(Op, DAG); 2704 case ISD::ABS: 2705 case ISD::SMIN: 2706 case ISD::SMAX: 2707 case ISD::UMIN: 2708 case ISD::UMAX: 2709 case ISD::ADD: 2710 case ISD::SUB: 2711 case ISD::MUL: 2712 case ISD::SHL: 2713 case ISD::SREM: 2714 case ISD::UREM: 2715 return LowerVectorArith(Op, DAG); 2716 case ISD::DYNAMIC_STACKALLOC: 2717 return LowerDYNAMIC_STACKALLOC(Op, DAG); 2718 case ISD::STACKRESTORE: 2719 return LowerSTACKRESTORE(Op, DAG); 2720 case ISD::STACKSAVE: 2721 return LowerSTACKSAVE(Op, DAG); 2722 case ISD::CopyToReg: 2723 return LowerCopyToReg_128(Op, DAG); 2724 case ISD::FADD: 2725 case ISD::FSUB: 2726 case ISD::FMUL: 2727 // Used only for bf16 on SM80, where we select fma for non-ftz operation 2728 return PromoteBinOpIfF32FTZ(Op, DAG); 2729 2730 default: 2731 llvm_unreachable("Custom lowering not defined for operation"); 2732 } 2733 } 2734 2735 SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const { 2736 SDLoc DL(Op); 2737 SDValue Chain = Op.getOperand(0); 2738 const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1)); 2739 SDValue Index = Op.getOperand(2); 2740 2741 unsigned JId = JT->getIndex(); 2742 MachineJumpTableInfo *MJTI = DAG.getMachineFunction().getJumpTableInfo(); 2743 ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs; 2744 2745 SDValue IdV = DAG.getConstant(JId, DL, MVT::i32); 2746 2747 // Generate BrxStart node 2748 SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue); 2749 Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV); 2750 2751 // Generate BrxItem nodes 2752 assert(!MBBs.empty()); 2753 for (MachineBasicBlock *MBB : MBBs.drop_back()) 2754 Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0), 2755 DAG.getBasicBlock(MBB), Chain.getValue(1)); 2756 2757 // Generate BrxEnd nodes 2758 SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index, 2759 IdV, Chain.getValue(1)}; 2760 SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps); 2761 2762 return BrxEnd; 2763 } 2764 2765 // This will prevent AsmPrinter from trying to print the jump tables itself. 2766 unsigned NVPTXTargetLowering::getJumpTableEncoding() const { 2767 return MachineJumpTableInfo::EK_Inline; 2768 } 2769 2770 // This function is almost a copy of SelectionDAG::expandVAArg(). 2771 // The only diff is that this one produces loads from local address space. 2772 SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const { 2773 const TargetLowering *TLI = STI.getTargetLowering(); 2774 SDLoc DL(Op); 2775 2776 SDNode *Node = Op.getNode(); 2777 const Value *V = cast<SrcValueSDNode>(Node->getOperand(2))->getValue(); 2778 EVT VT = Node->getValueType(0); 2779 auto *Ty = VT.getTypeForEVT(*DAG.getContext()); 2780 SDValue Tmp1 = Node->getOperand(0); 2781 SDValue Tmp2 = Node->getOperand(1); 2782 const MaybeAlign MA(Node->getConstantOperandVal(3)); 2783 2784 SDValue VAListLoad = DAG.getLoad(TLI->getPointerTy(DAG.getDataLayout()), DL, 2785 Tmp1, Tmp2, MachinePointerInfo(V)); 2786 SDValue VAList = VAListLoad; 2787 2788 if (MA && *MA > TLI->getMinStackArgumentAlignment()) { 2789 VAList = DAG.getNode( 2790 ISD::ADD, DL, VAList.getValueType(), VAList, 2791 DAG.getConstant(MA->value() - 1, DL, VAList.getValueType())); 2792 2793 VAList = DAG.getNode(ISD::AND, DL, VAList.getValueType(), VAList, 2794 DAG.getSignedConstant(-(int64_t)MA->value(), DL, 2795 VAList.getValueType())); 2796 } 2797 2798 // Increment the pointer, VAList, to the next vaarg 2799 Tmp1 = DAG.getNode(ISD::ADD, DL, VAList.getValueType(), VAList, 2800 DAG.getConstant(DAG.getDataLayout().getTypeAllocSize(Ty), 2801 DL, VAList.getValueType())); 2802 2803 // Store the incremented VAList to the legalized pointer 2804 Tmp1 = DAG.getStore(VAListLoad.getValue(1), DL, Tmp1, Tmp2, 2805 MachinePointerInfo(V)); 2806 2807 const Value *SrcV = Constant::getNullValue( 2808 PointerType::get(*DAG.getContext(), ADDRESS_SPACE_LOCAL)); 2809 2810 // Load the actual argument out of the pointer VAList 2811 return DAG.getLoad(VT, DL, Tmp1, VAList, MachinePointerInfo(SrcV)); 2812 } 2813 2814 SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const { 2815 const TargetLowering *TLI = STI.getTargetLowering(); 2816 SDLoc DL(Op); 2817 EVT PtrVT = TLI->getPointerTy(DAG.getDataLayout()); 2818 2819 // Store the address of unsized array <function>_vararg[] in the ap object. 2820 SDValue Arg = getParamSymbol(DAG, /* vararg */ -1, PtrVT); 2821 SDValue VAReg = DAG.getNode(NVPTXISD::Wrapper, DL, PtrVT, Arg); 2822 2823 const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue(); 2824 return DAG.getStore(Op.getOperand(0), DL, VAReg, Op.getOperand(1), 2825 MachinePointerInfo(SV)); 2826 } 2827 2828 SDValue NVPTXTargetLowering::LowerSelect(SDValue Op, SelectionDAG &DAG) const { 2829 SDValue Op0 = Op->getOperand(0); 2830 SDValue Op1 = Op->getOperand(1); 2831 SDValue Op2 = Op->getOperand(2); 2832 SDLoc DL(Op.getNode()); 2833 2834 assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1"); 2835 2836 Op1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op1); 2837 Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op2); 2838 SDValue Select = DAG.getNode(ISD::SELECT, DL, MVT::i32, Op0, Op1, Op2); 2839 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Select); 2840 2841 return Trunc; 2842 } 2843 2844 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const { 2845 if (Op.getValueType() == MVT::i1) 2846 return LowerLOADi1(Op, DAG); 2847 2848 // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle 2849 // unaligned loads and have to handle it here. 2850 EVT VT = Op.getValueType(); 2851 if (Isv2x16VT(VT) || VT == MVT::v4i8) { 2852 LoadSDNode *Load = cast<LoadSDNode>(Op); 2853 EVT MemVT = Load->getMemoryVT(); 2854 if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(), 2855 MemVT, *Load->getMemOperand())) { 2856 SDValue Ops[2]; 2857 std::tie(Ops[0], Ops[1]) = expandUnalignedLoad(Load, DAG); 2858 return DAG.getMergeValues(Ops, SDLoc(Op)); 2859 } 2860 } 2861 2862 return SDValue(); 2863 } 2864 2865 // v = ld i1* addr 2866 // => 2867 // v1 = ld i8* addr (-> i16) 2868 // v = trunc i16 to i1 2869 SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const { 2870 SDNode *Node = Op.getNode(); 2871 LoadSDNode *LD = cast<LoadSDNode>(Node); 2872 SDLoc dl(Node); 2873 assert(LD->getExtensionType() == ISD::NON_EXTLOAD); 2874 assert(Node->getValueType(0) == MVT::i1 && 2875 "Custom lowering for i1 load only"); 2876 SDValue newLD = DAG.getExtLoad(ISD::ZEXTLOAD, dl, MVT::i16, LD->getChain(), 2877 LD->getBasePtr(), LD->getPointerInfo(), 2878 MVT::i8, LD->getAlign(), 2879 LD->getMemOperand()->getFlags()); 2880 SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD); 2881 // The legalizer (the caller) is expecting two values from the legalized 2882 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad() 2883 // in LegalizeDAG.cpp which also uses MergeValues. 2884 SDValue Ops[] = { result, LD->getChain() }; 2885 return DAG.getMergeValues(Ops, dl); 2886 } 2887 2888 SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const { 2889 StoreSDNode *Store = cast<StoreSDNode>(Op); 2890 EVT VT = Store->getMemoryVT(); 2891 2892 if (VT == MVT::i1) 2893 return LowerSTOREi1(Op, DAG); 2894 2895 // v2f16 is legal, so we can't rely on legalizer to handle unaligned 2896 // stores and have to handle it here. 2897 if ((Isv2x16VT(VT) || VT == MVT::v4i8) && 2898 !allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(), 2899 VT, *Store->getMemOperand())) 2900 return expandUnalignedStore(Store, DAG); 2901 2902 // v2f16, v2bf16 and v2i16 don't need special handling. 2903 if (Isv2x16VT(VT) || VT == MVT::v4i8) 2904 return SDValue(); 2905 2906 if (VT.isVector()) 2907 return LowerSTOREVector(Op, DAG); 2908 2909 return SDValue(); 2910 } 2911 2912 SDValue 2913 NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const { 2914 SDNode *N = Op.getNode(); 2915 SDValue Val = N->getOperand(1); 2916 SDLoc DL(N); 2917 EVT ValVT = Val.getValueType(); 2918 2919 auto NumEltsAndEltVT = getVectorLoweringShape(ValVT); 2920 if (!NumEltsAndEltVT) 2921 return SDValue(); 2922 auto [NumElts, EltVT] = NumEltsAndEltVT.value(); 2923 2924 MemSDNode *MemSD = cast<MemSDNode>(N); 2925 const DataLayout &TD = DAG.getDataLayout(); 2926 2927 Align Alignment = MemSD->getAlign(); 2928 Align PrefAlign = TD.getPrefTypeAlign(ValVT.getTypeForEVT(*DAG.getContext())); 2929 if (Alignment < PrefAlign) { 2930 // This store is not sufficiently aligned, so bail out and let this vector 2931 // store be scalarized. Note that we may still be able to emit smaller 2932 // vector stores. For example, if we are storing a <4 x float> with an 2933 // alignment of 8, this check will fail but the legalizer will try again 2934 // with 2 x <2 x float>, which will succeed with an alignment of 8. 2935 return SDValue(); 2936 } 2937 2938 // Since StoreV2 is a target node, we cannot rely on DAG type legalization. 2939 // Therefore, we must ensure the type is legal. For i1 and i8, we set the 2940 // stored type to i16 and propagate the "real" type as the memory type. 2941 bool NeedExt = false; 2942 if (EltVT.getSizeInBits() < 16) 2943 NeedExt = true; 2944 2945 unsigned Opcode = 0; 2946 switch (NumElts) { 2947 default: 2948 return SDValue(); 2949 case 2: 2950 Opcode = NVPTXISD::StoreV2; 2951 break; 2952 case 4: 2953 Opcode = NVPTXISD::StoreV4; 2954 break; 2955 } 2956 2957 SmallVector<SDValue, 8> Ops; 2958 2959 // First is the chain 2960 Ops.push_back(N->getOperand(0)); 2961 2962 // Then the split values 2963 assert(NumElts <= ValVT.getVectorNumElements() && 2964 "NumElts should not increase, only decrease or stay the same."); 2965 if (NumElts < ValVT.getVectorNumElements()) { 2966 // If the number of elements has decreased, getVectorLoweringShape has 2967 // upsized the element types 2968 assert(EltVT.isVector() && EltVT.getSizeInBits() == 32 && 2969 EltVT.getVectorNumElements() <= 4 && "Unexpected upsized type."); 2970 // Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be 2971 // stored as b32s 2972 unsigned NumEltsPerSubVector = EltVT.getVectorNumElements(); 2973 for (unsigned i = 0; i < NumElts; ++i) { 2974 SmallVector<SDValue, 4> SubVectorElts; 2975 DAG.ExtractVectorElements(Val, SubVectorElts, i * NumEltsPerSubVector, 2976 NumEltsPerSubVector); 2977 SDValue SubVector = DAG.getBuildVector(EltVT, DL, SubVectorElts); 2978 Ops.push_back(SubVector); 2979 } 2980 } else { 2981 for (unsigned i = 0; i < NumElts; ++i) { 2982 SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val, 2983 DAG.getIntPtrConstant(i, DL)); 2984 if (NeedExt) 2985 ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal); 2986 Ops.push_back(ExtVal); 2987 } 2988 } 2989 2990 // Then any remaining arguments 2991 Ops.append(N->op_begin() + 2, N->op_end()); 2992 2993 SDValue NewSt = 2994 DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops, 2995 MemSD->getMemoryVT(), MemSD->getMemOperand()); 2996 2997 // return DCI.CombineTo(N, NewSt, true); 2998 return NewSt; 2999 } 3000 3001 // st i1 v, addr 3002 // => 3003 // v1 = zxt v to i16 3004 // st.u8 i16, addr 3005 SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const { 3006 SDNode *Node = Op.getNode(); 3007 SDLoc dl(Node); 3008 StoreSDNode *ST = cast<StoreSDNode>(Node); 3009 SDValue Tmp1 = ST->getChain(); 3010 SDValue Tmp2 = ST->getBasePtr(); 3011 SDValue Tmp3 = ST->getValue(); 3012 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only"); 3013 Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3); 3014 SDValue Result = 3015 DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2, ST->getPointerInfo(), MVT::i8, 3016 ST->getAlign(), ST->getMemOperand()->getFlags()); 3017 return Result; 3018 } 3019 3020 SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op, 3021 SelectionDAG &DAG) const { 3022 // Change the CopyToReg to take in two 64-bit operands instead of a 128-bit 3023 // operand so that it can pass the legalization. 3024 3025 assert(Op.getOperand(1).getValueType() == MVT::i128 && 3026 "Custom lowering for 128-bit CopyToReg only"); 3027 3028 SDNode *Node = Op.getNode(); 3029 SDLoc DL(Node); 3030 3031 SDValue Cast = DAG.getBitcast(MVT::v2i64, Op->getOperand(2)); 3032 SDValue Lo = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast, 3033 DAG.getIntPtrConstant(0, DL)); 3034 SDValue Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast, 3035 DAG.getIntPtrConstant(1, DL)); 3036 3037 SmallVector<SDValue, 5> NewOps(Op->getNumOperands() + 1); 3038 SmallVector<EVT, 3> ResultsType(Node->values()); 3039 3040 NewOps[0] = Op->getOperand(0); // Chain 3041 NewOps[1] = Op->getOperand(1); // Dst Reg 3042 NewOps[2] = Lo; // Lower 64-bit 3043 NewOps[3] = Hi; // Higher 64-bit 3044 if (Op.getNumOperands() == 4) 3045 NewOps[4] = Op->getOperand(3); // Glue if exists 3046 3047 return DAG.getNode(ISD::CopyToReg, DL, ResultsType, NewOps); 3048 } 3049 3050 unsigned NVPTXTargetLowering::getNumRegisters( 3051 LLVMContext &Context, EVT VT, 3052 std::optional<MVT> RegisterVT = std::nullopt) const { 3053 if (VT == MVT::i128 && RegisterVT == MVT::i128) 3054 return 1; 3055 return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT); 3056 } 3057 3058 bool NVPTXTargetLowering::splitValueIntoRegisterParts( 3059 SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts, 3060 unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const { 3061 if (Val.getValueType() == MVT::i128 && NumParts == 1) { 3062 Parts[0] = Val; 3063 return true; 3064 } 3065 return false; 3066 } 3067 3068 // This creates target external symbol for a function parameter. 3069 // Name of the symbol is composed from its index and the function name. 3070 // Negative index corresponds to special parameter (unsized array) used for 3071 // passing variable arguments. 3072 SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, 3073 EVT v) const { 3074 StringRef SavedStr = nvTM->getStrPool().save( 3075 getParamName(&DAG.getMachineFunction().getFunction(), idx)); 3076 return DAG.getTargetExternalSymbol(SavedStr.data(), v); 3077 } 3078 3079 SDValue NVPTXTargetLowering::LowerFormalArguments( 3080 SDValue Chain, CallingConv::ID CallConv, bool isVarArg, 3081 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl, 3082 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const { 3083 MachineFunction &MF = DAG.getMachineFunction(); 3084 const DataLayout &DL = DAG.getDataLayout(); 3085 auto PtrVT = getPointerTy(DAG.getDataLayout()); 3086 3087 const Function *F = &MF.getFunction(); 3088 const AttributeList &PAL = F->getAttributes(); 3089 const TargetLowering *TLI = STI.getTargetLowering(); 3090 3091 SDValue Root = DAG.getRoot(); 3092 std::vector<SDValue> OutChains; 3093 3094 bool isABI = (STI.getSmVersion() >= 20); 3095 assert(isABI && "Non-ABI compilation is not supported"); 3096 if (!isABI) 3097 return Chain; 3098 3099 std::vector<Type *> argTypes; 3100 std::vector<const Argument *> theArgs; 3101 for (const Argument &I : F->args()) { 3102 theArgs.push_back(&I); 3103 argTypes.push_back(I.getType()); 3104 } 3105 // argTypes.size() (or theArgs.size()) and Ins.size() need not match. 3106 // Ins.size() will be larger 3107 // * if there is an aggregate argument with multiple fields (each field 3108 // showing up separately in Ins) 3109 // * if there is a vector argument with more than typical vector-length 3110 // elements (generally if more than 4) where each vector element is 3111 // individually present in Ins. 3112 // So a different index should be used for indexing into Ins. 3113 // See similar issue in LowerCall. 3114 unsigned InsIdx = 0; 3115 3116 for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++InsIdx) { 3117 Type *Ty = argTypes[i]; 3118 3119 if (theArgs[i]->use_empty()) { 3120 // argument is dead 3121 if (IsTypePassedAsArray(Ty) && !Ty->isVectorTy()) { 3122 SmallVector<EVT, 16> vtparts; 3123 3124 ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts); 3125 if (vtparts.empty()) 3126 report_fatal_error("Empty parameter types are not supported"); 3127 3128 for (unsigned parti = 0, parte = vtparts.size(); parti != parte; 3129 ++parti) { 3130 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT)); 3131 ++InsIdx; 3132 } 3133 if (vtparts.size() > 0) 3134 --InsIdx; 3135 continue; 3136 } 3137 if (Ty->isVectorTy()) { 3138 EVT ObjectVT = getValueType(DL, Ty); 3139 unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT); 3140 for (unsigned parti = 0; parti < NumRegs; ++parti) { 3141 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT)); 3142 ++InsIdx; 3143 } 3144 if (NumRegs > 0) 3145 --InsIdx; 3146 continue; 3147 } 3148 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT)); 3149 continue; 3150 } 3151 3152 // In the following cases, assign a node order of "i+1" 3153 // to newly created nodes. The SDNodes for params have to 3154 // appear in the same order as their order of appearance 3155 // in the original function. "i+1" holds that order. 3156 if (!PAL.hasParamAttr(i, Attribute::ByVal)) { 3157 bool aggregateIsPacked = false; 3158 if (StructType *STy = dyn_cast<StructType>(Ty)) 3159 aggregateIsPacked = STy->isPacked(); 3160 3161 SmallVector<EVT, 16> VTs; 3162 SmallVector<uint64_t, 16> Offsets; 3163 ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0); 3164 if (VTs.empty()) 3165 report_fatal_error("Empty parameter types are not supported"); 3166 3167 Align ArgAlign = getFunctionArgumentAlignment( 3168 F, Ty, i + AttributeList::FirstArgIndex, DL); 3169 auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign); 3170 3171 SDValue Arg = getParamSymbol(DAG, i, PtrVT); 3172 int VecIdx = -1; // Index of the first element of the current vector. 3173 for (unsigned parti = 0, parte = VTs.size(); parti != parte; ++parti) { 3174 if (VectorInfo[parti] & PVF_FIRST) { 3175 assert(VecIdx == -1 && "Orphaned vector."); 3176 VecIdx = parti; 3177 } 3178 3179 // That's the last element of this store op. 3180 if (VectorInfo[parti] & PVF_LAST) { 3181 unsigned NumElts = parti - VecIdx + 1; 3182 EVT EltVT = VTs[parti]; 3183 // i1 is loaded/stored as i8. 3184 EVT LoadVT = EltVT; 3185 if (EltVT == MVT::i1) 3186 LoadVT = MVT::i8; 3187 else if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8) 3188 // getLoad needs a vector type, but it can't handle 3189 // vectors which contain v2f16 or v2bf16 elements. So we must load 3190 // using i32 here and then bitcast back. 3191 LoadVT = MVT::i32; 3192 3193 EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts); 3194 SDValue VecAddr = 3195 DAG.getNode(ISD::ADD, dl, PtrVT, Arg, 3196 DAG.getConstant(Offsets[VecIdx], dl, PtrVT)); 3197 Value *srcValue = Constant::getNullValue( 3198 PointerType::get(F->getContext(), ADDRESS_SPACE_PARAM)); 3199 3200 const MaybeAlign PartAlign = [&]() -> MaybeAlign { 3201 if (aggregateIsPacked) 3202 return Align(1); 3203 if (NumElts != 1) 3204 return std::nullopt; 3205 Align PartAlign = 3206 DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext())); 3207 return commonAlignment(PartAlign, Offsets[parti]); 3208 }(); 3209 SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr, 3210 MachinePointerInfo(srcValue), PartAlign, 3211 MachineMemOperand::MODereferenceable | 3212 MachineMemOperand::MOInvariant); 3213 if (P.getNode()) 3214 P.getNode()->setIROrder(i + 1); 3215 for (unsigned j = 0; j < NumElts; ++j) { 3216 SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P, 3217 DAG.getIntPtrConstant(j, dl)); 3218 // We've loaded i1 as an i8 and now must truncate it back to i1 3219 if (EltVT == MVT::i1) 3220 Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt); 3221 // v2f16 was loaded as an i32. Now we must bitcast it back. 3222 else if (EltVT != LoadVT) 3223 Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt); 3224 3225 // If a promoted integer type is used, truncate down to the original 3226 MVT PromotedVT; 3227 if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) { 3228 Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt); 3229 } 3230 3231 // Extend the element if necessary (e.g. an i8 is loaded 3232 // into an i16 register) 3233 if (Ins[InsIdx].VT.isInteger() && 3234 Ins[InsIdx].VT.getFixedSizeInBits() > 3235 LoadVT.getFixedSizeInBits()) { 3236 unsigned Extend = Ins[InsIdx].Flags.isSExt() ? ISD::SIGN_EXTEND 3237 : ISD::ZERO_EXTEND; 3238 Elt = DAG.getNode(Extend, dl, Ins[InsIdx].VT, Elt); 3239 } 3240 InVals.push_back(Elt); 3241 } 3242 3243 // Reset vector tracking state. 3244 VecIdx = -1; 3245 } 3246 ++InsIdx; 3247 } 3248 if (VTs.size() > 0) 3249 --InsIdx; 3250 continue; 3251 } 3252 3253 // Param has ByVal attribute 3254 // Return MoveParam(param symbol). 3255 // Ideally, the param symbol can be returned directly, 3256 // but when SDNode builder decides to use it in a CopyToReg(), 3257 // machine instruction fails because TargetExternalSymbol 3258 // (not lowered) is target dependent, and CopyToReg assumes 3259 // the source is lowered. 3260 EVT ObjectVT = getValueType(DL, Ty); 3261 assert(ObjectVT == Ins[InsIdx].VT && 3262 "Ins type did not match function type"); 3263 SDValue Arg = getParamSymbol(DAG, i, PtrVT); 3264 SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg); 3265 if (p.getNode()) 3266 p.getNode()->setIROrder(i + 1); 3267 InVals.push_back(p); 3268 } 3269 3270 if (!OutChains.empty()) 3271 DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains)); 3272 3273 return Chain; 3274 } 3275 3276 // Use byte-store when the param adress of the return value is unaligned. 3277 // This may happen when the return value is a field of a packed structure. 3278 static SDValue LowerUnalignedStoreRet(SelectionDAG &DAG, SDValue Chain, 3279 uint64_t Offset, EVT ElementType, 3280 SDValue RetVal, const SDLoc &dl) { 3281 // Bit logic only works on integer types 3282 if (adjustElementType(ElementType)) 3283 RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal); 3284 3285 // Store each byte 3286 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) { 3287 // Shift the byte to the last byte position 3288 SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, RetVal, 3289 DAG.getConstant(i * 8, dl, MVT::i32)); 3290 SDValue StoreOperands[] = {Chain, DAG.getConstant(Offset + i, dl, MVT::i32), 3291 ShiftVal}; 3292 // Trunc store only the last byte by using 3293 // st.param.b8 3294 // The register type can be larger than b8. 3295 Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl, 3296 DAG.getVTList(MVT::Other), StoreOperands, 3297 MVT::i8, MachinePointerInfo(), std::nullopt, 3298 MachineMemOperand::MOStore); 3299 } 3300 return Chain; 3301 } 3302 3303 SDValue 3304 NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, 3305 bool isVarArg, 3306 const SmallVectorImpl<ISD::OutputArg> &Outs, 3307 const SmallVectorImpl<SDValue> &OutVals, 3308 const SDLoc &dl, SelectionDAG &DAG) const { 3309 const MachineFunction &MF = DAG.getMachineFunction(); 3310 const Function &F = MF.getFunction(); 3311 Type *RetTy = MF.getFunction().getReturnType(); 3312 3313 bool isABI = (STI.getSmVersion() >= 20); 3314 assert(isABI && "Non-ABI compilation is not supported"); 3315 if (!isABI) 3316 return Chain; 3317 3318 const DataLayout &DL = DAG.getDataLayout(); 3319 SmallVector<SDValue, 16> PromotedOutVals; 3320 SmallVector<EVT, 16> VTs; 3321 SmallVector<uint64_t, 16> Offsets; 3322 ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets); 3323 assert(VTs.size() == OutVals.size() && "Bad return value decomposition"); 3324 3325 for (unsigned i = 0, e = VTs.size(); i != e; ++i) { 3326 SDValue PromotedOutVal = OutVals[i]; 3327 MVT PromotedVT; 3328 if (PromoteScalarIntegerPTX(VTs[i], &PromotedVT)) { 3329 VTs[i] = EVT(PromotedVT); 3330 } 3331 if (PromoteScalarIntegerPTX(PromotedOutVal.getValueType(), &PromotedVT)) { 3332 llvm::ISD::NodeType Ext = 3333 Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; 3334 PromotedOutVal = DAG.getNode(Ext, dl, PromotedVT, PromotedOutVal); 3335 } 3336 PromotedOutVals.push_back(PromotedOutVal); 3337 } 3338 3339 auto VectorInfo = VectorizePTXValueVTs( 3340 VTs, Offsets, 3341 RetTy->isSized() ? getFunctionParamOptimizedAlign(&F, RetTy, DL) 3342 : Align(1)); 3343 3344 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than 3345 // 32-bits are sign extended or zero extended, depending on whether 3346 // they are signed or unsigned types. 3347 bool ExtendIntegerRetVal = 3348 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32; 3349 3350 SmallVector<SDValue, 6> StoreOperands; 3351 for (unsigned i = 0, e = VTs.size(); i != e; ++i) { 3352 SDValue OutVal = OutVals[i]; 3353 SDValue RetVal = PromotedOutVals[i]; 3354 3355 if (ExtendIntegerRetVal) { 3356 RetVal = DAG.getNode(Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND 3357 : ISD::ZERO_EXTEND, 3358 dl, MVT::i32, RetVal); 3359 } else if (OutVal.getValueSizeInBits() < 16) { 3360 // Use 16-bit registers for small load-stores as it's the 3361 // smallest general purpose register size supported by NVPTX. 3362 RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal); 3363 } 3364 3365 // If we have a PVF_SCALAR entry, it may not even be sufficiently aligned 3366 // for a scalar store. In such cases, fall back to byte stores. 3367 if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType()) { 3368 EVT ElementType = ExtendIntegerRetVal ? MVT::i32 : VTs[i]; 3369 Align ElementTypeAlign = 3370 DL.getABITypeAlign(ElementType.getTypeForEVT(RetTy->getContext())); 3371 Align ElementAlign = 3372 commonAlignment(DL.getABITypeAlign(RetTy), Offsets[i]); 3373 if (ElementAlign < ElementTypeAlign) { 3374 assert(StoreOperands.empty() && "Orphaned operand list."); 3375 Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[i], ElementType, 3376 RetVal, dl); 3377 3378 // The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes 3379 // into the graph, so just move on to the next element. 3380 continue; 3381 } 3382 } 3383 3384 // New load/store. Record chain and offset operands. 3385 if (VectorInfo[i] & PVF_FIRST) { 3386 assert(StoreOperands.empty() && "Orphaned operand list."); 3387 StoreOperands.push_back(Chain); 3388 StoreOperands.push_back(DAG.getConstant(Offsets[i], dl, MVT::i32)); 3389 } 3390 3391 // Record the value to return. 3392 StoreOperands.push_back(RetVal); 3393 3394 // That's the last element of this store op. 3395 if (VectorInfo[i] & PVF_LAST) { 3396 NVPTXISD::NodeType Op; 3397 unsigned NumElts = StoreOperands.size() - 2; 3398 switch (NumElts) { 3399 case 1: 3400 Op = NVPTXISD::StoreRetval; 3401 break; 3402 case 2: 3403 Op = NVPTXISD::StoreRetvalV2; 3404 break; 3405 case 4: 3406 Op = NVPTXISD::StoreRetvalV4; 3407 break; 3408 default: 3409 llvm_unreachable("Invalid vector info."); 3410 } 3411 3412 // Adjust type of load/store op if we've extended the scalar 3413 // return value. 3414 EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[i]; 3415 Chain = DAG.getMemIntrinsicNode( 3416 Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType, 3417 MachinePointerInfo(), Align(1), MachineMemOperand::MOStore); 3418 // Cleanup vector state. 3419 StoreOperands.clear(); 3420 } 3421 } 3422 3423 return DAG.getNode(NVPTXISD::RET_GLUE, dl, MVT::Other, Chain); 3424 } 3425 3426 void NVPTXTargetLowering::LowerAsmOperandForConstraint( 3427 SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops, 3428 SelectionDAG &DAG) const { 3429 if (Constraint.size() > 1) 3430 return; 3431 TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG); 3432 } 3433 3434 // llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as 3435 // TgtMemIntrinsic 3436 // because we need the information that is only available in the "Value" type 3437 // of destination 3438 // pointer. In particular, the address space information. 3439 bool NVPTXTargetLowering::getTgtMemIntrinsic( 3440 IntrinsicInfo &Info, const CallInst &I, 3441 MachineFunction &MF, unsigned Intrinsic) const { 3442 switch (Intrinsic) { 3443 default: 3444 return false; 3445 case Intrinsic::nvvm_match_all_sync_i32p: 3446 case Intrinsic::nvvm_match_all_sync_i64p: 3447 Info.opc = ISD::INTRINSIC_W_CHAIN; 3448 // memVT is bogus. These intrinsics have IntrInaccessibleMemOnly attribute 3449 // in order to model data exchange with other threads, but perform no real 3450 // memory accesses. 3451 Info.memVT = MVT::i1; 3452 3453 // Our result depends on both our and other thread's arguments. 3454 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore; 3455 return true; 3456 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col: 3457 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row: 3458 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride: 3459 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride: 3460 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col: 3461 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row: 3462 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride: 3463 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride: 3464 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col: 3465 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row: 3466 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride: 3467 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride: 3468 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col: 3469 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row: 3470 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride: 3471 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride: 3472 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col: 3473 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row: 3474 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride: 3475 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride: 3476 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col: 3477 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row: 3478 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride: 3479 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: { 3480 Info.opc = ISD::INTRINSIC_W_CHAIN; 3481 Info.memVT = MVT::v8f16; 3482 Info.ptrVal = I.getArgOperand(0); 3483 Info.offset = 0; 3484 Info.flags = MachineMemOperand::MOLoad; 3485 Info.align = Align(16); 3486 return true; 3487 } 3488 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col: 3489 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride: 3490 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride: 3491 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col: 3492 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row: 3493 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride: 3494 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride: 3495 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row: 3496 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col: 3497 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride: 3498 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row: 3499 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride: 3500 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col: 3501 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride: 3502 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride: 3503 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col: 3504 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row: 3505 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride: 3506 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride: 3507 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row: 3508 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col: 3509 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride: 3510 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row: 3511 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: { 3512 Info.opc = ISD::INTRINSIC_W_CHAIN; 3513 Info.memVT = MVT::v2i32; 3514 Info.ptrVal = I.getArgOperand(0); 3515 Info.offset = 0; 3516 Info.flags = MachineMemOperand::MOLoad; 3517 Info.align = Align(8); 3518 return true; 3519 } 3520 3521 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col: 3522 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride: 3523 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride: 3524 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col: 3525 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row: 3526 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride: 3527 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride: 3528 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row: 3529 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col: 3530 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride: 3531 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row: 3532 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride: 3533 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col: 3534 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride: 3535 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row: 3536 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride: 3537 3538 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col: 3539 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride: 3540 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride: 3541 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col: 3542 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row: 3543 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride: 3544 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride: 3545 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row: 3546 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col: 3547 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride: 3548 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row: 3549 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride: 3550 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col: 3551 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride: 3552 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row: 3553 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride: 3554 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16: 3555 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16: { 3556 Info.opc = ISD::INTRINSIC_W_CHAIN; 3557 Info.memVT = MVT::v4i32; 3558 Info.ptrVal = I.getArgOperand(0); 3559 Info.offset = 0; 3560 Info.flags = MachineMemOperand::MOLoad; 3561 Info.align = Align(16); 3562 return true; 3563 } 3564 3565 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col: 3566 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride: 3567 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride: 3568 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col: 3569 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row: 3570 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride: 3571 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride: 3572 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row: 3573 3574 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col: 3575 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride: 3576 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride: 3577 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col: 3578 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row: 3579 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride: 3580 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride: 3581 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row: 3582 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row: 3583 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride: 3584 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col: 3585 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride: 3586 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row: 3587 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride: 3588 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride: 3589 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row: 3590 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col: 3591 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride: 3592 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride: 3593 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col: 3594 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16: 3595 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16: { 3596 Info.opc = ISD::INTRINSIC_W_CHAIN; 3597 Info.memVT = MVT::i32; 3598 Info.ptrVal = I.getArgOperand(0); 3599 Info.offset = 0; 3600 Info.flags = MachineMemOperand::MOLoad; 3601 Info.align = Align(4); 3602 return true; 3603 } 3604 3605 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col: 3606 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row: 3607 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride: 3608 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride: 3609 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col: 3610 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row: 3611 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride: 3612 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride: 3613 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col: 3614 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row: 3615 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride: 3616 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: { 3617 Info.opc = ISD::INTRINSIC_W_CHAIN; 3618 Info.memVT = MVT::v4f16; 3619 Info.ptrVal = I.getArgOperand(0); 3620 Info.offset = 0; 3621 Info.flags = MachineMemOperand::MOLoad; 3622 Info.align = Align(16); 3623 return true; 3624 } 3625 3626 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col: 3627 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row: 3628 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride: 3629 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride: 3630 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col: 3631 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row: 3632 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride: 3633 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride: 3634 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col: 3635 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row: 3636 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride: 3637 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride: 3638 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col: 3639 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row: 3640 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride: 3641 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: { 3642 Info.opc = ISD::INTRINSIC_W_CHAIN; 3643 Info.memVT = MVT::v8f32; 3644 Info.ptrVal = I.getArgOperand(0); 3645 Info.offset = 0; 3646 Info.flags = MachineMemOperand::MOLoad; 3647 Info.align = Align(16); 3648 return true; 3649 } 3650 3651 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col: 3652 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride: 3653 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row: 3654 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride: 3655 3656 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col: 3657 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride: 3658 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row: 3659 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride: 3660 3661 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col: 3662 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride: 3663 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row: 3664 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride: 3665 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col: 3666 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride: 3667 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row: 3668 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride: 3669 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col: 3670 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride: 3671 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row: 3672 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: { 3673 Info.opc = ISD::INTRINSIC_W_CHAIN; 3674 Info.memVT = MVT::v8i32; 3675 Info.ptrVal = I.getArgOperand(0); 3676 Info.offset = 0; 3677 Info.flags = MachineMemOperand::MOLoad; 3678 Info.align = Align(16); 3679 return true; 3680 } 3681 3682 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col: 3683 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride: 3684 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row: 3685 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride: 3686 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col: 3687 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride: 3688 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row: 3689 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride: 3690 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16: 3691 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16: { 3692 Info.opc = ISD::INTRINSIC_W_CHAIN; 3693 Info.memVT = MVT::v2i32; 3694 Info.ptrVal = I.getArgOperand(0); 3695 Info.offset = 0; 3696 Info.flags = MachineMemOperand::MOLoad; 3697 Info.align = Align(8); 3698 return true; 3699 } 3700 3701 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col: 3702 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride: 3703 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row: 3704 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride: 3705 3706 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col: 3707 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride: 3708 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row: 3709 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: { 3710 Info.opc = ISD::INTRINSIC_W_CHAIN; 3711 Info.memVT = MVT::f64; 3712 Info.ptrVal = I.getArgOperand(0); 3713 Info.offset = 0; 3714 Info.flags = MachineMemOperand::MOLoad; 3715 Info.align = Align(8); 3716 return true; 3717 } 3718 3719 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col: 3720 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride: 3721 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row: 3722 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: { 3723 Info.opc = ISD::INTRINSIC_W_CHAIN; 3724 Info.memVT = MVT::v2f64; 3725 Info.ptrVal = I.getArgOperand(0); 3726 Info.offset = 0; 3727 Info.flags = MachineMemOperand::MOLoad; 3728 Info.align = Align(16); 3729 return true; 3730 } 3731 3732 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col: 3733 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row: 3734 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride: 3735 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride: 3736 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col: 3737 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row: 3738 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride: 3739 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride: 3740 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col: 3741 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row: 3742 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride: 3743 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: { 3744 Info.opc = ISD::INTRINSIC_VOID; 3745 Info.memVT = MVT::v4f16; 3746 Info.ptrVal = I.getArgOperand(0); 3747 Info.offset = 0; 3748 Info.flags = MachineMemOperand::MOStore; 3749 Info.align = Align(16); 3750 return true; 3751 } 3752 3753 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col: 3754 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row: 3755 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride: 3756 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride: 3757 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col: 3758 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row: 3759 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride: 3760 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride: 3761 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col: 3762 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row: 3763 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride: 3764 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride: 3765 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col: 3766 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row: 3767 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride: 3768 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: { 3769 Info.opc = ISD::INTRINSIC_VOID; 3770 Info.memVT = MVT::v8f32; 3771 Info.ptrVal = I.getArgOperand(0); 3772 Info.offset = 0; 3773 Info.flags = MachineMemOperand::MOStore; 3774 Info.align = Align(16); 3775 return true; 3776 } 3777 3778 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col: 3779 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col_stride: 3780 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row: 3781 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row_stride: 3782 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col: 3783 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col_stride: 3784 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row: 3785 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row_stride: 3786 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col: 3787 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col_stride: 3788 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row: 3789 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row_stride: { 3790 Info.opc = ISD::INTRINSIC_VOID; 3791 Info.memVT = MVT::v8i32; 3792 Info.ptrVal = I.getArgOperand(0); 3793 Info.offset = 0; 3794 Info.flags = MachineMemOperand::MOStore; 3795 Info.align = Align(16); 3796 return true; 3797 } 3798 3799 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col: 3800 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col_stride: 3801 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row: 3802 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row_stride: 3803 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col: 3804 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride: 3805 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row: 3806 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: { 3807 Info.opc = ISD::INTRINSIC_VOID; 3808 Info.memVT = MVT::v2i32; 3809 Info.ptrVal = I.getArgOperand(0); 3810 Info.offset = 0; 3811 Info.flags = MachineMemOperand::MOStore; 3812 Info.align = Align(8); 3813 return true; 3814 } 3815 3816 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col: 3817 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride: 3818 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row: 3819 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: { 3820 Info.opc = ISD::INTRINSIC_VOID; 3821 Info.memVT = MVT::v2f64; 3822 Info.ptrVal = I.getArgOperand(0); 3823 Info.offset = 0; 3824 Info.flags = MachineMemOperand::MOStore; 3825 Info.align = Align(16); 3826 return true; 3827 } 3828 3829 case Intrinsic::nvvm_atomic_load_inc_32: 3830 case Intrinsic::nvvm_atomic_load_dec_32: 3831 3832 case Intrinsic::nvvm_atomic_add_gen_f_cta: 3833 case Intrinsic::nvvm_atomic_add_gen_f_sys: 3834 case Intrinsic::nvvm_atomic_add_gen_i_cta: 3835 case Intrinsic::nvvm_atomic_add_gen_i_sys: 3836 case Intrinsic::nvvm_atomic_and_gen_i_cta: 3837 case Intrinsic::nvvm_atomic_and_gen_i_sys: 3838 case Intrinsic::nvvm_atomic_cas_gen_i_cta: 3839 case Intrinsic::nvvm_atomic_cas_gen_i_sys: 3840 case Intrinsic::nvvm_atomic_dec_gen_i_cta: 3841 case Intrinsic::nvvm_atomic_dec_gen_i_sys: 3842 case Intrinsic::nvvm_atomic_inc_gen_i_cta: 3843 case Intrinsic::nvvm_atomic_inc_gen_i_sys: 3844 case Intrinsic::nvvm_atomic_max_gen_i_cta: 3845 case Intrinsic::nvvm_atomic_max_gen_i_sys: 3846 case Intrinsic::nvvm_atomic_min_gen_i_cta: 3847 case Intrinsic::nvvm_atomic_min_gen_i_sys: 3848 case Intrinsic::nvvm_atomic_or_gen_i_cta: 3849 case Intrinsic::nvvm_atomic_or_gen_i_sys: 3850 case Intrinsic::nvvm_atomic_exch_gen_i_cta: 3851 case Intrinsic::nvvm_atomic_exch_gen_i_sys: 3852 case Intrinsic::nvvm_atomic_xor_gen_i_cta: 3853 case Intrinsic::nvvm_atomic_xor_gen_i_sys: { 3854 auto &DL = I.getDataLayout(); 3855 Info.opc = ISD::INTRINSIC_W_CHAIN; 3856 Info.memVT = getValueType(DL, I.getType()); 3857 Info.ptrVal = I.getArgOperand(0); 3858 Info.offset = 0; 3859 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore; 3860 Info.align.reset(); 3861 return true; 3862 } 3863 3864 case Intrinsic::nvvm_ldu_global_i: 3865 case Intrinsic::nvvm_ldu_global_f: 3866 case Intrinsic::nvvm_ldu_global_p: { 3867 auto &DL = I.getDataLayout(); 3868 Info.opc = ISD::INTRINSIC_W_CHAIN; 3869 if (Intrinsic == Intrinsic::nvvm_ldu_global_i) 3870 Info.memVT = getValueType(DL, I.getType()); 3871 else if(Intrinsic == Intrinsic::nvvm_ldu_global_p) 3872 Info.memVT = getPointerTy(DL); 3873 else 3874 Info.memVT = getValueType(DL, I.getType()); 3875 Info.ptrVal = I.getArgOperand(0); 3876 Info.offset = 0; 3877 Info.flags = MachineMemOperand::MOLoad; 3878 Info.align = cast<ConstantInt>(I.getArgOperand(1))->getMaybeAlignValue(); 3879 3880 return true; 3881 } 3882 case Intrinsic::nvvm_tex_1d_v4f32_s32: 3883 case Intrinsic::nvvm_tex_1d_v4f32_f32: 3884 case Intrinsic::nvvm_tex_1d_level_v4f32_f32: 3885 case Intrinsic::nvvm_tex_1d_grad_v4f32_f32: 3886 case Intrinsic::nvvm_tex_1d_array_v4f32_s32: 3887 case Intrinsic::nvvm_tex_1d_array_v4f32_f32: 3888 case Intrinsic::nvvm_tex_1d_array_level_v4f32_f32: 3889 case Intrinsic::nvvm_tex_1d_array_grad_v4f32_f32: 3890 case Intrinsic::nvvm_tex_2d_v4f32_s32: 3891 case Intrinsic::nvvm_tex_2d_v4f32_f32: 3892 case Intrinsic::nvvm_tex_2d_level_v4f32_f32: 3893 case Intrinsic::nvvm_tex_2d_grad_v4f32_f32: 3894 case Intrinsic::nvvm_tex_2d_array_v4f32_s32: 3895 case Intrinsic::nvvm_tex_2d_array_v4f32_f32: 3896 case Intrinsic::nvvm_tex_2d_array_level_v4f32_f32: 3897 case Intrinsic::nvvm_tex_2d_array_grad_v4f32_f32: 3898 case Intrinsic::nvvm_tex_3d_v4f32_s32: 3899 case Intrinsic::nvvm_tex_3d_v4f32_f32: 3900 case Intrinsic::nvvm_tex_3d_level_v4f32_f32: 3901 case Intrinsic::nvvm_tex_3d_grad_v4f32_f32: 3902 case Intrinsic::nvvm_tex_cube_v4f32_f32: 3903 case Intrinsic::nvvm_tex_cube_level_v4f32_f32: 3904 case Intrinsic::nvvm_tex_cube_array_v4f32_f32: 3905 case Intrinsic::nvvm_tex_cube_array_level_v4f32_f32: 3906 case Intrinsic::nvvm_tld4_r_2d_v4f32_f32: 3907 case Intrinsic::nvvm_tld4_g_2d_v4f32_f32: 3908 case Intrinsic::nvvm_tld4_b_2d_v4f32_f32: 3909 case Intrinsic::nvvm_tld4_a_2d_v4f32_f32: 3910 case Intrinsic::nvvm_tex_unified_1d_v4f32_s32: 3911 case Intrinsic::nvvm_tex_unified_1d_v4f32_f32: 3912 case Intrinsic::nvvm_tex_unified_1d_level_v4f32_f32: 3913 case Intrinsic::nvvm_tex_unified_1d_grad_v4f32_f32: 3914 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_s32: 3915 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_f32: 3916 case Intrinsic::nvvm_tex_unified_1d_array_level_v4f32_f32: 3917 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4f32_f32: 3918 case Intrinsic::nvvm_tex_unified_2d_v4f32_s32: 3919 case Intrinsic::nvvm_tex_unified_2d_v4f32_f32: 3920 case Intrinsic::nvvm_tex_unified_2d_level_v4f32_f32: 3921 case Intrinsic::nvvm_tex_unified_2d_grad_v4f32_f32: 3922 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_s32: 3923 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_f32: 3924 case Intrinsic::nvvm_tex_unified_2d_array_level_v4f32_f32: 3925 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4f32_f32: 3926 case Intrinsic::nvvm_tex_unified_3d_v4f32_s32: 3927 case Intrinsic::nvvm_tex_unified_3d_v4f32_f32: 3928 case Intrinsic::nvvm_tex_unified_3d_level_v4f32_f32: 3929 case Intrinsic::nvvm_tex_unified_3d_grad_v4f32_f32: 3930 case Intrinsic::nvvm_tex_unified_cube_v4f32_f32: 3931 case Intrinsic::nvvm_tex_unified_cube_level_v4f32_f32: 3932 case Intrinsic::nvvm_tex_unified_cube_array_v4f32_f32: 3933 case Intrinsic::nvvm_tex_unified_cube_array_level_v4f32_f32: 3934 case Intrinsic::nvvm_tex_unified_cube_grad_v4f32_f32: 3935 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4f32_f32: 3936 case Intrinsic::nvvm_tld4_unified_r_2d_v4f32_f32: 3937 case Intrinsic::nvvm_tld4_unified_g_2d_v4f32_f32: 3938 case Intrinsic::nvvm_tld4_unified_b_2d_v4f32_f32: 3939 case Intrinsic::nvvm_tld4_unified_a_2d_v4f32_f32: 3940 Info.opc = ISD::INTRINSIC_W_CHAIN; 3941 Info.memVT = MVT::v4f32; 3942 Info.ptrVal = nullptr; 3943 Info.offset = 0; 3944 Info.flags = MachineMemOperand::MOLoad; 3945 Info.align = Align(16); 3946 return true; 3947 3948 case Intrinsic::nvvm_tex_1d_v4s32_s32: 3949 case Intrinsic::nvvm_tex_1d_v4s32_f32: 3950 case Intrinsic::nvvm_tex_1d_level_v4s32_f32: 3951 case Intrinsic::nvvm_tex_1d_grad_v4s32_f32: 3952 case Intrinsic::nvvm_tex_1d_array_v4s32_s32: 3953 case Intrinsic::nvvm_tex_1d_array_v4s32_f32: 3954 case Intrinsic::nvvm_tex_1d_array_level_v4s32_f32: 3955 case Intrinsic::nvvm_tex_1d_array_grad_v4s32_f32: 3956 case Intrinsic::nvvm_tex_2d_v4s32_s32: 3957 case Intrinsic::nvvm_tex_2d_v4s32_f32: 3958 case Intrinsic::nvvm_tex_2d_level_v4s32_f32: 3959 case Intrinsic::nvvm_tex_2d_grad_v4s32_f32: 3960 case Intrinsic::nvvm_tex_2d_array_v4s32_s32: 3961 case Intrinsic::nvvm_tex_2d_array_v4s32_f32: 3962 case Intrinsic::nvvm_tex_2d_array_level_v4s32_f32: 3963 case Intrinsic::nvvm_tex_2d_array_grad_v4s32_f32: 3964 case Intrinsic::nvvm_tex_3d_v4s32_s32: 3965 case Intrinsic::nvvm_tex_3d_v4s32_f32: 3966 case Intrinsic::nvvm_tex_3d_level_v4s32_f32: 3967 case Intrinsic::nvvm_tex_3d_grad_v4s32_f32: 3968 case Intrinsic::nvvm_tex_cube_v4s32_f32: 3969 case Intrinsic::nvvm_tex_cube_level_v4s32_f32: 3970 case Intrinsic::nvvm_tex_cube_array_v4s32_f32: 3971 case Intrinsic::nvvm_tex_cube_array_level_v4s32_f32: 3972 case Intrinsic::nvvm_tex_cube_v4u32_f32: 3973 case Intrinsic::nvvm_tex_cube_level_v4u32_f32: 3974 case Intrinsic::nvvm_tex_cube_array_v4u32_f32: 3975 case Intrinsic::nvvm_tex_cube_array_level_v4u32_f32: 3976 case Intrinsic::nvvm_tex_1d_v4u32_s32: 3977 case Intrinsic::nvvm_tex_1d_v4u32_f32: 3978 case Intrinsic::nvvm_tex_1d_level_v4u32_f32: 3979 case Intrinsic::nvvm_tex_1d_grad_v4u32_f32: 3980 case Intrinsic::nvvm_tex_1d_array_v4u32_s32: 3981 case Intrinsic::nvvm_tex_1d_array_v4u32_f32: 3982 case Intrinsic::nvvm_tex_1d_array_level_v4u32_f32: 3983 case Intrinsic::nvvm_tex_1d_array_grad_v4u32_f32: 3984 case Intrinsic::nvvm_tex_2d_v4u32_s32: 3985 case Intrinsic::nvvm_tex_2d_v4u32_f32: 3986 case Intrinsic::nvvm_tex_2d_level_v4u32_f32: 3987 case Intrinsic::nvvm_tex_2d_grad_v4u32_f32: 3988 case Intrinsic::nvvm_tex_2d_array_v4u32_s32: 3989 case Intrinsic::nvvm_tex_2d_array_v4u32_f32: 3990 case Intrinsic::nvvm_tex_2d_array_level_v4u32_f32: 3991 case Intrinsic::nvvm_tex_2d_array_grad_v4u32_f32: 3992 case Intrinsic::nvvm_tex_3d_v4u32_s32: 3993 case Intrinsic::nvvm_tex_3d_v4u32_f32: 3994 case Intrinsic::nvvm_tex_3d_level_v4u32_f32: 3995 case Intrinsic::nvvm_tex_3d_grad_v4u32_f32: 3996 case Intrinsic::nvvm_tld4_r_2d_v4s32_f32: 3997 case Intrinsic::nvvm_tld4_g_2d_v4s32_f32: 3998 case Intrinsic::nvvm_tld4_b_2d_v4s32_f32: 3999 case Intrinsic::nvvm_tld4_a_2d_v4s32_f32: 4000 case Intrinsic::nvvm_tld4_r_2d_v4u32_f32: 4001 case Intrinsic::nvvm_tld4_g_2d_v4u32_f32: 4002 case Intrinsic::nvvm_tld4_b_2d_v4u32_f32: 4003 case Intrinsic::nvvm_tld4_a_2d_v4u32_f32: 4004 case Intrinsic::nvvm_tex_unified_1d_v4s32_s32: 4005 case Intrinsic::nvvm_tex_unified_1d_v4s32_f32: 4006 case Intrinsic::nvvm_tex_unified_1d_level_v4s32_f32: 4007 case Intrinsic::nvvm_tex_unified_1d_grad_v4s32_f32: 4008 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_s32: 4009 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_f32: 4010 case Intrinsic::nvvm_tex_unified_1d_array_level_v4s32_f32: 4011 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4s32_f32: 4012 case Intrinsic::nvvm_tex_unified_2d_v4s32_s32: 4013 case Intrinsic::nvvm_tex_unified_2d_v4s32_f32: 4014 case Intrinsic::nvvm_tex_unified_2d_level_v4s32_f32: 4015 case Intrinsic::nvvm_tex_unified_2d_grad_v4s32_f32: 4016 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_s32: 4017 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_f32: 4018 case Intrinsic::nvvm_tex_unified_2d_array_level_v4s32_f32: 4019 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4s32_f32: 4020 case Intrinsic::nvvm_tex_unified_3d_v4s32_s32: 4021 case Intrinsic::nvvm_tex_unified_3d_v4s32_f32: 4022 case Intrinsic::nvvm_tex_unified_3d_level_v4s32_f32: 4023 case Intrinsic::nvvm_tex_unified_3d_grad_v4s32_f32: 4024 case Intrinsic::nvvm_tex_unified_1d_v4u32_s32: 4025 case Intrinsic::nvvm_tex_unified_1d_v4u32_f32: 4026 case Intrinsic::nvvm_tex_unified_1d_level_v4u32_f32: 4027 case Intrinsic::nvvm_tex_unified_1d_grad_v4u32_f32: 4028 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_s32: 4029 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_f32: 4030 case Intrinsic::nvvm_tex_unified_1d_array_level_v4u32_f32: 4031 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4u32_f32: 4032 case Intrinsic::nvvm_tex_unified_2d_v4u32_s32: 4033 case Intrinsic::nvvm_tex_unified_2d_v4u32_f32: 4034 case Intrinsic::nvvm_tex_unified_2d_level_v4u32_f32: 4035 case Intrinsic::nvvm_tex_unified_2d_grad_v4u32_f32: 4036 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_s32: 4037 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_f32: 4038 case Intrinsic::nvvm_tex_unified_2d_array_level_v4u32_f32: 4039 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4u32_f32: 4040 case Intrinsic::nvvm_tex_unified_3d_v4u32_s32: 4041 case Intrinsic::nvvm_tex_unified_3d_v4u32_f32: 4042 case Intrinsic::nvvm_tex_unified_3d_level_v4u32_f32: 4043 case Intrinsic::nvvm_tex_unified_3d_grad_v4u32_f32: 4044 case Intrinsic::nvvm_tex_unified_cube_v4s32_f32: 4045 case Intrinsic::nvvm_tex_unified_cube_level_v4s32_f32: 4046 case Intrinsic::nvvm_tex_unified_cube_array_v4s32_f32: 4047 case Intrinsic::nvvm_tex_unified_cube_array_level_v4s32_f32: 4048 case Intrinsic::nvvm_tex_unified_cube_v4u32_f32: 4049 case Intrinsic::nvvm_tex_unified_cube_level_v4u32_f32: 4050 case Intrinsic::nvvm_tex_unified_cube_array_v4u32_f32: 4051 case Intrinsic::nvvm_tex_unified_cube_array_level_v4u32_f32: 4052 case Intrinsic::nvvm_tex_unified_cube_grad_v4s32_f32: 4053 case Intrinsic::nvvm_tex_unified_cube_grad_v4u32_f32: 4054 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4s32_f32: 4055 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4u32_f32: 4056 case Intrinsic::nvvm_tld4_unified_r_2d_v4s32_f32: 4057 case Intrinsic::nvvm_tld4_unified_g_2d_v4s32_f32: 4058 case Intrinsic::nvvm_tld4_unified_b_2d_v4s32_f32: 4059 case Intrinsic::nvvm_tld4_unified_a_2d_v4s32_f32: 4060 case Intrinsic::nvvm_tld4_unified_r_2d_v4u32_f32: 4061 case Intrinsic::nvvm_tld4_unified_g_2d_v4u32_f32: 4062 case Intrinsic::nvvm_tld4_unified_b_2d_v4u32_f32: 4063 case Intrinsic::nvvm_tld4_unified_a_2d_v4u32_f32: 4064 Info.opc = ISD::INTRINSIC_W_CHAIN; 4065 Info.memVT = MVT::v4i32; 4066 Info.ptrVal = nullptr; 4067 Info.offset = 0; 4068 Info.flags = MachineMemOperand::MOLoad; 4069 Info.align = Align(16); 4070 return true; 4071 4072 case Intrinsic::nvvm_suld_1d_i8_clamp: 4073 case Intrinsic::nvvm_suld_1d_v2i8_clamp: 4074 case Intrinsic::nvvm_suld_1d_v4i8_clamp: 4075 case Intrinsic::nvvm_suld_1d_array_i8_clamp: 4076 case Intrinsic::nvvm_suld_1d_array_v2i8_clamp: 4077 case Intrinsic::nvvm_suld_1d_array_v4i8_clamp: 4078 case Intrinsic::nvvm_suld_2d_i8_clamp: 4079 case Intrinsic::nvvm_suld_2d_v2i8_clamp: 4080 case Intrinsic::nvvm_suld_2d_v4i8_clamp: 4081 case Intrinsic::nvvm_suld_2d_array_i8_clamp: 4082 case Intrinsic::nvvm_suld_2d_array_v2i8_clamp: 4083 case Intrinsic::nvvm_suld_2d_array_v4i8_clamp: 4084 case Intrinsic::nvvm_suld_3d_i8_clamp: 4085 case Intrinsic::nvvm_suld_3d_v2i8_clamp: 4086 case Intrinsic::nvvm_suld_3d_v4i8_clamp: 4087 case Intrinsic::nvvm_suld_1d_i8_trap: 4088 case Intrinsic::nvvm_suld_1d_v2i8_trap: 4089 case Intrinsic::nvvm_suld_1d_v4i8_trap: 4090 case Intrinsic::nvvm_suld_1d_array_i8_trap: 4091 case Intrinsic::nvvm_suld_1d_array_v2i8_trap: 4092 case Intrinsic::nvvm_suld_1d_array_v4i8_trap: 4093 case Intrinsic::nvvm_suld_2d_i8_trap: 4094 case Intrinsic::nvvm_suld_2d_v2i8_trap: 4095 case Intrinsic::nvvm_suld_2d_v4i8_trap: 4096 case Intrinsic::nvvm_suld_2d_array_i8_trap: 4097 case Intrinsic::nvvm_suld_2d_array_v2i8_trap: 4098 case Intrinsic::nvvm_suld_2d_array_v4i8_trap: 4099 case Intrinsic::nvvm_suld_3d_i8_trap: 4100 case Intrinsic::nvvm_suld_3d_v2i8_trap: 4101 case Intrinsic::nvvm_suld_3d_v4i8_trap: 4102 case Intrinsic::nvvm_suld_1d_i8_zero: 4103 case Intrinsic::nvvm_suld_1d_v2i8_zero: 4104 case Intrinsic::nvvm_suld_1d_v4i8_zero: 4105 case Intrinsic::nvvm_suld_1d_array_i8_zero: 4106 case Intrinsic::nvvm_suld_1d_array_v2i8_zero: 4107 case Intrinsic::nvvm_suld_1d_array_v4i8_zero: 4108 case Intrinsic::nvvm_suld_2d_i8_zero: 4109 case Intrinsic::nvvm_suld_2d_v2i8_zero: 4110 case Intrinsic::nvvm_suld_2d_v4i8_zero: 4111 case Intrinsic::nvvm_suld_2d_array_i8_zero: 4112 case Intrinsic::nvvm_suld_2d_array_v2i8_zero: 4113 case Intrinsic::nvvm_suld_2d_array_v4i8_zero: 4114 case Intrinsic::nvvm_suld_3d_i8_zero: 4115 case Intrinsic::nvvm_suld_3d_v2i8_zero: 4116 case Intrinsic::nvvm_suld_3d_v4i8_zero: 4117 Info.opc = ISD::INTRINSIC_W_CHAIN; 4118 Info.memVT = MVT::i8; 4119 Info.ptrVal = nullptr; 4120 Info.offset = 0; 4121 Info.flags = MachineMemOperand::MOLoad; 4122 Info.align = Align(16); 4123 return true; 4124 4125 case Intrinsic::nvvm_suld_1d_i16_clamp: 4126 case Intrinsic::nvvm_suld_1d_v2i16_clamp: 4127 case Intrinsic::nvvm_suld_1d_v4i16_clamp: 4128 case Intrinsic::nvvm_suld_1d_array_i16_clamp: 4129 case Intrinsic::nvvm_suld_1d_array_v2i16_clamp: 4130 case Intrinsic::nvvm_suld_1d_array_v4i16_clamp: 4131 case Intrinsic::nvvm_suld_2d_i16_clamp: 4132 case Intrinsic::nvvm_suld_2d_v2i16_clamp: 4133 case Intrinsic::nvvm_suld_2d_v4i16_clamp: 4134 case Intrinsic::nvvm_suld_2d_array_i16_clamp: 4135 case Intrinsic::nvvm_suld_2d_array_v2i16_clamp: 4136 case Intrinsic::nvvm_suld_2d_array_v4i16_clamp: 4137 case Intrinsic::nvvm_suld_3d_i16_clamp: 4138 case Intrinsic::nvvm_suld_3d_v2i16_clamp: 4139 case Intrinsic::nvvm_suld_3d_v4i16_clamp: 4140 case Intrinsic::nvvm_suld_1d_i16_trap: 4141 case Intrinsic::nvvm_suld_1d_v2i16_trap: 4142 case Intrinsic::nvvm_suld_1d_v4i16_trap: 4143 case Intrinsic::nvvm_suld_1d_array_i16_trap: 4144 case Intrinsic::nvvm_suld_1d_array_v2i16_trap: 4145 case Intrinsic::nvvm_suld_1d_array_v4i16_trap: 4146 case Intrinsic::nvvm_suld_2d_i16_trap: 4147 case Intrinsic::nvvm_suld_2d_v2i16_trap: 4148 case Intrinsic::nvvm_suld_2d_v4i16_trap: 4149 case Intrinsic::nvvm_suld_2d_array_i16_trap: 4150 case Intrinsic::nvvm_suld_2d_array_v2i16_trap: 4151 case Intrinsic::nvvm_suld_2d_array_v4i16_trap: 4152 case Intrinsic::nvvm_suld_3d_i16_trap: 4153 case Intrinsic::nvvm_suld_3d_v2i16_trap: 4154 case Intrinsic::nvvm_suld_3d_v4i16_trap: 4155 case Intrinsic::nvvm_suld_1d_i16_zero: 4156 case Intrinsic::nvvm_suld_1d_v2i16_zero: 4157 case Intrinsic::nvvm_suld_1d_v4i16_zero: 4158 case Intrinsic::nvvm_suld_1d_array_i16_zero: 4159 case Intrinsic::nvvm_suld_1d_array_v2i16_zero: 4160 case Intrinsic::nvvm_suld_1d_array_v4i16_zero: 4161 case Intrinsic::nvvm_suld_2d_i16_zero: 4162 case Intrinsic::nvvm_suld_2d_v2i16_zero: 4163 case Intrinsic::nvvm_suld_2d_v4i16_zero: 4164 case Intrinsic::nvvm_suld_2d_array_i16_zero: 4165 case Intrinsic::nvvm_suld_2d_array_v2i16_zero: 4166 case Intrinsic::nvvm_suld_2d_array_v4i16_zero: 4167 case Intrinsic::nvvm_suld_3d_i16_zero: 4168 case Intrinsic::nvvm_suld_3d_v2i16_zero: 4169 case Intrinsic::nvvm_suld_3d_v4i16_zero: 4170 Info.opc = ISD::INTRINSIC_W_CHAIN; 4171 Info.memVT = MVT::i16; 4172 Info.ptrVal = nullptr; 4173 Info.offset = 0; 4174 Info.flags = MachineMemOperand::MOLoad; 4175 Info.align = Align(16); 4176 return true; 4177 4178 case Intrinsic::nvvm_suld_1d_i32_clamp: 4179 case Intrinsic::nvvm_suld_1d_v2i32_clamp: 4180 case Intrinsic::nvvm_suld_1d_v4i32_clamp: 4181 case Intrinsic::nvvm_suld_1d_array_i32_clamp: 4182 case Intrinsic::nvvm_suld_1d_array_v2i32_clamp: 4183 case Intrinsic::nvvm_suld_1d_array_v4i32_clamp: 4184 case Intrinsic::nvvm_suld_2d_i32_clamp: 4185 case Intrinsic::nvvm_suld_2d_v2i32_clamp: 4186 case Intrinsic::nvvm_suld_2d_v4i32_clamp: 4187 case Intrinsic::nvvm_suld_2d_array_i32_clamp: 4188 case Intrinsic::nvvm_suld_2d_array_v2i32_clamp: 4189 case Intrinsic::nvvm_suld_2d_array_v4i32_clamp: 4190 case Intrinsic::nvvm_suld_3d_i32_clamp: 4191 case Intrinsic::nvvm_suld_3d_v2i32_clamp: 4192 case Intrinsic::nvvm_suld_3d_v4i32_clamp: 4193 case Intrinsic::nvvm_suld_1d_i32_trap: 4194 case Intrinsic::nvvm_suld_1d_v2i32_trap: 4195 case Intrinsic::nvvm_suld_1d_v4i32_trap: 4196 case Intrinsic::nvvm_suld_1d_array_i32_trap: 4197 case Intrinsic::nvvm_suld_1d_array_v2i32_trap: 4198 case Intrinsic::nvvm_suld_1d_array_v4i32_trap: 4199 case Intrinsic::nvvm_suld_2d_i32_trap: 4200 case Intrinsic::nvvm_suld_2d_v2i32_trap: 4201 case Intrinsic::nvvm_suld_2d_v4i32_trap: 4202 case Intrinsic::nvvm_suld_2d_array_i32_trap: 4203 case Intrinsic::nvvm_suld_2d_array_v2i32_trap: 4204 case Intrinsic::nvvm_suld_2d_array_v4i32_trap: 4205 case Intrinsic::nvvm_suld_3d_i32_trap: 4206 case Intrinsic::nvvm_suld_3d_v2i32_trap: 4207 case Intrinsic::nvvm_suld_3d_v4i32_trap: 4208 case Intrinsic::nvvm_suld_1d_i32_zero: 4209 case Intrinsic::nvvm_suld_1d_v2i32_zero: 4210 case Intrinsic::nvvm_suld_1d_v4i32_zero: 4211 case Intrinsic::nvvm_suld_1d_array_i32_zero: 4212 case Intrinsic::nvvm_suld_1d_array_v2i32_zero: 4213 case Intrinsic::nvvm_suld_1d_array_v4i32_zero: 4214 case Intrinsic::nvvm_suld_2d_i32_zero: 4215 case Intrinsic::nvvm_suld_2d_v2i32_zero: 4216 case Intrinsic::nvvm_suld_2d_v4i32_zero: 4217 case Intrinsic::nvvm_suld_2d_array_i32_zero: 4218 case Intrinsic::nvvm_suld_2d_array_v2i32_zero: 4219 case Intrinsic::nvvm_suld_2d_array_v4i32_zero: 4220 case Intrinsic::nvvm_suld_3d_i32_zero: 4221 case Intrinsic::nvvm_suld_3d_v2i32_zero: 4222 case Intrinsic::nvvm_suld_3d_v4i32_zero: 4223 Info.opc = ISD::INTRINSIC_W_CHAIN; 4224 Info.memVT = MVT::i32; 4225 Info.ptrVal = nullptr; 4226 Info.offset = 0; 4227 Info.flags = MachineMemOperand::MOLoad; 4228 Info.align = Align(16); 4229 return true; 4230 4231 case Intrinsic::nvvm_suld_1d_i64_clamp: 4232 case Intrinsic::nvvm_suld_1d_v2i64_clamp: 4233 case Intrinsic::nvvm_suld_1d_array_i64_clamp: 4234 case Intrinsic::nvvm_suld_1d_array_v2i64_clamp: 4235 case Intrinsic::nvvm_suld_2d_i64_clamp: 4236 case Intrinsic::nvvm_suld_2d_v2i64_clamp: 4237 case Intrinsic::nvvm_suld_2d_array_i64_clamp: 4238 case Intrinsic::nvvm_suld_2d_array_v2i64_clamp: 4239 case Intrinsic::nvvm_suld_3d_i64_clamp: 4240 case Intrinsic::nvvm_suld_3d_v2i64_clamp: 4241 case Intrinsic::nvvm_suld_1d_i64_trap: 4242 case Intrinsic::nvvm_suld_1d_v2i64_trap: 4243 case Intrinsic::nvvm_suld_1d_array_i64_trap: 4244 case Intrinsic::nvvm_suld_1d_array_v2i64_trap: 4245 case Intrinsic::nvvm_suld_2d_i64_trap: 4246 case Intrinsic::nvvm_suld_2d_v2i64_trap: 4247 case Intrinsic::nvvm_suld_2d_array_i64_trap: 4248 case Intrinsic::nvvm_suld_2d_array_v2i64_trap: 4249 case Intrinsic::nvvm_suld_3d_i64_trap: 4250 case Intrinsic::nvvm_suld_3d_v2i64_trap: 4251 case Intrinsic::nvvm_suld_1d_i64_zero: 4252 case Intrinsic::nvvm_suld_1d_v2i64_zero: 4253 case Intrinsic::nvvm_suld_1d_array_i64_zero: 4254 case Intrinsic::nvvm_suld_1d_array_v2i64_zero: 4255 case Intrinsic::nvvm_suld_2d_i64_zero: 4256 case Intrinsic::nvvm_suld_2d_v2i64_zero: 4257 case Intrinsic::nvvm_suld_2d_array_i64_zero: 4258 case Intrinsic::nvvm_suld_2d_array_v2i64_zero: 4259 case Intrinsic::nvvm_suld_3d_i64_zero: 4260 case Intrinsic::nvvm_suld_3d_v2i64_zero: 4261 Info.opc = ISD::INTRINSIC_W_CHAIN; 4262 Info.memVT = MVT::i64; 4263 Info.ptrVal = nullptr; 4264 Info.offset = 0; 4265 Info.flags = MachineMemOperand::MOLoad; 4266 Info.align = Align(16); 4267 return true; 4268 } 4269 return false; 4270 } 4271 4272 /// getFunctionParamOptimizedAlign - since function arguments are passed via 4273 /// .param space, we may want to increase their alignment in a way that 4274 /// ensures that we can effectively vectorize their loads & stores. We can 4275 /// increase alignment only if the function has internal or has private 4276 /// linkage as for other linkage types callers may already rely on default 4277 /// alignment. To allow using 128-bit vectorized loads/stores, this function 4278 /// ensures that alignment is 16 or greater. 4279 Align NVPTXTargetLowering::getFunctionParamOptimizedAlign( 4280 const Function *F, Type *ArgTy, const DataLayout &DL) const { 4281 // Capping the alignment to 128 bytes as that is the maximum alignment 4282 // supported by PTX. 4283 const Align ABITypeAlign = std::min(Align(128), DL.getABITypeAlign(ArgTy)); 4284 4285 // If a function has linkage different from internal or private, we 4286 // must use default ABI alignment as external users rely on it. Same 4287 // for a function that may be called from a function pointer. 4288 if (!F || !F->hasLocalLinkage() || 4289 F->hasAddressTaken(/*Users=*/nullptr, 4290 /*IgnoreCallbackUses=*/false, 4291 /*IgnoreAssumeLikeCalls=*/true, 4292 /*IgnoreLLVMUsed=*/true)) 4293 return ABITypeAlign; 4294 4295 assert(!isKernelFunction(*F) && "Expect kernels to have non-local linkage"); 4296 return std::max(Align(16), ABITypeAlign); 4297 } 4298 4299 /// Helper for computing alignment of a device function byval parameter. 4300 Align NVPTXTargetLowering::getFunctionByValParamAlign( 4301 const Function *F, Type *ArgTy, Align InitialAlign, 4302 const DataLayout &DL) const { 4303 Align ArgAlign = InitialAlign; 4304 // Try to increase alignment to enhance vectorization options. 4305 if (F) 4306 ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(F, ArgTy, DL)); 4307 4308 // Old ptx versions have a bug. When PTX code takes address of 4309 // byval parameter with alignment < 4, ptxas generates code to 4310 // spill argument into memory. Alas on sm_50+ ptxas generates 4311 // SASS code that fails with misaligned access. To work around 4312 // the problem, make sure that we align byval parameters by at 4313 // least 4. This bug seems to be fixed at least starting from 4314 // ptxas > 9.0. 4315 // TODO: remove this after verifying the bug is not reproduced 4316 // on non-deprecated ptxas versions. 4317 if (ForceMinByValParamAlign) 4318 ArgAlign = std::max(ArgAlign, Align(4)); 4319 4320 return ArgAlign; 4321 } 4322 4323 // Helper for getting a function parameter name. Name is composed from 4324 // its index and the function name. Negative index corresponds to special 4325 // parameter (unsized array) used for passing variable arguments. 4326 std::string NVPTXTargetLowering::getParamName(const Function *F, 4327 int Idx) const { 4328 std::string ParamName; 4329 raw_string_ostream ParamStr(ParamName); 4330 4331 ParamStr << getTargetMachine().getSymbol(F)->getName(); 4332 if (Idx < 0) 4333 ParamStr << "_vararg"; 4334 else 4335 ParamStr << "_param_" << Idx; 4336 4337 return ParamName; 4338 } 4339 4340 /// isLegalAddressingMode - Return true if the addressing mode represented 4341 /// by AM is legal for this target, for a load/store of the specified type. 4342 /// Used to guide target specific optimizations, like loop strength reduction 4343 /// (LoopStrengthReduce.cpp) and memory optimization for address mode 4344 /// (CodeGenPrepare.cpp) 4345 bool NVPTXTargetLowering::isLegalAddressingMode(const DataLayout &DL, 4346 const AddrMode &AM, Type *Ty, 4347 unsigned AS, Instruction *I) const { 4348 // AddrMode - This represents an addressing mode of: 4349 // BaseGV + BaseOffs + BaseReg + Scale*ScaleReg 4350 // 4351 // The legal address modes are 4352 // - [avar] 4353 // - [areg] 4354 // - [areg+immoff] 4355 // - [immAddr] 4356 4357 // immoff must fit in a signed 32-bit int 4358 if (!APInt(64, AM.BaseOffs).isSignedIntN(32)) 4359 return false; 4360 4361 if (AM.BaseGV) 4362 return !AM.BaseOffs && !AM.HasBaseReg && !AM.Scale; 4363 4364 switch (AM.Scale) { 4365 case 0: // "r", "r+i" or "i" is allowed 4366 break; 4367 case 1: 4368 if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed. 4369 return false; 4370 // Otherwise we have r+i. 4371 break; 4372 default: 4373 // No scale > 1 is allowed 4374 return false; 4375 } 4376 return true; 4377 } 4378 4379 //===----------------------------------------------------------------------===// 4380 // NVPTX Inline Assembly Support 4381 //===----------------------------------------------------------------------===// 4382 4383 /// getConstraintType - Given a constraint letter, return the type of 4384 /// constraint it is for this target. 4385 NVPTXTargetLowering::ConstraintType 4386 NVPTXTargetLowering::getConstraintType(StringRef Constraint) const { 4387 if (Constraint.size() == 1) { 4388 switch (Constraint[0]) { 4389 default: 4390 break; 4391 case 'b': 4392 case 'r': 4393 case 'h': 4394 case 'c': 4395 case 'l': 4396 case 'f': 4397 case 'd': 4398 case 'q': 4399 case '0': 4400 case 'N': 4401 return C_RegisterClass; 4402 } 4403 } 4404 return TargetLowering::getConstraintType(Constraint); 4405 } 4406 4407 std::pair<unsigned, const TargetRegisterClass *> 4408 NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, 4409 StringRef Constraint, 4410 MVT VT) const { 4411 if (Constraint.size() == 1) { 4412 switch (Constraint[0]) { 4413 case 'b': 4414 return std::make_pair(0U, &NVPTX::Int1RegsRegClass); 4415 case 'c': 4416 return std::make_pair(0U, &NVPTX::Int16RegsRegClass); 4417 case 'h': 4418 return std::make_pair(0U, &NVPTX::Int16RegsRegClass); 4419 case 'r': 4420 return std::make_pair(0U, &NVPTX::Int32RegsRegClass); 4421 case 'l': 4422 case 'N': 4423 return std::make_pair(0U, &NVPTX::Int64RegsRegClass); 4424 case 'q': { 4425 if (STI.getSmVersion() < 70) 4426 report_fatal_error("Inline asm with 128 bit operands is only " 4427 "supported for sm_70 and higher!"); 4428 return std::make_pair(0U, &NVPTX::Int128RegsRegClass); 4429 } 4430 case 'f': 4431 return std::make_pair(0U, &NVPTX::Float32RegsRegClass); 4432 case 'd': 4433 return std::make_pair(0U, &NVPTX::Float64RegsRegClass); 4434 } 4435 } 4436 return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT); 4437 } 4438 4439 //===----------------------------------------------------------------------===// 4440 // NVPTX DAG Combining 4441 //===----------------------------------------------------------------------===// 4442 4443 bool NVPTXTargetLowering::allowFMA(MachineFunction &MF, 4444 CodeGenOptLevel OptLevel) const { 4445 // Always honor command-line argument 4446 if (FMAContractLevelOpt.getNumOccurrences() > 0) 4447 return FMAContractLevelOpt > 0; 4448 4449 // Do not contract if we're not optimizing the code. 4450 if (OptLevel == CodeGenOptLevel::None) 4451 return false; 4452 4453 // Honor TargetOptions flags that explicitly say fusion is okay. 4454 if (MF.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast) 4455 return true; 4456 4457 return allowUnsafeFPMath(MF); 4458 } 4459 4460 bool NVPTXTargetLowering::allowUnsafeFPMath(MachineFunction &MF) const { 4461 // Honor TargetOptions flags that explicitly say unsafe math is okay. 4462 if (MF.getTarget().Options.UnsafeFPMath) 4463 return true; 4464 4465 // Allow unsafe math if unsafe-fp-math attribute explicitly says so. 4466 const Function &F = MF.getFunction(); 4467 return F.getFnAttribute("unsafe-fp-math").getValueAsBool(); 4468 } 4469 4470 static bool isConstZero(const SDValue &Operand) { 4471 const auto *Const = dyn_cast<ConstantSDNode>(Operand); 4472 return Const && Const->getZExtValue() == 0; 4473 } 4474 4475 /// PerformADDCombineWithOperands - Try DAG combinations for an ADD with 4476 /// operands N0 and N1. This is a helper for PerformADDCombine that is 4477 /// called with the default operands, and if that fails, with commuted 4478 /// operands. 4479 static SDValue 4480 PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, 4481 TargetLowering::DAGCombinerInfo &DCI) { 4482 EVT VT = N0.getValueType(); 4483 4484 // Since integer multiply-add costs the same as integer multiply 4485 // but is more costly than integer add, do the fusion only when 4486 // the mul is only used in the add. 4487 // TODO: this may not be true for later architectures, consider relaxing this 4488 if (!N0.getNode()->hasOneUse()) 4489 return SDValue(); 4490 4491 // fold (add (select cond, 0, (mul a, b)), c) 4492 // -> (select cond, c, (add (mul a, b), c)) 4493 // 4494 if (N0.getOpcode() == ISD::SELECT) { 4495 unsigned ZeroOpNum; 4496 if (isConstZero(N0->getOperand(1))) 4497 ZeroOpNum = 1; 4498 else if (isConstZero(N0->getOperand(2))) 4499 ZeroOpNum = 2; 4500 else 4501 return SDValue(); 4502 4503 SDValue M = N0->getOperand((ZeroOpNum == 1) ? 2 : 1); 4504 if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse()) 4505 return SDValue(); 4506 4507 SDLoc DL(N); 4508 SDValue Mul = 4509 DCI.DAG.getNode(ISD::MUL, DL, VT, M->getOperand(0), M->getOperand(1)); 4510 SDValue MAD = DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, N1); 4511 return DCI.DAG.getSelect(SDLoc(N), VT, N0->getOperand(0), 4512 ((ZeroOpNum == 1) ? N1 : MAD), 4513 ((ZeroOpNum == 1) ? MAD : N1)); 4514 } 4515 4516 return SDValue(); 4517 } 4518 4519 static SDValue 4520 PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, 4521 TargetLowering::DAGCombinerInfo &DCI, 4522 CodeGenOptLevel OptLevel) { 4523 EVT VT = N0.getValueType(); 4524 if (N0.getOpcode() == ISD::FMUL) { 4525 const auto *TLI = static_cast<const NVPTXTargetLowering *>( 4526 &DCI.DAG.getTargetLoweringInfo()); 4527 if (!TLI->allowFMA(DCI.DAG.getMachineFunction(), OptLevel)) 4528 return SDValue(); 4529 4530 // For floating point: 4531 // Do the fusion only when the mul has less than 5 uses and all 4532 // are add. 4533 // The heuristic is that if a use is not an add, then that use 4534 // cannot be fused into fma, therefore mul is still needed anyway. 4535 // If there are more than 4 uses, even if they are all add, fusing 4536 // them will increase register pressue. 4537 // 4538 int numUses = 0; 4539 int nonAddCount = 0; 4540 for (const SDNode *User : N0.getNode()->users()) { 4541 numUses++; 4542 if (User->getOpcode() != ISD::FADD) 4543 ++nonAddCount; 4544 if (numUses >= 5) 4545 return SDValue(); 4546 } 4547 if (nonAddCount) { 4548 int orderNo = N->getIROrder(); 4549 int orderNo2 = N0.getNode()->getIROrder(); 4550 // simple heuristics here for considering potential register 4551 // pressure, the logics here is that the differnce are used 4552 // to measure the distance between def and use, the longer distance 4553 // more likely cause register pressure. 4554 if (orderNo - orderNo2 < 500) 4555 return SDValue(); 4556 4557 // Now, check if at least one of the FMUL's operands is live beyond the 4558 // node N, which guarantees that the FMA will not increase register 4559 // pressure at node N. 4560 bool opIsLive = false; 4561 const SDNode *left = N0.getOperand(0).getNode(); 4562 const SDNode *right = N0.getOperand(1).getNode(); 4563 4564 if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right)) 4565 opIsLive = true; 4566 4567 if (!opIsLive) 4568 for (const SDNode *User : left->users()) { 4569 int orderNo3 = User->getIROrder(); 4570 if (orderNo3 > orderNo) { 4571 opIsLive = true; 4572 break; 4573 } 4574 } 4575 4576 if (!opIsLive) 4577 for (const SDNode *User : right->users()) { 4578 int orderNo3 = User->getIROrder(); 4579 if (orderNo3 > orderNo) { 4580 opIsLive = true; 4581 break; 4582 } 4583 } 4584 4585 if (!opIsLive) 4586 return SDValue(); 4587 } 4588 4589 return DCI.DAG.getNode(ISD::FMA, SDLoc(N), VT, N0.getOperand(0), 4590 N0.getOperand(1), N1); 4591 } 4592 4593 return SDValue(); 4594 } 4595 4596 static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front, 4597 std::size_t Back) { 4598 if (all_of(N->ops().drop_front(Front).drop_back(Back), 4599 [](const SDUse &U) { return U.get()->isUndef(); })) 4600 // Operand 0 is the previous value in the chain. Cannot return EntryToken 4601 // as the previous value will become unused and eliminated later. 4602 return N->getOperand(0); 4603 4604 return SDValue(); 4605 } 4606 4607 static SDValue PerformStoreParamCombine(SDNode *N) { 4608 // Operands from the 3rd to the 2nd last one are the values to be stored. 4609 // {Chain, ArgID, Offset, Val, Glue} 4610 return PerformStoreCombineHelper(N, 3, 1); 4611 } 4612 4613 static SDValue PerformStoreRetvalCombine(SDNode *N) { 4614 // Operands from the 2nd to the last one are the values to be stored 4615 return PerformStoreCombineHelper(N, 2, 0); 4616 } 4617 4618 /// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD. 4619 /// 4620 static SDValue PerformADDCombine(SDNode *N, 4621 TargetLowering::DAGCombinerInfo &DCI, 4622 CodeGenOptLevel OptLevel) { 4623 if (OptLevel == CodeGenOptLevel::None) 4624 return SDValue(); 4625 4626 SDValue N0 = N->getOperand(0); 4627 SDValue N1 = N->getOperand(1); 4628 4629 // Skip non-integer, non-scalar case 4630 EVT VT = N0.getValueType(); 4631 if (VT.isVector() || VT != MVT::i32) 4632 return SDValue(); 4633 4634 // First try with the default operand order. 4635 if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI)) 4636 return Result; 4637 4638 // If that didn't work, try again with the operands commuted. 4639 return PerformADDCombineWithOperands(N, N1, N0, DCI); 4640 } 4641 4642 /// PerformFADDCombine - Target-specific dag combine xforms for ISD::FADD. 4643 /// 4644 static SDValue PerformFADDCombine(SDNode *N, 4645 TargetLowering::DAGCombinerInfo &DCI, 4646 CodeGenOptLevel OptLevel) { 4647 SDValue N0 = N->getOperand(0); 4648 SDValue N1 = N->getOperand(1); 4649 4650 EVT VT = N0.getValueType(); 4651 if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64)) 4652 return SDValue(); 4653 4654 // First try with the default operand order. 4655 if (SDValue Result = PerformFADDCombineWithOperands(N, N0, N1, DCI, OptLevel)) 4656 return Result; 4657 4658 // If that didn't work, try again with the operands commuted. 4659 return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel); 4660 } 4661 4662 static SDValue PerformANDCombine(SDNode *N, 4663 TargetLowering::DAGCombinerInfo &DCI) { 4664 // The type legalizer turns a vector load of i8 values into a zextload to i16 4665 // registers, optionally ANY_EXTENDs it (if target type is integer), 4666 // and ANDs off the high 8 bits. Since we turn this load into a 4667 // target-specific DAG node, the DAG combiner fails to eliminate these AND 4668 // nodes. Do that here. 4669 SDValue Val = N->getOperand(0); 4670 SDValue Mask = N->getOperand(1); 4671 4672 if (isa<ConstantSDNode>(Val)) { 4673 std::swap(Val, Mask); 4674 } 4675 4676 SDValue AExt; 4677 4678 // Convert BFE-> truncate i16 -> and 255 4679 // To just BFE-> truncate i16, as the value already has all the bits in the 4680 // right places. 4681 if (Val.getOpcode() == ISD::TRUNCATE) { 4682 SDValue BFE = Val.getOperand(0); 4683 if (BFE.getOpcode() != NVPTXISD::BFE) 4684 return SDValue(); 4685 4686 ConstantSDNode *BFEBits = dyn_cast<ConstantSDNode>(BFE.getOperand(0)); 4687 if (!BFEBits) 4688 return SDValue(); 4689 uint64_t BFEBitsVal = BFEBits->getZExtValue(); 4690 4691 ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask); 4692 if (!MaskCnst) { 4693 // Not an AND with a constant 4694 return SDValue(); 4695 } 4696 uint64_t MaskVal = MaskCnst->getZExtValue(); 4697 4698 if (MaskVal != (uint64_t(1) << BFEBitsVal) - 1) 4699 return SDValue(); 4700 // If we get here, the AND is unnecessary. Just replace it with the trunc 4701 DCI.CombineTo(N, Val, false); 4702 } 4703 // Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and 4704 if (Val.getOpcode() == ISD::ANY_EXTEND) { 4705 AExt = Val; 4706 Val = Val->getOperand(0); 4707 } 4708 4709 if (Val->getOpcode() == NVPTXISD::LoadV2 || 4710 Val->getOpcode() == NVPTXISD::LoadV4) { 4711 ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask); 4712 if (!MaskCnst) { 4713 // Not an AND with a constant 4714 return SDValue(); 4715 } 4716 4717 uint64_t MaskVal = MaskCnst->getZExtValue(); 4718 if (MaskVal != 0xff) { 4719 // Not an AND that chops off top 8 bits 4720 return SDValue(); 4721 } 4722 4723 MemSDNode *Mem = dyn_cast<MemSDNode>(Val); 4724 if (!Mem) { 4725 // Not a MemSDNode?!? 4726 return SDValue(); 4727 } 4728 4729 EVT MemVT = Mem->getMemoryVT(); 4730 if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) { 4731 // We only handle the i8 case 4732 return SDValue(); 4733 } 4734 4735 unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1); 4736 if (ExtType == ISD::SEXTLOAD) { 4737 // If for some reason the load is a sextload, the and is needed to zero 4738 // out the high 8 bits 4739 return SDValue(); 4740 } 4741 4742 bool AddTo = false; 4743 if (AExt.getNode() != nullptr) { 4744 // Re-insert the ext as a zext. 4745 Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), 4746 AExt.getValueType(), Val); 4747 AddTo = true; 4748 } 4749 4750 // If we get here, the AND is unnecessary. Just replace it with the load 4751 DCI.CombineTo(N, Val, AddTo); 4752 } 4753 4754 return SDValue(); 4755 } 4756 4757 static SDValue PerformREMCombine(SDNode *N, 4758 TargetLowering::DAGCombinerInfo &DCI, 4759 CodeGenOptLevel OptLevel) { 4760 assert(N->getOpcode() == ISD::SREM || N->getOpcode() == ISD::UREM); 4761 4762 // Don't do anything at less than -O2. 4763 if (OptLevel < CodeGenOptLevel::Default) 4764 return SDValue(); 4765 4766 SelectionDAG &DAG = DCI.DAG; 4767 SDLoc DL(N); 4768 EVT VT = N->getValueType(0); 4769 bool IsSigned = N->getOpcode() == ISD::SREM; 4770 unsigned DivOpc = IsSigned ? ISD::SDIV : ISD::UDIV; 4771 4772 const SDValue &Num = N->getOperand(0); 4773 const SDValue &Den = N->getOperand(1); 4774 4775 for (const SDNode *U : Num->users()) { 4776 if (U->getOpcode() == DivOpc && U->getOperand(0) == Num && 4777 U->getOperand(1) == Den) { 4778 // Num % Den -> Num - (Num / Den) * Den 4779 return DAG.getNode(ISD::SUB, DL, VT, Num, 4780 DAG.getNode(ISD::MUL, DL, VT, 4781 DAG.getNode(DivOpc, DL, VT, Num, Den), 4782 Den)); 4783 } 4784 } 4785 return SDValue(); 4786 } 4787 4788 enum OperandSignedness { 4789 Signed = 0, 4790 Unsigned, 4791 Unknown 4792 }; 4793 4794 /// IsMulWideOperandDemotable - Checks if the provided DAG node is an operand 4795 /// that can be demoted to \p OptSize bits without loss of information. The 4796 /// signedness of the operand, if determinable, is placed in \p S. 4797 static bool IsMulWideOperandDemotable(SDValue Op, 4798 unsigned OptSize, 4799 OperandSignedness &S) { 4800 S = Unknown; 4801 4802 if (Op.getOpcode() == ISD::SIGN_EXTEND || 4803 Op.getOpcode() == ISD::SIGN_EXTEND_INREG) { 4804 EVT OrigVT = Op.getOperand(0).getValueType(); 4805 if (OrigVT.getFixedSizeInBits() <= OptSize) { 4806 S = Signed; 4807 return true; 4808 } 4809 } else if (Op.getOpcode() == ISD::ZERO_EXTEND) { 4810 EVT OrigVT = Op.getOperand(0).getValueType(); 4811 if (OrigVT.getFixedSizeInBits() <= OptSize) { 4812 S = Unsigned; 4813 return true; 4814 } 4815 } 4816 4817 return false; 4818 } 4819 4820 /// AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can 4821 /// be demoted to \p OptSize bits without loss of information. If the operands 4822 /// contain a constant, it should appear as the RHS operand. The signedness of 4823 /// the operands is placed in \p IsSigned. 4824 static bool AreMulWideOperandsDemotable(SDValue LHS, SDValue RHS, 4825 unsigned OptSize, 4826 bool &IsSigned) { 4827 OperandSignedness LHSSign; 4828 4829 // The LHS operand must be a demotable op 4830 if (!IsMulWideOperandDemotable(LHS, OptSize, LHSSign)) 4831 return false; 4832 4833 // We should have been able to determine the signedness from the LHS 4834 if (LHSSign == Unknown) 4835 return false; 4836 4837 IsSigned = (LHSSign == Signed); 4838 4839 // The RHS can be a demotable op or a constant 4840 if (ConstantSDNode *CI = dyn_cast<ConstantSDNode>(RHS)) { 4841 const APInt &Val = CI->getAPIntValue(); 4842 if (LHSSign == Unsigned) { 4843 return Val.isIntN(OptSize); 4844 } else { 4845 return Val.isSignedIntN(OptSize); 4846 } 4847 } else { 4848 OperandSignedness RHSSign; 4849 if (!IsMulWideOperandDemotable(RHS, OptSize, RHSSign)) 4850 return false; 4851 4852 return LHSSign == RHSSign; 4853 } 4854 } 4855 4856 /// TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply 4857 /// of M/2 bits that produces an M-bit result (i.e. mul.wide). This transform 4858 /// works on both multiply DAG nodes and SHL DAG nodes with a constant shift 4859 /// amount. 4860 static SDValue TryMULWIDECombine(SDNode *N, 4861 TargetLowering::DAGCombinerInfo &DCI) { 4862 EVT MulType = N->getValueType(0); 4863 if (MulType != MVT::i32 && MulType != MVT::i64) { 4864 return SDValue(); 4865 } 4866 4867 SDLoc DL(N); 4868 unsigned OptSize = MulType.getSizeInBits() >> 1; 4869 SDValue LHS = N->getOperand(0); 4870 SDValue RHS = N->getOperand(1); 4871 4872 // Canonicalize the multiply so the constant (if any) is on the right 4873 if (N->getOpcode() == ISD::MUL) { 4874 if (isa<ConstantSDNode>(LHS)) { 4875 std::swap(LHS, RHS); 4876 } 4877 } 4878 4879 // If we have a SHL, determine the actual multiply amount 4880 if (N->getOpcode() == ISD::SHL) { 4881 ConstantSDNode *ShlRHS = dyn_cast<ConstantSDNode>(RHS); 4882 if (!ShlRHS) { 4883 return SDValue(); 4884 } 4885 4886 APInt ShiftAmt = ShlRHS->getAPIntValue(); 4887 unsigned BitWidth = MulType.getSizeInBits(); 4888 if (ShiftAmt.sge(0) && ShiftAmt.slt(BitWidth)) { 4889 APInt MulVal = APInt(BitWidth, 1) << ShiftAmt; 4890 RHS = DCI.DAG.getConstant(MulVal, DL, MulType); 4891 } else { 4892 return SDValue(); 4893 } 4894 } 4895 4896 bool Signed; 4897 // Verify that our operands are demotable 4898 if (!AreMulWideOperandsDemotable(LHS, RHS, OptSize, Signed)) { 4899 return SDValue(); 4900 } 4901 4902 EVT DemotedVT; 4903 if (MulType == MVT::i32) { 4904 DemotedVT = MVT::i16; 4905 } else { 4906 DemotedVT = MVT::i32; 4907 } 4908 4909 // Truncate the operands to the correct size. Note that these are just for 4910 // type consistency and will (likely) be eliminated in later phases. 4911 SDValue TruncLHS = 4912 DCI.DAG.getNode(ISD::TRUNCATE, DL, DemotedVT, LHS); 4913 SDValue TruncRHS = 4914 DCI.DAG.getNode(ISD::TRUNCATE, DL, DemotedVT, RHS); 4915 4916 unsigned Opc; 4917 if (Signed) { 4918 Opc = NVPTXISD::MUL_WIDE_SIGNED; 4919 } else { 4920 Opc = NVPTXISD::MUL_WIDE_UNSIGNED; 4921 } 4922 4923 return DCI.DAG.getNode(Opc, DL, MulType, TruncLHS, TruncRHS); 4924 } 4925 4926 static bool isConstOne(const SDValue &Operand) { 4927 const auto *Const = dyn_cast<ConstantSDNode>(Operand); 4928 return Const && Const->getZExtValue() == 1; 4929 } 4930 4931 static SDValue matchMADConstOnePattern(SDValue Add) { 4932 if (Add->getOpcode() != ISD::ADD) 4933 return SDValue(); 4934 4935 if (isConstOne(Add->getOperand(0))) 4936 return Add->getOperand(1); 4937 4938 if (isConstOne(Add->getOperand(1))) 4939 return Add->getOperand(0); 4940 4941 return SDValue(); 4942 } 4943 4944 static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL, 4945 TargetLowering::DAGCombinerInfo &DCI) { 4946 4947 if (SDValue Y = matchMADConstOnePattern(Add)) { 4948 SDValue Mul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y); 4949 return DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, X); 4950 } 4951 4952 return SDValue(); 4953 } 4954 4955 static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT, 4956 SDLoc DL, 4957 TargetLowering::DAGCombinerInfo &DCI) { 4958 if (Select->getOpcode() != ISD::SELECT) 4959 return SDValue(); 4960 4961 SDValue Cond = Select->getOperand(0); 4962 4963 unsigned ConstOpNo; 4964 if (isConstOne(Select->getOperand(1))) 4965 ConstOpNo = 1; 4966 else if (isConstOne(Select->getOperand(2))) 4967 ConstOpNo = 2; 4968 else 4969 return SDValue(); 4970 4971 SDValue Y = Select->getOperand((ConstOpNo == 1) ? 2 : 1); 4972 4973 // Do not combine if the resulting sequence is not obviously profitable. 4974 if (!matchMADConstOnePattern(Y)) 4975 return SDValue(); 4976 4977 SDValue NewMul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y); 4978 4979 return DCI.DAG.getNode(ISD::SELECT, DL, VT, Cond, 4980 (ConstOpNo == 1) ? X : NewMul, 4981 (ConstOpNo == 1) ? NewMul : X); 4982 } 4983 4984 static SDValue 4985 PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, 4986 TargetLowering::DAGCombinerInfo &DCI) { 4987 4988 EVT VT = N0.getValueType(); 4989 if (VT.isVector()) 4990 return SDValue(); 4991 4992 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64) 4993 return SDValue(); 4994 4995 SDLoc DL(N); 4996 4997 // (mul x, (add y, 1)) -> (add (mul x, y), x) 4998 if (SDValue Res = combineMADConstOne(N0, N1, VT, DL, DCI)) 4999 return Res; 5000 if (SDValue Res = combineMADConstOne(N1, N0, VT, DL, DCI)) 5001 return Res; 5002 5003 // (mul x, (select y, 1)) -> (select (mul x, y), x) 5004 if (SDValue Res = combineMulSelectConstOne(N0, N1, VT, DL, DCI)) 5005 return Res; 5006 if (SDValue Res = combineMulSelectConstOne(N1, N0, VT, DL, DCI)) 5007 return Res; 5008 5009 return SDValue(); 5010 } 5011 5012 /// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes. 5013 static SDValue PerformMULCombine(SDNode *N, 5014 TargetLowering::DAGCombinerInfo &DCI, 5015 CodeGenOptLevel OptLevel) { 5016 if (OptLevel == CodeGenOptLevel::None) 5017 return SDValue(); 5018 5019 if (SDValue Ret = TryMULWIDECombine(N, DCI)) 5020 return Ret; 5021 5022 SDValue N0 = N->getOperand(0); 5023 SDValue N1 = N->getOperand(1); 5024 return PerformMULCombineWithOperands(N, N0, N1, DCI); 5025 } 5026 5027 /// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes. 5028 static SDValue PerformSHLCombine(SDNode *N, 5029 TargetLowering::DAGCombinerInfo &DCI, 5030 CodeGenOptLevel OptLevel) { 5031 if (OptLevel > CodeGenOptLevel::None) { 5032 // Try mul.wide combining at OptLevel > 0 5033 if (SDValue Ret = TryMULWIDECombine(N, DCI)) 5034 return Ret; 5035 } 5036 5037 return SDValue(); 5038 } 5039 5040 static SDValue PerformSETCCCombine(SDNode *N, 5041 TargetLowering::DAGCombinerInfo &DCI, 5042 unsigned int SmVersion) { 5043 EVT CCType = N->getValueType(0); 5044 SDValue A = N->getOperand(0); 5045 SDValue B = N->getOperand(1); 5046 5047 EVT AType = A.getValueType(); 5048 if (!(CCType == MVT::v2i1 && (AType == MVT::v2f16 || AType == MVT::v2bf16))) 5049 return SDValue(); 5050 5051 if (A.getValueType() == MVT::v2bf16 && SmVersion < 90) 5052 return SDValue(); 5053 5054 SDLoc DL(N); 5055 // setp.f16x2 returns two scalar predicates, which we need to 5056 // convert back to v2i1. The returned result will be scalarized by 5057 // the legalizer, but the comparison will remain a single vector 5058 // instruction. 5059 SDValue CCNode = DCI.DAG.getNode( 5060 A.getValueType() == MVT::v2f16 ? NVPTXISD::SETP_F16X2 5061 : NVPTXISD::SETP_BF16X2, 5062 DL, DCI.DAG.getVTList(MVT::i1, MVT::i1), {A, B, N->getOperand(2)}); 5063 return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, CCType, CCNode.getValue(0), 5064 CCNode.getValue(1)); 5065 } 5066 5067 static SDValue PerformEXTRACTCombine(SDNode *N, 5068 TargetLowering::DAGCombinerInfo &DCI) { 5069 SDValue Vector = N->getOperand(0); 5070 if (Vector->getOpcode() == ISD::FREEZE) 5071 Vector = Vector->getOperand(0); 5072 SDLoc DL(N); 5073 EVT VectorVT = Vector.getValueType(); 5074 if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() && 5075 IsPTXVectorType(VectorVT.getSimpleVT())) 5076 return SDValue(); // Native vector loads already combine nicely w/ 5077 // extract_vector_elt. 5078 // Don't mess with singletons or v2*16, v4i8 and v8i8 types, we already 5079 // handle them OK. 5080 if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT) || 5081 VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8) 5082 return SDValue(); 5083 5084 // Don't mess with undef values as sra may be simplified to 0, not undef. 5085 if (Vector->isUndef() || ISD::allOperandsUndef(Vector.getNode())) 5086 return SDValue(); 5087 5088 uint64_t VectorBits = VectorVT.getSizeInBits(); 5089 // We only handle the types we can extract in-register. 5090 if (!(VectorBits == 16 || VectorBits == 32 || VectorBits == 64)) 5091 return SDValue(); 5092 5093 ConstantSDNode *Index = dyn_cast<ConstantSDNode>(N->getOperand(1)); 5094 // Index == 0 is handled by generic DAG combiner. 5095 if (!Index || Index->getZExtValue() == 0) 5096 return SDValue(); 5097 5098 MVT IVT = MVT::getIntegerVT(VectorBits); 5099 EVT EltVT = VectorVT.getVectorElementType(); 5100 EVT EltIVT = EltVT.changeTypeToInteger(); 5101 uint64_t EltBits = EltVT.getScalarSizeInBits(); 5102 5103 SDValue Result = DCI.DAG.getNode( 5104 ISD::TRUNCATE, DL, EltIVT, 5105 DCI.DAG.getNode( 5106 ISD::SRA, DL, IVT, DCI.DAG.getNode(ISD::BITCAST, DL, IVT, Vector), 5107 DCI.DAG.getConstant(Index->getZExtValue() * EltBits, DL, IVT))); 5108 5109 // If element has non-integer type, bitcast it back to the expected type. 5110 if (EltVT != EltIVT) 5111 Result = DCI.DAG.getNode(ISD::BITCAST, DL, EltVT, Result); 5112 // Past legalizer, we may need to extent i8 -> i16 to match the register type. 5113 if (EltVT != N->getValueType(0)) 5114 Result = DCI.DAG.getNode(ISD::ANY_EXTEND, DL, N->getValueType(0), Result); 5115 5116 return Result; 5117 } 5118 5119 static SDValue PerformVSELECTCombine(SDNode *N, 5120 TargetLowering::DAGCombinerInfo &DCI) { 5121 SDValue VA = N->getOperand(1); 5122 EVT VectorVT = VA.getValueType(); 5123 if (VectorVT != MVT::v4i8) 5124 return SDValue(); 5125 5126 // We need to split vselect into individual per-element operations Because we 5127 // use BFE/BFI instruction for byte extraction/insertion, we do end up with 5128 // 32-bit values, so we may as well do comparison as i32 to avoid conversions 5129 // to/from i16 normally used for i8 values. 5130 SmallVector<SDValue, 4> E; 5131 SDLoc DL(N); 5132 SDValue VCond = N->getOperand(0); 5133 SDValue VB = N->getOperand(2); 5134 for (int I = 0; I < 4; ++I) { 5135 SDValue C = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i1, VCond, 5136 DCI.DAG.getConstant(I, DL, MVT::i32)); 5137 SDValue EA = DCI.DAG.getAnyExtOrTrunc( 5138 DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, VA, 5139 DCI.DAG.getConstant(I, DL, MVT::i32)), 5140 DL, MVT::i32); 5141 SDValue EB = DCI.DAG.getAnyExtOrTrunc( 5142 DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, VB, 5143 DCI.DAG.getConstant(I, DL, MVT::i32)), 5144 DL, MVT::i32); 5145 E.push_back(DCI.DAG.getAnyExtOrTrunc( 5146 DCI.DAG.getNode(ISD::SELECT, DL, MVT::i32, C, EA, EB), DL, MVT::i8)); 5147 } 5148 return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v4i8, E); 5149 } 5150 5151 static SDValue 5152 PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { 5153 auto VT = N->getValueType(0); 5154 if (!DCI.isAfterLegalizeDAG() || !Isv2x16VT(VT)) 5155 return SDValue(); 5156 5157 auto Op0 = N->getOperand(0); 5158 auto Op1 = N->getOperand(1); 5159 5160 // Start out by assuming we want to take the lower 2 bytes of each i32 5161 // operand. 5162 uint64_t Op0Bytes = 0x10; 5163 uint64_t Op1Bytes = 0x54; 5164 5165 std::pair<SDValue *, uint64_t *> OpData[2] = {{&Op0, &Op0Bytes}, 5166 {&Op1, &Op1Bytes}}; 5167 5168 // Check that each operand is an i16, truncated from an i32 operand. We'll 5169 // select individual bytes from those original operands. Optionally, fold in a 5170 // shift right of that original operand. 5171 for (auto &[Op, OpBytes] : OpData) { 5172 // Eat up any bitcast 5173 if (Op->getOpcode() == ISD::BITCAST) 5174 *Op = Op->getOperand(0); 5175 5176 if (!(Op->getValueType() == MVT::i16 && Op->getOpcode() == ISD::TRUNCATE && 5177 Op->getOperand(0).getValueType() == MVT::i32)) 5178 return SDValue(); 5179 5180 // If the truncate has multiple uses, this optimization can increase 5181 // register pressure 5182 if (!Op->hasOneUse()) 5183 return SDValue(); 5184 5185 *Op = Op->getOperand(0); 5186 5187 // Optionally, fold in a shift-right of the original operand and let permute 5188 // pick the two higher bytes of the original value directly. 5189 if (Op->getOpcode() == ISD::SRL && isa<ConstantSDNode>(Op->getOperand(1))) { 5190 if (cast<ConstantSDNode>(Op->getOperand(1))->getZExtValue() == 16) { 5191 // Shift the PRMT byte selector to pick upper bytes from each respective 5192 // value, instead of the lower ones: 0x10 -> 0x32, 0x54 -> 0x76 5193 assert((*OpBytes == 0x10 || *OpBytes == 0x54) && 5194 "PRMT selector values out of range"); 5195 *OpBytes += 0x22; 5196 *Op = Op->getOperand(0); 5197 } 5198 } 5199 } 5200 5201 SDLoc DL(N); 5202 auto &DAG = DCI.DAG; 5203 5204 auto PRMT = DAG.getNode( 5205 NVPTXISD::PRMT, DL, MVT::v4i8, 5206 {Op0, Op1, DAG.getConstant((Op1Bytes << 8) | Op0Bytes, DL, MVT::i32), 5207 DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)}); 5208 return DAG.getNode(ISD::BITCAST, DL, VT, PRMT); 5209 } 5210 5211 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, 5212 DAGCombinerInfo &DCI) const { 5213 CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel(); 5214 switch (N->getOpcode()) { 5215 default: break; 5216 case ISD::ADD: 5217 return PerformADDCombine(N, DCI, OptLevel); 5218 case ISD::FADD: 5219 return PerformFADDCombine(N, DCI, OptLevel); 5220 case ISD::MUL: 5221 return PerformMULCombine(N, DCI, OptLevel); 5222 case ISD::SHL: 5223 return PerformSHLCombine(N, DCI, OptLevel); 5224 case ISD::AND: 5225 return PerformANDCombine(N, DCI); 5226 case ISD::UREM: 5227 case ISD::SREM: 5228 return PerformREMCombine(N, DCI, OptLevel); 5229 case ISD::SETCC: 5230 return PerformSETCCCombine(N, DCI, STI.getSmVersion()); 5231 case NVPTXISD::StoreRetval: 5232 case NVPTXISD::StoreRetvalV2: 5233 case NVPTXISD::StoreRetvalV4: 5234 return PerformStoreRetvalCombine(N); 5235 case NVPTXISD::StoreParam: 5236 case NVPTXISD::StoreParamV2: 5237 case NVPTXISD::StoreParamV4: 5238 return PerformStoreParamCombine(N); 5239 case ISD::EXTRACT_VECTOR_ELT: 5240 return PerformEXTRACTCombine(N, DCI); 5241 case ISD::VSELECT: 5242 return PerformVSELECTCombine(N, DCI); 5243 case ISD::BUILD_VECTOR: 5244 return PerformBUILD_VECTORCombine(N, DCI); 5245 } 5246 return SDValue(); 5247 } 5248 5249 static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG, 5250 SmallVectorImpl<SDValue> &Results) { 5251 // Handle bitcasting to v2i8 without hitting the default promotion 5252 // strategy which goes through stack memory. 5253 SDValue Op(Node, 0); 5254 EVT ToVT = Op->getValueType(0); 5255 if (ToVT != MVT::v2i8) { 5256 return; 5257 } 5258 5259 // Bitcast to i16 and unpack elements into a vector 5260 SDLoc DL(Node); 5261 SDValue AsInt = MaybeBitcast(DAG, DL, MVT::i16, Op->getOperand(0)); 5262 SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt); 5263 SDValue Const8 = DAG.getConstant(8, DL, MVT::i16); 5264 SDValue Vec1 = 5265 DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, 5266 DAG.getNode(ISD::SRL, DL, MVT::i16, {AsInt, Const8})); 5267 Results.push_back( 5268 DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1})); 5269 } 5270 5271 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads. 5272 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG, 5273 SmallVectorImpl<SDValue> &Results) { 5274 EVT ResVT = N->getValueType(0); 5275 SDLoc DL(N); 5276 5277 assert(ResVT.isVector() && "Vector load must have vector type"); 5278 5279 auto NumEltsAndEltVT = getVectorLoweringShape(ResVT); 5280 if (!NumEltsAndEltVT) 5281 return; 5282 auto [NumElts, EltVT] = NumEltsAndEltVT.value(); 5283 5284 LoadSDNode *LD = cast<LoadSDNode>(N); 5285 5286 Align Alignment = LD->getAlign(); 5287 auto &TD = DAG.getDataLayout(); 5288 Align PrefAlign = 5289 TD.getPrefTypeAlign(LD->getMemoryVT().getTypeForEVT(*DAG.getContext())); 5290 if (Alignment < PrefAlign) { 5291 // This load is not sufficiently aligned, so bail out and let this vector 5292 // load be scalarized. Note that we may still be able to emit smaller 5293 // vector loads. For example, if we are loading a <4 x float> with an 5294 // alignment of 8, this check will fail but the legalizer will try again 5295 // with 2 x <2 x float>, which will succeed with an alignment of 8. 5296 return; 5297 } 5298 5299 // Since LoadV2 is a target node, we cannot rely on DAG type legalization. 5300 // Therefore, we must ensure the type is legal. For i1 and i8, we set the 5301 // loaded type to i16 and propagate the "real" type as the memory type. 5302 bool NeedTrunc = false; 5303 if (EltVT.getSizeInBits() < 16) { 5304 EltVT = MVT::i16; 5305 NeedTrunc = true; 5306 } 5307 5308 unsigned Opcode = 0; 5309 SDVTList LdResVTs; 5310 5311 switch (NumElts) { 5312 default: 5313 return; 5314 case 2: 5315 Opcode = NVPTXISD::LoadV2; 5316 LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other); 5317 break; 5318 case 4: { 5319 Opcode = NVPTXISD::LoadV4; 5320 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other }; 5321 LdResVTs = DAG.getVTList(ListVTs); 5322 break; 5323 } 5324 } 5325 5326 // Copy regular operands 5327 SmallVector<SDValue, 8> OtherOps(N->ops()); 5328 5329 // The select routine does not have access to the LoadSDNode instance, so 5330 // pass along the extension information 5331 OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL)); 5332 5333 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, 5334 LD->getMemoryVT(), 5335 LD->getMemOperand()); 5336 5337 SmallVector<SDValue> ScalarRes; 5338 assert(NumElts <= ResVT.getVectorNumElements() && 5339 "NumElts should not increase, only decrease or stay the same."); 5340 if (NumElts < ResVT.getVectorNumElements()) { 5341 // If the number of elements has decreased, getVectorLoweringShape has 5342 // upsized the element types 5343 assert(EltVT.isVector() && EltVT.getSizeInBits() == 32 && 5344 EltVT.getVectorNumElements() <= 4 && "Unexpected upsized type."); 5345 // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back 5346 // into individual elements. 5347 for (unsigned i = 0; i < NumElts; ++i) { 5348 SDValue SubVector = NewLD.getValue(i); 5349 DAG.ExtractVectorElements(SubVector, ScalarRes); 5350 } 5351 } else { 5352 for (unsigned i = 0; i < NumElts; ++i) { 5353 SDValue Res = NewLD.getValue(i); 5354 if (NeedTrunc) 5355 Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res); 5356 ScalarRes.push_back(Res); 5357 } 5358 } 5359 5360 SDValue LoadChain = NewLD.getValue(NumElts); 5361 5362 SDValue BuildVec = DAG.getBuildVector(ResVT, DL, ScalarRes); 5363 5364 Results.push_back(BuildVec); 5365 Results.push_back(LoadChain); 5366 } 5367 5368 static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG, 5369 SmallVectorImpl<SDValue> &Results) { 5370 SDValue Chain = N->getOperand(0); 5371 SDValue Intrin = N->getOperand(1); 5372 SDLoc DL(N); 5373 5374 // Get the intrinsic ID 5375 unsigned IntrinNo = Intrin.getNode()->getAsZExtVal(); 5376 switch (IntrinNo) { 5377 default: 5378 return; 5379 case Intrinsic::nvvm_ldu_global_i: 5380 case Intrinsic::nvvm_ldu_global_f: 5381 case Intrinsic::nvvm_ldu_global_p: { 5382 EVT ResVT = N->getValueType(0); 5383 5384 if (ResVT.isVector()) { 5385 // Vector LDG/LDU 5386 5387 unsigned NumElts = ResVT.getVectorNumElements(); 5388 EVT EltVT = ResVT.getVectorElementType(); 5389 5390 // Since LDU/LDG are target nodes, we cannot rely on DAG type 5391 // legalization. 5392 // Therefore, we must ensure the type is legal. For i1 and i8, we set the 5393 // loaded type to i16 and propagate the "real" type as the memory type. 5394 bool NeedTrunc = false; 5395 if (EltVT.getSizeInBits() < 16) { 5396 EltVT = MVT::i16; 5397 NeedTrunc = true; 5398 } 5399 5400 unsigned Opcode = 0; 5401 SDVTList LdResVTs; 5402 5403 switch (NumElts) { 5404 default: 5405 return; 5406 case 2: 5407 Opcode = NVPTXISD::LDUV2; 5408 LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other); 5409 break; 5410 case 4: { 5411 Opcode = NVPTXISD::LDUV4; 5412 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other }; 5413 LdResVTs = DAG.getVTList(ListVTs); 5414 break; 5415 } 5416 } 5417 5418 SmallVector<SDValue, 8> OtherOps; 5419 5420 // Copy regular operands 5421 5422 OtherOps.push_back(Chain); // Chain 5423 // Skip operand 1 (intrinsic ID) 5424 // Others 5425 OtherOps.append(N->op_begin() + 2, N->op_end()); 5426 5427 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N); 5428 5429 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, 5430 MemSD->getMemoryVT(), 5431 MemSD->getMemOperand()); 5432 5433 SmallVector<SDValue, 4> ScalarRes; 5434 5435 for (unsigned i = 0; i < NumElts; ++i) { 5436 SDValue Res = NewLD.getValue(i); 5437 if (NeedTrunc) 5438 Res = 5439 DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res); 5440 ScalarRes.push_back(Res); 5441 } 5442 5443 SDValue LoadChain = NewLD.getValue(NumElts); 5444 5445 SDValue BuildVec = 5446 DAG.getBuildVector(ResVT, DL, ScalarRes); 5447 5448 Results.push_back(BuildVec); 5449 Results.push_back(LoadChain); 5450 } else { 5451 // i8 LDG/LDU 5452 assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 && 5453 "Custom handling of non-i8 ldu/ldg?"); 5454 5455 // Just copy all operands as-is 5456 SmallVector<SDValue, 4> Ops(N->ops()); 5457 5458 // Force output to i16 5459 SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other); 5460 5461 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N); 5462 5463 // We make sure the memory type is i8, which will be used during isel 5464 // to select the proper instruction. 5465 SDValue NewLD = 5466 DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, Ops, 5467 MVT::i8, MemSD->getMemOperand()); 5468 5469 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, 5470 NewLD.getValue(0))); 5471 Results.push_back(NewLD.getValue(1)); 5472 } 5473 } 5474 } 5475 } 5476 5477 static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG, 5478 SmallVectorImpl<SDValue> &Results) { 5479 // Change the CopyFromReg to output 2 64-bit results instead of a 128-bit 5480 // result so that it can pass the legalization 5481 SDLoc DL(N); 5482 SDValue Chain = N->getOperand(0); 5483 SDValue Reg = N->getOperand(1); 5484 SDValue Glue = N->getOperand(2); 5485 5486 assert(Reg.getValueType() == MVT::i128 && 5487 "Custom lowering for CopyFromReg with 128-bit reg only"); 5488 SmallVector<EVT, 4> ResultsType = {MVT::i64, MVT::i64, N->getValueType(1), 5489 N->getValueType(2)}; 5490 SmallVector<SDValue, 3> NewOps = {Chain, Reg, Glue}; 5491 5492 SDValue NewValue = DAG.getNode(ISD::CopyFromReg, DL, ResultsType, NewOps); 5493 SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i128, 5494 {NewValue.getValue(0), NewValue.getValue(1)}); 5495 5496 Results.push_back(Pair); 5497 Results.push_back(NewValue.getValue(2)); 5498 Results.push_back(NewValue.getValue(3)); 5499 } 5500 5501 void NVPTXTargetLowering::ReplaceNodeResults( 5502 SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const { 5503 switch (N->getOpcode()) { 5504 default: 5505 report_fatal_error("Unhandled custom legalization"); 5506 case ISD::BITCAST: 5507 ReplaceBITCAST(N, DAG, Results); 5508 return; 5509 case ISD::LOAD: 5510 ReplaceLoadVector(N, DAG, Results); 5511 return; 5512 case ISD::INTRINSIC_W_CHAIN: 5513 ReplaceINTRINSIC_W_CHAIN(N, DAG, Results); 5514 return; 5515 case ISD::CopyFromReg: 5516 ReplaceCopyFromReg_128(N, DAG, Results); 5517 return; 5518 } 5519 } 5520 5521 NVPTXTargetLowering::AtomicExpansionKind 5522 NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const { 5523 Type *Ty = AI->getValOperand()->getType(); 5524 5525 if (AI->isFloatingPointOperation()) { 5526 if (AI->getOperation() == AtomicRMWInst::BinOp::FAdd) { 5527 if (Ty->isHalfTy() && STI.getSmVersion() >= 70 && 5528 STI.getPTXVersion() >= 63) 5529 return AtomicExpansionKind::None; 5530 if (Ty->isBFloatTy() && STI.getSmVersion() >= 90 && 5531 STI.getPTXVersion() >= 78) 5532 return AtomicExpansionKind::None; 5533 if (Ty->isFloatTy()) 5534 return AtomicExpansionKind::None; 5535 if (Ty->isDoubleTy() && STI.hasAtomAddF64()) 5536 return AtomicExpansionKind::None; 5537 } 5538 return AtomicExpansionKind::CmpXChg; 5539 } 5540 5541 assert(Ty->isIntegerTy() && "Ty should be integer at this point"); 5542 auto ITy = cast<llvm::IntegerType>(Ty); 5543 5544 switch (AI->getOperation()) { 5545 default: 5546 return AtomicExpansionKind::CmpXChg; 5547 case AtomicRMWInst::BinOp::And: 5548 case AtomicRMWInst::BinOp::Or: 5549 case AtomicRMWInst::BinOp::Xor: 5550 case AtomicRMWInst::BinOp::Xchg: 5551 switch (ITy->getBitWidth()) { 5552 case 8: 5553 case 16: 5554 return AtomicExpansionKind::CmpXChg; 5555 case 32: 5556 return AtomicExpansionKind::None; 5557 case 64: 5558 if (STI.hasAtomBitwise64()) 5559 return AtomicExpansionKind::None; 5560 return AtomicExpansionKind::CmpXChg; 5561 default: 5562 llvm_unreachable("unsupported width encountered"); 5563 } 5564 case AtomicRMWInst::BinOp::Add: 5565 case AtomicRMWInst::BinOp::Sub: 5566 case AtomicRMWInst::BinOp::Max: 5567 case AtomicRMWInst::BinOp::Min: 5568 case AtomicRMWInst::BinOp::UMax: 5569 case AtomicRMWInst::BinOp::UMin: 5570 switch (ITy->getBitWidth()) { 5571 case 8: 5572 case 16: 5573 return AtomicExpansionKind::CmpXChg; 5574 case 32: 5575 return AtomicExpansionKind::None; 5576 case 64: 5577 if (STI.hasAtomMinMax64()) 5578 return AtomicExpansionKind::None; 5579 return AtomicExpansionKind::CmpXChg; 5580 default: 5581 llvm_unreachable("unsupported width encountered"); 5582 } 5583 } 5584 5585 return AtomicExpansionKind::CmpXChg; 5586 } 5587 5588 // Pin NVPTXTargetObjectFile's vtables to this file. 5589 NVPTXTargetObjectFile::~NVPTXTargetObjectFile() = default; 5590 5591 MCSection *NVPTXTargetObjectFile::SelectSectionForGlobal( 5592 const GlobalObject *GO, SectionKind Kind, const TargetMachine &TM) const { 5593 return getDataSection(); 5594 } 5595