1 //===- CombinerHelperVectorOps.cpp-----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements CombinerHelper for G_EXTRACT_VECTOR_ELT, 10 // G_INSERT_VECTOR_ELT, and G_VSCALE 11 // 12 //===----------------------------------------------------------------------===// 13 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" 14 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" 15 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" 16 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" 17 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" 18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 19 #include "llvm/CodeGen/GlobalISel/Utils.h" 20 #include "llvm/CodeGen/LowLevelTypeUtils.h" 21 #include "llvm/CodeGen/MachineOperand.h" 22 #include "llvm/CodeGen/MachineRegisterInfo.h" 23 #include "llvm/CodeGen/TargetLowering.h" 24 #include "llvm/CodeGen/TargetOpcodes.h" 25 #include "llvm/Support/Casting.h" 26 #include <optional> 27 28 #define DEBUG_TYPE "gi-combiner" 29 30 using namespace llvm; 31 using namespace MIPatternMatch; 32 33 bool CombinerHelper::matchExtractVectorElement(MachineInstr &MI, 34 BuildFnTy &MatchInfo) const { 35 GExtractVectorElement *Extract = cast<GExtractVectorElement>(&MI); 36 37 Register Dst = Extract->getReg(0); 38 Register Vector = Extract->getVectorReg(); 39 Register Index = Extract->getIndexReg(); 40 LLT DstTy = MRI.getType(Dst); 41 LLT VectorTy = MRI.getType(Vector); 42 43 // The vector register can be def'd by various ops that have vector as its 44 // type. They can all be used for constant folding, scalarizing, 45 // canonicalization, or combining based on symmetry. 46 // 47 // vector like ops 48 // * build vector 49 // * build vector trunc 50 // * shuffle vector 51 // * splat vector 52 // * concat vectors 53 // * insert/extract vector element 54 // * insert/extract subvector 55 // * vector loads 56 // * scalable vector loads 57 // 58 // compute like ops 59 // * binary ops 60 // * unary ops 61 // * exts and truncs 62 // * casts 63 // * fneg 64 // * select 65 // * phis 66 // * cmps 67 // * freeze 68 // * bitcast 69 // * undef 70 71 // We try to get the value of the Index register. 72 std::optional<ValueAndVReg> MaybeIndex = 73 getIConstantVRegValWithLookThrough(Index, MRI); 74 std::optional<APInt> IndexC = std::nullopt; 75 76 if (MaybeIndex) 77 IndexC = MaybeIndex->Value; 78 79 // Fold extractVectorElement(Vector, TOOLARGE) -> undef 80 if (IndexC && VectorTy.isFixedVector() && 81 IndexC->uge(VectorTy.getNumElements()) && 82 isLegalOrBeforeLegalizer({TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) { 83 // For fixed-length vectors, it's invalid to extract out-of-range elements. 84 MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Dst); }; 85 return true; 86 } 87 88 return false; 89 } 90 91 bool CombinerHelper::matchExtractVectorElementWithDifferentIndices( 92 const MachineOperand &MO, BuildFnTy &MatchInfo) const { 93 MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI); 94 GExtractVectorElement *Extract = cast<GExtractVectorElement>(Root); 95 96 // 97 // %idx1:_(s64) = G_CONSTANT i64 1 98 // %idx2:_(s64) = G_CONSTANT i64 2 99 // %insert:_(<2 x s32>) = G_INSERT_VECTOR_ELT_ELT %bv(<2 x s32>), 100 // %value(s32), %idx2(s64) %extract:_(s32) = G_EXTRACT_VECTOR_ELT %insert(<2 101 // x s32>), %idx1(s64) 102 // 103 // --> 104 // 105 // %insert:_(<2 x s32>) = G_INSERT_VECTOR_ELT_ELT %bv(<2 x s32>), 106 // %value(s32), %idx2(s64) %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x 107 // s32>), %idx1(s64) 108 // 109 // 110 111 Register Index = Extract->getIndexReg(); 112 113 // We try to get the value of the Index register. 114 std::optional<ValueAndVReg> MaybeIndex = 115 getIConstantVRegValWithLookThrough(Index, MRI); 116 std::optional<APInt> IndexC = std::nullopt; 117 118 if (!MaybeIndex) 119 return false; 120 else 121 IndexC = MaybeIndex->Value; 122 123 Register Vector = Extract->getVectorReg(); 124 125 GInsertVectorElement *Insert = 126 getOpcodeDef<GInsertVectorElement>(Vector, MRI); 127 if (!Insert) 128 return false; 129 130 Register Dst = Extract->getReg(0); 131 132 std::optional<ValueAndVReg> MaybeInsertIndex = 133 getIConstantVRegValWithLookThrough(Insert->getIndexReg(), MRI); 134 135 if (MaybeInsertIndex && MaybeInsertIndex->Value != *IndexC) { 136 // There is no one-use check. We have to keep the insert. When both Index 137 // registers are constants and not equal, we can look into the Vector 138 // register of the insert. 139 MatchInfo = [=](MachineIRBuilder &B) { 140 B.buildExtractVectorElement(Dst, Insert->getVectorReg(), Index); 141 }; 142 return true; 143 } 144 145 return false; 146 } 147 148 bool CombinerHelper::matchExtractVectorElementWithBuildVector( 149 const MachineInstr &MI, const MachineInstr &MI2, 150 BuildFnTy &MatchInfo) const { 151 const GExtractVectorElement *Extract = cast<GExtractVectorElement>(&MI); 152 const GBuildVector *Build = cast<GBuildVector>(&MI2); 153 154 // 155 // %zero:_(s64) = G_CONSTANT i64 0 156 // %bv:_(<2 x s32>) = G_BUILD_VECTOR %arg1(s32), %arg2(s32) 157 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %zero(s64) 158 // 159 // --> 160 // 161 // %extract:_(32) = COPY %arg1(s32) 162 // 163 // 164 165 Register Vector = Extract->getVectorReg(); 166 LLT VectorTy = MRI.getType(Vector); 167 168 // There is a one-use check. There are more combines on build vectors. 169 EVT Ty(getMVTForLLT(VectorTy)); 170 if (!MRI.hasOneNonDBGUse(Build->getReg(0)) || 171 !getTargetLowering().aggressivelyPreferBuildVectorSources(Ty)) 172 return false; 173 174 APInt Index = getIConstantFromReg(Extract->getIndexReg(), MRI); 175 176 // We now know that there is a buildVector def'd on the Vector register and 177 // the index is const. The combine will succeed. 178 179 Register Dst = Extract->getReg(0); 180 181 MatchInfo = [=](MachineIRBuilder &B) { 182 B.buildCopy(Dst, Build->getSourceReg(Index.getZExtValue())); 183 }; 184 185 return true; 186 } 187 188 bool CombinerHelper::matchExtractVectorElementWithBuildVectorTrunc( 189 const MachineOperand &MO, BuildFnTy &MatchInfo) const { 190 MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI); 191 GExtractVectorElement *Extract = cast<GExtractVectorElement>(Root); 192 193 // 194 // %zero:_(s64) = G_CONSTANT i64 0 195 // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64) 196 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %zero(s64) 197 // 198 // --> 199 // 200 // %extract:_(32) = G_TRUNC %arg1(s64) 201 // 202 // 203 // 204 // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64) 205 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64) 206 // 207 // --> 208 // 209 // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64) 210 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64) 211 // 212 213 Register Vector = Extract->getVectorReg(); 214 215 // We expect a buildVectorTrunc on the Vector register. 216 GBuildVectorTrunc *Build = getOpcodeDef<GBuildVectorTrunc>(Vector, MRI); 217 if (!Build) 218 return false; 219 220 LLT VectorTy = MRI.getType(Vector); 221 222 // There is a one-use check. There are more combines on build vectors. 223 EVT Ty(getMVTForLLT(VectorTy)); 224 if (!MRI.hasOneNonDBGUse(Build->getReg(0)) || 225 !getTargetLowering().aggressivelyPreferBuildVectorSources(Ty)) 226 return false; 227 228 Register Index = Extract->getIndexReg(); 229 230 // If the Index is constant, then we can extract the element from the given 231 // offset. 232 std::optional<ValueAndVReg> MaybeIndex = 233 getIConstantVRegValWithLookThrough(Index, MRI); 234 if (!MaybeIndex) 235 return false; 236 237 // We now know that there is a buildVectorTrunc def'd on the Vector register 238 // and the index is const. The combine will succeed. 239 240 Register Dst = Extract->getReg(0); 241 LLT DstTy = MRI.getType(Dst); 242 LLT SrcTy = MRI.getType(Build->getSourceReg(0)); 243 244 // For buildVectorTrunc, the inputs are truncated. 245 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) 246 return false; 247 248 MatchInfo = [=](MachineIRBuilder &B) { 249 B.buildTrunc(Dst, Build->getSourceReg(MaybeIndex->Value.getZExtValue())); 250 }; 251 252 return true; 253 } 254 255 bool CombinerHelper::matchExtractVectorElementWithShuffleVector( 256 const MachineInstr &MI, const MachineInstr &MI2, 257 BuildFnTy &MatchInfo) const { 258 const GExtractVectorElement *Extract = cast<GExtractVectorElement>(&MI); 259 const GShuffleVector *Shuffle = cast<GShuffleVector>(&MI2); 260 261 // 262 // %zero:_(s64) = G_CONSTANT i64 0 263 // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>), 264 // shufflemask(0, 0, 0, 0) 265 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %zero(s64) 266 // 267 // --> 268 // 269 // %zero1:_(s64) = G_CONSTANT i64 0 270 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %arg1(<4 x s32>), %zero1(s64) 271 // 272 // 273 // 274 // 275 // %three:_(s64) = G_CONSTANT i64 3 276 // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>), 277 // shufflemask(0, 0, 0, -1) 278 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %three(s64) 279 // 280 // --> 281 // 282 // %extract:_(s32) = G_IMPLICIT_DEF 283 // 284 // 285 286 APInt Index = getIConstantFromReg(Extract->getIndexReg(), MRI); 287 288 ArrayRef<int> Mask = Shuffle->getMask(); 289 290 unsigned Offset = Index.getZExtValue(); 291 int SrcIdx = Mask[Offset]; 292 293 LLT Src1Type = MRI.getType(Shuffle->getSrc1Reg()); 294 // At the IR level a <1 x ty> shuffle vector is valid, but we want to extract 295 // from a vector. 296 assert(Src1Type.isVector() && "expected to extract from a vector"); 297 unsigned LHSWidth = Src1Type.isVector() ? Src1Type.getNumElements() : 1; 298 299 // Note that there is no one use check. 300 Register Dst = Extract->getReg(0); 301 LLT DstTy = MRI.getType(Dst); 302 303 if (SrcIdx < 0 && 304 isLegalOrBeforeLegalizer({TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) { 305 MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Dst); }; 306 return true; 307 } 308 309 // If the legality check failed, then we still have to abort. 310 if (SrcIdx < 0) 311 return false; 312 313 Register NewVector; 314 315 // We check in which vector and at what offset to look through. 316 if (SrcIdx < (int)LHSWidth) { 317 NewVector = Shuffle->getSrc1Reg(); 318 // SrcIdx unchanged 319 } else { // SrcIdx >= LHSWidth 320 NewVector = Shuffle->getSrc2Reg(); 321 SrcIdx -= LHSWidth; 322 } 323 324 LLT IdxTy = MRI.getType(Extract->getIndexReg()); 325 LLT NewVectorTy = MRI.getType(NewVector); 326 327 // We check the legality of the look through. 328 if (!isLegalOrBeforeLegalizer( 329 {TargetOpcode::G_EXTRACT_VECTOR_ELT, {DstTy, NewVectorTy, IdxTy}}) || 330 !isConstantLegalOrBeforeLegalizer({IdxTy})) 331 return false; 332 333 // We look through the shuffle vector. 334 MatchInfo = [=](MachineIRBuilder &B) { 335 auto Idx = B.buildConstant(IdxTy, SrcIdx); 336 B.buildExtractVectorElement(Dst, NewVector, Idx); 337 }; 338 339 return true; 340 } 341 342 bool CombinerHelper::matchInsertVectorElementOOB(MachineInstr &MI, 343 BuildFnTy &MatchInfo) const { 344 GInsertVectorElement *Insert = cast<GInsertVectorElement>(&MI); 345 346 Register Dst = Insert->getReg(0); 347 LLT DstTy = MRI.getType(Dst); 348 Register Index = Insert->getIndexReg(); 349 350 if (!DstTy.isFixedVector()) 351 return false; 352 353 std::optional<ValueAndVReg> MaybeIndex = 354 getIConstantVRegValWithLookThrough(Index, MRI); 355 356 if (MaybeIndex && MaybeIndex->Value.uge(DstTy.getNumElements()) && 357 isLegalOrBeforeLegalizer({TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) { 358 MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Dst); }; 359 return true; 360 } 361 362 return false; 363 } 364 365 bool CombinerHelper::matchAddOfVScale(const MachineOperand &MO, 366 BuildFnTy &MatchInfo) const { 367 GAdd *Add = cast<GAdd>(MRI.getVRegDef(MO.getReg())); 368 GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Add->getLHSReg())); 369 GVScale *RHSVScale = cast<GVScale>(MRI.getVRegDef(Add->getRHSReg())); 370 371 Register Dst = Add->getReg(0); 372 373 if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0)) || 374 !MRI.hasOneNonDBGUse(RHSVScale->getReg(0))) 375 return false; 376 377 MatchInfo = [=](MachineIRBuilder &B) { 378 B.buildVScale(Dst, LHSVScale->getSrc() + RHSVScale->getSrc()); 379 }; 380 381 return true; 382 } 383 384 bool CombinerHelper::matchMulOfVScale(const MachineOperand &MO, 385 BuildFnTy &MatchInfo) const { 386 GMul *Mul = cast<GMul>(MRI.getVRegDef(MO.getReg())); 387 GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Mul->getLHSReg())); 388 389 std::optional<APInt> MaybeRHS = getIConstantVRegVal(Mul->getRHSReg(), MRI); 390 if (!MaybeRHS) 391 return false; 392 393 Register Dst = MO.getReg(); 394 395 if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0))) 396 return false; 397 398 MatchInfo = [=](MachineIRBuilder &B) { 399 B.buildVScale(Dst, LHSVScale->getSrc() * *MaybeRHS); 400 }; 401 402 return true; 403 } 404 405 bool CombinerHelper::matchSubOfVScale(const MachineOperand &MO, 406 BuildFnTy &MatchInfo) const { 407 GSub *Sub = cast<GSub>(MRI.getVRegDef(MO.getReg())); 408 GVScale *RHSVScale = cast<GVScale>(MRI.getVRegDef(Sub->getRHSReg())); 409 410 Register Dst = MO.getReg(); 411 LLT DstTy = MRI.getType(Dst); 412 413 if (!MRI.hasOneNonDBGUse(RHSVScale->getReg(0)) || 414 !isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, DstTy})) 415 return false; 416 417 MatchInfo = [=](MachineIRBuilder &B) { 418 auto VScale = B.buildVScale(DstTy, -RHSVScale->getSrc()); 419 B.buildAdd(Dst, Sub->getLHSReg(), VScale, Sub->getFlags()); 420 }; 421 422 return true; 423 } 424 425 bool CombinerHelper::matchShlOfVScale(const MachineOperand &MO, 426 BuildFnTy &MatchInfo) const { 427 GShl *Shl = cast<GShl>(MRI.getVRegDef(MO.getReg())); 428 GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Shl->getSrcReg())); 429 430 std::optional<APInt> MaybeRHS = getIConstantVRegVal(Shl->getShiftReg(), MRI); 431 if (!MaybeRHS) 432 return false; 433 434 Register Dst = MO.getReg(); 435 LLT DstTy = MRI.getType(Dst); 436 437 if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0)) || 438 !isLegalOrBeforeLegalizer({TargetOpcode::G_VSCALE, DstTy})) 439 return false; 440 441 MatchInfo = [=](MachineIRBuilder &B) { 442 B.buildVScale(Dst, LHSVScale->getSrc().shl(*MaybeRHS)); 443 }; 444 445 return true; 446 } 447