xref: /netbsd-src/external/apache2/llvm/dist/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (revision 82d56013d7b633d116a93943de88e08335357a7c)
1 //===-- AArch64ISelLowering.cpp - AArch64 DAG Lowering Implementation  ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AArch64TargetLowering class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "AArch64ISelLowering.h"
14 #include "AArch64CallingConvention.h"
15 #include "AArch64ExpandImm.h"
16 #include "AArch64MachineFunctionInfo.h"
17 #include "AArch64PerfectShuffle.h"
18 #include "AArch64RegisterInfo.h"
19 #include "AArch64Subtarget.h"
20 #include "MCTargetDesc/AArch64AddressingModes.h"
21 #include "Utils/AArch64BaseInfo.h"
22 #include "llvm/ADT/APFloat.h"
23 #include "llvm/ADT/APInt.h"
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallSet.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/Statistic.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/ADT/Triple.h"
31 #include "llvm/ADT/Twine.h"
32 #include "llvm/Analysis/ObjCARCUtil.h"
33 #include "llvm/Analysis/VectorUtils.h"
34 #include "llvm/CodeGen/CallingConvLower.h"
35 #include "llvm/CodeGen/MachineBasicBlock.h"
36 #include "llvm/CodeGen/MachineFrameInfo.h"
37 #include "llvm/CodeGen/MachineFunction.h"
38 #include "llvm/CodeGen/MachineInstr.h"
39 #include "llvm/CodeGen/MachineInstrBuilder.h"
40 #include "llvm/CodeGen/MachineMemOperand.h"
41 #include "llvm/CodeGen/MachineRegisterInfo.h"
42 #include "llvm/CodeGen/RuntimeLibcalls.h"
43 #include "llvm/CodeGen/SelectionDAG.h"
44 #include "llvm/CodeGen/SelectionDAGNodes.h"
45 #include "llvm/CodeGen/TargetCallingConv.h"
46 #include "llvm/CodeGen/TargetInstrInfo.h"
47 #include "llvm/CodeGen/ValueTypes.h"
48 #include "llvm/IR/Attributes.h"
49 #include "llvm/IR/Constants.h"
50 #include "llvm/IR/DataLayout.h"
51 #include "llvm/IR/DebugLoc.h"
52 #include "llvm/IR/DerivedTypes.h"
53 #include "llvm/IR/Function.h"
54 #include "llvm/IR/GetElementPtrTypeIterator.h"
55 #include "llvm/IR/GlobalValue.h"
56 #include "llvm/IR/IRBuilder.h"
57 #include "llvm/IR/Instruction.h"
58 #include "llvm/IR/Instructions.h"
59 #include "llvm/IR/IntrinsicInst.h"
60 #include "llvm/IR/Intrinsics.h"
61 #include "llvm/IR/IntrinsicsAArch64.h"
62 #include "llvm/IR/Module.h"
63 #include "llvm/IR/OperandTraits.h"
64 #include "llvm/IR/PatternMatch.h"
65 #include "llvm/IR/Type.h"
66 #include "llvm/IR/Use.h"
67 #include "llvm/IR/Value.h"
68 #include "llvm/MC/MCRegisterInfo.h"
69 #include "llvm/Support/Casting.h"
70 #include "llvm/Support/CodeGen.h"
71 #include "llvm/Support/CommandLine.h"
72 #include "llvm/Support/Compiler.h"
73 #include "llvm/Support/Debug.h"
74 #include "llvm/Support/ErrorHandling.h"
75 #include "llvm/Support/KnownBits.h"
76 #include "llvm/Support/MachineValueType.h"
77 #include "llvm/Support/MathExtras.h"
78 #include "llvm/Support/raw_ostream.h"
79 #include "llvm/Target/TargetMachine.h"
80 #include "llvm/Target/TargetOptions.h"
81 #include <algorithm>
82 #include <bitset>
83 #include <cassert>
84 #include <cctype>
85 #include <cstdint>
86 #include <cstdlib>
87 #include <iterator>
88 #include <limits>
89 #include <tuple>
90 #include <utility>
91 #include <vector>
92 
93 using namespace llvm;
94 using namespace llvm::PatternMatch;
95 
96 #define DEBUG_TYPE "aarch64-lower"
97 
98 STATISTIC(NumTailCalls, "Number of tail calls");
99 STATISTIC(NumShiftInserts, "Number of vector shift inserts");
100 STATISTIC(NumOptimizedImms, "Number of times immediates were optimized");
101 
102 // FIXME: The necessary dtprel relocations don't seem to be supported
103 // well in the GNU bfd and gold linkers at the moment. Therefore, by
104 // default, for now, fall back to GeneralDynamic code generation.
105 cl::opt<bool> EnableAArch64ELFLocalDynamicTLSGeneration(
106     "aarch64-elf-ldtls-generation", cl::Hidden,
107     cl::desc("Allow AArch64 Local Dynamic TLS code generation"),
108     cl::init(false));
109 
110 static cl::opt<bool>
111 EnableOptimizeLogicalImm("aarch64-enable-logical-imm", cl::Hidden,
112                          cl::desc("Enable AArch64 logical imm instruction "
113                                   "optimization"),
114                          cl::init(true));
115 
116 // Temporary option added for the purpose of testing functionality added
117 // to DAGCombiner.cpp in D92230. It is expected that this can be removed
118 // in future when both implementations will be based off MGATHER rather
119 // than the GLD1 nodes added for the SVE gather load intrinsics.
120 static cl::opt<bool>
121 EnableCombineMGatherIntrinsics("aarch64-enable-mgather-combine", cl::Hidden,
122                                 cl::desc("Combine extends of AArch64 masked "
123                                          "gather intrinsics"),
124                                 cl::init(true));
125 
126 /// Value type used for condition codes.
127 static const MVT MVT_CC = MVT::i32;
128 
getPackedSVEVectorVT(EVT VT)129 static inline EVT getPackedSVEVectorVT(EVT VT) {
130   switch (VT.getSimpleVT().SimpleTy) {
131   default:
132     llvm_unreachable("unexpected element type for vector");
133   case MVT::i8:
134     return MVT::nxv16i8;
135   case MVT::i16:
136     return MVT::nxv8i16;
137   case MVT::i32:
138     return MVT::nxv4i32;
139   case MVT::i64:
140     return MVT::nxv2i64;
141   case MVT::f16:
142     return MVT::nxv8f16;
143   case MVT::f32:
144     return MVT::nxv4f32;
145   case MVT::f64:
146     return MVT::nxv2f64;
147   case MVT::bf16:
148     return MVT::nxv8bf16;
149   }
150 }
151 
152 // NOTE: Currently there's only a need to return integer vector types. If this
153 // changes then just add an extra "type" parameter.
getPackedSVEVectorVT(ElementCount EC)154 static inline EVT getPackedSVEVectorVT(ElementCount EC) {
155   switch (EC.getKnownMinValue()) {
156   default:
157     llvm_unreachable("unexpected element count for vector");
158   case 16:
159     return MVT::nxv16i8;
160   case 8:
161     return MVT::nxv8i16;
162   case 4:
163     return MVT::nxv4i32;
164   case 2:
165     return MVT::nxv2i64;
166   }
167 }
168 
getPromotedVTForPredicate(EVT VT)169 static inline EVT getPromotedVTForPredicate(EVT VT) {
170   assert(VT.isScalableVector() && (VT.getVectorElementType() == MVT::i1) &&
171          "Expected scalable predicate vector type!");
172   switch (VT.getVectorMinNumElements()) {
173   default:
174     llvm_unreachable("unexpected element count for vector");
175   case 2:
176     return MVT::nxv2i64;
177   case 4:
178     return MVT::nxv4i32;
179   case 8:
180     return MVT::nxv8i16;
181   case 16:
182     return MVT::nxv16i8;
183   }
184 }
185 
186 /// Returns true if VT's elements occupy the lowest bit positions of its
187 /// associated register class without any intervening space.
188 ///
189 /// For example, nxv2f16, nxv4f16 and nxv8f16 are legal types that belong to the
190 /// same register class, but only nxv8f16 can be treated as a packed vector.
isPackedVectorType(EVT VT,SelectionDAG & DAG)191 static inline bool isPackedVectorType(EVT VT, SelectionDAG &DAG) {
192   assert(VT.isVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
193          "Expected legal vector type!");
194   return VT.isFixedLengthVector() ||
195          VT.getSizeInBits().getKnownMinSize() == AArch64::SVEBitsPerBlock;
196 }
197 
198 // Returns true for ####_MERGE_PASSTHRU opcodes, whose operands have a leading
199 // predicate and end with a passthru value matching the result type.
isMergePassthruOpcode(unsigned Opc)200 static bool isMergePassthruOpcode(unsigned Opc) {
201   switch (Opc) {
202   default:
203     return false;
204   case AArch64ISD::BITREVERSE_MERGE_PASSTHRU:
205   case AArch64ISD::BSWAP_MERGE_PASSTHRU:
206   case AArch64ISD::CTLZ_MERGE_PASSTHRU:
207   case AArch64ISD::CTPOP_MERGE_PASSTHRU:
208   case AArch64ISD::DUP_MERGE_PASSTHRU:
209   case AArch64ISD::ABS_MERGE_PASSTHRU:
210   case AArch64ISD::NEG_MERGE_PASSTHRU:
211   case AArch64ISD::FNEG_MERGE_PASSTHRU:
212   case AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU:
213   case AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU:
214   case AArch64ISD::FCEIL_MERGE_PASSTHRU:
215   case AArch64ISD::FFLOOR_MERGE_PASSTHRU:
216   case AArch64ISD::FNEARBYINT_MERGE_PASSTHRU:
217   case AArch64ISD::FRINT_MERGE_PASSTHRU:
218   case AArch64ISD::FROUND_MERGE_PASSTHRU:
219   case AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU:
220   case AArch64ISD::FTRUNC_MERGE_PASSTHRU:
221   case AArch64ISD::FP_ROUND_MERGE_PASSTHRU:
222   case AArch64ISD::FP_EXTEND_MERGE_PASSTHRU:
223   case AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU:
224   case AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU:
225   case AArch64ISD::FCVTZU_MERGE_PASSTHRU:
226   case AArch64ISD::FCVTZS_MERGE_PASSTHRU:
227   case AArch64ISD::FSQRT_MERGE_PASSTHRU:
228   case AArch64ISD::FRECPX_MERGE_PASSTHRU:
229   case AArch64ISD::FABS_MERGE_PASSTHRU:
230     return true;
231   }
232 }
233 
AArch64TargetLowering(const TargetMachine & TM,const AArch64Subtarget & STI)234 AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
235                                              const AArch64Subtarget &STI)
236     : TargetLowering(TM), Subtarget(&STI) {
237   // AArch64 doesn't have comparisons which set GPRs or setcc instructions, so
238   // we have to make something up. Arbitrarily, choose ZeroOrOne.
239   setBooleanContents(ZeroOrOneBooleanContent);
240   // When comparing vectors the result sets the different elements in the
241   // vector to all-one or all-zero.
242   setBooleanVectorContents(ZeroOrNegativeOneBooleanContent);
243 
244   // Set up the register classes.
245   addRegisterClass(MVT::i32, &AArch64::GPR32allRegClass);
246   addRegisterClass(MVT::i64, &AArch64::GPR64allRegClass);
247 
248   if (Subtarget->hasFPARMv8()) {
249     addRegisterClass(MVT::f16, &AArch64::FPR16RegClass);
250     addRegisterClass(MVT::bf16, &AArch64::FPR16RegClass);
251     addRegisterClass(MVT::f32, &AArch64::FPR32RegClass);
252     addRegisterClass(MVT::f64, &AArch64::FPR64RegClass);
253     addRegisterClass(MVT::f128, &AArch64::FPR128RegClass);
254   }
255 
256   if (Subtarget->hasNEON()) {
257     addRegisterClass(MVT::v16i8, &AArch64::FPR8RegClass);
258     addRegisterClass(MVT::v8i16, &AArch64::FPR16RegClass);
259     // Someone set us up the NEON.
260     addDRTypeForNEON(MVT::v2f32);
261     addDRTypeForNEON(MVT::v8i8);
262     addDRTypeForNEON(MVT::v4i16);
263     addDRTypeForNEON(MVT::v2i32);
264     addDRTypeForNEON(MVT::v1i64);
265     addDRTypeForNEON(MVT::v1f64);
266     addDRTypeForNEON(MVT::v4f16);
267     if (Subtarget->hasBF16())
268       addDRTypeForNEON(MVT::v4bf16);
269 
270     addQRTypeForNEON(MVT::v4f32);
271     addQRTypeForNEON(MVT::v2f64);
272     addQRTypeForNEON(MVT::v16i8);
273     addQRTypeForNEON(MVT::v8i16);
274     addQRTypeForNEON(MVT::v4i32);
275     addQRTypeForNEON(MVT::v2i64);
276     addQRTypeForNEON(MVT::v8f16);
277     if (Subtarget->hasBF16())
278       addQRTypeForNEON(MVT::v8bf16);
279   }
280 
281   if (Subtarget->hasSVE()) {
282     // Add legal sve predicate types
283     addRegisterClass(MVT::nxv2i1, &AArch64::PPRRegClass);
284     addRegisterClass(MVT::nxv4i1, &AArch64::PPRRegClass);
285     addRegisterClass(MVT::nxv8i1, &AArch64::PPRRegClass);
286     addRegisterClass(MVT::nxv16i1, &AArch64::PPRRegClass);
287 
288     // Add legal sve data types
289     addRegisterClass(MVT::nxv16i8, &AArch64::ZPRRegClass);
290     addRegisterClass(MVT::nxv8i16, &AArch64::ZPRRegClass);
291     addRegisterClass(MVT::nxv4i32, &AArch64::ZPRRegClass);
292     addRegisterClass(MVT::nxv2i64, &AArch64::ZPRRegClass);
293 
294     addRegisterClass(MVT::nxv2f16, &AArch64::ZPRRegClass);
295     addRegisterClass(MVT::nxv4f16, &AArch64::ZPRRegClass);
296     addRegisterClass(MVT::nxv8f16, &AArch64::ZPRRegClass);
297     addRegisterClass(MVT::nxv2f32, &AArch64::ZPRRegClass);
298     addRegisterClass(MVT::nxv4f32, &AArch64::ZPRRegClass);
299     addRegisterClass(MVT::nxv2f64, &AArch64::ZPRRegClass);
300 
301     if (Subtarget->hasBF16()) {
302       addRegisterClass(MVT::nxv2bf16, &AArch64::ZPRRegClass);
303       addRegisterClass(MVT::nxv4bf16, &AArch64::ZPRRegClass);
304       addRegisterClass(MVT::nxv8bf16, &AArch64::ZPRRegClass);
305     }
306 
307     if (Subtarget->useSVEForFixedLengthVectors()) {
308       for (MVT VT : MVT::integer_fixedlen_vector_valuetypes())
309         if (useSVEForFixedLengthVectorVT(VT))
310           addRegisterClass(VT, &AArch64::ZPRRegClass);
311 
312       for (MVT VT : MVT::fp_fixedlen_vector_valuetypes())
313         if (useSVEForFixedLengthVectorVT(VT))
314           addRegisterClass(VT, &AArch64::ZPRRegClass);
315     }
316 
317     for (auto VT : { MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64 }) {
318       setOperationAction(ISD::SADDSAT, VT, Legal);
319       setOperationAction(ISD::UADDSAT, VT, Legal);
320       setOperationAction(ISD::SSUBSAT, VT, Legal);
321       setOperationAction(ISD::USUBSAT, VT, Legal);
322       setOperationAction(ISD::UREM, VT, Expand);
323       setOperationAction(ISD::SREM, VT, Expand);
324       setOperationAction(ISD::SDIVREM, VT, Expand);
325       setOperationAction(ISD::UDIVREM, VT, Expand);
326     }
327 
328     for (auto VT :
329          { MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv4i8,
330            MVT::nxv4i16, MVT::nxv4i32, MVT::nxv8i8, MVT::nxv8i16 })
331       setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Legal);
332 
333     for (auto VT :
334          { MVT::nxv2f16, MVT::nxv4f16, MVT::nxv8f16, MVT::nxv2f32, MVT::nxv4f32,
335            MVT::nxv2f64 }) {
336       setCondCodeAction(ISD::SETO, VT, Expand);
337       setCondCodeAction(ISD::SETOLT, VT, Expand);
338       setCondCodeAction(ISD::SETLT, VT, Expand);
339       setCondCodeAction(ISD::SETOLE, VT, Expand);
340       setCondCodeAction(ISD::SETLE, VT, Expand);
341       setCondCodeAction(ISD::SETULT, VT, Expand);
342       setCondCodeAction(ISD::SETULE, VT, Expand);
343       setCondCodeAction(ISD::SETUGE, VT, Expand);
344       setCondCodeAction(ISD::SETUGT, VT, Expand);
345       setCondCodeAction(ISD::SETUEQ, VT, Expand);
346       setCondCodeAction(ISD::SETUNE, VT, Expand);
347 
348       setOperationAction(ISD::FREM, VT, Expand);
349       setOperationAction(ISD::FPOW, VT, Expand);
350       setOperationAction(ISD::FPOWI, VT, Expand);
351       setOperationAction(ISD::FCOS, VT, Expand);
352       setOperationAction(ISD::FSIN, VT, Expand);
353       setOperationAction(ISD::FSINCOS, VT, Expand);
354       setOperationAction(ISD::FEXP, VT, Expand);
355       setOperationAction(ISD::FEXP2, VT, Expand);
356       setOperationAction(ISD::FLOG, VT, Expand);
357       setOperationAction(ISD::FLOG2, VT, Expand);
358       setOperationAction(ISD::FLOG10, VT, Expand);
359     }
360   }
361 
362   // Compute derived properties from the register classes
363   computeRegisterProperties(Subtarget->getRegisterInfo());
364 
365   // Provide all sorts of operation actions
366   setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
367   setOperationAction(ISD::GlobalTLSAddress, MVT::i64, Custom);
368   setOperationAction(ISD::SETCC, MVT::i32, Custom);
369   setOperationAction(ISD::SETCC, MVT::i64, Custom);
370   setOperationAction(ISD::SETCC, MVT::f16, Custom);
371   setOperationAction(ISD::SETCC, MVT::f32, Custom);
372   setOperationAction(ISD::SETCC, MVT::f64, Custom);
373   setOperationAction(ISD::STRICT_FSETCC, MVT::f16, Custom);
374   setOperationAction(ISD::STRICT_FSETCC, MVT::f32, Custom);
375   setOperationAction(ISD::STRICT_FSETCC, MVT::f64, Custom);
376   setOperationAction(ISD::STRICT_FSETCCS, MVT::f16, Custom);
377   setOperationAction(ISD::STRICT_FSETCCS, MVT::f32, Custom);
378   setOperationAction(ISD::STRICT_FSETCCS, MVT::f64, Custom);
379   setOperationAction(ISD::BITREVERSE, MVT::i32, Legal);
380   setOperationAction(ISD::BITREVERSE, MVT::i64, Legal);
381   setOperationAction(ISD::BRCOND, MVT::Other, Expand);
382   setOperationAction(ISD::BR_CC, MVT::i32, Custom);
383   setOperationAction(ISD::BR_CC, MVT::i64, Custom);
384   setOperationAction(ISD::BR_CC, MVT::f16, Custom);
385   setOperationAction(ISD::BR_CC, MVT::f32, Custom);
386   setOperationAction(ISD::BR_CC, MVT::f64, Custom);
387   setOperationAction(ISD::SELECT, MVT::i32, Custom);
388   setOperationAction(ISD::SELECT, MVT::i64, Custom);
389   setOperationAction(ISD::SELECT, MVT::f16, Custom);
390   setOperationAction(ISD::SELECT, MVT::f32, Custom);
391   setOperationAction(ISD::SELECT, MVT::f64, Custom);
392   setOperationAction(ISD::SELECT_CC, MVT::i32, Custom);
393   setOperationAction(ISD::SELECT_CC, MVT::i64, Custom);
394   setOperationAction(ISD::SELECT_CC, MVT::f16, Custom);
395   setOperationAction(ISD::SELECT_CC, MVT::f32, Custom);
396   setOperationAction(ISD::SELECT_CC, MVT::f64, Custom);
397   setOperationAction(ISD::BR_JT, MVT::Other, Custom);
398   setOperationAction(ISD::JumpTable, MVT::i64, Custom);
399 
400   setOperationAction(ISD::SHL_PARTS, MVT::i64, Custom);
401   setOperationAction(ISD::SRA_PARTS, MVT::i64, Custom);
402   setOperationAction(ISD::SRL_PARTS, MVT::i64, Custom);
403 
404   setOperationAction(ISD::FREM, MVT::f32, Expand);
405   setOperationAction(ISD::FREM, MVT::f64, Expand);
406   setOperationAction(ISD::FREM, MVT::f80, Expand);
407 
408   setOperationAction(ISD::BUILD_PAIR, MVT::i64, Expand);
409 
410   // Custom lowering hooks are needed for XOR
411   // to fold it into CSINC/CSINV.
412   setOperationAction(ISD::XOR, MVT::i32, Custom);
413   setOperationAction(ISD::XOR, MVT::i64, Custom);
414 
415   // Virtually no operation on f128 is legal, but LLVM can't expand them when
416   // there's a valid register class, so we need custom operations in most cases.
417   setOperationAction(ISD::FABS, MVT::f128, Expand);
418   setOperationAction(ISD::FADD, MVT::f128, LibCall);
419   setOperationAction(ISD::FCOPYSIGN, MVT::f128, Expand);
420   setOperationAction(ISD::FCOS, MVT::f128, Expand);
421   setOperationAction(ISD::FDIV, MVT::f128, LibCall);
422   setOperationAction(ISD::FMA, MVT::f128, Expand);
423   setOperationAction(ISD::FMUL, MVT::f128, LibCall);
424   setOperationAction(ISD::FNEG, MVT::f128, Expand);
425   setOperationAction(ISD::FPOW, MVT::f128, Expand);
426   setOperationAction(ISD::FREM, MVT::f128, Expand);
427   setOperationAction(ISD::FRINT, MVT::f128, Expand);
428   setOperationAction(ISD::FSIN, MVT::f128, Expand);
429   setOperationAction(ISD::FSINCOS, MVT::f128, Expand);
430   setOperationAction(ISD::FSQRT, MVT::f128, Expand);
431   setOperationAction(ISD::FSUB, MVT::f128, LibCall);
432   setOperationAction(ISD::FTRUNC, MVT::f128, Expand);
433   setOperationAction(ISD::SETCC, MVT::f128, Custom);
434   setOperationAction(ISD::STRICT_FSETCC, MVT::f128, Custom);
435   setOperationAction(ISD::STRICT_FSETCCS, MVT::f128, Custom);
436   setOperationAction(ISD::BR_CC, MVT::f128, Custom);
437   setOperationAction(ISD::SELECT, MVT::f128, Custom);
438   setOperationAction(ISD::SELECT_CC, MVT::f128, Custom);
439   setOperationAction(ISD::FP_EXTEND, MVT::f128, Custom);
440 
441   // Lowering for many of the conversions is actually specified by the non-f128
442   // type. The LowerXXX function will be trivial when f128 isn't involved.
443   setOperationAction(ISD::FP_TO_SINT, MVT::i32, Custom);
444   setOperationAction(ISD::FP_TO_SINT, MVT::i64, Custom);
445   setOperationAction(ISD::FP_TO_SINT, MVT::i128, Custom);
446   setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i32, Custom);
447   setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i64, Custom);
448   setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i128, Custom);
449   setOperationAction(ISD::FP_TO_UINT, MVT::i32, Custom);
450   setOperationAction(ISD::FP_TO_UINT, MVT::i64, Custom);
451   setOperationAction(ISD::FP_TO_UINT, MVT::i128, Custom);
452   setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::i32, Custom);
453   setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::i64, Custom);
454   setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::i128, Custom);
455   setOperationAction(ISD::SINT_TO_FP, MVT::i32, Custom);
456   setOperationAction(ISD::SINT_TO_FP, MVT::i64, Custom);
457   setOperationAction(ISD::SINT_TO_FP, MVT::i128, Custom);
458   setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::i32, Custom);
459   setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::i64, Custom);
460   setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::i128, Custom);
461   setOperationAction(ISD::UINT_TO_FP, MVT::i32, Custom);
462   setOperationAction(ISD::UINT_TO_FP, MVT::i64, Custom);
463   setOperationAction(ISD::UINT_TO_FP, MVT::i128, Custom);
464   setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i32, Custom);
465   setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i64, Custom);
466   setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i128, Custom);
467   setOperationAction(ISD::FP_ROUND, MVT::f16, Custom);
468   setOperationAction(ISD::FP_ROUND, MVT::f32, Custom);
469   setOperationAction(ISD::FP_ROUND, MVT::f64, Custom);
470   setOperationAction(ISD::STRICT_FP_ROUND, MVT::f16, Custom);
471   setOperationAction(ISD::STRICT_FP_ROUND, MVT::f32, Custom);
472   setOperationAction(ISD::STRICT_FP_ROUND, MVT::f64, Custom);
473 
474   setOperationAction(ISD::FP_TO_UINT_SAT, MVT::i32, Custom);
475   setOperationAction(ISD::FP_TO_UINT_SAT, MVT::i64, Custom);
476   setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i32, Custom);
477   setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i64, Custom);
478 
479   // Variable arguments.
480   setOperationAction(ISD::VASTART, MVT::Other, Custom);
481   setOperationAction(ISD::VAARG, MVT::Other, Custom);
482   setOperationAction(ISD::VACOPY, MVT::Other, Custom);
483   setOperationAction(ISD::VAEND, MVT::Other, Expand);
484 
485   // Variable-sized objects.
486   setOperationAction(ISD::STACKSAVE, MVT::Other, Expand);
487   setOperationAction(ISD::STACKRESTORE, MVT::Other, Expand);
488 
489   if (Subtarget->isTargetWindows())
490     setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i64, Custom);
491   else
492     setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i64, Expand);
493 
494   // Constant pool entries
495   setOperationAction(ISD::ConstantPool, MVT::i64, Custom);
496 
497   // BlockAddress
498   setOperationAction(ISD::BlockAddress, MVT::i64, Custom);
499 
500   // Add/Sub overflow ops with MVT::Glues are lowered to NZCV dependences.
501   setOperationAction(ISD::ADDC, MVT::i32, Custom);
502   setOperationAction(ISD::ADDE, MVT::i32, Custom);
503   setOperationAction(ISD::SUBC, MVT::i32, Custom);
504   setOperationAction(ISD::SUBE, MVT::i32, Custom);
505   setOperationAction(ISD::ADDC, MVT::i64, Custom);
506   setOperationAction(ISD::ADDE, MVT::i64, Custom);
507   setOperationAction(ISD::SUBC, MVT::i64, Custom);
508   setOperationAction(ISD::SUBE, MVT::i64, Custom);
509 
510   // AArch64 lacks both left-rotate and popcount instructions.
511   setOperationAction(ISD::ROTL, MVT::i32, Expand);
512   setOperationAction(ISD::ROTL, MVT::i64, Expand);
513   for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
514     setOperationAction(ISD::ROTL, VT, Expand);
515     setOperationAction(ISD::ROTR, VT, Expand);
516   }
517 
518   // AArch64 doesn't have i32 MULH{S|U}.
519   setOperationAction(ISD::MULHU, MVT::i32, Expand);
520   setOperationAction(ISD::MULHS, MVT::i32, Expand);
521 
522   // AArch64 doesn't have {U|S}MUL_LOHI.
523   setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
524   setOperationAction(ISD::SMUL_LOHI, MVT::i64, Expand);
525 
526   setOperationAction(ISD::CTPOP, MVT::i32, Custom);
527   setOperationAction(ISD::CTPOP, MVT::i64, Custom);
528   setOperationAction(ISD::CTPOP, MVT::i128, Custom);
529 
530   setOperationAction(ISD::ABS, MVT::i32, Custom);
531   setOperationAction(ISD::ABS, MVT::i64, Custom);
532 
533   setOperationAction(ISD::SDIVREM, MVT::i32, Expand);
534   setOperationAction(ISD::SDIVREM, MVT::i64, Expand);
535   for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
536     setOperationAction(ISD::SDIVREM, VT, Expand);
537     setOperationAction(ISD::UDIVREM, VT, Expand);
538   }
539   setOperationAction(ISD::SREM, MVT::i32, Expand);
540   setOperationAction(ISD::SREM, MVT::i64, Expand);
541   setOperationAction(ISD::UDIVREM, MVT::i32, Expand);
542   setOperationAction(ISD::UDIVREM, MVT::i64, Expand);
543   setOperationAction(ISD::UREM, MVT::i32, Expand);
544   setOperationAction(ISD::UREM, MVT::i64, Expand);
545 
546   // Custom lower Add/Sub/Mul with overflow.
547   setOperationAction(ISD::SADDO, MVT::i32, Custom);
548   setOperationAction(ISD::SADDO, MVT::i64, Custom);
549   setOperationAction(ISD::UADDO, MVT::i32, Custom);
550   setOperationAction(ISD::UADDO, MVT::i64, Custom);
551   setOperationAction(ISD::SSUBO, MVT::i32, Custom);
552   setOperationAction(ISD::SSUBO, MVT::i64, Custom);
553   setOperationAction(ISD::USUBO, MVT::i32, Custom);
554   setOperationAction(ISD::USUBO, MVT::i64, Custom);
555   setOperationAction(ISD::SMULO, MVT::i32, Custom);
556   setOperationAction(ISD::SMULO, MVT::i64, Custom);
557   setOperationAction(ISD::UMULO, MVT::i32, Custom);
558   setOperationAction(ISD::UMULO, MVT::i64, Custom);
559 
560   setOperationAction(ISD::FSIN, MVT::f32, Expand);
561   setOperationAction(ISD::FSIN, MVT::f64, Expand);
562   setOperationAction(ISD::FCOS, MVT::f32, Expand);
563   setOperationAction(ISD::FCOS, MVT::f64, Expand);
564   setOperationAction(ISD::FPOW, MVT::f32, Expand);
565   setOperationAction(ISD::FPOW, MVT::f64, Expand);
566   setOperationAction(ISD::FCOPYSIGN, MVT::f64, Custom);
567   setOperationAction(ISD::FCOPYSIGN, MVT::f32, Custom);
568   if (Subtarget->hasFullFP16())
569     setOperationAction(ISD::FCOPYSIGN, MVT::f16, Custom);
570   else
571     setOperationAction(ISD::FCOPYSIGN, MVT::f16, Promote);
572 
573   setOperationAction(ISD::FREM,    MVT::f16,   Promote);
574   setOperationAction(ISD::FREM,    MVT::v4f16, Expand);
575   setOperationAction(ISD::FREM,    MVT::v8f16, Expand);
576   setOperationAction(ISD::FPOW,    MVT::f16,   Promote);
577   setOperationAction(ISD::FPOW,    MVT::v4f16, Expand);
578   setOperationAction(ISD::FPOW,    MVT::v8f16, Expand);
579   setOperationAction(ISD::FPOWI,   MVT::f16,   Promote);
580   setOperationAction(ISD::FPOWI,   MVT::v4f16, Expand);
581   setOperationAction(ISD::FPOWI,   MVT::v8f16, Expand);
582   setOperationAction(ISD::FCOS,    MVT::f16,   Promote);
583   setOperationAction(ISD::FCOS,    MVT::v4f16, Expand);
584   setOperationAction(ISD::FCOS,    MVT::v8f16, Expand);
585   setOperationAction(ISD::FSIN,    MVT::f16,   Promote);
586   setOperationAction(ISD::FSIN,    MVT::v4f16, Expand);
587   setOperationAction(ISD::FSIN,    MVT::v8f16, Expand);
588   setOperationAction(ISD::FSINCOS, MVT::f16,   Promote);
589   setOperationAction(ISD::FSINCOS, MVT::v4f16, Expand);
590   setOperationAction(ISD::FSINCOS, MVT::v8f16, Expand);
591   setOperationAction(ISD::FEXP,    MVT::f16,   Promote);
592   setOperationAction(ISD::FEXP,    MVT::v4f16, Expand);
593   setOperationAction(ISD::FEXP,    MVT::v8f16, Expand);
594   setOperationAction(ISD::FEXP2,   MVT::f16,   Promote);
595   setOperationAction(ISD::FEXP2,   MVT::v4f16, Expand);
596   setOperationAction(ISD::FEXP2,   MVT::v8f16, Expand);
597   setOperationAction(ISD::FLOG,    MVT::f16,   Promote);
598   setOperationAction(ISD::FLOG,    MVT::v4f16, Expand);
599   setOperationAction(ISD::FLOG,    MVT::v8f16, Expand);
600   setOperationAction(ISD::FLOG2,   MVT::f16,   Promote);
601   setOperationAction(ISD::FLOG2,   MVT::v4f16, Expand);
602   setOperationAction(ISD::FLOG2,   MVT::v8f16, Expand);
603   setOperationAction(ISD::FLOG10,  MVT::f16,   Promote);
604   setOperationAction(ISD::FLOG10,  MVT::v4f16, Expand);
605   setOperationAction(ISD::FLOG10,  MVT::v8f16, Expand);
606 
607   if (!Subtarget->hasFullFP16()) {
608     setOperationAction(ISD::SELECT,      MVT::f16,  Promote);
609     setOperationAction(ISD::SELECT_CC,   MVT::f16,  Promote);
610     setOperationAction(ISD::SETCC,       MVT::f16,  Promote);
611     setOperationAction(ISD::BR_CC,       MVT::f16,  Promote);
612     setOperationAction(ISD::FADD,        MVT::f16,  Promote);
613     setOperationAction(ISD::FSUB,        MVT::f16,  Promote);
614     setOperationAction(ISD::FMUL,        MVT::f16,  Promote);
615     setOperationAction(ISD::FDIV,        MVT::f16,  Promote);
616     setOperationAction(ISD::FMA,         MVT::f16,  Promote);
617     setOperationAction(ISD::FNEG,        MVT::f16,  Promote);
618     setOperationAction(ISD::FABS,        MVT::f16,  Promote);
619     setOperationAction(ISD::FCEIL,       MVT::f16,  Promote);
620     setOperationAction(ISD::FSQRT,       MVT::f16,  Promote);
621     setOperationAction(ISD::FFLOOR,      MVT::f16,  Promote);
622     setOperationAction(ISD::FNEARBYINT,  MVT::f16,  Promote);
623     setOperationAction(ISD::FRINT,       MVT::f16,  Promote);
624     setOperationAction(ISD::FROUND,      MVT::f16,  Promote);
625     setOperationAction(ISD::FROUNDEVEN,  MVT::f16,  Promote);
626     setOperationAction(ISD::FTRUNC,      MVT::f16,  Promote);
627     setOperationAction(ISD::FMINNUM,     MVT::f16,  Promote);
628     setOperationAction(ISD::FMAXNUM,     MVT::f16,  Promote);
629     setOperationAction(ISD::FMINIMUM,    MVT::f16,  Promote);
630     setOperationAction(ISD::FMAXIMUM,    MVT::f16,  Promote);
631 
632     // promote v4f16 to v4f32 when that is known to be safe.
633     setOperationAction(ISD::FADD,        MVT::v4f16, Promote);
634     setOperationAction(ISD::FSUB,        MVT::v4f16, Promote);
635     setOperationAction(ISD::FMUL,        MVT::v4f16, Promote);
636     setOperationAction(ISD::FDIV,        MVT::v4f16, Promote);
637     AddPromotedToType(ISD::FADD,         MVT::v4f16, MVT::v4f32);
638     AddPromotedToType(ISD::FSUB,         MVT::v4f16, MVT::v4f32);
639     AddPromotedToType(ISD::FMUL,         MVT::v4f16, MVT::v4f32);
640     AddPromotedToType(ISD::FDIV,         MVT::v4f16, MVT::v4f32);
641 
642     setOperationAction(ISD::FABS,        MVT::v4f16, Expand);
643     setOperationAction(ISD::FNEG,        MVT::v4f16, Expand);
644     setOperationAction(ISD::FROUND,      MVT::v4f16, Expand);
645     setOperationAction(ISD::FROUNDEVEN,  MVT::v4f16, Expand);
646     setOperationAction(ISD::FMA,         MVT::v4f16, Expand);
647     setOperationAction(ISD::SETCC,       MVT::v4f16, Expand);
648     setOperationAction(ISD::BR_CC,       MVT::v4f16, Expand);
649     setOperationAction(ISD::SELECT,      MVT::v4f16, Expand);
650     setOperationAction(ISD::SELECT_CC,   MVT::v4f16, Expand);
651     setOperationAction(ISD::FTRUNC,      MVT::v4f16, Expand);
652     setOperationAction(ISD::FCOPYSIGN,   MVT::v4f16, Expand);
653     setOperationAction(ISD::FFLOOR,      MVT::v4f16, Expand);
654     setOperationAction(ISD::FCEIL,       MVT::v4f16, Expand);
655     setOperationAction(ISD::FRINT,       MVT::v4f16, Expand);
656     setOperationAction(ISD::FNEARBYINT,  MVT::v4f16, Expand);
657     setOperationAction(ISD::FSQRT,       MVT::v4f16, Expand);
658 
659     setOperationAction(ISD::FABS,        MVT::v8f16, Expand);
660     setOperationAction(ISD::FADD,        MVT::v8f16, Expand);
661     setOperationAction(ISD::FCEIL,       MVT::v8f16, Expand);
662     setOperationAction(ISD::FCOPYSIGN,   MVT::v8f16, Expand);
663     setOperationAction(ISD::FDIV,        MVT::v8f16, Expand);
664     setOperationAction(ISD::FFLOOR,      MVT::v8f16, Expand);
665     setOperationAction(ISD::FMA,         MVT::v8f16, Expand);
666     setOperationAction(ISD::FMUL,        MVT::v8f16, Expand);
667     setOperationAction(ISD::FNEARBYINT,  MVT::v8f16, Expand);
668     setOperationAction(ISD::FNEG,        MVT::v8f16, Expand);
669     setOperationAction(ISD::FROUND,      MVT::v8f16, Expand);
670     setOperationAction(ISD::FROUNDEVEN,  MVT::v8f16, Expand);
671     setOperationAction(ISD::FRINT,       MVT::v8f16, Expand);
672     setOperationAction(ISD::FSQRT,       MVT::v8f16, Expand);
673     setOperationAction(ISD::FSUB,        MVT::v8f16, Expand);
674     setOperationAction(ISD::FTRUNC,      MVT::v8f16, Expand);
675     setOperationAction(ISD::SETCC,       MVT::v8f16, Expand);
676     setOperationAction(ISD::BR_CC,       MVT::v8f16, Expand);
677     setOperationAction(ISD::SELECT,      MVT::v8f16, Expand);
678     setOperationAction(ISD::SELECT_CC,   MVT::v8f16, Expand);
679     setOperationAction(ISD::FP_EXTEND,   MVT::v8f16, Expand);
680   }
681 
682   // AArch64 has implementations of a lot of rounding-like FP operations.
683   for (MVT Ty : {MVT::f32, MVT::f64}) {
684     setOperationAction(ISD::FFLOOR, Ty, Legal);
685     setOperationAction(ISD::FNEARBYINT, Ty, Legal);
686     setOperationAction(ISD::FCEIL, Ty, Legal);
687     setOperationAction(ISD::FRINT, Ty, Legal);
688     setOperationAction(ISD::FTRUNC, Ty, Legal);
689     setOperationAction(ISD::FROUND, Ty, Legal);
690     setOperationAction(ISD::FROUNDEVEN, Ty, Legal);
691     setOperationAction(ISD::FMINNUM, Ty, Legal);
692     setOperationAction(ISD::FMAXNUM, Ty, Legal);
693     setOperationAction(ISD::FMINIMUM, Ty, Legal);
694     setOperationAction(ISD::FMAXIMUM, Ty, Legal);
695     setOperationAction(ISD::LROUND, Ty, Legal);
696     setOperationAction(ISD::LLROUND, Ty, Legal);
697     setOperationAction(ISD::LRINT, Ty, Legal);
698     setOperationAction(ISD::LLRINT, Ty, Legal);
699   }
700 
701   if (Subtarget->hasFullFP16()) {
702     setOperationAction(ISD::FNEARBYINT, MVT::f16, Legal);
703     setOperationAction(ISD::FFLOOR,  MVT::f16, Legal);
704     setOperationAction(ISD::FCEIL,   MVT::f16, Legal);
705     setOperationAction(ISD::FRINT,   MVT::f16, Legal);
706     setOperationAction(ISD::FTRUNC,  MVT::f16, Legal);
707     setOperationAction(ISD::FROUND,  MVT::f16, Legal);
708     setOperationAction(ISD::FROUNDEVEN,  MVT::f16, Legal);
709     setOperationAction(ISD::FMINNUM, MVT::f16, Legal);
710     setOperationAction(ISD::FMAXNUM, MVT::f16, Legal);
711     setOperationAction(ISD::FMINIMUM, MVT::f16, Legal);
712     setOperationAction(ISD::FMAXIMUM, MVT::f16, Legal);
713   }
714 
715   setOperationAction(ISD::PREFETCH, MVT::Other, Custom);
716 
717   setOperationAction(ISD::FLT_ROUNDS_, MVT::i32, Custom);
718   setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom);
719 
720   setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i128, Custom);
721   setOperationAction(ISD::ATOMIC_LOAD_SUB, MVT::i32, Custom);
722   setOperationAction(ISD::ATOMIC_LOAD_SUB, MVT::i64, Custom);
723   setOperationAction(ISD::ATOMIC_LOAD_AND, MVT::i32, Custom);
724   setOperationAction(ISD::ATOMIC_LOAD_AND, MVT::i64, Custom);
725 
726   // Generate outline atomics library calls only if LSE was not specified for
727   // subtarget
728   if (Subtarget->outlineAtomics() && !Subtarget->hasLSE()) {
729     setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i8, LibCall);
730     setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i16, LibCall);
731     setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i32, LibCall);
732     setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i64, LibCall);
733     setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i128, LibCall);
734     setOperationAction(ISD::ATOMIC_SWAP, MVT::i8, LibCall);
735     setOperationAction(ISD::ATOMIC_SWAP, MVT::i16, LibCall);
736     setOperationAction(ISD::ATOMIC_SWAP, MVT::i32, LibCall);
737     setOperationAction(ISD::ATOMIC_SWAP, MVT::i64, LibCall);
738     setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i8, LibCall);
739     setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i16, LibCall);
740     setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i32, LibCall);
741     setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i64, LibCall);
742     setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i8, LibCall);
743     setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i16, LibCall);
744     setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i32, LibCall);
745     setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i64, LibCall);
746     setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i8, LibCall);
747     setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i16, LibCall);
748     setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i32, LibCall);
749     setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i64, LibCall);
750     setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i8, LibCall);
751     setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i16, LibCall);
752     setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i32, LibCall);
753     setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i64, LibCall);
754 #define LCALLNAMES(A, B, N)                                                    \
755   setLibcallName(A##N##_RELAX, #B #N "_relax");                                \
756   setLibcallName(A##N##_ACQ, #B #N "_acq");                                    \
757   setLibcallName(A##N##_REL, #B #N "_rel");                                    \
758   setLibcallName(A##N##_ACQ_REL, #B #N "_acq_rel");
759 #define LCALLNAME4(A, B)                                                       \
760   LCALLNAMES(A, B, 1)                                                          \
761   LCALLNAMES(A, B, 2) LCALLNAMES(A, B, 4) LCALLNAMES(A, B, 8)
762 #define LCALLNAME5(A, B)                                                       \
763   LCALLNAMES(A, B, 1)                                                          \
764   LCALLNAMES(A, B, 2)                                                          \
765   LCALLNAMES(A, B, 4) LCALLNAMES(A, B, 8) LCALLNAMES(A, B, 16)
766     LCALLNAME5(RTLIB::OUTLINE_ATOMIC_CAS, __aarch64_cas)
767     LCALLNAME4(RTLIB::OUTLINE_ATOMIC_SWP, __aarch64_swp)
768     LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDADD, __aarch64_ldadd)
769     LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDSET, __aarch64_ldset)
770     LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDCLR, __aarch64_ldclr)
771     LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDEOR, __aarch64_ldeor)
772 #undef LCALLNAMES
773 #undef LCALLNAME4
774 #undef LCALLNAME5
775   }
776 
777   // 128-bit loads and stores can be done without expanding
778   setOperationAction(ISD::LOAD, MVT::i128, Custom);
779   setOperationAction(ISD::STORE, MVT::i128, Custom);
780 
781   // 256 bit non-temporal stores can be lowered to STNP. Do this as part of the
782   // custom lowering, as there are no un-paired non-temporal stores and
783   // legalization will break up 256 bit inputs.
784   setOperationAction(ISD::STORE, MVT::v32i8, Custom);
785   setOperationAction(ISD::STORE, MVT::v16i16, Custom);
786   setOperationAction(ISD::STORE, MVT::v16f16, Custom);
787   setOperationAction(ISD::STORE, MVT::v8i32, Custom);
788   setOperationAction(ISD::STORE, MVT::v8f32, Custom);
789   setOperationAction(ISD::STORE, MVT::v4f64, Custom);
790   setOperationAction(ISD::STORE, MVT::v4i64, Custom);
791 
792   // Lower READCYCLECOUNTER using an mrs from PMCCNTR_EL0.
793   // This requires the Performance Monitors extension.
794   if (Subtarget->hasPerfMon())
795     setOperationAction(ISD::READCYCLECOUNTER, MVT::i64, Legal);
796 
797   if (getLibcallName(RTLIB::SINCOS_STRET_F32) != nullptr &&
798       getLibcallName(RTLIB::SINCOS_STRET_F64) != nullptr) {
799     // Issue __sincos_stret if available.
800     setOperationAction(ISD::FSINCOS, MVT::f64, Custom);
801     setOperationAction(ISD::FSINCOS, MVT::f32, Custom);
802   } else {
803     setOperationAction(ISD::FSINCOS, MVT::f64, Expand);
804     setOperationAction(ISD::FSINCOS, MVT::f32, Expand);
805   }
806 
807   if (Subtarget->getTargetTriple().isOSMSVCRT()) {
808     // MSVCRT doesn't have powi; fall back to pow
809     setLibcallName(RTLIB::POWI_F32, nullptr);
810     setLibcallName(RTLIB::POWI_F64, nullptr);
811   }
812 
813   // Make floating-point constants legal for the large code model, so they don't
814   // become loads from the constant pool.
815   if (Subtarget->isTargetMachO() && TM.getCodeModel() == CodeModel::Large) {
816     setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
817     setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
818   }
819 
820   // AArch64 does not have floating-point extending loads, i1 sign-extending
821   // load, floating-point truncating stores, or v2i32->v2i16 truncating store.
822   for (MVT VT : MVT::fp_valuetypes()) {
823     setLoadExtAction(ISD::EXTLOAD, VT, MVT::f16, Expand);
824     setLoadExtAction(ISD::EXTLOAD, VT, MVT::f32, Expand);
825     setLoadExtAction(ISD::EXTLOAD, VT, MVT::f64, Expand);
826     setLoadExtAction(ISD::EXTLOAD, VT, MVT::f80, Expand);
827   }
828   for (MVT VT : MVT::integer_valuetypes())
829     setLoadExtAction(ISD::SEXTLOAD, VT, MVT::i1, Expand);
830 
831   setTruncStoreAction(MVT::f32, MVT::f16, Expand);
832   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
833   setTruncStoreAction(MVT::f64, MVT::f16, Expand);
834   setTruncStoreAction(MVT::f128, MVT::f80, Expand);
835   setTruncStoreAction(MVT::f128, MVT::f64, Expand);
836   setTruncStoreAction(MVT::f128, MVT::f32, Expand);
837   setTruncStoreAction(MVT::f128, MVT::f16, Expand);
838 
839   setOperationAction(ISD::BITCAST, MVT::i16, Custom);
840   setOperationAction(ISD::BITCAST, MVT::f16, Custom);
841   setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
842 
843   // Indexed loads and stores are supported.
844   for (unsigned im = (unsigned)ISD::PRE_INC;
845        im != (unsigned)ISD::LAST_INDEXED_MODE; ++im) {
846     setIndexedLoadAction(im, MVT::i8, Legal);
847     setIndexedLoadAction(im, MVT::i16, Legal);
848     setIndexedLoadAction(im, MVT::i32, Legal);
849     setIndexedLoadAction(im, MVT::i64, Legal);
850     setIndexedLoadAction(im, MVT::f64, Legal);
851     setIndexedLoadAction(im, MVT::f32, Legal);
852     setIndexedLoadAction(im, MVT::f16, Legal);
853     setIndexedLoadAction(im, MVT::bf16, Legal);
854     setIndexedStoreAction(im, MVT::i8, Legal);
855     setIndexedStoreAction(im, MVT::i16, Legal);
856     setIndexedStoreAction(im, MVT::i32, Legal);
857     setIndexedStoreAction(im, MVT::i64, Legal);
858     setIndexedStoreAction(im, MVT::f64, Legal);
859     setIndexedStoreAction(im, MVT::f32, Legal);
860     setIndexedStoreAction(im, MVT::f16, Legal);
861     setIndexedStoreAction(im, MVT::bf16, Legal);
862   }
863 
864   // Trap.
865   setOperationAction(ISD::TRAP, MVT::Other, Legal);
866   setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal);
867   setOperationAction(ISD::UBSANTRAP, MVT::Other, Legal);
868 
869   // We combine OR nodes for bitfield operations.
870   setTargetDAGCombine(ISD::OR);
871   // Try to create BICs for vector ANDs.
872   setTargetDAGCombine(ISD::AND);
873 
874   // Vector add and sub nodes may conceal a high-half opportunity.
875   // Also, try to fold ADD into CSINC/CSINV..
876   setTargetDAGCombine(ISD::ADD);
877   setTargetDAGCombine(ISD::ABS);
878   setTargetDAGCombine(ISD::SUB);
879   setTargetDAGCombine(ISD::SRL);
880   setTargetDAGCombine(ISD::XOR);
881   setTargetDAGCombine(ISD::SINT_TO_FP);
882   setTargetDAGCombine(ISD::UINT_TO_FP);
883 
884   // TODO: Do the same for FP_TO_*INT_SAT.
885   setTargetDAGCombine(ISD::FP_TO_SINT);
886   setTargetDAGCombine(ISD::FP_TO_UINT);
887   setTargetDAGCombine(ISD::FDIV);
888 
889   setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
890 
891   setTargetDAGCombine(ISD::ANY_EXTEND);
892   setTargetDAGCombine(ISD::ZERO_EXTEND);
893   setTargetDAGCombine(ISD::SIGN_EXTEND);
894   setTargetDAGCombine(ISD::SIGN_EXTEND_INREG);
895   setTargetDAGCombine(ISD::TRUNCATE);
896   setTargetDAGCombine(ISD::CONCAT_VECTORS);
897   setTargetDAGCombine(ISD::STORE);
898   if (Subtarget->supportsAddressTopByteIgnored())
899     setTargetDAGCombine(ISD::LOAD);
900 
901   setTargetDAGCombine(ISD::MUL);
902 
903   setTargetDAGCombine(ISD::SELECT);
904   setTargetDAGCombine(ISD::VSELECT);
905 
906   setTargetDAGCombine(ISD::INTRINSIC_VOID);
907   setTargetDAGCombine(ISD::INTRINSIC_W_CHAIN);
908   setTargetDAGCombine(ISD::INSERT_VECTOR_ELT);
909   setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT);
910   setTargetDAGCombine(ISD::VECREDUCE_ADD);
911   setTargetDAGCombine(ISD::STEP_VECTOR);
912 
913   setTargetDAGCombine(ISD::GlobalAddress);
914 
915   // In case of strict alignment, avoid an excessive number of byte wide stores.
916   MaxStoresPerMemsetOptSize = 8;
917   MaxStoresPerMemset = Subtarget->requiresStrictAlign()
918                        ? MaxStoresPerMemsetOptSize : 32;
919 
920   MaxGluedStoresPerMemcpy = 4;
921   MaxStoresPerMemcpyOptSize = 4;
922   MaxStoresPerMemcpy = Subtarget->requiresStrictAlign()
923                        ? MaxStoresPerMemcpyOptSize : 16;
924 
925   MaxStoresPerMemmoveOptSize = MaxStoresPerMemmove = 4;
926 
927   MaxLoadsPerMemcmpOptSize = 4;
928   MaxLoadsPerMemcmp = Subtarget->requiresStrictAlign()
929                       ? MaxLoadsPerMemcmpOptSize : 8;
930 
931   setStackPointerRegisterToSaveRestore(AArch64::SP);
932 
933   setSchedulingPreference(Sched::Hybrid);
934 
935   EnableExtLdPromotion = true;
936 
937   // Set required alignment.
938   setMinFunctionAlignment(Align(4));
939   // Set preferred alignments.
940   setPrefLoopAlignment(Align(1ULL << STI.getPrefLoopLogAlignment()));
941   setPrefFunctionAlignment(Align(1ULL << STI.getPrefFunctionLogAlignment()));
942 
943   // Only change the limit for entries in a jump table if specified by
944   // the sub target, but not at the command line.
945   unsigned MaxJT = STI.getMaximumJumpTableSize();
946   if (MaxJT && getMaximumJumpTableSize() == UINT_MAX)
947     setMaximumJumpTableSize(MaxJT);
948 
949   setHasExtractBitsInsn(true);
950 
951   setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
952 
953   if (Subtarget->hasNEON()) {
954     // FIXME: v1f64 shouldn't be legal if we can avoid it, because it leads to
955     // silliness like this:
956     setOperationAction(ISD::FABS, MVT::v1f64, Expand);
957     setOperationAction(ISD::FADD, MVT::v1f64, Expand);
958     setOperationAction(ISD::FCEIL, MVT::v1f64, Expand);
959     setOperationAction(ISD::FCOPYSIGN, MVT::v1f64, Expand);
960     setOperationAction(ISD::FCOS, MVT::v1f64, Expand);
961     setOperationAction(ISD::FDIV, MVT::v1f64, Expand);
962     setOperationAction(ISD::FFLOOR, MVT::v1f64, Expand);
963     setOperationAction(ISD::FMA, MVT::v1f64, Expand);
964     setOperationAction(ISD::FMUL, MVT::v1f64, Expand);
965     setOperationAction(ISD::FNEARBYINT, MVT::v1f64, Expand);
966     setOperationAction(ISD::FNEG, MVT::v1f64, Expand);
967     setOperationAction(ISD::FPOW, MVT::v1f64, Expand);
968     setOperationAction(ISD::FREM, MVT::v1f64, Expand);
969     setOperationAction(ISD::FROUND, MVT::v1f64, Expand);
970     setOperationAction(ISD::FROUNDEVEN, MVT::v1f64, Expand);
971     setOperationAction(ISD::FRINT, MVT::v1f64, Expand);
972     setOperationAction(ISD::FSIN, MVT::v1f64, Expand);
973     setOperationAction(ISD::FSINCOS, MVT::v1f64, Expand);
974     setOperationAction(ISD::FSQRT, MVT::v1f64, Expand);
975     setOperationAction(ISD::FSUB, MVT::v1f64, Expand);
976     setOperationAction(ISD::FTRUNC, MVT::v1f64, Expand);
977     setOperationAction(ISD::SETCC, MVT::v1f64, Expand);
978     setOperationAction(ISD::BR_CC, MVT::v1f64, Expand);
979     setOperationAction(ISD::SELECT, MVT::v1f64, Expand);
980     setOperationAction(ISD::SELECT_CC, MVT::v1f64, Expand);
981     setOperationAction(ISD::FP_EXTEND, MVT::v1f64, Expand);
982 
983     setOperationAction(ISD::FP_TO_SINT, MVT::v1i64, Expand);
984     setOperationAction(ISD::FP_TO_UINT, MVT::v1i64, Expand);
985     setOperationAction(ISD::SINT_TO_FP, MVT::v1i64, Expand);
986     setOperationAction(ISD::UINT_TO_FP, MVT::v1i64, Expand);
987     setOperationAction(ISD::FP_ROUND, MVT::v1f64, Expand);
988 
989     setOperationAction(ISD::MUL, MVT::v1i64, Expand);
990 
991     // AArch64 doesn't have a direct vector ->f32 conversion instructions for
992     // elements smaller than i32, so promote the input to i32 first.
993     setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v4i8, MVT::v4i32);
994     setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v4i8, MVT::v4i32);
995     setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v8i8, MVT::v8i32);
996     setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v8i8, MVT::v8i32);
997     setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v16i8, MVT::v16i32);
998     setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v16i8, MVT::v16i32);
999 
1000     // Similarly, there is no direct i32 -> f64 vector conversion instruction.
1001     setOperationAction(ISD::SINT_TO_FP, MVT::v2i32, Custom);
1002     setOperationAction(ISD::UINT_TO_FP, MVT::v2i32, Custom);
1003     setOperationAction(ISD::SINT_TO_FP, MVT::v2i64, Custom);
1004     setOperationAction(ISD::UINT_TO_FP, MVT::v2i64, Custom);
1005     // Or, direct i32 -> f16 vector conversion.  Set it so custom, so the
1006     // conversion happens in two steps: v4i32 -> v4f32 -> v4f16
1007     setOperationAction(ISD::SINT_TO_FP, MVT::v4i32, Custom);
1008     setOperationAction(ISD::UINT_TO_FP, MVT::v4i32, Custom);
1009 
1010     if (Subtarget->hasFullFP16()) {
1011       setOperationAction(ISD::SINT_TO_FP, MVT::v4i16, Custom);
1012       setOperationAction(ISD::UINT_TO_FP, MVT::v4i16, Custom);
1013       setOperationAction(ISD::SINT_TO_FP, MVT::v8i16, Custom);
1014       setOperationAction(ISD::UINT_TO_FP, MVT::v8i16, Custom);
1015     } else {
1016       // when AArch64 doesn't have fullfp16 support, promote the input
1017       // to i32 first.
1018       setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v4i16, MVT::v4i32);
1019       setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v4i16, MVT::v4i32);
1020       setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v8i16, MVT::v8i32);
1021       setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v8i16, MVT::v8i32);
1022     }
1023 
1024     setOperationAction(ISD::CTLZ,       MVT::v1i64, Expand);
1025     setOperationAction(ISD::CTLZ,       MVT::v2i64, Expand);
1026     setOperationAction(ISD::BITREVERSE, MVT::v8i8, Legal);
1027     setOperationAction(ISD::BITREVERSE, MVT::v16i8, Legal);
1028 
1029     // AArch64 doesn't have MUL.2d:
1030     setOperationAction(ISD::MUL, MVT::v2i64, Expand);
1031     // Custom handling for some quad-vector types to detect MULL.
1032     setOperationAction(ISD::MUL, MVT::v8i16, Custom);
1033     setOperationAction(ISD::MUL, MVT::v4i32, Custom);
1034     setOperationAction(ISD::MUL, MVT::v2i64, Custom);
1035 
1036     // Saturates
1037     for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32,
1038                     MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) {
1039       setOperationAction(ISD::SADDSAT, VT, Legal);
1040       setOperationAction(ISD::UADDSAT, VT, Legal);
1041       setOperationAction(ISD::SSUBSAT, VT, Legal);
1042       setOperationAction(ISD::USUBSAT, VT, Legal);
1043     }
1044 
1045     // Vector reductions
1046     for (MVT VT : { MVT::v4f16, MVT::v2f32,
1047                     MVT::v8f16, MVT::v4f32, MVT::v2f64 }) {
1048       if (VT.getVectorElementType() != MVT::f16 || Subtarget->hasFullFP16()) {
1049         setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom);
1050         setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom);
1051 
1052         setOperationAction(ISD::VECREDUCE_FADD, VT, Legal);
1053       }
1054     }
1055     for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32,
1056                     MVT::v16i8, MVT::v8i16, MVT::v4i32 }) {
1057       setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
1058       setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
1059       setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
1060       setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
1061       setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
1062     }
1063     setOperationAction(ISD::VECREDUCE_ADD, MVT::v2i64, Custom);
1064 
1065     setOperationAction(ISD::ANY_EXTEND, MVT::v4i32, Legal);
1066     setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
1067     // Likewise, narrowing and extending vector loads/stores aren't handled
1068     // directly.
1069     for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
1070       setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Expand);
1071 
1072       if (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32) {
1073         setOperationAction(ISD::MULHS, VT, Legal);
1074         setOperationAction(ISD::MULHU, VT, Legal);
1075       } else {
1076         setOperationAction(ISD::MULHS, VT, Expand);
1077         setOperationAction(ISD::MULHU, VT, Expand);
1078       }
1079       setOperationAction(ISD::SMUL_LOHI, VT, Expand);
1080       setOperationAction(ISD::UMUL_LOHI, VT, Expand);
1081 
1082       setOperationAction(ISD::BSWAP, VT, Expand);
1083       setOperationAction(ISD::CTTZ, VT, Expand);
1084 
1085       for (MVT InnerVT : MVT::fixedlen_vector_valuetypes()) {
1086         setTruncStoreAction(VT, InnerVT, Expand);
1087         setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Expand);
1088         setLoadExtAction(ISD::ZEXTLOAD, VT, InnerVT, Expand);
1089         setLoadExtAction(ISD::EXTLOAD, VT, InnerVT, Expand);
1090       }
1091     }
1092 
1093     // AArch64 has implementations of a lot of rounding-like FP operations.
1094     for (MVT Ty : {MVT::v2f32, MVT::v4f32, MVT::v2f64}) {
1095       setOperationAction(ISD::FFLOOR, Ty, Legal);
1096       setOperationAction(ISD::FNEARBYINT, Ty, Legal);
1097       setOperationAction(ISD::FCEIL, Ty, Legal);
1098       setOperationAction(ISD::FRINT, Ty, Legal);
1099       setOperationAction(ISD::FTRUNC, Ty, Legal);
1100       setOperationAction(ISD::FROUND, Ty, Legal);
1101       setOperationAction(ISD::FROUNDEVEN, Ty, Legal);
1102     }
1103 
1104     if (Subtarget->hasFullFP16()) {
1105       for (MVT Ty : {MVT::v4f16, MVT::v8f16}) {
1106         setOperationAction(ISD::FFLOOR, Ty, Legal);
1107         setOperationAction(ISD::FNEARBYINT, Ty, Legal);
1108         setOperationAction(ISD::FCEIL, Ty, Legal);
1109         setOperationAction(ISD::FRINT, Ty, Legal);
1110         setOperationAction(ISD::FTRUNC, Ty, Legal);
1111         setOperationAction(ISD::FROUND, Ty, Legal);
1112         setOperationAction(ISD::FROUNDEVEN, Ty, Legal);
1113       }
1114     }
1115 
1116     if (Subtarget->hasSVE())
1117       setOperationAction(ISD::VSCALE, MVT::i32, Custom);
1118 
1119     setTruncStoreAction(MVT::v4i16, MVT::v4i8, Custom);
1120   }
1121 
1122   if (Subtarget->hasSVE()) {
1123     for (auto VT : {MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64}) {
1124       setOperationAction(ISD::BITREVERSE, VT, Custom);
1125       setOperationAction(ISD::BSWAP, VT, Custom);
1126       setOperationAction(ISD::CTLZ, VT, Custom);
1127       setOperationAction(ISD::CTPOP, VT, Custom);
1128       setOperationAction(ISD::CTTZ, VT, Custom);
1129       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1130       setOperationAction(ISD::UINT_TO_FP, VT, Custom);
1131       setOperationAction(ISD::SINT_TO_FP, VT, Custom);
1132       setOperationAction(ISD::FP_TO_UINT, VT, Custom);
1133       setOperationAction(ISD::FP_TO_SINT, VT, Custom);
1134       setOperationAction(ISD::MGATHER, VT, Custom);
1135       setOperationAction(ISD::MSCATTER, VT, Custom);
1136       setOperationAction(ISD::MUL, VT, Custom);
1137       setOperationAction(ISD::MULHS, VT, Custom);
1138       setOperationAction(ISD::MULHU, VT, Custom);
1139       setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1140       setOperationAction(ISD::SELECT, VT, Custom);
1141       setOperationAction(ISD::SETCC, VT, Custom);
1142       setOperationAction(ISD::SDIV, VT, Custom);
1143       setOperationAction(ISD::UDIV, VT, Custom);
1144       setOperationAction(ISD::SMIN, VT, Custom);
1145       setOperationAction(ISD::UMIN, VT, Custom);
1146       setOperationAction(ISD::SMAX, VT, Custom);
1147       setOperationAction(ISD::UMAX, VT, Custom);
1148       setOperationAction(ISD::SHL, VT, Custom);
1149       setOperationAction(ISD::SRL, VT, Custom);
1150       setOperationAction(ISD::SRA, VT, Custom);
1151       setOperationAction(ISD::ABS, VT, Custom);
1152       setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
1153       setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1154       setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
1155       setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1156       setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
1157       setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
1158       setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
1159       setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
1160 
1161       setOperationAction(ISD::UMUL_LOHI, VT, Expand);
1162       setOperationAction(ISD::SMUL_LOHI, VT, Expand);
1163       setOperationAction(ISD::SELECT_CC, VT, Expand);
1164       setOperationAction(ISD::ROTL, VT, Expand);
1165       setOperationAction(ISD::ROTR, VT, Expand);
1166     }
1167 
1168     // Illegal unpacked integer vector types.
1169     for (auto VT : {MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32}) {
1170       setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
1171       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1172     }
1173 
1174     for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) {
1175       setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
1176       setOperationAction(ISD::SELECT, VT, Custom);
1177       setOperationAction(ISD::SETCC, VT, Custom);
1178       setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1179       setOperationAction(ISD::TRUNCATE, VT, Custom);
1180       setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1181       setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
1182       setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1183 
1184       setOperationAction(ISD::SELECT_CC, VT, Expand);
1185       setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
1186 
1187       // There are no legal MVT::nxv16f## based types.
1188       if (VT != MVT::nxv16i1) {
1189         setOperationAction(ISD::SINT_TO_FP, VT, Custom);
1190         setOperationAction(ISD::UINT_TO_FP, VT, Custom);
1191       }
1192 
1193       // NEON doesn't support masked loads or stores, but SVE does
1194       for (auto VT :
1195            {MVT::v4f16, MVT::v8f16, MVT::v2f32, MVT::v4f32, MVT::v1f64,
1196             MVT::v2f64, MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16,
1197             MVT::v2i32, MVT::v4i32, MVT::v1i64, MVT::v2i64}) {
1198         setOperationAction(ISD::MLOAD, VT, Custom);
1199         setOperationAction(ISD::MSTORE, VT, Custom);
1200       }
1201     }
1202 
1203     for (MVT VT : MVT::fp_scalable_vector_valuetypes()) {
1204       for (MVT InnerVT : MVT::fp_scalable_vector_valuetypes()) {
1205         // Avoid marking truncating FP stores as legal to prevent the
1206         // DAGCombiner from creating unsupported truncating stores.
1207         setTruncStoreAction(VT, InnerVT, Expand);
1208         // SVE does not have floating-point extending loads.
1209         setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Expand);
1210         setLoadExtAction(ISD::ZEXTLOAD, VT, InnerVT, Expand);
1211         setLoadExtAction(ISD::EXTLOAD, VT, InnerVT, Expand);
1212       }
1213     }
1214 
1215     for (auto VT : {MVT::nxv2f16, MVT::nxv4f16, MVT::nxv8f16, MVT::nxv2f32,
1216                     MVT::nxv4f32, MVT::nxv2f64}) {
1217       setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
1218       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1219       setOperationAction(ISD::MGATHER, VT, Custom);
1220       setOperationAction(ISD::MSCATTER, VT, Custom);
1221       setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1222       setOperationAction(ISD::SELECT, VT, Custom);
1223       setOperationAction(ISD::FADD, VT, Custom);
1224       setOperationAction(ISD::FDIV, VT, Custom);
1225       setOperationAction(ISD::FMA, VT, Custom);
1226       setOperationAction(ISD::FMAXIMUM, VT, Custom);
1227       setOperationAction(ISD::FMAXNUM, VT, Custom);
1228       setOperationAction(ISD::FMINIMUM, VT, Custom);
1229       setOperationAction(ISD::FMINNUM, VT, Custom);
1230       setOperationAction(ISD::FMUL, VT, Custom);
1231       setOperationAction(ISD::FNEG, VT, Custom);
1232       setOperationAction(ISD::FSUB, VT, Custom);
1233       setOperationAction(ISD::FCEIL, VT, Custom);
1234       setOperationAction(ISD::FFLOOR, VT, Custom);
1235       setOperationAction(ISD::FNEARBYINT, VT, Custom);
1236       setOperationAction(ISD::FRINT, VT, Custom);
1237       setOperationAction(ISD::FROUND, VT, Custom);
1238       setOperationAction(ISD::FROUNDEVEN, VT, Custom);
1239       setOperationAction(ISD::FTRUNC, VT, Custom);
1240       setOperationAction(ISD::FSQRT, VT, Custom);
1241       setOperationAction(ISD::FABS, VT, Custom);
1242       setOperationAction(ISD::FP_EXTEND, VT, Custom);
1243       setOperationAction(ISD::FP_ROUND, VT, Custom);
1244       setOperationAction(ISD::VECREDUCE_FADD, VT, Custom);
1245       setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom);
1246       setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom);
1247       setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
1248 
1249       setOperationAction(ISD::SELECT_CC, VT, Expand);
1250     }
1251 
1252     for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
1253       setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
1254       setOperationAction(ISD::MGATHER, VT, Custom);
1255       setOperationAction(ISD::MSCATTER, VT, Custom);
1256     }
1257 
1258     setOperationAction(ISD::SPLAT_VECTOR, MVT::nxv8bf16, Custom);
1259 
1260     setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
1261     setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i16, Custom);
1262 
1263     // NOTE: Currently this has to happen after computeRegisterProperties rather
1264     // than the preferred option of combining it with the addRegisterClass call.
1265     if (Subtarget->useSVEForFixedLengthVectors()) {
1266       for (MVT VT : MVT::integer_fixedlen_vector_valuetypes())
1267         if (useSVEForFixedLengthVectorVT(VT))
1268           addTypeForFixedLengthSVE(VT);
1269       for (MVT VT : MVT::fp_fixedlen_vector_valuetypes())
1270         if (useSVEForFixedLengthVectorVT(VT))
1271           addTypeForFixedLengthSVE(VT);
1272 
1273       // 64bit results can mean a bigger than NEON input.
1274       for (auto VT : {MVT::v8i8, MVT::v4i16})
1275         setOperationAction(ISD::TRUNCATE, VT, Custom);
1276       setOperationAction(ISD::FP_ROUND, MVT::v4f16, Custom);
1277 
1278       // 128bit results imply a bigger than NEON input.
1279       for (auto VT : {MVT::v16i8, MVT::v8i16, MVT::v4i32})
1280         setOperationAction(ISD::TRUNCATE, VT, Custom);
1281       for (auto VT : {MVT::v8f16, MVT::v4f32})
1282         setOperationAction(ISD::FP_ROUND, VT, Expand);
1283 
1284       // These operations are not supported on NEON but SVE can do them.
1285       setOperationAction(ISD::BITREVERSE, MVT::v1i64, Custom);
1286       setOperationAction(ISD::CTLZ, MVT::v1i64, Custom);
1287       setOperationAction(ISD::CTLZ, MVT::v2i64, Custom);
1288       setOperationAction(ISD::CTTZ, MVT::v1i64, Custom);
1289       setOperationAction(ISD::MUL, MVT::v1i64, Custom);
1290       setOperationAction(ISD::MUL, MVT::v2i64, Custom);
1291       setOperationAction(ISD::MULHS, MVT::v1i64, Custom);
1292       setOperationAction(ISD::MULHS, MVT::v2i64, Custom);
1293       setOperationAction(ISD::MULHU, MVT::v1i64, Custom);
1294       setOperationAction(ISD::MULHU, MVT::v2i64, Custom);
1295       setOperationAction(ISD::SDIV, MVT::v8i8, Custom);
1296       setOperationAction(ISD::SDIV, MVT::v16i8, Custom);
1297       setOperationAction(ISD::SDIV, MVT::v4i16, Custom);
1298       setOperationAction(ISD::SDIV, MVT::v8i16, Custom);
1299       setOperationAction(ISD::SDIV, MVT::v2i32, Custom);
1300       setOperationAction(ISD::SDIV, MVT::v4i32, Custom);
1301       setOperationAction(ISD::SDIV, MVT::v1i64, Custom);
1302       setOperationAction(ISD::SDIV, MVT::v2i64, Custom);
1303       setOperationAction(ISD::SMAX, MVT::v1i64, Custom);
1304       setOperationAction(ISD::SMAX, MVT::v2i64, Custom);
1305       setOperationAction(ISD::SMIN, MVT::v1i64, Custom);
1306       setOperationAction(ISD::SMIN, MVT::v2i64, Custom);
1307       setOperationAction(ISD::UDIV, MVT::v8i8, Custom);
1308       setOperationAction(ISD::UDIV, MVT::v16i8, Custom);
1309       setOperationAction(ISD::UDIV, MVT::v4i16, Custom);
1310       setOperationAction(ISD::UDIV, MVT::v8i16, Custom);
1311       setOperationAction(ISD::UDIV, MVT::v2i32, Custom);
1312       setOperationAction(ISD::UDIV, MVT::v4i32, Custom);
1313       setOperationAction(ISD::UDIV, MVT::v1i64, Custom);
1314       setOperationAction(ISD::UDIV, MVT::v2i64, Custom);
1315       setOperationAction(ISD::UMAX, MVT::v1i64, Custom);
1316       setOperationAction(ISD::UMAX, MVT::v2i64, Custom);
1317       setOperationAction(ISD::UMIN, MVT::v1i64, Custom);
1318       setOperationAction(ISD::UMIN, MVT::v2i64, Custom);
1319       setOperationAction(ISD::VECREDUCE_SMAX, MVT::v2i64, Custom);
1320       setOperationAction(ISD::VECREDUCE_SMIN, MVT::v2i64, Custom);
1321       setOperationAction(ISD::VECREDUCE_UMAX, MVT::v2i64, Custom);
1322       setOperationAction(ISD::VECREDUCE_UMIN, MVT::v2i64, Custom);
1323 
1324       // Int operations with no NEON support.
1325       for (auto VT : {MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16,
1326                       MVT::v2i32, MVT::v4i32, MVT::v2i64}) {
1327         setOperationAction(ISD::BITREVERSE, VT, Custom);
1328         setOperationAction(ISD::CTTZ, VT, Custom);
1329         setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1330         setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
1331         setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1332       }
1333 
1334       // FP operations with no NEON support.
1335       for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v2f32, MVT::v4f32,
1336                       MVT::v1f64, MVT::v2f64})
1337         setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
1338 
1339       // Use SVE for vectors with more than 2 elements.
1340       for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v4f32})
1341         setOperationAction(ISD::VECREDUCE_FADD, VT, Custom);
1342     }
1343 
1344     setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv2i1, MVT::nxv2i64);
1345     setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv4i1, MVT::nxv4i32);
1346     setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv8i1, MVT::nxv8i16);
1347     setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv16i1, MVT::nxv16i8);
1348   }
1349 
1350   PredictableSelectIsExpensive = Subtarget->predictableSelectIsExpensive();
1351 }
1352 
addTypeForNEON(MVT VT,MVT PromotedBitwiseVT)1353 void AArch64TargetLowering::addTypeForNEON(MVT VT, MVT PromotedBitwiseVT) {
1354   assert(VT.isVector() && "VT should be a vector type");
1355 
1356   if (VT.isFloatingPoint()) {
1357     MVT PromoteTo = EVT(VT).changeVectorElementTypeToInteger().getSimpleVT();
1358     setOperationPromotedToType(ISD::LOAD, VT, PromoteTo);
1359     setOperationPromotedToType(ISD::STORE, VT, PromoteTo);
1360   }
1361 
1362   // Mark vector float intrinsics as expand.
1363   if (VT == MVT::v2f32 || VT == MVT::v4f32 || VT == MVT::v2f64) {
1364     setOperationAction(ISD::FSIN, VT, Expand);
1365     setOperationAction(ISD::FCOS, VT, Expand);
1366     setOperationAction(ISD::FPOW, VT, Expand);
1367     setOperationAction(ISD::FLOG, VT, Expand);
1368     setOperationAction(ISD::FLOG2, VT, Expand);
1369     setOperationAction(ISD::FLOG10, VT, Expand);
1370     setOperationAction(ISD::FEXP, VT, Expand);
1371     setOperationAction(ISD::FEXP2, VT, Expand);
1372 
1373     // But we do support custom-lowering for FCOPYSIGN.
1374     setOperationAction(ISD::FCOPYSIGN, VT, Custom);
1375   }
1376 
1377   setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
1378   setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
1379   setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
1380   setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom);
1381   setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
1382   setOperationAction(ISD::SRA, VT, Custom);
1383   setOperationAction(ISD::SRL, VT, Custom);
1384   setOperationAction(ISD::SHL, VT, Custom);
1385   setOperationAction(ISD::OR, VT, Custom);
1386   setOperationAction(ISD::SETCC, VT, Custom);
1387   setOperationAction(ISD::CONCAT_VECTORS, VT, Legal);
1388 
1389   setOperationAction(ISD::SELECT, VT, Expand);
1390   setOperationAction(ISD::SELECT_CC, VT, Expand);
1391   setOperationAction(ISD::VSELECT, VT, Expand);
1392   for (MVT InnerVT : MVT::all_valuetypes())
1393     setLoadExtAction(ISD::EXTLOAD, InnerVT, VT, Expand);
1394 
1395   // CNT supports only B element sizes, then use UADDLP to widen.
1396   if (VT != MVT::v8i8 && VT != MVT::v16i8)
1397     setOperationAction(ISD::CTPOP, VT, Custom);
1398 
1399   setOperationAction(ISD::UDIV, VT, Expand);
1400   setOperationAction(ISD::SDIV, VT, Expand);
1401   setOperationAction(ISD::UREM, VT, Expand);
1402   setOperationAction(ISD::SREM, VT, Expand);
1403   setOperationAction(ISD::FREM, VT, Expand);
1404 
1405   setOperationAction(ISD::FP_TO_SINT, VT, Custom);
1406   setOperationAction(ISD::FP_TO_UINT, VT, Custom);
1407 
1408   if (!VT.isFloatingPoint())
1409     setOperationAction(ISD::ABS, VT, Legal);
1410 
1411   // [SU][MIN|MAX] are available for all NEON types apart from i64.
1412   if (!VT.isFloatingPoint() && VT != MVT::v2i64 && VT != MVT::v1i64)
1413     for (unsigned Opcode : {ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX})
1414       setOperationAction(Opcode, VT, Legal);
1415 
1416   // F[MIN|MAX][NUM|NAN] are available for all FP NEON types.
1417   if (VT.isFloatingPoint() &&
1418       VT.getVectorElementType() != MVT::bf16 &&
1419       (VT.getVectorElementType() != MVT::f16 || Subtarget->hasFullFP16()))
1420     for (unsigned Opcode :
1421          {ISD::FMINIMUM, ISD::FMAXIMUM, ISD::FMINNUM, ISD::FMAXNUM})
1422       setOperationAction(Opcode, VT, Legal);
1423 
1424   if (Subtarget->isLittleEndian()) {
1425     for (unsigned im = (unsigned)ISD::PRE_INC;
1426          im != (unsigned)ISD::LAST_INDEXED_MODE; ++im) {
1427       setIndexedLoadAction(im, VT, Legal);
1428       setIndexedStoreAction(im, VT, Legal);
1429     }
1430   }
1431 }
1432 
addTypeForFixedLengthSVE(MVT VT)1433 void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
1434   assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
1435 
1436   // By default everything must be expanded.
1437   for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
1438     setOperationAction(Op, VT, Expand);
1439 
1440   // We use EXTRACT_SUBVECTOR to "cast" a scalable vector to a fixed length one.
1441   setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
1442 
1443   if (VT.isFloatingPoint()) {
1444     setCondCodeAction(ISD::SETO, VT, Expand);
1445     setCondCodeAction(ISD::SETOLT, VT, Expand);
1446     setCondCodeAction(ISD::SETLT, VT, Expand);
1447     setCondCodeAction(ISD::SETOLE, VT, Expand);
1448     setCondCodeAction(ISD::SETLE, VT, Expand);
1449     setCondCodeAction(ISD::SETULT, VT, Expand);
1450     setCondCodeAction(ISD::SETULE, VT, Expand);
1451     setCondCodeAction(ISD::SETUGE, VT, Expand);
1452     setCondCodeAction(ISD::SETUGT, VT, Expand);
1453     setCondCodeAction(ISD::SETUEQ, VT, Expand);
1454     setCondCodeAction(ISD::SETUNE, VT, Expand);
1455   }
1456 
1457   // Lower fixed length vector operations to scalable equivalents.
1458   setOperationAction(ISD::ABS, VT, Custom);
1459   setOperationAction(ISD::ADD, VT, Custom);
1460   setOperationAction(ISD::AND, VT, Custom);
1461   setOperationAction(ISD::ANY_EXTEND, VT, Custom);
1462   setOperationAction(ISD::BITCAST, VT, Custom);
1463   setOperationAction(ISD::BITREVERSE, VT, Custom);
1464   setOperationAction(ISD::BSWAP, VT, Custom);
1465   setOperationAction(ISD::CTLZ, VT, Custom);
1466   setOperationAction(ISD::CTPOP, VT, Custom);
1467   setOperationAction(ISD::CTTZ, VT, Custom);
1468   setOperationAction(ISD::FABS, VT, Custom);
1469   setOperationAction(ISD::FADD, VT, Custom);
1470   setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
1471   setOperationAction(ISD::FCEIL, VT, Custom);
1472   setOperationAction(ISD::FDIV, VT, Custom);
1473   setOperationAction(ISD::FFLOOR, VT, Custom);
1474   setOperationAction(ISD::FMA, VT, Custom);
1475   setOperationAction(ISD::FMAXIMUM, VT, Custom);
1476   setOperationAction(ISD::FMAXNUM, VT, Custom);
1477   setOperationAction(ISD::FMINIMUM, VT, Custom);
1478   setOperationAction(ISD::FMINNUM, VT, Custom);
1479   setOperationAction(ISD::FMUL, VT, Custom);
1480   setOperationAction(ISD::FNEARBYINT, VT, Custom);
1481   setOperationAction(ISD::FNEG, VT, Custom);
1482   setOperationAction(ISD::FRINT, VT, Custom);
1483   setOperationAction(ISD::FROUND, VT, Custom);
1484   setOperationAction(ISD::FROUNDEVEN, VT, Custom);
1485   setOperationAction(ISD::FSQRT, VT, Custom);
1486   setOperationAction(ISD::FSUB, VT, Custom);
1487   setOperationAction(ISD::FTRUNC, VT, Custom);
1488   setOperationAction(ISD::LOAD, VT, Custom);
1489   setOperationAction(ISD::MLOAD, VT, Custom);
1490   setOperationAction(ISD::MSTORE, VT, Custom);
1491   setOperationAction(ISD::MUL, VT, Custom);
1492   setOperationAction(ISD::MULHS, VT, Custom);
1493   setOperationAction(ISD::MULHU, VT, Custom);
1494   setOperationAction(ISD::OR, VT, Custom);
1495   setOperationAction(ISD::SDIV, VT, Custom);
1496   setOperationAction(ISD::SELECT, VT, Custom);
1497   setOperationAction(ISD::SETCC, VT, Custom);
1498   setOperationAction(ISD::SHL, VT, Custom);
1499   setOperationAction(ISD::SIGN_EXTEND, VT, Custom);
1500   setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Custom);
1501   setOperationAction(ISD::SMAX, VT, Custom);
1502   setOperationAction(ISD::SMIN, VT, Custom);
1503   setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1504   setOperationAction(ISD::SRA, VT, Custom);
1505   setOperationAction(ISD::SRL, VT, Custom);
1506   setOperationAction(ISD::STORE, VT, Custom);
1507   setOperationAction(ISD::SUB, VT, Custom);
1508   setOperationAction(ISD::TRUNCATE, VT, Custom);
1509   setOperationAction(ISD::UDIV, VT, Custom);
1510   setOperationAction(ISD::UMAX, VT, Custom);
1511   setOperationAction(ISD::UMIN, VT, Custom);
1512   setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
1513   setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1514   setOperationAction(ISD::VECREDUCE_FADD, VT, Custom);
1515   setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
1516   setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom);
1517   setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom);
1518   setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
1519   setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
1520   setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
1521   setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
1522   setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
1523   setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
1524   setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1525   setOperationAction(ISD::VSELECT, VT, Custom);
1526   setOperationAction(ISD::XOR, VT, Custom);
1527   setOperationAction(ISD::ZERO_EXTEND, VT, Custom);
1528 }
1529 
addDRTypeForNEON(MVT VT)1530 void AArch64TargetLowering::addDRTypeForNEON(MVT VT) {
1531   addRegisterClass(VT, &AArch64::FPR64RegClass);
1532   addTypeForNEON(VT, MVT::v2i32);
1533 }
1534 
addQRTypeForNEON(MVT VT)1535 void AArch64TargetLowering::addQRTypeForNEON(MVT VT) {
1536   addRegisterClass(VT, &AArch64::FPR128RegClass);
1537   addTypeForNEON(VT, MVT::v4i32);
1538 }
1539 
getSetCCResultType(const DataLayout &,LLVMContext & C,EVT VT) const1540 EVT AArch64TargetLowering::getSetCCResultType(const DataLayout &,
1541                                               LLVMContext &C, EVT VT) const {
1542   if (!VT.isVector())
1543     return MVT::i32;
1544   if (VT.isScalableVector())
1545     return EVT::getVectorVT(C, MVT::i1, VT.getVectorElementCount());
1546   return VT.changeVectorElementTypeToInteger();
1547 }
1548 
optimizeLogicalImm(SDValue Op,unsigned Size,uint64_t Imm,const APInt & Demanded,TargetLowering::TargetLoweringOpt & TLO,unsigned NewOpc)1549 static bool optimizeLogicalImm(SDValue Op, unsigned Size, uint64_t Imm,
1550                                const APInt &Demanded,
1551                                TargetLowering::TargetLoweringOpt &TLO,
1552                                unsigned NewOpc) {
1553   uint64_t OldImm = Imm, NewImm, Enc;
1554   uint64_t Mask = ((uint64_t)(-1LL) >> (64 - Size)), OrigMask = Mask;
1555 
1556   // Return if the immediate is already all zeros, all ones, a bimm32 or a
1557   // bimm64.
1558   if (Imm == 0 || Imm == Mask ||
1559       AArch64_AM::isLogicalImmediate(Imm & Mask, Size))
1560     return false;
1561 
1562   unsigned EltSize = Size;
1563   uint64_t DemandedBits = Demanded.getZExtValue();
1564 
1565   // Clear bits that are not demanded.
1566   Imm &= DemandedBits;
1567 
1568   while (true) {
1569     // The goal here is to set the non-demanded bits in a way that minimizes
1570     // the number of switching between 0 and 1. In order to achieve this goal,
1571     // we set the non-demanded bits to the value of the preceding demanded bits.
1572     // For example, if we have an immediate 0bx10xx0x1 ('x' indicates a
1573     // non-demanded bit), we copy bit0 (1) to the least significant 'x',
1574     // bit2 (0) to 'xx', and bit6 (1) to the most significant 'x'.
1575     // The final result is 0b11000011.
1576     uint64_t NonDemandedBits = ~DemandedBits;
1577     uint64_t InvertedImm = ~Imm & DemandedBits;
1578     uint64_t RotatedImm =
1579         ((InvertedImm << 1) | (InvertedImm >> (EltSize - 1) & 1)) &
1580         NonDemandedBits;
1581     uint64_t Sum = RotatedImm + NonDemandedBits;
1582     bool Carry = NonDemandedBits & ~Sum & (1ULL << (EltSize - 1));
1583     uint64_t Ones = (Sum + Carry) & NonDemandedBits;
1584     NewImm = (Imm | Ones) & Mask;
1585 
1586     // If NewImm or its bitwise NOT is a shifted mask, it is a bitmask immediate
1587     // or all-ones or all-zeros, in which case we can stop searching. Otherwise,
1588     // we halve the element size and continue the search.
1589     if (isShiftedMask_64(NewImm) || isShiftedMask_64(~(NewImm | ~Mask)))
1590       break;
1591 
1592     // We cannot shrink the element size any further if it is 2-bits.
1593     if (EltSize == 2)
1594       return false;
1595 
1596     EltSize /= 2;
1597     Mask >>= EltSize;
1598     uint64_t Hi = Imm >> EltSize, DemandedBitsHi = DemandedBits >> EltSize;
1599 
1600     // Return if there is mismatch in any of the demanded bits of Imm and Hi.
1601     if (((Imm ^ Hi) & (DemandedBits & DemandedBitsHi) & Mask) != 0)
1602       return false;
1603 
1604     // Merge the upper and lower halves of Imm and DemandedBits.
1605     Imm |= Hi;
1606     DemandedBits |= DemandedBitsHi;
1607   }
1608 
1609   ++NumOptimizedImms;
1610 
1611   // Replicate the element across the register width.
1612   while (EltSize < Size) {
1613     NewImm |= NewImm << EltSize;
1614     EltSize *= 2;
1615   }
1616 
1617   (void)OldImm;
1618   assert(((OldImm ^ NewImm) & Demanded.getZExtValue()) == 0 &&
1619          "demanded bits should never be altered");
1620   assert(OldImm != NewImm && "the new imm shouldn't be equal to the old imm");
1621 
1622   // Create the new constant immediate node.
1623   EVT VT = Op.getValueType();
1624   SDLoc DL(Op);
1625   SDValue New;
1626 
1627   // If the new constant immediate is all-zeros or all-ones, let the target
1628   // independent DAG combine optimize this node.
1629   if (NewImm == 0 || NewImm == OrigMask) {
1630     New = TLO.DAG.getNode(Op.getOpcode(), DL, VT, Op.getOperand(0),
1631                           TLO.DAG.getConstant(NewImm, DL, VT));
1632   // Otherwise, create a machine node so that target independent DAG combine
1633   // doesn't undo this optimization.
1634   } else {
1635     Enc = AArch64_AM::encodeLogicalImmediate(NewImm, Size);
1636     SDValue EncConst = TLO.DAG.getTargetConstant(Enc, DL, VT);
1637     New = SDValue(
1638         TLO.DAG.getMachineNode(NewOpc, DL, VT, Op.getOperand(0), EncConst), 0);
1639   }
1640 
1641   return TLO.CombineTo(Op, New);
1642 }
1643 
targetShrinkDemandedConstant(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,TargetLoweringOpt & TLO) const1644 bool AArch64TargetLowering::targetShrinkDemandedConstant(
1645     SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
1646     TargetLoweringOpt &TLO) const {
1647   // Delay this optimization to as late as possible.
1648   if (!TLO.LegalOps)
1649     return false;
1650 
1651   if (!EnableOptimizeLogicalImm)
1652     return false;
1653 
1654   EVT VT = Op.getValueType();
1655   if (VT.isVector())
1656     return false;
1657 
1658   unsigned Size = VT.getSizeInBits();
1659   assert((Size == 32 || Size == 64) &&
1660          "i32 or i64 is expected after legalization.");
1661 
1662   // Exit early if we demand all bits.
1663   if (DemandedBits.countPopulation() == Size)
1664     return false;
1665 
1666   unsigned NewOpc;
1667   switch (Op.getOpcode()) {
1668   default:
1669     return false;
1670   case ISD::AND:
1671     NewOpc = Size == 32 ? AArch64::ANDWri : AArch64::ANDXri;
1672     break;
1673   case ISD::OR:
1674     NewOpc = Size == 32 ? AArch64::ORRWri : AArch64::ORRXri;
1675     break;
1676   case ISD::XOR:
1677     NewOpc = Size == 32 ? AArch64::EORWri : AArch64::EORXri;
1678     break;
1679   }
1680   ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getOperand(1));
1681   if (!C)
1682     return false;
1683   uint64_t Imm = C->getZExtValue();
1684   return optimizeLogicalImm(Op, Size, Imm, DemandedBits, TLO, NewOpc);
1685 }
1686 
1687 /// computeKnownBitsForTargetNode - Determine which of the bits specified in
1688 /// Mask are known to be either zero or one and return them Known.
computeKnownBitsForTargetNode(const SDValue Op,KnownBits & Known,const APInt & DemandedElts,const SelectionDAG & DAG,unsigned Depth) const1689 void AArch64TargetLowering::computeKnownBitsForTargetNode(
1690     const SDValue Op, KnownBits &Known,
1691     const APInt &DemandedElts, const SelectionDAG &DAG, unsigned Depth) const {
1692   switch (Op.getOpcode()) {
1693   default:
1694     break;
1695   case AArch64ISD::CSEL: {
1696     KnownBits Known2;
1697     Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
1698     Known2 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1);
1699     Known = KnownBits::commonBits(Known, Known2);
1700     break;
1701   }
1702   case AArch64ISD::LOADgot:
1703   case AArch64ISD::ADDlow: {
1704     if (!Subtarget->isTargetILP32())
1705       break;
1706     // In ILP32 mode all valid pointers are in the low 4GB of the address-space.
1707     Known.Zero = APInt::getHighBitsSet(64, 32);
1708     break;
1709   }
1710   case ISD::INTRINSIC_W_CHAIN: {
1711     ConstantSDNode *CN = cast<ConstantSDNode>(Op->getOperand(1));
1712     Intrinsic::ID IntID = static_cast<Intrinsic::ID>(CN->getZExtValue());
1713     switch (IntID) {
1714     default: return;
1715     case Intrinsic::aarch64_ldaxr:
1716     case Intrinsic::aarch64_ldxr: {
1717       unsigned BitWidth = Known.getBitWidth();
1718       EVT VT = cast<MemIntrinsicSDNode>(Op)->getMemoryVT();
1719       unsigned MemBits = VT.getScalarSizeInBits();
1720       Known.Zero |= APInt::getHighBitsSet(BitWidth, BitWidth - MemBits);
1721       return;
1722     }
1723     }
1724     break;
1725   }
1726   case ISD::INTRINSIC_WO_CHAIN:
1727   case ISD::INTRINSIC_VOID: {
1728     unsigned IntNo = cast<ConstantSDNode>(Op.getOperand(0))->getZExtValue();
1729     switch (IntNo) {
1730     default:
1731       break;
1732     case Intrinsic::aarch64_neon_umaxv:
1733     case Intrinsic::aarch64_neon_uminv: {
1734       // Figure out the datatype of the vector operand. The UMINV instruction
1735       // will zero extend the result, so we can mark as known zero all the
1736       // bits larger than the element datatype. 32-bit or larget doesn't need
1737       // this as those are legal types and will be handled by isel directly.
1738       MVT VT = Op.getOperand(1).getValueType().getSimpleVT();
1739       unsigned BitWidth = Known.getBitWidth();
1740       if (VT == MVT::v8i8 || VT == MVT::v16i8) {
1741         assert(BitWidth >= 8 && "Unexpected width!");
1742         APInt Mask = APInt::getHighBitsSet(BitWidth, BitWidth - 8);
1743         Known.Zero |= Mask;
1744       } else if (VT == MVT::v4i16 || VT == MVT::v8i16) {
1745         assert(BitWidth >= 16 && "Unexpected width!");
1746         APInt Mask = APInt::getHighBitsSet(BitWidth, BitWidth - 16);
1747         Known.Zero |= Mask;
1748       }
1749       break;
1750     } break;
1751     }
1752   }
1753   }
1754 }
1755 
getScalarShiftAmountTy(const DataLayout & DL,EVT) const1756 MVT AArch64TargetLowering::getScalarShiftAmountTy(const DataLayout &DL,
1757                                                   EVT) const {
1758   return MVT::i64;
1759 }
1760 
allowsMisalignedMemoryAccesses(EVT VT,unsigned AddrSpace,Align Alignment,MachineMemOperand::Flags Flags,bool * Fast) const1761 bool AArch64TargetLowering::allowsMisalignedMemoryAccesses(
1762     EVT VT, unsigned AddrSpace, Align Alignment, MachineMemOperand::Flags Flags,
1763     bool *Fast) const {
1764   if (Subtarget->requiresStrictAlign())
1765     return false;
1766 
1767   if (Fast) {
1768     // Some CPUs are fine with unaligned stores except for 128-bit ones.
1769     *Fast = !Subtarget->isMisaligned128StoreSlow() || VT.getStoreSize() != 16 ||
1770             // See comments in performSTORECombine() for more details about
1771             // these conditions.
1772 
1773             // Code that uses clang vector extensions can mark that it
1774             // wants unaligned accesses to be treated as fast by
1775             // underspecifying alignment to be 1 or 2.
1776             Alignment <= 2 ||
1777 
1778             // Disregard v2i64. Memcpy lowering produces those and splitting
1779             // them regresses performance on micro-benchmarks and olden/bh.
1780             VT == MVT::v2i64;
1781   }
1782   return true;
1783 }
1784 
1785 // Same as above but handling LLTs instead.
allowsMisalignedMemoryAccesses(LLT Ty,unsigned AddrSpace,Align Alignment,MachineMemOperand::Flags Flags,bool * Fast) const1786 bool AArch64TargetLowering::allowsMisalignedMemoryAccesses(
1787     LLT Ty, unsigned AddrSpace, Align Alignment, MachineMemOperand::Flags Flags,
1788     bool *Fast) const {
1789   if (Subtarget->requiresStrictAlign())
1790     return false;
1791 
1792   if (Fast) {
1793     // Some CPUs are fine with unaligned stores except for 128-bit ones.
1794     *Fast = !Subtarget->isMisaligned128StoreSlow() ||
1795             Ty.getSizeInBytes() != 16 ||
1796             // See comments in performSTORECombine() for more details about
1797             // these conditions.
1798 
1799             // Code that uses clang vector extensions can mark that it
1800             // wants unaligned accesses to be treated as fast by
1801             // underspecifying alignment to be 1 or 2.
1802             Alignment <= 2 ||
1803 
1804             // Disregard v2i64. Memcpy lowering produces those and splitting
1805             // them regresses performance on micro-benchmarks and olden/bh.
1806             Ty == LLT::vector(2, 64);
1807   }
1808   return true;
1809 }
1810 
1811 FastISel *
createFastISel(FunctionLoweringInfo & funcInfo,const TargetLibraryInfo * libInfo) const1812 AArch64TargetLowering::createFastISel(FunctionLoweringInfo &funcInfo,
1813                                       const TargetLibraryInfo *libInfo) const {
1814   return AArch64::createFastISel(funcInfo, libInfo);
1815 }
1816 
getTargetNodeName(unsigned Opcode) const1817 const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
1818 #define MAKE_CASE(V)                                                           \
1819   case V:                                                                      \
1820     return #V;
1821   switch ((AArch64ISD::NodeType)Opcode) {
1822   case AArch64ISD::FIRST_NUMBER:
1823     break;
1824     MAKE_CASE(AArch64ISD::CALL)
1825     MAKE_CASE(AArch64ISD::ADRP)
1826     MAKE_CASE(AArch64ISD::ADR)
1827     MAKE_CASE(AArch64ISD::ADDlow)
1828     MAKE_CASE(AArch64ISD::LOADgot)
1829     MAKE_CASE(AArch64ISD::RET_FLAG)
1830     MAKE_CASE(AArch64ISD::BRCOND)
1831     MAKE_CASE(AArch64ISD::CSEL)
1832     MAKE_CASE(AArch64ISD::CSINV)
1833     MAKE_CASE(AArch64ISD::CSNEG)
1834     MAKE_CASE(AArch64ISD::CSINC)
1835     MAKE_CASE(AArch64ISD::THREAD_POINTER)
1836     MAKE_CASE(AArch64ISD::TLSDESC_CALLSEQ)
1837     MAKE_CASE(AArch64ISD::ADD_PRED)
1838     MAKE_CASE(AArch64ISD::MUL_PRED)
1839     MAKE_CASE(AArch64ISD::MULHS_PRED)
1840     MAKE_CASE(AArch64ISD::MULHU_PRED)
1841     MAKE_CASE(AArch64ISD::SDIV_PRED)
1842     MAKE_CASE(AArch64ISD::SHL_PRED)
1843     MAKE_CASE(AArch64ISD::SMAX_PRED)
1844     MAKE_CASE(AArch64ISD::SMIN_PRED)
1845     MAKE_CASE(AArch64ISD::SRA_PRED)
1846     MAKE_CASE(AArch64ISD::SRL_PRED)
1847     MAKE_CASE(AArch64ISD::SUB_PRED)
1848     MAKE_CASE(AArch64ISD::UDIV_PRED)
1849     MAKE_CASE(AArch64ISD::UMAX_PRED)
1850     MAKE_CASE(AArch64ISD::UMIN_PRED)
1851     MAKE_CASE(AArch64ISD::FNEG_MERGE_PASSTHRU)
1852     MAKE_CASE(AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU)
1853     MAKE_CASE(AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU)
1854     MAKE_CASE(AArch64ISD::FCEIL_MERGE_PASSTHRU)
1855     MAKE_CASE(AArch64ISD::FFLOOR_MERGE_PASSTHRU)
1856     MAKE_CASE(AArch64ISD::FNEARBYINT_MERGE_PASSTHRU)
1857     MAKE_CASE(AArch64ISD::FRINT_MERGE_PASSTHRU)
1858     MAKE_CASE(AArch64ISD::FROUND_MERGE_PASSTHRU)
1859     MAKE_CASE(AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU)
1860     MAKE_CASE(AArch64ISD::FTRUNC_MERGE_PASSTHRU)
1861     MAKE_CASE(AArch64ISD::FP_ROUND_MERGE_PASSTHRU)
1862     MAKE_CASE(AArch64ISD::FP_EXTEND_MERGE_PASSTHRU)
1863     MAKE_CASE(AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU)
1864     MAKE_CASE(AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU)
1865     MAKE_CASE(AArch64ISD::FCVTZU_MERGE_PASSTHRU)
1866     MAKE_CASE(AArch64ISD::FCVTZS_MERGE_PASSTHRU)
1867     MAKE_CASE(AArch64ISD::FSQRT_MERGE_PASSTHRU)
1868     MAKE_CASE(AArch64ISD::FRECPX_MERGE_PASSTHRU)
1869     MAKE_CASE(AArch64ISD::FABS_MERGE_PASSTHRU)
1870     MAKE_CASE(AArch64ISD::ABS_MERGE_PASSTHRU)
1871     MAKE_CASE(AArch64ISD::NEG_MERGE_PASSTHRU)
1872     MAKE_CASE(AArch64ISD::SETCC_MERGE_ZERO)
1873     MAKE_CASE(AArch64ISD::ADC)
1874     MAKE_CASE(AArch64ISD::SBC)
1875     MAKE_CASE(AArch64ISD::ADDS)
1876     MAKE_CASE(AArch64ISD::SUBS)
1877     MAKE_CASE(AArch64ISD::ADCS)
1878     MAKE_CASE(AArch64ISD::SBCS)
1879     MAKE_CASE(AArch64ISD::ANDS)
1880     MAKE_CASE(AArch64ISD::CCMP)
1881     MAKE_CASE(AArch64ISD::CCMN)
1882     MAKE_CASE(AArch64ISD::FCCMP)
1883     MAKE_CASE(AArch64ISD::FCMP)
1884     MAKE_CASE(AArch64ISD::STRICT_FCMP)
1885     MAKE_CASE(AArch64ISD::STRICT_FCMPE)
1886     MAKE_CASE(AArch64ISD::DUP)
1887     MAKE_CASE(AArch64ISD::DUPLANE8)
1888     MAKE_CASE(AArch64ISD::DUPLANE16)
1889     MAKE_CASE(AArch64ISD::DUPLANE32)
1890     MAKE_CASE(AArch64ISD::DUPLANE64)
1891     MAKE_CASE(AArch64ISD::MOVI)
1892     MAKE_CASE(AArch64ISD::MOVIshift)
1893     MAKE_CASE(AArch64ISD::MOVIedit)
1894     MAKE_CASE(AArch64ISD::MOVImsl)
1895     MAKE_CASE(AArch64ISD::FMOV)
1896     MAKE_CASE(AArch64ISD::MVNIshift)
1897     MAKE_CASE(AArch64ISD::MVNImsl)
1898     MAKE_CASE(AArch64ISD::BICi)
1899     MAKE_CASE(AArch64ISD::ORRi)
1900     MAKE_CASE(AArch64ISD::BSP)
1901     MAKE_CASE(AArch64ISD::NEG)
1902     MAKE_CASE(AArch64ISD::EXTR)
1903     MAKE_CASE(AArch64ISD::ZIP1)
1904     MAKE_CASE(AArch64ISD::ZIP2)
1905     MAKE_CASE(AArch64ISD::UZP1)
1906     MAKE_CASE(AArch64ISD::UZP2)
1907     MAKE_CASE(AArch64ISD::TRN1)
1908     MAKE_CASE(AArch64ISD::TRN2)
1909     MAKE_CASE(AArch64ISD::REV16)
1910     MAKE_CASE(AArch64ISD::REV32)
1911     MAKE_CASE(AArch64ISD::REV64)
1912     MAKE_CASE(AArch64ISD::EXT)
1913     MAKE_CASE(AArch64ISD::VSHL)
1914     MAKE_CASE(AArch64ISD::VLSHR)
1915     MAKE_CASE(AArch64ISD::VASHR)
1916     MAKE_CASE(AArch64ISD::VSLI)
1917     MAKE_CASE(AArch64ISD::VSRI)
1918     MAKE_CASE(AArch64ISD::CMEQ)
1919     MAKE_CASE(AArch64ISD::CMGE)
1920     MAKE_CASE(AArch64ISD::CMGT)
1921     MAKE_CASE(AArch64ISD::CMHI)
1922     MAKE_CASE(AArch64ISD::CMHS)
1923     MAKE_CASE(AArch64ISD::FCMEQ)
1924     MAKE_CASE(AArch64ISD::FCMGE)
1925     MAKE_CASE(AArch64ISD::FCMGT)
1926     MAKE_CASE(AArch64ISD::CMEQz)
1927     MAKE_CASE(AArch64ISD::CMGEz)
1928     MAKE_CASE(AArch64ISD::CMGTz)
1929     MAKE_CASE(AArch64ISD::CMLEz)
1930     MAKE_CASE(AArch64ISD::CMLTz)
1931     MAKE_CASE(AArch64ISD::FCMEQz)
1932     MAKE_CASE(AArch64ISD::FCMGEz)
1933     MAKE_CASE(AArch64ISD::FCMGTz)
1934     MAKE_CASE(AArch64ISD::FCMLEz)
1935     MAKE_CASE(AArch64ISD::FCMLTz)
1936     MAKE_CASE(AArch64ISD::SADDV)
1937     MAKE_CASE(AArch64ISD::UADDV)
1938     MAKE_CASE(AArch64ISD::SRHADD)
1939     MAKE_CASE(AArch64ISD::URHADD)
1940     MAKE_CASE(AArch64ISD::SHADD)
1941     MAKE_CASE(AArch64ISD::UHADD)
1942     MAKE_CASE(AArch64ISD::SDOT)
1943     MAKE_CASE(AArch64ISD::UDOT)
1944     MAKE_CASE(AArch64ISD::SMINV)
1945     MAKE_CASE(AArch64ISD::UMINV)
1946     MAKE_CASE(AArch64ISD::SMAXV)
1947     MAKE_CASE(AArch64ISD::UMAXV)
1948     MAKE_CASE(AArch64ISD::SADDV_PRED)
1949     MAKE_CASE(AArch64ISD::UADDV_PRED)
1950     MAKE_CASE(AArch64ISD::SMAXV_PRED)
1951     MAKE_CASE(AArch64ISD::UMAXV_PRED)
1952     MAKE_CASE(AArch64ISD::SMINV_PRED)
1953     MAKE_CASE(AArch64ISD::UMINV_PRED)
1954     MAKE_CASE(AArch64ISD::ORV_PRED)
1955     MAKE_CASE(AArch64ISD::EORV_PRED)
1956     MAKE_CASE(AArch64ISD::ANDV_PRED)
1957     MAKE_CASE(AArch64ISD::CLASTA_N)
1958     MAKE_CASE(AArch64ISD::CLASTB_N)
1959     MAKE_CASE(AArch64ISD::LASTA)
1960     MAKE_CASE(AArch64ISD::LASTB)
1961     MAKE_CASE(AArch64ISD::REINTERPRET_CAST)
1962     MAKE_CASE(AArch64ISD::TBL)
1963     MAKE_CASE(AArch64ISD::FADD_PRED)
1964     MAKE_CASE(AArch64ISD::FADDA_PRED)
1965     MAKE_CASE(AArch64ISD::FADDV_PRED)
1966     MAKE_CASE(AArch64ISD::FDIV_PRED)
1967     MAKE_CASE(AArch64ISD::FMA_PRED)
1968     MAKE_CASE(AArch64ISD::FMAX_PRED)
1969     MAKE_CASE(AArch64ISD::FMAXV_PRED)
1970     MAKE_CASE(AArch64ISD::FMAXNM_PRED)
1971     MAKE_CASE(AArch64ISD::FMAXNMV_PRED)
1972     MAKE_CASE(AArch64ISD::FMIN_PRED)
1973     MAKE_CASE(AArch64ISD::FMINV_PRED)
1974     MAKE_CASE(AArch64ISD::FMINNM_PRED)
1975     MAKE_CASE(AArch64ISD::FMINNMV_PRED)
1976     MAKE_CASE(AArch64ISD::FMUL_PRED)
1977     MAKE_CASE(AArch64ISD::FSUB_PRED)
1978     MAKE_CASE(AArch64ISD::BIC)
1979     MAKE_CASE(AArch64ISD::BIT)
1980     MAKE_CASE(AArch64ISD::CBZ)
1981     MAKE_CASE(AArch64ISD::CBNZ)
1982     MAKE_CASE(AArch64ISD::TBZ)
1983     MAKE_CASE(AArch64ISD::TBNZ)
1984     MAKE_CASE(AArch64ISD::TC_RETURN)
1985     MAKE_CASE(AArch64ISD::PREFETCH)
1986     MAKE_CASE(AArch64ISD::SITOF)
1987     MAKE_CASE(AArch64ISD::UITOF)
1988     MAKE_CASE(AArch64ISD::NVCAST)
1989     MAKE_CASE(AArch64ISD::MRS)
1990     MAKE_CASE(AArch64ISD::SQSHL_I)
1991     MAKE_CASE(AArch64ISD::UQSHL_I)
1992     MAKE_CASE(AArch64ISD::SRSHR_I)
1993     MAKE_CASE(AArch64ISD::URSHR_I)
1994     MAKE_CASE(AArch64ISD::SQSHLU_I)
1995     MAKE_CASE(AArch64ISD::WrapperLarge)
1996     MAKE_CASE(AArch64ISD::LD2post)
1997     MAKE_CASE(AArch64ISD::LD3post)
1998     MAKE_CASE(AArch64ISD::LD4post)
1999     MAKE_CASE(AArch64ISD::ST2post)
2000     MAKE_CASE(AArch64ISD::ST3post)
2001     MAKE_CASE(AArch64ISD::ST4post)
2002     MAKE_CASE(AArch64ISD::LD1x2post)
2003     MAKE_CASE(AArch64ISD::LD1x3post)
2004     MAKE_CASE(AArch64ISD::LD1x4post)
2005     MAKE_CASE(AArch64ISD::ST1x2post)
2006     MAKE_CASE(AArch64ISD::ST1x3post)
2007     MAKE_CASE(AArch64ISD::ST1x4post)
2008     MAKE_CASE(AArch64ISD::LD1DUPpost)
2009     MAKE_CASE(AArch64ISD::LD2DUPpost)
2010     MAKE_CASE(AArch64ISD::LD3DUPpost)
2011     MAKE_CASE(AArch64ISD::LD4DUPpost)
2012     MAKE_CASE(AArch64ISD::LD1LANEpost)
2013     MAKE_CASE(AArch64ISD::LD2LANEpost)
2014     MAKE_CASE(AArch64ISD::LD3LANEpost)
2015     MAKE_CASE(AArch64ISD::LD4LANEpost)
2016     MAKE_CASE(AArch64ISD::ST2LANEpost)
2017     MAKE_CASE(AArch64ISD::ST3LANEpost)
2018     MAKE_CASE(AArch64ISD::ST4LANEpost)
2019     MAKE_CASE(AArch64ISD::SMULL)
2020     MAKE_CASE(AArch64ISD::UMULL)
2021     MAKE_CASE(AArch64ISD::FRECPE)
2022     MAKE_CASE(AArch64ISD::FRECPS)
2023     MAKE_CASE(AArch64ISD::FRSQRTE)
2024     MAKE_CASE(AArch64ISD::FRSQRTS)
2025     MAKE_CASE(AArch64ISD::STG)
2026     MAKE_CASE(AArch64ISD::STZG)
2027     MAKE_CASE(AArch64ISD::ST2G)
2028     MAKE_CASE(AArch64ISD::STZ2G)
2029     MAKE_CASE(AArch64ISD::SUNPKHI)
2030     MAKE_CASE(AArch64ISD::SUNPKLO)
2031     MAKE_CASE(AArch64ISD::UUNPKHI)
2032     MAKE_CASE(AArch64ISD::UUNPKLO)
2033     MAKE_CASE(AArch64ISD::INSR)
2034     MAKE_CASE(AArch64ISD::PTEST)
2035     MAKE_CASE(AArch64ISD::PTRUE)
2036     MAKE_CASE(AArch64ISD::LD1_MERGE_ZERO)
2037     MAKE_CASE(AArch64ISD::LD1S_MERGE_ZERO)
2038     MAKE_CASE(AArch64ISD::LDNF1_MERGE_ZERO)
2039     MAKE_CASE(AArch64ISD::LDNF1S_MERGE_ZERO)
2040     MAKE_CASE(AArch64ISD::LDFF1_MERGE_ZERO)
2041     MAKE_CASE(AArch64ISD::LDFF1S_MERGE_ZERO)
2042     MAKE_CASE(AArch64ISD::LD1RQ_MERGE_ZERO)
2043     MAKE_CASE(AArch64ISD::LD1RO_MERGE_ZERO)
2044     MAKE_CASE(AArch64ISD::SVE_LD2_MERGE_ZERO)
2045     MAKE_CASE(AArch64ISD::SVE_LD3_MERGE_ZERO)
2046     MAKE_CASE(AArch64ISD::SVE_LD4_MERGE_ZERO)
2047     MAKE_CASE(AArch64ISD::GLD1_MERGE_ZERO)
2048     MAKE_CASE(AArch64ISD::GLD1_SCALED_MERGE_ZERO)
2049     MAKE_CASE(AArch64ISD::GLD1_SXTW_MERGE_ZERO)
2050     MAKE_CASE(AArch64ISD::GLD1_UXTW_MERGE_ZERO)
2051     MAKE_CASE(AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO)
2052     MAKE_CASE(AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO)
2053     MAKE_CASE(AArch64ISD::GLD1_IMM_MERGE_ZERO)
2054     MAKE_CASE(AArch64ISD::GLD1S_MERGE_ZERO)
2055     MAKE_CASE(AArch64ISD::GLD1S_SCALED_MERGE_ZERO)
2056     MAKE_CASE(AArch64ISD::GLD1S_SXTW_MERGE_ZERO)
2057     MAKE_CASE(AArch64ISD::GLD1S_UXTW_MERGE_ZERO)
2058     MAKE_CASE(AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO)
2059     MAKE_CASE(AArch64ISD::GLD1S_UXTW_SCALED_MERGE_ZERO)
2060     MAKE_CASE(AArch64ISD::GLD1S_IMM_MERGE_ZERO)
2061     MAKE_CASE(AArch64ISD::GLDFF1_MERGE_ZERO)
2062     MAKE_CASE(AArch64ISD::GLDFF1_SCALED_MERGE_ZERO)
2063     MAKE_CASE(AArch64ISD::GLDFF1_SXTW_MERGE_ZERO)
2064     MAKE_CASE(AArch64ISD::GLDFF1_UXTW_MERGE_ZERO)
2065     MAKE_CASE(AArch64ISD::GLDFF1_SXTW_SCALED_MERGE_ZERO)
2066     MAKE_CASE(AArch64ISD::GLDFF1_UXTW_SCALED_MERGE_ZERO)
2067     MAKE_CASE(AArch64ISD::GLDFF1_IMM_MERGE_ZERO)
2068     MAKE_CASE(AArch64ISD::GLDFF1S_MERGE_ZERO)
2069     MAKE_CASE(AArch64ISD::GLDFF1S_SCALED_MERGE_ZERO)
2070     MAKE_CASE(AArch64ISD::GLDFF1S_SXTW_MERGE_ZERO)
2071     MAKE_CASE(AArch64ISD::GLDFF1S_UXTW_MERGE_ZERO)
2072     MAKE_CASE(AArch64ISD::GLDFF1S_SXTW_SCALED_MERGE_ZERO)
2073     MAKE_CASE(AArch64ISD::GLDFF1S_UXTW_SCALED_MERGE_ZERO)
2074     MAKE_CASE(AArch64ISD::GLDFF1S_IMM_MERGE_ZERO)
2075     MAKE_CASE(AArch64ISD::GLDNT1_MERGE_ZERO)
2076     MAKE_CASE(AArch64ISD::GLDNT1_INDEX_MERGE_ZERO)
2077     MAKE_CASE(AArch64ISD::GLDNT1S_MERGE_ZERO)
2078     MAKE_CASE(AArch64ISD::ST1_PRED)
2079     MAKE_CASE(AArch64ISD::SST1_PRED)
2080     MAKE_CASE(AArch64ISD::SST1_SCALED_PRED)
2081     MAKE_CASE(AArch64ISD::SST1_SXTW_PRED)
2082     MAKE_CASE(AArch64ISD::SST1_UXTW_PRED)
2083     MAKE_CASE(AArch64ISD::SST1_SXTW_SCALED_PRED)
2084     MAKE_CASE(AArch64ISD::SST1_UXTW_SCALED_PRED)
2085     MAKE_CASE(AArch64ISD::SST1_IMM_PRED)
2086     MAKE_CASE(AArch64ISD::SSTNT1_PRED)
2087     MAKE_CASE(AArch64ISD::SSTNT1_INDEX_PRED)
2088     MAKE_CASE(AArch64ISD::LDP)
2089     MAKE_CASE(AArch64ISD::STP)
2090     MAKE_CASE(AArch64ISD::STNP)
2091     MAKE_CASE(AArch64ISD::BITREVERSE_MERGE_PASSTHRU)
2092     MAKE_CASE(AArch64ISD::BSWAP_MERGE_PASSTHRU)
2093     MAKE_CASE(AArch64ISD::CTLZ_MERGE_PASSTHRU)
2094     MAKE_CASE(AArch64ISD::CTPOP_MERGE_PASSTHRU)
2095     MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU)
2096     MAKE_CASE(AArch64ISD::INDEX_VECTOR)
2097     MAKE_CASE(AArch64ISD::UABD)
2098     MAKE_CASE(AArch64ISD::SABD)
2099     MAKE_CASE(AArch64ISD::CALL_RVMARKER)
2100   }
2101 #undef MAKE_CASE
2102   return nullptr;
2103 }
2104 
2105 MachineBasicBlock *
EmitF128CSEL(MachineInstr & MI,MachineBasicBlock * MBB) const2106 AArch64TargetLowering::EmitF128CSEL(MachineInstr &MI,
2107                                     MachineBasicBlock *MBB) const {
2108   // We materialise the F128CSEL pseudo-instruction as some control flow and a
2109   // phi node:
2110 
2111   // OrigBB:
2112   //     [... previous instrs leading to comparison ...]
2113   //     b.ne TrueBB
2114   //     b EndBB
2115   // TrueBB:
2116   //     ; Fallthrough
2117   // EndBB:
2118   //     Dest = PHI [IfTrue, TrueBB], [IfFalse, OrigBB]
2119 
2120   MachineFunction *MF = MBB->getParent();
2121   const TargetInstrInfo *TII = Subtarget->getInstrInfo();
2122   const BasicBlock *LLVM_BB = MBB->getBasicBlock();
2123   DebugLoc DL = MI.getDebugLoc();
2124   MachineFunction::iterator It = ++MBB->getIterator();
2125 
2126   Register DestReg = MI.getOperand(0).getReg();
2127   Register IfTrueReg = MI.getOperand(1).getReg();
2128   Register IfFalseReg = MI.getOperand(2).getReg();
2129   unsigned CondCode = MI.getOperand(3).getImm();
2130   bool NZCVKilled = MI.getOperand(4).isKill();
2131 
2132   MachineBasicBlock *TrueBB = MF->CreateMachineBasicBlock(LLVM_BB);
2133   MachineBasicBlock *EndBB = MF->CreateMachineBasicBlock(LLVM_BB);
2134   MF->insert(It, TrueBB);
2135   MF->insert(It, EndBB);
2136 
2137   // Transfer rest of current basic-block to EndBB
2138   EndBB->splice(EndBB->begin(), MBB, std::next(MachineBasicBlock::iterator(MI)),
2139                 MBB->end());
2140   EndBB->transferSuccessorsAndUpdatePHIs(MBB);
2141 
2142   BuildMI(MBB, DL, TII->get(AArch64::Bcc)).addImm(CondCode).addMBB(TrueBB);
2143   BuildMI(MBB, DL, TII->get(AArch64::B)).addMBB(EndBB);
2144   MBB->addSuccessor(TrueBB);
2145   MBB->addSuccessor(EndBB);
2146 
2147   // TrueBB falls through to the end.
2148   TrueBB->addSuccessor(EndBB);
2149 
2150   if (!NZCVKilled) {
2151     TrueBB->addLiveIn(AArch64::NZCV);
2152     EndBB->addLiveIn(AArch64::NZCV);
2153   }
2154 
2155   BuildMI(*EndBB, EndBB->begin(), DL, TII->get(AArch64::PHI), DestReg)
2156       .addReg(IfTrueReg)
2157       .addMBB(TrueBB)
2158       .addReg(IfFalseReg)
2159       .addMBB(MBB);
2160 
2161   MI.eraseFromParent();
2162   return EndBB;
2163 }
2164 
EmitLoweredCatchRet(MachineInstr & MI,MachineBasicBlock * BB) const2165 MachineBasicBlock *AArch64TargetLowering::EmitLoweredCatchRet(
2166        MachineInstr &MI, MachineBasicBlock *BB) const {
2167   assert(!isAsynchronousEHPersonality(classifyEHPersonality(
2168              BB->getParent()->getFunction().getPersonalityFn())) &&
2169          "SEH does not use catchret!");
2170   return BB;
2171 }
2172 
EmitInstrWithCustomInserter(MachineInstr & MI,MachineBasicBlock * BB) const2173 MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
2174     MachineInstr &MI, MachineBasicBlock *BB) const {
2175   switch (MI.getOpcode()) {
2176   default:
2177 #ifndef NDEBUG
2178     MI.dump();
2179 #endif
2180     llvm_unreachable("Unexpected instruction for custom inserter!");
2181 
2182   case AArch64::F128CSEL:
2183     return EmitF128CSEL(MI, BB);
2184 
2185   case TargetOpcode::STACKMAP:
2186   case TargetOpcode::PATCHPOINT:
2187   case TargetOpcode::STATEPOINT:
2188     return emitPatchPoint(MI, BB);
2189 
2190   case AArch64::CATCHRET:
2191     return EmitLoweredCatchRet(MI, BB);
2192   }
2193 }
2194 
2195 //===----------------------------------------------------------------------===//
2196 // AArch64 Lowering private implementation.
2197 //===----------------------------------------------------------------------===//
2198 
2199 //===----------------------------------------------------------------------===//
2200 // Lowering Code
2201 //===----------------------------------------------------------------------===//
2202 
2203 /// isZerosVector - Check whether SDNode N is a zero-filled vector.
isZerosVector(const SDNode * N)2204 static bool isZerosVector(const SDNode *N) {
2205   // Look through a bit convert.
2206   while (N->getOpcode() == ISD::BITCAST)
2207     N = N->getOperand(0).getNode();
2208 
2209   if (ISD::isConstantSplatVectorAllZeros(N))
2210     return true;
2211 
2212   if (N->getOpcode() != AArch64ISD::DUP)
2213     return false;
2214 
2215   auto Opnd0 = N->getOperand(0);
2216   auto *CINT = dyn_cast<ConstantSDNode>(Opnd0);
2217   auto *CFP = dyn_cast<ConstantFPSDNode>(Opnd0);
2218   return (CINT && CINT->isNullValue()) || (CFP && CFP->isZero());
2219 }
2220 
2221 /// changeIntCCToAArch64CC - Convert a DAG integer condition code to an AArch64
2222 /// CC
changeIntCCToAArch64CC(ISD::CondCode CC)2223 static AArch64CC::CondCode changeIntCCToAArch64CC(ISD::CondCode CC) {
2224   switch (CC) {
2225   default:
2226     llvm_unreachable("Unknown condition code!");
2227   case ISD::SETNE:
2228     return AArch64CC::NE;
2229   case ISD::SETEQ:
2230     return AArch64CC::EQ;
2231   case ISD::SETGT:
2232     return AArch64CC::GT;
2233   case ISD::SETGE:
2234     return AArch64CC::GE;
2235   case ISD::SETLT:
2236     return AArch64CC::LT;
2237   case ISD::SETLE:
2238     return AArch64CC::LE;
2239   case ISD::SETUGT:
2240     return AArch64CC::HI;
2241   case ISD::SETUGE:
2242     return AArch64CC::HS;
2243   case ISD::SETULT:
2244     return AArch64CC::LO;
2245   case ISD::SETULE:
2246     return AArch64CC::LS;
2247   }
2248 }
2249 
2250 /// changeFPCCToAArch64CC - Convert a DAG fp condition code to an AArch64 CC.
changeFPCCToAArch64CC(ISD::CondCode CC,AArch64CC::CondCode & CondCode,AArch64CC::CondCode & CondCode2)2251 static void changeFPCCToAArch64CC(ISD::CondCode CC,
2252                                   AArch64CC::CondCode &CondCode,
2253                                   AArch64CC::CondCode &CondCode2) {
2254   CondCode2 = AArch64CC::AL;
2255   switch (CC) {
2256   default:
2257     llvm_unreachable("Unknown FP condition!");
2258   case ISD::SETEQ:
2259   case ISD::SETOEQ:
2260     CondCode = AArch64CC::EQ;
2261     break;
2262   case ISD::SETGT:
2263   case ISD::SETOGT:
2264     CondCode = AArch64CC::GT;
2265     break;
2266   case ISD::SETGE:
2267   case ISD::SETOGE:
2268     CondCode = AArch64CC::GE;
2269     break;
2270   case ISD::SETOLT:
2271     CondCode = AArch64CC::MI;
2272     break;
2273   case ISD::SETOLE:
2274     CondCode = AArch64CC::LS;
2275     break;
2276   case ISD::SETONE:
2277     CondCode = AArch64CC::MI;
2278     CondCode2 = AArch64CC::GT;
2279     break;
2280   case ISD::SETO:
2281     CondCode = AArch64CC::VC;
2282     break;
2283   case ISD::SETUO:
2284     CondCode = AArch64CC::VS;
2285     break;
2286   case ISD::SETUEQ:
2287     CondCode = AArch64CC::EQ;
2288     CondCode2 = AArch64CC::VS;
2289     break;
2290   case ISD::SETUGT:
2291     CondCode = AArch64CC::HI;
2292     break;
2293   case ISD::SETUGE:
2294     CondCode = AArch64CC::PL;
2295     break;
2296   case ISD::SETLT:
2297   case ISD::SETULT:
2298     CondCode = AArch64CC::LT;
2299     break;
2300   case ISD::SETLE:
2301   case ISD::SETULE:
2302     CondCode = AArch64CC::LE;
2303     break;
2304   case ISD::SETNE:
2305   case ISD::SETUNE:
2306     CondCode = AArch64CC::NE;
2307     break;
2308   }
2309 }
2310 
2311 /// Convert a DAG fp condition code to an AArch64 CC.
2312 /// This differs from changeFPCCToAArch64CC in that it returns cond codes that
2313 /// should be AND'ed instead of OR'ed.
changeFPCCToANDAArch64CC(ISD::CondCode CC,AArch64CC::CondCode & CondCode,AArch64CC::CondCode & CondCode2)2314 static void changeFPCCToANDAArch64CC(ISD::CondCode CC,
2315                                      AArch64CC::CondCode &CondCode,
2316                                      AArch64CC::CondCode &CondCode2) {
2317   CondCode2 = AArch64CC::AL;
2318   switch (CC) {
2319   default:
2320     changeFPCCToAArch64CC(CC, CondCode, CondCode2);
2321     assert(CondCode2 == AArch64CC::AL);
2322     break;
2323   case ISD::SETONE:
2324     // (a one b)
2325     // == ((a olt b) || (a ogt b))
2326     // == ((a ord b) && (a une b))
2327     CondCode = AArch64CC::VC;
2328     CondCode2 = AArch64CC::NE;
2329     break;
2330   case ISD::SETUEQ:
2331     // (a ueq b)
2332     // == ((a uno b) || (a oeq b))
2333     // == ((a ule b) && (a uge b))
2334     CondCode = AArch64CC::PL;
2335     CondCode2 = AArch64CC::LE;
2336     break;
2337   }
2338 }
2339 
2340 /// changeVectorFPCCToAArch64CC - Convert a DAG fp condition code to an AArch64
2341 /// CC usable with the vector instructions. Fewer operations are available
2342 /// without a real NZCV register, so we have to use less efficient combinations
2343 /// to get the same effect.
changeVectorFPCCToAArch64CC(ISD::CondCode CC,AArch64CC::CondCode & CondCode,AArch64CC::CondCode & CondCode2,bool & Invert)2344 static void changeVectorFPCCToAArch64CC(ISD::CondCode CC,
2345                                         AArch64CC::CondCode &CondCode,
2346                                         AArch64CC::CondCode &CondCode2,
2347                                         bool &Invert) {
2348   Invert = false;
2349   switch (CC) {
2350   default:
2351     // Mostly the scalar mappings work fine.
2352     changeFPCCToAArch64CC(CC, CondCode, CondCode2);
2353     break;
2354   case ISD::SETUO:
2355     Invert = true;
2356     LLVM_FALLTHROUGH;
2357   case ISD::SETO:
2358     CondCode = AArch64CC::MI;
2359     CondCode2 = AArch64CC::GE;
2360     break;
2361   case ISD::SETUEQ:
2362   case ISD::SETULT:
2363   case ISD::SETULE:
2364   case ISD::SETUGT:
2365   case ISD::SETUGE:
2366     // All of the compare-mask comparisons are ordered, but we can switch
2367     // between the two by a double inversion. E.g. ULE == !OGT.
2368     Invert = true;
2369     changeFPCCToAArch64CC(getSetCCInverse(CC, /* FP inverse */ MVT::f32),
2370                           CondCode, CondCode2);
2371     break;
2372   }
2373 }
2374 
isLegalArithImmed(uint64_t C)2375 static bool isLegalArithImmed(uint64_t C) {
2376   // Matches AArch64DAGToDAGISel::SelectArithImmed().
2377   bool IsLegal = (C >> 12 == 0) || ((C & 0xFFFULL) == 0 && C >> 24 == 0);
2378   LLVM_DEBUG(dbgs() << "Is imm " << C
2379                     << " legal: " << (IsLegal ? "yes\n" : "no\n"));
2380   return IsLegal;
2381 }
2382 
2383 // Can a (CMP op1, (sub 0, op2) be turned into a CMN instruction on
2384 // the grounds that "op1 - (-op2) == op1 + op2" ? Not always, the C and V flags
2385 // can be set differently by this operation. It comes down to whether
2386 // "SInt(~op2)+1 == SInt(~op2+1)" (and the same for UInt). If they are then
2387 // everything is fine. If not then the optimization is wrong. Thus general
2388 // comparisons are only valid if op2 != 0.
2389 //
2390 // So, finally, the only LLVM-native comparisons that don't mention C and V
2391 // are SETEQ and SETNE. They're the only ones we can safely use CMN for in
2392 // the absence of information about op2.
isCMN(SDValue Op,ISD::CondCode CC)2393 static bool isCMN(SDValue Op, ISD::CondCode CC) {
2394   return Op.getOpcode() == ISD::SUB && isNullConstant(Op.getOperand(0)) &&
2395          (CC == ISD::SETEQ || CC == ISD::SETNE);
2396 }
2397 
emitStrictFPComparison(SDValue LHS,SDValue RHS,const SDLoc & dl,SelectionDAG & DAG,SDValue Chain,bool IsSignaling)2398 static SDValue emitStrictFPComparison(SDValue LHS, SDValue RHS, const SDLoc &dl,
2399                                       SelectionDAG &DAG, SDValue Chain,
2400                                       bool IsSignaling) {
2401   EVT VT = LHS.getValueType();
2402   assert(VT != MVT::f128);
2403   assert(VT != MVT::f16 && "Lowering of strict fp16 not yet implemented");
2404   unsigned Opcode =
2405       IsSignaling ? AArch64ISD::STRICT_FCMPE : AArch64ISD::STRICT_FCMP;
2406   return DAG.getNode(Opcode, dl, {VT, MVT::Other}, {Chain, LHS, RHS});
2407 }
2408 
emitComparison(SDValue LHS,SDValue RHS,ISD::CondCode CC,const SDLoc & dl,SelectionDAG & DAG)2409 static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
2410                               const SDLoc &dl, SelectionDAG &DAG) {
2411   EVT VT = LHS.getValueType();
2412   const bool FullFP16 =
2413     static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasFullFP16();
2414 
2415   if (VT.isFloatingPoint()) {
2416     assert(VT != MVT::f128);
2417     if (VT == MVT::f16 && !FullFP16) {
2418       LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, LHS);
2419       RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, RHS);
2420       VT = MVT::f32;
2421     }
2422     return DAG.getNode(AArch64ISD::FCMP, dl, VT, LHS, RHS);
2423   }
2424 
2425   // The CMP instruction is just an alias for SUBS, and representing it as
2426   // SUBS means that it's possible to get CSE with subtract operations.
2427   // A later phase can perform the optimization of setting the destination
2428   // register to WZR/XZR if it ends up being unused.
2429   unsigned Opcode = AArch64ISD::SUBS;
2430 
2431   if (isCMN(RHS, CC)) {
2432     // Can we combine a (CMP op1, (sub 0, op2) into a CMN instruction ?
2433     Opcode = AArch64ISD::ADDS;
2434     RHS = RHS.getOperand(1);
2435   } else if (isCMN(LHS, CC)) {
2436     // As we are looking for EQ/NE compares, the operands can be commuted ; can
2437     // we combine a (CMP (sub 0, op1), op2) into a CMN instruction ?
2438     Opcode = AArch64ISD::ADDS;
2439     LHS = LHS.getOperand(1);
2440   } else if (isNullConstant(RHS) && !isUnsignedIntSetCC(CC)) {
2441     if (LHS.getOpcode() == ISD::AND) {
2442       // Similarly, (CMP (and X, Y), 0) can be implemented with a TST
2443       // (a.k.a. ANDS) except that the flags are only guaranteed to work for one
2444       // of the signed comparisons.
2445       const SDValue ANDSNode = DAG.getNode(AArch64ISD::ANDS, dl,
2446                                            DAG.getVTList(VT, MVT_CC),
2447                                            LHS.getOperand(0),
2448                                            LHS.getOperand(1));
2449       // Replace all users of (and X, Y) with newly generated (ands X, Y)
2450       DAG.ReplaceAllUsesWith(LHS, ANDSNode);
2451       return ANDSNode.getValue(1);
2452     } else if (LHS.getOpcode() == AArch64ISD::ANDS) {
2453       // Use result of ANDS
2454       return LHS.getValue(1);
2455     }
2456   }
2457 
2458   return DAG.getNode(Opcode, dl, DAG.getVTList(VT, MVT_CC), LHS, RHS)
2459       .getValue(1);
2460 }
2461 
2462 /// \defgroup AArch64CCMP CMP;CCMP matching
2463 ///
2464 /// These functions deal with the formation of CMP;CCMP;... sequences.
2465 /// The CCMP/CCMN/FCCMP/FCCMPE instructions allow the conditional execution of
2466 /// a comparison. They set the NZCV flags to a predefined value if their
2467 /// predicate is false. This allows to express arbitrary conjunctions, for
2468 /// example "cmp 0 (and (setCA (cmp A)) (setCB (cmp B)))"
2469 /// expressed as:
2470 ///   cmp A
2471 ///   ccmp B, inv(CB), CA
2472 ///   check for CB flags
2473 ///
2474 /// This naturally lets us implement chains of AND operations with SETCC
2475 /// operands. And we can even implement some other situations by transforming
2476 /// them:
2477 ///   - We can implement (NEG SETCC) i.e. negating a single comparison by
2478 ///     negating the flags used in a CCMP/FCCMP operations.
2479 ///   - We can negate the result of a whole chain of CMP/CCMP/FCCMP operations
2480 ///     by negating the flags we test for afterwards. i.e.
2481 ///     NEG (CMP CCMP CCCMP ...) can be implemented.
2482 ///   - Note that we can only ever negate all previously processed results.
2483 ///     What we can not implement by flipping the flags to test is a negation
2484 ///     of two sub-trees (because the negation affects all sub-trees emitted so
2485 ///     far, so the 2nd sub-tree we emit would also affect the first).
2486 /// With those tools we can implement some OR operations:
2487 ///   - (OR (SETCC A) (SETCC B)) can be implemented via:
2488 ///     NEG (AND (NEG (SETCC A)) (NEG (SETCC B)))
2489 ///   - After transforming OR to NEG/AND combinations we may be able to use NEG
2490 ///     elimination rules from earlier to implement the whole thing as a
2491 ///     CCMP/FCCMP chain.
2492 ///
2493 /// As complete example:
2494 ///     or (or (setCA (cmp A)) (setCB (cmp B)))
2495 ///        (and (setCC (cmp C)) (setCD (cmp D)))"
2496 /// can be reassociated to:
2497 ///     or (and (setCC (cmp C)) setCD (cmp D))
2498 //         (or (setCA (cmp A)) (setCB (cmp B)))
2499 /// can be transformed to:
2500 ///     not (and (not (and (setCC (cmp C)) (setCD (cmp D))))
2501 ///              (and (not (setCA (cmp A)) (not (setCB (cmp B))))))"
2502 /// which can be implemented as:
2503 ///   cmp C
2504 ///   ccmp D, inv(CD), CC
2505 ///   ccmp A, CA, inv(CD)
2506 ///   ccmp B, CB, inv(CA)
2507 ///   check for CB flags
2508 ///
2509 /// A counterexample is "or (and A B) (and C D)" which translates to
2510 /// not (and (not (and (not A) (not B))) (not (and (not C) (not D)))), we
2511 /// can only implement 1 of the inner (not) operations, but not both!
2512 /// @{
2513 
2514 /// Create a conditional comparison; Use CCMP, CCMN or FCCMP as appropriate.
emitConditionalComparison(SDValue LHS,SDValue RHS,ISD::CondCode CC,SDValue CCOp,AArch64CC::CondCode Predicate,AArch64CC::CondCode OutCC,const SDLoc & DL,SelectionDAG & DAG)2515 static SDValue emitConditionalComparison(SDValue LHS, SDValue RHS,
2516                                          ISD::CondCode CC, SDValue CCOp,
2517                                          AArch64CC::CondCode Predicate,
2518                                          AArch64CC::CondCode OutCC,
2519                                          const SDLoc &DL, SelectionDAG &DAG) {
2520   unsigned Opcode = 0;
2521   const bool FullFP16 =
2522     static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasFullFP16();
2523 
2524   if (LHS.getValueType().isFloatingPoint()) {
2525     assert(LHS.getValueType() != MVT::f128);
2526     if (LHS.getValueType() == MVT::f16 && !FullFP16) {
2527       LHS = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, LHS);
2528       RHS = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, RHS);
2529     }
2530     Opcode = AArch64ISD::FCCMP;
2531   } else if (RHS.getOpcode() == ISD::SUB) {
2532     SDValue SubOp0 = RHS.getOperand(0);
2533     if (isNullConstant(SubOp0) && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
2534       // See emitComparison() on why we can only do this for SETEQ and SETNE.
2535       Opcode = AArch64ISD::CCMN;
2536       RHS = RHS.getOperand(1);
2537     }
2538   }
2539   if (Opcode == 0)
2540     Opcode = AArch64ISD::CCMP;
2541 
2542   SDValue Condition = DAG.getConstant(Predicate, DL, MVT_CC);
2543   AArch64CC::CondCode InvOutCC = AArch64CC::getInvertedCondCode(OutCC);
2544   unsigned NZCV = AArch64CC::getNZCVToSatisfyCondCode(InvOutCC);
2545   SDValue NZCVOp = DAG.getConstant(NZCV, DL, MVT::i32);
2546   return DAG.getNode(Opcode, DL, MVT_CC, LHS, RHS, NZCVOp, Condition, CCOp);
2547 }
2548 
2549 /// Returns true if @p Val is a tree of AND/OR/SETCC operations that can be
2550 /// expressed as a conjunction. See \ref AArch64CCMP.
2551 /// \param CanNegate    Set to true if we can negate the whole sub-tree just by
2552 ///                     changing the conditions on the SETCC tests.
2553 ///                     (this means we can call emitConjunctionRec() with
2554 ///                      Negate==true on this sub-tree)
2555 /// \param MustBeFirst  Set to true if this subtree needs to be negated and we
2556 ///                     cannot do the negation naturally. We are required to
2557 ///                     emit the subtree first in this case.
2558 /// \param WillNegate   Is true if are called when the result of this
2559 ///                     subexpression must be negated. This happens when the
2560 ///                     outer expression is an OR. We can use this fact to know
2561 ///                     that we have a double negation (or (or ...) ...) that
2562 ///                     can be implemented for free.
canEmitConjunction(const SDValue Val,bool & CanNegate,bool & MustBeFirst,bool WillNegate,unsigned Depth=0)2563 static bool canEmitConjunction(const SDValue Val, bool &CanNegate,
2564                                bool &MustBeFirst, bool WillNegate,
2565                                unsigned Depth = 0) {
2566   if (!Val.hasOneUse())
2567     return false;
2568   unsigned Opcode = Val->getOpcode();
2569   if (Opcode == ISD::SETCC) {
2570     if (Val->getOperand(0).getValueType() == MVT::f128)
2571       return false;
2572     CanNegate = true;
2573     MustBeFirst = false;
2574     return true;
2575   }
2576   // Protect against exponential runtime and stack overflow.
2577   if (Depth > 6)
2578     return false;
2579   if (Opcode == ISD::AND || Opcode == ISD::OR) {
2580     bool IsOR = Opcode == ISD::OR;
2581     SDValue O0 = Val->getOperand(0);
2582     SDValue O1 = Val->getOperand(1);
2583     bool CanNegateL;
2584     bool MustBeFirstL;
2585     if (!canEmitConjunction(O0, CanNegateL, MustBeFirstL, IsOR, Depth+1))
2586       return false;
2587     bool CanNegateR;
2588     bool MustBeFirstR;
2589     if (!canEmitConjunction(O1, CanNegateR, MustBeFirstR, IsOR, Depth+1))
2590       return false;
2591 
2592     if (MustBeFirstL && MustBeFirstR)
2593       return false;
2594 
2595     if (IsOR) {
2596       // For an OR expression we need to be able to naturally negate at least
2597       // one side or we cannot do the transformation at all.
2598       if (!CanNegateL && !CanNegateR)
2599         return false;
2600       // If we the result of the OR will be negated and we can naturally negate
2601       // the leafs, then this sub-tree as a whole negates naturally.
2602       CanNegate = WillNegate && CanNegateL && CanNegateR;
2603       // If we cannot naturally negate the whole sub-tree, then this must be
2604       // emitted first.
2605       MustBeFirst = !CanNegate;
2606     } else {
2607       assert(Opcode == ISD::AND && "Must be OR or AND");
2608       // We cannot naturally negate an AND operation.
2609       CanNegate = false;
2610       MustBeFirst = MustBeFirstL || MustBeFirstR;
2611     }
2612     return true;
2613   }
2614   return false;
2615 }
2616 
2617 /// Emit conjunction or disjunction tree with the CMP/FCMP followed by a chain
2618 /// of CCMP/CFCMP ops. See @ref AArch64CCMP.
2619 /// Tries to transform the given i1 producing node @p Val to a series compare
2620 /// and conditional compare operations. @returns an NZCV flags producing node
2621 /// and sets @p OutCC to the flags that should be tested or returns SDValue() if
2622 /// transformation was not possible.
2623 /// \p Negate is true if we want this sub-tree being negated just by changing
2624 /// SETCC conditions.
emitConjunctionRec(SelectionDAG & DAG,SDValue Val,AArch64CC::CondCode & OutCC,bool Negate,SDValue CCOp,AArch64CC::CondCode Predicate)2625 static SDValue emitConjunctionRec(SelectionDAG &DAG, SDValue Val,
2626     AArch64CC::CondCode &OutCC, bool Negate, SDValue CCOp,
2627     AArch64CC::CondCode Predicate) {
2628   // We're at a tree leaf, produce a conditional comparison operation.
2629   unsigned Opcode = Val->getOpcode();
2630   if (Opcode == ISD::SETCC) {
2631     SDValue LHS = Val->getOperand(0);
2632     SDValue RHS = Val->getOperand(1);
2633     ISD::CondCode CC = cast<CondCodeSDNode>(Val->getOperand(2))->get();
2634     bool isInteger = LHS.getValueType().isInteger();
2635     if (Negate)
2636       CC = getSetCCInverse(CC, LHS.getValueType());
2637     SDLoc DL(Val);
2638     // Determine OutCC and handle FP special case.
2639     if (isInteger) {
2640       OutCC = changeIntCCToAArch64CC(CC);
2641     } else {
2642       assert(LHS.getValueType().isFloatingPoint());
2643       AArch64CC::CondCode ExtraCC;
2644       changeFPCCToANDAArch64CC(CC, OutCC, ExtraCC);
2645       // Some floating point conditions can't be tested with a single condition
2646       // code. Construct an additional comparison in this case.
2647       if (ExtraCC != AArch64CC::AL) {
2648         SDValue ExtraCmp;
2649         if (!CCOp.getNode())
2650           ExtraCmp = emitComparison(LHS, RHS, CC, DL, DAG);
2651         else
2652           ExtraCmp = emitConditionalComparison(LHS, RHS, CC, CCOp, Predicate,
2653                                                ExtraCC, DL, DAG);
2654         CCOp = ExtraCmp;
2655         Predicate = ExtraCC;
2656       }
2657     }
2658 
2659     // Produce a normal comparison if we are first in the chain
2660     if (!CCOp)
2661       return emitComparison(LHS, RHS, CC, DL, DAG);
2662     // Otherwise produce a ccmp.
2663     return emitConditionalComparison(LHS, RHS, CC, CCOp, Predicate, OutCC, DL,
2664                                      DAG);
2665   }
2666   assert(Val->hasOneUse() && "Valid conjunction/disjunction tree");
2667 
2668   bool IsOR = Opcode == ISD::OR;
2669 
2670   SDValue LHS = Val->getOperand(0);
2671   bool CanNegateL;
2672   bool MustBeFirstL;
2673   bool ValidL = canEmitConjunction(LHS, CanNegateL, MustBeFirstL, IsOR);
2674   assert(ValidL && "Valid conjunction/disjunction tree");
2675   (void)ValidL;
2676 
2677   SDValue RHS = Val->getOperand(1);
2678   bool CanNegateR;
2679   bool MustBeFirstR;
2680   bool ValidR = canEmitConjunction(RHS, CanNegateR, MustBeFirstR, IsOR);
2681   assert(ValidR && "Valid conjunction/disjunction tree");
2682   (void)ValidR;
2683 
2684   // Swap sub-tree that must come first to the right side.
2685   if (MustBeFirstL) {
2686     assert(!MustBeFirstR && "Valid conjunction/disjunction tree");
2687     std::swap(LHS, RHS);
2688     std::swap(CanNegateL, CanNegateR);
2689     std::swap(MustBeFirstL, MustBeFirstR);
2690   }
2691 
2692   bool NegateR;
2693   bool NegateAfterR;
2694   bool NegateL;
2695   bool NegateAfterAll;
2696   if (Opcode == ISD::OR) {
2697     // Swap the sub-tree that we can negate naturally to the left.
2698     if (!CanNegateL) {
2699       assert(CanNegateR && "at least one side must be negatable");
2700       assert(!MustBeFirstR && "invalid conjunction/disjunction tree");
2701       assert(!Negate);
2702       std::swap(LHS, RHS);
2703       NegateR = false;
2704       NegateAfterR = true;
2705     } else {
2706       // Negate the left sub-tree if possible, otherwise negate the result.
2707       NegateR = CanNegateR;
2708       NegateAfterR = !CanNegateR;
2709     }
2710     NegateL = true;
2711     NegateAfterAll = !Negate;
2712   } else {
2713     assert(Opcode == ISD::AND && "Valid conjunction/disjunction tree");
2714     assert(!Negate && "Valid conjunction/disjunction tree");
2715 
2716     NegateL = false;
2717     NegateR = false;
2718     NegateAfterR = false;
2719     NegateAfterAll = false;
2720   }
2721 
2722   // Emit sub-trees.
2723   AArch64CC::CondCode RHSCC;
2724   SDValue CmpR = emitConjunctionRec(DAG, RHS, RHSCC, NegateR, CCOp, Predicate);
2725   if (NegateAfterR)
2726     RHSCC = AArch64CC::getInvertedCondCode(RHSCC);
2727   SDValue CmpL = emitConjunctionRec(DAG, LHS, OutCC, NegateL, CmpR, RHSCC);
2728   if (NegateAfterAll)
2729     OutCC = AArch64CC::getInvertedCondCode(OutCC);
2730   return CmpL;
2731 }
2732 
2733 /// Emit expression as a conjunction (a series of CCMP/CFCMP ops).
2734 /// In some cases this is even possible with OR operations in the expression.
2735 /// See \ref AArch64CCMP.
2736 /// \see emitConjunctionRec().
emitConjunction(SelectionDAG & DAG,SDValue Val,AArch64CC::CondCode & OutCC)2737 static SDValue emitConjunction(SelectionDAG &DAG, SDValue Val,
2738                                AArch64CC::CondCode &OutCC) {
2739   bool DummyCanNegate;
2740   bool DummyMustBeFirst;
2741   if (!canEmitConjunction(Val, DummyCanNegate, DummyMustBeFirst, false))
2742     return SDValue();
2743 
2744   return emitConjunctionRec(DAG, Val, OutCC, false, SDValue(), AArch64CC::AL);
2745 }
2746 
2747 /// @}
2748 
2749 /// Returns how profitable it is to fold a comparison's operand's shift and/or
2750 /// extension operations.
getCmpOperandFoldingProfit(SDValue Op)2751 static unsigned getCmpOperandFoldingProfit(SDValue Op) {
2752   auto isSupportedExtend = [&](SDValue V) {
2753     if (V.getOpcode() == ISD::SIGN_EXTEND_INREG)
2754       return true;
2755 
2756     if (V.getOpcode() == ISD::AND)
2757       if (ConstantSDNode *MaskCst = dyn_cast<ConstantSDNode>(V.getOperand(1))) {
2758         uint64_t Mask = MaskCst->getZExtValue();
2759         return (Mask == 0xFF || Mask == 0xFFFF || Mask == 0xFFFFFFFF);
2760       }
2761 
2762     return false;
2763   };
2764 
2765   if (!Op.hasOneUse())
2766     return 0;
2767 
2768   if (isSupportedExtend(Op))
2769     return 1;
2770 
2771   unsigned Opc = Op.getOpcode();
2772   if (Opc == ISD::SHL || Opc == ISD::SRL || Opc == ISD::SRA)
2773     if (ConstantSDNode *ShiftCst = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
2774       uint64_t Shift = ShiftCst->getZExtValue();
2775       if (isSupportedExtend(Op.getOperand(0)))
2776         return (Shift <= 4) ? 2 : 1;
2777       EVT VT = Op.getValueType();
2778       if ((VT == MVT::i32 && Shift <= 31) || (VT == MVT::i64 && Shift <= 63))
2779         return 1;
2780     }
2781 
2782   return 0;
2783 }
2784 
getAArch64Cmp(SDValue LHS,SDValue RHS,ISD::CondCode CC,SDValue & AArch64cc,SelectionDAG & DAG,const SDLoc & dl)2785 static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
2786                              SDValue &AArch64cc, SelectionDAG &DAG,
2787                              const SDLoc &dl) {
2788   if (ConstantSDNode *RHSC = dyn_cast<ConstantSDNode>(RHS.getNode())) {
2789     EVT VT = RHS.getValueType();
2790     uint64_t C = RHSC->getZExtValue();
2791     if (!isLegalArithImmed(C)) {
2792       // Constant does not fit, try adjusting it by one?
2793       switch (CC) {
2794       default:
2795         break;
2796       case ISD::SETLT:
2797       case ISD::SETGE:
2798         if ((VT == MVT::i32 && C != 0x80000000 &&
2799              isLegalArithImmed((uint32_t)(C - 1))) ||
2800             (VT == MVT::i64 && C != 0x80000000ULL &&
2801              isLegalArithImmed(C - 1ULL))) {
2802           CC = (CC == ISD::SETLT) ? ISD::SETLE : ISD::SETGT;
2803           C = (VT == MVT::i32) ? (uint32_t)(C - 1) : C - 1;
2804           RHS = DAG.getConstant(C, dl, VT);
2805         }
2806         break;
2807       case ISD::SETULT:
2808       case ISD::SETUGE:
2809         if ((VT == MVT::i32 && C != 0 &&
2810              isLegalArithImmed((uint32_t)(C - 1))) ||
2811             (VT == MVT::i64 && C != 0ULL && isLegalArithImmed(C - 1ULL))) {
2812           CC = (CC == ISD::SETULT) ? ISD::SETULE : ISD::SETUGT;
2813           C = (VT == MVT::i32) ? (uint32_t)(C - 1) : C - 1;
2814           RHS = DAG.getConstant(C, dl, VT);
2815         }
2816         break;
2817       case ISD::SETLE:
2818       case ISD::SETGT:
2819         if ((VT == MVT::i32 && C != INT32_MAX &&
2820              isLegalArithImmed((uint32_t)(C + 1))) ||
2821             (VT == MVT::i64 && C != INT64_MAX &&
2822              isLegalArithImmed(C + 1ULL))) {
2823           CC = (CC == ISD::SETLE) ? ISD::SETLT : ISD::SETGE;
2824           C = (VT == MVT::i32) ? (uint32_t)(C + 1) : C + 1;
2825           RHS = DAG.getConstant(C, dl, VT);
2826         }
2827         break;
2828       case ISD::SETULE:
2829       case ISD::SETUGT:
2830         if ((VT == MVT::i32 && C != UINT32_MAX &&
2831              isLegalArithImmed((uint32_t)(C + 1))) ||
2832             (VT == MVT::i64 && C != UINT64_MAX &&
2833              isLegalArithImmed(C + 1ULL))) {
2834           CC = (CC == ISD::SETULE) ? ISD::SETULT : ISD::SETUGE;
2835           C = (VT == MVT::i32) ? (uint32_t)(C + 1) : C + 1;
2836           RHS = DAG.getConstant(C, dl, VT);
2837         }
2838         break;
2839       }
2840     }
2841   }
2842 
2843   // Comparisons are canonicalized so that the RHS operand is simpler than the
2844   // LHS one, the extreme case being when RHS is an immediate. However, AArch64
2845   // can fold some shift+extend operations on the RHS operand, so swap the
2846   // operands if that can be done.
2847   //
2848   // For example:
2849   //    lsl     w13, w11, #1
2850   //    cmp     w13, w12
2851   // can be turned into:
2852   //    cmp     w12, w11, lsl #1
2853   if (!isa<ConstantSDNode>(RHS) ||
2854       !isLegalArithImmed(cast<ConstantSDNode>(RHS)->getZExtValue())) {
2855     SDValue TheLHS = isCMN(LHS, CC) ? LHS.getOperand(1) : LHS;
2856 
2857     if (getCmpOperandFoldingProfit(TheLHS) > getCmpOperandFoldingProfit(RHS)) {
2858       std::swap(LHS, RHS);
2859       CC = ISD::getSetCCSwappedOperands(CC);
2860     }
2861   }
2862 
2863   SDValue Cmp;
2864   AArch64CC::CondCode AArch64CC;
2865   if ((CC == ISD::SETEQ || CC == ISD::SETNE) && isa<ConstantSDNode>(RHS)) {
2866     const ConstantSDNode *RHSC = cast<ConstantSDNode>(RHS);
2867 
2868     // The imm operand of ADDS is an unsigned immediate, in the range 0 to 4095.
2869     // For the i8 operand, the largest immediate is 255, so this can be easily
2870     // encoded in the compare instruction. For the i16 operand, however, the
2871     // largest immediate cannot be encoded in the compare.
2872     // Therefore, use a sign extending load and cmn to avoid materializing the
2873     // -1 constant. For example,
2874     // movz w1, #65535
2875     // ldrh w0, [x0, #0]
2876     // cmp w0, w1
2877     // >
2878     // ldrsh w0, [x0, #0]
2879     // cmn w0, #1
2880     // Fundamental, we're relying on the property that (zext LHS) == (zext RHS)
2881     // if and only if (sext LHS) == (sext RHS). The checks are in place to
2882     // ensure both the LHS and RHS are truly zero extended and to make sure the
2883     // transformation is profitable.
2884     if ((RHSC->getZExtValue() >> 16 == 0) && isa<LoadSDNode>(LHS) &&
2885         cast<LoadSDNode>(LHS)->getExtensionType() == ISD::ZEXTLOAD &&
2886         cast<LoadSDNode>(LHS)->getMemoryVT() == MVT::i16 &&
2887         LHS.getNode()->hasNUsesOfValue(1, 0)) {
2888       int16_t ValueofRHS = cast<ConstantSDNode>(RHS)->getZExtValue();
2889       if (ValueofRHS < 0 && isLegalArithImmed(-ValueofRHS)) {
2890         SDValue SExt =
2891             DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, LHS.getValueType(), LHS,
2892                         DAG.getValueType(MVT::i16));
2893         Cmp = emitComparison(SExt, DAG.getConstant(ValueofRHS, dl,
2894                                                    RHS.getValueType()),
2895                              CC, dl, DAG);
2896         AArch64CC = changeIntCCToAArch64CC(CC);
2897       }
2898     }
2899 
2900     if (!Cmp && (RHSC->isNullValue() || RHSC->isOne())) {
2901       if ((Cmp = emitConjunction(DAG, LHS, AArch64CC))) {
2902         if ((CC == ISD::SETNE) ^ RHSC->isNullValue())
2903           AArch64CC = AArch64CC::getInvertedCondCode(AArch64CC);
2904       }
2905     }
2906   }
2907 
2908   if (!Cmp) {
2909     Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
2910     AArch64CC = changeIntCCToAArch64CC(CC);
2911   }
2912   AArch64cc = DAG.getConstant(AArch64CC, dl, MVT_CC);
2913   return Cmp;
2914 }
2915 
2916 static std::pair<SDValue, SDValue>
getAArch64XALUOOp(AArch64CC::CondCode & CC,SDValue Op,SelectionDAG & DAG)2917 getAArch64XALUOOp(AArch64CC::CondCode &CC, SDValue Op, SelectionDAG &DAG) {
2918   assert((Op.getValueType() == MVT::i32 || Op.getValueType() == MVT::i64) &&
2919          "Unsupported value type");
2920   SDValue Value, Overflow;
2921   SDLoc DL(Op);
2922   SDValue LHS = Op.getOperand(0);
2923   SDValue RHS = Op.getOperand(1);
2924   unsigned Opc = 0;
2925   switch (Op.getOpcode()) {
2926   default:
2927     llvm_unreachable("Unknown overflow instruction!");
2928   case ISD::SADDO:
2929     Opc = AArch64ISD::ADDS;
2930     CC = AArch64CC::VS;
2931     break;
2932   case ISD::UADDO:
2933     Opc = AArch64ISD::ADDS;
2934     CC = AArch64CC::HS;
2935     break;
2936   case ISD::SSUBO:
2937     Opc = AArch64ISD::SUBS;
2938     CC = AArch64CC::VS;
2939     break;
2940   case ISD::USUBO:
2941     Opc = AArch64ISD::SUBS;
2942     CC = AArch64CC::LO;
2943     break;
2944   // Multiply needs a little bit extra work.
2945   case ISD::SMULO:
2946   case ISD::UMULO: {
2947     CC = AArch64CC::NE;
2948     bool IsSigned = Op.getOpcode() == ISD::SMULO;
2949     if (Op.getValueType() == MVT::i32) {
2950       unsigned ExtendOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
2951       // For a 32 bit multiply with overflow check we want the instruction
2952       // selector to generate a widening multiply (SMADDL/UMADDL). For that we
2953       // need to generate the following pattern:
2954       // (i64 add 0, (i64 mul (i64 sext|zext i32 %a), (i64 sext|zext i32 %b))
2955       LHS = DAG.getNode(ExtendOpc, DL, MVT::i64, LHS);
2956       RHS = DAG.getNode(ExtendOpc, DL, MVT::i64, RHS);
2957       SDValue Mul = DAG.getNode(ISD::MUL, DL, MVT::i64, LHS, RHS);
2958       SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::i64, Mul,
2959                                 DAG.getConstant(0, DL, MVT::i64));
2960       // On AArch64 the upper 32 bits are always zero extended for a 32 bit
2961       // operation. We need to clear out the upper 32 bits, because we used a
2962       // widening multiply that wrote all 64 bits. In the end this should be a
2963       // noop.
2964       Value = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Add);
2965       if (IsSigned) {
2966         // The signed overflow check requires more than just a simple check for
2967         // any bit set in the upper 32 bits of the result. These bits could be
2968         // just the sign bits of a negative number. To perform the overflow
2969         // check we have to arithmetic shift right the 32nd bit of the result by
2970         // 31 bits. Then we compare the result to the upper 32 bits.
2971         SDValue UpperBits = DAG.getNode(ISD::SRL, DL, MVT::i64, Add,
2972                                         DAG.getConstant(32, DL, MVT::i64));
2973         UpperBits = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, UpperBits);
2974         SDValue LowerBits = DAG.getNode(ISD::SRA, DL, MVT::i32, Value,
2975                                         DAG.getConstant(31, DL, MVT::i64));
2976         // It is important that LowerBits is last, otherwise the arithmetic
2977         // shift will not be folded into the compare (SUBS).
2978         SDVTList VTs = DAG.getVTList(MVT::i32, MVT::i32);
2979         Overflow = DAG.getNode(AArch64ISD::SUBS, DL, VTs, UpperBits, LowerBits)
2980                        .getValue(1);
2981       } else {
2982         // The overflow check for unsigned multiply is easy. We only need to
2983         // check if any of the upper 32 bits are set. This can be done with a
2984         // CMP (shifted register). For that we need to generate the following
2985         // pattern:
2986         // (i64 AArch64ISD::SUBS i64 0, (i64 srl i64 %Mul, i64 32)
2987         SDValue UpperBits = DAG.getNode(ISD::SRL, DL, MVT::i64, Mul,
2988                                         DAG.getConstant(32, DL, MVT::i64));
2989         SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i32);
2990         Overflow =
2991             DAG.getNode(AArch64ISD::SUBS, DL, VTs,
2992                         DAG.getConstant(0, DL, MVT::i64),
2993                         UpperBits).getValue(1);
2994       }
2995       break;
2996     }
2997     assert(Op.getValueType() == MVT::i64 && "Expected an i64 value type");
2998     // For the 64 bit multiply
2999     Value = DAG.getNode(ISD::MUL, DL, MVT::i64, LHS, RHS);
3000     if (IsSigned) {
3001       SDValue UpperBits = DAG.getNode(ISD::MULHS, DL, MVT::i64, LHS, RHS);
3002       SDValue LowerBits = DAG.getNode(ISD::SRA, DL, MVT::i64, Value,
3003                                       DAG.getConstant(63, DL, MVT::i64));
3004       // It is important that LowerBits is last, otherwise the arithmetic
3005       // shift will not be folded into the compare (SUBS).
3006       SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i32);
3007       Overflow = DAG.getNode(AArch64ISD::SUBS, DL, VTs, UpperBits, LowerBits)
3008                      .getValue(1);
3009     } else {
3010       SDValue UpperBits = DAG.getNode(ISD::MULHU, DL, MVT::i64, LHS, RHS);
3011       SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i32);
3012       Overflow =
3013           DAG.getNode(AArch64ISD::SUBS, DL, VTs,
3014                       DAG.getConstant(0, DL, MVT::i64),
3015                       UpperBits).getValue(1);
3016     }
3017     break;
3018   }
3019   } // switch (...)
3020 
3021   if (Opc) {
3022     SDVTList VTs = DAG.getVTList(Op->getValueType(0), MVT::i32);
3023 
3024     // Emit the AArch64 operation with overflow check.
3025     Value = DAG.getNode(Opc, DL, VTs, LHS, RHS);
3026     Overflow = Value.getValue(1);
3027   }
3028   return std::make_pair(Value, Overflow);
3029 }
3030 
LowerXOR(SDValue Op,SelectionDAG & DAG) const3031 SDValue AArch64TargetLowering::LowerXOR(SDValue Op, SelectionDAG &DAG) const {
3032   if (useSVEForFixedLengthVectorVT(Op.getValueType()))
3033     return LowerToScalableOp(Op, DAG);
3034 
3035   SDValue Sel = Op.getOperand(0);
3036   SDValue Other = Op.getOperand(1);
3037   SDLoc dl(Sel);
3038 
3039   // If the operand is an overflow checking operation, invert the condition
3040   // code and kill the Not operation. I.e., transform:
3041   // (xor (overflow_op_bool, 1))
3042   //   -->
3043   // (csel 1, 0, invert(cc), overflow_op_bool)
3044   // ... which later gets transformed to just a cset instruction with an
3045   // inverted condition code, rather than a cset + eor sequence.
3046   if (isOneConstant(Other) && ISD::isOverflowIntrOpRes(Sel)) {
3047     // Only lower legal XALUO ops.
3048     if (!DAG.getTargetLoweringInfo().isTypeLegal(Sel->getValueType(0)))
3049       return SDValue();
3050 
3051     SDValue TVal = DAG.getConstant(1, dl, MVT::i32);
3052     SDValue FVal = DAG.getConstant(0, dl, MVT::i32);
3053     AArch64CC::CondCode CC;
3054     SDValue Value, Overflow;
3055     std::tie(Value, Overflow) = getAArch64XALUOOp(CC, Sel.getValue(0), DAG);
3056     SDValue CCVal = DAG.getConstant(getInvertedCondCode(CC), dl, MVT::i32);
3057     return DAG.getNode(AArch64ISD::CSEL, dl, Op.getValueType(), TVal, FVal,
3058                        CCVal, Overflow);
3059   }
3060   // If neither operand is a SELECT_CC, give up.
3061   if (Sel.getOpcode() != ISD::SELECT_CC)
3062     std::swap(Sel, Other);
3063   if (Sel.getOpcode() != ISD::SELECT_CC)
3064     return Op;
3065 
3066   // The folding we want to perform is:
3067   // (xor x, (select_cc a, b, cc, 0, -1) )
3068   //   -->
3069   // (csel x, (xor x, -1), cc ...)
3070   //
3071   // The latter will get matched to a CSINV instruction.
3072 
3073   ISD::CondCode CC = cast<CondCodeSDNode>(Sel.getOperand(4))->get();
3074   SDValue LHS = Sel.getOperand(0);
3075   SDValue RHS = Sel.getOperand(1);
3076   SDValue TVal = Sel.getOperand(2);
3077   SDValue FVal = Sel.getOperand(3);
3078 
3079   // FIXME: This could be generalized to non-integer comparisons.
3080   if (LHS.getValueType() != MVT::i32 && LHS.getValueType() != MVT::i64)
3081     return Op;
3082 
3083   ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(FVal);
3084   ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(TVal);
3085 
3086   // The values aren't constants, this isn't the pattern we're looking for.
3087   if (!CFVal || !CTVal)
3088     return Op;
3089 
3090   // We can commute the SELECT_CC by inverting the condition.  This
3091   // might be needed to make this fit into a CSINV pattern.
3092   if (CTVal->isAllOnesValue() && CFVal->isNullValue()) {
3093     std::swap(TVal, FVal);
3094     std::swap(CTVal, CFVal);
3095     CC = ISD::getSetCCInverse(CC, LHS.getValueType());
3096   }
3097 
3098   // If the constants line up, perform the transform!
3099   if (CTVal->isNullValue() && CFVal->isAllOnesValue()) {
3100     SDValue CCVal;
3101     SDValue Cmp = getAArch64Cmp(LHS, RHS, CC, CCVal, DAG, dl);
3102 
3103     FVal = Other;
3104     TVal = DAG.getNode(ISD::XOR, dl, Other.getValueType(), Other,
3105                        DAG.getConstant(-1ULL, dl, Other.getValueType()));
3106 
3107     return DAG.getNode(AArch64ISD::CSEL, dl, Sel.getValueType(), FVal, TVal,
3108                        CCVal, Cmp);
3109   }
3110 
3111   return Op;
3112 }
3113 
LowerADDC_ADDE_SUBC_SUBE(SDValue Op,SelectionDAG & DAG)3114 static SDValue LowerADDC_ADDE_SUBC_SUBE(SDValue Op, SelectionDAG &DAG) {
3115   EVT VT = Op.getValueType();
3116 
3117   // Let legalize expand this if it isn't a legal type yet.
3118   if (!DAG.getTargetLoweringInfo().isTypeLegal(VT))
3119     return SDValue();
3120 
3121   SDVTList VTs = DAG.getVTList(VT, MVT::i32);
3122 
3123   unsigned Opc;
3124   bool ExtraOp = false;
3125   switch (Op.getOpcode()) {
3126   default:
3127     llvm_unreachable("Invalid code");
3128   case ISD::ADDC:
3129     Opc = AArch64ISD::ADDS;
3130     break;
3131   case ISD::SUBC:
3132     Opc = AArch64ISD::SUBS;
3133     break;
3134   case ISD::ADDE:
3135     Opc = AArch64ISD::ADCS;
3136     ExtraOp = true;
3137     break;
3138   case ISD::SUBE:
3139     Opc = AArch64ISD::SBCS;
3140     ExtraOp = true;
3141     break;
3142   }
3143 
3144   if (!ExtraOp)
3145     return DAG.getNode(Opc, SDLoc(Op), VTs, Op.getOperand(0), Op.getOperand(1));
3146   return DAG.getNode(Opc, SDLoc(Op), VTs, Op.getOperand(0), Op.getOperand(1),
3147                      Op.getOperand(2));
3148 }
3149 
LowerXALUO(SDValue Op,SelectionDAG & DAG)3150 static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) {
3151   // Let legalize expand this if it isn't a legal type yet.
3152   if (!DAG.getTargetLoweringInfo().isTypeLegal(Op.getValueType()))
3153     return SDValue();
3154 
3155   SDLoc dl(Op);
3156   AArch64CC::CondCode CC;
3157   // The actual operation that sets the overflow or carry flag.
3158   SDValue Value, Overflow;
3159   std::tie(Value, Overflow) = getAArch64XALUOOp(CC, Op, DAG);
3160 
3161   // We use 0 and 1 as false and true values.
3162   SDValue TVal = DAG.getConstant(1, dl, MVT::i32);
3163   SDValue FVal = DAG.getConstant(0, dl, MVT::i32);
3164 
3165   // We use an inverted condition, because the conditional select is inverted
3166   // too. This will allow it to be selected to a single instruction:
3167   // CSINC Wd, WZR, WZR, invert(cond).
3168   SDValue CCVal = DAG.getConstant(getInvertedCondCode(CC), dl, MVT::i32);
3169   Overflow = DAG.getNode(AArch64ISD::CSEL, dl, MVT::i32, FVal, TVal,
3170                          CCVal, Overflow);
3171 
3172   SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32);
3173   return DAG.getNode(ISD::MERGE_VALUES, dl, VTs, Value, Overflow);
3174 }
3175 
3176 // Prefetch operands are:
3177 // 1: Address to prefetch
3178 // 2: bool isWrite
3179 // 3: int locality (0 = no locality ... 3 = extreme locality)
3180 // 4: bool isDataCache
LowerPREFETCH(SDValue Op,SelectionDAG & DAG)3181 static SDValue LowerPREFETCH(SDValue Op, SelectionDAG &DAG) {
3182   SDLoc DL(Op);
3183   unsigned IsWrite = cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue();
3184   unsigned Locality = cast<ConstantSDNode>(Op.getOperand(3))->getZExtValue();
3185   unsigned IsData = cast<ConstantSDNode>(Op.getOperand(4))->getZExtValue();
3186 
3187   bool IsStream = !Locality;
3188   // When the locality number is set
3189   if (Locality) {
3190     // The front-end should have filtered out the out-of-range values
3191     assert(Locality <= 3 && "Prefetch locality out-of-range");
3192     // The locality degree is the opposite of the cache speed.
3193     // Put the number the other way around.
3194     // The encoding starts at 0 for level 1
3195     Locality = 3 - Locality;
3196   }
3197 
3198   // built the mask value encoding the expected behavior.
3199   unsigned PrfOp = (IsWrite << 4) |     // Load/Store bit
3200                    (!IsData << 3) |     // IsDataCache bit
3201                    (Locality << 1) |    // Cache level bits
3202                    (unsigned)IsStream;  // Stream bit
3203   return DAG.getNode(AArch64ISD::PREFETCH, DL, MVT::Other, Op.getOperand(0),
3204                      DAG.getConstant(PrfOp, DL, MVT::i32), Op.getOperand(1));
3205 }
3206 
LowerFP_EXTEND(SDValue Op,SelectionDAG & DAG) const3207 SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
3208                                               SelectionDAG &DAG) const {
3209   if (Op.getValueType().isScalableVector())
3210     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_EXTEND_MERGE_PASSTHRU);
3211 
3212   assert(Op.getValueType() == MVT::f128 && "Unexpected lowering");
3213   return SDValue();
3214 }
3215 
LowerFP_ROUND(SDValue Op,SelectionDAG & DAG) const3216 SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
3217                                              SelectionDAG &DAG) const {
3218   if (Op.getValueType().isScalableVector())
3219     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU);
3220 
3221   bool IsStrict = Op->isStrictFPOpcode();
3222   SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
3223   EVT SrcVT = SrcVal.getValueType();
3224 
3225   if (SrcVT != MVT::f128) {
3226     // Expand cases where the input is a vector bigger than NEON.
3227     if (useSVEForFixedLengthVectorVT(SrcVT))
3228       return SDValue();
3229 
3230     // It's legal except when f128 is involved
3231     return Op;
3232   }
3233 
3234   return SDValue();
3235 }
3236 
LowerVectorFP_TO_INT(SDValue Op,SelectionDAG & DAG) const3237 SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
3238                                                     SelectionDAG &DAG) const {
3239   // Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp.
3240   // Any additional optimization in this function should be recorded
3241   // in the cost tables.
3242   EVT InVT = Op.getOperand(0).getValueType();
3243   EVT VT = Op.getValueType();
3244 
3245   if (VT.isScalableVector()) {
3246     unsigned Opcode = Op.getOpcode() == ISD::FP_TO_UINT
3247                           ? AArch64ISD::FCVTZU_MERGE_PASSTHRU
3248                           : AArch64ISD::FCVTZS_MERGE_PASSTHRU;
3249     return LowerToPredicatedOp(Op, DAG, Opcode);
3250   }
3251 
3252   unsigned NumElts = InVT.getVectorNumElements();
3253 
3254   // f16 conversions are promoted to f32 when full fp16 is not supported.
3255   if (InVT.getVectorElementType() == MVT::f16 &&
3256       !Subtarget->hasFullFP16()) {
3257     MVT NewVT = MVT::getVectorVT(MVT::f32, NumElts);
3258     SDLoc dl(Op);
3259     return DAG.getNode(
3260         Op.getOpcode(), dl, Op.getValueType(),
3261         DAG.getNode(ISD::FP_EXTEND, dl, NewVT, Op.getOperand(0)));
3262   }
3263 
3264   uint64_t VTSize = VT.getFixedSizeInBits();
3265   uint64_t InVTSize = InVT.getFixedSizeInBits();
3266   if (VTSize < InVTSize) {
3267     SDLoc dl(Op);
3268     SDValue Cv =
3269         DAG.getNode(Op.getOpcode(), dl, InVT.changeVectorElementTypeToInteger(),
3270                     Op.getOperand(0));
3271     return DAG.getNode(ISD::TRUNCATE, dl, VT, Cv);
3272   }
3273 
3274   if (VTSize > InVTSize) {
3275     SDLoc dl(Op);
3276     MVT ExtVT =
3277         MVT::getVectorVT(MVT::getFloatingPointVT(VT.getScalarSizeInBits()),
3278                          VT.getVectorNumElements());
3279     SDValue Ext = DAG.getNode(ISD::FP_EXTEND, dl, ExtVT, Op.getOperand(0));
3280     return DAG.getNode(Op.getOpcode(), dl, VT, Ext);
3281   }
3282 
3283   // Type changing conversions are illegal.
3284   return Op;
3285 }
3286 
LowerFP_TO_INT(SDValue Op,SelectionDAG & DAG) const3287 SDValue AArch64TargetLowering::LowerFP_TO_INT(SDValue Op,
3288                                               SelectionDAG &DAG) const {
3289   bool IsStrict = Op->isStrictFPOpcode();
3290   SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
3291 
3292   if (SrcVal.getValueType().isVector())
3293     return LowerVectorFP_TO_INT(Op, DAG);
3294 
3295   // f16 conversions are promoted to f32 when full fp16 is not supported.
3296   if (SrcVal.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) {
3297     assert(!IsStrict && "Lowering of strict fp16 not yet implemented");
3298     SDLoc dl(Op);
3299     return DAG.getNode(
3300         Op.getOpcode(), dl, Op.getValueType(),
3301         DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, SrcVal));
3302   }
3303 
3304   if (SrcVal.getValueType() != MVT::f128) {
3305     // It's legal except when f128 is involved
3306     return Op;
3307   }
3308 
3309   return SDValue();
3310 }
3311 
LowerFP_TO_INT_SAT(SDValue Op,SelectionDAG & DAG) const3312 SDValue AArch64TargetLowering::LowerFP_TO_INT_SAT(SDValue Op,
3313                                                   SelectionDAG &DAG) const {
3314   // AArch64 FP-to-int conversions saturate to the destination register size, so
3315   // we can lower common saturating conversions to simple instructions.
3316   SDValue SrcVal = Op.getOperand(0);
3317 
3318   EVT SrcVT = SrcVal.getValueType();
3319   EVT DstVT = Op.getValueType();
3320 
3321   EVT SatVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
3322   uint64_t SatWidth = SatVT.getScalarSizeInBits();
3323   uint64_t DstWidth = DstVT.getScalarSizeInBits();
3324   assert(SatWidth <= DstWidth && "Saturation width cannot exceed result width");
3325 
3326   // TODO: Support lowering of NEON and SVE conversions.
3327   if (SrcVT.isVector())
3328     return SDValue();
3329 
3330   // TODO: Saturate to SatWidth explicitly.
3331   if (SatWidth != DstWidth)
3332     return SDValue();
3333 
3334   // In the absence of FP16 support, promote f32 to f16, like LowerFP_TO_INT().
3335   if (SrcVT == MVT::f16 && !Subtarget->hasFullFP16())
3336     return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(),
3337                        DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, SrcVal),
3338                        Op.getOperand(1));
3339 
3340   // Cases that we can emit directly.
3341   if ((SrcVT == MVT::f64 || SrcVT == MVT::f32 ||
3342        (SrcVT == MVT::f16 && Subtarget->hasFullFP16())) &&
3343       (DstVT == MVT::i64 || DstVT == MVT::i32))
3344     return Op;
3345 
3346   // For all other cases, fall back on the expanded form.
3347   return SDValue();
3348 }
3349 
LowerVectorINT_TO_FP(SDValue Op,SelectionDAG & DAG) const3350 SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
3351                                                     SelectionDAG &DAG) const {
3352   // Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp.
3353   // Any additional optimization in this function should be recorded
3354   // in the cost tables.
3355   EVT VT = Op.getValueType();
3356   SDLoc dl(Op);
3357   SDValue In = Op.getOperand(0);
3358   EVT InVT = In.getValueType();
3359   unsigned Opc = Op.getOpcode();
3360   bool IsSigned = Opc == ISD::SINT_TO_FP || Opc == ISD::STRICT_SINT_TO_FP;
3361 
3362   if (VT.isScalableVector()) {
3363     if (InVT.getVectorElementType() == MVT::i1) {
3364       // We can't directly extend an SVE predicate; extend it first.
3365       unsigned CastOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
3366       EVT CastVT = getPromotedVTForPredicate(InVT);
3367       In = DAG.getNode(CastOpc, dl, CastVT, In);
3368       return DAG.getNode(Opc, dl, VT, In);
3369     }
3370 
3371     unsigned Opcode = IsSigned ? AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU
3372                                : AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU;
3373     return LowerToPredicatedOp(Op, DAG, Opcode);
3374   }
3375 
3376   uint64_t VTSize = VT.getFixedSizeInBits();
3377   uint64_t InVTSize = InVT.getFixedSizeInBits();
3378   if (VTSize < InVTSize) {
3379     MVT CastVT =
3380         MVT::getVectorVT(MVT::getFloatingPointVT(InVT.getScalarSizeInBits()),
3381                          InVT.getVectorNumElements());
3382     In = DAG.getNode(Opc, dl, CastVT, In);
3383     return DAG.getNode(ISD::FP_ROUND, dl, VT, In, DAG.getIntPtrConstant(0, dl));
3384   }
3385 
3386   if (VTSize > InVTSize) {
3387     unsigned CastOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
3388     EVT CastVT = VT.changeVectorElementTypeToInteger();
3389     In = DAG.getNode(CastOpc, dl, CastVT, In);
3390     return DAG.getNode(Opc, dl, VT, In);
3391   }
3392 
3393   return Op;
3394 }
3395 
LowerINT_TO_FP(SDValue Op,SelectionDAG & DAG) const3396 SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
3397                                             SelectionDAG &DAG) const {
3398   if (Op.getValueType().isVector())
3399     return LowerVectorINT_TO_FP(Op, DAG);
3400 
3401   bool IsStrict = Op->isStrictFPOpcode();
3402   SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
3403 
3404   // f16 conversions are promoted to f32 when full fp16 is not supported.
3405   if (Op.getValueType() == MVT::f16 &&
3406       !Subtarget->hasFullFP16()) {
3407     assert(!IsStrict && "Lowering of strict fp16 not yet implemented");
3408     SDLoc dl(Op);
3409     return DAG.getNode(
3410         ISD::FP_ROUND, dl, MVT::f16,
3411         DAG.getNode(Op.getOpcode(), dl, MVT::f32, SrcVal),
3412         DAG.getIntPtrConstant(0, dl));
3413   }
3414 
3415   // i128 conversions are libcalls.
3416   if (SrcVal.getValueType() == MVT::i128)
3417     return SDValue();
3418 
3419   // Other conversions are legal, unless it's to the completely software-based
3420   // fp128.
3421   if (Op.getValueType() != MVT::f128)
3422     return Op;
3423   return SDValue();
3424 }
3425 
LowerFSINCOS(SDValue Op,SelectionDAG & DAG) const3426 SDValue AArch64TargetLowering::LowerFSINCOS(SDValue Op,
3427                                             SelectionDAG &DAG) const {
3428   // For iOS, we want to call an alternative entry point: __sincos_stret,
3429   // which returns the values in two S / D registers.
3430   SDLoc dl(Op);
3431   SDValue Arg = Op.getOperand(0);
3432   EVT ArgVT = Arg.getValueType();
3433   Type *ArgTy = ArgVT.getTypeForEVT(*DAG.getContext());
3434 
3435   ArgListTy Args;
3436   ArgListEntry Entry;
3437 
3438   Entry.Node = Arg;
3439   Entry.Ty = ArgTy;
3440   Entry.IsSExt = false;
3441   Entry.IsZExt = false;
3442   Args.push_back(Entry);
3443 
3444   RTLIB::Libcall LC = ArgVT == MVT::f64 ? RTLIB::SINCOS_STRET_F64
3445                                         : RTLIB::SINCOS_STRET_F32;
3446   const char *LibcallName = getLibcallName(LC);
3447   SDValue Callee =
3448       DAG.getExternalSymbol(LibcallName, getPointerTy(DAG.getDataLayout()));
3449 
3450   StructType *RetTy = StructType::get(ArgTy, ArgTy);
3451   TargetLowering::CallLoweringInfo CLI(DAG);
3452   CLI.setDebugLoc(dl)
3453       .setChain(DAG.getEntryNode())
3454       .setLibCallee(CallingConv::Fast, RetTy, Callee, std::move(Args));
3455 
3456   std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI);
3457   return CallResult.first;
3458 }
3459 
LowerBITCAST(SDValue Op,SelectionDAG & DAG) const3460 SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
3461                                             SelectionDAG &DAG) const {
3462   EVT OpVT = Op.getValueType();
3463 
3464   if (useSVEForFixedLengthVectorVT(OpVT))
3465     return LowerFixedLengthBitcastToSVE(Op, DAG);
3466 
3467   if (OpVT != MVT::f16 && OpVT != MVT::bf16)
3468     return SDValue();
3469 
3470   assert(Op.getOperand(0).getValueType() == MVT::i16);
3471   SDLoc DL(Op);
3472 
3473   Op = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op.getOperand(0));
3474   Op = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Op);
3475   return SDValue(
3476       DAG.getMachineNode(TargetOpcode::EXTRACT_SUBREG, DL, OpVT, Op,
3477                          DAG.getTargetConstant(AArch64::hsub, DL, MVT::i32)),
3478       0);
3479 }
3480 
getExtensionTo64Bits(const EVT & OrigVT)3481 static EVT getExtensionTo64Bits(const EVT &OrigVT) {
3482   if (OrigVT.getSizeInBits() >= 64)
3483     return OrigVT;
3484 
3485   assert(OrigVT.isSimple() && "Expecting a simple value type");
3486 
3487   MVT::SimpleValueType OrigSimpleTy = OrigVT.getSimpleVT().SimpleTy;
3488   switch (OrigSimpleTy) {
3489   default: llvm_unreachable("Unexpected Vector Type");
3490   case MVT::v2i8:
3491   case MVT::v2i16:
3492      return MVT::v2i32;
3493   case MVT::v4i8:
3494     return  MVT::v4i16;
3495   }
3496 }
3497 
addRequiredExtensionForVectorMULL(SDValue N,SelectionDAG & DAG,const EVT & OrigTy,const EVT & ExtTy,unsigned ExtOpcode)3498 static SDValue addRequiredExtensionForVectorMULL(SDValue N, SelectionDAG &DAG,
3499                                                  const EVT &OrigTy,
3500                                                  const EVT &ExtTy,
3501                                                  unsigned ExtOpcode) {
3502   // The vector originally had a size of OrigTy. It was then extended to ExtTy.
3503   // We expect the ExtTy to be 128-bits total. If the OrigTy is less than
3504   // 64-bits we need to insert a new extension so that it will be 64-bits.
3505   assert(ExtTy.is128BitVector() && "Unexpected extension size");
3506   if (OrigTy.getSizeInBits() >= 64)
3507     return N;
3508 
3509   // Must extend size to at least 64 bits to be used as an operand for VMULL.
3510   EVT NewVT = getExtensionTo64Bits(OrigTy);
3511 
3512   return DAG.getNode(ExtOpcode, SDLoc(N), NewVT, N);
3513 }
3514 
isExtendedBUILD_VECTOR(SDNode * N,SelectionDAG & DAG,bool isSigned)3515 static bool isExtendedBUILD_VECTOR(SDNode *N, SelectionDAG &DAG,
3516                                    bool isSigned) {
3517   EVT VT = N->getValueType(0);
3518 
3519   if (N->getOpcode() != ISD::BUILD_VECTOR)
3520     return false;
3521 
3522   for (const SDValue &Elt : N->op_values()) {
3523     if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Elt)) {
3524       unsigned EltSize = VT.getScalarSizeInBits();
3525       unsigned HalfSize = EltSize / 2;
3526       if (isSigned) {
3527         if (!isIntN(HalfSize, C->getSExtValue()))
3528           return false;
3529       } else {
3530         if (!isUIntN(HalfSize, C->getZExtValue()))
3531           return false;
3532       }
3533       continue;
3534     }
3535     return false;
3536   }
3537 
3538   return true;
3539 }
3540 
skipExtensionForVectorMULL(SDNode * N,SelectionDAG & DAG)3541 static SDValue skipExtensionForVectorMULL(SDNode *N, SelectionDAG &DAG) {
3542   if (N->getOpcode() == ISD::SIGN_EXTEND ||
3543       N->getOpcode() == ISD::ZERO_EXTEND || N->getOpcode() == ISD::ANY_EXTEND)
3544     return addRequiredExtensionForVectorMULL(N->getOperand(0), DAG,
3545                                              N->getOperand(0)->getValueType(0),
3546                                              N->getValueType(0),
3547                                              N->getOpcode());
3548 
3549   assert(N->getOpcode() == ISD::BUILD_VECTOR && "expected BUILD_VECTOR");
3550   EVT VT = N->getValueType(0);
3551   SDLoc dl(N);
3552   unsigned EltSize = VT.getScalarSizeInBits() / 2;
3553   unsigned NumElts = VT.getVectorNumElements();
3554   MVT TruncVT = MVT::getIntegerVT(EltSize);
3555   SmallVector<SDValue, 8> Ops;
3556   for (unsigned i = 0; i != NumElts; ++i) {
3557     ConstantSDNode *C = cast<ConstantSDNode>(N->getOperand(i));
3558     const APInt &CInt = C->getAPIntValue();
3559     // Element types smaller than 32 bits are not legal, so use i32 elements.
3560     // The values are implicitly truncated so sext vs. zext doesn't matter.
3561     Ops.push_back(DAG.getConstant(CInt.zextOrTrunc(32), dl, MVT::i32));
3562   }
3563   return DAG.getBuildVector(MVT::getVectorVT(TruncVT, NumElts), dl, Ops);
3564 }
3565 
isSignExtended(SDNode * N,SelectionDAG & DAG)3566 static bool isSignExtended(SDNode *N, SelectionDAG &DAG) {
3567   return N->getOpcode() == ISD::SIGN_EXTEND ||
3568          N->getOpcode() == ISD::ANY_EXTEND ||
3569          isExtendedBUILD_VECTOR(N, DAG, true);
3570 }
3571 
isZeroExtended(SDNode * N,SelectionDAG & DAG)3572 static bool isZeroExtended(SDNode *N, SelectionDAG &DAG) {
3573   return N->getOpcode() == ISD::ZERO_EXTEND ||
3574          N->getOpcode() == ISD::ANY_EXTEND ||
3575          isExtendedBUILD_VECTOR(N, DAG, false);
3576 }
3577 
isAddSubSExt(SDNode * N,SelectionDAG & DAG)3578 static bool isAddSubSExt(SDNode *N, SelectionDAG &DAG) {
3579   unsigned Opcode = N->getOpcode();
3580   if (Opcode == ISD::ADD || Opcode == ISD::SUB) {
3581     SDNode *N0 = N->getOperand(0).getNode();
3582     SDNode *N1 = N->getOperand(1).getNode();
3583     return N0->hasOneUse() && N1->hasOneUse() &&
3584       isSignExtended(N0, DAG) && isSignExtended(N1, DAG);
3585   }
3586   return false;
3587 }
3588 
isAddSubZExt(SDNode * N,SelectionDAG & DAG)3589 static bool isAddSubZExt(SDNode *N, SelectionDAG &DAG) {
3590   unsigned Opcode = N->getOpcode();
3591   if (Opcode == ISD::ADD || Opcode == ISD::SUB) {
3592     SDNode *N0 = N->getOperand(0).getNode();
3593     SDNode *N1 = N->getOperand(1).getNode();
3594     return N0->hasOneUse() && N1->hasOneUse() &&
3595       isZeroExtended(N0, DAG) && isZeroExtended(N1, DAG);
3596   }
3597   return false;
3598 }
3599 
LowerFLT_ROUNDS_(SDValue Op,SelectionDAG & DAG) const3600 SDValue AArch64TargetLowering::LowerFLT_ROUNDS_(SDValue Op,
3601                                                 SelectionDAG &DAG) const {
3602   // The rounding mode is in bits 23:22 of the FPSCR.
3603   // The ARM rounding mode value to FLT_ROUNDS mapping is 0->1, 1->2, 2->3, 3->0
3604   // The formula we use to implement this is (((FPSCR + 1 << 22) >> 22) & 3)
3605   // so that the shift + and get folded into a bitfield extract.
3606   SDLoc dl(Op);
3607 
3608   SDValue Chain = Op.getOperand(0);
3609   SDValue FPCR_64 = DAG.getNode(
3610       ISD::INTRINSIC_W_CHAIN, dl, {MVT::i64, MVT::Other},
3611       {Chain, DAG.getConstant(Intrinsic::aarch64_get_fpcr, dl, MVT::i64)});
3612   Chain = FPCR_64.getValue(1);
3613   SDValue FPCR_32 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, FPCR_64);
3614   SDValue FltRounds = DAG.getNode(ISD::ADD, dl, MVT::i32, FPCR_32,
3615                                   DAG.getConstant(1U << 22, dl, MVT::i32));
3616   SDValue RMODE = DAG.getNode(ISD::SRL, dl, MVT::i32, FltRounds,
3617                               DAG.getConstant(22, dl, MVT::i32));
3618   SDValue AND = DAG.getNode(ISD::AND, dl, MVT::i32, RMODE,
3619                             DAG.getConstant(3, dl, MVT::i32));
3620   return DAG.getMergeValues({AND, Chain}, dl);
3621 }
3622 
LowerSET_ROUNDING(SDValue Op,SelectionDAG & DAG) const3623 SDValue AArch64TargetLowering::LowerSET_ROUNDING(SDValue Op,
3624                                                  SelectionDAG &DAG) const {
3625   SDLoc DL(Op);
3626   SDValue Chain = Op->getOperand(0);
3627   SDValue RMValue = Op->getOperand(1);
3628 
3629   // The rounding mode is in bits 23:22 of the FPCR.
3630   // The llvm.set.rounding argument value to the rounding mode in FPCR mapping
3631   // is 0->3, 1->0, 2->1, 3->2. The formula we use to implement this is
3632   // ((arg - 1) & 3) << 22).
3633   //
3634   // The argument of llvm.set.rounding must be within the segment [0, 3], so
3635   // NearestTiesToAway (4) is not handled here. It is responsibility of the code
3636   // generated llvm.set.rounding to ensure this condition.
3637 
3638   // Calculate new value of FPCR[23:22].
3639   RMValue = DAG.getNode(ISD::SUB, DL, MVT::i32, RMValue,
3640                         DAG.getConstant(1, DL, MVT::i32));
3641   RMValue = DAG.getNode(ISD::AND, DL, MVT::i32, RMValue,
3642                         DAG.getConstant(0x3, DL, MVT::i32));
3643   RMValue =
3644       DAG.getNode(ISD::SHL, DL, MVT::i32, RMValue,
3645                   DAG.getConstant(AArch64::RoundingBitsPos, DL, MVT::i32));
3646   RMValue = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, RMValue);
3647 
3648   // Get current value of FPCR.
3649   SDValue Ops[] = {
3650       Chain, DAG.getTargetConstant(Intrinsic::aarch64_get_fpcr, DL, MVT::i64)};
3651   SDValue FPCR =
3652       DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, {MVT::i64, MVT::Other}, Ops);
3653   Chain = FPCR.getValue(1);
3654   FPCR = FPCR.getValue(0);
3655 
3656   // Put new rounding mode into FPSCR[23:22].
3657   const int RMMask = ~(AArch64::Rounding::rmMask << AArch64::RoundingBitsPos);
3658   FPCR = DAG.getNode(ISD::AND, DL, MVT::i64, FPCR,
3659                      DAG.getConstant(RMMask, DL, MVT::i64));
3660   FPCR = DAG.getNode(ISD::OR, DL, MVT::i64, FPCR, RMValue);
3661   SDValue Ops2[] = {
3662       Chain, DAG.getTargetConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64),
3663       FPCR};
3664   return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2);
3665 }
3666 
LowerMUL(SDValue Op,SelectionDAG & DAG) const3667 SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const {
3668   EVT VT = Op.getValueType();
3669 
3670   // If SVE is available then i64 vector multiplications can also be made legal.
3671   bool OverrideNEON = VT == MVT::v2i64 || VT == MVT::v1i64;
3672 
3673   if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT, OverrideNEON))
3674     return LowerToPredicatedOp(Op, DAG, AArch64ISD::MUL_PRED, OverrideNEON);
3675 
3676   // Multiplications are only custom-lowered for 128-bit vectors so that
3677   // VMULL can be detected.  Otherwise v2i64 multiplications are not legal.
3678   assert(VT.is128BitVector() && VT.isInteger() &&
3679          "unexpected type for custom-lowering ISD::MUL");
3680   SDNode *N0 = Op.getOperand(0).getNode();
3681   SDNode *N1 = Op.getOperand(1).getNode();
3682   unsigned NewOpc = 0;
3683   bool isMLA = false;
3684   bool isN0SExt = isSignExtended(N0, DAG);
3685   bool isN1SExt = isSignExtended(N1, DAG);
3686   if (isN0SExt && isN1SExt)
3687     NewOpc = AArch64ISD::SMULL;
3688   else {
3689     bool isN0ZExt = isZeroExtended(N0, DAG);
3690     bool isN1ZExt = isZeroExtended(N1, DAG);
3691     if (isN0ZExt && isN1ZExt)
3692       NewOpc = AArch64ISD::UMULL;
3693     else if (isN1SExt || isN1ZExt) {
3694       // Look for (s/zext A + s/zext B) * (s/zext C). We want to turn these
3695       // into (s/zext A * s/zext C) + (s/zext B * s/zext C)
3696       if (isN1SExt && isAddSubSExt(N0, DAG)) {
3697         NewOpc = AArch64ISD::SMULL;
3698         isMLA = true;
3699       } else if (isN1ZExt && isAddSubZExt(N0, DAG)) {
3700         NewOpc =  AArch64ISD::UMULL;
3701         isMLA = true;
3702       } else if (isN0ZExt && isAddSubZExt(N1, DAG)) {
3703         std::swap(N0, N1);
3704         NewOpc =  AArch64ISD::UMULL;
3705         isMLA = true;
3706       }
3707     }
3708 
3709     if (!NewOpc) {
3710       if (VT == MVT::v2i64)
3711         // Fall through to expand this.  It is not legal.
3712         return SDValue();
3713       else
3714         // Other vector multiplications are legal.
3715         return Op;
3716     }
3717   }
3718 
3719   // Legalize to a S/UMULL instruction
3720   SDLoc DL(Op);
3721   SDValue Op0;
3722   SDValue Op1 = skipExtensionForVectorMULL(N1, DAG);
3723   if (!isMLA) {
3724     Op0 = skipExtensionForVectorMULL(N0, DAG);
3725     assert(Op0.getValueType().is64BitVector() &&
3726            Op1.getValueType().is64BitVector() &&
3727            "unexpected types for extended operands to VMULL");
3728     return DAG.getNode(NewOpc, DL, VT, Op0, Op1);
3729   }
3730   // Optimizing (zext A + zext B) * C, to (S/UMULL A, C) + (S/UMULL B, C) during
3731   // isel lowering to take advantage of no-stall back to back s/umul + s/umla.
3732   // This is true for CPUs with accumulate forwarding such as Cortex-A53/A57
3733   SDValue N00 = skipExtensionForVectorMULL(N0->getOperand(0).getNode(), DAG);
3734   SDValue N01 = skipExtensionForVectorMULL(N0->getOperand(1).getNode(), DAG);
3735   EVT Op1VT = Op1.getValueType();
3736   return DAG.getNode(N0->getOpcode(), DL, VT,
3737                      DAG.getNode(NewOpc, DL, VT,
3738                                DAG.getNode(ISD::BITCAST, DL, Op1VT, N00), Op1),
3739                      DAG.getNode(NewOpc, DL, VT,
3740                                DAG.getNode(ISD::BITCAST, DL, Op1VT, N01), Op1));
3741 }
3742 
getPTrue(SelectionDAG & DAG,SDLoc DL,EVT VT,int Pattern)3743 static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT,
3744                                int Pattern) {
3745   return DAG.getNode(AArch64ISD::PTRUE, DL, VT,
3746                      DAG.getTargetConstant(Pattern, DL, MVT::i32));
3747 }
3748 
lowerConvertToSVBool(SDValue Op,SelectionDAG & DAG)3749 static SDValue lowerConvertToSVBool(SDValue Op, SelectionDAG &DAG) {
3750   SDLoc DL(Op);
3751   EVT OutVT = Op.getValueType();
3752   SDValue InOp = Op.getOperand(1);
3753   EVT InVT = InOp.getValueType();
3754 
3755   // Return the operand if the cast isn't changing type,
3756   // i.e. <n x 16 x i1> -> <n x 16 x i1>
3757   if (InVT == OutVT)
3758     return InOp;
3759 
3760   SDValue Reinterpret =
3761       DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, OutVT, InOp);
3762 
3763   // If the argument converted to an svbool is a ptrue or a comparison, the
3764   // lanes introduced by the widening are zero by construction.
3765   switch (InOp.getOpcode()) {
3766   case AArch64ISD::SETCC_MERGE_ZERO:
3767     return Reinterpret;
3768   case ISD::INTRINSIC_WO_CHAIN:
3769     if (InOp.getConstantOperandVal(0) == Intrinsic::aarch64_sve_ptrue)
3770       return Reinterpret;
3771   }
3772 
3773   // Otherwise, zero the newly introduced lanes.
3774   SDValue Mask = getPTrue(DAG, DL, InVT, AArch64SVEPredPattern::all);
3775   SDValue MaskReinterpret =
3776       DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, OutVT, Mask);
3777   return DAG.getNode(ISD::AND, DL, OutVT, Reinterpret, MaskReinterpret);
3778 }
3779 
LowerINTRINSIC_WO_CHAIN(SDValue Op,SelectionDAG & DAG) const3780 SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
3781                                                      SelectionDAG &DAG) const {
3782   unsigned IntNo = cast<ConstantSDNode>(Op.getOperand(0))->getZExtValue();
3783   SDLoc dl(Op);
3784   switch (IntNo) {
3785   default: return SDValue();    // Don't custom lower most intrinsics.
3786   case Intrinsic::thread_pointer: {
3787     EVT PtrVT = getPointerTy(DAG.getDataLayout());
3788     return DAG.getNode(AArch64ISD::THREAD_POINTER, dl, PtrVT);
3789   }
3790   case Intrinsic::aarch64_neon_abs: {
3791     EVT Ty = Op.getValueType();
3792     if (Ty == MVT::i64) {
3793       SDValue Result = DAG.getNode(ISD::BITCAST, dl, MVT::v1i64,
3794                                    Op.getOperand(1));
3795       Result = DAG.getNode(ISD::ABS, dl, MVT::v1i64, Result);
3796       return DAG.getNode(ISD::BITCAST, dl, MVT::i64, Result);
3797     } else if (Ty.isVector() && Ty.isInteger() && isTypeLegal(Ty)) {
3798       return DAG.getNode(ISD::ABS, dl, Ty, Op.getOperand(1));
3799     } else {
3800       report_fatal_error("Unexpected type for AArch64 NEON intrinic");
3801     }
3802   }
3803   case Intrinsic::aarch64_neon_smax:
3804     return DAG.getNode(ISD::SMAX, dl, Op.getValueType(),
3805                        Op.getOperand(1), Op.getOperand(2));
3806   case Intrinsic::aarch64_neon_umax:
3807     return DAG.getNode(ISD::UMAX, dl, Op.getValueType(),
3808                        Op.getOperand(1), Op.getOperand(2));
3809   case Intrinsic::aarch64_neon_smin:
3810     return DAG.getNode(ISD::SMIN, dl, Op.getValueType(),
3811                        Op.getOperand(1), Op.getOperand(2));
3812   case Intrinsic::aarch64_neon_umin:
3813     return DAG.getNode(ISD::UMIN, dl, Op.getValueType(),
3814                        Op.getOperand(1), Op.getOperand(2));
3815 
3816   case Intrinsic::aarch64_sve_sunpkhi:
3817     return DAG.getNode(AArch64ISD::SUNPKHI, dl, Op.getValueType(),
3818                        Op.getOperand(1));
3819   case Intrinsic::aarch64_sve_sunpklo:
3820     return DAG.getNode(AArch64ISD::SUNPKLO, dl, Op.getValueType(),
3821                        Op.getOperand(1));
3822   case Intrinsic::aarch64_sve_uunpkhi:
3823     return DAG.getNode(AArch64ISD::UUNPKHI, dl, Op.getValueType(),
3824                        Op.getOperand(1));
3825   case Intrinsic::aarch64_sve_uunpklo:
3826     return DAG.getNode(AArch64ISD::UUNPKLO, dl, Op.getValueType(),
3827                        Op.getOperand(1));
3828   case Intrinsic::aarch64_sve_clasta_n:
3829     return DAG.getNode(AArch64ISD::CLASTA_N, dl, Op.getValueType(),
3830                        Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
3831   case Intrinsic::aarch64_sve_clastb_n:
3832     return DAG.getNode(AArch64ISD::CLASTB_N, dl, Op.getValueType(),
3833                        Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
3834   case Intrinsic::aarch64_sve_lasta:
3835     return DAG.getNode(AArch64ISD::LASTA, dl, Op.getValueType(),
3836                        Op.getOperand(1), Op.getOperand(2));
3837   case Intrinsic::aarch64_sve_lastb:
3838     return DAG.getNode(AArch64ISD::LASTB, dl, Op.getValueType(),
3839                        Op.getOperand(1), Op.getOperand(2));
3840   case Intrinsic::aarch64_sve_rev:
3841     return DAG.getNode(ISD::VECTOR_REVERSE, dl, Op.getValueType(),
3842                        Op.getOperand(1));
3843   case Intrinsic::aarch64_sve_tbl:
3844     return DAG.getNode(AArch64ISD::TBL, dl, Op.getValueType(),
3845                        Op.getOperand(1), Op.getOperand(2));
3846   case Intrinsic::aarch64_sve_trn1:
3847     return DAG.getNode(AArch64ISD::TRN1, dl, Op.getValueType(),
3848                        Op.getOperand(1), Op.getOperand(2));
3849   case Intrinsic::aarch64_sve_trn2:
3850     return DAG.getNode(AArch64ISD::TRN2, dl, Op.getValueType(),
3851                        Op.getOperand(1), Op.getOperand(2));
3852   case Intrinsic::aarch64_sve_uzp1:
3853     return DAG.getNode(AArch64ISD::UZP1, dl, Op.getValueType(),
3854                        Op.getOperand(1), Op.getOperand(2));
3855   case Intrinsic::aarch64_sve_uzp2:
3856     return DAG.getNode(AArch64ISD::UZP2, dl, Op.getValueType(),
3857                        Op.getOperand(1), Op.getOperand(2));
3858   case Intrinsic::aarch64_sve_zip1:
3859     return DAG.getNode(AArch64ISD::ZIP1, dl, Op.getValueType(),
3860                        Op.getOperand(1), Op.getOperand(2));
3861   case Intrinsic::aarch64_sve_zip2:
3862     return DAG.getNode(AArch64ISD::ZIP2, dl, Op.getValueType(),
3863                        Op.getOperand(1), Op.getOperand(2));
3864   case Intrinsic::aarch64_sve_ptrue:
3865     return DAG.getNode(AArch64ISD::PTRUE, dl, Op.getValueType(),
3866                        Op.getOperand(1));
3867   case Intrinsic::aarch64_sve_clz:
3868     return DAG.getNode(AArch64ISD::CTLZ_MERGE_PASSTHRU, dl, Op.getValueType(),
3869                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
3870   case Intrinsic::aarch64_sve_cnt: {
3871     SDValue Data = Op.getOperand(3);
3872     // CTPOP only supports integer operands.
3873     if (Data.getValueType().isFloatingPoint())
3874       Data = DAG.getNode(ISD::BITCAST, dl, Op.getValueType(), Data);
3875     return DAG.getNode(AArch64ISD::CTPOP_MERGE_PASSTHRU, dl, Op.getValueType(),
3876                        Op.getOperand(2), Data, Op.getOperand(1));
3877   }
3878   case Intrinsic::aarch64_sve_dupq_lane:
3879     return LowerDUPQLane(Op, DAG);
3880   case Intrinsic::aarch64_sve_convert_from_svbool:
3881     return DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, Op.getValueType(),
3882                        Op.getOperand(1));
3883   case Intrinsic::aarch64_sve_convert_to_svbool:
3884     return lowerConvertToSVBool(Op, DAG);
3885   case Intrinsic::aarch64_sve_fneg:
3886     return DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, dl, Op.getValueType(),
3887                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
3888   case Intrinsic::aarch64_sve_frintp:
3889     return DAG.getNode(AArch64ISD::FCEIL_MERGE_PASSTHRU, dl, Op.getValueType(),
3890                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
3891   case Intrinsic::aarch64_sve_frintm:
3892     return DAG.getNode(AArch64ISD::FFLOOR_MERGE_PASSTHRU, dl, Op.getValueType(),
3893                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
3894   case Intrinsic::aarch64_sve_frinti:
3895     return DAG.getNode(AArch64ISD::FNEARBYINT_MERGE_PASSTHRU, dl, Op.getValueType(),
3896                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
3897   case Intrinsic::aarch64_sve_frintx:
3898     return DAG.getNode(AArch64ISD::FRINT_MERGE_PASSTHRU, dl, Op.getValueType(),
3899                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
3900   case Intrinsic::aarch64_sve_frinta:
3901     return DAG.getNode(AArch64ISD::FROUND_MERGE_PASSTHRU, dl, Op.getValueType(),
3902                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
3903   case Intrinsic::aarch64_sve_frintn:
3904     return DAG.getNode(AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU, dl, Op.getValueType(),
3905                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
3906   case Intrinsic::aarch64_sve_frintz:
3907     return DAG.getNode(AArch64ISD::FTRUNC_MERGE_PASSTHRU, dl, Op.getValueType(),
3908                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
3909   case Intrinsic::aarch64_sve_ucvtf:
3910     return DAG.getNode(AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU, dl,
3911                        Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
3912                        Op.getOperand(1));
3913   case Intrinsic::aarch64_sve_scvtf:
3914     return DAG.getNode(AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU, dl,
3915                        Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
3916                        Op.getOperand(1));
3917   case Intrinsic::aarch64_sve_fcvtzu:
3918     return DAG.getNode(AArch64ISD::FCVTZU_MERGE_PASSTHRU, dl,
3919                        Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
3920                        Op.getOperand(1));
3921   case Intrinsic::aarch64_sve_fcvtzs:
3922     return DAG.getNode(AArch64ISD::FCVTZS_MERGE_PASSTHRU, dl,
3923                        Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
3924                        Op.getOperand(1));
3925   case Intrinsic::aarch64_sve_fsqrt:
3926     return DAG.getNode(AArch64ISD::FSQRT_MERGE_PASSTHRU, dl, Op.getValueType(),
3927                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
3928   case Intrinsic::aarch64_sve_frecpx:
3929     return DAG.getNode(AArch64ISD::FRECPX_MERGE_PASSTHRU, dl, Op.getValueType(),
3930                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
3931   case Intrinsic::aarch64_sve_fabs:
3932     return DAG.getNode(AArch64ISD::FABS_MERGE_PASSTHRU, dl, Op.getValueType(),
3933                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
3934   case Intrinsic::aarch64_sve_abs:
3935     return DAG.getNode(AArch64ISD::ABS_MERGE_PASSTHRU, dl, Op.getValueType(),
3936                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
3937   case Intrinsic::aarch64_sve_neg:
3938     return DAG.getNode(AArch64ISD::NEG_MERGE_PASSTHRU, dl, Op.getValueType(),
3939                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
3940   case Intrinsic::aarch64_sve_insr: {
3941     SDValue Scalar = Op.getOperand(2);
3942     EVT ScalarTy = Scalar.getValueType();
3943     if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16))
3944       Scalar = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, Scalar);
3945 
3946     return DAG.getNode(AArch64ISD::INSR, dl, Op.getValueType(),
3947                        Op.getOperand(1), Scalar);
3948   }
3949   case Intrinsic::aarch64_sve_rbit:
3950     return DAG.getNode(AArch64ISD::BITREVERSE_MERGE_PASSTHRU, dl,
3951                        Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
3952                        Op.getOperand(1));
3953   case Intrinsic::aarch64_sve_revb:
3954     return DAG.getNode(AArch64ISD::BSWAP_MERGE_PASSTHRU, dl, Op.getValueType(),
3955                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
3956   case Intrinsic::aarch64_sve_sxtb:
3957     return DAG.getNode(
3958         AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
3959         Op.getOperand(2), Op.getOperand(3),
3960         DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i8)),
3961         Op.getOperand(1));
3962   case Intrinsic::aarch64_sve_sxth:
3963     return DAG.getNode(
3964         AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
3965         Op.getOperand(2), Op.getOperand(3),
3966         DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i16)),
3967         Op.getOperand(1));
3968   case Intrinsic::aarch64_sve_sxtw:
3969     return DAG.getNode(
3970         AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
3971         Op.getOperand(2), Op.getOperand(3),
3972         DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i32)),
3973         Op.getOperand(1));
3974   case Intrinsic::aarch64_sve_uxtb:
3975     return DAG.getNode(
3976         AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
3977         Op.getOperand(2), Op.getOperand(3),
3978         DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i8)),
3979         Op.getOperand(1));
3980   case Intrinsic::aarch64_sve_uxth:
3981     return DAG.getNode(
3982         AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
3983         Op.getOperand(2), Op.getOperand(3),
3984         DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i16)),
3985         Op.getOperand(1));
3986   case Intrinsic::aarch64_sve_uxtw:
3987     return DAG.getNode(
3988         AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
3989         Op.getOperand(2), Op.getOperand(3),
3990         DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i32)),
3991         Op.getOperand(1));
3992 
3993   case Intrinsic::localaddress: {
3994     const auto &MF = DAG.getMachineFunction();
3995     const auto *RegInfo = Subtarget->getRegisterInfo();
3996     unsigned Reg = RegInfo->getLocalAddressRegister(MF);
3997     return DAG.getCopyFromReg(DAG.getEntryNode(), dl, Reg,
3998                               Op.getSimpleValueType());
3999   }
4000 
4001   case Intrinsic::eh_recoverfp: {
4002     // FIXME: This needs to be implemented to correctly handle highly aligned
4003     // stack objects. For now we simply return the incoming FP. Refer D53541
4004     // for more details.
4005     SDValue FnOp = Op.getOperand(1);
4006     SDValue IncomingFPOp = Op.getOperand(2);
4007     GlobalAddressSDNode *GSD = dyn_cast<GlobalAddressSDNode>(FnOp);
4008     auto *Fn = dyn_cast_or_null<Function>(GSD ? GSD->getGlobal() : nullptr);
4009     if (!Fn)
4010       report_fatal_error(
4011           "llvm.eh.recoverfp must take a function as the first argument");
4012     return IncomingFPOp;
4013   }
4014 
4015   case Intrinsic::aarch64_neon_vsri:
4016   case Intrinsic::aarch64_neon_vsli: {
4017     EVT Ty = Op.getValueType();
4018 
4019     if (!Ty.isVector())
4020       report_fatal_error("Unexpected type for aarch64_neon_vsli");
4021 
4022     assert(Op.getConstantOperandVal(3) <= Ty.getScalarSizeInBits());
4023 
4024     bool IsShiftRight = IntNo == Intrinsic::aarch64_neon_vsri;
4025     unsigned Opcode = IsShiftRight ? AArch64ISD::VSRI : AArch64ISD::VSLI;
4026     return DAG.getNode(Opcode, dl, Ty, Op.getOperand(1), Op.getOperand(2),
4027                        Op.getOperand(3));
4028   }
4029 
4030   case Intrinsic::aarch64_neon_srhadd:
4031   case Intrinsic::aarch64_neon_urhadd:
4032   case Intrinsic::aarch64_neon_shadd:
4033   case Intrinsic::aarch64_neon_uhadd: {
4034     bool IsSignedAdd = (IntNo == Intrinsic::aarch64_neon_srhadd ||
4035                         IntNo == Intrinsic::aarch64_neon_shadd);
4036     bool IsRoundingAdd = (IntNo == Intrinsic::aarch64_neon_srhadd ||
4037                           IntNo == Intrinsic::aarch64_neon_urhadd);
4038     unsigned Opcode =
4039         IsSignedAdd ? (IsRoundingAdd ? AArch64ISD::SRHADD : AArch64ISD::SHADD)
4040                     : (IsRoundingAdd ? AArch64ISD::URHADD : AArch64ISD::UHADD);
4041     return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
4042                        Op.getOperand(2));
4043   }
4044   case Intrinsic::aarch64_neon_sabd:
4045   case Intrinsic::aarch64_neon_uabd: {
4046     unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? AArch64ISD::UABD
4047                                                             : AArch64ISD::SABD;
4048     return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
4049                        Op.getOperand(2));
4050   }
4051   case Intrinsic::aarch64_neon_sdot:
4052   case Intrinsic::aarch64_neon_udot:
4053   case Intrinsic::aarch64_sve_sdot:
4054   case Intrinsic::aarch64_sve_udot: {
4055     unsigned Opcode = (IntNo == Intrinsic::aarch64_neon_udot ||
4056                        IntNo == Intrinsic::aarch64_sve_udot)
4057                           ? AArch64ISD::UDOT
4058                           : AArch64ISD::SDOT;
4059     return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
4060                        Op.getOperand(2), Op.getOperand(3));
4061   }
4062   }
4063 }
4064 
shouldExtendGSIndex(EVT VT,EVT & EltTy) const4065 bool AArch64TargetLowering::shouldExtendGSIndex(EVT VT, EVT &EltTy) const {
4066   if (VT.getVectorElementType() == MVT::i8 ||
4067       VT.getVectorElementType() == MVT::i16) {
4068     EltTy = MVT::i32;
4069     return true;
4070   }
4071   return false;
4072 }
4073 
shouldRemoveExtendFromGSIndex(EVT VT) const4074 bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(EVT VT) const {
4075   if (VT.getVectorElementType() == MVT::i32 &&
4076       VT.getVectorElementCount().getKnownMinValue() >= 4)
4077     return true;
4078 
4079   return false;
4080 }
4081 
isVectorLoadExtDesirable(SDValue ExtVal) const4082 bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const {
4083   return ExtVal.getValueType().isScalableVector();
4084 }
4085 
getGatherVecOpcode(bool IsScaled,bool IsSigned,bool NeedsExtend)4086 unsigned getGatherVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) {
4087   std::map<std::tuple<bool, bool, bool>, unsigned> AddrModes = {
4088       {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ false),
4089        AArch64ISD::GLD1_MERGE_ZERO},
4090       {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ true),
4091        AArch64ISD::GLD1_UXTW_MERGE_ZERO},
4092       {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ false),
4093        AArch64ISD::GLD1_MERGE_ZERO},
4094       {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ true),
4095        AArch64ISD::GLD1_SXTW_MERGE_ZERO},
4096       {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ false),
4097        AArch64ISD::GLD1_SCALED_MERGE_ZERO},
4098       {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ true),
4099        AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO},
4100       {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ false),
4101        AArch64ISD::GLD1_SCALED_MERGE_ZERO},
4102       {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ true),
4103        AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO},
4104   };
4105   auto Key = std::make_tuple(IsScaled, IsSigned, NeedsExtend);
4106   return AddrModes.find(Key)->second;
4107 }
4108 
getScatterVecOpcode(bool IsScaled,bool IsSigned,bool NeedsExtend)4109 unsigned getScatterVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) {
4110   std::map<std::tuple<bool, bool, bool>, unsigned> AddrModes = {
4111       {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ false),
4112        AArch64ISD::SST1_PRED},
4113       {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ true),
4114        AArch64ISD::SST1_UXTW_PRED},
4115       {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ false),
4116        AArch64ISD::SST1_PRED},
4117       {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ true),
4118        AArch64ISD::SST1_SXTW_PRED},
4119       {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ false),
4120        AArch64ISD::SST1_SCALED_PRED},
4121       {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ true),
4122        AArch64ISD::SST1_UXTW_SCALED_PRED},
4123       {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ false),
4124        AArch64ISD::SST1_SCALED_PRED},
4125       {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ true),
4126        AArch64ISD::SST1_SXTW_SCALED_PRED},
4127   };
4128   auto Key = std::make_tuple(IsScaled, IsSigned, NeedsExtend);
4129   return AddrModes.find(Key)->second;
4130 }
4131 
getSignExtendedGatherOpcode(unsigned Opcode)4132 unsigned getSignExtendedGatherOpcode(unsigned Opcode) {
4133   switch (Opcode) {
4134   default:
4135     llvm_unreachable("unimplemented opcode");
4136     return Opcode;
4137   case AArch64ISD::GLD1_MERGE_ZERO:
4138     return AArch64ISD::GLD1S_MERGE_ZERO;
4139   case AArch64ISD::GLD1_IMM_MERGE_ZERO:
4140     return AArch64ISD::GLD1S_IMM_MERGE_ZERO;
4141   case AArch64ISD::GLD1_UXTW_MERGE_ZERO:
4142     return AArch64ISD::GLD1S_UXTW_MERGE_ZERO;
4143   case AArch64ISD::GLD1_SXTW_MERGE_ZERO:
4144     return AArch64ISD::GLD1S_SXTW_MERGE_ZERO;
4145   case AArch64ISD::GLD1_SCALED_MERGE_ZERO:
4146     return AArch64ISD::GLD1S_SCALED_MERGE_ZERO;
4147   case AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO:
4148     return AArch64ISD::GLD1S_UXTW_SCALED_MERGE_ZERO;
4149   case AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO:
4150     return AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO;
4151   }
4152 }
4153 
getGatherScatterIndexIsExtended(SDValue Index)4154 bool getGatherScatterIndexIsExtended(SDValue Index) {
4155   unsigned Opcode = Index.getOpcode();
4156   if (Opcode == ISD::SIGN_EXTEND_INREG)
4157     return true;
4158 
4159   if (Opcode == ISD::AND) {
4160     SDValue Splat = Index.getOperand(1);
4161     if (Splat.getOpcode() != ISD::SPLAT_VECTOR)
4162       return false;
4163     ConstantSDNode *Mask = dyn_cast<ConstantSDNode>(Splat.getOperand(0));
4164     if (!Mask || Mask->getZExtValue() != 0xFFFFFFFF)
4165       return false;
4166     return true;
4167   }
4168 
4169   return false;
4170 }
4171 
4172 // If the base pointer of a masked gather or scatter is null, we
4173 // may be able to swap BasePtr & Index and use the vector + register
4174 // or vector + immediate addressing mode, e.g.
4175 // VECTOR + REGISTER:
4176 //    getelementptr nullptr, <vscale x N x T> (splat(%offset)) + %indices)
4177 // -> getelementptr %offset, <vscale x N x T> %indices
4178 // VECTOR + IMMEDIATE:
4179 //    getelementptr nullptr, <vscale x N x T> (splat(#x)) + %indices)
4180 // -> getelementptr #x, <vscale x N x T> %indices
selectGatherScatterAddrMode(SDValue & BasePtr,SDValue & Index,EVT MemVT,unsigned & Opcode,bool IsGather,SelectionDAG & DAG)4181 void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index, EVT MemVT,
4182                                  unsigned &Opcode, bool IsGather,
4183                                  SelectionDAG &DAG) {
4184   if (!isNullConstant(BasePtr))
4185     return;
4186 
4187   ConstantSDNode *Offset = nullptr;
4188   if (Index.getOpcode() == ISD::ADD)
4189     if (auto SplatVal = DAG.getSplatValue(Index.getOperand(1))) {
4190       if (isa<ConstantSDNode>(SplatVal))
4191         Offset = cast<ConstantSDNode>(SplatVal);
4192       else {
4193         BasePtr = SplatVal;
4194         Index = Index->getOperand(0);
4195         return;
4196       }
4197     }
4198 
4199   unsigned NewOp =
4200       IsGather ? AArch64ISD::GLD1_IMM_MERGE_ZERO : AArch64ISD::SST1_IMM_PRED;
4201 
4202   if (!Offset) {
4203     std::swap(BasePtr, Index);
4204     Opcode = NewOp;
4205     return;
4206   }
4207 
4208   uint64_t OffsetVal = Offset->getZExtValue();
4209   unsigned ScalarSizeInBytes = MemVT.getScalarSizeInBits() / 8;
4210   auto ConstOffset = DAG.getConstant(OffsetVal, SDLoc(Index), MVT::i64);
4211 
4212   if (OffsetVal % ScalarSizeInBytes || OffsetVal / ScalarSizeInBytes > 31) {
4213     // Index is out of range for the immediate addressing mode
4214     BasePtr = ConstOffset;
4215     Index = Index->getOperand(0);
4216     return;
4217   }
4218 
4219   // Immediate is in range
4220   Opcode = NewOp;
4221   BasePtr = Index->getOperand(0);
4222   Index = ConstOffset;
4223 }
4224 
LowerMGATHER(SDValue Op,SelectionDAG & DAG) const4225 SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
4226                                             SelectionDAG &DAG) const {
4227   SDLoc DL(Op);
4228   MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(Op);
4229   assert(MGT && "Can only custom lower gather load nodes");
4230 
4231   SDValue Index = MGT->getIndex();
4232   SDValue Chain = MGT->getChain();
4233   SDValue PassThru = MGT->getPassThru();
4234   SDValue Mask = MGT->getMask();
4235   SDValue BasePtr = MGT->getBasePtr();
4236   ISD::LoadExtType ExtTy = MGT->getExtensionType();
4237 
4238   ISD::MemIndexType IndexType = MGT->getIndexType();
4239   bool IsScaled =
4240       IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED;
4241   bool IsSigned =
4242       IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED;
4243   bool IdxNeedsExtend =
4244       getGatherScatterIndexIsExtended(Index) ||
4245       Index.getSimpleValueType().getVectorElementType() == MVT::i32;
4246   bool ResNeedsSignExtend = ExtTy == ISD::EXTLOAD || ExtTy == ISD::SEXTLOAD;
4247 
4248   EVT VT = PassThru.getSimpleValueType();
4249   EVT MemVT = MGT->getMemoryVT();
4250   SDValue InputVT = DAG.getValueType(MemVT);
4251 
4252   if (VT.getVectorElementType() == MVT::bf16 &&
4253       !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
4254     return SDValue();
4255 
4256   // Handle FP data by using an integer gather and casting the result.
4257   if (VT.isFloatingPoint()) {
4258     EVT PassThruVT = getPackedSVEVectorVT(VT.getVectorElementCount());
4259     PassThru = getSVESafeBitCast(PassThruVT, PassThru, DAG);
4260     InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
4261   }
4262 
4263   SDVTList VTs = DAG.getVTList(PassThru.getSimpleValueType(), MVT::Other);
4264 
4265   if (getGatherScatterIndexIsExtended(Index))
4266     Index = Index.getOperand(0);
4267 
4268   unsigned Opcode = getGatherVecOpcode(IsScaled, IsSigned, IdxNeedsExtend);
4269   selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode,
4270                               /*isGather=*/true, DAG);
4271 
4272   if (ResNeedsSignExtend)
4273     Opcode = getSignExtendedGatherOpcode(Opcode);
4274 
4275   SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT, PassThru};
4276   SDValue Gather = DAG.getNode(Opcode, DL, VTs, Ops);
4277 
4278   if (VT.isFloatingPoint()) {
4279     SDValue Cast = getSVESafeBitCast(VT, Gather, DAG);
4280     return DAG.getMergeValues({Cast, Gather.getValue(1)}, DL);
4281   }
4282 
4283   return Gather;
4284 }
4285 
LowerMSCATTER(SDValue Op,SelectionDAG & DAG) const4286 SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
4287                                              SelectionDAG &DAG) const {
4288   SDLoc DL(Op);
4289   MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(Op);
4290   assert(MSC && "Can only custom lower scatter store nodes");
4291 
4292   SDValue Index = MSC->getIndex();
4293   SDValue Chain = MSC->getChain();
4294   SDValue StoreVal = MSC->getValue();
4295   SDValue Mask = MSC->getMask();
4296   SDValue BasePtr = MSC->getBasePtr();
4297 
4298   ISD::MemIndexType IndexType = MSC->getIndexType();
4299   bool IsScaled =
4300       IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED;
4301   bool IsSigned =
4302       IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED;
4303   bool NeedsExtend =
4304       getGatherScatterIndexIsExtended(Index) ||
4305       Index.getSimpleValueType().getVectorElementType() == MVT::i32;
4306 
4307   EVT VT = StoreVal.getSimpleValueType();
4308   SDVTList VTs = DAG.getVTList(MVT::Other);
4309   EVT MemVT = MSC->getMemoryVT();
4310   SDValue InputVT = DAG.getValueType(MemVT);
4311 
4312   if (VT.getVectorElementType() == MVT::bf16 &&
4313       !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
4314     return SDValue();
4315 
4316   // Handle FP data by casting the data so an integer scatter can be used.
4317   if (VT.isFloatingPoint()) {
4318     EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount());
4319     StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG);
4320     InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
4321   }
4322 
4323   if (getGatherScatterIndexIsExtended(Index))
4324     Index = Index.getOperand(0);
4325 
4326   unsigned Opcode = getScatterVecOpcode(IsScaled, IsSigned, NeedsExtend);
4327   selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode,
4328                               /*isGather=*/false, DAG);
4329 
4330   SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT};
4331   return DAG.getNode(Opcode, DL, VTs, Ops);
4332 }
4333 
4334 // Custom lower trunc store for v4i8 vectors, since it is promoted to v4i16.
LowerTruncateVectorStore(SDLoc DL,StoreSDNode * ST,EVT VT,EVT MemVT,SelectionDAG & DAG)4335 static SDValue LowerTruncateVectorStore(SDLoc DL, StoreSDNode *ST,
4336                                         EVT VT, EVT MemVT,
4337                                         SelectionDAG &DAG) {
4338   assert(VT.isVector() && "VT should be a vector type");
4339   assert(MemVT == MVT::v4i8 && VT == MVT::v4i16);
4340 
4341   SDValue Value = ST->getValue();
4342 
4343   // It first extend the promoted v4i16 to v8i16, truncate to v8i8, and extract
4344   // the word lane which represent the v4i8 subvector.  It optimizes the store
4345   // to:
4346   //
4347   //   xtn  v0.8b, v0.8h
4348   //   str  s0, [x0]
4349 
4350   SDValue Undef = DAG.getUNDEF(MVT::i16);
4351   SDValue UndefVec = DAG.getBuildVector(MVT::v4i16, DL,
4352                                         {Undef, Undef, Undef, Undef});
4353 
4354   SDValue TruncExt = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i16,
4355                                  Value, UndefVec);
4356   SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i8, TruncExt);
4357 
4358   Trunc = DAG.getNode(ISD::BITCAST, DL, MVT::v2i32, Trunc);
4359   SDValue ExtractTrunc = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32,
4360                                      Trunc, DAG.getConstant(0, DL, MVT::i64));
4361 
4362   return DAG.getStore(ST->getChain(), DL, ExtractTrunc,
4363                       ST->getBasePtr(), ST->getMemOperand());
4364 }
4365 
4366 // Custom lowering for any store, vector or scalar and/or default or with
4367 // a truncate operations.  Currently only custom lower truncate operation
4368 // from vector v4i16 to v4i8 or volatile stores of i128.
LowerSTORE(SDValue Op,SelectionDAG & DAG) const4369 SDValue AArch64TargetLowering::LowerSTORE(SDValue Op,
4370                                           SelectionDAG &DAG) const {
4371   SDLoc Dl(Op);
4372   StoreSDNode *StoreNode = cast<StoreSDNode>(Op);
4373   assert (StoreNode && "Can only custom lower store nodes");
4374 
4375   SDValue Value = StoreNode->getValue();
4376 
4377   EVT VT = Value.getValueType();
4378   EVT MemVT = StoreNode->getMemoryVT();
4379 
4380   if (VT.isVector()) {
4381     if (useSVEForFixedLengthVectorVT(VT))
4382       return LowerFixedLengthVectorStoreToSVE(Op, DAG);
4383 
4384     unsigned AS = StoreNode->getAddressSpace();
4385     Align Alignment = StoreNode->getAlign();
4386     if (Alignment < MemVT.getStoreSize() &&
4387         !allowsMisalignedMemoryAccesses(MemVT, AS, Alignment,
4388                                         StoreNode->getMemOperand()->getFlags(),
4389                                         nullptr)) {
4390       return scalarizeVectorStore(StoreNode, DAG);
4391     }
4392 
4393     if (StoreNode->isTruncatingStore()) {
4394       return LowerTruncateVectorStore(Dl, StoreNode, VT, MemVT, DAG);
4395     }
4396     // 256 bit non-temporal stores can be lowered to STNP. Do this as part of
4397     // the custom lowering, as there are no un-paired non-temporal stores and
4398     // legalization will break up 256 bit inputs.
4399     ElementCount EC = MemVT.getVectorElementCount();
4400     if (StoreNode->isNonTemporal() && MemVT.getSizeInBits() == 256u &&
4401         EC.isKnownEven() &&
4402         ((MemVT.getScalarSizeInBits() == 8u ||
4403           MemVT.getScalarSizeInBits() == 16u ||
4404           MemVT.getScalarSizeInBits() == 32u ||
4405           MemVT.getScalarSizeInBits() == 64u))) {
4406       SDValue Lo =
4407           DAG.getNode(ISD::EXTRACT_SUBVECTOR, Dl,
4408                       MemVT.getHalfNumVectorElementsVT(*DAG.getContext()),
4409                       StoreNode->getValue(), DAG.getConstant(0, Dl, MVT::i64));
4410       SDValue Hi =
4411           DAG.getNode(ISD::EXTRACT_SUBVECTOR, Dl,
4412                       MemVT.getHalfNumVectorElementsVT(*DAG.getContext()),
4413                       StoreNode->getValue(),
4414                       DAG.getConstant(EC.getKnownMinValue() / 2, Dl, MVT::i64));
4415       SDValue Result = DAG.getMemIntrinsicNode(
4416           AArch64ISD::STNP, Dl, DAG.getVTList(MVT::Other),
4417           {StoreNode->getChain(), Lo, Hi, StoreNode->getBasePtr()},
4418           StoreNode->getMemoryVT(), StoreNode->getMemOperand());
4419       return Result;
4420     }
4421   } else if (MemVT == MVT::i128 && StoreNode->isVolatile()) {
4422     assert(StoreNode->getValue()->getValueType(0) == MVT::i128);
4423     SDValue Lo =
4424         DAG.getNode(ISD::EXTRACT_ELEMENT, Dl, MVT::i64, StoreNode->getValue(),
4425                     DAG.getConstant(0, Dl, MVT::i64));
4426     SDValue Hi =
4427         DAG.getNode(ISD::EXTRACT_ELEMENT, Dl, MVT::i64, StoreNode->getValue(),
4428                     DAG.getConstant(1, Dl, MVT::i64));
4429     SDValue Result = DAG.getMemIntrinsicNode(
4430         AArch64ISD::STP, Dl, DAG.getVTList(MVT::Other),
4431         {StoreNode->getChain(), Lo, Hi, StoreNode->getBasePtr()},
4432         StoreNode->getMemoryVT(), StoreNode->getMemOperand());
4433     return Result;
4434   }
4435 
4436   return SDValue();
4437 }
4438 
4439 // Generate SUBS and CSEL for integer abs.
LowerABS(SDValue Op,SelectionDAG & DAG) const4440 SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const {
4441   MVT VT = Op.getSimpleValueType();
4442 
4443   if (VT.isVector())
4444     return LowerToPredicatedOp(Op, DAG, AArch64ISD::ABS_MERGE_PASSTHRU);
4445 
4446   SDLoc DL(Op);
4447   SDValue Neg = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
4448                             Op.getOperand(0));
4449   // Generate SUBS & CSEL.
4450   SDValue Cmp =
4451       DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, MVT::i32),
4452                   Op.getOperand(0), DAG.getConstant(0, DL, VT));
4453   return DAG.getNode(AArch64ISD::CSEL, DL, VT, Op.getOperand(0), Neg,
4454                      DAG.getConstant(AArch64CC::PL, DL, MVT::i32),
4455                      Cmp.getValue(1));
4456 }
4457 
LowerOperation(SDValue Op,SelectionDAG & DAG) const4458 SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
4459                                               SelectionDAG &DAG) const {
4460   LLVM_DEBUG(dbgs() << "Custom lowering: ");
4461   LLVM_DEBUG(Op.dump());
4462 
4463   switch (Op.getOpcode()) {
4464   default:
4465     llvm_unreachable("unimplemented operand");
4466     return SDValue();
4467   case ISD::BITCAST:
4468     return LowerBITCAST(Op, DAG);
4469   case ISD::GlobalAddress:
4470     return LowerGlobalAddress(Op, DAG);
4471   case ISD::GlobalTLSAddress:
4472     return LowerGlobalTLSAddress(Op, DAG);
4473   case ISD::SETCC:
4474   case ISD::STRICT_FSETCC:
4475   case ISD::STRICT_FSETCCS:
4476     return LowerSETCC(Op, DAG);
4477   case ISD::BR_CC:
4478     return LowerBR_CC(Op, DAG);
4479   case ISD::SELECT:
4480     return LowerSELECT(Op, DAG);
4481   case ISD::SELECT_CC:
4482     return LowerSELECT_CC(Op, DAG);
4483   case ISD::JumpTable:
4484     return LowerJumpTable(Op, DAG);
4485   case ISD::BR_JT:
4486     return LowerBR_JT(Op, DAG);
4487   case ISD::ConstantPool:
4488     return LowerConstantPool(Op, DAG);
4489   case ISD::BlockAddress:
4490     return LowerBlockAddress(Op, DAG);
4491   case ISD::VASTART:
4492     return LowerVASTART(Op, DAG);
4493   case ISD::VACOPY:
4494     return LowerVACOPY(Op, DAG);
4495   case ISD::VAARG:
4496     return LowerVAARG(Op, DAG);
4497   case ISD::ADDC:
4498   case ISD::ADDE:
4499   case ISD::SUBC:
4500   case ISD::SUBE:
4501     return LowerADDC_ADDE_SUBC_SUBE(Op, DAG);
4502   case ISD::SADDO:
4503   case ISD::UADDO:
4504   case ISD::SSUBO:
4505   case ISD::USUBO:
4506   case ISD::SMULO:
4507   case ISD::UMULO:
4508     return LowerXALUO(Op, DAG);
4509   case ISD::FADD:
4510     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FADD_PRED);
4511   case ISD::FSUB:
4512     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED);
4513   case ISD::FMUL:
4514     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
4515   case ISD::FMA:
4516     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
4517   case ISD::FDIV:
4518     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED);
4519   case ISD::FNEG:
4520     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU);
4521   case ISD::FCEIL:
4522     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FCEIL_MERGE_PASSTHRU);
4523   case ISD::FFLOOR:
4524     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FFLOOR_MERGE_PASSTHRU);
4525   case ISD::FNEARBYINT:
4526     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEARBYINT_MERGE_PASSTHRU);
4527   case ISD::FRINT:
4528     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FRINT_MERGE_PASSTHRU);
4529   case ISD::FROUND:
4530     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUND_MERGE_PASSTHRU);
4531   case ISD::FROUNDEVEN:
4532     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU);
4533   case ISD::FTRUNC:
4534     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FTRUNC_MERGE_PASSTHRU);
4535   case ISD::FSQRT:
4536     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSQRT_MERGE_PASSTHRU);
4537   case ISD::FABS:
4538     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FABS_MERGE_PASSTHRU);
4539   case ISD::FP_ROUND:
4540   case ISD::STRICT_FP_ROUND:
4541     return LowerFP_ROUND(Op, DAG);
4542   case ISD::FP_EXTEND:
4543     return LowerFP_EXTEND(Op, DAG);
4544   case ISD::FRAMEADDR:
4545     return LowerFRAMEADDR(Op, DAG);
4546   case ISD::SPONENTRY:
4547     return LowerSPONENTRY(Op, DAG);
4548   case ISD::RETURNADDR:
4549     return LowerRETURNADDR(Op, DAG);
4550   case ISD::ADDROFRETURNADDR:
4551     return LowerADDROFRETURNADDR(Op, DAG);
4552   case ISD::CONCAT_VECTORS:
4553     return LowerCONCAT_VECTORS(Op, DAG);
4554   case ISD::INSERT_VECTOR_ELT:
4555     return LowerINSERT_VECTOR_ELT(Op, DAG);
4556   case ISD::EXTRACT_VECTOR_ELT:
4557     return LowerEXTRACT_VECTOR_ELT(Op, DAG);
4558   case ISD::BUILD_VECTOR:
4559     return LowerBUILD_VECTOR(Op, DAG);
4560   case ISD::VECTOR_SHUFFLE:
4561     return LowerVECTOR_SHUFFLE(Op, DAG);
4562   case ISD::SPLAT_VECTOR:
4563     return LowerSPLAT_VECTOR(Op, DAG);
4564   case ISD::EXTRACT_SUBVECTOR:
4565     return LowerEXTRACT_SUBVECTOR(Op, DAG);
4566   case ISD::INSERT_SUBVECTOR:
4567     return LowerINSERT_SUBVECTOR(Op, DAG);
4568   case ISD::SDIV:
4569   case ISD::UDIV:
4570     return LowerDIV(Op, DAG);
4571   case ISD::SMIN:
4572     return LowerToPredicatedOp(Op, DAG, AArch64ISD::SMIN_PRED,
4573                                /*OverrideNEON=*/true);
4574   case ISD::UMIN:
4575     return LowerToPredicatedOp(Op, DAG, AArch64ISD::UMIN_PRED,
4576                                /*OverrideNEON=*/true);
4577   case ISD::SMAX:
4578     return LowerToPredicatedOp(Op, DAG, AArch64ISD::SMAX_PRED,
4579                                /*OverrideNEON=*/true);
4580   case ISD::UMAX:
4581     return LowerToPredicatedOp(Op, DAG, AArch64ISD::UMAX_PRED,
4582                                /*OverrideNEON=*/true);
4583   case ISD::SRA:
4584   case ISD::SRL:
4585   case ISD::SHL:
4586     return LowerVectorSRA_SRL_SHL(Op, DAG);
4587   case ISD::SHL_PARTS:
4588   case ISD::SRL_PARTS:
4589   case ISD::SRA_PARTS:
4590     return LowerShiftParts(Op, DAG);
4591   case ISD::CTPOP:
4592     return LowerCTPOP(Op, DAG);
4593   case ISD::FCOPYSIGN:
4594     return LowerFCOPYSIGN(Op, DAG);
4595   case ISD::OR:
4596     return LowerVectorOR(Op, DAG);
4597   case ISD::XOR:
4598     return LowerXOR(Op, DAG);
4599   case ISD::PREFETCH:
4600     return LowerPREFETCH(Op, DAG);
4601   case ISD::SINT_TO_FP:
4602   case ISD::UINT_TO_FP:
4603   case ISD::STRICT_SINT_TO_FP:
4604   case ISD::STRICT_UINT_TO_FP:
4605     return LowerINT_TO_FP(Op, DAG);
4606   case ISD::FP_TO_SINT:
4607   case ISD::FP_TO_UINT:
4608   case ISD::STRICT_FP_TO_SINT:
4609   case ISD::STRICT_FP_TO_UINT:
4610     return LowerFP_TO_INT(Op, DAG);
4611   case ISD::FP_TO_SINT_SAT:
4612   case ISD::FP_TO_UINT_SAT:
4613     return LowerFP_TO_INT_SAT(Op, DAG);
4614   case ISD::FSINCOS:
4615     return LowerFSINCOS(Op, DAG);
4616   case ISD::FLT_ROUNDS_:
4617     return LowerFLT_ROUNDS_(Op, DAG);
4618   case ISD::SET_ROUNDING:
4619     return LowerSET_ROUNDING(Op, DAG);
4620   case ISD::MUL:
4621     return LowerMUL(Op, DAG);
4622   case ISD::MULHS:
4623     return LowerToPredicatedOp(Op, DAG, AArch64ISD::MULHS_PRED,
4624                                /*OverrideNEON=*/true);
4625   case ISD::MULHU:
4626     return LowerToPredicatedOp(Op, DAG, AArch64ISD::MULHU_PRED,
4627                                /*OverrideNEON=*/true);
4628   case ISD::INTRINSIC_WO_CHAIN:
4629     return LowerINTRINSIC_WO_CHAIN(Op, DAG);
4630   case ISD::STORE:
4631     return LowerSTORE(Op, DAG);
4632   case ISD::MSTORE:
4633     return LowerFixedLengthVectorMStoreToSVE(Op, DAG);
4634   case ISD::MGATHER:
4635     return LowerMGATHER(Op, DAG);
4636   case ISD::MSCATTER:
4637     return LowerMSCATTER(Op, DAG);
4638   case ISD::VECREDUCE_SEQ_FADD:
4639     return LowerVECREDUCE_SEQ_FADD(Op, DAG);
4640   case ISD::VECREDUCE_ADD:
4641   case ISD::VECREDUCE_AND:
4642   case ISD::VECREDUCE_OR:
4643   case ISD::VECREDUCE_XOR:
4644   case ISD::VECREDUCE_SMAX:
4645   case ISD::VECREDUCE_SMIN:
4646   case ISD::VECREDUCE_UMAX:
4647   case ISD::VECREDUCE_UMIN:
4648   case ISD::VECREDUCE_FADD:
4649   case ISD::VECREDUCE_FMAX:
4650   case ISD::VECREDUCE_FMIN:
4651     return LowerVECREDUCE(Op, DAG);
4652   case ISD::ATOMIC_LOAD_SUB:
4653     return LowerATOMIC_LOAD_SUB(Op, DAG);
4654   case ISD::ATOMIC_LOAD_AND:
4655     return LowerATOMIC_LOAD_AND(Op, DAG);
4656   case ISD::DYNAMIC_STACKALLOC:
4657     return LowerDYNAMIC_STACKALLOC(Op, DAG);
4658   case ISD::VSCALE:
4659     return LowerVSCALE(Op, DAG);
4660   case ISD::ANY_EXTEND:
4661   case ISD::SIGN_EXTEND:
4662   case ISD::ZERO_EXTEND:
4663     return LowerFixedLengthVectorIntExtendToSVE(Op, DAG);
4664   case ISD::SIGN_EXTEND_INREG: {
4665     // Only custom lower when ExtraVT has a legal byte based element type.
4666     EVT ExtraVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
4667     EVT ExtraEltVT = ExtraVT.getVectorElementType();
4668     if ((ExtraEltVT != MVT::i8) && (ExtraEltVT != MVT::i16) &&
4669         (ExtraEltVT != MVT::i32) && (ExtraEltVT != MVT::i64))
4670       return SDValue();
4671 
4672     return LowerToPredicatedOp(Op, DAG,
4673                                AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU);
4674   }
4675   case ISD::TRUNCATE:
4676     return LowerTRUNCATE(Op, DAG);
4677   case ISD::MLOAD:
4678     return LowerFixedLengthVectorMLoadToSVE(Op, DAG);
4679   case ISD::LOAD:
4680     if (useSVEForFixedLengthVectorVT(Op.getValueType()))
4681       return LowerFixedLengthVectorLoadToSVE(Op, DAG);
4682     llvm_unreachable("Unexpected request to lower ISD::LOAD");
4683   case ISD::ADD:
4684     return LowerToPredicatedOp(Op, DAG, AArch64ISD::ADD_PRED);
4685   case ISD::AND:
4686     return LowerToScalableOp(Op, DAG);
4687   case ISD::SUB:
4688     return LowerToPredicatedOp(Op, DAG, AArch64ISD::SUB_PRED);
4689   case ISD::FMAXIMUM:
4690     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAX_PRED);
4691   case ISD::FMAXNUM:
4692     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAXNM_PRED);
4693   case ISD::FMINIMUM:
4694     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMIN_PRED);
4695   case ISD::FMINNUM:
4696     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMINNM_PRED);
4697   case ISD::VSELECT:
4698     return LowerFixedLengthVectorSelectToSVE(Op, DAG);
4699   case ISD::ABS:
4700     return LowerABS(Op, DAG);
4701   case ISD::BITREVERSE:
4702     return LowerToPredicatedOp(Op, DAG, AArch64ISD::BITREVERSE_MERGE_PASSTHRU,
4703                                /*OverrideNEON=*/true);
4704   case ISD::BSWAP:
4705     return LowerToPredicatedOp(Op, DAG, AArch64ISD::BSWAP_MERGE_PASSTHRU);
4706   case ISD::CTLZ:
4707     return LowerToPredicatedOp(Op, DAG, AArch64ISD::CTLZ_MERGE_PASSTHRU,
4708                                /*OverrideNEON=*/true);
4709   case ISD::CTTZ:
4710     return LowerCTTZ(Op, DAG);
4711   }
4712 }
4713 
mergeStoresAfterLegalization(EVT VT) const4714 bool AArch64TargetLowering::mergeStoresAfterLegalization(EVT VT) const {
4715   return !Subtarget->useSVEForFixedLengthVectors();
4716 }
4717 
useSVEForFixedLengthVectorVT(EVT VT,bool OverrideNEON) const4718 bool AArch64TargetLowering::useSVEForFixedLengthVectorVT(
4719     EVT VT, bool OverrideNEON) const {
4720   if (!Subtarget->useSVEForFixedLengthVectors())
4721     return false;
4722 
4723   if (!VT.isFixedLengthVector())
4724     return false;
4725 
4726   // Don't use SVE for vectors we cannot scalarize if required.
4727   switch (VT.getVectorElementType().getSimpleVT().SimpleTy) {
4728   // Fixed length predicates should be promoted to i8.
4729   // NOTE: This is consistent with how NEON (and thus 64/128bit vectors) work.
4730   case MVT::i1:
4731   default:
4732     return false;
4733   case MVT::i8:
4734   case MVT::i16:
4735   case MVT::i32:
4736   case MVT::i64:
4737   case MVT::f16:
4738   case MVT::f32:
4739   case MVT::f64:
4740     break;
4741   }
4742 
4743   // All SVE implementations support NEON sized vectors.
4744   if (OverrideNEON && (VT.is128BitVector() || VT.is64BitVector()))
4745     return true;
4746 
4747   // Ensure NEON MVTs only belong to a single register class.
4748   if (VT.getFixedSizeInBits() <= 128)
4749     return false;
4750 
4751   // Don't use SVE for types that don't fit.
4752   if (VT.getFixedSizeInBits() > Subtarget->getMinSVEVectorSizeInBits())
4753     return false;
4754 
4755   // TODO: Perhaps an artificial restriction, but worth having whilst getting
4756   // the base fixed length SVE support in place.
4757   if (!VT.isPow2VectorType())
4758     return false;
4759 
4760   return true;
4761 }
4762 
4763 //===----------------------------------------------------------------------===//
4764 //                      Calling Convention Implementation
4765 //===----------------------------------------------------------------------===//
4766 
4767 /// Selects the correct CCAssignFn for a given CallingConvention value.
CCAssignFnForCall(CallingConv::ID CC,bool IsVarArg) const4768 CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
4769                                                      bool IsVarArg) const {
4770   switch (CC) {
4771   default:
4772     report_fatal_error("Unsupported calling convention.");
4773   case CallingConv::WebKit_JS:
4774     return CC_AArch64_WebKit_JS;
4775   case CallingConv::GHC:
4776     return CC_AArch64_GHC;
4777   case CallingConv::C:
4778   case CallingConv::Fast:
4779   case CallingConv::PreserveMost:
4780   case CallingConv::CXX_FAST_TLS:
4781   case CallingConv::Swift:
4782   case CallingConv::SwiftTail:
4783   case CallingConv::Tail:
4784     if (Subtarget->isTargetWindows() && IsVarArg)
4785       return CC_AArch64_Win64_VarArg;
4786     if (!Subtarget->isTargetDarwin())
4787       return CC_AArch64_AAPCS;
4788     if (!IsVarArg)
4789       return CC_AArch64_DarwinPCS;
4790     return Subtarget->isTargetILP32() ? CC_AArch64_DarwinPCS_ILP32_VarArg
4791                                       : CC_AArch64_DarwinPCS_VarArg;
4792    case CallingConv::Win64:
4793     return IsVarArg ? CC_AArch64_Win64_VarArg : CC_AArch64_AAPCS;
4794    case CallingConv::CFGuard_Check:
4795      return CC_AArch64_Win64_CFGuard_Check;
4796    case CallingConv::AArch64_VectorCall:
4797    case CallingConv::AArch64_SVE_VectorCall:
4798      return CC_AArch64_AAPCS;
4799   }
4800 }
4801 
4802 CCAssignFn *
CCAssignFnForReturn(CallingConv::ID CC) const4803 AArch64TargetLowering::CCAssignFnForReturn(CallingConv::ID CC) const {
4804   return CC == CallingConv::WebKit_JS ? RetCC_AArch64_WebKit_JS
4805                                       : RetCC_AArch64_AAPCS;
4806 }
4807 
LowerFormalArguments(SDValue Chain,CallingConv::ID CallConv,bool isVarArg,const SmallVectorImpl<ISD::InputArg> & Ins,const SDLoc & DL,SelectionDAG & DAG,SmallVectorImpl<SDValue> & InVals) const4808 SDValue AArch64TargetLowering::LowerFormalArguments(
4809     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
4810     const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
4811     SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
4812   MachineFunction &MF = DAG.getMachineFunction();
4813   MachineFrameInfo &MFI = MF.getFrameInfo();
4814   bool IsWin64 = Subtarget->isCallingConvWin64(MF.getFunction().getCallingConv());
4815 
4816   // Assign locations to all of the incoming arguments.
4817   SmallVector<CCValAssign, 16> ArgLocs;
4818   DenseMap<unsigned, SDValue> CopiedRegs;
4819   CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(), ArgLocs,
4820                  *DAG.getContext());
4821 
4822   // At this point, Ins[].VT may already be promoted to i32. To correctly
4823   // handle passing i8 as i8 instead of i32 on stack, we pass in both i32 and
4824   // i8 to CC_AArch64_AAPCS with i32 being ValVT and i8 being LocVT.
4825   // Since AnalyzeFormalArguments uses Ins[].VT for both ValVT and LocVT, here
4826   // we use a special version of AnalyzeFormalArguments to pass in ValVT and
4827   // LocVT.
4828   unsigned NumArgs = Ins.size();
4829   Function::const_arg_iterator CurOrigArg = MF.getFunction().arg_begin();
4830   unsigned CurArgIdx = 0;
4831   for (unsigned i = 0; i != NumArgs; ++i) {
4832     MVT ValVT = Ins[i].VT;
4833     if (Ins[i].isOrigArg()) {
4834       std::advance(CurOrigArg, Ins[i].getOrigArgIndex() - CurArgIdx);
4835       CurArgIdx = Ins[i].getOrigArgIndex();
4836 
4837       // Get type of the original argument.
4838       EVT ActualVT = getValueType(DAG.getDataLayout(), CurOrigArg->getType(),
4839                                   /*AllowUnknown*/ true);
4840       MVT ActualMVT = ActualVT.isSimple() ? ActualVT.getSimpleVT() : MVT::Other;
4841       // If ActualMVT is i1/i8/i16, we should set LocVT to i8/i8/i16.
4842       if (ActualMVT == MVT::i1 || ActualMVT == MVT::i8)
4843         ValVT = MVT::i8;
4844       else if (ActualMVT == MVT::i16)
4845         ValVT = MVT::i16;
4846     }
4847     bool UseVarArgCC = false;
4848     if (IsWin64)
4849       UseVarArgCC = isVarArg;
4850     CCAssignFn *AssignFn = CCAssignFnForCall(CallConv, UseVarArgCC);
4851     bool Res =
4852         AssignFn(i, ValVT, ValVT, CCValAssign::Full, Ins[i].Flags, CCInfo);
4853     assert(!Res && "Call operand has unhandled type");
4854     (void)Res;
4855   }
4856   SmallVector<SDValue, 16> ArgValues;
4857   unsigned ExtraArgLocs = 0;
4858   for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
4859     CCValAssign &VA = ArgLocs[i - ExtraArgLocs];
4860 
4861     if (Ins[i].Flags.isByVal()) {
4862       // Byval is used for HFAs in the PCS, but the system should work in a
4863       // non-compliant manner for larger structs.
4864       EVT PtrVT = getPointerTy(DAG.getDataLayout());
4865       int Size = Ins[i].Flags.getByValSize();
4866       unsigned NumRegs = (Size + 7) / 8;
4867 
4868       // FIXME: This works on big-endian for composite byvals, which are the common
4869       // case. It should also work for fundamental types too.
4870       unsigned FrameIdx =
4871         MFI.CreateFixedObject(8 * NumRegs, VA.getLocMemOffset(), false);
4872       SDValue FrameIdxN = DAG.getFrameIndex(FrameIdx, PtrVT);
4873       InVals.push_back(FrameIdxN);
4874 
4875       continue;
4876     }
4877 
4878     if (Ins[i].Flags.isSwiftAsync())
4879       MF.getInfo<AArch64FunctionInfo>()->setHasSwiftAsyncContext(true);
4880 
4881     SDValue ArgValue;
4882     if (VA.isRegLoc()) {
4883       // Arguments stored in registers.
4884       EVT RegVT = VA.getLocVT();
4885       const TargetRegisterClass *RC;
4886 
4887       if (RegVT == MVT::i32)
4888         RC = &AArch64::GPR32RegClass;
4889       else if (RegVT == MVT::i64)
4890         RC = &AArch64::GPR64RegClass;
4891       else if (RegVT == MVT::f16 || RegVT == MVT::bf16)
4892         RC = &AArch64::FPR16RegClass;
4893       else if (RegVT == MVT::f32)
4894         RC = &AArch64::FPR32RegClass;
4895       else if (RegVT == MVT::f64 || RegVT.is64BitVector())
4896         RC = &AArch64::FPR64RegClass;
4897       else if (RegVT == MVT::f128 || RegVT.is128BitVector())
4898         RC = &AArch64::FPR128RegClass;
4899       else if (RegVT.isScalableVector() &&
4900                RegVT.getVectorElementType() == MVT::i1)
4901         RC = &AArch64::PPRRegClass;
4902       else if (RegVT.isScalableVector())
4903         RC = &AArch64::ZPRRegClass;
4904       else
4905         llvm_unreachable("RegVT not supported by FORMAL_ARGUMENTS Lowering");
4906 
4907       // Transform the arguments in physical registers into virtual ones.
4908       unsigned Reg = MF.addLiveIn(VA.getLocReg(), RC);
4909       ArgValue = DAG.getCopyFromReg(Chain, DL, Reg, RegVT);
4910 
4911       // If this is an 8, 16 or 32-bit value, it is really passed promoted
4912       // to 64 bits.  Insert an assert[sz]ext to capture this, then
4913       // truncate to the right size.
4914       switch (VA.getLocInfo()) {
4915       default:
4916         llvm_unreachable("Unknown loc info!");
4917       case CCValAssign::Full:
4918         break;
4919       case CCValAssign::Indirect:
4920         assert(VA.getValVT().isScalableVector() &&
4921                "Only scalable vectors can be passed indirectly");
4922         break;
4923       case CCValAssign::BCvt:
4924         ArgValue = DAG.getNode(ISD::BITCAST, DL, VA.getValVT(), ArgValue);
4925         break;
4926       case CCValAssign::AExt:
4927       case CCValAssign::SExt:
4928       case CCValAssign::ZExt:
4929         break;
4930       case CCValAssign::AExtUpper:
4931         ArgValue = DAG.getNode(ISD::SRL, DL, RegVT, ArgValue,
4932                                DAG.getConstant(32, DL, RegVT));
4933         ArgValue = DAG.getZExtOrTrunc(ArgValue, DL, VA.getValVT());
4934         break;
4935       }
4936     } else { // VA.isRegLoc()
4937       assert(VA.isMemLoc() && "CCValAssign is neither reg nor mem");
4938       unsigned ArgOffset = VA.getLocMemOffset();
4939       unsigned ArgSize = (VA.getLocInfo() == CCValAssign::Indirect
4940                               ? VA.getLocVT().getSizeInBits()
4941                               : VA.getValVT().getSizeInBits()) / 8;
4942 
4943       uint32_t BEAlign = 0;
4944       if (!Subtarget->isLittleEndian() && ArgSize < 8 &&
4945           !Ins[i].Flags.isInConsecutiveRegs())
4946         BEAlign = 8 - ArgSize;
4947 
4948       int FI = MFI.CreateFixedObject(ArgSize, ArgOffset + BEAlign, true);
4949 
4950       // Create load nodes to retrieve arguments from the stack.
4951       SDValue FIN = DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout()));
4952 
4953       // For NON_EXTLOAD, generic code in getLoad assert(ValVT == MemVT)
4954       ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
4955       MVT MemVT = VA.getValVT();
4956 
4957       switch (VA.getLocInfo()) {
4958       default:
4959         break;
4960       case CCValAssign::Trunc:
4961       case CCValAssign::BCvt:
4962         MemVT = VA.getLocVT();
4963         break;
4964       case CCValAssign::Indirect:
4965         assert(VA.getValVT().isScalableVector() &&
4966                "Only scalable vectors can be passed indirectly");
4967         MemVT = VA.getLocVT();
4968         break;
4969       case CCValAssign::SExt:
4970         ExtType = ISD::SEXTLOAD;
4971         break;
4972       case CCValAssign::ZExt:
4973         ExtType = ISD::ZEXTLOAD;
4974         break;
4975       case CCValAssign::AExt:
4976         ExtType = ISD::EXTLOAD;
4977         break;
4978       }
4979 
4980       ArgValue = DAG.getExtLoad(
4981           ExtType, DL, VA.getLocVT(), Chain, FIN,
4982           MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI),
4983           MemVT);
4984 
4985     }
4986 
4987     if (VA.getLocInfo() == CCValAssign::Indirect) {
4988       assert(VA.getValVT().isScalableVector() &&
4989            "Only scalable vectors can be passed indirectly");
4990 
4991       uint64_t PartSize = VA.getValVT().getStoreSize().getKnownMinSize();
4992       unsigned NumParts = 1;
4993       if (Ins[i].Flags.isInConsecutiveRegs()) {
4994         assert(!Ins[i].Flags.isInConsecutiveRegsLast());
4995         while (!Ins[i + NumParts - 1].Flags.isInConsecutiveRegsLast())
4996           ++NumParts;
4997       }
4998 
4999       MVT PartLoad = VA.getValVT();
5000       SDValue Ptr = ArgValue;
5001 
5002       // Ensure we generate all loads for each tuple part, whilst updating the
5003       // pointer after each load correctly using vscale.
5004       while (NumParts > 0) {
5005         ArgValue = DAG.getLoad(PartLoad, DL, Chain, Ptr, MachinePointerInfo());
5006         InVals.push_back(ArgValue);
5007         NumParts--;
5008         if (NumParts > 0) {
5009           SDValue BytesIncrement = DAG.getVScale(
5010               DL, Ptr.getValueType(),
5011               APInt(Ptr.getValueSizeInBits().getFixedSize(), PartSize));
5012           SDNodeFlags Flags;
5013           Flags.setNoUnsignedWrap(true);
5014           Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
5015                             BytesIncrement, Flags);
5016           ExtraArgLocs++;
5017           i++;
5018         }
5019       }
5020     } else {
5021       if (Subtarget->isTargetILP32() && Ins[i].Flags.isPointer())
5022         ArgValue = DAG.getNode(ISD::AssertZext, DL, ArgValue.getValueType(),
5023                                ArgValue, DAG.getValueType(MVT::i32));
5024       InVals.push_back(ArgValue);
5025     }
5026   }
5027   assert((ArgLocs.size() + ExtraArgLocs) == Ins.size());
5028 
5029   // varargs
5030   AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
5031   if (isVarArg) {
5032     if (!Subtarget->isTargetDarwin() || IsWin64) {
5033       // The AAPCS variadic function ABI is identical to the non-variadic
5034       // one. As a result there may be more arguments in registers and we should
5035       // save them for future reference.
5036       // Win64 variadic functions also pass arguments in registers, but all float
5037       // arguments are passed in integer registers.
5038       saveVarArgRegisters(CCInfo, DAG, DL, Chain);
5039     }
5040 
5041     // This will point to the next argument passed via stack.
5042     unsigned StackOffset = CCInfo.getNextStackOffset();
5043     // We currently pass all varargs at 8-byte alignment, or 4 for ILP32
5044     StackOffset = alignTo(StackOffset, Subtarget->isTargetILP32() ? 4 : 8);
5045     FuncInfo->setVarArgsStackIndex(MFI.CreateFixedObject(4, StackOffset, true));
5046 
5047     if (MFI.hasMustTailInVarArgFunc()) {
5048       SmallVector<MVT, 2> RegParmTypes;
5049       RegParmTypes.push_back(MVT::i64);
5050       RegParmTypes.push_back(MVT::f128);
5051       // Compute the set of forwarded registers. The rest are scratch.
5052       SmallVectorImpl<ForwardedRegister> &Forwards =
5053                                        FuncInfo->getForwardedMustTailRegParms();
5054       CCInfo.analyzeMustTailForwardedRegisters(Forwards, RegParmTypes,
5055                                                CC_AArch64_AAPCS);
5056 
5057       // Conservatively forward X8, since it might be used for aggregate return.
5058       if (!CCInfo.isAllocated(AArch64::X8)) {
5059         unsigned X8VReg = MF.addLiveIn(AArch64::X8, &AArch64::GPR64RegClass);
5060         Forwards.push_back(ForwardedRegister(X8VReg, AArch64::X8, MVT::i64));
5061       }
5062     }
5063   }
5064 
5065   // On Windows, InReg pointers must be returned, so record the pointer in a
5066   // virtual register at the start of the function so it can be returned in the
5067   // epilogue.
5068   if (IsWin64) {
5069     for (unsigned I = 0, E = Ins.size(); I != E; ++I) {
5070       if (Ins[I].Flags.isInReg()) {
5071         assert(!FuncInfo->getSRetReturnReg());
5072 
5073         MVT PtrTy = getPointerTy(DAG.getDataLayout());
5074         Register Reg =
5075             MF.getRegInfo().createVirtualRegister(getRegClassFor(PtrTy));
5076         FuncInfo->setSRetReturnReg(Reg);
5077 
5078         SDValue Copy = DAG.getCopyToReg(DAG.getEntryNode(), DL, Reg, InVals[I]);
5079         Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Copy, Chain);
5080         break;
5081       }
5082     }
5083   }
5084 
5085   unsigned StackArgSize = CCInfo.getNextStackOffset();
5086   bool TailCallOpt = MF.getTarget().Options.GuaranteedTailCallOpt;
5087   if (DoesCalleeRestoreStack(CallConv, TailCallOpt)) {
5088     // This is a non-standard ABI so by fiat I say we're allowed to make full
5089     // use of the stack area to be popped, which must be aligned to 16 bytes in
5090     // any case:
5091     StackArgSize = alignTo(StackArgSize, 16);
5092 
5093     // If we're expected to restore the stack (e.g. fastcc) then we'll be adding
5094     // a multiple of 16.
5095     FuncInfo->setArgumentStackToRestore(StackArgSize);
5096 
5097     // This realignment carries over to the available bytes below. Our own
5098     // callers will guarantee the space is free by giving an aligned value to
5099     // CALLSEQ_START.
5100   }
5101   // Even if we're not expected to free up the space, it's useful to know how
5102   // much is there while considering tail calls (because we can reuse it).
5103   FuncInfo->setBytesInStackArgArea(StackArgSize);
5104 
5105   if (Subtarget->hasCustomCallingConv())
5106     Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
5107 
5108   return Chain;
5109 }
5110 
saveVarArgRegisters(CCState & CCInfo,SelectionDAG & DAG,const SDLoc & DL,SDValue & Chain) const5111 void AArch64TargetLowering::saveVarArgRegisters(CCState &CCInfo,
5112                                                 SelectionDAG &DAG,
5113                                                 const SDLoc &DL,
5114                                                 SDValue &Chain) const {
5115   MachineFunction &MF = DAG.getMachineFunction();
5116   MachineFrameInfo &MFI = MF.getFrameInfo();
5117   AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
5118   auto PtrVT = getPointerTy(DAG.getDataLayout());
5119   bool IsWin64 = Subtarget->isCallingConvWin64(MF.getFunction().getCallingConv());
5120 
5121   SmallVector<SDValue, 8> MemOps;
5122 
5123   static const MCPhysReg GPRArgRegs[] = { AArch64::X0, AArch64::X1, AArch64::X2,
5124                                           AArch64::X3, AArch64::X4, AArch64::X5,
5125                                           AArch64::X6, AArch64::X7 };
5126   static const unsigned NumGPRArgRegs = array_lengthof(GPRArgRegs);
5127   unsigned FirstVariadicGPR = CCInfo.getFirstUnallocated(GPRArgRegs);
5128 
5129   unsigned GPRSaveSize = 8 * (NumGPRArgRegs - FirstVariadicGPR);
5130   int GPRIdx = 0;
5131   if (GPRSaveSize != 0) {
5132     if (IsWin64) {
5133       GPRIdx = MFI.CreateFixedObject(GPRSaveSize, -(int)GPRSaveSize, false);
5134       if (GPRSaveSize & 15)
5135         // The extra size here, if triggered, will always be 8.
5136         MFI.CreateFixedObject(16 - (GPRSaveSize & 15), -(int)alignTo(GPRSaveSize, 16), false);
5137     } else
5138       GPRIdx = MFI.CreateStackObject(GPRSaveSize, Align(8), false);
5139 
5140     SDValue FIN = DAG.getFrameIndex(GPRIdx, PtrVT);
5141 
5142     for (unsigned i = FirstVariadicGPR; i < NumGPRArgRegs; ++i) {
5143       unsigned VReg = MF.addLiveIn(GPRArgRegs[i], &AArch64::GPR64RegClass);
5144       SDValue Val = DAG.getCopyFromReg(Chain, DL, VReg, MVT::i64);
5145       SDValue Store = DAG.getStore(
5146           Val.getValue(1), DL, Val, FIN,
5147           IsWin64
5148               ? MachinePointerInfo::getFixedStack(DAG.getMachineFunction(),
5149                                                   GPRIdx,
5150                                                   (i - FirstVariadicGPR) * 8)
5151               : MachinePointerInfo::getStack(DAG.getMachineFunction(), i * 8));
5152       MemOps.push_back(Store);
5153       FIN =
5154           DAG.getNode(ISD::ADD, DL, PtrVT, FIN, DAG.getConstant(8, DL, PtrVT));
5155     }
5156   }
5157   FuncInfo->setVarArgsGPRIndex(GPRIdx);
5158   FuncInfo->setVarArgsGPRSize(GPRSaveSize);
5159 
5160   if (Subtarget->hasFPARMv8() && !IsWin64) {
5161     static const MCPhysReg FPRArgRegs[] = {
5162         AArch64::Q0, AArch64::Q1, AArch64::Q2, AArch64::Q3,
5163         AArch64::Q4, AArch64::Q5, AArch64::Q6, AArch64::Q7};
5164     static const unsigned NumFPRArgRegs = array_lengthof(FPRArgRegs);
5165     unsigned FirstVariadicFPR = CCInfo.getFirstUnallocated(FPRArgRegs);
5166 
5167     unsigned FPRSaveSize = 16 * (NumFPRArgRegs - FirstVariadicFPR);
5168     int FPRIdx = 0;
5169     if (FPRSaveSize != 0) {
5170       FPRIdx = MFI.CreateStackObject(FPRSaveSize, Align(16), false);
5171 
5172       SDValue FIN = DAG.getFrameIndex(FPRIdx, PtrVT);
5173 
5174       for (unsigned i = FirstVariadicFPR; i < NumFPRArgRegs; ++i) {
5175         unsigned VReg = MF.addLiveIn(FPRArgRegs[i], &AArch64::FPR128RegClass);
5176         SDValue Val = DAG.getCopyFromReg(Chain, DL, VReg, MVT::f128);
5177 
5178         SDValue Store = DAG.getStore(
5179             Val.getValue(1), DL, Val, FIN,
5180             MachinePointerInfo::getStack(DAG.getMachineFunction(), i * 16));
5181         MemOps.push_back(Store);
5182         FIN = DAG.getNode(ISD::ADD, DL, PtrVT, FIN,
5183                           DAG.getConstant(16, DL, PtrVT));
5184       }
5185     }
5186     FuncInfo->setVarArgsFPRIndex(FPRIdx);
5187     FuncInfo->setVarArgsFPRSize(FPRSaveSize);
5188   }
5189 
5190   if (!MemOps.empty()) {
5191     Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOps);
5192   }
5193 }
5194 
5195 /// LowerCallResult - Lower the result values of a call into the
5196 /// appropriate copies out of appropriate physical registers.
LowerCallResult(SDValue Chain,SDValue InFlag,CallingConv::ID CallConv,bool isVarArg,const SmallVectorImpl<ISD::InputArg> & Ins,const SDLoc & DL,SelectionDAG & DAG,SmallVectorImpl<SDValue> & InVals,bool isThisReturn,SDValue ThisVal) const5197 SDValue AArch64TargetLowering::LowerCallResult(
5198     SDValue Chain, SDValue InFlag, CallingConv::ID CallConv, bool isVarArg,
5199     const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
5200     SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals, bool isThisReturn,
5201     SDValue ThisVal) const {
5202   CCAssignFn *RetCC = CCAssignFnForReturn(CallConv);
5203   // Assign locations to each value returned by this call.
5204   SmallVector<CCValAssign, 16> RVLocs;
5205   DenseMap<unsigned, SDValue> CopiedRegs;
5206   CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(), RVLocs,
5207                  *DAG.getContext());
5208   CCInfo.AnalyzeCallResult(Ins, RetCC);
5209 
5210   // Copy all of the result registers out of their specified physreg.
5211   for (unsigned i = 0; i != RVLocs.size(); ++i) {
5212     CCValAssign VA = RVLocs[i];
5213 
5214     // Pass 'this' value directly from the argument to return value, to avoid
5215     // reg unit interference
5216     if (i == 0 && isThisReturn) {
5217       assert(!VA.needsCustom() && VA.getLocVT() == MVT::i64 &&
5218              "unexpected return calling convention register assignment");
5219       InVals.push_back(ThisVal);
5220       continue;
5221     }
5222 
5223     // Avoid copying a physreg twice since RegAllocFast is incompetent and only
5224     // allows one use of a physreg per block.
5225     SDValue Val = CopiedRegs.lookup(VA.getLocReg());
5226     if (!Val) {
5227       Val =
5228           DAG.getCopyFromReg(Chain, DL, VA.getLocReg(), VA.getLocVT(), InFlag);
5229       Chain = Val.getValue(1);
5230       InFlag = Val.getValue(2);
5231       CopiedRegs[VA.getLocReg()] = Val;
5232     }
5233 
5234     switch (VA.getLocInfo()) {
5235     default:
5236       llvm_unreachable("Unknown loc info!");
5237     case CCValAssign::Full:
5238       break;
5239     case CCValAssign::BCvt:
5240       Val = DAG.getNode(ISD::BITCAST, DL, VA.getValVT(), Val);
5241       break;
5242     case CCValAssign::AExtUpper:
5243       Val = DAG.getNode(ISD::SRL, DL, VA.getLocVT(), Val,
5244                         DAG.getConstant(32, DL, VA.getLocVT()));
5245       LLVM_FALLTHROUGH;
5246     case CCValAssign::AExt:
5247       LLVM_FALLTHROUGH;
5248     case CCValAssign::ZExt:
5249       Val = DAG.getZExtOrTrunc(Val, DL, VA.getValVT());
5250       break;
5251     }
5252 
5253     InVals.push_back(Val);
5254   }
5255 
5256   return Chain;
5257 }
5258 
5259 /// Return true if the calling convention is one that we can guarantee TCO for.
canGuaranteeTCO(CallingConv::ID CC,bool GuaranteeTailCalls)5260 static bool canGuaranteeTCO(CallingConv::ID CC, bool GuaranteeTailCalls) {
5261   return (CC == CallingConv::Fast && GuaranteeTailCalls) ||
5262          CC == CallingConv::Tail || CC == CallingConv::SwiftTail;
5263 }
5264 
5265 /// Return true if we might ever do TCO for calls with this calling convention.
mayTailCallThisCC(CallingConv::ID CC)5266 static bool mayTailCallThisCC(CallingConv::ID CC) {
5267   switch (CC) {
5268   case CallingConv::C:
5269   case CallingConv::AArch64_SVE_VectorCall:
5270   case CallingConv::PreserveMost:
5271   case CallingConv::Swift:
5272   case CallingConv::SwiftTail:
5273   case CallingConv::Tail:
5274   case CallingConv::Fast:
5275     return true;
5276   default:
5277     return false;
5278   }
5279 }
5280 
isEligibleForTailCallOptimization(SDValue Callee,CallingConv::ID CalleeCC,bool isVarArg,const SmallVectorImpl<ISD::OutputArg> & Outs,const SmallVectorImpl<SDValue> & OutVals,const SmallVectorImpl<ISD::InputArg> & Ins,SelectionDAG & DAG) const5281 bool AArch64TargetLowering::isEligibleForTailCallOptimization(
5282     SDValue Callee, CallingConv::ID CalleeCC, bool isVarArg,
5283     const SmallVectorImpl<ISD::OutputArg> &Outs,
5284     const SmallVectorImpl<SDValue> &OutVals,
5285     const SmallVectorImpl<ISD::InputArg> &Ins, SelectionDAG &DAG) const {
5286   if (!mayTailCallThisCC(CalleeCC))
5287     return false;
5288 
5289   MachineFunction &MF = DAG.getMachineFunction();
5290   const Function &CallerF = MF.getFunction();
5291   CallingConv::ID CallerCC = CallerF.getCallingConv();
5292 
5293   // Functions using the C or Fast calling convention that have an SVE signature
5294   // preserve more registers and should assume the SVE_VectorCall CC.
5295   // The check for matching callee-saved regs will determine whether it is
5296   // eligible for TCO.
5297   if ((CallerCC == CallingConv::C || CallerCC == CallingConv::Fast) &&
5298       AArch64RegisterInfo::hasSVEArgsOrReturn(&MF))
5299     CallerCC = CallingConv::AArch64_SVE_VectorCall;
5300 
5301   bool CCMatch = CallerCC == CalleeCC;
5302 
5303   // When using the Windows calling convention on a non-windows OS, we want
5304   // to back up and restore X18 in such functions; we can't do a tail call
5305   // from those functions.
5306   if (CallerCC == CallingConv::Win64 && !Subtarget->isTargetWindows() &&
5307       CalleeCC != CallingConv::Win64)
5308     return false;
5309 
5310   // Byval parameters hand the function a pointer directly into the stack area
5311   // we want to reuse during a tail call. Working around this *is* possible (see
5312   // X86) but less efficient and uglier in LowerCall.
5313   for (Function::const_arg_iterator i = CallerF.arg_begin(),
5314                                     e = CallerF.arg_end();
5315        i != e; ++i) {
5316     if (i->hasByValAttr())
5317       return false;
5318 
5319     // On Windows, "inreg" attributes signify non-aggregate indirect returns.
5320     // In this case, it is necessary to save/restore X0 in the callee. Tail
5321     // call opt interferes with this. So we disable tail call opt when the
5322     // caller has an argument with "inreg" attribute.
5323 
5324     // FIXME: Check whether the callee also has an "inreg" argument.
5325     if (i->hasInRegAttr())
5326       return false;
5327   }
5328 
5329   if (canGuaranteeTCO(CalleeCC, getTargetMachine().Options.GuaranteedTailCallOpt))
5330     return CCMatch;
5331 
5332   // Externally-defined functions with weak linkage should not be
5333   // tail-called on AArch64 when the OS does not support dynamic
5334   // pre-emption of symbols, as the AAELF spec requires normal calls
5335   // to undefined weak functions to be replaced with a NOP or jump to the
5336   // next instruction. The behaviour of branch instructions in this
5337   // situation (as used for tail calls) is implementation-defined, so we
5338   // cannot rely on the linker replacing the tail call with a return.
5339   if (GlobalAddressSDNode *G = dyn_cast<GlobalAddressSDNode>(Callee)) {
5340     const GlobalValue *GV = G->getGlobal();
5341     const Triple &TT = getTargetMachine().getTargetTriple();
5342     if (GV->hasExternalWeakLinkage() &&
5343         (!TT.isOSWindows() || TT.isOSBinFormatELF() || TT.isOSBinFormatMachO()))
5344       return false;
5345   }
5346 
5347   // Now we search for cases where we can use a tail call without changing the
5348   // ABI. Sibcall is used in some places (particularly gcc) to refer to this
5349   // concept.
5350 
5351   // I want anyone implementing a new calling convention to think long and hard
5352   // about this assert.
5353   assert((!isVarArg || CalleeCC == CallingConv::C) &&
5354          "Unexpected variadic calling convention");
5355 
5356   LLVMContext &C = *DAG.getContext();
5357   if (isVarArg && !Outs.empty()) {
5358     // At least two cases here: if caller is fastcc then we can't have any
5359     // memory arguments (we'd be expected to clean up the stack afterwards). If
5360     // caller is C then we could potentially use its argument area.
5361 
5362     // FIXME: for now we take the most conservative of these in both cases:
5363     // disallow all variadic memory operands.
5364     SmallVector<CCValAssign, 16> ArgLocs;
5365     CCState CCInfo(CalleeCC, isVarArg, MF, ArgLocs, C);
5366 
5367     CCInfo.AnalyzeCallOperands(Outs, CCAssignFnForCall(CalleeCC, true));
5368     for (const CCValAssign &ArgLoc : ArgLocs)
5369       if (!ArgLoc.isRegLoc())
5370         return false;
5371   }
5372 
5373   // Check that the call results are passed in the same way.
5374   if (!CCState::resultsCompatible(CalleeCC, CallerCC, MF, C, Ins,
5375                                   CCAssignFnForCall(CalleeCC, isVarArg),
5376                                   CCAssignFnForCall(CallerCC, isVarArg)))
5377     return false;
5378   // The callee has to preserve all registers the caller needs to preserve.
5379   const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
5380   const uint32_t *CallerPreserved = TRI->getCallPreservedMask(MF, CallerCC);
5381   if (!CCMatch) {
5382     const uint32_t *CalleePreserved = TRI->getCallPreservedMask(MF, CalleeCC);
5383     if (Subtarget->hasCustomCallingConv()) {
5384       TRI->UpdateCustomCallPreservedMask(MF, &CallerPreserved);
5385       TRI->UpdateCustomCallPreservedMask(MF, &CalleePreserved);
5386     }
5387     if (!TRI->regmaskSubsetEqual(CallerPreserved, CalleePreserved))
5388       return false;
5389   }
5390 
5391   // Nothing more to check if the callee is taking no arguments
5392   if (Outs.empty())
5393     return true;
5394 
5395   SmallVector<CCValAssign, 16> ArgLocs;
5396   CCState CCInfo(CalleeCC, isVarArg, MF, ArgLocs, C);
5397 
5398   CCInfo.AnalyzeCallOperands(Outs, CCAssignFnForCall(CalleeCC, isVarArg));
5399 
5400   const AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
5401 
5402   // If any of the arguments is passed indirectly, it must be SVE, so the
5403   // 'getBytesInStackArgArea' is not sufficient to determine whether we need to
5404   // allocate space on the stack. That is why we determine this explicitly here
5405   // the call cannot be a tailcall.
5406   if (llvm::any_of(ArgLocs, [](CCValAssign &A) {
5407         assert((A.getLocInfo() != CCValAssign::Indirect ||
5408                 A.getValVT().isScalableVector()) &&
5409                "Expected value to be scalable");
5410         return A.getLocInfo() == CCValAssign::Indirect;
5411       }))
5412     return false;
5413 
5414   // If the stack arguments for this call do not fit into our own save area then
5415   // the call cannot be made tail.
5416   if (CCInfo.getNextStackOffset() > FuncInfo->getBytesInStackArgArea())
5417     return false;
5418 
5419   const MachineRegisterInfo &MRI = MF.getRegInfo();
5420   if (!parametersInCSRMatch(MRI, CallerPreserved, ArgLocs, OutVals))
5421     return false;
5422 
5423   return true;
5424 }
5425 
addTokenForArgument(SDValue Chain,SelectionDAG & DAG,MachineFrameInfo & MFI,int ClobberedFI) const5426 SDValue AArch64TargetLowering::addTokenForArgument(SDValue Chain,
5427                                                    SelectionDAG &DAG,
5428                                                    MachineFrameInfo &MFI,
5429                                                    int ClobberedFI) const {
5430   SmallVector<SDValue, 8> ArgChains;
5431   int64_t FirstByte = MFI.getObjectOffset(ClobberedFI);
5432   int64_t LastByte = FirstByte + MFI.getObjectSize(ClobberedFI) - 1;
5433 
5434   // Include the original chain at the beginning of the list. When this is
5435   // used by target LowerCall hooks, this helps legalize find the
5436   // CALLSEQ_BEGIN node.
5437   ArgChains.push_back(Chain);
5438 
5439   // Add a chain value for each stack argument corresponding
5440   for (SDNode::use_iterator U = DAG.getEntryNode().getNode()->use_begin(),
5441                             UE = DAG.getEntryNode().getNode()->use_end();
5442        U != UE; ++U)
5443     if (LoadSDNode *L = dyn_cast<LoadSDNode>(*U))
5444       if (FrameIndexSDNode *FI = dyn_cast<FrameIndexSDNode>(L->getBasePtr()))
5445         if (FI->getIndex() < 0) {
5446           int64_t InFirstByte = MFI.getObjectOffset(FI->getIndex());
5447           int64_t InLastByte = InFirstByte;
5448           InLastByte += MFI.getObjectSize(FI->getIndex()) - 1;
5449 
5450           if ((InFirstByte <= FirstByte && FirstByte <= InLastByte) ||
5451               (FirstByte <= InFirstByte && InFirstByte <= LastByte))
5452             ArgChains.push_back(SDValue(L, 1));
5453         }
5454 
5455   // Build a tokenfactor for all the chains.
5456   return DAG.getNode(ISD::TokenFactor, SDLoc(Chain), MVT::Other, ArgChains);
5457 }
5458 
DoesCalleeRestoreStack(CallingConv::ID CallCC,bool TailCallOpt) const5459 bool AArch64TargetLowering::DoesCalleeRestoreStack(CallingConv::ID CallCC,
5460                                                    bool TailCallOpt) const {
5461   return (CallCC == CallingConv::Fast && TailCallOpt) ||
5462          CallCC == CallingConv::Tail || CallCC == CallingConv::SwiftTail;
5463 }
5464 
5465 /// LowerCall - Lower a call to a callseq_start + CALL + callseq_end chain,
5466 /// and add input and output parameter nodes.
5467 SDValue
LowerCall(CallLoweringInfo & CLI,SmallVectorImpl<SDValue> & InVals) const5468 AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
5469                                  SmallVectorImpl<SDValue> &InVals) const {
5470   SelectionDAG &DAG = CLI.DAG;
5471   SDLoc &DL = CLI.DL;
5472   SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
5473   SmallVector<SDValue, 32> &OutVals = CLI.OutVals;
5474   SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins;
5475   SDValue Chain = CLI.Chain;
5476   SDValue Callee = CLI.Callee;
5477   bool &IsTailCall = CLI.IsTailCall;
5478   CallingConv::ID CallConv = CLI.CallConv;
5479   bool IsVarArg = CLI.IsVarArg;
5480 
5481   MachineFunction &MF = DAG.getMachineFunction();
5482   MachineFunction::CallSiteInfo CSInfo;
5483   bool IsThisReturn = false;
5484 
5485   AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
5486   bool TailCallOpt = MF.getTarget().Options.GuaranteedTailCallOpt;
5487   bool IsSibCall = false;
5488   bool IsCalleeWin64 = Subtarget->isCallingConvWin64(CallConv);
5489 
5490   // Check callee args/returns for SVE registers and set calling convention
5491   // accordingly.
5492   if (CallConv == CallingConv::C || CallConv == CallingConv::Fast) {
5493     bool CalleeOutSVE = any_of(Outs, [](ISD::OutputArg &Out){
5494       return Out.VT.isScalableVector();
5495     });
5496     bool CalleeInSVE = any_of(Ins, [](ISD::InputArg &In){
5497       return In.VT.isScalableVector();
5498     });
5499 
5500     if (CalleeInSVE || CalleeOutSVE)
5501       CallConv = CallingConv::AArch64_SVE_VectorCall;
5502   }
5503 
5504   if (IsTailCall) {
5505     // Check if it's really possible to do a tail call.
5506     IsTailCall = isEligibleForTailCallOptimization(
5507         Callee, CallConv, IsVarArg, Outs, OutVals, Ins, DAG);
5508     if (!IsTailCall && CLI.CB && CLI.CB->isMustTailCall())
5509       report_fatal_error("failed to perform tail call elimination on a call "
5510                          "site marked musttail");
5511 
5512     // A sibling call is one where we're under the usual C ABI and not planning
5513     // to change that but can still do a tail call:
5514     if (!TailCallOpt && IsTailCall && CallConv != CallingConv::Tail &&
5515         CallConv != CallingConv::SwiftTail)
5516       IsSibCall = true;
5517 
5518     if (IsTailCall)
5519       ++NumTailCalls;
5520   }
5521 
5522   // Analyze operands of the call, assigning locations to each operand.
5523   SmallVector<CCValAssign, 16> ArgLocs;
5524   CCState CCInfo(CallConv, IsVarArg, DAG.getMachineFunction(), ArgLocs,
5525                  *DAG.getContext());
5526 
5527   if (IsVarArg) {
5528     // Handle fixed and variable vector arguments differently.
5529     // Variable vector arguments always go into memory.
5530     unsigned NumArgs = Outs.size();
5531 
5532     for (unsigned i = 0; i != NumArgs; ++i) {
5533       MVT ArgVT = Outs[i].VT;
5534       if (!Outs[i].IsFixed && ArgVT.isScalableVector())
5535         report_fatal_error("Passing SVE types to variadic functions is "
5536                            "currently not supported");
5537 
5538       ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
5539       bool UseVarArgCC = !Outs[i].IsFixed;
5540       // On Windows, the fixed arguments in a vararg call are passed in GPRs
5541       // too, so use the vararg CC to force them to integer registers.
5542       if (IsCalleeWin64)
5543         UseVarArgCC = true;
5544       CCAssignFn *AssignFn = CCAssignFnForCall(CallConv, UseVarArgCC);
5545       bool Res = AssignFn(i, ArgVT, ArgVT, CCValAssign::Full, ArgFlags, CCInfo);
5546       assert(!Res && "Call operand has unhandled type");
5547       (void)Res;
5548     }
5549   } else {
5550     // At this point, Outs[].VT may already be promoted to i32. To correctly
5551     // handle passing i8 as i8 instead of i32 on stack, we pass in both i32 and
5552     // i8 to CC_AArch64_AAPCS with i32 being ValVT and i8 being LocVT.
5553     // Since AnalyzeCallOperands uses Ins[].VT for both ValVT and LocVT, here
5554     // we use a special version of AnalyzeCallOperands to pass in ValVT and
5555     // LocVT.
5556     unsigned NumArgs = Outs.size();
5557     for (unsigned i = 0; i != NumArgs; ++i) {
5558       MVT ValVT = Outs[i].VT;
5559       // Get type of the original argument.
5560       EVT ActualVT = getValueType(DAG.getDataLayout(),
5561                                   CLI.getArgs()[Outs[i].OrigArgIndex].Ty,
5562                                   /*AllowUnknown*/ true);
5563       MVT ActualMVT = ActualVT.isSimple() ? ActualVT.getSimpleVT() : ValVT;
5564       ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
5565       // If ActualMVT is i1/i8/i16, we should set LocVT to i8/i8/i16.
5566       if (ActualMVT == MVT::i1 || ActualMVT == MVT::i8)
5567         ValVT = MVT::i8;
5568       else if (ActualMVT == MVT::i16)
5569         ValVT = MVT::i16;
5570 
5571       CCAssignFn *AssignFn = CCAssignFnForCall(CallConv, /*IsVarArg=*/false);
5572       bool Res = AssignFn(i, ValVT, ValVT, CCValAssign::Full, ArgFlags, CCInfo);
5573       assert(!Res && "Call operand has unhandled type");
5574       (void)Res;
5575     }
5576   }
5577 
5578   // Get a count of how many bytes are to be pushed on the stack.
5579   unsigned NumBytes = CCInfo.getNextStackOffset();
5580 
5581   if (IsSibCall) {
5582     // Since we're not changing the ABI to make this a tail call, the memory
5583     // operands are already available in the caller's incoming argument space.
5584     NumBytes = 0;
5585   }
5586 
5587   // FPDiff is the byte offset of the call's argument area from the callee's.
5588   // Stores to callee stack arguments will be placed in FixedStackSlots offset
5589   // by this amount for a tail call. In a sibling call it must be 0 because the
5590   // caller will deallocate the entire stack and the callee still expects its
5591   // arguments to begin at SP+0. Completely unused for non-tail calls.
5592   int FPDiff = 0;
5593 
5594   if (IsTailCall && !IsSibCall) {
5595     unsigned NumReusableBytes = FuncInfo->getBytesInStackArgArea();
5596 
5597     // Since callee will pop argument stack as a tail call, we must keep the
5598     // popped size 16-byte aligned.
5599     NumBytes = alignTo(NumBytes, 16);
5600 
5601     // FPDiff will be negative if this tail call requires more space than we
5602     // would automatically have in our incoming argument space. Positive if we
5603     // can actually shrink the stack.
5604     FPDiff = NumReusableBytes - NumBytes;
5605 
5606     // Update the required reserved area if this is the tail call requiring the
5607     // most argument stack space.
5608     if (FPDiff < 0 && FuncInfo->getTailCallReservedStack() < (unsigned)-FPDiff)
5609       FuncInfo->setTailCallReservedStack(-FPDiff);
5610 
5611     // The stack pointer must be 16-byte aligned at all times it's used for a
5612     // memory operation, which in practice means at *all* times and in
5613     // particular across call boundaries. Therefore our own arguments started at
5614     // a 16-byte aligned SP and the delta applied for the tail call should
5615     // satisfy the same constraint.
5616     assert(FPDiff % 16 == 0 && "unaligned stack on tail call");
5617   }
5618 
5619   // Adjust the stack pointer for the new arguments...
5620   // These operations are automatically eliminated by the prolog/epilog pass
5621   if (!IsSibCall)
5622     Chain = DAG.getCALLSEQ_START(Chain, IsTailCall ? 0 : NumBytes, 0, DL);
5623 
5624   SDValue StackPtr = DAG.getCopyFromReg(Chain, DL, AArch64::SP,
5625                                         getPointerTy(DAG.getDataLayout()));
5626 
5627   SmallVector<std::pair<unsigned, SDValue>, 8> RegsToPass;
5628   SmallSet<unsigned, 8> RegsUsed;
5629   SmallVector<SDValue, 8> MemOpChains;
5630   auto PtrVT = getPointerTy(DAG.getDataLayout());
5631 
5632   if (IsVarArg && CLI.CB && CLI.CB->isMustTailCall()) {
5633     const auto &Forwards = FuncInfo->getForwardedMustTailRegParms();
5634     for (const auto &F : Forwards) {
5635       SDValue Val = DAG.getCopyFromReg(Chain, DL, F.VReg, F.VT);
5636        RegsToPass.emplace_back(F.PReg, Val);
5637     }
5638   }
5639 
5640   // Walk the register/memloc assignments, inserting copies/loads.
5641   unsigned ExtraArgLocs = 0;
5642   for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
5643     CCValAssign &VA = ArgLocs[i - ExtraArgLocs];
5644     SDValue Arg = OutVals[i];
5645     ISD::ArgFlagsTy Flags = Outs[i].Flags;
5646 
5647     // Promote the value if needed.
5648     switch (VA.getLocInfo()) {
5649     default:
5650       llvm_unreachable("Unknown loc info!");
5651     case CCValAssign::Full:
5652       break;
5653     case CCValAssign::SExt:
5654       Arg = DAG.getNode(ISD::SIGN_EXTEND, DL, VA.getLocVT(), Arg);
5655       break;
5656     case CCValAssign::ZExt:
5657       Arg = DAG.getNode(ISD::ZERO_EXTEND, DL, VA.getLocVT(), Arg);
5658       break;
5659     case CCValAssign::AExt:
5660       if (Outs[i].ArgVT == MVT::i1) {
5661         // AAPCS requires i1 to be zero-extended to 8-bits by the caller.
5662         Arg = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Arg);
5663         Arg = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i8, Arg);
5664       }
5665       Arg = DAG.getNode(ISD::ANY_EXTEND, DL, VA.getLocVT(), Arg);
5666       break;
5667     case CCValAssign::AExtUpper:
5668       assert(VA.getValVT() == MVT::i32 && "only expect 32 -> 64 upper bits");
5669       Arg = DAG.getNode(ISD::ANY_EXTEND, DL, VA.getLocVT(), Arg);
5670       Arg = DAG.getNode(ISD::SHL, DL, VA.getLocVT(), Arg,
5671                         DAG.getConstant(32, DL, VA.getLocVT()));
5672       break;
5673     case CCValAssign::BCvt:
5674       Arg = DAG.getBitcast(VA.getLocVT(), Arg);
5675       break;
5676     case CCValAssign::Trunc:
5677       Arg = DAG.getZExtOrTrunc(Arg, DL, VA.getLocVT());
5678       break;
5679     case CCValAssign::FPExt:
5680       Arg = DAG.getNode(ISD::FP_EXTEND, DL, VA.getLocVT(), Arg);
5681       break;
5682     case CCValAssign::Indirect:
5683       assert(VA.getValVT().isScalableVector() &&
5684              "Only scalable vectors can be passed indirectly");
5685 
5686       uint64_t StoreSize = VA.getValVT().getStoreSize().getKnownMinSize();
5687       uint64_t PartSize = StoreSize;
5688       unsigned NumParts = 1;
5689       if (Outs[i].Flags.isInConsecutiveRegs()) {
5690         assert(!Outs[i].Flags.isInConsecutiveRegsLast());
5691         while (!Outs[i + NumParts - 1].Flags.isInConsecutiveRegsLast())
5692           ++NumParts;
5693         StoreSize *= NumParts;
5694       }
5695 
5696       MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
5697       Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext());
5698       Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty);
5699       int FI = MFI.CreateStackObject(StoreSize, Alignment, false);
5700       MFI.setStackID(FI, TargetStackID::ScalableVector);
5701 
5702       MachinePointerInfo MPI =
5703           MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI);
5704       SDValue Ptr = DAG.getFrameIndex(
5705           FI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
5706       SDValue SpillSlot = Ptr;
5707 
5708       // Ensure we generate all stores for each tuple part, whilst updating the
5709       // pointer after each store correctly using vscale.
5710       while (NumParts) {
5711         Chain = DAG.getStore(Chain, DL, OutVals[i], Ptr, MPI);
5712         NumParts--;
5713         if (NumParts > 0) {
5714           SDValue BytesIncrement = DAG.getVScale(
5715               DL, Ptr.getValueType(),
5716               APInt(Ptr.getValueSizeInBits().getFixedSize(), PartSize));
5717           SDNodeFlags Flags;
5718           Flags.setNoUnsignedWrap(true);
5719 
5720           MPI = MachinePointerInfo(MPI.getAddrSpace());
5721           Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
5722                             BytesIncrement, Flags);
5723           ExtraArgLocs++;
5724           i++;
5725         }
5726       }
5727 
5728       Arg = SpillSlot;
5729       break;
5730     }
5731 
5732     if (VA.isRegLoc()) {
5733       if (i == 0 && Flags.isReturned() && !Flags.isSwiftSelf() &&
5734           Outs[0].VT == MVT::i64) {
5735         assert(VA.getLocVT() == MVT::i64 &&
5736                "unexpected calling convention register assignment");
5737         assert(!Ins.empty() && Ins[0].VT == MVT::i64 &&
5738                "unexpected use of 'returned'");
5739         IsThisReturn = true;
5740       }
5741       if (RegsUsed.count(VA.getLocReg())) {
5742         // If this register has already been used then we're trying to pack
5743         // parts of an [N x i32] into an X-register. The extension type will
5744         // take care of putting the two halves in the right place but we have to
5745         // combine them.
5746         SDValue &Bits =
5747             llvm::find_if(RegsToPass,
5748                           [=](const std::pair<unsigned, SDValue> &Elt) {
5749                             return Elt.first == VA.getLocReg();
5750                           })
5751                 ->second;
5752         Bits = DAG.getNode(ISD::OR, DL, Bits.getValueType(), Bits, Arg);
5753         // Call site info is used for function's parameter entry value
5754         // tracking. For now we track only simple cases when parameter
5755         // is transferred through whole register.
5756         llvm::erase_if(CSInfo, [&VA](MachineFunction::ArgRegPair ArgReg) {
5757           return ArgReg.Reg == VA.getLocReg();
5758         });
5759       } else {
5760         RegsToPass.emplace_back(VA.getLocReg(), Arg);
5761         RegsUsed.insert(VA.getLocReg());
5762         const TargetOptions &Options = DAG.getTarget().Options;
5763         if (Options.EmitCallSiteInfo)
5764           CSInfo.emplace_back(VA.getLocReg(), i);
5765       }
5766     } else {
5767       assert(VA.isMemLoc());
5768 
5769       SDValue DstAddr;
5770       MachinePointerInfo DstInfo;
5771 
5772       // FIXME: This works on big-endian for composite byvals, which are the
5773       // common case. It should also work for fundamental types too.
5774       uint32_t BEAlign = 0;
5775       unsigned OpSize;
5776       if (VA.getLocInfo() == CCValAssign::Indirect)
5777         OpSize = VA.getLocVT().getFixedSizeInBits();
5778       else
5779         OpSize = Flags.isByVal() ? Flags.getByValSize() * 8
5780                                  : VA.getValVT().getSizeInBits();
5781       OpSize = (OpSize + 7) / 8;
5782       if (!Subtarget->isLittleEndian() && !Flags.isByVal() &&
5783           !Flags.isInConsecutiveRegs()) {
5784         if (OpSize < 8)
5785           BEAlign = 8 - OpSize;
5786       }
5787       unsigned LocMemOffset = VA.getLocMemOffset();
5788       int32_t Offset = LocMemOffset + BEAlign;
5789       SDValue PtrOff = DAG.getIntPtrConstant(Offset, DL);
5790       PtrOff = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, PtrOff);
5791 
5792       if (IsTailCall) {
5793         Offset = Offset + FPDiff;
5794         int FI = MF.getFrameInfo().CreateFixedObject(OpSize, Offset, true);
5795 
5796         DstAddr = DAG.getFrameIndex(FI, PtrVT);
5797         DstInfo =
5798             MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI);
5799 
5800         // Make sure any stack arguments overlapping with where we're storing
5801         // are loaded before this eventual operation. Otherwise they'll be
5802         // clobbered.
5803         Chain = addTokenForArgument(Chain, DAG, MF.getFrameInfo(), FI);
5804       } else {
5805         SDValue PtrOff = DAG.getIntPtrConstant(Offset, DL);
5806 
5807         DstAddr = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, PtrOff);
5808         DstInfo = MachinePointerInfo::getStack(DAG.getMachineFunction(),
5809                                                LocMemOffset);
5810       }
5811 
5812       if (Outs[i].Flags.isByVal()) {
5813         SDValue SizeNode =
5814             DAG.getConstant(Outs[i].Flags.getByValSize(), DL, MVT::i64);
5815         SDValue Cpy = DAG.getMemcpy(
5816             Chain, DL, DstAddr, Arg, SizeNode,
5817             Outs[i].Flags.getNonZeroByValAlign(),
5818             /*isVol = */ false, /*AlwaysInline = */ false,
5819             /*isTailCall = */ false, DstInfo, MachinePointerInfo());
5820 
5821         MemOpChains.push_back(Cpy);
5822       } else {
5823         // Since we pass i1/i8/i16 as i1/i8/i16 on stack and Arg is already
5824         // promoted to a legal register type i32, we should truncate Arg back to
5825         // i1/i8/i16.
5826         if (VA.getValVT() == MVT::i1 || VA.getValVT() == MVT::i8 ||
5827             VA.getValVT() == MVT::i16)
5828           Arg = DAG.getNode(ISD::TRUNCATE, DL, VA.getValVT(), Arg);
5829 
5830         SDValue Store = DAG.getStore(Chain, DL, Arg, DstAddr, DstInfo);
5831         MemOpChains.push_back(Store);
5832       }
5833     }
5834   }
5835 
5836   if (!MemOpChains.empty())
5837     Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOpChains);
5838 
5839   // Build a sequence of copy-to-reg nodes chained together with token chain
5840   // and flag operands which copy the outgoing args into the appropriate regs.
5841   SDValue InFlag;
5842   for (auto &RegToPass : RegsToPass) {
5843     Chain = DAG.getCopyToReg(Chain, DL, RegToPass.first,
5844                              RegToPass.second, InFlag);
5845     InFlag = Chain.getValue(1);
5846   }
5847 
5848   // If the callee is a GlobalAddress/ExternalSymbol node (quite common, every
5849   // direct call is) turn it into a TargetGlobalAddress/TargetExternalSymbol
5850   // node so that legalize doesn't hack it.
5851   if (auto *G = dyn_cast<GlobalAddressSDNode>(Callee)) {
5852     auto GV = G->getGlobal();
5853     unsigned OpFlags =
5854         Subtarget->classifyGlobalFunctionReference(GV, getTargetMachine());
5855     if (OpFlags & AArch64II::MO_GOT) {
5856       Callee = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, OpFlags);
5857       Callee = DAG.getNode(AArch64ISD::LOADgot, DL, PtrVT, Callee);
5858     } else {
5859       const GlobalValue *GV = G->getGlobal();
5860       Callee = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, 0);
5861     }
5862   } else if (auto *S = dyn_cast<ExternalSymbolSDNode>(Callee)) {
5863     if (getTargetMachine().getCodeModel() == CodeModel::Large &&
5864         Subtarget->isTargetMachO()) {
5865       const char *Sym = S->getSymbol();
5866       Callee = DAG.getTargetExternalSymbol(Sym, PtrVT, AArch64II::MO_GOT);
5867       Callee = DAG.getNode(AArch64ISD::LOADgot, DL, PtrVT, Callee);
5868     } else {
5869       const char *Sym = S->getSymbol();
5870       Callee = DAG.getTargetExternalSymbol(Sym, PtrVT, 0);
5871     }
5872   }
5873 
5874   // We don't usually want to end the call-sequence here because we would tidy
5875   // the frame up *after* the call, however in the ABI-changing tail-call case
5876   // we've carefully laid out the parameters so that when sp is reset they'll be
5877   // in the correct location.
5878   if (IsTailCall && !IsSibCall) {
5879     Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(0, DL, true),
5880                                DAG.getIntPtrConstant(0, DL, true), InFlag, DL);
5881     InFlag = Chain.getValue(1);
5882   }
5883 
5884   std::vector<SDValue> Ops;
5885   Ops.push_back(Chain);
5886   Ops.push_back(Callee);
5887 
5888   if (IsTailCall) {
5889     // Each tail call may have to adjust the stack by a different amount, so
5890     // this information must travel along with the operation for eventual
5891     // consumption by emitEpilogue.
5892     Ops.push_back(DAG.getTargetConstant(FPDiff, DL, MVT::i32));
5893   }
5894 
5895   // Add argument registers to the end of the list so that they are known live
5896   // into the call.
5897   for (auto &RegToPass : RegsToPass)
5898     Ops.push_back(DAG.getRegister(RegToPass.first,
5899                                   RegToPass.second.getValueType()));
5900 
5901   // Add a register mask operand representing the call-preserved registers.
5902   const uint32_t *Mask;
5903   const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
5904   if (IsThisReturn) {
5905     // For 'this' returns, use the X0-preserving mask if applicable
5906     Mask = TRI->getThisReturnPreservedMask(MF, CallConv);
5907     if (!Mask) {
5908       IsThisReturn = false;
5909       Mask = TRI->getCallPreservedMask(MF, CallConv);
5910     }
5911   } else
5912     Mask = TRI->getCallPreservedMask(MF, CallConv);
5913 
5914   if (Subtarget->hasCustomCallingConv())
5915     TRI->UpdateCustomCallPreservedMask(MF, &Mask);
5916 
5917   if (TRI->isAnyArgRegReserved(MF))
5918     TRI->emitReservedArgRegCallError(MF);
5919 
5920   assert(Mask && "Missing call preserved mask for calling convention");
5921   Ops.push_back(DAG.getRegisterMask(Mask));
5922 
5923   if (InFlag.getNode())
5924     Ops.push_back(InFlag);
5925 
5926   SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
5927 
5928   // If we're doing a tall call, use a TC_RETURN here rather than an
5929   // actual call instruction.
5930   if (IsTailCall) {
5931     MF.getFrameInfo().setHasTailCall();
5932     SDValue Ret = DAG.getNode(AArch64ISD::TC_RETURN, DL, NodeTys, Ops);
5933     DAG.addCallSiteInfo(Ret.getNode(), std::move(CSInfo));
5934     return Ret;
5935   }
5936 
5937   unsigned CallOpc = AArch64ISD::CALL;
5938   // Calls with operand bundle "clang.arc.attachedcall" are special. They should
5939   // be expanded to the call, directly followed by a special marker sequence.
5940   // Use the CALL_RVMARKER to do that.
5941   if (CLI.CB && objcarc::hasAttachedCallOpBundle(CLI.CB)) {
5942     assert(!IsTailCall &&
5943            "tail calls cannot be marked with clang.arc.attachedcall");
5944     CallOpc = AArch64ISD::CALL_RVMARKER;
5945   }
5946 
5947   // Returns a chain and a flag for retval copy to use.
5948   Chain = DAG.getNode(CallOpc, DL, NodeTys, Ops);
5949   DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge);
5950   InFlag = Chain.getValue(1);
5951   DAG.addCallSiteInfo(Chain.getNode(), std::move(CSInfo));
5952 
5953   uint64_t CalleePopBytes =
5954       DoesCalleeRestoreStack(CallConv, TailCallOpt) ? alignTo(NumBytes, 16) : 0;
5955 
5956   Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(NumBytes, DL, true),
5957                              DAG.getIntPtrConstant(CalleePopBytes, DL, true),
5958                              InFlag, DL);
5959   if (!Ins.empty())
5960     InFlag = Chain.getValue(1);
5961 
5962   // Handle result values, copying them out of physregs into vregs that we
5963   // return.
5964   return LowerCallResult(Chain, InFlag, CallConv, IsVarArg, Ins, DL, DAG,
5965                          InVals, IsThisReturn,
5966                          IsThisReturn ? OutVals[0] : SDValue());
5967 }
5968 
CanLowerReturn(CallingConv::ID CallConv,MachineFunction & MF,bool isVarArg,const SmallVectorImpl<ISD::OutputArg> & Outs,LLVMContext & Context) const5969 bool AArch64TargetLowering::CanLowerReturn(
5970     CallingConv::ID CallConv, MachineFunction &MF, bool isVarArg,
5971     const SmallVectorImpl<ISD::OutputArg> &Outs, LLVMContext &Context) const {
5972   CCAssignFn *RetCC = CCAssignFnForReturn(CallConv);
5973   SmallVector<CCValAssign, 16> RVLocs;
5974   CCState CCInfo(CallConv, isVarArg, MF, RVLocs, Context);
5975   return CCInfo.CheckReturn(Outs, RetCC);
5976 }
5977 
5978 SDValue
LowerReturn(SDValue Chain,CallingConv::ID CallConv,bool isVarArg,const SmallVectorImpl<ISD::OutputArg> & Outs,const SmallVectorImpl<SDValue> & OutVals,const SDLoc & DL,SelectionDAG & DAG) const5979 AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
5980                                    bool isVarArg,
5981                                    const SmallVectorImpl<ISD::OutputArg> &Outs,
5982                                    const SmallVectorImpl<SDValue> &OutVals,
5983                                    const SDLoc &DL, SelectionDAG &DAG) const {
5984   auto &MF = DAG.getMachineFunction();
5985   auto *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
5986 
5987   CCAssignFn *RetCC = CCAssignFnForReturn(CallConv);
5988   SmallVector<CCValAssign, 16> RVLocs;
5989   CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(), RVLocs,
5990                  *DAG.getContext());
5991   CCInfo.AnalyzeReturn(Outs, RetCC);
5992 
5993   // Copy the result values into the output registers.
5994   SDValue Flag;
5995   SmallVector<std::pair<unsigned, SDValue>, 4> RetVals;
5996   SmallSet<unsigned, 4> RegsUsed;
5997   for (unsigned i = 0, realRVLocIdx = 0; i != RVLocs.size();
5998        ++i, ++realRVLocIdx) {
5999     CCValAssign &VA = RVLocs[i];
6000     assert(VA.isRegLoc() && "Can only return in registers!");
6001     SDValue Arg = OutVals[realRVLocIdx];
6002 
6003     switch (VA.getLocInfo()) {
6004     default:
6005       llvm_unreachable("Unknown loc info!");
6006     case CCValAssign::Full:
6007       if (Outs[i].ArgVT == MVT::i1) {
6008         // AAPCS requires i1 to be zero-extended to i8 by the producer of the
6009         // value. This is strictly redundant on Darwin (which uses "zeroext
6010         // i1"), but will be optimised out before ISel.
6011         Arg = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Arg);
6012         Arg = DAG.getNode(ISD::ZERO_EXTEND, DL, VA.getLocVT(), Arg);
6013       }
6014       break;
6015     case CCValAssign::BCvt:
6016       Arg = DAG.getNode(ISD::BITCAST, DL, VA.getLocVT(), Arg);
6017       break;
6018     case CCValAssign::AExt:
6019     case CCValAssign::ZExt:
6020       Arg = DAG.getZExtOrTrunc(Arg, DL, VA.getLocVT());
6021       break;
6022     case CCValAssign::AExtUpper:
6023       assert(VA.getValVT() == MVT::i32 && "only expect 32 -> 64 upper bits");
6024       Arg = DAG.getZExtOrTrunc(Arg, DL, VA.getLocVT());
6025       Arg = DAG.getNode(ISD::SHL, DL, VA.getLocVT(), Arg,
6026                         DAG.getConstant(32, DL, VA.getLocVT()));
6027       break;
6028     }
6029 
6030     if (RegsUsed.count(VA.getLocReg())) {
6031       SDValue &Bits =
6032           llvm::find_if(RetVals, [=](const std::pair<unsigned, SDValue> &Elt) {
6033             return Elt.first == VA.getLocReg();
6034           })->second;
6035       Bits = DAG.getNode(ISD::OR, DL, Bits.getValueType(), Bits, Arg);
6036     } else {
6037       RetVals.emplace_back(VA.getLocReg(), Arg);
6038       RegsUsed.insert(VA.getLocReg());
6039     }
6040   }
6041 
6042   SmallVector<SDValue, 4> RetOps(1, Chain);
6043   for (auto &RetVal : RetVals) {
6044     Chain = DAG.getCopyToReg(Chain, DL, RetVal.first, RetVal.second, Flag);
6045     Flag = Chain.getValue(1);
6046     RetOps.push_back(
6047         DAG.getRegister(RetVal.first, RetVal.second.getValueType()));
6048   }
6049 
6050   // Windows AArch64 ABIs require that for returning structs by value we copy
6051   // the sret argument into X0 for the return.
6052   // We saved the argument into a virtual register in the entry block,
6053   // so now we copy the value out and into X0.
6054   if (unsigned SRetReg = FuncInfo->getSRetReturnReg()) {
6055     SDValue Val = DAG.getCopyFromReg(RetOps[0], DL, SRetReg,
6056                                      getPointerTy(MF.getDataLayout()));
6057 
6058     unsigned RetValReg = AArch64::X0;
6059     Chain = DAG.getCopyToReg(Chain, DL, RetValReg, Val, Flag);
6060     Flag = Chain.getValue(1);
6061 
6062     RetOps.push_back(
6063       DAG.getRegister(RetValReg, getPointerTy(DAG.getDataLayout())));
6064   }
6065 
6066   const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
6067   const MCPhysReg *I =
6068       TRI->getCalleeSavedRegsViaCopy(&DAG.getMachineFunction());
6069   if (I) {
6070     for (; *I; ++I) {
6071       if (AArch64::GPR64RegClass.contains(*I))
6072         RetOps.push_back(DAG.getRegister(*I, MVT::i64));
6073       else if (AArch64::FPR64RegClass.contains(*I))
6074         RetOps.push_back(DAG.getRegister(*I, MVT::getFloatingPointVT(64)));
6075       else
6076         llvm_unreachable("Unexpected register class in CSRsViaCopy!");
6077     }
6078   }
6079 
6080   RetOps[0] = Chain; // Update chain.
6081 
6082   // Add the flag if we have it.
6083   if (Flag.getNode())
6084     RetOps.push_back(Flag);
6085 
6086   return DAG.getNode(AArch64ISD::RET_FLAG, DL, MVT::Other, RetOps);
6087 }
6088 
6089 //===----------------------------------------------------------------------===//
6090 //  Other Lowering Code
6091 //===----------------------------------------------------------------------===//
6092 
getTargetNode(GlobalAddressSDNode * N,EVT Ty,SelectionDAG & DAG,unsigned Flag) const6093 SDValue AArch64TargetLowering::getTargetNode(GlobalAddressSDNode *N, EVT Ty,
6094                                              SelectionDAG &DAG,
6095                                              unsigned Flag) const {
6096   return DAG.getTargetGlobalAddress(N->getGlobal(), SDLoc(N), Ty,
6097                                     N->getOffset(), Flag);
6098 }
6099 
getTargetNode(JumpTableSDNode * N,EVT Ty,SelectionDAG & DAG,unsigned Flag) const6100 SDValue AArch64TargetLowering::getTargetNode(JumpTableSDNode *N, EVT Ty,
6101                                              SelectionDAG &DAG,
6102                                              unsigned Flag) const {
6103   return DAG.getTargetJumpTable(N->getIndex(), Ty, Flag);
6104 }
6105 
getTargetNode(ConstantPoolSDNode * N,EVT Ty,SelectionDAG & DAG,unsigned Flag) const6106 SDValue AArch64TargetLowering::getTargetNode(ConstantPoolSDNode *N, EVT Ty,
6107                                              SelectionDAG &DAG,
6108                                              unsigned Flag) const {
6109   return DAG.getTargetConstantPool(N->getConstVal(), Ty, N->getAlign(),
6110                                    N->getOffset(), Flag);
6111 }
6112 
getTargetNode(BlockAddressSDNode * N,EVT Ty,SelectionDAG & DAG,unsigned Flag) const6113 SDValue AArch64TargetLowering::getTargetNode(BlockAddressSDNode* N, EVT Ty,
6114                                              SelectionDAG &DAG,
6115                                              unsigned Flag) const {
6116   return DAG.getTargetBlockAddress(N->getBlockAddress(), Ty, 0, Flag);
6117 }
6118 
6119 // (loadGOT sym)
6120 template <class NodeTy>
getGOT(NodeTy * N,SelectionDAG & DAG,unsigned Flags) const6121 SDValue AArch64TargetLowering::getGOT(NodeTy *N, SelectionDAG &DAG,
6122                                       unsigned Flags) const {
6123   LLVM_DEBUG(dbgs() << "AArch64TargetLowering::getGOT\n");
6124   SDLoc DL(N);
6125   EVT Ty = getPointerTy(DAG.getDataLayout());
6126   SDValue GotAddr = getTargetNode(N, Ty, DAG, AArch64II::MO_GOT | Flags);
6127   // FIXME: Once remat is capable of dealing with instructions with register
6128   // operands, expand this into two nodes instead of using a wrapper node.
6129   return DAG.getNode(AArch64ISD::LOADgot, DL, Ty, GotAddr);
6130 }
6131 
6132 // (wrapper %highest(sym), %higher(sym), %hi(sym), %lo(sym))
6133 template <class NodeTy>
getAddrLarge(NodeTy * N,SelectionDAG & DAG,unsigned Flags) const6134 SDValue AArch64TargetLowering::getAddrLarge(NodeTy *N, SelectionDAG &DAG,
6135                                             unsigned Flags) const {
6136   LLVM_DEBUG(dbgs() << "AArch64TargetLowering::getAddrLarge\n");
6137   SDLoc DL(N);
6138   EVT Ty = getPointerTy(DAG.getDataLayout());
6139   const unsigned char MO_NC = AArch64II::MO_NC;
6140   return DAG.getNode(
6141       AArch64ISD::WrapperLarge, DL, Ty,
6142       getTargetNode(N, Ty, DAG, AArch64II::MO_G3 | Flags),
6143       getTargetNode(N, Ty, DAG, AArch64II::MO_G2 | MO_NC | Flags),
6144       getTargetNode(N, Ty, DAG, AArch64II::MO_G1 | MO_NC | Flags),
6145       getTargetNode(N, Ty, DAG, AArch64II::MO_G0 | MO_NC | Flags));
6146 }
6147 
6148 // (addlow (adrp %hi(sym)) %lo(sym))
6149 template <class NodeTy>
getAddr(NodeTy * N,SelectionDAG & DAG,unsigned Flags) const6150 SDValue AArch64TargetLowering::getAddr(NodeTy *N, SelectionDAG &DAG,
6151                                        unsigned Flags) const {
6152   LLVM_DEBUG(dbgs() << "AArch64TargetLowering::getAddr\n");
6153   SDLoc DL(N);
6154   EVT Ty = getPointerTy(DAG.getDataLayout());
6155   SDValue Hi = getTargetNode(N, Ty, DAG, AArch64II::MO_PAGE | Flags);
6156   SDValue Lo = getTargetNode(N, Ty, DAG,
6157                              AArch64II::MO_PAGEOFF | AArch64II::MO_NC | Flags);
6158   SDValue ADRP = DAG.getNode(AArch64ISD::ADRP, DL, Ty, Hi);
6159   return DAG.getNode(AArch64ISD::ADDlow, DL, Ty, ADRP, Lo);
6160 }
6161 
6162 // (adr sym)
6163 template <class NodeTy>
getAddrTiny(NodeTy * N,SelectionDAG & DAG,unsigned Flags) const6164 SDValue AArch64TargetLowering::getAddrTiny(NodeTy *N, SelectionDAG &DAG,
6165                                            unsigned Flags) const {
6166   LLVM_DEBUG(dbgs() << "AArch64TargetLowering::getAddrTiny\n");
6167   SDLoc DL(N);
6168   EVT Ty = getPointerTy(DAG.getDataLayout());
6169   SDValue Sym = getTargetNode(N, Ty, DAG, Flags);
6170   return DAG.getNode(AArch64ISD::ADR, DL, Ty, Sym);
6171 }
6172 
LowerGlobalAddress(SDValue Op,SelectionDAG & DAG) const6173 SDValue AArch64TargetLowering::LowerGlobalAddress(SDValue Op,
6174                                                   SelectionDAG &DAG) const {
6175   GlobalAddressSDNode *GN = cast<GlobalAddressSDNode>(Op);
6176   const GlobalValue *GV = GN->getGlobal();
6177   unsigned OpFlags = Subtarget->ClassifyGlobalReference(GV, getTargetMachine());
6178 
6179   if (OpFlags != AArch64II::MO_NO_FLAG)
6180     assert(cast<GlobalAddressSDNode>(Op)->getOffset() == 0 &&
6181            "unexpected offset in global node");
6182 
6183   // This also catches the large code model case for Darwin, and tiny code
6184   // model with got relocations.
6185   if ((OpFlags & AArch64II::MO_GOT) != 0) {
6186     return getGOT(GN, DAG, OpFlags);
6187   }
6188 
6189   SDValue Result;
6190   if (getTargetMachine().getCodeModel() == CodeModel::Large) {
6191     Result = getAddrLarge(GN, DAG, OpFlags);
6192   } else if (getTargetMachine().getCodeModel() == CodeModel::Tiny) {
6193     Result = getAddrTiny(GN, DAG, OpFlags);
6194   } else {
6195     Result = getAddr(GN, DAG, OpFlags);
6196   }
6197   EVT PtrVT = getPointerTy(DAG.getDataLayout());
6198   SDLoc DL(GN);
6199   if (OpFlags & (AArch64II::MO_DLLIMPORT | AArch64II::MO_COFFSTUB))
6200     Result = DAG.getLoad(PtrVT, DL, DAG.getEntryNode(), Result,
6201                          MachinePointerInfo::getGOT(DAG.getMachineFunction()));
6202   return Result;
6203 }
6204 
6205 /// Convert a TLS address reference into the correct sequence of loads
6206 /// and calls to compute the variable's address (for Darwin, currently) and
6207 /// return an SDValue containing the final node.
6208 
6209 /// Darwin only has one TLS scheme which must be capable of dealing with the
6210 /// fully general situation, in the worst case. This means:
6211 ///     + "extern __thread" declaration.
6212 ///     + Defined in a possibly unknown dynamic library.
6213 ///
6214 /// The general system is that each __thread variable has a [3 x i64] descriptor
6215 /// which contains information used by the runtime to calculate the address. The
6216 /// only part of this the compiler needs to know about is the first xword, which
6217 /// contains a function pointer that must be called with the address of the
6218 /// entire descriptor in "x0".
6219 ///
6220 /// Since this descriptor may be in a different unit, in general even the
6221 /// descriptor must be accessed via an indirect load. The "ideal" code sequence
6222 /// is:
6223 ///     adrp x0, _var@TLVPPAGE
6224 ///     ldr x0, [x0, _var@TLVPPAGEOFF]   ; x0 now contains address of descriptor
6225 ///     ldr x1, [x0]                     ; x1 contains 1st entry of descriptor,
6226 ///                                      ; the function pointer
6227 ///     blr x1                           ; Uses descriptor address in x0
6228 ///     ; Address of _var is now in x0.
6229 ///
6230 /// If the address of _var's descriptor *is* known to the linker, then it can
6231 /// change the first "ldr" instruction to an appropriate "add x0, x0, #imm" for
6232 /// a slight efficiency gain.
6233 SDValue
LowerDarwinGlobalTLSAddress(SDValue Op,SelectionDAG & DAG) const6234 AArch64TargetLowering::LowerDarwinGlobalTLSAddress(SDValue Op,
6235                                                    SelectionDAG &DAG) const {
6236   assert(Subtarget->isTargetDarwin() &&
6237          "This function expects a Darwin target");
6238 
6239   SDLoc DL(Op);
6240   MVT PtrVT = getPointerTy(DAG.getDataLayout());
6241   MVT PtrMemVT = getPointerMemTy(DAG.getDataLayout());
6242   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
6243 
6244   SDValue TLVPAddr =
6245       DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, AArch64II::MO_TLS);
6246   SDValue DescAddr = DAG.getNode(AArch64ISD::LOADgot, DL, PtrVT, TLVPAddr);
6247 
6248   // The first entry in the descriptor is a function pointer that we must call
6249   // to obtain the address of the variable.
6250   SDValue Chain = DAG.getEntryNode();
6251   SDValue FuncTLVGet = DAG.getLoad(
6252       PtrMemVT, DL, Chain, DescAddr,
6253       MachinePointerInfo::getGOT(DAG.getMachineFunction()),
6254       Align(PtrMemVT.getSizeInBits() / 8),
6255       MachineMemOperand::MOInvariant | MachineMemOperand::MODereferenceable);
6256   Chain = FuncTLVGet.getValue(1);
6257 
6258   // Extend loaded pointer if necessary (i.e. if ILP32) to DAG pointer.
6259   FuncTLVGet = DAG.getZExtOrTrunc(FuncTLVGet, DL, PtrVT);
6260 
6261   MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
6262   MFI.setAdjustsStack(true);
6263 
6264   // TLS calls preserve all registers except those that absolutely must be
6265   // trashed: X0 (it takes an argument), LR (it's a call) and NZCV (let's not be
6266   // silly).
6267   const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
6268   const uint32_t *Mask = TRI->getTLSCallPreservedMask();
6269   if (Subtarget->hasCustomCallingConv())
6270     TRI->UpdateCustomCallPreservedMask(DAG.getMachineFunction(), &Mask);
6271 
6272   // Finally, we can make the call. This is just a degenerate version of a
6273   // normal AArch64 call node: x0 takes the address of the descriptor, and
6274   // returns the address of the variable in this thread.
6275   Chain = DAG.getCopyToReg(Chain, DL, AArch64::X0, DescAddr, SDValue());
6276   Chain =
6277       DAG.getNode(AArch64ISD::CALL, DL, DAG.getVTList(MVT::Other, MVT::Glue),
6278                   Chain, FuncTLVGet, DAG.getRegister(AArch64::X0, MVT::i64),
6279                   DAG.getRegisterMask(Mask), Chain.getValue(1));
6280   return DAG.getCopyFromReg(Chain, DL, AArch64::X0, PtrVT, Chain.getValue(1));
6281 }
6282 
6283 /// Convert a thread-local variable reference into a sequence of instructions to
6284 /// compute the variable's address for the local exec TLS model of ELF targets.
6285 /// The sequence depends on the maximum TLS area size.
LowerELFTLSLocalExec(const GlobalValue * GV,SDValue ThreadBase,const SDLoc & DL,SelectionDAG & DAG) const6286 SDValue AArch64TargetLowering::LowerELFTLSLocalExec(const GlobalValue *GV,
6287                                                     SDValue ThreadBase,
6288                                                     const SDLoc &DL,
6289                                                     SelectionDAG &DAG) const {
6290   EVT PtrVT = getPointerTy(DAG.getDataLayout());
6291   SDValue TPOff, Addr;
6292 
6293   switch (DAG.getTarget().Options.TLSSize) {
6294   default:
6295     llvm_unreachable("Unexpected TLS size");
6296 
6297   case 12: {
6298     // mrs   x0, TPIDR_EL0
6299     // add   x0, x0, :tprel_lo12:a
6300     SDValue Var = DAG.getTargetGlobalAddress(
6301         GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_PAGEOFF);
6302     return SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, ThreadBase,
6303                                       Var,
6304                                       DAG.getTargetConstant(0, DL, MVT::i32)),
6305                    0);
6306   }
6307 
6308   case 24: {
6309     // mrs   x0, TPIDR_EL0
6310     // add   x0, x0, :tprel_hi12:a
6311     // add   x0, x0, :tprel_lo12_nc:a
6312     SDValue HiVar = DAG.getTargetGlobalAddress(
6313         GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_HI12);
6314     SDValue LoVar = DAG.getTargetGlobalAddress(
6315         GV, DL, PtrVT, 0,
6316         AArch64II::MO_TLS | AArch64II::MO_PAGEOFF | AArch64II::MO_NC);
6317     Addr = SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, ThreadBase,
6318                                       HiVar,
6319                                       DAG.getTargetConstant(0, DL, MVT::i32)),
6320                    0);
6321     return SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, Addr,
6322                                       LoVar,
6323                                       DAG.getTargetConstant(0, DL, MVT::i32)),
6324                    0);
6325   }
6326 
6327   case 32: {
6328     // mrs   x1, TPIDR_EL0
6329     // movz  x0, #:tprel_g1:a
6330     // movk  x0, #:tprel_g0_nc:a
6331     // add   x0, x1, x0
6332     SDValue HiVar = DAG.getTargetGlobalAddress(
6333         GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_G1);
6334     SDValue LoVar = DAG.getTargetGlobalAddress(
6335         GV, DL, PtrVT, 0,
6336         AArch64II::MO_TLS | AArch64II::MO_G0 | AArch64II::MO_NC);
6337     TPOff = SDValue(DAG.getMachineNode(AArch64::MOVZXi, DL, PtrVT, HiVar,
6338                                        DAG.getTargetConstant(16, DL, MVT::i32)),
6339                     0);
6340     TPOff = SDValue(DAG.getMachineNode(AArch64::MOVKXi, DL, PtrVT, TPOff, LoVar,
6341                                        DAG.getTargetConstant(0, DL, MVT::i32)),
6342                     0);
6343     return DAG.getNode(ISD::ADD, DL, PtrVT, ThreadBase, TPOff);
6344   }
6345 
6346   case 48: {
6347     // mrs   x1, TPIDR_EL0
6348     // movz  x0, #:tprel_g2:a
6349     // movk  x0, #:tprel_g1_nc:a
6350     // movk  x0, #:tprel_g0_nc:a
6351     // add   x0, x1, x0
6352     SDValue HiVar = DAG.getTargetGlobalAddress(
6353         GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_G2);
6354     SDValue MiVar = DAG.getTargetGlobalAddress(
6355         GV, DL, PtrVT, 0,
6356         AArch64II::MO_TLS | AArch64II::MO_G1 | AArch64II::MO_NC);
6357     SDValue LoVar = DAG.getTargetGlobalAddress(
6358         GV, DL, PtrVT, 0,
6359         AArch64II::MO_TLS | AArch64II::MO_G0 | AArch64II::MO_NC);
6360     TPOff = SDValue(DAG.getMachineNode(AArch64::MOVZXi, DL, PtrVT, HiVar,
6361                                        DAG.getTargetConstant(32, DL, MVT::i32)),
6362                     0);
6363     TPOff = SDValue(DAG.getMachineNode(AArch64::MOVKXi, DL, PtrVT, TPOff, MiVar,
6364                                        DAG.getTargetConstant(16, DL, MVT::i32)),
6365                     0);
6366     TPOff = SDValue(DAG.getMachineNode(AArch64::MOVKXi, DL, PtrVT, TPOff, LoVar,
6367                                        DAG.getTargetConstant(0, DL, MVT::i32)),
6368                     0);
6369     return DAG.getNode(ISD::ADD, DL, PtrVT, ThreadBase, TPOff);
6370   }
6371   }
6372 }
6373 
6374 /// When accessing thread-local variables under either the general-dynamic or
6375 /// local-dynamic system, we make a "TLS-descriptor" call. The variable will
6376 /// have a descriptor, accessible via a PC-relative ADRP, and whose first entry
6377 /// is a function pointer to carry out the resolution.
6378 ///
6379 /// The sequence is:
6380 ///    adrp  x0, :tlsdesc:var
6381 ///    ldr   x1, [x0, #:tlsdesc_lo12:var]
6382 ///    add   x0, x0, #:tlsdesc_lo12:var
6383 ///    .tlsdesccall var
6384 ///    blr   x1
6385 ///    (TPIDR_EL0 offset now in x0)
6386 ///
6387 ///  The above sequence must be produced unscheduled, to enable the linker to
6388 ///  optimize/relax this sequence.
6389 ///  Therefore, a pseudo-instruction (TLSDESC_CALLSEQ) is used to represent the
6390 ///  above sequence, and expanded really late in the compilation flow, to ensure
6391 ///  the sequence is produced as per above.
LowerELFTLSDescCallSeq(SDValue SymAddr,const SDLoc & DL,SelectionDAG & DAG) const6392 SDValue AArch64TargetLowering::LowerELFTLSDescCallSeq(SDValue SymAddr,
6393                                                       const SDLoc &DL,
6394                                                       SelectionDAG &DAG) const {
6395   EVT PtrVT = getPointerTy(DAG.getDataLayout());
6396 
6397   SDValue Chain = DAG.getEntryNode();
6398   SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
6399 
6400   Chain =
6401       DAG.getNode(AArch64ISD::TLSDESC_CALLSEQ, DL, NodeTys, {Chain, SymAddr});
6402   SDValue Glue = Chain.getValue(1);
6403 
6404   return DAG.getCopyFromReg(Chain, DL, AArch64::X0, PtrVT, Glue);
6405 }
6406 
6407 SDValue
LowerELFGlobalTLSAddress(SDValue Op,SelectionDAG & DAG) const6408 AArch64TargetLowering::LowerELFGlobalTLSAddress(SDValue Op,
6409                                                 SelectionDAG &DAG) const {
6410   assert(Subtarget->isTargetELF() && "This function expects an ELF target");
6411 
6412   const GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Op);
6413 
6414   TLSModel::Model Model = getTargetMachine().getTLSModel(GA->getGlobal());
6415 
6416   if (!EnableAArch64ELFLocalDynamicTLSGeneration) {
6417     if (Model == TLSModel::LocalDynamic)
6418       Model = TLSModel::GeneralDynamic;
6419   }
6420 
6421   if (getTargetMachine().getCodeModel() == CodeModel::Large &&
6422       Model != TLSModel::LocalExec)
6423     report_fatal_error("ELF TLS only supported in small memory model or "
6424                        "in local exec TLS model");
6425   // Different choices can be made for the maximum size of the TLS area for a
6426   // module. For the small address model, the default TLS size is 16MiB and the
6427   // maximum TLS size is 4GiB.
6428   // FIXME: add tiny and large code model support for TLS access models other
6429   // than local exec. We currently generate the same code as small for tiny,
6430   // which may be larger than needed.
6431 
6432   SDValue TPOff;
6433   EVT PtrVT = getPointerTy(DAG.getDataLayout());
6434   SDLoc DL(Op);
6435   const GlobalValue *GV = GA->getGlobal();
6436 
6437   SDValue ThreadBase = DAG.getNode(AArch64ISD::THREAD_POINTER, DL, PtrVT);
6438 
6439   if (Model == TLSModel::LocalExec) {
6440     return LowerELFTLSLocalExec(GV, ThreadBase, DL, DAG);
6441   } else if (Model == TLSModel::InitialExec) {
6442     TPOff = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, AArch64II::MO_TLS);
6443     TPOff = DAG.getNode(AArch64ISD::LOADgot, DL, PtrVT, TPOff);
6444   } else if (Model == TLSModel::LocalDynamic) {
6445     // Local-dynamic accesses proceed in two phases. A general-dynamic TLS
6446     // descriptor call against the special symbol _TLS_MODULE_BASE_ to calculate
6447     // the beginning of the module's TLS region, followed by a DTPREL offset
6448     // calculation.
6449 
6450     // These accesses will need deduplicating if there's more than one.
6451     AArch64FunctionInfo *MFI =
6452         DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
6453     MFI->incNumLocalDynamicTLSAccesses();
6454 
6455     // The call needs a relocation too for linker relaxation. It doesn't make
6456     // sense to call it MO_PAGE or MO_PAGEOFF though so we need another copy of
6457     // the address.
6458     SDValue SymAddr = DAG.getTargetExternalSymbol("_TLS_MODULE_BASE_", PtrVT,
6459                                                   AArch64II::MO_TLS);
6460 
6461     // Now we can calculate the offset from TPIDR_EL0 to this module's
6462     // thread-local area.
6463     TPOff = LowerELFTLSDescCallSeq(SymAddr, DL, DAG);
6464 
6465     // Now use :dtprel_whatever: operations to calculate this variable's offset
6466     // in its thread-storage area.
6467     SDValue HiVar = DAG.getTargetGlobalAddress(
6468         GV, DL, MVT::i64, 0, AArch64II::MO_TLS | AArch64II::MO_HI12);
6469     SDValue LoVar = DAG.getTargetGlobalAddress(
6470         GV, DL, MVT::i64, 0,
6471         AArch64II::MO_TLS | AArch64II::MO_PAGEOFF | AArch64II::MO_NC);
6472 
6473     TPOff = SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, TPOff, HiVar,
6474                                        DAG.getTargetConstant(0, DL, MVT::i32)),
6475                     0);
6476     TPOff = SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, TPOff, LoVar,
6477                                        DAG.getTargetConstant(0, DL, MVT::i32)),
6478                     0);
6479   } else if (Model == TLSModel::GeneralDynamic) {
6480     // The call needs a relocation too for linker relaxation. It doesn't make
6481     // sense to call it MO_PAGE or MO_PAGEOFF though so we need another copy of
6482     // the address.
6483     SDValue SymAddr =
6484         DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, AArch64II::MO_TLS);
6485 
6486     // Finally we can make a call to calculate the offset from tpidr_el0.
6487     TPOff = LowerELFTLSDescCallSeq(SymAddr, DL, DAG);
6488   } else
6489     llvm_unreachable("Unsupported ELF TLS access model");
6490 
6491   return DAG.getNode(ISD::ADD, DL, PtrVT, ThreadBase, TPOff);
6492 }
6493 
6494 SDValue
LowerWindowsGlobalTLSAddress(SDValue Op,SelectionDAG & DAG) const6495 AArch64TargetLowering::LowerWindowsGlobalTLSAddress(SDValue Op,
6496                                                     SelectionDAG &DAG) const {
6497   assert(Subtarget->isTargetWindows() && "Windows specific TLS lowering");
6498 
6499   SDValue Chain = DAG.getEntryNode();
6500   EVT PtrVT = getPointerTy(DAG.getDataLayout());
6501   SDLoc DL(Op);
6502 
6503   SDValue TEB = DAG.getRegister(AArch64::X18, MVT::i64);
6504 
6505   // Load the ThreadLocalStoragePointer from the TEB
6506   // A pointer to the TLS array is located at offset 0x58 from the TEB.
6507   SDValue TLSArray =
6508       DAG.getNode(ISD::ADD, DL, PtrVT, TEB, DAG.getIntPtrConstant(0x58, DL));
6509   TLSArray = DAG.getLoad(PtrVT, DL, Chain, TLSArray, MachinePointerInfo());
6510   Chain = TLSArray.getValue(1);
6511 
6512   // Load the TLS index from the C runtime;
6513   // This does the same as getAddr(), but without having a GlobalAddressSDNode.
6514   // This also does the same as LOADgot, but using a generic i32 load,
6515   // while LOADgot only loads i64.
6516   SDValue TLSIndexHi =
6517       DAG.getTargetExternalSymbol("_tls_index", PtrVT, AArch64II::MO_PAGE);
6518   SDValue TLSIndexLo = DAG.getTargetExternalSymbol(
6519       "_tls_index", PtrVT, AArch64II::MO_PAGEOFF | AArch64II::MO_NC);
6520   SDValue ADRP = DAG.getNode(AArch64ISD::ADRP, DL, PtrVT, TLSIndexHi);
6521   SDValue TLSIndex =
6522       DAG.getNode(AArch64ISD::ADDlow, DL, PtrVT, ADRP, TLSIndexLo);
6523   TLSIndex = DAG.getLoad(MVT::i32, DL, Chain, TLSIndex, MachinePointerInfo());
6524   Chain = TLSIndex.getValue(1);
6525 
6526   // The pointer to the thread's TLS data area is at the TLS Index scaled by 8
6527   // offset into the TLSArray.
6528   TLSIndex = DAG.getNode(ISD::ZERO_EXTEND, DL, PtrVT, TLSIndex);
6529   SDValue Slot = DAG.getNode(ISD::SHL, DL, PtrVT, TLSIndex,
6530                              DAG.getConstant(3, DL, PtrVT));
6531   SDValue TLS = DAG.getLoad(PtrVT, DL, Chain,
6532                             DAG.getNode(ISD::ADD, DL, PtrVT, TLSArray, Slot),
6533                             MachinePointerInfo());
6534   Chain = TLS.getValue(1);
6535 
6536   const GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Op);
6537   const GlobalValue *GV = GA->getGlobal();
6538   SDValue TGAHi = DAG.getTargetGlobalAddress(
6539       GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_HI12);
6540   SDValue TGALo = DAG.getTargetGlobalAddress(
6541       GV, DL, PtrVT, 0,
6542       AArch64II::MO_TLS | AArch64II::MO_PAGEOFF | AArch64II::MO_NC);
6543 
6544   // Add the offset from the start of the .tls section (section base).
6545   SDValue Addr =
6546       SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, TLS, TGAHi,
6547                                  DAG.getTargetConstant(0, DL, MVT::i32)),
6548               0);
6549   Addr = DAG.getNode(AArch64ISD::ADDlow, DL, PtrVT, Addr, TGALo);
6550   return Addr;
6551 }
6552 
LowerGlobalTLSAddress(SDValue Op,SelectionDAG & DAG) const6553 SDValue AArch64TargetLowering::LowerGlobalTLSAddress(SDValue Op,
6554                                                      SelectionDAG &DAG) const {
6555   const GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Op);
6556   if (DAG.getTarget().useEmulatedTLS())
6557     return LowerToTLSEmulatedModel(GA, DAG);
6558 
6559   if (Subtarget->isTargetDarwin())
6560     return LowerDarwinGlobalTLSAddress(Op, DAG);
6561   if (Subtarget->isTargetELF())
6562     return LowerELFGlobalTLSAddress(Op, DAG);
6563   if (Subtarget->isTargetWindows())
6564     return LowerWindowsGlobalTLSAddress(Op, DAG);
6565 
6566   llvm_unreachable("Unexpected platform trying to use TLS");
6567 }
6568 
6569 // Looks through \param Val to determine the bit that can be used to
6570 // check the sign of the value. It returns the unextended value and
6571 // the sign bit position.
lookThroughSignExtension(SDValue Val)6572 std::pair<SDValue, uint64_t> lookThroughSignExtension(SDValue Val) {
6573   if (Val.getOpcode() == ISD::SIGN_EXTEND_INREG)
6574     return {Val.getOperand(0),
6575             cast<VTSDNode>(Val.getOperand(1))->getVT().getFixedSizeInBits() -
6576                 1};
6577 
6578   if (Val.getOpcode() == ISD::SIGN_EXTEND)
6579     return {Val.getOperand(0),
6580             Val.getOperand(0)->getValueType(0).getFixedSizeInBits() - 1};
6581 
6582   return {Val, Val.getValueSizeInBits() - 1};
6583 }
6584 
LowerBR_CC(SDValue Op,SelectionDAG & DAG) const6585 SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const {
6586   SDValue Chain = Op.getOperand(0);
6587   ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(1))->get();
6588   SDValue LHS = Op.getOperand(2);
6589   SDValue RHS = Op.getOperand(3);
6590   SDValue Dest = Op.getOperand(4);
6591   SDLoc dl(Op);
6592 
6593   MachineFunction &MF = DAG.getMachineFunction();
6594   // Speculation tracking/SLH assumes that optimized TB(N)Z/CB(N)Z instructions
6595   // will not be produced, as they are conditional branch instructions that do
6596   // not set flags.
6597   bool ProduceNonFlagSettingCondBr =
6598       !MF.getFunction().hasFnAttribute(Attribute::SpeculativeLoadHardening);
6599 
6600   // Handle f128 first, since lowering it will result in comparing the return
6601   // value of a libcall against zero, which is just what the rest of LowerBR_CC
6602   // is expecting to deal with.
6603   if (LHS.getValueType() == MVT::f128) {
6604     softenSetCCOperands(DAG, MVT::f128, LHS, RHS, CC, dl, LHS, RHS);
6605 
6606     // If softenSetCCOperands returned a scalar, we need to compare the result
6607     // against zero to select between true and false values.
6608     if (!RHS.getNode()) {
6609       RHS = DAG.getConstant(0, dl, LHS.getValueType());
6610       CC = ISD::SETNE;
6611     }
6612   }
6613 
6614   // Optimize {s|u}{add|sub|mul}.with.overflow feeding into a branch
6615   // instruction.
6616   if (ISD::isOverflowIntrOpRes(LHS) && isOneConstant(RHS) &&
6617       (CC == ISD::SETEQ || CC == ISD::SETNE)) {
6618     // Only lower legal XALUO ops.
6619     if (!DAG.getTargetLoweringInfo().isTypeLegal(LHS->getValueType(0)))
6620       return SDValue();
6621 
6622     // The actual operation with overflow check.
6623     AArch64CC::CondCode OFCC;
6624     SDValue Value, Overflow;
6625     std::tie(Value, Overflow) = getAArch64XALUOOp(OFCC, LHS.getValue(0), DAG);
6626 
6627     if (CC == ISD::SETNE)
6628       OFCC = getInvertedCondCode(OFCC);
6629     SDValue CCVal = DAG.getConstant(OFCC, dl, MVT::i32);
6630 
6631     return DAG.getNode(AArch64ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
6632                        Overflow);
6633   }
6634 
6635   if (LHS.getValueType().isInteger()) {
6636     assert((LHS.getValueType() == RHS.getValueType()) &&
6637            (LHS.getValueType() == MVT::i32 || LHS.getValueType() == MVT::i64));
6638 
6639     // If the RHS of the comparison is zero, we can potentially fold this
6640     // to a specialized branch.
6641     const ConstantSDNode *RHSC = dyn_cast<ConstantSDNode>(RHS);
6642     if (RHSC && RHSC->getZExtValue() == 0 && ProduceNonFlagSettingCondBr) {
6643       if (CC == ISD::SETEQ) {
6644         // See if we can use a TBZ to fold in an AND as well.
6645         // TBZ has a smaller branch displacement than CBZ.  If the offset is
6646         // out of bounds, a late MI-layer pass rewrites branches.
6647         // 403.gcc is an example that hits this case.
6648         if (LHS.getOpcode() == ISD::AND &&
6649             isa<ConstantSDNode>(LHS.getOperand(1)) &&
6650             isPowerOf2_64(LHS.getConstantOperandVal(1))) {
6651           SDValue Test = LHS.getOperand(0);
6652           uint64_t Mask = LHS.getConstantOperandVal(1);
6653           return DAG.getNode(AArch64ISD::TBZ, dl, MVT::Other, Chain, Test,
6654                              DAG.getConstant(Log2_64(Mask), dl, MVT::i64),
6655                              Dest);
6656         }
6657 
6658         return DAG.getNode(AArch64ISD::CBZ, dl, MVT::Other, Chain, LHS, Dest);
6659       } else if (CC == ISD::SETNE) {
6660         // See if we can use a TBZ to fold in an AND as well.
6661         // TBZ has a smaller branch displacement than CBZ.  If the offset is
6662         // out of bounds, a late MI-layer pass rewrites branches.
6663         // 403.gcc is an example that hits this case.
6664         if (LHS.getOpcode() == ISD::AND &&
6665             isa<ConstantSDNode>(LHS.getOperand(1)) &&
6666             isPowerOf2_64(LHS.getConstantOperandVal(1))) {
6667           SDValue Test = LHS.getOperand(0);
6668           uint64_t Mask = LHS.getConstantOperandVal(1);
6669           return DAG.getNode(AArch64ISD::TBNZ, dl, MVT::Other, Chain, Test,
6670                              DAG.getConstant(Log2_64(Mask), dl, MVT::i64),
6671                              Dest);
6672         }
6673 
6674         return DAG.getNode(AArch64ISD::CBNZ, dl, MVT::Other, Chain, LHS, Dest);
6675       } else if (CC == ISD::SETLT && LHS.getOpcode() != ISD::AND) {
6676         // Don't combine AND since emitComparison converts the AND to an ANDS
6677         // (a.k.a. TST) and the test in the test bit and branch instruction
6678         // becomes redundant.  This would also increase register pressure.
6679         uint64_t SignBitPos;
6680         std::tie(LHS, SignBitPos) = lookThroughSignExtension(LHS);
6681         return DAG.getNode(AArch64ISD::TBNZ, dl, MVT::Other, Chain, LHS,
6682                            DAG.getConstant(SignBitPos, dl, MVT::i64), Dest);
6683       }
6684     }
6685     if (RHSC && RHSC->getSExtValue() == -1 && CC == ISD::SETGT &&
6686         LHS.getOpcode() != ISD::AND && ProduceNonFlagSettingCondBr) {
6687       // Don't combine AND since emitComparison converts the AND to an ANDS
6688       // (a.k.a. TST) and the test in the test bit and branch instruction
6689       // becomes redundant.  This would also increase register pressure.
6690       uint64_t SignBitPos;
6691       std::tie(LHS, SignBitPos) = lookThroughSignExtension(LHS);
6692       return DAG.getNode(AArch64ISD::TBZ, dl, MVT::Other, Chain, LHS,
6693                          DAG.getConstant(SignBitPos, dl, MVT::i64), Dest);
6694     }
6695 
6696     SDValue CCVal;
6697     SDValue Cmp = getAArch64Cmp(LHS, RHS, CC, CCVal, DAG, dl);
6698     return DAG.getNode(AArch64ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
6699                        Cmp);
6700   }
6701 
6702   assert(LHS.getValueType() == MVT::f16 || LHS.getValueType() == MVT::bf16 ||
6703          LHS.getValueType() == MVT::f32 || LHS.getValueType() == MVT::f64);
6704 
6705   // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
6706   // clean.  Some of them require two branches to implement.
6707   SDValue Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
6708   AArch64CC::CondCode CC1, CC2;
6709   changeFPCCToAArch64CC(CC, CC1, CC2);
6710   SDValue CC1Val = DAG.getConstant(CC1, dl, MVT::i32);
6711   SDValue BR1 =
6712       DAG.getNode(AArch64ISD::BRCOND, dl, MVT::Other, Chain, Dest, CC1Val, Cmp);
6713   if (CC2 != AArch64CC::AL) {
6714     SDValue CC2Val = DAG.getConstant(CC2, dl, MVT::i32);
6715     return DAG.getNode(AArch64ISD::BRCOND, dl, MVT::Other, BR1, Dest, CC2Val,
6716                        Cmp);
6717   }
6718 
6719   return BR1;
6720 }
6721 
LowerFCOPYSIGN(SDValue Op,SelectionDAG & DAG) const6722 SDValue AArch64TargetLowering::LowerFCOPYSIGN(SDValue Op,
6723                                               SelectionDAG &DAG) const {
6724   EVT VT = Op.getValueType();
6725   SDLoc DL(Op);
6726 
6727   SDValue In1 = Op.getOperand(0);
6728   SDValue In2 = Op.getOperand(1);
6729   EVT SrcVT = In2.getValueType();
6730 
6731   if (SrcVT.bitsLT(VT))
6732     In2 = DAG.getNode(ISD::FP_EXTEND, DL, VT, In2);
6733   else if (SrcVT.bitsGT(VT))
6734     In2 = DAG.getNode(ISD::FP_ROUND, DL, VT, In2, DAG.getIntPtrConstant(0, DL));
6735 
6736   EVT VecVT;
6737   uint64_t EltMask;
6738   SDValue VecVal1, VecVal2;
6739 
6740   auto setVecVal = [&] (int Idx) {
6741     if (!VT.isVector()) {
6742       VecVal1 = DAG.getTargetInsertSubreg(Idx, DL, VecVT,
6743                                           DAG.getUNDEF(VecVT), In1);
6744       VecVal2 = DAG.getTargetInsertSubreg(Idx, DL, VecVT,
6745                                           DAG.getUNDEF(VecVT), In2);
6746     } else {
6747       VecVal1 = DAG.getNode(ISD::BITCAST, DL, VecVT, In1);
6748       VecVal2 = DAG.getNode(ISD::BITCAST, DL, VecVT, In2);
6749     }
6750   };
6751 
6752   if (VT == MVT::f32 || VT == MVT::v2f32 || VT == MVT::v4f32) {
6753     VecVT = (VT == MVT::v2f32 ? MVT::v2i32 : MVT::v4i32);
6754     EltMask = 0x80000000ULL;
6755     setVecVal(AArch64::ssub);
6756   } else if (VT == MVT::f64 || VT == MVT::v2f64) {
6757     VecVT = MVT::v2i64;
6758 
6759     // We want to materialize a mask with the high bit set, but the AdvSIMD
6760     // immediate moves cannot materialize that in a single instruction for
6761     // 64-bit elements. Instead, materialize zero and then negate it.
6762     EltMask = 0;
6763 
6764     setVecVal(AArch64::dsub);
6765   } else if (VT == MVT::f16 || VT == MVT::v4f16 || VT == MVT::v8f16) {
6766     VecVT = (VT == MVT::v4f16 ? MVT::v4i16 : MVT::v8i16);
6767     EltMask = 0x8000ULL;
6768     setVecVal(AArch64::hsub);
6769   } else {
6770     llvm_unreachable("Invalid type for copysign!");
6771   }
6772 
6773   SDValue BuildVec = DAG.getConstant(EltMask, DL, VecVT);
6774 
6775   // If we couldn't materialize the mask above, then the mask vector will be
6776   // the zero vector, and we need to negate it here.
6777   if (VT == MVT::f64 || VT == MVT::v2f64) {
6778     BuildVec = DAG.getNode(ISD::BITCAST, DL, MVT::v2f64, BuildVec);
6779     BuildVec = DAG.getNode(ISD::FNEG, DL, MVT::v2f64, BuildVec);
6780     BuildVec = DAG.getNode(ISD::BITCAST, DL, MVT::v2i64, BuildVec);
6781   }
6782 
6783   SDValue Sel =
6784       DAG.getNode(AArch64ISD::BIT, DL, VecVT, VecVal1, VecVal2, BuildVec);
6785 
6786   if (VT == MVT::f16)
6787     return DAG.getTargetExtractSubreg(AArch64::hsub, DL, VT, Sel);
6788   if (VT == MVT::f32)
6789     return DAG.getTargetExtractSubreg(AArch64::ssub, DL, VT, Sel);
6790   else if (VT == MVT::f64)
6791     return DAG.getTargetExtractSubreg(AArch64::dsub, DL, VT, Sel);
6792   else
6793     return DAG.getNode(ISD::BITCAST, DL, VT, Sel);
6794 }
6795 
LowerCTPOP(SDValue Op,SelectionDAG & DAG) const6796 SDValue AArch64TargetLowering::LowerCTPOP(SDValue Op, SelectionDAG &DAG) const {
6797   if (DAG.getMachineFunction().getFunction().hasFnAttribute(
6798           Attribute::NoImplicitFloat))
6799     return SDValue();
6800 
6801   if (!Subtarget->hasNEON())
6802     return SDValue();
6803 
6804   // While there is no integer popcount instruction, it can
6805   // be more efficiently lowered to the following sequence that uses
6806   // AdvSIMD registers/instructions as long as the copies to/from
6807   // the AdvSIMD registers are cheap.
6808   //  FMOV    D0, X0        // copy 64-bit int to vector, high bits zero'd
6809   //  CNT     V0.8B, V0.8B  // 8xbyte pop-counts
6810   //  ADDV    B0, V0.8B     // sum 8xbyte pop-counts
6811   //  UMOV    X0, V0.B[0]   // copy byte result back to integer reg
6812   SDValue Val = Op.getOperand(0);
6813   SDLoc DL(Op);
6814   EVT VT = Op.getValueType();
6815 
6816   if (VT == MVT::i32 || VT == MVT::i64) {
6817     if (VT == MVT::i32)
6818       Val = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Val);
6819     Val = DAG.getNode(ISD::BITCAST, DL, MVT::v8i8, Val);
6820 
6821     SDValue CtPop = DAG.getNode(ISD::CTPOP, DL, MVT::v8i8, Val);
6822     SDValue UaddLV = DAG.getNode(
6823         ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32,
6824         DAG.getConstant(Intrinsic::aarch64_neon_uaddlv, DL, MVT::i32), CtPop);
6825 
6826     if (VT == MVT::i64)
6827       UaddLV = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, UaddLV);
6828     return UaddLV;
6829   } else if (VT == MVT::i128) {
6830     Val = DAG.getNode(ISD::BITCAST, DL, MVT::v16i8, Val);
6831 
6832     SDValue CtPop = DAG.getNode(ISD::CTPOP, DL, MVT::v16i8, Val);
6833     SDValue UaddLV = DAG.getNode(
6834         ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32,
6835         DAG.getConstant(Intrinsic::aarch64_neon_uaddlv, DL, MVT::i32), CtPop);
6836 
6837     return DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i128, UaddLV);
6838   }
6839 
6840   if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT))
6841     return LowerToPredicatedOp(Op, DAG, AArch64ISD::CTPOP_MERGE_PASSTHRU);
6842 
6843   assert((VT == MVT::v1i64 || VT == MVT::v2i64 || VT == MVT::v2i32 ||
6844           VT == MVT::v4i32 || VT == MVT::v4i16 || VT == MVT::v8i16) &&
6845          "Unexpected type for custom ctpop lowering");
6846 
6847   EVT VT8Bit = VT.is64BitVector() ? MVT::v8i8 : MVT::v16i8;
6848   Val = DAG.getBitcast(VT8Bit, Val);
6849   Val = DAG.getNode(ISD::CTPOP, DL, VT8Bit, Val);
6850 
6851   // Widen v8i8/v16i8 CTPOP result to VT by repeatedly widening pairwise adds.
6852   unsigned EltSize = 8;
6853   unsigned NumElts = VT.is64BitVector() ? 8 : 16;
6854   while (EltSize != VT.getScalarSizeInBits()) {
6855     EltSize *= 2;
6856     NumElts /= 2;
6857     MVT WidenVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize), NumElts);
6858     Val = DAG.getNode(
6859         ISD::INTRINSIC_WO_CHAIN, DL, WidenVT,
6860         DAG.getConstant(Intrinsic::aarch64_neon_uaddlp, DL, MVT::i32), Val);
6861   }
6862 
6863   return Val;
6864 }
6865 
LowerCTTZ(SDValue Op,SelectionDAG & DAG) const6866 SDValue AArch64TargetLowering::LowerCTTZ(SDValue Op, SelectionDAG &DAG) const {
6867   EVT VT = Op.getValueType();
6868   assert(VT.isScalableVector() ||
6869          useSVEForFixedLengthVectorVT(VT, /*OverrideNEON=*/true));
6870 
6871   SDLoc DL(Op);
6872   SDValue RBIT = DAG.getNode(ISD::BITREVERSE, DL, VT, Op.getOperand(0));
6873   return DAG.getNode(ISD::CTLZ, DL, VT, RBIT);
6874 }
6875 
LowerSETCC(SDValue Op,SelectionDAG & DAG) const6876 SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
6877 
6878   if (Op.getValueType().isVector())
6879     return LowerVSETCC(Op, DAG);
6880 
6881   bool IsStrict = Op->isStrictFPOpcode();
6882   bool IsSignaling = Op.getOpcode() == ISD::STRICT_FSETCCS;
6883   unsigned OpNo = IsStrict ? 1 : 0;
6884   SDValue Chain;
6885   if (IsStrict)
6886     Chain = Op.getOperand(0);
6887   SDValue LHS = Op.getOperand(OpNo + 0);
6888   SDValue RHS = Op.getOperand(OpNo + 1);
6889   ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(OpNo + 2))->get();
6890   SDLoc dl(Op);
6891 
6892   // We chose ZeroOrOneBooleanContents, so use zero and one.
6893   EVT VT = Op.getValueType();
6894   SDValue TVal = DAG.getConstant(1, dl, VT);
6895   SDValue FVal = DAG.getConstant(0, dl, VT);
6896 
6897   // Handle f128 first, since one possible outcome is a normal integer
6898   // comparison which gets picked up by the next if statement.
6899   if (LHS.getValueType() == MVT::f128) {
6900     softenSetCCOperands(DAG, MVT::f128, LHS, RHS, CC, dl, LHS, RHS, Chain,
6901                         IsSignaling);
6902 
6903     // If softenSetCCOperands returned a scalar, use it.
6904     if (!RHS.getNode()) {
6905       assert(LHS.getValueType() == Op.getValueType() &&
6906              "Unexpected setcc expansion!");
6907       return IsStrict ? DAG.getMergeValues({LHS, Chain}, dl) : LHS;
6908     }
6909   }
6910 
6911   if (LHS.getValueType().isInteger()) {
6912     SDValue CCVal;
6913     SDValue Cmp = getAArch64Cmp(
6914         LHS, RHS, ISD::getSetCCInverse(CC, LHS.getValueType()), CCVal, DAG, dl);
6915 
6916     // Note that we inverted the condition above, so we reverse the order of
6917     // the true and false operands here.  This will allow the setcc to be
6918     // matched to a single CSINC instruction.
6919     SDValue Res = DAG.getNode(AArch64ISD::CSEL, dl, VT, FVal, TVal, CCVal, Cmp);
6920     return IsStrict ? DAG.getMergeValues({Res, Chain}, dl) : Res;
6921   }
6922 
6923   // Now we know we're dealing with FP values.
6924   assert(LHS.getValueType() == MVT::f16 || LHS.getValueType() == MVT::f32 ||
6925          LHS.getValueType() == MVT::f64);
6926 
6927   // If that fails, we'll need to perform an FCMP + CSEL sequence.  Go ahead
6928   // and do the comparison.
6929   SDValue Cmp;
6930   if (IsStrict)
6931     Cmp = emitStrictFPComparison(LHS, RHS, dl, DAG, Chain, IsSignaling);
6932   else
6933     Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
6934 
6935   AArch64CC::CondCode CC1, CC2;
6936   changeFPCCToAArch64CC(CC, CC1, CC2);
6937   SDValue Res;
6938   if (CC2 == AArch64CC::AL) {
6939     changeFPCCToAArch64CC(ISD::getSetCCInverse(CC, LHS.getValueType()), CC1,
6940                           CC2);
6941     SDValue CC1Val = DAG.getConstant(CC1, dl, MVT::i32);
6942 
6943     // Note that we inverted the condition above, so we reverse the order of
6944     // the true and false operands here.  This will allow the setcc to be
6945     // matched to a single CSINC instruction.
6946     Res = DAG.getNode(AArch64ISD::CSEL, dl, VT, FVal, TVal, CC1Val, Cmp);
6947   } else {
6948     // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't
6949     // totally clean.  Some of them require two CSELs to implement.  As is in
6950     // this case, we emit the first CSEL and then emit a second using the output
6951     // of the first as the RHS.  We're effectively OR'ing the two CC's together.
6952 
6953     // FIXME: It would be nice if we could match the two CSELs to two CSINCs.
6954     SDValue CC1Val = DAG.getConstant(CC1, dl, MVT::i32);
6955     SDValue CS1 =
6956         DAG.getNode(AArch64ISD::CSEL, dl, VT, TVal, FVal, CC1Val, Cmp);
6957 
6958     SDValue CC2Val = DAG.getConstant(CC2, dl, MVT::i32);
6959     Res = DAG.getNode(AArch64ISD::CSEL, dl, VT, TVal, CS1, CC2Val, Cmp);
6960   }
6961   return IsStrict ? DAG.getMergeValues({Res, Cmp.getValue(1)}, dl) : Res;
6962 }
6963 
LowerSELECT_CC(ISD::CondCode CC,SDValue LHS,SDValue RHS,SDValue TVal,SDValue FVal,const SDLoc & dl,SelectionDAG & DAG) const6964 SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
6965                                               SDValue RHS, SDValue TVal,
6966                                               SDValue FVal, const SDLoc &dl,
6967                                               SelectionDAG &DAG) const {
6968   // Handle f128 first, because it will result in a comparison of some RTLIB
6969   // call result against zero.
6970   if (LHS.getValueType() == MVT::f128) {
6971     softenSetCCOperands(DAG, MVT::f128, LHS, RHS, CC, dl, LHS, RHS);
6972 
6973     // If softenSetCCOperands returned a scalar, we need to compare the result
6974     // against zero to select between true and false values.
6975     if (!RHS.getNode()) {
6976       RHS = DAG.getConstant(0, dl, LHS.getValueType());
6977       CC = ISD::SETNE;
6978     }
6979   }
6980 
6981   // Also handle f16, for which we need to do a f32 comparison.
6982   if (LHS.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) {
6983     LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, LHS);
6984     RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, RHS);
6985   }
6986 
6987   // Next, handle integers.
6988   if (LHS.getValueType().isInteger()) {
6989     assert((LHS.getValueType() == RHS.getValueType()) &&
6990            (LHS.getValueType() == MVT::i32 || LHS.getValueType() == MVT::i64));
6991 
6992     ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(FVal);
6993     ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(TVal);
6994     ConstantSDNode *RHSC = dyn_cast<ConstantSDNode>(RHS);
6995     // Check for sign pattern (SELECT_CC setgt, iN lhs, -1, 1, -1) and transform
6996     // into (OR (ASR lhs, N-1), 1), which requires less instructions for the
6997     // supported types.
6998     if (CC == ISD::SETGT && RHSC && RHSC->isAllOnesValue() && CTVal && CFVal &&
6999         CTVal->isOne() && CFVal->isAllOnesValue() &&
7000         LHS.getValueType() == TVal.getValueType()) {
7001       EVT VT = LHS.getValueType();
7002       SDValue Shift =
7003           DAG.getNode(ISD::SRA, dl, VT, LHS,
7004                       DAG.getConstant(VT.getSizeInBits() - 1, dl, VT));
7005       return DAG.getNode(ISD::OR, dl, VT, Shift, DAG.getConstant(1, dl, VT));
7006     }
7007 
7008     unsigned Opcode = AArch64ISD::CSEL;
7009 
7010     // If both the TVal and the FVal are constants, see if we can swap them in
7011     // order to for a CSINV or CSINC out of them.
7012     if (CTVal && CFVal && CTVal->isAllOnesValue() && CFVal->isNullValue()) {
7013       std::swap(TVal, FVal);
7014       std::swap(CTVal, CFVal);
7015       CC = ISD::getSetCCInverse(CC, LHS.getValueType());
7016     } else if (CTVal && CFVal && CTVal->isOne() && CFVal->isNullValue()) {
7017       std::swap(TVal, FVal);
7018       std::swap(CTVal, CFVal);
7019       CC = ISD::getSetCCInverse(CC, LHS.getValueType());
7020     } else if (TVal.getOpcode() == ISD::XOR) {
7021       // If TVal is a NOT we want to swap TVal and FVal so that we can match
7022       // with a CSINV rather than a CSEL.
7023       if (isAllOnesConstant(TVal.getOperand(1))) {
7024         std::swap(TVal, FVal);
7025         std::swap(CTVal, CFVal);
7026         CC = ISD::getSetCCInverse(CC, LHS.getValueType());
7027       }
7028     } else if (TVal.getOpcode() == ISD::SUB) {
7029       // If TVal is a negation (SUB from 0) we want to swap TVal and FVal so
7030       // that we can match with a CSNEG rather than a CSEL.
7031       if (isNullConstant(TVal.getOperand(0))) {
7032         std::swap(TVal, FVal);
7033         std::swap(CTVal, CFVal);
7034         CC = ISD::getSetCCInverse(CC, LHS.getValueType());
7035       }
7036     } else if (CTVal && CFVal) {
7037       const int64_t TrueVal = CTVal->getSExtValue();
7038       const int64_t FalseVal = CFVal->getSExtValue();
7039       bool Swap = false;
7040 
7041       // If both TVal and FVal are constants, see if FVal is the
7042       // inverse/negation/increment of TVal and generate a CSINV/CSNEG/CSINC
7043       // instead of a CSEL in that case.
7044       if (TrueVal == ~FalseVal) {
7045         Opcode = AArch64ISD::CSINV;
7046       } else if (FalseVal > std::numeric_limits<int64_t>::min() &&
7047                  TrueVal == -FalseVal) {
7048         Opcode = AArch64ISD::CSNEG;
7049       } else if (TVal.getValueType() == MVT::i32) {
7050         // If our operands are only 32-bit wide, make sure we use 32-bit
7051         // arithmetic for the check whether we can use CSINC. This ensures that
7052         // the addition in the check will wrap around properly in case there is
7053         // an overflow (which would not be the case if we do the check with
7054         // 64-bit arithmetic).
7055         const uint32_t TrueVal32 = CTVal->getZExtValue();
7056         const uint32_t FalseVal32 = CFVal->getZExtValue();
7057 
7058         if ((TrueVal32 == FalseVal32 + 1) || (TrueVal32 + 1 == FalseVal32)) {
7059           Opcode = AArch64ISD::CSINC;
7060 
7061           if (TrueVal32 > FalseVal32) {
7062             Swap = true;
7063           }
7064         }
7065         // 64-bit check whether we can use CSINC.
7066       } else if ((TrueVal == FalseVal + 1) || (TrueVal + 1 == FalseVal)) {
7067         Opcode = AArch64ISD::CSINC;
7068 
7069         if (TrueVal > FalseVal) {
7070           Swap = true;
7071         }
7072       }
7073 
7074       // Swap TVal and FVal if necessary.
7075       if (Swap) {
7076         std::swap(TVal, FVal);
7077         std::swap(CTVal, CFVal);
7078         CC = ISD::getSetCCInverse(CC, LHS.getValueType());
7079       }
7080 
7081       if (Opcode != AArch64ISD::CSEL) {
7082         // Drop FVal since we can get its value by simply inverting/negating
7083         // TVal.
7084         FVal = TVal;
7085       }
7086     }
7087 
7088     // Avoid materializing a constant when possible by reusing a known value in
7089     // a register.  However, don't perform this optimization if the known value
7090     // is one, zero or negative one in the case of a CSEL.  We can always
7091     // materialize these values using CSINC, CSEL and CSINV with wzr/xzr as the
7092     // FVal, respectively.
7093     ConstantSDNode *RHSVal = dyn_cast<ConstantSDNode>(RHS);
7094     if (Opcode == AArch64ISD::CSEL && RHSVal && !RHSVal->isOne() &&
7095         !RHSVal->isNullValue() && !RHSVal->isAllOnesValue()) {
7096       AArch64CC::CondCode AArch64CC = changeIntCCToAArch64CC(CC);
7097       // Transform "a == C ? C : x" to "a == C ? a : x" and "a != C ? x : C" to
7098       // "a != C ? x : a" to avoid materializing C.
7099       if (CTVal && CTVal == RHSVal && AArch64CC == AArch64CC::EQ)
7100         TVal = LHS;
7101       else if (CFVal && CFVal == RHSVal && AArch64CC == AArch64CC::NE)
7102         FVal = LHS;
7103     } else if (Opcode == AArch64ISD::CSNEG && RHSVal && RHSVal->isOne()) {
7104       assert (CTVal && CFVal && "Expected constant operands for CSNEG.");
7105       // Use a CSINV to transform "a == C ? 1 : -1" to "a == C ? a : -1" to
7106       // avoid materializing C.
7107       AArch64CC::CondCode AArch64CC = changeIntCCToAArch64CC(CC);
7108       if (CTVal == RHSVal && AArch64CC == AArch64CC::EQ) {
7109         Opcode = AArch64ISD::CSINV;
7110         TVal = LHS;
7111         FVal = DAG.getConstant(0, dl, FVal.getValueType());
7112       }
7113     }
7114 
7115     SDValue CCVal;
7116     SDValue Cmp = getAArch64Cmp(LHS, RHS, CC, CCVal, DAG, dl);
7117     EVT VT = TVal.getValueType();
7118     return DAG.getNode(Opcode, dl, VT, TVal, FVal, CCVal, Cmp);
7119   }
7120 
7121   // Now we know we're dealing with FP values.
7122   assert(LHS.getValueType() == MVT::f16 || LHS.getValueType() == MVT::f32 ||
7123          LHS.getValueType() == MVT::f64);
7124   assert(LHS.getValueType() == RHS.getValueType());
7125   EVT VT = TVal.getValueType();
7126   SDValue Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
7127 
7128   // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
7129   // clean.  Some of them require two CSELs to implement.
7130   AArch64CC::CondCode CC1, CC2;
7131   changeFPCCToAArch64CC(CC, CC1, CC2);
7132 
7133   if (DAG.getTarget().Options.UnsafeFPMath) {
7134     // Transform "a == 0.0 ? 0.0 : x" to "a == 0.0 ? a : x" and
7135     // "a != 0.0 ? x : 0.0" to "a != 0.0 ? x : a" to avoid materializing 0.0.
7136     ConstantFPSDNode *RHSVal = dyn_cast<ConstantFPSDNode>(RHS);
7137     if (RHSVal && RHSVal->isZero()) {
7138       ConstantFPSDNode *CFVal = dyn_cast<ConstantFPSDNode>(FVal);
7139       ConstantFPSDNode *CTVal = dyn_cast<ConstantFPSDNode>(TVal);
7140 
7141       if ((CC == ISD::SETEQ || CC == ISD::SETOEQ || CC == ISD::SETUEQ) &&
7142           CTVal && CTVal->isZero() && TVal.getValueType() == LHS.getValueType())
7143         TVal = LHS;
7144       else if ((CC == ISD::SETNE || CC == ISD::SETONE || CC == ISD::SETUNE) &&
7145                CFVal && CFVal->isZero() &&
7146                FVal.getValueType() == LHS.getValueType())
7147         FVal = LHS;
7148     }
7149   }
7150 
7151   // Emit first, and possibly only, CSEL.
7152   SDValue CC1Val = DAG.getConstant(CC1, dl, MVT::i32);
7153   SDValue CS1 = DAG.getNode(AArch64ISD::CSEL, dl, VT, TVal, FVal, CC1Val, Cmp);
7154 
7155   // If we need a second CSEL, emit it, using the output of the first as the
7156   // RHS.  We're effectively OR'ing the two CC's together.
7157   if (CC2 != AArch64CC::AL) {
7158     SDValue CC2Val = DAG.getConstant(CC2, dl, MVT::i32);
7159     return DAG.getNode(AArch64ISD::CSEL, dl, VT, TVal, CS1, CC2Val, Cmp);
7160   }
7161 
7162   // Otherwise, return the output of the first CSEL.
7163   return CS1;
7164 }
7165 
LowerSELECT_CC(SDValue Op,SelectionDAG & DAG) const7166 SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
7167                                               SelectionDAG &DAG) const {
7168   ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(4))->get();
7169   SDValue LHS = Op.getOperand(0);
7170   SDValue RHS = Op.getOperand(1);
7171   SDValue TVal = Op.getOperand(2);
7172   SDValue FVal = Op.getOperand(3);
7173   SDLoc DL(Op);
7174   return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
7175 }
7176 
LowerSELECT(SDValue Op,SelectionDAG & DAG) const7177 SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
7178                                            SelectionDAG &DAG) const {
7179   SDValue CCVal = Op->getOperand(0);
7180   SDValue TVal = Op->getOperand(1);
7181   SDValue FVal = Op->getOperand(2);
7182   SDLoc DL(Op);
7183 
7184   EVT Ty = Op.getValueType();
7185   if (Ty.isScalableVector()) {
7186     SDValue TruncCC = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, CCVal);
7187     MVT PredVT = MVT::getVectorVT(MVT::i1, Ty.getVectorElementCount());
7188     SDValue SplatPred = DAG.getNode(ISD::SPLAT_VECTOR, DL, PredVT, TruncCC);
7189     return DAG.getNode(ISD::VSELECT, DL, Ty, SplatPred, TVal, FVal);
7190   }
7191 
7192   if (useSVEForFixedLengthVectorVT(Ty)) {
7193     // FIXME: Ideally this would be the same as above using i1 types, however
7194     // for the moment we can't deal with fixed i1 vector types properly, so
7195     // instead extend the predicate to a result type sized integer vector.
7196     MVT SplatValVT = MVT::getIntegerVT(Ty.getScalarSizeInBits());
7197     MVT PredVT = MVT::getVectorVT(SplatValVT, Ty.getVectorElementCount());
7198     SDValue SplatVal = DAG.getSExtOrTrunc(CCVal, DL, SplatValVT);
7199     SDValue SplatPred = DAG.getNode(ISD::SPLAT_VECTOR, DL, PredVT, SplatVal);
7200     return DAG.getNode(ISD::VSELECT, DL, Ty, SplatPred, TVal, FVal);
7201   }
7202 
7203   // Optimize {s|u}{add|sub|mul}.with.overflow feeding into a select
7204   // instruction.
7205   if (ISD::isOverflowIntrOpRes(CCVal)) {
7206     // Only lower legal XALUO ops.
7207     if (!DAG.getTargetLoweringInfo().isTypeLegal(CCVal->getValueType(0)))
7208       return SDValue();
7209 
7210     AArch64CC::CondCode OFCC;
7211     SDValue Value, Overflow;
7212     std::tie(Value, Overflow) = getAArch64XALUOOp(OFCC, CCVal.getValue(0), DAG);
7213     SDValue CCVal = DAG.getConstant(OFCC, DL, MVT::i32);
7214 
7215     return DAG.getNode(AArch64ISD::CSEL, DL, Op.getValueType(), TVal, FVal,
7216                        CCVal, Overflow);
7217   }
7218 
7219   // Lower it the same way as we would lower a SELECT_CC node.
7220   ISD::CondCode CC;
7221   SDValue LHS, RHS;
7222   if (CCVal.getOpcode() == ISD::SETCC) {
7223     LHS = CCVal.getOperand(0);
7224     RHS = CCVal.getOperand(1);
7225     CC = cast<CondCodeSDNode>(CCVal.getOperand(2))->get();
7226   } else {
7227     LHS = CCVal;
7228     RHS = DAG.getConstant(0, DL, CCVal.getValueType());
7229     CC = ISD::SETNE;
7230   }
7231   return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
7232 }
7233 
LowerJumpTable(SDValue Op,SelectionDAG & DAG) const7234 SDValue AArch64TargetLowering::LowerJumpTable(SDValue Op,
7235                                               SelectionDAG &DAG) const {
7236   // Jump table entries as PC relative offsets. No additional tweaking
7237   // is necessary here. Just get the address of the jump table.
7238   JumpTableSDNode *JT = cast<JumpTableSDNode>(Op);
7239 
7240   if (getTargetMachine().getCodeModel() == CodeModel::Large &&
7241       !Subtarget->isTargetMachO()) {
7242     return getAddrLarge(JT, DAG);
7243   } else if (getTargetMachine().getCodeModel() == CodeModel::Tiny) {
7244     return getAddrTiny(JT, DAG);
7245   }
7246   return getAddr(JT, DAG);
7247 }
7248 
LowerBR_JT(SDValue Op,SelectionDAG & DAG) const7249 SDValue AArch64TargetLowering::LowerBR_JT(SDValue Op,
7250                                           SelectionDAG &DAG) const {
7251   // Jump table entries as PC relative offsets. No additional tweaking
7252   // is necessary here. Just get the address of the jump table.
7253   SDLoc DL(Op);
7254   SDValue JT = Op.getOperand(1);
7255   SDValue Entry = Op.getOperand(2);
7256   int JTI = cast<JumpTableSDNode>(JT.getNode())->getIndex();
7257 
7258   auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
7259   AFI->setJumpTableEntryInfo(JTI, 4, nullptr);
7260 
7261   SDNode *Dest =
7262       DAG.getMachineNode(AArch64::JumpTableDest32, DL, MVT::i64, MVT::i64, JT,
7263                          Entry, DAG.getTargetJumpTable(JTI, MVT::i32));
7264   return DAG.getNode(ISD::BRIND, DL, MVT::Other, Op.getOperand(0),
7265                      SDValue(Dest, 0));
7266 }
7267 
LowerConstantPool(SDValue Op,SelectionDAG & DAG) const7268 SDValue AArch64TargetLowering::LowerConstantPool(SDValue Op,
7269                                                  SelectionDAG &DAG) const {
7270   ConstantPoolSDNode *CP = cast<ConstantPoolSDNode>(Op);
7271 
7272   if (getTargetMachine().getCodeModel() == CodeModel::Large) {
7273     // Use the GOT for the large code model on iOS.
7274     if (Subtarget->isTargetMachO()) {
7275       return getGOT(CP, DAG);
7276     }
7277     return getAddrLarge(CP, DAG);
7278   } else if (getTargetMachine().getCodeModel() == CodeModel::Tiny) {
7279     return getAddrTiny(CP, DAG);
7280   } else {
7281     return getAddr(CP, DAG);
7282   }
7283 }
7284 
LowerBlockAddress(SDValue Op,SelectionDAG & DAG) const7285 SDValue AArch64TargetLowering::LowerBlockAddress(SDValue Op,
7286                                                SelectionDAG &DAG) const {
7287   BlockAddressSDNode *BA = cast<BlockAddressSDNode>(Op);
7288   if (getTargetMachine().getCodeModel() == CodeModel::Large &&
7289       !Subtarget->isTargetMachO()) {
7290     return getAddrLarge(BA, DAG);
7291   } else if (getTargetMachine().getCodeModel() == CodeModel::Tiny) {
7292     return getAddrTiny(BA, DAG);
7293   }
7294   return getAddr(BA, DAG);
7295 }
7296 
LowerDarwin_VASTART(SDValue Op,SelectionDAG & DAG) const7297 SDValue AArch64TargetLowering::LowerDarwin_VASTART(SDValue Op,
7298                                                  SelectionDAG &DAG) const {
7299   AArch64FunctionInfo *FuncInfo =
7300       DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
7301 
7302   SDLoc DL(Op);
7303   SDValue FR = DAG.getFrameIndex(FuncInfo->getVarArgsStackIndex(),
7304                                  getPointerTy(DAG.getDataLayout()));
7305   FR = DAG.getZExtOrTrunc(FR, DL, getPointerMemTy(DAG.getDataLayout()));
7306   const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
7307   return DAG.getStore(Op.getOperand(0), DL, FR, Op.getOperand(1),
7308                       MachinePointerInfo(SV));
7309 }
7310 
LowerWin64_VASTART(SDValue Op,SelectionDAG & DAG) const7311 SDValue AArch64TargetLowering::LowerWin64_VASTART(SDValue Op,
7312                                                   SelectionDAG &DAG) const {
7313   AArch64FunctionInfo *FuncInfo =
7314       DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
7315 
7316   SDLoc DL(Op);
7317   SDValue FR = DAG.getFrameIndex(FuncInfo->getVarArgsGPRSize() > 0
7318                                      ? FuncInfo->getVarArgsGPRIndex()
7319                                      : FuncInfo->getVarArgsStackIndex(),
7320                                  getPointerTy(DAG.getDataLayout()));
7321   const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
7322   return DAG.getStore(Op.getOperand(0), DL, FR, Op.getOperand(1),
7323                       MachinePointerInfo(SV));
7324 }
7325 
LowerAAPCS_VASTART(SDValue Op,SelectionDAG & DAG) const7326 SDValue AArch64TargetLowering::LowerAAPCS_VASTART(SDValue Op,
7327                                                   SelectionDAG &DAG) const {
7328   // The layout of the va_list struct is specified in the AArch64 Procedure Call
7329   // Standard, section B.3.
7330   MachineFunction &MF = DAG.getMachineFunction();
7331   AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
7332   unsigned PtrSize = Subtarget->isTargetILP32() ? 4 : 8;
7333   auto PtrMemVT = getPointerMemTy(DAG.getDataLayout());
7334   auto PtrVT = getPointerTy(DAG.getDataLayout());
7335   SDLoc DL(Op);
7336 
7337   SDValue Chain = Op.getOperand(0);
7338   SDValue VAList = Op.getOperand(1);
7339   const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
7340   SmallVector<SDValue, 4> MemOps;
7341 
7342   // void *__stack at offset 0
7343   unsigned Offset = 0;
7344   SDValue Stack = DAG.getFrameIndex(FuncInfo->getVarArgsStackIndex(), PtrVT);
7345   Stack = DAG.getZExtOrTrunc(Stack, DL, PtrMemVT);
7346   MemOps.push_back(DAG.getStore(Chain, DL, Stack, VAList,
7347                                 MachinePointerInfo(SV), Align(PtrSize)));
7348 
7349   // void *__gr_top at offset 8 (4 on ILP32)
7350   Offset += PtrSize;
7351   int GPRSize = FuncInfo->getVarArgsGPRSize();
7352   if (GPRSize > 0) {
7353     SDValue GRTop, GRTopAddr;
7354 
7355     GRTopAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
7356                             DAG.getConstant(Offset, DL, PtrVT));
7357 
7358     GRTop = DAG.getFrameIndex(FuncInfo->getVarArgsGPRIndex(), PtrVT);
7359     GRTop = DAG.getNode(ISD::ADD, DL, PtrVT, GRTop,
7360                         DAG.getConstant(GPRSize, DL, PtrVT));
7361     GRTop = DAG.getZExtOrTrunc(GRTop, DL, PtrMemVT);
7362 
7363     MemOps.push_back(DAG.getStore(Chain, DL, GRTop, GRTopAddr,
7364                                   MachinePointerInfo(SV, Offset),
7365                                   Align(PtrSize)));
7366   }
7367 
7368   // void *__vr_top at offset 16 (8 on ILP32)
7369   Offset += PtrSize;
7370   int FPRSize = FuncInfo->getVarArgsFPRSize();
7371   if (FPRSize > 0) {
7372     SDValue VRTop, VRTopAddr;
7373     VRTopAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
7374                             DAG.getConstant(Offset, DL, PtrVT));
7375 
7376     VRTop = DAG.getFrameIndex(FuncInfo->getVarArgsFPRIndex(), PtrVT);
7377     VRTop = DAG.getNode(ISD::ADD, DL, PtrVT, VRTop,
7378                         DAG.getConstant(FPRSize, DL, PtrVT));
7379     VRTop = DAG.getZExtOrTrunc(VRTop, DL, PtrMemVT);
7380 
7381     MemOps.push_back(DAG.getStore(Chain, DL, VRTop, VRTopAddr,
7382                                   MachinePointerInfo(SV, Offset),
7383                                   Align(PtrSize)));
7384   }
7385 
7386   // int __gr_offs at offset 24 (12 on ILP32)
7387   Offset += PtrSize;
7388   SDValue GROffsAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
7389                                    DAG.getConstant(Offset, DL, PtrVT));
7390   MemOps.push_back(
7391       DAG.getStore(Chain, DL, DAG.getConstant(-GPRSize, DL, MVT::i32),
7392                    GROffsAddr, MachinePointerInfo(SV, Offset), Align(4)));
7393 
7394   // int __vr_offs at offset 28 (16 on ILP32)
7395   Offset += 4;
7396   SDValue VROffsAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
7397                                    DAG.getConstant(Offset, DL, PtrVT));
7398   MemOps.push_back(
7399       DAG.getStore(Chain, DL, DAG.getConstant(-FPRSize, DL, MVT::i32),
7400                    VROffsAddr, MachinePointerInfo(SV, Offset), Align(4)));
7401 
7402   return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOps);
7403 }
7404 
LowerVASTART(SDValue Op,SelectionDAG & DAG) const7405 SDValue AArch64TargetLowering::LowerVASTART(SDValue Op,
7406                                             SelectionDAG &DAG) const {
7407   MachineFunction &MF = DAG.getMachineFunction();
7408 
7409   if (Subtarget->isCallingConvWin64(MF.getFunction().getCallingConv()))
7410     return LowerWin64_VASTART(Op, DAG);
7411   else if (Subtarget->isTargetDarwin())
7412     return LowerDarwin_VASTART(Op, DAG);
7413   else
7414     return LowerAAPCS_VASTART(Op, DAG);
7415 }
7416 
LowerVACOPY(SDValue Op,SelectionDAG & DAG) const7417 SDValue AArch64TargetLowering::LowerVACOPY(SDValue Op,
7418                                            SelectionDAG &DAG) const {
7419   // AAPCS has three pointers and two ints (= 32 bytes), Darwin has single
7420   // pointer.
7421   SDLoc DL(Op);
7422   unsigned PtrSize = Subtarget->isTargetILP32() ? 4 : 8;
7423   unsigned VaListSize =
7424       (Subtarget->isTargetDarwin() || Subtarget->isTargetWindows())
7425           ? PtrSize
7426           : Subtarget->isTargetILP32() ? 20 : 32;
7427   const Value *DestSV = cast<SrcValueSDNode>(Op.getOperand(3))->getValue();
7428   const Value *SrcSV = cast<SrcValueSDNode>(Op.getOperand(4))->getValue();
7429 
7430   return DAG.getMemcpy(Op.getOperand(0), DL, Op.getOperand(1), Op.getOperand(2),
7431                        DAG.getConstant(VaListSize, DL, MVT::i32),
7432                        Align(PtrSize), false, false, false,
7433                        MachinePointerInfo(DestSV), MachinePointerInfo(SrcSV));
7434 }
7435 
LowerVAARG(SDValue Op,SelectionDAG & DAG) const7436 SDValue AArch64TargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
7437   assert(Subtarget->isTargetDarwin() &&
7438          "automatic va_arg instruction only works on Darwin");
7439 
7440   const Value *V = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
7441   EVT VT = Op.getValueType();
7442   SDLoc DL(Op);
7443   SDValue Chain = Op.getOperand(0);
7444   SDValue Addr = Op.getOperand(1);
7445   MaybeAlign Align(Op.getConstantOperandVal(3));
7446   unsigned MinSlotSize = Subtarget->isTargetILP32() ? 4 : 8;
7447   auto PtrVT = getPointerTy(DAG.getDataLayout());
7448   auto PtrMemVT = getPointerMemTy(DAG.getDataLayout());
7449   SDValue VAList =
7450       DAG.getLoad(PtrMemVT, DL, Chain, Addr, MachinePointerInfo(V));
7451   Chain = VAList.getValue(1);
7452   VAList = DAG.getZExtOrTrunc(VAList, DL, PtrVT);
7453 
7454   if (VT.isScalableVector())
7455     report_fatal_error("Passing SVE types to variadic functions is "
7456                        "currently not supported");
7457 
7458   if (Align && *Align > MinSlotSize) {
7459     VAList = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
7460                          DAG.getConstant(Align->value() - 1, DL, PtrVT));
7461     VAList = DAG.getNode(ISD::AND, DL, PtrVT, VAList,
7462                          DAG.getConstant(-(int64_t)Align->value(), DL, PtrVT));
7463   }
7464 
7465   Type *ArgTy = VT.getTypeForEVT(*DAG.getContext());
7466   unsigned ArgSize = DAG.getDataLayout().getTypeAllocSize(ArgTy);
7467 
7468   // Scalar integer and FP values smaller than 64 bits are implicitly extended
7469   // up to 64 bits.  At the very least, we have to increase the striding of the
7470   // vaargs list to match this, and for FP values we need to introduce
7471   // FP_ROUND nodes as well.
7472   if (VT.isInteger() && !VT.isVector())
7473     ArgSize = std::max(ArgSize, MinSlotSize);
7474   bool NeedFPTrunc = false;
7475   if (VT.isFloatingPoint() && !VT.isVector() && VT != MVT::f64) {
7476     ArgSize = 8;
7477     NeedFPTrunc = true;
7478   }
7479 
7480   // Increment the pointer, VAList, to the next vaarg
7481   SDValue VANext = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
7482                                DAG.getConstant(ArgSize, DL, PtrVT));
7483   VANext = DAG.getZExtOrTrunc(VANext, DL, PtrMemVT);
7484 
7485   // Store the incremented VAList to the legalized pointer
7486   SDValue APStore =
7487       DAG.getStore(Chain, DL, VANext, Addr, MachinePointerInfo(V));
7488 
7489   // Load the actual argument out of the pointer VAList
7490   if (NeedFPTrunc) {
7491     // Load the value as an f64.
7492     SDValue WideFP =
7493         DAG.getLoad(MVT::f64, DL, APStore, VAList, MachinePointerInfo());
7494     // Round the value down to an f32.
7495     SDValue NarrowFP = DAG.getNode(ISD::FP_ROUND, DL, VT, WideFP.getValue(0),
7496                                    DAG.getIntPtrConstant(1, DL));
7497     SDValue Ops[] = { NarrowFP, WideFP.getValue(1) };
7498     // Merge the rounded value with the chain output of the load.
7499     return DAG.getMergeValues(Ops, DL);
7500   }
7501 
7502   return DAG.getLoad(VT, DL, APStore, VAList, MachinePointerInfo());
7503 }
7504 
LowerFRAMEADDR(SDValue Op,SelectionDAG & DAG) const7505 SDValue AArch64TargetLowering::LowerFRAMEADDR(SDValue Op,
7506                                               SelectionDAG &DAG) const {
7507   MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
7508   MFI.setFrameAddressIsTaken(true);
7509 
7510   EVT VT = Op.getValueType();
7511   SDLoc DL(Op);
7512   unsigned Depth = cast<ConstantSDNode>(Op.getOperand(0))->getZExtValue();
7513   SDValue FrameAddr =
7514       DAG.getCopyFromReg(DAG.getEntryNode(), DL, AArch64::FP, MVT::i64);
7515   while (Depth--)
7516     FrameAddr = DAG.getLoad(VT, DL, DAG.getEntryNode(), FrameAddr,
7517                             MachinePointerInfo());
7518 
7519   if (Subtarget->isTargetILP32())
7520     FrameAddr = DAG.getNode(ISD::AssertZext, DL, MVT::i64, FrameAddr,
7521                             DAG.getValueType(VT));
7522 
7523   return FrameAddr;
7524 }
7525 
LowerSPONENTRY(SDValue Op,SelectionDAG & DAG) const7526 SDValue AArch64TargetLowering::LowerSPONENTRY(SDValue Op,
7527                                               SelectionDAG &DAG) const {
7528   MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
7529 
7530   EVT VT = getPointerTy(DAG.getDataLayout());
7531   SDLoc DL(Op);
7532   int FI = MFI.CreateFixedObject(4, 0, false);
7533   return DAG.getFrameIndex(FI, VT);
7534 }
7535 
7536 #define GET_REGISTER_MATCHER
7537 #include "AArch64GenAsmMatcher.inc"
7538 
7539 // FIXME? Maybe this could be a TableGen attribute on some registers and
7540 // this table could be generated automatically from RegInfo.
7541 Register AArch64TargetLowering::
getRegisterByName(const char * RegName,LLT VT,const MachineFunction & MF) const7542 getRegisterByName(const char* RegName, LLT VT, const MachineFunction &MF) const {
7543   Register Reg = MatchRegisterName(RegName);
7544   if (AArch64::X1 <= Reg && Reg <= AArch64::X28) {
7545     const MCRegisterInfo *MRI = Subtarget->getRegisterInfo();
7546     unsigned DwarfRegNum = MRI->getDwarfRegNum(Reg, false);
7547     if (!Subtarget->isXRegisterReserved(DwarfRegNum))
7548       Reg = 0;
7549   }
7550   if (Reg)
7551     return Reg;
7552   report_fatal_error(Twine("Invalid register name \""
7553                               + StringRef(RegName)  + "\"."));
7554 }
7555 
LowerADDROFRETURNADDR(SDValue Op,SelectionDAG & DAG) const7556 SDValue AArch64TargetLowering::LowerADDROFRETURNADDR(SDValue Op,
7557                                                      SelectionDAG &DAG) const {
7558   DAG.getMachineFunction().getFrameInfo().setFrameAddressIsTaken(true);
7559 
7560   EVT VT = Op.getValueType();
7561   SDLoc DL(Op);
7562 
7563   SDValue FrameAddr =
7564       DAG.getCopyFromReg(DAG.getEntryNode(), DL, AArch64::FP, VT);
7565   SDValue Offset = DAG.getConstant(8, DL, getPointerTy(DAG.getDataLayout()));
7566 
7567   return DAG.getNode(ISD::ADD, DL, VT, FrameAddr, Offset);
7568 }
7569 
LowerRETURNADDR(SDValue Op,SelectionDAG & DAG) const7570 SDValue AArch64TargetLowering::LowerRETURNADDR(SDValue Op,
7571                                                SelectionDAG &DAG) const {
7572   MachineFunction &MF = DAG.getMachineFunction();
7573   MachineFrameInfo &MFI = MF.getFrameInfo();
7574   MFI.setReturnAddressIsTaken(true);
7575 
7576   EVT VT = Op.getValueType();
7577   SDLoc DL(Op);
7578   unsigned Depth = cast<ConstantSDNode>(Op.getOperand(0))->getZExtValue();
7579   SDValue ReturnAddress;
7580   if (Depth) {
7581     SDValue FrameAddr = LowerFRAMEADDR(Op, DAG);
7582     SDValue Offset = DAG.getConstant(8, DL, getPointerTy(DAG.getDataLayout()));
7583     ReturnAddress = DAG.getLoad(
7584         VT, DL, DAG.getEntryNode(),
7585         DAG.getNode(ISD::ADD, DL, VT, FrameAddr, Offset), MachinePointerInfo());
7586   } else {
7587     // Return LR, which contains the return address. Mark it an implicit
7588     // live-in.
7589     unsigned Reg = MF.addLiveIn(AArch64::LR, &AArch64::GPR64RegClass);
7590     ReturnAddress = DAG.getCopyFromReg(DAG.getEntryNode(), DL, Reg, VT);
7591   }
7592 
7593   // The XPACLRI instruction assembles to a hint-space instruction before
7594   // Armv8.3-A therefore this instruction can be safely used for any pre
7595   // Armv8.3-A architectures. On Armv8.3-A and onwards XPACI is available so use
7596   // that instead.
7597   SDNode *St;
7598   if (Subtarget->hasPAuth()) {
7599     St = DAG.getMachineNode(AArch64::XPACI, DL, VT, ReturnAddress);
7600   } else {
7601     // XPACLRI operates on LR therefore we must move the operand accordingly.
7602     SDValue Chain =
7603         DAG.getCopyToReg(DAG.getEntryNode(), DL, AArch64::LR, ReturnAddress);
7604     St = DAG.getMachineNode(AArch64::XPACLRI, DL, VT, Chain);
7605   }
7606   return SDValue(St, 0);
7607 }
7608 
7609 /// LowerShiftParts - Lower SHL_PARTS/SRA_PARTS/SRL_PARTS, which returns two
7610 /// i32 values and take a 2 x i32 value to shift plus a shift amount.
LowerShiftParts(SDValue Op,SelectionDAG & DAG) const7611 SDValue AArch64TargetLowering::LowerShiftParts(SDValue Op,
7612                                                SelectionDAG &DAG) const {
7613   SDValue Lo, Hi;
7614   expandShiftParts(Op.getNode(), Lo, Hi, DAG);
7615   return DAG.getMergeValues({Lo, Hi}, SDLoc(Op));
7616 }
7617 
isOffsetFoldingLegal(const GlobalAddressSDNode * GA) const7618 bool AArch64TargetLowering::isOffsetFoldingLegal(
7619     const GlobalAddressSDNode *GA) const {
7620   // Offsets are folded in the DAG combine rather than here so that we can
7621   // intelligently choose an offset based on the uses.
7622   return false;
7623 }
7624 
isFPImmLegal(const APFloat & Imm,EVT VT,bool OptForSize) const7625 bool AArch64TargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
7626                                          bool OptForSize) const {
7627   bool IsLegal = false;
7628   // We can materialize #0.0 as fmov $Rd, XZR for 64-bit, 32-bit cases, and
7629   // 16-bit case when target has full fp16 support.
7630   // FIXME: We should be able to handle f128 as well with a clever lowering.
7631   const APInt ImmInt = Imm.bitcastToAPInt();
7632   if (VT == MVT::f64)
7633     IsLegal = AArch64_AM::getFP64Imm(ImmInt) != -1 || Imm.isPosZero();
7634   else if (VT == MVT::f32)
7635     IsLegal = AArch64_AM::getFP32Imm(ImmInt) != -1 || Imm.isPosZero();
7636   else if (VT == MVT::f16 && Subtarget->hasFullFP16())
7637     IsLegal = AArch64_AM::getFP16Imm(ImmInt) != -1 || Imm.isPosZero();
7638   // TODO: fmov h0, w0 is also legal, however on't have an isel pattern to
7639   //       generate that fmov.
7640 
7641   // If we can not materialize in immediate field for fmov, check if the
7642   // value can be encoded as the immediate operand of a logical instruction.
7643   // The immediate value will be created with either MOVZ, MOVN, or ORR.
7644   if (!IsLegal && (VT == MVT::f64 || VT == MVT::f32)) {
7645     // The cost is actually exactly the same for mov+fmov vs. adrp+ldr;
7646     // however the mov+fmov sequence is always better because of the reduced
7647     // cache pressure. The timings are still the same if you consider
7648     // movw+movk+fmov vs. adrp+ldr (it's one instruction longer, but the
7649     // movw+movk is fused). So we limit up to 2 instrdduction at most.
7650     SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
7651     AArch64_IMM::expandMOVImm(ImmInt.getZExtValue(), VT.getSizeInBits(),
7652 			      Insn);
7653     unsigned Limit = (OptForSize ? 1 : (Subtarget->hasFuseLiterals() ? 5 : 2));
7654     IsLegal = Insn.size() <= Limit;
7655   }
7656 
7657   LLVM_DEBUG(dbgs() << (IsLegal ? "Legal " : "Illegal ") << VT.getEVTString()
7658                     << " imm value: "; Imm.dump(););
7659   return IsLegal;
7660 }
7661 
7662 //===----------------------------------------------------------------------===//
7663 //                          AArch64 Optimization Hooks
7664 //===----------------------------------------------------------------------===//
7665 
getEstimate(const AArch64Subtarget * ST,unsigned Opcode,SDValue Operand,SelectionDAG & DAG,int & ExtraSteps)7666 static SDValue getEstimate(const AArch64Subtarget *ST, unsigned Opcode,
7667                            SDValue Operand, SelectionDAG &DAG,
7668                            int &ExtraSteps) {
7669   EVT VT = Operand.getValueType();
7670   if (ST->hasNEON() &&
7671       (VT == MVT::f64 || VT == MVT::v1f64 || VT == MVT::v2f64 ||
7672        VT == MVT::f32 || VT == MVT::v1f32 ||
7673        VT == MVT::v2f32 || VT == MVT::v4f32)) {
7674     if (ExtraSteps == TargetLoweringBase::ReciprocalEstimate::Unspecified)
7675       // For the reciprocal estimates, convergence is quadratic, so the number
7676       // of digits is doubled after each iteration.  In ARMv8, the accuracy of
7677       // the initial estimate is 2^-8.  Thus the number of extra steps to refine
7678       // the result for float (23 mantissa bits) is 2 and for double (52
7679       // mantissa bits) is 3.
7680       ExtraSteps = VT.getScalarType() == MVT::f64 ? 3 : 2;
7681 
7682     return DAG.getNode(Opcode, SDLoc(Operand), VT, Operand);
7683   }
7684 
7685   return SDValue();
7686 }
7687 
7688 SDValue
getSqrtInputTest(SDValue Op,SelectionDAG & DAG,const DenormalMode & Mode) const7689 AArch64TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
7690                                         const DenormalMode &Mode) const {
7691   SDLoc DL(Op);
7692   EVT VT = Op.getValueType();
7693   EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
7694   SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
7695   return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
7696 }
7697 
7698 SDValue
getSqrtResultForDenormInput(SDValue Op,SelectionDAG & DAG) const7699 AArch64TargetLowering::getSqrtResultForDenormInput(SDValue Op,
7700                                                    SelectionDAG &DAG) const {
7701   return Op;
7702 }
7703 
getSqrtEstimate(SDValue Operand,SelectionDAG & DAG,int Enabled,int & ExtraSteps,bool & UseOneConst,bool Reciprocal) const7704 SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand,
7705                                                SelectionDAG &DAG, int Enabled,
7706                                                int &ExtraSteps,
7707                                                bool &UseOneConst,
7708                                                bool Reciprocal) const {
7709   if (Enabled == ReciprocalEstimate::Enabled ||
7710       (Enabled == ReciprocalEstimate::Unspecified && Subtarget->useRSqrt()))
7711     if (SDValue Estimate = getEstimate(Subtarget, AArch64ISD::FRSQRTE, Operand,
7712                                        DAG, ExtraSteps)) {
7713       SDLoc DL(Operand);
7714       EVT VT = Operand.getValueType();
7715 
7716       SDNodeFlags Flags;
7717       Flags.setAllowReassociation(true);
7718 
7719       // Newton reciprocal square root iteration: E * 0.5 * (3 - X * E^2)
7720       // AArch64 reciprocal square root iteration instruction: 0.5 * (3 - M * N)
7721       for (int i = ExtraSteps; i > 0; --i) {
7722         SDValue Step = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Estimate,
7723                                    Flags);
7724         Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, Operand, Step, Flags);
7725         Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, Flags);
7726       }
7727       if (!Reciprocal)
7728         Estimate = DAG.getNode(ISD::FMUL, DL, VT, Operand, Estimate, Flags);
7729 
7730       ExtraSteps = 0;
7731       return Estimate;
7732     }
7733 
7734   return SDValue();
7735 }
7736 
getRecipEstimate(SDValue Operand,SelectionDAG & DAG,int Enabled,int & ExtraSteps) const7737 SDValue AArch64TargetLowering::getRecipEstimate(SDValue Operand,
7738                                                 SelectionDAG &DAG, int Enabled,
7739                                                 int &ExtraSteps) const {
7740   if (Enabled == ReciprocalEstimate::Enabled)
7741     if (SDValue Estimate = getEstimate(Subtarget, AArch64ISD::FRECPE, Operand,
7742                                        DAG, ExtraSteps)) {
7743       SDLoc DL(Operand);
7744       EVT VT = Operand.getValueType();
7745 
7746       SDNodeFlags Flags;
7747       Flags.setAllowReassociation(true);
7748 
7749       // Newton reciprocal iteration: E * (2 - X * E)
7750       // AArch64 reciprocal iteration instruction: (2 - M * N)
7751       for (int i = ExtraSteps; i > 0; --i) {
7752         SDValue Step = DAG.getNode(AArch64ISD::FRECPS, DL, VT, Operand,
7753                                    Estimate, Flags);
7754         Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, Flags);
7755       }
7756 
7757       ExtraSteps = 0;
7758       return Estimate;
7759     }
7760 
7761   return SDValue();
7762 }
7763 
7764 //===----------------------------------------------------------------------===//
7765 //                          AArch64 Inline Assembly Support
7766 //===----------------------------------------------------------------------===//
7767 
7768 // Table of Constraints
7769 // TODO: This is the current set of constraints supported by ARM for the
7770 // compiler, not all of them may make sense.
7771 //
7772 // r - A general register
7773 // w - An FP/SIMD register of some size in the range v0-v31
7774 // x - An FP/SIMD register of some size in the range v0-v15
7775 // I - Constant that can be used with an ADD instruction
7776 // J - Constant that can be used with a SUB instruction
7777 // K - Constant that can be used with a 32-bit logical instruction
7778 // L - Constant that can be used with a 64-bit logical instruction
7779 // M - Constant that can be used as a 32-bit MOV immediate
7780 // N - Constant that can be used as a 64-bit MOV immediate
7781 // Q - A memory reference with base register and no offset
7782 // S - A symbolic address
7783 // Y - Floating point constant zero
7784 // Z - Integer constant zero
7785 //
7786 //   Note that general register operands will be output using their 64-bit x
7787 // register name, whatever the size of the variable, unless the asm operand
7788 // is prefixed by the %w modifier. Floating-point and SIMD register operands
7789 // will be output with the v prefix unless prefixed by the %b, %h, %s, %d or
7790 // %q modifier.
LowerXConstraint(EVT ConstraintVT) const7791 const char *AArch64TargetLowering::LowerXConstraint(EVT ConstraintVT) const {
7792   // At this point, we have to lower this constraint to something else, so we
7793   // lower it to an "r" or "w". However, by doing this we will force the result
7794   // to be in register, while the X constraint is much more permissive.
7795   //
7796   // Although we are correct (we are free to emit anything, without
7797   // constraints), we might break use cases that would expect us to be more
7798   // efficient and emit something else.
7799   if (!Subtarget->hasFPARMv8())
7800     return "r";
7801 
7802   if (ConstraintVT.isFloatingPoint())
7803     return "w";
7804 
7805   if (ConstraintVT.isVector() &&
7806      (ConstraintVT.getSizeInBits() == 64 ||
7807       ConstraintVT.getSizeInBits() == 128))
7808     return "w";
7809 
7810   return "r";
7811 }
7812 
7813 enum PredicateConstraint {
7814   Upl,
7815   Upa,
7816   Invalid
7817 };
7818 
parsePredicateConstraint(StringRef Constraint)7819 static PredicateConstraint parsePredicateConstraint(StringRef Constraint) {
7820   PredicateConstraint P = PredicateConstraint::Invalid;
7821   if (Constraint == "Upa")
7822     P = PredicateConstraint::Upa;
7823   if (Constraint == "Upl")
7824     P = PredicateConstraint::Upl;
7825   return P;
7826 }
7827 
7828 /// getConstraintType - Given a constraint letter, return the type of
7829 /// constraint it is for this target.
7830 AArch64TargetLowering::ConstraintType
getConstraintType(StringRef Constraint) const7831 AArch64TargetLowering::getConstraintType(StringRef Constraint) const {
7832   if (Constraint.size() == 1) {
7833     switch (Constraint[0]) {
7834     default:
7835       break;
7836     case 'x':
7837     case 'w':
7838     case 'y':
7839       return C_RegisterClass;
7840     // An address with a single base register. Due to the way we
7841     // currently handle addresses it is the same as 'r'.
7842     case 'Q':
7843       return C_Memory;
7844     case 'I':
7845     case 'J':
7846     case 'K':
7847     case 'L':
7848     case 'M':
7849     case 'N':
7850     case 'Y':
7851     case 'Z':
7852       return C_Immediate;
7853     case 'z':
7854     case 'S': // A symbolic address
7855       return C_Other;
7856     }
7857   } else if (parsePredicateConstraint(Constraint) !=
7858              PredicateConstraint::Invalid)
7859       return C_RegisterClass;
7860   return TargetLowering::getConstraintType(Constraint);
7861 }
7862 
7863 /// Examine constraint type and operand type and determine a weight value.
7864 /// This object must already have been set up with the operand type
7865 /// and the current alternative constraint selected.
7866 TargetLowering::ConstraintWeight
getSingleConstraintMatchWeight(AsmOperandInfo & info,const char * constraint) const7867 AArch64TargetLowering::getSingleConstraintMatchWeight(
7868     AsmOperandInfo &info, const char *constraint) const {
7869   ConstraintWeight weight = CW_Invalid;
7870   Value *CallOperandVal = info.CallOperandVal;
7871   // If we don't have a value, we can't do a match,
7872   // but allow it at the lowest weight.
7873   if (!CallOperandVal)
7874     return CW_Default;
7875   Type *type = CallOperandVal->getType();
7876   // Look at the constraint type.
7877   switch (*constraint) {
7878   default:
7879     weight = TargetLowering::getSingleConstraintMatchWeight(info, constraint);
7880     break;
7881   case 'x':
7882   case 'w':
7883   case 'y':
7884     if (type->isFloatingPointTy() || type->isVectorTy())
7885       weight = CW_Register;
7886     break;
7887   case 'z':
7888     weight = CW_Constant;
7889     break;
7890   case 'U':
7891     if (parsePredicateConstraint(constraint) != PredicateConstraint::Invalid)
7892       weight = CW_Register;
7893     break;
7894   }
7895   return weight;
7896 }
7897 
7898 std::pair<unsigned, const TargetRegisterClass *>
getRegForInlineAsmConstraint(const TargetRegisterInfo * TRI,StringRef Constraint,MVT VT) const7899 AArch64TargetLowering::getRegForInlineAsmConstraint(
7900     const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const {
7901   if (Constraint.size() == 1) {
7902     switch (Constraint[0]) {
7903     case 'r':
7904       if (VT.isScalableVector())
7905         return std::make_pair(0U, nullptr);
7906       if (VT.getFixedSizeInBits() == 64)
7907         return std::make_pair(0U, &AArch64::GPR64commonRegClass);
7908       return std::make_pair(0U, &AArch64::GPR32commonRegClass);
7909     case 'w': {
7910       if (!Subtarget->hasFPARMv8())
7911         break;
7912       if (VT.isScalableVector()) {
7913         if (VT.getVectorElementType() != MVT::i1)
7914           return std::make_pair(0U, &AArch64::ZPRRegClass);
7915         return std::make_pair(0U, nullptr);
7916       }
7917       uint64_t VTSize = VT.getFixedSizeInBits();
7918       if (VTSize == 16)
7919         return std::make_pair(0U, &AArch64::FPR16RegClass);
7920       if (VTSize == 32)
7921         return std::make_pair(0U, &AArch64::FPR32RegClass);
7922       if (VTSize == 64)
7923         return std::make_pair(0U, &AArch64::FPR64RegClass);
7924       if (VTSize == 128)
7925         return std::make_pair(0U, &AArch64::FPR128RegClass);
7926       break;
7927     }
7928     // The instructions that this constraint is designed for can
7929     // only take 128-bit registers so just use that regclass.
7930     case 'x':
7931       if (!Subtarget->hasFPARMv8())
7932         break;
7933       if (VT.isScalableVector())
7934         return std::make_pair(0U, &AArch64::ZPR_4bRegClass);
7935       if (VT.getSizeInBits() == 128)
7936         return std::make_pair(0U, &AArch64::FPR128_loRegClass);
7937       break;
7938     case 'y':
7939       if (!Subtarget->hasFPARMv8())
7940         break;
7941       if (VT.isScalableVector())
7942         return std::make_pair(0U, &AArch64::ZPR_3bRegClass);
7943       break;
7944     }
7945   } else {
7946     PredicateConstraint PC = parsePredicateConstraint(Constraint);
7947     if (PC != PredicateConstraint::Invalid) {
7948       if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1)
7949         return std::make_pair(0U, nullptr);
7950       bool restricted = (PC == PredicateConstraint::Upl);
7951       return restricted ? std::make_pair(0U, &AArch64::PPR_3bRegClass)
7952                         : std::make_pair(0U, &AArch64::PPRRegClass);
7953     }
7954   }
7955   if (StringRef("{cc}").equals_lower(Constraint))
7956     return std::make_pair(unsigned(AArch64::NZCV), &AArch64::CCRRegClass);
7957 
7958   // Use the default implementation in TargetLowering to convert the register
7959   // constraint into a member of a register class.
7960   std::pair<unsigned, const TargetRegisterClass *> Res;
7961   Res = TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
7962 
7963   // Not found as a standard register?
7964   if (!Res.second) {
7965     unsigned Size = Constraint.size();
7966     if ((Size == 4 || Size == 5) && Constraint[0] == '{' &&
7967         tolower(Constraint[1]) == 'v' && Constraint[Size - 1] == '}') {
7968       int RegNo;
7969       bool Failed = Constraint.slice(2, Size - 1).getAsInteger(10, RegNo);
7970       if (!Failed && RegNo >= 0 && RegNo <= 31) {
7971         // v0 - v31 are aliases of q0 - q31 or d0 - d31 depending on size.
7972         // By default we'll emit v0-v31 for this unless there's a modifier where
7973         // we'll emit the correct register as well.
7974         if (VT != MVT::Other && VT.getSizeInBits() == 64) {
7975           Res.first = AArch64::FPR64RegClass.getRegister(RegNo);
7976           Res.second = &AArch64::FPR64RegClass;
7977         } else {
7978           Res.first = AArch64::FPR128RegClass.getRegister(RegNo);
7979           Res.second = &AArch64::FPR128RegClass;
7980         }
7981       }
7982     }
7983   }
7984 
7985   if (Res.second && !Subtarget->hasFPARMv8() &&
7986       !AArch64::GPR32allRegClass.hasSubClassEq(Res.second) &&
7987       !AArch64::GPR64allRegClass.hasSubClassEq(Res.second))
7988     return std::make_pair(0U, nullptr);
7989 
7990   return Res;
7991 }
7992 
7993 /// LowerAsmOperandForConstraint - Lower the specified operand into the Ops
7994 /// vector.  If it is invalid, don't add anything to Ops.
LowerAsmOperandForConstraint(SDValue Op,std::string & Constraint,std::vector<SDValue> & Ops,SelectionDAG & DAG) const7995 void AArch64TargetLowering::LowerAsmOperandForConstraint(
7996     SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops,
7997     SelectionDAG &DAG) const {
7998   SDValue Result;
7999 
8000   // Currently only support length 1 constraints.
8001   if (Constraint.length() != 1)
8002     return;
8003 
8004   char ConstraintLetter = Constraint[0];
8005   switch (ConstraintLetter) {
8006   default:
8007     break;
8008 
8009   // This set of constraints deal with valid constants for various instructions.
8010   // Validate and return a target constant for them if we can.
8011   case 'z': {
8012     // 'z' maps to xzr or wzr so it needs an input of 0.
8013     if (!isNullConstant(Op))
8014       return;
8015 
8016     if (Op.getValueType() == MVT::i64)
8017       Result = DAG.getRegister(AArch64::XZR, MVT::i64);
8018     else
8019       Result = DAG.getRegister(AArch64::WZR, MVT::i32);
8020     break;
8021   }
8022   case 'S': {
8023     // An absolute symbolic address or label reference.
8024     if (const GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Op)) {
8025       Result = DAG.getTargetGlobalAddress(GA->getGlobal(), SDLoc(Op),
8026                                           GA->getValueType(0));
8027     } else if (const BlockAddressSDNode *BA =
8028                    dyn_cast<BlockAddressSDNode>(Op)) {
8029       Result =
8030           DAG.getTargetBlockAddress(BA->getBlockAddress(), BA->getValueType(0));
8031     } else if (const ExternalSymbolSDNode *ES =
8032                    dyn_cast<ExternalSymbolSDNode>(Op)) {
8033       Result =
8034           DAG.getTargetExternalSymbol(ES->getSymbol(), ES->getValueType(0));
8035     } else
8036       return;
8037     break;
8038   }
8039 
8040   case 'I':
8041   case 'J':
8042   case 'K':
8043   case 'L':
8044   case 'M':
8045   case 'N':
8046     ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op);
8047     if (!C)
8048       return;
8049 
8050     // Grab the value and do some validation.
8051     uint64_t CVal = C->getZExtValue();
8052     switch (ConstraintLetter) {
8053     // The I constraint applies only to simple ADD or SUB immediate operands:
8054     // i.e. 0 to 4095 with optional shift by 12
8055     // The J constraint applies only to ADD or SUB immediates that would be
8056     // valid when negated, i.e. if [an add pattern] were to be output as a SUB
8057     // instruction [or vice versa], in other words -1 to -4095 with optional
8058     // left shift by 12.
8059     case 'I':
8060       if (isUInt<12>(CVal) || isShiftedUInt<12, 12>(CVal))
8061         break;
8062       return;
8063     case 'J': {
8064       uint64_t NVal = -C->getSExtValue();
8065       if (isUInt<12>(NVal) || isShiftedUInt<12, 12>(NVal)) {
8066         CVal = C->getSExtValue();
8067         break;
8068       }
8069       return;
8070     }
8071     // The K and L constraints apply *only* to logical immediates, including
8072     // what used to be the MOVI alias for ORR (though the MOVI alias has now
8073     // been removed and MOV should be used). So these constraints have to
8074     // distinguish between bit patterns that are valid 32-bit or 64-bit
8075     // "bitmask immediates": for example 0xaaaaaaaa is a valid bimm32 (K), but
8076     // not a valid bimm64 (L) where 0xaaaaaaaaaaaaaaaa would be valid, and vice
8077     // versa.
8078     case 'K':
8079       if (AArch64_AM::isLogicalImmediate(CVal, 32))
8080         break;
8081       return;
8082     case 'L':
8083       if (AArch64_AM::isLogicalImmediate(CVal, 64))
8084         break;
8085       return;
8086     // The M and N constraints are a superset of K and L respectively, for use
8087     // with the MOV (immediate) alias. As well as the logical immediates they
8088     // also match 32 or 64-bit immediates that can be loaded either using a
8089     // *single* MOVZ or MOVN , such as 32-bit 0x12340000, 0x00001234, 0xffffedca
8090     // (M) or 64-bit 0x1234000000000000 (N) etc.
8091     // As a note some of this code is liberally stolen from the asm parser.
8092     case 'M': {
8093       if (!isUInt<32>(CVal))
8094         return;
8095       if (AArch64_AM::isLogicalImmediate(CVal, 32))
8096         break;
8097       if ((CVal & 0xFFFF) == CVal)
8098         break;
8099       if ((CVal & 0xFFFF0000ULL) == CVal)
8100         break;
8101       uint64_t NCVal = ~(uint32_t)CVal;
8102       if ((NCVal & 0xFFFFULL) == NCVal)
8103         break;
8104       if ((NCVal & 0xFFFF0000ULL) == NCVal)
8105         break;
8106       return;
8107     }
8108     case 'N': {
8109       if (AArch64_AM::isLogicalImmediate(CVal, 64))
8110         break;
8111       if ((CVal & 0xFFFFULL) == CVal)
8112         break;
8113       if ((CVal & 0xFFFF0000ULL) == CVal)
8114         break;
8115       if ((CVal & 0xFFFF00000000ULL) == CVal)
8116         break;
8117       if ((CVal & 0xFFFF000000000000ULL) == CVal)
8118         break;
8119       uint64_t NCVal = ~CVal;
8120       if ((NCVal & 0xFFFFULL) == NCVal)
8121         break;
8122       if ((NCVal & 0xFFFF0000ULL) == NCVal)
8123         break;
8124       if ((NCVal & 0xFFFF00000000ULL) == NCVal)
8125         break;
8126       if ((NCVal & 0xFFFF000000000000ULL) == NCVal)
8127         break;
8128       return;
8129     }
8130     default:
8131       return;
8132     }
8133 
8134     // All assembler immediates are 64-bit integers.
8135     Result = DAG.getTargetConstant(CVal, SDLoc(Op), MVT::i64);
8136     break;
8137   }
8138 
8139   if (Result.getNode()) {
8140     Ops.push_back(Result);
8141     return;
8142   }
8143 
8144   return TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
8145 }
8146 
8147 //===----------------------------------------------------------------------===//
8148 //                     AArch64 Advanced SIMD Support
8149 //===----------------------------------------------------------------------===//
8150 
8151 /// WidenVector - Given a value in the V64 register class, produce the
8152 /// equivalent value in the V128 register class.
WidenVector(SDValue V64Reg,SelectionDAG & DAG)8153 static SDValue WidenVector(SDValue V64Reg, SelectionDAG &DAG) {
8154   EVT VT = V64Reg.getValueType();
8155   unsigned NarrowSize = VT.getVectorNumElements();
8156   MVT EltTy = VT.getVectorElementType().getSimpleVT();
8157   MVT WideTy = MVT::getVectorVT(EltTy, 2 * NarrowSize);
8158   SDLoc DL(V64Reg);
8159 
8160   return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideTy, DAG.getUNDEF(WideTy),
8161                      V64Reg, DAG.getConstant(0, DL, MVT::i64));
8162 }
8163 
8164 /// getExtFactor - Determine the adjustment factor for the position when
8165 /// generating an "extract from vector registers" instruction.
getExtFactor(SDValue & V)8166 static unsigned getExtFactor(SDValue &V) {
8167   EVT EltType = V.getValueType().getVectorElementType();
8168   return EltType.getSizeInBits() / 8;
8169 }
8170 
8171 /// NarrowVector - Given a value in the V128 register class, produce the
8172 /// equivalent value in the V64 register class.
NarrowVector(SDValue V128Reg,SelectionDAG & DAG)8173 static SDValue NarrowVector(SDValue V128Reg, SelectionDAG &DAG) {
8174   EVT VT = V128Reg.getValueType();
8175   unsigned WideSize = VT.getVectorNumElements();
8176   MVT EltTy = VT.getVectorElementType().getSimpleVT();
8177   MVT NarrowTy = MVT::getVectorVT(EltTy, WideSize / 2);
8178   SDLoc DL(V128Reg);
8179 
8180   return DAG.getTargetExtractSubreg(AArch64::dsub, DL, NarrowTy, V128Reg);
8181 }
8182 
8183 // Gather data to see if the operation can be modelled as a
8184 // shuffle in combination with VEXTs.
ReconstructShuffle(SDValue Op,SelectionDAG & DAG) const8185 SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op,
8186                                                   SelectionDAG &DAG) const {
8187   assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unknown opcode!");
8188   LLVM_DEBUG(dbgs() << "AArch64TargetLowering::ReconstructShuffle\n");
8189   SDLoc dl(Op);
8190   EVT VT = Op.getValueType();
8191   assert(!VT.isScalableVector() &&
8192          "Scalable vectors cannot be used with ISD::BUILD_VECTOR");
8193   unsigned NumElts = VT.getVectorNumElements();
8194 
8195   struct ShuffleSourceInfo {
8196     SDValue Vec;
8197     unsigned MinElt;
8198     unsigned MaxElt;
8199 
8200     // We may insert some combination of BITCASTs and VEXT nodes to force Vec to
8201     // be compatible with the shuffle we intend to construct. As a result
8202     // ShuffleVec will be some sliding window into the original Vec.
8203     SDValue ShuffleVec;
8204 
8205     // Code should guarantee that element i in Vec starts at element "WindowBase
8206     // + i * WindowScale in ShuffleVec".
8207     int WindowBase;
8208     int WindowScale;
8209 
8210     ShuffleSourceInfo(SDValue Vec)
8211       : Vec(Vec), MinElt(std::numeric_limits<unsigned>::max()), MaxElt(0),
8212           ShuffleVec(Vec), WindowBase(0), WindowScale(1) {}
8213 
8214     bool operator ==(SDValue OtherVec) { return Vec == OtherVec; }
8215   };
8216 
8217   // First gather all vectors used as an immediate source for this BUILD_VECTOR
8218   // node.
8219   SmallVector<ShuffleSourceInfo, 2> Sources;
8220   for (unsigned i = 0; i < NumElts; ++i) {
8221     SDValue V = Op.getOperand(i);
8222     if (V.isUndef())
8223       continue;
8224     else if (V.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
8225              !isa<ConstantSDNode>(V.getOperand(1))) {
8226       LLVM_DEBUG(
8227           dbgs() << "Reshuffle failed: "
8228                     "a shuffle can only come from building a vector from "
8229                     "various elements of other vectors, provided their "
8230                     "indices are constant\n");
8231       return SDValue();
8232     }
8233 
8234     // Add this element source to the list if it's not already there.
8235     SDValue SourceVec = V.getOperand(0);
8236     auto Source = find(Sources, SourceVec);
8237     if (Source == Sources.end())
8238       Source = Sources.insert(Sources.end(), ShuffleSourceInfo(SourceVec));
8239 
8240     // Update the minimum and maximum lane number seen.
8241     unsigned EltNo = cast<ConstantSDNode>(V.getOperand(1))->getZExtValue();
8242     Source->MinElt = std::min(Source->MinElt, EltNo);
8243     Source->MaxElt = std::max(Source->MaxElt, EltNo);
8244   }
8245 
8246   if (Sources.size() > 2) {
8247     LLVM_DEBUG(
8248         dbgs() << "Reshuffle failed: currently only do something sane when at "
8249                   "most two source vectors are involved\n");
8250     return SDValue();
8251   }
8252 
8253   // Find out the smallest element size among result and two sources, and use
8254   // it as element size to build the shuffle_vector.
8255   EVT SmallestEltTy = VT.getVectorElementType();
8256   for (auto &Source : Sources) {
8257     EVT SrcEltTy = Source.Vec.getValueType().getVectorElementType();
8258     if (SrcEltTy.bitsLT(SmallestEltTy)) {
8259       SmallestEltTy = SrcEltTy;
8260     }
8261   }
8262   unsigned ResMultiplier =
8263       VT.getScalarSizeInBits() / SmallestEltTy.getFixedSizeInBits();
8264   uint64_t VTSize = VT.getFixedSizeInBits();
8265   NumElts = VTSize / SmallestEltTy.getFixedSizeInBits();
8266   EVT ShuffleVT = EVT::getVectorVT(*DAG.getContext(), SmallestEltTy, NumElts);
8267 
8268   // If the source vector is too wide or too narrow, we may nevertheless be able
8269   // to construct a compatible shuffle either by concatenating it with UNDEF or
8270   // extracting a suitable range of elements.
8271   for (auto &Src : Sources) {
8272     EVT SrcVT = Src.ShuffleVec.getValueType();
8273 
8274     uint64_t SrcVTSize = SrcVT.getFixedSizeInBits();
8275     if (SrcVTSize == VTSize)
8276       continue;
8277 
8278     // This stage of the search produces a source with the same element type as
8279     // the original, but with a total width matching the BUILD_VECTOR output.
8280     EVT EltVT = SrcVT.getVectorElementType();
8281     unsigned NumSrcElts = VTSize / EltVT.getFixedSizeInBits();
8282     EVT DestVT = EVT::getVectorVT(*DAG.getContext(), EltVT, NumSrcElts);
8283 
8284     if (SrcVTSize < VTSize) {
8285       assert(2 * SrcVTSize == VTSize);
8286       // We can pad out the smaller vector for free, so if it's part of a
8287       // shuffle...
8288       Src.ShuffleVec =
8289           DAG.getNode(ISD::CONCAT_VECTORS, dl, DestVT, Src.ShuffleVec,
8290                       DAG.getUNDEF(Src.ShuffleVec.getValueType()));
8291       continue;
8292     }
8293 
8294     if (SrcVTSize != 2 * VTSize) {
8295       LLVM_DEBUG(
8296           dbgs() << "Reshuffle failed: result vector too small to extract\n");
8297       return SDValue();
8298     }
8299 
8300     if (Src.MaxElt - Src.MinElt >= NumSrcElts) {
8301       LLVM_DEBUG(
8302           dbgs() << "Reshuffle failed: span too large for a VEXT to cope\n");
8303       return SDValue();
8304     }
8305 
8306     if (Src.MinElt >= NumSrcElts) {
8307       // The extraction can just take the second half
8308       Src.ShuffleVec =
8309           DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, DestVT, Src.ShuffleVec,
8310                       DAG.getConstant(NumSrcElts, dl, MVT::i64));
8311       Src.WindowBase = -NumSrcElts;
8312     } else if (Src.MaxElt < NumSrcElts) {
8313       // The extraction can just take the first half
8314       Src.ShuffleVec =
8315           DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, DestVT, Src.ShuffleVec,
8316                       DAG.getConstant(0, dl, MVT::i64));
8317     } else {
8318       // An actual VEXT is needed
8319       SDValue VEXTSrc1 =
8320           DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, DestVT, Src.ShuffleVec,
8321                       DAG.getConstant(0, dl, MVT::i64));
8322       SDValue VEXTSrc2 =
8323           DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, DestVT, Src.ShuffleVec,
8324                       DAG.getConstant(NumSrcElts, dl, MVT::i64));
8325       unsigned Imm = Src.MinElt * getExtFactor(VEXTSrc1);
8326 
8327       if (!SrcVT.is64BitVector()) {
8328         LLVM_DEBUG(
8329           dbgs() << "Reshuffle failed: don't know how to lower AArch64ISD::EXT "
8330                     "for SVE vectors.");
8331         return SDValue();
8332       }
8333 
8334       Src.ShuffleVec = DAG.getNode(AArch64ISD::EXT, dl, DestVT, VEXTSrc1,
8335                                    VEXTSrc2,
8336                                    DAG.getConstant(Imm, dl, MVT::i32));
8337       Src.WindowBase = -Src.MinElt;
8338     }
8339   }
8340 
8341   // Another possible incompatibility occurs from the vector element types. We
8342   // can fix this by bitcasting the source vectors to the same type we intend
8343   // for the shuffle.
8344   for (auto &Src : Sources) {
8345     EVT SrcEltTy = Src.ShuffleVec.getValueType().getVectorElementType();
8346     if (SrcEltTy == SmallestEltTy)
8347       continue;
8348     assert(ShuffleVT.getVectorElementType() == SmallestEltTy);
8349     Src.ShuffleVec = DAG.getNode(ISD::BITCAST, dl, ShuffleVT, Src.ShuffleVec);
8350     Src.WindowScale =
8351         SrcEltTy.getFixedSizeInBits() / SmallestEltTy.getFixedSizeInBits();
8352     Src.WindowBase *= Src.WindowScale;
8353   }
8354 
8355   // Final sanity check before we try to actually produce a shuffle.
8356   LLVM_DEBUG(for (auto Src
8357                   : Sources)
8358                  assert(Src.ShuffleVec.getValueType() == ShuffleVT););
8359 
8360   // The stars all align, our next step is to produce the mask for the shuffle.
8361   SmallVector<int, 8> Mask(ShuffleVT.getVectorNumElements(), -1);
8362   int BitsPerShuffleLane = ShuffleVT.getScalarSizeInBits();
8363   for (unsigned i = 0; i < VT.getVectorNumElements(); ++i) {
8364     SDValue Entry = Op.getOperand(i);
8365     if (Entry.isUndef())
8366       continue;
8367 
8368     auto Src = find(Sources, Entry.getOperand(0));
8369     int EltNo = cast<ConstantSDNode>(Entry.getOperand(1))->getSExtValue();
8370 
8371     // EXTRACT_VECTOR_ELT performs an implicit any_ext; BUILD_VECTOR an implicit
8372     // trunc. So only std::min(SrcBits, DestBits) actually get defined in this
8373     // segment.
8374     EVT OrigEltTy = Entry.getOperand(0).getValueType().getVectorElementType();
8375     int BitsDefined = std::min(OrigEltTy.getScalarSizeInBits(),
8376                                VT.getScalarSizeInBits());
8377     int LanesDefined = BitsDefined / BitsPerShuffleLane;
8378 
8379     // This source is expected to fill ResMultiplier lanes of the final shuffle,
8380     // starting at the appropriate offset.
8381     int *LaneMask = &Mask[i * ResMultiplier];
8382 
8383     int ExtractBase = EltNo * Src->WindowScale + Src->WindowBase;
8384     ExtractBase += NumElts * (Src - Sources.begin());
8385     for (int j = 0; j < LanesDefined; ++j)
8386       LaneMask[j] = ExtractBase + j;
8387   }
8388 
8389   // Final check before we try to produce nonsense...
8390   if (!isShuffleMaskLegal(Mask, ShuffleVT)) {
8391     LLVM_DEBUG(dbgs() << "Reshuffle failed: illegal shuffle mask\n");
8392     return SDValue();
8393   }
8394 
8395   SDValue ShuffleOps[] = { DAG.getUNDEF(ShuffleVT), DAG.getUNDEF(ShuffleVT) };
8396   for (unsigned i = 0; i < Sources.size(); ++i)
8397     ShuffleOps[i] = Sources[i].ShuffleVec;
8398 
8399   SDValue Shuffle = DAG.getVectorShuffle(ShuffleVT, dl, ShuffleOps[0],
8400                                          ShuffleOps[1], Mask);
8401   SDValue V = DAG.getNode(ISD::BITCAST, dl, VT, Shuffle);
8402 
8403   LLVM_DEBUG(dbgs() << "Reshuffle, creating node: "; Shuffle.dump();
8404              dbgs() << "Reshuffle, creating node: "; V.dump(););
8405 
8406   return V;
8407 }
8408 
8409 // check if an EXT instruction can handle the shuffle mask when the
8410 // vector sources of the shuffle are the same.
isSingletonEXTMask(ArrayRef<int> M,EVT VT,unsigned & Imm)8411 static bool isSingletonEXTMask(ArrayRef<int> M, EVT VT, unsigned &Imm) {
8412   unsigned NumElts = VT.getVectorNumElements();
8413 
8414   // Assume that the first shuffle index is not UNDEF.  Fail if it is.
8415   if (M[0] < 0)
8416     return false;
8417 
8418   Imm = M[0];
8419 
8420   // If this is a VEXT shuffle, the immediate value is the index of the first
8421   // element.  The other shuffle indices must be the successive elements after
8422   // the first one.
8423   unsigned ExpectedElt = Imm;
8424   for (unsigned i = 1; i < NumElts; ++i) {
8425     // Increment the expected index.  If it wraps around, just follow it
8426     // back to index zero and keep going.
8427     ++ExpectedElt;
8428     if (ExpectedElt == NumElts)
8429       ExpectedElt = 0;
8430 
8431     if (M[i] < 0)
8432       continue; // ignore UNDEF indices
8433     if (ExpectedElt != static_cast<unsigned>(M[i]))
8434       return false;
8435   }
8436 
8437   return true;
8438 }
8439 
8440 /// Check if a vector shuffle corresponds to a DUP instructions with a larger
8441 /// element width than the vector lane type. If that is the case the function
8442 /// returns true and writes the value of the DUP instruction lane operand into
8443 /// DupLaneOp
isWideDUPMask(ArrayRef<int> M,EVT VT,unsigned BlockSize,unsigned & DupLaneOp)8444 static bool isWideDUPMask(ArrayRef<int> M, EVT VT, unsigned BlockSize,
8445                           unsigned &DupLaneOp) {
8446   assert((BlockSize == 16 || BlockSize == 32 || BlockSize == 64) &&
8447          "Only possible block sizes for wide DUP are: 16, 32, 64");
8448 
8449   if (BlockSize <= VT.getScalarSizeInBits())
8450     return false;
8451   if (BlockSize % VT.getScalarSizeInBits() != 0)
8452     return false;
8453   if (VT.getSizeInBits() % BlockSize != 0)
8454     return false;
8455 
8456   size_t SingleVecNumElements = VT.getVectorNumElements();
8457   size_t NumEltsPerBlock = BlockSize / VT.getScalarSizeInBits();
8458   size_t NumBlocks = VT.getSizeInBits() / BlockSize;
8459 
8460   // We are looking for masks like
8461   // [0, 1, 0, 1] or [2, 3, 2, 3] or [4, 5, 6, 7, 4, 5, 6, 7] where any element
8462   // might be replaced by 'undefined'. BlockIndices will eventually contain
8463   // lane indices of the duplicated block (i.e. [0, 1], [2, 3] and [4, 5, 6, 7]
8464   // for the above examples)
8465   SmallVector<int, 8> BlockElts(NumEltsPerBlock, -1);
8466   for (size_t BlockIndex = 0; BlockIndex < NumBlocks; BlockIndex++)
8467     for (size_t I = 0; I < NumEltsPerBlock; I++) {
8468       int Elt = M[BlockIndex * NumEltsPerBlock + I];
8469       if (Elt < 0)
8470         continue;
8471       // For now we don't support shuffles that use the second operand
8472       if ((unsigned)Elt >= SingleVecNumElements)
8473         return false;
8474       if (BlockElts[I] < 0)
8475         BlockElts[I] = Elt;
8476       else if (BlockElts[I] != Elt)
8477         return false;
8478     }
8479 
8480   // We found a candidate block (possibly with some undefs). It must be a
8481   // sequence of consecutive integers starting with a value divisible by
8482   // NumEltsPerBlock with some values possibly replaced by undef-s.
8483 
8484   // Find first non-undef element
8485   auto FirstRealEltIter = find_if(BlockElts, [](int Elt) { return Elt >= 0; });
8486   assert(FirstRealEltIter != BlockElts.end() &&
8487          "Shuffle with all-undefs must have been caught by previous cases, "
8488          "e.g. isSplat()");
8489   if (FirstRealEltIter == BlockElts.end()) {
8490     DupLaneOp = 0;
8491     return true;
8492   }
8493 
8494   // Index of FirstRealElt in BlockElts
8495   size_t FirstRealIndex = FirstRealEltIter - BlockElts.begin();
8496 
8497   if ((unsigned)*FirstRealEltIter < FirstRealIndex)
8498     return false;
8499   // BlockElts[0] must have the following value if it isn't undef:
8500   size_t Elt0 = *FirstRealEltIter - FirstRealIndex;
8501 
8502   // Check the first element
8503   if (Elt0 % NumEltsPerBlock != 0)
8504     return false;
8505   // Check that the sequence indeed consists of consecutive integers (modulo
8506   // undefs)
8507   for (size_t I = 0; I < NumEltsPerBlock; I++)
8508     if (BlockElts[I] >= 0 && (unsigned)BlockElts[I] != Elt0 + I)
8509       return false;
8510 
8511   DupLaneOp = Elt0 / NumEltsPerBlock;
8512   return true;
8513 }
8514 
8515 // check if an EXT instruction can handle the shuffle mask when the
8516 // vector sources of the shuffle are different.
isEXTMask(ArrayRef<int> M,EVT VT,bool & ReverseEXT,unsigned & Imm)8517 static bool isEXTMask(ArrayRef<int> M, EVT VT, bool &ReverseEXT,
8518                       unsigned &Imm) {
8519   // Look for the first non-undef element.
8520   const int *FirstRealElt = find_if(M, [](int Elt) { return Elt >= 0; });
8521 
8522   // Benefit form APInt to handle overflow when calculating expected element.
8523   unsigned NumElts = VT.getVectorNumElements();
8524   unsigned MaskBits = APInt(32, NumElts * 2).logBase2();
8525   APInt ExpectedElt = APInt(MaskBits, *FirstRealElt + 1);
8526   // The following shuffle indices must be the successive elements after the
8527   // first real element.
8528   const int *FirstWrongElt = std::find_if(FirstRealElt + 1, M.end(),
8529       [&](int Elt) {return Elt != ExpectedElt++ && Elt != -1;});
8530   if (FirstWrongElt != M.end())
8531     return false;
8532 
8533   // The index of an EXT is the first element if it is not UNDEF.
8534   // Watch out for the beginning UNDEFs. The EXT index should be the expected
8535   // value of the first element.  E.g.
8536   // <-1, -1, 3, ...> is treated as <1, 2, 3, ...>.
8537   // <-1, -1, 0, 1, ...> is treated as <2*NumElts-2, 2*NumElts-1, 0, 1, ...>.
8538   // ExpectedElt is the last mask index plus 1.
8539   Imm = ExpectedElt.getZExtValue();
8540 
8541   // There are two difference cases requiring to reverse input vectors.
8542   // For example, for vector <4 x i32> we have the following cases,
8543   // Case 1: shufflevector(<4 x i32>,<4 x i32>,<-1, -1, -1, 0>)
8544   // Case 2: shufflevector(<4 x i32>,<4 x i32>,<-1, -1, 7, 0>)
8545   // For both cases, we finally use mask <5, 6, 7, 0>, which requires
8546   // to reverse two input vectors.
8547   if (Imm < NumElts)
8548     ReverseEXT = true;
8549   else
8550     Imm -= NumElts;
8551 
8552   return true;
8553 }
8554 
8555 /// isREVMask - Check if a vector shuffle corresponds to a REV
8556 /// instruction with the specified blocksize.  (The order of the elements
8557 /// within each block of the vector is reversed.)
isREVMask(ArrayRef<int> M,EVT VT,unsigned BlockSize)8558 static bool isREVMask(ArrayRef<int> M, EVT VT, unsigned BlockSize) {
8559   assert((BlockSize == 16 || BlockSize == 32 || BlockSize == 64) &&
8560          "Only possible block sizes for REV are: 16, 32, 64");
8561 
8562   unsigned EltSz = VT.getScalarSizeInBits();
8563   if (EltSz == 64)
8564     return false;
8565 
8566   unsigned NumElts = VT.getVectorNumElements();
8567   unsigned BlockElts = M[0] + 1;
8568   // If the first shuffle index is UNDEF, be optimistic.
8569   if (M[0] < 0)
8570     BlockElts = BlockSize / EltSz;
8571 
8572   if (BlockSize <= EltSz || BlockSize != BlockElts * EltSz)
8573     return false;
8574 
8575   for (unsigned i = 0; i < NumElts; ++i) {
8576     if (M[i] < 0)
8577       continue; // ignore UNDEF indices
8578     if ((unsigned)M[i] != (i - i % BlockElts) + (BlockElts - 1 - i % BlockElts))
8579       return false;
8580   }
8581 
8582   return true;
8583 }
8584 
isZIPMask(ArrayRef<int> M,EVT VT,unsigned & WhichResult)8585 static bool isZIPMask(ArrayRef<int> M, EVT VT, unsigned &WhichResult) {
8586   unsigned NumElts = VT.getVectorNumElements();
8587   if (NumElts % 2 != 0)
8588     return false;
8589   WhichResult = (M[0] == 0 ? 0 : 1);
8590   unsigned Idx = WhichResult * NumElts / 2;
8591   for (unsigned i = 0; i != NumElts; i += 2) {
8592     if ((M[i] >= 0 && (unsigned)M[i] != Idx) ||
8593         (M[i + 1] >= 0 && (unsigned)M[i + 1] != Idx + NumElts))
8594       return false;
8595     Idx += 1;
8596   }
8597 
8598   return true;
8599 }
8600 
isUZPMask(ArrayRef<int> M,EVT VT,unsigned & WhichResult)8601 static bool isUZPMask(ArrayRef<int> M, EVT VT, unsigned &WhichResult) {
8602   unsigned NumElts = VT.getVectorNumElements();
8603   WhichResult = (M[0] == 0 ? 0 : 1);
8604   for (unsigned i = 0; i != NumElts; ++i) {
8605     if (M[i] < 0)
8606       continue; // ignore UNDEF indices
8607     if ((unsigned)M[i] != 2 * i + WhichResult)
8608       return false;
8609   }
8610 
8611   return true;
8612 }
8613 
isTRNMask(ArrayRef<int> M,EVT VT,unsigned & WhichResult)8614 static bool isTRNMask(ArrayRef<int> M, EVT VT, unsigned &WhichResult) {
8615   unsigned NumElts = VT.getVectorNumElements();
8616   if (NumElts % 2 != 0)
8617     return false;
8618   WhichResult = (M[0] == 0 ? 0 : 1);
8619   for (unsigned i = 0; i < NumElts; i += 2) {
8620     if ((M[i] >= 0 && (unsigned)M[i] != i + WhichResult) ||
8621         (M[i + 1] >= 0 && (unsigned)M[i + 1] != i + NumElts + WhichResult))
8622       return false;
8623   }
8624   return true;
8625 }
8626 
8627 /// isZIP_v_undef_Mask - Special case of isZIPMask for canonical form of
8628 /// "vector_shuffle v, v", i.e., "vector_shuffle v, undef".
8629 /// Mask is e.g., <0, 0, 1, 1> instead of <0, 4, 1, 5>.
isZIP_v_undef_Mask(ArrayRef<int> M,EVT VT,unsigned & WhichResult)8630 static bool isZIP_v_undef_Mask(ArrayRef<int> M, EVT VT, unsigned &WhichResult) {
8631   unsigned NumElts = VT.getVectorNumElements();
8632   if (NumElts % 2 != 0)
8633     return false;
8634   WhichResult = (M[0] == 0 ? 0 : 1);
8635   unsigned Idx = WhichResult * NumElts / 2;
8636   for (unsigned i = 0; i != NumElts; i += 2) {
8637     if ((M[i] >= 0 && (unsigned)M[i] != Idx) ||
8638         (M[i + 1] >= 0 && (unsigned)M[i + 1] != Idx))
8639       return false;
8640     Idx += 1;
8641   }
8642 
8643   return true;
8644 }
8645 
8646 /// isUZP_v_undef_Mask - Special case of isUZPMask for canonical form of
8647 /// "vector_shuffle v, v", i.e., "vector_shuffle v, undef".
8648 /// Mask is e.g., <0, 2, 0, 2> instead of <0, 2, 4, 6>,
isUZP_v_undef_Mask(ArrayRef<int> M,EVT VT,unsigned & WhichResult)8649 static bool isUZP_v_undef_Mask(ArrayRef<int> M, EVT VT, unsigned &WhichResult) {
8650   unsigned Half = VT.getVectorNumElements() / 2;
8651   WhichResult = (M[0] == 0 ? 0 : 1);
8652   for (unsigned j = 0; j != 2; ++j) {
8653     unsigned Idx = WhichResult;
8654     for (unsigned i = 0; i != Half; ++i) {
8655       int MIdx = M[i + j * Half];
8656       if (MIdx >= 0 && (unsigned)MIdx != Idx)
8657         return false;
8658       Idx += 2;
8659     }
8660   }
8661 
8662   return true;
8663 }
8664 
8665 /// isTRN_v_undef_Mask - Special case of isTRNMask for canonical form of
8666 /// "vector_shuffle v, v", i.e., "vector_shuffle v, undef".
8667 /// Mask is e.g., <0, 0, 2, 2> instead of <0, 4, 2, 6>.
isTRN_v_undef_Mask(ArrayRef<int> M,EVT VT,unsigned & WhichResult)8668 static bool isTRN_v_undef_Mask(ArrayRef<int> M, EVT VT, unsigned &WhichResult) {
8669   unsigned NumElts = VT.getVectorNumElements();
8670   if (NumElts % 2 != 0)
8671     return false;
8672   WhichResult = (M[0] == 0 ? 0 : 1);
8673   for (unsigned i = 0; i < NumElts; i += 2) {
8674     if ((M[i] >= 0 && (unsigned)M[i] != i + WhichResult) ||
8675         (M[i + 1] >= 0 && (unsigned)M[i + 1] != i + WhichResult))
8676       return false;
8677   }
8678   return true;
8679 }
8680 
isINSMask(ArrayRef<int> M,int NumInputElements,bool & DstIsLeft,int & Anomaly)8681 static bool isINSMask(ArrayRef<int> M, int NumInputElements,
8682                       bool &DstIsLeft, int &Anomaly) {
8683   if (M.size() != static_cast<size_t>(NumInputElements))
8684     return false;
8685 
8686   int NumLHSMatch = 0, NumRHSMatch = 0;
8687   int LastLHSMismatch = -1, LastRHSMismatch = -1;
8688 
8689   for (int i = 0; i < NumInputElements; ++i) {
8690     if (M[i] == -1) {
8691       ++NumLHSMatch;
8692       ++NumRHSMatch;
8693       continue;
8694     }
8695 
8696     if (M[i] == i)
8697       ++NumLHSMatch;
8698     else
8699       LastLHSMismatch = i;
8700 
8701     if (M[i] == i + NumInputElements)
8702       ++NumRHSMatch;
8703     else
8704       LastRHSMismatch = i;
8705   }
8706 
8707   if (NumLHSMatch == NumInputElements - 1) {
8708     DstIsLeft = true;
8709     Anomaly = LastLHSMismatch;
8710     return true;
8711   } else if (NumRHSMatch == NumInputElements - 1) {
8712     DstIsLeft = false;
8713     Anomaly = LastRHSMismatch;
8714     return true;
8715   }
8716 
8717   return false;
8718 }
8719 
isConcatMask(ArrayRef<int> Mask,EVT VT,bool SplitLHS)8720 static bool isConcatMask(ArrayRef<int> Mask, EVT VT, bool SplitLHS) {
8721   if (VT.getSizeInBits() != 128)
8722     return false;
8723 
8724   unsigned NumElts = VT.getVectorNumElements();
8725 
8726   for (int I = 0, E = NumElts / 2; I != E; I++) {
8727     if (Mask[I] != I)
8728       return false;
8729   }
8730 
8731   int Offset = NumElts / 2;
8732   for (int I = NumElts / 2, E = NumElts; I != E; I++) {
8733     if (Mask[I] != I + SplitLHS * Offset)
8734       return false;
8735   }
8736 
8737   return true;
8738 }
8739 
tryFormConcatFromShuffle(SDValue Op,SelectionDAG & DAG)8740 static SDValue tryFormConcatFromShuffle(SDValue Op, SelectionDAG &DAG) {
8741   SDLoc DL(Op);
8742   EVT VT = Op.getValueType();
8743   SDValue V0 = Op.getOperand(0);
8744   SDValue V1 = Op.getOperand(1);
8745   ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(Op)->getMask();
8746 
8747   if (VT.getVectorElementType() != V0.getValueType().getVectorElementType() ||
8748       VT.getVectorElementType() != V1.getValueType().getVectorElementType())
8749     return SDValue();
8750 
8751   bool SplitV0 = V0.getValueSizeInBits() == 128;
8752 
8753   if (!isConcatMask(Mask, VT, SplitV0))
8754     return SDValue();
8755 
8756   EVT CastVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
8757   if (SplitV0) {
8758     V0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, CastVT, V0,
8759                      DAG.getConstant(0, DL, MVT::i64));
8760   }
8761   if (V1.getValueSizeInBits() == 128) {
8762     V1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, CastVT, V1,
8763                      DAG.getConstant(0, DL, MVT::i64));
8764   }
8765   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, V0, V1);
8766 }
8767 
8768 /// GeneratePerfectShuffle - Given an entry in the perfect-shuffle table, emit
8769 /// the specified operations to build the shuffle.
GeneratePerfectShuffle(unsigned PFEntry,SDValue LHS,SDValue RHS,SelectionDAG & DAG,const SDLoc & dl)8770 static SDValue GeneratePerfectShuffle(unsigned PFEntry, SDValue LHS,
8771                                       SDValue RHS, SelectionDAG &DAG,
8772                                       const SDLoc &dl) {
8773   unsigned OpNum = (PFEntry >> 26) & 0x0F;
8774   unsigned LHSID = (PFEntry >> 13) & ((1 << 13) - 1);
8775   unsigned RHSID = (PFEntry >> 0) & ((1 << 13) - 1);
8776 
8777   enum {
8778     OP_COPY = 0, // Copy, used for things like <u,u,u,3> to say it is <0,1,2,3>
8779     OP_VREV,
8780     OP_VDUP0,
8781     OP_VDUP1,
8782     OP_VDUP2,
8783     OP_VDUP3,
8784     OP_VEXT1,
8785     OP_VEXT2,
8786     OP_VEXT3,
8787     OP_VUZPL, // VUZP, left result
8788     OP_VUZPR, // VUZP, right result
8789     OP_VZIPL, // VZIP, left result
8790     OP_VZIPR, // VZIP, right result
8791     OP_VTRNL, // VTRN, left result
8792     OP_VTRNR  // VTRN, right result
8793   };
8794 
8795   if (OpNum == OP_COPY) {
8796     if (LHSID == (1 * 9 + 2) * 9 + 3)
8797       return LHS;
8798     assert(LHSID == ((4 * 9 + 5) * 9 + 6) * 9 + 7 && "Illegal OP_COPY!");
8799     return RHS;
8800   }
8801 
8802   SDValue OpLHS, OpRHS;
8803   OpLHS = GeneratePerfectShuffle(PerfectShuffleTable[LHSID], LHS, RHS, DAG, dl);
8804   OpRHS = GeneratePerfectShuffle(PerfectShuffleTable[RHSID], LHS, RHS, DAG, dl);
8805   EVT VT = OpLHS.getValueType();
8806 
8807   switch (OpNum) {
8808   default:
8809     llvm_unreachable("Unknown shuffle opcode!");
8810   case OP_VREV:
8811     // VREV divides the vector in half and swaps within the half.
8812     if (VT.getVectorElementType() == MVT::i32 ||
8813         VT.getVectorElementType() == MVT::f32)
8814       return DAG.getNode(AArch64ISD::REV64, dl, VT, OpLHS);
8815     // vrev <4 x i16> -> REV32
8816     if (VT.getVectorElementType() == MVT::i16 ||
8817         VT.getVectorElementType() == MVT::f16 ||
8818         VT.getVectorElementType() == MVT::bf16)
8819       return DAG.getNode(AArch64ISD::REV32, dl, VT, OpLHS);
8820     // vrev <4 x i8> -> REV16
8821     assert(VT.getVectorElementType() == MVT::i8);
8822     return DAG.getNode(AArch64ISD::REV16, dl, VT, OpLHS);
8823   case OP_VDUP0:
8824   case OP_VDUP1:
8825   case OP_VDUP2:
8826   case OP_VDUP3: {
8827     EVT EltTy = VT.getVectorElementType();
8828     unsigned Opcode;
8829     if (EltTy == MVT::i8)
8830       Opcode = AArch64ISD::DUPLANE8;
8831     else if (EltTy == MVT::i16 || EltTy == MVT::f16 || EltTy == MVT::bf16)
8832       Opcode = AArch64ISD::DUPLANE16;
8833     else if (EltTy == MVT::i32 || EltTy == MVT::f32)
8834       Opcode = AArch64ISD::DUPLANE32;
8835     else if (EltTy == MVT::i64 || EltTy == MVT::f64)
8836       Opcode = AArch64ISD::DUPLANE64;
8837     else
8838       llvm_unreachable("Invalid vector element type?");
8839 
8840     if (VT.getSizeInBits() == 64)
8841       OpLHS = WidenVector(OpLHS, DAG);
8842     SDValue Lane = DAG.getConstant(OpNum - OP_VDUP0, dl, MVT::i64);
8843     return DAG.getNode(Opcode, dl, VT, OpLHS, Lane);
8844   }
8845   case OP_VEXT1:
8846   case OP_VEXT2:
8847   case OP_VEXT3: {
8848     unsigned Imm = (OpNum - OP_VEXT1 + 1) * getExtFactor(OpLHS);
8849     return DAG.getNode(AArch64ISD::EXT, dl, VT, OpLHS, OpRHS,
8850                        DAG.getConstant(Imm, dl, MVT::i32));
8851   }
8852   case OP_VUZPL:
8853     return DAG.getNode(AArch64ISD::UZP1, dl, DAG.getVTList(VT, VT), OpLHS,
8854                        OpRHS);
8855   case OP_VUZPR:
8856     return DAG.getNode(AArch64ISD::UZP2, dl, DAG.getVTList(VT, VT), OpLHS,
8857                        OpRHS);
8858   case OP_VZIPL:
8859     return DAG.getNode(AArch64ISD::ZIP1, dl, DAG.getVTList(VT, VT), OpLHS,
8860                        OpRHS);
8861   case OP_VZIPR:
8862     return DAG.getNode(AArch64ISD::ZIP2, dl, DAG.getVTList(VT, VT), OpLHS,
8863                        OpRHS);
8864   case OP_VTRNL:
8865     return DAG.getNode(AArch64ISD::TRN1, dl, DAG.getVTList(VT, VT), OpLHS,
8866                        OpRHS);
8867   case OP_VTRNR:
8868     return DAG.getNode(AArch64ISD::TRN2, dl, DAG.getVTList(VT, VT), OpLHS,
8869                        OpRHS);
8870   }
8871 }
8872 
GenerateTBL(SDValue Op,ArrayRef<int> ShuffleMask,SelectionDAG & DAG)8873 static SDValue GenerateTBL(SDValue Op, ArrayRef<int> ShuffleMask,
8874                            SelectionDAG &DAG) {
8875   // Check to see if we can use the TBL instruction.
8876   SDValue V1 = Op.getOperand(0);
8877   SDValue V2 = Op.getOperand(1);
8878   SDLoc DL(Op);
8879 
8880   EVT EltVT = Op.getValueType().getVectorElementType();
8881   unsigned BytesPerElt = EltVT.getSizeInBits() / 8;
8882 
8883   SmallVector<SDValue, 8> TBLMask;
8884   for (int Val : ShuffleMask) {
8885     for (unsigned Byte = 0; Byte < BytesPerElt; ++Byte) {
8886       unsigned Offset = Byte + Val * BytesPerElt;
8887       TBLMask.push_back(DAG.getConstant(Offset, DL, MVT::i32));
8888     }
8889   }
8890 
8891   MVT IndexVT = MVT::v8i8;
8892   unsigned IndexLen = 8;
8893   if (Op.getValueSizeInBits() == 128) {
8894     IndexVT = MVT::v16i8;
8895     IndexLen = 16;
8896   }
8897 
8898   SDValue V1Cst = DAG.getNode(ISD::BITCAST, DL, IndexVT, V1);
8899   SDValue V2Cst = DAG.getNode(ISD::BITCAST, DL, IndexVT, V2);
8900 
8901   SDValue Shuffle;
8902   if (V2.getNode()->isUndef()) {
8903     if (IndexLen == 8)
8904       V1Cst = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, V1Cst, V1Cst);
8905     Shuffle = DAG.getNode(
8906         ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
8907         DAG.getConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), V1Cst,
8908         DAG.getBuildVector(IndexVT, DL,
8909                            makeArrayRef(TBLMask.data(), IndexLen)));
8910   } else {
8911     if (IndexLen == 8) {
8912       V1Cst = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, V1Cst, V2Cst);
8913       Shuffle = DAG.getNode(
8914           ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
8915           DAG.getConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), V1Cst,
8916           DAG.getBuildVector(IndexVT, DL,
8917                              makeArrayRef(TBLMask.data(), IndexLen)));
8918     } else {
8919       // FIXME: We cannot, for the moment, emit a TBL2 instruction because we
8920       // cannot currently represent the register constraints on the input
8921       // table registers.
8922       //  Shuffle = DAG.getNode(AArch64ISD::TBL2, DL, IndexVT, V1Cst, V2Cst,
8923       //                   DAG.getBuildVector(IndexVT, DL, &TBLMask[0],
8924       //                   IndexLen));
8925       Shuffle = DAG.getNode(
8926           ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
8927           DAG.getConstant(Intrinsic::aarch64_neon_tbl2, DL, MVT::i32), V1Cst,
8928           V2Cst, DAG.getBuildVector(IndexVT, DL,
8929                                     makeArrayRef(TBLMask.data(), IndexLen)));
8930     }
8931   }
8932   return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(), Shuffle);
8933 }
8934 
getDUPLANEOp(EVT EltType)8935 static unsigned getDUPLANEOp(EVT EltType) {
8936   if (EltType == MVT::i8)
8937     return AArch64ISD::DUPLANE8;
8938   if (EltType == MVT::i16 || EltType == MVT::f16 || EltType == MVT::bf16)
8939     return AArch64ISD::DUPLANE16;
8940   if (EltType == MVT::i32 || EltType == MVT::f32)
8941     return AArch64ISD::DUPLANE32;
8942   if (EltType == MVT::i64 || EltType == MVT::f64)
8943     return AArch64ISD::DUPLANE64;
8944 
8945   llvm_unreachable("Invalid vector element type?");
8946 }
8947 
constructDup(SDValue V,int Lane,SDLoc dl,EVT VT,unsigned Opcode,SelectionDAG & DAG)8948 static SDValue constructDup(SDValue V, int Lane, SDLoc dl, EVT VT,
8949                             unsigned Opcode, SelectionDAG &DAG) {
8950   // Try to eliminate a bitcasted extract subvector before a DUPLANE.
8951   auto getScaledOffsetDup = [](SDValue BitCast, int &LaneC, MVT &CastVT) {
8952     // Match: dup (bitcast (extract_subv X, C)), LaneC
8953     if (BitCast.getOpcode() != ISD::BITCAST ||
8954         BitCast.getOperand(0).getOpcode() != ISD::EXTRACT_SUBVECTOR)
8955       return false;
8956 
8957     // The extract index must align in the destination type. That may not
8958     // happen if the bitcast is from narrow to wide type.
8959     SDValue Extract = BitCast.getOperand(0);
8960     unsigned ExtIdx = Extract.getConstantOperandVal(1);
8961     unsigned SrcEltBitWidth = Extract.getScalarValueSizeInBits();
8962     unsigned ExtIdxInBits = ExtIdx * SrcEltBitWidth;
8963     unsigned CastedEltBitWidth = BitCast.getScalarValueSizeInBits();
8964     if (ExtIdxInBits % CastedEltBitWidth != 0)
8965       return false;
8966 
8967     // Update the lane value by offsetting with the scaled extract index.
8968     LaneC += ExtIdxInBits / CastedEltBitWidth;
8969 
8970     // Determine the casted vector type of the wide vector input.
8971     // dup (bitcast (extract_subv X, C)), LaneC --> dup (bitcast X), LaneC'
8972     // Examples:
8973     // dup (bitcast (extract_subv v2f64 X, 1) to v2f32), 1 --> dup v4f32 X, 3
8974     // dup (bitcast (extract_subv v16i8 X, 8) to v4i16), 1 --> dup v8i16 X, 5
8975     unsigned SrcVecNumElts =
8976         Extract.getOperand(0).getValueSizeInBits() / CastedEltBitWidth;
8977     CastVT = MVT::getVectorVT(BitCast.getSimpleValueType().getScalarType(),
8978                               SrcVecNumElts);
8979     return true;
8980   };
8981   MVT CastVT;
8982   if (getScaledOffsetDup(V, Lane, CastVT)) {
8983     V = DAG.getBitcast(CastVT, V.getOperand(0).getOperand(0));
8984   } else if (V.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
8985     // The lane is incremented by the index of the extract.
8986     // Example: dup v2f32 (extract v4f32 X, 2), 1 --> dup v4f32 X, 3
8987     Lane += V.getConstantOperandVal(1);
8988     V = V.getOperand(0);
8989   } else if (V.getOpcode() == ISD::CONCAT_VECTORS) {
8990     // The lane is decremented if we are splatting from the 2nd operand.
8991     // Example: dup v4i32 (concat v2i32 X, v2i32 Y), 3 --> dup v4i32 Y, 1
8992     unsigned Idx = Lane >= (int)VT.getVectorNumElements() / 2;
8993     Lane -= Idx * VT.getVectorNumElements() / 2;
8994     V = WidenVector(V.getOperand(Idx), DAG);
8995   } else if (VT.getSizeInBits() == 64) {
8996     // Widen the operand to 128-bit register with undef.
8997     V = WidenVector(V, DAG);
8998   }
8999   return DAG.getNode(Opcode, dl, VT, V, DAG.getConstant(Lane, dl, MVT::i64));
9000 }
9001 
LowerVECTOR_SHUFFLE(SDValue Op,SelectionDAG & DAG) const9002 SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
9003                                                    SelectionDAG &DAG) const {
9004   SDLoc dl(Op);
9005   EVT VT = Op.getValueType();
9006 
9007   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
9008 
9009   // Convert shuffles that are directly supported on NEON to target-specific
9010   // DAG nodes, instead of keeping them as shuffles and matching them again
9011   // during code selection.  This is more efficient and avoids the possibility
9012   // of inconsistencies between legalization and selection.
9013   ArrayRef<int> ShuffleMask = SVN->getMask();
9014 
9015   SDValue V1 = Op.getOperand(0);
9016   SDValue V2 = Op.getOperand(1);
9017 
9018   assert(V1.getValueType() == VT && "Unexpected VECTOR_SHUFFLE type!");
9019   assert(ShuffleMask.size() == VT.getVectorNumElements() &&
9020          "Unexpected VECTOR_SHUFFLE mask size!");
9021 
9022   if (SVN->isSplat()) {
9023     int Lane = SVN->getSplatIndex();
9024     // If this is undef splat, generate it via "just" vdup, if possible.
9025     if (Lane == -1)
9026       Lane = 0;
9027 
9028     if (Lane == 0 && V1.getOpcode() == ISD::SCALAR_TO_VECTOR)
9029       return DAG.getNode(AArch64ISD::DUP, dl, V1.getValueType(),
9030                          V1.getOperand(0));
9031     // Test if V1 is a BUILD_VECTOR and the lane being referenced is a non-
9032     // constant. If so, we can just reference the lane's definition directly.
9033     if (V1.getOpcode() == ISD::BUILD_VECTOR &&
9034         !isa<ConstantSDNode>(V1.getOperand(Lane)))
9035       return DAG.getNode(AArch64ISD::DUP, dl, VT, V1.getOperand(Lane));
9036 
9037     // Otherwise, duplicate from the lane of the input vector.
9038     unsigned Opcode = getDUPLANEOp(V1.getValueType().getVectorElementType());
9039     return constructDup(V1, Lane, dl, VT, Opcode, DAG);
9040   }
9041 
9042   // Check if the mask matches a DUP for a wider element
9043   for (unsigned LaneSize : {64U, 32U, 16U}) {
9044     unsigned Lane = 0;
9045     if (isWideDUPMask(ShuffleMask, VT, LaneSize, Lane)) {
9046       unsigned Opcode = LaneSize == 64 ? AArch64ISD::DUPLANE64
9047                                        : LaneSize == 32 ? AArch64ISD::DUPLANE32
9048                                                         : AArch64ISD::DUPLANE16;
9049       // Cast V1 to an integer vector with required lane size
9050       MVT NewEltTy = MVT::getIntegerVT(LaneSize);
9051       unsigned NewEltCount = VT.getSizeInBits() / LaneSize;
9052       MVT NewVecTy = MVT::getVectorVT(NewEltTy, NewEltCount);
9053       V1 = DAG.getBitcast(NewVecTy, V1);
9054       // Constuct the DUP instruction
9055       V1 = constructDup(V1, Lane, dl, NewVecTy, Opcode, DAG);
9056       // Cast back to the original type
9057       return DAG.getBitcast(VT, V1);
9058     }
9059   }
9060 
9061   if (isREVMask(ShuffleMask, VT, 64))
9062     return DAG.getNode(AArch64ISD::REV64, dl, V1.getValueType(), V1, V2);
9063   if (isREVMask(ShuffleMask, VT, 32))
9064     return DAG.getNode(AArch64ISD::REV32, dl, V1.getValueType(), V1, V2);
9065   if (isREVMask(ShuffleMask, VT, 16))
9066     return DAG.getNode(AArch64ISD::REV16, dl, V1.getValueType(), V1, V2);
9067 
9068   if (((VT.getVectorNumElements() == 8 && VT.getScalarSizeInBits() == 16) ||
9069        (VT.getVectorNumElements() == 16 && VT.getScalarSizeInBits() == 8)) &&
9070       ShuffleVectorInst::isReverseMask(ShuffleMask)) {
9071     SDValue Rev = DAG.getNode(AArch64ISD::REV64, dl, VT, V1);
9072     return DAG.getNode(AArch64ISD::EXT, dl, VT, Rev, Rev,
9073                        DAG.getConstant(8, dl, MVT::i32));
9074   }
9075 
9076   bool ReverseEXT = false;
9077   unsigned Imm;
9078   if (isEXTMask(ShuffleMask, VT, ReverseEXT, Imm)) {
9079     if (ReverseEXT)
9080       std::swap(V1, V2);
9081     Imm *= getExtFactor(V1);
9082     return DAG.getNode(AArch64ISD::EXT, dl, V1.getValueType(), V1, V2,
9083                        DAG.getConstant(Imm, dl, MVT::i32));
9084   } else if (V2->isUndef() && isSingletonEXTMask(ShuffleMask, VT, Imm)) {
9085     Imm *= getExtFactor(V1);
9086     return DAG.getNode(AArch64ISD::EXT, dl, V1.getValueType(), V1, V1,
9087                        DAG.getConstant(Imm, dl, MVT::i32));
9088   }
9089 
9090   unsigned WhichResult;
9091   if (isZIPMask(ShuffleMask, VT, WhichResult)) {
9092     unsigned Opc = (WhichResult == 0) ? AArch64ISD::ZIP1 : AArch64ISD::ZIP2;
9093     return DAG.getNode(Opc, dl, V1.getValueType(), V1, V2);
9094   }
9095   if (isUZPMask(ShuffleMask, VT, WhichResult)) {
9096     unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2;
9097     return DAG.getNode(Opc, dl, V1.getValueType(), V1, V2);
9098   }
9099   if (isTRNMask(ShuffleMask, VT, WhichResult)) {
9100     unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2;
9101     return DAG.getNode(Opc, dl, V1.getValueType(), V1, V2);
9102   }
9103 
9104   if (isZIP_v_undef_Mask(ShuffleMask, VT, WhichResult)) {
9105     unsigned Opc = (WhichResult == 0) ? AArch64ISD::ZIP1 : AArch64ISD::ZIP2;
9106     return DAG.getNode(Opc, dl, V1.getValueType(), V1, V1);
9107   }
9108   if (isUZP_v_undef_Mask(ShuffleMask, VT, WhichResult)) {
9109     unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2;
9110     return DAG.getNode(Opc, dl, V1.getValueType(), V1, V1);
9111   }
9112   if (isTRN_v_undef_Mask(ShuffleMask, VT, WhichResult)) {
9113     unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2;
9114     return DAG.getNode(Opc, dl, V1.getValueType(), V1, V1);
9115   }
9116 
9117   if (SDValue Concat = tryFormConcatFromShuffle(Op, DAG))
9118     return Concat;
9119 
9120   bool DstIsLeft;
9121   int Anomaly;
9122   int NumInputElements = V1.getValueType().getVectorNumElements();
9123   if (isINSMask(ShuffleMask, NumInputElements, DstIsLeft, Anomaly)) {
9124     SDValue DstVec = DstIsLeft ? V1 : V2;
9125     SDValue DstLaneV = DAG.getConstant(Anomaly, dl, MVT::i64);
9126 
9127     SDValue SrcVec = V1;
9128     int SrcLane = ShuffleMask[Anomaly];
9129     if (SrcLane >= NumInputElements) {
9130       SrcVec = V2;
9131       SrcLane -= VT.getVectorNumElements();
9132     }
9133     SDValue SrcLaneV = DAG.getConstant(SrcLane, dl, MVT::i64);
9134 
9135     EVT ScalarVT = VT.getVectorElementType();
9136 
9137     if (ScalarVT.getFixedSizeInBits() < 32 && ScalarVT.isInteger())
9138       ScalarVT = MVT::i32;
9139 
9140     return DAG.getNode(
9141         ISD::INSERT_VECTOR_ELT, dl, VT, DstVec,
9142         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ScalarVT, SrcVec, SrcLaneV),
9143         DstLaneV);
9144   }
9145 
9146   // If the shuffle is not directly supported and it has 4 elements, use
9147   // the PerfectShuffle-generated table to synthesize it from other shuffles.
9148   unsigned NumElts = VT.getVectorNumElements();
9149   if (NumElts == 4) {
9150     unsigned PFIndexes[4];
9151     for (unsigned i = 0; i != 4; ++i) {
9152       if (ShuffleMask[i] < 0)
9153         PFIndexes[i] = 8;
9154       else
9155         PFIndexes[i] = ShuffleMask[i];
9156     }
9157 
9158     // Compute the index in the perfect shuffle table.
9159     unsigned PFTableIndex = PFIndexes[0] * 9 * 9 * 9 + PFIndexes[1] * 9 * 9 +
9160                             PFIndexes[2] * 9 + PFIndexes[3];
9161     unsigned PFEntry = PerfectShuffleTable[PFTableIndex];
9162     unsigned Cost = (PFEntry >> 30);
9163 
9164     if (Cost <= 4)
9165       return GeneratePerfectShuffle(PFEntry, V1, V2, DAG, dl);
9166   }
9167 
9168   return GenerateTBL(Op, ShuffleMask, DAG);
9169 }
9170 
LowerSPLAT_VECTOR(SDValue Op,SelectionDAG & DAG) const9171 SDValue AArch64TargetLowering::LowerSPLAT_VECTOR(SDValue Op,
9172                                                  SelectionDAG &DAG) const {
9173   SDLoc dl(Op);
9174   EVT VT = Op.getValueType();
9175   EVT ElemVT = VT.getScalarType();
9176   SDValue SplatVal = Op.getOperand(0);
9177 
9178   if (useSVEForFixedLengthVectorVT(VT))
9179     return LowerToScalableOp(Op, DAG);
9180 
9181   // Extend input splat value where needed to fit into a GPR (32b or 64b only)
9182   // FPRs don't have this restriction.
9183   switch (ElemVT.getSimpleVT().SimpleTy) {
9184   case MVT::i1: {
9185     // The only legal i1 vectors are SVE vectors, so we can use SVE-specific
9186     // lowering code.
9187     if (auto *ConstVal = dyn_cast<ConstantSDNode>(SplatVal)) {
9188       if (ConstVal->isOne())
9189         return getPTrue(DAG, dl, VT, AArch64SVEPredPattern::all);
9190       // TODO: Add special case for constant false
9191     }
9192     // The general case of i1.  There isn't any natural way to do this,
9193     // so we use some trickery with whilelo.
9194     SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i64);
9195     SplatVal = DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, MVT::i64, SplatVal,
9196                            DAG.getValueType(MVT::i1));
9197     SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, dl,
9198                                        MVT::i64);
9199     return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID,
9200                        DAG.getConstant(0, dl, MVT::i64), SplatVal);
9201   }
9202   case MVT::i8:
9203   case MVT::i16:
9204   case MVT::i32:
9205     SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i32);
9206     break;
9207   case MVT::i64:
9208     SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i64);
9209     break;
9210   case MVT::f16:
9211   case MVT::bf16:
9212   case MVT::f32:
9213   case MVT::f64:
9214     // Fine as is
9215     break;
9216   default:
9217     report_fatal_error("Unsupported SPLAT_VECTOR input operand type");
9218   }
9219 
9220   return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal);
9221 }
9222 
LowerDUPQLane(SDValue Op,SelectionDAG & DAG) const9223 SDValue AArch64TargetLowering::LowerDUPQLane(SDValue Op,
9224                                              SelectionDAG &DAG) const {
9225   SDLoc DL(Op);
9226 
9227   EVT VT = Op.getValueType();
9228   if (!isTypeLegal(VT) || !VT.isScalableVector())
9229     return SDValue();
9230 
9231   // Current lowering only supports the SVE-ACLE types.
9232   if (VT.getSizeInBits().getKnownMinSize() != AArch64::SVEBitsPerBlock)
9233     return SDValue();
9234 
9235   // The DUPQ operation is indepedent of element type so normalise to i64s.
9236   SDValue V = DAG.getNode(ISD::BITCAST, DL, MVT::nxv2i64, Op.getOperand(1));
9237   SDValue Idx128 = Op.getOperand(2);
9238 
9239   // DUPQ can be used when idx is in range.
9240   auto *CIdx = dyn_cast<ConstantSDNode>(Idx128);
9241   if (CIdx && (CIdx->getZExtValue() <= 3)) {
9242     SDValue CI = DAG.getTargetConstant(CIdx->getZExtValue(), DL, MVT::i64);
9243     SDNode *DUPQ =
9244         DAG.getMachineNode(AArch64::DUP_ZZI_Q, DL, MVT::nxv2i64, V, CI);
9245     return DAG.getNode(ISD::BITCAST, DL, VT, SDValue(DUPQ, 0));
9246   }
9247 
9248   // The ACLE says this must produce the same result as:
9249   //   svtbl(data, svadd_x(svptrue_b64(),
9250   //                       svand_x(svptrue_b64(), svindex_u64(0, 1), 1),
9251   //                       index * 2))
9252   SDValue One = DAG.getConstant(1, DL, MVT::i64);
9253   SDValue SplatOne = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv2i64, One);
9254 
9255   // create the vector 0,1,0,1,...
9256   SDValue SV = DAG.getNode(ISD::STEP_VECTOR, DL, MVT::nxv2i64, One);
9257   SV = DAG.getNode(ISD::AND, DL, MVT::nxv2i64, SV, SplatOne);
9258 
9259   // create the vector idx64,idx64+1,idx64,idx64+1,...
9260   SDValue Idx64 = DAG.getNode(ISD::ADD, DL, MVT::i64, Idx128, Idx128);
9261   SDValue SplatIdx64 = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv2i64, Idx64);
9262   SDValue ShuffleMask = DAG.getNode(ISD::ADD, DL, MVT::nxv2i64, SV, SplatIdx64);
9263 
9264   // create the vector Val[idx64],Val[idx64+1],Val[idx64],Val[idx64+1],...
9265   SDValue TBL = DAG.getNode(AArch64ISD::TBL, DL, MVT::nxv2i64, V, ShuffleMask);
9266   return DAG.getNode(ISD::BITCAST, DL, VT, TBL);
9267 }
9268 
9269 
resolveBuildVector(BuildVectorSDNode * BVN,APInt & CnstBits,APInt & UndefBits)9270 static bool resolveBuildVector(BuildVectorSDNode *BVN, APInt &CnstBits,
9271                                APInt &UndefBits) {
9272   EVT VT = BVN->getValueType(0);
9273   APInt SplatBits, SplatUndef;
9274   unsigned SplatBitSize;
9275   bool HasAnyUndefs;
9276   if (BVN->isConstantSplat(SplatBits, SplatUndef, SplatBitSize, HasAnyUndefs)) {
9277     unsigned NumSplats = VT.getSizeInBits() / SplatBitSize;
9278 
9279     for (unsigned i = 0; i < NumSplats; ++i) {
9280       CnstBits <<= SplatBitSize;
9281       UndefBits <<= SplatBitSize;
9282       CnstBits |= SplatBits.zextOrTrunc(VT.getSizeInBits());
9283       UndefBits |= (SplatBits ^ SplatUndef).zextOrTrunc(VT.getSizeInBits());
9284     }
9285 
9286     return true;
9287   }
9288 
9289   return false;
9290 }
9291 
9292 // Try 64-bit splatted SIMD immediate.
tryAdvSIMDModImm64(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits)9293 static SDValue tryAdvSIMDModImm64(unsigned NewOp, SDValue Op, SelectionDAG &DAG,
9294                                  const APInt &Bits) {
9295   if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
9296     uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
9297     EVT VT = Op.getValueType();
9298     MVT MovTy = (VT.getSizeInBits() == 128) ? MVT::v2i64 : MVT::f64;
9299 
9300     if (AArch64_AM::isAdvSIMDModImmType10(Value)) {
9301       Value = AArch64_AM::encodeAdvSIMDModImmType10(Value);
9302 
9303       SDLoc dl(Op);
9304       SDValue Mov = DAG.getNode(NewOp, dl, MovTy,
9305                                 DAG.getConstant(Value, dl, MVT::i32));
9306       return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
9307     }
9308   }
9309 
9310   return SDValue();
9311 }
9312 
9313 // Try 32-bit splatted SIMD immediate.
tryAdvSIMDModImm32(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits,const SDValue * LHS=nullptr)9314 static SDValue tryAdvSIMDModImm32(unsigned NewOp, SDValue Op, SelectionDAG &DAG,
9315                                   const APInt &Bits,
9316                                   const SDValue *LHS = nullptr) {
9317   if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
9318     uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
9319     EVT VT = Op.getValueType();
9320     MVT MovTy = (VT.getSizeInBits() == 128) ? MVT::v4i32 : MVT::v2i32;
9321     bool isAdvSIMDModImm = false;
9322     uint64_t Shift;
9323 
9324     if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType1(Value))) {
9325       Value = AArch64_AM::encodeAdvSIMDModImmType1(Value);
9326       Shift = 0;
9327     }
9328     else if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType2(Value))) {
9329       Value = AArch64_AM::encodeAdvSIMDModImmType2(Value);
9330       Shift = 8;
9331     }
9332     else if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType3(Value))) {
9333       Value = AArch64_AM::encodeAdvSIMDModImmType3(Value);
9334       Shift = 16;
9335     }
9336     else if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType4(Value))) {
9337       Value = AArch64_AM::encodeAdvSIMDModImmType4(Value);
9338       Shift = 24;
9339     }
9340 
9341     if (isAdvSIMDModImm) {
9342       SDLoc dl(Op);
9343       SDValue Mov;
9344 
9345       if (LHS)
9346         Mov = DAG.getNode(NewOp, dl, MovTy, *LHS,
9347                           DAG.getConstant(Value, dl, MVT::i32),
9348                           DAG.getConstant(Shift, dl, MVT::i32));
9349       else
9350         Mov = DAG.getNode(NewOp, dl, MovTy,
9351                           DAG.getConstant(Value, dl, MVT::i32),
9352                           DAG.getConstant(Shift, dl, MVT::i32));
9353 
9354       return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
9355     }
9356   }
9357 
9358   return SDValue();
9359 }
9360 
9361 // Try 16-bit splatted SIMD immediate.
tryAdvSIMDModImm16(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits,const SDValue * LHS=nullptr)9362 static SDValue tryAdvSIMDModImm16(unsigned NewOp, SDValue Op, SelectionDAG &DAG,
9363                                   const APInt &Bits,
9364                                   const SDValue *LHS = nullptr) {
9365   if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
9366     uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
9367     EVT VT = Op.getValueType();
9368     MVT MovTy = (VT.getSizeInBits() == 128) ? MVT::v8i16 : MVT::v4i16;
9369     bool isAdvSIMDModImm = false;
9370     uint64_t Shift;
9371 
9372     if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType5(Value))) {
9373       Value = AArch64_AM::encodeAdvSIMDModImmType5(Value);
9374       Shift = 0;
9375     }
9376     else if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType6(Value))) {
9377       Value = AArch64_AM::encodeAdvSIMDModImmType6(Value);
9378       Shift = 8;
9379     }
9380 
9381     if (isAdvSIMDModImm) {
9382       SDLoc dl(Op);
9383       SDValue Mov;
9384 
9385       if (LHS)
9386         Mov = DAG.getNode(NewOp, dl, MovTy, *LHS,
9387                           DAG.getConstant(Value, dl, MVT::i32),
9388                           DAG.getConstant(Shift, dl, MVT::i32));
9389       else
9390         Mov = DAG.getNode(NewOp, dl, MovTy,
9391                           DAG.getConstant(Value, dl, MVT::i32),
9392                           DAG.getConstant(Shift, dl, MVT::i32));
9393 
9394       return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
9395     }
9396   }
9397 
9398   return SDValue();
9399 }
9400 
9401 // Try 32-bit splatted SIMD immediate with shifted ones.
tryAdvSIMDModImm321s(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits)9402 static SDValue tryAdvSIMDModImm321s(unsigned NewOp, SDValue Op,
9403                                     SelectionDAG &DAG, const APInt &Bits) {
9404   if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
9405     uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
9406     EVT VT = Op.getValueType();
9407     MVT MovTy = (VT.getSizeInBits() == 128) ? MVT::v4i32 : MVT::v2i32;
9408     bool isAdvSIMDModImm = false;
9409     uint64_t Shift;
9410 
9411     if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType7(Value))) {
9412       Value = AArch64_AM::encodeAdvSIMDModImmType7(Value);
9413       Shift = 264;
9414     }
9415     else if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType8(Value))) {
9416       Value = AArch64_AM::encodeAdvSIMDModImmType8(Value);
9417       Shift = 272;
9418     }
9419 
9420     if (isAdvSIMDModImm) {
9421       SDLoc dl(Op);
9422       SDValue Mov = DAG.getNode(NewOp, dl, MovTy,
9423                                 DAG.getConstant(Value, dl, MVT::i32),
9424                                 DAG.getConstant(Shift, dl, MVT::i32));
9425       return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
9426     }
9427   }
9428 
9429   return SDValue();
9430 }
9431 
9432 // Try 8-bit splatted SIMD immediate.
tryAdvSIMDModImm8(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits)9433 static SDValue tryAdvSIMDModImm8(unsigned NewOp, SDValue Op, SelectionDAG &DAG,
9434                                  const APInt &Bits) {
9435   if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
9436     uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
9437     EVT VT = Op.getValueType();
9438     MVT MovTy = (VT.getSizeInBits() == 128) ? MVT::v16i8 : MVT::v8i8;
9439 
9440     if (AArch64_AM::isAdvSIMDModImmType9(Value)) {
9441       Value = AArch64_AM::encodeAdvSIMDModImmType9(Value);
9442 
9443       SDLoc dl(Op);
9444       SDValue Mov = DAG.getNode(NewOp, dl, MovTy,
9445                                 DAG.getConstant(Value, dl, MVT::i32));
9446       return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
9447     }
9448   }
9449 
9450   return SDValue();
9451 }
9452 
9453 // Try FP splatted SIMD immediate.
tryAdvSIMDModImmFP(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits)9454 static SDValue tryAdvSIMDModImmFP(unsigned NewOp, SDValue Op, SelectionDAG &DAG,
9455                                   const APInt &Bits) {
9456   if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
9457     uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
9458     EVT VT = Op.getValueType();
9459     bool isWide = (VT.getSizeInBits() == 128);
9460     MVT MovTy;
9461     bool isAdvSIMDModImm = false;
9462 
9463     if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType11(Value))) {
9464       Value = AArch64_AM::encodeAdvSIMDModImmType11(Value);
9465       MovTy = isWide ? MVT::v4f32 : MVT::v2f32;
9466     }
9467     else if (isWide &&
9468              (isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType12(Value))) {
9469       Value = AArch64_AM::encodeAdvSIMDModImmType12(Value);
9470       MovTy = MVT::v2f64;
9471     }
9472 
9473     if (isAdvSIMDModImm) {
9474       SDLoc dl(Op);
9475       SDValue Mov = DAG.getNode(NewOp, dl, MovTy,
9476                                 DAG.getConstant(Value, dl, MVT::i32));
9477       return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
9478     }
9479   }
9480 
9481   return SDValue();
9482 }
9483 
9484 // Specialized code to quickly find if PotentialBVec is a BuildVector that
9485 // consists of only the same constant int value, returned in reference arg
9486 // ConstVal
isAllConstantBuildVector(const SDValue & PotentialBVec,uint64_t & ConstVal)9487 static bool isAllConstantBuildVector(const SDValue &PotentialBVec,
9488                                      uint64_t &ConstVal) {
9489   BuildVectorSDNode *Bvec = dyn_cast<BuildVectorSDNode>(PotentialBVec);
9490   if (!Bvec)
9491     return false;
9492   ConstantSDNode *FirstElt = dyn_cast<ConstantSDNode>(Bvec->getOperand(0));
9493   if (!FirstElt)
9494     return false;
9495   EVT VT = Bvec->getValueType(0);
9496   unsigned NumElts = VT.getVectorNumElements();
9497   for (unsigned i = 1; i < NumElts; ++i)
9498     if (dyn_cast<ConstantSDNode>(Bvec->getOperand(i)) != FirstElt)
9499       return false;
9500   ConstVal = FirstElt->getZExtValue();
9501   return true;
9502 }
9503 
getIntrinsicID(const SDNode * N)9504 static unsigned getIntrinsicID(const SDNode *N) {
9505   unsigned Opcode = N->getOpcode();
9506   switch (Opcode) {
9507   default:
9508     return Intrinsic::not_intrinsic;
9509   case ISD::INTRINSIC_WO_CHAIN: {
9510     unsigned IID = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue();
9511     if (IID < Intrinsic::num_intrinsics)
9512       return IID;
9513     return Intrinsic::not_intrinsic;
9514   }
9515   }
9516 }
9517 
9518 // Attempt to form a vector S[LR]I from (or (and X, BvecC1), (lsl Y, C2)),
9519 // to (SLI X, Y, C2), where X and Y have matching vector types, BvecC1 is a
9520 // BUILD_VECTORs with constant element C1, C2 is a constant, and:
9521 //   - for the SLI case: C1 == ~(Ones(ElemSizeInBits) << C2)
9522 //   - for the SRI case: C1 == ~(Ones(ElemSizeInBits) >> C2)
9523 // The (or (lsl Y, C2), (and X, BvecC1)) case is also handled.
tryLowerToSLI(SDNode * N,SelectionDAG & DAG)9524 static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
9525   EVT VT = N->getValueType(0);
9526 
9527   if (!VT.isVector())
9528     return SDValue();
9529 
9530   SDLoc DL(N);
9531 
9532   SDValue And;
9533   SDValue Shift;
9534 
9535   SDValue FirstOp = N->getOperand(0);
9536   unsigned FirstOpc = FirstOp.getOpcode();
9537   SDValue SecondOp = N->getOperand(1);
9538   unsigned SecondOpc = SecondOp.getOpcode();
9539 
9540   // Is one of the operands an AND or a BICi? The AND may have been optimised to
9541   // a BICi in order to use an immediate instead of a register.
9542   // Is the other operand an shl or lshr? This will have been turned into:
9543   // AArch64ISD::VSHL vector, #shift or AArch64ISD::VLSHR vector, #shift.
9544   if ((FirstOpc == ISD::AND || FirstOpc == AArch64ISD::BICi) &&
9545       (SecondOpc == AArch64ISD::VSHL || SecondOpc == AArch64ISD::VLSHR)) {
9546     And = FirstOp;
9547     Shift = SecondOp;
9548 
9549   } else if ((SecondOpc == ISD::AND || SecondOpc == AArch64ISD::BICi) &&
9550              (FirstOpc == AArch64ISD::VSHL || FirstOpc == AArch64ISD::VLSHR)) {
9551     And = SecondOp;
9552     Shift = FirstOp;
9553   } else
9554     return SDValue();
9555 
9556   bool IsAnd = And.getOpcode() == ISD::AND;
9557   bool IsShiftRight = Shift.getOpcode() == AArch64ISD::VLSHR;
9558 
9559   // Is the shift amount constant?
9560   ConstantSDNode *C2node = dyn_cast<ConstantSDNode>(Shift.getOperand(1));
9561   if (!C2node)
9562     return SDValue();
9563 
9564   uint64_t C1;
9565   if (IsAnd) {
9566     // Is the and mask vector all constant?
9567     if (!isAllConstantBuildVector(And.getOperand(1), C1))
9568       return SDValue();
9569   } else {
9570     // Reconstruct the corresponding AND immediate from the two BICi immediates.
9571     ConstantSDNode *C1nodeImm = dyn_cast<ConstantSDNode>(And.getOperand(1));
9572     ConstantSDNode *C1nodeShift = dyn_cast<ConstantSDNode>(And.getOperand(2));
9573     assert(C1nodeImm && C1nodeShift);
9574     C1 = ~(C1nodeImm->getZExtValue() << C1nodeShift->getZExtValue());
9575   }
9576 
9577   // Is C1 == ~(Ones(ElemSizeInBits) << C2) or
9578   // C1 == ~(Ones(ElemSizeInBits) >> C2), taking into account
9579   // how much one can shift elements of a particular size?
9580   uint64_t C2 = C2node->getZExtValue();
9581   unsigned ElemSizeInBits = VT.getScalarSizeInBits();
9582   if (C2 > ElemSizeInBits)
9583     return SDValue();
9584 
9585   APInt C1AsAPInt(ElemSizeInBits, C1);
9586   APInt RequiredC1 = IsShiftRight ? APInt::getHighBitsSet(ElemSizeInBits, C2)
9587                                   : APInt::getLowBitsSet(ElemSizeInBits, C2);
9588   if (C1AsAPInt != RequiredC1)
9589     return SDValue();
9590 
9591   SDValue X = And.getOperand(0);
9592   SDValue Y = Shift.getOperand(0);
9593 
9594   unsigned Inst = IsShiftRight ? AArch64ISD::VSRI : AArch64ISD::VSLI;
9595   SDValue ResultSLI = DAG.getNode(Inst, DL, VT, X, Y, Shift.getOperand(1));
9596 
9597   LLVM_DEBUG(dbgs() << "aarch64-lower: transformed: \n");
9598   LLVM_DEBUG(N->dump(&DAG));
9599   LLVM_DEBUG(dbgs() << "into: \n");
9600   LLVM_DEBUG(ResultSLI->dump(&DAG));
9601 
9602   ++NumShiftInserts;
9603   return ResultSLI;
9604 }
9605 
LowerVectorOR(SDValue Op,SelectionDAG & DAG) const9606 SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op,
9607                                              SelectionDAG &DAG) const {
9608   if (useSVEForFixedLengthVectorVT(Op.getValueType()))
9609     return LowerToScalableOp(Op, DAG);
9610 
9611   // Attempt to form a vector S[LR]I from (or (and X, C1), (lsl Y, C2))
9612   if (SDValue Res = tryLowerToSLI(Op.getNode(), DAG))
9613     return Res;
9614 
9615   EVT VT = Op.getValueType();
9616 
9617   SDValue LHS = Op.getOperand(0);
9618   BuildVectorSDNode *BVN =
9619       dyn_cast<BuildVectorSDNode>(Op.getOperand(1).getNode());
9620   if (!BVN) {
9621     // OR commutes, so try swapping the operands.
9622     LHS = Op.getOperand(1);
9623     BVN = dyn_cast<BuildVectorSDNode>(Op.getOperand(0).getNode());
9624   }
9625   if (!BVN)
9626     return Op;
9627 
9628   APInt DefBits(VT.getSizeInBits(), 0);
9629   APInt UndefBits(VT.getSizeInBits(), 0);
9630   if (resolveBuildVector(BVN, DefBits, UndefBits)) {
9631     SDValue NewOp;
9632 
9633     if ((NewOp = tryAdvSIMDModImm32(AArch64ISD::ORRi, Op, DAG,
9634                                     DefBits, &LHS)) ||
9635         (NewOp = tryAdvSIMDModImm16(AArch64ISD::ORRi, Op, DAG,
9636                                     DefBits, &LHS)))
9637       return NewOp;
9638 
9639     if ((NewOp = tryAdvSIMDModImm32(AArch64ISD::ORRi, Op, DAG,
9640                                     UndefBits, &LHS)) ||
9641         (NewOp = tryAdvSIMDModImm16(AArch64ISD::ORRi, Op, DAG,
9642                                     UndefBits, &LHS)))
9643       return NewOp;
9644   }
9645 
9646   // We can always fall back to a non-immediate OR.
9647   return Op;
9648 }
9649 
9650 // Normalize the operands of BUILD_VECTOR. The value of constant operands will
9651 // be truncated to fit element width.
NormalizeBuildVector(SDValue Op,SelectionDAG & DAG)9652 static SDValue NormalizeBuildVector(SDValue Op,
9653                                     SelectionDAG &DAG) {
9654   assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unknown opcode!");
9655   SDLoc dl(Op);
9656   EVT VT = Op.getValueType();
9657   EVT EltTy= VT.getVectorElementType();
9658 
9659   if (EltTy.isFloatingPoint() || EltTy.getSizeInBits() > 16)
9660     return Op;
9661 
9662   SmallVector<SDValue, 16> Ops;
9663   for (SDValue Lane : Op->ops()) {
9664     // For integer vectors, type legalization would have promoted the
9665     // operands already. Otherwise, if Op is a floating-point splat
9666     // (with operands cast to integers), then the only possibilities
9667     // are constants and UNDEFs.
9668     if (auto *CstLane = dyn_cast<ConstantSDNode>(Lane)) {
9669       APInt LowBits(EltTy.getSizeInBits(),
9670                     CstLane->getZExtValue());
9671       Lane = DAG.getConstant(LowBits.getZExtValue(), dl, MVT::i32);
9672     } else if (Lane.getNode()->isUndef()) {
9673       Lane = DAG.getUNDEF(MVT::i32);
9674     } else {
9675       assert(Lane.getValueType() == MVT::i32 &&
9676              "Unexpected BUILD_VECTOR operand type");
9677     }
9678     Ops.push_back(Lane);
9679   }
9680   return DAG.getBuildVector(VT, dl, Ops);
9681 }
9682 
ConstantBuildVector(SDValue Op,SelectionDAG & DAG)9683 static SDValue ConstantBuildVector(SDValue Op, SelectionDAG &DAG) {
9684   EVT VT = Op.getValueType();
9685 
9686   APInt DefBits(VT.getSizeInBits(), 0);
9687   APInt UndefBits(VT.getSizeInBits(), 0);
9688   BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(Op.getNode());
9689   if (resolveBuildVector(BVN, DefBits, UndefBits)) {
9690     SDValue NewOp;
9691     if ((NewOp = tryAdvSIMDModImm64(AArch64ISD::MOVIedit, Op, DAG, DefBits)) ||
9692         (NewOp = tryAdvSIMDModImm32(AArch64ISD::MOVIshift, Op, DAG, DefBits)) ||
9693         (NewOp = tryAdvSIMDModImm321s(AArch64ISD::MOVImsl, Op, DAG, DefBits)) ||
9694         (NewOp = tryAdvSIMDModImm16(AArch64ISD::MOVIshift, Op, DAG, DefBits)) ||
9695         (NewOp = tryAdvSIMDModImm8(AArch64ISD::MOVI, Op, DAG, DefBits)) ||
9696         (NewOp = tryAdvSIMDModImmFP(AArch64ISD::FMOV, Op, DAG, DefBits)))
9697       return NewOp;
9698 
9699     DefBits = ~DefBits;
9700     if ((NewOp = tryAdvSIMDModImm32(AArch64ISD::MVNIshift, Op, DAG, DefBits)) ||
9701         (NewOp = tryAdvSIMDModImm321s(AArch64ISD::MVNImsl, Op, DAG, DefBits)) ||
9702         (NewOp = tryAdvSIMDModImm16(AArch64ISD::MVNIshift, Op, DAG, DefBits)))
9703       return NewOp;
9704 
9705     DefBits = UndefBits;
9706     if ((NewOp = tryAdvSIMDModImm64(AArch64ISD::MOVIedit, Op, DAG, DefBits)) ||
9707         (NewOp = tryAdvSIMDModImm32(AArch64ISD::MOVIshift, Op, DAG, DefBits)) ||
9708         (NewOp = tryAdvSIMDModImm321s(AArch64ISD::MOVImsl, Op, DAG, DefBits)) ||
9709         (NewOp = tryAdvSIMDModImm16(AArch64ISD::MOVIshift, Op, DAG, DefBits)) ||
9710         (NewOp = tryAdvSIMDModImm8(AArch64ISD::MOVI, Op, DAG, DefBits)) ||
9711         (NewOp = tryAdvSIMDModImmFP(AArch64ISD::FMOV, Op, DAG, DefBits)))
9712       return NewOp;
9713 
9714     DefBits = ~UndefBits;
9715     if ((NewOp = tryAdvSIMDModImm32(AArch64ISD::MVNIshift, Op, DAG, DefBits)) ||
9716         (NewOp = tryAdvSIMDModImm321s(AArch64ISD::MVNImsl, Op, DAG, DefBits)) ||
9717         (NewOp = tryAdvSIMDModImm16(AArch64ISD::MVNIshift, Op, DAG, DefBits)))
9718       return NewOp;
9719   }
9720 
9721   return SDValue();
9722 }
9723 
LowerBUILD_VECTOR(SDValue Op,SelectionDAG & DAG) const9724 SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op,
9725                                                  SelectionDAG &DAG) const {
9726   EVT VT = Op.getValueType();
9727 
9728   // Try to build a simple constant vector.
9729   Op = NormalizeBuildVector(Op, DAG);
9730   if (VT.isInteger()) {
9731     // Certain vector constants, used to express things like logical NOT and
9732     // arithmetic NEG, are passed through unmodified.  This allows special
9733     // patterns for these operations to match, which will lower these constants
9734     // to whatever is proven necessary.
9735     BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(Op.getNode());
9736     if (BVN->isConstant())
9737       if (ConstantSDNode *Const = BVN->getConstantSplatNode()) {
9738         unsigned BitSize = VT.getVectorElementType().getSizeInBits();
9739         APInt Val(BitSize,
9740                   Const->getAPIntValue().zextOrTrunc(BitSize).getZExtValue());
9741         if (Val.isNullValue() || Val.isAllOnesValue())
9742           return Op;
9743       }
9744   }
9745 
9746   if (SDValue V = ConstantBuildVector(Op, DAG))
9747     return V;
9748 
9749   // Scan through the operands to find some interesting properties we can
9750   // exploit:
9751   //   1) If only one value is used, we can use a DUP, or
9752   //   2) if only the low element is not undef, we can just insert that, or
9753   //   3) if only one constant value is used (w/ some non-constant lanes),
9754   //      we can splat the constant value into the whole vector then fill
9755   //      in the non-constant lanes.
9756   //   4) FIXME: If different constant values are used, but we can intelligently
9757   //             select the values we'll be overwriting for the non-constant
9758   //             lanes such that we can directly materialize the vector
9759   //             some other way (MOVI, e.g.), we can be sneaky.
9760   //   5) if all operands are EXTRACT_VECTOR_ELT, check for VUZP.
9761   SDLoc dl(Op);
9762   unsigned NumElts = VT.getVectorNumElements();
9763   bool isOnlyLowElement = true;
9764   bool usesOnlyOneValue = true;
9765   bool usesOnlyOneConstantValue = true;
9766   bool isConstant = true;
9767   bool AllLanesExtractElt = true;
9768   unsigned NumConstantLanes = 0;
9769   unsigned NumDifferentLanes = 0;
9770   unsigned NumUndefLanes = 0;
9771   SDValue Value;
9772   SDValue ConstantValue;
9773   for (unsigned i = 0; i < NumElts; ++i) {
9774     SDValue V = Op.getOperand(i);
9775     if (V.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
9776       AllLanesExtractElt = false;
9777     if (V.isUndef()) {
9778       ++NumUndefLanes;
9779       continue;
9780     }
9781     if (i > 0)
9782       isOnlyLowElement = false;
9783     if (!isIntOrFPConstant(V))
9784       isConstant = false;
9785 
9786     if (isIntOrFPConstant(V)) {
9787       ++NumConstantLanes;
9788       if (!ConstantValue.getNode())
9789         ConstantValue = V;
9790       else if (ConstantValue != V)
9791         usesOnlyOneConstantValue = false;
9792     }
9793 
9794     if (!Value.getNode())
9795       Value = V;
9796     else if (V != Value) {
9797       usesOnlyOneValue = false;
9798       ++NumDifferentLanes;
9799     }
9800   }
9801 
9802   if (!Value.getNode()) {
9803     LLVM_DEBUG(
9804         dbgs() << "LowerBUILD_VECTOR: value undefined, creating undef node\n");
9805     return DAG.getUNDEF(VT);
9806   }
9807 
9808   // Convert BUILD_VECTOR where all elements but the lowest are undef into
9809   // SCALAR_TO_VECTOR, except for when we have a single-element constant vector
9810   // as SimplifyDemandedBits will just turn that back into BUILD_VECTOR.
9811   if (isOnlyLowElement && !(NumElts == 1 && isIntOrFPConstant(Value))) {
9812     LLVM_DEBUG(dbgs() << "LowerBUILD_VECTOR: only low element used, creating 1 "
9813                          "SCALAR_TO_VECTOR node\n");
9814     return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Value);
9815   }
9816 
9817   if (AllLanesExtractElt) {
9818     SDNode *Vector = nullptr;
9819     bool Even = false;
9820     bool Odd = false;
9821     // Check whether the extract elements match the Even pattern <0,2,4,...> or
9822     // the Odd pattern <1,3,5,...>.
9823     for (unsigned i = 0; i < NumElts; ++i) {
9824       SDValue V = Op.getOperand(i);
9825       const SDNode *N = V.getNode();
9826       if (!isa<ConstantSDNode>(N->getOperand(1)))
9827         break;
9828       SDValue N0 = N->getOperand(0);
9829 
9830       // All elements are extracted from the same vector.
9831       if (!Vector) {
9832         Vector = N0.getNode();
9833         // Check that the type of EXTRACT_VECTOR_ELT matches the type of
9834         // BUILD_VECTOR.
9835         if (VT.getVectorElementType() !=
9836             N0.getValueType().getVectorElementType())
9837           break;
9838       } else if (Vector != N0.getNode()) {
9839         Odd = false;
9840         Even = false;
9841         break;
9842       }
9843 
9844       // Extracted values are either at Even indices <0,2,4,...> or at Odd
9845       // indices <1,3,5,...>.
9846       uint64_t Val = N->getConstantOperandVal(1);
9847       if (Val == 2 * i) {
9848         Even = true;
9849         continue;
9850       }
9851       if (Val - 1 == 2 * i) {
9852         Odd = true;
9853         continue;
9854       }
9855 
9856       // Something does not match: abort.
9857       Odd = false;
9858       Even = false;
9859       break;
9860     }
9861     if (Even || Odd) {
9862       SDValue LHS =
9863           DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, SDValue(Vector, 0),
9864                       DAG.getConstant(0, dl, MVT::i64));
9865       SDValue RHS =
9866           DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, SDValue(Vector, 0),
9867                       DAG.getConstant(NumElts, dl, MVT::i64));
9868 
9869       if (Even && !Odd)
9870         return DAG.getNode(AArch64ISD::UZP1, dl, DAG.getVTList(VT, VT), LHS,
9871                            RHS);
9872       if (Odd && !Even)
9873         return DAG.getNode(AArch64ISD::UZP2, dl, DAG.getVTList(VT, VT), LHS,
9874                            RHS);
9875     }
9876   }
9877 
9878   // Use DUP for non-constant splats. For f32 constant splats, reduce to
9879   // i32 and try again.
9880   if (usesOnlyOneValue) {
9881     if (!isConstant) {
9882       if (Value.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
9883           Value.getValueType() != VT) {
9884         LLVM_DEBUG(
9885             dbgs() << "LowerBUILD_VECTOR: use DUP for non-constant splats\n");
9886         return DAG.getNode(AArch64ISD::DUP, dl, VT, Value);
9887       }
9888 
9889       // This is actually a DUPLANExx operation, which keeps everything vectory.
9890 
9891       SDValue Lane = Value.getOperand(1);
9892       Value = Value.getOperand(0);
9893       if (Value.getValueSizeInBits() == 64) {
9894         LLVM_DEBUG(
9895             dbgs() << "LowerBUILD_VECTOR: DUPLANE works on 128-bit vectors, "
9896                       "widening it\n");
9897         Value = WidenVector(Value, DAG);
9898       }
9899 
9900       unsigned Opcode = getDUPLANEOp(VT.getVectorElementType());
9901       return DAG.getNode(Opcode, dl, VT, Value, Lane);
9902     }
9903 
9904     if (VT.getVectorElementType().isFloatingPoint()) {
9905       SmallVector<SDValue, 8> Ops;
9906       EVT EltTy = VT.getVectorElementType();
9907       assert ((EltTy == MVT::f16 || EltTy == MVT::bf16 || EltTy == MVT::f32 ||
9908                EltTy == MVT::f64) && "Unsupported floating-point vector type");
9909       LLVM_DEBUG(
9910           dbgs() << "LowerBUILD_VECTOR: float constant splats, creating int "
9911                     "BITCASTS, and try again\n");
9912       MVT NewType = MVT::getIntegerVT(EltTy.getSizeInBits());
9913       for (unsigned i = 0; i < NumElts; ++i)
9914         Ops.push_back(DAG.getNode(ISD::BITCAST, dl, NewType, Op.getOperand(i)));
9915       EVT VecVT = EVT::getVectorVT(*DAG.getContext(), NewType, NumElts);
9916       SDValue Val = DAG.getBuildVector(VecVT, dl, Ops);
9917       LLVM_DEBUG(dbgs() << "LowerBUILD_VECTOR: trying to lower new vector: ";
9918                  Val.dump(););
9919       Val = LowerBUILD_VECTOR(Val, DAG);
9920       if (Val.getNode())
9921         return DAG.getNode(ISD::BITCAST, dl, VT, Val);
9922     }
9923   }
9924 
9925   // If we need to insert a small number of different non-constant elements and
9926   // the vector width is sufficiently large, prefer using DUP with the common
9927   // value and INSERT_VECTOR_ELT for the different lanes. If DUP is preferred,
9928   // skip the constant lane handling below.
9929   bool PreferDUPAndInsert =
9930       !isConstant && NumDifferentLanes >= 1 &&
9931       NumDifferentLanes < ((NumElts - NumUndefLanes) / 2) &&
9932       NumDifferentLanes >= NumConstantLanes;
9933 
9934   // If there was only one constant value used and for more than one lane,
9935   // start by splatting that value, then replace the non-constant lanes. This
9936   // is better than the default, which will perform a separate initialization
9937   // for each lane.
9938   if (!PreferDUPAndInsert && NumConstantLanes > 0 && usesOnlyOneConstantValue) {
9939     // Firstly, try to materialize the splat constant.
9940     SDValue Vec = DAG.getSplatBuildVector(VT, dl, ConstantValue),
9941             Val = ConstantBuildVector(Vec, DAG);
9942     if (!Val) {
9943       // Otherwise, materialize the constant and splat it.
9944       Val = DAG.getNode(AArch64ISD::DUP, dl, VT, ConstantValue);
9945       DAG.ReplaceAllUsesWith(Vec.getNode(), &Val);
9946     }
9947 
9948     // Now insert the non-constant lanes.
9949     for (unsigned i = 0; i < NumElts; ++i) {
9950       SDValue V = Op.getOperand(i);
9951       SDValue LaneIdx = DAG.getConstant(i, dl, MVT::i64);
9952       if (!isIntOrFPConstant(V))
9953         // Note that type legalization likely mucked about with the VT of the
9954         // source operand, so we may have to convert it here before inserting.
9955         Val = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Val, V, LaneIdx);
9956     }
9957     return Val;
9958   }
9959 
9960   // This will generate a load from the constant pool.
9961   if (isConstant) {
9962     LLVM_DEBUG(
9963         dbgs() << "LowerBUILD_VECTOR: all elements are constant, use default "
9964                   "expansion\n");
9965     return SDValue();
9966   }
9967 
9968   // Empirical tests suggest this is rarely worth it for vectors of length <= 2.
9969   if (NumElts >= 4) {
9970     if (SDValue shuffle = ReconstructShuffle(Op, DAG))
9971       return shuffle;
9972   }
9973 
9974   if (PreferDUPAndInsert) {
9975     // First, build a constant vector with the common element.
9976     SmallVector<SDValue, 8> Ops(NumElts, Value);
9977     SDValue NewVector = LowerBUILD_VECTOR(DAG.getBuildVector(VT, dl, Ops), DAG);
9978     // Next, insert the elements that do not match the common value.
9979     for (unsigned I = 0; I < NumElts; ++I)
9980       if (Op.getOperand(I) != Value)
9981         NewVector =
9982             DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, NewVector,
9983                         Op.getOperand(I), DAG.getConstant(I, dl, MVT::i64));
9984 
9985     return NewVector;
9986   }
9987 
9988   // If all else fails, just use a sequence of INSERT_VECTOR_ELT when we
9989   // know the default expansion would otherwise fall back on something even
9990   // worse. For a vector with one or two non-undef values, that's
9991   // scalar_to_vector for the elements followed by a shuffle (provided the
9992   // shuffle is valid for the target) and materialization element by element
9993   // on the stack followed by a load for everything else.
9994   if (!isConstant && !usesOnlyOneValue) {
9995     LLVM_DEBUG(
9996         dbgs() << "LowerBUILD_VECTOR: alternatives failed, creating sequence "
9997                   "of INSERT_VECTOR_ELT\n");
9998 
9999     SDValue Vec = DAG.getUNDEF(VT);
10000     SDValue Op0 = Op.getOperand(0);
10001     unsigned i = 0;
10002 
10003     // Use SCALAR_TO_VECTOR for lane zero to
10004     // a) Avoid a RMW dependency on the full vector register, and
10005     // b) Allow the register coalescer to fold away the copy if the
10006     //    value is already in an S or D register, and we're forced to emit an
10007     //    INSERT_SUBREG that we can't fold anywhere.
10008     //
10009     // We also allow types like i8 and i16 which are illegal scalar but legal
10010     // vector element types. After type-legalization the inserted value is
10011     // extended (i32) and it is safe to cast them to the vector type by ignoring
10012     // the upper bits of the lowest lane (e.g. v8i8, v4i16).
10013     if (!Op0.isUndef()) {
10014       LLVM_DEBUG(dbgs() << "Creating node for op0, it is not undefined:\n");
10015       Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Op0);
10016       ++i;
10017     }
10018     LLVM_DEBUG(if (i < NumElts) dbgs()
10019                    << "Creating nodes for the other vector elements:\n";);
10020     for (; i < NumElts; ++i) {
10021       SDValue V = Op.getOperand(i);
10022       if (V.isUndef())
10023         continue;
10024       SDValue LaneIdx = DAG.getConstant(i, dl, MVT::i64);
10025       Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Vec, V, LaneIdx);
10026     }
10027     return Vec;
10028   }
10029 
10030   LLVM_DEBUG(
10031       dbgs() << "LowerBUILD_VECTOR: use default expansion, failed to find "
10032                 "better alternative\n");
10033   return SDValue();
10034 }
10035 
LowerCONCAT_VECTORS(SDValue Op,SelectionDAG & DAG) const10036 SDValue AArch64TargetLowering::LowerCONCAT_VECTORS(SDValue Op,
10037                                                    SelectionDAG &DAG) const {
10038   assert(Op.getValueType().isScalableVector() &&
10039          isTypeLegal(Op.getValueType()) &&
10040          "Expected legal scalable vector type!");
10041 
10042   if (isTypeLegal(Op.getOperand(0).getValueType()) && Op.getNumOperands() == 2)
10043     return Op;
10044 
10045   return SDValue();
10046 }
10047 
LowerINSERT_VECTOR_ELT(SDValue Op,SelectionDAG & DAG) const10048 SDValue AArch64TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
10049                                                       SelectionDAG &DAG) const {
10050   assert(Op.getOpcode() == ISD::INSERT_VECTOR_ELT && "Unknown opcode!");
10051 
10052   if (useSVEForFixedLengthVectorVT(Op.getValueType()))
10053     return LowerFixedLengthInsertVectorElt(Op, DAG);
10054 
10055   // Check for non-constant or out of range lane.
10056   EVT VT = Op.getOperand(0).getValueType();
10057   ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Op.getOperand(2));
10058   if (!CI || CI->getZExtValue() >= VT.getVectorNumElements())
10059     return SDValue();
10060 
10061 
10062   // Insertion/extraction are legal for V128 types.
10063   if (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32 ||
10064       VT == MVT::v2i64 || VT == MVT::v4f32 || VT == MVT::v2f64 ||
10065       VT == MVT::v8f16 || VT == MVT::v8bf16)
10066     return Op;
10067 
10068   if (VT != MVT::v8i8 && VT != MVT::v4i16 && VT != MVT::v2i32 &&
10069       VT != MVT::v1i64 && VT != MVT::v2f32 && VT != MVT::v4f16 &&
10070       VT != MVT::v4bf16)
10071     return SDValue();
10072 
10073   // For V64 types, we perform insertion by expanding the value
10074   // to a V128 type and perform the insertion on that.
10075   SDLoc DL(Op);
10076   SDValue WideVec = WidenVector(Op.getOperand(0), DAG);
10077   EVT WideTy = WideVec.getValueType();
10078 
10079   SDValue Node = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, WideTy, WideVec,
10080                              Op.getOperand(1), Op.getOperand(2));
10081   // Re-narrow the resultant vector.
10082   return NarrowVector(Node, DAG);
10083 }
10084 
10085 SDValue
LowerEXTRACT_VECTOR_ELT(SDValue Op,SelectionDAG & DAG) const10086 AArch64TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
10087                                                SelectionDAG &DAG) const {
10088   assert(Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT && "Unknown opcode!");
10089   EVT VT = Op.getOperand(0).getValueType();
10090 
10091   if (VT.getScalarType() == MVT::i1) {
10092     // We can't directly extract from an SVE predicate; extend it first.
10093     // (This isn't the only possible lowering, but it's straightforward.)
10094     EVT VectorVT = getPromotedVTForPredicate(VT);
10095     SDLoc DL(Op);
10096     SDValue Extend =
10097         DAG.getNode(ISD::ANY_EXTEND, DL, VectorVT, Op.getOperand(0));
10098     MVT ExtractTy = VectorVT == MVT::nxv2i64 ? MVT::i64 : MVT::i32;
10099     SDValue Extract = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractTy,
10100                                   Extend, Op.getOperand(1));
10101     return DAG.getAnyExtOrTrunc(Extract, DL, Op.getValueType());
10102   }
10103 
10104   if (useSVEForFixedLengthVectorVT(VT))
10105     return LowerFixedLengthExtractVectorElt(Op, DAG);
10106 
10107   // Check for non-constant or out of range lane.
10108   ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Op.getOperand(1));
10109   if (!CI || CI->getZExtValue() >= VT.getVectorNumElements())
10110     return SDValue();
10111 
10112   // Insertion/extraction are legal for V128 types.
10113   if (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32 ||
10114       VT == MVT::v2i64 || VT == MVT::v4f32 || VT == MVT::v2f64 ||
10115       VT == MVT::v8f16 || VT == MVT::v8bf16)
10116     return Op;
10117 
10118   if (VT != MVT::v8i8 && VT != MVT::v4i16 && VT != MVT::v2i32 &&
10119       VT != MVT::v1i64 && VT != MVT::v2f32 && VT != MVT::v4f16 &&
10120       VT != MVT::v4bf16)
10121     return SDValue();
10122 
10123   // For V64 types, we perform extraction by expanding the value
10124   // to a V128 type and perform the extraction on that.
10125   SDLoc DL(Op);
10126   SDValue WideVec = WidenVector(Op.getOperand(0), DAG);
10127   EVT WideTy = WideVec.getValueType();
10128 
10129   EVT ExtrTy = WideTy.getVectorElementType();
10130   if (ExtrTy == MVT::i16 || ExtrTy == MVT::i8)
10131     ExtrTy = MVT::i32;
10132 
10133   // For extractions, we just return the result directly.
10134   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtrTy, WideVec,
10135                      Op.getOperand(1));
10136 }
10137 
LowerEXTRACT_SUBVECTOR(SDValue Op,SelectionDAG & DAG) const10138 SDValue AArch64TargetLowering::LowerEXTRACT_SUBVECTOR(SDValue Op,
10139                                                       SelectionDAG &DAG) const {
10140   assert(Op.getValueType().isFixedLengthVector() &&
10141          "Only cases that extract a fixed length vector are supported!");
10142 
10143   EVT InVT = Op.getOperand(0).getValueType();
10144   unsigned Idx = cast<ConstantSDNode>(Op.getOperand(1))->getZExtValue();
10145   unsigned Size = Op.getValueSizeInBits();
10146 
10147   if (InVT.isScalableVector()) {
10148     // This will be matched by custom code during ISelDAGToDAG.
10149     if (Idx == 0 && isPackedVectorType(InVT, DAG))
10150       return Op;
10151 
10152     return SDValue();
10153   }
10154 
10155   // This will get lowered to an appropriate EXTRACT_SUBREG in ISel.
10156   if (Idx == 0 && InVT.getSizeInBits() <= 128)
10157     return Op;
10158 
10159   // If this is extracting the upper 64-bits of a 128-bit vector, we match
10160   // that directly.
10161   if (Size == 64 && Idx * InVT.getScalarSizeInBits() == 64 &&
10162       InVT.getSizeInBits() == 128)
10163     return Op;
10164 
10165   return SDValue();
10166 }
10167 
LowerINSERT_SUBVECTOR(SDValue Op,SelectionDAG & DAG) const10168 SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
10169                                                      SelectionDAG &DAG) const {
10170   assert(Op.getValueType().isScalableVector() &&
10171          "Only expect to lower inserts into scalable vectors!");
10172 
10173   EVT InVT = Op.getOperand(1).getValueType();
10174   unsigned Idx = cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue();
10175 
10176   if (InVT.isScalableVector()) {
10177     SDLoc DL(Op);
10178     EVT VT = Op.getValueType();
10179 
10180     if (!isTypeLegal(VT) || !VT.isInteger())
10181       return SDValue();
10182 
10183     SDValue Vec0 = Op.getOperand(0);
10184     SDValue Vec1 = Op.getOperand(1);
10185 
10186     // Ensure the subvector is half the size of the main vector.
10187     if (VT.getVectorElementCount() != (InVT.getVectorElementCount() * 2))
10188       return SDValue();
10189 
10190     // Extend elements of smaller vector...
10191     EVT WideVT = InVT.widenIntegerVectorElementType(*(DAG.getContext()));
10192     SDValue ExtVec = DAG.getNode(ISD::ANY_EXTEND, DL, WideVT, Vec1);
10193 
10194     if (Idx == 0) {
10195       SDValue HiVec0 = DAG.getNode(AArch64ISD::UUNPKHI, DL, WideVT, Vec0);
10196       return DAG.getNode(AArch64ISD::UZP1, DL, VT, ExtVec, HiVec0);
10197     } else if (Idx == InVT.getVectorMinNumElements()) {
10198       SDValue LoVec0 = DAG.getNode(AArch64ISD::UUNPKLO, DL, WideVT, Vec0);
10199       return DAG.getNode(AArch64ISD::UZP1, DL, VT, LoVec0, ExtVec);
10200     }
10201 
10202     return SDValue();
10203   }
10204 
10205   // This will be matched by custom code during ISelDAGToDAG.
10206   if (Idx == 0 && isPackedVectorType(InVT, DAG) && Op.getOperand(0).isUndef())
10207     return Op;
10208 
10209   return SDValue();
10210 }
10211 
LowerDIV(SDValue Op,SelectionDAG & DAG) const10212 SDValue AArch64TargetLowering::LowerDIV(SDValue Op, SelectionDAG &DAG) const {
10213   EVT VT = Op.getValueType();
10214 
10215   if (useSVEForFixedLengthVectorVT(VT, /*OverrideNEON=*/true))
10216     return LowerFixedLengthVectorIntDivideToSVE(Op, DAG);
10217 
10218   assert(VT.isScalableVector() && "Expected a scalable vector.");
10219 
10220   bool Signed = Op.getOpcode() == ISD::SDIV;
10221   unsigned PredOpcode = Signed ? AArch64ISD::SDIV_PRED : AArch64ISD::UDIV_PRED;
10222 
10223   if (VT == MVT::nxv4i32 || VT == MVT::nxv2i64)
10224     return LowerToPredicatedOp(Op, DAG, PredOpcode);
10225 
10226   // SVE doesn't have i8 and i16 DIV operations; widen them to 32-bit
10227   // operations, and truncate the result.
10228   EVT WidenedVT;
10229   if (VT == MVT::nxv16i8)
10230     WidenedVT = MVT::nxv8i16;
10231   else if (VT == MVT::nxv8i16)
10232     WidenedVT = MVT::nxv4i32;
10233   else
10234     llvm_unreachable("Unexpected Custom DIV operation");
10235 
10236   SDLoc dl(Op);
10237   unsigned UnpkLo = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO;
10238   unsigned UnpkHi = Signed ? AArch64ISD::SUNPKHI : AArch64ISD::UUNPKHI;
10239   SDValue Op0Lo = DAG.getNode(UnpkLo, dl, WidenedVT, Op.getOperand(0));
10240   SDValue Op1Lo = DAG.getNode(UnpkLo, dl, WidenedVT, Op.getOperand(1));
10241   SDValue Op0Hi = DAG.getNode(UnpkHi, dl, WidenedVT, Op.getOperand(0));
10242   SDValue Op1Hi = DAG.getNode(UnpkHi, dl, WidenedVT, Op.getOperand(1));
10243   SDValue ResultLo = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Lo, Op1Lo);
10244   SDValue ResultHi = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Hi, Op1Hi);
10245   return DAG.getNode(AArch64ISD::UZP1, dl, VT, ResultLo, ResultHi);
10246 }
10247 
isShuffleMaskLegal(ArrayRef<int> M,EVT VT) const10248 bool AArch64TargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const {
10249   // Currently no fixed length shuffles that require SVE are legal.
10250   if (useSVEForFixedLengthVectorVT(VT))
10251     return false;
10252 
10253   if (VT.getVectorNumElements() == 4 &&
10254       (VT.is128BitVector() || VT.is64BitVector())) {
10255     unsigned PFIndexes[4];
10256     for (unsigned i = 0; i != 4; ++i) {
10257       if (M[i] < 0)
10258         PFIndexes[i] = 8;
10259       else
10260         PFIndexes[i] = M[i];
10261     }
10262 
10263     // Compute the index in the perfect shuffle table.
10264     unsigned PFTableIndex = PFIndexes[0] * 9 * 9 * 9 + PFIndexes[1] * 9 * 9 +
10265                             PFIndexes[2] * 9 + PFIndexes[3];
10266     unsigned PFEntry = PerfectShuffleTable[PFTableIndex];
10267     unsigned Cost = (PFEntry >> 30);
10268 
10269     if (Cost <= 4)
10270       return true;
10271   }
10272 
10273   bool DummyBool;
10274   int DummyInt;
10275   unsigned DummyUnsigned;
10276 
10277   return (ShuffleVectorSDNode::isSplatMask(&M[0], VT) || isREVMask(M, VT, 64) ||
10278           isREVMask(M, VT, 32) || isREVMask(M, VT, 16) ||
10279           isEXTMask(M, VT, DummyBool, DummyUnsigned) ||
10280           // isTBLMask(M, VT) || // FIXME: Port TBL support from ARM.
10281           isTRNMask(M, VT, DummyUnsigned) || isUZPMask(M, VT, DummyUnsigned) ||
10282           isZIPMask(M, VT, DummyUnsigned) ||
10283           isTRN_v_undef_Mask(M, VT, DummyUnsigned) ||
10284           isUZP_v_undef_Mask(M, VT, DummyUnsigned) ||
10285           isZIP_v_undef_Mask(M, VT, DummyUnsigned) ||
10286           isINSMask(M, VT.getVectorNumElements(), DummyBool, DummyInt) ||
10287           isConcatMask(M, VT, VT.getSizeInBits() == 128));
10288 }
10289 
10290 /// getVShiftImm - Check if this is a valid build_vector for the immediate
10291 /// operand of a vector shift operation, where all the elements of the
10292 /// build_vector must have the same constant integer value.
getVShiftImm(SDValue Op,unsigned ElementBits,int64_t & Cnt)10293 static bool getVShiftImm(SDValue Op, unsigned ElementBits, int64_t &Cnt) {
10294   // Ignore bit_converts.
10295   while (Op.getOpcode() == ISD::BITCAST)
10296     Op = Op.getOperand(0);
10297   BuildVectorSDNode *BVN = dyn_cast<BuildVectorSDNode>(Op.getNode());
10298   APInt SplatBits, SplatUndef;
10299   unsigned SplatBitSize;
10300   bool HasAnyUndefs;
10301   if (!BVN || !BVN->isConstantSplat(SplatBits, SplatUndef, SplatBitSize,
10302                                     HasAnyUndefs, ElementBits) ||
10303       SplatBitSize > ElementBits)
10304     return false;
10305   Cnt = SplatBits.getSExtValue();
10306   return true;
10307 }
10308 
10309 /// isVShiftLImm - Check if this is a valid build_vector for the immediate
10310 /// operand of a vector shift left operation.  That value must be in the range:
10311 ///   0 <= Value < ElementBits for a left shift; or
10312 ///   0 <= Value <= ElementBits for a long left shift.
isVShiftLImm(SDValue Op,EVT VT,bool isLong,int64_t & Cnt)10313 static bool isVShiftLImm(SDValue Op, EVT VT, bool isLong, int64_t &Cnt) {
10314   assert(VT.isVector() && "vector shift count is not a vector type");
10315   int64_t ElementBits = VT.getScalarSizeInBits();
10316   if (!getVShiftImm(Op, ElementBits, Cnt))
10317     return false;
10318   return (Cnt >= 0 && (isLong ? Cnt - 1 : Cnt) < ElementBits);
10319 }
10320 
10321 /// isVShiftRImm - Check if this is a valid build_vector for the immediate
10322 /// operand of a vector shift right operation. The value must be in the range:
10323 ///   1 <= Value <= ElementBits for a right shift; or
isVShiftRImm(SDValue Op,EVT VT,bool isNarrow,int64_t & Cnt)10324 static bool isVShiftRImm(SDValue Op, EVT VT, bool isNarrow, int64_t &Cnt) {
10325   assert(VT.isVector() && "vector shift count is not a vector type");
10326   int64_t ElementBits = VT.getScalarSizeInBits();
10327   if (!getVShiftImm(Op, ElementBits, Cnt))
10328     return false;
10329   return (Cnt >= 1 && Cnt <= (isNarrow ? ElementBits / 2 : ElementBits));
10330 }
10331 
LowerTRUNCATE(SDValue Op,SelectionDAG & DAG) const10332 SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
10333                                              SelectionDAG &DAG) const {
10334   EVT VT = Op.getValueType();
10335 
10336   if (VT.getScalarType() == MVT::i1) {
10337     // Lower i1 truncate to `(x & 1) != 0`.
10338     SDLoc dl(Op);
10339     EVT OpVT = Op.getOperand(0).getValueType();
10340     SDValue Zero = DAG.getConstant(0, dl, OpVT);
10341     SDValue One = DAG.getConstant(1, dl, OpVT);
10342     SDValue And = DAG.getNode(ISD::AND, dl, OpVT, Op.getOperand(0), One);
10343     return DAG.getSetCC(dl, VT, And, Zero, ISD::SETNE);
10344   }
10345 
10346   if (!VT.isVector() || VT.isScalableVector())
10347     return SDValue();
10348 
10349   if (useSVEForFixedLengthVectorVT(Op.getOperand(0).getValueType()))
10350     return LowerFixedLengthVectorTruncateToSVE(Op, DAG);
10351 
10352   return SDValue();
10353 }
10354 
LowerVectorSRA_SRL_SHL(SDValue Op,SelectionDAG & DAG) const10355 SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
10356                                                       SelectionDAG &DAG) const {
10357   EVT VT = Op.getValueType();
10358   SDLoc DL(Op);
10359   int64_t Cnt;
10360 
10361   if (!Op.getOperand(1).getValueType().isVector())
10362     return Op;
10363   unsigned EltSize = VT.getScalarSizeInBits();
10364 
10365   switch (Op.getOpcode()) {
10366   default:
10367     llvm_unreachable("unexpected shift opcode");
10368 
10369   case ISD::SHL:
10370     if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT))
10371       return LowerToPredicatedOp(Op, DAG, AArch64ISD::SHL_PRED);
10372 
10373     if (isVShiftLImm(Op.getOperand(1), VT, false, Cnt) && Cnt < EltSize)
10374       return DAG.getNode(AArch64ISD::VSHL, DL, VT, Op.getOperand(0),
10375                          DAG.getConstant(Cnt, DL, MVT::i32));
10376     return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
10377                        DAG.getConstant(Intrinsic::aarch64_neon_ushl, DL,
10378                                        MVT::i32),
10379                        Op.getOperand(0), Op.getOperand(1));
10380   case ISD::SRA:
10381   case ISD::SRL:
10382     if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT)) {
10383       unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_PRED
10384                                                 : AArch64ISD::SRL_PRED;
10385       return LowerToPredicatedOp(Op, DAG, Opc);
10386     }
10387 
10388     // Right shift immediate
10389     if (isVShiftRImm(Op.getOperand(1), VT, false, Cnt) && Cnt < EltSize) {
10390       unsigned Opc =
10391           (Op.getOpcode() == ISD::SRA) ? AArch64ISD::VASHR : AArch64ISD::VLSHR;
10392       return DAG.getNode(Opc, DL, VT, Op.getOperand(0),
10393                          DAG.getConstant(Cnt, DL, MVT::i32));
10394     }
10395 
10396     // Right shift register.  Note, there is not a shift right register
10397     // instruction, but the shift left register instruction takes a signed
10398     // value, where negative numbers specify a right shift.
10399     unsigned Opc = (Op.getOpcode() == ISD::SRA) ? Intrinsic::aarch64_neon_sshl
10400                                                 : Intrinsic::aarch64_neon_ushl;
10401     // negate the shift amount
10402     SDValue NegShift = DAG.getNode(AArch64ISD::NEG, DL, VT, Op.getOperand(1));
10403     SDValue NegShiftLeft =
10404         DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
10405                     DAG.getConstant(Opc, DL, MVT::i32), Op.getOperand(0),
10406                     NegShift);
10407     return NegShiftLeft;
10408   }
10409 
10410   return SDValue();
10411 }
10412 
EmitVectorComparison(SDValue LHS,SDValue RHS,AArch64CC::CondCode CC,bool NoNans,EVT VT,const SDLoc & dl,SelectionDAG & DAG)10413 static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
10414                                     AArch64CC::CondCode CC, bool NoNans, EVT VT,
10415                                     const SDLoc &dl, SelectionDAG &DAG) {
10416   EVT SrcVT = LHS.getValueType();
10417   assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
10418          "function only supposed to emit natural comparisons");
10419 
10420   BuildVectorSDNode *BVN = dyn_cast<BuildVectorSDNode>(RHS.getNode());
10421   APInt CnstBits(VT.getSizeInBits(), 0);
10422   APInt UndefBits(VT.getSizeInBits(), 0);
10423   bool IsCnst = BVN && resolveBuildVector(BVN, CnstBits, UndefBits);
10424   bool IsZero = IsCnst && (CnstBits == 0);
10425 
10426   if (SrcVT.getVectorElementType().isFloatingPoint()) {
10427     switch (CC) {
10428     default:
10429       return SDValue();
10430     case AArch64CC::NE: {
10431       SDValue Fcmeq;
10432       if (IsZero)
10433         Fcmeq = DAG.getNode(AArch64ISD::FCMEQz, dl, VT, LHS);
10434       else
10435         Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
10436       return DAG.getNOT(dl, Fcmeq, VT);
10437     }
10438     case AArch64CC::EQ:
10439       if (IsZero)
10440         return DAG.getNode(AArch64ISD::FCMEQz, dl, VT, LHS);
10441       return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
10442     case AArch64CC::GE:
10443       if (IsZero)
10444         return DAG.getNode(AArch64ISD::FCMGEz, dl, VT, LHS);
10445       return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
10446     case AArch64CC::GT:
10447       if (IsZero)
10448         return DAG.getNode(AArch64ISD::FCMGTz, dl, VT, LHS);
10449       return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
10450     case AArch64CC::LS:
10451       if (IsZero)
10452         return DAG.getNode(AArch64ISD::FCMLEz, dl, VT, LHS);
10453       return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
10454     case AArch64CC::LT:
10455       if (!NoNans)
10456         return SDValue();
10457       // If we ignore NaNs then we can use to the MI implementation.
10458       LLVM_FALLTHROUGH;
10459     case AArch64CC::MI:
10460       if (IsZero)
10461         return DAG.getNode(AArch64ISD::FCMLTz, dl, VT, LHS);
10462       return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
10463     }
10464   }
10465 
10466   switch (CC) {
10467   default:
10468     return SDValue();
10469   case AArch64CC::NE: {
10470     SDValue Cmeq;
10471     if (IsZero)
10472       Cmeq = DAG.getNode(AArch64ISD::CMEQz, dl, VT, LHS);
10473     else
10474       Cmeq = DAG.getNode(AArch64ISD::CMEQ, dl, VT, LHS, RHS);
10475     return DAG.getNOT(dl, Cmeq, VT);
10476   }
10477   case AArch64CC::EQ:
10478     if (IsZero)
10479       return DAG.getNode(AArch64ISD::CMEQz, dl, VT, LHS);
10480     return DAG.getNode(AArch64ISD::CMEQ, dl, VT, LHS, RHS);
10481   case AArch64CC::GE:
10482     if (IsZero)
10483       return DAG.getNode(AArch64ISD::CMGEz, dl, VT, LHS);
10484     return DAG.getNode(AArch64ISD::CMGE, dl, VT, LHS, RHS);
10485   case AArch64CC::GT:
10486     if (IsZero)
10487       return DAG.getNode(AArch64ISD::CMGTz, dl, VT, LHS);
10488     return DAG.getNode(AArch64ISD::CMGT, dl, VT, LHS, RHS);
10489   case AArch64CC::LE:
10490     if (IsZero)
10491       return DAG.getNode(AArch64ISD::CMLEz, dl, VT, LHS);
10492     return DAG.getNode(AArch64ISD::CMGE, dl, VT, RHS, LHS);
10493   case AArch64CC::LS:
10494     return DAG.getNode(AArch64ISD::CMHS, dl, VT, RHS, LHS);
10495   case AArch64CC::LO:
10496     return DAG.getNode(AArch64ISD::CMHI, dl, VT, RHS, LHS);
10497   case AArch64CC::LT:
10498     if (IsZero)
10499       return DAG.getNode(AArch64ISD::CMLTz, dl, VT, LHS);
10500     return DAG.getNode(AArch64ISD::CMGT, dl, VT, RHS, LHS);
10501   case AArch64CC::HI:
10502     return DAG.getNode(AArch64ISD::CMHI, dl, VT, LHS, RHS);
10503   case AArch64CC::HS:
10504     return DAG.getNode(AArch64ISD::CMHS, dl, VT, LHS, RHS);
10505   }
10506 }
10507 
LowerVSETCC(SDValue Op,SelectionDAG & DAG) const10508 SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
10509                                            SelectionDAG &DAG) const {
10510   if (Op.getValueType().isScalableVector())
10511     return LowerToPredicatedOp(Op, DAG, AArch64ISD::SETCC_MERGE_ZERO);
10512 
10513   if (useSVEForFixedLengthVectorVT(Op.getOperand(0).getValueType()))
10514     return LowerFixedLengthVectorSetccToSVE(Op, DAG);
10515 
10516   ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
10517   SDValue LHS = Op.getOperand(0);
10518   SDValue RHS = Op.getOperand(1);
10519   EVT CmpVT = LHS.getValueType().changeVectorElementTypeToInteger();
10520   SDLoc dl(Op);
10521 
10522   if (LHS.getValueType().getVectorElementType().isInteger()) {
10523     assert(LHS.getValueType() == RHS.getValueType());
10524     AArch64CC::CondCode AArch64CC = changeIntCCToAArch64CC(CC);
10525     SDValue Cmp =
10526         EmitVectorComparison(LHS, RHS, AArch64CC, false, CmpVT, dl, DAG);
10527     return DAG.getSExtOrTrunc(Cmp, dl, Op.getValueType());
10528   }
10529 
10530   const bool FullFP16 =
10531     static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasFullFP16();
10532 
10533   // Make v4f16 (only) fcmp operations utilise vector instructions
10534   // v8f16 support will be a litle more complicated
10535   if (!FullFP16 && LHS.getValueType().getVectorElementType() == MVT::f16) {
10536     if (LHS.getValueType().getVectorNumElements() == 4) {
10537       LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::v4f32, LHS);
10538       RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::v4f32, RHS);
10539       SDValue NewSetcc = DAG.getSetCC(dl, MVT::v4i16, LHS, RHS, CC);
10540       DAG.ReplaceAllUsesWith(Op, NewSetcc);
10541       CmpVT = MVT::v4i32;
10542     } else
10543       return SDValue();
10544   }
10545 
10546   assert((!FullFP16 && LHS.getValueType().getVectorElementType() != MVT::f16) ||
10547           LHS.getValueType().getVectorElementType() != MVT::f128);
10548 
10549   // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
10550   // clean.  Some of them require two branches to implement.
10551   AArch64CC::CondCode CC1, CC2;
10552   bool ShouldInvert;
10553   changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);
10554 
10555   bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath;
10556   SDValue Cmp =
10557       EmitVectorComparison(LHS, RHS, CC1, NoNaNs, CmpVT, dl, DAG);
10558   if (!Cmp.getNode())
10559     return SDValue();
10560 
10561   if (CC2 != AArch64CC::AL) {
10562     SDValue Cmp2 =
10563         EmitVectorComparison(LHS, RHS, CC2, NoNaNs, CmpVT, dl, DAG);
10564     if (!Cmp2.getNode())
10565       return SDValue();
10566 
10567     Cmp = DAG.getNode(ISD::OR, dl, CmpVT, Cmp, Cmp2);
10568   }
10569 
10570   Cmp = DAG.getSExtOrTrunc(Cmp, dl, Op.getValueType());
10571 
10572   if (ShouldInvert)
10573     Cmp = DAG.getNOT(dl, Cmp, Cmp.getValueType());
10574 
10575   return Cmp;
10576 }
10577 
getReductionSDNode(unsigned Op,SDLoc DL,SDValue ScalarOp,SelectionDAG & DAG)10578 static SDValue getReductionSDNode(unsigned Op, SDLoc DL, SDValue ScalarOp,
10579                                   SelectionDAG &DAG) {
10580   SDValue VecOp = ScalarOp.getOperand(0);
10581   auto Rdx = DAG.getNode(Op, DL, VecOp.getSimpleValueType(), VecOp);
10582   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarOp.getValueType(), Rdx,
10583                      DAG.getConstant(0, DL, MVT::i64));
10584 }
10585 
LowerVECREDUCE(SDValue Op,SelectionDAG & DAG) const10586 SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op,
10587                                               SelectionDAG &DAG) const {
10588   SDValue Src = Op.getOperand(0);
10589 
10590   // Try to lower fixed length reductions to SVE.
10591   EVT SrcVT = Src.getValueType();
10592   bool OverrideNEON = Op.getOpcode() == ISD::VECREDUCE_AND ||
10593                       Op.getOpcode() == ISD::VECREDUCE_OR ||
10594                       Op.getOpcode() == ISD::VECREDUCE_XOR ||
10595                       Op.getOpcode() == ISD::VECREDUCE_FADD ||
10596                       (Op.getOpcode() != ISD::VECREDUCE_ADD &&
10597                        SrcVT.getVectorElementType() == MVT::i64);
10598   if (SrcVT.isScalableVector() ||
10599       useSVEForFixedLengthVectorVT(SrcVT, OverrideNEON)) {
10600 
10601     if (SrcVT.getVectorElementType() == MVT::i1)
10602       return LowerPredReductionToSVE(Op, DAG);
10603 
10604     switch (Op.getOpcode()) {
10605     case ISD::VECREDUCE_ADD:
10606       return LowerReductionToSVE(AArch64ISD::UADDV_PRED, Op, DAG);
10607     case ISD::VECREDUCE_AND:
10608       return LowerReductionToSVE(AArch64ISD::ANDV_PRED, Op, DAG);
10609     case ISD::VECREDUCE_OR:
10610       return LowerReductionToSVE(AArch64ISD::ORV_PRED, Op, DAG);
10611     case ISD::VECREDUCE_SMAX:
10612       return LowerReductionToSVE(AArch64ISD::SMAXV_PRED, Op, DAG);
10613     case ISD::VECREDUCE_SMIN:
10614       return LowerReductionToSVE(AArch64ISD::SMINV_PRED, Op, DAG);
10615     case ISD::VECREDUCE_UMAX:
10616       return LowerReductionToSVE(AArch64ISD::UMAXV_PRED, Op, DAG);
10617     case ISD::VECREDUCE_UMIN:
10618       return LowerReductionToSVE(AArch64ISD::UMINV_PRED, Op, DAG);
10619     case ISD::VECREDUCE_XOR:
10620       return LowerReductionToSVE(AArch64ISD::EORV_PRED, Op, DAG);
10621     case ISD::VECREDUCE_FADD:
10622       return LowerReductionToSVE(AArch64ISD::FADDV_PRED, Op, DAG);
10623     case ISD::VECREDUCE_FMAX:
10624       return LowerReductionToSVE(AArch64ISD::FMAXNMV_PRED, Op, DAG);
10625     case ISD::VECREDUCE_FMIN:
10626       return LowerReductionToSVE(AArch64ISD::FMINNMV_PRED, Op, DAG);
10627     default:
10628       llvm_unreachable("Unhandled fixed length reduction");
10629     }
10630   }
10631 
10632   // Lower NEON reductions.
10633   SDLoc dl(Op);
10634   switch (Op.getOpcode()) {
10635   case ISD::VECREDUCE_ADD:
10636     return getReductionSDNode(AArch64ISD::UADDV, dl, Op, DAG);
10637   case ISD::VECREDUCE_SMAX:
10638     return getReductionSDNode(AArch64ISD::SMAXV, dl, Op, DAG);
10639   case ISD::VECREDUCE_SMIN:
10640     return getReductionSDNode(AArch64ISD::SMINV, dl, Op, DAG);
10641   case ISD::VECREDUCE_UMAX:
10642     return getReductionSDNode(AArch64ISD::UMAXV, dl, Op, DAG);
10643   case ISD::VECREDUCE_UMIN:
10644     return getReductionSDNode(AArch64ISD::UMINV, dl, Op, DAG);
10645   case ISD::VECREDUCE_FMAX: {
10646     return DAG.getNode(
10647         ISD::INTRINSIC_WO_CHAIN, dl, Op.getValueType(),
10648         DAG.getConstant(Intrinsic::aarch64_neon_fmaxnmv, dl, MVT::i32),
10649         Src);
10650   }
10651   case ISD::VECREDUCE_FMIN: {
10652     return DAG.getNode(
10653         ISD::INTRINSIC_WO_CHAIN, dl, Op.getValueType(),
10654         DAG.getConstant(Intrinsic::aarch64_neon_fminnmv, dl, MVT::i32),
10655         Src);
10656   }
10657   default:
10658     llvm_unreachable("Unhandled reduction");
10659   }
10660 }
10661 
LowerATOMIC_LOAD_SUB(SDValue Op,SelectionDAG & DAG) const10662 SDValue AArch64TargetLowering::LowerATOMIC_LOAD_SUB(SDValue Op,
10663                                                     SelectionDAG &DAG) const {
10664   auto &Subtarget = static_cast<const AArch64Subtarget &>(DAG.getSubtarget());
10665   if (!Subtarget.hasLSE() && !Subtarget.outlineAtomics())
10666     return SDValue();
10667 
10668   // LSE has an atomic load-add instruction, but not a load-sub.
10669   SDLoc dl(Op);
10670   MVT VT = Op.getSimpleValueType();
10671   SDValue RHS = Op.getOperand(2);
10672   AtomicSDNode *AN = cast<AtomicSDNode>(Op.getNode());
10673   RHS = DAG.getNode(ISD::SUB, dl, VT, DAG.getConstant(0, dl, VT), RHS);
10674   return DAG.getAtomic(ISD::ATOMIC_LOAD_ADD, dl, AN->getMemoryVT(),
10675                        Op.getOperand(0), Op.getOperand(1), RHS,
10676                        AN->getMemOperand());
10677 }
10678 
LowerATOMIC_LOAD_AND(SDValue Op,SelectionDAG & DAG) const10679 SDValue AArch64TargetLowering::LowerATOMIC_LOAD_AND(SDValue Op,
10680                                                     SelectionDAG &DAG) const {
10681   auto &Subtarget = static_cast<const AArch64Subtarget &>(DAG.getSubtarget());
10682   if (!Subtarget.hasLSE() && !Subtarget.outlineAtomics())
10683     return SDValue();
10684 
10685   // LSE has an atomic load-clear instruction, but not a load-and.
10686   SDLoc dl(Op);
10687   MVT VT = Op.getSimpleValueType();
10688   SDValue RHS = Op.getOperand(2);
10689   AtomicSDNode *AN = cast<AtomicSDNode>(Op.getNode());
10690   RHS = DAG.getNode(ISD::XOR, dl, VT, DAG.getConstant(-1ULL, dl, VT), RHS);
10691   return DAG.getAtomic(ISD::ATOMIC_LOAD_CLR, dl, AN->getMemoryVT(),
10692                        Op.getOperand(0), Op.getOperand(1), RHS,
10693                        AN->getMemOperand());
10694 }
10695 
LowerWindowsDYNAMIC_STACKALLOC(SDValue Op,SDValue Chain,SDValue & Size,SelectionDAG & DAG) const10696 SDValue AArch64TargetLowering::LowerWindowsDYNAMIC_STACKALLOC(
10697     SDValue Op, SDValue Chain, SDValue &Size, SelectionDAG &DAG) const {
10698   SDLoc dl(Op);
10699   EVT PtrVT = getPointerTy(DAG.getDataLayout());
10700   SDValue Callee = DAG.getTargetExternalSymbol("__chkstk", PtrVT, 0);
10701 
10702   const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
10703   const uint32_t *Mask = TRI->getWindowsStackProbePreservedMask();
10704   if (Subtarget->hasCustomCallingConv())
10705     TRI->UpdateCustomCallPreservedMask(DAG.getMachineFunction(), &Mask);
10706 
10707   Size = DAG.getNode(ISD::SRL, dl, MVT::i64, Size,
10708                      DAG.getConstant(4, dl, MVT::i64));
10709   Chain = DAG.getCopyToReg(Chain, dl, AArch64::X15, Size, SDValue());
10710   Chain =
10711       DAG.getNode(AArch64ISD::CALL, dl, DAG.getVTList(MVT::Other, MVT::Glue),
10712                   Chain, Callee, DAG.getRegister(AArch64::X15, MVT::i64),
10713                   DAG.getRegisterMask(Mask), Chain.getValue(1));
10714   // To match the actual intent better, we should read the output from X15 here
10715   // again (instead of potentially spilling it to the stack), but rereading Size
10716   // from X15 here doesn't work at -O0, since it thinks that X15 is undefined
10717   // here.
10718 
10719   Size = DAG.getNode(ISD::SHL, dl, MVT::i64, Size,
10720                      DAG.getConstant(4, dl, MVT::i64));
10721   return Chain;
10722 }
10723 
10724 SDValue
LowerDYNAMIC_STACKALLOC(SDValue Op,SelectionDAG & DAG) const10725 AArch64TargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
10726                                                SelectionDAG &DAG) const {
10727   assert(Subtarget->isTargetWindows() &&
10728          "Only Windows alloca probing supported");
10729   SDLoc dl(Op);
10730   // Get the inputs.
10731   SDNode *Node = Op.getNode();
10732   SDValue Chain = Op.getOperand(0);
10733   SDValue Size = Op.getOperand(1);
10734   MaybeAlign Align =
10735       cast<ConstantSDNode>(Op.getOperand(2))->getMaybeAlignValue();
10736   EVT VT = Node->getValueType(0);
10737 
10738   if (DAG.getMachineFunction().getFunction().hasFnAttribute(
10739           "no-stack-arg-probe")) {
10740     SDValue SP = DAG.getCopyFromReg(Chain, dl, AArch64::SP, MVT::i64);
10741     Chain = SP.getValue(1);
10742     SP = DAG.getNode(ISD::SUB, dl, MVT::i64, SP, Size);
10743     if (Align)
10744       SP = DAG.getNode(ISD::AND, dl, VT, SP.getValue(0),
10745                        DAG.getConstant(-(uint64_t)Align->value(), dl, VT));
10746     Chain = DAG.getCopyToReg(Chain, dl, AArch64::SP, SP);
10747     SDValue Ops[2] = {SP, Chain};
10748     return DAG.getMergeValues(Ops, dl);
10749   }
10750 
10751   Chain = DAG.getCALLSEQ_START(Chain, 0, 0, dl);
10752 
10753   Chain = LowerWindowsDYNAMIC_STACKALLOC(Op, Chain, Size, DAG);
10754 
10755   SDValue SP = DAG.getCopyFromReg(Chain, dl, AArch64::SP, MVT::i64);
10756   Chain = SP.getValue(1);
10757   SP = DAG.getNode(ISD::SUB, dl, MVT::i64, SP, Size);
10758   if (Align)
10759     SP = DAG.getNode(ISD::AND, dl, VT, SP.getValue(0),
10760                      DAG.getConstant(-(uint64_t)Align->value(), dl, VT));
10761   Chain = DAG.getCopyToReg(Chain, dl, AArch64::SP, SP);
10762 
10763   Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(0, dl, true),
10764                              DAG.getIntPtrConstant(0, dl, true), SDValue(), dl);
10765 
10766   SDValue Ops[2] = {SP, Chain};
10767   return DAG.getMergeValues(Ops, dl);
10768 }
10769 
LowerVSCALE(SDValue Op,SelectionDAG & DAG) const10770 SDValue AArch64TargetLowering::LowerVSCALE(SDValue Op,
10771                                            SelectionDAG &DAG) const {
10772   EVT VT = Op.getValueType();
10773   assert(VT != MVT::i64 && "Expected illegal VSCALE node");
10774 
10775   SDLoc DL(Op);
10776   APInt MulImm = cast<ConstantSDNode>(Op.getOperand(0))->getAPIntValue();
10777   return DAG.getZExtOrTrunc(DAG.getVScale(DL, MVT::i64, MulImm.sextOrSelf(64)),
10778                             DL, VT);
10779 }
10780 
10781 /// Set the IntrinsicInfo for the `aarch64_sve_st<N>` intrinsics.
10782 template <unsigned NumVecs>
10783 static bool
setInfoSVEStN(const AArch64TargetLowering & TLI,const DataLayout & DL,AArch64TargetLowering::IntrinsicInfo & Info,const CallInst & CI)10784 setInfoSVEStN(const AArch64TargetLowering &TLI, const DataLayout &DL,
10785               AArch64TargetLowering::IntrinsicInfo &Info, const CallInst &CI) {
10786   Info.opc = ISD::INTRINSIC_VOID;
10787   // Retrieve EC from first vector argument.
10788   const EVT VT = TLI.getMemValueType(DL, CI.getArgOperand(0)->getType());
10789   ElementCount EC = VT.getVectorElementCount();
10790 #ifndef NDEBUG
10791   // Check the assumption that all input vectors are the same type.
10792   for (unsigned I = 0; I < NumVecs; ++I)
10793     assert(VT == TLI.getMemValueType(DL, CI.getArgOperand(I)->getType()) &&
10794            "Invalid type.");
10795 #endif
10796   // memVT is `NumVecs * VT`.
10797   Info.memVT = EVT::getVectorVT(CI.getType()->getContext(), VT.getScalarType(),
10798                                 EC * NumVecs);
10799   Info.ptrVal = CI.getArgOperand(CI.getNumArgOperands() - 1);
10800   Info.offset = 0;
10801   Info.align.reset();
10802   Info.flags = MachineMemOperand::MOStore;
10803   return true;
10804 }
10805 
10806 /// getTgtMemIntrinsic - Represent NEON load and store intrinsics as
10807 /// MemIntrinsicNodes.  The associated MachineMemOperands record the alignment
10808 /// specified in the intrinsic calls.
getTgtMemIntrinsic(IntrinsicInfo & Info,const CallInst & I,MachineFunction & MF,unsigned Intrinsic) const10809 bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
10810                                                const CallInst &I,
10811                                                MachineFunction &MF,
10812                                                unsigned Intrinsic) const {
10813   auto &DL = I.getModule()->getDataLayout();
10814   switch (Intrinsic) {
10815   case Intrinsic::aarch64_sve_st2:
10816     return setInfoSVEStN<2>(*this, DL, Info, I);
10817   case Intrinsic::aarch64_sve_st3:
10818     return setInfoSVEStN<3>(*this, DL, Info, I);
10819   case Intrinsic::aarch64_sve_st4:
10820     return setInfoSVEStN<4>(*this, DL, Info, I);
10821   case Intrinsic::aarch64_neon_ld2:
10822   case Intrinsic::aarch64_neon_ld3:
10823   case Intrinsic::aarch64_neon_ld4:
10824   case Intrinsic::aarch64_neon_ld1x2:
10825   case Intrinsic::aarch64_neon_ld1x3:
10826   case Intrinsic::aarch64_neon_ld1x4:
10827   case Intrinsic::aarch64_neon_ld2lane:
10828   case Intrinsic::aarch64_neon_ld3lane:
10829   case Intrinsic::aarch64_neon_ld4lane:
10830   case Intrinsic::aarch64_neon_ld2r:
10831   case Intrinsic::aarch64_neon_ld3r:
10832   case Intrinsic::aarch64_neon_ld4r: {
10833     Info.opc = ISD::INTRINSIC_W_CHAIN;
10834     // Conservatively set memVT to the entire set of vectors loaded.
10835     uint64_t NumElts = DL.getTypeSizeInBits(I.getType()) / 64;
10836     Info.memVT = EVT::getVectorVT(I.getType()->getContext(), MVT::i64, NumElts);
10837     Info.ptrVal = I.getArgOperand(I.getNumArgOperands() - 1);
10838     Info.offset = 0;
10839     Info.align.reset();
10840     // volatile loads with NEON intrinsics not supported
10841     Info.flags = MachineMemOperand::MOLoad;
10842     return true;
10843   }
10844   case Intrinsic::aarch64_neon_st2:
10845   case Intrinsic::aarch64_neon_st3:
10846   case Intrinsic::aarch64_neon_st4:
10847   case Intrinsic::aarch64_neon_st1x2:
10848   case Intrinsic::aarch64_neon_st1x3:
10849   case Intrinsic::aarch64_neon_st1x4:
10850   case Intrinsic::aarch64_neon_st2lane:
10851   case Intrinsic::aarch64_neon_st3lane:
10852   case Intrinsic::aarch64_neon_st4lane: {
10853     Info.opc = ISD::INTRINSIC_VOID;
10854     // Conservatively set memVT to the entire set of vectors stored.
10855     unsigned NumElts = 0;
10856     for (unsigned ArgI = 0, ArgE = I.getNumArgOperands(); ArgI < ArgE; ++ArgI) {
10857       Type *ArgTy = I.getArgOperand(ArgI)->getType();
10858       if (!ArgTy->isVectorTy())
10859         break;
10860       NumElts += DL.getTypeSizeInBits(ArgTy) / 64;
10861     }
10862     Info.memVT = EVT::getVectorVT(I.getType()->getContext(), MVT::i64, NumElts);
10863     Info.ptrVal = I.getArgOperand(I.getNumArgOperands() - 1);
10864     Info.offset = 0;
10865     Info.align.reset();
10866     // volatile stores with NEON intrinsics not supported
10867     Info.flags = MachineMemOperand::MOStore;
10868     return true;
10869   }
10870   case Intrinsic::aarch64_ldaxr:
10871   case Intrinsic::aarch64_ldxr: {
10872     PointerType *PtrTy = cast<PointerType>(I.getArgOperand(0)->getType());
10873     Info.opc = ISD::INTRINSIC_W_CHAIN;
10874     Info.memVT = MVT::getVT(PtrTy->getElementType());
10875     Info.ptrVal = I.getArgOperand(0);
10876     Info.offset = 0;
10877     Info.align = DL.getABITypeAlign(PtrTy->getElementType());
10878     Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOVolatile;
10879     return true;
10880   }
10881   case Intrinsic::aarch64_stlxr:
10882   case Intrinsic::aarch64_stxr: {
10883     PointerType *PtrTy = cast<PointerType>(I.getArgOperand(1)->getType());
10884     Info.opc = ISD::INTRINSIC_W_CHAIN;
10885     Info.memVT = MVT::getVT(PtrTy->getElementType());
10886     Info.ptrVal = I.getArgOperand(1);
10887     Info.offset = 0;
10888     Info.align = DL.getABITypeAlign(PtrTy->getElementType());
10889     Info.flags = MachineMemOperand::MOStore | MachineMemOperand::MOVolatile;
10890     return true;
10891   }
10892   case Intrinsic::aarch64_ldaxp:
10893   case Intrinsic::aarch64_ldxp:
10894     Info.opc = ISD::INTRINSIC_W_CHAIN;
10895     Info.memVT = MVT::i128;
10896     Info.ptrVal = I.getArgOperand(0);
10897     Info.offset = 0;
10898     Info.align = Align(16);
10899     Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOVolatile;
10900     return true;
10901   case Intrinsic::aarch64_stlxp:
10902   case Intrinsic::aarch64_stxp:
10903     Info.opc = ISD::INTRINSIC_W_CHAIN;
10904     Info.memVT = MVT::i128;
10905     Info.ptrVal = I.getArgOperand(2);
10906     Info.offset = 0;
10907     Info.align = Align(16);
10908     Info.flags = MachineMemOperand::MOStore | MachineMemOperand::MOVolatile;
10909     return true;
10910   case Intrinsic::aarch64_sve_ldnt1: {
10911     PointerType *PtrTy = cast<PointerType>(I.getArgOperand(1)->getType());
10912     Info.opc = ISD::INTRINSIC_W_CHAIN;
10913     Info.memVT = MVT::getVT(I.getType());
10914     Info.ptrVal = I.getArgOperand(1);
10915     Info.offset = 0;
10916     Info.align = DL.getABITypeAlign(PtrTy->getElementType());
10917     Info.flags = MachineMemOperand::MOLoad;
10918     if (Intrinsic == Intrinsic::aarch64_sve_ldnt1)
10919       Info.flags |= MachineMemOperand::MONonTemporal;
10920     return true;
10921   }
10922   case Intrinsic::aarch64_sve_stnt1: {
10923     PointerType *PtrTy = cast<PointerType>(I.getArgOperand(2)->getType());
10924     Info.opc = ISD::INTRINSIC_W_CHAIN;
10925     Info.memVT = MVT::getVT(I.getOperand(0)->getType());
10926     Info.ptrVal = I.getArgOperand(2);
10927     Info.offset = 0;
10928     Info.align = DL.getABITypeAlign(PtrTy->getElementType());
10929     Info.flags = MachineMemOperand::MOStore;
10930     if (Intrinsic == Intrinsic::aarch64_sve_stnt1)
10931       Info.flags |= MachineMemOperand::MONonTemporal;
10932     return true;
10933   }
10934   default:
10935     break;
10936   }
10937 
10938   return false;
10939 }
10940 
shouldReduceLoadWidth(SDNode * Load,ISD::LoadExtType ExtTy,EVT NewVT) const10941 bool AArch64TargetLowering::shouldReduceLoadWidth(SDNode *Load,
10942                                                   ISD::LoadExtType ExtTy,
10943                                                   EVT NewVT) const {
10944   // TODO: This may be worth removing. Check regression tests for diffs.
10945   if (!TargetLoweringBase::shouldReduceLoadWidth(Load, ExtTy, NewVT))
10946     return false;
10947 
10948   // If we're reducing the load width in order to avoid having to use an extra
10949   // instruction to do extension then it's probably a good idea.
10950   if (ExtTy != ISD::NON_EXTLOAD)
10951     return true;
10952   // Don't reduce load width if it would prevent us from combining a shift into
10953   // the offset.
10954   MemSDNode *Mem = dyn_cast<MemSDNode>(Load);
10955   assert(Mem);
10956   const SDValue &Base = Mem->getBasePtr();
10957   if (Base.getOpcode() == ISD::ADD &&
10958       Base.getOperand(1).getOpcode() == ISD::SHL &&
10959       Base.getOperand(1).hasOneUse() &&
10960       Base.getOperand(1).getOperand(1).getOpcode() == ISD::Constant) {
10961     // The shift can be combined if it matches the size of the value being
10962     // loaded (and so reducing the width would make it not match).
10963     uint64_t ShiftAmount = Base.getOperand(1).getConstantOperandVal(1);
10964     uint64_t LoadBytes = Mem->getMemoryVT().getSizeInBits()/8;
10965     if (ShiftAmount == Log2_32(LoadBytes))
10966       return false;
10967   }
10968   // We have no reason to disallow reducing the load width, so allow it.
10969   return true;
10970 }
10971 
10972 // Truncations from 64-bit GPR to 32-bit GPR is free.
isTruncateFree(Type * Ty1,Type * Ty2) const10973 bool AArch64TargetLowering::isTruncateFree(Type *Ty1, Type *Ty2) const {
10974   if (!Ty1->isIntegerTy() || !Ty2->isIntegerTy())
10975     return false;
10976   uint64_t NumBits1 = Ty1->getPrimitiveSizeInBits().getFixedSize();
10977   uint64_t NumBits2 = Ty2->getPrimitiveSizeInBits().getFixedSize();
10978   return NumBits1 > NumBits2;
10979 }
isTruncateFree(EVT VT1,EVT VT2) const10980 bool AArch64TargetLowering::isTruncateFree(EVT VT1, EVT VT2) const {
10981   if (VT1.isVector() || VT2.isVector() || !VT1.isInteger() || !VT2.isInteger())
10982     return false;
10983   uint64_t NumBits1 = VT1.getFixedSizeInBits();
10984   uint64_t NumBits2 = VT2.getFixedSizeInBits();
10985   return NumBits1 > NumBits2;
10986 }
10987 
10988 /// Check if it is profitable to hoist instruction in then/else to if.
10989 /// Not profitable if I and it's user can form a FMA instruction
10990 /// because we prefer FMSUB/FMADD.
isProfitableToHoist(Instruction * I) const10991 bool AArch64TargetLowering::isProfitableToHoist(Instruction *I) const {
10992   if (I->getOpcode() != Instruction::FMul)
10993     return true;
10994 
10995   if (!I->hasOneUse())
10996     return true;
10997 
10998   Instruction *User = I->user_back();
10999 
11000   if (User &&
11001       !(User->getOpcode() == Instruction::FSub ||
11002         User->getOpcode() == Instruction::FAdd))
11003     return true;
11004 
11005   const TargetOptions &Options = getTargetMachine().Options;
11006   const Function *F = I->getFunction();
11007   const DataLayout &DL = F->getParent()->getDataLayout();
11008   Type *Ty = User->getOperand(0)->getType();
11009 
11010   return !(isFMAFasterThanFMulAndFAdd(*F, Ty) &&
11011            isOperationLegalOrCustom(ISD::FMA, getValueType(DL, Ty)) &&
11012            (Options.AllowFPOpFusion == FPOpFusion::Fast ||
11013             Options.UnsafeFPMath));
11014 }
11015 
11016 // All 32-bit GPR operations implicitly zero the high-half of the corresponding
11017 // 64-bit GPR.
isZExtFree(Type * Ty1,Type * Ty2) const11018 bool AArch64TargetLowering::isZExtFree(Type *Ty1, Type *Ty2) const {
11019   if (!Ty1->isIntegerTy() || !Ty2->isIntegerTy())
11020     return false;
11021   unsigned NumBits1 = Ty1->getPrimitiveSizeInBits();
11022   unsigned NumBits2 = Ty2->getPrimitiveSizeInBits();
11023   return NumBits1 == 32 && NumBits2 == 64;
11024 }
isZExtFree(EVT VT1,EVT VT2) const11025 bool AArch64TargetLowering::isZExtFree(EVT VT1, EVT VT2) const {
11026   if (VT1.isVector() || VT2.isVector() || !VT1.isInteger() || !VT2.isInteger())
11027     return false;
11028   unsigned NumBits1 = VT1.getSizeInBits();
11029   unsigned NumBits2 = VT2.getSizeInBits();
11030   return NumBits1 == 32 && NumBits2 == 64;
11031 }
11032 
isZExtFree(SDValue Val,EVT VT2) const11033 bool AArch64TargetLowering::isZExtFree(SDValue Val, EVT VT2) const {
11034   EVT VT1 = Val.getValueType();
11035   if (isZExtFree(VT1, VT2)) {
11036     return true;
11037   }
11038 
11039   if (Val.getOpcode() != ISD::LOAD)
11040     return false;
11041 
11042   // 8-, 16-, and 32-bit integer loads all implicitly zero-extend.
11043   return (VT1.isSimple() && !VT1.isVector() && VT1.isInteger() &&
11044           VT2.isSimple() && !VT2.isVector() && VT2.isInteger() &&
11045           VT1.getSizeInBits() <= 32);
11046 }
11047 
isExtFreeImpl(const Instruction * Ext) const11048 bool AArch64TargetLowering::isExtFreeImpl(const Instruction *Ext) const {
11049   if (isa<FPExtInst>(Ext))
11050     return false;
11051 
11052   // Vector types are not free.
11053   if (Ext->getType()->isVectorTy())
11054     return false;
11055 
11056   for (const Use &U : Ext->uses()) {
11057     // The extension is free if we can fold it with a left shift in an
11058     // addressing mode or an arithmetic operation: add, sub, and cmp.
11059 
11060     // Is there a shift?
11061     const Instruction *Instr = cast<Instruction>(U.getUser());
11062 
11063     // Is this a constant shift?
11064     switch (Instr->getOpcode()) {
11065     case Instruction::Shl:
11066       if (!isa<ConstantInt>(Instr->getOperand(1)))
11067         return false;
11068       break;
11069     case Instruction::GetElementPtr: {
11070       gep_type_iterator GTI = gep_type_begin(Instr);
11071       auto &DL = Ext->getModule()->getDataLayout();
11072       std::advance(GTI, U.getOperandNo()-1);
11073       Type *IdxTy = GTI.getIndexedType();
11074       // This extension will end up with a shift because of the scaling factor.
11075       // 8-bit sized types have a scaling factor of 1, thus a shift amount of 0.
11076       // Get the shift amount based on the scaling factor:
11077       // log2(sizeof(IdxTy)) - log2(8).
11078       uint64_t ShiftAmt =
11079         countTrailingZeros(DL.getTypeStoreSizeInBits(IdxTy).getFixedSize()) - 3;
11080       // Is the constant foldable in the shift of the addressing mode?
11081       // I.e., shift amount is between 1 and 4 inclusive.
11082       if (ShiftAmt == 0 || ShiftAmt > 4)
11083         return false;
11084       break;
11085     }
11086     case Instruction::Trunc:
11087       // Check if this is a noop.
11088       // trunc(sext ty1 to ty2) to ty1.
11089       if (Instr->getType() == Ext->getOperand(0)->getType())
11090         continue;
11091       LLVM_FALLTHROUGH;
11092     default:
11093       return false;
11094     }
11095 
11096     // At this point we can use the bfm family, so this extension is free
11097     // for that use.
11098   }
11099   return true;
11100 }
11101 
11102 /// Check if both Op1 and Op2 are shufflevector extracts of either the lower
11103 /// or upper half of the vector elements.
areExtractShuffleVectors(Value * Op1,Value * Op2)11104 static bool areExtractShuffleVectors(Value *Op1, Value *Op2) {
11105   auto areTypesHalfed = [](Value *FullV, Value *HalfV) {
11106     auto *FullTy = FullV->getType();
11107     auto *HalfTy = HalfV->getType();
11108     return FullTy->getPrimitiveSizeInBits().getFixedSize() ==
11109            2 * HalfTy->getPrimitiveSizeInBits().getFixedSize();
11110   };
11111 
11112   auto extractHalf = [](Value *FullV, Value *HalfV) {
11113     auto *FullVT = cast<FixedVectorType>(FullV->getType());
11114     auto *HalfVT = cast<FixedVectorType>(HalfV->getType());
11115     return FullVT->getNumElements() == 2 * HalfVT->getNumElements();
11116   };
11117 
11118   ArrayRef<int> M1, M2;
11119   Value *S1Op1, *S2Op1;
11120   if (!match(Op1, m_Shuffle(m_Value(S1Op1), m_Undef(), m_Mask(M1))) ||
11121       !match(Op2, m_Shuffle(m_Value(S2Op1), m_Undef(), m_Mask(M2))))
11122     return false;
11123 
11124   // Check that the operands are half as wide as the result and we extract
11125   // half of the elements of the input vectors.
11126   if (!areTypesHalfed(S1Op1, Op1) || !areTypesHalfed(S2Op1, Op2) ||
11127       !extractHalf(S1Op1, Op1) || !extractHalf(S2Op1, Op2))
11128     return false;
11129 
11130   // Check the mask extracts either the lower or upper half of vector
11131   // elements.
11132   int M1Start = -1;
11133   int M2Start = -1;
11134   int NumElements = cast<FixedVectorType>(Op1->getType())->getNumElements() * 2;
11135   if (!ShuffleVectorInst::isExtractSubvectorMask(M1, NumElements, M1Start) ||
11136       !ShuffleVectorInst::isExtractSubvectorMask(M2, NumElements, M2Start) ||
11137       M1Start != M2Start || (M1Start != 0 && M2Start != (NumElements / 2)))
11138     return false;
11139 
11140   return true;
11141 }
11142 
11143 /// Check if Ext1 and Ext2 are extends of the same type, doubling the bitwidth
11144 /// of the vector elements.
areExtractExts(Value * Ext1,Value * Ext2)11145 static bool areExtractExts(Value *Ext1, Value *Ext2) {
11146   auto areExtDoubled = [](Instruction *Ext) {
11147     return Ext->getType()->getScalarSizeInBits() ==
11148            2 * Ext->getOperand(0)->getType()->getScalarSizeInBits();
11149   };
11150 
11151   if (!match(Ext1, m_ZExtOrSExt(m_Value())) ||
11152       !match(Ext2, m_ZExtOrSExt(m_Value())) ||
11153       !areExtDoubled(cast<Instruction>(Ext1)) ||
11154       !areExtDoubled(cast<Instruction>(Ext2)))
11155     return false;
11156 
11157   return true;
11158 }
11159 
11160 /// Check if Op could be used with vmull_high_p64 intrinsic.
isOperandOfVmullHighP64(Value * Op)11161 static bool isOperandOfVmullHighP64(Value *Op) {
11162   Value *VectorOperand = nullptr;
11163   ConstantInt *ElementIndex = nullptr;
11164   return match(Op, m_ExtractElt(m_Value(VectorOperand),
11165                                 m_ConstantInt(ElementIndex))) &&
11166          ElementIndex->getValue() == 1 &&
11167          isa<FixedVectorType>(VectorOperand->getType()) &&
11168          cast<FixedVectorType>(VectorOperand->getType())->getNumElements() == 2;
11169 }
11170 
11171 /// Check if Op1 and Op2 could be used with vmull_high_p64 intrinsic.
areOperandsOfVmullHighP64(Value * Op1,Value * Op2)11172 static bool areOperandsOfVmullHighP64(Value *Op1, Value *Op2) {
11173   return isOperandOfVmullHighP64(Op1) && isOperandOfVmullHighP64(Op2);
11174 }
11175 
11176 /// Check if sinking \p I's operands to I's basic block is profitable, because
11177 /// the operands can be folded into a target instruction, e.g.
11178 /// shufflevectors extracts and/or sext/zext can be folded into (u,s)subl(2).
shouldSinkOperands(Instruction * I,SmallVectorImpl<Use * > & Ops) const11179 bool AArch64TargetLowering::shouldSinkOperands(
11180     Instruction *I, SmallVectorImpl<Use *> &Ops) const {
11181   if (!I->getType()->isVectorTy())
11182     return false;
11183 
11184   if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
11185     switch (II->getIntrinsicID()) {
11186     case Intrinsic::aarch64_neon_umull:
11187       if (!areExtractShuffleVectors(II->getOperand(0), II->getOperand(1)))
11188         return false;
11189       Ops.push_back(&II->getOperandUse(0));
11190       Ops.push_back(&II->getOperandUse(1));
11191       return true;
11192 
11193     case Intrinsic::aarch64_neon_pmull64:
11194       if (!areOperandsOfVmullHighP64(II->getArgOperand(0),
11195                                      II->getArgOperand(1)))
11196         return false;
11197       Ops.push_back(&II->getArgOperandUse(0));
11198       Ops.push_back(&II->getArgOperandUse(1));
11199       return true;
11200 
11201     default:
11202       return false;
11203     }
11204   }
11205 
11206   switch (I->getOpcode()) {
11207   case Instruction::Sub:
11208   case Instruction::Add: {
11209     if (!areExtractExts(I->getOperand(0), I->getOperand(1)))
11210       return false;
11211 
11212     // If the exts' operands extract either the lower or upper elements, we
11213     // can sink them too.
11214     auto Ext1 = cast<Instruction>(I->getOperand(0));
11215     auto Ext2 = cast<Instruction>(I->getOperand(1));
11216     if (areExtractShuffleVectors(Ext1, Ext2)) {
11217       Ops.push_back(&Ext1->getOperandUse(0));
11218       Ops.push_back(&Ext2->getOperandUse(0));
11219     }
11220 
11221     Ops.push_back(&I->getOperandUse(0));
11222     Ops.push_back(&I->getOperandUse(1));
11223 
11224     return true;
11225   }
11226   case Instruction::Mul: {
11227     bool IsProfitable = false;
11228     for (auto &Op : I->operands()) {
11229       // Make sure we are not already sinking this operand
11230       if (any_of(Ops, [&](Use *U) { return U->get() == Op; }))
11231         continue;
11232 
11233       ShuffleVectorInst *Shuffle = dyn_cast<ShuffleVectorInst>(Op);
11234       if (!Shuffle || !Shuffle->isZeroEltSplat())
11235         continue;
11236 
11237       Value *ShuffleOperand = Shuffle->getOperand(0);
11238       InsertElementInst *Insert = dyn_cast<InsertElementInst>(ShuffleOperand);
11239       if (!Insert)
11240         continue;
11241 
11242       Instruction *OperandInstr = dyn_cast<Instruction>(Insert->getOperand(1));
11243       if (!OperandInstr)
11244         continue;
11245 
11246       ConstantInt *ElementConstant =
11247           dyn_cast<ConstantInt>(Insert->getOperand(2));
11248       // Check that the insertelement is inserting into element 0
11249       if (!ElementConstant || ElementConstant->getZExtValue() != 0)
11250         continue;
11251 
11252       unsigned Opcode = OperandInstr->getOpcode();
11253       if (Opcode != Instruction::SExt && Opcode != Instruction::ZExt)
11254         continue;
11255 
11256       Ops.push_back(&Shuffle->getOperandUse(0));
11257       Ops.push_back(&Op);
11258       IsProfitable = true;
11259     }
11260 
11261     return IsProfitable;
11262   }
11263   default:
11264     return false;
11265   }
11266   return false;
11267 }
11268 
hasPairedLoad(EVT LoadedType,Align & RequiredAligment) const11269 bool AArch64TargetLowering::hasPairedLoad(EVT LoadedType,
11270                                           Align &RequiredAligment) const {
11271   if (!LoadedType.isSimple() ||
11272       (!LoadedType.isInteger() && !LoadedType.isFloatingPoint()))
11273     return false;
11274   // Cyclone supports unaligned accesses.
11275   RequiredAligment = Align(1);
11276   unsigned NumBits = LoadedType.getSizeInBits();
11277   return NumBits == 32 || NumBits == 64;
11278 }
11279 
11280 /// A helper function for determining the number of interleaved accesses we
11281 /// will generate when lowering accesses of the given type.
11282 unsigned
getNumInterleavedAccesses(VectorType * VecTy,const DataLayout & DL) const11283 AArch64TargetLowering::getNumInterleavedAccesses(VectorType *VecTy,
11284                                                  const DataLayout &DL) const {
11285   return (DL.getTypeSizeInBits(VecTy) + 127) / 128;
11286 }
11287 
11288 MachineMemOperand::Flags
getTargetMMOFlags(const Instruction & I) const11289 AArch64TargetLowering::getTargetMMOFlags(const Instruction &I) const {
11290   if (Subtarget->getProcFamily() == AArch64Subtarget::Falkor &&
11291       I.getMetadata(FALKOR_STRIDED_ACCESS_MD) != nullptr)
11292     return MOStridedAccess;
11293   return MachineMemOperand::MONone;
11294 }
11295 
isLegalInterleavedAccessType(VectorType * VecTy,const DataLayout & DL) const11296 bool AArch64TargetLowering::isLegalInterleavedAccessType(
11297     VectorType *VecTy, const DataLayout &DL) const {
11298 
11299   unsigned VecSize = DL.getTypeSizeInBits(VecTy);
11300   unsigned ElSize = DL.getTypeSizeInBits(VecTy->getElementType());
11301 
11302   // Ensure the number of vector elements is greater than 1.
11303   if (cast<FixedVectorType>(VecTy)->getNumElements() < 2)
11304     return false;
11305 
11306   // Ensure the element type is legal.
11307   if (ElSize != 8 && ElSize != 16 && ElSize != 32 && ElSize != 64)
11308     return false;
11309 
11310   // Ensure the total vector size is 64 or a multiple of 128. Types larger than
11311   // 128 will be split into multiple interleaved accesses.
11312   return VecSize == 64 || VecSize % 128 == 0;
11313 }
11314 
11315 /// Lower an interleaved load into a ldN intrinsic.
11316 ///
11317 /// E.g. Lower an interleaved load (Factor = 2):
11318 ///        %wide.vec = load <8 x i32>, <8 x i32>* %ptr
11319 ///        %v0 = shuffle %wide.vec, undef, <0, 2, 4, 6>  ; Extract even elements
11320 ///        %v1 = shuffle %wide.vec, undef, <1, 3, 5, 7>  ; Extract odd elements
11321 ///
11322 ///      Into:
11323 ///        %ld2 = { <4 x i32>, <4 x i32> } call llvm.aarch64.neon.ld2(%ptr)
11324 ///        %vec0 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 0
11325 ///        %vec1 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 1
lowerInterleavedLoad(LoadInst * LI,ArrayRef<ShuffleVectorInst * > Shuffles,ArrayRef<unsigned> Indices,unsigned Factor) const11326 bool AArch64TargetLowering::lowerInterleavedLoad(
11327     LoadInst *LI, ArrayRef<ShuffleVectorInst *> Shuffles,
11328     ArrayRef<unsigned> Indices, unsigned Factor) const {
11329   assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() &&
11330          "Invalid interleave factor");
11331   assert(!Shuffles.empty() && "Empty shufflevector input");
11332   assert(Shuffles.size() == Indices.size() &&
11333          "Unmatched number of shufflevectors and indices");
11334 
11335   const DataLayout &DL = LI->getModule()->getDataLayout();
11336 
11337   VectorType *VTy = Shuffles[0]->getType();
11338 
11339   // Skip if we do not have NEON and skip illegal vector types. We can
11340   // "legalize" wide vector types into multiple interleaved accesses as long as
11341   // the vector types are divisible by 128.
11342   if (!Subtarget->hasNEON() || !isLegalInterleavedAccessType(VTy, DL))
11343     return false;
11344 
11345   unsigned NumLoads = getNumInterleavedAccesses(VTy, DL);
11346 
11347   auto *FVTy = cast<FixedVectorType>(VTy);
11348 
11349   // A pointer vector can not be the return type of the ldN intrinsics. Need to
11350   // load integer vectors first and then convert to pointer vectors.
11351   Type *EltTy = FVTy->getElementType();
11352   if (EltTy->isPointerTy())
11353     FVTy =
11354         FixedVectorType::get(DL.getIntPtrType(EltTy), FVTy->getNumElements());
11355 
11356   IRBuilder<> Builder(LI);
11357 
11358   // The base address of the load.
11359   Value *BaseAddr = LI->getPointerOperand();
11360 
11361   if (NumLoads > 1) {
11362     // If we're going to generate more than one load, reset the sub-vector type
11363     // to something legal.
11364     FVTy = FixedVectorType::get(FVTy->getElementType(),
11365                                 FVTy->getNumElements() / NumLoads);
11366 
11367     // We will compute the pointer operand of each load from the original base
11368     // address using GEPs. Cast the base address to a pointer to the scalar
11369     // element type.
11370     BaseAddr = Builder.CreateBitCast(
11371         BaseAddr,
11372         FVTy->getElementType()->getPointerTo(LI->getPointerAddressSpace()));
11373   }
11374 
11375   Type *PtrTy = FVTy->getPointerTo(LI->getPointerAddressSpace());
11376   Type *Tys[2] = {FVTy, PtrTy};
11377   static const Intrinsic::ID LoadInts[3] = {Intrinsic::aarch64_neon_ld2,
11378                                             Intrinsic::aarch64_neon_ld3,
11379                                             Intrinsic::aarch64_neon_ld4};
11380   Function *LdNFunc =
11381       Intrinsic::getDeclaration(LI->getModule(), LoadInts[Factor - 2], Tys);
11382 
11383   // Holds sub-vectors extracted from the load intrinsic return values. The
11384   // sub-vectors are associated with the shufflevector instructions they will
11385   // replace.
11386   DenseMap<ShuffleVectorInst *, SmallVector<Value *, 4>> SubVecs;
11387 
11388   for (unsigned LoadCount = 0; LoadCount < NumLoads; ++LoadCount) {
11389 
11390     // If we're generating more than one load, compute the base address of
11391     // subsequent loads as an offset from the previous.
11392     if (LoadCount > 0)
11393       BaseAddr = Builder.CreateConstGEP1_32(FVTy->getElementType(), BaseAddr,
11394                                             FVTy->getNumElements() * Factor);
11395 
11396     CallInst *LdN = Builder.CreateCall(
11397         LdNFunc, Builder.CreateBitCast(BaseAddr, PtrTy), "ldN");
11398 
11399     // Extract and store the sub-vectors returned by the load intrinsic.
11400     for (unsigned i = 0; i < Shuffles.size(); i++) {
11401       ShuffleVectorInst *SVI = Shuffles[i];
11402       unsigned Index = Indices[i];
11403 
11404       Value *SubVec = Builder.CreateExtractValue(LdN, Index);
11405 
11406       // Convert the integer vector to pointer vector if the element is pointer.
11407       if (EltTy->isPointerTy())
11408         SubVec = Builder.CreateIntToPtr(
11409             SubVec, FixedVectorType::get(SVI->getType()->getElementType(),
11410                                          FVTy->getNumElements()));
11411       SubVecs[SVI].push_back(SubVec);
11412     }
11413   }
11414 
11415   // Replace uses of the shufflevector instructions with the sub-vectors
11416   // returned by the load intrinsic. If a shufflevector instruction is
11417   // associated with more than one sub-vector, those sub-vectors will be
11418   // concatenated into a single wide vector.
11419   for (ShuffleVectorInst *SVI : Shuffles) {
11420     auto &SubVec = SubVecs[SVI];
11421     auto *WideVec =
11422         SubVec.size() > 1 ? concatenateVectors(Builder, SubVec) : SubVec[0];
11423     SVI->replaceAllUsesWith(WideVec);
11424   }
11425 
11426   return true;
11427 }
11428 
11429 /// Lower an interleaved store into a stN intrinsic.
11430 ///
11431 /// E.g. Lower an interleaved store (Factor = 3):
11432 ///        %i.vec = shuffle <8 x i32> %v0, <8 x i32> %v1,
11433 ///                 <0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11>
11434 ///        store <12 x i32> %i.vec, <12 x i32>* %ptr
11435 ///
11436 ///      Into:
11437 ///        %sub.v0 = shuffle <8 x i32> %v0, <8 x i32> v1, <0, 1, 2, 3>
11438 ///        %sub.v1 = shuffle <8 x i32> %v0, <8 x i32> v1, <4, 5, 6, 7>
11439 ///        %sub.v2 = shuffle <8 x i32> %v0, <8 x i32> v1, <8, 9, 10, 11>
11440 ///        call void llvm.aarch64.neon.st3(%sub.v0, %sub.v1, %sub.v2, %ptr)
11441 ///
11442 /// Note that the new shufflevectors will be removed and we'll only generate one
11443 /// st3 instruction in CodeGen.
11444 ///
11445 /// Example for a more general valid mask (Factor 3). Lower:
11446 ///        %i.vec = shuffle <32 x i32> %v0, <32 x i32> %v1,
11447 ///                 <4, 32, 16, 5, 33, 17, 6, 34, 18, 7, 35, 19>
11448 ///        store <12 x i32> %i.vec, <12 x i32>* %ptr
11449 ///
11450 ///      Into:
11451 ///        %sub.v0 = shuffle <32 x i32> %v0, <32 x i32> v1, <4, 5, 6, 7>
11452 ///        %sub.v1 = shuffle <32 x i32> %v0, <32 x i32> v1, <32, 33, 34, 35>
11453 ///        %sub.v2 = shuffle <32 x i32> %v0, <32 x i32> v1, <16, 17, 18, 19>
11454 ///        call void llvm.aarch64.neon.st3(%sub.v0, %sub.v1, %sub.v2, %ptr)
lowerInterleavedStore(StoreInst * SI,ShuffleVectorInst * SVI,unsigned Factor) const11455 bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
11456                                                   ShuffleVectorInst *SVI,
11457                                                   unsigned Factor) const {
11458   assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() &&
11459          "Invalid interleave factor");
11460 
11461   auto *VecTy = cast<FixedVectorType>(SVI->getType());
11462   assert(VecTy->getNumElements() % Factor == 0 && "Invalid interleaved store");
11463 
11464   unsigned LaneLen = VecTy->getNumElements() / Factor;
11465   Type *EltTy = VecTy->getElementType();
11466   auto *SubVecTy = FixedVectorType::get(EltTy, LaneLen);
11467 
11468   const DataLayout &DL = SI->getModule()->getDataLayout();
11469 
11470   // Skip if we do not have NEON and skip illegal vector types. We can
11471   // "legalize" wide vector types into multiple interleaved accesses as long as
11472   // the vector types are divisible by 128.
11473   if (!Subtarget->hasNEON() || !isLegalInterleavedAccessType(SubVecTy, DL))
11474     return false;
11475 
11476   unsigned NumStores = getNumInterleavedAccesses(SubVecTy, DL);
11477 
11478   Value *Op0 = SVI->getOperand(0);
11479   Value *Op1 = SVI->getOperand(1);
11480   IRBuilder<> Builder(SI);
11481 
11482   // StN intrinsics don't support pointer vectors as arguments. Convert pointer
11483   // vectors to integer vectors.
11484   if (EltTy->isPointerTy()) {
11485     Type *IntTy = DL.getIntPtrType(EltTy);
11486     unsigned NumOpElts =
11487         cast<FixedVectorType>(Op0->getType())->getNumElements();
11488 
11489     // Convert to the corresponding integer vector.
11490     auto *IntVecTy = FixedVectorType::get(IntTy, NumOpElts);
11491     Op0 = Builder.CreatePtrToInt(Op0, IntVecTy);
11492     Op1 = Builder.CreatePtrToInt(Op1, IntVecTy);
11493 
11494     SubVecTy = FixedVectorType::get(IntTy, LaneLen);
11495   }
11496 
11497   // The base address of the store.
11498   Value *BaseAddr = SI->getPointerOperand();
11499 
11500   if (NumStores > 1) {
11501     // If we're going to generate more than one store, reset the lane length
11502     // and sub-vector type to something legal.
11503     LaneLen /= NumStores;
11504     SubVecTy = FixedVectorType::get(SubVecTy->getElementType(), LaneLen);
11505 
11506     // We will compute the pointer operand of each store from the original base
11507     // address using GEPs. Cast the base address to a pointer to the scalar
11508     // element type.
11509     BaseAddr = Builder.CreateBitCast(
11510         BaseAddr,
11511         SubVecTy->getElementType()->getPointerTo(SI->getPointerAddressSpace()));
11512   }
11513 
11514   auto Mask = SVI->getShuffleMask();
11515 
11516   Type *PtrTy = SubVecTy->getPointerTo(SI->getPointerAddressSpace());
11517   Type *Tys[2] = {SubVecTy, PtrTy};
11518   static const Intrinsic::ID StoreInts[3] = {Intrinsic::aarch64_neon_st2,
11519                                              Intrinsic::aarch64_neon_st3,
11520                                              Intrinsic::aarch64_neon_st4};
11521   Function *StNFunc =
11522       Intrinsic::getDeclaration(SI->getModule(), StoreInts[Factor - 2], Tys);
11523 
11524   for (unsigned StoreCount = 0; StoreCount < NumStores; ++StoreCount) {
11525 
11526     SmallVector<Value *, 5> Ops;
11527 
11528     // Split the shufflevector operands into sub vectors for the new stN call.
11529     for (unsigned i = 0; i < Factor; i++) {
11530       unsigned IdxI = StoreCount * LaneLen * Factor + i;
11531       if (Mask[IdxI] >= 0) {
11532         Ops.push_back(Builder.CreateShuffleVector(
11533             Op0, Op1, createSequentialMask(Mask[IdxI], LaneLen, 0)));
11534       } else {
11535         unsigned StartMask = 0;
11536         for (unsigned j = 1; j < LaneLen; j++) {
11537           unsigned IdxJ = StoreCount * LaneLen * Factor + j;
11538           if (Mask[IdxJ * Factor + IdxI] >= 0) {
11539             StartMask = Mask[IdxJ * Factor + IdxI] - IdxJ;
11540             break;
11541           }
11542         }
11543         // Note: Filling undef gaps with random elements is ok, since
11544         // those elements were being written anyway (with undefs).
11545         // In the case of all undefs we're defaulting to using elems from 0
11546         // Note: StartMask cannot be negative, it's checked in
11547         // isReInterleaveMask
11548         Ops.push_back(Builder.CreateShuffleVector(
11549             Op0, Op1, createSequentialMask(StartMask, LaneLen, 0)));
11550       }
11551     }
11552 
11553     // If we generating more than one store, we compute the base address of
11554     // subsequent stores as an offset from the previous.
11555     if (StoreCount > 0)
11556       BaseAddr = Builder.CreateConstGEP1_32(SubVecTy->getElementType(),
11557                                             BaseAddr, LaneLen * Factor);
11558 
11559     Ops.push_back(Builder.CreateBitCast(BaseAddr, PtrTy));
11560     Builder.CreateCall(StNFunc, Ops);
11561   }
11562   return true;
11563 }
11564 
11565 // Lower an SVE structured load intrinsic returning a tuple type to target
11566 // specific intrinsic taking the same input but returning a multi-result value
11567 // of the split tuple type.
11568 //
11569 // E.g. Lowering an LD3:
11570 //
11571 //  call <vscale x 12 x i32> @llvm.aarch64.sve.ld3.nxv12i32(
11572 //                                                    <vscale x 4 x i1> %pred,
11573 //                                                    <vscale x 4 x i32>* %addr)
11574 //
11575 //  Output DAG:
11576 //
11577 //    t0: ch = EntryToken
11578 //        t2: nxv4i1,ch = CopyFromReg t0, Register:nxv4i1 %0
11579 //        t4: i64,ch = CopyFromReg t0, Register:i64 %1
11580 //    t5: nxv4i32,nxv4i32,nxv4i32,ch = AArch64ISD::SVE_LD3 t0, t2, t4
11581 //    t6: nxv12i32 = concat_vectors t5, t5:1, t5:2
11582 //
11583 // This is called pre-legalization to avoid widening/splitting issues with
11584 // non-power-of-2 tuple types used for LD3, such as nxv12i32.
LowerSVEStructLoad(unsigned Intrinsic,ArrayRef<SDValue> LoadOps,EVT VT,SelectionDAG & DAG,const SDLoc & DL) const11585 SDValue AArch64TargetLowering::LowerSVEStructLoad(unsigned Intrinsic,
11586                                                   ArrayRef<SDValue> LoadOps,
11587                                                   EVT VT, SelectionDAG &DAG,
11588                                                   const SDLoc &DL) const {
11589   assert(VT.isScalableVector() && "Can only lower scalable vectors");
11590 
11591   unsigned N, Opcode;
11592   static std::map<unsigned, std::pair<unsigned, unsigned>> IntrinsicMap = {
11593       {Intrinsic::aarch64_sve_ld2, {2, AArch64ISD::SVE_LD2_MERGE_ZERO}},
11594       {Intrinsic::aarch64_sve_ld3, {3, AArch64ISD::SVE_LD3_MERGE_ZERO}},
11595       {Intrinsic::aarch64_sve_ld4, {4, AArch64ISD::SVE_LD4_MERGE_ZERO}}};
11596 
11597   std::tie(N, Opcode) = IntrinsicMap[Intrinsic];
11598   assert(VT.getVectorElementCount().getKnownMinValue() % N == 0 &&
11599          "invalid tuple vector type!");
11600 
11601   EVT SplitVT =
11602       EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
11603                        VT.getVectorElementCount().divideCoefficientBy(N));
11604   assert(isTypeLegal(SplitVT));
11605 
11606   SmallVector<EVT, 5> VTs(N, SplitVT);
11607   VTs.push_back(MVT::Other); // Chain
11608   SDVTList NodeTys = DAG.getVTList(VTs);
11609 
11610   SDValue PseudoLoad = DAG.getNode(Opcode, DL, NodeTys, LoadOps);
11611   SmallVector<SDValue, 4> PseudoLoadOps;
11612   for (unsigned I = 0; I < N; ++I)
11613     PseudoLoadOps.push_back(SDValue(PseudoLoad.getNode(), I));
11614   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, PseudoLoadOps);
11615 }
11616 
getOptimalMemOpType(const MemOp & Op,const AttributeList & FuncAttributes) const11617 EVT AArch64TargetLowering::getOptimalMemOpType(
11618     const MemOp &Op, const AttributeList &FuncAttributes) const {
11619   bool CanImplicitFloat =
11620       !FuncAttributes.hasFnAttribute(Attribute::NoImplicitFloat);
11621   bool CanUseNEON = Subtarget->hasNEON() && CanImplicitFloat;
11622   bool CanUseFP = Subtarget->hasFPARMv8() && CanImplicitFloat;
11623   // Only use AdvSIMD to implement memset of 32-byte and above. It would have
11624   // taken one instruction to materialize the v2i64 zero and one store (with
11625   // restrictive addressing mode). Just do i64 stores.
11626   bool IsSmallMemset = Op.isMemset() && Op.size() < 32;
11627   auto AlignmentIsAcceptable = [&](EVT VT, Align AlignCheck) {
11628     if (Op.isAligned(AlignCheck))
11629       return true;
11630     bool Fast;
11631     return allowsMisalignedMemoryAccesses(VT, 0, Align(1),
11632                                           MachineMemOperand::MONone, &Fast) &&
11633            Fast;
11634   };
11635 
11636   if (CanUseNEON && Op.isMemset() && !IsSmallMemset &&
11637       AlignmentIsAcceptable(MVT::v2i64, Align(16)))
11638     return MVT::v2i64;
11639   if (CanUseFP && !IsSmallMemset && AlignmentIsAcceptable(MVT::f128, Align(16)))
11640     return MVT::f128;
11641   if (Op.size() >= 8 && AlignmentIsAcceptable(MVT::i64, Align(8)))
11642     return MVT::i64;
11643   if (Op.size() >= 4 && AlignmentIsAcceptable(MVT::i32, Align(4)))
11644     return MVT::i32;
11645   return MVT::Other;
11646 }
11647 
getOptimalMemOpLLT(const MemOp & Op,const AttributeList & FuncAttributes) const11648 LLT AArch64TargetLowering::getOptimalMemOpLLT(
11649     const MemOp &Op, const AttributeList &FuncAttributes) const {
11650   bool CanImplicitFloat =
11651       !FuncAttributes.hasFnAttribute(Attribute::NoImplicitFloat);
11652   bool CanUseNEON = Subtarget->hasNEON() && CanImplicitFloat;
11653   bool CanUseFP = Subtarget->hasFPARMv8() && CanImplicitFloat;
11654   // Only use AdvSIMD to implement memset of 32-byte and above. It would have
11655   // taken one instruction to materialize the v2i64 zero and one store (with
11656   // restrictive addressing mode). Just do i64 stores.
11657   bool IsSmallMemset = Op.isMemset() && Op.size() < 32;
11658   auto AlignmentIsAcceptable = [&](EVT VT, Align AlignCheck) {
11659     if (Op.isAligned(AlignCheck))
11660       return true;
11661     bool Fast;
11662     return allowsMisalignedMemoryAccesses(VT, 0, Align(1),
11663                                           MachineMemOperand::MONone, &Fast) &&
11664            Fast;
11665   };
11666 
11667   if (CanUseNEON && Op.isMemset() && !IsSmallMemset &&
11668       AlignmentIsAcceptable(MVT::v2i64, Align(16)))
11669     return LLT::vector(2, 64);
11670   if (CanUseFP && !IsSmallMemset && AlignmentIsAcceptable(MVT::f128, Align(16)))
11671     return LLT::scalar(128);
11672   if (Op.size() >= 8 && AlignmentIsAcceptable(MVT::i64, Align(8)))
11673     return LLT::scalar(64);
11674   if (Op.size() >= 4 && AlignmentIsAcceptable(MVT::i32, Align(4)))
11675     return LLT::scalar(32);
11676   return LLT();
11677 }
11678 
11679 // 12-bit optionally shifted immediates are legal for adds.
isLegalAddImmediate(int64_t Immed) const11680 bool AArch64TargetLowering::isLegalAddImmediate(int64_t Immed) const {
11681   if (Immed == std::numeric_limits<int64_t>::min()) {
11682     LLVM_DEBUG(dbgs() << "Illegal add imm " << Immed
11683                       << ": avoid UB for INT64_MIN\n");
11684     return false;
11685   }
11686   // Same encoding for add/sub, just flip the sign.
11687   Immed = std::abs(Immed);
11688   bool IsLegal = ((Immed >> 12) == 0 ||
11689                   ((Immed & 0xfff) == 0 && Immed >> 24 == 0));
11690   LLVM_DEBUG(dbgs() << "Is " << Immed
11691                     << " legal add imm: " << (IsLegal ? "yes" : "no") << "\n");
11692   return IsLegal;
11693 }
11694 
11695 // Integer comparisons are implemented with ADDS/SUBS, so the range of valid
11696 // immediates is the same as for an add or a sub.
isLegalICmpImmediate(int64_t Immed) const11697 bool AArch64TargetLowering::isLegalICmpImmediate(int64_t Immed) const {
11698   return isLegalAddImmediate(Immed);
11699 }
11700 
11701 /// isLegalAddressingMode - Return true if the addressing mode represented
11702 /// by AM is legal for this target, for a load/store of the specified type.
isLegalAddressingMode(const DataLayout & DL,const AddrMode & AM,Type * Ty,unsigned AS,Instruction * I) const11703 bool AArch64TargetLowering::isLegalAddressingMode(const DataLayout &DL,
11704                                                   const AddrMode &AM, Type *Ty,
11705                                                   unsigned AS, Instruction *I) const {
11706   // AArch64 has five basic addressing modes:
11707   //  reg
11708   //  reg + 9-bit signed offset
11709   //  reg + SIZE_IN_BYTES * 12-bit unsigned offset
11710   //  reg1 + reg2
11711   //  reg + SIZE_IN_BYTES * reg
11712 
11713   // No global is ever allowed as a base.
11714   if (AM.BaseGV)
11715     return false;
11716 
11717   // No reg+reg+imm addressing.
11718   if (AM.HasBaseReg && AM.BaseOffs && AM.Scale)
11719     return false;
11720 
11721   // FIXME: Update this method to support scalable addressing modes.
11722   if (isa<ScalableVectorType>(Ty))
11723     return AM.HasBaseReg && !AM.BaseOffs && !AM.Scale;
11724 
11725   // check reg + imm case:
11726   // i.e., reg + 0, reg + imm9, reg + SIZE_IN_BYTES * uimm12
11727   uint64_t NumBytes = 0;
11728   if (Ty->isSized()) {
11729     uint64_t NumBits = DL.getTypeSizeInBits(Ty);
11730     NumBytes = NumBits / 8;
11731     if (!isPowerOf2_64(NumBits))
11732       NumBytes = 0;
11733   }
11734 
11735   if (!AM.Scale) {
11736     int64_t Offset = AM.BaseOffs;
11737 
11738     // 9-bit signed offset
11739     if (isInt<9>(Offset))
11740       return true;
11741 
11742     // 12-bit unsigned offset
11743     unsigned shift = Log2_64(NumBytes);
11744     if (NumBytes && Offset > 0 && (Offset / NumBytes) <= (1LL << 12) - 1 &&
11745         // Must be a multiple of NumBytes (NumBytes is a power of 2)
11746         (Offset >> shift) << shift == Offset)
11747       return true;
11748     return false;
11749   }
11750 
11751   // Check reg1 + SIZE_IN_BYTES * reg2 and reg1 + reg2
11752 
11753   return AM.Scale == 1 || (AM.Scale > 0 && (uint64_t)AM.Scale == NumBytes);
11754 }
11755 
shouldConsiderGEPOffsetSplit() const11756 bool AArch64TargetLowering::shouldConsiderGEPOffsetSplit() const {
11757   // Consider splitting large offset of struct or array.
11758   return true;
11759 }
11760 
getScalingFactorCost(const DataLayout & DL,const AddrMode & AM,Type * Ty,unsigned AS) const11761 InstructionCost AArch64TargetLowering::getScalingFactorCost(
11762     const DataLayout &DL, const AddrMode &AM, Type *Ty, unsigned AS) const {
11763   // Scaling factors are not free at all.
11764   // Operands                     | Rt Latency
11765   // -------------------------------------------
11766   // Rt, [Xn, Xm]                 | 4
11767   // -------------------------------------------
11768   // Rt, [Xn, Xm, lsl #imm]       | Rn: 4 Rm: 5
11769   // Rt, [Xn, Wm, <extend> #imm]  |
11770   if (isLegalAddressingMode(DL, AM, Ty, AS))
11771     // Scale represents reg2 * scale, thus account for 1 if
11772     // it is not equal to 0 or 1.
11773     return AM.Scale != 0 && AM.Scale != 1;
11774   return -1;
11775 }
11776 
isFMAFasterThanFMulAndFAdd(const MachineFunction & MF,EVT VT) const11777 bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(
11778     const MachineFunction &MF, EVT VT) const {
11779   VT = VT.getScalarType();
11780 
11781   if (!VT.isSimple())
11782     return false;
11783 
11784   switch (VT.getSimpleVT().SimpleTy) {
11785   case MVT::f16:
11786     return Subtarget->hasFullFP16();
11787   case MVT::f32:
11788   case MVT::f64:
11789     return true;
11790   default:
11791     break;
11792   }
11793 
11794   return false;
11795 }
11796 
isFMAFasterThanFMulAndFAdd(const Function & F,Type * Ty) const11797 bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(const Function &F,
11798                                                        Type *Ty) const {
11799   switch (Ty->getScalarType()->getTypeID()) {
11800   case Type::FloatTyID:
11801   case Type::DoubleTyID:
11802     return true;
11803   default:
11804     return false;
11805   }
11806 }
11807 
generateFMAsInMachineCombiner(EVT VT,CodeGenOpt::Level OptLevel) const11808 bool AArch64TargetLowering::generateFMAsInMachineCombiner(
11809     EVT VT, CodeGenOpt::Level OptLevel) const {
11810   return (OptLevel >= CodeGenOpt::Aggressive) && !VT.isScalableVector();
11811 }
11812 
11813 const MCPhysReg *
getScratchRegisters(CallingConv::ID) const11814 AArch64TargetLowering::getScratchRegisters(CallingConv::ID) const {
11815   // LR is a callee-save register, but we must treat it as clobbered by any call
11816   // site. Hence we include LR in the scratch registers, which are in turn added
11817   // as implicit-defs for stackmaps and patchpoints.
11818   static const MCPhysReg ScratchRegs[] = {
11819     AArch64::X16, AArch64::X17, AArch64::LR, 0
11820   };
11821   return ScratchRegs;
11822 }
11823 
11824 bool
isDesirableToCommuteWithShift(const SDNode * N,CombineLevel Level) const11825 AArch64TargetLowering::isDesirableToCommuteWithShift(const SDNode *N,
11826                                                      CombineLevel Level) const {
11827   N = N->getOperand(0).getNode();
11828   EVT VT = N->getValueType(0);
11829     // If N is unsigned bit extraction: ((x >> C) & mask), then do not combine
11830     // it with shift to let it be lowered to UBFX.
11831   if (N->getOpcode() == ISD::AND && (VT == MVT::i32 || VT == MVT::i64) &&
11832       isa<ConstantSDNode>(N->getOperand(1))) {
11833     uint64_t TruncMask = N->getConstantOperandVal(1);
11834     if (isMask_64(TruncMask) &&
11835       N->getOperand(0).getOpcode() == ISD::SRL &&
11836       isa<ConstantSDNode>(N->getOperand(0)->getOperand(1)))
11837       return false;
11838   }
11839   return true;
11840 }
11841 
shouldConvertConstantLoadToIntImm(const APInt & Imm,Type * Ty) const11842 bool AArch64TargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm,
11843                                                               Type *Ty) const {
11844   assert(Ty->isIntegerTy());
11845 
11846   unsigned BitSize = Ty->getPrimitiveSizeInBits();
11847   if (BitSize == 0)
11848     return false;
11849 
11850   int64_t Val = Imm.getSExtValue();
11851   if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, BitSize))
11852     return true;
11853 
11854   if ((int64_t)Val < 0)
11855     Val = ~Val;
11856   if (BitSize == 32)
11857     Val &= (1LL << 32) - 1;
11858 
11859   unsigned LZ = countLeadingZeros((uint64_t)Val);
11860   unsigned Shift = (63 - LZ) / 16;
11861   // MOVZ is free so return true for one or fewer MOVK.
11862   return Shift < 3;
11863 }
11864 
isExtractSubvectorCheap(EVT ResVT,EVT SrcVT,unsigned Index) const11865 bool AArch64TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT,
11866                                                     unsigned Index) const {
11867   if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT))
11868     return false;
11869 
11870   return (Index == 0 || Index == ResVT.getVectorNumElements());
11871 }
11872 
11873 /// Turn vector tests of the signbit in the form of:
11874 ///   xor (sra X, elt_size(X)-1), -1
11875 /// into:
11876 ///   cmge X, X, #0
foldVectorXorShiftIntoCmp(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)11877 static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
11878                                          const AArch64Subtarget *Subtarget) {
11879   EVT VT = N->getValueType(0);
11880   if (!Subtarget->hasNEON() || !VT.isVector())
11881     return SDValue();
11882 
11883   // There must be a shift right algebraic before the xor, and the xor must be a
11884   // 'not' operation.
11885   SDValue Shift = N->getOperand(0);
11886   SDValue Ones = N->getOperand(1);
11887   if (Shift.getOpcode() != AArch64ISD::VASHR || !Shift.hasOneUse() ||
11888       !ISD::isBuildVectorAllOnes(Ones.getNode()))
11889     return SDValue();
11890 
11891   // The shift should be smearing the sign bit across each vector element.
11892   auto *ShiftAmt = dyn_cast<ConstantSDNode>(Shift.getOperand(1));
11893   EVT ShiftEltTy = Shift.getValueType().getVectorElementType();
11894   if (!ShiftAmt || ShiftAmt->getZExtValue() != ShiftEltTy.getSizeInBits() - 1)
11895     return SDValue();
11896 
11897   return DAG.getNode(AArch64ISD::CMGEz, SDLoc(N), VT, Shift.getOperand(0));
11898 }
11899 
11900 // Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
11901 //   vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one))
11902 //   vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B))
performVecReduceAddCombine(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * ST)11903 static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
11904                                           const AArch64Subtarget *ST) {
11905   SDValue Op0 = N->getOperand(0);
11906   if (!ST->hasDotProd() || N->getValueType(0) != MVT::i32 ||
11907       Op0.getValueType().getVectorElementType() != MVT::i32)
11908     return SDValue();
11909 
11910   unsigned ExtOpcode = Op0.getOpcode();
11911   SDValue A = Op0;
11912   SDValue B;
11913   if (ExtOpcode == ISD::MUL) {
11914     A = Op0.getOperand(0);
11915     B = Op0.getOperand(1);
11916     if (A.getOpcode() != B.getOpcode() ||
11917         A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
11918       return SDValue();
11919     ExtOpcode = A.getOpcode();
11920   }
11921   if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND)
11922     return SDValue();
11923 
11924   EVT Op0VT = A.getOperand(0).getValueType();
11925   if (Op0VT != MVT::v8i8 && Op0VT != MVT::v16i8)
11926     return SDValue();
11927 
11928   SDLoc DL(Op0);
11929   // For non-mla reductions B can be set to 1. For MLA we take the operand of
11930   // the extend B.
11931   if (!B)
11932     B = DAG.getConstant(1, DL, Op0VT);
11933   else
11934     B = B.getOperand(0);
11935 
11936   SDValue Zeros =
11937       DAG.getConstant(0, DL, Op0VT == MVT::v8i8 ? MVT::v2i32 : MVT::v4i32);
11938   auto DotOpcode =
11939       (ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT;
11940   SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros,
11941                             A.getOperand(0), B);
11942   return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
11943 }
11944 
11945 // Given a ABS node, detect the following pattern:
11946 // (ABS (SUB (EXTEND a), (EXTEND b))).
11947 // Generates UABD/SABD instruction.
performABSCombine(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)11948 static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG,
11949                                  TargetLowering::DAGCombinerInfo &DCI,
11950                                  const AArch64Subtarget *Subtarget) {
11951   SDValue AbsOp1 = N->getOperand(0);
11952   SDValue Op0, Op1;
11953 
11954   if (AbsOp1.getOpcode() != ISD::SUB)
11955     return SDValue();
11956 
11957   Op0 = AbsOp1.getOperand(0);
11958   Op1 = AbsOp1.getOperand(1);
11959 
11960   unsigned Opc0 = Op0.getOpcode();
11961   // Check if the operands of the sub are (zero|sign)-extended.
11962   if (Opc0 != Op1.getOpcode() ||
11963       (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
11964     return SDValue();
11965 
11966   EVT VectorT1 = Op0.getOperand(0).getValueType();
11967   EVT VectorT2 = Op1.getOperand(0).getValueType();
11968   // Check if vectors are of same type and valid size.
11969   uint64_t Size = VectorT1.getFixedSizeInBits();
11970   if (VectorT1 != VectorT2 || (Size != 64 && Size != 128))
11971     return SDValue();
11972 
11973   // Check if vector element types are valid.
11974   EVT VT1 = VectorT1.getVectorElementType();
11975   if (VT1 != MVT::i8 && VT1 != MVT::i16 && VT1 != MVT::i32)
11976     return SDValue();
11977 
11978   Op0 = Op0.getOperand(0);
11979   Op1 = Op1.getOperand(0);
11980   unsigned ABDOpcode =
11981       (Opc0 == ISD::SIGN_EXTEND) ? AArch64ISD::SABD : AArch64ISD::UABD;
11982   SDValue ABD =
11983       DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1);
11984   return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD);
11985 }
11986 
performXorCombine(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)11987 static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG,
11988                                  TargetLowering::DAGCombinerInfo &DCI,
11989                                  const AArch64Subtarget *Subtarget) {
11990   if (DCI.isBeforeLegalizeOps())
11991     return SDValue();
11992 
11993   return foldVectorXorShiftIntoCmp(N, DAG, Subtarget);
11994 }
11995 
11996 SDValue
BuildSDIVPow2(SDNode * N,const APInt & Divisor,SelectionDAG & DAG,SmallVectorImpl<SDNode * > & Created) const11997 AArch64TargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor,
11998                                      SelectionDAG &DAG,
11999                                      SmallVectorImpl<SDNode *> &Created) const {
12000   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
12001   if (isIntDivCheap(N->getValueType(0), Attr))
12002     return SDValue(N,0); // Lower SDIV as SDIV
12003 
12004   // fold (sdiv X, pow2)
12005   EVT VT = N->getValueType(0);
12006   if ((VT != MVT::i32 && VT != MVT::i64) ||
12007       !(Divisor.isPowerOf2() || (-Divisor).isPowerOf2()))
12008     return SDValue();
12009 
12010   SDLoc DL(N);
12011   SDValue N0 = N->getOperand(0);
12012   unsigned Lg2 = Divisor.countTrailingZeros();
12013   SDValue Zero = DAG.getConstant(0, DL, VT);
12014   SDValue Pow2MinusOne = DAG.getConstant((1ULL << Lg2) - 1, DL, VT);
12015 
12016   // Add (N0 < 0) ? Pow2 - 1 : 0;
12017   SDValue CCVal;
12018   SDValue Cmp = getAArch64Cmp(N0, Zero, ISD::SETLT, CCVal, DAG, DL);
12019   SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Pow2MinusOne);
12020   SDValue CSel = DAG.getNode(AArch64ISD::CSEL, DL, VT, Add, N0, CCVal, Cmp);
12021 
12022   Created.push_back(Cmp.getNode());
12023   Created.push_back(Add.getNode());
12024   Created.push_back(CSel.getNode());
12025 
12026   // Divide by pow2.
12027   SDValue SRA =
12028       DAG.getNode(ISD::SRA, DL, VT, CSel, DAG.getConstant(Lg2, DL, MVT::i64));
12029 
12030   // If we're dividing by a positive value, we're done.  Otherwise, we must
12031   // negate the result.
12032   if (Divisor.isNonNegative())
12033     return SRA;
12034 
12035   Created.push_back(SRA.getNode());
12036   return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), SRA);
12037 }
12038 
IsSVECntIntrinsic(SDValue S)12039 static bool IsSVECntIntrinsic(SDValue S) {
12040   switch(getIntrinsicID(S.getNode())) {
12041   default:
12042     break;
12043   case Intrinsic::aarch64_sve_cntb:
12044   case Intrinsic::aarch64_sve_cnth:
12045   case Intrinsic::aarch64_sve_cntw:
12046   case Intrinsic::aarch64_sve_cntd:
12047     return true;
12048   }
12049   return false;
12050 }
12051 
12052 /// Calculates what the pre-extend type is, based on the extension
12053 /// operation node provided by \p Extend.
12054 ///
12055 /// In the case that \p Extend is a SIGN_EXTEND or a ZERO_EXTEND, the
12056 /// pre-extend type is pulled directly from the operand, while other extend
12057 /// operations need a bit more inspection to get this information.
12058 ///
12059 /// \param Extend The SDNode from the DAG that represents the extend operation
12060 /// \param DAG The SelectionDAG hosting the \p Extend node
12061 ///
12062 /// \returns The type representing the \p Extend source type, or \p MVT::Other
12063 /// if no valid type can be determined
calculatePreExtendType(SDValue Extend,SelectionDAG & DAG)12064 static EVT calculatePreExtendType(SDValue Extend, SelectionDAG &DAG) {
12065   switch (Extend.getOpcode()) {
12066   case ISD::SIGN_EXTEND:
12067   case ISD::ZERO_EXTEND:
12068     return Extend.getOperand(0).getValueType();
12069   case ISD::AssertSext:
12070   case ISD::AssertZext:
12071   case ISD::SIGN_EXTEND_INREG: {
12072     VTSDNode *TypeNode = dyn_cast<VTSDNode>(Extend.getOperand(1));
12073     if (!TypeNode)
12074       return MVT::Other;
12075     return TypeNode->getVT();
12076   }
12077   case ISD::AND: {
12078     ConstantSDNode *Constant =
12079         dyn_cast<ConstantSDNode>(Extend.getOperand(1).getNode());
12080     if (!Constant)
12081       return MVT::Other;
12082 
12083     uint32_t Mask = Constant->getZExtValue();
12084 
12085     if (Mask == UCHAR_MAX)
12086       return MVT::i8;
12087     else if (Mask == USHRT_MAX)
12088       return MVT::i16;
12089     else if (Mask == UINT_MAX)
12090       return MVT::i32;
12091 
12092     return MVT::Other;
12093   }
12094   default:
12095     return MVT::Other;
12096   }
12097 
12098   llvm_unreachable("Code path unhandled in calculatePreExtendType!");
12099 }
12100 
12101 /// Combines a dup(sext/zext) node pattern into sext/zext(dup)
12102 /// making use of the vector SExt/ZExt rather than the scalar SExt/ZExt
performCommonVectorExtendCombine(SDValue VectorShuffle,SelectionDAG & DAG)12103 static SDValue performCommonVectorExtendCombine(SDValue VectorShuffle,
12104                                                 SelectionDAG &DAG) {
12105 
12106   ShuffleVectorSDNode *ShuffleNode =
12107       dyn_cast<ShuffleVectorSDNode>(VectorShuffle.getNode());
12108   if (!ShuffleNode)
12109     return SDValue();
12110 
12111   // Ensuring the mask is zero before continuing
12112   if (!ShuffleNode->isSplat() || ShuffleNode->getSplatIndex() != 0)
12113     return SDValue();
12114 
12115   SDValue InsertVectorElt = VectorShuffle.getOperand(0);
12116 
12117   if (InsertVectorElt.getOpcode() != ISD::INSERT_VECTOR_ELT)
12118     return SDValue();
12119 
12120   SDValue InsertLane = InsertVectorElt.getOperand(2);
12121   ConstantSDNode *Constant = dyn_cast<ConstantSDNode>(InsertLane.getNode());
12122   // Ensures the insert is inserting into lane 0
12123   if (!Constant || Constant->getZExtValue() != 0)
12124     return SDValue();
12125 
12126   SDValue Extend = InsertVectorElt.getOperand(1);
12127   unsigned ExtendOpcode = Extend.getOpcode();
12128 
12129   bool IsSExt = ExtendOpcode == ISD::SIGN_EXTEND ||
12130                 ExtendOpcode == ISD::SIGN_EXTEND_INREG ||
12131                 ExtendOpcode == ISD::AssertSext;
12132   if (!IsSExt && ExtendOpcode != ISD::ZERO_EXTEND &&
12133       ExtendOpcode != ISD::AssertZext && ExtendOpcode != ISD::AND)
12134     return SDValue();
12135 
12136   EVT TargetType = VectorShuffle.getValueType();
12137   EVT PreExtendType = calculatePreExtendType(Extend, DAG);
12138 
12139   if ((TargetType != MVT::v8i16 && TargetType != MVT::v4i32 &&
12140        TargetType != MVT::v2i64) ||
12141       (PreExtendType == MVT::Other))
12142     return SDValue();
12143 
12144   // Restrict valid pre-extend data type
12145   if (PreExtendType != MVT::i8 && PreExtendType != MVT::i16 &&
12146       PreExtendType != MVT::i32)
12147     return SDValue();
12148 
12149   EVT PreExtendVT = TargetType.changeVectorElementType(PreExtendType);
12150 
12151   if (PreExtendVT.getVectorElementCount() != TargetType.getVectorElementCount())
12152     return SDValue();
12153 
12154   if (TargetType.getScalarSizeInBits() != PreExtendVT.getScalarSizeInBits() * 2)
12155     return SDValue();
12156 
12157   SDLoc DL(VectorShuffle);
12158 
12159   SDValue InsertVectorNode = DAG.getNode(
12160       InsertVectorElt.getOpcode(), DL, PreExtendVT, DAG.getUNDEF(PreExtendVT),
12161       DAG.getAnyExtOrTrunc(Extend.getOperand(0), DL, PreExtendType),
12162       DAG.getConstant(0, DL, MVT::i64));
12163 
12164   std::vector<int> ShuffleMask(TargetType.getVectorElementCount().getValue());
12165 
12166   SDValue VectorShuffleNode =
12167       DAG.getVectorShuffle(PreExtendVT, DL, InsertVectorNode,
12168                            DAG.getUNDEF(PreExtendVT), ShuffleMask);
12169 
12170   SDValue ExtendNode = DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND,
12171                                    DL, TargetType, VectorShuffleNode);
12172 
12173   return ExtendNode;
12174 }
12175 
12176 /// Combines a mul(dup(sext/zext)) node pattern into mul(sext/zext(dup))
12177 /// making use of the vector SExt/ZExt rather than the scalar SExt/ZExt
performMulVectorExtendCombine(SDNode * Mul,SelectionDAG & DAG)12178 static SDValue performMulVectorExtendCombine(SDNode *Mul, SelectionDAG &DAG) {
12179   // If the value type isn't a vector, none of the operands are going to be dups
12180   if (!Mul->getValueType(0).isVector())
12181     return SDValue();
12182 
12183   SDValue Op0 = performCommonVectorExtendCombine(Mul->getOperand(0), DAG);
12184   SDValue Op1 = performCommonVectorExtendCombine(Mul->getOperand(1), DAG);
12185 
12186   // Neither operands have been changed, don't make any further changes
12187   if (!Op0 && !Op1)
12188     return SDValue();
12189 
12190   SDLoc DL(Mul);
12191   return DAG.getNode(Mul->getOpcode(), DL, Mul->getValueType(0),
12192                      Op0 ? Op0 : Mul->getOperand(0),
12193                      Op1 ? Op1 : Mul->getOperand(1));
12194 }
12195 
performMulCombine(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)12196 static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
12197                                  TargetLowering::DAGCombinerInfo &DCI,
12198                                  const AArch64Subtarget *Subtarget) {
12199 
12200   if (SDValue Ext = performMulVectorExtendCombine(N, DAG))
12201     return Ext;
12202 
12203   if (DCI.isBeforeLegalizeOps())
12204     return SDValue();
12205 
12206   // The below optimizations require a constant RHS.
12207   if (!isa<ConstantSDNode>(N->getOperand(1)))
12208     return SDValue();
12209 
12210   SDValue N0 = N->getOperand(0);
12211   ConstantSDNode *C = cast<ConstantSDNode>(N->getOperand(1));
12212   const APInt &ConstValue = C->getAPIntValue();
12213 
12214   // Allow the scaling to be folded into the `cnt` instruction by preventing
12215   // the scaling to be obscured here. This makes it easier to pattern match.
12216   if (IsSVECntIntrinsic(N0) ||
12217      (N0->getOpcode() == ISD::TRUNCATE &&
12218       (IsSVECntIntrinsic(N0->getOperand(0)))))
12219        if (ConstValue.sge(1) && ConstValue.sle(16))
12220          return SDValue();
12221 
12222   // Multiplication of a power of two plus/minus one can be done more
12223   // cheaply as as shift+add/sub. For now, this is true unilaterally. If
12224   // future CPUs have a cheaper MADD instruction, this may need to be
12225   // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and
12226   // 64-bit is 5 cycles, so this is always a win.
12227   // More aggressively, some multiplications N0 * C can be lowered to
12228   // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M,
12229   // e.g. 6=3*2=(2+1)*2.
12230   // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45
12231   // which equals to (1+2)*16-(1+2).
12232   // TrailingZeroes is used to test if the mul can be lowered to
12233   // shift+add+shift.
12234   unsigned TrailingZeroes = ConstValue.countTrailingZeros();
12235   if (TrailingZeroes) {
12236     // Conservatively do not lower to shift+add+shift if the mul might be
12237     // folded into smul or umul.
12238     if (N0->hasOneUse() && (isSignExtended(N0.getNode(), DAG) ||
12239                             isZeroExtended(N0.getNode(), DAG)))
12240       return SDValue();
12241     // Conservatively do not lower to shift+add+shift if the mul might be
12242     // folded into madd or msub.
12243     if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ADD ||
12244                            N->use_begin()->getOpcode() == ISD::SUB))
12245       return SDValue();
12246   }
12247   // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub
12248   // and shift+add+shift.
12249   APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes);
12250 
12251   unsigned ShiftAmt, AddSubOpc;
12252   // Is the shifted value the LHS operand of the add/sub?
12253   bool ShiftValUseIsN0 = true;
12254   // Do we need to negate the result?
12255   bool NegateResult = false;
12256 
12257   if (ConstValue.isNonNegative()) {
12258     // (mul x, 2^N + 1) => (add (shl x, N), x)
12259     // (mul x, 2^N - 1) => (sub (shl x, N), x)
12260     // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
12261     APInt SCVMinus1 = ShiftedConstValue - 1;
12262     APInt CVPlus1 = ConstValue + 1;
12263     if (SCVMinus1.isPowerOf2()) {
12264       ShiftAmt = SCVMinus1.logBase2();
12265       AddSubOpc = ISD::ADD;
12266     } else if (CVPlus1.isPowerOf2()) {
12267       ShiftAmt = CVPlus1.logBase2();
12268       AddSubOpc = ISD::SUB;
12269     } else
12270       return SDValue();
12271   } else {
12272     // (mul x, -(2^N - 1)) => (sub x, (shl x, N))
12273     // (mul x, -(2^N + 1)) => - (add (shl x, N), x)
12274     APInt CVNegPlus1 = -ConstValue + 1;
12275     APInt CVNegMinus1 = -ConstValue - 1;
12276     if (CVNegPlus1.isPowerOf2()) {
12277       ShiftAmt = CVNegPlus1.logBase2();
12278       AddSubOpc = ISD::SUB;
12279       ShiftValUseIsN0 = false;
12280     } else if (CVNegMinus1.isPowerOf2()) {
12281       ShiftAmt = CVNegMinus1.logBase2();
12282       AddSubOpc = ISD::ADD;
12283       NegateResult = true;
12284     } else
12285       return SDValue();
12286   }
12287 
12288   SDLoc DL(N);
12289   EVT VT = N->getValueType(0);
12290   SDValue ShiftedVal = DAG.getNode(ISD::SHL, DL, VT, N0,
12291                                    DAG.getConstant(ShiftAmt, DL, MVT::i64));
12292 
12293   SDValue AddSubN0 = ShiftValUseIsN0 ? ShiftedVal : N0;
12294   SDValue AddSubN1 = ShiftValUseIsN0 ? N0 : ShiftedVal;
12295   SDValue Res = DAG.getNode(AddSubOpc, DL, VT, AddSubN0, AddSubN1);
12296   assert(!(NegateResult && TrailingZeroes) &&
12297          "NegateResult and TrailingZeroes cannot both be true for now.");
12298   // Negate the result.
12299   if (NegateResult)
12300     return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Res);
12301   // Shift the result.
12302   if (TrailingZeroes)
12303     return DAG.getNode(ISD::SHL, DL, VT, Res,
12304                        DAG.getConstant(TrailingZeroes, DL, MVT::i64));
12305   return Res;
12306 }
12307 
performVectorCompareAndMaskUnaryOpCombine(SDNode * N,SelectionDAG & DAG)12308 static SDValue performVectorCompareAndMaskUnaryOpCombine(SDNode *N,
12309                                                          SelectionDAG &DAG) {
12310   // Take advantage of vector comparisons producing 0 or -1 in each lane to
12311   // optimize away operation when it's from a constant.
12312   //
12313   // The general transformation is:
12314   //    UNARYOP(AND(VECTOR_CMP(x,y), constant)) -->
12315   //       AND(VECTOR_CMP(x,y), constant2)
12316   //    constant2 = UNARYOP(constant)
12317 
12318   // Early exit if this isn't a vector operation, the operand of the
12319   // unary operation isn't a bitwise AND, or if the sizes of the operations
12320   // aren't the same.
12321   EVT VT = N->getValueType(0);
12322   if (!VT.isVector() || N->getOperand(0)->getOpcode() != ISD::AND ||
12323       N->getOperand(0)->getOperand(0)->getOpcode() != ISD::SETCC ||
12324       VT.getSizeInBits() != N->getOperand(0)->getValueType(0).getSizeInBits())
12325     return SDValue();
12326 
12327   // Now check that the other operand of the AND is a constant. We could
12328   // make the transformation for non-constant splats as well, but it's unclear
12329   // that would be a benefit as it would not eliminate any operations, just
12330   // perform one more step in scalar code before moving to the vector unit.
12331   if (BuildVectorSDNode *BV =
12332           dyn_cast<BuildVectorSDNode>(N->getOperand(0)->getOperand(1))) {
12333     // Bail out if the vector isn't a constant.
12334     if (!BV->isConstant())
12335       return SDValue();
12336 
12337     // Everything checks out. Build up the new and improved node.
12338     SDLoc DL(N);
12339     EVT IntVT = BV->getValueType(0);
12340     // Create a new constant of the appropriate type for the transformed
12341     // DAG.
12342     SDValue SourceConst = DAG.getNode(N->getOpcode(), DL, VT, SDValue(BV, 0));
12343     // The AND node needs bitcasts to/from an integer vector type around it.
12344     SDValue MaskConst = DAG.getNode(ISD::BITCAST, DL, IntVT, SourceConst);
12345     SDValue NewAnd = DAG.getNode(ISD::AND, DL, IntVT,
12346                                  N->getOperand(0)->getOperand(0), MaskConst);
12347     SDValue Res = DAG.getNode(ISD::BITCAST, DL, VT, NewAnd);
12348     return Res;
12349   }
12350 
12351   return SDValue();
12352 }
12353 
performIntToFpCombine(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)12354 static SDValue performIntToFpCombine(SDNode *N, SelectionDAG &DAG,
12355                                      const AArch64Subtarget *Subtarget) {
12356   // First try to optimize away the conversion when it's conditionally from
12357   // a constant. Vectors only.
12358   if (SDValue Res = performVectorCompareAndMaskUnaryOpCombine(N, DAG))
12359     return Res;
12360 
12361   EVT VT = N->getValueType(0);
12362   if (VT != MVT::f32 && VT != MVT::f64)
12363     return SDValue();
12364 
12365   // Only optimize when the source and destination types have the same width.
12366   if (VT.getSizeInBits() != N->getOperand(0).getValueSizeInBits())
12367     return SDValue();
12368 
12369   // If the result of an integer load is only used by an integer-to-float
12370   // conversion, use a fp load instead and a AdvSIMD scalar {S|U}CVTF instead.
12371   // This eliminates an "integer-to-vector-move" UOP and improves throughput.
12372   SDValue N0 = N->getOperand(0);
12373   if (Subtarget->hasNEON() && ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
12374       // Do not change the width of a volatile load.
12375       !cast<LoadSDNode>(N0)->isVolatile()) {
12376     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
12377     SDValue Load = DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
12378                                LN0->getPointerInfo(), LN0->getAlignment(),
12379                                LN0->getMemOperand()->getFlags());
12380 
12381     // Make sure successors of the original load stay after it by updating them
12382     // to use the new Chain.
12383     DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), Load.getValue(1));
12384 
12385     unsigned Opcode =
12386         (N->getOpcode() == ISD::SINT_TO_FP) ? AArch64ISD::SITOF : AArch64ISD::UITOF;
12387     return DAG.getNode(Opcode, SDLoc(N), VT, Load);
12388   }
12389 
12390   return SDValue();
12391 }
12392 
12393 /// Fold a floating-point multiply by power of two into floating-point to
12394 /// fixed-point conversion.
performFpToIntCombine(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)12395 static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG,
12396                                      TargetLowering::DAGCombinerInfo &DCI,
12397                                      const AArch64Subtarget *Subtarget) {
12398   if (!Subtarget->hasNEON())
12399     return SDValue();
12400 
12401   if (!N->getValueType(0).isSimple())
12402     return SDValue();
12403 
12404   SDValue Op = N->getOperand(0);
12405   if (!Op.getValueType().isVector() || !Op.getValueType().isSimple() ||
12406       Op.getOpcode() != ISD::FMUL)
12407     return SDValue();
12408 
12409   SDValue ConstVec = Op->getOperand(1);
12410   if (!isa<BuildVectorSDNode>(ConstVec))
12411     return SDValue();
12412 
12413   MVT FloatTy = Op.getSimpleValueType().getVectorElementType();
12414   uint32_t FloatBits = FloatTy.getSizeInBits();
12415   if (FloatBits != 32 && FloatBits != 64)
12416     return SDValue();
12417 
12418   MVT IntTy = N->getSimpleValueType(0).getVectorElementType();
12419   uint32_t IntBits = IntTy.getSizeInBits();
12420   if (IntBits != 16 && IntBits != 32 && IntBits != 64)
12421     return SDValue();
12422 
12423   // Avoid conversions where iN is larger than the float (e.g., float -> i64).
12424   if (IntBits > FloatBits)
12425     return SDValue();
12426 
12427   BitVector UndefElements;
12428   BuildVectorSDNode *BV = cast<BuildVectorSDNode>(ConstVec);
12429   int32_t Bits = IntBits == 64 ? 64 : 32;
12430   int32_t C = BV->getConstantFPSplatPow2ToLog2Int(&UndefElements, Bits + 1);
12431   if (C == -1 || C == 0 || C > Bits)
12432     return SDValue();
12433 
12434   MVT ResTy;
12435   unsigned NumLanes = Op.getValueType().getVectorNumElements();
12436   switch (NumLanes) {
12437   default:
12438     return SDValue();
12439   case 2:
12440     ResTy = FloatBits == 32 ? MVT::v2i32 : MVT::v2i64;
12441     break;
12442   case 4:
12443     ResTy = FloatBits == 32 ? MVT::v4i32 : MVT::v4i64;
12444     break;
12445   }
12446 
12447   if (ResTy == MVT::v4i64 && DCI.isBeforeLegalizeOps())
12448     return SDValue();
12449 
12450   assert((ResTy != MVT::v4i64 || DCI.isBeforeLegalizeOps()) &&
12451          "Illegal vector type after legalization");
12452 
12453   SDLoc DL(N);
12454   bool IsSigned = N->getOpcode() == ISD::FP_TO_SINT;
12455   unsigned IntrinsicOpcode = IsSigned ? Intrinsic::aarch64_neon_vcvtfp2fxs
12456                                       : Intrinsic::aarch64_neon_vcvtfp2fxu;
12457   SDValue FixConv =
12458       DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ResTy,
12459                   DAG.getConstant(IntrinsicOpcode, DL, MVT::i32),
12460                   Op->getOperand(0), DAG.getConstant(C, DL, MVT::i32));
12461   // We can handle smaller integers by generating an extra trunc.
12462   if (IntBits < FloatBits)
12463     FixConv = DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), FixConv);
12464 
12465   return FixConv;
12466 }
12467 
12468 /// Fold a floating-point divide by power of two into fixed-point to
12469 /// floating-point conversion.
performFDivCombine(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)12470 static SDValue performFDivCombine(SDNode *N, SelectionDAG &DAG,
12471                                   TargetLowering::DAGCombinerInfo &DCI,
12472                                   const AArch64Subtarget *Subtarget) {
12473   if (!Subtarget->hasNEON())
12474     return SDValue();
12475 
12476   SDValue Op = N->getOperand(0);
12477   unsigned Opc = Op->getOpcode();
12478   if (!Op.getValueType().isVector() || !Op.getValueType().isSimple() ||
12479       !Op.getOperand(0).getValueType().isSimple() ||
12480       (Opc != ISD::SINT_TO_FP && Opc != ISD::UINT_TO_FP))
12481     return SDValue();
12482 
12483   SDValue ConstVec = N->getOperand(1);
12484   if (!isa<BuildVectorSDNode>(ConstVec))
12485     return SDValue();
12486 
12487   MVT IntTy = Op.getOperand(0).getSimpleValueType().getVectorElementType();
12488   int32_t IntBits = IntTy.getSizeInBits();
12489   if (IntBits != 16 && IntBits != 32 && IntBits != 64)
12490     return SDValue();
12491 
12492   MVT FloatTy = N->getSimpleValueType(0).getVectorElementType();
12493   int32_t FloatBits = FloatTy.getSizeInBits();
12494   if (FloatBits != 32 && FloatBits != 64)
12495     return SDValue();
12496 
12497   // Avoid conversions where iN is larger than the float (e.g., i64 -> float).
12498   if (IntBits > FloatBits)
12499     return SDValue();
12500 
12501   BitVector UndefElements;
12502   BuildVectorSDNode *BV = cast<BuildVectorSDNode>(ConstVec);
12503   int32_t C = BV->getConstantFPSplatPow2ToLog2Int(&UndefElements, FloatBits + 1);
12504   if (C == -1 || C == 0 || C > FloatBits)
12505     return SDValue();
12506 
12507   MVT ResTy;
12508   unsigned NumLanes = Op.getValueType().getVectorNumElements();
12509   switch (NumLanes) {
12510   default:
12511     return SDValue();
12512   case 2:
12513     ResTy = FloatBits == 32 ? MVT::v2i32 : MVT::v2i64;
12514     break;
12515   case 4:
12516     ResTy = FloatBits == 32 ? MVT::v4i32 : MVT::v4i64;
12517     break;
12518   }
12519 
12520   if (ResTy == MVT::v4i64 && DCI.isBeforeLegalizeOps())
12521     return SDValue();
12522 
12523   SDLoc DL(N);
12524   SDValue ConvInput = Op.getOperand(0);
12525   bool IsSigned = Opc == ISD::SINT_TO_FP;
12526   if (IntBits < FloatBits)
12527     ConvInput = DAG.getNode(IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL,
12528                             ResTy, ConvInput);
12529 
12530   unsigned IntrinsicOpcode = IsSigned ? Intrinsic::aarch64_neon_vcvtfxs2fp
12531                                       : Intrinsic::aarch64_neon_vcvtfxu2fp;
12532   return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
12533                      DAG.getConstant(IntrinsicOpcode, DL, MVT::i32), ConvInput,
12534                      DAG.getConstant(C, DL, MVT::i32));
12535 }
12536 
12537 /// An EXTR instruction is made up of two shifts, ORed together. This helper
12538 /// searches for and classifies those shifts.
findEXTRHalf(SDValue N,SDValue & Src,uint32_t & ShiftAmount,bool & FromHi)12539 static bool findEXTRHalf(SDValue N, SDValue &Src, uint32_t &ShiftAmount,
12540                          bool &FromHi) {
12541   if (N.getOpcode() == ISD::SHL)
12542     FromHi = false;
12543   else if (N.getOpcode() == ISD::SRL)
12544     FromHi = true;
12545   else
12546     return false;
12547 
12548   if (!isa<ConstantSDNode>(N.getOperand(1)))
12549     return false;
12550 
12551   ShiftAmount = N->getConstantOperandVal(1);
12552   Src = N->getOperand(0);
12553   return true;
12554 }
12555 
12556 /// EXTR instruction extracts a contiguous chunk of bits from two existing
12557 /// registers viewed as a high/low pair. This function looks for the pattern:
12558 /// <tt>(or (shl VAL1, \#N), (srl VAL2, \#RegWidth-N))</tt> and replaces it
12559 /// with an EXTR. Can't quite be done in TableGen because the two immediates
12560 /// aren't independent.
tryCombineToEXTR(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)12561 static SDValue tryCombineToEXTR(SDNode *N,
12562                                 TargetLowering::DAGCombinerInfo &DCI) {
12563   SelectionDAG &DAG = DCI.DAG;
12564   SDLoc DL(N);
12565   EVT VT = N->getValueType(0);
12566 
12567   assert(N->getOpcode() == ISD::OR && "Unexpected root");
12568 
12569   if (VT != MVT::i32 && VT != MVT::i64)
12570     return SDValue();
12571 
12572   SDValue LHS;
12573   uint32_t ShiftLHS = 0;
12574   bool LHSFromHi = false;
12575   if (!findEXTRHalf(N->getOperand(0), LHS, ShiftLHS, LHSFromHi))
12576     return SDValue();
12577 
12578   SDValue RHS;
12579   uint32_t ShiftRHS = 0;
12580   bool RHSFromHi = false;
12581   if (!findEXTRHalf(N->getOperand(1), RHS, ShiftRHS, RHSFromHi))
12582     return SDValue();
12583 
12584   // If they're both trying to come from the high part of the register, they're
12585   // not really an EXTR.
12586   if (LHSFromHi == RHSFromHi)
12587     return SDValue();
12588 
12589   if (ShiftLHS + ShiftRHS != VT.getSizeInBits())
12590     return SDValue();
12591 
12592   if (LHSFromHi) {
12593     std::swap(LHS, RHS);
12594     std::swap(ShiftLHS, ShiftRHS);
12595   }
12596 
12597   return DAG.getNode(AArch64ISD::EXTR, DL, VT, LHS, RHS,
12598                      DAG.getConstant(ShiftRHS, DL, MVT::i64));
12599 }
12600 
tryCombineToBSL(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)12601 static SDValue tryCombineToBSL(SDNode *N,
12602                                 TargetLowering::DAGCombinerInfo &DCI) {
12603   EVT VT = N->getValueType(0);
12604   SelectionDAG &DAG = DCI.DAG;
12605   SDLoc DL(N);
12606 
12607   if (!VT.isVector())
12608     return SDValue();
12609 
12610   // The combining code currently only works for NEON vectors. In particular,
12611   // it does not work for SVE when dealing with vectors wider than 128 bits.
12612   if (!VT.is64BitVector() && !VT.is128BitVector())
12613     return SDValue();
12614 
12615   SDValue N0 = N->getOperand(0);
12616   if (N0.getOpcode() != ISD::AND)
12617     return SDValue();
12618 
12619   SDValue N1 = N->getOperand(1);
12620   if (N1.getOpcode() != ISD::AND)
12621     return SDValue();
12622 
12623   // InstCombine does (not (neg a)) => (add a -1).
12624   // Try: (or (and (neg a) b) (and (add a -1) c)) => (bsl (neg a) b c)
12625   // Loop over all combinations of AND operands.
12626   for (int i = 1; i >= 0; --i) {
12627     for (int j = 1; j >= 0; --j) {
12628       SDValue O0 = N0->getOperand(i);
12629       SDValue O1 = N1->getOperand(j);
12630       SDValue Sub, Add, SubSibling, AddSibling;
12631 
12632       // Find a SUB and an ADD operand, one from each AND.
12633       if (O0.getOpcode() == ISD::SUB && O1.getOpcode() == ISD::ADD) {
12634         Sub = O0;
12635         Add = O1;
12636         SubSibling = N0->getOperand(1 - i);
12637         AddSibling = N1->getOperand(1 - j);
12638       } else if (O0.getOpcode() == ISD::ADD && O1.getOpcode() == ISD::SUB) {
12639         Add = O0;
12640         Sub = O1;
12641         AddSibling = N0->getOperand(1 - i);
12642         SubSibling = N1->getOperand(1 - j);
12643       } else
12644         continue;
12645 
12646       if (!ISD::isBuildVectorAllZeros(Sub.getOperand(0).getNode()))
12647         continue;
12648 
12649       // Constant ones is always righthand operand of the Add.
12650       if (!ISD::isBuildVectorAllOnes(Add.getOperand(1).getNode()))
12651         continue;
12652 
12653       if (Sub.getOperand(1) != Add.getOperand(0))
12654         continue;
12655 
12656       return DAG.getNode(AArch64ISD::BSP, DL, VT, Sub, SubSibling, AddSibling);
12657     }
12658   }
12659 
12660   // (or (and a b) (and (not a) c)) => (bsl a b c)
12661   // We only have to look for constant vectors here since the general, variable
12662   // case can be handled in TableGen.
12663   unsigned Bits = VT.getScalarSizeInBits();
12664   uint64_t BitMask = Bits == 64 ? -1ULL : ((1ULL << Bits) - 1);
12665   for (int i = 1; i >= 0; --i)
12666     for (int j = 1; j >= 0; --j) {
12667       BuildVectorSDNode *BVN0 = dyn_cast<BuildVectorSDNode>(N0->getOperand(i));
12668       BuildVectorSDNode *BVN1 = dyn_cast<BuildVectorSDNode>(N1->getOperand(j));
12669       if (!BVN0 || !BVN1)
12670         continue;
12671 
12672       bool FoundMatch = true;
12673       for (unsigned k = 0; k < VT.getVectorNumElements(); ++k) {
12674         ConstantSDNode *CN0 = dyn_cast<ConstantSDNode>(BVN0->getOperand(k));
12675         ConstantSDNode *CN1 = dyn_cast<ConstantSDNode>(BVN1->getOperand(k));
12676         if (!CN0 || !CN1 ||
12677             CN0->getZExtValue() != (BitMask & ~CN1->getZExtValue())) {
12678           FoundMatch = false;
12679           break;
12680         }
12681       }
12682 
12683       if (FoundMatch)
12684         return DAG.getNode(AArch64ISD::BSP, DL, VT, SDValue(BVN0, 0),
12685                            N0->getOperand(1 - i), N1->getOperand(1 - j));
12686     }
12687 
12688   return SDValue();
12689 }
12690 
performORCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)12691 static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
12692                                 const AArch64Subtarget *Subtarget) {
12693   // Attempt to form an EXTR from (or (shl VAL1, #N), (srl VAL2, #RegWidth-N))
12694   SelectionDAG &DAG = DCI.DAG;
12695   EVT VT = N->getValueType(0);
12696 
12697   if (!DAG.getTargetLoweringInfo().isTypeLegal(VT))
12698     return SDValue();
12699 
12700   if (SDValue Res = tryCombineToEXTR(N, DCI))
12701     return Res;
12702 
12703   if (SDValue Res = tryCombineToBSL(N, DCI))
12704     return Res;
12705 
12706   return SDValue();
12707 }
12708 
isConstantSplatVectorMaskForType(SDNode * N,EVT MemVT)12709 static bool isConstantSplatVectorMaskForType(SDNode *N, EVT MemVT) {
12710   if (!MemVT.getVectorElementType().isSimple())
12711     return false;
12712 
12713   uint64_t MaskForTy = 0ull;
12714   switch (MemVT.getVectorElementType().getSimpleVT().SimpleTy) {
12715   case MVT::i8:
12716     MaskForTy = 0xffull;
12717     break;
12718   case MVT::i16:
12719     MaskForTy = 0xffffull;
12720     break;
12721   case MVT::i32:
12722     MaskForTy = 0xffffffffull;
12723     break;
12724   default:
12725     return false;
12726     break;
12727   }
12728 
12729   if (N->getOpcode() == AArch64ISD::DUP || N->getOpcode() == ISD::SPLAT_VECTOR)
12730     if (auto *Op0 = dyn_cast<ConstantSDNode>(N->getOperand(0)))
12731       return Op0->getAPIntValue().getLimitedValue() == MaskForTy;
12732 
12733   return false;
12734 }
12735 
performSVEAndCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)12736 static SDValue performSVEAndCombine(SDNode *N,
12737                                     TargetLowering::DAGCombinerInfo &DCI) {
12738   if (DCI.isBeforeLegalizeOps())
12739     return SDValue();
12740 
12741   SelectionDAG &DAG = DCI.DAG;
12742   SDValue Src = N->getOperand(0);
12743   unsigned Opc = Src->getOpcode();
12744 
12745   // Zero/any extend of an unsigned unpack
12746   if (Opc == AArch64ISD::UUNPKHI || Opc == AArch64ISD::UUNPKLO) {
12747     SDValue UnpkOp = Src->getOperand(0);
12748     SDValue Dup = N->getOperand(1);
12749 
12750     if (Dup.getOpcode() != AArch64ISD::DUP)
12751       return SDValue();
12752 
12753     SDLoc DL(N);
12754     ConstantSDNode *C = dyn_cast<ConstantSDNode>(Dup->getOperand(0));
12755     uint64_t ExtVal = C->getZExtValue();
12756 
12757     // If the mask is fully covered by the unpack, we don't need to push
12758     // a new AND onto the operand
12759     EVT EltTy = UnpkOp->getValueType(0).getVectorElementType();
12760     if ((ExtVal == 0xFF && EltTy == MVT::i8) ||
12761         (ExtVal == 0xFFFF && EltTy == MVT::i16) ||
12762         (ExtVal == 0xFFFFFFFF && EltTy == MVT::i32))
12763       return Src;
12764 
12765     // Truncate to prevent a DUP with an over wide constant
12766     APInt Mask = C->getAPIntValue().trunc(EltTy.getSizeInBits());
12767 
12768     // Otherwise, make sure we propagate the AND to the operand
12769     // of the unpack
12770     Dup = DAG.getNode(AArch64ISD::DUP, DL,
12771                       UnpkOp->getValueType(0),
12772                       DAG.getConstant(Mask.zextOrTrunc(32), DL, MVT::i32));
12773 
12774     SDValue And = DAG.getNode(ISD::AND, DL,
12775                               UnpkOp->getValueType(0), UnpkOp, Dup);
12776 
12777     return DAG.getNode(Opc, DL, N->getValueType(0), And);
12778   }
12779 
12780   if (!EnableCombineMGatherIntrinsics)
12781     return SDValue();
12782 
12783   SDValue Mask = N->getOperand(1);
12784 
12785   if (!Src.hasOneUse())
12786     return SDValue();
12787 
12788   EVT MemVT;
12789 
12790   // SVE load instructions perform an implicit zero-extend, which makes them
12791   // perfect candidates for combining.
12792   switch (Opc) {
12793   case AArch64ISD::LD1_MERGE_ZERO:
12794   case AArch64ISD::LDNF1_MERGE_ZERO:
12795   case AArch64ISD::LDFF1_MERGE_ZERO:
12796     MemVT = cast<VTSDNode>(Src->getOperand(3))->getVT();
12797     break;
12798   case AArch64ISD::GLD1_MERGE_ZERO:
12799   case AArch64ISD::GLD1_SCALED_MERGE_ZERO:
12800   case AArch64ISD::GLD1_SXTW_MERGE_ZERO:
12801   case AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO:
12802   case AArch64ISD::GLD1_UXTW_MERGE_ZERO:
12803   case AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO:
12804   case AArch64ISD::GLD1_IMM_MERGE_ZERO:
12805   case AArch64ISD::GLDFF1_MERGE_ZERO:
12806   case AArch64ISD::GLDFF1_SCALED_MERGE_ZERO:
12807   case AArch64ISD::GLDFF1_SXTW_MERGE_ZERO:
12808   case AArch64ISD::GLDFF1_SXTW_SCALED_MERGE_ZERO:
12809   case AArch64ISD::GLDFF1_UXTW_MERGE_ZERO:
12810   case AArch64ISD::GLDFF1_UXTW_SCALED_MERGE_ZERO:
12811   case AArch64ISD::GLDFF1_IMM_MERGE_ZERO:
12812   case AArch64ISD::GLDNT1_MERGE_ZERO:
12813     MemVT = cast<VTSDNode>(Src->getOperand(4))->getVT();
12814     break;
12815   default:
12816     return SDValue();
12817   }
12818 
12819   if (isConstantSplatVectorMaskForType(Mask.getNode(), MemVT))
12820     return Src;
12821 
12822   return SDValue();
12823 }
12824 
performANDCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)12825 static SDValue performANDCombine(SDNode *N,
12826                                  TargetLowering::DAGCombinerInfo &DCI) {
12827   SelectionDAG &DAG = DCI.DAG;
12828   SDValue LHS = N->getOperand(0);
12829   EVT VT = N->getValueType(0);
12830   if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT))
12831     return SDValue();
12832 
12833   if (VT.isScalableVector())
12834     return performSVEAndCombine(N, DCI);
12835 
12836   // The combining code below works only for NEON vectors. In particular, it
12837   // does not work for SVE when dealing with vectors wider than 128 bits.
12838   if (!(VT.is64BitVector() || VT.is128BitVector()))
12839     return SDValue();
12840 
12841   BuildVectorSDNode *BVN =
12842       dyn_cast<BuildVectorSDNode>(N->getOperand(1).getNode());
12843   if (!BVN)
12844     return SDValue();
12845 
12846   // AND does not accept an immediate, so check if we can use a BIC immediate
12847   // instruction instead. We do this here instead of using a (and x, (mvni imm))
12848   // pattern in isel, because some immediates may be lowered to the preferred
12849   // (and x, (movi imm)) form, even though an mvni representation also exists.
12850   APInt DefBits(VT.getSizeInBits(), 0);
12851   APInt UndefBits(VT.getSizeInBits(), 0);
12852   if (resolveBuildVector(BVN, DefBits, UndefBits)) {
12853     SDValue NewOp;
12854 
12855     DefBits = ~DefBits;
12856     if ((NewOp = tryAdvSIMDModImm32(AArch64ISD::BICi, SDValue(N, 0), DAG,
12857                                     DefBits, &LHS)) ||
12858         (NewOp = tryAdvSIMDModImm16(AArch64ISD::BICi, SDValue(N, 0), DAG,
12859                                     DefBits, &LHS)))
12860       return NewOp;
12861 
12862     UndefBits = ~UndefBits;
12863     if ((NewOp = tryAdvSIMDModImm32(AArch64ISD::BICi, SDValue(N, 0), DAG,
12864                                     UndefBits, &LHS)) ||
12865         (NewOp = tryAdvSIMDModImm16(AArch64ISD::BICi, SDValue(N, 0), DAG,
12866                                     UndefBits, &LHS)))
12867       return NewOp;
12868   }
12869 
12870   return SDValue();
12871 }
12872 
performSRLCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)12873 static SDValue performSRLCombine(SDNode *N,
12874                                  TargetLowering::DAGCombinerInfo &DCI) {
12875   SelectionDAG &DAG = DCI.DAG;
12876   EVT VT = N->getValueType(0);
12877   if (VT != MVT::i32 && VT != MVT::i64)
12878     return SDValue();
12879 
12880   // Canonicalize (srl (bswap i32 x), 16) to (rotr (bswap i32 x), 16), if the
12881   // high 16-bits of x are zero. Similarly, canonicalize (srl (bswap i64 x), 32)
12882   // to (rotr (bswap i64 x), 32), if the high 32-bits of x are zero.
12883   SDValue N0 = N->getOperand(0);
12884   if (N0.getOpcode() == ISD::BSWAP) {
12885     SDLoc DL(N);
12886     SDValue N1 = N->getOperand(1);
12887     SDValue N00 = N0.getOperand(0);
12888     if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(N1)) {
12889       uint64_t ShiftAmt = C->getZExtValue();
12890       if (VT == MVT::i32 && ShiftAmt == 16 &&
12891           DAG.MaskedValueIsZero(N00, APInt::getHighBitsSet(32, 16)))
12892         return DAG.getNode(ISD::ROTR, DL, VT, N0, N1);
12893       if (VT == MVT::i64 && ShiftAmt == 32 &&
12894           DAG.MaskedValueIsZero(N00, APInt::getHighBitsSet(64, 32)))
12895         return DAG.getNode(ISD::ROTR, DL, VT, N0, N1);
12896     }
12897   }
12898   return SDValue();
12899 }
12900 
12901 // Attempt to form urhadd(OpA, OpB) from
12902 // truncate(vlshr(sub(zext(OpB), xor(zext(OpA), Ones(ElemSizeInBits))), 1))
12903 // or uhadd(OpA, OpB) from truncate(vlshr(add(zext(OpA), zext(OpB)), 1)).
12904 // The original form of the first expression is
12905 // truncate(srl(add(zext(OpB), add(zext(OpA), 1)), 1)) and the
12906 // (OpA + OpB + 1) subexpression will have been changed to (OpB - (~OpA)).
12907 // Before this function is called the srl will have been lowered to
12908 // AArch64ISD::VLSHR.
12909 // This pass can also recognize signed variants of the patterns that use sign
12910 // extension instead of zero extension and form a srhadd(OpA, OpB) or a
12911 // shadd(OpA, OpB) from them.
12912 static SDValue
performVectorTruncateCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)12913 performVectorTruncateCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
12914                              SelectionDAG &DAG) {
12915   EVT VT = N->getValueType(0);
12916 
12917   // Since we are looking for a right shift by a constant value of 1 and we are
12918   // operating on types at least 16 bits in length (sign/zero extended OpA and
12919   // OpB, which are at least 8 bits), it follows that the truncate will always
12920   // discard the shifted-in bit and therefore the right shift will be logical
12921   // regardless of the signedness of OpA and OpB.
12922   SDValue Shift = N->getOperand(0);
12923   if (Shift.getOpcode() != AArch64ISD::VLSHR)
12924     return SDValue();
12925 
12926   // Is the right shift using an immediate value of 1?
12927   uint64_t ShiftAmount = Shift.getConstantOperandVal(1);
12928   if (ShiftAmount != 1)
12929     return SDValue();
12930 
12931   SDValue ExtendOpA, ExtendOpB;
12932   SDValue ShiftOp0 = Shift.getOperand(0);
12933   unsigned ShiftOp0Opc = ShiftOp0.getOpcode();
12934   if (ShiftOp0Opc == ISD::SUB) {
12935 
12936     SDValue Xor = ShiftOp0.getOperand(1);
12937     if (Xor.getOpcode() != ISD::XOR)
12938       return SDValue();
12939 
12940     // Is the XOR using a constant amount of all ones in the right hand side?
12941     uint64_t C;
12942     if (!isAllConstantBuildVector(Xor.getOperand(1), C))
12943       return SDValue();
12944 
12945     unsigned ElemSizeInBits = VT.getScalarSizeInBits();
12946     APInt CAsAPInt(ElemSizeInBits, C);
12947     if (CAsAPInt != APInt::getAllOnesValue(ElemSizeInBits))
12948       return SDValue();
12949 
12950     ExtendOpA = Xor.getOperand(0);
12951     ExtendOpB = ShiftOp0.getOperand(0);
12952   } else if (ShiftOp0Opc == ISD::ADD) {
12953     ExtendOpA = ShiftOp0.getOperand(0);
12954     ExtendOpB = ShiftOp0.getOperand(1);
12955   } else
12956     return SDValue();
12957 
12958   unsigned ExtendOpAOpc = ExtendOpA.getOpcode();
12959   unsigned ExtendOpBOpc = ExtendOpB.getOpcode();
12960   if (!(ExtendOpAOpc == ExtendOpBOpc &&
12961         (ExtendOpAOpc == ISD::ZERO_EXTEND || ExtendOpAOpc == ISD::SIGN_EXTEND)))
12962     return SDValue();
12963 
12964   // Is the result of the right shift being truncated to the same value type as
12965   // the original operands, OpA and OpB?
12966   SDValue OpA = ExtendOpA.getOperand(0);
12967   SDValue OpB = ExtendOpB.getOperand(0);
12968   EVT OpAVT = OpA.getValueType();
12969   assert(ExtendOpA.getValueType() == ExtendOpB.getValueType());
12970   if (!(VT == OpAVT && OpAVT == OpB.getValueType()))
12971     return SDValue();
12972 
12973   SDLoc DL(N);
12974   bool IsSignExtend = ExtendOpAOpc == ISD::SIGN_EXTEND;
12975   bool IsRHADD = ShiftOp0Opc == ISD::SUB;
12976   unsigned HADDOpc = IsSignExtend
12977                          ? (IsRHADD ? AArch64ISD::SRHADD : AArch64ISD::SHADD)
12978                          : (IsRHADD ? AArch64ISD::URHADD : AArch64ISD::UHADD);
12979   SDValue ResultHADD = DAG.getNode(HADDOpc, DL, VT, OpA, OpB);
12980 
12981   return ResultHADD;
12982 }
12983 
hasPairwiseAdd(unsigned Opcode,EVT VT,bool FullFP16)12984 static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
12985   switch (Opcode) {
12986   case ISD::FADD:
12987     return (FullFP16 && VT == MVT::f16) || VT == MVT::f32 || VT == MVT::f64;
12988   case ISD::ADD:
12989     return VT == MVT::i64;
12990   default:
12991     return false;
12992   }
12993 }
12994 
performExtractVectorEltCombine(SDNode * N,SelectionDAG & DAG)12995 static SDValue performExtractVectorEltCombine(SDNode *N, SelectionDAG &DAG) {
12996   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
12997   ConstantSDNode *ConstantN1 = dyn_cast<ConstantSDNode>(N1);
12998 
12999   EVT VT = N->getValueType(0);
13000   const bool FullFP16 =
13001       static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasFullFP16();
13002 
13003   // Rewrite for pairwise fadd pattern
13004   //   (f32 (extract_vector_elt
13005   //           (fadd (vXf32 Other)
13006   //                 (vector_shuffle (vXf32 Other) undef <1,X,...> )) 0))
13007   // ->
13008   //   (f32 (fadd (extract_vector_elt (vXf32 Other) 0)
13009   //              (extract_vector_elt (vXf32 Other) 1))
13010   if (ConstantN1 && ConstantN1->getZExtValue() == 0 &&
13011       hasPairwiseAdd(N0->getOpcode(), VT, FullFP16)) {
13012     SDLoc DL(N0);
13013     SDValue N00 = N0->getOperand(0);
13014     SDValue N01 = N0->getOperand(1);
13015 
13016     ShuffleVectorSDNode *Shuffle = dyn_cast<ShuffleVectorSDNode>(N01);
13017     SDValue Other = N00;
13018 
13019     // And handle the commutative case.
13020     if (!Shuffle) {
13021       Shuffle = dyn_cast<ShuffleVectorSDNode>(N00);
13022       Other = N01;
13023     }
13024 
13025     if (Shuffle && Shuffle->getMaskElt(0) == 1 &&
13026         Other == Shuffle->getOperand(0)) {
13027       return DAG.getNode(N0->getOpcode(), DL, VT,
13028                          DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
13029                                      DAG.getConstant(0, DL, MVT::i64)),
13030                          DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
13031                                      DAG.getConstant(1, DL, MVT::i64)));
13032     }
13033   }
13034 
13035   return SDValue();
13036 }
13037 
performConcatVectorsCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)13038 static SDValue performConcatVectorsCombine(SDNode *N,
13039                                            TargetLowering::DAGCombinerInfo &DCI,
13040                                            SelectionDAG &DAG) {
13041   SDLoc dl(N);
13042   EVT VT = N->getValueType(0);
13043   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
13044   unsigned N0Opc = N0->getOpcode(), N1Opc = N1->getOpcode();
13045 
13046   // Optimize concat_vectors of truncated vectors, where the intermediate
13047   // type is illegal, to avoid said illegality,  e.g.,
13048   //   (v4i16 (concat_vectors (v2i16 (truncate (v2i64))),
13049   //                          (v2i16 (truncate (v2i64)))))
13050   // ->
13051   //   (v4i16 (truncate (vector_shuffle (v4i32 (bitcast (v2i64))),
13052   //                                    (v4i32 (bitcast (v2i64))),
13053   //                                    <0, 2, 4, 6>)))
13054   // This isn't really target-specific, but ISD::TRUNCATE legality isn't keyed
13055   // on both input and result type, so we might generate worse code.
13056   // On AArch64 we know it's fine for v2i64->v4i16 and v4i32->v8i8.
13057   if (N->getNumOperands() == 2 && N0Opc == ISD::TRUNCATE &&
13058       N1Opc == ISD::TRUNCATE) {
13059     SDValue N00 = N0->getOperand(0);
13060     SDValue N10 = N1->getOperand(0);
13061     EVT N00VT = N00.getValueType();
13062 
13063     if (N00VT == N10.getValueType() &&
13064         (N00VT == MVT::v2i64 || N00VT == MVT::v4i32) &&
13065         N00VT.getScalarSizeInBits() == 4 * VT.getScalarSizeInBits()) {
13066       MVT MidVT = (N00VT == MVT::v2i64 ? MVT::v4i32 : MVT::v8i16);
13067       SmallVector<int, 8> Mask(MidVT.getVectorNumElements());
13068       for (size_t i = 0; i < Mask.size(); ++i)
13069         Mask[i] = i * 2;
13070       return DAG.getNode(ISD::TRUNCATE, dl, VT,
13071                          DAG.getVectorShuffle(
13072                              MidVT, dl,
13073                              DAG.getNode(ISD::BITCAST, dl, MidVT, N00),
13074                              DAG.getNode(ISD::BITCAST, dl, MidVT, N10), Mask));
13075     }
13076   }
13077 
13078   // Wait 'til after everything is legalized to try this. That way we have
13079   // legal vector types and such.
13080   if (DCI.isBeforeLegalizeOps())
13081     return SDValue();
13082 
13083   // Optimise concat_vectors of two [us]rhadds or [us]hadds that use extracted
13084   // subvectors from the same original vectors. Combine these into a single
13085   // [us]rhadd or [us]hadd that operates on the two original vectors. Example:
13086   //  (v16i8 (concat_vectors (v8i8 (urhadd (extract_subvector (v16i8 OpA, <0>),
13087   //                                        extract_subvector (v16i8 OpB,
13088   //                                        <0>))),
13089   //                         (v8i8 (urhadd (extract_subvector (v16i8 OpA, <8>),
13090   //                                        extract_subvector (v16i8 OpB,
13091   //                                        <8>)))))
13092   // ->
13093   //  (v16i8(urhadd(v16i8 OpA, v16i8 OpB)))
13094   if (N->getNumOperands() == 2 && N0Opc == N1Opc &&
13095       (N0Opc == AArch64ISD::URHADD || N0Opc == AArch64ISD::SRHADD ||
13096        N0Opc == AArch64ISD::UHADD || N0Opc == AArch64ISD::SHADD)) {
13097     SDValue N00 = N0->getOperand(0);
13098     SDValue N01 = N0->getOperand(1);
13099     SDValue N10 = N1->getOperand(0);
13100     SDValue N11 = N1->getOperand(1);
13101 
13102     EVT N00VT = N00.getValueType();
13103     EVT N10VT = N10.getValueType();
13104 
13105     if (N00->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
13106         N01->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
13107         N10->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
13108         N11->getOpcode() == ISD::EXTRACT_SUBVECTOR && N00VT == N10VT) {
13109       SDValue N00Source = N00->getOperand(0);
13110       SDValue N01Source = N01->getOperand(0);
13111       SDValue N10Source = N10->getOperand(0);
13112       SDValue N11Source = N11->getOperand(0);
13113 
13114       if (N00Source == N10Source && N01Source == N11Source &&
13115           N00Source.getValueType() == VT && N01Source.getValueType() == VT) {
13116         assert(N0.getValueType() == N1.getValueType());
13117 
13118         uint64_t N00Index = N00.getConstantOperandVal(1);
13119         uint64_t N01Index = N01.getConstantOperandVal(1);
13120         uint64_t N10Index = N10.getConstantOperandVal(1);
13121         uint64_t N11Index = N11.getConstantOperandVal(1);
13122 
13123         if (N00Index == N01Index && N10Index == N11Index && N00Index == 0 &&
13124             N10Index == N00VT.getVectorNumElements())
13125           return DAG.getNode(N0Opc, dl, VT, N00Source, N01Source);
13126       }
13127     }
13128   }
13129 
13130   // If we see a (concat_vectors (v1x64 A), (v1x64 A)) it's really a vector
13131   // splat. The indexed instructions are going to be expecting a DUPLANE64, so
13132   // canonicalise to that.
13133   if (N0 == N1 && VT.getVectorNumElements() == 2) {
13134     assert(VT.getScalarSizeInBits() == 64);
13135     return DAG.getNode(AArch64ISD::DUPLANE64, dl, VT, WidenVector(N0, DAG),
13136                        DAG.getConstant(0, dl, MVT::i64));
13137   }
13138 
13139   // Canonicalise concat_vectors so that the right-hand vector has as few
13140   // bit-casts as possible before its real operation. The primary matching
13141   // destination for these operations will be the narrowing "2" instructions,
13142   // which depend on the operation being performed on this right-hand vector.
13143   // For example,
13144   //    (concat_vectors LHS,  (v1i64 (bitconvert (v4i16 RHS))))
13145   // becomes
13146   //    (bitconvert (concat_vectors (v4i16 (bitconvert LHS)), RHS))
13147 
13148   if (N1Opc != ISD::BITCAST)
13149     return SDValue();
13150   SDValue RHS = N1->getOperand(0);
13151   MVT RHSTy = RHS.getValueType().getSimpleVT();
13152   // If the RHS is not a vector, this is not the pattern we're looking for.
13153   if (!RHSTy.isVector())
13154     return SDValue();
13155 
13156   LLVM_DEBUG(
13157       dbgs() << "aarch64-lower: concat_vectors bitcast simplification\n");
13158 
13159   MVT ConcatTy = MVT::getVectorVT(RHSTy.getVectorElementType(),
13160                                   RHSTy.getVectorNumElements() * 2);
13161   return DAG.getNode(ISD::BITCAST, dl, VT,
13162                      DAG.getNode(ISD::CONCAT_VECTORS, dl, ConcatTy,
13163                                  DAG.getNode(ISD::BITCAST, dl, RHSTy, N0),
13164                                  RHS));
13165 }
13166 
tryCombineFixedPointConvert(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)13167 static SDValue tryCombineFixedPointConvert(SDNode *N,
13168                                            TargetLowering::DAGCombinerInfo &DCI,
13169                                            SelectionDAG &DAG) {
13170   // Wait until after everything is legalized to try this. That way we have
13171   // legal vector types and such.
13172   if (DCI.isBeforeLegalizeOps())
13173     return SDValue();
13174   // Transform a scalar conversion of a value from a lane extract into a
13175   // lane extract of a vector conversion. E.g., from foo1 to foo2:
13176   // double foo1(int64x2_t a) { return vcvtd_n_f64_s64(a[1], 9); }
13177   // double foo2(int64x2_t a) { return vcvtq_n_f64_s64(a, 9)[1]; }
13178   //
13179   // The second form interacts better with instruction selection and the
13180   // register allocator to avoid cross-class register copies that aren't
13181   // coalescable due to a lane reference.
13182 
13183   // Check the operand and see if it originates from a lane extract.
13184   SDValue Op1 = N->getOperand(1);
13185   if (Op1.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
13186     // Yep, no additional predication needed. Perform the transform.
13187     SDValue IID = N->getOperand(0);
13188     SDValue Shift = N->getOperand(2);
13189     SDValue Vec = Op1.getOperand(0);
13190     SDValue Lane = Op1.getOperand(1);
13191     EVT ResTy = N->getValueType(0);
13192     EVT VecResTy;
13193     SDLoc DL(N);
13194 
13195     // The vector width should be 128 bits by the time we get here, even
13196     // if it started as 64 bits (the extract_vector handling will have
13197     // done so).
13198     assert(Vec.getValueSizeInBits() == 128 &&
13199            "unexpected vector size on extract_vector_elt!");
13200     if (Vec.getValueType() == MVT::v4i32)
13201       VecResTy = MVT::v4f32;
13202     else if (Vec.getValueType() == MVT::v2i64)
13203       VecResTy = MVT::v2f64;
13204     else
13205       llvm_unreachable("unexpected vector type!");
13206 
13207     SDValue Convert =
13208         DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VecResTy, IID, Vec, Shift);
13209     return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResTy, Convert, Lane);
13210   }
13211   return SDValue();
13212 }
13213 
13214 // AArch64 high-vector "long" operations are formed by performing the non-high
13215 // version on an extract_subvector of each operand which gets the high half:
13216 //
13217 //  (longop2 LHS, RHS) == (longop (extract_high LHS), (extract_high RHS))
13218 //
13219 // However, there are cases which don't have an extract_high explicitly, but
13220 // have another operation that can be made compatible with one for free. For
13221 // example:
13222 //
13223 //  (dupv64 scalar) --> (extract_high (dup128 scalar))
13224 //
13225 // This routine does the actual conversion of such DUPs, once outer routines
13226 // have determined that everything else is in order.
13227 // It also supports immediate DUP-like nodes (MOVI/MVNi), which we can fold
13228 // similarly here.
tryExtendDUPToExtractHigh(SDValue N,SelectionDAG & DAG)13229 static SDValue tryExtendDUPToExtractHigh(SDValue N, SelectionDAG &DAG) {
13230   switch (N.getOpcode()) {
13231   case AArch64ISD::DUP:
13232   case AArch64ISD::DUPLANE8:
13233   case AArch64ISD::DUPLANE16:
13234   case AArch64ISD::DUPLANE32:
13235   case AArch64ISD::DUPLANE64:
13236   case AArch64ISD::MOVI:
13237   case AArch64ISD::MOVIshift:
13238   case AArch64ISD::MOVIedit:
13239   case AArch64ISD::MOVImsl:
13240   case AArch64ISD::MVNIshift:
13241   case AArch64ISD::MVNImsl:
13242     break;
13243   default:
13244     // FMOV could be supported, but isn't very useful, as it would only occur
13245     // if you passed a bitcast' floating point immediate to an eligible long
13246     // integer op (addl, smull, ...).
13247     return SDValue();
13248   }
13249 
13250   MVT NarrowTy = N.getSimpleValueType();
13251   if (!NarrowTy.is64BitVector())
13252     return SDValue();
13253 
13254   MVT ElementTy = NarrowTy.getVectorElementType();
13255   unsigned NumElems = NarrowTy.getVectorNumElements();
13256   MVT NewVT = MVT::getVectorVT(ElementTy, NumElems * 2);
13257 
13258   SDLoc dl(N);
13259   return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NarrowTy,
13260                      DAG.getNode(N->getOpcode(), dl, NewVT, N->ops()),
13261                      DAG.getConstant(NumElems, dl, MVT::i64));
13262 }
13263 
isEssentiallyExtractHighSubvector(SDValue N)13264 static bool isEssentiallyExtractHighSubvector(SDValue N) {
13265   if (N.getOpcode() == ISD::BITCAST)
13266     N = N.getOperand(0);
13267   if (N.getOpcode() != ISD::EXTRACT_SUBVECTOR)
13268     return false;
13269   return cast<ConstantSDNode>(N.getOperand(1))->getAPIntValue() ==
13270          N.getOperand(0).getValueType().getVectorNumElements() / 2;
13271 }
13272 
13273 /// Helper structure to keep track of ISD::SET_CC operands.
13274 struct GenericSetCCInfo {
13275   const SDValue *Opnd0;
13276   const SDValue *Opnd1;
13277   ISD::CondCode CC;
13278 };
13279 
13280 /// Helper structure to keep track of a SET_CC lowered into AArch64 code.
13281 struct AArch64SetCCInfo {
13282   const SDValue *Cmp;
13283   AArch64CC::CondCode CC;
13284 };
13285 
13286 /// Helper structure to keep track of SetCC information.
13287 union SetCCInfo {
13288   GenericSetCCInfo Generic;
13289   AArch64SetCCInfo AArch64;
13290 };
13291 
13292 /// Helper structure to be able to read SetCC information.  If set to
13293 /// true, IsAArch64 field, Info is a AArch64SetCCInfo, otherwise Info is a
13294 /// GenericSetCCInfo.
13295 struct SetCCInfoAndKind {
13296   SetCCInfo Info;
13297   bool IsAArch64;
13298 };
13299 
13300 /// Check whether or not \p Op is a SET_CC operation, either a generic or
13301 /// an
13302 /// AArch64 lowered one.
13303 /// \p SetCCInfo is filled accordingly.
13304 /// \post SetCCInfo is meanginfull only when this function returns true.
13305 /// \return True when Op is a kind of SET_CC operation.
isSetCC(SDValue Op,SetCCInfoAndKind & SetCCInfo)13306 static bool isSetCC(SDValue Op, SetCCInfoAndKind &SetCCInfo) {
13307   // If this is a setcc, this is straight forward.
13308   if (Op.getOpcode() == ISD::SETCC) {
13309     SetCCInfo.Info.Generic.Opnd0 = &Op.getOperand(0);
13310     SetCCInfo.Info.Generic.Opnd1 = &Op.getOperand(1);
13311     SetCCInfo.Info.Generic.CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
13312     SetCCInfo.IsAArch64 = false;
13313     return true;
13314   }
13315   // Otherwise, check if this is a matching csel instruction.
13316   // In other words:
13317   // - csel 1, 0, cc
13318   // - csel 0, 1, !cc
13319   if (Op.getOpcode() != AArch64ISD::CSEL)
13320     return false;
13321   // Set the information about the operands.
13322   // TODO: we want the operands of the Cmp not the csel
13323   SetCCInfo.Info.AArch64.Cmp = &Op.getOperand(3);
13324   SetCCInfo.IsAArch64 = true;
13325   SetCCInfo.Info.AArch64.CC = static_cast<AArch64CC::CondCode>(
13326       cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue());
13327 
13328   // Check that the operands matches the constraints:
13329   // (1) Both operands must be constants.
13330   // (2) One must be 1 and the other must be 0.
13331   ConstantSDNode *TValue = dyn_cast<ConstantSDNode>(Op.getOperand(0));
13332   ConstantSDNode *FValue = dyn_cast<ConstantSDNode>(Op.getOperand(1));
13333 
13334   // Check (1).
13335   if (!TValue || !FValue)
13336     return false;
13337 
13338   // Check (2).
13339   if (!TValue->isOne()) {
13340     // Update the comparison when we are interested in !cc.
13341     std::swap(TValue, FValue);
13342     SetCCInfo.Info.AArch64.CC =
13343         AArch64CC::getInvertedCondCode(SetCCInfo.Info.AArch64.CC);
13344   }
13345   return TValue->isOne() && FValue->isNullValue();
13346 }
13347 
13348 // Returns true if Op is setcc or zext of setcc.
isSetCCOrZExtSetCC(const SDValue & Op,SetCCInfoAndKind & Info)13349 static bool isSetCCOrZExtSetCC(const SDValue& Op, SetCCInfoAndKind &Info) {
13350   if (isSetCC(Op, Info))
13351     return true;
13352   return ((Op.getOpcode() == ISD::ZERO_EXTEND) &&
13353     isSetCC(Op->getOperand(0), Info));
13354 }
13355 
13356 // The folding we want to perform is:
13357 // (add x, [zext] (setcc cc ...) )
13358 //   -->
13359 // (csel x, (add x, 1), !cc ...)
13360 //
13361 // The latter will get matched to a CSINC instruction.
performSetccAddFolding(SDNode * Op,SelectionDAG & DAG)13362 static SDValue performSetccAddFolding(SDNode *Op, SelectionDAG &DAG) {
13363   assert(Op && Op->getOpcode() == ISD::ADD && "Unexpected operation!");
13364   SDValue LHS = Op->getOperand(0);
13365   SDValue RHS = Op->getOperand(1);
13366   SetCCInfoAndKind InfoAndKind;
13367 
13368   // If both operands are a SET_CC, then we don't want to perform this
13369   // folding and create another csel as this results in more instructions
13370   // (and higher register usage).
13371   if (isSetCCOrZExtSetCC(LHS, InfoAndKind) &&
13372       isSetCCOrZExtSetCC(RHS, InfoAndKind))
13373     return SDValue();
13374 
13375   // If neither operand is a SET_CC, give up.
13376   if (!isSetCCOrZExtSetCC(LHS, InfoAndKind)) {
13377     std::swap(LHS, RHS);
13378     if (!isSetCCOrZExtSetCC(LHS, InfoAndKind))
13379       return SDValue();
13380   }
13381 
13382   // FIXME: This could be generatized to work for FP comparisons.
13383   EVT CmpVT = InfoAndKind.IsAArch64
13384                   ? InfoAndKind.Info.AArch64.Cmp->getOperand(0).getValueType()
13385                   : InfoAndKind.Info.Generic.Opnd0->getValueType();
13386   if (CmpVT != MVT::i32 && CmpVT != MVT::i64)
13387     return SDValue();
13388 
13389   SDValue CCVal;
13390   SDValue Cmp;
13391   SDLoc dl(Op);
13392   if (InfoAndKind.IsAArch64) {
13393     CCVal = DAG.getConstant(
13394         AArch64CC::getInvertedCondCode(InfoAndKind.Info.AArch64.CC), dl,
13395         MVT::i32);
13396     Cmp = *InfoAndKind.Info.AArch64.Cmp;
13397   } else
13398     Cmp = getAArch64Cmp(
13399         *InfoAndKind.Info.Generic.Opnd0, *InfoAndKind.Info.Generic.Opnd1,
13400         ISD::getSetCCInverse(InfoAndKind.Info.Generic.CC, CmpVT), CCVal, DAG,
13401         dl);
13402 
13403   EVT VT = Op->getValueType(0);
13404   LHS = DAG.getNode(ISD::ADD, dl, VT, RHS, DAG.getConstant(1, dl, VT));
13405   return DAG.getNode(AArch64ISD::CSEL, dl, VT, RHS, LHS, CCVal, Cmp);
13406 }
13407 
13408 // ADD(UADDV a, UADDV b) -->  UADDV(ADD a, b)
performUADDVCombine(SDNode * N,SelectionDAG & DAG)13409 static SDValue performUADDVCombine(SDNode *N, SelectionDAG &DAG) {
13410   EVT VT = N->getValueType(0);
13411   // Only scalar integer and vector types.
13412   if (N->getOpcode() != ISD::ADD || !VT.isScalarInteger())
13413     return SDValue();
13414 
13415   SDValue LHS = N->getOperand(0);
13416   SDValue RHS = N->getOperand(1);
13417   if (LHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
13418       RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT || LHS.getValueType() != VT)
13419     return SDValue();
13420 
13421   auto *LHSN1 = dyn_cast<ConstantSDNode>(LHS->getOperand(1));
13422   auto *RHSN1 = dyn_cast<ConstantSDNode>(RHS->getOperand(1));
13423   if (!LHSN1 || LHSN1 != RHSN1 || !RHSN1->isNullValue())
13424     return SDValue();
13425 
13426   SDValue Op1 = LHS->getOperand(0);
13427   SDValue Op2 = RHS->getOperand(0);
13428   EVT OpVT1 = Op1.getValueType();
13429   EVT OpVT2 = Op2.getValueType();
13430   if (Op1.getOpcode() != AArch64ISD::UADDV || OpVT1 != OpVT2 ||
13431       Op2.getOpcode() != AArch64ISD::UADDV ||
13432       OpVT1.getVectorElementType() != VT)
13433     return SDValue();
13434 
13435   SDValue Val1 = Op1.getOperand(0);
13436   SDValue Val2 = Op2.getOperand(0);
13437   EVT ValVT = Val1->getValueType(0);
13438   SDLoc DL(N);
13439   SDValue AddVal = DAG.getNode(ISD::ADD, DL, ValVT, Val1, Val2);
13440   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT,
13441                      DAG.getNode(AArch64ISD::UADDV, DL, ValVT, AddVal),
13442                      DAG.getConstant(0, DL, MVT::i64));
13443 }
13444 
13445 // ADD(UDOT(zero, x, y), A) -->  UDOT(A, x, y)
performAddDotCombine(SDNode * N,SelectionDAG & DAG)13446 static SDValue performAddDotCombine(SDNode *N, SelectionDAG &DAG) {
13447   EVT VT = N->getValueType(0);
13448   if (N->getOpcode() != ISD::ADD)
13449     return SDValue();
13450 
13451   SDValue Dot = N->getOperand(0);
13452   SDValue A = N->getOperand(1);
13453   // Handle commutivity
13454   auto isZeroDot = [](SDValue Dot) {
13455     return (Dot.getOpcode() == AArch64ISD::UDOT ||
13456             Dot.getOpcode() == AArch64ISD::SDOT) &&
13457            isZerosVector(Dot.getOperand(0).getNode());
13458   };
13459   if (!isZeroDot(Dot))
13460     std::swap(Dot, A);
13461   if (!isZeroDot(Dot))
13462     return SDValue();
13463 
13464   return DAG.getNode(Dot.getOpcode(), SDLoc(N), VT, A, Dot.getOperand(1),
13465                      Dot.getOperand(2));
13466 }
13467 
13468 // The basic add/sub long vector instructions have variants with "2" on the end
13469 // which act on the high-half of their inputs. They are normally matched by
13470 // patterns like:
13471 //
13472 // (add (zeroext (extract_high LHS)),
13473 //      (zeroext (extract_high RHS)))
13474 // -> uaddl2 vD, vN, vM
13475 //
13476 // However, if one of the extracts is something like a duplicate, this
13477 // instruction can still be used profitably. This function puts the DAG into a
13478 // more appropriate form for those patterns to trigger.
performAddSubLongCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)13479 static SDValue performAddSubLongCombine(SDNode *N,
13480                                         TargetLowering::DAGCombinerInfo &DCI,
13481                                         SelectionDAG &DAG) {
13482   if (DCI.isBeforeLegalizeOps())
13483     return SDValue();
13484 
13485   MVT VT = N->getSimpleValueType(0);
13486   if (!VT.is128BitVector()) {
13487     if (N->getOpcode() == ISD::ADD)
13488       return performSetccAddFolding(N, DAG);
13489     return SDValue();
13490   }
13491 
13492   // Make sure both branches are extended in the same way.
13493   SDValue LHS = N->getOperand(0);
13494   SDValue RHS = N->getOperand(1);
13495   if ((LHS.getOpcode() != ISD::ZERO_EXTEND &&
13496        LHS.getOpcode() != ISD::SIGN_EXTEND) ||
13497       LHS.getOpcode() != RHS.getOpcode())
13498     return SDValue();
13499 
13500   unsigned ExtType = LHS.getOpcode();
13501 
13502   // It's not worth doing if at least one of the inputs isn't already an
13503   // extract, but we don't know which it'll be so we have to try both.
13504   if (isEssentiallyExtractHighSubvector(LHS.getOperand(0))) {
13505     RHS = tryExtendDUPToExtractHigh(RHS.getOperand(0), DAG);
13506     if (!RHS.getNode())
13507       return SDValue();
13508 
13509     RHS = DAG.getNode(ExtType, SDLoc(N), VT, RHS);
13510   } else if (isEssentiallyExtractHighSubvector(RHS.getOperand(0))) {
13511     LHS = tryExtendDUPToExtractHigh(LHS.getOperand(0), DAG);
13512     if (!LHS.getNode())
13513       return SDValue();
13514 
13515     LHS = DAG.getNode(ExtType, SDLoc(N), VT, LHS);
13516   }
13517 
13518   return DAG.getNode(N->getOpcode(), SDLoc(N), VT, LHS, RHS);
13519 }
13520 
performAddSubCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)13521 static SDValue performAddSubCombine(SDNode *N,
13522                                     TargetLowering::DAGCombinerInfo &DCI,
13523                                     SelectionDAG &DAG) {
13524   // Try to change sum of two reductions.
13525   if (SDValue Val = performUADDVCombine(N, DAG))
13526     return Val;
13527   if (SDValue Val = performAddDotCombine(N, DAG))
13528     return Val;
13529 
13530   return performAddSubLongCombine(N, DCI, DAG);
13531 }
13532 
13533 // Massage DAGs which we can use the high-half "long" operations on into
13534 // something isel will recognize better. E.g.
13535 //
13536 // (aarch64_neon_umull (extract_high vec) (dupv64 scalar)) -->
13537 //   (aarch64_neon_umull (extract_high (v2i64 vec)))
13538 //                     (extract_high (v2i64 (dup128 scalar)))))
13539 //
tryCombineLongOpWithDup(unsigned IID,SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)13540 static SDValue tryCombineLongOpWithDup(unsigned IID, SDNode *N,
13541                                        TargetLowering::DAGCombinerInfo &DCI,
13542                                        SelectionDAG &DAG) {
13543   if (DCI.isBeforeLegalizeOps())
13544     return SDValue();
13545 
13546   SDValue LHS = N->getOperand((IID == Intrinsic::not_intrinsic) ? 0 : 1);
13547   SDValue RHS = N->getOperand((IID == Intrinsic::not_intrinsic) ? 1 : 2);
13548   assert(LHS.getValueType().is64BitVector() &&
13549          RHS.getValueType().is64BitVector() &&
13550          "unexpected shape for long operation");
13551 
13552   // Either node could be a DUP, but it's not worth doing both of them (you'd
13553   // just as well use the non-high version) so look for a corresponding extract
13554   // operation on the other "wing".
13555   if (isEssentiallyExtractHighSubvector(LHS)) {
13556     RHS = tryExtendDUPToExtractHigh(RHS, DAG);
13557     if (!RHS.getNode())
13558       return SDValue();
13559   } else if (isEssentiallyExtractHighSubvector(RHS)) {
13560     LHS = tryExtendDUPToExtractHigh(LHS, DAG);
13561     if (!LHS.getNode())
13562       return SDValue();
13563   }
13564 
13565   if (IID == Intrinsic::not_intrinsic)
13566     return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), LHS, RHS);
13567 
13568   return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SDLoc(N), N->getValueType(0),
13569                      N->getOperand(0), LHS, RHS);
13570 }
13571 
tryCombineShiftImm(unsigned IID,SDNode * N,SelectionDAG & DAG)13572 static SDValue tryCombineShiftImm(unsigned IID, SDNode *N, SelectionDAG &DAG) {
13573   MVT ElemTy = N->getSimpleValueType(0).getScalarType();
13574   unsigned ElemBits = ElemTy.getSizeInBits();
13575 
13576   int64_t ShiftAmount;
13577   if (BuildVectorSDNode *BVN = dyn_cast<BuildVectorSDNode>(N->getOperand(2))) {
13578     APInt SplatValue, SplatUndef;
13579     unsigned SplatBitSize;
13580     bool HasAnyUndefs;
13581     if (!BVN->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
13582                               HasAnyUndefs, ElemBits) ||
13583         SplatBitSize != ElemBits)
13584       return SDValue();
13585 
13586     ShiftAmount = SplatValue.getSExtValue();
13587   } else if (ConstantSDNode *CVN = dyn_cast<ConstantSDNode>(N->getOperand(2))) {
13588     ShiftAmount = CVN->getSExtValue();
13589   } else
13590     return SDValue();
13591 
13592   unsigned Opcode;
13593   bool IsRightShift;
13594   switch (IID) {
13595   default:
13596     llvm_unreachable("Unknown shift intrinsic");
13597   case Intrinsic::aarch64_neon_sqshl:
13598     Opcode = AArch64ISD::SQSHL_I;
13599     IsRightShift = false;
13600     break;
13601   case Intrinsic::aarch64_neon_uqshl:
13602     Opcode = AArch64ISD::UQSHL_I;
13603     IsRightShift = false;
13604     break;
13605   case Intrinsic::aarch64_neon_srshl:
13606     Opcode = AArch64ISD::SRSHR_I;
13607     IsRightShift = true;
13608     break;
13609   case Intrinsic::aarch64_neon_urshl:
13610     Opcode = AArch64ISD::URSHR_I;
13611     IsRightShift = true;
13612     break;
13613   case Intrinsic::aarch64_neon_sqshlu:
13614     Opcode = AArch64ISD::SQSHLU_I;
13615     IsRightShift = false;
13616     break;
13617   case Intrinsic::aarch64_neon_sshl:
13618   case Intrinsic::aarch64_neon_ushl:
13619     // For positive shift amounts we can use SHL, as ushl/sshl perform a regular
13620     // left shift for positive shift amounts. Below, we only replace the current
13621     // node with VSHL, if this condition is met.
13622     Opcode = AArch64ISD::VSHL;
13623     IsRightShift = false;
13624     break;
13625   }
13626 
13627   if (IsRightShift && ShiftAmount <= -1 && ShiftAmount >= -(int)ElemBits) {
13628     SDLoc dl(N);
13629     return DAG.getNode(Opcode, dl, N->getValueType(0), N->getOperand(1),
13630                        DAG.getConstant(-ShiftAmount, dl, MVT::i32));
13631   } else if (!IsRightShift && ShiftAmount >= 0 && ShiftAmount < ElemBits) {
13632     SDLoc dl(N);
13633     return DAG.getNode(Opcode, dl, N->getValueType(0), N->getOperand(1),
13634                        DAG.getConstant(ShiftAmount, dl, MVT::i32));
13635   }
13636 
13637   return SDValue();
13638 }
13639 
13640 // The CRC32[BH] instructions ignore the high bits of their data operand. Since
13641 // the intrinsics must be legal and take an i32, this means there's almost
13642 // certainly going to be a zext in the DAG which we can eliminate.
tryCombineCRC32(unsigned Mask,SDNode * N,SelectionDAG & DAG)13643 static SDValue tryCombineCRC32(unsigned Mask, SDNode *N, SelectionDAG &DAG) {
13644   SDValue AndN = N->getOperand(2);
13645   if (AndN.getOpcode() != ISD::AND)
13646     return SDValue();
13647 
13648   ConstantSDNode *CMask = dyn_cast<ConstantSDNode>(AndN.getOperand(1));
13649   if (!CMask || CMask->getZExtValue() != Mask)
13650     return SDValue();
13651 
13652   return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SDLoc(N), MVT::i32,
13653                      N->getOperand(0), N->getOperand(1), AndN.getOperand(0));
13654 }
13655 
combineAcrossLanesIntrinsic(unsigned Opc,SDNode * N,SelectionDAG & DAG)13656 static SDValue combineAcrossLanesIntrinsic(unsigned Opc, SDNode *N,
13657                                            SelectionDAG &DAG) {
13658   SDLoc dl(N);
13659   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, N->getValueType(0),
13660                      DAG.getNode(Opc, dl,
13661                                  N->getOperand(1).getSimpleValueType(),
13662                                  N->getOperand(1)),
13663                      DAG.getConstant(0, dl, MVT::i64));
13664 }
13665 
LowerSVEIntrinsicIndex(SDNode * N,SelectionDAG & DAG)13666 static SDValue LowerSVEIntrinsicIndex(SDNode *N, SelectionDAG &DAG) {
13667   SDLoc DL(N);
13668   SDValue Op1 = N->getOperand(1);
13669   SDValue Op2 = N->getOperand(2);
13670   EVT ScalarTy = Op2.getValueType();
13671   if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16))
13672     ScalarTy = MVT::i32;
13673 
13674   // Lower index_vector(base, step) to mul(step step_vector(1)) + splat(base).
13675   SDValue One = DAG.getConstant(1, DL, ScalarTy);
13676   SDValue StepVector =
13677       DAG.getNode(ISD::STEP_VECTOR, DL, N->getValueType(0), One);
13678   SDValue Step = DAG.getNode(ISD::SPLAT_VECTOR, DL, N->getValueType(0), Op2);
13679   SDValue Mul = DAG.getNode(ISD::MUL, DL, N->getValueType(0), StepVector, Step);
13680   SDValue Base = DAG.getNode(ISD::SPLAT_VECTOR, DL, N->getValueType(0), Op1);
13681   return DAG.getNode(ISD::ADD, DL, N->getValueType(0), Mul, Base);
13682 }
13683 
LowerSVEIntrinsicDUP(SDNode * N,SelectionDAG & DAG)13684 static SDValue LowerSVEIntrinsicDUP(SDNode *N, SelectionDAG &DAG) {
13685   SDLoc dl(N);
13686   SDValue Scalar = N->getOperand(3);
13687   EVT ScalarTy = Scalar.getValueType();
13688 
13689   if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16))
13690     Scalar = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, Scalar);
13691 
13692   SDValue Passthru = N->getOperand(1);
13693   SDValue Pred = N->getOperand(2);
13694   return DAG.getNode(AArch64ISD::DUP_MERGE_PASSTHRU, dl, N->getValueType(0),
13695                      Pred, Scalar, Passthru);
13696 }
13697 
LowerSVEIntrinsicEXT(SDNode * N,SelectionDAG & DAG)13698 static SDValue LowerSVEIntrinsicEXT(SDNode *N, SelectionDAG &DAG) {
13699   SDLoc dl(N);
13700   LLVMContext &Ctx = *DAG.getContext();
13701   EVT VT = N->getValueType(0);
13702 
13703   assert(VT.isScalableVector() && "Expected a scalable vector.");
13704 
13705   // Current lowering only supports the SVE-ACLE types.
13706   if (VT.getSizeInBits().getKnownMinSize() != AArch64::SVEBitsPerBlock)
13707     return SDValue();
13708 
13709   unsigned ElemSize = VT.getVectorElementType().getSizeInBits() / 8;
13710   unsigned ByteSize = VT.getSizeInBits().getKnownMinSize() / 8;
13711   EVT ByteVT =
13712       EVT::getVectorVT(Ctx, MVT::i8, ElementCount::getScalable(ByteSize));
13713 
13714   // Convert everything to the domain of EXT (i.e bytes).
13715   SDValue Op0 = DAG.getNode(ISD::BITCAST, dl, ByteVT, N->getOperand(1));
13716   SDValue Op1 = DAG.getNode(ISD::BITCAST, dl, ByteVT, N->getOperand(2));
13717   SDValue Op2 = DAG.getNode(ISD::MUL, dl, MVT::i32, N->getOperand(3),
13718                             DAG.getConstant(ElemSize, dl, MVT::i32));
13719 
13720   SDValue EXT = DAG.getNode(AArch64ISD::EXT, dl, ByteVT, Op0, Op1, Op2);
13721   return DAG.getNode(ISD::BITCAST, dl, VT, EXT);
13722 }
13723 
tryConvertSVEWideCompare(SDNode * N,ISD::CondCode CC,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)13724 static SDValue tryConvertSVEWideCompare(SDNode *N, ISD::CondCode CC,
13725                                         TargetLowering::DAGCombinerInfo &DCI,
13726                                         SelectionDAG &DAG) {
13727   if (DCI.isBeforeLegalize())
13728     return SDValue();
13729 
13730   SDValue Comparator = N->getOperand(3);
13731   if (Comparator.getOpcode() == AArch64ISD::DUP ||
13732       Comparator.getOpcode() == ISD::SPLAT_VECTOR) {
13733     unsigned IID = getIntrinsicID(N);
13734     EVT VT = N->getValueType(0);
13735     EVT CmpVT = N->getOperand(2).getValueType();
13736     SDValue Pred = N->getOperand(1);
13737     SDValue Imm;
13738     SDLoc DL(N);
13739 
13740     switch (IID) {
13741     default:
13742       llvm_unreachable("Called with wrong intrinsic!");
13743       break;
13744 
13745     // Signed comparisons
13746     case Intrinsic::aarch64_sve_cmpeq_wide:
13747     case Intrinsic::aarch64_sve_cmpne_wide:
13748     case Intrinsic::aarch64_sve_cmpge_wide:
13749     case Intrinsic::aarch64_sve_cmpgt_wide:
13750     case Intrinsic::aarch64_sve_cmplt_wide:
13751     case Intrinsic::aarch64_sve_cmple_wide: {
13752       if (auto *CN = dyn_cast<ConstantSDNode>(Comparator.getOperand(0))) {
13753         int64_t ImmVal = CN->getSExtValue();
13754         if (ImmVal >= -16 && ImmVal <= 15)
13755           Imm = DAG.getConstant(ImmVal, DL, MVT::i32);
13756         else
13757           return SDValue();
13758       }
13759       break;
13760     }
13761     // Unsigned comparisons
13762     case Intrinsic::aarch64_sve_cmphs_wide:
13763     case Intrinsic::aarch64_sve_cmphi_wide:
13764     case Intrinsic::aarch64_sve_cmplo_wide:
13765     case Intrinsic::aarch64_sve_cmpls_wide:  {
13766       if (auto *CN = dyn_cast<ConstantSDNode>(Comparator.getOperand(0))) {
13767         uint64_t ImmVal = CN->getZExtValue();
13768         if (ImmVal <= 127)
13769           Imm = DAG.getConstant(ImmVal, DL, MVT::i32);
13770         else
13771           return SDValue();
13772       }
13773       break;
13774     }
13775     }
13776 
13777     if (!Imm)
13778       return SDValue();
13779 
13780     SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, DL, CmpVT, Imm);
13781     return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, VT, Pred,
13782                        N->getOperand(2), Splat, DAG.getCondCode(CC));
13783   }
13784 
13785   return SDValue();
13786 }
13787 
getPTest(SelectionDAG & DAG,EVT VT,SDValue Pg,SDValue Op,AArch64CC::CondCode Cond)13788 static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
13789                         AArch64CC::CondCode Cond) {
13790   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13791 
13792   SDLoc DL(Op);
13793   assert(Op.getValueType().isScalableVector() &&
13794          TLI.isTypeLegal(Op.getValueType()) &&
13795          "Expected legal scalable vector type!");
13796 
13797   // Ensure target specific opcodes are using legal type.
13798   EVT OutVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
13799   SDValue TVal = DAG.getConstant(1, DL, OutVT);
13800   SDValue FVal = DAG.getConstant(0, DL, OutVT);
13801 
13802   // Set condition code (CC) flags.
13803   SDValue Test = DAG.getNode(AArch64ISD::PTEST, DL, MVT::Other, Pg, Op);
13804 
13805   // Convert CC to integer based on requested condition.
13806   // NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare.
13807   SDValue CC = DAG.getConstant(getInvertedCondCode(Cond), DL, MVT::i32);
13808   SDValue Res = DAG.getNode(AArch64ISD::CSEL, DL, OutVT, FVal, TVal, CC, Test);
13809   return DAG.getZExtOrTrunc(Res, DL, VT);
13810 }
13811 
combineSVEReductionInt(SDNode * N,unsigned Opc,SelectionDAG & DAG)13812 static SDValue combineSVEReductionInt(SDNode *N, unsigned Opc,
13813                                       SelectionDAG &DAG) {
13814   SDLoc DL(N);
13815 
13816   SDValue Pred = N->getOperand(1);
13817   SDValue VecToReduce = N->getOperand(2);
13818 
13819   // NOTE: The integer reduction's result type is not always linked to the
13820   // operand's element type so we construct it from the intrinsic's result type.
13821   EVT ReduceVT = getPackedSVEVectorVT(N->getValueType(0));
13822   SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, VecToReduce);
13823 
13824   // SVE reductions set the whole vector register with the first element
13825   // containing the reduction result, which we'll now extract.
13826   SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
13827   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce,
13828                      Zero);
13829 }
13830 
combineSVEReductionFP(SDNode * N,unsigned Opc,SelectionDAG & DAG)13831 static SDValue combineSVEReductionFP(SDNode *N, unsigned Opc,
13832                                      SelectionDAG &DAG) {
13833   SDLoc DL(N);
13834 
13835   SDValue Pred = N->getOperand(1);
13836   SDValue VecToReduce = N->getOperand(2);
13837 
13838   EVT ReduceVT = VecToReduce.getValueType();
13839   SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, VecToReduce);
13840 
13841   // SVE reductions set the whole vector register with the first element
13842   // containing the reduction result, which we'll now extract.
13843   SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
13844   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce,
13845                      Zero);
13846 }
13847 
combineSVEReductionOrderedFP(SDNode * N,unsigned Opc,SelectionDAG & DAG)13848 static SDValue combineSVEReductionOrderedFP(SDNode *N, unsigned Opc,
13849                                             SelectionDAG &DAG) {
13850   SDLoc DL(N);
13851 
13852   SDValue Pred = N->getOperand(1);
13853   SDValue InitVal = N->getOperand(2);
13854   SDValue VecToReduce = N->getOperand(3);
13855   EVT ReduceVT = VecToReduce.getValueType();
13856 
13857   // Ordered reductions use the first lane of the result vector as the
13858   // reduction's initial value.
13859   SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
13860   InitVal = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ReduceVT,
13861                         DAG.getUNDEF(ReduceVT), InitVal, Zero);
13862 
13863   SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, InitVal, VecToReduce);
13864 
13865   // SVE reductions set the whole vector register with the first element
13866   // containing the reduction result, which we'll now extract.
13867   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce,
13868                      Zero);
13869 }
13870 
isAllActivePredicate(SDValue N)13871 static bool isAllActivePredicate(SDValue N) {
13872   unsigned NumElts = N.getValueType().getVectorMinNumElements();
13873 
13874   // Look through cast.
13875   while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST) {
13876     N = N.getOperand(0);
13877     // When reinterpreting from a type with fewer elements the "new" elements
13878     // are not active, so bail if they're likely to be used.
13879     if (N.getValueType().getVectorMinNumElements() < NumElts)
13880       return false;
13881   }
13882 
13883   // "ptrue p.<ty>, all" can be considered all active when <ty> is the same size
13884   // or smaller than the implicit element type represented by N.
13885   // NOTE: A larger element count implies a smaller element type.
13886   if (N.getOpcode() == AArch64ISD::PTRUE &&
13887       N.getConstantOperandVal(0) == AArch64SVEPredPattern::all)
13888     return N.getValueType().getVectorMinNumElements() >= NumElts;
13889 
13890   return false;
13891 }
13892 
13893 // If a merged operation has no inactive lanes we can relax it to a predicated
13894 // or unpredicated operation, which potentially allows better isel (perhaps
13895 // using immediate forms) or relaxing register reuse requirements.
convertMergedOpToPredOp(SDNode * N,unsigned Opc,SelectionDAG & DAG,bool UnpredOp=false)13896 static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
13897                                        SelectionDAG &DAG,
13898                                        bool UnpredOp = false) {
13899   assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN && "Expected intrinsic!");
13900   assert(N->getNumOperands() == 4 && "Expected 3 operand intrinsic!");
13901   SDValue Pg = N->getOperand(1);
13902 
13903   // ISD way to specify an all active predicate.
13904   if (isAllActivePredicate(Pg)) {
13905     if (UnpredOp)
13906       return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), N->getOperand(2),
13907                          N->getOperand(3));
13908     else
13909       return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Pg,
13910                          N->getOperand(2), N->getOperand(3));
13911   }
13912 
13913   // FUTURE: SplatVector(true)
13914   return SDValue();
13915 }
13916 
performIntrinsicCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)13917 static SDValue performIntrinsicCombine(SDNode *N,
13918                                        TargetLowering::DAGCombinerInfo &DCI,
13919                                        const AArch64Subtarget *Subtarget) {
13920   SelectionDAG &DAG = DCI.DAG;
13921   unsigned IID = getIntrinsicID(N);
13922   switch (IID) {
13923   default:
13924     break;
13925   case Intrinsic::aarch64_neon_vcvtfxs2fp:
13926   case Intrinsic::aarch64_neon_vcvtfxu2fp:
13927     return tryCombineFixedPointConvert(N, DCI, DAG);
13928   case Intrinsic::aarch64_neon_saddv:
13929     return combineAcrossLanesIntrinsic(AArch64ISD::SADDV, N, DAG);
13930   case Intrinsic::aarch64_neon_uaddv:
13931     return combineAcrossLanesIntrinsic(AArch64ISD::UADDV, N, DAG);
13932   case Intrinsic::aarch64_neon_sminv:
13933     return combineAcrossLanesIntrinsic(AArch64ISD::SMINV, N, DAG);
13934   case Intrinsic::aarch64_neon_uminv:
13935     return combineAcrossLanesIntrinsic(AArch64ISD::UMINV, N, DAG);
13936   case Intrinsic::aarch64_neon_smaxv:
13937     return combineAcrossLanesIntrinsic(AArch64ISD::SMAXV, N, DAG);
13938   case Intrinsic::aarch64_neon_umaxv:
13939     return combineAcrossLanesIntrinsic(AArch64ISD::UMAXV, N, DAG);
13940   case Intrinsic::aarch64_neon_fmax:
13941     return DAG.getNode(ISD::FMAXIMUM, SDLoc(N), N->getValueType(0),
13942                        N->getOperand(1), N->getOperand(2));
13943   case Intrinsic::aarch64_neon_fmin:
13944     return DAG.getNode(ISD::FMINIMUM, SDLoc(N), N->getValueType(0),
13945                        N->getOperand(1), N->getOperand(2));
13946   case Intrinsic::aarch64_neon_fmaxnm:
13947     return DAG.getNode(ISD::FMAXNUM, SDLoc(N), N->getValueType(0),
13948                        N->getOperand(1), N->getOperand(2));
13949   case Intrinsic::aarch64_neon_fminnm:
13950     return DAG.getNode(ISD::FMINNUM, SDLoc(N), N->getValueType(0),
13951                        N->getOperand(1), N->getOperand(2));
13952   case Intrinsic::aarch64_neon_smull:
13953   case Intrinsic::aarch64_neon_umull:
13954   case Intrinsic::aarch64_neon_pmull:
13955   case Intrinsic::aarch64_neon_sqdmull:
13956     return tryCombineLongOpWithDup(IID, N, DCI, DAG);
13957   case Intrinsic::aarch64_neon_sqshl:
13958   case Intrinsic::aarch64_neon_uqshl:
13959   case Intrinsic::aarch64_neon_sqshlu:
13960   case Intrinsic::aarch64_neon_srshl:
13961   case Intrinsic::aarch64_neon_urshl:
13962   case Intrinsic::aarch64_neon_sshl:
13963   case Intrinsic::aarch64_neon_ushl:
13964     return tryCombineShiftImm(IID, N, DAG);
13965   case Intrinsic::aarch64_crc32b:
13966   case Intrinsic::aarch64_crc32cb:
13967     return tryCombineCRC32(0xff, N, DAG);
13968   case Intrinsic::aarch64_crc32h:
13969   case Intrinsic::aarch64_crc32ch:
13970     return tryCombineCRC32(0xffff, N, DAG);
13971   case Intrinsic::aarch64_sve_saddv:
13972     // There is no i64 version of SADDV because the sign is irrelevant.
13973     if (N->getOperand(2)->getValueType(0).getVectorElementType() == MVT::i64)
13974       return combineSVEReductionInt(N, AArch64ISD::UADDV_PRED, DAG);
13975     else
13976       return combineSVEReductionInt(N, AArch64ISD::SADDV_PRED, DAG);
13977   case Intrinsic::aarch64_sve_uaddv:
13978     return combineSVEReductionInt(N, AArch64ISD::UADDV_PRED, DAG);
13979   case Intrinsic::aarch64_sve_smaxv:
13980     return combineSVEReductionInt(N, AArch64ISD::SMAXV_PRED, DAG);
13981   case Intrinsic::aarch64_sve_umaxv:
13982     return combineSVEReductionInt(N, AArch64ISD::UMAXV_PRED, DAG);
13983   case Intrinsic::aarch64_sve_sminv:
13984     return combineSVEReductionInt(N, AArch64ISD::SMINV_PRED, DAG);
13985   case Intrinsic::aarch64_sve_uminv:
13986     return combineSVEReductionInt(N, AArch64ISD::UMINV_PRED, DAG);
13987   case Intrinsic::aarch64_sve_orv:
13988     return combineSVEReductionInt(N, AArch64ISD::ORV_PRED, DAG);
13989   case Intrinsic::aarch64_sve_eorv:
13990     return combineSVEReductionInt(N, AArch64ISD::EORV_PRED, DAG);
13991   case Intrinsic::aarch64_sve_andv:
13992     return combineSVEReductionInt(N, AArch64ISD::ANDV_PRED, DAG);
13993   case Intrinsic::aarch64_sve_index:
13994     return LowerSVEIntrinsicIndex(N, DAG);
13995   case Intrinsic::aarch64_sve_dup:
13996     return LowerSVEIntrinsicDUP(N, DAG);
13997   case Intrinsic::aarch64_sve_dup_x:
13998     return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), N->getValueType(0),
13999                        N->getOperand(1));
14000   case Intrinsic::aarch64_sve_ext:
14001     return LowerSVEIntrinsicEXT(N, DAG);
14002   case Intrinsic::aarch64_sve_mul:
14003     return convertMergedOpToPredOp(N, AArch64ISD::MUL_PRED, DAG);
14004   case Intrinsic::aarch64_sve_smulh:
14005     return convertMergedOpToPredOp(N, AArch64ISD::MULHS_PRED, DAG);
14006   case Intrinsic::aarch64_sve_umulh:
14007     return convertMergedOpToPredOp(N, AArch64ISD::MULHU_PRED, DAG);
14008   case Intrinsic::aarch64_sve_smin:
14009     return convertMergedOpToPredOp(N, AArch64ISD::SMIN_PRED, DAG);
14010   case Intrinsic::aarch64_sve_umin:
14011     return convertMergedOpToPredOp(N, AArch64ISD::UMIN_PRED, DAG);
14012   case Intrinsic::aarch64_sve_smax:
14013     return convertMergedOpToPredOp(N, AArch64ISD::SMAX_PRED, DAG);
14014   case Intrinsic::aarch64_sve_umax:
14015     return convertMergedOpToPredOp(N, AArch64ISD::UMAX_PRED, DAG);
14016   case Intrinsic::aarch64_sve_lsl:
14017     return convertMergedOpToPredOp(N, AArch64ISD::SHL_PRED, DAG);
14018   case Intrinsic::aarch64_sve_lsr:
14019     return convertMergedOpToPredOp(N, AArch64ISD::SRL_PRED, DAG);
14020   case Intrinsic::aarch64_sve_asr:
14021     return convertMergedOpToPredOp(N, AArch64ISD::SRA_PRED, DAG);
14022   case Intrinsic::aarch64_sve_fadd:
14023     return convertMergedOpToPredOp(N, AArch64ISD::FADD_PRED, DAG);
14024   case Intrinsic::aarch64_sve_fsub:
14025     return convertMergedOpToPredOp(N, AArch64ISD::FSUB_PRED, DAG);
14026   case Intrinsic::aarch64_sve_fmul:
14027     return convertMergedOpToPredOp(N, AArch64ISD::FMUL_PRED, DAG);
14028   case Intrinsic::aarch64_sve_add:
14029     return convertMergedOpToPredOp(N, ISD::ADD, DAG, true);
14030   case Intrinsic::aarch64_sve_sub:
14031     return convertMergedOpToPredOp(N, ISD::SUB, DAG, true);
14032   case Intrinsic::aarch64_sve_and:
14033     return convertMergedOpToPredOp(N, ISD::AND, DAG, true);
14034   case Intrinsic::aarch64_sve_bic:
14035     return convertMergedOpToPredOp(N, AArch64ISD::BIC, DAG, true);
14036   case Intrinsic::aarch64_sve_eor:
14037     return convertMergedOpToPredOp(N, ISD::XOR, DAG, true);
14038   case Intrinsic::aarch64_sve_orr:
14039     return convertMergedOpToPredOp(N, ISD::OR, DAG, true);
14040   case Intrinsic::aarch64_sve_sqadd:
14041     return convertMergedOpToPredOp(N, ISD::SADDSAT, DAG, true);
14042   case Intrinsic::aarch64_sve_sqsub:
14043     return convertMergedOpToPredOp(N, ISD::SSUBSAT, DAG, true);
14044   case Intrinsic::aarch64_sve_uqadd:
14045     return convertMergedOpToPredOp(N, ISD::UADDSAT, DAG, true);
14046   case Intrinsic::aarch64_sve_uqsub:
14047     return convertMergedOpToPredOp(N, ISD::USUBSAT, DAG, true);
14048   case Intrinsic::aarch64_sve_sqadd_x:
14049     return DAG.getNode(ISD::SADDSAT, SDLoc(N), N->getValueType(0),
14050                        N->getOperand(1), N->getOperand(2));
14051   case Intrinsic::aarch64_sve_sqsub_x:
14052     return DAG.getNode(ISD::SSUBSAT, SDLoc(N), N->getValueType(0),
14053                        N->getOperand(1), N->getOperand(2));
14054   case Intrinsic::aarch64_sve_uqadd_x:
14055     return DAG.getNode(ISD::UADDSAT, SDLoc(N), N->getValueType(0),
14056                        N->getOperand(1), N->getOperand(2));
14057   case Intrinsic::aarch64_sve_uqsub_x:
14058     return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0),
14059                        N->getOperand(1), N->getOperand(2));
14060   case Intrinsic::aarch64_sve_cmphs:
14061     if (!N->getOperand(2).getValueType().isFloatingPoint())
14062       return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
14063                          N->getValueType(0), N->getOperand(1), N->getOperand(2),
14064                          N->getOperand(3), DAG.getCondCode(ISD::SETUGE));
14065     break;
14066   case Intrinsic::aarch64_sve_cmphi:
14067     if (!N->getOperand(2).getValueType().isFloatingPoint())
14068       return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
14069                          N->getValueType(0), N->getOperand(1), N->getOperand(2),
14070                          N->getOperand(3), DAG.getCondCode(ISD::SETUGT));
14071     break;
14072   case Intrinsic::aarch64_sve_cmpge:
14073     if (!N->getOperand(2).getValueType().isFloatingPoint())
14074       return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
14075                          N->getValueType(0), N->getOperand(1), N->getOperand(2),
14076                          N->getOperand(3), DAG.getCondCode(ISD::SETGE));
14077     break;
14078   case Intrinsic::aarch64_sve_cmpgt:
14079     if (!N->getOperand(2).getValueType().isFloatingPoint())
14080       return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
14081                          N->getValueType(0), N->getOperand(1), N->getOperand(2),
14082                          N->getOperand(3), DAG.getCondCode(ISD::SETGT));
14083     break;
14084   case Intrinsic::aarch64_sve_cmpeq:
14085     if (!N->getOperand(2).getValueType().isFloatingPoint())
14086       return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
14087                          N->getValueType(0), N->getOperand(1), N->getOperand(2),
14088                          N->getOperand(3), DAG.getCondCode(ISD::SETEQ));
14089     break;
14090   case Intrinsic::aarch64_sve_cmpne:
14091     if (!N->getOperand(2).getValueType().isFloatingPoint())
14092       return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
14093                          N->getValueType(0), N->getOperand(1), N->getOperand(2),
14094                          N->getOperand(3), DAG.getCondCode(ISD::SETNE));
14095     break;
14096   case Intrinsic::aarch64_sve_fadda:
14097     return combineSVEReductionOrderedFP(N, AArch64ISD::FADDA_PRED, DAG);
14098   case Intrinsic::aarch64_sve_faddv:
14099     return combineSVEReductionFP(N, AArch64ISD::FADDV_PRED, DAG);
14100   case Intrinsic::aarch64_sve_fmaxnmv:
14101     return combineSVEReductionFP(N, AArch64ISD::FMAXNMV_PRED, DAG);
14102   case Intrinsic::aarch64_sve_fmaxv:
14103     return combineSVEReductionFP(N, AArch64ISD::FMAXV_PRED, DAG);
14104   case Intrinsic::aarch64_sve_fminnmv:
14105     return combineSVEReductionFP(N, AArch64ISD::FMINNMV_PRED, DAG);
14106   case Intrinsic::aarch64_sve_fminv:
14107     return combineSVEReductionFP(N, AArch64ISD::FMINV_PRED, DAG);
14108   case Intrinsic::aarch64_sve_sel:
14109     return DAG.getNode(ISD::VSELECT, SDLoc(N), N->getValueType(0),
14110                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
14111   case Intrinsic::aarch64_sve_cmpeq_wide:
14112     return tryConvertSVEWideCompare(N, ISD::SETEQ, DCI, DAG);
14113   case Intrinsic::aarch64_sve_cmpne_wide:
14114     return tryConvertSVEWideCompare(N, ISD::SETNE, DCI, DAG);
14115   case Intrinsic::aarch64_sve_cmpge_wide:
14116     return tryConvertSVEWideCompare(N, ISD::SETGE, DCI, DAG);
14117   case Intrinsic::aarch64_sve_cmpgt_wide:
14118     return tryConvertSVEWideCompare(N, ISD::SETGT, DCI, DAG);
14119   case Intrinsic::aarch64_sve_cmplt_wide:
14120     return tryConvertSVEWideCompare(N, ISD::SETLT, DCI, DAG);
14121   case Intrinsic::aarch64_sve_cmple_wide:
14122     return tryConvertSVEWideCompare(N, ISD::SETLE, DCI, DAG);
14123   case Intrinsic::aarch64_sve_cmphs_wide:
14124     return tryConvertSVEWideCompare(N, ISD::SETUGE, DCI, DAG);
14125   case Intrinsic::aarch64_sve_cmphi_wide:
14126     return tryConvertSVEWideCompare(N, ISD::SETUGT, DCI, DAG);
14127   case Intrinsic::aarch64_sve_cmplo_wide:
14128     return tryConvertSVEWideCompare(N, ISD::SETULT, DCI, DAG);
14129   case Intrinsic::aarch64_sve_cmpls_wide:
14130     return tryConvertSVEWideCompare(N, ISD::SETULE, DCI, DAG);
14131   case Intrinsic::aarch64_sve_ptest_any:
14132     return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2),
14133                     AArch64CC::ANY_ACTIVE);
14134   case Intrinsic::aarch64_sve_ptest_first:
14135     return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2),
14136                     AArch64CC::FIRST_ACTIVE);
14137   case Intrinsic::aarch64_sve_ptest_last:
14138     return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2),
14139                     AArch64CC::LAST_ACTIVE);
14140   }
14141   return SDValue();
14142 }
14143 
performExtendCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)14144 static SDValue performExtendCombine(SDNode *N,
14145                                     TargetLowering::DAGCombinerInfo &DCI,
14146                                     SelectionDAG &DAG) {
14147   // If we see something like (zext (sabd (extract_high ...), (DUP ...))) then
14148   // we can convert that DUP into another extract_high (of a bigger DUP), which
14149   // helps the backend to decide that an sabdl2 would be useful, saving a real
14150   // extract_high operation.
14151   if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND &&
14152       (N->getOperand(0).getOpcode() == AArch64ISD::UABD ||
14153        N->getOperand(0).getOpcode() == AArch64ISD::SABD)) {
14154     SDNode *ABDNode = N->getOperand(0).getNode();
14155     SDValue NewABD =
14156         tryCombineLongOpWithDup(Intrinsic::not_intrinsic, ABDNode, DCI, DAG);
14157     if (!NewABD.getNode())
14158       return SDValue();
14159 
14160     return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), NewABD);
14161   }
14162   return SDValue();
14163 }
14164 
splitStoreSplat(SelectionDAG & DAG,StoreSDNode & St,SDValue SplatVal,unsigned NumVecElts)14165 static SDValue splitStoreSplat(SelectionDAG &DAG, StoreSDNode &St,
14166                                SDValue SplatVal, unsigned NumVecElts) {
14167   assert(!St.isTruncatingStore() && "cannot split truncating vector store");
14168   unsigned OrigAlignment = St.getAlignment();
14169   unsigned EltOffset = SplatVal.getValueType().getSizeInBits() / 8;
14170 
14171   // Create scalar stores. This is at least as good as the code sequence for a
14172   // split unaligned store which is a dup.s, ext.b, and two stores.
14173   // Most of the time the three stores should be replaced by store pair
14174   // instructions (stp).
14175   SDLoc DL(&St);
14176   SDValue BasePtr = St.getBasePtr();
14177   uint64_t BaseOffset = 0;
14178 
14179   const MachinePointerInfo &PtrInfo = St.getPointerInfo();
14180   SDValue NewST1 =
14181       DAG.getStore(St.getChain(), DL, SplatVal, BasePtr, PtrInfo,
14182                    OrigAlignment, St.getMemOperand()->getFlags());
14183 
14184   // As this in ISel, we will not merge this add which may degrade results.
14185   if (BasePtr->getOpcode() == ISD::ADD &&
14186       isa<ConstantSDNode>(BasePtr->getOperand(1))) {
14187     BaseOffset = cast<ConstantSDNode>(BasePtr->getOperand(1))->getSExtValue();
14188     BasePtr = BasePtr->getOperand(0);
14189   }
14190 
14191   unsigned Offset = EltOffset;
14192   while (--NumVecElts) {
14193     unsigned Alignment = MinAlign(OrigAlignment, Offset);
14194     SDValue OffsetPtr =
14195         DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr,
14196                     DAG.getConstant(BaseOffset + Offset, DL, MVT::i64));
14197     NewST1 = DAG.getStore(NewST1.getValue(0), DL, SplatVal, OffsetPtr,
14198                           PtrInfo.getWithOffset(Offset), Alignment,
14199                           St.getMemOperand()->getFlags());
14200     Offset += EltOffset;
14201   }
14202   return NewST1;
14203 }
14204 
14205 // Returns an SVE type that ContentTy can be trivially sign or zero extended
14206 // into.
getSVEContainerType(EVT ContentTy)14207 static MVT getSVEContainerType(EVT ContentTy) {
14208   assert(ContentTy.isSimple() && "No SVE containers for extended types");
14209 
14210   switch (ContentTy.getSimpleVT().SimpleTy) {
14211   default:
14212     llvm_unreachable("No known SVE container for this MVT type");
14213   case MVT::nxv2i8:
14214   case MVT::nxv2i16:
14215   case MVT::nxv2i32:
14216   case MVT::nxv2i64:
14217   case MVT::nxv2f32:
14218   case MVT::nxv2f64:
14219     return MVT::nxv2i64;
14220   case MVT::nxv4i8:
14221   case MVT::nxv4i16:
14222   case MVT::nxv4i32:
14223   case MVT::nxv4f32:
14224     return MVT::nxv4i32;
14225   case MVT::nxv8i8:
14226   case MVT::nxv8i16:
14227   case MVT::nxv8f16:
14228   case MVT::nxv8bf16:
14229     return MVT::nxv8i16;
14230   case MVT::nxv16i8:
14231     return MVT::nxv16i8;
14232   }
14233 }
14234 
performLD1Combine(SDNode * N,SelectionDAG & DAG,unsigned Opc)14235 static SDValue performLD1Combine(SDNode *N, SelectionDAG &DAG, unsigned Opc) {
14236   SDLoc DL(N);
14237   EVT VT = N->getValueType(0);
14238 
14239   if (VT.getSizeInBits().getKnownMinSize() > AArch64::SVEBitsPerBlock)
14240     return SDValue();
14241 
14242   EVT ContainerVT = VT;
14243   if (ContainerVT.isInteger())
14244     ContainerVT = getSVEContainerType(ContainerVT);
14245 
14246   SDVTList VTs = DAG.getVTList(ContainerVT, MVT::Other);
14247   SDValue Ops[] = { N->getOperand(0), // Chain
14248                     N->getOperand(2), // Pg
14249                     N->getOperand(3), // Base
14250                     DAG.getValueType(VT) };
14251 
14252   SDValue Load = DAG.getNode(Opc, DL, VTs, Ops);
14253   SDValue LoadChain = SDValue(Load.getNode(), 1);
14254 
14255   if (ContainerVT.isInteger() && (VT != ContainerVT))
14256     Load = DAG.getNode(ISD::TRUNCATE, DL, VT, Load.getValue(0));
14257 
14258   return DAG.getMergeValues({ Load, LoadChain }, DL);
14259 }
14260 
performLDNT1Combine(SDNode * N,SelectionDAG & DAG)14261 static SDValue performLDNT1Combine(SDNode *N, SelectionDAG &DAG) {
14262   SDLoc DL(N);
14263   EVT VT = N->getValueType(0);
14264   EVT PtrTy = N->getOperand(3).getValueType();
14265 
14266   if (VT == MVT::nxv8bf16 &&
14267       !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
14268     return SDValue();
14269 
14270   EVT LoadVT = VT;
14271   if (VT.isFloatingPoint())
14272     LoadVT = VT.changeTypeToInteger();
14273 
14274   auto *MINode = cast<MemIntrinsicSDNode>(N);
14275   SDValue PassThru = DAG.getConstant(0, DL, LoadVT);
14276   SDValue L = DAG.getMaskedLoad(LoadVT, DL, MINode->getChain(),
14277                                 MINode->getOperand(3), DAG.getUNDEF(PtrTy),
14278                                 MINode->getOperand(2), PassThru,
14279                                 MINode->getMemoryVT(), MINode->getMemOperand(),
14280                                 ISD::UNINDEXED, ISD::NON_EXTLOAD, false);
14281 
14282    if (VT.isFloatingPoint()) {
14283      SDValue Ops[] = { DAG.getNode(ISD::BITCAST, DL, VT, L), L.getValue(1) };
14284      return DAG.getMergeValues(Ops, DL);
14285    }
14286 
14287   return L;
14288 }
14289 
14290 template <unsigned Opcode>
performLD1ReplicateCombine(SDNode * N,SelectionDAG & DAG)14291 static SDValue performLD1ReplicateCombine(SDNode *N, SelectionDAG &DAG) {
14292   static_assert(Opcode == AArch64ISD::LD1RQ_MERGE_ZERO ||
14293                     Opcode == AArch64ISD::LD1RO_MERGE_ZERO,
14294                 "Unsupported opcode.");
14295   SDLoc DL(N);
14296   EVT VT = N->getValueType(0);
14297   if (VT == MVT::nxv8bf16 &&
14298       !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
14299     return SDValue();
14300 
14301   EVT LoadVT = VT;
14302   if (VT.isFloatingPoint())
14303     LoadVT = VT.changeTypeToInteger();
14304 
14305   SDValue Ops[] = {N->getOperand(0), N->getOperand(2), N->getOperand(3)};
14306   SDValue Load = DAG.getNode(Opcode, DL, {LoadVT, MVT::Other}, Ops);
14307   SDValue LoadChain = SDValue(Load.getNode(), 1);
14308 
14309   if (VT.isFloatingPoint())
14310     Load = DAG.getNode(ISD::BITCAST, DL, VT, Load.getValue(0));
14311 
14312   return DAG.getMergeValues({Load, LoadChain}, DL);
14313 }
14314 
performST1Combine(SDNode * N,SelectionDAG & DAG)14315 static SDValue performST1Combine(SDNode *N, SelectionDAG &DAG) {
14316   SDLoc DL(N);
14317   SDValue Data = N->getOperand(2);
14318   EVT DataVT = Data.getValueType();
14319   EVT HwSrcVt = getSVEContainerType(DataVT);
14320   SDValue InputVT = DAG.getValueType(DataVT);
14321 
14322   if (DataVT == MVT::nxv8bf16 &&
14323       !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
14324     return SDValue();
14325 
14326   if (DataVT.isFloatingPoint())
14327     InputVT = DAG.getValueType(HwSrcVt);
14328 
14329   SDValue SrcNew;
14330   if (Data.getValueType().isFloatingPoint())
14331     SrcNew = DAG.getNode(ISD::BITCAST, DL, HwSrcVt, Data);
14332   else
14333     SrcNew = DAG.getNode(ISD::ANY_EXTEND, DL, HwSrcVt, Data);
14334 
14335   SDValue Ops[] = { N->getOperand(0), // Chain
14336                     SrcNew,
14337                     N->getOperand(4), // Base
14338                     N->getOperand(3), // Pg
14339                     InputVT
14340                   };
14341 
14342   return DAG.getNode(AArch64ISD::ST1_PRED, DL, N->getValueType(0), Ops);
14343 }
14344 
performSTNT1Combine(SDNode * N,SelectionDAG & DAG)14345 static SDValue performSTNT1Combine(SDNode *N, SelectionDAG &DAG) {
14346   SDLoc DL(N);
14347 
14348   SDValue Data = N->getOperand(2);
14349   EVT DataVT = Data.getValueType();
14350   EVT PtrTy = N->getOperand(4).getValueType();
14351 
14352   if (DataVT == MVT::nxv8bf16 &&
14353       !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
14354     return SDValue();
14355 
14356   if (DataVT.isFloatingPoint())
14357     Data = DAG.getNode(ISD::BITCAST, DL, DataVT.changeTypeToInteger(), Data);
14358 
14359   auto *MINode = cast<MemIntrinsicSDNode>(N);
14360   return DAG.getMaskedStore(MINode->getChain(), DL, Data, MINode->getOperand(4),
14361                             DAG.getUNDEF(PtrTy), MINode->getOperand(3),
14362                             MINode->getMemoryVT(), MINode->getMemOperand(),
14363                             ISD::UNINDEXED, false, false);
14364 }
14365 
14366 /// Replace a splat of zeros to a vector store by scalar stores of WZR/XZR.  The
14367 /// load store optimizer pass will merge them to store pair stores.  This should
14368 /// be better than a movi to create the vector zero followed by a vector store
14369 /// if the zero constant is not re-used, since one instructions and one register
14370 /// live range will be removed.
14371 ///
14372 /// For example, the final generated code should be:
14373 ///
14374 ///   stp xzr, xzr, [x0]
14375 ///
14376 /// instead of:
14377 ///
14378 ///   movi v0.2d, #0
14379 ///   str q0, [x0]
14380 ///
replaceZeroVectorStore(SelectionDAG & DAG,StoreSDNode & St)14381 static SDValue replaceZeroVectorStore(SelectionDAG &DAG, StoreSDNode &St) {
14382   SDValue StVal = St.getValue();
14383   EVT VT = StVal.getValueType();
14384 
14385   // Avoid scalarizing zero splat stores for scalable vectors.
14386   if (VT.isScalableVector())
14387     return SDValue();
14388 
14389   // It is beneficial to scalarize a zero splat store for 2 or 3 i64 elements or
14390   // 2, 3 or 4 i32 elements.
14391   int NumVecElts = VT.getVectorNumElements();
14392   if (!(((NumVecElts == 2 || NumVecElts == 3) &&
14393          VT.getVectorElementType().getSizeInBits() == 64) ||
14394         ((NumVecElts == 2 || NumVecElts == 3 || NumVecElts == 4) &&
14395          VT.getVectorElementType().getSizeInBits() == 32)))
14396     return SDValue();
14397 
14398   if (StVal.getOpcode() != ISD::BUILD_VECTOR)
14399     return SDValue();
14400 
14401   // If the zero constant has more than one use then the vector store could be
14402   // better since the constant mov will be amortized and stp q instructions
14403   // should be able to be formed.
14404   if (!StVal.hasOneUse())
14405     return SDValue();
14406 
14407   // If the store is truncating then it's going down to i16 or smaller, which
14408   // means it can be implemented in a single store anyway.
14409   if (St.isTruncatingStore())
14410     return SDValue();
14411 
14412   // If the immediate offset of the address operand is too large for the stp
14413   // instruction, then bail out.
14414   if (DAG.isBaseWithConstantOffset(St.getBasePtr())) {
14415     int64_t Offset = St.getBasePtr()->getConstantOperandVal(1);
14416     if (Offset < -512 || Offset > 504)
14417       return SDValue();
14418   }
14419 
14420   for (int I = 0; I < NumVecElts; ++I) {
14421     SDValue EltVal = StVal.getOperand(I);
14422     if (!isNullConstant(EltVal) && !isNullFPConstant(EltVal))
14423       return SDValue();
14424   }
14425 
14426   // Use a CopyFromReg WZR/XZR here to prevent
14427   // DAGCombiner::MergeConsecutiveStores from undoing this transformation.
14428   SDLoc DL(&St);
14429   unsigned ZeroReg;
14430   EVT ZeroVT;
14431   if (VT.getVectorElementType().getSizeInBits() == 32) {
14432     ZeroReg = AArch64::WZR;
14433     ZeroVT = MVT::i32;
14434   } else {
14435     ZeroReg = AArch64::XZR;
14436     ZeroVT = MVT::i64;
14437   }
14438   SDValue SplatVal =
14439       DAG.getCopyFromReg(DAG.getEntryNode(), DL, ZeroReg, ZeroVT);
14440   return splitStoreSplat(DAG, St, SplatVal, NumVecElts);
14441 }
14442 
14443 /// Replace a splat of a scalar to a vector store by scalar stores of the scalar
14444 /// value. The load store optimizer pass will merge them to store pair stores.
14445 /// This has better performance than a splat of the scalar followed by a split
14446 /// vector store. Even if the stores are not merged it is four stores vs a dup,
14447 /// followed by an ext.b and two stores.
replaceSplatVectorStore(SelectionDAG & DAG,StoreSDNode & St)14448 static SDValue replaceSplatVectorStore(SelectionDAG &DAG, StoreSDNode &St) {
14449   SDValue StVal = St.getValue();
14450   EVT VT = StVal.getValueType();
14451 
14452   // Don't replace floating point stores, they possibly won't be transformed to
14453   // stp because of the store pair suppress pass.
14454   if (VT.isFloatingPoint())
14455     return SDValue();
14456 
14457   // We can express a splat as store pair(s) for 2 or 4 elements.
14458   unsigned NumVecElts = VT.getVectorNumElements();
14459   if (NumVecElts != 4 && NumVecElts != 2)
14460     return SDValue();
14461 
14462   // If the store is truncating then it's going down to i16 or smaller, which
14463   // means it can be implemented in a single store anyway.
14464   if (St.isTruncatingStore())
14465     return SDValue();
14466 
14467   // Check that this is a splat.
14468   // Make sure that each of the relevant vector element locations are inserted
14469   // to, i.e. 0 and 1 for v2i64 and 0, 1, 2, 3 for v4i32.
14470   std::bitset<4> IndexNotInserted((1 << NumVecElts) - 1);
14471   SDValue SplatVal;
14472   for (unsigned I = 0; I < NumVecElts; ++I) {
14473     // Check for insert vector elements.
14474     if (StVal.getOpcode() != ISD::INSERT_VECTOR_ELT)
14475       return SDValue();
14476 
14477     // Check that same value is inserted at each vector element.
14478     if (I == 0)
14479       SplatVal = StVal.getOperand(1);
14480     else if (StVal.getOperand(1) != SplatVal)
14481       return SDValue();
14482 
14483     // Check insert element index.
14484     ConstantSDNode *CIndex = dyn_cast<ConstantSDNode>(StVal.getOperand(2));
14485     if (!CIndex)
14486       return SDValue();
14487     uint64_t IndexVal = CIndex->getZExtValue();
14488     if (IndexVal >= NumVecElts)
14489       return SDValue();
14490     IndexNotInserted.reset(IndexVal);
14491 
14492     StVal = StVal.getOperand(0);
14493   }
14494   // Check that all vector element locations were inserted to.
14495   if (IndexNotInserted.any())
14496       return SDValue();
14497 
14498   return splitStoreSplat(DAG, St, SplatVal, NumVecElts);
14499 }
14500 
splitStores(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)14501 static SDValue splitStores(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
14502                            SelectionDAG &DAG,
14503                            const AArch64Subtarget *Subtarget) {
14504 
14505   StoreSDNode *S = cast<StoreSDNode>(N);
14506   if (S->isVolatile() || S->isIndexed())
14507     return SDValue();
14508 
14509   SDValue StVal = S->getValue();
14510   EVT VT = StVal.getValueType();
14511 
14512   if (!VT.isFixedLengthVector())
14513     return SDValue();
14514 
14515   // If we get a splat of zeros, convert this vector store to a store of
14516   // scalars. They will be merged into store pairs of xzr thereby removing one
14517   // instruction and one register.
14518   if (SDValue ReplacedZeroSplat = replaceZeroVectorStore(DAG, *S))
14519     return ReplacedZeroSplat;
14520 
14521   // FIXME: The logic for deciding if an unaligned store should be split should
14522   // be included in TLI.allowsMisalignedMemoryAccesses(), and there should be
14523   // a call to that function here.
14524 
14525   if (!Subtarget->isMisaligned128StoreSlow())
14526     return SDValue();
14527 
14528   // Don't split at -Oz.
14529   if (DAG.getMachineFunction().getFunction().hasMinSize())
14530     return SDValue();
14531 
14532   // Don't split v2i64 vectors. Memcpy lowering produces those and splitting
14533   // those up regresses performance on micro-benchmarks and olden/bh.
14534   if (VT.getVectorNumElements() < 2 || VT == MVT::v2i64)
14535     return SDValue();
14536 
14537   // Split unaligned 16B stores. They are terrible for performance.
14538   // Don't split stores with alignment of 1 or 2. Code that uses clang vector
14539   // extensions can use this to mark that it does not want splitting to happen
14540   // (by underspecifying alignment to be 1 or 2). Furthermore, the chance of
14541   // eliminating alignment hazards is only 1 in 8 for alignment of 2.
14542   if (VT.getSizeInBits() != 128 || S->getAlignment() >= 16 ||
14543       S->getAlignment() <= 2)
14544     return SDValue();
14545 
14546   // If we get a splat of a scalar convert this vector store to a store of
14547   // scalars. They will be merged into store pairs thereby removing two
14548   // instructions.
14549   if (SDValue ReplacedSplat = replaceSplatVectorStore(DAG, *S))
14550     return ReplacedSplat;
14551 
14552   SDLoc DL(S);
14553 
14554   // Split VT into two.
14555   EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
14556   unsigned NumElts = HalfVT.getVectorNumElements();
14557   SDValue SubVector0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, StVal,
14558                                    DAG.getConstant(0, DL, MVT::i64));
14559   SDValue SubVector1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, StVal,
14560                                    DAG.getConstant(NumElts, DL, MVT::i64));
14561   SDValue BasePtr = S->getBasePtr();
14562   SDValue NewST1 =
14563       DAG.getStore(S->getChain(), DL, SubVector0, BasePtr, S->getPointerInfo(),
14564                    S->getAlignment(), S->getMemOperand()->getFlags());
14565   SDValue OffsetPtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr,
14566                                   DAG.getConstant(8, DL, MVT::i64));
14567   return DAG.getStore(NewST1.getValue(0), DL, SubVector1, OffsetPtr,
14568                       S->getPointerInfo(), S->getAlignment(),
14569                       S->getMemOperand()->getFlags());
14570 }
14571 
performUzpCombine(SDNode * N,SelectionDAG & DAG)14572 static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG) {
14573   SDLoc DL(N);
14574   SDValue Op0 = N->getOperand(0);
14575   SDValue Op1 = N->getOperand(1);
14576   EVT ResVT = N->getValueType(0);
14577 
14578   // uzp1(unpklo(uzp1(x, y)), z) => uzp1(x, z)
14579   if (Op0.getOpcode() == AArch64ISD::UUNPKLO) {
14580     if (Op0.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
14581       SDValue X = Op0.getOperand(0).getOperand(0);
14582       return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, X, Op1);
14583     }
14584   }
14585 
14586   // uzp1(x, unpkhi(uzp1(y, z))) => uzp1(x, z)
14587   if (Op1.getOpcode() == AArch64ISD::UUNPKHI) {
14588     if (Op1.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
14589       SDValue Z = Op1.getOperand(0).getOperand(1);
14590       return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Z);
14591     }
14592   }
14593 
14594   return SDValue();
14595 }
14596 
performGLD1Combine(SDNode * N,SelectionDAG & DAG)14597 static SDValue performGLD1Combine(SDNode *N, SelectionDAG &DAG) {
14598   unsigned Opc = N->getOpcode();
14599 
14600   assert(((Opc >= AArch64ISD::GLD1_MERGE_ZERO && // unsigned gather loads
14601            Opc <= AArch64ISD::GLD1_IMM_MERGE_ZERO) ||
14602           (Opc >= AArch64ISD::GLD1S_MERGE_ZERO && // signed gather loads
14603            Opc <= AArch64ISD::GLD1S_IMM_MERGE_ZERO)) &&
14604          "Invalid opcode.");
14605 
14606   const bool Scaled = Opc == AArch64ISD::GLD1_SCALED_MERGE_ZERO ||
14607                       Opc == AArch64ISD::GLD1S_SCALED_MERGE_ZERO;
14608   const bool Signed = Opc == AArch64ISD::GLD1S_MERGE_ZERO ||
14609                       Opc == AArch64ISD::GLD1S_SCALED_MERGE_ZERO;
14610   const bool Extended = Opc == AArch64ISD::GLD1_SXTW_MERGE_ZERO ||
14611                         Opc == AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO ||
14612                         Opc == AArch64ISD::GLD1_UXTW_MERGE_ZERO ||
14613                         Opc == AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO;
14614 
14615   SDLoc DL(N);
14616   SDValue Chain = N->getOperand(0);
14617   SDValue Pg = N->getOperand(1);
14618   SDValue Base = N->getOperand(2);
14619   SDValue Offset = N->getOperand(3);
14620   SDValue Ty = N->getOperand(4);
14621 
14622   EVT ResVT = N->getValueType(0);
14623 
14624   const auto OffsetOpc = Offset.getOpcode();
14625   const bool OffsetIsZExt =
14626       OffsetOpc == AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU;
14627   const bool OffsetIsSExt =
14628       OffsetOpc == AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU;
14629 
14630   // Fold sign/zero extensions of vector offsets into GLD1 nodes where possible.
14631   if (!Extended && (OffsetIsSExt || OffsetIsZExt)) {
14632     SDValue ExtPg = Offset.getOperand(0);
14633     VTSDNode *ExtFrom = cast<VTSDNode>(Offset.getOperand(2).getNode());
14634     EVT ExtFromEVT = ExtFrom->getVT().getVectorElementType();
14635 
14636     // If the predicate for the sign- or zero-extended offset is the
14637     // same as the predicate used for this load and the sign-/zero-extension
14638     // was from a 32-bits...
14639     if (ExtPg == Pg && ExtFromEVT == MVT::i32) {
14640       SDValue UnextendedOffset = Offset.getOperand(1);
14641 
14642       unsigned NewOpc = getGatherVecOpcode(Scaled, OffsetIsSExt, true);
14643       if (Signed)
14644         NewOpc = getSignExtendedGatherOpcode(NewOpc);
14645 
14646       return DAG.getNode(NewOpc, DL, {ResVT, MVT::Other},
14647                          {Chain, Pg, Base, UnextendedOffset, Ty});
14648     }
14649   }
14650 
14651   return SDValue();
14652 }
14653 
14654 /// Optimize a vector shift instruction and its operand if shifted out
14655 /// bits are not used.
performVectorShiftCombine(SDNode * N,const AArch64TargetLowering & TLI,TargetLowering::DAGCombinerInfo & DCI)14656 static SDValue performVectorShiftCombine(SDNode *N,
14657                                          const AArch64TargetLowering &TLI,
14658                                          TargetLowering::DAGCombinerInfo &DCI) {
14659   assert(N->getOpcode() == AArch64ISD::VASHR ||
14660          N->getOpcode() == AArch64ISD::VLSHR);
14661 
14662   SDValue Op = N->getOperand(0);
14663   unsigned OpScalarSize = Op.getScalarValueSizeInBits();
14664 
14665   unsigned ShiftImm = N->getConstantOperandVal(1);
14666   assert(OpScalarSize > ShiftImm && "Invalid shift imm");
14667 
14668   APInt ShiftedOutBits = APInt::getLowBitsSet(OpScalarSize, ShiftImm);
14669   APInt DemandedMask = ~ShiftedOutBits;
14670 
14671   if (TLI.SimplifyDemandedBits(Op, DemandedMask, DCI))
14672     return SDValue(N, 0);
14673 
14674   return SDValue();
14675 }
14676 
14677 /// Target-specific DAG combine function for post-increment LD1 (lane) and
14678 /// post-increment LD1R.
performPostLD1Combine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,bool IsLaneOp)14679 static SDValue performPostLD1Combine(SDNode *N,
14680                                      TargetLowering::DAGCombinerInfo &DCI,
14681                                      bool IsLaneOp) {
14682   if (DCI.isBeforeLegalizeOps())
14683     return SDValue();
14684 
14685   SelectionDAG &DAG = DCI.DAG;
14686   EVT VT = N->getValueType(0);
14687 
14688   if (VT.isScalableVector())
14689     return SDValue();
14690 
14691   unsigned LoadIdx = IsLaneOp ? 1 : 0;
14692   SDNode *LD = N->getOperand(LoadIdx).getNode();
14693   // If it is not LOAD, can not do such combine.
14694   if (LD->getOpcode() != ISD::LOAD)
14695     return SDValue();
14696 
14697   // The vector lane must be a constant in the LD1LANE opcode.
14698   SDValue Lane;
14699   if (IsLaneOp) {
14700     Lane = N->getOperand(2);
14701     auto *LaneC = dyn_cast<ConstantSDNode>(Lane);
14702     if (!LaneC || LaneC->getZExtValue() >= VT.getVectorNumElements())
14703       return SDValue();
14704   }
14705 
14706   LoadSDNode *LoadSDN = cast<LoadSDNode>(LD);
14707   EVT MemVT = LoadSDN->getMemoryVT();
14708   // Check if memory operand is the same type as the vector element.
14709   if (MemVT != VT.getVectorElementType())
14710     return SDValue();
14711 
14712   // Check if there are other uses. If so, do not combine as it will introduce
14713   // an extra load.
14714   for (SDNode::use_iterator UI = LD->use_begin(), UE = LD->use_end(); UI != UE;
14715        ++UI) {
14716     if (UI.getUse().getResNo() == 1) // Ignore uses of the chain result.
14717       continue;
14718     if (*UI != N)
14719       return SDValue();
14720   }
14721 
14722   SDValue Addr = LD->getOperand(1);
14723   SDValue Vector = N->getOperand(0);
14724   // Search for a use of the address operand that is an increment.
14725   for (SDNode::use_iterator UI = Addr.getNode()->use_begin(), UE =
14726        Addr.getNode()->use_end(); UI != UE; ++UI) {
14727     SDNode *User = *UI;
14728     if (User->getOpcode() != ISD::ADD
14729         || UI.getUse().getResNo() != Addr.getResNo())
14730       continue;
14731 
14732     // If the increment is a constant, it must match the memory ref size.
14733     SDValue Inc = User->getOperand(User->getOperand(0) == Addr ? 1 : 0);
14734     if (ConstantSDNode *CInc = dyn_cast<ConstantSDNode>(Inc.getNode())) {
14735       uint32_t IncVal = CInc->getZExtValue();
14736       unsigned NumBytes = VT.getScalarSizeInBits() / 8;
14737       if (IncVal != NumBytes)
14738         continue;
14739       Inc = DAG.getRegister(AArch64::XZR, MVT::i64);
14740     }
14741 
14742     // To avoid cycle construction make sure that neither the load nor the add
14743     // are predecessors to each other or the Vector.
14744     SmallPtrSet<const SDNode *, 32> Visited;
14745     SmallVector<const SDNode *, 16> Worklist;
14746     Visited.insert(Addr.getNode());
14747     Worklist.push_back(User);
14748     Worklist.push_back(LD);
14749     Worklist.push_back(Vector.getNode());
14750     if (SDNode::hasPredecessorHelper(LD, Visited, Worklist) ||
14751         SDNode::hasPredecessorHelper(User, Visited, Worklist))
14752       continue;
14753 
14754     SmallVector<SDValue, 8> Ops;
14755     Ops.push_back(LD->getOperand(0));  // Chain
14756     if (IsLaneOp) {
14757       Ops.push_back(Vector);           // The vector to be inserted
14758       Ops.push_back(Lane);             // The lane to be inserted in the vector
14759     }
14760     Ops.push_back(Addr);
14761     Ops.push_back(Inc);
14762 
14763     EVT Tys[3] = { VT, MVT::i64, MVT::Other };
14764     SDVTList SDTys = DAG.getVTList(Tys);
14765     unsigned NewOp = IsLaneOp ? AArch64ISD::LD1LANEpost : AArch64ISD::LD1DUPpost;
14766     SDValue UpdN = DAG.getMemIntrinsicNode(NewOp, SDLoc(N), SDTys, Ops,
14767                                            MemVT,
14768                                            LoadSDN->getMemOperand());
14769 
14770     // Update the uses.
14771     SDValue NewResults[] = {
14772         SDValue(LD, 0),            // The result of load
14773         SDValue(UpdN.getNode(), 2) // Chain
14774     };
14775     DCI.CombineTo(LD, NewResults);
14776     DCI.CombineTo(N, SDValue(UpdN.getNode(), 0));     // Dup/Inserted Result
14777     DCI.CombineTo(User, SDValue(UpdN.getNode(), 1));  // Write back register
14778 
14779     break;
14780   }
14781   return SDValue();
14782 }
14783 
14784 /// Simplify ``Addr`` given that the top byte of it is ignored by HW during
14785 /// address translation.
performTBISimplification(SDValue Addr,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)14786 static bool performTBISimplification(SDValue Addr,
14787                                      TargetLowering::DAGCombinerInfo &DCI,
14788                                      SelectionDAG &DAG) {
14789   APInt DemandedMask = APInt::getLowBitsSet(64, 56);
14790   KnownBits Known;
14791   TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
14792                                         !DCI.isBeforeLegalizeOps());
14793   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
14794   if (TLI.SimplifyDemandedBits(Addr, DemandedMask, Known, TLO)) {
14795     DCI.CommitTargetLoweringOpt(TLO);
14796     return true;
14797   }
14798   return false;
14799 }
14800 
performSTORECombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)14801 static SDValue performSTORECombine(SDNode *N,
14802                                    TargetLowering::DAGCombinerInfo &DCI,
14803                                    SelectionDAG &DAG,
14804                                    const AArch64Subtarget *Subtarget) {
14805   if (SDValue Split = splitStores(N, DCI, DAG, Subtarget))
14806     return Split;
14807 
14808   if (Subtarget->supportsAddressTopByteIgnored() &&
14809       performTBISimplification(N->getOperand(2), DCI, DAG))
14810     return SDValue(N, 0);
14811 
14812   return SDValue();
14813 }
14814 
14815 /// Target-specific DAG combine function for NEON load/store intrinsics
14816 /// to merge base address updates.
performNEONPostLDSTCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)14817 static SDValue performNEONPostLDSTCombine(SDNode *N,
14818                                           TargetLowering::DAGCombinerInfo &DCI,
14819                                           SelectionDAG &DAG) {
14820   if (DCI.isBeforeLegalize() || DCI.isCalledByLegalizer())
14821     return SDValue();
14822 
14823   unsigned AddrOpIdx = N->getNumOperands() - 1;
14824   SDValue Addr = N->getOperand(AddrOpIdx);
14825 
14826   // Search for a use of the address operand that is an increment.
14827   for (SDNode::use_iterator UI = Addr.getNode()->use_begin(),
14828        UE = Addr.getNode()->use_end(); UI != UE; ++UI) {
14829     SDNode *User = *UI;
14830     if (User->getOpcode() != ISD::ADD ||
14831         UI.getUse().getResNo() != Addr.getResNo())
14832       continue;
14833 
14834     // Check that the add is independent of the load/store.  Otherwise, folding
14835     // it would create a cycle.
14836     SmallPtrSet<const SDNode *, 32> Visited;
14837     SmallVector<const SDNode *, 16> Worklist;
14838     Visited.insert(Addr.getNode());
14839     Worklist.push_back(N);
14840     Worklist.push_back(User);
14841     if (SDNode::hasPredecessorHelper(N, Visited, Worklist) ||
14842         SDNode::hasPredecessorHelper(User, Visited, Worklist))
14843       continue;
14844 
14845     // Find the new opcode for the updating load/store.
14846     bool IsStore = false;
14847     bool IsLaneOp = false;
14848     bool IsDupOp = false;
14849     unsigned NewOpc = 0;
14850     unsigned NumVecs = 0;
14851     unsigned IntNo = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
14852     switch (IntNo) {
14853     default: llvm_unreachable("unexpected intrinsic for Neon base update");
14854     case Intrinsic::aarch64_neon_ld2:       NewOpc = AArch64ISD::LD2post;
14855       NumVecs = 2; break;
14856     case Intrinsic::aarch64_neon_ld3:       NewOpc = AArch64ISD::LD3post;
14857       NumVecs = 3; break;
14858     case Intrinsic::aarch64_neon_ld4:       NewOpc = AArch64ISD::LD4post;
14859       NumVecs = 4; break;
14860     case Intrinsic::aarch64_neon_st2:       NewOpc = AArch64ISD::ST2post;
14861       NumVecs = 2; IsStore = true; break;
14862     case Intrinsic::aarch64_neon_st3:       NewOpc = AArch64ISD::ST3post;
14863       NumVecs = 3; IsStore = true; break;
14864     case Intrinsic::aarch64_neon_st4:       NewOpc = AArch64ISD::ST4post;
14865       NumVecs = 4; IsStore = true; break;
14866     case Intrinsic::aarch64_neon_ld1x2:     NewOpc = AArch64ISD::LD1x2post;
14867       NumVecs = 2; break;
14868     case Intrinsic::aarch64_neon_ld1x3:     NewOpc = AArch64ISD::LD1x3post;
14869       NumVecs = 3; break;
14870     case Intrinsic::aarch64_neon_ld1x4:     NewOpc = AArch64ISD::LD1x4post;
14871       NumVecs = 4; break;
14872     case Intrinsic::aarch64_neon_st1x2:     NewOpc = AArch64ISD::ST1x2post;
14873       NumVecs = 2; IsStore = true; break;
14874     case Intrinsic::aarch64_neon_st1x3:     NewOpc = AArch64ISD::ST1x3post;
14875       NumVecs = 3; IsStore = true; break;
14876     case Intrinsic::aarch64_neon_st1x4:     NewOpc = AArch64ISD::ST1x4post;
14877       NumVecs = 4; IsStore = true; break;
14878     case Intrinsic::aarch64_neon_ld2r:      NewOpc = AArch64ISD::LD2DUPpost;
14879       NumVecs = 2; IsDupOp = true; break;
14880     case Intrinsic::aarch64_neon_ld3r:      NewOpc = AArch64ISD::LD3DUPpost;
14881       NumVecs = 3; IsDupOp = true; break;
14882     case Intrinsic::aarch64_neon_ld4r:      NewOpc = AArch64ISD::LD4DUPpost;
14883       NumVecs = 4; IsDupOp = true; break;
14884     case Intrinsic::aarch64_neon_ld2lane:   NewOpc = AArch64ISD::LD2LANEpost;
14885       NumVecs = 2; IsLaneOp = true; break;
14886     case Intrinsic::aarch64_neon_ld3lane:   NewOpc = AArch64ISD::LD3LANEpost;
14887       NumVecs = 3; IsLaneOp = true; break;
14888     case Intrinsic::aarch64_neon_ld4lane:   NewOpc = AArch64ISD::LD4LANEpost;
14889       NumVecs = 4; IsLaneOp = true; break;
14890     case Intrinsic::aarch64_neon_st2lane:   NewOpc = AArch64ISD::ST2LANEpost;
14891       NumVecs = 2; IsStore = true; IsLaneOp = true; break;
14892     case Intrinsic::aarch64_neon_st3lane:   NewOpc = AArch64ISD::ST3LANEpost;
14893       NumVecs = 3; IsStore = true; IsLaneOp = true; break;
14894     case Intrinsic::aarch64_neon_st4lane:   NewOpc = AArch64ISD::ST4LANEpost;
14895       NumVecs = 4; IsStore = true; IsLaneOp = true; break;
14896     }
14897 
14898     EVT VecTy;
14899     if (IsStore)
14900       VecTy = N->getOperand(2).getValueType();
14901     else
14902       VecTy = N->getValueType(0);
14903 
14904     // If the increment is a constant, it must match the memory ref size.
14905     SDValue Inc = User->getOperand(User->getOperand(0) == Addr ? 1 : 0);
14906     if (ConstantSDNode *CInc = dyn_cast<ConstantSDNode>(Inc.getNode())) {
14907       uint32_t IncVal = CInc->getZExtValue();
14908       unsigned NumBytes = NumVecs * VecTy.getSizeInBits() / 8;
14909       if (IsLaneOp || IsDupOp)
14910         NumBytes /= VecTy.getVectorNumElements();
14911       if (IncVal != NumBytes)
14912         continue;
14913       Inc = DAG.getRegister(AArch64::XZR, MVT::i64);
14914     }
14915     SmallVector<SDValue, 8> Ops;
14916     Ops.push_back(N->getOperand(0)); // Incoming chain
14917     // Load lane and store have vector list as input.
14918     if (IsLaneOp || IsStore)
14919       for (unsigned i = 2; i < AddrOpIdx; ++i)
14920         Ops.push_back(N->getOperand(i));
14921     Ops.push_back(Addr); // Base register
14922     Ops.push_back(Inc);
14923 
14924     // Return Types.
14925     EVT Tys[6];
14926     unsigned NumResultVecs = (IsStore ? 0 : NumVecs);
14927     unsigned n;
14928     for (n = 0; n < NumResultVecs; ++n)
14929       Tys[n] = VecTy;
14930     Tys[n++] = MVT::i64;  // Type of write back register
14931     Tys[n] = MVT::Other;  // Type of the chain
14932     SDVTList SDTys = DAG.getVTList(makeArrayRef(Tys, NumResultVecs + 2));
14933 
14934     MemIntrinsicSDNode *MemInt = cast<MemIntrinsicSDNode>(N);
14935     SDValue UpdN = DAG.getMemIntrinsicNode(NewOpc, SDLoc(N), SDTys, Ops,
14936                                            MemInt->getMemoryVT(),
14937                                            MemInt->getMemOperand());
14938 
14939     // Update the uses.
14940     std::vector<SDValue> NewResults;
14941     for (unsigned i = 0; i < NumResultVecs; ++i) {
14942       NewResults.push_back(SDValue(UpdN.getNode(), i));
14943     }
14944     NewResults.push_back(SDValue(UpdN.getNode(), NumResultVecs + 1));
14945     DCI.CombineTo(N, NewResults);
14946     DCI.CombineTo(User, SDValue(UpdN.getNode(), NumResultVecs));
14947 
14948     break;
14949   }
14950   return SDValue();
14951 }
14952 
14953 // Checks to see if the value is the prescribed width and returns information
14954 // about its extension mode.
14955 static
checkValueWidth(SDValue V,unsigned width,ISD::LoadExtType & ExtType)14956 bool checkValueWidth(SDValue V, unsigned width, ISD::LoadExtType &ExtType) {
14957   ExtType = ISD::NON_EXTLOAD;
14958   switch(V.getNode()->getOpcode()) {
14959   default:
14960     return false;
14961   case ISD::LOAD: {
14962     LoadSDNode *LoadNode = cast<LoadSDNode>(V.getNode());
14963     if ((LoadNode->getMemoryVT() == MVT::i8 && width == 8)
14964        || (LoadNode->getMemoryVT() == MVT::i16 && width == 16)) {
14965       ExtType = LoadNode->getExtensionType();
14966       return true;
14967     }
14968     return false;
14969   }
14970   case ISD::AssertSext: {
14971     VTSDNode *TypeNode = cast<VTSDNode>(V.getNode()->getOperand(1));
14972     if ((TypeNode->getVT() == MVT::i8 && width == 8)
14973        || (TypeNode->getVT() == MVT::i16 && width == 16)) {
14974       ExtType = ISD::SEXTLOAD;
14975       return true;
14976     }
14977     return false;
14978   }
14979   case ISD::AssertZext: {
14980     VTSDNode *TypeNode = cast<VTSDNode>(V.getNode()->getOperand(1));
14981     if ((TypeNode->getVT() == MVT::i8 && width == 8)
14982        || (TypeNode->getVT() == MVT::i16 && width == 16)) {
14983       ExtType = ISD::ZEXTLOAD;
14984       return true;
14985     }
14986     return false;
14987   }
14988   case ISD::Constant:
14989   case ISD::TargetConstant: {
14990     return std::abs(cast<ConstantSDNode>(V.getNode())->getSExtValue()) <
14991            1LL << (width - 1);
14992   }
14993   }
14994 
14995   return true;
14996 }
14997 
14998 // This function does a whole lot of voodoo to determine if the tests are
14999 // equivalent without and with a mask. Essentially what happens is that given a
15000 // DAG resembling:
15001 //
15002 //  +-------------+ +-------------+ +-------------+ +-------------+
15003 //  |    Input    | | AddConstant | | CompConstant| |     CC      |
15004 //  +-------------+ +-------------+ +-------------+ +-------------+
15005 //           |           |           |               |
15006 //           V           V           |    +----------+
15007 //          +-------------+  +----+  |    |
15008 //          |     ADD     |  |0xff|  |    |
15009 //          +-------------+  +----+  |    |
15010 //                  |           |    |    |
15011 //                  V           V    |    |
15012 //                 +-------------+   |    |
15013 //                 |     AND     |   |    |
15014 //                 +-------------+   |    |
15015 //                      |            |    |
15016 //                      +-----+      |    |
15017 //                            |      |    |
15018 //                            V      V    V
15019 //                           +-------------+
15020 //                           |     CMP     |
15021 //                           +-------------+
15022 //
15023 // The AND node may be safely removed for some combinations of inputs. In
15024 // particular we need to take into account the extension type of the Input,
15025 // the exact values of AddConstant, CompConstant, and CC, along with the nominal
15026 // width of the input (this can work for any width inputs, the above graph is
15027 // specific to 8 bits.
15028 //
15029 // The specific equations were worked out by generating output tables for each
15030 // AArch64CC value in terms of and AddConstant (w1), CompConstant(w2). The
15031 // problem was simplified by working with 4 bit inputs, which means we only
15032 // needed to reason about 24 distinct bit patterns: 8 patterns unique to zero
15033 // extension (8,15), 8 patterns unique to sign extensions (-8,-1), and 8
15034 // patterns present in both extensions (0,7). For every distinct set of
15035 // AddConstant and CompConstants bit patterns we can consider the masked and
15036 // unmasked versions to be equivalent if the result of this function is true for
15037 // all 16 distinct bit patterns of for the current extension type of Input (w0).
15038 //
15039 //   sub      w8, w0, w1
15040 //   and      w10, w8, #0x0f
15041 //   cmp      w8, w2
15042 //   cset     w9, AArch64CC
15043 //   cmp      w10, w2
15044 //   cset     w11, AArch64CC
15045 //   cmp      w9, w11
15046 //   cset     w0, eq
15047 //   ret
15048 //
15049 // Since the above function shows when the outputs are equivalent it defines
15050 // when it is safe to remove the AND. Unfortunately it only runs on AArch64 and
15051 // would be expensive to run during compiles. The equations below were written
15052 // in a test harness that confirmed they gave equivalent outputs to the above
15053 // for all inputs function, so they can be used determine if the removal is
15054 // legal instead.
15055 //
15056 // isEquivalentMaskless() is the code for testing if the AND can be removed
15057 // factored out of the DAG recognition as the DAG can take several forms.
15058 
isEquivalentMaskless(unsigned CC,unsigned width,ISD::LoadExtType ExtType,int AddConstant,int CompConstant)15059 static bool isEquivalentMaskless(unsigned CC, unsigned width,
15060                                  ISD::LoadExtType ExtType, int AddConstant,
15061                                  int CompConstant) {
15062   // By being careful about our equations and only writing the in term
15063   // symbolic values and well known constants (0, 1, -1, MaxUInt) we can
15064   // make them generally applicable to all bit widths.
15065   int MaxUInt = (1 << width);
15066 
15067   // For the purposes of these comparisons sign extending the type is
15068   // equivalent to zero extending the add and displacing it by half the integer
15069   // width. Provided we are careful and make sure our equations are valid over
15070   // the whole range we can just adjust the input and avoid writing equations
15071   // for sign extended inputs.
15072   if (ExtType == ISD::SEXTLOAD)
15073     AddConstant -= (1 << (width-1));
15074 
15075   switch(CC) {
15076   case AArch64CC::LE:
15077   case AArch64CC::GT:
15078     if ((AddConstant == 0) ||
15079         (CompConstant == MaxUInt - 1 && AddConstant < 0) ||
15080         (AddConstant >= 0 && CompConstant < 0) ||
15081         (AddConstant <= 0 && CompConstant <= 0 && CompConstant < AddConstant))
15082       return true;
15083     break;
15084   case AArch64CC::LT:
15085   case AArch64CC::GE:
15086     if ((AddConstant == 0) ||
15087         (AddConstant >= 0 && CompConstant <= 0) ||
15088         (AddConstant <= 0 && CompConstant <= 0 && CompConstant <= AddConstant))
15089       return true;
15090     break;
15091   case AArch64CC::HI:
15092   case AArch64CC::LS:
15093     if ((AddConstant >= 0 && CompConstant < 0) ||
15094        (AddConstant <= 0 && CompConstant >= -1 &&
15095         CompConstant < AddConstant + MaxUInt))
15096       return true;
15097    break;
15098   case AArch64CC::PL:
15099   case AArch64CC::MI:
15100     if ((AddConstant == 0) ||
15101         (AddConstant > 0 && CompConstant <= 0) ||
15102         (AddConstant < 0 && CompConstant <= AddConstant))
15103       return true;
15104     break;
15105   case AArch64CC::LO:
15106   case AArch64CC::HS:
15107     if ((AddConstant >= 0 && CompConstant <= 0) ||
15108         (AddConstant <= 0 && CompConstant >= 0 &&
15109          CompConstant <= AddConstant + MaxUInt))
15110       return true;
15111     break;
15112   case AArch64CC::EQ:
15113   case AArch64CC::NE:
15114     if ((AddConstant > 0 && CompConstant < 0) ||
15115         (AddConstant < 0 && CompConstant >= 0 &&
15116          CompConstant < AddConstant + MaxUInt) ||
15117         (AddConstant >= 0 && CompConstant >= 0 &&
15118          CompConstant >= AddConstant) ||
15119         (AddConstant <= 0 && CompConstant < 0 && CompConstant < AddConstant))
15120       return true;
15121     break;
15122   case AArch64CC::VS:
15123   case AArch64CC::VC:
15124   case AArch64CC::AL:
15125   case AArch64CC::NV:
15126     return true;
15127   case AArch64CC::Invalid:
15128     break;
15129   }
15130 
15131   return false;
15132 }
15133 
15134 static
performCONDCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG,unsigned CCIndex,unsigned CmpIndex)15135 SDValue performCONDCombine(SDNode *N,
15136                            TargetLowering::DAGCombinerInfo &DCI,
15137                            SelectionDAG &DAG, unsigned CCIndex,
15138                            unsigned CmpIndex) {
15139   unsigned CC = cast<ConstantSDNode>(N->getOperand(CCIndex))->getSExtValue();
15140   SDNode *SubsNode = N->getOperand(CmpIndex).getNode();
15141   unsigned CondOpcode = SubsNode->getOpcode();
15142 
15143   if (CondOpcode != AArch64ISD::SUBS)
15144     return SDValue();
15145 
15146   // There is a SUBS feeding this condition. Is it fed by a mask we can
15147   // use?
15148 
15149   SDNode *AndNode = SubsNode->getOperand(0).getNode();
15150   unsigned MaskBits = 0;
15151 
15152   if (AndNode->getOpcode() != ISD::AND)
15153     return SDValue();
15154 
15155   if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(AndNode->getOperand(1))) {
15156     uint32_t CNV = CN->getZExtValue();
15157     if (CNV == 255)
15158       MaskBits = 8;
15159     else if (CNV == 65535)
15160       MaskBits = 16;
15161   }
15162 
15163   if (!MaskBits)
15164     return SDValue();
15165 
15166   SDValue AddValue = AndNode->getOperand(0);
15167 
15168   if (AddValue.getOpcode() != ISD::ADD)
15169     return SDValue();
15170 
15171   // The basic dag structure is correct, grab the inputs and validate them.
15172 
15173   SDValue AddInputValue1 = AddValue.getNode()->getOperand(0);
15174   SDValue AddInputValue2 = AddValue.getNode()->getOperand(1);
15175   SDValue SubsInputValue = SubsNode->getOperand(1);
15176 
15177   // The mask is present and the provenance of all the values is a smaller type,
15178   // lets see if the mask is superfluous.
15179 
15180   if (!isa<ConstantSDNode>(AddInputValue2.getNode()) ||
15181       !isa<ConstantSDNode>(SubsInputValue.getNode()))
15182     return SDValue();
15183 
15184   ISD::LoadExtType ExtType;
15185 
15186   if (!checkValueWidth(SubsInputValue, MaskBits, ExtType) ||
15187       !checkValueWidth(AddInputValue2, MaskBits, ExtType) ||
15188       !checkValueWidth(AddInputValue1, MaskBits, ExtType) )
15189     return SDValue();
15190 
15191   if(!isEquivalentMaskless(CC, MaskBits, ExtType,
15192                 cast<ConstantSDNode>(AddInputValue2.getNode())->getSExtValue(),
15193                 cast<ConstantSDNode>(SubsInputValue.getNode())->getSExtValue()))
15194     return SDValue();
15195 
15196   // The AND is not necessary, remove it.
15197 
15198   SDVTList VTs = DAG.getVTList(SubsNode->getValueType(0),
15199                                SubsNode->getValueType(1));
15200   SDValue Ops[] = { AddValue, SubsNode->getOperand(1) };
15201 
15202   SDValue NewValue = DAG.getNode(CondOpcode, SDLoc(SubsNode), VTs, Ops);
15203   DAG.ReplaceAllUsesWith(SubsNode, NewValue.getNode());
15204 
15205   return SDValue(N, 0);
15206 }
15207 
15208 // Optimize compare with zero and branch.
performBRCONDCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)15209 static SDValue performBRCONDCombine(SDNode *N,
15210                                     TargetLowering::DAGCombinerInfo &DCI,
15211                                     SelectionDAG &DAG) {
15212   MachineFunction &MF = DAG.getMachineFunction();
15213   // Speculation tracking/SLH assumes that optimized TB(N)Z/CB(N)Z instructions
15214   // will not be produced, as they are conditional branch instructions that do
15215   // not set flags.
15216   if (MF.getFunction().hasFnAttribute(Attribute::SpeculativeLoadHardening))
15217     return SDValue();
15218 
15219   if (SDValue NV = performCONDCombine(N, DCI, DAG, 2, 3))
15220     N = NV.getNode();
15221   SDValue Chain = N->getOperand(0);
15222   SDValue Dest = N->getOperand(1);
15223   SDValue CCVal = N->getOperand(2);
15224   SDValue Cmp = N->getOperand(3);
15225 
15226   assert(isa<ConstantSDNode>(CCVal) && "Expected a ConstantSDNode here!");
15227   unsigned CC = cast<ConstantSDNode>(CCVal)->getZExtValue();
15228   if (CC != AArch64CC::EQ && CC != AArch64CC::NE)
15229     return SDValue();
15230 
15231   unsigned CmpOpc = Cmp.getOpcode();
15232   if (CmpOpc != AArch64ISD::ADDS && CmpOpc != AArch64ISD::SUBS)
15233     return SDValue();
15234 
15235   // Only attempt folding if there is only one use of the flag and no use of the
15236   // value.
15237   if (!Cmp->hasNUsesOfValue(0, 0) || !Cmp->hasNUsesOfValue(1, 1))
15238     return SDValue();
15239 
15240   SDValue LHS = Cmp.getOperand(0);
15241   SDValue RHS = Cmp.getOperand(1);
15242 
15243   assert(LHS.getValueType() == RHS.getValueType() &&
15244          "Expected the value type to be the same for both operands!");
15245   if (LHS.getValueType() != MVT::i32 && LHS.getValueType() != MVT::i64)
15246     return SDValue();
15247 
15248   if (isNullConstant(LHS))
15249     std::swap(LHS, RHS);
15250 
15251   if (!isNullConstant(RHS))
15252     return SDValue();
15253 
15254   if (LHS.getOpcode() == ISD::SHL || LHS.getOpcode() == ISD::SRA ||
15255       LHS.getOpcode() == ISD::SRL)
15256     return SDValue();
15257 
15258   // Fold the compare into the branch instruction.
15259   SDValue BR;
15260   if (CC == AArch64CC::EQ)
15261     BR = DAG.getNode(AArch64ISD::CBZ, SDLoc(N), MVT::Other, Chain, LHS, Dest);
15262   else
15263     BR = DAG.getNode(AArch64ISD::CBNZ, SDLoc(N), MVT::Other, Chain, LHS, Dest);
15264 
15265   // Do not add new nodes to DAG combiner worklist.
15266   DCI.CombineTo(N, BR, false);
15267 
15268   return SDValue();
15269 }
15270 
15271 // Optimize CSEL instructions
performCSELCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)15272 static SDValue performCSELCombine(SDNode *N,
15273                                   TargetLowering::DAGCombinerInfo &DCI,
15274                                   SelectionDAG &DAG) {
15275   // CSEL x, x, cc -> x
15276   if (N->getOperand(0) == N->getOperand(1))
15277     return N->getOperand(0);
15278 
15279   return performCONDCombine(N, DCI, DAG, 2, 3);
15280 }
15281 
15282 // Optimize some simple tbz/tbnz cases.  Returns the new operand and bit to test
15283 // as well as whether the test should be inverted.  This code is required to
15284 // catch these cases (as opposed to standard dag combines) because
15285 // AArch64ISD::TBZ is matched during legalization.
getTestBitOperand(SDValue Op,unsigned & Bit,bool & Invert,SelectionDAG & DAG)15286 static SDValue getTestBitOperand(SDValue Op, unsigned &Bit, bool &Invert,
15287                                  SelectionDAG &DAG) {
15288 
15289   if (!Op->hasOneUse())
15290     return Op;
15291 
15292   // We don't handle undef/constant-fold cases below, as they should have
15293   // already been taken care of (e.g. and of 0, test of undefined shifted bits,
15294   // etc.)
15295 
15296   // (tbz (trunc x), b) -> (tbz x, b)
15297   // This case is just here to enable more of the below cases to be caught.
15298   if (Op->getOpcode() == ISD::TRUNCATE &&
15299       Bit < Op->getValueType(0).getSizeInBits()) {
15300     return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
15301   }
15302 
15303   // (tbz (any_ext x), b) -> (tbz x, b) if we don't use the extended bits.
15304   if (Op->getOpcode() == ISD::ANY_EXTEND &&
15305       Bit < Op->getOperand(0).getValueSizeInBits()) {
15306     return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
15307   }
15308 
15309   if (Op->getNumOperands() != 2)
15310     return Op;
15311 
15312   auto *C = dyn_cast<ConstantSDNode>(Op->getOperand(1));
15313   if (!C)
15314     return Op;
15315 
15316   switch (Op->getOpcode()) {
15317   default:
15318     return Op;
15319 
15320   // (tbz (and x, m), b) -> (tbz x, b)
15321   case ISD::AND:
15322     if ((C->getZExtValue() >> Bit) & 1)
15323       return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
15324     return Op;
15325 
15326   // (tbz (shl x, c), b) -> (tbz x, b-c)
15327   case ISD::SHL:
15328     if (C->getZExtValue() <= Bit &&
15329         (Bit - C->getZExtValue()) < Op->getValueType(0).getSizeInBits()) {
15330       Bit = Bit - C->getZExtValue();
15331       return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
15332     }
15333     return Op;
15334 
15335   // (tbz (sra x, c), b) -> (tbz x, b+c) or (tbz x, msb) if b+c is > # bits in x
15336   case ISD::SRA:
15337     Bit = Bit + C->getZExtValue();
15338     if (Bit >= Op->getValueType(0).getSizeInBits())
15339       Bit = Op->getValueType(0).getSizeInBits() - 1;
15340     return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
15341 
15342   // (tbz (srl x, c), b) -> (tbz x, b+c)
15343   case ISD::SRL:
15344     if ((Bit + C->getZExtValue()) < Op->getValueType(0).getSizeInBits()) {
15345       Bit = Bit + C->getZExtValue();
15346       return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
15347     }
15348     return Op;
15349 
15350   // (tbz (xor x, -1), b) -> (tbnz x, b)
15351   case ISD::XOR:
15352     if ((C->getZExtValue() >> Bit) & 1)
15353       Invert = !Invert;
15354     return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
15355   }
15356 }
15357 
15358 // Optimize test single bit zero/non-zero and branch.
performTBZCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)15359 static SDValue performTBZCombine(SDNode *N,
15360                                  TargetLowering::DAGCombinerInfo &DCI,
15361                                  SelectionDAG &DAG) {
15362   unsigned Bit = cast<ConstantSDNode>(N->getOperand(2))->getZExtValue();
15363   bool Invert = false;
15364   SDValue TestSrc = N->getOperand(1);
15365   SDValue NewTestSrc = getTestBitOperand(TestSrc, Bit, Invert, DAG);
15366 
15367   if (TestSrc == NewTestSrc)
15368     return SDValue();
15369 
15370   unsigned NewOpc = N->getOpcode();
15371   if (Invert) {
15372     if (NewOpc == AArch64ISD::TBZ)
15373       NewOpc = AArch64ISD::TBNZ;
15374     else {
15375       assert(NewOpc == AArch64ISD::TBNZ);
15376       NewOpc = AArch64ISD::TBZ;
15377     }
15378   }
15379 
15380   SDLoc DL(N);
15381   return DAG.getNode(NewOpc, DL, MVT::Other, N->getOperand(0), NewTestSrc,
15382                      DAG.getConstant(Bit, DL, MVT::i64), N->getOperand(3));
15383 }
15384 
15385 // vselect (v1i1 setcc) ->
15386 //     vselect (v1iXX setcc)  (XX is the size of the compared operand type)
15387 // FIXME: Currently the type legalizer can't handle VSELECT having v1i1 as
15388 // condition. If it can legalize "VSELECT v1i1" correctly, no need to combine
15389 // such VSELECT.
performVSelectCombine(SDNode * N,SelectionDAG & DAG)15390 static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
15391   SDValue N0 = N->getOperand(0);
15392   EVT CCVT = N0.getValueType();
15393 
15394   // Check for sign pattern (VSELECT setgt, iN lhs, -1, 1, -1) and transform
15395   // into (OR (ASR lhs, N-1), 1), which requires less instructions for the
15396   // supported types.
15397   SDValue SetCC = N->getOperand(0);
15398   if (SetCC.getOpcode() == ISD::SETCC &&
15399       SetCC.getOperand(2) == DAG.getCondCode(ISD::SETGT)) {
15400     SDValue CmpLHS = SetCC.getOperand(0);
15401     EVT VT = CmpLHS.getValueType();
15402     SDNode *CmpRHS = SetCC.getOperand(1).getNode();
15403     SDNode *SplatLHS = N->getOperand(1).getNode();
15404     SDNode *SplatRHS = N->getOperand(2).getNode();
15405     APInt SplatLHSVal;
15406     if (CmpLHS.getValueType() == N->getOperand(1).getValueType() &&
15407         VT.isSimple() &&
15408         is_contained(
15409             makeArrayRef({MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16,
15410                           MVT::v2i32, MVT::v4i32, MVT::v2i64}),
15411             VT.getSimpleVT().SimpleTy) &&
15412         ISD::isConstantSplatVector(SplatLHS, SplatLHSVal) &&
15413         SplatLHSVal.isOneValue() && ISD::isConstantSplatVectorAllOnes(CmpRHS) &&
15414         ISD::isConstantSplatVectorAllOnes(SplatRHS)) {
15415       unsigned NumElts = VT.getVectorNumElements();
15416       SmallVector<SDValue, 8> Ops(
15417           NumElts, DAG.getConstant(VT.getScalarSizeInBits() - 1, SDLoc(N),
15418                                    VT.getScalarType()));
15419       SDValue Val = DAG.getBuildVector(VT, SDLoc(N), Ops);
15420 
15421       auto Shift = DAG.getNode(ISD::SRA, SDLoc(N), VT, CmpLHS, Val);
15422       auto Or = DAG.getNode(ISD::OR, SDLoc(N), VT, Shift, N->getOperand(1));
15423       return Or;
15424     }
15425   }
15426 
15427   if (N0.getOpcode() != ISD::SETCC ||
15428       CCVT.getVectorElementCount() != ElementCount::getFixed(1) ||
15429       CCVT.getVectorElementType() != MVT::i1)
15430     return SDValue();
15431 
15432   EVT ResVT = N->getValueType(0);
15433   EVT CmpVT = N0.getOperand(0).getValueType();
15434   // Only combine when the result type is of the same size as the compared
15435   // operands.
15436   if (ResVT.getSizeInBits() != CmpVT.getSizeInBits())
15437     return SDValue();
15438 
15439   SDValue IfTrue = N->getOperand(1);
15440   SDValue IfFalse = N->getOperand(2);
15441   SetCC = DAG.getSetCC(SDLoc(N), CmpVT.changeVectorElementTypeToInteger(),
15442                        N0.getOperand(0), N0.getOperand(1),
15443                        cast<CondCodeSDNode>(N0.getOperand(2))->get());
15444   return DAG.getNode(ISD::VSELECT, SDLoc(N), ResVT, SetCC,
15445                      IfTrue, IfFalse);
15446 }
15447 
15448 /// A vector select: "(select vL, vR, (setcc LHS, RHS))" is best performed with
15449 /// the compare-mask instructions rather than going via NZCV, even if LHS and
15450 /// RHS are really scalar. This replaces any scalar setcc in the above pattern
15451 /// with a vector one followed by a DUP shuffle on the result.
performSelectCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)15452 static SDValue performSelectCombine(SDNode *N,
15453                                     TargetLowering::DAGCombinerInfo &DCI) {
15454   SelectionDAG &DAG = DCI.DAG;
15455   SDValue N0 = N->getOperand(0);
15456   EVT ResVT = N->getValueType(0);
15457 
15458   if (N0.getOpcode() != ISD::SETCC)
15459     return SDValue();
15460 
15461   if (ResVT.isScalableVector())
15462     return SDValue();
15463 
15464   // Make sure the SETCC result is either i1 (initial DAG), or i32, the lowered
15465   // scalar SetCCResultType. We also don't expect vectors, because we assume
15466   // that selects fed by vector SETCCs are canonicalized to VSELECT.
15467   assert((N0.getValueType() == MVT::i1 || N0.getValueType() == MVT::i32) &&
15468          "Scalar-SETCC feeding SELECT has unexpected result type!");
15469 
15470   // If NumMaskElts == 0, the comparison is larger than select result. The
15471   // largest real NEON comparison is 64-bits per lane, which means the result is
15472   // at most 32-bits and an illegal vector. Just bail out for now.
15473   EVT SrcVT = N0.getOperand(0).getValueType();
15474 
15475   // Don't try to do this optimization when the setcc itself has i1 operands.
15476   // There are no legal vectors of i1, so this would be pointless.
15477   if (SrcVT == MVT::i1)
15478     return SDValue();
15479 
15480   int NumMaskElts = ResVT.getSizeInBits() / SrcVT.getSizeInBits();
15481   if (!ResVT.isVector() || NumMaskElts == 0)
15482     return SDValue();
15483 
15484   SrcVT = EVT::getVectorVT(*DAG.getContext(), SrcVT, NumMaskElts);
15485   EVT CCVT = SrcVT.changeVectorElementTypeToInteger();
15486 
15487   // Also bail out if the vector CCVT isn't the same size as ResVT.
15488   // This can happen if the SETCC operand size doesn't divide the ResVT size
15489   // (e.g., f64 vs v3f32).
15490   if (CCVT.getSizeInBits() != ResVT.getSizeInBits())
15491     return SDValue();
15492 
15493   // Make sure we didn't create illegal types, if we're not supposed to.
15494   assert(DCI.isBeforeLegalize() ||
15495          DAG.getTargetLoweringInfo().isTypeLegal(SrcVT));
15496 
15497   // First perform a vector comparison, where lane 0 is the one we're interested
15498   // in.
15499   SDLoc DL(N0);
15500   SDValue LHS =
15501       DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, SrcVT, N0.getOperand(0));
15502   SDValue RHS =
15503       DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, SrcVT, N0.getOperand(1));
15504   SDValue SetCC = DAG.getNode(ISD::SETCC, DL, CCVT, LHS, RHS, N0.getOperand(2));
15505 
15506   // Now duplicate the comparison mask we want across all other lanes.
15507   SmallVector<int, 8> DUPMask(CCVT.getVectorNumElements(), 0);
15508   SDValue Mask = DAG.getVectorShuffle(CCVT, DL, SetCC, SetCC, DUPMask);
15509   Mask = DAG.getNode(ISD::BITCAST, DL,
15510                      ResVT.changeVectorElementTypeToInteger(), Mask);
15511 
15512   return DAG.getSelect(DL, ResVT, Mask, N->getOperand(1), N->getOperand(2));
15513 }
15514 
15515 /// Get rid of unnecessary NVCASTs (that don't change the type).
performNVCASTCombine(SDNode * N)15516 static SDValue performNVCASTCombine(SDNode *N) {
15517   if (N->getValueType(0) == N->getOperand(0).getValueType())
15518     return N->getOperand(0);
15519 
15520   return SDValue();
15521 }
15522 
15523 // If all users of the globaladdr are of the form (globaladdr + constant), find
15524 // the smallest constant, fold it into the globaladdr's offset and rewrite the
15525 // globaladdr as (globaladdr + constant) - constant.
performGlobalAddressCombine(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * Subtarget,const TargetMachine & TM)15526 static SDValue performGlobalAddressCombine(SDNode *N, SelectionDAG &DAG,
15527                                            const AArch64Subtarget *Subtarget,
15528                                            const TargetMachine &TM) {
15529   auto *GN = cast<GlobalAddressSDNode>(N);
15530   if (Subtarget->ClassifyGlobalReference(GN->getGlobal(), TM) !=
15531       AArch64II::MO_NO_FLAG)
15532     return SDValue();
15533 
15534   uint64_t MinOffset = -1ull;
15535   for (SDNode *N : GN->uses()) {
15536     if (N->getOpcode() != ISD::ADD)
15537       return SDValue();
15538     auto *C = dyn_cast<ConstantSDNode>(N->getOperand(0));
15539     if (!C)
15540       C = dyn_cast<ConstantSDNode>(N->getOperand(1));
15541     if (!C)
15542       return SDValue();
15543     MinOffset = std::min(MinOffset, C->getZExtValue());
15544   }
15545   uint64_t Offset = MinOffset + GN->getOffset();
15546 
15547   // Require that the new offset is larger than the existing one. Otherwise, we
15548   // can end up oscillating between two possible DAGs, for example,
15549   // (add (add globaladdr + 10, -1), 1) and (add globaladdr + 9, 1).
15550   if (Offset <= uint64_t(GN->getOffset()))
15551     return SDValue();
15552 
15553   // Check whether folding this offset is legal. It must not go out of bounds of
15554   // the referenced object to avoid violating the code model, and must be
15555   // smaller than 2^21 because this is the largest offset expressible in all
15556   // object formats.
15557   //
15558   // This check also prevents us from folding negative offsets, which will end
15559   // up being treated in the same way as large positive ones. They could also
15560   // cause code model violations, and aren't really common enough to matter.
15561   if (Offset >= (1 << 21))
15562     return SDValue();
15563 
15564   const GlobalValue *GV = GN->getGlobal();
15565   Type *T = GV->getValueType();
15566   if (!T->isSized() ||
15567       Offset > GV->getParent()->getDataLayout().getTypeAllocSize(T))
15568     return SDValue();
15569 
15570   SDLoc DL(GN);
15571   SDValue Result = DAG.getGlobalAddress(GV, DL, MVT::i64, Offset);
15572   return DAG.getNode(ISD::SUB, DL, MVT::i64, Result,
15573                      DAG.getConstant(MinOffset, DL, MVT::i64));
15574 }
15575 
15576 // Turns the vector of indices into a vector of byte offstes by scaling Offset
15577 // by (BitWidth / 8).
getScaledOffsetForBitWidth(SelectionDAG & DAG,SDValue Offset,SDLoc DL,unsigned BitWidth)15578 static SDValue getScaledOffsetForBitWidth(SelectionDAG &DAG, SDValue Offset,
15579                                           SDLoc DL, unsigned BitWidth) {
15580   assert(Offset.getValueType().isScalableVector() &&
15581          "This method is only for scalable vectors of offsets");
15582 
15583   SDValue Shift = DAG.getConstant(Log2_32(BitWidth / 8), DL, MVT::i64);
15584   SDValue SplatShift = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv2i64, Shift);
15585 
15586   return DAG.getNode(ISD::SHL, DL, MVT::nxv2i64, Offset, SplatShift);
15587 }
15588 
15589 /// Check if the value of \p OffsetInBytes can be used as an immediate for
15590 /// the gather load/prefetch and scatter store instructions with vector base and
15591 /// immediate offset addressing mode:
15592 ///
15593 ///      [<Zn>.[S|D]{, #<imm>}]
15594 ///
15595 /// where <imm> = sizeof(<T>) * k, for k = 0, 1, ..., 31.
15596 
isValidImmForSVEVecImmAddrMode(unsigned OffsetInBytes,unsigned ScalarSizeInBytes)15597 inline static bool isValidImmForSVEVecImmAddrMode(unsigned OffsetInBytes,
15598                                                   unsigned ScalarSizeInBytes) {
15599   // The immediate is not a multiple of the scalar size.
15600   if (OffsetInBytes % ScalarSizeInBytes)
15601     return false;
15602 
15603   // The immediate is out of range.
15604   if (OffsetInBytes / ScalarSizeInBytes > 31)
15605     return false;
15606 
15607   return true;
15608 }
15609 
15610 /// Check if the value of \p Offset represents a valid immediate for the SVE
15611 /// gather load/prefetch and scatter store instructiona with vector base and
15612 /// immediate offset addressing mode:
15613 ///
15614 ///      [<Zn>.[S|D]{, #<imm>}]
15615 ///
15616 /// where <imm> = sizeof(<T>) * k, for k = 0, 1, ..., 31.
isValidImmForSVEVecImmAddrMode(SDValue Offset,unsigned ScalarSizeInBytes)15617 static bool isValidImmForSVEVecImmAddrMode(SDValue Offset,
15618                                            unsigned ScalarSizeInBytes) {
15619   ConstantSDNode *OffsetConst = dyn_cast<ConstantSDNode>(Offset.getNode());
15620   return OffsetConst && isValidImmForSVEVecImmAddrMode(
15621                             OffsetConst->getZExtValue(), ScalarSizeInBytes);
15622 }
15623 
performScatterStoreCombine(SDNode * N,SelectionDAG & DAG,unsigned Opcode,bool OnlyPackedOffsets=true)15624 static SDValue performScatterStoreCombine(SDNode *N, SelectionDAG &DAG,
15625                                           unsigned Opcode,
15626                                           bool OnlyPackedOffsets = true) {
15627   const SDValue Src = N->getOperand(2);
15628   const EVT SrcVT = Src->getValueType(0);
15629   assert(SrcVT.isScalableVector() &&
15630          "Scatter stores are only possible for SVE vectors");
15631 
15632   SDLoc DL(N);
15633   MVT SrcElVT = SrcVT.getVectorElementType().getSimpleVT();
15634 
15635   // Make sure that source data will fit into an SVE register
15636   if (SrcVT.getSizeInBits().getKnownMinSize() > AArch64::SVEBitsPerBlock)
15637     return SDValue();
15638 
15639   // For FPs, ACLE only supports _packed_ single and double precision types.
15640   if (SrcElVT.isFloatingPoint())
15641     if ((SrcVT != MVT::nxv4f32) && (SrcVT != MVT::nxv2f64))
15642       return SDValue();
15643 
15644   // Depending on the addressing mode, this is either a pointer or a vector of
15645   // pointers (that fits into one register)
15646   SDValue Base = N->getOperand(4);
15647   // Depending on the addressing mode, this is either a single offset or a
15648   // vector of offsets  (that fits into one register)
15649   SDValue Offset = N->getOperand(5);
15650 
15651   // For "scalar + vector of indices", just scale the indices. This only
15652   // applies to non-temporal scatters because there's no instruction that takes
15653   // indicies.
15654   if (Opcode == AArch64ISD::SSTNT1_INDEX_PRED) {
15655     Offset =
15656         getScaledOffsetForBitWidth(DAG, Offset, DL, SrcElVT.getSizeInBits());
15657     Opcode = AArch64ISD::SSTNT1_PRED;
15658   }
15659 
15660   // In the case of non-temporal gather loads there's only one SVE instruction
15661   // per data-size: "scalar + vector", i.e.
15662   //    * stnt1{b|h|w|d} { z0.s }, p0/z, [z0.s, x0]
15663   // Since we do have intrinsics that allow the arguments to be in a different
15664   // order, we may need to swap them to match the spec.
15665   if (Opcode == AArch64ISD::SSTNT1_PRED && Offset.getValueType().isVector())
15666     std::swap(Base, Offset);
15667 
15668   // SST1_IMM requires that the offset is an immediate that is:
15669   //    * a multiple of #SizeInBytes,
15670   //    * in the range [0, 31 x #SizeInBytes],
15671   // where #SizeInBytes is the size in bytes of the stored items. For
15672   // immediates outside that range and non-immediate scalar offsets use SST1 or
15673   // SST1_UXTW instead.
15674   if (Opcode == AArch64ISD::SST1_IMM_PRED) {
15675     if (!isValidImmForSVEVecImmAddrMode(Offset,
15676                                         SrcVT.getScalarSizeInBits() / 8)) {
15677       if (MVT::nxv4i32 == Base.getValueType().getSimpleVT().SimpleTy)
15678         Opcode = AArch64ISD::SST1_UXTW_PRED;
15679       else
15680         Opcode = AArch64ISD::SST1_PRED;
15681 
15682       std::swap(Base, Offset);
15683     }
15684   }
15685 
15686   auto &TLI = DAG.getTargetLoweringInfo();
15687   if (!TLI.isTypeLegal(Base.getValueType()))
15688     return SDValue();
15689 
15690   // Some scatter store variants allow unpacked offsets, but only as nxv2i32
15691   // vectors. These are implicitly sign (sxtw) or zero (zxtw) extend to
15692   // nxv2i64. Legalize accordingly.
15693   if (!OnlyPackedOffsets &&
15694       Offset.getValueType().getSimpleVT().SimpleTy == MVT::nxv2i32)
15695     Offset = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv2i64, Offset).getValue(0);
15696 
15697   if (!TLI.isTypeLegal(Offset.getValueType()))
15698     return SDValue();
15699 
15700   // Source value type that is representable in hardware
15701   EVT HwSrcVt = getSVEContainerType(SrcVT);
15702 
15703   // Keep the original type of the input data to store - this is needed to be
15704   // able to select the correct instruction, e.g. ST1B, ST1H, ST1W and ST1D. For
15705   // FP values we want the integer equivalent, so just use HwSrcVt.
15706   SDValue InputVT = DAG.getValueType(SrcVT);
15707   if (SrcVT.isFloatingPoint())
15708     InputVT = DAG.getValueType(HwSrcVt);
15709 
15710   SDVTList VTs = DAG.getVTList(MVT::Other);
15711   SDValue SrcNew;
15712 
15713   if (Src.getValueType().isFloatingPoint())
15714     SrcNew = DAG.getNode(ISD::BITCAST, DL, HwSrcVt, Src);
15715   else
15716     SrcNew = DAG.getNode(ISD::ANY_EXTEND, DL, HwSrcVt, Src);
15717 
15718   SDValue Ops[] = {N->getOperand(0), // Chain
15719                    SrcNew,
15720                    N->getOperand(3), // Pg
15721                    Base,
15722                    Offset,
15723                    InputVT};
15724 
15725   return DAG.getNode(Opcode, DL, VTs, Ops);
15726 }
15727 
performGatherLoadCombine(SDNode * N,SelectionDAG & DAG,unsigned Opcode,bool OnlyPackedOffsets=true)15728 static SDValue performGatherLoadCombine(SDNode *N, SelectionDAG &DAG,
15729                                         unsigned Opcode,
15730                                         bool OnlyPackedOffsets = true) {
15731   const EVT RetVT = N->getValueType(0);
15732   assert(RetVT.isScalableVector() &&
15733          "Gather loads are only possible for SVE vectors");
15734 
15735   SDLoc DL(N);
15736 
15737   // Make sure that the loaded data will fit into an SVE register
15738   if (RetVT.getSizeInBits().getKnownMinSize() > AArch64::SVEBitsPerBlock)
15739     return SDValue();
15740 
15741   // Depending on the addressing mode, this is either a pointer or a vector of
15742   // pointers (that fits into one register)
15743   SDValue Base = N->getOperand(3);
15744   // Depending on the addressing mode, this is either a single offset or a
15745   // vector of offsets  (that fits into one register)
15746   SDValue Offset = N->getOperand(4);
15747 
15748   // For "scalar + vector of indices", just scale the indices. This only
15749   // applies to non-temporal gathers because there's no instruction that takes
15750   // indicies.
15751   if (Opcode == AArch64ISD::GLDNT1_INDEX_MERGE_ZERO) {
15752     Offset = getScaledOffsetForBitWidth(DAG, Offset, DL,
15753                                         RetVT.getScalarSizeInBits());
15754     Opcode = AArch64ISD::GLDNT1_MERGE_ZERO;
15755   }
15756 
15757   // In the case of non-temporal gather loads there's only one SVE instruction
15758   // per data-size: "scalar + vector", i.e.
15759   //    * ldnt1{b|h|w|d} { z0.s }, p0/z, [z0.s, x0]
15760   // Since we do have intrinsics that allow the arguments to be in a different
15761   // order, we may need to swap them to match the spec.
15762   if (Opcode == AArch64ISD::GLDNT1_MERGE_ZERO &&
15763       Offset.getValueType().isVector())
15764     std::swap(Base, Offset);
15765 
15766   // GLD{FF}1_IMM requires that the offset is an immediate that is:
15767   //    * a multiple of #SizeInBytes,
15768   //    * in the range [0, 31 x #SizeInBytes],
15769   // where #SizeInBytes is the size in bytes of the loaded items. For
15770   // immediates outside that range and non-immediate scalar offsets use
15771   // GLD1_MERGE_ZERO or GLD1_UXTW_MERGE_ZERO instead.
15772   if (Opcode == AArch64ISD::GLD1_IMM_MERGE_ZERO ||
15773       Opcode == AArch64ISD::GLDFF1_IMM_MERGE_ZERO) {
15774     if (!isValidImmForSVEVecImmAddrMode(Offset,
15775                                         RetVT.getScalarSizeInBits() / 8)) {
15776       if (MVT::nxv4i32 == Base.getValueType().getSimpleVT().SimpleTy)
15777         Opcode = (Opcode == AArch64ISD::GLD1_IMM_MERGE_ZERO)
15778                      ? AArch64ISD::GLD1_UXTW_MERGE_ZERO
15779                      : AArch64ISD::GLDFF1_UXTW_MERGE_ZERO;
15780       else
15781         Opcode = (Opcode == AArch64ISD::GLD1_IMM_MERGE_ZERO)
15782                      ? AArch64ISD::GLD1_MERGE_ZERO
15783                      : AArch64ISD::GLDFF1_MERGE_ZERO;
15784 
15785       std::swap(Base, Offset);
15786     }
15787   }
15788 
15789   auto &TLI = DAG.getTargetLoweringInfo();
15790   if (!TLI.isTypeLegal(Base.getValueType()))
15791     return SDValue();
15792 
15793   // Some gather load variants allow unpacked offsets, but only as nxv2i32
15794   // vectors. These are implicitly sign (sxtw) or zero (zxtw) extend to
15795   // nxv2i64. Legalize accordingly.
15796   if (!OnlyPackedOffsets &&
15797       Offset.getValueType().getSimpleVT().SimpleTy == MVT::nxv2i32)
15798     Offset = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv2i64, Offset).getValue(0);
15799 
15800   // Return value type that is representable in hardware
15801   EVT HwRetVt = getSVEContainerType(RetVT);
15802 
15803   // Keep the original output value type around - this is needed to be able to
15804   // select the correct instruction, e.g. LD1B, LD1H, LD1W and LD1D. For FP
15805   // values we want the integer equivalent, so just use HwRetVT.
15806   SDValue OutVT = DAG.getValueType(RetVT);
15807   if (RetVT.isFloatingPoint())
15808     OutVT = DAG.getValueType(HwRetVt);
15809 
15810   SDVTList VTs = DAG.getVTList(HwRetVt, MVT::Other);
15811   SDValue Ops[] = {N->getOperand(0), // Chain
15812                    N->getOperand(2), // Pg
15813                    Base, Offset, OutVT};
15814 
15815   SDValue Load = DAG.getNode(Opcode, DL, VTs, Ops);
15816   SDValue LoadChain = SDValue(Load.getNode(), 1);
15817 
15818   if (RetVT.isInteger() && (RetVT != HwRetVt))
15819     Load = DAG.getNode(ISD::TRUNCATE, DL, RetVT, Load.getValue(0));
15820 
15821   // If the original return value was FP, bitcast accordingly. Doing it here
15822   // means that we can avoid adding TableGen patterns for FPs.
15823   if (RetVT.isFloatingPoint())
15824     Load = DAG.getNode(ISD::BITCAST, DL, RetVT, Load.getValue(0));
15825 
15826   return DAG.getMergeValues({Load, LoadChain}, DL);
15827 }
15828 
15829 static SDValue
performSignExtendInRegCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)15830 performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
15831                               SelectionDAG &DAG) {
15832   SDLoc DL(N);
15833   SDValue Src = N->getOperand(0);
15834   unsigned Opc = Src->getOpcode();
15835 
15836   // Sign extend of an unsigned unpack -> signed unpack
15837   if (Opc == AArch64ISD::UUNPKHI || Opc == AArch64ISD::UUNPKLO) {
15838 
15839     unsigned SOpc = Opc == AArch64ISD::UUNPKHI ? AArch64ISD::SUNPKHI
15840                                                : AArch64ISD::SUNPKLO;
15841 
15842     // Push the sign extend to the operand of the unpack
15843     // This is necessary where, for example, the operand of the unpack
15844     // is another unpack:
15845     // 4i32 sign_extend_inreg (4i32 uunpklo(8i16 uunpklo (16i8 opnd)), from 4i8)
15846     // ->
15847     // 4i32 sunpklo (8i16 sign_extend_inreg(8i16 uunpklo (16i8 opnd), from 8i8)
15848     // ->
15849     // 4i32 sunpklo(8i16 sunpklo(16i8 opnd))
15850     SDValue ExtOp = Src->getOperand(0);
15851     auto VT = cast<VTSDNode>(N->getOperand(1))->getVT();
15852     EVT EltTy = VT.getVectorElementType();
15853     (void)EltTy;
15854 
15855     assert((EltTy == MVT::i8 || EltTy == MVT::i16 || EltTy == MVT::i32) &&
15856            "Sign extending from an invalid type");
15857 
15858     EVT ExtVT = VT.getDoubleNumVectorElementsVT(*DAG.getContext());
15859 
15860     SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, ExtOp.getValueType(),
15861                               ExtOp, DAG.getValueType(ExtVT));
15862 
15863     return DAG.getNode(SOpc, DL, N->getValueType(0), Ext);
15864   }
15865 
15866   if (DCI.isBeforeLegalizeOps())
15867     return SDValue();
15868 
15869   if (!EnableCombineMGatherIntrinsics)
15870     return SDValue();
15871 
15872   // SVE load nodes (e.g. AArch64ISD::GLD1) are straightforward candidates
15873   // for DAG Combine with SIGN_EXTEND_INREG. Bail out for all other nodes.
15874   unsigned NewOpc;
15875   unsigned MemVTOpNum = 4;
15876   switch (Opc) {
15877   case AArch64ISD::LD1_MERGE_ZERO:
15878     NewOpc = AArch64ISD::LD1S_MERGE_ZERO;
15879     MemVTOpNum = 3;
15880     break;
15881   case AArch64ISD::LDNF1_MERGE_ZERO:
15882     NewOpc = AArch64ISD::LDNF1S_MERGE_ZERO;
15883     MemVTOpNum = 3;
15884     break;
15885   case AArch64ISD::LDFF1_MERGE_ZERO:
15886     NewOpc = AArch64ISD::LDFF1S_MERGE_ZERO;
15887     MemVTOpNum = 3;
15888     break;
15889   case AArch64ISD::GLD1_MERGE_ZERO:
15890     NewOpc = AArch64ISD::GLD1S_MERGE_ZERO;
15891     break;
15892   case AArch64ISD::GLD1_SCALED_MERGE_ZERO:
15893     NewOpc = AArch64ISD::GLD1S_SCALED_MERGE_ZERO;
15894     break;
15895   case AArch64ISD::GLD1_SXTW_MERGE_ZERO:
15896     NewOpc = AArch64ISD::GLD1S_SXTW_MERGE_ZERO;
15897     break;
15898   case AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO:
15899     NewOpc = AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO;
15900     break;
15901   case AArch64ISD::GLD1_UXTW_MERGE_ZERO:
15902     NewOpc = AArch64ISD::GLD1S_UXTW_MERGE_ZERO;
15903     break;
15904   case AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO:
15905     NewOpc = AArch64ISD::GLD1S_UXTW_SCALED_MERGE_ZERO;
15906     break;
15907   case AArch64ISD::GLD1_IMM_MERGE_ZERO:
15908     NewOpc = AArch64ISD::GLD1S_IMM_MERGE_ZERO;
15909     break;
15910   case AArch64ISD::GLDFF1_MERGE_ZERO:
15911     NewOpc = AArch64ISD::GLDFF1S_MERGE_ZERO;
15912     break;
15913   case AArch64ISD::GLDFF1_SCALED_MERGE_ZERO:
15914     NewOpc = AArch64ISD::GLDFF1S_SCALED_MERGE_ZERO;
15915     break;
15916   case AArch64ISD::GLDFF1_SXTW_MERGE_ZERO:
15917     NewOpc = AArch64ISD::GLDFF1S_SXTW_MERGE_ZERO;
15918     break;
15919   case AArch64ISD::GLDFF1_SXTW_SCALED_MERGE_ZERO:
15920     NewOpc = AArch64ISD::GLDFF1S_SXTW_SCALED_MERGE_ZERO;
15921     break;
15922   case AArch64ISD::GLDFF1_UXTW_MERGE_ZERO:
15923     NewOpc = AArch64ISD::GLDFF1S_UXTW_MERGE_ZERO;
15924     break;
15925   case AArch64ISD::GLDFF1_UXTW_SCALED_MERGE_ZERO:
15926     NewOpc = AArch64ISD::GLDFF1S_UXTW_SCALED_MERGE_ZERO;
15927     break;
15928   case AArch64ISD::GLDFF1_IMM_MERGE_ZERO:
15929     NewOpc = AArch64ISD::GLDFF1S_IMM_MERGE_ZERO;
15930     break;
15931   case AArch64ISD::GLDNT1_MERGE_ZERO:
15932     NewOpc = AArch64ISD::GLDNT1S_MERGE_ZERO;
15933     break;
15934   default:
15935     return SDValue();
15936   }
15937 
15938   EVT SignExtSrcVT = cast<VTSDNode>(N->getOperand(1))->getVT();
15939   EVT SrcMemVT = cast<VTSDNode>(Src->getOperand(MemVTOpNum))->getVT();
15940 
15941   if ((SignExtSrcVT != SrcMemVT) || !Src.hasOneUse())
15942     return SDValue();
15943 
15944   EVT DstVT = N->getValueType(0);
15945   SDVTList VTs = DAG.getVTList(DstVT, MVT::Other);
15946 
15947   SmallVector<SDValue, 5> Ops;
15948   for (unsigned I = 0; I < Src->getNumOperands(); ++I)
15949     Ops.push_back(Src->getOperand(I));
15950 
15951   SDValue ExtLoad = DAG.getNode(NewOpc, SDLoc(N), VTs, Ops);
15952   DCI.CombineTo(N, ExtLoad);
15953   DCI.CombineTo(Src.getNode(), ExtLoad, ExtLoad.getValue(1));
15954 
15955   // Return N so it doesn't get rechecked
15956   return SDValue(N, 0);
15957 }
15958 
15959 /// Legalize the gather prefetch (scalar + vector addressing mode) when the
15960 /// offset vector is an unpacked 32-bit scalable vector. The other cases (Offset
15961 /// != nxv2i32) do not need legalization.
legalizeSVEGatherPrefetchOffsVec(SDNode * N,SelectionDAG & DAG)15962 static SDValue legalizeSVEGatherPrefetchOffsVec(SDNode *N, SelectionDAG &DAG) {
15963   const unsigned OffsetPos = 4;
15964   SDValue Offset = N->getOperand(OffsetPos);
15965 
15966   // Not an unpacked vector, bail out.
15967   if (Offset.getValueType().getSimpleVT().SimpleTy != MVT::nxv2i32)
15968     return SDValue();
15969 
15970   // Extend the unpacked offset vector to 64-bit lanes.
15971   SDLoc DL(N);
15972   Offset = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv2i64, Offset);
15973   SmallVector<SDValue, 5> Ops(N->op_begin(), N->op_end());
15974   // Replace the offset operand with the 64-bit one.
15975   Ops[OffsetPos] = Offset;
15976 
15977   return DAG.getNode(N->getOpcode(), DL, DAG.getVTList(MVT::Other), Ops);
15978 }
15979 
15980 /// Combines a node carrying the intrinsic
15981 /// `aarch64_sve_prf<T>_gather_scalar_offset` into a node that uses
15982 /// `aarch64_sve_prfb_gather_uxtw_index` when the scalar offset passed to
15983 /// `aarch64_sve_prf<T>_gather_scalar_offset` is not a valid immediate for the
15984 /// sve gather prefetch instruction with vector plus immediate addressing mode.
combineSVEPrefetchVecBaseImmOff(SDNode * N,SelectionDAG & DAG,unsigned ScalarSizeInBytes)15985 static SDValue combineSVEPrefetchVecBaseImmOff(SDNode *N, SelectionDAG &DAG,
15986                                                unsigned ScalarSizeInBytes) {
15987   const unsigned ImmPos = 4, OffsetPos = 3;
15988   // No need to combine the node if the immediate is valid...
15989   if (isValidImmForSVEVecImmAddrMode(N->getOperand(ImmPos), ScalarSizeInBytes))
15990     return SDValue();
15991 
15992   // ...otherwise swap the offset base with the offset...
15993   SmallVector<SDValue, 5> Ops(N->op_begin(), N->op_end());
15994   std::swap(Ops[ImmPos], Ops[OffsetPos]);
15995   // ...and remap the intrinsic `aarch64_sve_prf<T>_gather_scalar_offset` to
15996   // `aarch64_sve_prfb_gather_uxtw_index`.
15997   SDLoc DL(N);
15998   Ops[1] = DAG.getConstant(Intrinsic::aarch64_sve_prfb_gather_uxtw_index, DL,
15999                            MVT::i64);
16000 
16001   return DAG.getNode(N->getOpcode(), DL, DAG.getVTList(MVT::Other), Ops);
16002 }
16003 
16004 // Return true if the vector operation can guarantee only the first lane of its
16005 // result contains data, with all bits in other lanes set to zero.
isLanes1toNKnownZero(SDValue Op)16006 static bool isLanes1toNKnownZero(SDValue Op) {
16007   switch (Op.getOpcode()) {
16008   default:
16009     return false;
16010   case AArch64ISD::ANDV_PRED:
16011   case AArch64ISD::EORV_PRED:
16012   case AArch64ISD::FADDA_PRED:
16013   case AArch64ISD::FADDV_PRED:
16014   case AArch64ISD::FMAXNMV_PRED:
16015   case AArch64ISD::FMAXV_PRED:
16016   case AArch64ISD::FMINNMV_PRED:
16017   case AArch64ISD::FMINV_PRED:
16018   case AArch64ISD::ORV_PRED:
16019   case AArch64ISD::SADDV_PRED:
16020   case AArch64ISD::SMAXV_PRED:
16021   case AArch64ISD::SMINV_PRED:
16022   case AArch64ISD::UADDV_PRED:
16023   case AArch64ISD::UMAXV_PRED:
16024   case AArch64ISD::UMINV_PRED:
16025     return true;
16026   }
16027 }
16028 
removeRedundantInsertVectorElt(SDNode * N)16029 static SDValue removeRedundantInsertVectorElt(SDNode *N) {
16030   assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT && "Unexpected node!");
16031   SDValue InsertVec = N->getOperand(0);
16032   SDValue InsertElt = N->getOperand(1);
16033   SDValue InsertIdx = N->getOperand(2);
16034 
16035   // We only care about inserts into the first element...
16036   if (!isNullConstant(InsertIdx))
16037     return SDValue();
16038   // ...of a zero'd vector...
16039   if (!ISD::isConstantSplatVectorAllZeros(InsertVec.getNode()))
16040     return SDValue();
16041   // ...where the inserted data was previously extracted...
16042   if (InsertElt.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
16043     return SDValue();
16044 
16045   SDValue ExtractVec = InsertElt.getOperand(0);
16046   SDValue ExtractIdx = InsertElt.getOperand(1);
16047 
16048   // ...from the first element of a vector.
16049   if (!isNullConstant(ExtractIdx))
16050     return SDValue();
16051 
16052   // If we get here we are effectively trying to zero lanes 1-N of a vector.
16053 
16054   // Ensure there's no type conversion going on.
16055   if (N->getValueType(0) != ExtractVec.getValueType())
16056     return SDValue();
16057 
16058   if (!isLanes1toNKnownZero(ExtractVec))
16059     return SDValue();
16060 
16061   // The explicit zeroing is redundant.
16062   return ExtractVec;
16063 }
16064 
16065 static SDValue
performInsertVectorEltCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)16066 performInsertVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
16067   if (SDValue Res = removeRedundantInsertVectorElt(N))
16068     return Res;
16069 
16070   return performPostLD1Combine(N, DCI, true);
16071 }
16072 
PerformDAGCombine(SDNode * N,DAGCombinerInfo & DCI) const16073 SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
16074                                                  DAGCombinerInfo &DCI) const {
16075   SelectionDAG &DAG = DCI.DAG;
16076   switch (N->getOpcode()) {
16077   default:
16078     LLVM_DEBUG(dbgs() << "Custom combining: skipping\n");
16079     break;
16080   case ISD::ABS:
16081     return performABSCombine(N, DAG, DCI, Subtarget);
16082   case ISD::ADD:
16083   case ISD::SUB:
16084     return performAddSubCombine(N, DCI, DAG);
16085   case ISD::XOR:
16086     return performXorCombine(N, DAG, DCI, Subtarget);
16087   case ISD::MUL:
16088     return performMulCombine(N, DAG, DCI, Subtarget);
16089   case ISD::SINT_TO_FP:
16090   case ISD::UINT_TO_FP:
16091     return performIntToFpCombine(N, DAG, Subtarget);
16092   case ISD::FP_TO_SINT:
16093   case ISD::FP_TO_UINT:
16094     return performFpToIntCombine(N, DAG, DCI, Subtarget);
16095   case ISD::FDIV:
16096     return performFDivCombine(N, DAG, DCI, Subtarget);
16097   case ISD::OR:
16098     return performORCombine(N, DCI, Subtarget);
16099   case ISD::AND:
16100     return performANDCombine(N, DCI);
16101   case ISD::SRL:
16102     return performSRLCombine(N, DCI);
16103   case ISD::INTRINSIC_WO_CHAIN:
16104     return performIntrinsicCombine(N, DCI, Subtarget);
16105   case ISD::ANY_EXTEND:
16106   case ISD::ZERO_EXTEND:
16107   case ISD::SIGN_EXTEND:
16108     return performExtendCombine(N, DCI, DAG);
16109   case ISD::SIGN_EXTEND_INREG:
16110     return performSignExtendInRegCombine(N, DCI, DAG);
16111   case ISD::TRUNCATE:
16112     return performVectorTruncateCombine(N, DCI, DAG);
16113   case ISD::CONCAT_VECTORS:
16114     return performConcatVectorsCombine(N, DCI, DAG);
16115   case ISD::SELECT:
16116     return performSelectCombine(N, DCI);
16117   case ISD::VSELECT:
16118     return performVSelectCombine(N, DCI.DAG);
16119   case ISD::LOAD:
16120     if (performTBISimplification(N->getOperand(1), DCI, DAG))
16121       return SDValue(N, 0);
16122     break;
16123   case ISD::STORE:
16124     return performSTORECombine(N, DCI, DAG, Subtarget);
16125   case AArch64ISD::BRCOND:
16126     return performBRCONDCombine(N, DCI, DAG);
16127   case AArch64ISD::TBNZ:
16128   case AArch64ISD::TBZ:
16129     return performTBZCombine(N, DCI, DAG);
16130   case AArch64ISD::CSEL:
16131     return performCSELCombine(N, DCI, DAG);
16132   case AArch64ISD::DUP:
16133     return performPostLD1Combine(N, DCI, false);
16134   case AArch64ISD::NVCAST:
16135     return performNVCASTCombine(N);
16136   case AArch64ISD::UZP1:
16137     return performUzpCombine(N, DAG);
16138   case AArch64ISD::GLD1_MERGE_ZERO:
16139   case AArch64ISD::GLD1_SCALED_MERGE_ZERO:
16140   case AArch64ISD::GLD1_UXTW_MERGE_ZERO:
16141   case AArch64ISD::GLD1_SXTW_MERGE_ZERO:
16142   case AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO:
16143   case AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO:
16144   case AArch64ISD::GLD1_IMM_MERGE_ZERO:
16145   case AArch64ISD::GLD1S_MERGE_ZERO:
16146   case AArch64ISD::GLD1S_SCALED_MERGE_ZERO:
16147   case AArch64ISD::GLD1S_UXTW_MERGE_ZERO:
16148   case AArch64ISD::GLD1S_SXTW_MERGE_ZERO:
16149   case AArch64ISD::GLD1S_UXTW_SCALED_MERGE_ZERO:
16150   case AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO:
16151   case AArch64ISD::GLD1S_IMM_MERGE_ZERO:
16152     return performGLD1Combine(N, DAG);
16153   case AArch64ISD::VASHR:
16154   case AArch64ISD::VLSHR:
16155     return performVectorShiftCombine(N, *this, DCI);
16156   case ISD::INSERT_VECTOR_ELT:
16157     return performInsertVectorEltCombine(N, DCI);
16158   case ISD::EXTRACT_VECTOR_ELT:
16159     return performExtractVectorEltCombine(N, DAG);
16160   case ISD::VECREDUCE_ADD:
16161     return performVecReduceAddCombine(N, DCI.DAG, Subtarget);
16162   case ISD::INTRINSIC_VOID:
16163   case ISD::INTRINSIC_W_CHAIN:
16164     switch (cast<ConstantSDNode>(N->getOperand(1))->getZExtValue()) {
16165     case Intrinsic::aarch64_sve_prfb_gather_scalar_offset:
16166       return combineSVEPrefetchVecBaseImmOff(N, DAG, 1 /*=ScalarSizeInBytes*/);
16167     case Intrinsic::aarch64_sve_prfh_gather_scalar_offset:
16168       return combineSVEPrefetchVecBaseImmOff(N, DAG, 2 /*=ScalarSizeInBytes*/);
16169     case Intrinsic::aarch64_sve_prfw_gather_scalar_offset:
16170       return combineSVEPrefetchVecBaseImmOff(N, DAG, 4 /*=ScalarSizeInBytes*/);
16171     case Intrinsic::aarch64_sve_prfd_gather_scalar_offset:
16172       return combineSVEPrefetchVecBaseImmOff(N, DAG, 8 /*=ScalarSizeInBytes*/);
16173     case Intrinsic::aarch64_sve_prfb_gather_uxtw_index:
16174     case Intrinsic::aarch64_sve_prfb_gather_sxtw_index:
16175     case Intrinsic::aarch64_sve_prfh_gather_uxtw_index:
16176     case Intrinsic::aarch64_sve_prfh_gather_sxtw_index:
16177     case Intrinsic::aarch64_sve_prfw_gather_uxtw_index:
16178     case Intrinsic::aarch64_sve_prfw_gather_sxtw_index:
16179     case Intrinsic::aarch64_sve_prfd_gather_uxtw_index:
16180     case Intrinsic::aarch64_sve_prfd_gather_sxtw_index:
16181       return legalizeSVEGatherPrefetchOffsVec(N, DAG);
16182     case Intrinsic::aarch64_neon_ld2:
16183     case Intrinsic::aarch64_neon_ld3:
16184     case Intrinsic::aarch64_neon_ld4:
16185     case Intrinsic::aarch64_neon_ld1x2:
16186     case Intrinsic::aarch64_neon_ld1x3:
16187     case Intrinsic::aarch64_neon_ld1x4:
16188     case Intrinsic::aarch64_neon_ld2lane:
16189     case Intrinsic::aarch64_neon_ld3lane:
16190     case Intrinsic::aarch64_neon_ld4lane:
16191     case Intrinsic::aarch64_neon_ld2r:
16192     case Intrinsic::aarch64_neon_ld3r:
16193     case Intrinsic::aarch64_neon_ld4r:
16194     case Intrinsic::aarch64_neon_st2:
16195     case Intrinsic::aarch64_neon_st3:
16196     case Intrinsic::aarch64_neon_st4:
16197     case Intrinsic::aarch64_neon_st1x2:
16198     case Intrinsic::aarch64_neon_st1x3:
16199     case Intrinsic::aarch64_neon_st1x4:
16200     case Intrinsic::aarch64_neon_st2lane:
16201     case Intrinsic::aarch64_neon_st3lane:
16202     case Intrinsic::aarch64_neon_st4lane:
16203       return performNEONPostLDSTCombine(N, DCI, DAG);
16204     case Intrinsic::aarch64_sve_ldnt1:
16205       return performLDNT1Combine(N, DAG);
16206     case Intrinsic::aarch64_sve_ld1rq:
16207       return performLD1ReplicateCombine<AArch64ISD::LD1RQ_MERGE_ZERO>(N, DAG);
16208     case Intrinsic::aarch64_sve_ld1ro:
16209       return performLD1ReplicateCombine<AArch64ISD::LD1RO_MERGE_ZERO>(N, DAG);
16210     case Intrinsic::aarch64_sve_ldnt1_gather_scalar_offset:
16211       return performGatherLoadCombine(N, DAG, AArch64ISD::GLDNT1_MERGE_ZERO);
16212     case Intrinsic::aarch64_sve_ldnt1_gather:
16213       return performGatherLoadCombine(N, DAG, AArch64ISD::GLDNT1_MERGE_ZERO);
16214     case Intrinsic::aarch64_sve_ldnt1_gather_index:
16215       return performGatherLoadCombine(N, DAG,
16216                                       AArch64ISD::GLDNT1_INDEX_MERGE_ZERO);
16217     case Intrinsic::aarch64_sve_ldnt1_gather_uxtw:
16218       return performGatherLoadCombine(N, DAG, AArch64ISD::GLDNT1_MERGE_ZERO);
16219     case Intrinsic::aarch64_sve_ld1:
16220       return performLD1Combine(N, DAG, AArch64ISD::LD1_MERGE_ZERO);
16221     case Intrinsic::aarch64_sve_ldnf1:
16222       return performLD1Combine(N, DAG, AArch64ISD::LDNF1_MERGE_ZERO);
16223     case Intrinsic::aarch64_sve_ldff1:
16224       return performLD1Combine(N, DAG, AArch64ISD::LDFF1_MERGE_ZERO);
16225     case Intrinsic::aarch64_sve_st1:
16226       return performST1Combine(N, DAG);
16227     case Intrinsic::aarch64_sve_stnt1:
16228       return performSTNT1Combine(N, DAG);
16229     case Intrinsic::aarch64_sve_stnt1_scatter_scalar_offset:
16230       return performScatterStoreCombine(N, DAG, AArch64ISD::SSTNT1_PRED);
16231     case Intrinsic::aarch64_sve_stnt1_scatter_uxtw:
16232       return performScatterStoreCombine(N, DAG, AArch64ISD::SSTNT1_PRED);
16233     case Intrinsic::aarch64_sve_stnt1_scatter:
16234       return performScatterStoreCombine(N, DAG, AArch64ISD::SSTNT1_PRED);
16235     case Intrinsic::aarch64_sve_stnt1_scatter_index:
16236       return performScatterStoreCombine(N, DAG, AArch64ISD::SSTNT1_INDEX_PRED);
16237     case Intrinsic::aarch64_sve_ld1_gather:
16238       return performGatherLoadCombine(N, DAG, AArch64ISD::GLD1_MERGE_ZERO);
16239     case Intrinsic::aarch64_sve_ld1_gather_index:
16240       return performGatherLoadCombine(N, DAG,
16241                                       AArch64ISD::GLD1_SCALED_MERGE_ZERO);
16242     case Intrinsic::aarch64_sve_ld1_gather_sxtw:
16243       return performGatherLoadCombine(N, DAG, AArch64ISD::GLD1_SXTW_MERGE_ZERO,
16244                                       /*OnlyPackedOffsets=*/false);
16245     case Intrinsic::aarch64_sve_ld1_gather_uxtw:
16246       return performGatherLoadCombine(N, DAG, AArch64ISD::GLD1_UXTW_MERGE_ZERO,
16247                                       /*OnlyPackedOffsets=*/false);
16248     case Intrinsic::aarch64_sve_ld1_gather_sxtw_index:
16249       return performGatherLoadCombine(N, DAG,
16250                                       AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO,
16251                                       /*OnlyPackedOffsets=*/false);
16252     case Intrinsic::aarch64_sve_ld1_gather_uxtw_index:
16253       return performGatherLoadCombine(N, DAG,
16254                                       AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO,
16255                                       /*OnlyPackedOffsets=*/false);
16256     case Intrinsic::aarch64_sve_ld1_gather_scalar_offset:
16257       return performGatherLoadCombine(N, DAG, AArch64ISD::GLD1_IMM_MERGE_ZERO);
16258     case Intrinsic::aarch64_sve_ldff1_gather:
16259       return performGatherLoadCombine(N, DAG, AArch64ISD::GLDFF1_MERGE_ZERO);
16260     case Intrinsic::aarch64_sve_ldff1_gather_index:
16261       return performGatherLoadCombine(N, DAG,
16262                                       AArch64ISD::GLDFF1_SCALED_MERGE_ZERO);
16263     case Intrinsic::aarch64_sve_ldff1_gather_sxtw:
16264       return performGatherLoadCombine(N, DAG,
16265                                       AArch64ISD::GLDFF1_SXTW_MERGE_ZERO,
16266                                       /*OnlyPackedOffsets=*/false);
16267     case Intrinsic::aarch64_sve_ldff1_gather_uxtw:
16268       return performGatherLoadCombine(N, DAG,
16269                                       AArch64ISD::GLDFF1_UXTW_MERGE_ZERO,
16270                                       /*OnlyPackedOffsets=*/false);
16271     case Intrinsic::aarch64_sve_ldff1_gather_sxtw_index:
16272       return performGatherLoadCombine(N, DAG,
16273                                       AArch64ISD::GLDFF1_SXTW_SCALED_MERGE_ZERO,
16274                                       /*OnlyPackedOffsets=*/false);
16275     case Intrinsic::aarch64_sve_ldff1_gather_uxtw_index:
16276       return performGatherLoadCombine(N, DAG,
16277                                       AArch64ISD::GLDFF1_UXTW_SCALED_MERGE_ZERO,
16278                                       /*OnlyPackedOffsets=*/false);
16279     case Intrinsic::aarch64_sve_ldff1_gather_scalar_offset:
16280       return performGatherLoadCombine(N, DAG,
16281                                       AArch64ISD::GLDFF1_IMM_MERGE_ZERO);
16282     case Intrinsic::aarch64_sve_st1_scatter:
16283       return performScatterStoreCombine(N, DAG, AArch64ISD::SST1_PRED);
16284     case Intrinsic::aarch64_sve_st1_scatter_index:
16285       return performScatterStoreCombine(N, DAG, AArch64ISD::SST1_SCALED_PRED);
16286     case Intrinsic::aarch64_sve_st1_scatter_sxtw:
16287       return performScatterStoreCombine(N, DAG, AArch64ISD::SST1_SXTW_PRED,
16288                                         /*OnlyPackedOffsets=*/false);
16289     case Intrinsic::aarch64_sve_st1_scatter_uxtw:
16290       return performScatterStoreCombine(N, DAG, AArch64ISD::SST1_UXTW_PRED,
16291                                         /*OnlyPackedOffsets=*/false);
16292     case Intrinsic::aarch64_sve_st1_scatter_sxtw_index:
16293       return performScatterStoreCombine(N, DAG,
16294                                         AArch64ISD::SST1_SXTW_SCALED_PRED,
16295                                         /*OnlyPackedOffsets=*/false);
16296     case Intrinsic::aarch64_sve_st1_scatter_uxtw_index:
16297       return performScatterStoreCombine(N, DAG,
16298                                         AArch64ISD::SST1_UXTW_SCALED_PRED,
16299                                         /*OnlyPackedOffsets=*/false);
16300     case Intrinsic::aarch64_sve_st1_scatter_scalar_offset:
16301       return performScatterStoreCombine(N, DAG, AArch64ISD::SST1_IMM_PRED);
16302     case Intrinsic::aarch64_sve_tuple_get: {
16303       SDLoc DL(N);
16304       SDValue Chain = N->getOperand(0);
16305       SDValue Src1 = N->getOperand(2);
16306       SDValue Idx = N->getOperand(3);
16307 
16308       uint64_t IdxConst = cast<ConstantSDNode>(Idx)->getZExtValue();
16309       EVT ResVT = N->getValueType(0);
16310       uint64_t NumLanes = ResVT.getVectorElementCount().getKnownMinValue();
16311       SDValue ExtIdx = DAG.getVectorIdxConstant(IdxConst * NumLanes, DL);
16312       SDValue Val =
16313           DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResVT, Src1, ExtIdx);
16314       return DAG.getMergeValues({Val, Chain}, DL);
16315     }
16316     case Intrinsic::aarch64_sve_tuple_set: {
16317       SDLoc DL(N);
16318       SDValue Chain = N->getOperand(0);
16319       SDValue Tuple = N->getOperand(2);
16320       SDValue Idx = N->getOperand(3);
16321       SDValue Vec = N->getOperand(4);
16322 
16323       EVT TupleVT = Tuple.getValueType();
16324       uint64_t TupleLanes = TupleVT.getVectorElementCount().getKnownMinValue();
16325 
16326       uint64_t IdxConst = cast<ConstantSDNode>(Idx)->getZExtValue();
16327       uint64_t NumLanes =
16328           Vec.getValueType().getVectorElementCount().getKnownMinValue();
16329 
16330       if ((TupleLanes % NumLanes) != 0)
16331         report_fatal_error("invalid tuple vector!");
16332 
16333       uint64_t NumVecs = TupleLanes / NumLanes;
16334 
16335       SmallVector<SDValue, 4> Opnds;
16336       for (unsigned I = 0; I < NumVecs; ++I) {
16337         if (I == IdxConst)
16338           Opnds.push_back(Vec);
16339         else {
16340           SDValue ExtIdx = DAG.getVectorIdxConstant(I * NumLanes, DL);
16341           Opnds.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL,
16342                                       Vec.getValueType(), Tuple, ExtIdx));
16343         }
16344       }
16345       SDValue Concat =
16346           DAG.getNode(ISD::CONCAT_VECTORS, DL, Tuple.getValueType(), Opnds);
16347       return DAG.getMergeValues({Concat, Chain}, DL);
16348     }
16349     case Intrinsic::aarch64_sve_tuple_create2:
16350     case Intrinsic::aarch64_sve_tuple_create3:
16351     case Intrinsic::aarch64_sve_tuple_create4: {
16352       SDLoc DL(N);
16353       SDValue Chain = N->getOperand(0);
16354 
16355       SmallVector<SDValue, 4> Opnds;
16356       for (unsigned I = 2; I < N->getNumOperands(); ++I)
16357         Opnds.push_back(N->getOperand(I));
16358 
16359       EVT VT = Opnds[0].getValueType();
16360       EVT EltVT = VT.getVectorElementType();
16361       EVT DestVT = EVT::getVectorVT(*DAG.getContext(), EltVT,
16362                                     VT.getVectorElementCount() *
16363                                         (N->getNumOperands() - 2));
16364       SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, DestVT, Opnds);
16365       return DAG.getMergeValues({Concat, Chain}, DL);
16366     }
16367     case Intrinsic::aarch64_sve_ld2:
16368     case Intrinsic::aarch64_sve_ld3:
16369     case Intrinsic::aarch64_sve_ld4: {
16370       SDLoc DL(N);
16371       SDValue Chain = N->getOperand(0);
16372       SDValue Mask = N->getOperand(2);
16373       SDValue BasePtr = N->getOperand(3);
16374       SDValue LoadOps[] = {Chain, Mask, BasePtr};
16375       unsigned IntrinsicID =
16376           cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
16377       SDValue Result =
16378           LowerSVEStructLoad(IntrinsicID, LoadOps, N->getValueType(0), DAG, DL);
16379       return DAG.getMergeValues({Result, Chain}, DL);
16380     }
16381     case Intrinsic::aarch64_rndr:
16382     case Intrinsic::aarch64_rndrrs: {
16383       unsigned IntrinsicID =
16384           cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
16385       auto Register =
16386           (IntrinsicID == Intrinsic::aarch64_rndr ? AArch64SysReg::RNDR
16387                                                   : AArch64SysReg::RNDRRS);
16388       SDLoc DL(N);
16389       SDValue A = DAG.getNode(
16390           AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, MVT::Glue, MVT::Other),
16391           N->getOperand(0), DAG.getConstant(Register, DL, MVT::i64));
16392       SDValue B = DAG.getNode(
16393           AArch64ISD::CSINC, DL, MVT::i32, DAG.getConstant(0, DL, MVT::i32),
16394           DAG.getConstant(0, DL, MVT::i32),
16395           DAG.getConstant(AArch64CC::NE, DL, MVT::i32), A.getValue(1));
16396       return DAG.getMergeValues(
16397           {A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL);
16398     }
16399     default:
16400       break;
16401     }
16402     break;
16403   case ISD::GlobalAddress:
16404     return performGlobalAddressCombine(N, DAG, Subtarget, getTargetMachine());
16405   }
16406   return SDValue();
16407 }
16408 
16409 // Check if the return value is used as only a return value, as otherwise
16410 // we can't perform a tail-call. In particular, we need to check for
16411 // target ISD nodes that are returns and any other "odd" constructs
16412 // that the generic analysis code won't necessarily catch.
isUsedByReturnOnly(SDNode * N,SDValue & Chain) const16413 bool AArch64TargetLowering::isUsedByReturnOnly(SDNode *N,
16414                                                SDValue &Chain) const {
16415   if (N->getNumValues() != 1)
16416     return false;
16417   if (!N->hasNUsesOfValue(1, 0))
16418     return false;
16419 
16420   SDValue TCChain = Chain;
16421   SDNode *Copy = *N->use_begin();
16422   if (Copy->getOpcode() == ISD::CopyToReg) {
16423     // If the copy has a glue operand, we conservatively assume it isn't safe to
16424     // perform a tail call.
16425     if (Copy->getOperand(Copy->getNumOperands() - 1).getValueType() ==
16426         MVT::Glue)
16427       return false;
16428     TCChain = Copy->getOperand(0);
16429   } else if (Copy->getOpcode() != ISD::FP_EXTEND)
16430     return false;
16431 
16432   bool HasRet = false;
16433   for (SDNode *Node : Copy->uses()) {
16434     if (Node->getOpcode() != AArch64ISD::RET_FLAG)
16435       return false;
16436     HasRet = true;
16437   }
16438 
16439   if (!HasRet)
16440     return false;
16441 
16442   Chain = TCChain;
16443   return true;
16444 }
16445 
16446 // Return whether the an instruction can potentially be optimized to a tail
16447 // call. This will cause the optimizers to attempt to move, or duplicate,
16448 // return instructions to help enable tail call optimizations for this
16449 // instruction.
mayBeEmittedAsTailCall(const CallInst * CI) const16450 bool AArch64TargetLowering::mayBeEmittedAsTailCall(const CallInst *CI) const {
16451   return CI->isTailCall();
16452 }
16453 
getIndexedAddressParts(SDNode * Op,SDValue & Base,SDValue & Offset,ISD::MemIndexedMode & AM,bool & IsInc,SelectionDAG & DAG) const16454 bool AArch64TargetLowering::getIndexedAddressParts(SDNode *Op, SDValue &Base,
16455                                                    SDValue &Offset,
16456                                                    ISD::MemIndexedMode &AM,
16457                                                    bool &IsInc,
16458                                                    SelectionDAG &DAG) const {
16459   if (Op->getOpcode() != ISD::ADD && Op->getOpcode() != ISD::SUB)
16460     return false;
16461 
16462   Base = Op->getOperand(0);
16463   // All of the indexed addressing mode instructions take a signed
16464   // 9 bit immediate offset.
16465   if (ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(Op->getOperand(1))) {
16466     int64_t RHSC = RHS->getSExtValue();
16467     if (Op->getOpcode() == ISD::SUB)
16468       RHSC = -(uint64_t)RHSC;
16469     if (!isInt<9>(RHSC))
16470       return false;
16471     IsInc = (Op->getOpcode() == ISD::ADD);
16472     Offset = Op->getOperand(1);
16473     return true;
16474   }
16475   return false;
16476 }
16477 
getPreIndexedAddressParts(SDNode * N,SDValue & Base,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG) const16478 bool AArch64TargetLowering::getPreIndexedAddressParts(SDNode *N, SDValue &Base,
16479                                                       SDValue &Offset,
16480                                                       ISD::MemIndexedMode &AM,
16481                                                       SelectionDAG &DAG) const {
16482   EVT VT;
16483   SDValue Ptr;
16484   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
16485     VT = LD->getMemoryVT();
16486     Ptr = LD->getBasePtr();
16487   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
16488     VT = ST->getMemoryVT();
16489     Ptr = ST->getBasePtr();
16490   } else
16491     return false;
16492 
16493   bool IsInc;
16494   if (!getIndexedAddressParts(Ptr.getNode(), Base, Offset, AM, IsInc, DAG))
16495     return false;
16496   AM = IsInc ? ISD::PRE_INC : ISD::PRE_DEC;
16497   return true;
16498 }
16499 
getPostIndexedAddressParts(SDNode * N,SDNode * Op,SDValue & Base,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG) const16500 bool AArch64TargetLowering::getPostIndexedAddressParts(
16501     SDNode *N, SDNode *Op, SDValue &Base, SDValue &Offset,
16502     ISD::MemIndexedMode &AM, SelectionDAG &DAG) const {
16503   EVT VT;
16504   SDValue Ptr;
16505   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
16506     VT = LD->getMemoryVT();
16507     Ptr = LD->getBasePtr();
16508   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
16509     VT = ST->getMemoryVT();
16510     Ptr = ST->getBasePtr();
16511   } else
16512     return false;
16513 
16514   bool IsInc;
16515   if (!getIndexedAddressParts(Op, Base, Offset, AM, IsInc, DAG))
16516     return false;
16517   // Post-indexing updates the base, so it's not a valid transform
16518   // if that's not the same as the load's pointer.
16519   if (Ptr != Base)
16520     return false;
16521   AM = IsInc ? ISD::POST_INC : ISD::POST_DEC;
16522   return true;
16523 }
16524 
ReplaceBITCASTResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG)16525 static void ReplaceBITCASTResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
16526                                   SelectionDAG &DAG) {
16527   SDLoc DL(N);
16528   SDValue Op = N->getOperand(0);
16529 
16530   if (N->getValueType(0) != MVT::i16 ||
16531       (Op.getValueType() != MVT::f16 && Op.getValueType() != MVT::bf16))
16532     return;
16533 
16534   Op = SDValue(
16535       DAG.getMachineNode(TargetOpcode::INSERT_SUBREG, DL, MVT::f32,
16536                          DAG.getUNDEF(MVT::i32), Op,
16537                          DAG.getTargetConstant(AArch64::hsub, DL, MVT::i32)),
16538       0);
16539   Op = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Op);
16540   Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Op));
16541 }
16542 
ReplaceReductionResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG,unsigned InterOp,unsigned AcrossOp)16543 static void ReplaceReductionResults(SDNode *N,
16544                                     SmallVectorImpl<SDValue> &Results,
16545                                     SelectionDAG &DAG, unsigned InterOp,
16546                                     unsigned AcrossOp) {
16547   EVT LoVT, HiVT;
16548   SDValue Lo, Hi;
16549   SDLoc dl(N);
16550   std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(N->getValueType(0));
16551   std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
16552   SDValue InterVal = DAG.getNode(InterOp, dl, LoVT, Lo, Hi);
16553   SDValue SplitVal = DAG.getNode(AcrossOp, dl, LoVT, InterVal);
16554   Results.push_back(SplitVal);
16555 }
16556 
splitInt128(SDValue N,SelectionDAG & DAG)16557 static std::pair<SDValue, SDValue> splitInt128(SDValue N, SelectionDAG &DAG) {
16558   SDLoc DL(N);
16559   SDValue Lo = DAG.getNode(ISD::TRUNCATE, DL, MVT::i64, N);
16560   SDValue Hi = DAG.getNode(ISD::TRUNCATE, DL, MVT::i64,
16561                            DAG.getNode(ISD::SRL, DL, MVT::i128, N,
16562                                        DAG.getConstant(64, DL, MVT::i64)));
16563   return std::make_pair(Lo, Hi);
16564 }
16565 
ReplaceExtractSubVectorResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG) const16566 void AArch64TargetLowering::ReplaceExtractSubVectorResults(
16567     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
16568   SDValue In = N->getOperand(0);
16569   EVT InVT = In.getValueType();
16570 
16571   // Common code will handle these just fine.
16572   if (!InVT.isScalableVector() || !InVT.isInteger())
16573     return;
16574 
16575   SDLoc DL(N);
16576   EVT VT = N->getValueType(0);
16577 
16578   // The following checks bail if this is not a halving operation.
16579 
16580   ElementCount ResEC = VT.getVectorElementCount();
16581 
16582   if (InVT.getVectorElementCount() != (ResEC * 2))
16583     return;
16584 
16585   auto *CIndex = dyn_cast<ConstantSDNode>(N->getOperand(1));
16586   if (!CIndex)
16587     return;
16588 
16589   unsigned Index = CIndex->getZExtValue();
16590   if ((Index != 0) && (Index != ResEC.getKnownMinValue()))
16591     return;
16592 
16593   unsigned Opcode = (Index == 0) ? AArch64ISD::UUNPKLO : AArch64ISD::UUNPKHI;
16594   EVT ExtendedHalfVT = VT.widenIntegerVectorElementType(*DAG.getContext());
16595 
16596   SDValue Half = DAG.getNode(Opcode, DL, ExtendedHalfVT, N->getOperand(0));
16597   Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Half));
16598 }
16599 
16600 // Create an even/odd pair of X registers holding integer value V.
createGPRPairNode(SelectionDAG & DAG,SDValue V)16601 static SDValue createGPRPairNode(SelectionDAG &DAG, SDValue V) {
16602   SDLoc dl(V.getNode());
16603   SDValue VLo = DAG.getAnyExtOrTrunc(V, dl, MVT::i64);
16604   SDValue VHi = DAG.getAnyExtOrTrunc(
16605       DAG.getNode(ISD::SRL, dl, MVT::i128, V, DAG.getConstant(64, dl, MVT::i64)),
16606       dl, MVT::i64);
16607   if (DAG.getDataLayout().isBigEndian())
16608     std::swap (VLo, VHi);
16609   SDValue RegClass =
16610       DAG.getTargetConstant(AArch64::XSeqPairsClassRegClassID, dl, MVT::i32);
16611   SDValue SubReg0 = DAG.getTargetConstant(AArch64::sube64, dl, MVT::i32);
16612   SDValue SubReg1 = DAG.getTargetConstant(AArch64::subo64, dl, MVT::i32);
16613   const SDValue Ops[] = { RegClass, VLo, SubReg0, VHi, SubReg1 };
16614   return SDValue(
16615       DAG.getMachineNode(TargetOpcode::REG_SEQUENCE, dl, MVT::Untyped, Ops), 0);
16616 }
16617 
ReplaceCMP_SWAP_128Results(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)16618 static void ReplaceCMP_SWAP_128Results(SDNode *N,
16619                                        SmallVectorImpl<SDValue> &Results,
16620                                        SelectionDAG &DAG,
16621                                        const AArch64Subtarget *Subtarget) {
16622   assert(N->getValueType(0) == MVT::i128 &&
16623          "AtomicCmpSwap on types less than 128 should be legal");
16624 
16625   if (Subtarget->hasLSE() || Subtarget->outlineAtomics()) {
16626     // LSE has a 128-bit compare and swap (CASP), but i128 is not a legal type,
16627     // so lower it here, wrapped in REG_SEQUENCE and EXTRACT_SUBREG.
16628     SDValue Ops[] = {
16629         createGPRPairNode(DAG, N->getOperand(2)), // Compare value
16630         createGPRPairNode(DAG, N->getOperand(3)), // Store value
16631         N->getOperand(1), // Ptr
16632         N->getOperand(0), // Chain in
16633     };
16634 
16635     MachineMemOperand *MemOp = cast<MemSDNode>(N)->getMemOperand();
16636 
16637     unsigned Opcode;
16638     switch (MemOp->getOrdering()) {
16639     case AtomicOrdering::Monotonic:
16640       Opcode = AArch64::CASPX;
16641       break;
16642     case AtomicOrdering::Acquire:
16643       Opcode = AArch64::CASPAX;
16644       break;
16645     case AtomicOrdering::Release:
16646       Opcode = AArch64::CASPLX;
16647       break;
16648     case AtomicOrdering::AcquireRelease:
16649     case AtomicOrdering::SequentiallyConsistent:
16650       Opcode = AArch64::CASPALX;
16651       break;
16652     default:
16653       llvm_unreachable("Unexpected ordering!");
16654     }
16655 
16656     MachineSDNode *CmpSwap = DAG.getMachineNode(
16657         Opcode, SDLoc(N), DAG.getVTList(MVT::Untyped, MVT::Other), Ops);
16658     DAG.setNodeMemRefs(CmpSwap, {MemOp});
16659 
16660     unsigned SubReg1 = AArch64::sube64, SubReg2 = AArch64::subo64;
16661     if (DAG.getDataLayout().isBigEndian())
16662       std::swap(SubReg1, SubReg2);
16663     SDValue Lo = DAG.getTargetExtractSubreg(SubReg1, SDLoc(N), MVT::i64,
16664                                             SDValue(CmpSwap, 0));
16665     SDValue Hi = DAG.getTargetExtractSubreg(SubReg2, SDLoc(N), MVT::i64,
16666                                             SDValue(CmpSwap, 0));
16667     Results.push_back(
16668         DAG.getNode(ISD::BUILD_PAIR, SDLoc(N), MVT::i128, Lo, Hi));
16669     Results.push_back(SDValue(CmpSwap, 1)); // Chain out
16670     return;
16671   }
16672 
16673   auto Desired = splitInt128(N->getOperand(2), DAG);
16674   auto New = splitInt128(N->getOperand(3), DAG);
16675   SDValue Ops[] = {N->getOperand(1), Desired.first, Desired.second,
16676                    New.first,        New.second,    N->getOperand(0)};
16677   SDNode *CmpSwap = DAG.getMachineNode(
16678       AArch64::CMP_SWAP_128, SDLoc(N),
16679       DAG.getVTList(MVT::i64, MVT::i64, MVT::i32, MVT::Other), Ops);
16680 
16681   MachineMemOperand *MemOp = cast<MemSDNode>(N)->getMemOperand();
16682   DAG.setNodeMemRefs(cast<MachineSDNode>(CmpSwap), {MemOp});
16683 
16684   Results.push_back(DAG.getNode(ISD::BUILD_PAIR, SDLoc(N), MVT::i128,
16685                                 SDValue(CmpSwap, 0), SDValue(CmpSwap, 1)));
16686   Results.push_back(SDValue(CmpSwap, 3));
16687 }
16688 
ReplaceNodeResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG) const16689 void AArch64TargetLowering::ReplaceNodeResults(
16690     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
16691   switch (N->getOpcode()) {
16692   default:
16693     llvm_unreachable("Don't know how to custom expand this");
16694   case ISD::BITCAST:
16695     ReplaceBITCASTResults(N, Results, DAG);
16696     return;
16697   case ISD::VECREDUCE_ADD:
16698   case ISD::VECREDUCE_SMAX:
16699   case ISD::VECREDUCE_SMIN:
16700   case ISD::VECREDUCE_UMAX:
16701   case ISD::VECREDUCE_UMIN:
16702     Results.push_back(LowerVECREDUCE(SDValue(N, 0), DAG));
16703     return;
16704 
16705   case ISD::CTPOP:
16706     if (SDValue Result = LowerCTPOP(SDValue(N, 0), DAG))
16707       Results.push_back(Result);
16708     return;
16709   case AArch64ISD::SADDV:
16710     ReplaceReductionResults(N, Results, DAG, ISD::ADD, AArch64ISD::SADDV);
16711     return;
16712   case AArch64ISD::UADDV:
16713     ReplaceReductionResults(N, Results, DAG, ISD::ADD, AArch64ISD::UADDV);
16714     return;
16715   case AArch64ISD::SMINV:
16716     ReplaceReductionResults(N, Results, DAG, ISD::SMIN, AArch64ISD::SMINV);
16717     return;
16718   case AArch64ISD::UMINV:
16719     ReplaceReductionResults(N, Results, DAG, ISD::UMIN, AArch64ISD::UMINV);
16720     return;
16721   case AArch64ISD::SMAXV:
16722     ReplaceReductionResults(N, Results, DAG, ISD::SMAX, AArch64ISD::SMAXV);
16723     return;
16724   case AArch64ISD::UMAXV:
16725     ReplaceReductionResults(N, Results, DAG, ISD::UMAX, AArch64ISD::UMAXV);
16726     return;
16727   case ISD::FP_TO_UINT:
16728   case ISD::FP_TO_SINT:
16729     assert(N->getValueType(0) == MVT::i128 && "unexpected illegal conversion");
16730     // Let normal code take care of it by not adding anything to Results.
16731     return;
16732   case ISD::ATOMIC_CMP_SWAP:
16733     ReplaceCMP_SWAP_128Results(N, Results, DAG, Subtarget);
16734     return;
16735   case ISD::LOAD: {
16736     assert(SDValue(N, 0).getValueType() == MVT::i128 &&
16737            "unexpected load's value type");
16738     LoadSDNode *LoadNode = cast<LoadSDNode>(N);
16739     if (!LoadNode->isVolatile() || LoadNode->getMemoryVT() != MVT::i128) {
16740       // Non-volatile loads are optimized later in AArch64's load/store
16741       // optimizer.
16742       return;
16743     }
16744 
16745     SDValue Result = DAG.getMemIntrinsicNode(
16746         AArch64ISD::LDP, SDLoc(N),
16747         DAG.getVTList({MVT::i64, MVT::i64, MVT::Other}),
16748         {LoadNode->getChain(), LoadNode->getBasePtr()}, LoadNode->getMemoryVT(),
16749         LoadNode->getMemOperand());
16750 
16751     SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, SDLoc(N), MVT::i128,
16752                                Result.getValue(0), Result.getValue(1));
16753     Results.append({Pair, Result.getValue(2) /* Chain */});
16754     return;
16755   }
16756   case ISD::EXTRACT_SUBVECTOR:
16757     ReplaceExtractSubVectorResults(N, Results, DAG);
16758     return;
16759   case ISD::INTRINSIC_WO_CHAIN: {
16760     EVT VT = N->getValueType(0);
16761     assert((VT == MVT::i8 || VT == MVT::i16) &&
16762            "custom lowering for unexpected type");
16763 
16764     ConstantSDNode *CN = cast<ConstantSDNode>(N->getOperand(0));
16765     Intrinsic::ID IntID = static_cast<Intrinsic::ID>(CN->getZExtValue());
16766     switch (IntID) {
16767     default:
16768       return;
16769     case Intrinsic::aarch64_sve_clasta_n: {
16770       SDLoc DL(N);
16771       auto Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, N->getOperand(2));
16772       auto V = DAG.getNode(AArch64ISD::CLASTA_N, DL, MVT::i32,
16773                            N->getOperand(1), Op2, N->getOperand(3));
16774       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
16775       return;
16776     }
16777     case Intrinsic::aarch64_sve_clastb_n: {
16778       SDLoc DL(N);
16779       auto Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, N->getOperand(2));
16780       auto V = DAG.getNode(AArch64ISD::CLASTB_N, DL, MVT::i32,
16781                            N->getOperand(1), Op2, N->getOperand(3));
16782       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
16783       return;
16784     }
16785     case Intrinsic::aarch64_sve_lasta: {
16786       SDLoc DL(N);
16787       auto V = DAG.getNode(AArch64ISD::LASTA, DL, MVT::i32,
16788                            N->getOperand(1), N->getOperand(2));
16789       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
16790       return;
16791     }
16792     case Intrinsic::aarch64_sve_lastb: {
16793       SDLoc DL(N);
16794       auto V = DAG.getNode(AArch64ISD::LASTB, DL, MVT::i32,
16795                            N->getOperand(1), N->getOperand(2));
16796       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
16797       return;
16798     }
16799     }
16800   }
16801   }
16802 }
16803 
useLoadStackGuardNode() const16804 bool AArch64TargetLowering::useLoadStackGuardNode() const {
16805   if (Subtarget->isTargetAndroid() || Subtarget->isTargetFuchsia())
16806     return TargetLowering::useLoadStackGuardNode();
16807   return true;
16808 }
16809 
combineRepeatedFPDivisors() const16810 unsigned AArch64TargetLowering::combineRepeatedFPDivisors() const {
16811   // Combine multiple FDIVs with the same divisor into multiple FMULs by the
16812   // reciprocal if there are three or more FDIVs.
16813   return 3;
16814 }
16815 
16816 TargetLoweringBase::LegalizeTypeAction
getPreferredVectorAction(MVT VT) const16817 AArch64TargetLowering::getPreferredVectorAction(MVT VT) const {
16818   // During type legalization, we prefer to widen v1i8, v1i16, v1i32  to v8i8,
16819   // v4i16, v2i32 instead of to promote.
16820   if (VT == MVT::v1i8 || VT == MVT::v1i16 || VT == MVT::v1i32 ||
16821       VT == MVT::v1f32)
16822     return TypeWidenVector;
16823 
16824   return TargetLoweringBase::getPreferredVectorAction(VT);
16825 }
16826 
16827 // Loads and stores less than 128-bits are already atomic; ones above that
16828 // are doomed anyway, so defer to the default libcall and blame the OS when
16829 // things go wrong.
shouldExpandAtomicStoreInIR(StoreInst * SI) const16830 bool AArch64TargetLowering::shouldExpandAtomicStoreInIR(StoreInst *SI) const {
16831   unsigned Size = SI->getValueOperand()->getType()->getPrimitiveSizeInBits();
16832   return Size == 128;
16833 }
16834 
16835 // Loads and stores less than 128-bits are already atomic; ones above that
16836 // are doomed anyway, so defer to the default libcall and blame the OS when
16837 // things go wrong.
16838 TargetLowering::AtomicExpansionKind
shouldExpandAtomicLoadInIR(LoadInst * LI) const16839 AArch64TargetLowering::shouldExpandAtomicLoadInIR(LoadInst *LI) const {
16840   unsigned Size = LI->getType()->getPrimitiveSizeInBits();
16841   return Size == 128 ? AtomicExpansionKind::LLSC : AtomicExpansionKind::None;
16842 }
16843 
16844 // For the real atomic operations, we have ldxr/stxr up to 128 bits,
16845 TargetLowering::AtomicExpansionKind
shouldExpandAtomicRMWInIR(AtomicRMWInst * AI) const16846 AArch64TargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
16847   if (AI->isFloatingPointOperation())
16848     return AtomicExpansionKind::CmpXChg;
16849 
16850   unsigned Size = AI->getType()->getPrimitiveSizeInBits();
16851   if (Size > 128) return AtomicExpansionKind::None;
16852 
16853   // Nand is not supported in LSE.
16854   // Leave 128 bits to LLSC or CmpXChg.
16855   if (AI->getOperation() != AtomicRMWInst::Nand && Size < 128) {
16856     if (Subtarget->hasLSE())
16857       return AtomicExpansionKind::None;
16858     if (Subtarget->outlineAtomics()) {
16859       // [U]Min/[U]Max RWM atomics are used in __sync_fetch_ libcalls so far.
16860       // Don't outline them unless
16861       // (1) high level <atomic> support approved:
16862       //   http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0493r1.pdf
16863       // (2) low level libgcc and compiler-rt support implemented by:
16864       //   min/max outline atomics helpers
16865       if (AI->getOperation() != AtomicRMWInst::Min &&
16866           AI->getOperation() != AtomicRMWInst::Max &&
16867           AI->getOperation() != AtomicRMWInst::UMin &&
16868           AI->getOperation() != AtomicRMWInst::UMax) {
16869         return AtomicExpansionKind::None;
16870       }
16871     }
16872   }
16873 
16874   // At -O0, fast-regalloc cannot cope with the live vregs necessary to
16875   // implement atomicrmw without spilling. If the target address is also on the
16876   // stack and close enough to the spill slot, this can lead to a situation
16877   // where the monitor always gets cleared and the atomic operation can never
16878   // succeed. So at -O0 lower this operation to a CAS loop.
16879   if (getTargetMachine().getOptLevel() == CodeGenOpt::None)
16880     return AtomicExpansionKind::CmpXChg;
16881 
16882   return AtomicExpansionKind::LLSC;
16883 }
16884 
16885 TargetLowering::AtomicExpansionKind
shouldExpandAtomicCmpXchgInIR(AtomicCmpXchgInst * AI) const16886 AArch64TargetLowering::shouldExpandAtomicCmpXchgInIR(
16887     AtomicCmpXchgInst *AI) const {
16888   // If subtarget has LSE, leave cmpxchg intact for codegen.
16889   if (Subtarget->hasLSE() || Subtarget->outlineAtomics())
16890     return AtomicExpansionKind::None;
16891   // At -O0, fast-regalloc cannot cope with the live vregs necessary to
16892   // implement cmpxchg without spilling. If the address being exchanged is also
16893   // on the stack and close enough to the spill slot, this can lead to a
16894   // situation where the monitor always gets cleared and the atomic operation
16895   // can never succeed. So at -O0 we need a late-expanded pseudo-inst instead.
16896   if (getTargetMachine().getOptLevel() == CodeGenOpt::None)
16897     return AtomicExpansionKind::None;
16898   return AtomicExpansionKind::LLSC;
16899 }
16900 
emitLoadLinked(IRBuilder<> & Builder,Value * Addr,AtomicOrdering Ord) const16901 Value *AArch64TargetLowering::emitLoadLinked(IRBuilder<> &Builder, Value *Addr,
16902                                              AtomicOrdering Ord) const {
16903   Module *M = Builder.GetInsertBlock()->getParent()->getParent();
16904   Type *ValTy = cast<PointerType>(Addr->getType())->getElementType();
16905   bool IsAcquire = isAcquireOrStronger(Ord);
16906 
16907   // Since i128 isn't legal and intrinsics don't get type-lowered, the ldrexd
16908   // intrinsic must return {i64, i64} and we have to recombine them into a
16909   // single i128 here.
16910   if (ValTy->getPrimitiveSizeInBits() == 128) {
16911     Intrinsic::ID Int =
16912         IsAcquire ? Intrinsic::aarch64_ldaxp : Intrinsic::aarch64_ldxp;
16913     Function *Ldxr = Intrinsic::getDeclaration(M, Int);
16914 
16915     Addr = Builder.CreateBitCast(Addr, Type::getInt8PtrTy(M->getContext()));
16916     Value *LoHi = Builder.CreateCall(Ldxr, Addr, "lohi");
16917 
16918     Value *Lo = Builder.CreateExtractValue(LoHi, 0, "lo");
16919     Value *Hi = Builder.CreateExtractValue(LoHi, 1, "hi");
16920     Lo = Builder.CreateZExt(Lo, ValTy, "lo64");
16921     Hi = Builder.CreateZExt(Hi, ValTy, "hi64");
16922     return Builder.CreateOr(
16923         Lo, Builder.CreateShl(Hi, ConstantInt::get(ValTy, 64)), "val64");
16924   }
16925 
16926   Type *Tys[] = { Addr->getType() };
16927   Intrinsic::ID Int =
16928       IsAcquire ? Intrinsic::aarch64_ldaxr : Intrinsic::aarch64_ldxr;
16929   Function *Ldxr = Intrinsic::getDeclaration(M, Int, Tys);
16930 
16931   Type *EltTy = cast<PointerType>(Addr->getType())->getElementType();
16932 
16933   const DataLayout &DL = M->getDataLayout();
16934   IntegerType *IntEltTy = Builder.getIntNTy(DL.getTypeSizeInBits(EltTy));
16935   Value *Trunc = Builder.CreateTrunc(Builder.CreateCall(Ldxr, Addr), IntEltTy);
16936 
16937   return Builder.CreateBitCast(Trunc, EltTy);
16938 }
16939 
emitAtomicCmpXchgNoStoreLLBalance(IRBuilder<> & Builder) const16940 void AArch64TargetLowering::emitAtomicCmpXchgNoStoreLLBalance(
16941     IRBuilder<> &Builder) const {
16942   Module *M = Builder.GetInsertBlock()->getParent()->getParent();
16943   Builder.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::aarch64_clrex));
16944 }
16945 
emitStoreConditional(IRBuilder<> & Builder,Value * Val,Value * Addr,AtomicOrdering Ord) const16946 Value *AArch64TargetLowering::emitStoreConditional(IRBuilder<> &Builder,
16947                                                    Value *Val, Value *Addr,
16948                                                    AtomicOrdering Ord) const {
16949   Module *M = Builder.GetInsertBlock()->getParent()->getParent();
16950   bool IsRelease = isReleaseOrStronger(Ord);
16951 
16952   // Since the intrinsics must have legal type, the i128 intrinsics take two
16953   // parameters: "i64, i64". We must marshal Val into the appropriate form
16954   // before the call.
16955   if (Val->getType()->getPrimitiveSizeInBits() == 128) {
16956     Intrinsic::ID Int =
16957         IsRelease ? Intrinsic::aarch64_stlxp : Intrinsic::aarch64_stxp;
16958     Function *Stxr = Intrinsic::getDeclaration(M, Int);
16959     Type *Int64Ty = Type::getInt64Ty(M->getContext());
16960 
16961     Value *Lo = Builder.CreateTrunc(Val, Int64Ty, "lo");
16962     Value *Hi = Builder.CreateTrunc(Builder.CreateLShr(Val, 64), Int64Ty, "hi");
16963     Addr = Builder.CreateBitCast(Addr, Type::getInt8PtrTy(M->getContext()));
16964     return Builder.CreateCall(Stxr, {Lo, Hi, Addr});
16965   }
16966 
16967   Intrinsic::ID Int =
16968       IsRelease ? Intrinsic::aarch64_stlxr : Intrinsic::aarch64_stxr;
16969   Type *Tys[] = { Addr->getType() };
16970   Function *Stxr = Intrinsic::getDeclaration(M, Int, Tys);
16971 
16972   const DataLayout &DL = M->getDataLayout();
16973   IntegerType *IntValTy = Builder.getIntNTy(DL.getTypeSizeInBits(Val->getType()));
16974   Val = Builder.CreateBitCast(Val, IntValTy);
16975 
16976   return Builder.CreateCall(Stxr,
16977                             {Builder.CreateZExtOrBitCast(
16978                                  Val, Stxr->getFunctionType()->getParamType(0)),
16979                              Addr});
16980 }
16981 
functionArgumentNeedsConsecutiveRegisters(Type * Ty,CallingConv::ID CallConv,bool isVarArg) const16982 bool AArch64TargetLowering::functionArgumentNeedsConsecutiveRegisters(
16983     Type *Ty, CallingConv::ID CallConv, bool isVarArg) const {
16984   if (Ty->isArrayTy())
16985     return true;
16986 
16987   const TypeSize &TySize = Ty->getPrimitiveSizeInBits();
16988   if (TySize.isScalable() && TySize.getKnownMinSize() > 128)
16989     return true;
16990 
16991   return false;
16992 }
16993 
shouldNormalizeToSelectSequence(LLVMContext &,EVT) const16994 bool AArch64TargetLowering::shouldNormalizeToSelectSequence(LLVMContext &,
16995                                                             EVT) const {
16996   return false;
16997 }
16998 
UseTlsOffset(IRBuilder<> & IRB,unsigned Offset)16999 static Value *UseTlsOffset(IRBuilder<> &IRB, unsigned Offset) {
17000   Module *M = IRB.GetInsertBlock()->getParent()->getParent();
17001   Function *ThreadPointerFunc =
17002       Intrinsic::getDeclaration(M, Intrinsic::thread_pointer);
17003   return IRB.CreatePointerCast(
17004       IRB.CreateConstGEP1_32(IRB.getInt8Ty(), IRB.CreateCall(ThreadPointerFunc),
17005                              Offset),
17006       IRB.getInt8PtrTy()->getPointerTo(0));
17007 }
17008 
getIRStackGuard(IRBuilder<> & IRB) const17009 Value *AArch64TargetLowering::getIRStackGuard(IRBuilder<> &IRB) const {
17010   // Android provides a fixed TLS slot for the stack cookie. See the definition
17011   // of TLS_SLOT_STACK_GUARD in
17012   // https://android.googlesource.com/platform/bionic/+/master/libc/private/bionic_tls.h
17013   if (Subtarget->isTargetAndroid())
17014     return UseTlsOffset(IRB, 0x28);
17015 
17016   // Fuchsia is similar.
17017   // <zircon/tls.h> defines ZX_TLS_STACK_GUARD_OFFSET with this value.
17018   if (Subtarget->isTargetFuchsia())
17019     return UseTlsOffset(IRB, -0x10);
17020 
17021   return TargetLowering::getIRStackGuard(IRB);
17022 }
17023 
insertSSPDeclarations(Module & M) const17024 void AArch64TargetLowering::insertSSPDeclarations(Module &M) const {
17025   // MSVC CRT provides functionalities for stack protection.
17026   if (Subtarget->getTargetTriple().isWindowsMSVCEnvironment()) {
17027     // MSVC CRT has a global variable holding security cookie.
17028     M.getOrInsertGlobal("__security_cookie",
17029                         Type::getInt8PtrTy(M.getContext()));
17030 
17031     // MSVC CRT has a function to validate security cookie.
17032     FunctionCallee SecurityCheckCookie = M.getOrInsertFunction(
17033         "__security_check_cookie", Type::getVoidTy(M.getContext()),
17034         Type::getInt8PtrTy(M.getContext()));
17035     if (Function *F = dyn_cast<Function>(SecurityCheckCookie.getCallee())) {
17036       F->setCallingConv(CallingConv::Win64);
17037       F->addAttribute(1, Attribute::AttrKind::InReg);
17038     }
17039     return;
17040   }
17041   TargetLowering::insertSSPDeclarations(M);
17042 }
17043 
getSDagStackGuard(const Module & M) const17044 Value *AArch64TargetLowering::getSDagStackGuard(const Module &M) const {
17045   // MSVC CRT has a global variable holding security cookie.
17046   if (Subtarget->getTargetTriple().isWindowsMSVCEnvironment())
17047     return M.getGlobalVariable("__security_cookie");
17048   return TargetLowering::getSDagStackGuard(M);
17049 }
17050 
getSSPStackGuardCheck(const Module & M) const17051 Function *AArch64TargetLowering::getSSPStackGuardCheck(const Module &M) const {
17052   // MSVC CRT has a function to validate security cookie.
17053   if (Subtarget->getTargetTriple().isWindowsMSVCEnvironment())
17054     return M.getFunction("__security_check_cookie");
17055   return TargetLowering::getSSPStackGuardCheck(M);
17056 }
17057 
getSafeStackPointerLocation(IRBuilder<> & IRB) const17058 Value *AArch64TargetLowering::getSafeStackPointerLocation(IRBuilder<> &IRB) const {
17059   // Android provides a fixed TLS slot for the SafeStack pointer. See the
17060   // definition of TLS_SLOT_SAFESTACK in
17061   // https://android.googlesource.com/platform/bionic/+/master/libc/private/bionic_tls.h
17062   if (Subtarget->isTargetAndroid())
17063     return UseTlsOffset(IRB, 0x48);
17064 
17065   // Fuchsia is similar.
17066   // <zircon/tls.h> defines ZX_TLS_UNSAFE_SP_OFFSET with this value.
17067   if (Subtarget->isTargetFuchsia())
17068     return UseTlsOffset(IRB, -0x8);
17069 
17070   return TargetLowering::getSafeStackPointerLocation(IRB);
17071 }
17072 
isMaskAndCmp0FoldingBeneficial(const Instruction & AndI) const17073 bool AArch64TargetLowering::isMaskAndCmp0FoldingBeneficial(
17074     const Instruction &AndI) const {
17075   // Only sink 'and' mask to cmp use block if it is masking a single bit, since
17076   // this is likely to be fold the and/cmp/br into a single tbz instruction.  It
17077   // may be beneficial to sink in other cases, but we would have to check that
17078   // the cmp would not get folded into the br to form a cbz for these to be
17079   // beneficial.
17080   ConstantInt* Mask = dyn_cast<ConstantInt>(AndI.getOperand(1));
17081   if (!Mask)
17082     return false;
17083   return Mask->getValue().isPowerOf2();
17084 }
17085 
17086 bool AArch64TargetLowering::
shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(SDValue X,ConstantSDNode * XC,ConstantSDNode * CC,SDValue Y,unsigned OldShiftOpcode,unsigned NewShiftOpcode,SelectionDAG & DAG) const17087     shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(
17088         SDValue X, ConstantSDNode *XC, ConstantSDNode *CC, SDValue Y,
17089         unsigned OldShiftOpcode, unsigned NewShiftOpcode,
17090         SelectionDAG &DAG) const {
17091   // Does baseline recommend not to perform the fold by default?
17092   if (!TargetLowering::shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(
17093           X, XC, CC, Y, OldShiftOpcode, NewShiftOpcode, DAG))
17094     return false;
17095   // Else, if this is a vector shift, prefer 'shl'.
17096   return X.getValueType().isScalarInteger() || NewShiftOpcode == ISD::SHL;
17097 }
17098 
shouldExpandShift(SelectionDAG & DAG,SDNode * N) const17099 bool AArch64TargetLowering::shouldExpandShift(SelectionDAG &DAG,
17100                                               SDNode *N) const {
17101   if (DAG.getMachineFunction().getFunction().hasMinSize() &&
17102       !Subtarget->isTargetWindows() && !Subtarget->isTargetDarwin())
17103     return false;
17104   return true;
17105 }
17106 
initializeSplitCSR(MachineBasicBlock * Entry) const17107 void AArch64TargetLowering::initializeSplitCSR(MachineBasicBlock *Entry) const {
17108   // Update IsSplitCSR in AArch64unctionInfo.
17109   AArch64FunctionInfo *AFI = Entry->getParent()->getInfo<AArch64FunctionInfo>();
17110   AFI->setIsSplitCSR(true);
17111 }
17112 
insertCopiesSplitCSR(MachineBasicBlock * Entry,const SmallVectorImpl<MachineBasicBlock * > & Exits) const17113 void AArch64TargetLowering::insertCopiesSplitCSR(
17114     MachineBasicBlock *Entry,
17115     const SmallVectorImpl<MachineBasicBlock *> &Exits) const {
17116   const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
17117   const MCPhysReg *IStart = TRI->getCalleeSavedRegsViaCopy(Entry->getParent());
17118   if (!IStart)
17119     return;
17120 
17121   const TargetInstrInfo *TII = Subtarget->getInstrInfo();
17122   MachineRegisterInfo *MRI = &Entry->getParent()->getRegInfo();
17123   MachineBasicBlock::iterator MBBI = Entry->begin();
17124   for (const MCPhysReg *I = IStart; *I; ++I) {
17125     const TargetRegisterClass *RC = nullptr;
17126     if (AArch64::GPR64RegClass.contains(*I))
17127       RC = &AArch64::GPR64RegClass;
17128     else if (AArch64::FPR64RegClass.contains(*I))
17129       RC = &AArch64::FPR64RegClass;
17130     else
17131       llvm_unreachable("Unexpected register class in CSRsViaCopy!");
17132 
17133     Register NewVR = MRI->createVirtualRegister(RC);
17134     // Create copy from CSR to a virtual register.
17135     // FIXME: this currently does not emit CFI pseudo-instructions, it works
17136     // fine for CXX_FAST_TLS since the C++-style TLS access functions should be
17137     // nounwind. If we want to generalize this later, we may need to emit
17138     // CFI pseudo-instructions.
17139     assert(Entry->getParent()->getFunction().hasFnAttribute(
17140                Attribute::NoUnwind) &&
17141            "Function should be nounwind in insertCopiesSplitCSR!");
17142     Entry->addLiveIn(*I);
17143     BuildMI(*Entry, MBBI, DebugLoc(), TII->get(TargetOpcode::COPY), NewVR)
17144         .addReg(*I);
17145 
17146     // Insert the copy-back instructions right before the terminator.
17147     for (auto *Exit : Exits)
17148       BuildMI(*Exit, Exit->getFirstTerminator(), DebugLoc(),
17149               TII->get(TargetOpcode::COPY), *I)
17150           .addReg(NewVR);
17151   }
17152 }
17153 
isIntDivCheap(EVT VT,AttributeList Attr) const17154 bool AArch64TargetLowering::isIntDivCheap(EVT VT, AttributeList Attr) const {
17155   // Integer division on AArch64 is expensive. However, when aggressively
17156   // optimizing for code size, we prefer to use a div instruction, as it is
17157   // usually smaller than the alternative sequence.
17158   // The exception to this is vector division. Since AArch64 doesn't have vector
17159   // integer division, leaving the division as-is is a loss even in terms of
17160   // size, because it will have to be scalarized, while the alternative code
17161   // sequence can be performed in vector form.
17162   bool OptSize = Attr.hasFnAttribute(Attribute::MinSize);
17163   return OptSize && !VT.isVector();
17164 }
17165 
preferIncOfAddToSubOfNot(EVT VT) const17166 bool AArch64TargetLowering::preferIncOfAddToSubOfNot(EVT VT) const {
17167   // We want inc-of-add for scalars and sub-of-not for vectors.
17168   return VT.isScalarInteger();
17169 }
17170 
enableAggressiveFMAFusion(EVT VT) const17171 bool AArch64TargetLowering::enableAggressiveFMAFusion(EVT VT) const {
17172   return Subtarget->hasAggressiveFMA() && VT.isFloatingPoint();
17173 }
17174 
17175 unsigned
getVaListSizeInBits(const DataLayout & DL) const17176 AArch64TargetLowering::getVaListSizeInBits(const DataLayout &DL) const {
17177   if (Subtarget->isTargetDarwin() || Subtarget->isTargetWindows())
17178     return getPointerTy(DL).getSizeInBits();
17179 
17180   return 3 * getPointerTy(DL).getSizeInBits() + 2 * 32;
17181 }
17182 
finalizeLowering(MachineFunction & MF) const17183 void AArch64TargetLowering::finalizeLowering(MachineFunction &MF) const {
17184   MF.getFrameInfo().computeMaxCallFrameSize(MF);
17185   TargetLoweringBase::finalizeLowering(MF);
17186 }
17187 
17188 // Unlike X86, we let frame lowering assign offsets to all catch objects.
needsFixedCatchObjects() const17189 bool AArch64TargetLowering::needsFixedCatchObjects() const {
17190   return false;
17191 }
17192 
shouldLocalize(const MachineInstr & MI,const TargetTransformInfo * TTI) const17193 bool AArch64TargetLowering::shouldLocalize(
17194     const MachineInstr &MI, const TargetTransformInfo *TTI) const {
17195   switch (MI.getOpcode()) {
17196   case TargetOpcode::G_GLOBAL_VALUE: {
17197     // On Darwin, TLS global vars get selected into function calls, which
17198     // we don't want localized, as they can get moved into the middle of a
17199     // another call sequence.
17200     const GlobalValue &GV = *MI.getOperand(1).getGlobal();
17201     if (GV.isThreadLocal() && Subtarget->isTargetMachO())
17202       return false;
17203     break;
17204   }
17205   // If we legalized G_GLOBAL_VALUE into ADRP + G_ADD_LOW, mark both as being
17206   // localizable.
17207   case AArch64::ADRP:
17208   case AArch64::G_ADD_LOW:
17209     return true;
17210   default:
17211     break;
17212   }
17213   return TargetLoweringBase::shouldLocalize(MI, TTI);
17214 }
17215 
fallBackToDAGISel(const Instruction & Inst) const17216 bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
17217   if (isa<ScalableVectorType>(Inst.getType()))
17218     return true;
17219 
17220   for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
17221     if (isa<ScalableVectorType>(Inst.getOperand(i)->getType()))
17222       return true;
17223 
17224   if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
17225     if (isa<ScalableVectorType>(AI->getAllocatedType()))
17226       return true;
17227   }
17228 
17229   return false;
17230 }
17231 
17232 // Return the largest legal scalable vector type that matches VT's element type.
getContainerForFixedLengthVector(SelectionDAG & DAG,EVT VT)17233 static EVT getContainerForFixedLengthVector(SelectionDAG &DAG, EVT VT) {
17234   assert(VT.isFixedLengthVector() &&
17235          DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
17236          "Expected legal fixed length vector!");
17237   switch (VT.getVectorElementType().getSimpleVT().SimpleTy) {
17238   default:
17239     llvm_unreachable("unexpected element type for SVE container");
17240   case MVT::i8:
17241     return EVT(MVT::nxv16i8);
17242   case MVT::i16:
17243     return EVT(MVT::nxv8i16);
17244   case MVT::i32:
17245     return EVT(MVT::nxv4i32);
17246   case MVT::i64:
17247     return EVT(MVT::nxv2i64);
17248   case MVT::f16:
17249     return EVT(MVT::nxv8f16);
17250   case MVT::f32:
17251     return EVT(MVT::nxv4f32);
17252   case MVT::f64:
17253     return EVT(MVT::nxv2f64);
17254   }
17255 }
17256 
17257 // Return a PTRUE with active lanes corresponding to the extent of VT.
getPredicateForFixedLengthVector(SelectionDAG & DAG,SDLoc & DL,EVT VT)17258 static SDValue getPredicateForFixedLengthVector(SelectionDAG &DAG, SDLoc &DL,
17259                                                 EVT VT) {
17260   assert(VT.isFixedLengthVector() &&
17261          DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
17262          "Expected legal fixed length vector!");
17263 
17264   int PgPattern;
17265   switch (VT.getVectorNumElements()) {
17266   default:
17267     llvm_unreachable("unexpected element count for SVE predicate");
17268   case 1:
17269     PgPattern = AArch64SVEPredPattern::vl1;
17270     break;
17271   case 2:
17272     PgPattern = AArch64SVEPredPattern::vl2;
17273     break;
17274   case 4:
17275     PgPattern = AArch64SVEPredPattern::vl4;
17276     break;
17277   case 8:
17278     PgPattern = AArch64SVEPredPattern::vl8;
17279     break;
17280   case 16:
17281     PgPattern = AArch64SVEPredPattern::vl16;
17282     break;
17283   case 32:
17284     PgPattern = AArch64SVEPredPattern::vl32;
17285     break;
17286   case 64:
17287     PgPattern = AArch64SVEPredPattern::vl64;
17288     break;
17289   case 128:
17290     PgPattern = AArch64SVEPredPattern::vl128;
17291     break;
17292   case 256:
17293     PgPattern = AArch64SVEPredPattern::vl256;
17294     break;
17295   }
17296 
17297   // TODO: For vectors that are exactly getMaxSVEVectorSizeInBits big, we can
17298   // use AArch64SVEPredPattern::all, which can enable the use of unpredicated
17299   // variants of instructions when available.
17300 
17301   MVT MaskVT;
17302   switch (VT.getVectorElementType().getSimpleVT().SimpleTy) {
17303   default:
17304     llvm_unreachable("unexpected element type for SVE predicate");
17305   case MVT::i8:
17306     MaskVT = MVT::nxv16i1;
17307     break;
17308   case MVT::i16:
17309   case MVT::f16:
17310     MaskVT = MVT::nxv8i1;
17311     break;
17312   case MVT::i32:
17313   case MVT::f32:
17314     MaskVT = MVT::nxv4i1;
17315     break;
17316   case MVT::i64:
17317   case MVT::f64:
17318     MaskVT = MVT::nxv2i1;
17319     break;
17320   }
17321 
17322   return DAG.getNode(AArch64ISD::PTRUE, DL, MaskVT,
17323                      DAG.getTargetConstant(PgPattern, DL, MVT::i64));
17324 }
17325 
getPredicateForScalableVector(SelectionDAG & DAG,SDLoc & DL,EVT VT)17326 static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL,
17327                                              EVT VT) {
17328   assert(VT.isScalableVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
17329          "Expected legal scalable vector!");
17330   auto PredTy = VT.changeVectorElementType(MVT::i1);
17331   return getPTrue(DAG, DL, PredTy, AArch64SVEPredPattern::all);
17332 }
17333 
getPredicateForVector(SelectionDAG & DAG,SDLoc & DL,EVT VT)17334 static SDValue getPredicateForVector(SelectionDAG &DAG, SDLoc &DL, EVT VT) {
17335   if (VT.isFixedLengthVector())
17336     return getPredicateForFixedLengthVector(DAG, DL, VT);
17337 
17338   return getPredicateForScalableVector(DAG, DL, VT);
17339 }
17340 
17341 // Grow V to consume an entire SVE register.
convertToScalableVector(SelectionDAG & DAG,EVT VT,SDValue V)17342 static SDValue convertToScalableVector(SelectionDAG &DAG, EVT VT, SDValue V) {
17343   assert(VT.isScalableVector() &&
17344          "Expected to convert into a scalable vector!");
17345   assert(V.getValueType().isFixedLengthVector() &&
17346          "Expected a fixed length vector operand!");
17347   SDLoc DL(V);
17348   SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
17349   return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT), V, Zero);
17350 }
17351 
17352 // Shrink V so it's just big enough to maintain a VT's worth of data.
convertFromScalableVector(SelectionDAG & DAG,EVT VT,SDValue V)17353 static SDValue convertFromScalableVector(SelectionDAG &DAG, EVT VT, SDValue V) {
17354   assert(VT.isFixedLengthVector() &&
17355          "Expected to convert into a fixed length vector!");
17356   assert(V.getValueType().isScalableVector() &&
17357          "Expected a scalable vector operand!");
17358   SDLoc DL(V);
17359   SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
17360   return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, V, Zero);
17361 }
17362 
17363 // Convert all fixed length vector loads larger than NEON to masked_loads.
LowerFixedLengthVectorLoadToSVE(SDValue Op,SelectionDAG & DAG) const17364 SDValue AArch64TargetLowering::LowerFixedLengthVectorLoadToSVE(
17365     SDValue Op, SelectionDAG &DAG) const {
17366   auto Load = cast<LoadSDNode>(Op);
17367 
17368   SDLoc DL(Op);
17369   EVT VT = Op.getValueType();
17370   EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
17371 
17372   auto NewLoad = DAG.getMaskedLoad(
17373       ContainerVT, DL, Load->getChain(), Load->getBasePtr(), Load->getOffset(),
17374       getPredicateForFixedLengthVector(DAG, DL, VT), DAG.getUNDEF(ContainerVT),
17375       Load->getMemoryVT(), Load->getMemOperand(), Load->getAddressingMode(),
17376       Load->getExtensionType());
17377 
17378   auto Result = convertFromScalableVector(DAG, VT, NewLoad);
17379   SDValue MergedValues[2] = {Result, Load->getChain()};
17380   return DAG.getMergeValues(MergedValues, DL);
17381 }
17382 
convertFixedMaskToScalableVector(SDValue Mask,SelectionDAG & DAG)17383 static SDValue convertFixedMaskToScalableVector(SDValue Mask,
17384                                                 SelectionDAG &DAG) {
17385   SDLoc DL(Mask);
17386   EVT InVT = Mask.getValueType();
17387   EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
17388 
17389   auto Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
17390   auto Op2 = DAG.getConstant(0, DL, ContainerVT);
17391   auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
17392 
17393   EVT CmpVT = Pg.getValueType();
17394   return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, CmpVT,
17395                      {Pg, Op1, Op2, DAG.getCondCode(ISD::SETNE)});
17396 }
17397 
17398 // Convert all fixed length vector loads larger than NEON to masked_loads.
LowerFixedLengthVectorMLoadToSVE(SDValue Op,SelectionDAG & DAG) const17399 SDValue AArch64TargetLowering::LowerFixedLengthVectorMLoadToSVE(
17400     SDValue Op, SelectionDAG &DAG) const {
17401   auto Load = cast<MaskedLoadSDNode>(Op);
17402 
17403   if (Load->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD)
17404     return SDValue();
17405 
17406   SDLoc DL(Op);
17407   EVT VT = Op.getValueType();
17408   EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
17409 
17410   SDValue Mask = convertFixedMaskToScalableVector(Load->getMask(), DAG);
17411 
17412   SDValue PassThru;
17413   bool IsPassThruZeroOrUndef = false;
17414 
17415   if (Load->getPassThru()->isUndef()) {
17416     PassThru = DAG.getUNDEF(ContainerVT);
17417     IsPassThruZeroOrUndef = true;
17418   } else {
17419     if (ContainerVT.isInteger())
17420       PassThru = DAG.getConstant(0, DL, ContainerVT);
17421     else
17422       PassThru = DAG.getConstantFP(0, DL, ContainerVT);
17423     if (isZerosVector(Load->getPassThru().getNode()))
17424       IsPassThruZeroOrUndef = true;
17425   }
17426 
17427   auto NewLoad = DAG.getMaskedLoad(
17428       ContainerVT, DL, Load->getChain(), Load->getBasePtr(), Load->getOffset(),
17429       Mask, PassThru, Load->getMemoryVT(), Load->getMemOperand(),
17430       Load->getAddressingMode(), Load->getExtensionType());
17431 
17432   if (!IsPassThruZeroOrUndef) {
17433     SDValue OldPassThru =
17434         convertToScalableVector(DAG, ContainerVT, Load->getPassThru());
17435     NewLoad = DAG.getSelect(DL, ContainerVT, Mask, NewLoad, OldPassThru);
17436   }
17437 
17438   auto Result = convertFromScalableVector(DAG, VT, NewLoad);
17439   SDValue MergedValues[2] = {Result, Load->getChain()};
17440   return DAG.getMergeValues(MergedValues, DL);
17441 }
17442 
17443 // Convert all fixed length vector stores larger than NEON to masked_stores.
LowerFixedLengthVectorStoreToSVE(SDValue Op,SelectionDAG & DAG) const17444 SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE(
17445     SDValue Op, SelectionDAG &DAG) const {
17446   auto Store = cast<StoreSDNode>(Op);
17447 
17448   SDLoc DL(Op);
17449   EVT VT = Store->getValue().getValueType();
17450   EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
17451 
17452   auto NewValue = convertToScalableVector(DAG, ContainerVT, Store->getValue());
17453   return DAG.getMaskedStore(
17454       Store->getChain(), DL, NewValue, Store->getBasePtr(), Store->getOffset(),
17455       getPredicateForFixedLengthVector(DAG, DL, VT), Store->getMemoryVT(),
17456       Store->getMemOperand(), Store->getAddressingMode(),
17457       Store->isTruncatingStore());
17458 }
17459 
LowerFixedLengthVectorMStoreToSVE(SDValue Op,SelectionDAG & DAG) const17460 SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE(
17461     SDValue Op, SelectionDAG &DAG) const {
17462   auto Store = cast<MaskedStoreSDNode>(Op);
17463 
17464   if (Store->isTruncatingStore())
17465     return SDValue();
17466 
17467   SDLoc DL(Op);
17468   EVT VT = Store->getValue().getValueType();
17469   EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
17470 
17471   auto NewValue = convertToScalableVector(DAG, ContainerVT, Store->getValue());
17472   SDValue Mask = convertFixedMaskToScalableVector(Store->getMask(), DAG);
17473 
17474   return DAG.getMaskedStore(
17475       Store->getChain(), DL, NewValue, Store->getBasePtr(), Store->getOffset(),
17476       Mask, Store->getMemoryVT(), Store->getMemOperand(),
17477       Store->getAddressingMode(), Store->isTruncatingStore());
17478 }
17479 
LowerFixedLengthVectorIntDivideToSVE(SDValue Op,SelectionDAG & DAG) const17480 SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE(
17481     SDValue Op, SelectionDAG &DAG) const {
17482   SDLoc dl(Op);
17483   EVT VT = Op.getValueType();
17484   EVT EltVT = VT.getVectorElementType();
17485 
17486   bool Signed = Op.getOpcode() == ISD::SDIV;
17487   unsigned PredOpcode = Signed ? AArch64ISD::SDIV_PRED : AArch64ISD::UDIV_PRED;
17488 
17489   // Scalable vector i32/i64 DIV is supported.
17490   if (EltVT == MVT::i32 || EltVT == MVT::i64)
17491     return LowerToPredicatedOp(Op, DAG, PredOpcode, /*OverrideNEON=*/true);
17492 
17493   // Scalable vector i8/i16 DIV is not supported. Promote it to i32.
17494   EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
17495   EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
17496   EVT FixedWidenedVT = HalfVT.widenIntegerVectorElementType(*DAG.getContext());
17497   EVT ScalableWidenedVT = getContainerForFixedLengthVector(DAG, FixedWidenedVT);
17498 
17499   // If this is not a full vector, extend, div, and truncate it.
17500   EVT WidenedVT = VT.widenIntegerVectorElementType(*DAG.getContext());
17501   if (DAG.getTargetLoweringInfo().isTypeLegal(WidenedVT)) {
17502     unsigned ExtendOpcode = Signed ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
17503     SDValue Op0 = DAG.getNode(ExtendOpcode, dl, WidenedVT, Op.getOperand(0));
17504     SDValue Op1 = DAG.getNode(ExtendOpcode, dl, WidenedVT, Op.getOperand(1));
17505     SDValue Div = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0, Op1);
17506     return DAG.getNode(ISD::TRUNCATE, dl, VT, Div);
17507   }
17508 
17509   // Convert the operands to scalable vectors.
17510   SDValue Op0 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(0));
17511   SDValue Op1 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(1));
17512 
17513   // Extend the scalable operands.
17514   unsigned UnpkLo = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO;
17515   unsigned UnpkHi = Signed ? AArch64ISD::SUNPKHI : AArch64ISD::UUNPKHI;
17516   SDValue Op0Lo = DAG.getNode(UnpkLo, dl, ScalableWidenedVT, Op0);
17517   SDValue Op1Lo = DAG.getNode(UnpkLo, dl, ScalableWidenedVT, Op1);
17518   SDValue Op0Hi = DAG.getNode(UnpkHi, dl, ScalableWidenedVT, Op0);
17519   SDValue Op1Hi = DAG.getNode(UnpkHi, dl, ScalableWidenedVT, Op1);
17520 
17521   // Convert back to fixed vectors so the DIV can be further lowered.
17522   Op0Lo = convertFromScalableVector(DAG, FixedWidenedVT, Op0Lo);
17523   Op1Lo = convertFromScalableVector(DAG, FixedWidenedVT, Op1Lo);
17524   Op0Hi = convertFromScalableVector(DAG, FixedWidenedVT, Op0Hi);
17525   Op1Hi = convertFromScalableVector(DAG, FixedWidenedVT, Op1Hi);
17526   SDValue ResultLo = DAG.getNode(Op.getOpcode(), dl, FixedWidenedVT,
17527                                  Op0Lo, Op1Lo);
17528   SDValue ResultHi = DAG.getNode(Op.getOpcode(), dl, FixedWidenedVT,
17529                                  Op0Hi, Op1Hi);
17530 
17531   // Convert again to scalable vectors to truncate.
17532   ResultLo = convertToScalableVector(DAG, ScalableWidenedVT, ResultLo);
17533   ResultHi = convertToScalableVector(DAG, ScalableWidenedVT, ResultHi);
17534   SDValue ScalableResult = DAG.getNode(AArch64ISD::UZP1, dl, ContainerVT,
17535                                        ResultLo, ResultHi);
17536 
17537   return convertFromScalableVector(DAG, VT, ScalableResult);
17538 }
17539 
LowerFixedLengthVectorIntExtendToSVE(SDValue Op,SelectionDAG & DAG) const17540 SDValue AArch64TargetLowering::LowerFixedLengthVectorIntExtendToSVE(
17541     SDValue Op, SelectionDAG &DAG) const {
17542   EVT VT = Op.getValueType();
17543   assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
17544 
17545   SDLoc DL(Op);
17546   SDValue Val = Op.getOperand(0);
17547   EVT ContainerVT = getContainerForFixedLengthVector(DAG, Val.getValueType());
17548   Val = convertToScalableVector(DAG, ContainerVT, Val);
17549 
17550   bool Signed = Op.getOpcode() == ISD::SIGN_EXTEND;
17551   unsigned ExtendOpc = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO;
17552 
17553   // Repeatedly unpack Val until the result is of the desired element type.
17554   switch (ContainerVT.getSimpleVT().SimpleTy) {
17555   default:
17556     llvm_unreachable("unimplemented container type");
17557   case MVT::nxv16i8:
17558     Val = DAG.getNode(ExtendOpc, DL, MVT::nxv8i16, Val);
17559     if (VT.getVectorElementType() == MVT::i16)
17560       break;
17561     LLVM_FALLTHROUGH;
17562   case MVT::nxv8i16:
17563     Val = DAG.getNode(ExtendOpc, DL, MVT::nxv4i32, Val);
17564     if (VT.getVectorElementType() == MVT::i32)
17565       break;
17566     LLVM_FALLTHROUGH;
17567   case MVT::nxv4i32:
17568     Val = DAG.getNode(ExtendOpc, DL, MVT::nxv2i64, Val);
17569     assert(VT.getVectorElementType() == MVT::i64 && "Unexpected element type!");
17570     break;
17571   }
17572 
17573   return convertFromScalableVector(DAG, VT, Val);
17574 }
17575 
LowerFixedLengthVectorTruncateToSVE(SDValue Op,SelectionDAG & DAG) const17576 SDValue AArch64TargetLowering::LowerFixedLengthVectorTruncateToSVE(
17577     SDValue Op, SelectionDAG &DAG) const {
17578   EVT VT = Op.getValueType();
17579   assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
17580 
17581   SDLoc DL(Op);
17582   SDValue Val = Op.getOperand(0);
17583   EVT ContainerVT = getContainerForFixedLengthVector(DAG, Val.getValueType());
17584   Val = convertToScalableVector(DAG, ContainerVT, Val);
17585 
17586   // Repeatedly truncate Val until the result is of the desired element type.
17587   switch (ContainerVT.getSimpleVT().SimpleTy) {
17588   default:
17589     llvm_unreachable("unimplemented container type");
17590   case MVT::nxv2i64:
17591     Val = DAG.getNode(ISD::BITCAST, DL, MVT::nxv4i32, Val);
17592     Val = DAG.getNode(AArch64ISD::UZP1, DL, MVT::nxv4i32, Val, Val);
17593     if (VT.getVectorElementType() == MVT::i32)
17594       break;
17595     LLVM_FALLTHROUGH;
17596   case MVT::nxv4i32:
17597     Val = DAG.getNode(ISD::BITCAST, DL, MVT::nxv8i16, Val);
17598     Val = DAG.getNode(AArch64ISD::UZP1, DL, MVT::nxv8i16, Val, Val);
17599     if (VT.getVectorElementType() == MVT::i16)
17600       break;
17601     LLVM_FALLTHROUGH;
17602   case MVT::nxv8i16:
17603     Val = DAG.getNode(ISD::BITCAST, DL, MVT::nxv16i8, Val);
17604     Val = DAG.getNode(AArch64ISD::UZP1, DL, MVT::nxv16i8, Val, Val);
17605     assert(VT.getVectorElementType() == MVT::i8 && "Unexpected element type!");
17606     break;
17607   }
17608 
17609   return convertFromScalableVector(DAG, VT, Val);
17610 }
17611 
LowerFixedLengthExtractVectorElt(SDValue Op,SelectionDAG & DAG) const17612 SDValue AArch64TargetLowering::LowerFixedLengthExtractVectorElt(
17613     SDValue Op, SelectionDAG &DAG) const {
17614   EVT VT = Op.getValueType();
17615   EVT InVT = Op.getOperand(0).getValueType();
17616   assert(InVT.isFixedLengthVector() && "Expected fixed length vector type!");
17617 
17618   SDLoc DL(Op);
17619   EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
17620   SDValue Op0 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(0));
17621 
17622   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Op.getOperand(1));
17623 }
17624 
LowerFixedLengthInsertVectorElt(SDValue Op,SelectionDAG & DAG) const17625 SDValue AArch64TargetLowering::LowerFixedLengthInsertVectorElt(
17626     SDValue Op, SelectionDAG &DAG) const {
17627   EVT VT = Op.getValueType();
17628   assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
17629 
17630   SDLoc DL(Op);
17631   EVT InVT = Op.getOperand(0).getValueType();
17632   EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
17633   SDValue Op0 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(0));
17634 
17635   auto ScalableRes = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ContainerVT, Op0,
17636                                  Op.getOperand(1), Op.getOperand(2));
17637 
17638   return convertFromScalableVector(DAG, VT, ScalableRes);
17639 }
17640 
17641 // Convert vector operation 'Op' to an equivalent predicated operation whereby
17642 // the original operation's type is used to construct a suitable predicate.
17643 // NOTE: The results for inactive lanes are undefined.
LowerToPredicatedOp(SDValue Op,SelectionDAG & DAG,unsigned NewOp,bool OverrideNEON) const17644 SDValue AArch64TargetLowering::LowerToPredicatedOp(SDValue Op,
17645                                                    SelectionDAG &DAG,
17646                                                    unsigned NewOp,
17647                                                    bool OverrideNEON) const {
17648   EVT VT = Op.getValueType();
17649   SDLoc DL(Op);
17650   auto Pg = getPredicateForVector(DAG, DL, VT);
17651 
17652   if (useSVEForFixedLengthVectorVT(VT, OverrideNEON)) {
17653     EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
17654 
17655     // Create list of operands by converting existing ones to scalable types.
17656     SmallVector<SDValue, 4> Operands = {Pg};
17657     for (const SDValue &V : Op->op_values()) {
17658       if (isa<CondCodeSDNode>(V)) {
17659         Operands.push_back(V);
17660         continue;
17661       }
17662 
17663       if (const VTSDNode *VTNode = dyn_cast<VTSDNode>(V)) {
17664         EVT VTArg = VTNode->getVT().getVectorElementType();
17665         EVT NewVTArg = ContainerVT.changeVectorElementType(VTArg);
17666         Operands.push_back(DAG.getValueType(NewVTArg));
17667         continue;
17668       }
17669 
17670       assert(useSVEForFixedLengthVectorVT(V.getValueType(), OverrideNEON) &&
17671              "Only fixed length vectors are supported!");
17672       Operands.push_back(convertToScalableVector(DAG, ContainerVT, V));
17673     }
17674 
17675     if (isMergePassthruOpcode(NewOp))
17676       Operands.push_back(DAG.getUNDEF(ContainerVT));
17677 
17678     auto ScalableRes = DAG.getNode(NewOp, DL, ContainerVT, Operands);
17679     return convertFromScalableVector(DAG, VT, ScalableRes);
17680   }
17681 
17682   assert(VT.isScalableVector() && "Only expect to lower scalable vector op!");
17683 
17684   SmallVector<SDValue, 4> Operands = {Pg};
17685   for (const SDValue &V : Op->op_values()) {
17686     assert((!V.getValueType().isVector() ||
17687             V.getValueType().isScalableVector()) &&
17688            "Only scalable vectors are supported!");
17689     Operands.push_back(V);
17690   }
17691 
17692   if (isMergePassthruOpcode(NewOp))
17693     Operands.push_back(DAG.getUNDEF(VT));
17694 
17695   return DAG.getNode(NewOp, DL, VT, Operands);
17696 }
17697 
17698 // If a fixed length vector operation has no side effects when applied to
17699 // undefined elements, we can safely use scalable vectors to perform the same
17700 // operation without needing to worry about predication.
LowerToScalableOp(SDValue Op,SelectionDAG & DAG) const17701 SDValue AArch64TargetLowering::LowerToScalableOp(SDValue Op,
17702                                                  SelectionDAG &DAG) const {
17703   EVT VT = Op.getValueType();
17704   assert(useSVEForFixedLengthVectorVT(VT) &&
17705          "Only expected to lower fixed length vector operation!");
17706   EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
17707 
17708   // Create list of operands by converting existing ones to scalable types.
17709   SmallVector<SDValue, 4> Ops;
17710   for (const SDValue &V : Op->op_values()) {
17711     assert(!isa<VTSDNode>(V) && "Unexpected VTSDNode node!");
17712 
17713     // Pass through non-vector operands.
17714     if (!V.getValueType().isVector()) {
17715       Ops.push_back(V);
17716       continue;
17717     }
17718 
17719     // "cast" fixed length vector to a scalable vector.
17720     assert(useSVEForFixedLengthVectorVT(V.getValueType()) &&
17721            "Only fixed length vectors are supported!");
17722     Ops.push_back(convertToScalableVector(DAG, ContainerVT, V));
17723   }
17724 
17725   auto ScalableRes = DAG.getNode(Op.getOpcode(), SDLoc(Op), ContainerVT, Ops);
17726   return convertFromScalableVector(DAG, VT, ScalableRes);
17727 }
17728 
LowerVECREDUCE_SEQ_FADD(SDValue ScalarOp,SelectionDAG & DAG) const17729 SDValue AArch64TargetLowering::LowerVECREDUCE_SEQ_FADD(SDValue ScalarOp,
17730     SelectionDAG &DAG) const {
17731   SDLoc DL(ScalarOp);
17732   SDValue AccOp = ScalarOp.getOperand(0);
17733   SDValue VecOp = ScalarOp.getOperand(1);
17734   EVT SrcVT = VecOp.getValueType();
17735   EVT ResVT = SrcVT.getVectorElementType();
17736 
17737   EVT ContainerVT = SrcVT;
17738   if (SrcVT.isFixedLengthVector()) {
17739     ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT);
17740     VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
17741   }
17742 
17743   SDValue Pg = getPredicateForVector(DAG, DL, SrcVT);
17744   SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
17745 
17746   // Convert operands to Scalable.
17747   AccOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ContainerVT,
17748                       DAG.getUNDEF(ContainerVT), AccOp, Zero);
17749 
17750   // Perform reduction.
17751   SDValue Rdx = DAG.getNode(AArch64ISD::FADDA_PRED, DL, ContainerVT,
17752                             Pg, AccOp, VecOp);
17753 
17754   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Rdx, Zero);
17755 }
17756 
LowerPredReductionToSVE(SDValue ReduceOp,SelectionDAG & DAG) const17757 SDValue AArch64TargetLowering::LowerPredReductionToSVE(SDValue ReduceOp,
17758                                                        SelectionDAG &DAG) const {
17759   SDLoc DL(ReduceOp);
17760   SDValue Op = ReduceOp.getOperand(0);
17761   EVT OpVT = Op.getValueType();
17762   EVT VT = ReduceOp.getValueType();
17763 
17764   if (!OpVT.isScalableVector() || OpVT.getVectorElementType() != MVT::i1)
17765     return SDValue();
17766 
17767   SDValue Pg = getPredicateForVector(DAG, DL, OpVT);
17768 
17769   switch (ReduceOp.getOpcode()) {
17770   default:
17771     return SDValue();
17772   case ISD::VECREDUCE_OR:
17773     return getPTest(DAG, VT, Pg, Op, AArch64CC::ANY_ACTIVE);
17774   case ISD::VECREDUCE_AND: {
17775     Op = DAG.getNode(ISD::XOR, DL, OpVT, Op, Pg);
17776     return getPTest(DAG, VT, Pg, Op, AArch64CC::NONE_ACTIVE);
17777   }
17778   case ISD::VECREDUCE_XOR: {
17779     SDValue ID =
17780         DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64);
17781     SDValue Cntp =
17782         DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64, ID, Pg, Op);
17783     return DAG.getAnyExtOrTrunc(Cntp, DL, VT);
17784   }
17785   }
17786 
17787   return SDValue();
17788 }
17789 
LowerReductionToSVE(unsigned Opcode,SDValue ScalarOp,SelectionDAG & DAG) const17790 SDValue AArch64TargetLowering::LowerReductionToSVE(unsigned Opcode,
17791                                                    SDValue ScalarOp,
17792                                                    SelectionDAG &DAG) const {
17793   SDLoc DL(ScalarOp);
17794   SDValue VecOp = ScalarOp.getOperand(0);
17795   EVT SrcVT = VecOp.getValueType();
17796 
17797   if (useSVEForFixedLengthVectorVT(SrcVT, true)) {
17798     EVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT);
17799     VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
17800   }
17801 
17802   // UADDV always returns an i64 result.
17803   EVT ResVT = (Opcode == AArch64ISD::UADDV_PRED) ? MVT::i64 :
17804                                                    SrcVT.getVectorElementType();
17805   EVT RdxVT = SrcVT;
17806   if (SrcVT.isFixedLengthVector() || Opcode == AArch64ISD::UADDV_PRED)
17807     RdxVT = getPackedSVEVectorVT(ResVT);
17808 
17809   SDValue Pg = getPredicateForVector(DAG, DL, SrcVT);
17810   SDValue Rdx = DAG.getNode(Opcode, DL, RdxVT, Pg, VecOp);
17811   SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT,
17812                             Rdx, DAG.getConstant(0, DL, MVT::i64));
17813 
17814   // The VEC_REDUCE nodes expect an element size result.
17815   if (ResVT != ScalarOp.getValueType())
17816     Res = DAG.getAnyExtOrTrunc(Res, DL, ScalarOp.getValueType());
17817 
17818   return Res;
17819 }
17820 
17821 SDValue
LowerFixedLengthVectorSelectToSVE(SDValue Op,SelectionDAG & DAG) const17822 AArch64TargetLowering::LowerFixedLengthVectorSelectToSVE(SDValue Op,
17823     SelectionDAG &DAG) const {
17824   EVT VT = Op.getValueType();
17825   SDLoc DL(Op);
17826 
17827   EVT InVT = Op.getOperand(1).getValueType();
17828   EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
17829   SDValue Op1 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(1));
17830   SDValue Op2 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(2));
17831 
17832   // Convert the mask to a predicated (NOTE: We don't need to worry about
17833   // inactive lanes since VSELECT is safe when given undefined elements).
17834   EVT MaskVT = Op.getOperand(0).getValueType();
17835   EVT MaskContainerVT = getContainerForFixedLengthVector(DAG, MaskVT);
17836   auto Mask = convertToScalableVector(DAG, MaskContainerVT, Op.getOperand(0));
17837   Mask = DAG.getNode(ISD::TRUNCATE, DL,
17838                      MaskContainerVT.changeVectorElementType(MVT::i1), Mask);
17839 
17840   auto ScalableRes = DAG.getNode(ISD::VSELECT, DL, ContainerVT,
17841                                 Mask, Op1, Op2);
17842 
17843   return convertFromScalableVector(DAG, VT, ScalableRes);
17844 }
17845 
LowerFixedLengthVectorSetccToSVE(SDValue Op,SelectionDAG & DAG) const17846 SDValue AArch64TargetLowering::LowerFixedLengthVectorSetccToSVE(
17847     SDValue Op, SelectionDAG &DAG) const {
17848   SDLoc DL(Op);
17849   EVT InVT = Op.getOperand(0).getValueType();
17850   EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
17851 
17852   assert(useSVEForFixedLengthVectorVT(InVT) &&
17853          "Only expected to lower fixed length vector operation!");
17854   assert(Op.getValueType() == InVT.changeTypeToInteger() &&
17855          "Expected integer result of the same bit length as the inputs!");
17856 
17857   auto Op1 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(0));
17858   auto Op2 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(1));
17859   auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
17860 
17861   EVT CmpVT = Pg.getValueType();
17862   auto Cmp = DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, CmpVT,
17863                          {Pg, Op1, Op2, Op.getOperand(2)});
17864 
17865   EVT PromoteVT = ContainerVT.changeTypeToInteger();
17866   auto Promote = DAG.getBoolExtOrTrunc(Cmp, DL, PromoteVT, InVT);
17867   return convertFromScalableVector(DAG, Op.getValueType(), Promote);
17868 }
17869 
17870 SDValue
LowerFixedLengthBitcastToSVE(SDValue Op,SelectionDAG & DAG) const17871 AArch64TargetLowering::LowerFixedLengthBitcastToSVE(SDValue Op,
17872                                                     SelectionDAG &DAG) const {
17873   SDLoc DL(Op);
17874   auto SrcOp = Op.getOperand(0);
17875   EVT VT = Op.getValueType();
17876   EVT ContainerDstVT = getContainerForFixedLengthVector(DAG, VT);
17877   EVT ContainerSrcVT =
17878       getContainerForFixedLengthVector(DAG, SrcOp.getValueType());
17879 
17880   SrcOp = convertToScalableVector(DAG, ContainerSrcVT, SrcOp);
17881   Op = DAG.getNode(ISD::BITCAST, DL, ContainerDstVT, SrcOp);
17882   return convertFromScalableVector(DAG, VT, Op);
17883 }
17884 
getSVESafeBitCast(EVT VT,SDValue Op,SelectionDAG & DAG) const17885 SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
17886                                                  SelectionDAG &DAG) const {
17887   SDLoc DL(Op);
17888   EVT InVT = Op.getValueType();
17889   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
17890   (void)TLI;
17891 
17892   assert(VT.isScalableVector() && TLI.isTypeLegal(VT) &&
17893          InVT.isScalableVector() && TLI.isTypeLegal(InVT) &&
17894          "Only expect to cast between legal scalable vector types!");
17895   assert((VT.getVectorElementType() == MVT::i1) ==
17896              (InVT.getVectorElementType() == MVT::i1) &&
17897          "Cannot cast between data and predicate scalable vector types!");
17898 
17899   if (InVT == VT)
17900     return Op;
17901 
17902   if (VT.getVectorElementType() == MVT::i1)
17903     return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
17904 
17905   EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType());
17906   EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType());
17907   assert((VT == PackedVT || InVT == PackedInVT) &&
17908          "Cannot cast between unpacked scalable vector types!");
17909 
17910   // Pack input if required.
17911   if (InVT != PackedInVT)
17912     Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, PackedInVT, Op);
17913 
17914   Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
17915 
17916   // Unpack result if required.
17917   if (VT != PackedVT)
17918     Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
17919 
17920   return Op;
17921 }
17922 
isAllActivePredicate(SDValue N) const17923 bool AArch64TargetLowering::isAllActivePredicate(SDValue N) const {
17924   return ::isAllActivePredicate(N);
17925 }
17926 
SimplifyDemandedBitsForTargetNode(SDValue Op,const APInt & OriginalDemandedBits,const APInt & OriginalDemandedElts,KnownBits & Known,TargetLoweringOpt & TLO,unsigned Depth) const17927 bool AArch64TargetLowering::SimplifyDemandedBitsForTargetNode(
17928     SDValue Op, const APInt &OriginalDemandedBits,
17929     const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO,
17930     unsigned Depth) const {
17931 
17932   unsigned Opc = Op.getOpcode();
17933   switch (Opc) {
17934   case AArch64ISD::VSHL: {
17935     // Match (VSHL (VLSHR Val X) X)
17936     SDValue ShiftL = Op;
17937     SDValue ShiftR = Op->getOperand(0);
17938     if (ShiftR->getOpcode() != AArch64ISD::VLSHR)
17939       return false;
17940 
17941     if (!ShiftL.hasOneUse() || !ShiftR.hasOneUse())
17942       return false;
17943 
17944     unsigned ShiftLBits = ShiftL->getConstantOperandVal(1);
17945     unsigned ShiftRBits = ShiftR->getConstantOperandVal(1);
17946 
17947     // Other cases can be handled as well, but this is not
17948     // implemented.
17949     if (ShiftRBits != ShiftLBits)
17950       return false;
17951 
17952     unsigned ScalarSize = Op.getScalarValueSizeInBits();
17953     assert(ScalarSize > ShiftLBits && "Invalid shift imm");
17954 
17955     APInt ZeroBits = APInt::getLowBitsSet(ScalarSize, ShiftLBits);
17956     APInt UnusedBits = ~OriginalDemandedBits;
17957 
17958     if ((ZeroBits & UnusedBits) != ZeroBits)
17959       return false;
17960 
17961     // All bits that are zeroed by (VSHL (VLSHR Val X) X) are not
17962     // used - simplify to just Val.
17963     return TLO.CombineTo(Op, ShiftR->getOperand(0));
17964   }
17965   }
17966 
17967   return TargetLowering::SimplifyDemandedBitsForTargetNode(
17968       Op, OriginalDemandedBits, OriginalDemandedElts, Known, TLO, Depth);
17969 }
17970