xref: /llvm-project/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp (revision 01c8cd664a9bea23a49c863a39351949ac11a4fd)
1 //=== AArch64PostLegalizerLowering.cpp --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Post-legalization lowering for instructions.
11 ///
12 /// This is used to offload pattern matching from the selector.
13 ///
14 /// For example, this combiner will notice that a G_SHUFFLE_VECTOR is actually
15 /// a G_ZIP, G_UZP, etc.
16 ///
17 /// General optimization combines should be handled by either the
18 /// AArch64PostLegalizerCombiner or the AArch64PreLegalizerCombiner.
19 ///
20 //===----------------------------------------------------------------------===//
21 
22 #include "AArch64ExpandImm.h"
23 #include "AArch64GlobalISelUtils.h"
24 #include "AArch64PerfectShuffle.h"
25 #include "AArch64Subtarget.h"
26 #include "GISel/AArch64LegalizerInfo.h"
27 #include "MCTargetDesc/AArch64MCTargetDesc.h"
28 #include "Utils/AArch64BaseInfo.h"
29 #include "llvm/CodeGen/GlobalISel/Combiner.h"
30 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
31 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
32 #include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h"
33 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
34 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
35 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
36 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
37 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
38 #include "llvm/CodeGen/GlobalISel/Utils.h"
39 #include "llvm/CodeGen/MachineFrameInfo.h"
40 #include "llvm/CodeGen/MachineFunctionPass.h"
41 #include "llvm/CodeGen/MachineInstrBuilder.h"
42 #include "llvm/CodeGen/MachineRegisterInfo.h"
43 #include "llvm/CodeGen/TargetOpcodes.h"
44 #include "llvm/CodeGen/TargetPassConfig.h"
45 #include "llvm/IR/InstrTypes.h"
46 #include "llvm/InitializePasses.h"
47 #include "llvm/Support/ErrorHandling.h"
48 #include <optional>
49 
50 #define GET_GICOMBINER_DEPS
51 #include "AArch64GenPostLegalizeGILowering.inc"
52 #undef GET_GICOMBINER_DEPS
53 
54 #define DEBUG_TYPE "aarch64-postlegalizer-lowering"
55 
56 using namespace llvm;
57 using namespace MIPatternMatch;
58 using namespace AArch64GISelUtils;
59 
60 namespace {
61 
62 #define GET_GICOMBINER_TYPES
63 #include "AArch64GenPostLegalizeGILowering.inc"
64 #undef GET_GICOMBINER_TYPES
65 
66 /// Represents a pseudo instruction which replaces a G_SHUFFLE_VECTOR.
67 ///
68 /// Used for matching target-supported shuffles before codegen.
69 struct ShuffleVectorPseudo {
70   unsigned Opc;                 ///< Opcode for the instruction. (E.g. G_ZIP1)
71   Register Dst;                 ///< Destination register.
72   SmallVector<SrcOp, 2> SrcOps; ///< Source registers.
73   ShuffleVectorPseudo(unsigned Opc, Register Dst,
74                       std::initializer_list<SrcOp> SrcOps)
75       : Opc(Opc), Dst(Dst), SrcOps(SrcOps){};
76   ShuffleVectorPseudo() = default;
77 };
78 
79 /// Check if a G_EXT instruction can handle a shuffle mask \p M when the vector
80 /// sources of the shuffle are different.
81 std::optional<std::pair<bool, uint64_t>> getExtMask(ArrayRef<int> M,
82                                                     unsigned NumElts) {
83   // Look for the first non-undef element.
84   auto FirstRealElt = find_if(M, [](int Elt) { return Elt >= 0; });
85   if (FirstRealElt == M.end())
86     return std::nullopt;
87 
88   // Use APInt to handle overflow when calculating expected element.
89   unsigned MaskBits = APInt(32, NumElts * 2).logBase2();
90   APInt ExpectedElt = APInt(MaskBits, *FirstRealElt + 1, false, true);
91 
92   // The following shuffle indices must be the successive elements after the
93   // first real element.
94   if (any_of(
95           make_range(std::next(FirstRealElt), M.end()),
96           [&ExpectedElt](int Elt) { return Elt != ExpectedElt++ && Elt >= 0; }))
97     return std::nullopt;
98 
99   // The index of an EXT is the first element if it is not UNDEF.
100   // Watch out for the beginning UNDEFs. The EXT index should be the expected
101   // value of the first element.  E.g.
102   // <-1, -1, 3, ...> is treated as <1, 2, 3, ...>.
103   // <-1, -1, 0, 1, ...> is treated as <2*NumElts-2, 2*NumElts-1, 0, 1, ...>.
104   // ExpectedElt is the last mask index plus 1.
105   uint64_t Imm = ExpectedElt.getZExtValue();
106   bool ReverseExt = false;
107 
108   // There are two difference cases requiring to reverse input vectors.
109   // For example, for vector <4 x i32> we have the following cases,
110   // Case 1: shufflevector(<4 x i32>,<4 x i32>,<-1, -1, -1, 0>)
111   // Case 2: shufflevector(<4 x i32>,<4 x i32>,<-1, -1, 7, 0>)
112   // For both cases, we finally use mask <5, 6, 7, 0>, which requires
113   // to reverse two input vectors.
114   if (Imm < NumElts)
115     ReverseExt = true;
116   else
117     Imm -= NumElts;
118   return std::make_pair(ReverseExt, Imm);
119 }
120 
121 /// Helper function for matchINS.
122 ///
123 /// \returns a value when \p M is an ins mask for \p NumInputElements.
124 ///
125 /// First element of the returned pair is true when the produced
126 /// G_INSERT_VECTOR_ELT destination should be the LHS of the G_SHUFFLE_VECTOR.
127 ///
128 /// Second element is the destination lane for the G_INSERT_VECTOR_ELT.
129 std::optional<std::pair<bool, int>> isINSMask(ArrayRef<int> M,
130                                               int NumInputElements) {
131   if (M.size() != static_cast<size_t>(NumInputElements))
132     return std::nullopt;
133   int NumLHSMatch = 0, NumRHSMatch = 0;
134   int LastLHSMismatch = -1, LastRHSMismatch = -1;
135   for (int Idx = 0; Idx < NumInputElements; ++Idx) {
136     if (M[Idx] == -1) {
137       ++NumLHSMatch;
138       ++NumRHSMatch;
139       continue;
140     }
141     M[Idx] == Idx ? ++NumLHSMatch : LastLHSMismatch = Idx;
142     M[Idx] == Idx + NumInputElements ? ++NumRHSMatch : LastRHSMismatch = Idx;
143   }
144   const int NumNeededToMatch = NumInputElements - 1;
145   if (NumLHSMatch == NumNeededToMatch)
146     return std::make_pair(true, LastLHSMismatch);
147   if (NumRHSMatch == NumNeededToMatch)
148     return std::make_pair(false, LastRHSMismatch);
149   return std::nullopt;
150 }
151 
152 /// \return true if a G_SHUFFLE_VECTOR instruction \p MI can be replaced with a
153 /// G_REV instruction. Returns the appropriate G_REV opcode in \p Opc.
154 bool matchREV(MachineInstr &MI, MachineRegisterInfo &MRI,
155               ShuffleVectorPseudo &MatchInfo) {
156   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
157   ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
158   Register Dst = MI.getOperand(0).getReg();
159   Register Src = MI.getOperand(1).getReg();
160   LLT Ty = MRI.getType(Dst);
161   unsigned EltSize = Ty.getScalarSizeInBits();
162 
163   // Element size for a rev cannot be 64.
164   if (EltSize == 64)
165     return false;
166 
167   unsigned NumElts = Ty.getNumElements();
168 
169   // Try to produce a G_REV instruction
170   for (unsigned LaneSize : {64U, 32U, 16U}) {
171     if (isREVMask(ShuffleMask, EltSize, NumElts, LaneSize)) {
172       unsigned Opcode;
173       if (LaneSize == 64U)
174         Opcode = AArch64::G_REV64;
175       else if (LaneSize == 32U)
176         Opcode = AArch64::G_REV32;
177       else
178         Opcode = AArch64::G_REV16;
179 
180       MatchInfo = ShuffleVectorPseudo(Opcode, Dst, {Src});
181       return true;
182     }
183   }
184 
185   return false;
186 }
187 
188 /// \return true if a G_SHUFFLE_VECTOR instruction \p MI can be replaced with
189 /// a G_TRN1 or G_TRN2 instruction.
190 bool matchTRN(MachineInstr &MI, MachineRegisterInfo &MRI,
191               ShuffleVectorPseudo &MatchInfo) {
192   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
193   unsigned WhichResult;
194   ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
195   Register Dst = MI.getOperand(0).getReg();
196   unsigned NumElts = MRI.getType(Dst).getNumElements();
197   if (!isTRNMask(ShuffleMask, NumElts, WhichResult))
198     return false;
199   unsigned Opc = (WhichResult == 0) ? AArch64::G_TRN1 : AArch64::G_TRN2;
200   Register V1 = MI.getOperand(1).getReg();
201   Register V2 = MI.getOperand(2).getReg();
202   MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2});
203   return true;
204 }
205 
206 /// \return true if a G_SHUFFLE_VECTOR instruction \p MI can be replaced with
207 /// a G_UZP1 or G_UZP2 instruction.
208 ///
209 /// \param [in] MI - The shuffle vector instruction.
210 /// \param [out] MatchInfo - Either G_UZP1 or G_UZP2 on success.
211 bool matchUZP(MachineInstr &MI, MachineRegisterInfo &MRI,
212               ShuffleVectorPseudo &MatchInfo) {
213   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
214   unsigned WhichResult;
215   ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
216   Register Dst = MI.getOperand(0).getReg();
217   unsigned NumElts = MRI.getType(Dst).getNumElements();
218   if (!isUZPMask(ShuffleMask, NumElts, WhichResult))
219     return false;
220   unsigned Opc = (WhichResult == 0) ? AArch64::G_UZP1 : AArch64::G_UZP2;
221   Register V1 = MI.getOperand(1).getReg();
222   Register V2 = MI.getOperand(2).getReg();
223   MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2});
224   return true;
225 }
226 
227 bool matchZip(MachineInstr &MI, MachineRegisterInfo &MRI,
228               ShuffleVectorPseudo &MatchInfo) {
229   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
230   unsigned WhichResult;
231   ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
232   Register Dst = MI.getOperand(0).getReg();
233   unsigned NumElts = MRI.getType(Dst).getNumElements();
234   if (!isZIPMask(ShuffleMask, NumElts, WhichResult))
235     return false;
236   unsigned Opc = (WhichResult == 0) ? AArch64::G_ZIP1 : AArch64::G_ZIP2;
237   Register V1 = MI.getOperand(1).getReg();
238   Register V2 = MI.getOperand(2).getReg();
239   MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2});
240   return true;
241 }
242 
243 /// Helper function for matchDup.
244 bool matchDupFromInsertVectorElt(int Lane, MachineInstr &MI,
245                                  MachineRegisterInfo &MRI,
246                                  ShuffleVectorPseudo &MatchInfo) {
247   if (Lane != 0)
248     return false;
249 
250   // Try to match a vector splat operation into a dup instruction.
251   // We're looking for this pattern:
252   //
253   // %scalar:gpr(s64) = COPY $x0
254   // %undef:fpr(<2 x s64>) = G_IMPLICIT_DEF
255   // %cst0:gpr(s32) = G_CONSTANT i32 0
256   // %zerovec:fpr(<2 x s32>) = G_BUILD_VECTOR %cst0(s32), %cst0(s32)
257   // %ins:fpr(<2 x s64>) = G_INSERT_VECTOR_ELT %undef, %scalar(s64), %cst0(s32)
258   // %splat:fpr(<2 x s64>) = G_SHUFFLE_VECTOR %ins(<2 x s64>), %undef,
259   // %zerovec(<2 x s32>)
260   //
261   // ...into:
262   // %splat = G_DUP %scalar
263 
264   // Begin matching the insert.
265   auto *InsMI = getOpcodeDef(TargetOpcode::G_INSERT_VECTOR_ELT,
266                              MI.getOperand(1).getReg(), MRI);
267   if (!InsMI)
268     return false;
269   // Match the undef vector operand.
270   if (!getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, InsMI->getOperand(1).getReg(),
271                     MRI))
272     return false;
273 
274   // Match the index constant 0.
275   if (!mi_match(InsMI->getOperand(3).getReg(), MRI, m_ZeroInt()))
276     return false;
277 
278   MatchInfo = ShuffleVectorPseudo(AArch64::G_DUP, MI.getOperand(0).getReg(),
279                                   {InsMI->getOperand(2).getReg()});
280   return true;
281 }
282 
283 /// Helper function for matchDup.
284 bool matchDupFromBuildVector(int Lane, MachineInstr &MI,
285                              MachineRegisterInfo &MRI,
286                              ShuffleVectorPseudo &MatchInfo) {
287   assert(Lane >= 0 && "Expected positive lane?");
288   int NumElements = MRI.getType(MI.getOperand(1).getReg()).getNumElements();
289   // Test if the LHS is a BUILD_VECTOR. If it is, then we can just reference the
290   // lane's definition directly.
291   auto *BuildVecMI =
292       getOpcodeDef(TargetOpcode::G_BUILD_VECTOR,
293                    MI.getOperand(Lane < NumElements ? 1 : 2).getReg(), MRI);
294   // If Lane >= NumElements then it is point to RHS, just check from RHS
295   if (NumElements <= Lane)
296     Lane -= NumElements;
297 
298   if (!BuildVecMI)
299     return false;
300   Register Reg = BuildVecMI->getOperand(Lane + 1).getReg();
301   MatchInfo =
302       ShuffleVectorPseudo(AArch64::G_DUP, MI.getOperand(0).getReg(), {Reg});
303   return true;
304 }
305 
306 bool matchDup(MachineInstr &MI, MachineRegisterInfo &MRI,
307               ShuffleVectorPseudo &MatchInfo) {
308   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
309   auto MaybeLane = getSplatIndex(MI);
310   if (!MaybeLane)
311     return false;
312   int Lane = *MaybeLane;
313   // If this is undef splat, generate it via "just" vdup, if possible.
314   if (Lane < 0)
315     Lane = 0;
316   if (matchDupFromInsertVectorElt(Lane, MI, MRI, MatchInfo))
317     return true;
318   if (matchDupFromBuildVector(Lane, MI, MRI, MatchInfo))
319     return true;
320   return false;
321 }
322 
323 // Check if an EXT instruction can handle the shuffle mask when the vector
324 // sources of the shuffle are the same.
325 bool isSingletonExtMask(ArrayRef<int> M, LLT Ty) {
326   unsigned NumElts = Ty.getNumElements();
327 
328   // Assume that the first shuffle index is not UNDEF.  Fail if it is.
329   if (M[0] < 0)
330     return false;
331 
332   // If this is a VEXT shuffle, the immediate value is the index of the first
333   // element.  The other shuffle indices must be the successive elements after
334   // the first one.
335   unsigned ExpectedElt = M[0];
336   for (unsigned I = 1; I < NumElts; ++I) {
337     // Increment the expected index.  If it wraps around, just follow it
338     // back to index zero and keep going.
339     ++ExpectedElt;
340     if (ExpectedElt == NumElts)
341       ExpectedElt = 0;
342 
343     if (M[I] < 0)
344       continue; // Ignore UNDEF indices.
345     if (ExpectedElt != static_cast<unsigned>(M[I]))
346       return false;
347   }
348 
349   return true;
350 }
351 
352 bool matchEXT(MachineInstr &MI, MachineRegisterInfo &MRI,
353               ShuffleVectorPseudo &MatchInfo) {
354   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
355   Register Dst = MI.getOperand(0).getReg();
356   LLT DstTy = MRI.getType(Dst);
357   Register V1 = MI.getOperand(1).getReg();
358   Register V2 = MI.getOperand(2).getReg();
359   auto Mask = MI.getOperand(3).getShuffleMask();
360   uint64_t Imm;
361   auto ExtInfo = getExtMask(Mask, DstTy.getNumElements());
362   uint64_t ExtFactor = MRI.getType(V1).getScalarSizeInBits() / 8;
363 
364   if (!ExtInfo) {
365     if (!getOpcodeDef<GImplicitDef>(V2, MRI) ||
366         !isSingletonExtMask(Mask, DstTy))
367       return false;
368 
369     Imm = Mask[0] * ExtFactor;
370     MatchInfo = ShuffleVectorPseudo(AArch64::G_EXT, Dst, {V1, V1, Imm});
371     return true;
372   }
373   bool ReverseExt;
374   std::tie(ReverseExt, Imm) = *ExtInfo;
375   if (ReverseExt)
376     std::swap(V1, V2);
377   Imm *= ExtFactor;
378   MatchInfo = ShuffleVectorPseudo(AArch64::G_EXT, Dst, {V1, V2, Imm});
379   return true;
380 }
381 
382 /// Replace a G_SHUFFLE_VECTOR instruction with a pseudo.
383 /// \p Opc is the opcode to use. \p MI is the G_SHUFFLE_VECTOR.
384 void applyShuffleVectorPseudo(MachineInstr &MI,
385                               ShuffleVectorPseudo &MatchInfo) {
386   MachineIRBuilder MIRBuilder(MI);
387   MIRBuilder.buildInstr(MatchInfo.Opc, {MatchInfo.Dst}, MatchInfo.SrcOps);
388   MI.eraseFromParent();
389 }
390 
391 /// Replace a G_SHUFFLE_VECTOR instruction with G_EXT.
392 /// Special-cased because the constant operand must be emitted as a G_CONSTANT
393 /// for the imported tablegen patterns to work.
394 void applyEXT(MachineInstr &MI, ShuffleVectorPseudo &MatchInfo) {
395   MachineIRBuilder MIRBuilder(MI);
396   if (MatchInfo.SrcOps[2].getImm() == 0)
397     MIRBuilder.buildCopy(MatchInfo.Dst, MatchInfo.SrcOps[0]);
398   else {
399     // Tablegen patterns expect an i32 G_CONSTANT as the final op.
400     auto Cst =
401         MIRBuilder.buildConstant(LLT::scalar(32), MatchInfo.SrcOps[2].getImm());
402     MIRBuilder.buildInstr(MatchInfo.Opc, {MatchInfo.Dst},
403                           {MatchInfo.SrcOps[0], MatchInfo.SrcOps[1], Cst});
404   }
405   MI.eraseFromParent();
406 }
407 
408 void applyFullRev(MachineInstr &MI, MachineRegisterInfo &MRI) {
409   Register Dst = MI.getOperand(0).getReg();
410   Register Src = MI.getOperand(1).getReg();
411   LLT DstTy = MRI.getType(Dst);
412   assert(DstTy.getSizeInBits() == 128 &&
413          "Expected 128bit vector in applyFullRev");
414   MachineIRBuilder MIRBuilder(MI);
415   auto Cst = MIRBuilder.buildConstant(LLT::scalar(32), 8);
416   auto Rev = MIRBuilder.buildInstr(AArch64::G_REV64, {DstTy}, {Src});
417   MIRBuilder.buildInstr(AArch64::G_EXT, {Dst}, {Rev, Rev, Cst});
418   MI.eraseFromParent();
419 }
420 
421 bool matchNonConstInsert(MachineInstr &MI, MachineRegisterInfo &MRI) {
422   assert(MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT);
423 
424   auto ValAndVReg =
425       getIConstantVRegValWithLookThrough(MI.getOperand(3).getReg(), MRI);
426   return !ValAndVReg;
427 }
428 
429 void applyNonConstInsert(MachineInstr &MI, MachineRegisterInfo &MRI,
430                          MachineIRBuilder &Builder) {
431   auto &Insert = cast<GInsertVectorElement>(MI);
432   Builder.setInstrAndDebugLoc(Insert);
433 
434   Register Offset = Insert.getIndexReg();
435   LLT VecTy = MRI.getType(Insert.getReg(0));
436   LLT EltTy = MRI.getType(Insert.getElementReg());
437   LLT IdxTy = MRI.getType(Insert.getIndexReg());
438 
439   if (VecTy.isScalableVector())
440     return;
441 
442   // Create a stack slot and store the vector into it
443   MachineFunction &MF = Builder.getMF();
444   Align Alignment(
445       std::min<uint64_t>(VecTy.getSizeInBytes().getKnownMinValue(), 16));
446   int FrameIdx = MF.getFrameInfo().CreateStackObject(VecTy.getSizeInBytes(),
447                                                      Alignment, false);
448   LLT FramePtrTy = LLT::pointer(0, 64);
449   MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIdx);
450   auto StackTemp = Builder.buildFrameIndex(FramePtrTy, FrameIdx);
451 
452   Builder.buildStore(Insert.getOperand(1), StackTemp, PtrInfo, Align(8));
453 
454   // Get the pointer to the element, and be sure not to hit undefined behavior
455   // if the index is out of bounds.
456   assert(isPowerOf2_64(VecTy.getNumElements()) &&
457          "Expected a power-2 vector size");
458   auto Mask = Builder.buildConstant(IdxTy, VecTy.getNumElements() - 1);
459   Register And = Builder.buildAnd(IdxTy, Offset, Mask).getReg(0);
460   auto EltSize = Builder.buildConstant(IdxTy, EltTy.getSizeInBytes());
461   Register Mul = Builder.buildMul(IdxTy, And, EltSize).getReg(0);
462   Register EltPtr =
463       Builder.buildPtrAdd(MRI.getType(StackTemp.getReg(0)), StackTemp, Mul)
464           .getReg(0);
465 
466   // Write the inserted element
467   Builder.buildStore(Insert.getElementReg(), EltPtr, PtrInfo, Align(1));
468   // Reload the whole vector.
469   Builder.buildLoad(Insert.getReg(0), StackTemp, PtrInfo, Align(8));
470   Insert.eraseFromParent();
471 }
472 
473 /// Match a G_SHUFFLE_VECTOR with a mask which corresponds to a
474 /// G_INSERT_VECTOR_ELT and G_EXTRACT_VECTOR_ELT pair.
475 ///
476 /// e.g.
477 ///   %shuf = G_SHUFFLE_VECTOR %left, %right, shufflemask(0, 0)
478 ///
479 /// Can be represented as
480 ///
481 ///   %extract = G_EXTRACT_VECTOR_ELT %left, 0
482 ///   %ins = G_INSERT_VECTOR_ELT %left, %extract, 1
483 ///
484 bool matchINS(MachineInstr &MI, MachineRegisterInfo &MRI,
485               std::tuple<Register, int, Register, int> &MatchInfo) {
486   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
487   ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
488   Register Dst = MI.getOperand(0).getReg();
489   int NumElts = MRI.getType(Dst).getNumElements();
490   auto DstIsLeftAndDstLane = isINSMask(ShuffleMask, NumElts);
491   if (!DstIsLeftAndDstLane)
492     return false;
493   bool DstIsLeft;
494   int DstLane;
495   std::tie(DstIsLeft, DstLane) = *DstIsLeftAndDstLane;
496   Register Left = MI.getOperand(1).getReg();
497   Register Right = MI.getOperand(2).getReg();
498   Register DstVec = DstIsLeft ? Left : Right;
499   Register SrcVec = Left;
500 
501   int SrcLane = ShuffleMask[DstLane];
502   if (SrcLane >= NumElts) {
503     SrcVec = Right;
504     SrcLane -= NumElts;
505   }
506 
507   MatchInfo = std::make_tuple(DstVec, DstLane, SrcVec, SrcLane);
508   return true;
509 }
510 
511 void applyINS(MachineInstr &MI, MachineRegisterInfo &MRI,
512               MachineIRBuilder &Builder,
513               std::tuple<Register, int, Register, int> &MatchInfo) {
514   Builder.setInstrAndDebugLoc(MI);
515   Register Dst = MI.getOperand(0).getReg();
516   auto ScalarTy = MRI.getType(Dst).getElementType();
517   Register DstVec, SrcVec;
518   int DstLane, SrcLane;
519   std::tie(DstVec, DstLane, SrcVec, SrcLane) = MatchInfo;
520   auto SrcCst = Builder.buildConstant(LLT::scalar(64), SrcLane);
521   auto Extract = Builder.buildExtractVectorElement(ScalarTy, SrcVec, SrcCst);
522   auto DstCst = Builder.buildConstant(LLT::scalar(64), DstLane);
523   Builder.buildInsertVectorElement(Dst, DstVec, Extract, DstCst);
524   MI.eraseFromParent();
525 }
526 
527 /// isVShiftRImm - Check if this is a valid vector for the immediate
528 /// operand of a vector shift right operation. The value must be in the range:
529 ///   1 <= Value <= ElementBits for a right shift.
530 bool isVShiftRImm(Register Reg, MachineRegisterInfo &MRI, LLT Ty,
531                   int64_t &Cnt) {
532   assert(Ty.isVector() && "vector shift count is not a vector type");
533   MachineInstr *MI = MRI.getVRegDef(Reg);
534   auto Cst = getAArch64VectorSplatScalar(*MI, MRI);
535   if (!Cst)
536     return false;
537   Cnt = *Cst;
538   int64_t ElementBits = Ty.getScalarSizeInBits();
539   return Cnt >= 1 && Cnt <= ElementBits;
540 }
541 
542 /// Match a vector G_ASHR or G_LSHR with a valid immediate shift.
543 bool matchVAshrLshrImm(MachineInstr &MI, MachineRegisterInfo &MRI,
544                        int64_t &Imm) {
545   assert(MI.getOpcode() == TargetOpcode::G_ASHR ||
546          MI.getOpcode() == TargetOpcode::G_LSHR);
547   LLT Ty = MRI.getType(MI.getOperand(1).getReg());
548   if (!Ty.isVector())
549     return false;
550   return isVShiftRImm(MI.getOperand(2).getReg(), MRI, Ty, Imm);
551 }
552 
553 void applyVAshrLshrImm(MachineInstr &MI, MachineRegisterInfo &MRI,
554                        int64_t &Imm) {
555   unsigned Opc = MI.getOpcode();
556   assert(Opc == TargetOpcode::G_ASHR || Opc == TargetOpcode::G_LSHR);
557   unsigned NewOpc =
558       Opc == TargetOpcode::G_ASHR ? AArch64::G_VASHR : AArch64::G_VLSHR;
559   MachineIRBuilder MIB(MI);
560   auto ImmDef = MIB.buildConstant(LLT::scalar(32), Imm);
561   MIB.buildInstr(NewOpc, {MI.getOperand(0)}, {MI.getOperand(1), ImmDef});
562   MI.eraseFromParent();
563 }
564 
565 /// Determine if it is possible to modify the \p RHS and predicate \p P of a
566 /// G_ICMP instruction such that the right-hand side is an arithmetic immediate.
567 ///
568 /// \returns A pair containing the updated immediate and predicate which may
569 /// be used to optimize the instruction.
570 ///
571 /// \note This assumes that the comparison has been legalized.
572 std::optional<std::pair<uint64_t, CmpInst::Predicate>>
573 tryAdjustICmpImmAndPred(Register RHS, CmpInst::Predicate P,
574                         const MachineRegisterInfo &MRI) {
575   const auto &Ty = MRI.getType(RHS);
576   if (Ty.isVector())
577     return std::nullopt;
578   unsigned Size = Ty.getSizeInBits();
579   assert((Size == 32 || Size == 64) && "Expected 32 or 64 bit compare only?");
580 
581   // If the RHS is not a constant, or the RHS is already a valid arithmetic
582   // immediate, then there is nothing to change.
583   auto ValAndVReg = getIConstantVRegValWithLookThrough(RHS, MRI);
584   if (!ValAndVReg)
585     return std::nullopt;
586   uint64_t OriginalC = ValAndVReg->Value.getZExtValue();
587   uint64_t C = OriginalC;
588   if (isLegalArithImmed(C))
589     return std::nullopt;
590 
591   // We have a non-arithmetic immediate. Check if adjusting the immediate and
592   // adjusting the predicate will result in a legal arithmetic immediate.
593   switch (P) {
594   default:
595     return std::nullopt;
596   case CmpInst::ICMP_SLT:
597   case CmpInst::ICMP_SGE:
598     // Check for
599     //
600     // x slt c => x sle c - 1
601     // x sge c => x sgt c - 1
602     //
603     // When c is not the smallest possible negative number.
604     if ((Size == 64 && static_cast<int64_t>(C) == INT64_MIN) ||
605         (Size == 32 && static_cast<int32_t>(C) == INT32_MIN))
606       return std::nullopt;
607     P = (P == CmpInst::ICMP_SLT) ? CmpInst::ICMP_SLE : CmpInst::ICMP_SGT;
608     C -= 1;
609     break;
610   case CmpInst::ICMP_ULT:
611   case CmpInst::ICMP_UGE:
612     // Check for
613     //
614     // x ult c => x ule c - 1
615     // x uge c => x ugt c - 1
616     //
617     // When c is not zero.
618     if (C == 0)
619       return std::nullopt;
620     P = (P == CmpInst::ICMP_ULT) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT;
621     C -= 1;
622     break;
623   case CmpInst::ICMP_SLE:
624   case CmpInst::ICMP_SGT:
625     // Check for
626     //
627     // x sle c => x slt c + 1
628     // x sgt c => s sge c + 1
629     //
630     // When c is not the largest possible signed integer.
631     if ((Size == 32 && static_cast<int32_t>(C) == INT32_MAX) ||
632         (Size == 64 && static_cast<int64_t>(C) == INT64_MAX))
633       return std::nullopt;
634     P = (P == CmpInst::ICMP_SLE) ? CmpInst::ICMP_SLT : CmpInst::ICMP_SGE;
635     C += 1;
636     break;
637   case CmpInst::ICMP_ULE:
638   case CmpInst::ICMP_UGT:
639     // Check for
640     //
641     // x ule c => x ult c + 1
642     // x ugt c => s uge c + 1
643     //
644     // When c is not the largest possible unsigned integer.
645     if ((Size == 32 && static_cast<uint32_t>(C) == UINT32_MAX) ||
646         (Size == 64 && C == UINT64_MAX))
647       return std::nullopt;
648     P = (P == CmpInst::ICMP_ULE) ? CmpInst::ICMP_ULT : CmpInst::ICMP_UGE;
649     C += 1;
650     break;
651   }
652 
653   // Check if the new constant is valid, and return the updated constant and
654   // predicate if it is.
655   if (Size == 32)
656     C = static_cast<uint32_t>(C);
657   if (isLegalArithImmed(C))
658     return {{C, P}};
659 
660   auto IsMaterializableInSingleInstruction = [=](uint64_t Imm) {
661     SmallVector<AArch64_IMM::ImmInsnModel> Insn;
662     AArch64_IMM::expandMOVImm(Imm, 32, Insn);
663     return Insn.size() == 1;
664   };
665 
666   if (!IsMaterializableInSingleInstruction(OriginalC) &&
667       IsMaterializableInSingleInstruction(C))
668     return {{C, P}};
669 
670   return std::nullopt;
671 }
672 
673 /// Determine whether or not it is possible to update the RHS and predicate of
674 /// a G_ICMP instruction such that the RHS will be selected as an arithmetic
675 /// immediate.
676 ///
677 /// \p MI - The G_ICMP instruction
678 /// \p MatchInfo - The new RHS immediate and predicate on success
679 ///
680 /// See tryAdjustICmpImmAndPred for valid transformations.
681 bool matchAdjustICmpImmAndPred(
682     MachineInstr &MI, const MachineRegisterInfo &MRI,
683     std::pair<uint64_t, CmpInst::Predicate> &MatchInfo) {
684   assert(MI.getOpcode() == TargetOpcode::G_ICMP);
685   Register RHS = MI.getOperand(3).getReg();
686   auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
687   if (auto MaybeNewImmAndPred = tryAdjustICmpImmAndPred(RHS, Pred, MRI)) {
688     MatchInfo = *MaybeNewImmAndPred;
689     return true;
690   }
691   return false;
692 }
693 
694 void applyAdjustICmpImmAndPred(
695     MachineInstr &MI, std::pair<uint64_t, CmpInst::Predicate> &MatchInfo,
696     MachineIRBuilder &MIB, GISelChangeObserver &Observer) {
697   MIB.setInstrAndDebugLoc(MI);
698   MachineOperand &RHS = MI.getOperand(3);
699   MachineRegisterInfo &MRI = *MIB.getMRI();
700   auto Cst = MIB.buildConstant(MRI.cloneVirtualRegister(RHS.getReg()),
701                                MatchInfo.first);
702   Observer.changingInstr(MI);
703   RHS.setReg(Cst->getOperand(0).getReg());
704   MI.getOperand(1).setPredicate(MatchInfo.second);
705   Observer.changedInstr(MI);
706 }
707 
708 bool matchDupLane(MachineInstr &MI, MachineRegisterInfo &MRI,
709                   std::pair<unsigned, int> &MatchInfo) {
710   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
711   Register Src1Reg = MI.getOperand(1).getReg();
712   const LLT SrcTy = MRI.getType(Src1Reg);
713   const LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
714 
715   auto LaneIdx = getSplatIndex(MI);
716   if (!LaneIdx)
717     return false;
718 
719   // The lane idx should be within the first source vector.
720   if (*LaneIdx >= SrcTy.getNumElements())
721     return false;
722 
723   if (DstTy != SrcTy)
724     return false;
725 
726   LLT ScalarTy = SrcTy.getElementType();
727   unsigned ScalarSize = ScalarTy.getSizeInBits();
728 
729   unsigned Opc = 0;
730   switch (SrcTy.getNumElements()) {
731   case 2:
732     if (ScalarSize == 64)
733       Opc = AArch64::G_DUPLANE64;
734     else if (ScalarSize == 32)
735       Opc = AArch64::G_DUPLANE32;
736     break;
737   case 4:
738     if (ScalarSize == 32)
739       Opc = AArch64::G_DUPLANE32;
740     else if (ScalarSize == 16)
741       Opc = AArch64::G_DUPLANE16;
742     break;
743   case 8:
744     if (ScalarSize == 8)
745       Opc = AArch64::G_DUPLANE8;
746     else if (ScalarSize == 16)
747       Opc = AArch64::G_DUPLANE16;
748     break;
749   case 16:
750     if (ScalarSize == 8)
751       Opc = AArch64::G_DUPLANE8;
752     break;
753   default:
754     break;
755   }
756   if (!Opc)
757     return false;
758 
759   MatchInfo.first = Opc;
760   MatchInfo.second = *LaneIdx;
761   return true;
762 }
763 
764 void applyDupLane(MachineInstr &MI, MachineRegisterInfo &MRI,
765                   MachineIRBuilder &B, std::pair<unsigned, int> &MatchInfo) {
766   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
767   Register Src1Reg = MI.getOperand(1).getReg();
768   const LLT SrcTy = MRI.getType(Src1Reg);
769 
770   B.setInstrAndDebugLoc(MI);
771   auto Lane = B.buildConstant(LLT::scalar(64), MatchInfo.second);
772 
773   Register DupSrc = MI.getOperand(1).getReg();
774   // For types like <2 x s32>, we can use G_DUPLANE32, with a <4 x s32> source.
775   // To do this, we can use a G_CONCAT_VECTORS to do the widening.
776   if (SrcTy.getSizeInBits() == 64) {
777     auto Undef = B.buildUndef(SrcTy);
778     DupSrc = B.buildConcatVectors(SrcTy.multiplyElements(2),
779                                   {Src1Reg, Undef.getReg(0)})
780                  .getReg(0);
781   }
782   B.buildInstr(MatchInfo.first, {MI.getOperand(0).getReg()}, {DupSrc, Lane});
783   MI.eraseFromParent();
784 }
785 
786 bool matchScalarizeVectorUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI) {
787   auto &Unmerge = cast<GUnmerge>(MI);
788   Register Src1Reg = Unmerge.getReg(Unmerge.getNumOperands() - 1);
789   const LLT SrcTy = MRI.getType(Src1Reg);
790   if (SrcTy.getSizeInBits() != 128 && SrcTy.getSizeInBits() != 64)
791     return false;
792   return SrcTy.isVector() && !SrcTy.isScalable() &&
793          Unmerge.getNumOperands() == (unsigned)SrcTy.getNumElements() + 1;
794 }
795 
796 void applyScalarizeVectorUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI,
797                                  MachineIRBuilder &B) {
798   auto &Unmerge = cast<GUnmerge>(MI);
799   Register Src1Reg = Unmerge.getReg(Unmerge.getNumOperands() - 1);
800   const LLT SrcTy = MRI.getType(Src1Reg);
801   assert((SrcTy.isVector() && !SrcTy.isScalable()) &&
802          "Expected a fixed length vector");
803 
804   for (int I = 0; I < SrcTy.getNumElements(); ++I)
805     B.buildExtractVectorElementConstant(Unmerge.getReg(I), Src1Reg, I);
806   MI.eraseFromParent();
807 }
808 
809 bool matchBuildVectorToDup(MachineInstr &MI, MachineRegisterInfo &MRI) {
810   assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
811   auto Splat = getAArch64VectorSplat(MI, MRI);
812   if (!Splat)
813     return false;
814   if (Splat->isReg())
815     return true;
816   // Later, during selection, we'll try to match imported patterns using
817   // immAllOnesV and immAllZerosV. These require G_BUILD_VECTOR. Don't lower
818   // G_BUILD_VECTORs which could match those patterns.
819   int64_t Cst = Splat->getCst();
820   return (Cst != 0 && Cst != -1);
821 }
822 
823 void applyBuildVectorToDup(MachineInstr &MI, MachineRegisterInfo &MRI,
824                            MachineIRBuilder &B) {
825   B.setInstrAndDebugLoc(MI);
826   B.buildInstr(AArch64::G_DUP, {MI.getOperand(0).getReg()},
827                {MI.getOperand(1).getReg()});
828   MI.eraseFromParent();
829 }
830 
831 /// \returns how many instructions would be saved by folding a G_ICMP's shift
832 /// and/or extension operations.
833 unsigned getCmpOperandFoldingProfit(Register CmpOp, MachineRegisterInfo &MRI) {
834   // No instructions to save if there's more than one use or no uses.
835   if (!MRI.hasOneNonDBGUse(CmpOp))
836     return 0;
837 
838   // FIXME: This is duplicated with the selector. (See: selectShiftedRegister)
839   auto IsSupportedExtend = [&](const MachineInstr &MI) {
840     if (MI.getOpcode() == TargetOpcode::G_SEXT_INREG)
841       return true;
842     if (MI.getOpcode() != TargetOpcode::G_AND)
843       return false;
844     auto ValAndVReg =
845         getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI);
846     if (!ValAndVReg)
847       return false;
848     uint64_t Mask = ValAndVReg->Value.getZExtValue();
849     return (Mask == 0xFF || Mask == 0xFFFF || Mask == 0xFFFFFFFF);
850   };
851 
852   MachineInstr *Def = getDefIgnoringCopies(CmpOp, MRI);
853   if (IsSupportedExtend(*Def))
854     return 1;
855 
856   unsigned Opc = Def->getOpcode();
857   if (Opc != TargetOpcode::G_SHL && Opc != TargetOpcode::G_ASHR &&
858       Opc != TargetOpcode::G_LSHR)
859     return 0;
860 
861   auto MaybeShiftAmt =
862       getIConstantVRegValWithLookThrough(Def->getOperand(2).getReg(), MRI);
863   if (!MaybeShiftAmt)
864     return 0;
865   uint64_t ShiftAmt = MaybeShiftAmt->Value.getZExtValue();
866   MachineInstr *ShiftLHS =
867       getDefIgnoringCopies(Def->getOperand(1).getReg(), MRI);
868 
869   // Check if we can fold an extend and a shift.
870   // FIXME: This is duplicated with the selector. (See:
871   // selectArithExtendedRegister)
872   if (IsSupportedExtend(*ShiftLHS))
873     return (ShiftAmt <= 4) ? 2 : 1;
874 
875   LLT Ty = MRI.getType(Def->getOperand(0).getReg());
876   if (Ty.isVector())
877     return 0;
878   unsigned ShiftSize = Ty.getSizeInBits();
879   if ((ShiftSize == 32 && ShiftAmt <= 31) ||
880       (ShiftSize == 64 && ShiftAmt <= 63))
881     return 1;
882   return 0;
883 }
884 
885 /// \returns true if it would be profitable to swap the LHS and RHS of a G_ICMP
886 /// instruction \p MI.
887 bool trySwapICmpOperands(MachineInstr &MI, MachineRegisterInfo &MRI) {
888   assert(MI.getOpcode() == TargetOpcode::G_ICMP);
889   // Swap the operands if it would introduce a profitable folding opportunity.
890   // (e.g. a shift + extend).
891   //
892   //  For example:
893   //    lsl     w13, w11, #1
894   //    cmp     w13, w12
895   // can be turned into:
896   //    cmp     w12, w11, lsl #1
897 
898   // Don't swap if there's a constant on the RHS, because we know we can fold
899   // that.
900   Register RHS = MI.getOperand(3).getReg();
901   auto RHSCst = getIConstantVRegValWithLookThrough(RHS, MRI);
902   if (RHSCst && isLegalArithImmed(RHSCst->Value.getSExtValue()))
903     return false;
904 
905   Register LHS = MI.getOperand(2).getReg();
906   auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
907   auto GetRegForProfit = [&](Register Reg) {
908     MachineInstr *Def = getDefIgnoringCopies(Reg, MRI);
909     return isCMN(Def, Pred, MRI) ? Def->getOperand(2).getReg() : Reg;
910   };
911 
912   // Don't have a constant on the RHS. If we swap the LHS and RHS of the
913   // compare, would we be able to fold more instructions?
914   Register TheLHS = GetRegForProfit(LHS);
915   Register TheRHS = GetRegForProfit(RHS);
916 
917   // If the LHS is more likely to give us a folding opportunity, then swap the
918   // LHS and RHS.
919   return (getCmpOperandFoldingProfit(TheLHS, MRI) >
920           getCmpOperandFoldingProfit(TheRHS, MRI));
921 }
922 
923 void applySwapICmpOperands(MachineInstr &MI, GISelChangeObserver &Observer) {
924   auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
925   Register LHS = MI.getOperand(2).getReg();
926   Register RHS = MI.getOperand(3).getReg();
927   Observer.changedInstr(MI);
928   MI.getOperand(1).setPredicate(CmpInst::getSwappedPredicate(Pred));
929   MI.getOperand(2).setReg(RHS);
930   MI.getOperand(3).setReg(LHS);
931   Observer.changedInstr(MI);
932 }
933 
934 /// \returns a function which builds a vector floating point compare instruction
935 /// for a condition code \p CC.
936 /// \param [in] IsZero - True if the comparison is against 0.
937 /// \param [in] NoNans - True if the target has NoNansFPMath.
938 std::function<Register(MachineIRBuilder &)>
939 getVectorFCMP(AArch64CC::CondCode CC, Register LHS, Register RHS, bool IsZero,
940               bool NoNans, MachineRegisterInfo &MRI) {
941   LLT DstTy = MRI.getType(LHS);
942   assert(DstTy.isVector() && "Expected vector types only?");
943   assert(DstTy == MRI.getType(RHS) && "Src and Dst types must match!");
944   switch (CC) {
945   default:
946     llvm_unreachable("Unexpected condition code!");
947   case AArch64CC::NE:
948     return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
949       auto FCmp = IsZero
950                       ? MIB.buildInstr(AArch64::G_FCMEQZ, {DstTy}, {LHS})
951                       : MIB.buildInstr(AArch64::G_FCMEQ, {DstTy}, {LHS, RHS});
952       return MIB.buildNot(DstTy, FCmp).getReg(0);
953     };
954   case AArch64CC::EQ:
955     return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
956       return IsZero
957                  ? MIB.buildInstr(AArch64::G_FCMEQZ, {DstTy}, {LHS}).getReg(0)
958                  : MIB.buildInstr(AArch64::G_FCMEQ, {DstTy}, {LHS, RHS})
959                        .getReg(0);
960     };
961   case AArch64CC::GE:
962     return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
963       return IsZero
964                  ? MIB.buildInstr(AArch64::G_FCMGEZ, {DstTy}, {LHS}).getReg(0)
965                  : MIB.buildInstr(AArch64::G_FCMGE, {DstTy}, {LHS, RHS})
966                        .getReg(0);
967     };
968   case AArch64CC::GT:
969     return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
970       return IsZero
971                  ? MIB.buildInstr(AArch64::G_FCMGTZ, {DstTy}, {LHS}).getReg(0)
972                  : MIB.buildInstr(AArch64::G_FCMGT, {DstTy}, {LHS, RHS})
973                        .getReg(0);
974     };
975   case AArch64CC::LS:
976     return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
977       return IsZero
978                  ? MIB.buildInstr(AArch64::G_FCMLEZ, {DstTy}, {LHS}).getReg(0)
979                  : MIB.buildInstr(AArch64::G_FCMGE, {DstTy}, {RHS, LHS})
980                        .getReg(0);
981     };
982   case AArch64CC::MI:
983     return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
984       return IsZero
985                  ? MIB.buildInstr(AArch64::G_FCMLTZ, {DstTy}, {LHS}).getReg(0)
986                  : MIB.buildInstr(AArch64::G_FCMGT, {DstTy}, {RHS, LHS})
987                        .getReg(0);
988     };
989   }
990 }
991 
992 /// Try to lower a vector G_FCMP \p MI into an AArch64-specific pseudo.
993 bool matchLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,
994                           MachineIRBuilder &MIB) {
995   assert(MI.getOpcode() == TargetOpcode::G_FCMP);
996   const auto &ST = MI.getMF()->getSubtarget<AArch64Subtarget>();
997 
998   Register Dst = MI.getOperand(0).getReg();
999   LLT DstTy = MRI.getType(Dst);
1000   if (!DstTy.isVector() || !ST.hasNEON())
1001     return false;
1002   Register LHS = MI.getOperand(2).getReg();
1003   unsigned EltSize = MRI.getType(LHS).getScalarSizeInBits();
1004   if (EltSize == 16 && !ST.hasFullFP16())
1005     return false;
1006   if (EltSize != 16 && EltSize != 32 && EltSize != 64)
1007     return false;
1008 
1009   return true;
1010 }
1011 
1012 /// Try to lower a vector G_FCMP \p MI into an AArch64-specific pseudo.
1013 void applyLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,
1014                           MachineIRBuilder &MIB) {
1015   assert(MI.getOpcode() == TargetOpcode::G_FCMP);
1016   const auto &ST = MI.getMF()->getSubtarget<AArch64Subtarget>();
1017 
1018   const auto &CmpMI = cast<GFCmp>(MI);
1019 
1020   Register Dst = CmpMI.getReg(0);
1021   CmpInst::Predicate Pred = CmpMI.getCond();
1022   Register LHS = CmpMI.getLHSReg();
1023   Register RHS = CmpMI.getRHSReg();
1024 
1025   LLT DstTy = MRI.getType(Dst);
1026 
1027   auto Splat = getAArch64VectorSplat(*MRI.getVRegDef(RHS), MRI);
1028 
1029   // Compares against 0 have special target-specific pseudos.
1030   bool IsZero = Splat && Splat->isCst() && Splat->getCst() == 0;
1031 
1032   bool Invert = false;
1033   AArch64CC::CondCode CC, CC2 = AArch64CC::AL;
1034   if ((Pred == CmpInst::Predicate::FCMP_ORD ||
1035        Pred == CmpInst::Predicate::FCMP_UNO) &&
1036       IsZero) {
1037     // The special case "fcmp ord %a, 0" is the canonical check that LHS isn't
1038     // NaN, so equivalent to a == a and doesn't need the two comparisons an
1039     // "ord" normally would.
1040     // Similarly, "fcmp uno %a, 0" is the canonical check that LHS is NaN and is
1041     // thus equivalent to a != a.
1042     RHS = LHS;
1043     IsZero = false;
1044     CC = Pred == CmpInst::Predicate::FCMP_ORD ? AArch64CC::EQ : AArch64CC::NE;
1045   } else
1046     changeVectorFCMPPredToAArch64CC(Pred, CC, CC2, Invert);
1047 
1048   // Instead of having an apply function, just build here to simplify things.
1049   MIB.setInstrAndDebugLoc(MI);
1050 
1051   const bool NoNans =
1052       ST.getTargetLowering()->getTargetMachine().Options.NoNaNsFPMath;
1053 
1054   auto Cmp = getVectorFCMP(CC, LHS, RHS, IsZero, NoNans, MRI);
1055   Register CmpRes;
1056   if (CC2 == AArch64CC::AL)
1057     CmpRes = Cmp(MIB);
1058   else {
1059     auto Cmp2 = getVectorFCMP(CC2, LHS, RHS, IsZero, NoNans, MRI);
1060     auto Cmp2Dst = Cmp2(MIB);
1061     auto Cmp1Dst = Cmp(MIB);
1062     CmpRes = MIB.buildOr(DstTy, Cmp1Dst, Cmp2Dst).getReg(0);
1063   }
1064   if (Invert)
1065     CmpRes = MIB.buildNot(DstTy, CmpRes).getReg(0);
1066   MRI.replaceRegWith(Dst, CmpRes);
1067   MI.eraseFromParent();
1068 }
1069 
1070 // Matches G_BUILD_VECTOR where at least one source operand is not a constant
1071 bool matchLowerBuildToInsertVecElt(MachineInstr &MI, MachineRegisterInfo &MRI) {
1072   auto *GBuildVec = cast<GBuildVector>(&MI);
1073 
1074   // Check if the values are all constants
1075   for (unsigned I = 0; I < GBuildVec->getNumSources(); ++I) {
1076     auto ConstVal =
1077         getAnyConstantVRegValWithLookThrough(GBuildVec->getSourceReg(I), MRI);
1078 
1079     if (!ConstVal.has_value())
1080       return true;
1081   }
1082 
1083   return false;
1084 }
1085 
1086 void applyLowerBuildToInsertVecElt(MachineInstr &MI, MachineRegisterInfo &MRI,
1087                                    MachineIRBuilder &B) {
1088   auto *GBuildVec = cast<GBuildVector>(&MI);
1089   LLT DstTy = MRI.getType(GBuildVec->getReg(0));
1090   Register DstReg = B.buildUndef(DstTy).getReg(0);
1091 
1092   for (unsigned I = 0; I < GBuildVec->getNumSources(); ++I) {
1093     Register SrcReg = GBuildVec->getSourceReg(I);
1094     if (mi_match(SrcReg, MRI, m_GImplicitDef()))
1095       continue;
1096     auto IdxReg = B.buildConstant(LLT::scalar(64), I);
1097     DstReg =
1098         B.buildInsertVectorElement(DstTy, DstReg, SrcReg, IdxReg).getReg(0);
1099   }
1100   B.buildCopy(GBuildVec->getReg(0), DstReg);
1101   GBuildVec->eraseFromParent();
1102 }
1103 
1104 bool matchFormTruncstore(MachineInstr &MI, MachineRegisterInfo &MRI,
1105                          Register &SrcReg) {
1106   assert(MI.getOpcode() == TargetOpcode::G_STORE);
1107   Register DstReg = MI.getOperand(0).getReg();
1108   if (MRI.getType(DstReg).isVector())
1109     return false;
1110   // Match a store of a truncate.
1111   if (!mi_match(DstReg, MRI, m_GTrunc(m_Reg(SrcReg))))
1112     return false;
1113   // Only form truncstores for value types of max 64b.
1114   return MRI.getType(SrcReg).getSizeInBits() <= 64;
1115 }
1116 
1117 void applyFormTruncstore(MachineInstr &MI, MachineRegisterInfo &MRI,
1118                          MachineIRBuilder &B, GISelChangeObserver &Observer,
1119                          Register &SrcReg) {
1120   assert(MI.getOpcode() == TargetOpcode::G_STORE);
1121   Observer.changingInstr(MI);
1122   MI.getOperand(0).setReg(SrcReg);
1123   Observer.changedInstr(MI);
1124 }
1125 
1126 // Lower vector G_SEXT_INREG back to shifts for selection. We allowed them to
1127 // form in the first place for combine opportunities, so any remaining ones
1128 // at this stage need be lowered back.
1129 bool matchVectorSextInReg(MachineInstr &MI, MachineRegisterInfo &MRI) {
1130   assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1131   Register DstReg = MI.getOperand(0).getReg();
1132   LLT DstTy = MRI.getType(DstReg);
1133   return DstTy.isVector();
1134 }
1135 
1136 void applyVectorSextInReg(MachineInstr &MI, MachineRegisterInfo &MRI,
1137                           MachineIRBuilder &B, GISelChangeObserver &Observer) {
1138   assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1139   B.setInstrAndDebugLoc(MI);
1140   LegalizerHelper Helper(*MI.getMF(), Observer, B);
1141   Helper.lower(MI, 0, /* Unused hint type */ LLT());
1142 }
1143 
1144 /// Combine <N x t>, unused = unmerge(G_EXT <2*N x t> v, undef, N)
1145 ///           => unused, <N x t> = unmerge v
1146 bool matchUnmergeExtToUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI,
1147                               Register &MatchInfo) {
1148   auto &Unmerge = cast<GUnmerge>(MI);
1149   if (Unmerge.getNumDefs() != 2)
1150     return false;
1151   if (!MRI.use_nodbg_empty(Unmerge.getReg(1)))
1152     return false;
1153 
1154   LLT DstTy = MRI.getType(Unmerge.getReg(0));
1155   if (!DstTy.isVector())
1156     return false;
1157 
1158   MachineInstr *Ext = getOpcodeDef(AArch64::G_EXT, Unmerge.getSourceReg(), MRI);
1159   if (!Ext)
1160     return false;
1161 
1162   Register ExtSrc1 = Ext->getOperand(1).getReg();
1163   Register ExtSrc2 = Ext->getOperand(2).getReg();
1164   auto LowestVal =
1165       getIConstantVRegValWithLookThrough(Ext->getOperand(3).getReg(), MRI);
1166   if (!LowestVal || LowestVal->Value.getZExtValue() != DstTy.getSizeInBytes())
1167     return false;
1168 
1169   if (!getOpcodeDef<GImplicitDef>(ExtSrc2, MRI))
1170     return false;
1171 
1172   MatchInfo = ExtSrc1;
1173   return true;
1174 }
1175 
1176 void applyUnmergeExtToUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI,
1177                               MachineIRBuilder &B,
1178                               GISelChangeObserver &Observer, Register &SrcReg) {
1179   Observer.changingInstr(MI);
1180   // Swap dst registers.
1181   Register Dst1 = MI.getOperand(0).getReg();
1182   MI.getOperand(0).setReg(MI.getOperand(1).getReg());
1183   MI.getOperand(1).setReg(Dst1);
1184   MI.getOperand(2).setReg(SrcReg);
1185   Observer.changedInstr(MI);
1186 }
1187 
1188 // Match mul({z/s}ext , {z/s}ext) => {u/s}mull OR
1189 // Match v2s64 mul instructions, which will then be scalarised later on
1190 // Doing these two matches in one function to ensure that the order of matching
1191 // will always be the same.
1192 // Try lowering MUL to MULL before trying to scalarize if needed.
1193 bool matchExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI) {
1194   // Get the instructions that defined the source operand
1195   LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
1196   MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
1197   MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
1198 
1199   if (DstTy.isVector()) {
1200     // If the source operands were EXTENDED before, then {U/S}MULL can be used
1201     unsigned I1Opc = I1->getOpcode();
1202     unsigned I2Opc = I2->getOpcode();
1203     if (((I1Opc == TargetOpcode::G_ZEXT && I2Opc == TargetOpcode::G_ZEXT) ||
1204          (I1Opc == TargetOpcode::G_SEXT && I2Opc == TargetOpcode::G_SEXT)) &&
1205         (MRI.getType(I1->getOperand(0).getReg()).getScalarSizeInBits() ==
1206          MRI.getType(I1->getOperand(1).getReg()).getScalarSizeInBits() * 2) &&
1207         (MRI.getType(I2->getOperand(0).getReg()).getScalarSizeInBits() ==
1208          MRI.getType(I2->getOperand(1).getReg()).getScalarSizeInBits() * 2)) {
1209       return true;
1210     }
1211     // If result type is v2s64, scalarise the instruction
1212     else if (DstTy == LLT::fixed_vector(2, 64)) {
1213       return true;
1214     }
1215   }
1216   return false;
1217 }
1218 
1219 void applyExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
1220                        MachineIRBuilder &B, GISelChangeObserver &Observer) {
1221   assert(MI.getOpcode() == TargetOpcode::G_MUL &&
1222          "Expected a G_MUL instruction");
1223 
1224   // Get the instructions that defined the source operand
1225   LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
1226   MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
1227   MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
1228 
1229   // If the source operands were EXTENDED before, then {U/S}MULL can be used
1230   unsigned I1Opc = I1->getOpcode();
1231   unsigned I2Opc = I2->getOpcode();
1232   if (((I1Opc == TargetOpcode::G_ZEXT && I2Opc == TargetOpcode::G_ZEXT) ||
1233        (I1Opc == TargetOpcode::G_SEXT && I2Opc == TargetOpcode::G_SEXT)) &&
1234       (MRI.getType(I1->getOperand(0).getReg()).getScalarSizeInBits() ==
1235        MRI.getType(I1->getOperand(1).getReg()).getScalarSizeInBits() * 2) &&
1236       (MRI.getType(I2->getOperand(0).getReg()).getScalarSizeInBits() ==
1237        MRI.getType(I2->getOperand(1).getReg()).getScalarSizeInBits() * 2)) {
1238 
1239     B.setInstrAndDebugLoc(MI);
1240     B.buildInstr(I1->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UMULL
1241                                                          : AArch64::G_SMULL,
1242                  {MI.getOperand(0).getReg()},
1243                  {I1->getOperand(1).getReg(), I2->getOperand(1).getReg()});
1244     MI.eraseFromParent();
1245   }
1246   // If result type is v2s64, scalarise the instruction
1247   else if (DstTy == LLT::fixed_vector(2, 64)) {
1248     LegalizerHelper Helper(*MI.getMF(), Observer, B);
1249     B.setInstrAndDebugLoc(MI);
1250     Helper.fewerElementsVector(
1251         MI, 0,
1252         DstTy.changeElementCount(
1253             DstTy.getElementCount().divideCoefficientBy(2)));
1254   }
1255 }
1256 
1257 class AArch64PostLegalizerLoweringImpl : public Combiner {
1258 protected:
1259   const CombinerHelper Helper;
1260   const AArch64PostLegalizerLoweringImplRuleConfig &RuleConfig;
1261   const AArch64Subtarget &STI;
1262 
1263 public:
1264   AArch64PostLegalizerLoweringImpl(
1265       MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC,
1266       GISelCSEInfo *CSEInfo,
1267       const AArch64PostLegalizerLoweringImplRuleConfig &RuleConfig,
1268       const AArch64Subtarget &STI);
1269 
1270   static const char *getName() { return "AArch6400PreLegalizerCombiner"; }
1271 
1272   bool tryCombineAll(MachineInstr &I) const override;
1273 
1274 private:
1275 #define GET_GICOMBINER_CLASS_MEMBERS
1276 #include "AArch64GenPostLegalizeGILowering.inc"
1277 #undef GET_GICOMBINER_CLASS_MEMBERS
1278 };
1279 
1280 #define GET_GICOMBINER_IMPL
1281 #include "AArch64GenPostLegalizeGILowering.inc"
1282 #undef GET_GICOMBINER_IMPL
1283 
1284 AArch64PostLegalizerLoweringImpl::AArch64PostLegalizerLoweringImpl(
1285     MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC,
1286     GISelCSEInfo *CSEInfo,
1287     const AArch64PostLegalizerLoweringImplRuleConfig &RuleConfig,
1288     const AArch64Subtarget &STI)
1289     : Combiner(MF, CInfo, TPC, /*KB*/ nullptr, CSEInfo),
1290       Helper(Observer, B, /*IsPreLegalize*/ true), RuleConfig(RuleConfig),
1291       STI(STI),
1292 #define GET_GICOMBINER_CONSTRUCTOR_INITS
1293 #include "AArch64GenPostLegalizeGILowering.inc"
1294 #undef GET_GICOMBINER_CONSTRUCTOR_INITS
1295 {
1296 }
1297 
1298 class AArch64PostLegalizerLowering : public MachineFunctionPass {
1299 public:
1300   static char ID;
1301 
1302   AArch64PostLegalizerLowering();
1303 
1304   StringRef getPassName() const override {
1305     return "AArch64PostLegalizerLowering";
1306   }
1307 
1308   bool runOnMachineFunction(MachineFunction &MF) override;
1309   void getAnalysisUsage(AnalysisUsage &AU) const override;
1310 
1311 private:
1312   AArch64PostLegalizerLoweringImplRuleConfig RuleConfig;
1313 };
1314 } // end anonymous namespace
1315 
1316 void AArch64PostLegalizerLowering::getAnalysisUsage(AnalysisUsage &AU) const {
1317   AU.addRequired<TargetPassConfig>();
1318   AU.setPreservesCFG();
1319   getSelectionDAGFallbackAnalysisUsage(AU);
1320   MachineFunctionPass::getAnalysisUsage(AU);
1321 }
1322 
1323 AArch64PostLegalizerLowering::AArch64PostLegalizerLowering()
1324     : MachineFunctionPass(ID) {
1325   initializeAArch64PostLegalizerLoweringPass(*PassRegistry::getPassRegistry());
1326 
1327   if (!RuleConfig.parseCommandLineOption())
1328     report_fatal_error("Invalid rule identifier");
1329 }
1330 
1331 bool AArch64PostLegalizerLowering::runOnMachineFunction(MachineFunction &MF) {
1332   if (MF.getProperties().hasProperty(
1333           MachineFunctionProperties::Property::FailedISel))
1334     return false;
1335   assert(MF.getProperties().hasProperty(
1336              MachineFunctionProperties::Property::Legalized) &&
1337          "Expected a legalized function?");
1338   auto *TPC = &getAnalysis<TargetPassConfig>();
1339   const Function &F = MF.getFunction();
1340 
1341   const AArch64Subtarget &ST = MF.getSubtarget<AArch64Subtarget>();
1342   CombinerInfo CInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false,
1343                      /*LegalizerInfo*/ nullptr, /*OptEnabled=*/true,
1344                      F.hasOptSize(), F.hasMinSize());
1345   // Disable fixed-point iteration to reduce compile-time
1346   CInfo.MaxIterations = 1;
1347   CInfo.ObserverLvl = CombinerInfo::ObserverLevel::SinglePass;
1348   // PostLegalizerCombiner performs DCE, so a full DCE pass is unnecessary.
1349   CInfo.EnableFullDCE = false;
1350   AArch64PostLegalizerLoweringImpl Impl(MF, CInfo, TPC, /*CSEInfo*/ nullptr,
1351                                         RuleConfig, ST);
1352   return Impl.combineMachineInstrs();
1353 }
1354 
1355 char AArch64PostLegalizerLowering::ID = 0;
1356 INITIALIZE_PASS_BEGIN(AArch64PostLegalizerLowering, DEBUG_TYPE,
1357                       "Lower AArch64 MachineInstrs after legalization", false,
1358                       false)
1359 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
1360 INITIALIZE_PASS_END(AArch64PostLegalizerLowering, DEBUG_TYPE,
1361                     "Lower AArch64 MachineInstrs after legalization", false,
1362                     false)
1363 
1364 namespace llvm {
1365 FunctionPass *createAArch64PostLegalizerLowering() {
1366   return new AArch64PostLegalizerLowering();
1367 }
1368 } // end namespace llvm
1369