1 //===-- X86PartialReduction.cpp -------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass looks for add instructions used by a horizontal reduction to see 10 // if we might be able to use pmaddwd or psadbw. Some cases of this require 11 // cross basic block knowledge and can't be done in SelectionDAG. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "X86.h" 16 #include "X86TargetMachine.h" 17 #include "llvm/Analysis/ValueTracking.h" 18 #include "llvm/CodeGen/TargetPassConfig.h" 19 #include "llvm/IR/Constants.h" 20 #include "llvm/IR/IRBuilder.h" 21 #include "llvm/IR/Instructions.h" 22 #include "llvm/IR/IntrinsicsX86.h" 23 #include "llvm/IR/PatternMatch.h" 24 #include "llvm/Pass.h" 25 #include "llvm/Support/KnownBits.h" 26 27 using namespace llvm; 28 29 #define DEBUG_TYPE "x86-partial-reduction" 30 31 namespace { 32 33 class X86PartialReduction : public FunctionPass { 34 const DataLayout *DL = nullptr; 35 const X86Subtarget *ST = nullptr; 36 37 public: 38 static char ID; // Pass identification, replacement for typeid. 39 40 X86PartialReduction() : FunctionPass(ID) { } 41 42 bool runOnFunction(Function &Fn) override; 43 44 void getAnalysisUsage(AnalysisUsage &AU) const override { 45 AU.setPreservesCFG(); 46 } 47 48 StringRef getPassName() const override { 49 return "X86 Partial Reduction"; 50 } 51 52 private: 53 bool tryMAddReplacement(Instruction *Op, bool ReduceInOneBB); 54 bool trySADReplacement(Instruction *Op); 55 }; 56 } 57 58 FunctionPass *llvm::createX86PartialReductionPass() { 59 return new X86PartialReduction(); 60 } 61 62 char X86PartialReduction::ID = 0; 63 64 INITIALIZE_PASS(X86PartialReduction, DEBUG_TYPE, 65 "X86 Partial Reduction", false, false) 66 67 // This function should be aligned with detectExtMul() in X86ISelLowering.cpp. 68 static bool matchVPDPBUSDPattern(const X86Subtarget *ST, BinaryOperator *Mul, 69 const DataLayout *DL) { 70 if (!ST->hasVNNI() && !ST->hasAVXVNNI()) 71 return false; 72 73 Value *LHS = Mul->getOperand(0); 74 Value *RHS = Mul->getOperand(1); 75 76 if (isa<SExtInst>(LHS)) 77 std::swap(LHS, RHS); 78 79 auto IsFreeTruncation = [&](Value *Op) { 80 if (auto *Cast = dyn_cast<CastInst>(Op)) { 81 if (Cast->getParent() == Mul->getParent() && 82 (Cast->getOpcode() == Instruction::SExt || 83 Cast->getOpcode() == Instruction::ZExt) && 84 Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 8) 85 return true; 86 } 87 88 return isa<Constant>(Op); 89 }; 90 91 // (dpbusd (zext a), (sext, b)). Since the first operand should be unsigned 92 // value, we need to check LHS is zero extended value. RHS should be signed 93 // value, so we just check the signed bits. 94 if ((IsFreeTruncation(LHS) && 95 computeKnownBits(LHS, *DL).countMaxActiveBits() <= 8) && 96 (IsFreeTruncation(RHS) && ComputeMaxSignificantBits(RHS, *DL) <= 8)) 97 return true; 98 99 return false; 100 } 101 102 bool X86PartialReduction::tryMAddReplacement(Instruction *Op, 103 bool ReduceInOneBB) { 104 if (!ST->hasSSE2()) 105 return false; 106 107 // Need at least 8 elements. 108 if (cast<FixedVectorType>(Op->getType())->getNumElements() < 8) 109 return false; 110 111 // Element type should be i32. 112 if (!cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(32)) 113 return false; 114 115 auto *Mul = dyn_cast<BinaryOperator>(Op); 116 if (!Mul || Mul->getOpcode() != Instruction::Mul) 117 return false; 118 119 Value *LHS = Mul->getOperand(0); 120 Value *RHS = Mul->getOperand(1); 121 122 // If the target support VNNI, leave it to ISel to combine reduce operation 123 // to VNNI instruction. 124 // TODO: we can support transforming reduce to VNNI intrinsic for across block 125 // in this pass. 126 if (ReduceInOneBB && matchVPDPBUSDPattern(ST, Mul, DL)) 127 return false; 128 129 // LHS and RHS should be only used once or if they are the same then only 130 // used twice. Only check this when SSE4.1 is enabled and we have zext/sext 131 // instructions, otherwise we use punpck to emulate zero extend in stages. The 132 // trunc/ we need to do likely won't introduce new instructions in that case. 133 if (ST->hasSSE41()) { 134 if (LHS == RHS) { 135 if (!isa<Constant>(LHS) && !LHS->hasNUses(2)) 136 return false; 137 } else { 138 if (!isa<Constant>(LHS) && !LHS->hasOneUse()) 139 return false; 140 if (!isa<Constant>(RHS) && !RHS->hasOneUse()) 141 return false; 142 } 143 } 144 145 auto CanShrinkOp = [&](Value *Op) { 146 auto IsFreeTruncation = [&](Value *Op) { 147 if (auto *Cast = dyn_cast<CastInst>(Op)) { 148 if (Cast->getParent() == Mul->getParent() && 149 (Cast->getOpcode() == Instruction::SExt || 150 Cast->getOpcode() == Instruction::ZExt) && 151 Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 16) 152 return true; 153 } 154 155 return isa<Constant>(Op); 156 }; 157 158 // If the operation can be freely truncated and has enough sign bits we 159 // can shrink. 160 if (IsFreeTruncation(Op) && 161 ComputeNumSignBits(Op, *DL, 0, nullptr, Mul) > 16) 162 return true; 163 164 // SelectionDAG has limited support for truncating through an add or sub if 165 // the inputs are freely truncatable. 166 if (auto *BO = dyn_cast<BinaryOperator>(Op)) { 167 if (BO->getParent() == Mul->getParent() && 168 IsFreeTruncation(BO->getOperand(0)) && 169 IsFreeTruncation(BO->getOperand(1)) && 170 ComputeNumSignBits(Op, *DL, 0, nullptr, Mul) > 16) 171 return true; 172 } 173 174 return false; 175 }; 176 177 // Both Ops need to be shrinkable. 178 if (!CanShrinkOp(LHS) && !CanShrinkOp(RHS)) 179 return false; 180 181 IRBuilder<> Builder(Mul); 182 183 auto *MulTy = cast<FixedVectorType>(Op->getType()); 184 unsigned NumElts = MulTy->getNumElements(); 185 186 // Extract even elements and odd elements and add them together. This will 187 // be pattern matched by SelectionDAG to pmaddwd. This instruction will be 188 // half the original width. 189 SmallVector<int, 16> EvenMask(NumElts / 2); 190 SmallVector<int, 16> OddMask(NumElts / 2); 191 for (int i = 0, e = NumElts / 2; i != e; ++i) { 192 EvenMask[i] = i * 2; 193 OddMask[i] = i * 2 + 1; 194 } 195 // Creating a new mul so the replaceAllUsesWith below doesn't replace the 196 // uses in the shuffles we're creating. 197 Value *NewMul = Builder.CreateMul(Mul->getOperand(0), Mul->getOperand(1)); 198 Value *EvenElts = Builder.CreateShuffleVector(NewMul, NewMul, EvenMask); 199 Value *OddElts = Builder.CreateShuffleVector(NewMul, NewMul, OddMask); 200 Value *MAdd = Builder.CreateAdd(EvenElts, OddElts); 201 202 // Concatenate zeroes to extend back to the original type. 203 SmallVector<int, 32> ConcatMask(NumElts); 204 std::iota(ConcatMask.begin(), ConcatMask.end(), 0); 205 Value *Zero = Constant::getNullValue(MAdd->getType()); 206 Value *Concat = Builder.CreateShuffleVector(MAdd, Zero, ConcatMask); 207 208 Mul->replaceAllUsesWith(Concat); 209 Mul->eraseFromParent(); 210 211 return true; 212 } 213 214 bool X86PartialReduction::trySADReplacement(Instruction *Op) { 215 if (!ST->hasSSE2()) 216 return false; 217 218 // TODO: There's nothing special about i32, any integer type above i16 should 219 // work just as well. 220 if (!cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(32)) 221 return false; 222 223 Value *LHS; 224 if (match(Op, PatternMatch::m_Intrinsic<Intrinsic::abs>())) { 225 LHS = Op->getOperand(0); 226 } else { 227 // Operand should be a select. 228 auto *SI = dyn_cast<SelectInst>(Op); 229 if (!SI) 230 return false; 231 232 Value *RHS; 233 // Select needs to implement absolute value. 234 auto SPR = matchSelectPattern(SI, LHS, RHS); 235 if (SPR.Flavor != SPF_ABS) 236 return false; 237 } 238 239 // Need a subtract of two values. 240 auto *Sub = dyn_cast<BinaryOperator>(LHS); 241 if (!Sub || Sub->getOpcode() != Instruction::Sub) 242 return false; 243 244 // Look for zero extend from i8. 245 auto getZeroExtendedVal = [](Value *Op) -> Value * { 246 if (auto *ZExt = dyn_cast<ZExtInst>(Op)) 247 if (cast<VectorType>(ZExt->getOperand(0)->getType()) 248 ->getElementType() 249 ->isIntegerTy(8)) 250 return ZExt->getOperand(0); 251 252 return nullptr; 253 }; 254 255 // Both operands of the subtract should be extends from vXi8. 256 Value *Op0 = getZeroExtendedVal(Sub->getOperand(0)); 257 Value *Op1 = getZeroExtendedVal(Sub->getOperand(1)); 258 if (!Op0 || !Op1) 259 return false; 260 261 IRBuilder<> Builder(Op); 262 263 auto *OpTy = cast<FixedVectorType>(Op->getType()); 264 unsigned NumElts = OpTy->getNumElements(); 265 266 unsigned IntrinsicNumElts; 267 Intrinsic::ID IID; 268 if (ST->hasBWI() && NumElts >= 64) { 269 IID = Intrinsic::x86_avx512_psad_bw_512; 270 IntrinsicNumElts = 64; 271 } else if (ST->hasAVX2() && NumElts >= 32) { 272 IID = Intrinsic::x86_avx2_psad_bw; 273 IntrinsicNumElts = 32; 274 } else { 275 IID = Intrinsic::x86_sse2_psad_bw; 276 IntrinsicNumElts = 16; 277 } 278 279 Function *PSADBWFn = Intrinsic::getOrInsertDeclaration(Op->getModule(), IID); 280 281 if (NumElts < 16) { 282 // Pad input with zeroes. 283 SmallVector<int, 32> ConcatMask(16); 284 for (unsigned i = 0; i != NumElts; ++i) 285 ConcatMask[i] = i; 286 for (unsigned i = NumElts; i != 16; ++i) 287 ConcatMask[i] = (i % NumElts) + NumElts; 288 289 Value *Zero = Constant::getNullValue(Op0->getType()); 290 Op0 = Builder.CreateShuffleVector(Op0, Zero, ConcatMask); 291 Op1 = Builder.CreateShuffleVector(Op1, Zero, ConcatMask); 292 NumElts = 16; 293 } 294 295 // Intrinsics produce vXi64 and need to be casted to vXi32. 296 auto *I32Ty = 297 FixedVectorType::get(Builder.getInt32Ty(), IntrinsicNumElts / 4); 298 299 assert(NumElts % IntrinsicNumElts == 0 && "Unexpected number of elements!"); 300 unsigned NumSplits = NumElts / IntrinsicNumElts; 301 302 // First collect the pieces we need. 303 SmallVector<Value *, 4> Ops(NumSplits); 304 for (unsigned i = 0; i != NumSplits; ++i) { 305 SmallVector<int, 64> ExtractMask(IntrinsicNumElts); 306 std::iota(ExtractMask.begin(), ExtractMask.end(), i * IntrinsicNumElts); 307 Value *ExtractOp0 = Builder.CreateShuffleVector(Op0, Op0, ExtractMask); 308 Value *ExtractOp1 = Builder.CreateShuffleVector(Op1, Op0, ExtractMask); 309 Ops[i] = Builder.CreateCall(PSADBWFn, {ExtractOp0, ExtractOp1}); 310 Ops[i] = Builder.CreateBitCast(Ops[i], I32Ty); 311 } 312 313 assert(isPowerOf2_32(NumSplits) && "Expected power of 2 splits"); 314 unsigned Stages = Log2_32(NumSplits); 315 for (unsigned s = Stages; s > 0; --s) { 316 unsigned NumConcatElts = 317 cast<FixedVectorType>(Ops[0]->getType())->getNumElements() * 2; 318 for (unsigned i = 0; i != 1U << (s - 1); ++i) { 319 SmallVector<int, 64> ConcatMask(NumConcatElts); 320 std::iota(ConcatMask.begin(), ConcatMask.end(), 0); 321 Ops[i] = Builder.CreateShuffleVector(Ops[i*2], Ops[i*2+1], ConcatMask); 322 } 323 } 324 325 // At this point the final value should be in Ops[0]. Now we need to adjust 326 // it to the final original type. 327 NumElts = cast<FixedVectorType>(OpTy)->getNumElements(); 328 if (NumElts == 2) { 329 // Extract down to 2 elements. 330 Ops[0] = Builder.CreateShuffleVector(Ops[0], Ops[0], ArrayRef<int>{0, 1}); 331 } else if (NumElts >= 8) { 332 SmallVector<int, 32> ConcatMask(NumElts); 333 unsigned SubElts = 334 cast<FixedVectorType>(Ops[0]->getType())->getNumElements(); 335 for (unsigned i = 0; i != SubElts; ++i) 336 ConcatMask[i] = i; 337 for (unsigned i = SubElts; i != NumElts; ++i) 338 ConcatMask[i] = (i % SubElts) + SubElts; 339 340 Value *Zero = Constant::getNullValue(Ops[0]->getType()); 341 Ops[0] = Builder.CreateShuffleVector(Ops[0], Zero, ConcatMask); 342 } 343 344 Op->replaceAllUsesWith(Ops[0]); 345 Op->eraseFromParent(); 346 347 return true; 348 } 349 350 // Walk backwards from the ExtractElementInst and determine if it is the end of 351 // a horizontal reduction. Return the input to the reduction if we find one. 352 static Value *matchAddReduction(const ExtractElementInst &EE, 353 bool &ReduceInOneBB) { 354 ReduceInOneBB = true; 355 // Make sure we're extracting index 0. 356 auto *Index = dyn_cast<ConstantInt>(EE.getIndexOperand()); 357 if (!Index || !Index->isNullValue()) 358 return nullptr; 359 360 const auto *BO = dyn_cast<BinaryOperator>(EE.getVectorOperand()); 361 if (!BO || BO->getOpcode() != Instruction::Add || !BO->hasOneUse()) 362 return nullptr; 363 if (EE.getParent() != BO->getParent()) 364 ReduceInOneBB = false; 365 366 unsigned NumElems = cast<FixedVectorType>(BO->getType())->getNumElements(); 367 // Ensure the reduction size is a power of 2. 368 if (!isPowerOf2_32(NumElems)) 369 return nullptr; 370 371 const Value *Op = BO; 372 unsigned Stages = Log2_32(NumElems); 373 for (unsigned i = 0; i != Stages; ++i) { 374 const auto *BO = dyn_cast<BinaryOperator>(Op); 375 if (!BO || BO->getOpcode() != Instruction::Add) 376 return nullptr; 377 if (EE.getParent() != BO->getParent()) 378 ReduceInOneBB = false; 379 380 // If this isn't the first add, then it should only have 2 users, the 381 // shuffle and another add which we checked in the previous iteration. 382 if (i != 0 && !BO->hasNUses(2)) 383 return nullptr; 384 385 Value *LHS = BO->getOperand(0); 386 Value *RHS = BO->getOperand(1); 387 388 auto *Shuffle = dyn_cast<ShuffleVectorInst>(LHS); 389 if (Shuffle) { 390 Op = RHS; 391 } else { 392 Shuffle = dyn_cast<ShuffleVectorInst>(RHS); 393 Op = LHS; 394 } 395 396 // The first operand of the shuffle should be the same as the other operand 397 // of the bin op. 398 if (!Shuffle || Shuffle->getOperand(0) != Op) 399 return nullptr; 400 401 // Verify the shuffle has the expected (at this stage of the pyramid) mask. 402 unsigned MaskEnd = 1 << i; 403 for (unsigned Index = 0; Index < MaskEnd; ++Index) 404 if (Shuffle->getMaskValue(Index) != (int)(MaskEnd + Index)) 405 return nullptr; 406 } 407 408 return const_cast<Value *>(Op); 409 } 410 411 // See if this BO is reachable from this Phi by walking forward through single 412 // use BinaryOperators with the same opcode. If we get back then we know we've 413 // found a loop and it is safe to step through this Add to find more leaves. 414 static bool isReachableFromPHI(PHINode *Phi, BinaryOperator *BO) { 415 // The PHI itself should only have one use. 416 if (!Phi->hasOneUse()) 417 return false; 418 419 Instruction *U = cast<Instruction>(*Phi->user_begin()); 420 if (U == BO) 421 return true; 422 423 while (U->hasOneUse() && U->getOpcode() == BO->getOpcode()) 424 U = cast<Instruction>(*U->user_begin()); 425 426 return U == BO; 427 } 428 429 // Collect all the leaves of the tree of adds that feeds into the horizontal 430 // reduction. Root is the Value that is used by the horizontal reduction. 431 // We look through single use phis, single use adds, or adds that are used by 432 // a phi that forms a loop with the add. 433 static void collectLeaves(Value *Root, SmallVectorImpl<Instruction *> &Leaves) { 434 SmallPtrSet<Value *, 8> Visited; 435 SmallVector<Value *, 8> Worklist; 436 Worklist.push_back(Root); 437 438 while (!Worklist.empty()) { 439 Value *V = Worklist.pop_back_val(); 440 if (!Visited.insert(V).second) 441 continue; 442 443 if (auto *PN = dyn_cast<PHINode>(V)) { 444 // PHI node should have single use unless it is the root node, then it 445 // has 2 uses. 446 if (!PN->hasNUses(PN == Root ? 2 : 1)) 447 break; 448 449 // Push incoming values to the worklist. 450 append_range(Worklist, PN->incoming_values()); 451 452 continue; 453 } 454 455 if (auto *BO = dyn_cast<BinaryOperator>(V)) { 456 if (BO->getOpcode() == Instruction::Add) { 457 // Simple case. Single use, just push its operands to the worklist. 458 if (BO->hasNUses(BO == Root ? 2 : 1)) { 459 append_range(Worklist, BO->operands()); 460 continue; 461 } 462 463 // If there is additional use, make sure it is an unvisited phi that 464 // gets us back to this node. 465 if (BO->hasNUses(BO == Root ? 3 : 2)) { 466 PHINode *PN = nullptr; 467 for (auto *U : BO->users()) 468 if (auto *P = dyn_cast<PHINode>(U)) 469 if (!Visited.count(P)) 470 PN = P; 471 472 // If we didn't find a 2-input PHI then this isn't a case we can 473 // handle. 474 if (!PN || PN->getNumIncomingValues() != 2) 475 continue; 476 477 // Walk forward from this phi to see if it reaches back to this add. 478 if (!isReachableFromPHI(PN, BO)) 479 continue; 480 481 // The phi forms a loop with this Add, push its operands. 482 append_range(Worklist, BO->operands()); 483 } 484 } 485 } 486 487 // Not an add or phi, make it a leaf. 488 if (auto *I = dyn_cast<Instruction>(V)) { 489 if (!V->hasNUses(I == Root ? 2 : 1)) 490 continue; 491 492 // Add this as a leaf. 493 Leaves.push_back(I); 494 } 495 } 496 } 497 498 bool X86PartialReduction::runOnFunction(Function &F) { 499 if (skipFunction(F)) 500 return false; 501 502 auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); 503 if (!TPC) 504 return false; 505 506 auto &TM = TPC->getTM<X86TargetMachine>(); 507 ST = TM.getSubtargetImpl(F); 508 509 DL = &F.getDataLayout(); 510 511 bool MadeChange = false; 512 for (auto &BB : F) { 513 for (auto &I : BB) { 514 auto *EE = dyn_cast<ExtractElementInst>(&I); 515 if (!EE) 516 continue; 517 518 bool ReduceInOneBB; 519 // First find a reduction tree. 520 // FIXME: Do we need to handle other opcodes than Add? 521 Value *Root = matchAddReduction(*EE, ReduceInOneBB); 522 if (!Root) 523 continue; 524 525 SmallVector<Instruction *, 8> Leaves; 526 collectLeaves(Root, Leaves); 527 528 for (Instruction *I : Leaves) { 529 if (tryMAddReplacement(I, ReduceInOneBB)) { 530 MadeChange = true; 531 continue; 532 } 533 534 // Don't do SAD matching on the root node. SelectionDAG already 535 // has support for that and currently generates better code. 536 if (I != Root && trySADReplacement(I)) 537 MadeChange = true; 538 } 539 } 540 } 541 542 return MadeChange; 543 } 544