xref: /llvm-project/llvm/lib/Target/X86/X86FixupVectorConstants.cpp (revision e4375bf47fafb0cfcb2288ecfad94dfe63ca994c)
1 //===-- X86FixupVectorConstants.cpp - optimize constant generation  -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file examines all full size vector constant pool loads and attempts to
10 // replace them with smaller constant pool entries, including:
11 // * Converting AVX512 memory-fold instructions to their broadcast-fold form
12 // * Broadcasting of full width loads.
13 // * TODO: Sign/Zero extension of full width loads.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "X86.h"
18 #include "X86InstrFoldTables.h"
19 #include "X86InstrInfo.h"
20 #include "X86Subtarget.h"
21 #include "llvm/ADT/Statistic.h"
22 #include "llvm/CodeGen/MachineConstantPool.h"
23 
24 using namespace llvm;
25 
26 #define DEBUG_TYPE "x86-fixup-vector-constants"
27 
28 STATISTIC(NumInstChanges, "Number of instructions changes");
29 
30 namespace {
31 class X86FixupVectorConstantsPass : public MachineFunctionPass {
32 public:
33   static char ID;
34 
35   X86FixupVectorConstantsPass() : MachineFunctionPass(ID) {}
36 
37   StringRef getPassName() const override {
38     return "X86 Fixup Vector Constants";
39   }
40 
41   bool runOnMachineFunction(MachineFunction &MF) override;
42   bool processInstruction(MachineFunction &MF, MachineBasicBlock &MBB,
43                           MachineInstr &MI);
44 
45   // This pass runs after regalloc and doesn't support VReg operands.
46   MachineFunctionProperties getRequiredProperties() const override {
47     return MachineFunctionProperties().set(
48         MachineFunctionProperties::Property::NoVRegs);
49   }
50 
51 private:
52   const X86InstrInfo *TII = nullptr;
53   const X86Subtarget *ST = nullptr;
54   const MCSchedModel *SM = nullptr;
55 };
56 } // end anonymous namespace
57 
58 char X86FixupVectorConstantsPass::ID = 0;
59 
60 INITIALIZE_PASS(X86FixupVectorConstantsPass, DEBUG_TYPE, DEBUG_TYPE, false, false)
61 
62 FunctionPass *llvm::createX86FixupVectorConstants() {
63   return new X86FixupVectorConstantsPass();
64 }
65 
66 // Attempt to extract the full width of bits data from the constant.
67 static std::optional<APInt> extractConstantBits(const Constant *C) {
68   unsigned NumBits = C->getType()->getPrimitiveSizeInBits();
69 
70   if (isa<UndefValue>(C))
71     return APInt::getZero(NumBits);
72 
73   if (auto *CInt = dyn_cast<ConstantInt>(C))
74     return CInt->getValue();
75 
76   if (auto *CFP = dyn_cast<ConstantFP>(C))
77     return CFP->getValue().bitcastToAPInt();
78 
79   if (auto *CV = dyn_cast<ConstantVector>(C)) {
80     if (auto *CVSplat = CV->getSplatValue(/*AllowUndefs*/ true)) {
81       if (std::optional<APInt> Bits = extractConstantBits(CVSplat)) {
82         assert((NumBits % Bits->getBitWidth()) == 0 && "Illegal splat");
83         return APInt::getSplat(NumBits, *Bits);
84       }
85     }
86 
87     APInt Bits = APInt::getZero(NumBits);
88     for (unsigned I = 0, E = CV->getNumOperands(); I != E; ++I) {
89       Constant *Elt = CV->getOperand(I);
90       std::optional<APInt> SubBits = extractConstantBits(Elt);
91       if (!SubBits)
92         return std::nullopt;
93       assert(NumBits == (E * SubBits->getBitWidth()) &&
94              "Illegal vector element size");
95       Bits.insertBits(*SubBits, I * SubBits->getBitWidth());
96     }
97     return Bits;
98   }
99 
100   if (auto *CDS = dyn_cast<ConstantDataSequential>(C)) {
101     bool IsInteger = CDS->getElementType()->isIntegerTy();
102     bool IsFloat = CDS->getElementType()->isHalfTy() ||
103                    CDS->getElementType()->isBFloatTy() ||
104                    CDS->getElementType()->isFloatTy() ||
105                    CDS->getElementType()->isDoubleTy();
106     if (IsInteger || IsFloat) {
107       APInt Bits = APInt::getZero(NumBits);
108       unsigned EltBits = CDS->getElementType()->getPrimitiveSizeInBits();
109       for (unsigned I = 0, E = CDS->getNumElements(); I != E; ++I) {
110         if (IsInteger)
111           Bits.insertBits(CDS->getElementAsAPInt(I), I * EltBits);
112         else
113           Bits.insertBits(CDS->getElementAsAPFloat(I).bitcastToAPInt(),
114                           I * EltBits);
115       }
116       return Bits;
117     }
118   }
119 
120   return std::nullopt;
121 }
122 
123 // Attempt to compute the splat width of bits data by normalizing the splat to
124 // remove undefs.
125 static std::optional<APInt> getSplatableConstant(const Constant *C,
126                                                  unsigned SplatBitWidth) {
127   const Type *Ty = C->getType();
128   assert((Ty->getPrimitiveSizeInBits() % SplatBitWidth) == 0 &&
129          "Illegal splat width");
130 
131   if (std::optional<APInt> Bits = extractConstantBits(C))
132     if (Bits->isSplat(SplatBitWidth))
133       return Bits->trunc(SplatBitWidth);
134 
135   // Detect general splats with undefs.
136   // TODO: Do we need to handle NumEltsBits > SplatBitWidth splitting?
137   if (auto *CV = dyn_cast<ConstantVector>(C)) {
138     unsigned NumOps = CV->getNumOperands();
139     unsigned NumEltsBits = Ty->getScalarSizeInBits();
140     unsigned NumScaleOps = SplatBitWidth / NumEltsBits;
141     if ((SplatBitWidth % NumEltsBits) == 0) {
142       // Collect the elements and ensure that within the repeated splat sequence
143       // they either match or are undef.
144       SmallVector<Constant *, 16> Sequence(NumScaleOps, nullptr);
145       for (unsigned Idx = 0; Idx != NumOps; ++Idx) {
146         if (Constant *Elt = CV->getAggregateElement(Idx)) {
147           if (isa<UndefValue>(Elt))
148             continue;
149           unsigned SplatIdx = Idx % NumScaleOps;
150           if (!Sequence[SplatIdx] || Sequence[SplatIdx] == Elt) {
151             Sequence[SplatIdx] = Elt;
152             continue;
153           }
154         }
155         return std::nullopt;
156       }
157       // Extract the constant bits forming the splat and insert into the bits
158       // data, leave undef as zero.
159       APInt SplatBits = APInt::getZero(SplatBitWidth);
160       for (unsigned I = 0; I != NumScaleOps; ++I) {
161         if (!Sequence[I])
162           continue;
163         if (std::optional<APInt> Bits = extractConstantBits(Sequence[I])) {
164           SplatBits.insertBits(*Bits, I * Bits->getBitWidth());
165           continue;
166         }
167         return std::nullopt;
168       }
169       return SplatBits;
170     }
171   }
172 
173   return std::nullopt;
174 }
175 
176 // Split raw bits into a constant vector of elements of a specific bit width.
177 // NOTE: We don't always bother converting to scalars if the vector length is 1.
178 static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
179                                  const APInt &Bits, unsigned NumSclBits) {
180   unsigned BitWidth = Bits.getBitWidth();
181 
182   if (NumSclBits == 8) {
183     SmallVector<uint8_t> RawBits;
184     for (unsigned I = 0; I != BitWidth; I += 8)
185       RawBits.push_back(Bits.extractBits(8, I).getZExtValue());
186     return ConstantDataVector::get(Ctx, RawBits);
187   }
188 
189   if (NumSclBits == 16) {
190     SmallVector<uint16_t> RawBits;
191     for (unsigned I = 0; I != BitWidth; I += 16)
192       RawBits.push_back(Bits.extractBits(16, I).getZExtValue());
193     if (SclTy->is16bitFPTy())
194       return ConstantDataVector::getFP(SclTy, RawBits);
195     return ConstantDataVector::get(Ctx, RawBits);
196   }
197 
198   if (NumSclBits == 32) {
199     SmallVector<uint32_t> RawBits;
200     for (unsigned I = 0; I != BitWidth; I += 32)
201       RawBits.push_back(Bits.extractBits(32, I).getZExtValue());
202     if (SclTy->isFloatTy())
203       return ConstantDataVector::getFP(SclTy, RawBits);
204     return ConstantDataVector::get(Ctx, RawBits);
205   }
206 
207   assert(NumSclBits == 64 && "Unhandled vector element width");
208 
209   SmallVector<uint64_t> RawBits;
210   for (unsigned I = 0; I != BitWidth; I += 64)
211     RawBits.push_back(Bits.extractBits(64, I).getZExtValue());
212   if (SclTy->isDoubleTy())
213     return ConstantDataVector::getFP(SclTy, RawBits);
214   return ConstantDataVector::get(Ctx, RawBits);
215 }
216 
217 // Attempt to rebuild a normalized splat vector constant of the requested splat
218 // width, built up of potentially smaller scalar values.
219 static Constant *rebuildSplatableConstant(const Constant *C,
220                                           unsigned SplatBitWidth) {
221   std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
222   if (!Splat)
223     return nullptr;
224 
225   // Determine scalar size to use for the constant splat vector, clamping as we
226   // might have found a splat smaller than the original constant data.
227   const Type *OriginalType = C->getType();
228   Type *SclTy = OriginalType->getScalarType();
229   unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
230   NumSclBits = std::min<unsigned>(NumSclBits, SplatBitWidth);
231 
232   // Fallback to i64 / double.
233   NumSclBits = (NumSclBits == 8 || NumSclBits == 16 || NumSclBits == 32)
234                    ? NumSclBits
235                    : 64;
236 
237   // Extract per-element bits.
238   return rebuildConstant(OriginalType->getContext(), SclTy, *Splat, NumSclBits);
239 }
240 
241 static Constant *rebuildZeroUpperConstant(const Constant *C,
242                                           unsigned ScalarBitWidth) {
243   Type *Ty = C->getType();
244   Type *SclTy = Ty->getScalarType();
245   unsigned NumBits = Ty->getPrimitiveSizeInBits();
246   unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
247   LLVMContext &Ctx = C->getContext();
248 
249   if (NumBits > ScalarBitWidth) {
250     // Determine if the upper bits are all zero.
251     if (std::optional<APInt> Bits = extractConstantBits(C)) {
252       if (Bits->countLeadingZeros() >= (NumBits - ScalarBitWidth)) {
253         // If the original constant was made of smaller elements, try to retain
254         // those types.
255         if (ScalarBitWidth > NumSclBits && (ScalarBitWidth % NumSclBits) == 0)
256           return rebuildConstant(Ctx, SclTy, *Bits, NumSclBits);
257 
258         // Fallback to raw integer bits.
259         APInt RawBits = Bits->zextOrTrunc(ScalarBitWidth);
260         return ConstantInt::get(Ctx, RawBits);
261       }
262     }
263   }
264 
265   return nullptr;
266 }
267 
268 typedef std::function<Constant *(const Constant *, unsigned)> RebuildFn;
269 
270 bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
271                                                      MachineBasicBlock &MBB,
272                                                      MachineInstr &MI) {
273   unsigned Opc = MI.getOpcode();
274   MachineConstantPool *CP = MI.getParent()->getParent()->getConstantPool();
275   bool HasAVX2 = ST->hasAVX2();
276   bool HasDQI = ST->hasDQI();
277   bool HasBWI = ST->hasBWI();
278   bool HasVLX = ST->hasVLX();
279 
280   auto FixupConstant =
281       [&](unsigned OpBcst256, unsigned OpBcst128, unsigned OpBcst64,
282           unsigned OpBcst32, unsigned OpBcst16, unsigned OpBcst8,
283           unsigned OpUpper64, unsigned OpUpper32, unsigned OperandNo) {
284         assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
285                "Unexpected number of operands!");
286 
287         if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
288           // Attempt to detect a suitable splat/vzload from increasing constant
289           // bitwidths.
290           // Prefer vzload vs broadcast for same bitwidth to avoid domain flips.
291           std::tuple<unsigned, unsigned, RebuildFn> FixupLoad[] = {
292               {8, OpBcst8, rebuildSplatableConstant},
293               {16, OpBcst16, rebuildSplatableConstant},
294               {32, OpUpper32, rebuildZeroUpperConstant},
295               {32, OpBcst32, rebuildSplatableConstant},
296               {64, OpUpper64, rebuildZeroUpperConstant},
297               {64, OpBcst64, rebuildSplatableConstant},
298               {128, OpBcst128, rebuildSplatableConstant},
299               {256, OpBcst256, rebuildSplatableConstant},
300           };
301           for (auto [BitWidth, Op, RebuildConstant] : FixupLoad) {
302             if (Op) {
303               // Construct a suitable constant and adjust the MI to use the new
304               // constant pool entry.
305               if (Constant *NewCst = RebuildConstant(C, BitWidth)) {
306                 unsigned NewCPI =
307                     CP->getConstantPoolIndex(NewCst, Align(BitWidth / 8));
308                 MI.setDesc(TII->get(Op));
309                 MI.getOperand(OperandNo + X86::AddrDisp).setIndex(NewCPI);
310                 return true;
311               }
312             }
313           }
314         }
315         return false;
316       };
317 
318   // Attempt to convert full width vector loads into broadcast/vzload loads.
319   switch (Opc) {
320   /* FP Loads */
321   case X86::MOVAPDrm:
322   case X86::MOVAPSrm:
323   case X86::MOVUPDrm:
324   case X86::MOVUPSrm:
325     // TODO: SSE3 MOVDDUP Handling
326     return FixupConstant(0, 0, 0, 0, 0, 0, X86::MOVSDrm, X86::MOVSSrm, 1);
327   case X86::VMOVAPDrm:
328   case X86::VMOVAPSrm:
329   case X86::VMOVUPDrm:
330   case X86::VMOVUPSrm:
331     return FixupConstant(0, 0, X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0, 0,
332                          X86::VMOVSDrm, X86::VMOVSSrm, 1);
333   case X86::VMOVAPDYrm:
334   case X86::VMOVAPSYrm:
335   case X86::VMOVUPDYrm:
336   case X86::VMOVUPSYrm:
337     return FixupConstant(0, X86::VBROADCASTF128rm, X86::VBROADCASTSDYrm,
338                          X86::VBROADCASTSSYrm, 0, 0, 0, 0, 1);
339   case X86::VMOVAPDZ128rm:
340   case X86::VMOVAPSZ128rm:
341   case X86::VMOVUPDZ128rm:
342   case X86::VMOVUPSZ128rm:
343     return FixupConstant(0, 0, X86::VMOVDDUPZ128rm, X86::VBROADCASTSSZ128rm, 0,
344                          0, X86::VMOVSDZrm, X86::VMOVSSZrm, 1);
345   case X86::VMOVAPDZ256rm:
346   case X86::VMOVAPSZ256rm:
347   case X86::VMOVUPDZ256rm:
348   case X86::VMOVUPSZ256rm:
349     return FixupConstant(0, X86::VBROADCASTF32X4Z256rm, X86::VBROADCASTSDZ256rm,
350                          X86::VBROADCASTSSZ256rm, 0, 0, 0, 0, 1);
351   case X86::VMOVAPDZrm:
352   case X86::VMOVAPSZrm:
353   case X86::VMOVUPDZrm:
354   case X86::VMOVUPSZrm:
355     return FixupConstant(X86::VBROADCASTF64X4rm, X86::VBROADCASTF32X4rm,
356                          X86::VBROADCASTSDZrm, X86::VBROADCASTSSZrm, 0, 0, 0, 0,
357                          1);
358     /* Integer Loads */
359   case X86::MOVDQArm:
360   case X86::MOVDQUrm:
361     return FixupConstant(0, 0, 0, 0, 0, 0, X86::MOVQI2PQIrm, X86::MOVDI2PDIrm,
362                          1);
363   case X86::VMOVDQArm:
364   case X86::VMOVDQUrm:
365     return FixupConstant(0, 0, HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm,
366                          HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm,
367                          HasAVX2 ? X86::VPBROADCASTWrm : 0,
368                          HasAVX2 ? X86::VPBROADCASTBrm : 0, X86::VMOVQI2PQIrm,
369                          X86::VMOVDI2PDIrm, 1);
370   case X86::VMOVDQAYrm:
371   case X86::VMOVDQUYrm:
372     return FixupConstant(
373         0, HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm,
374         HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm,
375         HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm,
376         HasAVX2 ? X86::VPBROADCASTWYrm : 0, HasAVX2 ? X86::VPBROADCASTBYrm : 0,
377         0, 0, 1);
378   case X86::VMOVDQA32Z128rm:
379   case X86::VMOVDQA64Z128rm:
380   case X86::VMOVDQU32Z128rm:
381   case X86::VMOVDQU64Z128rm:
382     return FixupConstant(0, 0, X86::VPBROADCASTQZ128rm, X86::VPBROADCASTDZ128rm,
383                          HasBWI ? X86::VPBROADCASTWZ128rm : 0,
384                          HasBWI ? X86::VPBROADCASTBZ128rm : 0,
385                          X86::VMOVQI2PQIZrm, X86::VMOVDI2PDIZrm, 1);
386   case X86::VMOVDQA32Z256rm:
387   case X86::VMOVDQA64Z256rm:
388   case X86::VMOVDQU32Z256rm:
389   case X86::VMOVDQU64Z256rm:
390     return FixupConstant(0, X86::VBROADCASTI32X4Z256rm, X86::VPBROADCASTQZ256rm,
391                          X86::VPBROADCASTDZ256rm,
392                          HasBWI ? X86::VPBROADCASTWZ256rm : 0,
393                          HasBWI ? X86::VPBROADCASTBZ256rm : 0, 0, 0, 1);
394   case X86::VMOVDQA32Zrm:
395   case X86::VMOVDQA64Zrm:
396   case X86::VMOVDQU32Zrm:
397   case X86::VMOVDQU64Zrm:
398     return FixupConstant(X86::VBROADCASTI64X4rm, X86::VBROADCASTI32X4rm,
399                          X86::VPBROADCASTQZrm, X86::VPBROADCASTDZrm,
400                          HasBWI ? X86::VPBROADCASTWZrm : 0,
401                          HasBWI ? X86::VPBROADCASTBZrm : 0, 0, 0, 1);
402   }
403 
404   auto ConvertToBroadcastAVX512 = [&](unsigned OpSrc32, unsigned OpSrc64) {
405     unsigned OpBcst32 = 0, OpBcst64 = 0;
406     unsigned OpNoBcst32 = 0, OpNoBcst64 = 0;
407     if (OpSrc32) {
408       if (const X86FoldTableEntry *Mem2Bcst =
409               llvm::lookupBroadcastFoldTable(OpSrc32, 32)) {
410         OpBcst32 = Mem2Bcst->DstOp;
411         OpNoBcst32 = Mem2Bcst->Flags & TB_INDEX_MASK;
412       }
413     }
414     if (OpSrc64) {
415       if (const X86FoldTableEntry *Mem2Bcst =
416               llvm::lookupBroadcastFoldTable(OpSrc64, 64)) {
417         OpBcst64 = Mem2Bcst->DstOp;
418         OpNoBcst64 = Mem2Bcst->Flags & TB_INDEX_MASK;
419       }
420     }
421     assert(((OpBcst32 == 0) || (OpBcst64 == 0) || (OpNoBcst32 == OpNoBcst64)) &&
422            "OperandNo mismatch");
423 
424     if (OpBcst32 || OpBcst64) {
425       unsigned OpNo = OpBcst32 == 0 ? OpNoBcst64 : OpNoBcst32;
426       return FixupConstant(0, 0, OpBcst64, OpBcst32, 0, 0, 0, 0, OpNo);
427     }
428     return false;
429   };
430 
431   // Attempt to find a AVX512 mapping from a full width memory-fold instruction
432   // to a broadcast-fold instruction variant.
433   if ((MI.getDesc().TSFlags & X86II::EncodingMask) == X86II::EVEX)
434     return ConvertToBroadcastAVX512(Opc, Opc);
435 
436   // Reverse the X86InstrInfo::setExecutionDomainCustom EVEX->VEX logic
437   // conversion to see if we can convert to a broadcasted (integer) logic op.
438   if (HasVLX && !HasDQI) {
439     unsigned OpSrc32 = 0, OpSrc64 = 0;
440     switch (Opc) {
441     case X86::VANDPDrm:
442     case X86::VANDPSrm:
443     case X86::VPANDrm:
444       OpSrc32 = X86 ::VPANDDZ128rm;
445       OpSrc64 = X86 ::VPANDQZ128rm;
446       break;
447     case X86::VANDPDYrm:
448     case X86::VANDPSYrm:
449     case X86::VPANDYrm:
450       OpSrc32 = X86 ::VPANDDZ256rm;
451       OpSrc64 = X86 ::VPANDQZ256rm;
452       break;
453     case X86::VANDNPDrm:
454     case X86::VANDNPSrm:
455     case X86::VPANDNrm:
456       OpSrc32 = X86 ::VPANDNDZ128rm;
457       OpSrc64 = X86 ::VPANDNQZ128rm;
458       break;
459     case X86::VANDNPDYrm:
460     case X86::VANDNPSYrm:
461     case X86::VPANDNYrm:
462       OpSrc32 = X86 ::VPANDNDZ256rm;
463       OpSrc64 = X86 ::VPANDNQZ256rm;
464       break;
465     case X86::VORPDrm:
466     case X86::VORPSrm:
467     case X86::VPORrm:
468       OpSrc32 = X86 ::VPORDZ128rm;
469       OpSrc64 = X86 ::VPORQZ128rm;
470       break;
471     case X86::VORPDYrm:
472     case X86::VORPSYrm:
473     case X86::VPORYrm:
474       OpSrc32 = X86 ::VPORDZ256rm;
475       OpSrc64 = X86 ::VPORQZ256rm;
476       break;
477     case X86::VXORPDrm:
478     case X86::VXORPSrm:
479     case X86::VPXORrm:
480       OpSrc32 = X86 ::VPXORDZ128rm;
481       OpSrc64 = X86 ::VPXORQZ128rm;
482       break;
483     case X86::VXORPDYrm:
484     case X86::VXORPSYrm:
485     case X86::VPXORYrm:
486       OpSrc32 = X86 ::VPXORDZ256rm;
487       OpSrc64 = X86 ::VPXORQZ256rm;
488       break;
489     }
490     if (OpSrc32 || OpSrc64)
491       return ConvertToBroadcastAVX512(OpSrc32, OpSrc64);
492   }
493 
494   return false;
495 }
496 
497 bool X86FixupVectorConstantsPass::runOnMachineFunction(MachineFunction &MF) {
498   LLVM_DEBUG(dbgs() << "Start X86FixupVectorConstants\n";);
499   bool Changed = false;
500   ST = &MF.getSubtarget<X86Subtarget>();
501   TII = ST->getInstrInfo();
502   SM = &ST->getSchedModel();
503 
504   for (MachineBasicBlock &MBB : MF) {
505     for (MachineInstr &MI : MBB) {
506       if (processInstruction(MF, MBB, MI)) {
507         ++NumInstChanges;
508         Changed = true;
509       }
510     }
511   }
512   LLVM_DEBUG(dbgs() << "End X86FixupVectorConstants\n";);
513   return Changed;
514 }
515