xref: /llvm-project/llvm/lib/Target/X86/X86FixupVectorConstants.cpp (revision b5d35feacb7246573c6a4ab2bddc4919a4228ed5)
1 //===-- X86FixupVectorConstants.cpp - optimize constant generation  -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file examines all full size vector constant pool loads and attempts to
10 // replace them with smaller constant pool entries, including:
11 // * Converting AVX512 memory-fold instructions to their broadcast-fold form
12 // * Broadcasting of full width loads.
13 // * TODO: Zero extension of full width loads.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "X86.h"
18 #include "X86InstrFoldTables.h"
19 #include "X86InstrInfo.h"
20 #include "X86Subtarget.h"
21 #include "llvm/ADT/Statistic.h"
22 #include "llvm/CodeGen/MachineConstantPool.h"
23 
24 using namespace llvm;
25 
26 #define DEBUG_TYPE "x86-fixup-vector-constants"
27 
28 STATISTIC(NumInstChanges, "Number of instructions changes");
29 
30 namespace {
31 class X86FixupVectorConstantsPass : public MachineFunctionPass {
32 public:
33   static char ID;
34 
35   X86FixupVectorConstantsPass() : MachineFunctionPass(ID) {}
36 
37   StringRef getPassName() const override {
38     return "X86 Fixup Vector Constants";
39   }
40 
41   bool runOnMachineFunction(MachineFunction &MF) override;
42   bool processInstruction(MachineFunction &MF, MachineBasicBlock &MBB,
43                           MachineInstr &MI);
44 
45   // This pass runs after regalloc and doesn't support VReg operands.
46   MachineFunctionProperties getRequiredProperties() const override {
47     return MachineFunctionProperties().set(
48         MachineFunctionProperties::Property::NoVRegs);
49   }
50 
51 private:
52   const X86InstrInfo *TII = nullptr;
53   const X86Subtarget *ST = nullptr;
54   const MCSchedModel *SM = nullptr;
55 };
56 } // end anonymous namespace
57 
58 char X86FixupVectorConstantsPass::ID = 0;
59 
60 INITIALIZE_PASS(X86FixupVectorConstantsPass, DEBUG_TYPE, DEBUG_TYPE, false, false)
61 
62 FunctionPass *llvm::createX86FixupVectorConstants() {
63   return new X86FixupVectorConstantsPass();
64 }
65 
66 // Attempt to extract the full width of bits data from the constant.
67 static std::optional<APInt> extractConstantBits(const Constant *C) {
68   unsigned NumBits = C->getType()->getPrimitiveSizeInBits();
69 
70   if (isa<UndefValue>(C))
71     return APInt::getZero(NumBits);
72 
73   if (auto *CInt = dyn_cast<ConstantInt>(C))
74     return CInt->getValue();
75 
76   if (auto *CFP = dyn_cast<ConstantFP>(C))
77     return CFP->getValue().bitcastToAPInt();
78 
79   if (auto *CV = dyn_cast<ConstantVector>(C)) {
80     if (auto *CVSplat = CV->getSplatValue(/*AllowUndefs*/ true)) {
81       if (std::optional<APInt> Bits = extractConstantBits(CVSplat)) {
82         assert((NumBits % Bits->getBitWidth()) == 0 && "Illegal splat");
83         return APInt::getSplat(NumBits, *Bits);
84       }
85     }
86 
87     APInt Bits = APInt::getZero(NumBits);
88     for (unsigned I = 0, E = CV->getNumOperands(); I != E; ++I) {
89       Constant *Elt = CV->getOperand(I);
90       std::optional<APInt> SubBits = extractConstantBits(Elt);
91       if (!SubBits)
92         return std::nullopt;
93       assert(NumBits == (E * SubBits->getBitWidth()) &&
94              "Illegal vector element size");
95       Bits.insertBits(*SubBits, I * SubBits->getBitWidth());
96     }
97     return Bits;
98   }
99 
100   if (auto *CDS = dyn_cast<ConstantDataSequential>(C)) {
101     bool IsInteger = CDS->getElementType()->isIntegerTy();
102     bool IsFloat = CDS->getElementType()->isHalfTy() ||
103                    CDS->getElementType()->isBFloatTy() ||
104                    CDS->getElementType()->isFloatTy() ||
105                    CDS->getElementType()->isDoubleTy();
106     if (IsInteger || IsFloat) {
107       APInt Bits = APInt::getZero(NumBits);
108       unsigned EltBits = CDS->getElementType()->getPrimitiveSizeInBits();
109       for (unsigned I = 0, E = CDS->getNumElements(); I != E; ++I) {
110         if (IsInteger)
111           Bits.insertBits(CDS->getElementAsAPInt(I), I * EltBits);
112         else
113           Bits.insertBits(CDS->getElementAsAPFloat(I).bitcastToAPInt(),
114                           I * EltBits);
115       }
116       return Bits;
117     }
118   }
119 
120   return std::nullopt;
121 }
122 
123 // Attempt to compute the splat width of bits data by normalizing the splat to
124 // remove undefs.
125 static std::optional<APInt> getSplatableConstant(const Constant *C,
126                                                  unsigned SplatBitWidth) {
127   const Type *Ty = C->getType();
128   assert((Ty->getPrimitiveSizeInBits() % SplatBitWidth) == 0 &&
129          "Illegal splat width");
130 
131   if (std::optional<APInt> Bits = extractConstantBits(C))
132     if (Bits->isSplat(SplatBitWidth))
133       return Bits->trunc(SplatBitWidth);
134 
135   // Detect general splats with undefs.
136   // TODO: Do we need to handle NumEltsBits > SplatBitWidth splitting?
137   if (auto *CV = dyn_cast<ConstantVector>(C)) {
138     unsigned NumOps = CV->getNumOperands();
139     unsigned NumEltsBits = Ty->getScalarSizeInBits();
140     unsigned NumScaleOps = SplatBitWidth / NumEltsBits;
141     if ((SplatBitWidth % NumEltsBits) == 0) {
142       // Collect the elements and ensure that within the repeated splat sequence
143       // they either match or are undef.
144       SmallVector<Constant *, 16> Sequence(NumScaleOps, nullptr);
145       for (unsigned Idx = 0; Idx != NumOps; ++Idx) {
146         if (Constant *Elt = CV->getAggregateElement(Idx)) {
147           if (isa<UndefValue>(Elt))
148             continue;
149           unsigned SplatIdx = Idx % NumScaleOps;
150           if (!Sequence[SplatIdx] || Sequence[SplatIdx] == Elt) {
151             Sequence[SplatIdx] = Elt;
152             continue;
153           }
154         }
155         return std::nullopt;
156       }
157       // Extract the constant bits forming the splat and insert into the bits
158       // data, leave undef as zero.
159       APInt SplatBits = APInt::getZero(SplatBitWidth);
160       for (unsigned I = 0; I != NumScaleOps; ++I) {
161         if (!Sequence[I])
162           continue;
163         if (std::optional<APInt> Bits = extractConstantBits(Sequence[I])) {
164           SplatBits.insertBits(*Bits, I * Bits->getBitWidth());
165           continue;
166         }
167         return std::nullopt;
168       }
169       return SplatBits;
170     }
171   }
172 
173   return std::nullopt;
174 }
175 
176 // Split raw bits into a constant vector of elements of a specific bit width.
177 // NOTE: We don't always bother converting to scalars if the vector length is 1.
178 static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
179                                  const APInt &Bits, unsigned NumSclBits) {
180   unsigned BitWidth = Bits.getBitWidth();
181 
182   if (NumSclBits == 8) {
183     SmallVector<uint8_t> RawBits;
184     for (unsigned I = 0; I != BitWidth; I += 8)
185       RawBits.push_back(Bits.extractBits(8, I).getZExtValue());
186     return ConstantDataVector::get(Ctx, RawBits);
187   }
188 
189   if (NumSclBits == 16) {
190     SmallVector<uint16_t> RawBits;
191     for (unsigned I = 0; I != BitWidth; I += 16)
192       RawBits.push_back(Bits.extractBits(16, I).getZExtValue());
193     if (SclTy->is16bitFPTy())
194       return ConstantDataVector::getFP(SclTy, RawBits);
195     return ConstantDataVector::get(Ctx, RawBits);
196   }
197 
198   if (NumSclBits == 32) {
199     SmallVector<uint32_t> RawBits;
200     for (unsigned I = 0; I != BitWidth; I += 32)
201       RawBits.push_back(Bits.extractBits(32, I).getZExtValue());
202     if (SclTy->isFloatTy())
203       return ConstantDataVector::getFP(SclTy, RawBits);
204     return ConstantDataVector::get(Ctx, RawBits);
205   }
206 
207   assert(NumSclBits == 64 && "Unhandled vector element width");
208 
209   SmallVector<uint64_t> RawBits;
210   for (unsigned I = 0; I != BitWidth; I += 64)
211     RawBits.push_back(Bits.extractBits(64, I).getZExtValue());
212   if (SclTy->isDoubleTy())
213     return ConstantDataVector::getFP(SclTy, RawBits);
214   return ConstantDataVector::get(Ctx, RawBits);
215 }
216 
217 // Attempt to rebuild a normalized splat vector constant of the requested splat
218 // width, built up of potentially smaller scalar values.
219 static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumElts*/,
220                                  unsigned SplatBitWidth) {
221   std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
222   if (!Splat)
223     return nullptr;
224 
225   // Determine scalar size to use for the constant splat vector, clamping as we
226   // might have found a splat smaller than the original constant data.
227   const Type *OriginalType = C->getType();
228   Type *SclTy = OriginalType->getScalarType();
229   unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
230   NumSclBits = std::min<unsigned>(NumSclBits, SplatBitWidth);
231 
232   // Fallback to i64 / double.
233   NumSclBits = (NumSclBits == 8 || NumSclBits == 16 || NumSclBits == 32)
234                    ? NumSclBits
235                    : 64;
236 
237   // Extract per-element bits.
238   return rebuildConstant(OriginalType->getContext(), SclTy, *Splat, NumSclBits);
239 }
240 
241 static Constant *rebuildZeroUpperCst(const Constant *C, unsigned /*NumElts*/,
242                                      unsigned ScalarBitWidth) {
243   Type *Ty = C->getType();
244   Type *SclTy = Ty->getScalarType();
245   unsigned NumBits = Ty->getPrimitiveSizeInBits();
246   unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
247   LLVMContext &Ctx = C->getContext();
248 
249   if (NumBits > ScalarBitWidth) {
250     // Determine if the upper bits are all zero.
251     if (std::optional<APInt> Bits = extractConstantBits(C)) {
252       if (Bits->countLeadingZeros() >= (NumBits - ScalarBitWidth)) {
253         // If the original constant was made of smaller elements, try to retain
254         // those types.
255         if (ScalarBitWidth > NumSclBits && (ScalarBitWidth % NumSclBits) == 0)
256           return rebuildConstant(Ctx, SclTy, *Bits, NumSclBits);
257 
258         // Fallback to raw integer bits.
259         APInt RawBits = Bits->zextOrTrunc(ScalarBitWidth);
260         return ConstantInt::get(Ctx, RawBits);
261       }
262     }
263   }
264 
265   return nullptr;
266 }
267 
268 static Constant *rebuildExtCst(const Constant *C, bool IsSExt, unsigned NumElts,
269                                unsigned SrcEltBitWidth) {
270   Type *Ty = C->getType();
271   unsigned NumBits = Ty->getPrimitiveSizeInBits();
272   unsigned DstEltBitWidth = NumBits / NumElts;
273   assert((NumBits % NumElts) == 0 && (NumBits % SrcEltBitWidth) == 0 &&
274          (DstEltBitWidth % SrcEltBitWidth) == 0 &&
275          (DstEltBitWidth > SrcEltBitWidth) && "Illegal extension width");
276 
277   if (std::optional<APInt> Bits = extractConstantBits(C)) {
278     assert((Bits->getBitWidth() / DstEltBitWidth) == NumElts &&
279            (Bits->getBitWidth() % DstEltBitWidth) == 0 &&
280            "Unexpected constant extension");
281 
282     // Ensure every vector element can be represented by the src bitwidth.
283     APInt TruncBits = APInt::getZero(NumElts * SrcEltBitWidth);
284     for (unsigned I = 0; I != NumElts; ++I) {
285       APInt Elt = Bits->extractBits(DstEltBitWidth, I * DstEltBitWidth);
286       if ((IsSExt && Elt.getSignificantBits() > SrcEltBitWidth) ||
287           (!IsSExt && Elt.getActiveBits() > SrcEltBitWidth))
288         return nullptr;
289       TruncBits.insertBits(Elt.trunc(SrcEltBitWidth), I * SrcEltBitWidth);
290     }
291 
292     return rebuildConstant(Ty->getContext(), Ty->getScalarType(), TruncBits,
293                            SrcEltBitWidth);
294   }
295 
296   return nullptr;
297 }
298 static Constant *rebuildSExtCst(const Constant *C, unsigned NumElts,
299                                 unsigned SrcEltBitWidth) {
300   return rebuildExtCst(C, true, NumElts, SrcEltBitWidth);
301 }
302 
303 bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
304                                                      MachineBasicBlock &MBB,
305                                                      MachineInstr &MI) {
306   unsigned Opc = MI.getOpcode();
307   MachineConstantPool *CP = MI.getParent()->getParent()->getConstantPool();
308   bool HasSSE41 = ST->hasSSE41();
309   bool HasAVX2 = ST->hasAVX2();
310   bool HasDQI = ST->hasDQI();
311   bool HasBWI = ST->hasBWI();
312   bool HasVLX = ST->hasVLX();
313 
314   struct FixupEntry {
315     int Op;
316     int NumCstElts;
317     int BitWidth;
318     std::function<Constant *(const Constant *, unsigned, unsigned)>
319         RebuildConstant;
320   };
321   auto FixupConstant = [&](ArrayRef<FixupEntry> Fixups, unsigned OperandNo) {
322 #ifdef EXPENSIVE_CHECKS
323     assert(llvm::is_sorted(Fixups,
324                            [](const FixupEntry &A, const FixupEntry &B) {
325                              return (A.NumCstElts * A.BitWidth) <
326                                     (B.NumCstElts * B.BitWidth);
327                            }) &&
328            "Constant fixup table not sorted in ascending constant size");
329 #endif
330     assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
331            "Unexpected number of operands!");
332     if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
333       for (const FixupEntry &Fixup : Fixups) {
334         if (Fixup.Op) {
335           // Construct a suitable constant and adjust the MI to use the new
336           // constant pool entry.
337           if (Constant *NewCst =
338                   Fixup.RebuildConstant(C, Fixup.NumCstElts, Fixup.BitWidth)) {
339             unsigned NewCPI =
340                 CP->getConstantPoolIndex(NewCst, Align(Fixup.BitWidth / 8));
341             MI.setDesc(TII->get(Fixup.Op));
342             MI.getOperand(OperandNo + X86::AddrDisp).setIndex(NewCPI);
343             return true;
344           }
345         }
346       }
347     }
348     return false;
349   };
350 
351   // Attempt to detect a suitable vzload/broadcast/vextload from increasing
352   // constant bitwidths. Prefer vzload/broadcast/vextload for same bitwidth:
353   // - vzload shouldn't ever need a shuffle port to zero the upper elements and
354   // the fp/int domain versions are equally available so we don't introduce a
355   // domain crossing penalty.
356   // - broadcast sometimes need a shuffle port (especially for 8/16-bit
357   // variants), AVX1 only has fp domain broadcasts but AVX2+ have good fp/int
358   // domain equivalents.
359   // - vextload always needs a shuffle port and is only ever int domain.
360   switch (Opc) {
361   /* FP Loads */
362   case X86::MOVAPDrm:
363   case X86::MOVAPSrm:
364   case X86::MOVUPDrm:
365   case X86::MOVUPSrm:
366     // TODO: SSE3 MOVDDUP Handling
367     return FixupConstant({{X86::MOVSSrm, 1, 32, rebuildZeroUpperCst},
368                           {X86::MOVSDrm, 1, 64, rebuildZeroUpperCst}},
369                          1);
370   case X86::VMOVAPDrm:
371   case X86::VMOVAPSrm:
372   case X86::VMOVUPDrm:
373   case X86::VMOVUPSrm:
374     return FixupConstant({{X86::VMOVSSrm, 1, 32, rebuildZeroUpperCst},
375                           {X86::VBROADCASTSSrm, 1, 32, rebuildSplatCst},
376                           {X86::VMOVSDrm, 1, 64, rebuildZeroUpperCst},
377                           {X86::VMOVDDUPrm, 1, 64, rebuildSplatCst}},
378                          1);
379   case X86::VMOVAPDYrm:
380   case X86::VMOVAPSYrm:
381   case X86::VMOVUPDYrm:
382   case X86::VMOVUPSYrm:
383     return FixupConstant({{X86::VBROADCASTSSYrm, 1, 32, rebuildSplatCst},
384                           {X86::VBROADCASTSDYrm, 1, 64, rebuildSplatCst},
385                           {X86::VBROADCASTF128rm, 1, 128, rebuildSplatCst}},
386                          1);
387   case X86::VMOVAPDZ128rm:
388   case X86::VMOVAPSZ128rm:
389   case X86::VMOVUPDZ128rm:
390   case X86::VMOVUPSZ128rm:
391     return FixupConstant({{X86::VMOVSSZrm, 1, 32, rebuildZeroUpperCst},
392                           {X86::VBROADCASTSSZ128rm, 1, 32, rebuildSplatCst},
393                           {X86::VMOVSDZrm, 1, 64, rebuildZeroUpperCst},
394                           {X86::VMOVDDUPZ128rm, 1, 64, rebuildSplatCst}},
395                          1);
396   case X86::VMOVAPDZ256rm:
397   case X86::VMOVAPSZ256rm:
398   case X86::VMOVUPDZ256rm:
399   case X86::VMOVUPSZ256rm:
400     return FixupConstant(
401         {{X86::VBROADCASTSSZ256rm, 1, 32, rebuildSplatCst},
402          {X86::VBROADCASTSDZ256rm, 1, 64, rebuildSplatCst},
403          {X86::VBROADCASTF32X4Z256rm, 1, 128, rebuildSplatCst}},
404         1);
405   case X86::VMOVAPDZrm:
406   case X86::VMOVAPSZrm:
407   case X86::VMOVUPDZrm:
408   case X86::VMOVUPSZrm:
409     return FixupConstant({{X86::VBROADCASTSSZrm, 1, 32, rebuildSplatCst},
410                           {X86::VBROADCASTSDZrm, 1, 64, rebuildSplatCst},
411                           {X86::VBROADCASTF32X4rm, 1, 128, rebuildSplatCst},
412                           {X86::VBROADCASTF64X4rm, 1, 256, rebuildSplatCst}},
413                          1);
414     /* Integer Loads */
415   case X86::MOVDQArm:
416   case X86::MOVDQUrm: {
417     FixupEntry Fixups[] = {
418         {HasSSE41 ? X86::PMOVSXBQrm : 0, 2, 8, rebuildSExtCst},
419         {X86::MOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
420         {HasSSE41 ? X86::PMOVSXBDrm : 0, 4, 8, rebuildSExtCst},
421         {HasSSE41 ? X86::PMOVSXWQrm : 0, 2, 16, rebuildSExtCst},
422         {X86::MOVQI2PQIrm, 1, 64, rebuildZeroUpperCst},
423         {HasSSE41 ? X86::PMOVSXBWrm : 0, 8, 8, rebuildSExtCst},
424         {HasSSE41 ? X86::PMOVSXWDrm : 0, 4, 16, rebuildSExtCst},
425         {HasSSE41 ? X86::PMOVSXDQrm : 0, 2, 32, rebuildSExtCst}};
426     return FixupConstant(Fixups, 1);
427   }
428   case X86::VMOVDQArm:
429   case X86::VMOVDQUrm: {
430     FixupEntry Fixups[] = {
431         {HasAVX2 ? X86::VPBROADCASTBrm : 0, 1, 8, rebuildSplatCst},
432         {HasAVX2 ? X86::VPBROADCASTWrm : 0, 1, 16, rebuildSplatCst},
433         {X86::VPMOVSXBQrm, 2, 8, rebuildSExtCst},
434         {X86::VMOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
435         {HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm, 1, 32,
436          rebuildSplatCst},
437         {X86::VPMOVSXBDrm, 4, 8, rebuildSExtCst},
438         {X86::VPMOVSXWQrm, 2, 16, rebuildSExtCst},
439         {X86::VMOVQI2PQIrm, 1, 64, rebuildZeroUpperCst},
440         {HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm, 1, 64,
441          rebuildSplatCst},
442         {X86::VPMOVSXBWrm, 8, 8, rebuildSExtCst},
443         {X86::VPMOVSXWDrm, 4, 16, rebuildSExtCst},
444         {X86::VPMOVSXDQrm, 2, 32, rebuildSExtCst}};
445     return FixupConstant(Fixups, 1);
446   }
447   case X86::VMOVDQAYrm:
448   case X86::VMOVDQUYrm: {
449     FixupEntry Fixups[] = {
450         {HasAVX2 ? X86::VPBROADCASTBYrm : 0, 1, 8, rebuildSplatCst},
451         {HasAVX2 ? X86::VPBROADCASTWYrm : 0, 1, 16, rebuildSplatCst},
452         {HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm, 1, 32,
453          rebuildSplatCst},
454         {HasAVX2 ? X86::VPMOVSXBQYrm : 0, 4, 8, rebuildSExtCst},
455         {HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm, 1, 64,
456          rebuildSplatCst},
457         {HasAVX2 ? X86::VPMOVSXBDYrm : 0, 8, 8, rebuildSExtCst},
458         {HasAVX2 ? X86::VPMOVSXWQYrm : 0, 4, 16, rebuildSExtCst},
459         {HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm, 1, 128,
460          rebuildSplatCst},
461         {HasAVX2 ? X86::VPMOVSXBWYrm : 0, 16, 8, rebuildSExtCst},
462         {HasAVX2 ? X86::VPMOVSXWDYrm : 0, 8, 16, rebuildSExtCst},
463         {HasAVX2 ? X86::VPMOVSXDQYrm : 0, 4, 32, rebuildSExtCst}};
464     return FixupConstant(Fixups, 1);
465   }
466   case X86::VMOVDQA32Z128rm:
467   case X86::VMOVDQA64Z128rm:
468   case X86::VMOVDQU32Z128rm:
469   case X86::VMOVDQU64Z128rm: {
470     FixupEntry Fixups[] = {
471         {HasBWI ? X86::VPBROADCASTBZ128rm : 0, 1, 8, rebuildSplatCst},
472         {HasBWI ? X86::VPBROADCASTWZ128rm : 0, 1, 16, rebuildSplatCst},
473         {X86::VPMOVSXBQZ128rm, 2, 8, rebuildSExtCst},
474         {X86::VMOVDI2PDIZrm, 1, 32, rebuildZeroUpperCst},
475         {X86::VPBROADCASTDZ128rm, 1, 32, rebuildSplatCst},
476         {X86::VPMOVSXBDZ128rm, 4, 8, rebuildSExtCst},
477         {X86::VPMOVSXWQZ128rm, 2, 16, rebuildSExtCst},
478         {X86::VMOVQI2PQIZrm, 1, 64, rebuildZeroUpperCst},
479         {X86::VPBROADCASTQZ128rm, 1, 64, rebuildSplatCst},
480         {HasBWI ? X86::VPMOVSXBWZ128rm : 0, 8, 8, rebuildSExtCst},
481         {X86::VPMOVSXWDZ128rm, 4, 16, rebuildSExtCst},
482         {X86::VPMOVSXDQZ128rm, 2, 32, rebuildSExtCst}};
483     return FixupConstant(Fixups, 1);
484   }
485   case X86::VMOVDQA32Z256rm:
486   case X86::VMOVDQA64Z256rm:
487   case X86::VMOVDQU32Z256rm:
488   case X86::VMOVDQU64Z256rm: {
489     FixupEntry Fixups[] = {
490         {HasBWI ? X86::VPBROADCASTBZ256rm : 0, 1, 8, rebuildSplatCst},
491         {HasBWI ? X86::VPBROADCASTWZ256rm : 0, 1, 16, rebuildSplatCst},
492         {X86::VPBROADCASTDZ256rm, 1, 32, rebuildSplatCst},
493         {X86::VPMOVSXBQZ256rm, 4, 8, rebuildSExtCst},
494         {X86::VPBROADCASTQZ256rm, 1, 64, rebuildSplatCst},
495         {X86::VPMOVSXBDZ256rm, 8, 8, rebuildSExtCst},
496         {X86::VPMOVSXWQZ256rm, 4, 16, rebuildSExtCst},
497         {X86::VBROADCASTI32X4Z256rm, 1, 128, rebuildSplatCst},
498         {HasBWI ? X86::VPMOVSXBWZ256rm : 0, 16, 8, rebuildSExtCst},
499         {X86::VPMOVSXWDZ256rm, 8, 16, rebuildSExtCst},
500         {X86::VPMOVSXDQZ256rm, 4, 32, rebuildSExtCst}};
501     return FixupConstant(Fixups, 1);
502   }
503   case X86::VMOVDQA32Zrm:
504   case X86::VMOVDQA64Zrm:
505   case X86::VMOVDQU32Zrm:
506   case X86::VMOVDQU64Zrm: {
507     FixupEntry Fixups[] = {
508         {HasBWI ? X86::VPBROADCASTBZrm : 0, 1, 8, rebuildSplatCst},
509         {HasBWI ? X86::VPBROADCASTWZrm : 0, 1, 16, rebuildSplatCst},
510         {X86::VPBROADCASTDZrm, 1, 32, rebuildSplatCst},
511         {X86::VPBROADCASTQZrm, 1, 64, rebuildSplatCst},
512         {X86::VPMOVSXBQZrm, 8, 8, rebuildSExtCst},
513         {X86::VBROADCASTI32X4rm, 1, 128, rebuildSplatCst},
514         {X86::VPMOVSXBDZrm, 16, 8, rebuildSExtCst},
515         {X86::VPMOVSXWQZrm, 8, 16, rebuildSExtCst},
516         {X86::VBROADCASTI64X4rm, 1, 256, rebuildSplatCst},
517         {HasBWI ? X86::VPMOVSXBWZrm : 0, 32, 8, rebuildSExtCst},
518         {X86::VPMOVSXWDZrm, 16, 16, rebuildSExtCst},
519         {X86::VPMOVSXDQZrm, 8, 32, rebuildSExtCst}};
520     return FixupConstant(Fixups, 1);
521   }
522   }
523 
524   auto ConvertToBroadcastAVX512 = [&](unsigned OpSrc32, unsigned OpSrc64) {
525     unsigned OpBcst32 = 0, OpBcst64 = 0;
526     unsigned OpNoBcst32 = 0, OpNoBcst64 = 0;
527     if (OpSrc32) {
528       if (const X86FoldTableEntry *Mem2Bcst =
529               llvm::lookupBroadcastFoldTableBySize(OpSrc32, 32)) {
530         OpBcst32 = Mem2Bcst->DstOp;
531         OpNoBcst32 = Mem2Bcst->Flags & TB_INDEX_MASK;
532       }
533     }
534     if (OpSrc64) {
535       if (const X86FoldTableEntry *Mem2Bcst =
536               llvm::lookupBroadcastFoldTableBySize(OpSrc64, 64)) {
537         OpBcst64 = Mem2Bcst->DstOp;
538         OpNoBcst64 = Mem2Bcst->Flags & TB_INDEX_MASK;
539       }
540     }
541     assert(((OpBcst32 == 0) || (OpBcst64 == 0) || (OpNoBcst32 == OpNoBcst64)) &&
542            "OperandNo mismatch");
543 
544     if (OpBcst32 || OpBcst64) {
545       unsigned OpNo = OpBcst32 == 0 ? OpNoBcst64 : OpNoBcst32;
546       FixupEntry Fixups[] = {{(int)OpBcst32, 32, 32, rebuildSplatCst},
547                              {(int)OpBcst64, 64, 64, rebuildSplatCst}};
548       return FixupConstant(Fixups, OpNo);
549     }
550     return false;
551   };
552 
553   // Attempt to find a AVX512 mapping from a full width memory-fold instruction
554   // to a broadcast-fold instruction variant.
555   if ((MI.getDesc().TSFlags & X86II::EncodingMask) == X86II::EVEX)
556     return ConvertToBroadcastAVX512(Opc, Opc);
557 
558   // Reverse the X86InstrInfo::setExecutionDomainCustom EVEX->VEX logic
559   // conversion to see if we can convert to a broadcasted (integer) logic op.
560   if (HasVLX && !HasDQI) {
561     unsigned OpSrc32 = 0, OpSrc64 = 0;
562     switch (Opc) {
563     case X86::VANDPDrm:
564     case X86::VANDPSrm:
565     case X86::VPANDrm:
566       OpSrc32 = X86 ::VPANDDZ128rm;
567       OpSrc64 = X86 ::VPANDQZ128rm;
568       break;
569     case X86::VANDPDYrm:
570     case X86::VANDPSYrm:
571     case X86::VPANDYrm:
572       OpSrc32 = X86 ::VPANDDZ256rm;
573       OpSrc64 = X86 ::VPANDQZ256rm;
574       break;
575     case X86::VANDNPDrm:
576     case X86::VANDNPSrm:
577     case X86::VPANDNrm:
578       OpSrc32 = X86 ::VPANDNDZ128rm;
579       OpSrc64 = X86 ::VPANDNQZ128rm;
580       break;
581     case X86::VANDNPDYrm:
582     case X86::VANDNPSYrm:
583     case X86::VPANDNYrm:
584       OpSrc32 = X86 ::VPANDNDZ256rm;
585       OpSrc64 = X86 ::VPANDNQZ256rm;
586       break;
587     case X86::VORPDrm:
588     case X86::VORPSrm:
589     case X86::VPORrm:
590       OpSrc32 = X86 ::VPORDZ128rm;
591       OpSrc64 = X86 ::VPORQZ128rm;
592       break;
593     case X86::VORPDYrm:
594     case X86::VORPSYrm:
595     case X86::VPORYrm:
596       OpSrc32 = X86 ::VPORDZ256rm;
597       OpSrc64 = X86 ::VPORQZ256rm;
598       break;
599     case X86::VXORPDrm:
600     case X86::VXORPSrm:
601     case X86::VPXORrm:
602       OpSrc32 = X86 ::VPXORDZ128rm;
603       OpSrc64 = X86 ::VPXORQZ128rm;
604       break;
605     case X86::VXORPDYrm:
606     case X86::VXORPSYrm:
607     case X86::VPXORYrm:
608       OpSrc32 = X86 ::VPXORDZ256rm;
609       OpSrc64 = X86 ::VPXORQZ256rm;
610       break;
611     }
612     if (OpSrc32 || OpSrc64)
613       return ConvertToBroadcastAVX512(OpSrc32, OpSrc64);
614   }
615 
616   return false;
617 }
618 
619 bool X86FixupVectorConstantsPass::runOnMachineFunction(MachineFunction &MF) {
620   LLVM_DEBUG(dbgs() << "Start X86FixupVectorConstants\n";);
621   bool Changed = false;
622   ST = &MF.getSubtarget<X86Subtarget>();
623   TII = ST->getInstrInfo();
624   SM = &ST->getSchedModel();
625 
626   for (MachineBasicBlock &MBB : MF) {
627     for (MachineInstr &MI : MBB) {
628       if (processInstruction(MF, MBB, MI)) {
629         ++NumInstChanges;
630         Changed = true;
631       }
632     }
633   }
634   LLVM_DEBUG(dbgs() << "End X86FixupVectorConstants\n";);
635   return Changed;
636 }
637