xref: /llvm-project/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp (revision ee7ca0dddafb609090ad1789570c099d95c0afb6)
1 //===- CombinerHelperVectorOps.cpp-----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements CombinerHelper for G_EXTRACT_VECTOR_ELT,
10 // G_INSERT_VECTOR_ELT, and G_VSCALE
11 //
12 //===----------------------------------------------------------------------===//
13 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
14 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
15 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
16 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
17 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19 #include "llvm/CodeGen/GlobalISel/Utils.h"
20 #include "llvm/CodeGen/LowLevelTypeUtils.h"
21 #include "llvm/CodeGen/MachineOperand.h"
22 #include "llvm/CodeGen/MachineRegisterInfo.h"
23 #include "llvm/CodeGen/TargetLowering.h"
24 #include "llvm/CodeGen/TargetOpcodes.h"
25 #include "llvm/Support/Casting.h"
26 #include <optional>
27 
28 #define DEBUG_TYPE "gi-combiner"
29 
30 using namespace llvm;
31 using namespace MIPatternMatch;
32 
33 bool CombinerHelper::matchExtractVectorElement(MachineInstr &MI,
34                                                BuildFnTy &MatchInfo) const {
35   GExtractVectorElement *Extract = cast<GExtractVectorElement>(&MI);
36 
37   Register Dst = Extract->getReg(0);
38   Register Vector = Extract->getVectorReg();
39   Register Index = Extract->getIndexReg();
40   LLT DstTy = MRI.getType(Dst);
41   LLT VectorTy = MRI.getType(Vector);
42 
43   // The vector register can be def'd by various ops that have vector as its
44   // type. They can all be used for constant folding, scalarizing,
45   // canonicalization, or combining based on symmetry.
46   //
47   // vector like ops
48   // * build vector
49   // * build vector trunc
50   // * shuffle vector
51   // * splat vector
52   // * concat vectors
53   // * insert/extract vector element
54   // * insert/extract subvector
55   // * vector loads
56   // * scalable vector loads
57   //
58   // compute like ops
59   // * binary ops
60   // * unary ops
61   //  * exts and truncs
62   //  * casts
63   //  * fneg
64   // * select
65   // * phis
66   // * cmps
67   // * freeze
68   // * bitcast
69   // * undef
70 
71   // We try to get the value of the Index register.
72   std::optional<ValueAndVReg> MaybeIndex =
73       getIConstantVRegValWithLookThrough(Index, MRI);
74   std::optional<APInt> IndexC = std::nullopt;
75 
76   if (MaybeIndex)
77     IndexC = MaybeIndex->Value;
78 
79   // Fold extractVectorElement(Vector, TOOLARGE) -> undef
80   if (IndexC && VectorTy.isFixedVector() &&
81       IndexC->uge(VectorTy.getNumElements()) &&
82       isLegalOrBeforeLegalizer({TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) {
83     // For fixed-length vectors, it's invalid to extract out-of-range elements.
84     MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Dst); };
85     return true;
86   }
87 
88   return false;
89 }
90 
91 bool CombinerHelper::matchExtractVectorElementWithDifferentIndices(
92     const MachineOperand &MO, BuildFnTy &MatchInfo) const {
93   MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI);
94   GExtractVectorElement *Extract = cast<GExtractVectorElement>(Root);
95 
96   //
97   //  %idx1:_(s64) = G_CONSTANT i64 1
98   //  %idx2:_(s64) = G_CONSTANT i64 2
99   //  %insert:_(<2 x s32>) = G_INSERT_VECTOR_ELT_ELT %bv(<2 x s32>),
100   //  %value(s32), %idx2(s64) %extract:_(s32) = G_EXTRACT_VECTOR_ELT %insert(<2
101   //  x s32>), %idx1(s64)
102   //
103   //  -->
104   //
105   //  %insert:_(<2 x s32>) = G_INSERT_VECTOR_ELT_ELT %bv(<2 x s32>),
106   //  %value(s32), %idx2(s64) %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x
107   //  s32>), %idx1(s64)
108   //
109   //
110 
111   Register Index = Extract->getIndexReg();
112 
113   // We try to get the value of the Index register.
114   std::optional<ValueAndVReg> MaybeIndex =
115       getIConstantVRegValWithLookThrough(Index, MRI);
116   std::optional<APInt> IndexC = std::nullopt;
117 
118   if (!MaybeIndex)
119     return false;
120   else
121     IndexC = MaybeIndex->Value;
122 
123   Register Vector = Extract->getVectorReg();
124 
125   GInsertVectorElement *Insert =
126       getOpcodeDef<GInsertVectorElement>(Vector, MRI);
127   if (!Insert)
128     return false;
129 
130   Register Dst = Extract->getReg(0);
131 
132   std::optional<ValueAndVReg> MaybeInsertIndex =
133       getIConstantVRegValWithLookThrough(Insert->getIndexReg(), MRI);
134 
135   if (MaybeInsertIndex && MaybeInsertIndex->Value != *IndexC) {
136     // There is no one-use check. We have to keep the insert. When both Index
137     // registers are constants and not equal, we can look into the Vector
138     // register of the insert.
139     MatchInfo = [=](MachineIRBuilder &B) {
140       B.buildExtractVectorElement(Dst, Insert->getVectorReg(), Index);
141     };
142     return true;
143   }
144 
145   return false;
146 }
147 
148 bool CombinerHelper::matchExtractVectorElementWithBuildVector(
149     const MachineInstr &MI, const MachineInstr &MI2,
150     BuildFnTy &MatchInfo) const {
151   const GExtractVectorElement *Extract = cast<GExtractVectorElement>(&MI);
152   const GBuildVector *Build = cast<GBuildVector>(&MI2);
153 
154   //
155   //  %zero:_(s64) = G_CONSTANT i64 0
156   //  %bv:_(<2 x s32>) = G_BUILD_VECTOR %arg1(s32), %arg2(s32)
157   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %zero(s64)
158   //
159   //  -->
160   //
161   //  %extract:_(32) = COPY %arg1(s32)
162   //
163   //
164 
165   Register Vector = Extract->getVectorReg();
166   LLT VectorTy = MRI.getType(Vector);
167 
168   // There is a one-use check. There are more combines on build vectors.
169   EVT Ty(getMVTForLLT(VectorTy));
170   if (!MRI.hasOneNonDBGUse(Build->getReg(0)) ||
171       !getTargetLowering().aggressivelyPreferBuildVectorSources(Ty))
172     return false;
173 
174   APInt Index = getIConstantFromReg(Extract->getIndexReg(), MRI);
175 
176   // We now know that there is a buildVector def'd on the Vector register and
177   // the index is const. The combine will succeed.
178 
179   Register Dst = Extract->getReg(0);
180 
181   MatchInfo = [=](MachineIRBuilder &B) {
182     B.buildCopy(Dst, Build->getSourceReg(Index.getZExtValue()));
183   };
184 
185   return true;
186 }
187 
188 bool CombinerHelper::matchExtractVectorElementWithBuildVectorTrunc(
189     const MachineOperand &MO, BuildFnTy &MatchInfo) const {
190   MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI);
191   GExtractVectorElement *Extract = cast<GExtractVectorElement>(Root);
192 
193   //
194   //  %zero:_(s64) = G_CONSTANT i64 0
195   //  %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64)
196   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %zero(s64)
197   //
198   //  -->
199   //
200   //  %extract:_(32) = G_TRUNC %arg1(s64)
201   //
202   //
203   //
204   //  %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64)
205   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64)
206   //
207   //  -->
208   //
209   //  %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64)
210   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64)
211   //
212 
213   Register Vector = Extract->getVectorReg();
214 
215   // We expect a buildVectorTrunc on the Vector register.
216   GBuildVectorTrunc *Build = getOpcodeDef<GBuildVectorTrunc>(Vector, MRI);
217   if (!Build)
218     return false;
219 
220   LLT VectorTy = MRI.getType(Vector);
221 
222   // There is a one-use check. There are more combines on build vectors.
223   EVT Ty(getMVTForLLT(VectorTy));
224   if (!MRI.hasOneNonDBGUse(Build->getReg(0)) ||
225       !getTargetLowering().aggressivelyPreferBuildVectorSources(Ty))
226     return false;
227 
228   Register Index = Extract->getIndexReg();
229 
230   // If the Index is constant, then we can extract the element from the given
231   // offset.
232   std::optional<ValueAndVReg> MaybeIndex =
233       getIConstantVRegValWithLookThrough(Index, MRI);
234   if (!MaybeIndex)
235     return false;
236 
237   // We now know that there is a buildVectorTrunc def'd on the Vector register
238   // and the index is const. The combine will succeed.
239 
240   Register Dst = Extract->getReg(0);
241   LLT DstTy = MRI.getType(Dst);
242   LLT SrcTy = MRI.getType(Build->getSourceReg(0));
243 
244   // For buildVectorTrunc, the inputs are truncated.
245   if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}}))
246     return false;
247 
248   MatchInfo = [=](MachineIRBuilder &B) {
249     B.buildTrunc(Dst, Build->getSourceReg(MaybeIndex->Value.getZExtValue()));
250   };
251 
252   return true;
253 }
254 
255 bool CombinerHelper::matchExtractVectorElementWithShuffleVector(
256     const MachineInstr &MI, const MachineInstr &MI2,
257     BuildFnTy &MatchInfo) const {
258   const GExtractVectorElement *Extract = cast<GExtractVectorElement>(&MI);
259   const GShuffleVector *Shuffle = cast<GShuffleVector>(&MI2);
260 
261   //
262   //  %zero:_(s64) = G_CONSTANT i64 0
263   //  %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>),
264   //                     shufflemask(0, 0, 0, 0)
265   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %zero(s64)
266   //
267   //  -->
268   //
269   //  %zero1:_(s64) = G_CONSTANT i64 0
270   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %arg1(<4 x s32>), %zero1(s64)
271   //
272   //
273   //
274   //
275   //  %three:_(s64) = G_CONSTANT i64 3
276   //  %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>),
277   //                     shufflemask(0, 0, 0, -1)
278   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %three(s64)
279   //
280   //  -->
281   //
282   //  %extract:_(s32) = G_IMPLICIT_DEF
283   //
284   //
285 
286   APInt Index = getIConstantFromReg(Extract->getIndexReg(), MRI);
287 
288   ArrayRef<int> Mask = Shuffle->getMask();
289 
290   unsigned Offset = Index.getZExtValue();
291   int SrcIdx = Mask[Offset];
292 
293   LLT Src1Type = MRI.getType(Shuffle->getSrc1Reg());
294   // At the IR level a <1 x ty> shuffle  vector is valid, but we want to extract
295   // from a vector.
296   assert(Src1Type.isVector() && "expected to extract from a vector");
297   unsigned LHSWidth = Src1Type.isVector() ? Src1Type.getNumElements() : 1;
298 
299   // Note that there is no one use check.
300   Register Dst = Extract->getReg(0);
301   LLT DstTy = MRI.getType(Dst);
302 
303   if (SrcIdx < 0 &&
304       isLegalOrBeforeLegalizer({TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) {
305     MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Dst); };
306     return true;
307   }
308 
309   // If the legality check failed, then we still have to abort.
310   if (SrcIdx < 0)
311     return false;
312 
313   Register NewVector;
314 
315   // We check in which vector and at what offset to look through.
316   if (SrcIdx < (int)LHSWidth) {
317     NewVector = Shuffle->getSrc1Reg();
318     // SrcIdx unchanged
319   } else { // SrcIdx >= LHSWidth
320     NewVector = Shuffle->getSrc2Reg();
321     SrcIdx -= LHSWidth;
322   }
323 
324   LLT IdxTy = MRI.getType(Extract->getIndexReg());
325   LLT NewVectorTy = MRI.getType(NewVector);
326 
327   // We check the legality of the look through.
328   if (!isLegalOrBeforeLegalizer(
329           {TargetOpcode::G_EXTRACT_VECTOR_ELT, {DstTy, NewVectorTy, IdxTy}}) ||
330       !isConstantLegalOrBeforeLegalizer({IdxTy}))
331     return false;
332 
333   // We look through the shuffle vector.
334   MatchInfo = [=](MachineIRBuilder &B) {
335     auto Idx = B.buildConstant(IdxTy, SrcIdx);
336     B.buildExtractVectorElement(Dst, NewVector, Idx);
337   };
338 
339   return true;
340 }
341 
342 bool CombinerHelper::matchInsertVectorElementOOB(MachineInstr &MI,
343                                                  BuildFnTy &MatchInfo) const {
344   GInsertVectorElement *Insert = cast<GInsertVectorElement>(&MI);
345 
346   Register Dst = Insert->getReg(0);
347   LLT DstTy = MRI.getType(Dst);
348   Register Index = Insert->getIndexReg();
349 
350   if (!DstTy.isFixedVector())
351     return false;
352 
353   std::optional<ValueAndVReg> MaybeIndex =
354       getIConstantVRegValWithLookThrough(Index, MRI);
355 
356   if (MaybeIndex && MaybeIndex->Value.uge(DstTy.getNumElements()) &&
357       isLegalOrBeforeLegalizer({TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) {
358     MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Dst); };
359     return true;
360   }
361 
362   return false;
363 }
364 
365 bool CombinerHelper::matchAddOfVScale(const MachineOperand &MO,
366                                       BuildFnTy &MatchInfo) const {
367   GAdd *Add = cast<GAdd>(MRI.getVRegDef(MO.getReg()));
368   GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Add->getLHSReg()));
369   GVScale *RHSVScale = cast<GVScale>(MRI.getVRegDef(Add->getRHSReg()));
370 
371   Register Dst = Add->getReg(0);
372 
373   if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0)) ||
374       !MRI.hasOneNonDBGUse(RHSVScale->getReg(0)))
375     return false;
376 
377   MatchInfo = [=](MachineIRBuilder &B) {
378     B.buildVScale(Dst, LHSVScale->getSrc() + RHSVScale->getSrc());
379   };
380 
381   return true;
382 }
383 
384 bool CombinerHelper::matchMulOfVScale(const MachineOperand &MO,
385                                       BuildFnTy &MatchInfo) const {
386   GMul *Mul = cast<GMul>(MRI.getVRegDef(MO.getReg()));
387   GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Mul->getLHSReg()));
388 
389   std::optional<APInt> MaybeRHS = getIConstantVRegVal(Mul->getRHSReg(), MRI);
390   if (!MaybeRHS)
391     return false;
392 
393   Register Dst = MO.getReg();
394 
395   if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0)))
396     return false;
397 
398   MatchInfo = [=](MachineIRBuilder &B) {
399     B.buildVScale(Dst, LHSVScale->getSrc() * *MaybeRHS);
400   };
401 
402   return true;
403 }
404 
405 bool CombinerHelper::matchSubOfVScale(const MachineOperand &MO,
406                                       BuildFnTy &MatchInfo) const {
407   GSub *Sub = cast<GSub>(MRI.getVRegDef(MO.getReg()));
408   GVScale *RHSVScale = cast<GVScale>(MRI.getVRegDef(Sub->getRHSReg()));
409 
410   Register Dst = MO.getReg();
411   LLT DstTy = MRI.getType(Dst);
412 
413   if (!MRI.hasOneNonDBGUse(RHSVScale->getReg(0)) ||
414       !isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, DstTy}))
415     return false;
416 
417   MatchInfo = [=](MachineIRBuilder &B) {
418     auto VScale = B.buildVScale(DstTy, -RHSVScale->getSrc());
419     B.buildAdd(Dst, Sub->getLHSReg(), VScale, Sub->getFlags());
420   };
421 
422   return true;
423 }
424 
425 bool CombinerHelper::matchShlOfVScale(const MachineOperand &MO,
426                                       BuildFnTy &MatchInfo) const {
427   GShl *Shl = cast<GShl>(MRI.getVRegDef(MO.getReg()));
428   GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Shl->getSrcReg()));
429 
430   std::optional<APInt> MaybeRHS = getIConstantVRegVal(Shl->getShiftReg(), MRI);
431   if (!MaybeRHS)
432     return false;
433 
434   Register Dst = MO.getReg();
435   LLT DstTy = MRI.getType(Dst);
436 
437   if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0)) ||
438       !isLegalOrBeforeLegalizer({TargetOpcode::G_VSCALE, DstTy}))
439     return false;
440 
441   MatchInfo = [=](MachineIRBuilder &B) {
442     B.buildVScale(Dst, LHSVScale->getSrc().shl(*MaybeRHS));
443   };
444 
445   return true;
446 }
447