xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/X86/X86FixupVectorConstants.cpp (revision 1db9f3b21e39176dd5b67cf8ac378633b172463e)
1 //===-- X86FixupVectorConstants.cpp - optimize constant generation  -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file examines all full size vector constant pool loads and attempts to
10 // replace them with smaller constant pool entries, including:
11 // * Converting AVX512 memory-fold instructions to their broadcast-fold form
12 // * Broadcasting of full width loads.
13 // * TODO: Sign/Zero extension of full width loads.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "X86.h"
18 #include "X86InstrFoldTables.h"
19 #include "X86InstrInfo.h"
20 #include "X86Subtarget.h"
21 #include "llvm/ADT/Statistic.h"
22 #include "llvm/CodeGen/MachineConstantPool.h"
23 
24 using namespace llvm;
25 
26 #define DEBUG_TYPE "x86-fixup-vector-constants"
27 
28 STATISTIC(NumInstChanges, "Number of instructions changes");
29 
30 namespace {
31 class X86FixupVectorConstantsPass : public MachineFunctionPass {
32 public:
33   static char ID;
34 
35   X86FixupVectorConstantsPass() : MachineFunctionPass(ID) {}
36 
37   StringRef getPassName() const override {
38     return "X86 Fixup Vector Constants";
39   }
40 
41   bool runOnMachineFunction(MachineFunction &MF) override;
42   bool processInstruction(MachineFunction &MF, MachineBasicBlock &MBB,
43                           MachineInstr &MI);
44 
45   // This pass runs after regalloc and doesn't support VReg operands.
46   MachineFunctionProperties getRequiredProperties() const override {
47     return MachineFunctionProperties().set(
48         MachineFunctionProperties::Property::NoVRegs);
49   }
50 
51 private:
52   const X86InstrInfo *TII = nullptr;
53   const X86Subtarget *ST = nullptr;
54   const MCSchedModel *SM = nullptr;
55 };
56 } // end anonymous namespace
57 
58 char X86FixupVectorConstantsPass::ID = 0;
59 
60 INITIALIZE_PASS(X86FixupVectorConstantsPass, DEBUG_TYPE, DEBUG_TYPE, false, false)
61 
62 FunctionPass *llvm::createX86FixupVectorConstants() {
63   return new X86FixupVectorConstantsPass();
64 }
65 
66 static const Constant *getConstantFromPool(const MachineInstr &MI,
67                                            const MachineOperand &Op) {
68   if (!Op.isCPI() || Op.getOffset() != 0)
69     return nullptr;
70 
71   ArrayRef<MachineConstantPoolEntry> Constants =
72       MI.getParent()->getParent()->getConstantPool()->getConstants();
73   const MachineConstantPoolEntry &ConstantEntry = Constants[Op.getIndex()];
74 
75   // Bail if this is a machine constant pool entry, we won't be able to dig out
76   // anything useful.
77   if (ConstantEntry.isMachineConstantPoolEntry())
78     return nullptr;
79 
80   return ConstantEntry.Val.ConstVal;
81 }
82 
83 // Attempt to extract the full width of bits data from the constant.
84 static std::optional<APInt> extractConstantBits(const Constant *C) {
85   unsigned NumBits = C->getType()->getPrimitiveSizeInBits();
86 
87   if (auto *CInt = dyn_cast<ConstantInt>(C))
88     return CInt->getValue();
89 
90   if (auto *CFP = dyn_cast<ConstantFP>(C))
91     return CFP->getValue().bitcastToAPInt();
92 
93   if (auto *CV = dyn_cast<ConstantVector>(C)) {
94     if (auto *CVSplat = CV->getSplatValue(/*AllowUndefs*/ true)) {
95       if (std::optional<APInt> Bits = extractConstantBits(CVSplat)) {
96         assert((NumBits % Bits->getBitWidth()) == 0 && "Illegal splat");
97         return APInt::getSplat(NumBits, *Bits);
98       }
99     }
100   }
101 
102   if (auto *CDS = dyn_cast<ConstantDataSequential>(C)) {
103     bool IsInteger = CDS->getElementType()->isIntegerTy();
104     bool IsFloat = CDS->getElementType()->isHalfTy() ||
105                    CDS->getElementType()->isBFloatTy() ||
106                    CDS->getElementType()->isFloatTy() ||
107                    CDS->getElementType()->isDoubleTy();
108     if (IsInteger || IsFloat) {
109       APInt Bits = APInt::getZero(NumBits);
110       unsigned EltBits = CDS->getElementType()->getPrimitiveSizeInBits();
111       for (unsigned I = 0, E = CDS->getNumElements(); I != E; ++I) {
112         if (IsInteger)
113           Bits.insertBits(CDS->getElementAsAPInt(I), I * EltBits);
114         else
115           Bits.insertBits(CDS->getElementAsAPFloat(I).bitcastToAPInt(),
116                           I * EltBits);
117       }
118       return Bits;
119     }
120   }
121 
122   return std::nullopt;
123 }
124 
125 // Attempt to compute the splat width of bits data by normalizing the splat to
126 // remove undefs.
127 static std::optional<APInt> getSplatableConstant(const Constant *C,
128                                                  unsigned SplatBitWidth) {
129   const Type *Ty = C->getType();
130   assert((Ty->getPrimitiveSizeInBits() % SplatBitWidth) == 0 &&
131          "Illegal splat width");
132 
133   if (std::optional<APInt> Bits = extractConstantBits(C))
134     if (Bits->isSplat(SplatBitWidth))
135       return Bits->trunc(SplatBitWidth);
136 
137   // Detect general splats with undefs.
138   // TODO: Do we need to handle NumEltsBits > SplatBitWidth splitting?
139   if (auto *CV = dyn_cast<ConstantVector>(C)) {
140     unsigned NumOps = CV->getNumOperands();
141     unsigned NumEltsBits = Ty->getScalarSizeInBits();
142     unsigned NumScaleOps = SplatBitWidth / NumEltsBits;
143     if ((SplatBitWidth % NumEltsBits) == 0) {
144       // Collect the elements and ensure that within the repeated splat sequence
145       // they either match or are undef.
146       SmallVector<Constant *, 16> Sequence(NumScaleOps, nullptr);
147       for (unsigned Idx = 0; Idx != NumOps; ++Idx) {
148         if (Constant *Elt = CV->getAggregateElement(Idx)) {
149           if (isa<UndefValue>(Elt))
150             continue;
151           unsigned SplatIdx = Idx % NumScaleOps;
152           if (!Sequence[SplatIdx] || Sequence[SplatIdx] == Elt) {
153             Sequence[SplatIdx] = Elt;
154             continue;
155           }
156         }
157         return std::nullopt;
158       }
159       // Extract the constant bits forming the splat and insert into the bits
160       // data, leave undef as zero.
161       APInt SplatBits = APInt::getZero(SplatBitWidth);
162       for (unsigned I = 0; I != NumScaleOps; ++I) {
163         if (!Sequence[I])
164           continue;
165         if (std::optional<APInt> Bits = extractConstantBits(Sequence[I])) {
166           SplatBits.insertBits(*Bits, I * Bits->getBitWidth());
167           continue;
168         }
169         return std::nullopt;
170       }
171       return SplatBits;
172     }
173   }
174 
175   return std::nullopt;
176 }
177 
178 // Attempt to rebuild a normalized splat vector constant of the requested splat
179 // width, built up of potentially smaller scalar values.
180 // NOTE: We don't always bother converting to scalars if the vector length is 1.
181 static Constant *rebuildSplatableConstant(const Constant *C,
182                                           unsigned SplatBitWidth) {
183   std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
184   if (!Splat)
185     return nullptr;
186 
187   // Determine scalar size to use for the constant splat vector, clamping as we
188   // might have found a splat smaller than the original constant data.
189   const Type *OriginalType = C->getType();
190   Type *SclTy = OriginalType->getScalarType();
191   unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
192   NumSclBits = std::min<unsigned>(NumSclBits, SplatBitWidth);
193   LLVMContext &Ctx = OriginalType->getContext();
194 
195   if (NumSclBits == 8) {
196     SmallVector<uint8_t> RawBits;
197     for (unsigned I = 0; I != SplatBitWidth; I += 8)
198       RawBits.push_back(Splat->extractBits(8, I).getZExtValue());
199     return ConstantDataVector::get(Ctx, RawBits);
200   }
201 
202   if (NumSclBits == 16) {
203     SmallVector<uint16_t> RawBits;
204     for (unsigned I = 0; I != SplatBitWidth; I += 16)
205       RawBits.push_back(Splat->extractBits(16, I).getZExtValue());
206     if (SclTy->is16bitFPTy())
207       return ConstantDataVector::getFP(SclTy, RawBits);
208     return ConstantDataVector::get(Ctx, RawBits);
209   }
210 
211   if (NumSclBits == 32) {
212     SmallVector<uint32_t> RawBits;
213     for (unsigned I = 0; I != SplatBitWidth; I += 32)
214       RawBits.push_back(Splat->extractBits(32, I).getZExtValue());
215     if (SclTy->isFloatTy())
216       return ConstantDataVector::getFP(SclTy, RawBits);
217     return ConstantDataVector::get(Ctx, RawBits);
218   }
219 
220   // Fallback to i64 / double.
221   SmallVector<uint64_t> RawBits;
222   for (unsigned I = 0; I != SplatBitWidth; I += 64)
223     RawBits.push_back(Splat->extractBits(64, I).getZExtValue());
224   if (SclTy->isDoubleTy())
225     return ConstantDataVector::getFP(SclTy, RawBits);
226   return ConstantDataVector::get(Ctx, RawBits);
227 }
228 
229 bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
230                                                      MachineBasicBlock &MBB,
231                                                      MachineInstr &MI) {
232   unsigned Opc = MI.getOpcode();
233   MachineConstantPool *CP = MI.getParent()->getParent()->getConstantPool();
234   bool HasAVX2 = ST->hasAVX2();
235   bool HasDQI = ST->hasDQI();
236   bool HasBWI = ST->hasBWI();
237   bool HasVLX = ST->hasVLX();
238 
239   auto ConvertToBroadcast = [&](unsigned OpBcst256, unsigned OpBcst128,
240                                 unsigned OpBcst64, unsigned OpBcst32,
241                                 unsigned OpBcst16, unsigned OpBcst8,
242                                 unsigned OperandNo) {
243     assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
244            "Unexpected number of operands!");
245 
246     MachineOperand &CstOp = MI.getOperand(OperandNo + X86::AddrDisp);
247     if (auto *C = getConstantFromPool(MI, CstOp)) {
248       // Attempt to detect a suitable splat from increasing splat widths.
249       std::pair<unsigned, unsigned> Broadcasts[] = {
250           {8, OpBcst8},   {16, OpBcst16},   {32, OpBcst32},
251           {64, OpBcst64}, {128, OpBcst128}, {256, OpBcst256},
252       };
253       for (auto [BitWidth, OpBcst] : Broadcasts) {
254         if (OpBcst) {
255           // Construct a suitable splat constant and adjust the MI to
256           // use the new constant pool entry.
257           if (Constant *NewCst = rebuildSplatableConstant(C, BitWidth)) {
258             unsigned NewCPI =
259                 CP->getConstantPoolIndex(NewCst, Align(BitWidth / 8));
260             MI.setDesc(TII->get(OpBcst));
261             CstOp.setIndex(NewCPI);
262             return true;
263           }
264         }
265       }
266     }
267     return false;
268   };
269 
270   // Attempt to convert full width vector loads into broadcast loads.
271   switch (Opc) {
272   /* FP Loads */
273   case X86::MOVAPDrm:
274   case X86::MOVAPSrm:
275   case X86::MOVUPDrm:
276   case X86::MOVUPSrm:
277     // TODO: SSE3 MOVDDUP Handling
278     return false;
279   case X86::VMOVAPDrm:
280   case X86::VMOVAPSrm:
281   case X86::VMOVUPDrm:
282   case X86::VMOVUPSrm:
283     return ConvertToBroadcast(0, 0, X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0, 0,
284                               1);
285   case X86::VMOVAPDYrm:
286   case X86::VMOVAPSYrm:
287   case X86::VMOVUPDYrm:
288   case X86::VMOVUPSYrm:
289     return ConvertToBroadcast(0, X86::VBROADCASTF128rm, X86::VBROADCASTSDYrm,
290                               X86::VBROADCASTSSYrm, 0, 0, 1);
291   case X86::VMOVAPDZ128rm:
292   case X86::VMOVAPSZ128rm:
293   case X86::VMOVUPDZ128rm:
294   case X86::VMOVUPSZ128rm:
295     return ConvertToBroadcast(0, 0, X86::VMOVDDUPZ128rm,
296                               X86::VBROADCASTSSZ128rm, 0, 0, 1);
297   case X86::VMOVAPDZ256rm:
298   case X86::VMOVAPSZ256rm:
299   case X86::VMOVUPDZ256rm:
300   case X86::VMOVUPSZ256rm:
301     return ConvertToBroadcast(0, X86::VBROADCASTF32X4Z256rm,
302                               X86::VBROADCASTSDZ256rm, X86::VBROADCASTSSZ256rm,
303                               0, 0, 1);
304   case X86::VMOVAPDZrm:
305   case X86::VMOVAPSZrm:
306   case X86::VMOVUPDZrm:
307   case X86::VMOVUPSZrm:
308     return ConvertToBroadcast(X86::VBROADCASTF64X4rm, X86::VBROADCASTF32X4rm,
309                               X86::VBROADCASTSDZrm, X86::VBROADCASTSSZrm, 0, 0,
310                               1);
311     /* Integer Loads */
312   case X86::VMOVDQArm:
313   case X86::VMOVDQUrm:
314     return ConvertToBroadcast(
315         0, 0, HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm,
316         HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm,
317         HasAVX2 ? X86::VPBROADCASTWrm : 0, HasAVX2 ? X86::VPBROADCASTBrm : 0,
318         1);
319   case X86::VMOVDQAYrm:
320   case X86::VMOVDQUYrm:
321     return ConvertToBroadcast(
322         0, HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm,
323         HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm,
324         HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm,
325         HasAVX2 ? X86::VPBROADCASTWYrm : 0, HasAVX2 ? X86::VPBROADCASTBYrm : 0,
326         1);
327   case X86::VMOVDQA32Z128rm:
328   case X86::VMOVDQA64Z128rm:
329   case X86::VMOVDQU32Z128rm:
330   case X86::VMOVDQU64Z128rm:
331     return ConvertToBroadcast(0, 0, X86::VPBROADCASTQZ128rm,
332                               X86::VPBROADCASTDZ128rm,
333                               HasBWI ? X86::VPBROADCASTWZ128rm : 0,
334                               HasBWI ? X86::VPBROADCASTBZ128rm : 0, 1);
335   case X86::VMOVDQA32Z256rm:
336   case X86::VMOVDQA64Z256rm:
337   case X86::VMOVDQU32Z256rm:
338   case X86::VMOVDQU64Z256rm:
339     return ConvertToBroadcast(0, X86::VBROADCASTI32X4Z256rm,
340                               X86::VPBROADCASTQZ256rm, X86::VPBROADCASTDZ256rm,
341                               HasBWI ? X86::VPBROADCASTWZ256rm : 0,
342                               HasBWI ? X86::VPBROADCASTBZ256rm : 0, 1);
343   case X86::VMOVDQA32Zrm:
344   case X86::VMOVDQA64Zrm:
345   case X86::VMOVDQU32Zrm:
346   case X86::VMOVDQU64Zrm:
347     return ConvertToBroadcast(X86::VBROADCASTI64X4rm, X86::VBROADCASTI32X4rm,
348                               X86::VPBROADCASTQZrm, X86::VPBROADCASTDZrm,
349                               HasBWI ? X86::VPBROADCASTWZrm : 0,
350                               HasBWI ? X86::VPBROADCASTBZrm : 0, 1);
351   }
352 
353   auto ConvertToBroadcastAVX512 = [&](unsigned OpSrc32, unsigned OpSrc64) {
354     unsigned OpBcst32 = 0, OpBcst64 = 0;
355     unsigned OpNoBcst32 = 0, OpNoBcst64 = 0;
356     if (OpSrc32) {
357       if (const X86FoldTableEntry *Mem2Bcst =
358               llvm::lookupBroadcastFoldTable(OpSrc32, 32)) {
359         OpBcst32 = Mem2Bcst->DstOp;
360         OpNoBcst32 = Mem2Bcst->Flags & TB_INDEX_MASK;
361       }
362     }
363     if (OpSrc64) {
364       if (const X86FoldTableEntry *Mem2Bcst =
365               llvm::lookupBroadcastFoldTable(OpSrc64, 64)) {
366         OpBcst64 = Mem2Bcst->DstOp;
367         OpNoBcst64 = Mem2Bcst->Flags & TB_INDEX_MASK;
368       }
369     }
370     assert(((OpBcst32 == 0) || (OpBcst64 == 0) || (OpNoBcst32 == OpNoBcst64)) &&
371            "OperandNo mismatch");
372 
373     if (OpBcst32 || OpBcst64) {
374       unsigned OpNo = OpBcst32 == 0 ? OpNoBcst64 : OpNoBcst32;
375       return ConvertToBroadcast(0, 0, OpBcst64, OpBcst32, 0, 0, OpNo);
376     }
377     return false;
378   };
379 
380   // Attempt to find a AVX512 mapping from a full width memory-fold instruction
381   // to a broadcast-fold instruction variant.
382   if ((MI.getDesc().TSFlags & X86II::EncodingMask) == X86II::EVEX)
383     return ConvertToBroadcastAVX512(Opc, Opc);
384 
385   // Reverse the X86InstrInfo::setExecutionDomainCustom EVEX->VEX logic
386   // conversion to see if we can convert to a broadcasted (integer) logic op.
387   if (HasVLX && !HasDQI) {
388     unsigned OpSrc32 = 0, OpSrc64 = 0;
389     switch (Opc) {
390     case X86::VANDPDrm:
391     case X86::VANDPSrm:
392     case X86::VPANDrm:
393       OpSrc32 = X86 ::VPANDDZ128rm;
394       OpSrc64 = X86 ::VPANDQZ128rm;
395       break;
396     case X86::VANDPDYrm:
397     case X86::VANDPSYrm:
398     case X86::VPANDYrm:
399       OpSrc32 = X86 ::VPANDDZ256rm;
400       OpSrc64 = X86 ::VPANDQZ256rm;
401       break;
402     case X86::VANDNPDrm:
403     case X86::VANDNPSrm:
404     case X86::VPANDNrm:
405       OpSrc32 = X86 ::VPANDNDZ128rm;
406       OpSrc64 = X86 ::VPANDNQZ128rm;
407       break;
408     case X86::VANDNPDYrm:
409     case X86::VANDNPSYrm:
410     case X86::VPANDNYrm:
411       OpSrc32 = X86 ::VPANDNDZ256rm;
412       OpSrc64 = X86 ::VPANDNQZ256rm;
413       break;
414     case X86::VORPDrm:
415     case X86::VORPSrm:
416     case X86::VPORrm:
417       OpSrc32 = X86 ::VPORDZ128rm;
418       OpSrc64 = X86 ::VPORQZ128rm;
419       break;
420     case X86::VORPDYrm:
421     case X86::VORPSYrm:
422     case X86::VPORYrm:
423       OpSrc32 = X86 ::VPORDZ256rm;
424       OpSrc64 = X86 ::VPORQZ256rm;
425       break;
426     case X86::VXORPDrm:
427     case X86::VXORPSrm:
428     case X86::VPXORrm:
429       OpSrc32 = X86 ::VPXORDZ128rm;
430       OpSrc64 = X86 ::VPXORQZ128rm;
431       break;
432     case X86::VXORPDYrm:
433     case X86::VXORPSYrm:
434     case X86::VPXORYrm:
435       OpSrc32 = X86 ::VPXORDZ256rm;
436       OpSrc64 = X86 ::VPXORQZ256rm;
437       break;
438     }
439     if (OpSrc32 || OpSrc64)
440       return ConvertToBroadcastAVX512(OpSrc32, OpSrc64);
441   }
442 
443   return false;
444 }
445 
446 bool X86FixupVectorConstantsPass::runOnMachineFunction(MachineFunction &MF) {
447   LLVM_DEBUG(dbgs() << "Start X86FixupVectorConstants\n";);
448   bool Changed = false;
449   ST = &MF.getSubtarget<X86Subtarget>();
450   TII = ST->getInstrInfo();
451   SM = &ST->getSchedModel();
452 
453   for (MachineBasicBlock &MBB : MF) {
454     for (MachineInstr &MI : MBB) {
455       if (processInstruction(MF, MBB, MI)) {
456         ++NumInstChanges;
457         Changed = true;
458       }
459     }
460   }
461   LLVM_DEBUG(dbgs() << "End X86FixupVectorConstants\n";);
462   return Changed;
463 }
464