1 //===-- X86FixupVectorConstants.cpp - optimize constant generation -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file examines all full size vector constant pool loads and attempts to 10 // replace them with smaller constant pool entries, including: 11 // * Converting AVX512 memory-fold instructions to their broadcast-fold form 12 // * Broadcasting of full width loads. 13 // * TODO: Sign/Zero extension of full width loads. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "X86.h" 18 #include "X86InstrFoldTables.h" 19 #include "X86InstrInfo.h" 20 #include "X86Subtarget.h" 21 #include "llvm/ADT/Statistic.h" 22 #include "llvm/CodeGen/MachineConstantPool.h" 23 24 using namespace llvm; 25 26 #define DEBUG_TYPE "x86-fixup-vector-constants" 27 28 STATISTIC(NumInstChanges, "Number of instructions changes"); 29 30 namespace { 31 class X86FixupVectorConstantsPass : public MachineFunctionPass { 32 public: 33 static char ID; 34 35 X86FixupVectorConstantsPass() : MachineFunctionPass(ID) {} 36 37 StringRef getPassName() const override { 38 return "X86 Fixup Vector Constants"; 39 } 40 41 bool runOnMachineFunction(MachineFunction &MF) override; 42 bool processInstruction(MachineFunction &MF, MachineBasicBlock &MBB, 43 MachineInstr &MI); 44 45 // This pass runs after regalloc and doesn't support VReg operands. 46 MachineFunctionProperties getRequiredProperties() const override { 47 return MachineFunctionProperties().set( 48 MachineFunctionProperties::Property::NoVRegs); 49 } 50 51 private: 52 const X86InstrInfo *TII = nullptr; 53 const X86Subtarget *ST = nullptr; 54 const MCSchedModel *SM = nullptr; 55 }; 56 } // end anonymous namespace 57 58 char X86FixupVectorConstantsPass::ID = 0; 59 60 INITIALIZE_PASS(X86FixupVectorConstantsPass, DEBUG_TYPE, DEBUG_TYPE, false, false) 61 62 FunctionPass *llvm::createX86FixupVectorConstants() { 63 return new X86FixupVectorConstantsPass(); 64 } 65 66 static const Constant *getConstantFromPool(const MachineInstr &MI, 67 const MachineOperand &Op) { 68 if (!Op.isCPI() || Op.getOffset() != 0) 69 return nullptr; 70 71 ArrayRef<MachineConstantPoolEntry> Constants = 72 MI.getParent()->getParent()->getConstantPool()->getConstants(); 73 const MachineConstantPoolEntry &ConstantEntry = Constants[Op.getIndex()]; 74 75 // Bail if this is a machine constant pool entry, we won't be able to dig out 76 // anything useful. 77 if (ConstantEntry.isMachineConstantPoolEntry()) 78 return nullptr; 79 80 return ConstantEntry.Val.ConstVal; 81 } 82 83 // Attempt to extract the full width of bits data from the constant. 84 static std::optional<APInt> extractConstantBits(const Constant *C) { 85 unsigned NumBits = C->getType()->getPrimitiveSizeInBits(); 86 87 if (auto *CInt = dyn_cast<ConstantInt>(C)) 88 return CInt->getValue(); 89 90 if (auto *CFP = dyn_cast<ConstantFP>(C)) 91 return CFP->getValue().bitcastToAPInt(); 92 93 if (auto *CV = dyn_cast<ConstantVector>(C)) { 94 if (auto *CVSplat = CV->getSplatValue(/*AllowUndefs*/ true)) { 95 if (std::optional<APInt> Bits = extractConstantBits(CVSplat)) { 96 assert((NumBits % Bits->getBitWidth()) == 0 && "Illegal splat"); 97 return APInt::getSplat(NumBits, *Bits); 98 } 99 } 100 } 101 102 if (auto *CDS = dyn_cast<ConstantDataSequential>(C)) { 103 bool IsInteger = CDS->getElementType()->isIntegerTy(); 104 bool IsFloat = CDS->getElementType()->isHalfTy() || 105 CDS->getElementType()->isBFloatTy() || 106 CDS->getElementType()->isFloatTy() || 107 CDS->getElementType()->isDoubleTy(); 108 if (IsInteger || IsFloat) { 109 APInt Bits = APInt::getZero(NumBits); 110 unsigned EltBits = CDS->getElementType()->getPrimitiveSizeInBits(); 111 for (unsigned I = 0, E = CDS->getNumElements(); I != E; ++I) { 112 if (IsInteger) 113 Bits.insertBits(CDS->getElementAsAPInt(I), I * EltBits); 114 else 115 Bits.insertBits(CDS->getElementAsAPFloat(I).bitcastToAPInt(), 116 I * EltBits); 117 } 118 return Bits; 119 } 120 } 121 122 return std::nullopt; 123 } 124 125 // Attempt to compute the splat width of bits data by normalizing the splat to 126 // remove undefs. 127 static std::optional<APInt> getSplatableConstant(const Constant *C, 128 unsigned SplatBitWidth) { 129 const Type *Ty = C->getType(); 130 assert((Ty->getPrimitiveSizeInBits() % SplatBitWidth) == 0 && 131 "Illegal splat width"); 132 133 if (std::optional<APInt> Bits = extractConstantBits(C)) 134 if (Bits->isSplat(SplatBitWidth)) 135 return Bits->trunc(SplatBitWidth); 136 137 // Detect general splats with undefs. 138 // TODO: Do we need to handle NumEltsBits > SplatBitWidth splitting? 139 if (auto *CV = dyn_cast<ConstantVector>(C)) { 140 unsigned NumOps = CV->getNumOperands(); 141 unsigned NumEltsBits = Ty->getScalarSizeInBits(); 142 unsigned NumScaleOps = SplatBitWidth / NumEltsBits; 143 if ((SplatBitWidth % NumEltsBits) == 0) { 144 // Collect the elements and ensure that within the repeated splat sequence 145 // they either match or are undef. 146 SmallVector<Constant *, 16> Sequence(NumScaleOps, nullptr); 147 for (unsigned Idx = 0; Idx != NumOps; ++Idx) { 148 if (Constant *Elt = CV->getAggregateElement(Idx)) { 149 if (isa<UndefValue>(Elt)) 150 continue; 151 unsigned SplatIdx = Idx % NumScaleOps; 152 if (!Sequence[SplatIdx] || Sequence[SplatIdx] == Elt) { 153 Sequence[SplatIdx] = Elt; 154 continue; 155 } 156 } 157 return std::nullopt; 158 } 159 // Extract the constant bits forming the splat and insert into the bits 160 // data, leave undef as zero. 161 APInt SplatBits = APInt::getZero(SplatBitWidth); 162 for (unsigned I = 0; I != NumScaleOps; ++I) { 163 if (!Sequence[I]) 164 continue; 165 if (std::optional<APInt> Bits = extractConstantBits(Sequence[I])) { 166 SplatBits.insertBits(*Bits, I * Bits->getBitWidth()); 167 continue; 168 } 169 return std::nullopt; 170 } 171 return SplatBits; 172 } 173 } 174 175 return std::nullopt; 176 } 177 178 // Attempt to rebuild a normalized splat vector constant of the requested splat 179 // width, built up of potentially smaller scalar values. 180 // NOTE: We don't always bother converting to scalars if the vector length is 1. 181 static Constant *rebuildSplatableConstant(const Constant *C, 182 unsigned SplatBitWidth) { 183 std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth); 184 if (!Splat) 185 return nullptr; 186 187 // Determine scalar size to use for the constant splat vector, clamping as we 188 // might have found a splat smaller than the original constant data. 189 const Type *OriginalType = C->getType(); 190 Type *SclTy = OriginalType->getScalarType(); 191 unsigned NumSclBits = SclTy->getPrimitiveSizeInBits(); 192 NumSclBits = std::min<unsigned>(NumSclBits, SplatBitWidth); 193 LLVMContext &Ctx = OriginalType->getContext(); 194 195 if (NumSclBits == 8) { 196 SmallVector<uint8_t> RawBits; 197 for (unsigned I = 0; I != SplatBitWidth; I += 8) 198 RawBits.push_back(Splat->extractBits(8, I).getZExtValue()); 199 return ConstantDataVector::get(Ctx, RawBits); 200 } 201 202 if (NumSclBits == 16) { 203 SmallVector<uint16_t> RawBits; 204 for (unsigned I = 0; I != SplatBitWidth; I += 16) 205 RawBits.push_back(Splat->extractBits(16, I).getZExtValue()); 206 if (SclTy->is16bitFPTy()) 207 return ConstantDataVector::getFP(SclTy, RawBits); 208 return ConstantDataVector::get(Ctx, RawBits); 209 } 210 211 if (NumSclBits == 32) { 212 SmallVector<uint32_t> RawBits; 213 for (unsigned I = 0; I != SplatBitWidth; I += 32) 214 RawBits.push_back(Splat->extractBits(32, I).getZExtValue()); 215 if (SclTy->isFloatTy()) 216 return ConstantDataVector::getFP(SclTy, RawBits); 217 return ConstantDataVector::get(Ctx, RawBits); 218 } 219 220 // Fallback to i64 / double. 221 SmallVector<uint64_t> RawBits; 222 for (unsigned I = 0; I != SplatBitWidth; I += 64) 223 RawBits.push_back(Splat->extractBits(64, I).getZExtValue()); 224 if (SclTy->isDoubleTy()) 225 return ConstantDataVector::getFP(SclTy, RawBits); 226 return ConstantDataVector::get(Ctx, RawBits); 227 } 228 229 bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF, 230 MachineBasicBlock &MBB, 231 MachineInstr &MI) { 232 unsigned Opc = MI.getOpcode(); 233 MachineConstantPool *CP = MI.getParent()->getParent()->getConstantPool(); 234 bool HasAVX2 = ST->hasAVX2(); 235 bool HasDQI = ST->hasDQI(); 236 bool HasBWI = ST->hasBWI(); 237 bool HasVLX = ST->hasVLX(); 238 239 auto ConvertToBroadcast = [&](unsigned OpBcst256, unsigned OpBcst128, 240 unsigned OpBcst64, unsigned OpBcst32, 241 unsigned OpBcst16, unsigned OpBcst8, 242 unsigned OperandNo) { 243 assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) && 244 "Unexpected number of operands!"); 245 246 MachineOperand &CstOp = MI.getOperand(OperandNo + X86::AddrDisp); 247 if (auto *C = getConstantFromPool(MI, CstOp)) { 248 // Attempt to detect a suitable splat from increasing splat widths. 249 std::pair<unsigned, unsigned> Broadcasts[] = { 250 {8, OpBcst8}, {16, OpBcst16}, {32, OpBcst32}, 251 {64, OpBcst64}, {128, OpBcst128}, {256, OpBcst256}, 252 }; 253 for (auto [BitWidth, OpBcst] : Broadcasts) { 254 if (OpBcst) { 255 // Construct a suitable splat constant and adjust the MI to 256 // use the new constant pool entry. 257 if (Constant *NewCst = rebuildSplatableConstant(C, BitWidth)) { 258 unsigned NewCPI = 259 CP->getConstantPoolIndex(NewCst, Align(BitWidth / 8)); 260 MI.setDesc(TII->get(OpBcst)); 261 CstOp.setIndex(NewCPI); 262 return true; 263 } 264 } 265 } 266 } 267 return false; 268 }; 269 270 // Attempt to convert full width vector loads into broadcast loads. 271 switch (Opc) { 272 /* FP Loads */ 273 case X86::MOVAPDrm: 274 case X86::MOVAPSrm: 275 case X86::MOVUPDrm: 276 case X86::MOVUPSrm: 277 // TODO: SSE3 MOVDDUP Handling 278 return false; 279 case X86::VMOVAPDrm: 280 case X86::VMOVAPSrm: 281 case X86::VMOVUPDrm: 282 case X86::VMOVUPSrm: 283 return ConvertToBroadcast(0, 0, X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0, 0, 284 1); 285 case X86::VMOVAPDYrm: 286 case X86::VMOVAPSYrm: 287 case X86::VMOVUPDYrm: 288 case X86::VMOVUPSYrm: 289 return ConvertToBroadcast(0, X86::VBROADCASTF128rm, X86::VBROADCASTSDYrm, 290 X86::VBROADCASTSSYrm, 0, 0, 1); 291 case X86::VMOVAPDZ128rm: 292 case X86::VMOVAPSZ128rm: 293 case X86::VMOVUPDZ128rm: 294 case X86::VMOVUPSZ128rm: 295 return ConvertToBroadcast(0, 0, X86::VMOVDDUPZ128rm, 296 X86::VBROADCASTSSZ128rm, 0, 0, 1); 297 case X86::VMOVAPDZ256rm: 298 case X86::VMOVAPSZ256rm: 299 case X86::VMOVUPDZ256rm: 300 case X86::VMOVUPSZ256rm: 301 return ConvertToBroadcast(0, X86::VBROADCASTF32X4Z256rm, 302 X86::VBROADCASTSDZ256rm, X86::VBROADCASTSSZ256rm, 303 0, 0, 1); 304 case X86::VMOVAPDZrm: 305 case X86::VMOVAPSZrm: 306 case X86::VMOVUPDZrm: 307 case X86::VMOVUPSZrm: 308 return ConvertToBroadcast(X86::VBROADCASTF64X4rm, X86::VBROADCASTF32X4rm, 309 X86::VBROADCASTSDZrm, X86::VBROADCASTSSZrm, 0, 0, 310 1); 311 /* Integer Loads */ 312 case X86::VMOVDQArm: 313 case X86::VMOVDQUrm: 314 return ConvertToBroadcast( 315 0, 0, HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm, 316 HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm, 317 HasAVX2 ? X86::VPBROADCASTWrm : 0, HasAVX2 ? X86::VPBROADCASTBrm : 0, 318 1); 319 case X86::VMOVDQAYrm: 320 case X86::VMOVDQUYrm: 321 return ConvertToBroadcast( 322 0, HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm, 323 HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm, 324 HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm, 325 HasAVX2 ? X86::VPBROADCASTWYrm : 0, HasAVX2 ? X86::VPBROADCASTBYrm : 0, 326 1); 327 case X86::VMOVDQA32Z128rm: 328 case X86::VMOVDQA64Z128rm: 329 case X86::VMOVDQU32Z128rm: 330 case X86::VMOVDQU64Z128rm: 331 return ConvertToBroadcast(0, 0, X86::VPBROADCASTQZ128rm, 332 X86::VPBROADCASTDZ128rm, 333 HasBWI ? X86::VPBROADCASTWZ128rm : 0, 334 HasBWI ? X86::VPBROADCASTBZ128rm : 0, 1); 335 case X86::VMOVDQA32Z256rm: 336 case X86::VMOVDQA64Z256rm: 337 case X86::VMOVDQU32Z256rm: 338 case X86::VMOVDQU64Z256rm: 339 return ConvertToBroadcast(0, X86::VBROADCASTI32X4Z256rm, 340 X86::VPBROADCASTQZ256rm, X86::VPBROADCASTDZ256rm, 341 HasBWI ? X86::VPBROADCASTWZ256rm : 0, 342 HasBWI ? X86::VPBROADCASTBZ256rm : 0, 1); 343 case X86::VMOVDQA32Zrm: 344 case X86::VMOVDQA64Zrm: 345 case X86::VMOVDQU32Zrm: 346 case X86::VMOVDQU64Zrm: 347 return ConvertToBroadcast(X86::VBROADCASTI64X4rm, X86::VBROADCASTI32X4rm, 348 X86::VPBROADCASTQZrm, X86::VPBROADCASTDZrm, 349 HasBWI ? X86::VPBROADCASTWZrm : 0, 350 HasBWI ? X86::VPBROADCASTBZrm : 0, 1); 351 } 352 353 auto ConvertToBroadcastAVX512 = [&](unsigned OpSrc32, unsigned OpSrc64) { 354 unsigned OpBcst32 = 0, OpBcst64 = 0; 355 unsigned OpNoBcst32 = 0, OpNoBcst64 = 0; 356 if (OpSrc32) { 357 if (const X86FoldTableEntry *Mem2Bcst = 358 llvm::lookupBroadcastFoldTable(OpSrc32, 32)) { 359 OpBcst32 = Mem2Bcst->DstOp; 360 OpNoBcst32 = Mem2Bcst->Flags & TB_INDEX_MASK; 361 } 362 } 363 if (OpSrc64) { 364 if (const X86FoldTableEntry *Mem2Bcst = 365 llvm::lookupBroadcastFoldTable(OpSrc64, 64)) { 366 OpBcst64 = Mem2Bcst->DstOp; 367 OpNoBcst64 = Mem2Bcst->Flags & TB_INDEX_MASK; 368 } 369 } 370 assert(((OpBcst32 == 0) || (OpBcst64 == 0) || (OpNoBcst32 == OpNoBcst64)) && 371 "OperandNo mismatch"); 372 373 if (OpBcst32 || OpBcst64) { 374 unsigned OpNo = OpBcst32 == 0 ? OpNoBcst64 : OpNoBcst32; 375 return ConvertToBroadcast(0, 0, OpBcst64, OpBcst32, 0, 0, OpNo); 376 } 377 return false; 378 }; 379 380 // Attempt to find a AVX512 mapping from a full width memory-fold instruction 381 // to a broadcast-fold instruction variant. 382 if ((MI.getDesc().TSFlags & X86II::EncodingMask) == X86II::EVEX) 383 return ConvertToBroadcastAVX512(Opc, Opc); 384 385 // Reverse the X86InstrInfo::setExecutionDomainCustom EVEX->VEX logic 386 // conversion to see if we can convert to a broadcasted (integer) logic op. 387 if (HasVLX && !HasDQI) { 388 unsigned OpSrc32 = 0, OpSrc64 = 0; 389 switch (Opc) { 390 case X86::VANDPDrm: 391 case X86::VANDPSrm: 392 case X86::VPANDrm: 393 OpSrc32 = X86 ::VPANDDZ128rm; 394 OpSrc64 = X86 ::VPANDQZ128rm; 395 break; 396 case X86::VANDPDYrm: 397 case X86::VANDPSYrm: 398 case X86::VPANDYrm: 399 OpSrc32 = X86 ::VPANDDZ256rm; 400 OpSrc64 = X86 ::VPANDQZ256rm; 401 break; 402 case X86::VANDNPDrm: 403 case X86::VANDNPSrm: 404 case X86::VPANDNrm: 405 OpSrc32 = X86 ::VPANDNDZ128rm; 406 OpSrc64 = X86 ::VPANDNQZ128rm; 407 break; 408 case X86::VANDNPDYrm: 409 case X86::VANDNPSYrm: 410 case X86::VPANDNYrm: 411 OpSrc32 = X86 ::VPANDNDZ256rm; 412 OpSrc64 = X86 ::VPANDNQZ256rm; 413 break; 414 case X86::VORPDrm: 415 case X86::VORPSrm: 416 case X86::VPORrm: 417 OpSrc32 = X86 ::VPORDZ128rm; 418 OpSrc64 = X86 ::VPORQZ128rm; 419 break; 420 case X86::VORPDYrm: 421 case X86::VORPSYrm: 422 case X86::VPORYrm: 423 OpSrc32 = X86 ::VPORDZ256rm; 424 OpSrc64 = X86 ::VPORQZ256rm; 425 break; 426 case X86::VXORPDrm: 427 case X86::VXORPSrm: 428 case X86::VPXORrm: 429 OpSrc32 = X86 ::VPXORDZ128rm; 430 OpSrc64 = X86 ::VPXORQZ128rm; 431 break; 432 case X86::VXORPDYrm: 433 case X86::VXORPSYrm: 434 case X86::VPXORYrm: 435 OpSrc32 = X86 ::VPXORDZ256rm; 436 OpSrc64 = X86 ::VPXORQZ256rm; 437 break; 438 } 439 if (OpSrc32 || OpSrc64) 440 return ConvertToBroadcastAVX512(OpSrc32, OpSrc64); 441 } 442 443 return false; 444 } 445 446 bool X86FixupVectorConstantsPass::runOnMachineFunction(MachineFunction &MF) { 447 LLVM_DEBUG(dbgs() << "Start X86FixupVectorConstants\n";); 448 bool Changed = false; 449 ST = &MF.getSubtarget<X86Subtarget>(); 450 TII = ST->getInstrInfo(); 451 SM = &ST->getSchedModel(); 452 453 for (MachineBasicBlock &MBB : MF) { 454 for (MachineInstr &MI : MBB) { 455 if (processInstruction(MF, MBB, MI)) { 456 ++NumInstChanges; 457 Changed = true; 458 } 459 } 460 } 461 LLVM_DEBUG(dbgs() << "End X86FixupVectorConstants\n";); 462 return Changed; 463 } 464