xref: /llvm-project/llvm/lib/Target/X86/X86ShuffleDecodeConstantPool.cpp (revision 3ace13adfa7a296595c3569d0e9bfc262aff33b5)
1 //===-- X86ShuffleDecodeConstantPool.cpp - X86 shuffle decode -------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Define several functions to decode x86 specific shuffle semantics using
11 // constants from the constant pool.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "X86ShuffleDecodeConstantPool.h"
16 #include "Utils/X86ShuffleDecode.h"
17 #include "llvm/ADT/SmallBitVector.h"
18 #include "llvm/CodeGen/MachineValueType.h"
19 #include "llvm/IR/Constants.h"
20 
21 //===----------------------------------------------------------------------===//
22 //  Vector Mask Decoding
23 //===----------------------------------------------------------------------===//
24 
25 namespace llvm {
26 
27 static bool extractConstantMask(const Constant *C, unsigned MaskEltSizeInBits,
28                                 SmallBitVector &UndefElts,
29                                 SmallVectorImpl<uint64_t> &RawMask) {
30   // It is not an error for shuffle masks to not be a vector of
31   // MaskEltSizeInBits because the constant pool uniques constants by their
32   // bit representation.
33   // e.g. the following take up the same space in the constant pool:
34   //   i128 -170141183420855150465331762880109871104
35   //
36   //   <2 x i64> <i64 -9223372034707292160, i64 -9223372034707292160>
37   //
38   //   <4 x i32> <i32 -2147483648, i32 -2147483648,
39   //              i32 -2147483648, i32 -2147483648>
40   Type *CstTy = C->getType();
41   if (!CstTy->isVectorTy())
42     return false;
43 
44   Type *CstEltTy = CstTy->getVectorElementType();
45   if (!CstEltTy->isIntegerTy())
46     return false;
47 
48   unsigned CstSizeInBits = CstTy->getPrimitiveSizeInBits();
49   unsigned CstEltSizeInBits = CstTy->getScalarSizeInBits();
50   unsigned NumCstElts = CstTy->getVectorNumElements();
51 
52   // Extract all the undef/constant element data and pack into single bitsets.
53   APInt UndefBits(CstSizeInBits, 0);
54   APInt MaskBits(CstSizeInBits, 0);
55   for (unsigned i = 0; i != NumCstElts; ++i) {
56     Constant *COp = C->getAggregateElement(i);
57     if (!COp || (!isa<UndefValue>(COp) && !isa<ConstantInt>(COp)))
58       return false;
59 
60     if (isa<UndefValue>(COp)) {
61       APInt EltUndef = APInt::getLowBitsSet(CstSizeInBits, CstEltSizeInBits);
62       UndefBits |= EltUndef.shl(i * CstEltSizeInBits);
63       continue;
64     }
65 
66     APInt EltBits = cast<ConstantInt>(COp)->getValue();
67     EltBits = EltBits.zextOrTrunc(CstSizeInBits);
68     MaskBits |= EltBits.shl(i * CstEltSizeInBits);
69   }
70 
71   // Now extract the undef/constant bit data into the raw shuffle masks.
72   assert((CstSizeInBits % MaskEltSizeInBits) == 0 && "");
73   unsigned NumMaskElts = CstSizeInBits / MaskEltSizeInBits;
74 
75   UndefElts = SmallBitVector(NumMaskElts, false);
76   RawMask.resize(NumMaskElts, 0);
77 
78   for (unsigned i = 0; i != NumMaskElts; ++i) {
79     APInt EltUndef = UndefBits.lshr(i * MaskEltSizeInBits);
80     EltUndef = EltUndef.zextOrTrunc(MaskEltSizeInBits);
81 
82     // Only treat the element as UNDEF if all bits are UNDEF, otherwise
83     // treat it as zero.
84     if (EltUndef.countPopulation() == MaskEltSizeInBits) {
85       UndefElts[i] = true;
86       RawMask[i] = 0;
87       continue;
88     }
89 
90     APInt EltBits = MaskBits.lshr(i * MaskEltSizeInBits);
91     EltBits = EltBits.zextOrTrunc(MaskEltSizeInBits);
92     RawMask[i] = EltBits.getZExtValue();
93   }
94 
95   return true;
96 }
97 
98 void DecodePSHUFBMask(const Constant *C, SmallVectorImpl<int> &ShuffleMask) {
99   Type *MaskTy = C->getType();
100   unsigned MaskTySize = MaskTy->getPrimitiveSizeInBits();
101   (void)MaskTySize;
102   assert(MaskTySize == 128 || MaskTySize == 256 || MaskTySize == 512);
103 
104   // The shuffle mask requires a byte vector.
105   SmallBitVector UndefElts;
106   SmallVector<uint64_t, 32> RawMask;
107   if (!extractConstantMask(C, 8, UndefElts, RawMask))
108     return;
109 
110   unsigned NumElts = RawMask.size();
111   assert((NumElts == 16 || NumElts == 32 || NumElts == 64) &&
112          "Unexpected number of vector elements.");
113 
114   for (unsigned i = 0; i != NumElts; ++i) {
115     if (UndefElts[i]) {
116       ShuffleMask.push_back(SM_SentinelUndef);
117       continue;
118     }
119 
120     uint64_t Element = RawMask[i];
121     // If the high bit (7) of the byte is set, the element is zeroed.
122     if (Element & (1 << 7))
123       ShuffleMask.push_back(SM_SentinelZero);
124     else {
125       // For AVX vectors with 32 bytes the base of the shuffle is the 16-byte
126       // lane of the vector we're inside.
127       unsigned Base = i & ~0xf;
128 
129       // Only the least significant 4 bits of the byte are used.
130       int Index = Base + (Element & 0xf);
131       ShuffleMask.push_back(Index);
132     }
133   }
134 }
135 
136 void DecodeVPERMILPMask(const Constant *C, unsigned ElSize,
137                         SmallVectorImpl<int> &ShuffleMask) {
138   Type *MaskTy = C->getType();
139   unsigned MaskTySize = MaskTy->getPrimitiveSizeInBits();
140   (void)MaskTySize;
141   assert(MaskTySize == 128 || MaskTySize == 256 || MaskTySize == 512);
142   assert(ElSize == 32 || ElSize == 64);
143 
144   // The shuffle mask requires elements the same size as the target.
145   SmallBitVector UndefElts;
146   SmallVector<uint64_t, 8> RawMask;
147   if (!extractConstantMask(C, ElSize, UndefElts, RawMask))
148     return;
149 
150   unsigned NumElts = RawMask.size();
151   unsigned NumEltsPerLane = 128 / ElSize;
152   assert((NumElts == 2 || NumElts == 4 || NumElts == 8 || NumElts == 16) &&
153          "Unexpected number of vector elements.");
154 
155   for (unsigned i = 0; i != NumElts; ++i) {
156     if (UndefElts[i]) {
157       ShuffleMask.push_back(SM_SentinelUndef);
158       continue;
159     }
160 
161     int Index = i & ~(NumEltsPerLane - 1);
162     uint64_t Element = RawMask[i];
163     if (ElSize == 64)
164       Index += (Element >> 1) & 0x1;
165     else
166       Index += Element & 0x3;
167 
168     ShuffleMask.push_back(Index);
169   }
170 }
171 
172 void DecodeVPERMIL2PMask(const Constant *C, unsigned M2Z, unsigned ElSize,
173                          SmallVectorImpl<int> &ShuffleMask) {
174   Type *MaskTy = C->getType();
175   unsigned MaskTySize = MaskTy->getPrimitiveSizeInBits();
176   (void)MaskTySize;
177   assert(MaskTySize == 128 || MaskTySize == 256);
178 
179   // The shuffle mask requires elements the same size as the target.
180   SmallBitVector UndefElts;
181   SmallVector<uint64_t, 8> RawMask;
182   if (!extractConstantMask(C, ElSize, UndefElts, RawMask))
183     return;
184 
185   unsigned NumElts = RawMask.size();
186   unsigned NumEltsPerLane = 128 / ElSize;
187   assert((NumElts == 2 || NumElts == 4 || NumElts == 8) &&
188          "Unexpected number of vector elements.");
189 
190   for (unsigned i = 0; i != NumElts; ++i) {
191     if (UndefElts[i]) {
192       ShuffleMask.push_back(SM_SentinelUndef);
193       continue;
194     }
195 
196     // VPERMIL2 Operation.
197     // Bits[3] - Match Bit.
198     // Bits[2:1] - (Per Lane) PD Shuffle Mask.
199     // Bits[2:0] - (Per Lane) PS Shuffle Mask.
200     uint64_t Selector = RawMask[i];
201     unsigned MatchBit = (Selector >> 3) & 0x1;
202 
203     // M2Z[0:1]     MatchBit
204     //   0Xb           X        Source selected by Selector index.
205     //   10b           0        Source selected by Selector index.
206     //   10b           1        Zero.
207     //   11b           0        Zero.
208     //   11b           1        Source selected by Selector index.
209     if ((M2Z & 0x2) != 0u && MatchBit != (M2Z & 0x1)) {
210       ShuffleMask.push_back(SM_SentinelZero);
211       continue;
212     }
213 
214     int Index = i & ~(NumEltsPerLane - 1);
215     if (ElSize == 64)
216       Index += (Selector >> 1) & 0x1;
217     else
218       Index += Selector & 0x3;
219 
220     int Src = (Selector >> 2) & 0x1;
221     Index += Src * NumElts;
222     ShuffleMask.push_back(Index);
223   }
224 }
225 
226 void DecodeVPPERMMask(const Constant *C, SmallVectorImpl<int> &ShuffleMask) {
227   assert(C->getType()->getPrimitiveSizeInBits() == 128);
228 
229   // The shuffle mask requires a byte vector.
230   SmallBitVector UndefElts;
231   SmallVector<uint64_t, 32> RawMask;
232   if (!extractConstantMask(C, 8, UndefElts, RawMask))
233     return;
234 
235   unsigned NumElts = RawMask.size();
236   assert(NumElts == 16 && "Unexpected number of vector elements.");
237 
238   for (unsigned i = 0; i != NumElts; ++i) {
239     if (UndefElts[i]) {
240       ShuffleMask.push_back(SM_SentinelUndef);
241       continue;
242     }
243 
244     // VPPERM Operation
245     // Bits[4:0] - Byte Index (0 - 31)
246     // Bits[7:5] - Permute Operation
247     //
248     // Permute Operation:
249     // 0 - Source byte (no logical operation).
250     // 1 - Invert source byte.
251     // 2 - Bit reverse of source byte.
252     // 3 - Bit reverse of inverted source byte.
253     // 4 - 00h (zero - fill).
254     // 5 - FFh (ones - fill).
255     // 6 - Most significant bit of source byte replicated in all bit positions.
256     // 7 - Invert most significant bit of source byte and replicate in all bit
257     // positions.
258     uint64_t Element = RawMask[i];
259     uint64_t Index = Element & 0x1F;
260     uint64_t PermuteOp = (Element >> 5) & 0x7;
261 
262     if (PermuteOp == 4) {
263       ShuffleMask.push_back(SM_SentinelZero);
264       continue;
265     }
266     if (PermuteOp != 0) {
267       ShuffleMask.clear();
268       return;
269     }
270     ShuffleMask.push_back((int)Index);
271   }
272 }
273 
274 void DecodeVPERMVMask(const Constant *C, MVT VT,
275                       SmallVectorImpl<int> &ShuffleMask) {
276   Type *MaskTy = C->getType();
277   if (MaskTy->isVectorTy()) {
278     unsigned NumElements = MaskTy->getVectorNumElements();
279     if (NumElements == VT.getVectorNumElements()) {
280       unsigned EltMaskSize = Log2_64(NumElements);
281       for (unsigned i = 0; i < NumElements; ++i) {
282         Constant *COp = C->getAggregateElement(i);
283         if (!COp || (!isa<UndefValue>(COp) && !isa<ConstantInt>(COp))) {
284           ShuffleMask.clear();
285           return;
286         }
287         if (isa<UndefValue>(COp))
288           ShuffleMask.push_back(SM_SentinelUndef);
289         else {
290           APInt Element = cast<ConstantInt>(COp)->getValue();
291           Element = Element.getLoBits(EltMaskSize);
292           ShuffleMask.push_back(Element.getZExtValue());
293         }
294       }
295     }
296     return;
297   }
298   // Scalar value; just broadcast it
299   if (!isa<ConstantInt>(C))
300     return;
301   uint64_t Element = cast<ConstantInt>(C)->getZExtValue();
302   int NumElements = VT.getVectorNumElements();
303   Element &= (1 << NumElements) - 1;
304   for (int i = 0; i < NumElements; ++i)
305     ShuffleMask.push_back(Element);
306 }
307 
308 void DecodeVPERMV3Mask(const Constant *C, MVT VT,
309                        SmallVectorImpl<int> &ShuffleMask) {
310   Type *MaskTy = C->getType();
311   unsigned NumElements = MaskTy->getVectorNumElements();
312   if (NumElements == VT.getVectorNumElements()) {
313     unsigned EltMaskSize = Log2_64(NumElements * 2);
314     for (unsigned i = 0; i < NumElements; ++i) {
315       Constant *COp = C->getAggregateElement(i);
316       if (!COp) {
317         ShuffleMask.clear();
318         return;
319       }
320       if (isa<UndefValue>(COp))
321         ShuffleMask.push_back(SM_SentinelUndef);
322       else {
323         APInt Element = cast<ConstantInt>(COp)->getValue();
324         Element = Element.getLoBits(EltMaskSize);
325         ShuffleMask.push_back(Element.getZExtValue());
326       }
327     }
328   }
329 }
330 } // llvm namespace
331