xref: /openbsd-src/gnu/llvm/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (revision d415bd752c734aee168c4ee86ff32e8cc249eb16)
109467b48Spatrick //===-- AArch64TargetTransformInfo.cpp - AArch64 specific TTI -------------===//
209467b48Spatrick //
309467b48Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409467b48Spatrick // See https://llvm.org/LICENSE.txt for license information.
509467b48Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609467b48Spatrick //
709467b48Spatrick //===----------------------------------------------------------------------===//
809467b48Spatrick 
909467b48Spatrick #include "AArch64TargetTransformInfo.h"
1073471bf0Spatrick #include "AArch64ExpandImm.h"
11*d415bd75Srobert #include "AArch64PerfectShuffle.h"
1209467b48Spatrick #include "MCTargetDesc/AArch64AddressingModes.h"
13*d415bd75Srobert #include "llvm/Analysis/IVDescriptors.h"
1409467b48Spatrick #include "llvm/Analysis/LoopInfo.h"
1509467b48Spatrick #include "llvm/Analysis/TargetTransformInfo.h"
1609467b48Spatrick #include "llvm/CodeGen/BasicTTIImpl.h"
1709467b48Spatrick #include "llvm/CodeGen/CostTable.h"
1809467b48Spatrick #include "llvm/CodeGen/TargetLowering.h"
1909467b48Spatrick #include "llvm/IR/IntrinsicInst.h"
20*d415bd75Srobert #include "llvm/IR/Intrinsics.h"
2109467b48Spatrick #include "llvm/IR/IntrinsicsAArch64.h"
2273471bf0Spatrick #include "llvm/IR/PatternMatch.h"
2309467b48Spatrick #include "llvm/Support/Debug.h"
2473471bf0Spatrick #include "llvm/Transforms/InstCombine/InstCombiner.h"
25*d415bd75Srobert #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
2609467b48Spatrick #include <algorithm>
27*d415bd75Srobert #include <optional>
2809467b48Spatrick using namespace llvm;
2973471bf0Spatrick using namespace llvm::PatternMatch;
3009467b48Spatrick 
3109467b48Spatrick #define DEBUG_TYPE "aarch64tti"
3209467b48Spatrick 
3309467b48Spatrick static cl::opt<bool> EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix",
3409467b48Spatrick                                                cl::init(true), cl::Hidden);
3509467b48Spatrick 
36*d415bd75Srobert static cl::opt<unsigned> SVEGatherOverhead("sve-gather-overhead", cl::init(10),
37*d415bd75Srobert                                            cl::Hidden);
38*d415bd75Srobert 
39*d415bd75Srobert static cl::opt<unsigned> SVEScatterOverhead("sve-scatter-overhead",
40*d415bd75Srobert                                             cl::init(10), cl::Hidden);
41*d415bd75Srobert 
42*d415bd75Srobert namespace {
43*d415bd75Srobert class TailFoldingKind {
44*d415bd75Srobert private:
45*d415bd75Srobert   uint8_t Bits = 0; // Currently defaults to disabled.
46*d415bd75Srobert 
47*d415bd75Srobert public:
48*d415bd75Srobert   enum TailFoldingOpts {
49*d415bd75Srobert     TFDisabled = 0x0,
50*d415bd75Srobert     TFReductions = 0x01,
51*d415bd75Srobert     TFRecurrences = 0x02,
52*d415bd75Srobert     TFSimple = 0x80,
53*d415bd75Srobert     TFAll = TFReductions | TFRecurrences | TFSimple
54*d415bd75Srobert   };
55*d415bd75Srobert 
operator =(const std::string & Val)56*d415bd75Srobert   void operator=(const std::string &Val) {
57*d415bd75Srobert     if (Val.empty())
58*d415bd75Srobert       return;
59*d415bd75Srobert     SmallVector<StringRef, 6> TailFoldTypes;
60*d415bd75Srobert     StringRef(Val).split(TailFoldTypes, '+', -1, false);
61*d415bd75Srobert     for (auto TailFoldType : TailFoldTypes) {
62*d415bd75Srobert       if (TailFoldType == "disabled")
63*d415bd75Srobert         Bits = 0;
64*d415bd75Srobert       else if (TailFoldType == "all")
65*d415bd75Srobert         Bits = TFAll;
66*d415bd75Srobert       else if (TailFoldType == "default")
67*d415bd75Srobert         Bits = 0; // Currently defaults to never tail-folding.
68*d415bd75Srobert       else if (TailFoldType == "simple")
69*d415bd75Srobert         add(TFSimple);
70*d415bd75Srobert       else if (TailFoldType == "reductions")
71*d415bd75Srobert         add(TFReductions);
72*d415bd75Srobert       else if (TailFoldType == "recurrences")
73*d415bd75Srobert         add(TFRecurrences);
74*d415bd75Srobert       else if (TailFoldType == "noreductions")
75*d415bd75Srobert         remove(TFReductions);
76*d415bd75Srobert       else if (TailFoldType == "norecurrences")
77*d415bd75Srobert         remove(TFRecurrences);
78*d415bd75Srobert       else {
79*d415bd75Srobert         errs()
80*d415bd75Srobert             << "invalid argument " << TailFoldType.str()
81*d415bd75Srobert             << " to -sve-tail-folding=; each element must be one of: disabled, "
82*d415bd75Srobert                "all, default, simple, reductions, noreductions, recurrences, "
83*d415bd75Srobert                "norecurrences\n";
84*d415bd75Srobert       }
85*d415bd75Srobert     }
86*d415bd75Srobert   }
87*d415bd75Srobert 
operator uint8_t() const88*d415bd75Srobert   operator uint8_t() const { return Bits; }
89*d415bd75Srobert 
add(uint8_t Flag)90*d415bd75Srobert   void add(uint8_t Flag) { Bits |= Flag; }
remove(uint8_t Flag)91*d415bd75Srobert   void remove(uint8_t Flag) { Bits &= ~Flag; }
92*d415bd75Srobert };
93*d415bd75Srobert } // namespace
94*d415bd75Srobert 
95*d415bd75Srobert TailFoldingKind TailFoldingKindLoc;
96*d415bd75Srobert 
97*d415bd75Srobert cl::opt<TailFoldingKind, true, cl::parser<std::string>> SVETailFolding(
98*d415bd75Srobert     "sve-tail-folding",
99*d415bd75Srobert     cl::desc(
100*d415bd75Srobert         "Control the use of vectorisation using tail-folding for SVE:"
101*d415bd75Srobert         "\ndisabled    No loop types will vectorize using tail-folding"
102*d415bd75Srobert         "\ndefault     Uses the default tail-folding settings for the target "
103*d415bd75Srobert         "CPU"
104*d415bd75Srobert         "\nall         All legal loop types will vectorize using tail-folding"
105*d415bd75Srobert         "\nsimple      Use tail-folding for simple loops (not reductions or "
106*d415bd75Srobert         "recurrences)"
107*d415bd75Srobert         "\nreductions  Use tail-folding for loops containing reductions"
108*d415bd75Srobert         "\nrecurrences Use tail-folding for loops containing fixed order "
109*d415bd75Srobert         "recurrences"),
110*d415bd75Srobert     cl::location(TailFoldingKindLoc));
111*d415bd75Srobert 
112*d415bd75Srobert // Experimental option that will only be fully functional when the
113*d415bd75Srobert // code-generator is changed to use SVE instead of NEON for all fixed-width
114*d415bd75Srobert // operations.
115*d415bd75Srobert static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode(
116*d415bd75Srobert     "enable-fixedwidth-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
117*d415bd75Srobert 
118*d415bd75Srobert // Experimental option that will only be fully functional when the cost-model
119*d415bd75Srobert // and code-generator have been changed to avoid using scalable vector
120*d415bd75Srobert // instructions that are not legal in streaming SVE mode.
121*d415bd75Srobert static cl::opt<bool> EnableScalableAutovecInStreamingMode(
122*d415bd75Srobert     "enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
123*d415bd75Srobert 
areInlineCompatible(const Function * Caller,const Function * Callee) const12409467b48Spatrick bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
12509467b48Spatrick                                          const Function *Callee) const {
126*d415bd75Srobert   SMEAttrs CallerAttrs(*Caller);
127*d415bd75Srobert   SMEAttrs CalleeAttrs(*Callee);
128*d415bd75Srobert   if (CallerAttrs.requiresSMChange(CalleeAttrs,
129*d415bd75Srobert                                    /*BodyOverridesInterface=*/true) ||
130*d415bd75Srobert       CallerAttrs.requiresLazySave(CalleeAttrs) ||
131*d415bd75Srobert       CalleeAttrs.hasNewZAInterface())
132*d415bd75Srobert     return false;
133*d415bd75Srobert 
13409467b48Spatrick   const TargetMachine &TM = getTLI()->getTargetMachine();
13509467b48Spatrick 
13609467b48Spatrick   const FeatureBitset &CallerBits =
13709467b48Spatrick       TM.getSubtargetImpl(*Caller)->getFeatureBits();
13809467b48Spatrick   const FeatureBitset &CalleeBits =
13909467b48Spatrick       TM.getSubtargetImpl(*Callee)->getFeatureBits();
14009467b48Spatrick 
14109467b48Spatrick   // Inline a callee if its target-features are a subset of the callers
14209467b48Spatrick   // target-features.
14309467b48Spatrick   return (CallerBits & CalleeBits) == CalleeBits;
14409467b48Spatrick }
14509467b48Spatrick 
shouldMaximizeVectorBandwidth(TargetTransformInfo::RegisterKind K) const146*d415bd75Srobert bool AArch64TTIImpl::shouldMaximizeVectorBandwidth(
147*d415bd75Srobert     TargetTransformInfo::RegisterKind K) const {
148*d415bd75Srobert   assert(K != TargetTransformInfo::RGK_Scalar);
149*d415bd75Srobert   return K == TargetTransformInfo::RGK_FixedWidthVector;
150*d415bd75Srobert }
151*d415bd75Srobert 
15209467b48Spatrick /// Calculate the cost of materializing a 64-bit value. This helper
15309467b48Spatrick /// method might only calculate a fraction of a larger immediate. Therefore it
15409467b48Spatrick /// is valid to return a cost of ZERO.
getIntImmCost(int64_t Val)15573471bf0Spatrick InstructionCost AArch64TTIImpl::getIntImmCost(int64_t Val) {
15609467b48Spatrick   // Check if the immediate can be encoded within an instruction.
15709467b48Spatrick   if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64))
15809467b48Spatrick     return 0;
15909467b48Spatrick 
16009467b48Spatrick   if (Val < 0)
16109467b48Spatrick     Val = ~Val;
16209467b48Spatrick 
16309467b48Spatrick   // Calculate how many moves we will need to materialize this constant.
16409467b48Spatrick   SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
16509467b48Spatrick   AArch64_IMM::expandMOVImm(Val, 64, Insn);
16609467b48Spatrick   return Insn.size();
16709467b48Spatrick }
16809467b48Spatrick 
16909467b48Spatrick /// Calculate the cost of materializing the given constant.
getIntImmCost(const APInt & Imm,Type * Ty,TTI::TargetCostKind CostKind)17073471bf0Spatrick InstructionCost AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
171097a140dSpatrick                                               TTI::TargetCostKind CostKind) {
17209467b48Spatrick   assert(Ty->isIntegerTy());
17309467b48Spatrick 
17409467b48Spatrick   unsigned BitSize = Ty->getPrimitiveSizeInBits();
17509467b48Spatrick   if (BitSize == 0)
17609467b48Spatrick     return ~0U;
17709467b48Spatrick 
17809467b48Spatrick   // Sign-extend all constants to a multiple of 64-bit.
17909467b48Spatrick   APInt ImmVal = Imm;
18009467b48Spatrick   if (BitSize & 0x3f)
18109467b48Spatrick     ImmVal = Imm.sext((BitSize + 63) & ~0x3fU);
18209467b48Spatrick 
18309467b48Spatrick   // Split the constant into 64-bit chunks and calculate the cost for each
18409467b48Spatrick   // chunk.
18573471bf0Spatrick   InstructionCost Cost = 0;
18609467b48Spatrick   for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) {
18709467b48Spatrick     APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64);
18809467b48Spatrick     int64_t Val = Tmp.getSExtValue();
18909467b48Spatrick     Cost += getIntImmCost(Val);
19009467b48Spatrick   }
19109467b48Spatrick   // We need at least one instruction to materialze the constant.
19273471bf0Spatrick   return std::max<InstructionCost>(1, Cost);
19309467b48Spatrick }
19409467b48Spatrick 
getIntImmCostInst(unsigned Opcode,unsigned Idx,const APInt & Imm,Type * Ty,TTI::TargetCostKind CostKind,Instruction * Inst)19573471bf0Spatrick InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
196097a140dSpatrick                                                   const APInt &Imm, Type *Ty,
19773471bf0Spatrick                                                   TTI::TargetCostKind CostKind,
19873471bf0Spatrick                                                   Instruction *Inst) {
19909467b48Spatrick   assert(Ty->isIntegerTy());
20009467b48Spatrick 
20109467b48Spatrick   unsigned BitSize = Ty->getPrimitiveSizeInBits();
20209467b48Spatrick   // There is no cost model for constants with a bit size of 0. Return TCC_Free
20309467b48Spatrick   // here, so that constant hoisting will ignore this constant.
20409467b48Spatrick   if (BitSize == 0)
20509467b48Spatrick     return TTI::TCC_Free;
20609467b48Spatrick 
20709467b48Spatrick   unsigned ImmIdx = ~0U;
20809467b48Spatrick   switch (Opcode) {
20909467b48Spatrick   default:
21009467b48Spatrick     return TTI::TCC_Free;
21109467b48Spatrick   case Instruction::GetElementPtr:
21209467b48Spatrick     // Always hoist the base address of a GetElementPtr.
21309467b48Spatrick     if (Idx == 0)
21409467b48Spatrick       return 2 * TTI::TCC_Basic;
21509467b48Spatrick     return TTI::TCC_Free;
21609467b48Spatrick   case Instruction::Store:
21709467b48Spatrick     ImmIdx = 0;
21809467b48Spatrick     break;
21909467b48Spatrick   case Instruction::Add:
22009467b48Spatrick   case Instruction::Sub:
22109467b48Spatrick   case Instruction::Mul:
22209467b48Spatrick   case Instruction::UDiv:
22309467b48Spatrick   case Instruction::SDiv:
22409467b48Spatrick   case Instruction::URem:
22509467b48Spatrick   case Instruction::SRem:
22609467b48Spatrick   case Instruction::And:
22709467b48Spatrick   case Instruction::Or:
22809467b48Spatrick   case Instruction::Xor:
22909467b48Spatrick   case Instruction::ICmp:
23009467b48Spatrick     ImmIdx = 1;
23109467b48Spatrick     break;
23209467b48Spatrick   // Always return TCC_Free for the shift value of a shift instruction.
23309467b48Spatrick   case Instruction::Shl:
23409467b48Spatrick   case Instruction::LShr:
23509467b48Spatrick   case Instruction::AShr:
23609467b48Spatrick     if (Idx == 1)
23709467b48Spatrick       return TTI::TCC_Free;
23809467b48Spatrick     break;
23909467b48Spatrick   case Instruction::Trunc:
24009467b48Spatrick   case Instruction::ZExt:
24109467b48Spatrick   case Instruction::SExt:
24209467b48Spatrick   case Instruction::IntToPtr:
24309467b48Spatrick   case Instruction::PtrToInt:
24409467b48Spatrick   case Instruction::BitCast:
24509467b48Spatrick   case Instruction::PHI:
24609467b48Spatrick   case Instruction::Call:
24709467b48Spatrick   case Instruction::Select:
24809467b48Spatrick   case Instruction::Ret:
24909467b48Spatrick   case Instruction::Load:
25009467b48Spatrick     break;
25109467b48Spatrick   }
25209467b48Spatrick 
25309467b48Spatrick   if (Idx == ImmIdx) {
25409467b48Spatrick     int NumConstants = (BitSize + 63) / 64;
25573471bf0Spatrick     InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
25609467b48Spatrick     return (Cost <= NumConstants * TTI::TCC_Basic)
25709467b48Spatrick                ? static_cast<int>(TTI::TCC_Free)
25809467b48Spatrick                : Cost;
25909467b48Spatrick   }
260097a140dSpatrick   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
26109467b48Spatrick }
26209467b48Spatrick 
26373471bf0Spatrick InstructionCost
getIntImmCostIntrin(Intrinsic::ID IID,unsigned Idx,const APInt & Imm,Type * Ty,TTI::TargetCostKind CostKind)26473471bf0Spatrick AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
265097a140dSpatrick                                     const APInt &Imm, Type *Ty,
266097a140dSpatrick                                     TTI::TargetCostKind CostKind) {
26709467b48Spatrick   assert(Ty->isIntegerTy());
26809467b48Spatrick 
26909467b48Spatrick   unsigned BitSize = Ty->getPrimitiveSizeInBits();
27009467b48Spatrick   // There is no cost model for constants with a bit size of 0. Return TCC_Free
27109467b48Spatrick   // here, so that constant hoisting will ignore this constant.
27209467b48Spatrick   if (BitSize == 0)
27309467b48Spatrick     return TTI::TCC_Free;
27409467b48Spatrick 
27509467b48Spatrick   // Most (all?) AArch64 intrinsics do not support folding immediates into the
27609467b48Spatrick   // selected instruction, so we compute the materialization cost for the
27709467b48Spatrick   // immediate directly.
27809467b48Spatrick   if (IID >= Intrinsic::aarch64_addg && IID <= Intrinsic::aarch64_udiv)
279097a140dSpatrick     return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
28009467b48Spatrick 
28109467b48Spatrick   switch (IID) {
28209467b48Spatrick   default:
28309467b48Spatrick     return TTI::TCC_Free;
28409467b48Spatrick   case Intrinsic::sadd_with_overflow:
28509467b48Spatrick   case Intrinsic::uadd_with_overflow:
28609467b48Spatrick   case Intrinsic::ssub_with_overflow:
28709467b48Spatrick   case Intrinsic::usub_with_overflow:
28809467b48Spatrick   case Intrinsic::smul_with_overflow:
28909467b48Spatrick   case Intrinsic::umul_with_overflow:
29009467b48Spatrick     if (Idx == 1) {
29109467b48Spatrick       int NumConstants = (BitSize + 63) / 64;
29273471bf0Spatrick       InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
29309467b48Spatrick       return (Cost <= NumConstants * TTI::TCC_Basic)
29409467b48Spatrick                  ? static_cast<int>(TTI::TCC_Free)
29509467b48Spatrick                  : Cost;
29609467b48Spatrick     }
29709467b48Spatrick     break;
29809467b48Spatrick   case Intrinsic::experimental_stackmap:
29909467b48Spatrick     if ((Idx < 2) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
30009467b48Spatrick       return TTI::TCC_Free;
30109467b48Spatrick     break;
30209467b48Spatrick   case Intrinsic::experimental_patchpoint_void:
30309467b48Spatrick   case Intrinsic::experimental_patchpoint_i64:
30409467b48Spatrick     if ((Idx < 4) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
30509467b48Spatrick       return TTI::TCC_Free;
30609467b48Spatrick     break;
30773471bf0Spatrick   case Intrinsic::experimental_gc_statepoint:
30873471bf0Spatrick     if ((Idx < 5) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
30973471bf0Spatrick       return TTI::TCC_Free;
31073471bf0Spatrick     break;
31109467b48Spatrick   }
312097a140dSpatrick   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
31309467b48Spatrick }
31409467b48Spatrick 
31509467b48Spatrick TargetTransformInfo::PopcntSupportKind
getPopcntSupport(unsigned TyWidth)31609467b48Spatrick AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) {
31709467b48Spatrick   assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2");
31809467b48Spatrick   if (TyWidth == 32 || TyWidth == 64)
31909467b48Spatrick     return TTI::PSK_FastHardware;
32009467b48Spatrick   // TODO: AArch64TargetLowering::LowerCTPOP() supports 128bit popcount.
32109467b48Spatrick   return TTI::PSK_Software;
32209467b48Spatrick }
32309467b48Spatrick 
32473471bf0Spatrick InstructionCost
getIntrinsicInstrCost(const IntrinsicCostAttributes & ICA,TTI::TargetCostKind CostKind)32573471bf0Spatrick AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
32673471bf0Spatrick                                       TTI::TargetCostKind CostKind) {
32773471bf0Spatrick   auto *RetTy = ICA.getReturnType();
32873471bf0Spatrick   switch (ICA.getID()) {
32973471bf0Spatrick   case Intrinsic::umin:
330*d415bd75Srobert   case Intrinsic::umax:
33173471bf0Spatrick   case Intrinsic::smin:
33273471bf0Spatrick   case Intrinsic::smax: {
33373471bf0Spatrick     static const auto ValidMinMaxTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
33473471bf0Spatrick                                         MVT::v8i16, MVT::v2i32, MVT::v4i32};
335*d415bd75Srobert     auto LT = getTypeLegalizationCost(RetTy);
336*d415bd75Srobert     // v2i64 types get converted to cmp+bif hence the cost of 2
337*d415bd75Srobert     if (LT.second == MVT::v2i64)
338*d415bd75Srobert       return LT.first * 2;
33973471bf0Spatrick     if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
34073471bf0Spatrick       return LT.first;
34173471bf0Spatrick     break;
34273471bf0Spatrick   }
34373471bf0Spatrick   case Intrinsic::sadd_sat:
34473471bf0Spatrick   case Intrinsic::ssub_sat:
34573471bf0Spatrick   case Intrinsic::uadd_sat:
34673471bf0Spatrick   case Intrinsic::usub_sat: {
34773471bf0Spatrick     static const auto ValidSatTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
34873471bf0Spatrick                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
34973471bf0Spatrick                                      MVT::v2i64};
350*d415bd75Srobert     auto LT = getTypeLegalizationCost(RetTy);
35173471bf0Spatrick     // This is a base cost of 1 for the vadd, plus 3 extract shifts if we
35273471bf0Spatrick     // need to extend the type, as it uses shr(qadd(shl, shl)).
35373471bf0Spatrick     unsigned Instrs =
35473471bf0Spatrick         LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4;
35573471bf0Spatrick     if (any_of(ValidSatTys, [&LT](MVT M) { return M == LT.second; }))
35673471bf0Spatrick       return LT.first * Instrs;
35773471bf0Spatrick     break;
35873471bf0Spatrick   }
35973471bf0Spatrick   case Intrinsic::abs: {
36073471bf0Spatrick     static const auto ValidAbsTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
36173471bf0Spatrick                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
36273471bf0Spatrick                                      MVT::v2i64};
363*d415bd75Srobert     auto LT = getTypeLegalizationCost(RetTy);
36473471bf0Spatrick     if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }))
36573471bf0Spatrick       return LT.first;
36673471bf0Spatrick     break;
36773471bf0Spatrick   }
36873471bf0Spatrick   case Intrinsic::experimental_stepvector: {
36973471bf0Spatrick     InstructionCost Cost = 1; // Cost of the `index' instruction
370*d415bd75Srobert     auto LT = getTypeLegalizationCost(RetTy);
37173471bf0Spatrick     // Legalisation of illegal vectors involves an `index' instruction plus
37273471bf0Spatrick     // (LT.first - 1) vector adds.
37373471bf0Spatrick     if (LT.first > 1) {
37473471bf0Spatrick       Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext());
37573471bf0Spatrick       InstructionCost AddCost =
37673471bf0Spatrick           getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind);
37773471bf0Spatrick       Cost += AddCost * (LT.first - 1);
37873471bf0Spatrick     }
37973471bf0Spatrick     return Cost;
38073471bf0Spatrick   }
38173471bf0Spatrick   case Intrinsic::bitreverse: {
38273471bf0Spatrick     static const CostTblEntry BitreverseTbl[] = {
38373471bf0Spatrick         {Intrinsic::bitreverse, MVT::i32, 1},
38473471bf0Spatrick         {Intrinsic::bitreverse, MVT::i64, 1},
38573471bf0Spatrick         {Intrinsic::bitreverse, MVT::v8i8, 1},
38673471bf0Spatrick         {Intrinsic::bitreverse, MVT::v16i8, 1},
38773471bf0Spatrick         {Intrinsic::bitreverse, MVT::v4i16, 2},
38873471bf0Spatrick         {Intrinsic::bitreverse, MVT::v8i16, 2},
38973471bf0Spatrick         {Intrinsic::bitreverse, MVT::v2i32, 2},
39073471bf0Spatrick         {Intrinsic::bitreverse, MVT::v4i32, 2},
39173471bf0Spatrick         {Intrinsic::bitreverse, MVT::v1i64, 2},
39273471bf0Spatrick         {Intrinsic::bitreverse, MVT::v2i64, 2},
39373471bf0Spatrick     };
394*d415bd75Srobert     const auto LegalisationCost = getTypeLegalizationCost(RetTy);
39573471bf0Spatrick     const auto *Entry =
39673471bf0Spatrick         CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second);
397*d415bd75Srobert     if (Entry) {
398*d415bd75Srobert       // Cost Model is using the legal type(i32) that i8 and i16 will be
399*d415bd75Srobert       // converted to +1 so that we match the actual lowering cost
40073471bf0Spatrick       if (TLI->getValueType(DL, RetTy, true) == MVT::i8 ||
40173471bf0Spatrick           TLI->getValueType(DL, RetTy, true) == MVT::i16)
40273471bf0Spatrick         return LegalisationCost.first * Entry->Cost + 1;
403*d415bd75Srobert 
40473471bf0Spatrick       return LegalisationCost.first * Entry->Cost;
405*d415bd75Srobert     }
40673471bf0Spatrick     break;
40773471bf0Spatrick   }
40873471bf0Spatrick   case Intrinsic::ctpop: {
409*d415bd75Srobert     if (!ST->hasNEON()) {
410*d415bd75Srobert       // 32-bit or 64-bit ctpop without NEON is 12 instructions.
411*d415bd75Srobert       return getTypeLegalizationCost(RetTy).first * 12;
412*d415bd75Srobert     }
41373471bf0Spatrick     static const CostTblEntry CtpopCostTbl[] = {
41473471bf0Spatrick         {ISD::CTPOP, MVT::v2i64, 4},
41573471bf0Spatrick         {ISD::CTPOP, MVT::v4i32, 3},
41673471bf0Spatrick         {ISD::CTPOP, MVT::v8i16, 2},
41773471bf0Spatrick         {ISD::CTPOP, MVT::v16i8, 1},
41873471bf0Spatrick         {ISD::CTPOP, MVT::i64,   4},
41973471bf0Spatrick         {ISD::CTPOP, MVT::v2i32, 3},
42073471bf0Spatrick         {ISD::CTPOP, MVT::v4i16, 2},
42173471bf0Spatrick         {ISD::CTPOP, MVT::v8i8,  1},
42273471bf0Spatrick         {ISD::CTPOP, MVT::i32,   5},
42373471bf0Spatrick     };
424*d415bd75Srobert     auto LT = getTypeLegalizationCost(RetTy);
42573471bf0Spatrick     MVT MTy = LT.second;
42673471bf0Spatrick     if (const auto *Entry = CostTableLookup(CtpopCostTbl, ISD::CTPOP, MTy)) {
42773471bf0Spatrick       // Extra cost of +1 when illegal vector types are legalized by promoting
42873471bf0Spatrick       // the integer type.
42973471bf0Spatrick       int ExtraCost = MTy.isVector() && MTy.getScalarSizeInBits() !=
43073471bf0Spatrick                                             RetTy->getScalarSizeInBits()
43173471bf0Spatrick                           ? 1
43273471bf0Spatrick                           : 0;
43373471bf0Spatrick       return LT.first * Entry->Cost + ExtraCost;
43473471bf0Spatrick     }
43573471bf0Spatrick     break;
43673471bf0Spatrick   }
437*d415bd75Srobert   case Intrinsic::sadd_with_overflow:
438*d415bd75Srobert   case Intrinsic::uadd_with_overflow:
439*d415bd75Srobert   case Intrinsic::ssub_with_overflow:
440*d415bd75Srobert   case Intrinsic::usub_with_overflow:
441*d415bd75Srobert   case Intrinsic::smul_with_overflow:
442*d415bd75Srobert   case Intrinsic::umul_with_overflow: {
443*d415bd75Srobert     static const CostTblEntry WithOverflowCostTbl[] = {
444*d415bd75Srobert         {Intrinsic::sadd_with_overflow, MVT::i8, 3},
445*d415bd75Srobert         {Intrinsic::uadd_with_overflow, MVT::i8, 3},
446*d415bd75Srobert         {Intrinsic::sadd_with_overflow, MVT::i16, 3},
447*d415bd75Srobert         {Intrinsic::uadd_with_overflow, MVT::i16, 3},
448*d415bd75Srobert         {Intrinsic::sadd_with_overflow, MVT::i32, 1},
449*d415bd75Srobert         {Intrinsic::uadd_with_overflow, MVT::i32, 1},
450*d415bd75Srobert         {Intrinsic::sadd_with_overflow, MVT::i64, 1},
451*d415bd75Srobert         {Intrinsic::uadd_with_overflow, MVT::i64, 1},
452*d415bd75Srobert         {Intrinsic::ssub_with_overflow, MVT::i8, 3},
453*d415bd75Srobert         {Intrinsic::usub_with_overflow, MVT::i8, 3},
454*d415bd75Srobert         {Intrinsic::ssub_with_overflow, MVT::i16, 3},
455*d415bd75Srobert         {Intrinsic::usub_with_overflow, MVT::i16, 3},
456*d415bd75Srobert         {Intrinsic::ssub_with_overflow, MVT::i32, 1},
457*d415bd75Srobert         {Intrinsic::usub_with_overflow, MVT::i32, 1},
458*d415bd75Srobert         {Intrinsic::ssub_with_overflow, MVT::i64, 1},
459*d415bd75Srobert         {Intrinsic::usub_with_overflow, MVT::i64, 1},
460*d415bd75Srobert         {Intrinsic::smul_with_overflow, MVT::i8, 5},
461*d415bd75Srobert         {Intrinsic::umul_with_overflow, MVT::i8, 4},
462*d415bd75Srobert         {Intrinsic::smul_with_overflow, MVT::i16, 5},
463*d415bd75Srobert         {Intrinsic::umul_with_overflow, MVT::i16, 4},
464*d415bd75Srobert         {Intrinsic::smul_with_overflow, MVT::i32, 2}, // eg umull;tst
465*d415bd75Srobert         {Intrinsic::umul_with_overflow, MVT::i32, 2}, // eg umull;cmp sxtw
466*d415bd75Srobert         {Intrinsic::smul_with_overflow, MVT::i64, 3}, // eg mul;smulh;cmp
467*d415bd75Srobert         {Intrinsic::umul_with_overflow, MVT::i64, 3}, // eg mul;umulh;cmp asr
468*d415bd75Srobert     };
469*d415bd75Srobert     EVT MTy = TLI->getValueType(DL, RetTy->getContainedType(0), true);
470*d415bd75Srobert     if (MTy.isSimple())
471*d415bd75Srobert       if (const auto *Entry = CostTableLookup(WithOverflowCostTbl, ICA.getID(),
472*d415bd75Srobert                                               MTy.getSimpleVT()))
473*d415bd75Srobert         return Entry->Cost;
474*d415bd75Srobert     break;
475*d415bd75Srobert   }
476*d415bd75Srobert   case Intrinsic::fptosi_sat:
477*d415bd75Srobert   case Intrinsic::fptoui_sat: {
478*d415bd75Srobert     if (ICA.getArgTypes().empty())
479*d415bd75Srobert       break;
480*d415bd75Srobert     bool IsSigned = ICA.getID() == Intrinsic::fptosi_sat;
481*d415bd75Srobert     auto LT = getTypeLegalizationCost(ICA.getArgTypes()[0]);
482*d415bd75Srobert     EVT MTy = TLI->getValueType(DL, RetTy);
483*d415bd75Srobert     // Check for the legal types, which are where the size of the input and the
484*d415bd75Srobert     // output are the same, or we are using cvt f64->i32 or f32->i64.
485*d415bd75Srobert     if ((LT.second == MVT::f32 || LT.second == MVT::f64 ||
486*d415bd75Srobert          LT.second == MVT::v2f32 || LT.second == MVT::v4f32 ||
487*d415bd75Srobert          LT.second == MVT::v2f64) &&
488*d415bd75Srobert         (LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits() ||
489*d415bd75Srobert          (LT.second == MVT::f64 && MTy == MVT::i32) ||
490*d415bd75Srobert          (LT.second == MVT::f32 && MTy == MVT::i64)))
491*d415bd75Srobert       return LT.first;
492*d415bd75Srobert     // Similarly for fp16 sizes
493*d415bd75Srobert     if (ST->hasFullFP16() &&
494*d415bd75Srobert         ((LT.second == MVT::f16 && MTy == MVT::i32) ||
495*d415bd75Srobert          ((LT.second == MVT::v4f16 || LT.second == MVT::v8f16) &&
496*d415bd75Srobert           (LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits()))))
497*d415bd75Srobert       return LT.first;
498*d415bd75Srobert 
499*d415bd75Srobert     // Otherwise we use a legal convert followed by a min+max
500*d415bd75Srobert     if ((LT.second.getScalarType() == MVT::f32 ||
501*d415bd75Srobert          LT.second.getScalarType() == MVT::f64 ||
502*d415bd75Srobert          (ST->hasFullFP16() && LT.second.getScalarType() == MVT::f16)) &&
503*d415bd75Srobert         LT.second.getScalarSizeInBits() >= MTy.getScalarSizeInBits()) {
504*d415bd75Srobert       Type *LegalTy =
505*d415bd75Srobert           Type::getIntNTy(RetTy->getContext(), LT.second.getScalarSizeInBits());
506*d415bd75Srobert       if (LT.second.isVector())
507*d415bd75Srobert         LegalTy = VectorType::get(LegalTy, LT.second.getVectorElementCount());
508*d415bd75Srobert       InstructionCost Cost = 1;
509*d415bd75Srobert       IntrinsicCostAttributes Attrs1(IsSigned ? Intrinsic::smin : Intrinsic::umin,
510*d415bd75Srobert                                     LegalTy, {LegalTy, LegalTy});
511*d415bd75Srobert       Cost += getIntrinsicInstrCost(Attrs1, CostKind);
512*d415bd75Srobert       IntrinsicCostAttributes Attrs2(IsSigned ? Intrinsic::smax : Intrinsic::umax,
513*d415bd75Srobert                                     LegalTy, {LegalTy, LegalTy});
514*d415bd75Srobert       Cost += getIntrinsicInstrCost(Attrs2, CostKind);
515*d415bd75Srobert       return LT.first * Cost;
516*d415bd75Srobert     }
517*d415bd75Srobert     break;
518*d415bd75Srobert   }
51973471bf0Spatrick   default:
52073471bf0Spatrick     break;
52173471bf0Spatrick   }
52273471bf0Spatrick   return BaseT::getIntrinsicInstrCost(ICA, CostKind);
52373471bf0Spatrick }
52473471bf0Spatrick 
52573471bf0Spatrick /// The function will remove redundant reinterprets casting in the presence
52673471bf0Spatrick /// of the control flow
processPhiNode(InstCombiner & IC,IntrinsicInst & II)527*d415bd75Srobert static std::optional<Instruction *> processPhiNode(InstCombiner &IC,
52873471bf0Spatrick                                                    IntrinsicInst &II) {
52973471bf0Spatrick   SmallVector<Instruction *, 32> Worklist;
53073471bf0Spatrick   auto RequiredType = II.getType();
53173471bf0Spatrick 
53273471bf0Spatrick   auto *PN = dyn_cast<PHINode>(II.getArgOperand(0));
53373471bf0Spatrick   assert(PN && "Expected Phi Node!");
53473471bf0Spatrick 
53573471bf0Spatrick   // Don't create a new Phi unless we can remove the old one.
53673471bf0Spatrick   if (!PN->hasOneUse())
537*d415bd75Srobert     return std::nullopt;
53873471bf0Spatrick 
53973471bf0Spatrick   for (Value *IncValPhi : PN->incoming_values()) {
54073471bf0Spatrick     auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi);
54173471bf0Spatrick     if (!Reinterpret ||
54273471bf0Spatrick         Reinterpret->getIntrinsicID() !=
54373471bf0Spatrick             Intrinsic::aarch64_sve_convert_to_svbool ||
54473471bf0Spatrick         RequiredType != Reinterpret->getArgOperand(0)->getType())
545*d415bd75Srobert       return std::nullopt;
54673471bf0Spatrick   }
54773471bf0Spatrick 
54873471bf0Spatrick   // Create the new Phi
54973471bf0Spatrick   LLVMContext &Ctx = PN->getContext();
55073471bf0Spatrick   IRBuilder<> Builder(Ctx);
55173471bf0Spatrick   Builder.SetInsertPoint(PN);
55273471bf0Spatrick   PHINode *NPN = Builder.CreatePHI(RequiredType, PN->getNumIncomingValues());
55373471bf0Spatrick   Worklist.push_back(PN);
55473471bf0Spatrick 
55573471bf0Spatrick   for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) {
55673471bf0Spatrick     auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I));
55773471bf0Spatrick     NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I));
55873471bf0Spatrick     Worklist.push_back(Reinterpret);
55973471bf0Spatrick   }
56073471bf0Spatrick 
56173471bf0Spatrick   // Cleanup Phi Node and reinterprets
56273471bf0Spatrick   return IC.replaceInstUsesWith(II, NPN);
56373471bf0Spatrick }
56473471bf0Spatrick 
565*d415bd75Srobert // (from_svbool (binop (to_svbool pred) (svbool_t _) (svbool_t _))))
566*d415bd75Srobert // => (binop (pred) (from_svbool _) (from_svbool _))
567*d415bd75Srobert //
568*d415bd75Srobert // The above transformation eliminates a `to_svbool` in the predicate
569*d415bd75Srobert // operand of bitwise operation `binop` by narrowing the vector width of
570*d415bd75Srobert // the operation. For example, it would convert a `<vscale x 16 x i1>
571*d415bd75Srobert // and` into a `<vscale x 4 x i1> and`. This is profitable because
572*d415bd75Srobert // to_svbool must zero the new lanes during widening, whereas
573*d415bd75Srobert // from_svbool is free.
574*d415bd75Srobert static std::optional<Instruction *>
tryCombineFromSVBoolBinOp(InstCombiner & IC,IntrinsicInst & II)575*d415bd75Srobert tryCombineFromSVBoolBinOp(InstCombiner &IC, IntrinsicInst &II) {
576*d415bd75Srobert   auto BinOp = dyn_cast<IntrinsicInst>(II.getOperand(0));
577*d415bd75Srobert   if (!BinOp)
578*d415bd75Srobert     return std::nullopt;
579*d415bd75Srobert 
580*d415bd75Srobert   auto IntrinsicID = BinOp->getIntrinsicID();
581*d415bd75Srobert   switch (IntrinsicID) {
582*d415bd75Srobert   case Intrinsic::aarch64_sve_and_z:
583*d415bd75Srobert   case Intrinsic::aarch64_sve_bic_z:
584*d415bd75Srobert   case Intrinsic::aarch64_sve_eor_z:
585*d415bd75Srobert   case Intrinsic::aarch64_sve_nand_z:
586*d415bd75Srobert   case Intrinsic::aarch64_sve_nor_z:
587*d415bd75Srobert   case Intrinsic::aarch64_sve_orn_z:
588*d415bd75Srobert   case Intrinsic::aarch64_sve_orr_z:
589*d415bd75Srobert     break;
590*d415bd75Srobert   default:
591*d415bd75Srobert     return std::nullopt;
592*d415bd75Srobert   }
593*d415bd75Srobert 
594*d415bd75Srobert   auto BinOpPred = BinOp->getOperand(0);
595*d415bd75Srobert   auto BinOpOp1 = BinOp->getOperand(1);
596*d415bd75Srobert   auto BinOpOp2 = BinOp->getOperand(2);
597*d415bd75Srobert 
598*d415bd75Srobert   auto PredIntr = dyn_cast<IntrinsicInst>(BinOpPred);
599*d415bd75Srobert   if (!PredIntr ||
600*d415bd75Srobert       PredIntr->getIntrinsicID() != Intrinsic::aarch64_sve_convert_to_svbool)
601*d415bd75Srobert     return std::nullopt;
602*d415bd75Srobert 
603*d415bd75Srobert   auto PredOp = PredIntr->getOperand(0);
604*d415bd75Srobert   auto PredOpTy = cast<VectorType>(PredOp->getType());
605*d415bd75Srobert   if (PredOpTy != II.getType())
606*d415bd75Srobert     return std::nullopt;
607*d415bd75Srobert 
608*d415bd75Srobert   IRBuilder<> Builder(II.getContext());
609*d415bd75Srobert   Builder.SetInsertPoint(&II);
610*d415bd75Srobert 
611*d415bd75Srobert   SmallVector<Value *> NarrowedBinOpArgs = {PredOp};
612*d415bd75Srobert   auto NarrowBinOpOp1 = Builder.CreateIntrinsic(
613*d415bd75Srobert       Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp1});
614*d415bd75Srobert   NarrowedBinOpArgs.push_back(NarrowBinOpOp1);
615*d415bd75Srobert   if (BinOpOp1 == BinOpOp2)
616*d415bd75Srobert     NarrowedBinOpArgs.push_back(NarrowBinOpOp1);
617*d415bd75Srobert   else
618*d415bd75Srobert     NarrowedBinOpArgs.push_back(Builder.CreateIntrinsic(
619*d415bd75Srobert         Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp2}));
620*d415bd75Srobert 
621*d415bd75Srobert   auto NarrowedBinOp =
622*d415bd75Srobert       Builder.CreateIntrinsic(IntrinsicID, {PredOpTy}, NarrowedBinOpArgs);
623*d415bd75Srobert   return IC.replaceInstUsesWith(II, NarrowedBinOp);
624*d415bd75Srobert }
625*d415bd75Srobert 
626*d415bd75Srobert static std::optional<Instruction *>
instCombineConvertFromSVBool(InstCombiner & IC,IntrinsicInst & II)627*d415bd75Srobert instCombineConvertFromSVBool(InstCombiner &IC, IntrinsicInst &II) {
62873471bf0Spatrick   // If the reinterpret instruction operand is a PHI Node
62973471bf0Spatrick   if (isa<PHINode>(II.getArgOperand(0)))
63073471bf0Spatrick     return processPhiNode(IC, II);
63173471bf0Spatrick 
632*d415bd75Srobert   if (auto BinOpCombine = tryCombineFromSVBoolBinOp(IC, II))
633*d415bd75Srobert     return BinOpCombine;
634*d415bd75Srobert 
63573471bf0Spatrick   SmallVector<Instruction *, 32> CandidatesForRemoval;
63673471bf0Spatrick   Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr;
63773471bf0Spatrick 
63873471bf0Spatrick   const auto *IVTy = cast<VectorType>(II.getType());
63973471bf0Spatrick 
64073471bf0Spatrick   // Walk the chain of conversions.
64173471bf0Spatrick   while (Cursor) {
64273471bf0Spatrick     // If the type of the cursor has fewer lanes than the final result, zeroing
64373471bf0Spatrick     // must take place, which breaks the equivalence chain.
64473471bf0Spatrick     const auto *CursorVTy = cast<VectorType>(Cursor->getType());
64573471bf0Spatrick     if (CursorVTy->getElementCount().getKnownMinValue() <
64673471bf0Spatrick         IVTy->getElementCount().getKnownMinValue())
64773471bf0Spatrick       break;
64873471bf0Spatrick 
64973471bf0Spatrick     // If the cursor has the same type as I, it is a viable replacement.
65073471bf0Spatrick     if (Cursor->getType() == IVTy)
65173471bf0Spatrick       EarliestReplacement = Cursor;
65273471bf0Spatrick 
65373471bf0Spatrick     auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor);
65473471bf0Spatrick 
65573471bf0Spatrick     // If this is not an SVE conversion intrinsic, this is the end of the chain.
65673471bf0Spatrick     if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() ==
65773471bf0Spatrick                                   Intrinsic::aarch64_sve_convert_to_svbool ||
65873471bf0Spatrick                               IntrinsicCursor->getIntrinsicID() ==
65973471bf0Spatrick                                   Intrinsic::aarch64_sve_convert_from_svbool))
66073471bf0Spatrick       break;
66173471bf0Spatrick 
66273471bf0Spatrick     CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor);
66373471bf0Spatrick     Cursor = IntrinsicCursor->getOperand(0);
66473471bf0Spatrick   }
66573471bf0Spatrick 
66673471bf0Spatrick   // If no viable replacement in the conversion chain was found, there is
66773471bf0Spatrick   // nothing to do.
66873471bf0Spatrick   if (!EarliestReplacement)
669*d415bd75Srobert     return std::nullopt;
67073471bf0Spatrick 
67173471bf0Spatrick   return IC.replaceInstUsesWith(II, EarliestReplacement);
67273471bf0Spatrick }
67373471bf0Spatrick 
instCombineSVESel(InstCombiner & IC,IntrinsicInst & II)674*d415bd75Srobert static std::optional<Instruction *> instCombineSVESel(InstCombiner &IC,
675*d415bd75Srobert                                                       IntrinsicInst &II) {
676*d415bd75Srobert   IRBuilder<> Builder(&II);
677*d415bd75Srobert   auto Select = Builder.CreateSelect(II.getOperand(0), II.getOperand(1),
678*d415bd75Srobert                                      II.getOperand(2));
679*d415bd75Srobert   return IC.replaceInstUsesWith(II, Select);
680*d415bd75Srobert }
681*d415bd75Srobert 
instCombineSVEDup(InstCombiner & IC,IntrinsicInst & II)682*d415bd75Srobert static std::optional<Instruction *> instCombineSVEDup(InstCombiner &IC,
68373471bf0Spatrick                                                       IntrinsicInst &II) {
68473471bf0Spatrick   IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
68573471bf0Spatrick   if (!Pg)
686*d415bd75Srobert     return std::nullopt;
68773471bf0Spatrick 
68873471bf0Spatrick   if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
689*d415bd75Srobert     return std::nullopt;
69073471bf0Spatrick 
69173471bf0Spatrick   const auto PTruePattern =
69273471bf0Spatrick       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
69373471bf0Spatrick   if (PTruePattern != AArch64SVEPredPattern::vl1)
694*d415bd75Srobert     return std::nullopt;
69573471bf0Spatrick 
69673471bf0Spatrick   // The intrinsic is inserting into lane zero so use an insert instead.
69773471bf0Spatrick   auto *IdxTy = Type::getInt64Ty(II.getContext());
69873471bf0Spatrick   auto *Insert = InsertElementInst::Create(
69973471bf0Spatrick       II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0));
70073471bf0Spatrick   Insert->insertBefore(&II);
70173471bf0Spatrick   Insert->takeName(&II);
70273471bf0Spatrick 
70373471bf0Spatrick   return IC.replaceInstUsesWith(II, Insert);
70473471bf0Spatrick }
70573471bf0Spatrick 
instCombineSVEDupX(InstCombiner & IC,IntrinsicInst & II)706*d415bd75Srobert static std::optional<Instruction *> instCombineSVEDupX(InstCombiner &IC,
707*d415bd75Srobert                                                        IntrinsicInst &II) {
708*d415bd75Srobert   // Replace DupX with a regular IR splat.
709*d415bd75Srobert   IRBuilder<> Builder(II.getContext());
710*d415bd75Srobert   Builder.SetInsertPoint(&II);
711*d415bd75Srobert   auto *RetTy = cast<ScalableVectorType>(II.getType());
712*d415bd75Srobert   Value *Splat =
713*d415bd75Srobert       Builder.CreateVectorSplat(RetTy->getElementCount(), II.getArgOperand(0));
714*d415bd75Srobert   Splat->takeName(&II);
715*d415bd75Srobert   return IC.replaceInstUsesWith(II, Splat);
716*d415bd75Srobert }
717*d415bd75Srobert 
instCombineSVECmpNE(InstCombiner & IC,IntrinsicInst & II)718*d415bd75Srobert static std::optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
71973471bf0Spatrick                                                         IntrinsicInst &II) {
72073471bf0Spatrick   LLVMContext &Ctx = II.getContext();
72173471bf0Spatrick   IRBuilder<> Builder(Ctx);
72273471bf0Spatrick   Builder.SetInsertPoint(&II);
72373471bf0Spatrick 
72473471bf0Spatrick   // Check that the predicate is all active
72573471bf0Spatrick   auto *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
72673471bf0Spatrick   if (!Pg || Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
727*d415bd75Srobert     return std::nullopt;
72873471bf0Spatrick 
72973471bf0Spatrick   const auto PTruePattern =
73073471bf0Spatrick       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
73173471bf0Spatrick   if (PTruePattern != AArch64SVEPredPattern::all)
732*d415bd75Srobert     return std::nullopt;
73373471bf0Spatrick 
73473471bf0Spatrick   // Check that we have a compare of zero..
735*d415bd75Srobert   auto *SplatValue =
736*d415bd75Srobert       dyn_cast_or_null<ConstantInt>(getSplatValue(II.getArgOperand(2)));
737*d415bd75Srobert   if (!SplatValue || !SplatValue->isZero())
738*d415bd75Srobert     return std::nullopt;
73973471bf0Spatrick 
74073471bf0Spatrick   // ..against a dupq
74173471bf0Spatrick   auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
74273471bf0Spatrick   if (!DupQLane ||
74373471bf0Spatrick       DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane)
744*d415bd75Srobert     return std::nullopt;
74573471bf0Spatrick 
74673471bf0Spatrick   // Where the dupq is a lane 0 replicate of a vector insert
74773471bf0Spatrick   if (!cast<ConstantInt>(DupQLane->getArgOperand(1))->isZero())
748*d415bd75Srobert     return std::nullopt;
74973471bf0Spatrick 
75073471bf0Spatrick   auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0));
751*d415bd75Srobert   if (!VecIns || VecIns->getIntrinsicID() != Intrinsic::vector_insert)
752*d415bd75Srobert     return std::nullopt;
75373471bf0Spatrick 
75473471bf0Spatrick   // Where the vector insert is a fixed constant vector insert into undef at
75573471bf0Spatrick   // index zero
75673471bf0Spatrick   if (!isa<UndefValue>(VecIns->getArgOperand(0)))
757*d415bd75Srobert     return std::nullopt;
75873471bf0Spatrick 
75973471bf0Spatrick   if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero())
760*d415bd75Srobert     return std::nullopt;
76173471bf0Spatrick 
76273471bf0Spatrick   auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1));
76373471bf0Spatrick   if (!ConstVec)
764*d415bd75Srobert     return std::nullopt;
76573471bf0Spatrick 
76673471bf0Spatrick   auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType());
76773471bf0Spatrick   auto *OutTy = dyn_cast<ScalableVectorType>(II.getType());
76873471bf0Spatrick   if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements())
769*d415bd75Srobert     return std::nullopt;
77073471bf0Spatrick 
77173471bf0Spatrick   unsigned NumElts = VecTy->getNumElements();
77273471bf0Spatrick   unsigned PredicateBits = 0;
77373471bf0Spatrick 
77473471bf0Spatrick   // Expand intrinsic operands to a 16-bit byte level predicate
77573471bf0Spatrick   for (unsigned I = 0; I < NumElts; ++I) {
77673471bf0Spatrick     auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I));
77773471bf0Spatrick     if (!Arg)
778*d415bd75Srobert       return std::nullopt;
77973471bf0Spatrick     if (!Arg->isZero())
78073471bf0Spatrick       PredicateBits |= 1 << (I * (16 / NumElts));
78173471bf0Spatrick   }
78273471bf0Spatrick 
78373471bf0Spatrick   // If all bits are zero bail early with an empty predicate
78473471bf0Spatrick   if (PredicateBits == 0) {
78573471bf0Spatrick     auto *PFalse = Constant::getNullValue(II.getType());
78673471bf0Spatrick     PFalse->takeName(&II);
78773471bf0Spatrick     return IC.replaceInstUsesWith(II, PFalse);
78873471bf0Spatrick   }
78973471bf0Spatrick 
79073471bf0Spatrick   // Calculate largest predicate type used (where byte predicate is largest)
79173471bf0Spatrick   unsigned Mask = 8;
79273471bf0Spatrick   for (unsigned I = 0; I < 16; ++I)
79373471bf0Spatrick     if ((PredicateBits & (1 << I)) != 0)
79473471bf0Spatrick       Mask |= (I % 8);
79573471bf0Spatrick 
79673471bf0Spatrick   unsigned PredSize = Mask & -Mask;
79773471bf0Spatrick   auto *PredType = ScalableVectorType::get(
79873471bf0Spatrick       Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8));
79973471bf0Spatrick 
80073471bf0Spatrick   // Ensure all relevant bits are set
80173471bf0Spatrick   for (unsigned I = 0; I < 16; I += PredSize)
80273471bf0Spatrick     if ((PredicateBits & (1 << I)) == 0)
803*d415bd75Srobert       return std::nullopt;
80473471bf0Spatrick 
80573471bf0Spatrick   auto *PTruePat =
80673471bf0Spatrick       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
80773471bf0Spatrick   auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
80873471bf0Spatrick                                         {PredType}, {PTruePat});
80973471bf0Spatrick   auto *ConvertToSVBool = Builder.CreateIntrinsic(
81073471bf0Spatrick       Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue});
81173471bf0Spatrick   auto *ConvertFromSVBool =
81273471bf0Spatrick       Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
81373471bf0Spatrick                               {II.getType()}, {ConvertToSVBool});
81473471bf0Spatrick 
81573471bf0Spatrick   ConvertFromSVBool->takeName(&II);
81673471bf0Spatrick   return IC.replaceInstUsesWith(II, ConvertFromSVBool);
81773471bf0Spatrick }
81873471bf0Spatrick 
instCombineSVELast(InstCombiner & IC,IntrinsicInst & II)819*d415bd75Srobert static std::optional<Instruction *> instCombineSVELast(InstCombiner &IC,
82073471bf0Spatrick                                                        IntrinsicInst &II) {
821*d415bd75Srobert   IRBuilder<> Builder(II.getContext());
822*d415bd75Srobert   Builder.SetInsertPoint(&II);
82373471bf0Spatrick   Value *Pg = II.getArgOperand(0);
82473471bf0Spatrick   Value *Vec = II.getArgOperand(1);
825*d415bd75Srobert   auto IntrinsicID = II.getIntrinsicID();
826*d415bd75Srobert   bool IsAfter = IntrinsicID == Intrinsic::aarch64_sve_lasta;
82773471bf0Spatrick 
82873471bf0Spatrick   // lastX(splat(X)) --> X
82973471bf0Spatrick   if (auto *SplatVal = getSplatValue(Vec))
83073471bf0Spatrick     return IC.replaceInstUsesWith(II, SplatVal);
83173471bf0Spatrick 
832*d415bd75Srobert   // If x and/or y is a splat value then:
833*d415bd75Srobert   // lastX (binop (x, y)) --> binop(lastX(x), lastX(y))
834*d415bd75Srobert   Value *LHS, *RHS;
835*d415bd75Srobert   if (match(Vec, m_OneUse(m_BinOp(m_Value(LHS), m_Value(RHS))))) {
836*d415bd75Srobert     if (isSplatValue(LHS) || isSplatValue(RHS)) {
837*d415bd75Srobert       auto *OldBinOp = cast<BinaryOperator>(Vec);
838*d415bd75Srobert       auto OpC = OldBinOp->getOpcode();
839*d415bd75Srobert       auto *NewLHS =
840*d415bd75Srobert           Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, LHS});
841*d415bd75Srobert       auto *NewRHS =
842*d415bd75Srobert           Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, RHS});
843*d415bd75Srobert       auto *NewBinOp = BinaryOperator::CreateWithCopiedFlags(
844*d415bd75Srobert           OpC, NewLHS, NewRHS, OldBinOp, OldBinOp->getName(), &II);
845*d415bd75Srobert       return IC.replaceInstUsesWith(II, NewBinOp);
846*d415bd75Srobert     }
847*d415bd75Srobert   }
848*d415bd75Srobert 
84973471bf0Spatrick   auto *C = dyn_cast<Constant>(Pg);
85073471bf0Spatrick   if (IsAfter && C && C->isNullValue()) {
85173471bf0Spatrick     // The intrinsic is extracting lane 0 so use an extract instead.
85273471bf0Spatrick     auto *IdxTy = Type::getInt64Ty(II.getContext());
85373471bf0Spatrick     auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0));
85473471bf0Spatrick     Extract->insertBefore(&II);
85573471bf0Spatrick     Extract->takeName(&II);
85673471bf0Spatrick     return IC.replaceInstUsesWith(II, Extract);
85773471bf0Spatrick   }
85873471bf0Spatrick 
85973471bf0Spatrick   auto *IntrPG = dyn_cast<IntrinsicInst>(Pg);
86073471bf0Spatrick   if (!IntrPG)
861*d415bd75Srobert     return std::nullopt;
86273471bf0Spatrick 
86373471bf0Spatrick   if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
864*d415bd75Srobert     return std::nullopt;
86573471bf0Spatrick 
86673471bf0Spatrick   const auto PTruePattern =
86773471bf0Spatrick       cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue();
86873471bf0Spatrick 
86973471bf0Spatrick   // Can the intrinsic's predicate be converted to a known constant index?
870*d415bd75Srobert   unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern);
871*d415bd75Srobert   if (!MinNumElts)
872*d415bd75Srobert     return std::nullopt;
87373471bf0Spatrick 
874*d415bd75Srobert   unsigned Idx = MinNumElts - 1;
87573471bf0Spatrick   // Increment the index if extracting the element after the last active
87673471bf0Spatrick   // predicate element.
87773471bf0Spatrick   if (IsAfter)
87873471bf0Spatrick     ++Idx;
87973471bf0Spatrick 
88073471bf0Spatrick   // Ignore extracts whose index is larger than the known minimum vector
88173471bf0Spatrick   // length. NOTE: This is an artificial constraint where we prefer to
88273471bf0Spatrick   // maintain what the user asked for until an alternative is proven faster.
88373471bf0Spatrick   auto *PgVTy = cast<ScalableVectorType>(Pg->getType());
88473471bf0Spatrick   if (Idx >= PgVTy->getMinNumElements())
885*d415bd75Srobert     return std::nullopt;
88673471bf0Spatrick 
88773471bf0Spatrick   // The intrinsic is extracting a fixed lane so use an extract instead.
88873471bf0Spatrick   auto *IdxTy = Type::getInt64Ty(II.getContext());
88973471bf0Spatrick   auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx));
89073471bf0Spatrick   Extract->insertBefore(&II);
89173471bf0Spatrick   Extract->takeName(&II);
89273471bf0Spatrick   return IC.replaceInstUsesWith(II, Extract);
89373471bf0Spatrick }
89473471bf0Spatrick 
instCombineSVECondLast(InstCombiner & IC,IntrinsicInst & II)895*d415bd75Srobert static std::optional<Instruction *> instCombineSVECondLast(InstCombiner &IC,
896*d415bd75Srobert                                                            IntrinsicInst &II) {
897*d415bd75Srobert   // The SIMD&FP variant of CLAST[AB] is significantly faster than the scalar
898*d415bd75Srobert   // integer variant across a variety of micro-architectures. Replace scalar
899*d415bd75Srobert   // integer CLAST[AB] intrinsic with optimal SIMD&FP variant. A simple
900*d415bd75Srobert   // bitcast-to-fp + clast[ab] + bitcast-to-int will cost a cycle or two more
901*d415bd75Srobert   // depending on the micro-architecture, but has been observed as generally
902*d415bd75Srobert   // being faster, particularly when the CLAST[AB] op is a loop-carried
903*d415bd75Srobert   // dependency.
904*d415bd75Srobert   IRBuilder<> Builder(II.getContext());
905*d415bd75Srobert   Builder.SetInsertPoint(&II);
906*d415bd75Srobert   Value *Pg = II.getArgOperand(0);
907*d415bd75Srobert   Value *Fallback = II.getArgOperand(1);
908*d415bd75Srobert   Value *Vec = II.getArgOperand(2);
909*d415bd75Srobert   Type *Ty = II.getType();
910*d415bd75Srobert 
911*d415bd75Srobert   if (!Ty->isIntegerTy())
912*d415bd75Srobert     return std::nullopt;
913*d415bd75Srobert 
914*d415bd75Srobert   Type *FPTy;
915*d415bd75Srobert   switch (cast<IntegerType>(Ty)->getBitWidth()) {
916*d415bd75Srobert   default:
917*d415bd75Srobert     return std::nullopt;
918*d415bd75Srobert   case 16:
919*d415bd75Srobert     FPTy = Builder.getHalfTy();
920*d415bd75Srobert     break;
921*d415bd75Srobert   case 32:
922*d415bd75Srobert     FPTy = Builder.getFloatTy();
923*d415bd75Srobert     break;
924*d415bd75Srobert   case 64:
925*d415bd75Srobert     FPTy = Builder.getDoubleTy();
926*d415bd75Srobert     break;
927*d415bd75Srobert   }
928*d415bd75Srobert 
929*d415bd75Srobert   Value *FPFallBack = Builder.CreateBitCast(Fallback, FPTy);
930*d415bd75Srobert   auto *FPVTy = VectorType::get(
931*d415bd75Srobert       FPTy, cast<VectorType>(Vec->getType())->getElementCount());
932*d415bd75Srobert   Value *FPVec = Builder.CreateBitCast(Vec, FPVTy);
933*d415bd75Srobert   auto *FPII = Builder.CreateIntrinsic(II.getIntrinsicID(), {FPVec->getType()},
934*d415bd75Srobert                                        {Pg, FPFallBack, FPVec});
935*d415bd75Srobert   Value *FPIItoInt = Builder.CreateBitCast(FPII, II.getType());
936*d415bd75Srobert   return IC.replaceInstUsesWith(II, FPIItoInt);
937*d415bd75Srobert }
938*d415bd75Srobert 
instCombineRDFFR(InstCombiner & IC,IntrinsicInst & II)939*d415bd75Srobert static std::optional<Instruction *> instCombineRDFFR(InstCombiner &IC,
94073471bf0Spatrick                                                      IntrinsicInst &II) {
94173471bf0Spatrick   LLVMContext &Ctx = II.getContext();
94273471bf0Spatrick   IRBuilder<> Builder(Ctx);
94373471bf0Spatrick   Builder.SetInsertPoint(&II);
94473471bf0Spatrick   // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr
94573471bf0Spatrick   // can work with RDFFR_PP for ptest elimination.
94673471bf0Spatrick   auto *AllPat =
94773471bf0Spatrick       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
94873471bf0Spatrick   auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
94973471bf0Spatrick                                         {II.getType()}, {AllPat});
95073471bf0Spatrick   auto *RDFFR =
95173471bf0Spatrick       Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {}, {PTrue});
95273471bf0Spatrick   RDFFR->takeName(&II);
95373471bf0Spatrick   return IC.replaceInstUsesWith(II, RDFFR);
95473471bf0Spatrick }
95573471bf0Spatrick 
956*d415bd75Srobert static std::optional<Instruction *>
instCombineSVECntElts(InstCombiner & IC,IntrinsicInst & II,unsigned NumElts)95773471bf0Spatrick instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
95873471bf0Spatrick   const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
95973471bf0Spatrick 
96073471bf0Spatrick   if (Pattern == AArch64SVEPredPattern::all) {
96173471bf0Spatrick     LLVMContext &Ctx = II.getContext();
96273471bf0Spatrick     IRBuilder<> Builder(Ctx);
96373471bf0Spatrick     Builder.SetInsertPoint(&II);
96473471bf0Spatrick 
96573471bf0Spatrick     Constant *StepVal = ConstantInt::get(II.getType(), NumElts);
96673471bf0Spatrick     auto *VScale = Builder.CreateVScale(StepVal);
96773471bf0Spatrick     VScale->takeName(&II);
96873471bf0Spatrick     return IC.replaceInstUsesWith(II, VScale);
96973471bf0Spatrick   }
97073471bf0Spatrick 
971*d415bd75Srobert   unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern);
97273471bf0Spatrick 
973*d415bd75Srobert   return MinNumElts && NumElts >= MinNumElts
974*d415bd75Srobert              ? std::optional<Instruction *>(IC.replaceInstUsesWith(
97573471bf0Spatrick                    II, ConstantInt::get(II.getType(), MinNumElts)))
976*d415bd75Srobert              : std::nullopt;
97773471bf0Spatrick }
97873471bf0Spatrick 
instCombineSVEPTest(InstCombiner & IC,IntrinsicInst & II)979*d415bd75Srobert static std::optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
98073471bf0Spatrick                                                         IntrinsicInst &II) {
981*d415bd75Srobert   Value *PgVal = II.getArgOperand(0);
982*d415bd75Srobert   Value *OpVal = II.getArgOperand(1);
98373471bf0Spatrick 
98473471bf0Spatrick   IRBuilder<> Builder(II.getContext());
98573471bf0Spatrick   Builder.SetInsertPoint(&II);
98673471bf0Spatrick 
987*d415bd75Srobert   // PTEST_<FIRST|LAST>(X, X) is equivalent to PTEST_ANY(X, X).
988*d415bd75Srobert   // Later optimizations prefer this form.
989*d415bd75Srobert   if (PgVal == OpVal &&
990*d415bd75Srobert       (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_first ||
991*d415bd75Srobert        II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_last)) {
992*d415bd75Srobert     Value *Ops[] = {PgVal, OpVal};
993*d415bd75Srobert     Type *Tys[] = {PgVal->getType()};
994*d415bd75Srobert 
995*d415bd75Srobert     auto *PTest =
996*d415bd75Srobert         Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptest_any, Tys, Ops);
997*d415bd75Srobert     PTest->takeName(&II);
998*d415bd75Srobert 
999*d415bd75Srobert     return IC.replaceInstUsesWith(II, PTest);
1000*d415bd75Srobert   }
1001*d415bd75Srobert 
1002*d415bd75Srobert   IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(PgVal);
1003*d415bd75Srobert   IntrinsicInst *Op = dyn_cast<IntrinsicInst>(OpVal);
1004*d415bd75Srobert 
1005*d415bd75Srobert   if (!Pg || !Op)
1006*d415bd75Srobert     return std::nullopt;
1007*d415bd75Srobert 
1008*d415bd75Srobert   Intrinsic::ID OpIID = Op->getIntrinsicID();
1009*d415bd75Srobert 
1010*d415bd75Srobert   if (Pg->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
1011*d415bd75Srobert       OpIID == Intrinsic::aarch64_sve_convert_to_svbool &&
1012*d415bd75Srobert       Pg->getArgOperand(0)->getType() == Op->getArgOperand(0)->getType()) {
1013*d415bd75Srobert     Value *Ops[] = {Pg->getArgOperand(0), Op->getArgOperand(0)};
1014*d415bd75Srobert     Type *Tys[] = {Pg->getArgOperand(0)->getType()};
101573471bf0Spatrick 
101673471bf0Spatrick     auto *PTest = Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
101773471bf0Spatrick 
101873471bf0Spatrick     PTest->takeName(&II);
101973471bf0Spatrick     return IC.replaceInstUsesWith(II, PTest);
102073471bf0Spatrick   }
102173471bf0Spatrick 
1022*d415bd75Srobert   // Transform PTEST_ANY(X=OP(PG,...), X) -> PTEST_ANY(PG, X)).
1023*d415bd75Srobert   // Later optimizations may rewrite sequence to use the flag-setting variant
1024*d415bd75Srobert   // of instruction X to remove PTEST.
1025*d415bd75Srobert   if ((Pg == Op) && (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_any) &&
1026*d415bd75Srobert       ((OpIID == Intrinsic::aarch64_sve_brka_z) ||
1027*d415bd75Srobert        (OpIID == Intrinsic::aarch64_sve_brkb_z) ||
1028*d415bd75Srobert        (OpIID == Intrinsic::aarch64_sve_brkpa_z) ||
1029*d415bd75Srobert        (OpIID == Intrinsic::aarch64_sve_brkpb_z) ||
1030*d415bd75Srobert        (OpIID == Intrinsic::aarch64_sve_rdffr_z) ||
1031*d415bd75Srobert        (OpIID == Intrinsic::aarch64_sve_and_z) ||
1032*d415bd75Srobert        (OpIID == Intrinsic::aarch64_sve_bic_z) ||
1033*d415bd75Srobert        (OpIID == Intrinsic::aarch64_sve_eor_z) ||
1034*d415bd75Srobert        (OpIID == Intrinsic::aarch64_sve_nand_z) ||
1035*d415bd75Srobert        (OpIID == Intrinsic::aarch64_sve_nor_z) ||
1036*d415bd75Srobert        (OpIID == Intrinsic::aarch64_sve_orn_z) ||
1037*d415bd75Srobert        (OpIID == Intrinsic::aarch64_sve_orr_z))) {
1038*d415bd75Srobert     Value *Ops[] = {Pg->getArgOperand(0), Pg};
1039*d415bd75Srobert     Type *Tys[] = {Pg->getType()};
1040*d415bd75Srobert 
1041*d415bd75Srobert     auto *PTest = Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
1042*d415bd75Srobert     PTest->takeName(&II);
1043*d415bd75Srobert 
1044*d415bd75Srobert     return IC.replaceInstUsesWith(II, PTest);
104573471bf0Spatrick   }
104673471bf0Spatrick 
1047*d415bd75Srobert   return std::nullopt;
1048*d415bd75Srobert }
1049*d415bd75Srobert 
1050*d415bd75Srobert template <Intrinsic::ID MulOpc, typename Intrinsic::ID FuseOpc>
1051*d415bd75Srobert static std::optional<Instruction *>
instCombineSVEVectorFuseMulAddSub(InstCombiner & IC,IntrinsicInst & II,bool MergeIntoAddendOp)1052*d415bd75Srobert instCombineSVEVectorFuseMulAddSub(InstCombiner &IC, IntrinsicInst &II,
1053*d415bd75Srobert                                   bool MergeIntoAddendOp) {
1054*d415bd75Srobert   Value *P = II.getOperand(0);
1055*d415bd75Srobert   Value *MulOp0, *MulOp1, *AddendOp, *Mul;
1056*d415bd75Srobert   if (MergeIntoAddendOp) {
1057*d415bd75Srobert     AddendOp = II.getOperand(1);
1058*d415bd75Srobert     Mul = II.getOperand(2);
1059*d415bd75Srobert   } else {
1060*d415bd75Srobert     AddendOp = II.getOperand(2);
1061*d415bd75Srobert     Mul = II.getOperand(1);
1062*d415bd75Srobert   }
1063*d415bd75Srobert 
1064*d415bd75Srobert   if (!match(Mul, m_Intrinsic<MulOpc>(m_Specific(P), m_Value(MulOp0),
1065*d415bd75Srobert                                       m_Value(MulOp1))))
1066*d415bd75Srobert     return std::nullopt;
1067*d415bd75Srobert 
1068*d415bd75Srobert   if (!Mul->hasOneUse())
1069*d415bd75Srobert     return std::nullopt;
1070*d415bd75Srobert 
1071*d415bd75Srobert   Instruction *FMFSource = nullptr;
1072*d415bd75Srobert   if (II.getType()->isFPOrFPVectorTy()) {
1073*d415bd75Srobert     llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
1074*d415bd75Srobert     // Stop the combine when the flags on the inputs differ in case dropping
1075*d415bd75Srobert     // flags would lead to us missing out on more beneficial optimizations.
1076*d415bd75Srobert     if (FAddFlags != cast<CallInst>(Mul)->getFastMathFlags())
1077*d415bd75Srobert       return std::nullopt;
1078*d415bd75Srobert     if (!FAddFlags.allowContract())
1079*d415bd75Srobert       return std::nullopt;
1080*d415bd75Srobert     FMFSource = &II;
1081*d415bd75Srobert   }
1082*d415bd75Srobert 
1083*d415bd75Srobert   IRBuilder<> Builder(II.getContext());
1084*d415bd75Srobert   Builder.SetInsertPoint(&II);
1085*d415bd75Srobert 
1086*d415bd75Srobert   CallInst *Res;
1087*d415bd75Srobert   if (MergeIntoAddendOp)
1088*d415bd75Srobert     Res = Builder.CreateIntrinsic(FuseOpc, {II.getType()},
1089*d415bd75Srobert                                   {P, AddendOp, MulOp0, MulOp1}, FMFSource);
1090*d415bd75Srobert   else
1091*d415bd75Srobert     Res = Builder.CreateIntrinsic(FuseOpc, {II.getType()},
1092*d415bd75Srobert                                   {P, MulOp0, MulOp1, AddendOp}, FMFSource);
1093*d415bd75Srobert 
1094*d415bd75Srobert   return IC.replaceInstUsesWith(II, Res);
1095*d415bd75Srobert }
1096*d415bd75Srobert 
isAllActivePredicate(Value * Pred)1097*d415bd75Srobert static bool isAllActivePredicate(Value *Pred) {
1098*d415bd75Srobert   // Look through convert.from.svbool(convert.to.svbool(...) chain.
1099*d415bd75Srobert   Value *UncastedPred;
1100*d415bd75Srobert   if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_from_svbool>(
1101*d415bd75Srobert                       m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>(
1102*d415bd75Srobert                           m_Value(UncastedPred)))))
1103*d415bd75Srobert     // If the predicate has the same or less lanes than the uncasted
1104*d415bd75Srobert     // predicate then we know the casting has no effect.
1105*d415bd75Srobert     if (cast<ScalableVectorType>(Pred->getType())->getMinNumElements() <=
1106*d415bd75Srobert         cast<ScalableVectorType>(UncastedPred->getType())->getMinNumElements())
1107*d415bd75Srobert       Pred = UncastedPred;
1108*d415bd75Srobert 
1109*d415bd75Srobert   return match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
1110*d415bd75Srobert                          m_ConstantInt<AArch64SVEPredPattern::all>()));
1111*d415bd75Srobert }
1112*d415bd75Srobert 
1113*d415bd75Srobert static std::optional<Instruction *>
instCombineSVELD1(InstCombiner & IC,IntrinsicInst & II,const DataLayout & DL)1114*d415bd75Srobert instCombineSVELD1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
1115*d415bd75Srobert   IRBuilder<> Builder(II.getContext());
1116*d415bd75Srobert   Builder.SetInsertPoint(&II);
1117*d415bd75Srobert 
1118*d415bd75Srobert   Value *Pred = II.getOperand(0);
1119*d415bd75Srobert   Value *PtrOp = II.getOperand(1);
1120*d415bd75Srobert   Type *VecTy = II.getType();
1121*d415bd75Srobert   Value *VecPtr = Builder.CreateBitCast(PtrOp, VecTy->getPointerTo());
1122*d415bd75Srobert 
1123*d415bd75Srobert   if (isAllActivePredicate(Pred)) {
1124*d415bd75Srobert     LoadInst *Load = Builder.CreateLoad(VecTy, VecPtr);
1125*d415bd75Srobert     Load->copyMetadata(II);
1126*d415bd75Srobert     return IC.replaceInstUsesWith(II, Load);
1127*d415bd75Srobert   }
1128*d415bd75Srobert 
1129*d415bd75Srobert   CallInst *MaskedLoad =
1130*d415bd75Srobert       Builder.CreateMaskedLoad(VecTy, VecPtr, PtrOp->getPointerAlignment(DL),
1131*d415bd75Srobert                                Pred, ConstantAggregateZero::get(VecTy));
1132*d415bd75Srobert   MaskedLoad->copyMetadata(II);
1133*d415bd75Srobert   return IC.replaceInstUsesWith(II, MaskedLoad);
1134*d415bd75Srobert }
1135*d415bd75Srobert 
1136*d415bd75Srobert static std::optional<Instruction *>
instCombineSVEST1(InstCombiner & IC,IntrinsicInst & II,const DataLayout & DL)1137*d415bd75Srobert instCombineSVEST1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
1138*d415bd75Srobert   IRBuilder<> Builder(II.getContext());
1139*d415bd75Srobert   Builder.SetInsertPoint(&II);
1140*d415bd75Srobert 
1141*d415bd75Srobert   Value *VecOp = II.getOperand(0);
1142*d415bd75Srobert   Value *Pred = II.getOperand(1);
1143*d415bd75Srobert   Value *PtrOp = II.getOperand(2);
1144*d415bd75Srobert   Value *VecPtr =
1145*d415bd75Srobert       Builder.CreateBitCast(PtrOp, VecOp->getType()->getPointerTo());
1146*d415bd75Srobert 
1147*d415bd75Srobert   if (isAllActivePredicate(Pred)) {
1148*d415bd75Srobert     StoreInst *Store = Builder.CreateStore(VecOp, VecPtr);
1149*d415bd75Srobert     Store->copyMetadata(II);
1150*d415bd75Srobert     return IC.eraseInstFromFunction(II);
1151*d415bd75Srobert   }
1152*d415bd75Srobert 
1153*d415bd75Srobert   CallInst *MaskedStore = Builder.CreateMaskedStore(
1154*d415bd75Srobert       VecOp, VecPtr, PtrOp->getPointerAlignment(DL), Pred);
1155*d415bd75Srobert   MaskedStore->copyMetadata(II);
1156*d415bd75Srobert   return IC.eraseInstFromFunction(II);
1157*d415bd75Srobert }
1158*d415bd75Srobert 
intrinsicIDToBinOpCode(unsigned Intrinsic)1159*d415bd75Srobert static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) {
1160*d415bd75Srobert   switch (Intrinsic) {
1161*d415bd75Srobert   case Intrinsic::aarch64_sve_fmul:
1162*d415bd75Srobert     return Instruction::BinaryOps::FMul;
1163*d415bd75Srobert   case Intrinsic::aarch64_sve_fadd:
1164*d415bd75Srobert     return Instruction::BinaryOps::FAdd;
1165*d415bd75Srobert   case Intrinsic::aarch64_sve_fsub:
1166*d415bd75Srobert     return Instruction::BinaryOps::FSub;
1167*d415bd75Srobert   default:
1168*d415bd75Srobert     return Instruction::BinaryOpsEnd;
1169*d415bd75Srobert   }
1170*d415bd75Srobert }
1171*d415bd75Srobert 
1172*d415bd75Srobert static std::optional<Instruction *>
instCombineSVEVectorBinOp(InstCombiner & IC,IntrinsicInst & II)1173*d415bd75Srobert instCombineSVEVectorBinOp(InstCombiner &IC, IntrinsicInst &II) {
1174*d415bd75Srobert   auto *OpPredicate = II.getOperand(0);
1175*d415bd75Srobert   auto BinOpCode = intrinsicIDToBinOpCode(II.getIntrinsicID());
1176*d415bd75Srobert   if (BinOpCode == Instruction::BinaryOpsEnd ||
1177*d415bd75Srobert       !match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
1178*d415bd75Srobert                               m_ConstantInt<AArch64SVEPredPattern::all>())))
1179*d415bd75Srobert     return std::nullopt;
1180*d415bd75Srobert   IRBuilder<> Builder(II.getContext());
1181*d415bd75Srobert   Builder.SetInsertPoint(&II);
1182*d415bd75Srobert   Builder.setFastMathFlags(II.getFastMathFlags());
1183*d415bd75Srobert   auto BinOp =
1184*d415bd75Srobert       Builder.CreateBinOp(BinOpCode, II.getOperand(1), II.getOperand(2));
1185*d415bd75Srobert   return IC.replaceInstUsesWith(II, BinOp);
1186*d415bd75Srobert }
1187*d415bd75Srobert 
instCombineSVEVectorAdd(InstCombiner & IC,IntrinsicInst & II)1188*d415bd75Srobert static std::optional<Instruction *> instCombineSVEVectorAdd(InstCombiner &IC,
1189*d415bd75Srobert                                                             IntrinsicInst &II) {
1190*d415bd75Srobert   if (auto FMLA =
1191*d415bd75Srobert           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1192*d415bd75Srobert                                             Intrinsic::aarch64_sve_fmla>(IC, II,
1193*d415bd75Srobert                                                                          true))
1194*d415bd75Srobert     return FMLA;
1195*d415bd75Srobert   if (auto MLA = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
1196*d415bd75Srobert                                                    Intrinsic::aarch64_sve_mla>(
1197*d415bd75Srobert           IC, II, true))
1198*d415bd75Srobert     return MLA;
1199*d415bd75Srobert   if (auto FMAD =
1200*d415bd75Srobert           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1201*d415bd75Srobert                                             Intrinsic::aarch64_sve_fmad>(IC, II,
1202*d415bd75Srobert                                                                          false))
1203*d415bd75Srobert     return FMAD;
1204*d415bd75Srobert   if (auto MAD = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
1205*d415bd75Srobert                                                    Intrinsic::aarch64_sve_mad>(
1206*d415bd75Srobert           IC, II, false))
1207*d415bd75Srobert     return MAD;
1208*d415bd75Srobert   return instCombineSVEVectorBinOp(IC, II);
1209*d415bd75Srobert }
1210*d415bd75Srobert 
instCombineSVEVectorSub(InstCombiner & IC,IntrinsicInst & II)1211*d415bd75Srobert static std::optional<Instruction *> instCombineSVEVectorSub(InstCombiner &IC,
1212*d415bd75Srobert                                                             IntrinsicInst &II) {
1213*d415bd75Srobert   if (auto FMLS =
1214*d415bd75Srobert           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1215*d415bd75Srobert                                             Intrinsic::aarch64_sve_fmls>(IC, II,
1216*d415bd75Srobert                                                                          true))
1217*d415bd75Srobert     return FMLS;
1218*d415bd75Srobert   if (auto MLS = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
1219*d415bd75Srobert                                                    Intrinsic::aarch64_sve_mls>(
1220*d415bd75Srobert           IC, II, true))
1221*d415bd75Srobert     return MLS;
1222*d415bd75Srobert   if (auto FMSB =
1223*d415bd75Srobert           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1224*d415bd75Srobert                                             Intrinsic::aarch64_sve_fnmsb>(
1225*d415bd75Srobert               IC, II, false))
1226*d415bd75Srobert     return FMSB;
1227*d415bd75Srobert   return instCombineSVEVectorBinOp(IC, II);
1228*d415bd75Srobert }
1229*d415bd75Srobert 
instCombineSVEVectorMul(InstCombiner & IC,IntrinsicInst & II)1230*d415bd75Srobert static std::optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
123173471bf0Spatrick                                                             IntrinsicInst &II) {
123273471bf0Spatrick   auto *OpPredicate = II.getOperand(0);
123373471bf0Spatrick   auto *OpMultiplicand = II.getOperand(1);
123473471bf0Spatrick   auto *OpMultiplier = II.getOperand(2);
123573471bf0Spatrick 
123673471bf0Spatrick   IRBuilder<> Builder(II.getContext());
123773471bf0Spatrick   Builder.SetInsertPoint(&II);
123873471bf0Spatrick 
1239*d415bd75Srobert   // Return true if a given instruction is a unit splat value, false otherwise.
1240*d415bd75Srobert   auto IsUnitSplat = [](auto *I) {
1241*d415bd75Srobert     auto *SplatValue = getSplatValue(I);
1242*d415bd75Srobert     if (!SplatValue)
124373471bf0Spatrick       return false;
124473471bf0Spatrick     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
124573471bf0Spatrick   };
124673471bf0Spatrick 
124773471bf0Spatrick   // Return true if a given instruction is an aarch64_sve_dup intrinsic call
124873471bf0Spatrick   // with a unit splat value, false otherwise.
124973471bf0Spatrick   auto IsUnitDup = [](auto *I) {
125073471bf0Spatrick     auto *IntrI = dyn_cast<IntrinsicInst>(I);
125173471bf0Spatrick     if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup)
125273471bf0Spatrick       return false;
125373471bf0Spatrick 
125473471bf0Spatrick     auto *SplatValue = IntrI->getOperand(2);
125573471bf0Spatrick     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
125673471bf0Spatrick   };
125773471bf0Spatrick 
1258*d415bd75Srobert   if (IsUnitSplat(OpMultiplier)) {
1259*d415bd75Srobert     // [f]mul pg %n, (dupx 1) => %n
126073471bf0Spatrick     OpMultiplicand->takeName(&II);
126173471bf0Spatrick     return IC.replaceInstUsesWith(II, OpMultiplicand);
126273471bf0Spatrick   } else if (IsUnitDup(OpMultiplier)) {
1263*d415bd75Srobert     // [f]mul pg %n, (dup pg 1) => %n
126473471bf0Spatrick     auto *DupInst = cast<IntrinsicInst>(OpMultiplier);
126573471bf0Spatrick     auto *DupPg = DupInst->getOperand(1);
126673471bf0Spatrick     // TODO: this is naive. The optimization is still valid if DupPg
126773471bf0Spatrick     // 'encompasses' OpPredicate, not only if they're the same predicate.
126873471bf0Spatrick     if (OpPredicate == DupPg) {
126973471bf0Spatrick       OpMultiplicand->takeName(&II);
127073471bf0Spatrick       return IC.replaceInstUsesWith(II, OpMultiplicand);
127173471bf0Spatrick     }
127273471bf0Spatrick   }
127373471bf0Spatrick 
1274*d415bd75Srobert   return instCombineSVEVectorBinOp(IC, II);
127573471bf0Spatrick }
127673471bf0Spatrick 
instCombineSVEUnpack(InstCombiner & IC,IntrinsicInst & II)1277*d415bd75Srobert static std::optional<Instruction *> instCombineSVEUnpack(InstCombiner &IC,
1278*d415bd75Srobert                                                          IntrinsicInst &II) {
1279*d415bd75Srobert   IRBuilder<> Builder(II.getContext());
1280*d415bd75Srobert   Builder.SetInsertPoint(&II);
1281*d415bd75Srobert   Value *UnpackArg = II.getArgOperand(0);
1282*d415bd75Srobert   auto *RetTy = cast<ScalableVectorType>(II.getType());
1283*d415bd75Srobert   bool IsSigned = II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpkhi ||
1284*d415bd75Srobert                   II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpklo;
1285*d415bd75Srobert 
1286*d415bd75Srobert   // Hi = uunpkhi(splat(X)) --> Hi = splat(extend(X))
1287*d415bd75Srobert   // Lo = uunpklo(splat(X)) --> Lo = splat(extend(X))
1288*d415bd75Srobert   if (auto *ScalarArg = getSplatValue(UnpackArg)) {
1289*d415bd75Srobert     ScalarArg =
1290*d415bd75Srobert         Builder.CreateIntCast(ScalarArg, RetTy->getScalarType(), IsSigned);
1291*d415bd75Srobert     Value *NewVal =
1292*d415bd75Srobert         Builder.CreateVectorSplat(RetTy->getElementCount(), ScalarArg);
1293*d415bd75Srobert     NewVal->takeName(&II);
1294*d415bd75Srobert     return IC.replaceInstUsesWith(II, NewVal);
1295*d415bd75Srobert   }
1296*d415bd75Srobert 
1297*d415bd75Srobert   return std::nullopt;
1298*d415bd75Srobert }
instCombineSVETBL(InstCombiner & IC,IntrinsicInst & II)1299*d415bd75Srobert static std::optional<Instruction *> instCombineSVETBL(InstCombiner &IC,
130073471bf0Spatrick                                                       IntrinsicInst &II) {
130173471bf0Spatrick   auto *OpVal = II.getOperand(0);
130273471bf0Spatrick   auto *OpIndices = II.getOperand(1);
130373471bf0Spatrick   VectorType *VTy = cast<VectorType>(II.getType());
130473471bf0Spatrick 
1305*d415bd75Srobert   // Check whether OpIndices is a constant splat value < minimal element count
1306*d415bd75Srobert   // of result.
1307*d415bd75Srobert   auto *SplatValue = dyn_cast_or_null<ConstantInt>(getSplatValue(OpIndices));
130873471bf0Spatrick   if (!SplatValue ||
130973471bf0Spatrick       SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue()))
1310*d415bd75Srobert     return std::nullopt;
131173471bf0Spatrick 
131273471bf0Spatrick   // Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to
131373471bf0Spatrick   // splat_vector(extractelement(OpVal, SplatValue)) for further optimization.
131473471bf0Spatrick   IRBuilder<> Builder(II.getContext());
131573471bf0Spatrick   Builder.SetInsertPoint(&II);
131673471bf0Spatrick   auto *Extract = Builder.CreateExtractElement(OpVal, SplatValue);
131773471bf0Spatrick   auto *VectorSplat =
131873471bf0Spatrick       Builder.CreateVectorSplat(VTy->getElementCount(), Extract);
131973471bf0Spatrick 
132073471bf0Spatrick   VectorSplat->takeName(&II);
132173471bf0Spatrick   return IC.replaceInstUsesWith(II, VectorSplat);
132273471bf0Spatrick }
132373471bf0Spatrick 
instCombineSVEZip(InstCombiner & IC,IntrinsicInst & II)1324*d415bd75Srobert static std::optional<Instruction *> instCombineSVEZip(InstCombiner &IC,
1325*d415bd75Srobert                                                       IntrinsicInst &II) {
1326*d415bd75Srobert   // zip1(uzp1(A, B), uzp2(A, B)) --> A
1327*d415bd75Srobert   // zip2(uzp1(A, B), uzp2(A, B)) --> B
1328*d415bd75Srobert   Value *A, *B;
1329*d415bd75Srobert   if (match(II.getArgOperand(0),
1330*d415bd75Srobert             m_Intrinsic<Intrinsic::aarch64_sve_uzp1>(m_Value(A), m_Value(B))) &&
1331*d415bd75Srobert       match(II.getArgOperand(1), m_Intrinsic<Intrinsic::aarch64_sve_uzp2>(
1332*d415bd75Srobert                                      m_Specific(A), m_Specific(B))))
1333*d415bd75Srobert     return IC.replaceInstUsesWith(
1334*d415bd75Srobert         II, (II.getIntrinsicID() == Intrinsic::aarch64_sve_zip1 ? A : B));
1335*d415bd75Srobert 
1336*d415bd75Srobert   return std::nullopt;
1337*d415bd75Srobert }
1338*d415bd75Srobert 
1339*d415bd75Srobert static std::optional<Instruction *>
instCombineLD1GatherIndex(InstCombiner & IC,IntrinsicInst & II)1340*d415bd75Srobert instCombineLD1GatherIndex(InstCombiner &IC, IntrinsicInst &II) {
1341*d415bd75Srobert   Value *Mask = II.getOperand(0);
1342*d415bd75Srobert   Value *BasePtr = II.getOperand(1);
1343*d415bd75Srobert   Value *Index = II.getOperand(2);
1344*d415bd75Srobert   Type *Ty = II.getType();
1345*d415bd75Srobert   Value *PassThru = ConstantAggregateZero::get(Ty);
1346*d415bd75Srobert 
1347*d415bd75Srobert   // Contiguous gather => masked load.
1348*d415bd75Srobert   // (sve.ld1.gather.index Mask BasePtr (sve.index IndexBase 1))
1349*d415bd75Srobert   // => (masked.load (gep BasePtr IndexBase) Align Mask zeroinitializer)
1350*d415bd75Srobert   Value *IndexBase;
1351*d415bd75Srobert   if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
1352*d415bd75Srobert                        m_Value(IndexBase), m_SpecificInt(1)))) {
1353*d415bd75Srobert     IRBuilder<> Builder(II.getContext());
1354*d415bd75Srobert     Builder.SetInsertPoint(&II);
1355*d415bd75Srobert 
1356*d415bd75Srobert     Align Alignment =
1357*d415bd75Srobert         BasePtr->getPointerAlignment(II.getModule()->getDataLayout());
1358*d415bd75Srobert 
1359*d415bd75Srobert     Type *VecPtrTy = PointerType::getUnqual(Ty);
1360*d415bd75Srobert     Value *Ptr = Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(),
1361*d415bd75Srobert                                    BasePtr, IndexBase);
1362*d415bd75Srobert     Ptr = Builder.CreateBitCast(Ptr, VecPtrTy);
1363*d415bd75Srobert     CallInst *MaskedLoad =
1364*d415bd75Srobert         Builder.CreateMaskedLoad(Ty, Ptr, Alignment, Mask, PassThru);
1365*d415bd75Srobert     MaskedLoad->takeName(&II);
1366*d415bd75Srobert     return IC.replaceInstUsesWith(II, MaskedLoad);
1367*d415bd75Srobert   }
1368*d415bd75Srobert 
1369*d415bd75Srobert   return std::nullopt;
1370*d415bd75Srobert }
1371*d415bd75Srobert 
1372*d415bd75Srobert static std::optional<Instruction *>
instCombineST1ScatterIndex(InstCombiner & IC,IntrinsicInst & II)1373*d415bd75Srobert instCombineST1ScatterIndex(InstCombiner &IC, IntrinsicInst &II) {
1374*d415bd75Srobert   Value *Val = II.getOperand(0);
1375*d415bd75Srobert   Value *Mask = II.getOperand(1);
1376*d415bd75Srobert   Value *BasePtr = II.getOperand(2);
1377*d415bd75Srobert   Value *Index = II.getOperand(3);
1378*d415bd75Srobert   Type *Ty = Val->getType();
1379*d415bd75Srobert 
1380*d415bd75Srobert   // Contiguous scatter => masked store.
1381*d415bd75Srobert   // (sve.st1.scatter.index Value Mask BasePtr (sve.index IndexBase 1))
1382*d415bd75Srobert   // => (masked.store Value (gep BasePtr IndexBase) Align Mask)
1383*d415bd75Srobert   Value *IndexBase;
1384*d415bd75Srobert   if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
1385*d415bd75Srobert                        m_Value(IndexBase), m_SpecificInt(1)))) {
1386*d415bd75Srobert     IRBuilder<> Builder(II.getContext());
1387*d415bd75Srobert     Builder.SetInsertPoint(&II);
1388*d415bd75Srobert 
1389*d415bd75Srobert     Align Alignment =
1390*d415bd75Srobert         BasePtr->getPointerAlignment(II.getModule()->getDataLayout());
1391*d415bd75Srobert 
1392*d415bd75Srobert     Value *Ptr = Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(),
1393*d415bd75Srobert                                    BasePtr, IndexBase);
1394*d415bd75Srobert     Type *VecPtrTy = PointerType::getUnqual(Ty);
1395*d415bd75Srobert     Ptr = Builder.CreateBitCast(Ptr, VecPtrTy);
1396*d415bd75Srobert 
1397*d415bd75Srobert     (void)Builder.CreateMaskedStore(Val, Ptr, Alignment, Mask);
1398*d415bd75Srobert 
1399*d415bd75Srobert     return IC.eraseInstFromFunction(II);
1400*d415bd75Srobert   }
1401*d415bd75Srobert 
1402*d415bd75Srobert   return std::nullopt;
1403*d415bd75Srobert }
1404*d415bd75Srobert 
instCombineSVESDIV(InstCombiner & IC,IntrinsicInst & II)1405*d415bd75Srobert static std::optional<Instruction *> instCombineSVESDIV(InstCombiner &IC,
1406*d415bd75Srobert                                                        IntrinsicInst &II) {
1407*d415bd75Srobert   IRBuilder<> Builder(II.getContext());
1408*d415bd75Srobert   Builder.SetInsertPoint(&II);
1409*d415bd75Srobert   Type *Int32Ty = Builder.getInt32Ty();
1410*d415bd75Srobert   Value *Pred = II.getOperand(0);
1411*d415bd75Srobert   Value *Vec = II.getOperand(1);
1412*d415bd75Srobert   Value *DivVec = II.getOperand(2);
1413*d415bd75Srobert 
1414*d415bd75Srobert   Value *SplatValue = getSplatValue(DivVec);
1415*d415bd75Srobert   ConstantInt *SplatConstantInt = dyn_cast_or_null<ConstantInt>(SplatValue);
1416*d415bd75Srobert   if (!SplatConstantInt)
1417*d415bd75Srobert     return std::nullopt;
1418*d415bd75Srobert   APInt Divisor = SplatConstantInt->getValue();
1419*d415bd75Srobert 
1420*d415bd75Srobert   if (Divisor.isPowerOf2()) {
1421*d415bd75Srobert     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
1422*d415bd75Srobert     auto ASRD = Builder.CreateIntrinsic(
1423*d415bd75Srobert         Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
1424*d415bd75Srobert     return IC.replaceInstUsesWith(II, ASRD);
1425*d415bd75Srobert   }
1426*d415bd75Srobert   if (Divisor.isNegatedPowerOf2()) {
1427*d415bd75Srobert     Divisor.negate();
1428*d415bd75Srobert     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
1429*d415bd75Srobert     auto ASRD = Builder.CreateIntrinsic(
1430*d415bd75Srobert         Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
1431*d415bd75Srobert     auto NEG = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_neg,
1432*d415bd75Srobert                                        {ASRD->getType()}, {ASRD, Pred, ASRD});
1433*d415bd75Srobert     return IC.replaceInstUsesWith(II, NEG);
1434*d415bd75Srobert   }
1435*d415bd75Srobert 
1436*d415bd75Srobert   return std::nullopt;
1437*d415bd75Srobert }
1438*d415bd75Srobert 
SimplifyValuePattern(SmallVector<Value * > & Vec,bool AllowPoison)1439*d415bd75Srobert bool SimplifyValuePattern(SmallVector<Value *> &Vec, bool AllowPoison) {
1440*d415bd75Srobert   size_t VecSize = Vec.size();
1441*d415bd75Srobert   if (VecSize == 1)
1442*d415bd75Srobert     return true;
1443*d415bd75Srobert   if (!isPowerOf2_64(VecSize))
1444*d415bd75Srobert     return false;
1445*d415bd75Srobert   size_t HalfVecSize = VecSize / 2;
1446*d415bd75Srobert 
1447*d415bd75Srobert   for (auto LHS = Vec.begin(), RHS = Vec.begin() + HalfVecSize;
1448*d415bd75Srobert        RHS != Vec.end(); LHS++, RHS++) {
1449*d415bd75Srobert     if (*LHS != nullptr && *RHS != nullptr) {
1450*d415bd75Srobert       if (*LHS == *RHS)
1451*d415bd75Srobert         continue;
1452*d415bd75Srobert       else
1453*d415bd75Srobert         return false;
1454*d415bd75Srobert     }
1455*d415bd75Srobert     if (!AllowPoison)
1456*d415bd75Srobert       return false;
1457*d415bd75Srobert     if (*LHS == nullptr && *RHS != nullptr)
1458*d415bd75Srobert       *LHS = *RHS;
1459*d415bd75Srobert   }
1460*d415bd75Srobert 
1461*d415bd75Srobert   Vec.resize(HalfVecSize);
1462*d415bd75Srobert   SimplifyValuePattern(Vec, AllowPoison);
1463*d415bd75Srobert   return true;
1464*d415bd75Srobert }
1465*d415bd75Srobert 
1466*d415bd75Srobert // Try to simplify dupqlane patterns like dupqlane(f32 A, f32 B, f32 A, f32 B)
1467*d415bd75Srobert // to dupqlane(f64(C)) where C is A concatenated with B
instCombineSVEDupqLane(InstCombiner & IC,IntrinsicInst & II)1468*d415bd75Srobert static std::optional<Instruction *> instCombineSVEDupqLane(InstCombiner &IC,
1469*d415bd75Srobert                                                            IntrinsicInst &II) {
1470*d415bd75Srobert   Value *CurrentInsertElt = nullptr, *Default = nullptr;
1471*d415bd75Srobert   if (!match(II.getOperand(0),
1472*d415bd75Srobert              m_Intrinsic<Intrinsic::vector_insert>(
1473*d415bd75Srobert                  m_Value(Default), m_Value(CurrentInsertElt), m_Value())) ||
1474*d415bd75Srobert       !isa<FixedVectorType>(CurrentInsertElt->getType()))
1475*d415bd75Srobert     return std::nullopt;
1476*d415bd75Srobert   auto IIScalableTy = cast<ScalableVectorType>(II.getType());
1477*d415bd75Srobert 
1478*d415bd75Srobert   // Insert the scalars into a container ordered by InsertElement index
1479*d415bd75Srobert   SmallVector<Value *> Elts(IIScalableTy->getMinNumElements(), nullptr);
1480*d415bd75Srobert   while (auto InsertElt = dyn_cast<InsertElementInst>(CurrentInsertElt)) {
1481*d415bd75Srobert     auto Idx = cast<ConstantInt>(InsertElt->getOperand(2));
1482*d415bd75Srobert     Elts[Idx->getValue().getZExtValue()] = InsertElt->getOperand(1);
1483*d415bd75Srobert     CurrentInsertElt = InsertElt->getOperand(0);
1484*d415bd75Srobert   }
1485*d415bd75Srobert 
1486*d415bd75Srobert   bool AllowPoison =
1487*d415bd75Srobert       isa<PoisonValue>(CurrentInsertElt) && isa<PoisonValue>(Default);
1488*d415bd75Srobert   if (!SimplifyValuePattern(Elts, AllowPoison))
1489*d415bd75Srobert     return std::nullopt;
1490*d415bd75Srobert 
1491*d415bd75Srobert   // Rebuild the simplified chain of InsertElements. e.g. (a, b, a, b) as (a, b)
1492*d415bd75Srobert   IRBuilder<> Builder(II.getContext());
1493*d415bd75Srobert   Builder.SetInsertPoint(&II);
1494*d415bd75Srobert   Value *InsertEltChain = PoisonValue::get(CurrentInsertElt->getType());
1495*d415bd75Srobert   for (size_t I = 0; I < Elts.size(); I++) {
1496*d415bd75Srobert     if (Elts[I] == nullptr)
1497*d415bd75Srobert       continue;
1498*d415bd75Srobert     InsertEltChain = Builder.CreateInsertElement(InsertEltChain, Elts[I],
1499*d415bd75Srobert                                                  Builder.getInt64(I));
1500*d415bd75Srobert   }
1501*d415bd75Srobert   if (InsertEltChain == nullptr)
1502*d415bd75Srobert     return std::nullopt;
1503*d415bd75Srobert 
1504*d415bd75Srobert   // Splat the simplified sequence, e.g. (f16 a, f16 b, f16 c, f16 d) as one i64
1505*d415bd75Srobert   // value or (f16 a, f16 b) as one i32 value. This requires an InsertSubvector
1506*d415bd75Srobert   // be bitcast to a type wide enough to fit the sequence, be splatted, and then
1507*d415bd75Srobert   // be narrowed back to the original type.
1508*d415bd75Srobert   unsigned PatternWidth = IIScalableTy->getScalarSizeInBits() * Elts.size();
1509*d415bd75Srobert   unsigned PatternElementCount = IIScalableTy->getScalarSizeInBits() *
1510*d415bd75Srobert                                  IIScalableTy->getMinNumElements() /
1511*d415bd75Srobert                                  PatternWidth;
1512*d415bd75Srobert 
1513*d415bd75Srobert   IntegerType *WideTy = Builder.getIntNTy(PatternWidth);
1514*d415bd75Srobert   auto *WideScalableTy = ScalableVectorType::get(WideTy, PatternElementCount);
1515*d415bd75Srobert   auto *WideShuffleMaskTy =
1516*d415bd75Srobert       ScalableVectorType::get(Builder.getInt32Ty(), PatternElementCount);
1517*d415bd75Srobert 
1518*d415bd75Srobert   auto ZeroIdx = ConstantInt::get(Builder.getInt64Ty(), APInt(64, 0));
1519*d415bd75Srobert   auto InsertSubvector = Builder.CreateInsertVector(
1520*d415bd75Srobert       II.getType(), PoisonValue::get(II.getType()), InsertEltChain, ZeroIdx);
1521*d415bd75Srobert   auto WideBitcast =
1522*d415bd75Srobert       Builder.CreateBitOrPointerCast(InsertSubvector, WideScalableTy);
1523*d415bd75Srobert   auto WideShuffleMask = ConstantAggregateZero::get(WideShuffleMaskTy);
1524*d415bd75Srobert   auto WideShuffle = Builder.CreateShuffleVector(
1525*d415bd75Srobert       WideBitcast, PoisonValue::get(WideScalableTy), WideShuffleMask);
1526*d415bd75Srobert   auto NarrowBitcast =
1527*d415bd75Srobert       Builder.CreateBitOrPointerCast(WideShuffle, II.getType());
1528*d415bd75Srobert 
1529*d415bd75Srobert   return IC.replaceInstUsesWith(II, NarrowBitcast);
1530*d415bd75Srobert }
1531*d415bd75Srobert 
instCombineMaxMinNM(InstCombiner & IC,IntrinsicInst & II)1532*d415bd75Srobert static std::optional<Instruction *> instCombineMaxMinNM(InstCombiner &IC,
1533*d415bd75Srobert                                                         IntrinsicInst &II) {
1534*d415bd75Srobert   Value *A = II.getArgOperand(0);
1535*d415bd75Srobert   Value *B = II.getArgOperand(1);
1536*d415bd75Srobert   if (A == B)
1537*d415bd75Srobert     return IC.replaceInstUsesWith(II, A);
1538*d415bd75Srobert 
1539*d415bd75Srobert   return std::nullopt;
1540*d415bd75Srobert }
1541*d415bd75Srobert 
instCombineSVESrshl(InstCombiner & IC,IntrinsicInst & II)1542*d415bd75Srobert static std::optional<Instruction *> instCombineSVESrshl(InstCombiner &IC,
1543*d415bd75Srobert                                                         IntrinsicInst &II) {
1544*d415bd75Srobert   IRBuilder<> Builder(&II);
1545*d415bd75Srobert   Value *Pred = II.getOperand(0);
1546*d415bd75Srobert   Value *Vec = II.getOperand(1);
1547*d415bd75Srobert   Value *Shift = II.getOperand(2);
1548*d415bd75Srobert 
1549*d415bd75Srobert   // Convert SRSHL into the simpler LSL intrinsic when fed by an ABS intrinsic.
1550*d415bd75Srobert   Value *AbsPred, *MergedValue;
1551*d415bd75Srobert   if (!match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_sqabs>(
1552*d415bd75Srobert                       m_Value(MergedValue), m_Value(AbsPred), m_Value())) &&
1553*d415bd75Srobert       !match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_abs>(
1554*d415bd75Srobert                       m_Value(MergedValue), m_Value(AbsPred), m_Value())))
1555*d415bd75Srobert 
1556*d415bd75Srobert     return std::nullopt;
1557*d415bd75Srobert 
1558*d415bd75Srobert   // Transform is valid if any of the following are true:
1559*d415bd75Srobert   // * The ABS merge value is an undef or non-negative
1560*d415bd75Srobert   // * The ABS predicate is all active
1561*d415bd75Srobert   // * The ABS predicate and the SRSHL predicates are the same
1562*d415bd75Srobert   if (!isa<UndefValue>(MergedValue) && !match(MergedValue, m_NonNegative()) &&
1563*d415bd75Srobert       AbsPred != Pred && !isAllActivePredicate(AbsPred))
1564*d415bd75Srobert     return std::nullopt;
1565*d415bd75Srobert 
1566*d415bd75Srobert   // Only valid when the shift amount is non-negative, otherwise the rounding
1567*d415bd75Srobert   // behaviour of SRSHL cannot be ignored.
1568*d415bd75Srobert   if (!match(Shift, m_NonNegative()))
1569*d415bd75Srobert     return std::nullopt;
1570*d415bd75Srobert 
1571*d415bd75Srobert   auto LSL = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_lsl, {II.getType()},
1572*d415bd75Srobert                                      {Pred, Vec, Shift});
1573*d415bd75Srobert 
1574*d415bd75Srobert   return IC.replaceInstUsesWith(II, LSL);
1575*d415bd75Srobert }
1576*d415bd75Srobert 
1577*d415bd75Srobert std::optional<Instruction *>
instCombineIntrinsic(InstCombiner & IC,IntrinsicInst & II) const157873471bf0Spatrick AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
157973471bf0Spatrick                                      IntrinsicInst &II) const {
158073471bf0Spatrick   Intrinsic::ID IID = II.getIntrinsicID();
158173471bf0Spatrick   switch (IID) {
158273471bf0Spatrick   default:
158373471bf0Spatrick     break;
1584*d415bd75Srobert   case Intrinsic::aarch64_neon_fmaxnm:
1585*d415bd75Srobert   case Intrinsic::aarch64_neon_fminnm:
1586*d415bd75Srobert     return instCombineMaxMinNM(IC, II);
158773471bf0Spatrick   case Intrinsic::aarch64_sve_convert_from_svbool:
158873471bf0Spatrick     return instCombineConvertFromSVBool(IC, II);
158973471bf0Spatrick   case Intrinsic::aarch64_sve_dup:
159073471bf0Spatrick     return instCombineSVEDup(IC, II);
1591*d415bd75Srobert   case Intrinsic::aarch64_sve_dup_x:
1592*d415bd75Srobert     return instCombineSVEDupX(IC, II);
159373471bf0Spatrick   case Intrinsic::aarch64_sve_cmpne:
159473471bf0Spatrick   case Intrinsic::aarch64_sve_cmpne_wide:
159573471bf0Spatrick     return instCombineSVECmpNE(IC, II);
159673471bf0Spatrick   case Intrinsic::aarch64_sve_rdffr:
159773471bf0Spatrick     return instCombineRDFFR(IC, II);
159873471bf0Spatrick   case Intrinsic::aarch64_sve_lasta:
159973471bf0Spatrick   case Intrinsic::aarch64_sve_lastb:
160073471bf0Spatrick     return instCombineSVELast(IC, II);
1601*d415bd75Srobert   case Intrinsic::aarch64_sve_clasta_n:
1602*d415bd75Srobert   case Intrinsic::aarch64_sve_clastb_n:
1603*d415bd75Srobert     return instCombineSVECondLast(IC, II);
160473471bf0Spatrick   case Intrinsic::aarch64_sve_cntd:
160573471bf0Spatrick     return instCombineSVECntElts(IC, II, 2);
160673471bf0Spatrick   case Intrinsic::aarch64_sve_cntw:
160773471bf0Spatrick     return instCombineSVECntElts(IC, II, 4);
160873471bf0Spatrick   case Intrinsic::aarch64_sve_cnth:
160973471bf0Spatrick     return instCombineSVECntElts(IC, II, 8);
161073471bf0Spatrick   case Intrinsic::aarch64_sve_cntb:
161173471bf0Spatrick     return instCombineSVECntElts(IC, II, 16);
161273471bf0Spatrick   case Intrinsic::aarch64_sve_ptest_any:
161373471bf0Spatrick   case Intrinsic::aarch64_sve_ptest_first:
161473471bf0Spatrick   case Intrinsic::aarch64_sve_ptest_last:
161573471bf0Spatrick     return instCombineSVEPTest(IC, II);
161673471bf0Spatrick   case Intrinsic::aarch64_sve_mul:
161773471bf0Spatrick   case Intrinsic::aarch64_sve_fmul:
161873471bf0Spatrick     return instCombineSVEVectorMul(IC, II);
1619*d415bd75Srobert   case Intrinsic::aarch64_sve_fadd:
1620*d415bd75Srobert   case Intrinsic::aarch64_sve_add:
1621*d415bd75Srobert     return instCombineSVEVectorAdd(IC, II);
1622*d415bd75Srobert   case Intrinsic::aarch64_sve_fsub:
1623*d415bd75Srobert   case Intrinsic::aarch64_sve_sub:
1624*d415bd75Srobert     return instCombineSVEVectorSub(IC, II);
162573471bf0Spatrick   case Intrinsic::aarch64_sve_tbl:
162673471bf0Spatrick     return instCombineSVETBL(IC, II);
1627*d415bd75Srobert   case Intrinsic::aarch64_sve_uunpkhi:
1628*d415bd75Srobert   case Intrinsic::aarch64_sve_uunpklo:
1629*d415bd75Srobert   case Intrinsic::aarch64_sve_sunpkhi:
1630*d415bd75Srobert   case Intrinsic::aarch64_sve_sunpklo:
1631*d415bd75Srobert     return instCombineSVEUnpack(IC, II);
1632*d415bd75Srobert   case Intrinsic::aarch64_sve_zip1:
1633*d415bd75Srobert   case Intrinsic::aarch64_sve_zip2:
1634*d415bd75Srobert     return instCombineSVEZip(IC, II);
1635*d415bd75Srobert   case Intrinsic::aarch64_sve_ld1_gather_index:
1636*d415bd75Srobert     return instCombineLD1GatherIndex(IC, II);
1637*d415bd75Srobert   case Intrinsic::aarch64_sve_st1_scatter_index:
1638*d415bd75Srobert     return instCombineST1ScatterIndex(IC, II);
1639*d415bd75Srobert   case Intrinsic::aarch64_sve_ld1:
1640*d415bd75Srobert     return instCombineSVELD1(IC, II, DL);
1641*d415bd75Srobert   case Intrinsic::aarch64_sve_st1:
1642*d415bd75Srobert     return instCombineSVEST1(IC, II, DL);
1643*d415bd75Srobert   case Intrinsic::aarch64_sve_sdiv:
1644*d415bd75Srobert     return instCombineSVESDIV(IC, II);
1645*d415bd75Srobert   case Intrinsic::aarch64_sve_sel:
1646*d415bd75Srobert     return instCombineSVESel(IC, II);
1647*d415bd75Srobert   case Intrinsic::aarch64_sve_srshl:
1648*d415bd75Srobert     return instCombineSVESrshl(IC, II);
1649*d415bd75Srobert   case Intrinsic::aarch64_sve_dupq_lane:
1650*d415bd75Srobert     return instCombineSVEDupqLane(IC, II);
165173471bf0Spatrick   }
165273471bf0Spatrick 
1653*d415bd75Srobert   return std::nullopt;
1654*d415bd75Srobert }
1655*d415bd75Srobert 
simplifyDemandedVectorEltsIntrinsic(InstCombiner & IC,IntrinsicInst & II,APInt OrigDemandedElts,APInt & UndefElts,APInt & UndefElts2,APInt & UndefElts3,std::function<void (Instruction *,unsigned,APInt,APInt &)> SimplifyAndSetOp) const1656*d415bd75Srobert std::optional<Value *> AArch64TTIImpl::simplifyDemandedVectorEltsIntrinsic(
1657*d415bd75Srobert     InstCombiner &IC, IntrinsicInst &II, APInt OrigDemandedElts,
1658*d415bd75Srobert     APInt &UndefElts, APInt &UndefElts2, APInt &UndefElts3,
1659*d415bd75Srobert     std::function<void(Instruction *, unsigned, APInt, APInt &)>
1660*d415bd75Srobert         SimplifyAndSetOp) const {
1661*d415bd75Srobert   switch (II.getIntrinsicID()) {
1662*d415bd75Srobert   default:
1663*d415bd75Srobert     break;
1664*d415bd75Srobert   case Intrinsic::aarch64_neon_fcvtxn:
1665*d415bd75Srobert   case Intrinsic::aarch64_neon_rshrn:
1666*d415bd75Srobert   case Intrinsic::aarch64_neon_sqrshrn:
1667*d415bd75Srobert   case Intrinsic::aarch64_neon_sqrshrun:
1668*d415bd75Srobert   case Intrinsic::aarch64_neon_sqshrn:
1669*d415bd75Srobert   case Intrinsic::aarch64_neon_sqshrun:
1670*d415bd75Srobert   case Intrinsic::aarch64_neon_sqxtn:
1671*d415bd75Srobert   case Intrinsic::aarch64_neon_sqxtun:
1672*d415bd75Srobert   case Intrinsic::aarch64_neon_uqrshrn:
1673*d415bd75Srobert   case Intrinsic::aarch64_neon_uqshrn:
1674*d415bd75Srobert   case Intrinsic::aarch64_neon_uqxtn:
1675*d415bd75Srobert     SimplifyAndSetOp(&II, 0, OrigDemandedElts, UndefElts);
1676*d415bd75Srobert     break;
1677*d415bd75Srobert   }
1678*d415bd75Srobert 
1679*d415bd75Srobert   return std::nullopt;
1680*d415bd75Srobert }
1681*d415bd75Srobert 
1682*d415bd75Srobert TypeSize
getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const1683*d415bd75Srobert AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
1684*d415bd75Srobert   switch (K) {
1685*d415bd75Srobert   case TargetTransformInfo::RGK_Scalar:
1686*d415bd75Srobert     return TypeSize::getFixed(64);
1687*d415bd75Srobert   case TargetTransformInfo::RGK_FixedWidthVector:
1688*d415bd75Srobert     if (!ST->isStreamingSVEModeDisabled() &&
1689*d415bd75Srobert         !EnableFixedwidthAutovecInStreamingMode)
1690*d415bd75Srobert       return TypeSize::getFixed(0);
1691*d415bd75Srobert 
1692*d415bd75Srobert     if (ST->hasSVE())
1693*d415bd75Srobert       return TypeSize::getFixed(
1694*d415bd75Srobert           std::max(ST->getMinSVEVectorSizeInBits(), 128u));
1695*d415bd75Srobert 
1696*d415bd75Srobert     return TypeSize::getFixed(ST->hasNEON() ? 128 : 0);
1697*d415bd75Srobert   case TargetTransformInfo::RGK_ScalableVector:
1698*d415bd75Srobert     if (!ST->isStreamingSVEModeDisabled() && !EnableScalableAutovecInStreamingMode)
1699*d415bd75Srobert       return TypeSize::getScalable(0);
1700*d415bd75Srobert 
1701*d415bd75Srobert     return TypeSize::getScalable(ST->hasSVE() ? 128 : 0);
1702*d415bd75Srobert   }
1703*d415bd75Srobert   llvm_unreachable("Unsupported register kind");
170473471bf0Spatrick }
170573471bf0Spatrick 
isWideningInstruction(Type * DstTy,unsigned Opcode,ArrayRef<const Value * > Args)170609467b48Spatrick bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
170709467b48Spatrick                                            ArrayRef<const Value *> Args) {
170809467b48Spatrick 
170909467b48Spatrick   // A helper that returns a vector type from the given type. The number of
1710*d415bd75Srobert   // elements in type Ty determines the vector width.
171109467b48Spatrick   auto toVectorTy = [&](Type *ArgTy) {
171273471bf0Spatrick     return VectorType::get(ArgTy->getScalarType(),
171373471bf0Spatrick                            cast<VectorType>(DstTy)->getElementCount());
171409467b48Spatrick   };
171509467b48Spatrick 
171609467b48Spatrick   // Exit early if DstTy is not a vector type whose elements are at least
1717*d415bd75Srobert   // 16-bits wide. SVE doesn't generally have the same set of instructions to
1718*d415bd75Srobert   // perform an extend with the add/sub/mul. There are SMULLB style
1719*d415bd75Srobert   // instructions, but they operate on top/bottom, requiring some sort of lane
1720*d415bd75Srobert   // interleaving to be used with zext/sext.
1721*d415bd75Srobert   if (!useNeonVector(DstTy) || DstTy->getScalarSizeInBits() < 16)
172209467b48Spatrick     return false;
172309467b48Spatrick 
172409467b48Spatrick   // Determine if the operation has a widening variant. We consider both the
172509467b48Spatrick   // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
172609467b48Spatrick   // instructions.
172709467b48Spatrick   //
1728*d415bd75Srobert   // TODO: Add additional widening operations (e.g., shl, etc.) once we
172909467b48Spatrick   //       verify that their extending operands are eliminated during code
173009467b48Spatrick   //       generation.
173109467b48Spatrick   switch (Opcode) {
173209467b48Spatrick   case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
173309467b48Spatrick   case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
1734*d415bd75Srobert   case Instruction::Mul: // SMULL(2), UMULL(2)
173509467b48Spatrick     break;
173609467b48Spatrick   default:
173709467b48Spatrick     return false;
173809467b48Spatrick   }
173909467b48Spatrick 
174009467b48Spatrick   // To be a widening instruction (either the "wide" or "long" versions), the
1741*d415bd75Srobert   // second operand must be a sign- or zero extend.
174209467b48Spatrick   if (Args.size() != 2 ||
1743*d415bd75Srobert       (!isa<SExtInst>(Args[1]) && !isa<ZExtInst>(Args[1])))
174409467b48Spatrick     return false;
174509467b48Spatrick   auto *Extend = cast<CastInst>(Args[1]);
1746*d415bd75Srobert   auto *Arg0 = dyn_cast<CastInst>(Args[0]);
1747*d415bd75Srobert 
1748*d415bd75Srobert   // A mul only has a mull version (not like addw). Both operands need to be
1749*d415bd75Srobert   // extending and the same type.
1750*d415bd75Srobert   if (Opcode == Instruction::Mul &&
1751*d415bd75Srobert       (!Arg0 || Arg0->getOpcode() != Extend->getOpcode() ||
1752*d415bd75Srobert        Arg0->getOperand(0)->getType() != Extend->getOperand(0)->getType()))
1753*d415bd75Srobert     return false;
175409467b48Spatrick 
175509467b48Spatrick   // Legalize the destination type and ensure it can be used in a widening
175609467b48Spatrick   // operation.
1757*d415bd75Srobert   auto DstTyL = getTypeLegalizationCost(DstTy);
175809467b48Spatrick   unsigned DstElTySize = DstTyL.second.getScalarSizeInBits();
175909467b48Spatrick   if (!DstTyL.second.isVector() || DstElTySize != DstTy->getScalarSizeInBits())
176009467b48Spatrick     return false;
176109467b48Spatrick 
176209467b48Spatrick   // Legalize the source type and ensure it can be used in a widening
176309467b48Spatrick   // operation.
1764097a140dSpatrick   auto *SrcTy = toVectorTy(Extend->getSrcTy());
1765*d415bd75Srobert   auto SrcTyL = getTypeLegalizationCost(SrcTy);
176609467b48Spatrick   unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits();
176709467b48Spatrick   if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits())
176809467b48Spatrick     return false;
176909467b48Spatrick 
177009467b48Spatrick   // Get the total number of vector elements in the legalized types.
177173471bf0Spatrick   InstructionCost NumDstEls =
177273471bf0Spatrick       DstTyL.first * DstTyL.second.getVectorMinNumElements();
177373471bf0Spatrick   InstructionCost NumSrcEls =
177473471bf0Spatrick       SrcTyL.first * SrcTyL.second.getVectorMinNumElements();
177509467b48Spatrick 
177609467b48Spatrick   // Return true if the legalized types have the same number of vector elements
177709467b48Spatrick   // and the destination element type size is twice that of the source type.
177809467b48Spatrick   return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstElTySize;
177909467b48Spatrick }
178009467b48Spatrick 
getCastInstrCost(unsigned Opcode,Type * Dst,Type * Src,TTI::CastContextHint CCH,TTI::TargetCostKind CostKind,const Instruction * I)178173471bf0Spatrick InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
178273471bf0Spatrick                                                  Type *Src,
178373471bf0Spatrick                                                  TTI::CastContextHint CCH,
1784097a140dSpatrick                                                  TTI::TargetCostKind CostKind,
178509467b48Spatrick                                                  const Instruction *I) {
178609467b48Spatrick   int ISD = TLI->InstructionOpcodeToISD(Opcode);
178709467b48Spatrick   assert(ISD && "Invalid opcode");
178809467b48Spatrick 
178909467b48Spatrick   // If the cast is observable, and it is used by a widening instruction (e.g.,
179009467b48Spatrick   // uaddl, saddw, etc.), it may be free.
1791*d415bd75Srobert   if (I && I->hasOneUser()) {
179209467b48Spatrick     auto *SingleUser = cast<Instruction>(*I->user_begin());
179309467b48Spatrick     SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
179409467b48Spatrick     if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands)) {
179509467b48Spatrick       // If the cast is the second operand, it is free. We will generate either
179609467b48Spatrick       // a "wide" or "long" version of the widening instruction.
179709467b48Spatrick       if (I == SingleUser->getOperand(1))
179809467b48Spatrick         return 0;
179909467b48Spatrick       // If the cast is not the second operand, it will be free if it looks the
180009467b48Spatrick       // same as the second operand. In this case, we will generate a "long"
180109467b48Spatrick       // version of the widening instruction.
180209467b48Spatrick       if (auto *Cast = dyn_cast<CastInst>(SingleUser->getOperand(1)))
180309467b48Spatrick         if (I->getOpcode() == unsigned(Cast->getOpcode()) &&
180409467b48Spatrick             cast<CastInst>(I)->getSrcTy() == Cast->getSrcTy())
180509467b48Spatrick           return 0;
180609467b48Spatrick     }
180709467b48Spatrick   }
180809467b48Spatrick 
1809097a140dSpatrick   // TODO: Allow non-throughput costs that aren't binary.
181073471bf0Spatrick   auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost {
1811097a140dSpatrick     if (CostKind != TTI::TCK_RecipThroughput)
1812097a140dSpatrick       return Cost == 0 ? 0 : 1;
1813097a140dSpatrick     return Cost;
1814097a140dSpatrick   };
1815097a140dSpatrick 
181609467b48Spatrick   EVT SrcTy = TLI->getValueType(DL, Src);
181709467b48Spatrick   EVT DstTy = TLI->getValueType(DL, Dst);
181809467b48Spatrick 
181909467b48Spatrick   if (!SrcTy.isSimple() || !DstTy.isSimple())
182073471bf0Spatrick     return AdjustCost(
182173471bf0Spatrick         BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
182209467b48Spatrick 
182309467b48Spatrick   static const TypeConversionCostTblEntry
182409467b48Spatrick   ConversionTbl[] = {
1825*d415bd75Srobert     { ISD::TRUNCATE, MVT::v2i8,   MVT::v2i64,  1},  // xtn
1826*d415bd75Srobert     { ISD::TRUNCATE, MVT::v2i16,  MVT::v2i64,  1},  // xtn
1827*d415bd75Srobert     { ISD::TRUNCATE, MVT::v2i32,  MVT::v2i64,  1},  // xtn
1828*d415bd75Srobert     { ISD::TRUNCATE, MVT::v4i8,   MVT::v4i32,  1},  // xtn
1829*d415bd75Srobert     { ISD::TRUNCATE, MVT::v4i8,   MVT::v4i64,  3},  // 2 xtn + 1 uzp1
1830*d415bd75Srobert     { ISD::TRUNCATE, MVT::v4i16,  MVT::v4i32,  1},  // xtn
1831*d415bd75Srobert     { ISD::TRUNCATE, MVT::v4i16,  MVT::v4i64,  2},  // 1 uzp1 + 1 xtn
1832*d415bd75Srobert     { ISD::TRUNCATE, MVT::v4i32,  MVT::v4i64,  1},  // 1 uzp1
1833*d415bd75Srobert     { ISD::TRUNCATE, MVT::v8i8,   MVT::v8i16,  1},  // 1 xtn
1834*d415bd75Srobert     { ISD::TRUNCATE, MVT::v8i8,   MVT::v8i32,  2},  // 1 uzp1 + 1 xtn
1835*d415bd75Srobert     { ISD::TRUNCATE, MVT::v8i8,   MVT::v8i64,  4},  // 3 x uzp1 + xtn
1836*d415bd75Srobert     { ISD::TRUNCATE, MVT::v8i16,  MVT::v8i32,  1},  // 1 uzp1
1837*d415bd75Srobert     { ISD::TRUNCATE, MVT::v8i16,  MVT::v8i64,  3},  // 3 x uzp1
1838*d415bd75Srobert     { ISD::TRUNCATE, MVT::v8i32,  MVT::v8i64,  2},  // 2 x uzp1
1839*d415bd75Srobert     { ISD::TRUNCATE, MVT::v16i8,  MVT::v16i16, 1},  // uzp1
1840*d415bd75Srobert     { ISD::TRUNCATE, MVT::v16i8,  MVT::v16i32, 3},  // (2 + 1) x uzp1
1841*d415bd75Srobert     { ISD::TRUNCATE, MVT::v16i8,  MVT::v16i64, 7},  // (4 + 2 + 1) x uzp1
1842*d415bd75Srobert     { ISD::TRUNCATE, MVT::v16i16, MVT::v16i32, 2},  // 2 x uzp1
1843*d415bd75Srobert     { ISD::TRUNCATE, MVT::v16i16, MVT::v16i64, 6},  // (4 + 2) x uzp1
1844*d415bd75Srobert     { ISD::TRUNCATE, MVT::v16i32, MVT::v16i64, 4},  // 4 x uzp1
184509467b48Spatrick 
184673471bf0Spatrick     // Truncations on nxvmiN
184773471bf0Spatrick     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 1 },
184873471bf0Spatrick     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 1 },
184973471bf0Spatrick     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 1 },
185073471bf0Spatrick     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 1 },
185173471bf0Spatrick     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 1 },
185273471bf0Spatrick     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 2 },
185373471bf0Spatrick     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 1 },
185473471bf0Spatrick     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 3 },
185573471bf0Spatrick     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 5 },
185673471bf0Spatrick     { ISD::TRUNCATE, MVT::nxv16i1, MVT::nxv16i8, 1 },
185773471bf0Spatrick     { ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 1 },
185873471bf0Spatrick     { ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 1 },
185973471bf0Spatrick     { ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 1 },
186073471bf0Spatrick     { ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 2 },
186173471bf0Spatrick     { ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 3 },
186273471bf0Spatrick     { ISD::TRUNCATE, MVT::nxv8i32, MVT::nxv8i64, 6 },
186373471bf0Spatrick 
186409467b48Spatrick     // The number of shll instructions for the extension.
186509467b48Spatrick     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
186609467b48Spatrick     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
186709467b48Spatrick     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
186809467b48Spatrick     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
186909467b48Spatrick     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
187009467b48Spatrick     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
187109467b48Spatrick     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
187209467b48Spatrick     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
187309467b48Spatrick     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
187409467b48Spatrick     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
187509467b48Spatrick     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
187609467b48Spatrick     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
187709467b48Spatrick     { ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
187809467b48Spatrick     { ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
187909467b48Spatrick     { ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
188009467b48Spatrick     { ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
188109467b48Spatrick 
188209467b48Spatrick     // LowerVectorINT_TO_FP:
188309467b48Spatrick     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
188409467b48Spatrick     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
188509467b48Spatrick     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
188609467b48Spatrick     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
188709467b48Spatrick     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
188809467b48Spatrick     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
188909467b48Spatrick 
189009467b48Spatrick     // Complex: to v2f32
189109467b48Spatrick     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
189209467b48Spatrick     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
189309467b48Spatrick     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
189409467b48Spatrick     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
189509467b48Spatrick     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
189609467b48Spatrick     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
189709467b48Spatrick 
189809467b48Spatrick     // Complex: to v4f32
189909467b48Spatrick     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8,  4 },
190009467b48Spatrick     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
190109467b48Spatrick     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8,  3 },
190209467b48Spatrick     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
190309467b48Spatrick 
190409467b48Spatrick     // Complex: to v8f32
190509467b48Spatrick     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
190609467b48Spatrick     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
190709467b48Spatrick     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
190809467b48Spatrick     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
190909467b48Spatrick 
191009467b48Spatrick     // Complex: to v16f32
191109467b48Spatrick     { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
191209467b48Spatrick     { ISD::UINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
191309467b48Spatrick 
191409467b48Spatrick     // Complex: to v2f64
191509467b48Spatrick     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
191609467b48Spatrick     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
191709467b48Spatrick     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
191809467b48Spatrick     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
191909467b48Spatrick     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
192009467b48Spatrick     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
192109467b48Spatrick 
1922*d415bd75Srobert     // Complex: to v4f64
1923*d415bd75Srobert     { ISD::SINT_TO_FP, MVT::v4f64, MVT::v4i32,  4 },
1924*d415bd75Srobert     { ISD::UINT_TO_FP, MVT::v4f64, MVT::v4i32,  4 },
192509467b48Spatrick 
192609467b48Spatrick     // LowerVectorFP_TO_INT
192709467b48Spatrick     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f32, 1 },
192809467b48Spatrick     { ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f32, 1 },
192909467b48Spatrick     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f64, 1 },
193009467b48Spatrick     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f32, 1 },
193109467b48Spatrick     { ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f32, 1 },
193209467b48Spatrick     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f64, 1 },
193309467b48Spatrick 
193409467b48Spatrick     // Complex, from v2f32: legal type is v2i32 (no cost) or v2i64 (1 ext).
193509467b48Spatrick     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f32, 2 },
193609467b48Spatrick     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f32, 1 },
193709467b48Spatrick     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f32, 1 },
193809467b48Spatrick     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f32, 2 },
193909467b48Spatrick     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f32, 1 },
194009467b48Spatrick     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f32, 1 },
194109467b48Spatrick 
194209467b48Spatrick     // Complex, from v4f32: legal type is v4i16, 1 narrowing => ~2
194309467b48Spatrick     { ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f32, 2 },
194409467b48Spatrick     { ISD::FP_TO_SINT, MVT::v4i8,  MVT::v4f32, 2 },
194509467b48Spatrick     { ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2 },
194609467b48Spatrick     { ISD::FP_TO_UINT, MVT::v4i8,  MVT::v4f32, 2 },
194709467b48Spatrick 
194873471bf0Spatrick     // Complex, from nxv2f32.
194973471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
195073471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
195173471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
195273471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
195373471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
195473471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
195573471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
195673471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
195773471bf0Spatrick 
195809467b48Spatrick     // Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2.
195909467b48Spatrick     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2 },
196009467b48Spatrick     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2 },
196109467b48Spatrick     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f64, 2 },
196209467b48Spatrick     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2 },
196309467b48Spatrick     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2 },
196409467b48Spatrick     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f64, 2 },
196573471bf0Spatrick 
196673471bf0Spatrick     // Complex, from nxv2f64.
196773471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
196873471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
196973471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
197073471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
197173471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
197273471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
197373471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
197473471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
197573471bf0Spatrick 
197673471bf0Spatrick     // Complex, from nxv4f32.
197773471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
197873471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
197973471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
198073471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
198173471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
198273471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
198373471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
198473471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
198573471bf0Spatrick 
198673471bf0Spatrick     // Complex, from nxv8f64. Illegal -> illegal conversions not required.
198773471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
198873471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
198973471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
199073471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
199173471bf0Spatrick 
199273471bf0Spatrick     // Complex, from nxv4f64. Illegal -> illegal conversions not required.
199373471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
199473471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
199573471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
199673471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
199773471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
199873471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
199973471bf0Spatrick 
200073471bf0Spatrick     // Complex, from nxv8f32. Illegal -> illegal conversions not required.
200173471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
200273471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
200373471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
200473471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
200573471bf0Spatrick 
200673471bf0Spatrick     // Complex, from nxv8f16.
200773471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
200873471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
200973471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
201073471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
201173471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
201273471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
201373471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
201473471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
201573471bf0Spatrick 
201673471bf0Spatrick     // Complex, from nxv4f16.
201773471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
201873471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
201973471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
202073471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
202173471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
202273471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
202373471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
202473471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
202573471bf0Spatrick 
202673471bf0Spatrick     // Complex, from nxv2f16.
202773471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
202873471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
202973471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
203073471bf0Spatrick     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
203173471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
203273471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
203373471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
203473471bf0Spatrick     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
203573471bf0Spatrick 
203673471bf0Spatrick     // Truncate from nxvmf32 to nxvmf16.
203773471bf0Spatrick     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1 },
203873471bf0Spatrick     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1 },
203973471bf0Spatrick     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3 },
204073471bf0Spatrick 
204173471bf0Spatrick     // Truncate from nxvmf64 to nxvmf16.
204273471bf0Spatrick     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1 },
204373471bf0Spatrick     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3 },
204473471bf0Spatrick     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7 },
204573471bf0Spatrick 
204673471bf0Spatrick     // Truncate from nxvmf64 to nxvmf32.
204773471bf0Spatrick     { ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1 },
204873471bf0Spatrick     { ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3 },
204973471bf0Spatrick     { ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6 },
205073471bf0Spatrick 
205173471bf0Spatrick     // Extend from nxvmf16 to nxvmf32.
205273471bf0Spatrick     { ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1},
205373471bf0Spatrick     { ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1},
205473471bf0Spatrick     { ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2},
205573471bf0Spatrick 
205673471bf0Spatrick     // Extend from nxvmf16 to nxvmf64.
205773471bf0Spatrick     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1},
205873471bf0Spatrick     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2},
205973471bf0Spatrick     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4},
206073471bf0Spatrick 
206173471bf0Spatrick     // Extend from nxvmf32 to nxvmf64.
206273471bf0Spatrick     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1},
206373471bf0Spatrick     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2},
206473471bf0Spatrick     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6},
206573471bf0Spatrick 
2066*d415bd75Srobert     // Bitcasts from float to integer
2067*d415bd75Srobert     { ISD::BITCAST, MVT::nxv2f16, MVT::nxv2i16, 0 },
2068*d415bd75Srobert     { ISD::BITCAST, MVT::nxv4f16, MVT::nxv4i16, 0 },
2069*d415bd75Srobert     { ISD::BITCAST, MVT::nxv2f32, MVT::nxv2i32, 0 },
2070*d415bd75Srobert 
2071*d415bd75Srobert     // Bitcasts from integer to float
2072*d415bd75Srobert     { ISD::BITCAST, MVT::nxv2i16, MVT::nxv2f16, 0 },
2073*d415bd75Srobert     { ISD::BITCAST, MVT::nxv4i16, MVT::nxv4f16, 0 },
2074*d415bd75Srobert     { ISD::BITCAST, MVT::nxv2i32, MVT::nxv2f32, 0 },
207509467b48Spatrick   };
207609467b48Spatrick 
207709467b48Spatrick   if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD,
207809467b48Spatrick                                                  DstTy.getSimpleVT(),
207909467b48Spatrick                                                  SrcTy.getSimpleVT()))
2080097a140dSpatrick     return AdjustCost(Entry->Cost);
208109467b48Spatrick 
2082*d415bd75Srobert   static const TypeConversionCostTblEntry FP16Tbl[] = {
2083*d415bd75Srobert       {ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f16, 1}, // fcvtzs
2084*d415bd75Srobert       {ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f16, 1},
2085*d415bd75Srobert       {ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f16, 1}, // fcvtzs
2086*d415bd75Srobert       {ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f16, 1},
2087*d415bd75Srobert       {ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f16, 2}, // fcvtl+fcvtzs
2088*d415bd75Srobert       {ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f16, 2},
2089*d415bd75Srobert       {ISD::FP_TO_SINT, MVT::v8i8, MVT::v8f16, 2}, // fcvtzs+xtn
2090*d415bd75Srobert       {ISD::FP_TO_UINT, MVT::v8i8, MVT::v8f16, 2},
2091*d415bd75Srobert       {ISD::FP_TO_SINT, MVT::v8i16, MVT::v8f16, 1}, // fcvtzs
2092*d415bd75Srobert       {ISD::FP_TO_UINT, MVT::v8i16, MVT::v8f16, 1},
2093*d415bd75Srobert       {ISD::FP_TO_SINT, MVT::v8i32, MVT::v8f16, 4}, // 2*fcvtl+2*fcvtzs
2094*d415bd75Srobert       {ISD::FP_TO_UINT, MVT::v8i32, MVT::v8f16, 4},
2095*d415bd75Srobert       {ISD::FP_TO_SINT, MVT::v16i8, MVT::v16f16, 3}, // 2*fcvtzs+xtn
2096*d415bd75Srobert       {ISD::FP_TO_UINT, MVT::v16i8, MVT::v16f16, 3},
2097*d415bd75Srobert       {ISD::FP_TO_SINT, MVT::v16i16, MVT::v16f16, 2}, // 2*fcvtzs
2098*d415bd75Srobert       {ISD::FP_TO_UINT, MVT::v16i16, MVT::v16f16, 2},
2099*d415bd75Srobert       {ISD::FP_TO_SINT, MVT::v16i32, MVT::v16f16, 8}, // 4*fcvtl+4*fcvtzs
2100*d415bd75Srobert       {ISD::FP_TO_UINT, MVT::v16i32, MVT::v16f16, 8},
2101*d415bd75Srobert       {ISD::UINT_TO_FP, MVT::v8f16, MVT::v8i8, 2},   // ushll + ucvtf
2102*d415bd75Srobert       {ISD::SINT_TO_FP, MVT::v8f16, MVT::v8i8, 2},   // sshll + scvtf
2103*d415bd75Srobert       {ISD::UINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * ushl(2) + 2 * ucvtf
2104*d415bd75Srobert       {ISD::SINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * sshl(2) + 2 * scvtf
2105*d415bd75Srobert   };
2106*d415bd75Srobert 
2107*d415bd75Srobert   if (ST->hasFullFP16())
2108*d415bd75Srobert     if (const auto *Entry = ConvertCostTableLookup(
2109*d415bd75Srobert             FP16Tbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
2110*d415bd75Srobert       return AdjustCost(Entry->Cost);
2111*d415bd75Srobert 
211273471bf0Spatrick   return AdjustCost(
211373471bf0Spatrick       BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
211409467b48Spatrick }
211509467b48Spatrick 
getExtractWithExtendCost(unsigned Opcode,Type * Dst,VectorType * VecTy,unsigned Index)211673471bf0Spatrick InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode,
211773471bf0Spatrick                                                          Type *Dst,
211809467b48Spatrick                                                          VectorType *VecTy,
211909467b48Spatrick                                                          unsigned Index) {
212009467b48Spatrick 
212109467b48Spatrick   // Make sure we were given a valid extend opcode.
212209467b48Spatrick   assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) &&
212309467b48Spatrick          "Invalid opcode");
212409467b48Spatrick 
212509467b48Spatrick   // We are extending an element we extract from a vector, so the source type
212609467b48Spatrick   // of the extend is the element type of the vector.
212709467b48Spatrick   auto *Src = VecTy->getElementType();
212809467b48Spatrick 
212909467b48Spatrick   // Sign- and zero-extends are for integer types only.
213009467b48Spatrick   assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type");
213109467b48Spatrick 
213209467b48Spatrick   // Get the cost for the extract. We compute the cost (if any) for the extend
213309467b48Spatrick   // below.
2134*d415bd75Srobert   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2135*d415bd75Srobert   InstructionCost Cost = getVectorInstrCost(Instruction::ExtractElement, VecTy,
2136*d415bd75Srobert                                             CostKind, Index, nullptr, nullptr);
213709467b48Spatrick 
213809467b48Spatrick   // Legalize the types.
2139*d415bd75Srobert   auto VecLT = getTypeLegalizationCost(VecTy);
214009467b48Spatrick   auto DstVT = TLI->getValueType(DL, Dst);
214109467b48Spatrick   auto SrcVT = TLI->getValueType(DL, Src);
214209467b48Spatrick 
214309467b48Spatrick   // If the resulting type is still a vector and the destination type is legal,
214409467b48Spatrick   // we may get the extension for free. If not, get the default cost for the
214509467b48Spatrick   // extend.
214609467b48Spatrick   if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT))
214773471bf0Spatrick     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
214873471bf0Spatrick                                    CostKind);
214909467b48Spatrick 
215009467b48Spatrick   // The destination type should be larger than the element type. If not, get
215109467b48Spatrick   // the default cost for the extend.
215273471bf0Spatrick   if (DstVT.getFixedSizeInBits() < SrcVT.getFixedSizeInBits())
215373471bf0Spatrick     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
215473471bf0Spatrick                                    CostKind);
215509467b48Spatrick 
215609467b48Spatrick   switch (Opcode) {
215709467b48Spatrick   default:
215809467b48Spatrick     llvm_unreachable("Opcode should be either SExt or ZExt");
215909467b48Spatrick 
216009467b48Spatrick   // For sign-extends, we only need a smov, which performs the extension
216109467b48Spatrick   // automatically.
216209467b48Spatrick   case Instruction::SExt:
216309467b48Spatrick     return Cost;
216409467b48Spatrick 
216509467b48Spatrick   // For zero-extends, the extend is performed automatically by a umov unless
216609467b48Spatrick   // the destination type is i64 and the element type is i8 or i16.
216709467b48Spatrick   case Instruction::ZExt:
216809467b48Spatrick     if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u)
216909467b48Spatrick       return Cost;
217009467b48Spatrick   }
217109467b48Spatrick 
217209467b48Spatrick   // If we are unable to perform the extend for free, get the default cost.
217373471bf0Spatrick   return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
217473471bf0Spatrick                                  CostKind);
2175097a140dSpatrick }
2176097a140dSpatrick 
getCFInstrCost(unsigned Opcode,TTI::TargetCostKind CostKind,const Instruction * I)217773471bf0Spatrick InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
217873471bf0Spatrick                                                TTI::TargetCostKind CostKind,
217973471bf0Spatrick                                                const Instruction *I) {
2180097a140dSpatrick   if (CostKind != TTI::TCK_RecipThroughput)
2181097a140dSpatrick     return Opcode == Instruction::PHI ? 0 : 1;
2182097a140dSpatrick   assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind");
2183097a140dSpatrick   // Branches are assumed to be predicted.
2184097a140dSpatrick   return 0;
218509467b48Spatrick }
218609467b48Spatrick 
getVectorInstrCostHelper(Type * Val,unsigned Index,bool HasRealUse)2187*d415bd75Srobert InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(Type *Val,
2188*d415bd75Srobert                                                          unsigned Index,
2189*d415bd75Srobert                                                          bool HasRealUse) {
219009467b48Spatrick   assert(Val->isVectorTy() && "This must be a vector type");
219109467b48Spatrick 
219209467b48Spatrick   if (Index != -1U) {
219309467b48Spatrick     // Legalize the type.
2194*d415bd75Srobert     std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Val);
219509467b48Spatrick 
219609467b48Spatrick     // This type is legalized to a scalar type.
219709467b48Spatrick     if (!LT.second.isVector())
219809467b48Spatrick       return 0;
219909467b48Spatrick 
2200*d415bd75Srobert     // The type may be split. For fixed-width vectors we can normalize the
2201*d415bd75Srobert     // index to the new type.
2202*d415bd75Srobert     if (LT.second.isFixedLengthVector()) {
220309467b48Spatrick       unsigned Width = LT.second.getVectorNumElements();
220409467b48Spatrick       Index = Index % Width;
2205*d415bd75Srobert     }
220609467b48Spatrick 
220709467b48Spatrick     // The element at index zero is already inside the vector.
2208*d415bd75Srobert     // - For a physical (HasRealUse==true) insert-element or extract-element
2209*d415bd75Srobert     // instruction that extracts integers, an explicit FPR -> GPR move is
2210*d415bd75Srobert     // needed. So it has non-zero cost.
2211*d415bd75Srobert     // - For the rest of cases (virtual instruction or element type is float),
2212*d415bd75Srobert     // consider the instruction free.
2213*d415bd75Srobert     //
2214*d415bd75Srobert     // FIXME:
2215*d415bd75Srobert     // If the extract-element and insert-element instructions could be
2216*d415bd75Srobert     // simplified away (e.g., could be combined into users by looking at use-def
2217*d415bd75Srobert     // context), they have no cost. This is not done in the first place for
2218*d415bd75Srobert     // compile-time considerations.
2219*d415bd75Srobert     if (Index == 0 && (!HasRealUse || !Val->getScalarType()->isIntegerTy()))
222009467b48Spatrick       return 0;
222109467b48Spatrick   }
222209467b48Spatrick 
222309467b48Spatrick   // All other insert/extracts cost this much.
222409467b48Spatrick   return ST->getVectorInsertExtractBaseCost();
222509467b48Spatrick }
222609467b48Spatrick 
getVectorInstrCost(unsigned Opcode,Type * Val,TTI::TargetCostKind CostKind,unsigned Index,Value * Op0,Value * Op1)2227*d415bd75Srobert InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
2228*d415bd75Srobert                                                    TTI::TargetCostKind CostKind,
2229*d415bd75Srobert                                                    unsigned Index, Value *Op0,
2230*d415bd75Srobert                                                    Value *Op1) {
2231*d415bd75Srobert   return getVectorInstrCostHelper(Val, Index, false /* HasRealUse */);
2232*d415bd75Srobert }
2233*d415bd75Srobert 
getVectorInstrCost(const Instruction & I,Type * Val,TTI::TargetCostKind CostKind,unsigned Index)2234*d415bd75Srobert InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
2235*d415bd75Srobert                                                    Type *Val,
2236*d415bd75Srobert                                                    TTI::TargetCostKind CostKind,
2237*d415bd75Srobert                                                    unsigned Index) {
2238*d415bd75Srobert   return getVectorInstrCostHelper(Val, Index, true /* HasRealUse */);
2239*d415bd75Srobert }
2240*d415bd75Srobert 
getArithmeticInstrCost(unsigned Opcode,Type * Ty,TTI::TargetCostKind CostKind,TTI::OperandValueInfo Op1Info,TTI::OperandValueInfo Op2Info,ArrayRef<const Value * > Args,const Instruction * CxtI)224173471bf0Spatrick InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
2242097a140dSpatrick     unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
2243*d415bd75Srobert     TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
2244*d415bd75Srobert     ArrayRef<const Value *> Args,
224509467b48Spatrick     const Instruction *CxtI) {
2246*d415bd75Srobert 
2247097a140dSpatrick   // TODO: Handle more cost kinds.
2248097a140dSpatrick   if (CostKind != TTI::TCK_RecipThroughput)
2249*d415bd75Srobert     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
2250*d415bd75Srobert                                          Op2Info, Args, CxtI);
2251097a140dSpatrick 
225209467b48Spatrick   // Legalize the type.
2253*d415bd75Srobert   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
225409467b48Spatrick   int ISD = TLI->InstructionOpcodeToISD(Opcode);
225509467b48Spatrick 
225609467b48Spatrick   switch (ISD) {
225709467b48Spatrick   default:
2258*d415bd75Srobert     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
2259*d415bd75Srobert                                          Op2Info);
226009467b48Spatrick   case ISD::SDIV:
2261*d415bd75Srobert     if (Op2Info.isConstant() && Op2Info.isUniform() && Op2Info.isPowerOf2()) {
226209467b48Spatrick       // On AArch64, scalar signed division by constants power-of-two are
226309467b48Spatrick       // normally expanded to the sequence ADD + CMP + SELECT + SRA.
226409467b48Spatrick       // The OperandValue properties many not be same as that of previous
226509467b48Spatrick       // operation; conservatively assume OP_None.
2266*d415bd75Srobert       InstructionCost Cost = getArithmeticInstrCost(
2267*d415bd75Srobert           Instruction::Add, Ty, CostKind,
2268*d415bd75Srobert           Op1Info.getNoProps(), Op2Info.getNoProps());
2269097a140dSpatrick       Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind,
2270*d415bd75Srobert                                      Op1Info.getNoProps(), Op2Info.getNoProps());
2271*d415bd75Srobert       Cost += getArithmeticInstrCost(
2272*d415bd75Srobert           Instruction::Select, Ty, CostKind,
2273*d415bd75Srobert           Op1Info.getNoProps(), Op2Info.getNoProps());
2274097a140dSpatrick       Cost += getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
2275*d415bd75Srobert                                      Op1Info.getNoProps(), Op2Info.getNoProps());
227609467b48Spatrick       return Cost;
227709467b48Spatrick     }
2278*d415bd75Srobert     [[fallthrough]];
2279*d415bd75Srobert   case ISD::UDIV: {
2280*d415bd75Srobert     if (Op2Info.isConstant() && Op2Info.isUniform()) {
228109467b48Spatrick       auto VT = TLI->getValueType(DL, Ty);
228209467b48Spatrick       if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT)) {
228309467b48Spatrick         // Vector signed division by constant are expanded to the
228409467b48Spatrick         // sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division
228509467b48Spatrick         // to MULHS + SUB + SRL + ADD + SRL.
228673471bf0Spatrick         InstructionCost MulCost = getArithmeticInstrCost(
2287*d415bd75Srobert             Instruction::Mul, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
228873471bf0Spatrick         InstructionCost AddCost = getArithmeticInstrCost(
2289*d415bd75Srobert             Instruction::Add, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
229073471bf0Spatrick         InstructionCost ShrCost = getArithmeticInstrCost(
2291*d415bd75Srobert             Instruction::AShr, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
229209467b48Spatrick         return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
229309467b48Spatrick       }
229409467b48Spatrick     }
229509467b48Spatrick 
2296*d415bd75Srobert     InstructionCost Cost = BaseT::getArithmeticInstrCost(
2297*d415bd75Srobert         Opcode, Ty, CostKind, Op1Info, Op2Info);
229809467b48Spatrick     if (Ty->isVectorTy()) {
2299*d415bd75Srobert       if (TLI->isOperationLegalOrCustom(ISD, LT.second) && ST->hasSVE()) {
2300*d415bd75Srobert         // SDIV/UDIV operations are lowered using SVE, then we can have less
2301*d415bd75Srobert         // costs.
2302*d415bd75Srobert         if (isa<FixedVectorType>(Ty) && cast<FixedVectorType>(Ty)
2303*d415bd75Srobert                                                 ->getPrimitiveSizeInBits()
2304*d415bd75Srobert                                                 .getFixedValue() < 128) {
2305*d415bd75Srobert           EVT VT = TLI->getValueType(DL, Ty);
2306*d415bd75Srobert           static const CostTblEntry DivTbl[]{
2307*d415bd75Srobert               {ISD::SDIV, MVT::v2i8, 5},  {ISD::SDIV, MVT::v4i8, 8},
2308*d415bd75Srobert               {ISD::SDIV, MVT::v8i8, 8},  {ISD::SDIV, MVT::v2i16, 5},
2309*d415bd75Srobert               {ISD::SDIV, MVT::v4i16, 5}, {ISD::SDIV, MVT::v2i32, 1},
2310*d415bd75Srobert               {ISD::UDIV, MVT::v2i8, 5},  {ISD::UDIV, MVT::v4i8, 8},
2311*d415bd75Srobert               {ISD::UDIV, MVT::v8i8, 8},  {ISD::UDIV, MVT::v2i16, 5},
2312*d415bd75Srobert               {ISD::UDIV, MVT::v4i16, 5}, {ISD::UDIV, MVT::v2i32, 1}};
2313*d415bd75Srobert 
2314*d415bd75Srobert           const auto *Entry = CostTableLookup(DivTbl, ISD, VT.getSimpleVT());
2315*d415bd75Srobert           if (nullptr != Entry)
2316*d415bd75Srobert             return Entry->Cost;
2317*d415bd75Srobert         }
2318*d415bd75Srobert         // For 8/16-bit elements, the cost is higher because the type
2319*d415bd75Srobert         // requires promotion and possibly splitting:
2320*d415bd75Srobert         if (LT.second.getScalarType() == MVT::i8)
2321*d415bd75Srobert           Cost *= 8;
2322*d415bd75Srobert         else if (LT.second.getScalarType() == MVT::i16)
2323*d415bd75Srobert           Cost *= 4;
2324*d415bd75Srobert         return Cost;
2325*d415bd75Srobert       } else {
2326*d415bd75Srobert         // If one of the operands is a uniform constant then the cost for each
2327*d415bd75Srobert         // element is Cost for insertion, extraction and division.
2328*d415bd75Srobert         // Insertion cost = 2, Extraction Cost = 2, Division = cost for the
2329*d415bd75Srobert         // operation with scalar type
2330*d415bd75Srobert         if ((Op1Info.isConstant() && Op1Info.isUniform()) ||
2331*d415bd75Srobert             (Op2Info.isConstant() && Op2Info.isUniform())) {
2332*d415bd75Srobert           if (auto *VTy = dyn_cast<FixedVectorType>(Ty)) {
2333*d415bd75Srobert             InstructionCost DivCost = BaseT::getArithmeticInstrCost(
2334*d415bd75Srobert                 Opcode, Ty->getScalarType(), CostKind, Op1Info, Op2Info);
2335*d415bd75Srobert             return (4 + DivCost) * VTy->getNumElements();
2336*d415bd75Srobert           }
2337*d415bd75Srobert         }
2338*d415bd75Srobert         // On AArch64, without SVE, vector divisions are expanded
2339*d415bd75Srobert         // into scalar divisions of each pair of elements.
2340*d415bd75Srobert         Cost += getArithmeticInstrCost(Instruction::ExtractElement, Ty,
2341*d415bd75Srobert                                        CostKind, Op1Info, Op2Info);
2342097a140dSpatrick         Cost += getArithmeticInstrCost(Instruction::InsertElement, Ty, CostKind,
2343*d415bd75Srobert                                        Op1Info, Op2Info);
2344*d415bd75Srobert       }
2345*d415bd75Srobert 
234609467b48Spatrick       // TODO: if one of the arguments is scalar, then it's not necessary to
234709467b48Spatrick       // double the cost of handling the vector elements.
234809467b48Spatrick       Cost += Cost;
234909467b48Spatrick     }
235009467b48Spatrick     return Cost;
2351*d415bd75Srobert   }
235209467b48Spatrick   case ISD::MUL:
2353*d415bd75Srobert     // When SVE is available, then we can lower the v2i64 operation using
2354*d415bd75Srobert     // the SVE mul instruction, which has a lower cost.
2355*d415bd75Srobert     if (LT.second == MVT::v2i64 && ST->hasSVE())
2356*d415bd75Srobert       return LT.first;
2357*d415bd75Srobert 
2358*d415bd75Srobert     // When SVE is not available, there is no MUL.2d instruction,
2359*d415bd75Srobert     // which means mul <2 x i64> is expensive as elements are extracted
2360*d415bd75Srobert     // from the vectors and the muls scalarized.
2361*d415bd75Srobert     // As getScalarizationOverhead is a bit too pessimistic, we
2362*d415bd75Srobert     // estimate the cost for a i64 vector directly here, which is:
2363*d415bd75Srobert     // - four 2-cost i64 extracts,
2364*d415bd75Srobert     // - two 2-cost i64 inserts, and
2365*d415bd75Srobert     // - two 1-cost muls.
2366*d415bd75Srobert     // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
2367*d415bd75Srobert     // LT.first = 2 the cost is 28. If both operands are extensions it will not
2368*d415bd75Srobert     // need to scalarize so the cost can be cheaper (smull or umull).
2369*d415bd75Srobert     // so the cost can be cheaper (smull or umull).
2370*d415bd75Srobert     if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args))
2371*d415bd75Srobert       return LT.first;
2372*d415bd75Srobert     return LT.first * 14;
237373471bf0Spatrick   case ISD::ADD:
237409467b48Spatrick   case ISD::XOR:
237509467b48Spatrick   case ISD::OR:
237609467b48Spatrick   case ISD::AND:
2377*d415bd75Srobert   case ISD::SRL:
2378*d415bd75Srobert   case ISD::SRA:
2379*d415bd75Srobert   case ISD::SHL:
238009467b48Spatrick     // These nodes are marked as 'custom' for combining purposes only.
238109467b48Spatrick     // We know that they are legal. See LowerAdd in ISelLowering.
2382*d415bd75Srobert     return LT.first;
2383097a140dSpatrick 
2384097a140dSpatrick   case ISD::FADD:
2385*d415bd75Srobert   case ISD::FSUB:
2386*d415bd75Srobert   case ISD::FMUL:
2387*d415bd75Srobert   case ISD::FDIV:
2388*d415bd75Srobert   case ISD::FNEG:
2389097a140dSpatrick     // These nodes are marked as 'custom' just to lower them to SVE.
2390097a140dSpatrick     // We know said lowering will incur no additional cost.
2391*d415bd75Srobert     if (!Ty->getScalarType()->isFP128Ty())
2392*d415bd75Srobert       return 2 * LT.first;
2393097a140dSpatrick 
2394*d415bd75Srobert     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
2395*d415bd75Srobert                                          Op2Info);
239609467b48Spatrick   }
239709467b48Spatrick }
239809467b48Spatrick 
getAddressComputationCost(Type * Ty,ScalarEvolution * SE,const SCEV * Ptr)239973471bf0Spatrick InstructionCost AArch64TTIImpl::getAddressComputationCost(Type *Ty,
240073471bf0Spatrick                                                           ScalarEvolution *SE,
240109467b48Spatrick                                                           const SCEV *Ptr) {
240209467b48Spatrick   // Address computations in vectorized code with non-consecutive addresses will
240309467b48Spatrick   // likely result in more instructions compared to scalar code where the
240409467b48Spatrick   // computation can more often be merged into the index mode. The resulting
240509467b48Spatrick   // extra micro-ops can significantly decrease throughput.
240609467b48Spatrick   unsigned NumVectorInstToHideOverhead = 10;
240709467b48Spatrick   int MaxMergeDistance = 64;
240809467b48Spatrick 
240909467b48Spatrick   if (Ty->isVectorTy() && SE &&
241009467b48Spatrick       !BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1))
241109467b48Spatrick     return NumVectorInstToHideOverhead;
241209467b48Spatrick 
241309467b48Spatrick   // In many cases the address computation is not merged into the instruction
241409467b48Spatrick   // addressing mode.
241509467b48Spatrick   return 1;
241609467b48Spatrick }
241709467b48Spatrick 
getCmpSelInstrCost(unsigned Opcode,Type * ValTy,Type * CondTy,CmpInst::Predicate VecPred,TTI::TargetCostKind CostKind,const Instruction * I)241873471bf0Spatrick InstructionCost AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
2419097a140dSpatrick                                                    Type *CondTy,
242073471bf0Spatrick                                                    CmpInst::Predicate VecPred,
2421097a140dSpatrick                                                    TTI::TargetCostKind CostKind,
2422097a140dSpatrick                                                    const Instruction *I) {
2423097a140dSpatrick   // TODO: Handle other cost kinds.
2424097a140dSpatrick   if (CostKind != TTI::TCK_RecipThroughput)
242573471bf0Spatrick     return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
242673471bf0Spatrick                                      I);
242709467b48Spatrick 
242809467b48Spatrick   int ISD = TLI->InstructionOpcodeToISD(Opcode);
242909467b48Spatrick   // We don't lower some vector selects well that are wider than the register
243009467b48Spatrick   // width.
243173471bf0Spatrick   if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) {
243209467b48Spatrick     // We would need this many instructions to hide the scalarization happening.
243309467b48Spatrick     const int AmortizationCost = 20;
243473471bf0Spatrick 
243573471bf0Spatrick     // If VecPred is not set, check if we can get a predicate from the context
243673471bf0Spatrick     // instruction, if its type matches the requested ValTy.
243773471bf0Spatrick     if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) {
243873471bf0Spatrick       CmpInst::Predicate CurrentPred;
243973471bf0Spatrick       if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(),
244073471bf0Spatrick                             m_Value())))
244173471bf0Spatrick         VecPred = CurrentPred;
244273471bf0Spatrick     }
2443*d415bd75Srobert     // Check if we have a compare/select chain that can be lowered using
2444*d415bd75Srobert     // a (F)CMxx & BFI pair.
2445*d415bd75Srobert     if (CmpInst::isIntPredicate(VecPred) || VecPred == CmpInst::FCMP_OLE ||
2446*d415bd75Srobert         VecPred == CmpInst::FCMP_OLT || VecPred == CmpInst::FCMP_OGT ||
2447*d415bd75Srobert         VecPred == CmpInst::FCMP_OGE || VecPred == CmpInst::FCMP_OEQ ||
2448*d415bd75Srobert         VecPred == CmpInst::FCMP_UNE) {
2449*d415bd75Srobert       static const auto ValidMinMaxTys = {
2450*d415bd75Srobert           MVT::v8i8,  MVT::v16i8, MVT::v4i16, MVT::v8i16, MVT::v2i32,
2451*d415bd75Srobert           MVT::v4i32, MVT::v2i64, MVT::v2f32, MVT::v4f32, MVT::v2f64};
2452*d415bd75Srobert       static const auto ValidFP16MinMaxTys = {MVT::v4f16, MVT::v8f16};
2453*d415bd75Srobert 
2454*d415bd75Srobert       auto LT = getTypeLegalizationCost(ValTy);
2455*d415bd75Srobert       if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }) ||
2456*d415bd75Srobert           (ST->hasFullFP16() &&
2457*d415bd75Srobert            any_of(ValidFP16MinMaxTys, [&LT](MVT M) { return M == LT.second; })))
245873471bf0Spatrick         return LT.first;
245973471bf0Spatrick     }
246073471bf0Spatrick 
246109467b48Spatrick     static const TypeConversionCostTblEntry
246209467b48Spatrick     VectorSelectTbl[] = {
246309467b48Spatrick       { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
246409467b48Spatrick       { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
246509467b48Spatrick       { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
246609467b48Spatrick       { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
246709467b48Spatrick       { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
246809467b48Spatrick       { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
246909467b48Spatrick     };
247009467b48Spatrick 
247109467b48Spatrick     EVT SelCondTy = TLI->getValueType(DL, CondTy);
247209467b48Spatrick     EVT SelValTy = TLI->getValueType(DL, ValTy);
247309467b48Spatrick     if (SelCondTy.isSimple() && SelValTy.isSimple()) {
247409467b48Spatrick       if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD,
247509467b48Spatrick                                                      SelCondTy.getSimpleVT(),
247609467b48Spatrick                                                      SelValTy.getSimpleVT()))
247709467b48Spatrick         return Entry->Cost;
247809467b48Spatrick     }
247909467b48Spatrick   }
248073471bf0Spatrick   // The base case handles scalable vectors fine for now, since it treats the
248173471bf0Spatrick   // cost as 1 * legalization cost.
248273471bf0Spatrick   return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
248309467b48Spatrick }
248409467b48Spatrick 
248509467b48Spatrick AArch64TTIImpl::TTI::MemCmpExpansionOptions
enableMemCmpExpansion(bool OptSize,bool IsZeroCmp) const248609467b48Spatrick AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
248709467b48Spatrick   TTI::MemCmpExpansionOptions Options;
2488097a140dSpatrick   if (ST->requiresStrictAlign()) {
2489097a140dSpatrick     // TODO: Add cost modeling for strict align. Misaligned loads expand to
2490097a140dSpatrick     // a bunch of instructions when strict align is enabled.
2491097a140dSpatrick     return Options;
2492097a140dSpatrick   }
2493097a140dSpatrick   Options.AllowOverlappingLoads = true;
249409467b48Spatrick   Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize);
249509467b48Spatrick   Options.NumLoadsPerBlock = Options.MaxNumLoads;
249609467b48Spatrick   // TODO: Though vector loads usually perform well on AArch64, in some targets
249709467b48Spatrick   // they may wake up the FP unit, which raises the power consumption.  Perhaps
249809467b48Spatrick   // they could be used with no holds barred (-O3).
249909467b48Spatrick   Options.LoadSizes = {8, 4, 2, 1};
250009467b48Spatrick   return Options;
250109467b48Spatrick }
250209467b48Spatrick 
prefersVectorizedAddressing() const2503*d415bd75Srobert bool AArch64TTIImpl::prefersVectorizedAddressing() const {
2504*d415bd75Srobert   return ST->hasSVE();
2505*d415bd75Srobert }
2506*d415bd75Srobert 
250773471bf0Spatrick InstructionCost
getMaskedMemoryOpCost(unsigned Opcode,Type * Src,Align Alignment,unsigned AddressSpace,TTI::TargetCostKind CostKind)250873471bf0Spatrick AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
250973471bf0Spatrick                                       Align Alignment, unsigned AddressSpace,
251073471bf0Spatrick                                       TTI::TargetCostKind CostKind) {
2511*d415bd75Srobert   if (useNeonVector(Src))
251273471bf0Spatrick     return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
251373471bf0Spatrick                                         CostKind);
2514*d415bd75Srobert   auto LT = getTypeLegalizationCost(Src);
251573471bf0Spatrick   if (!LT.first.isValid())
251673471bf0Spatrick     return InstructionCost::getInvalid();
251773471bf0Spatrick 
251873471bf0Spatrick   // The code-generator is currently not able to handle scalable vectors
251973471bf0Spatrick   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
252073471bf0Spatrick   // it. This change will be removed when code-generation for these types is
252173471bf0Spatrick   // sufficiently reliable.
252273471bf0Spatrick   if (cast<VectorType>(Src)->getElementCount() == ElementCount::getScalable(1))
252373471bf0Spatrick     return InstructionCost::getInvalid();
252473471bf0Spatrick 
2525*d415bd75Srobert   return LT.first;
2526*d415bd75Srobert }
2527*d415bd75Srobert 
getSVEGatherScatterOverhead(unsigned Opcode)2528*d415bd75Srobert static unsigned getSVEGatherScatterOverhead(unsigned Opcode) {
2529*d415bd75Srobert   return Opcode == Instruction::Load ? SVEGatherOverhead : SVEScatterOverhead;
253073471bf0Spatrick }
253173471bf0Spatrick 
getGatherScatterOpCost(unsigned Opcode,Type * DataTy,const Value * Ptr,bool VariableMask,Align Alignment,TTI::TargetCostKind CostKind,const Instruction * I)253273471bf0Spatrick InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
253373471bf0Spatrick     unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
253473471bf0Spatrick     Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
2535*d415bd75Srobert   if (useNeonVector(DataTy))
253673471bf0Spatrick     return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
253773471bf0Spatrick                                          Alignment, CostKind, I);
253873471bf0Spatrick   auto *VT = cast<VectorType>(DataTy);
2539*d415bd75Srobert   auto LT = getTypeLegalizationCost(DataTy);
254073471bf0Spatrick   if (!LT.first.isValid())
254173471bf0Spatrick     return InstructionCost::getInvalid();
254273471bf0Spatrick 
254373471bf0Spatrick   // The code-generator is currently not able to handle scalable vectors
254473471bf0Spatrick   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
254573471bf0Spatrick   // it. This change will be removed when code-generation for these types is
254673471bf0Spatrick   // sufficiently reliable.
254773471bf0Spatrick   if (cast<VectorType>(DataTy)->getElementCount() ==
254873471bf0Spatrick       ElementCount::getScalable(1))
254973471bf0Spatrick     return InstructionCost::getInvalid();
255073471bf0Spatrick 
255173471bf0Spatrick   ElementCount LegalVF = LT.second.getVectorElementCount();
255273471bf0Spatrick   InstructionCost MemOpCost =
2553*d415bd75Srobert       getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind,
2554*d415bd75Srobert                       {TTI::OK_AnyValue, TTI::OP_None}, I);
2555*d415bd75Srobert   // Add on an overhead cost for using gathers/scatters.
2556*d415bd75Srobert   // TODO: At the moment this is applied unilaterally for all CPUs, but at some
2557*d415bd75Srobert   // point we may want a per-CPU overhead.
2558*d415bd75Srobert   MemOpCost *= getSVEGatherScatterOverhead(Opcode);
255973471bf0Spatrick   return LT.first * MemOpCost * getMaxNumElements(LegalVF);
256073471bf0Spatrick }
256173471bf0Spatrick 
useNeonVector(const Type * Ty) const256273471bf0Spatrick bool AArch64TTIImpl::useNeonVector(const Type *Ty) const {
256373471bf0Spatrick   return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors();
256473471bf0Spatrick }
256573471bf0Spatrick 
getMemoryOpCost(unsigned Opcode,Type * Ty,MaybeAlign Alignment,unsigned AddressSpace,TTI::TargetCostKind CostKind,TTI::OperandValueInfo OpInfo,const Instruction * I)256673471bf0Spatrick InstructionCost AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
256773471bf0Spatrick                                                 MaybeAlign Alignment,
256873471bf0Spatrick                                                 unsigned AddressSpace,
2569097a140dSpatrick                                                 TTI::TargetCostKind CostKind,
2570*d415bd75Srobert                                                 TTI::OperandValueInfo OpInfo,
257109467b48Spatrick                                                 const Instruction *I) {
257273471bf0Spatrick   EVT VT = TLI->getValueType(DL, Ty, true);
2573097a140dSpatrick   // Type legalization can't handle structs
257473471bf0Spatrick   if (VT == MVT::Other)
2575097a140dSpatrick     return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
2576097a140dSpatrick                                   CostKind);
2577097a140dSpatrick 
2578*d415bd75Srobert   auto LT = getTypeLegalizationCost(Ty);
257973471bf0Spatrick   if (!LT.first.isValid())
258073471bf0Spatrick     return InstructionCost::getInvalid();
258173471bf0Spatrick 
258273471bf0Spatrick   // The code-generator is currently not able to handle scalable vectors
258373471bf0Spatrick   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
258473471bf0Spatrick   // it. This change will be removed when code-generation for these types is
258573471bf0Spatrick   // sufficiently reliable.
258673471bf0Spatrick   if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
258773471bf0Spatrick     if (VTy->getElementCount() == ElementCount::getScalable(1))
258873471bf0Spatrick       return InstructionCost::getInvalid();
258973471bf0Spatrick 
259073471bf0Spatrick   // TODO: consider latency as well for TCK_SizeAndLatency.
259173471bf0Spatrick   if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency)
259273471bf0Spatrick     return LT.first;
259373471bf0Spatrick 
259473471bf0Spatrick   if (CostKind != TTI::TCK_RecipThroughput)
259573471bf0Spatrick     return 1;
259609467b48Spatrick 
259709467b48Spatrick   if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store &&
259809467b48Spatrick       LT.second.is128BitVector() && (!Alignment || *Alignment < Align(16))) {
259909467b48Spatrick     // Unaligned stores are extremely inefficient. We don't split all
260009467b48Spatrick     // unaligned 128-bit stores because the negative impact that has shown in
260109467b48Spatrick     // practice on inlined block copy code.
260209467b48Spatrick     // We make such stores expensive so that we will only vectorize if there
260309467b48Spatrick     // are 6 other instructions getting vectorized.
260409467b48Spatrick     const int AmortizationCost = 6;
260509467b48Spatrick 
260609467b48Spatrick     return LT.first * 2 * AmortizationCost;
260709467b48Spatrick   }
260809467b48Spatrick 
2609*d415bd75Srobert   // Opaque ptr or ptr vector types are i64s and can be lowered to STP/LDPs.
2610*d415bd75Srobert   if (Ty->isPtrOrPtrVectorTy())
2611*d415bd75Srobert     return LT.first;
2612*d415bd75Srobert 
261373471bf0Spatrick   // Check truncating stores and extending loads.
261473471bf0Spatrick   if (useNeonVector(Ty) &&
261573471bf0Spatrick       Ty->getScalarSizeInBits() != LT.second.getScalarSizeInBits()) {
261673471bf0Spatrick     // v4i8 types are lowered to scalar a load/store and sshll/xtn.
261773471bf0Spatrick     if (VT == MVT::v4i8)
261873471bf0Spatrick       return 2;
261973471bf0Spatrick     // Otherwise we need to scalarize.
262073471bf0Spatrick     return cast<FixedVectorType>(Ty)->getNumElements() * 2;
262109467b48Spatrick   }
262209467b48Spatrick 
262309467b48Spatrick   return LT.first;
262409467b48Spatrick }
262509467b48Spatrick 
getInterleavedMemoryOpCost(unsigned Opcode,Type * VecTy,unsigned Factor,ArrayRef<unsigned> Indices,Align Alignment,unsigned AddressSpace,TTI::TargetCostKind CostKind,bool UseMaskForCond,bool UseMaskForGaps)262673471bf0Spatrick InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
2627097a140dSpatrick     unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
2628097a140dSpatrick     Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
2629097a140dSpatrick     bool UseMaskForCond, bool UseMaskForGaps) {
263009467b48Spatrick   assert(Factor >= 2 && "Invalid interleave factor");
2631097a140dSpatrick   auto *VecVTy = cast<FixedVectorType>(VecTy);
263209467b48Spatrick 
263309467b48Spatrick   if (!UseMaskForCond && !UseMaskForGaps &&
263409467b48Spatrick       Factor <= TLI->getMaxSupportedInterleaveFactor()) {
2635097a140dSpatrick     unsigned NumElts = VecVTy->getNumElements();
2636097a140dSpatrick     auto *SubVecTy =
2637097a140dSpatrick         FixedVectorType::get(VecTy->getScalarType(), NumElts / Factor);
263809467b48Spatrick 
263909467b48Spatrick     // ldN/stN only support legal vector types of size 64 or 128 in bits.
264009467b48Spatrick     // Accesses having vector types that are a multiple of 128 bits can be
264109467b48Spatrick     // matched to more than one ldN/stN instruction.
2642*d415bd75Srobert     bool UseScalable;
264309467b48Spatrick     if (NumElts % Factor == 0 &&
2644*d415bd75Srobert         TLI->isLegalInterleavedAccessType(SubVecTy, DL, UseScalable))
2645*d415bd75Srobert       return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL, UseScalable);
264609467b48Spatrick   }
264709467b48Spatrick 
264809467b48Spatrick   return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
2649097a140dSpatrick                                            Alignment, AddressSpace, CostKind,
265009467b48Spatrick                                            UseMaskForCond, UseMaskForGaps);
265109467b48Spatrick }
265209467b48Spatrick 
265373471bf0Spatrick InstructionCost
getCostOfKeepingLiveOverCall(ArrayRef<Type * > Tys)265473471bf0Spatrick AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
265573471bf0Spatrick   InstructionCost Cost = 0;
2656097a140dSpatrick   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
265709467b48Spatrick   for (auto *I : Tys) {
265809467b48Spatrick     if (!I->isVectorTy())
265909467b48Spatrick       continue;
2660097a140dSpatrick     if (I->getScalarSizeInBits() * cast<FixedVectorType>(I)->getNumElements() ==
2661097a140dSpatrick         128)
2662097a140dSpatrick       Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0, CostKind) +
2663097a140dSpatrick               getMemoryOpCost(Instruction::Load, I, Align(128), 0, CostKind);
266409467b48Spatrick   }
266509467b48Spatrick   return Cost;
266609467b48Spatrick }
266709467b48Spatrick 
getMaxInterleaveFactor(unsigned VF)266809467b48Spatrick unsigned AArch64TTIImpl::getMaxInterleaveFactor(unsigned VF) {
266909467b48Spatrick   return ST->getMaxInterleaveFactor();
267009467b48Spatrick }
267109467b48Spatrick 
267209467b48Spatrick // For Falkor, we want to avoid having too many strided loads in a loop since
267309467b48Spatrick // that can exhaust the HW prefetcher resources.  We adjust the unroller
267409467b48Spatrick // MaxCount preference below to attempt to ensure unrolling doesn't create too
267509467b48Spatrick // many strided loads.
267609467b48Spatrick static void
getFalkorUnrollingPreferences(Loop * L,ScalarEvolution & SE,TargetTransformInfo::UnrollingPreferences & UP)267709467b48Spatrick getFalkorUnrollingPreferences(Loop *L, ScalarEvolution &SE,
267809467b48Spatrick                               TargetTransformInfo::UnrollingPreferences &UP) {
267909467b48Spatrick   enum { MaxStridedLoads = 7 };
268009467b48Spatrick   auto countStridedLoads = [](Loop *L, ScalarEvolution &SE) {
268109467b48Spatrick     int StridedLoads = 0;
268209467b48Spatrick     // FIXME? We could make this more precise by looking at the CFG and
268309467b48Spatrick     // e.g. not counting loads in each side of an if-then-else diamond.
268409467b48Spatrick     for (const auto BB : L->blocks()) {
268509467b48Spatrick       for (auto &I : *BB) {
268609467b48Spatrick         LoadInst *LMemI = dyn_cast<LoadInst>(&I);
268709467b48Spatrick         if (!LMemI)
268809467b48Spatrick           continue;
268909467b48Spatrick 
269009467b48Spatrick         Value *PtrValue = LMemI->getPointerOperand();
269109467b48Spatrick         if (L->isLoopInvariant(PtrValue))
269209467b48Spatrick           continue;
269309467b48Spatrick 
269409467b48Spatrick         const SCEV *LSCEV = SE.getSCEV(PtrValue);
269509467b48Spatrick         const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
269609467b48Spatrick         if (!LSCEVAddRec || !LSCEVAddRec->isAffine())
269709467b48Spatrick           continue;
269809467b48Spatrick 
269909467b48Spatrick         // FIXME? We could take pairing of unrolled load copies into account
270009467b48Spatrick         // by looking at the AddRec, but we would probably have to limit this
270109467b48Spatrick         // to loops with no stores or other memory optimization barriers.
270209467b48Spatrick         ++StridedLoads;
270309467b48Spatrick         // We've seen enough strided loads that seeing more won't make a
270409467b48Spatrick         // difference.
270509467b48Spatrick         if (StridedLoads > MaxStridedLoads / 2)
270609467b48Spatrick           return StridedLoads;
270709467b48Spatrick       }
270809467b48Spatrick     }
270909467b48Spatrick     return StridedLoads;
271009467b48Spatrick   };
271109467b48Spatrick 
271209467b48Spatrick   int StridedLoads = countStridedLoads(L, SE);
271309467b48Spatrick   LLVM_DEBUG(dbgs() << "falkor-hwpf: detected " << StridedLoads
271409467b48Spatrick                     << " strided loads\n");
271509467b48Spatrick   // Pick the largest power of 2 unroll count that won't result in too many
271609467b48Spatrick   // strided loads.
271709467b48Spatrick   if (StridedLoads) {
271809467b48Spatrick     UP.MaxCount = 1 << Log2_32(MaxStridedLoads / StridedLoads);
271909467b48Spatrick     LLVM_DEBUG(dbgs() << "falkor-hwpf: setting unroll MaxCount to "
272009467b48Spatrick                       << UP.MaxCount << '\n');
272109467b48Spatrick   }
272209467b48Spatrick }
272309467b48Spatrick 
getUnrollingPreferences(Loop * L,ScalarEvolution & SE,TTI::UnrollingPreferences & UP,OptimizationRemarkEmitter * ORE)272409467b48Spatrick void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
2725*d415bd75Srobert                                              TTI::UnrollingPreferences &UP,
2726*d415bd75Srobert                                              OptimizationRemarkEmitter *ORE) {
272709467b48Spatrick   // Enable partial unrolling and runtime unrolling.
2728*d415bd75Srobert   BaseT::getUnrollingPreferences(L, SE, UP, ORE);
2729*d415bd75Srobert 
2730*d415bd75Srobert   UP.UpperBound = true;
273109467b48Spatrick 
273209467b48Spatrick   // For inner loop, it is more likely to be a hot one, and the runtime check
273309467b48Spatrick   // can be promoted out from LICM pass, so the overhead is less, let's try
273409467b48Spatrick   // a larger threshold to unroll more loops.
273509467b48Spatrick   if (L->getLoopDepth() > 1)
273609467b48Spatrick     UP.PartialThreshold *= 2;
273709467b48Spatrick 
273809467b48Spatrick   // Disable partial & runtime unrolling on -Os.
273909467b48Spatrick   UP.PartialOptSizeThreshold = 0;
274009467b48Spatrick 
274109467b48Spatrick   if (ST->getProcFamily() == AArch64Subtarget::Falkor &&
274209467b48Spatrick       EnableFalkorHWPFUnrollFix)
274309467b48Spatrick     getFalkorUnrollingPreferences(L, SE, UP);
274473471bf0Spatrick 
274573471bf0Spatrick   // Scan the loop: don't unroll loops with calls as this could prevent
274673471bf0Spatrick   // inlining. Don't unroll vector loops either, as they don't benefit much from
274773471bf0Spatrick   // unrolling.
274873471bf0Spatrick   for (auto *BB : L->getBlocks()) {
274973471bf0Spatrick     for (auto &I : *BB) {
275073471bf0Spatrick       // Don't unroll vectorised loop.
275173471bf0Spatrick       if (I.getType()->isVectorTy())
275273471bf0Spatrick         return;
275373471bf0Spatrick 
275473471bf0Spatrick       if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
275573471bf0Spatrick         if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
275673471bf0Spatrick           if (!isLoweredToCall(F))
275773471bf0Spatrick             continue;
275873471bf0Spatrick         }
275973471bf0Spatrick         return;
276073471bf0Spatrick       }
276173471bf0Spatrick     }
276273471bf0Spatrick   }
276373471bf0Spatrick 
276473471bf0Spatrick   // Enable runtime unrolling for in-order models
276573471bf0Spatrick   // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by
276673471bf0Spatrick   // checking for that case, we can ensure that the default behaviour is
276773471bf0Spatrick   // unchanged
276873471bf0Spatrick   if (ST->getProcFamily() != AArch64Subtarget::Others &&
276973471bf0Spatrick       !ST->getSchedModel().isOutOfOrder()) {
277073471bf0Spatrick     UP.Runtime = true;
277173471bf0Spatrick     UP.Partial = true;
277273471bf0Spatrick     UP.UnrollRemainder = true;
277373471bf0Spatrick     UP.DefaultUnrollRuntimeCount = 4;
277473471bf0Spatrick 
277573471bf0Spatrick     UP.UnrollAndJam = true;
277673471bf0Spatrick     UP.UnrollAndJamInnerLoopThreshold = 60;
277773471bf0Spatrick   }
277809467b48Spatrick }
277909467b48Spatrick 
getPeelingPreferences(Loop * L,ScalarEvolution & SE,TTI::PeelingPreferences & PP)2780097a140dSpatrick void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE,
2781097a140dSpatrick                                            TTI::PeelingPreferences &PP) {
2782097a140dSpatrick   BaseT::getPeelingPreferences(L, SE, PP);
2783097a140dSpatrick }
2784097a140dSpatrick 
getOrCreateResultFromMemIntrinsic(IntrinsicInst * Inst,Type * ExpectedType)278509467b48Spatrick Value *AArch64TTIImpl::getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
278609467b48Spatrick                                                          Type *ExpectedType) {
278709467b48Spatrick   switch (Inst->getIntrinsicID()) {
278809467b48Spatrick   default:
278909467b48Spatrick     return nullptr;
279009467b48Spatrick   case Intrinsic::aarch64_neon_st2:
279109467b48Spatrick   case Intrinsic::aarch64_neon_st3:
279209467b48Spatrick   case Intrinsic::aarch64_neon_st4: {
279309467b48Spatrick     // Create a struct type
279409467b48Spatrick     StructType *ST = dyn_cast<StructType>(ExpectedType);
279509467b48Spatrick     if (!ST)
279609467b48Spatrick       return nullptr;
2797*d415bd75Srobert     unsigned NumElts = Inst->arg_size() - 1;
279809467b48Spatrick     if (ST->getNumElements() != NumElts)
279909467b48Spatrick       return nullptr;
280009467b48Spatrick     for (unsigned i = 0, e = NumElts; i != e; ++i) {
280109467b48Spatrick       if (Inst->getArgOperand(i)->getType() != ST->getElementType(i))
280209467b48Spatrick         return nullptr;
280309467b48Spatrick     }
2804*d415bd75Srobert     Value *Res = PoisonValue::get(ExpectedType);
280509467b48Spatrick     IRBuilder<> Builder(Inst);
280609467b48Spatrick     for (unsigned i = 0, e = NumElts; i != e; ++i) {
280709467b48Spatrick       Value *L = Inst->getArgOperand(i);
280809467b48Spatrick       Res = Builder.CreateInsertValue(Res, L, i);
280909467b48Spatrick     }
281009467b48Spatrick     return Res;
281109467b48Spatrick   }
281209467b48Spatrick   case Intrinsic::aarch64_neon_ld2:
281309467b48Spatrick   case Intrinsic::aarch64_neon_ld3:
281409467b48Spatrick   case Intrinsic::aarch64_neon_ld4:
281509467b48Spatrick     if (Inst->getType() == ExpectedType)
281609467b48Spatrick       return Inst;
281709467b48Spatrick     return nullptr;
281809467b48Spatrick   }
281909467b48Spatrick }
282009467b48Spatrick 
getTgtMemIntrinsic(IntrinsicInst * Inst,MemIntrinsicInfo & Info)282109467b48Spatrick bool AArch64TTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
282209467b48Spatrick                                         MemIntrinsicInfo &Info) {
282309467b48Spatrick   switch (Inst->getIntrinsicID()) {
282409467b48Spatrick   default:
282509467b48Spatrick     break;
282609467b48Spatrick   case Intrinsic::aarch64_neon_ld2:
282709467b48Spatrick   case Intrinsic::aarch64_neon_ld3:
282809467b48Spatrick   case Intrinsic::aarch64_neon_ld4:
282909467b48Spatrick     Info.ReadMem = true;
283009467b48Spatrick     Info.WriteMem = false;
283109467b48Spatrick     Info.PtrVal = Inst->getArgOperand(0);
283209467b48Spatrick     break;
283309467b48Spatrick   case Intrinsic::aarch64_neon_st2:
283409467b48Spatrick   case Intrinsic::aarch64_neon_st3:
283509467b48Spatrick   case Intrinsic::aarch64_neon_st4:
283609467b48Spatrick     Info.ReadMem = false;
283709467b48Spatrick     Info.WriteMem = true;
2838*d415bd75Srobert     Info.PtrVal = Inst->getArgOperand(Inst->arg_size() - 1);
283909467b48Spatrick     break;
284009467b48Spatrick   }
284109467b48Spatrick 
284209467b48Spatrick   switch (Inst->getIntrinsicID()) {
284309467b48Spatrick   default:
284409467b48Spatrick     return false;
284509467b48Spatrick   case Intrinsic::aarch64_neon_ld2:
284609467b48Spatrick   case Intrinsic::aarch64_neon_st2:
284709467b48Spatrick     Info.MatchingId = VECTOR_LDST_TWO_ELEMENTS;
284809467b48Spatrick     break;
284909467b48Spatrick   case Intrinsic::aarch64_neon_ld3:
285009467b48Spatrick   case Intrinsic::aarch64_neon_st3:
285109467b48Spatrick     Info.MatchingId = VECTOR_LDST_THREE_ELEMENTS;
285209467b48Spatrick     break;
285309467b48Spatrick   case Intrinsic::aarch64_neon_ld4:
285409467b48Spatrick   case Intrinsic::aarch64_neon_st4:
285509467b48Spatrick     Info.MatchingId = VECTOR_LDST_FOUR_ELEMENTS;
285609467b48Spatrick     break;
285709467b48Spatrick   }
285809467b48Spatrick   return true;
285909467b48Spatrick }
286009467b48Spatrick 
286109467b48Spatrick /// See if \p I should be considered for address type promotion. We check if \p
286209467b48Spatrick /// I is a sext with right type and used in memory accesses. If it used in a
286309467b48Spatrick /// "complex" getelementptr, we allow it to be promoted without finding other
286409467b48Spatrick /// sext instructions that sign extended the same initial value. A getelementptr
286509467b48Spatrick /// is considered as "complex" if it has more than 2 operands.
shouldConsiderAddressTypePromotion(const Instruction & I,bool & AllowPromotionWithoutCommonHeader)286609467b48Spatrick bool AArch64TTIImpl::shouldConsiderAddressTypePromotion(
286709467b48Spatrick     const Instruction &I, bool &AllowPromotionWithoutCommonHeader) {
286809467b48Spatrick   bool Considerable = false;
286909467b48Spatrick   AllowPromotionWithoutCommonHeader = false;
287009467b48Spatrick   if (!isa<SExtInst>(&I))
287109467b48Spatrick     return false;
287209467b48Spatrick   Type *ConsideredSExtType =
287309467b48Spatrick       Type::getInt64Ty(I.getParent()->getParent()->getContext());
287409467b48Spatrick   if (I.getType() != ConsideredSExtType)
287509467b48Spatrick     return false;
287609467b48Spatrick   // See if the sext is the one with the right type and used in at least one
287709467b48Spatrick   // GetElementPtrInst.
287809467b48Spatrick   for (const User *U : I.users()) {
287909467b48Spatrick     if (const GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) {
288009467b48Spatrick       Considerable = true;
288109467b48Spatrick       // A getelementptr is considered as "complex" if it has more than 2
288209467b48Spatrick       // operands. We will promote a SExt used in such complex GEP as we
288309467b48Spatrick       // expect some computation to be merged if they are done on 64 bits.
288409467b48Spatrick       if (GEPInst->getNumOperands() > 2) {
288509467b48Spatrick         AllowPromotionWithoutCommonHeader = true;
288609467b48Spatrick         break;
288709467b48Spatrick       }
288809467b48Spatrick     }
288909467b48Spatrick   }
289009467b48Spatrick   return Considerable;
289109467b48Spatrick }
289209467b48Spatrick 
isLegalToVectorizeReduction(const RecurrenceDescriptor & RdxDesc,ElementCount VF) const289373471bf0Spatrick bool AArch64TTIImpl::isLegalToVectorizeReduction(
289473471bf0Spatrick     const RecurrenceDescriptor &RdxDesc, ElementCount VF) const {
289573471bf0Spatrick   if (!VF.isScalable())
289673471bf0Spatrick     return true;
289773471bf0Spatrick 
289873471bf0Spatrick   Type *Ty = RdxDesc.getRecurrenceType();
289973471bf0Spatrick   if (Ty->isBFloatTy() || !isElementTypeLegalForScalableVector(Ty))
290009467b48Spatrick     return false;
290173471bf0Spatrick 
290273471bf0Spatrick   switch (RdxDesc.getRecurrenceKind()) {
290373471bf0Spatrick   case RecurKind::Add:
290473471bf0Spatrick   case RecurKind::FAdd:
290573471bf0Spatrick   case RecurKind::And:
290673471bf0Spatrick   case RecurKind::Or:
290773471bf0Spatrick   case RecurKind::Xor:
290873471bf0Spatrick   case RecurKind::SMin:
290973471bf0Spatrick   case RecurKind::SMax:
291073471bf0Spatrick   case RecurKind::UMin:
291173471bf0Spatrick   case RecurKind::UMax:
291273471bf0Spatrick   case RecurKind::FMin:
291373471bf0Spatrick   case RecurKind::FMax:
2914*d415bd75Srobert   case RecurKind::SelectICmp:
2915*d415bd75Srobert   case RecurKind::SelectFCmp:
2916*d415bd75Srobert   case RecurKind::FMulAdd:
291773471bf0Spatrick     return true;
291809467b48Spatrick   default:
291909467b48Spatrick     return false;
292009467b48Spatrick   }
292173471bf0Spatrick }
292209467b48Spatrick 
292373471bf0Spatrick InstructionCost
getMinMaxReductionCost(VectorType * Ty,VectorType * CondTy,bool IsUnsigned,TTI::TargetCostKind CostKind)292473471bf0Spatrick AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
292573471bf0Spatrick                                        bool IsUnsigned,
2926097a140dSpatrick                                        TTI::TargetCostKind CostKind) {
2927*d415bd75Srobert   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
292809467b48Spatrick 
2929*d415bd75Srobert   if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
2930*d415bd75Srobert     return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind);
2931*d415bd75Srobert 
2932*d415bd75Srobert   assert((isa<ScalableVectorType>(Ty) == isa<ScalableVectorType>(CondTy)) &&
2933*d415bd75Srobert          "Both vector needs to be equally scalable");
2934*d415bd75Srobert 
293573471bf0Spatrick   InstructionCost LegalizationCost = 0;
293673471bf0Spatrick   if (LT.first > 1) {
293773471bf0Spatrick     Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
2938*d415bd75Srobert     unsigned MinMaxOpcode =
2939*d415bd75Srobert         Ty->isFPOrFPVectorTy()
2940*d415bd75Srobert             ? Intrinsic::maxnum
2941*d415bd75Srobert             : (IsUnsigned ? Intrinsic::umin : Intrinsic::smin);
2942*d415bd75Srobert     IntrinsicCostAttributes Attrs(MinMaxOpcode, LegalVTy, {LegalVTy, LegalVTy});
2943*d415bd75Srobert     LegalizationCost = getIntrinsicInstrCost(Attrs, CostKind) * (LT.first - 1);
294473471bf0Spatrick   }
294509467b48Spatrick 
294673471bf0Spatrick   return LegalizationCost + /*Cost of horizontal reduction*/ 2;
294773471bf0Spatrick }
294873471bf0Spatrick 
getArithmeticReductionCostSVE(unsigned Opcode,VectorType * ValTy,TTI::TargetCostKind CostKind)294973471bf0Spatrick InstructionCost AArch64TTIImpl::getArithmeticReductionCostSVE(
295073471bf0Spatrick     unsigned Opcode, VectorType *ValTy, TTI::TargetCostKind CostKind) {
2951*d415bd75Srobert   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
295273471bf0Spatrick   InstructionCost LegalizationCost = 0;
295373471bf0Spatrick   if (LT.first > 1) {
295473471bf0Spatrick     Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext());
295573471bf0Spatrick     LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind);
295673471bf0Spatrick     LegalizationCost *= LT.first - 1;
295773471bf0Spatrick   }
295873471bf0Spatrick 
295973471bf0Spatrick   int ISD = TLI->InstructionOpcodeToISD(Opcode);
296073471bf0Spatrick   assert(ISD && "Invalid opcode");
296173471bf0Spatrick   // Add the final reduction cost for the legal horizontal reduction
296273471bf0Spatrick   switch (ISD) {
296373471bf0Spatrick   case ISD::ADD:
296473471bf0Spatrick   case ISD::AND:
296573471bf0Spatrick   case ISD::OR:
296673471bf0Spatrick   case ISD::XOR:
296773471bf0Spatrick   case ISD::FADD:
296873471bf0Spatrick     return LegalizationCost + 2;
296973471bf0Spatrick   default:
297073471bf0Spatrick     return InstructionCost::getInvalid();
297173471bf0Spatrick   }
297273471bf0Spatrick }
297373471bf0Spatrick 
297473471bf0Spatrick InstructionCost
getArithmeticReductionCost(unsigned Opcode,VectorType * ValTy,std::optional<FastMathFlags> FMF,TTI::TargetCostKind CostKind)297573471bf0Spatrick AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
2976*d415bd75Srobert                                            std::optional<FastMathFlags> FMF,
297773471bf0Spatrick                                            TTI::TargetCostKind CostKind) {
297873471bf0Spatrick   if (TTI::requiresOrderedReduction(FMF)) {
2979*d415bd75Srobert     if (auto *FixedVTy = dyn_cast<FixedVectorType>(ValTy)) {
2980*d415bd75Srobert       InstructionCost BaseCost =
2981*d415bd75Srobert           BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
2982*d415bd75Srobert       // Add on extra cost to reflect the extra overhead on some CPUs. We still
2983*d415bd75Srobert       // end up vectorizing for more computationally intensive loops.
2984*d415bd75Srobert       return BaseCost + FixedVTy->getNumElements();
2985*d415bd75Srobert     }
298673471bf0Spatrick 
298773471bf0Spatrick     if (Opcode != Instruction::FAdd)
298873471bf0Spatrick       return InstructionCost::getInvalid();
298973471bf0Spatrick 
299073471bf0Spatrick     auto *VTy = cast<ScalableVectorType>(ValTy);
299173471bf0Spatrick     InstructionCost Cost =
299273471bf0Spatrick         getArithmeticInstrCost(Opcode, VTy->getScalarType(), CostKind);
299373471bf0Spatrick     Cost *= getMaxNumElements(VTy->getElementCount());
299473471bf0Spatrick     return Cost;
299573471bf0Spatrick   }
299673471bf0Spatrick 
299773471bf0Spatrick   if (isa<ScalableVectorType>(ValTy))
299873471bf0Spatrick     return getArithmeticReductionCostSVE(Opcode, ValTy, CostKind);
299973471bf0Spatrick 
3000*d415bd75Srobert   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
300109467b48Spatrick   MVT MTy = LT.second;
300209467b48Spatrick   int ISD = TLI->InstructionOpcodeToISD(Opcode);
300309467b48Spatrick   assert(ISD && "Invalid opcode");
300409467b48Spatrick 
300509467b48Spatrick   // Horizontal adds can use the 'addv' instruction. We model the cost of these
300673471bf0Spatrick   // instructions as twice a normal vector add, plus 1 for each legalization
300773471bf0Spatrick   // step (LT.first). This is the only arithmetic vector reduction operation for
300873471bf0Spatrick   // which we have an instruction.
300973471bf0Spatrick   // OR, XOR and AND costs should match the codegen from:
301073471bf0Spatrick   // OR: llvm/test/CodeGen/AArch64/reduce-or.ll
301173471bf0Spatrick   // XOR: llvm/test/CodeGen/AArch64/reduce-xor.ll
301273471bf0Spatrick   // AND: llvm/test/CodeGen/AArch64/reduce-and.ll
301309467b48Spatrick   static const CostTblEntry CostTblNoPairwise[]{
301473471bf0Spatrick       {ISD::ADD, MVT::v8i8,   2},
301573471bf0Spatrick       {ISD::ADD, MVT::v16i8,  2},
301673471bf0Spatrick       {ISD::ADD, MVT::v4i16,  2},
301773471bf0Spatrick       {ISD::ADD, MVT::v8i16,  2},
301873471bf0Spatrick       {ISD::ADD, MVT::v4i32,  2},
3019*d415bd75Srobert       {ISD::ADD, MVT::v2i64,  2},
302073471bf0Spatrick       {ISD::OR,  MVT::v8i8,  15},
302173471bf0Spatrick       {ISD::OR,  MVT::v16i8, 17},
302273471bf0Spatrick       {ISD::OR,  MVT::v4i16,  7},
302373471bf0Spatrick       {ISD::OR,  MVT::v8i16,  9},
302473471bf0Spatrick       {ISD::OR,  MVT::v2i32,  3},
302573471bf0Spatrick       {ISD::OR,  MVT::v4i32,  5},
302673471bf0Spatrick       {ISD::OR,  MVT::v2i64,  3},
302773471bf0Spatrick       {ISD::XOR, MVT::v8i8,  15},
302873471bf0Spatrick       {ISD::XOR, MVT::v16i8, 17},
302973471bf0Spatrick       {ISD::XOR, MVT::v4i16,  7},
303073471bf0Spatrick       {ISD::XOR, MVT::v8i16,  9},
303173471bf0Spatrick       {ISD::XOR, MVT::v2i32,  3},
303273471bf0Spatrick       {ISD::XOR, MVT::v4i32,  5},
303373471bf0Spatrick       {ISD::XOR, MVT::v2i64,  3},
303473471bf0Spatrick       {ISD::AND, MVT::v8i8,  15},
303573471bf0Spatrick       {ISD::AND, MVT::v16i8, 17},
303673471bf0Spatrick       {ISD::AND, MVT::v4i16,  7},
303773471bf0Spatrick       {ISD::AND, MVT::v8i16,  9},
303873471bf0Spatrick       {ISD::AND, MVT::v2i32,  3},
303973471bf0Spatrick       {ISD::AND, MVT::v4i32,  5},
304073471bf0Spatrick       {ISD::AND, MVT::v2i64,  3},
304109467b48Spatrick   };
304273471bf0Spatrick   switch (ISD) {
304373471bf0Spatrick   default:
304473471bf0Spatrick     break;
304573471bf0Spatrick   case ISD::ADD:
304609467b48Spatrick     if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy))
304773471bf0Spatrick       return (LT.first - 1) + Entry->Cost;
304873471bf0Spatrick     break;
304973471bf0Spatrick   case ISD::XOR:
305073471bf0Spatrick   case ISD::AND:
305173471bf0Spatrick   case ISD::OR:
305273471bf0Spatrick     const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy);
305373471bf0Spatrick     if (!Entry)
305473471bf0Spatrick       break;
305573471bf0Spatrick     auto *ValVTy = cast<FixedVectorType>(ValTy);
305673471bf0Spatrick     if (!ValVTy->getElementType()->isIntegerTy(1) &&
305773471bf0Spatrick         MTy.getVectorNumElements() <= ValVTy->getNumElements() &&
305873471bf0Spatrick         isPowerOf2_32(ValVTy->getNumElements())) {
305973471bf0Spatrick       InstructionCost ExtraCost = 0;
306073471bf0Spatrick       if (LT.first != 1) {
306173471bf0Spatrick         // Type needs to be split, so there is an extra cost of LT.first - 1
306273471bf0Spatrick         // arithmetic ops.
306373471bf0Spatrick         auto *Ty = FixedVectorType::get(ValTy->getElementType(),
306473471bf0Spatrick                                         MTy.getVectorNumElements());
306573471bf0Spatrick         ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind);
306673471bf0Spatrick         ExtraCost *= LT.first - 1;
306773471bf0Spatrick       }
306873471bf0Spatrick       return Entry->Cost + ExtraCost;
306973471bf0Spatrick     }
307073471bf0Spatrick     break;
307173471bf0Spatrick   }
307273471bf0Spatrick   return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
307309467b48Spatrick }
307409467b48Spatrick 
getSpliceCost(VectorType * Tp,int Index)307573471bf0Spatrick InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
307673471bf0Spatrick   static const CostTblEntry ShuffleTbl[] = {
307773471bf0Spatrick       { TTI::SK_Splice, MVT::nxv16i8,  1 },
307873471bf0Spatrick       { TTI::SK_Splice, MVT::nxv8i16,  1 },
307973471bf0Spatrick       { TTI::SK_Splice, MVT::nxv4i32,  1 },
308073471bf0Spatrick       { TTI::SK_Splice, MVT::nxv2i64,  1 },
308173471bf0Spatrick       { TTI::SK_Splice, MVT::nxv2f16,  1 },
308273471bf0Spatrick       { TTI::SK_Splice, MVT::nxv4f16,  1 },
308373471bf0Spatrick       { TTI::SK_Splice, MVT::nxv8f16,  1 },
308473471bf0Spatrick       { TTI::SK_Splice, MVT::nxv2bf16, 1 },
308573471bf0Spatrick       { TTI::SK_Splice, MVT::nxv4bf16, 1 },
308673471bf0Spatrick       { TTI::SK_Splice, MVT::nxv8bf16, 1 },
308773471bf0Spatrick       { TTI::SK_Splice, MVT::nxv2f32,  1 },
308873471bf0Spatrick       { TTI::SK_Splice, MVT::nxv4f32,  1 },
308973471bf0Spatrick       { TTI::SK_Splice, MVT::nxv2f64,  1 },
309073471bf0Spatrick   };
309173471bf0Spatrick 
3092*d415bd75Srobert   // The code-generator is currently not able to handle scalable vectors
3093*d415bd75Srobert   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
3094*d415bd75Srobert   // it. This change will be removed when code-generation for these types is
3095*d415bd75Srobert   // sufficiently reliable.
3096*d415bd75Srobert   if (Tp->getElementCount() == ElementCount::getScalable(1))
3097*d415bd75Srobert     return InstructionCost::getInvalid();
3098*d415bd75Srobert 
3099*d415bd75Srobert   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
310073471bf0Spatrick   Type *LegalVTy = EVT(LT.second).getTypeForEVT(Tp->getContext());
310173471bf0Spatrick   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
310273471bf0Spatrick   EVT PromotedVT = LT.second.getScalarType() == MVT::i1
310373471bf0Spatrick                        ? TLI->getPromotedVTForPredicate(EVT(LT.second))
310473471bf0Spatrick                        : LT.second;
310573471bf0Spatrick   Type *PromotedVTy = EVT(PromotedVT).getTypeForEVT(Tp->getContext());
310673471bf0Spatrick   InstructionCost LegalizationCost = 0;
310773471bf0Spatrick   if (Index < 0) {
310873471bf0Spatrick     LegalizationCost =
310973471bf0Spatrick         getCmpSelInstrCost(Instruction::ICmp, PromotedVTy, PromotedVTy,
311073471bf0Spatrick                            CmpInst::BAD_ICMP_PREDICATE, CostKind) +
311173471bf0Spatrick         getCmpSelInstrCost(Instruction::Select, PromotedVTy, LegalVTy,
311273471bf0Spatrick                            CmpInst::BAD_ICMP_PREDICATE, CostKind);
311373471bf0Spatrick   }
311473471bf0Spatrick 
311573471bf0Spatrick   // Predicated splice are promoted when lowering. See AArch64ISelLowering.cpp
311673471bf0Spatrick   // Cost performed on a promoted type.
311773471bf0Spatrick   if (LT.second.getScalarType() == MVT::i1) {
311873471bf0Spatrick     LegalizationCost +=
311973471bf0Spatrick         getCastInstrCost(Instruction::ZExt, PromotedVTy, LegalVTy,
312073471bf0Spatrick                          TTI::CastContextHint::None, CostKind) +
312173471bf0Spatrick         getCastInstrCost(Instruction::Trunc, LegalVTy, PromotedVTy,
312273471bf0Spatrick                          TTI::CastContextHint::None, CostKind);
312373471bf0Spatrick   }
312473471bf0Spatrick   const auto *Entry =
312573471bf0Spatrick       CostTableLookup(ShuffleTbl, TTI::SK_Splice, PromotedVT.getSimpleVT());
312673471bf0Spatrick   assert(Entry && "Illegal Type for Splice");
312773471bf0Spatrick   LegalizationCost += Entry->Cost;
312873471bf0Spatrick   return LegalizationCost * LT.first;
312973471bf0Spatrick }
313073471bf0Spatrick 
getShuffleCost(TTI::ShuffleKind Kind,VectorType * Tp,ArrayRef<int> Mask,TTI::TargetCostKind CostKind,int Index,VectorType * SubTp,ArrayRef<const Value * > Args)313173471bf0Spatrick InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
313273471bf0Spatrick                                                VectorType *Tp,
3133*d415bd75Srobert                                                ArrayRef<int> Mask,
3134*d415bd75Srobert                                                TTI::TargetCostKind CostKind,
3135*d415bd75Srobert                                                int Index, VectorType *SubTp,
3136*d415bd75Srobert                                                ArrayRef<const Value *> Args) {
3137*d415bd75Srobert   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
3138*d415bd75Srobert   // If we have a Mask, and the LT is being legalized somehow, split the Mask
3139*d415bd75Srobert   // into smaller vectors and sum the cost of each shuffle.
3140*d415bd75Srobert   if (!Mask.empty() && isa<FixedVectorType>(Tp) && LT.second.isVector() &&
3141*d415bd75Srobert       Tp->getScalarSizeInBits() == LT.second.getScalarSizeInBits() &&
3142*d415bd75Srobert       cast<FixedVectorType>(Tp)->getNumElements() >
3143*d415bd75Srobert           LT.second.getVectorNumElements() &&
3144*d415bd75Srobert       !Index && !SubTp) {
3145*d415bd75Srobert     unsigned TpNumElts = cast<FixedVectorType>(Tp)->getNumElements();
3146*d415bd75Srobert     assert(Mask.size() == TpNumElts && "Expected Mask and Tp size to match!");
3147*d415bd75Srobert     unsigned LTNumElts = LT.second.getVectorNumElements();
3148*d415bd75Srobert     unsigned NumVecs = (TpNumElts + LTNumElts - 1) / LTNumElts;
3149*d415bd75Srobert     VectorType *NTp =
3150*d415bd75Srobert         VectorType::get(Tp->getScalarType(), LT.second.getVectorElementCount());
3151*d415bd75Srobert     InstructionCost Cost;
3152*d415bd75Srobert     for (unsigned N = 0; N < NumVecs; N++) {
3153*d415bd75Srobert       SmallVector<int> NMask;
3154*d415bd75Srobert       // Split the existing mask into chunks of size LTNumElts. Track the source
3155*d415bd75Srobert       // sub-vectors to ensure the result has at most 2 inputs.
3156*d415bd75Srobert       unsigned Source1, Source2;
3157*d415bd75Srobert       unsigned NumSources = 0;
3158*d415bd75Srobert       for (unsigned E = 0; E < LTNumElts; E++) {
3159*d415bd75Srobert         int MaskElt = (N * LTNumElts + E < TpNumElts) ? Mask[N * LTNumElts + E]
3160*d415bd75Srobert                                                       : UndefMaskElem;
3161*d415bd75Srobert         if (MaskElt < 0) {
3162*d415bd75Srobert           NMask.push_back(UndefMaskElem);
3163*d415bd75Srobert           continue;
3164*d415bd75Srobert         }
3165*d415bd75Srobert 
3166*d415bd75Srobert         // Calculate which source from the input this comes from and whether it
3167*d415bd75Srobert         // is new to us.
3168*d415bd75Srobert         unsigned Source = MaskElt / LTNumElts;
3169*d415bd75Srobert         if (NumSources == 0) {
3170*d415bd75Srobert           Source1 = Source;
3171*d415bd75Srobert           NumSources = 1;
3172*d415bd75Srobert         } else if (NumSources == 1 && Source != Source1) {
3173*d415bd75Srobert           Source2 = Source;
3174*d415bd75Srobert           NumSources = 2;
3175*d415bd75Srobert         } else if (NumSources >= 2 && Source != Source1 && Source != Source2) {
3176*d415bd75Srobert           NumSources++;
3177*d415bd75Srobert         }
3178*d415bd75Srobert 
3179*d415bd75Srobert         // Add to the new mask. For the NumSources>2 case these are not correct,
3180*d415bd75Srobert         // but are only used for the modular lane number.
3181*d415bd75Srobert         if (Source == Source1)
3182*d415bd75Srobert           NMask.push_back(MaskElt % LTNumElts);
3183*d415bd75Srobert         else if (Source == Source2)
3184*d415bd75Srobert           NMask.push_back(MaskElt % LTNumElts + LTNumElts);
3185*d415bd75Srobert         else
3186*d415bd75Srobert           NMask.push_back(MaskElt % LTNumElts);
3187*d415bd75Srobert       }
3188*d415bd75Srobert       // If the sub-mask has at most 2 input sub-vectors then re-cost it using
3189*d415bd75Srobert       // getShuffleCost. If not then cost it using the worst case.
3190*d415bd75Srobert       if (NumSources <= 2)
3191*d415bd75Srobert         Cost += getShuffleCost(NumSources <= 1 ? TTI::SK_PermuteSingleSrc
3192*d415bd75Srobert                                                : TTI::SK_PermuteTwoSrc,
3193*d415bd75Srobert                                NTp, NMask, CostKind, 0, nullptr, Args);
3194*d415bd75Srobert       else if (any_of(enumerate(NMask), [&](const auto &ME) {
3195*d415bd75Srobert                  return ME.value() % LTNumElts == ME.index();
3196*d415bd75Srobert                }))
3197*d415bd75Srobert         Cost += LTNumElts - 1;
3198*d415bd75Srobert       else
3199*d415bd75Srobert         Cost += LTNumElts;
3200*d415bd75Srobert     }
3201*d415bd75Srobert     return Cost;
3202*d415bd75Srobert   }
3203*d415bd75Srobert 
320473471bf0Spatrick   Kind = improveShuffleKindFromMask(Kind, Mask);
3205*d415bd75Srobert 
3206*d415bd75Srobert   // Check for broadcast loads.
3207*d415bd75Srobert   if (Kind == TTI::SK_Broadcast) {
3208*d415bd75Srobert     bool IsLoad = !Args.empty() && isa<LoadInst>(Args[0]);
3209*d415bd75Srobert     if (IsLoad && LT.second.isVector() &&
3210*d415bd75Srobert         isLegalBroadcastLoad(Tp->getElementType(),
3211*d415bd75Srobert                              LT.second.getVectorElementCount()))
3212*d415bd75Srobert       return 0; // broadcast is handled by ld1r
3213*d415bd75Srobert   }
3214*d415bd75Srobert 
3215*d415bd75Srobert   // If we have 4 elements for the shuffle and a Mask, get the cost straight
3216*d415bd75Srobert   // from the perfect shuffle tables.
3217*d415bd75Srobert   if (Mask.size() == 4 && Tp->getElementCount() == ElementCount::getFixed(4) &&
3218*d415bd75Srobert       (Tp->getScalarSizeInBits() == 16 || Tp->getScalarSizeInBits() == 32) &&
3219*d415bd75Srobert       all_of(Mask, [](int E) { return E < 8; }))
3220*d415bd75Srobert     return getPerfectShuffleCost(Mask);
3221*d415bd75Srobert 
322209467b48Spatrick   if (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Transpose ||
322373471bf0Spatrick       Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc ||
3224*d415bd75Srobert       Kind == TTI::SK_Reverse || Kind == TTI::SK_Splice) {
322509467b48Spatrick     static const CostTblEntry ShuffleTbl[] = {
322609467b48Spatrick         // Broadcast shuffle kinds can be performed with 'dup'.
322709467b48Spatrick         {TTI::SK_Broadcast, MVT::v8i8, 1},
322809467b48Spatrick         {TTI::SK_Broadcast, MVT::v16i8, 1},
322909467b48Spatrick         {TTI::SK_Broadcast, MVT::v4i16, 1},
323009467b48Spatrick         {TTI::SK_Broadcast, MVT::v8i16, 1},
323109467b48Spatrick         {TTI::SK_Broadcast, MVT::v2i32, 1},
323209467b48Spatrick         {TTI::SK_Broadcast, MVT::v4i32, 1},
323309467b48Spatrick         {TTI::SK_Broadcast, MVT::v2i64, 1},
323409467b48Spatrick         {TTI::SK_Broadcast, MVT::v2f32, 1},
323509467b48Spatrick         {TTI::SK_Broadcast, MVT::v4f32, 1},
323609467b48Spatrick         {TTI::SK_Broadcast, MVT::v2f64, 1},
323709467b48Spatrick         // Transpose shuffle kinds can be performed with 'trn1/trn2' and
323809467b48Spatrick         // 'zip1/zip2' instructions.
323909467b48Spatrick         {TTI::SK_Transpose, MVT::v8i8, 1},
324009467b48Spatrick         {TTI::SK_Transpose, MVT::v16i8, 1},
324109467b48Spatrick         {TTI::SK_Transpose, MVT::v4i16, 1},
324209467b48Spatrick         {TTI::SK_Transpose, MVT::v8i16, 1},
324309467b48Spatrick         {TTI::SK_Transpose, MVT::v2i32, 1},
324409467b48Spatrick         {TTI::SK_Transpose, MVT::v4i32, 1},
324509467b48Spatrick         {TTI::SK_Transpose, MVT::v2i64, 1},
324609467b48Spatrick         {TTI::SK_Transpose, MVT::v2f32, 1},
324709467b48Spatrick         {TTI::SK_Transpose, MVT::v4f32, 1},
324809467b48Spatrick         {TTI::SK_Transpose, MVT::v2f64, 1},
324909467b48Spatrick         // Select shuffle kinds.
325009467b48Spatrick         // TODO: handle vXi8/vXi16.
325109467b48Spatrick         {TTI::SK_Select, MVT::v2i32, 1}, // mov.
325209467b48Spatrick         {TTI::SK_Select, MVT::v4i32, 2}, // rev+trn (or similar).
325309467b48Spatrick         {TTI::SK_Select, MVT::v2i64, 1}, // mov.
325409467b48Spatrick         {TTI::SK_Select, MVT::v2f32, 1}, // mov.
325509467b48Spatrick         {TTI::SK_Select, MVT::v4f32, 2}, // rev+trn (or similar).
325609467b48Spatrick         {TTI::SK_Select, MVT::v2f64, 1}, // mov.
325709467b48Spatrick         // PermuteSingleSrc shuffle kinds.
325809467b48Spatrick         {TTI::SK_PermuteSingleSrc, MVT::v2i32, 1}, // mov.
325909467b48Spatrick         {TTI::SK_PermuteSingleSrc, MVT::v4i32, 3}, // perfectshuffle worst case.
326009467b48Spatrick         {TTI::SK_PermuteSingleSrc, MVT::v2i64, 1}, // mov.
326109467b48Spatrick         {TTI::SK_PermuteSingleSrc, MVT::v2f32, 1}, // mov.
326209467b48Spatrick         {TTI::SK_PermuteSingleSrc, MVT::v4f32, 3}, // perfectshuffle worst case.
326309467b48Spatrick         {TTI::SK_PermuteSingleSrc, MVT::v2f64, 1}, // mov.
326473471bf0Spatrick         {TTI::SK_PermuteSingleSrc, MVT::v4i16, 3}, // perfectshuffle worst case.
326573471bf0Spatrick         {TTI::SK_PermuteSingleSrc, MVT::v4f16, 3}, // perfectshuffle worst case.
3266*d415bd75Srobert         {TTI::SK_PermuteSingleSrc, MVT::v4bf16, 3}, // same
326773471bf0Spatrick         {TTI::SK_PermuteSingleSrc, MVT::v8i16, 8},  // constpool + load + tbl
326873471bf0Spatrick         {TTI::SK_PermuteSingleSrc, MVT::v8f16, 8},  // constpool + load + tbl
326973471bf0Spatrick         {TTI::SK_PermuteSingleSrc, MVT::v8bf16, 8}, // constpool + load + tbl
327073471bf0Spatrick         {TTI::SK_PermuteSingleSrc, MVT::v8i8, 8},   // constpool + load + tbl
327173471bf0Spatrick         {TTI::SK_PermuteSingleSrc, MVT::v16i8, 8},  // constpool + load + tbl
327273471bf0Spatrick         // Reverse can be lowered with `rev`.
3273*d415bd75Srobert         {TTI::SK_Reverse, MVT::v2i32, 1}, // REV64
327473471bf0Spatrick         {TTI::SK_Reverse, MVT::v4i32, 2}, // REV64; EXT
3275*d415bd75Srobert         {TTI::SK_Reverse, MVT::v2i64, 1}, // EXT
3276*d415bd75Srobert         {TTI::SK_Reverse, MVT::v2f32, 1}, // REV64
327773471bf0Spatrick         {TTI::SK_Reverse, MVT::v4f32, 2}, // REV64; EXT
3278*d415bd75Srobert         {TTI::SK_Reverse, MVT::v2f64, 1}, // EXT
3279*d415bd75Srobert         {TTI::SK_Reverse, MVT::v8f16, 2}, // REV64; EXT
3280*d415bd75Srobert         {TTI::SK_Reverse, MVT::v8i16, 2}, // REV64; EXT
3281*d415bd75Srobert         {TTI::SK_Reverse, MVT::v16i8, 2}, // REV64; EXT
3282*d415bd75Srobert         {TTI::SK_Reverse, MVT::v4f16, 1}, // REV64
3283*d415bd75Srobert         {TTI::SK_Reverse, MVT::v4i16, 1}, // REV64
3284*d415bd75Srobert         {TTI::SK_Reverse, MVT::v8i8, 1},  // REV64
3285*d415bd75Srobert         // Splice can all be lowered as `ext`.
3286*d415bd75Srobert         {TTI::SK_Splice, MVT::v2i32, 1},
3287*d415bd75Srobert         {TTI::SK_Splice, MVT::v4i32, 1},
3288*d415bd75Srobert         {TTI::SK_Splice, MVT::v2i64, 1},
3289*d415bd75Srobert         {TTI::SK_Splice, MVT::v2f32, 1},
3290*d415bd75Srobert         {TTI::SK_Splice, MVT::v4f32, 1},
3291*d415bd75Srobert         {TTI::SK_Splice, MVT::v2f64, 1},
3292*d415bd75Srobert         {TTI::SK_Splice, MVT::v8f16, 1},
3293*d415bd75Srobert         {TTI::SK_Splice, MVT::v8bf16, 1},
3294*d415bd75Srobert         {TTI::SK_Splice, MVT::v8i16, 1},
3295*d415bd75Srobert         {TTI::SK_Splice, MVT::v16i8, 1},
3296*d415bd75Srobert         {TTI::SK_Splice, MVT::v4bf16, 1},
3297*d415bd75Srobert         {TTI::SK_Splice, MVT::v4f16, 1},
3298*d415bd75Srobert         {TTI::SK_Splice, MVT::v4i16, 1},
3299*d415bd75Srobert         {TTI::SK_Splice, MVT::v8i8, 1},
330073471bf0Spatrick         // Broadcast shuffle kinds for scalable vectors
330173471bf0Spatrick         {TTI::SK_Broadcast, MVT::nxv16i8, 1},
330273471bf0Spatrick         {TTI::SK_Broadcast, MVT::nxv8i16, 1},
330373471bf0Spatrick         {TTI::SK_Broadcast, MVT::nxv4i32, 1},
330473471bf0Spatrick         {TTI::SK_Broadcast, MVT::nxv2i64, 1},
330573471bf0Spatrick         {TTI::SK_Broadcast, MVT::nxv2f16, 1},
330673471bf0Spatrick         {TTI::SK_Broadcast, MVT::nxv4f16, 1},
330773471bf0Spatrick         {TTI::SK_Broadcast, MVT::nxv8f16, 1},
330873471bf0Spatrick         {TTI::SK_Broadcast, MVT::nxv2bf16, 1},
330973471bf0Spatrick         {TTI::SK_Broadcast, MVT::nxv4bf16, 1},
331073471bf0Spatrick         {TTI::SK_Broadcast, MVT::nxv8bf16, 1},
331173471bf0Spatrick         {TTI::SK_Broadcast, MVT::nxv2f32, 1},
331273471bf0Spatrick         {TTI::SK_Broadcast, MVT::nxv4f32, 1},
331373471bf0Spatrick         {TTI::SK_Broadcast, MVT::nxv2f64, 1},
331473471bf0Spatrick         {TTI::SK_Broadcast, MVT::nxv16i1, 1},
331573471bf0Spatrick         {TTI::SK_Broadcast, MVT::nxv8i1, 1},
331673471bf0Spatrick         {TTI::SK_Broadcast, MVT::nxv4i1, 1},
331773471bf0Spatrick         {TTI::SK_Broadcast, MVT::nxv2i1, 1},
331873471bf0Spatrick         // Handle the cases for vector.reverse with scalable vectors
331973471bf0Spatrick         {TTI::SK_Reverse, MVT::nxv16i8, 1},
332073471bf0Spatrick         {TTI::SK_Reverse, MVT::nxv8i16, 1},
332173471bf0Spatrick         {TTI::SK_Reverse, MVT::nxv4i32, 1},
332273471bf0Spatrick         {TTI::SK_Reverse, MVT::nxv2i64, 1},
332373471bf0Spatrick         {TTI::SK_Reverse, MVT::nxv2f16, 1},
332473471bf0Spatrick         {TTI::SK_Reverse, MVT::nxv4f16, 1},
332573471bf0Spatrick         {TTI::SK_Reverse, MVT::nxv8f16, 1},
332673471bf0Spatrick         {TTI::SK_Reverse, MVT::nxv2bf16, 1},
332773471bf0Spatrick         {TTI::SK_Reverse, MVT::nxv4bf16, 1},
332873471bf0Spatrick         {TTI::SK_Reverse, MVT::nxv8bf16, 1},
332973471bf0Spatrick         {TTI::SK_Reverse, MVT::nxv2f32, 1},
333073471bf0Spatrick         {TTI::SK_Reverse, MVT::nxv4f32, 1},
333173471bf0Spatrick         {TTI::SK_Reverse, MVT::nxv2f64, 1},
333273471bf0Spatrick         {TTI::SK_Reverse, MVT::nxv16i1, 1},
333373471bf0Spatrick         {TTI::SK_Reverse, MVT::nxv8i1, 1},
333473471bf0Spatrick         {TTI::SK_Reverse, MVT::nxv4i1, 1},
333573471bf0Spatrick         {TTI::SK_Reverse, MVT::nxv2i1, 1},
333609467b48Spatrick     };
333709467b48Spatrick     if (const auto *Entry = CostTableLookup(ShuffleTbl, Kind, LT.second))
333809467b48Spatrick       return LT.first * Entry->Cost;
333909467b48Spatrick   }
3340*d415bd75Srobert 
334173471bf0Spatrick   if (Kind == TTI::SK_Splice && isa<ScalableVectorType>(Tp))
334273471bf0Spatrick     return getSpliceCost(Tp, Index);
3343*d415bd75Srobert 
3344*d415bd75Srobert   // Inserting a subvector can often be done with either a D, S or H register
3345*d415bd75Srobert   // move, so long as the inserted vector is "aligned".
3346*d415bd75Srobert   if (Kind == TTI::SK_InsertSubvector && LT.second.isFixedLengthVector() &&
3347*d415bd75Srobert       LT.second.getSizeInBits() <= 128 && SubTp) {
3348*d415bd75Srobert     std::pair<InstructionCost, MVT> SubLT = getTypeLegalizationCost(SubTp);
3349*d415bd75Srobert     if (SubLT.second.isVector()) {
3350*d415bd75Srobert       int NumElts = LT.second.getVectorNumElements();
3351*d415bd75Srobert       int NumSubElts = SubLT.second.getVectorNumElements();
3352*d415bd75Srobert       if ((Index % NumSubElts) == 0 && (NumElts % NumSubElts) == 0)
3353*d415bd75Srobert         return SubLT.first;
3354*d415bd75Srobert     }
3355*d415bd75Srobert   }
3356*d415bd75Srobert 
3357*d415bd75Srobert   return BaseT::getShuffleCost(Kind, Tp, Mask, CostKind, Index, SubTp);
3358*d415bd75Srobert }
3359*d415bd75Srobert 
preferPredicateOverEpilogue(Loop * L,LoopInfo * LI,ScalarEvolution & SE,AssumptionCache & AC,TargetLibraryInfo * TLI,DominatorTree * DT,LoopVectorizationLegality * LVL,InterleavedAccessInfo * IAI)3360*d415bd75Srobert bool AArch64TTIImpl::preferPredicateOverEpilogue(
3361*d415bd75Srobert     Loop *L, LoopInfo *LI, ScalarEvolution &SE, AssumptionCache &AC,
3362*d415bd75Srobert     TargetLibraryInfo *TLI, DominatorTree *DT, LoopVectorizationLegality *LVL,
3363*d415bd75Srobert     InterleavedAccessInfo *IAI) {
3364*d415bd75Srobert   if (!ST->hasSVE() || TailFoldingKindLoc == TailFoldingKind::TFDisabled)
3365*d415bd75Srobert     return false;
3366*d415bd75Srobert 
3367*d415bd75Srobert   // We don't currently support vectorisation with interleaving for SVE - with
3368*d415bd75Srobert   // such loops we're better off not using tail-folding. This gives us a chance
3369*d415bd75Srobert   // to fall back on fixed-width vectorisation using NEON's ld2/st2/etc.
3370*d415bd75Srobert   if (IAI->hasGroups())
3371*d415bd75Srobert     return false;
3372*d415bd75Srobert 
3373*d415bd75Srobert   TailFoldingKind Required; // Defaults to 0.
3374*d415bd75Srobert   if (LVL->getReductionVars().size())
3375*d415bd75Srobert     Required.add(TailFoldingKind::TFReductions);
3376*d415bd75Srobert   if (LVL->getFixedOrderRecurrences().size())
3377*d415bd75Srobert     Required.add(TailFoldingKind::TFRecurrences);
3378*d415bd75Srobert   if (!Required)
3379*d415bd75Srobert     Required.add(TailFoldingKind::TFSimple);
3380*d415bd75Srobert 
3381*d415bd75Srobert   return (TailFoldingKindLoc & Required) == Required;
3382*d415bd75Srobert }
3383*d415bd75Srobert 
3384*d415bd75Srobert InstructionCost
getScalingFactorCost(Type * Ty,GlobalValue * BaseGV,int64_t BaseOffset,bool HasBaseReg,int64_t Scale,unsigned AddrSpace) const3385*d415bd75Srobert AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
3386*d415bd75Srobert                                      int64_t BaseOffset, bool HasBaseReg,
3387*d415bd75Srobert                                      int64_t Scale, unsigned AddrSpace) const {
3388*d415bd75Srobert   // Scaling factors are not free at all.
3389*d415bd75Srobert   // Operands                     | Rt Latency
3390*d415bd75Srobert   // -------------------------------------------
3391*d415bd75Srobert   // Rt, [Xn, Xm]                 | 4
3392*d415bd75Srobert   // -------------------------------------------
3393*d415bd75Srobert   // Rt, [Xn, Xm, lsl #imm]       | Rn: 4 Rm: 5
3394*d415bd75Srobert   // Rt, [Xn, Wm, <extend> #imm]  |
3395*d415bd75Srobert   TargetLoweringBase::AddrMode AM;
3396*d415bd75Srobert   AM.BaseGV = BaseGV;
3397*d415bd75Srobert   AM.BaseOffs = BaseOffset;
3398*d415bd75Srobert   AM.HasBaseReg = HasBaseReg;
3399*d415bd75Srobert   AM.Scale = Scale;
3400*d415bd75Srobert   if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
3401*d415bd75Srobert     // Scale represents reg2 * scale, thus account for 1 if
3402*d415bd75Srobert     // it is not equal to 0 or 1.
3403*d415bd75Srobert     return AM.Scale != 0 && AM.Scale != 1;
3404*d415bd75Srobert   return -1;
340509467b48Spatrick }
3406