1 //===----- CodeGen/ExpandVectorPredication.cpp - Expand VP intrinsics -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass implements IR expansion for vector predication intrinsics, allowing 10 // targets to enable vector predication until just before codegen. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/CodeGen/ExpandVectorPredication.h" 15 #include "llvm/ADT/Statistic.h" 16 #include "llvm/Analysis/TargetTransformInfo.h" 17 #include "llvm/Analysis/ValueTracking.h" 18 #include "llvm/Analysis/VectorUtils.h" 19 #include "llvm/CodeGen/Passes.h" 20 #include "llvm/IR/Constants.h" 21 #include "llvm/IR/Function.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/InstIterator.h" 24 #include "llvm/IR/Instructions.h" 25 #include "llvm/IR/IntrinsicInst.h" 26 #include "llvm/IR/Intrinsics.h" 27 #include "llvm/InitializePasses.h" 28 #include "llvm/Pass.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/Compiler.h" 31 #include "llvm/Support/Debug.h" 32 #include <optional> 33 34 using namespace llvm; 35 36 using VPLegalization = TargetTransformInfo::VPLegalization; 37 using VPTransform = TargetTransformInfo::VPLegalization::VPTransform; 38 39 // Keep this in sync with TargetTransformInfo::VPLegalization. 40 #define VPINTERNAL_VPLEGAL_CASES \ 41 VPINTERNAL_CASE(Legal) \ 42 VPINTERNAL_CASE(Discard) \ 43 VPINTERNAL_CASE(Convert) 44 45 #define VPINTERNAL_CASE(X) "|" #X 46 47 // Override options. 48 static cl::opt<std::string> EVLTransformOverride( 49 "expandvp-override-evl-transform", cl::init(""), cl::Hidden, 50 cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES 51 ". If non-empty, ignore " 52 "TargetTransformInfo and " 53 "always use this transformation for the %evl parameter (Used in " 54 "testing).")); 55 56 static cl::opt<std::string> MaskTransformOverride( 57 "expandvp-override-mask-transform", cl::init(""), cl::Hidden, 58 cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES 59 ". If non-empty, Ignore " 60 "TargetTransformInfo and " 61 "always use this transformation for the %mask parameter (Used in " 62 "testing).")); 63 64 #undef VPINTERNAL_CASE 65 #define VPINTERNAL_CASE(X) .Case(#X, VPLegalization::X) 66 67 static VPTransform parseOverrideOption(const std::string &TextOpt) { 68 return StringSwitch<VPTransform>(TextOpt) VPINTERNAL_VPLEGAL_CASES; 69 } 70 71 #undef VPINTERNAL_VPLEGAL_CASES 72 73 // Whether any override options are set. 74 static bool anyExpandVPOverridesSet() { 75 return !EVLTransformOverride.empty() || !MaskTransformOverride.empty(); 76 } 77 78 #define DEBUG_TYPE "expandvp" 79 80 STATISTIC(NumFoldedVL, "Number of folded vector length params"); 81 STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations"); 82 83 ///// Helpers { 84 85 /// \returns Whether the vector mask \p MaskVal has all lane bits set. 86 static bool isAllTrueMask(Value *MaskVal) { 87 if (Value *SplattedVal = getSplatValue(MaskVal)) 88 if (auto *ConstValue = dyn_cast<Constant>(SplattedVal)) 89 return ConstValue->isAllOnesValue(); 90 91 return false; 92 } 93 94 /// \returns A non-excepting divisor constant for this type. 95 static Constant *getSafeDivisor(Type *DivTy) { 96 assert(DivTy->isIntOrIntVectorTy() && "Unsupported divisor type"); 97 return ConstantInt::get(DivTy, 1u, false); 98 } 99 100 /// Transfer operation properties from \p OldVPI to \p NewVal. 101 static void transferDecorations(Value &NewVal, VPIntrinsic &VPI) { 102 auto *NewInst = dyn_cast<Instruction>(&NewVal); 103 if (!NewInst || !isa<FPMathOperator>(NewVal)) 104 return; 105 106 auto *OldFMOp = dyn_cast<FPMathOperator>(&VPI); 107 if (!OldFMOp) 108 return; 109 110 NewInst->setFastMathFlags(OldFMOp->getFastMathFlags()); 111 } 112 113 /// Transfer all properties from \p OldOp to \p NewOp and replace all uses. 114 /// OldVP gets erased. 115 static void replaceOperation(Value &NewOp, VPIntrinsic &OldOp) { 116 transferDecorations(NewOp, OldOp); 117 OldOp.replaceAllUsesWith(&NewOp); 118 OldOp.eraseFromParent(); 119 } 120 121 static bool maySpeculateLanes(VPIntrinsic &VPI) { 122 // The result of VP reductions depends on the mask and evl. 123 if (isa<VPReductionIntrinsic>(VPI)) 124 return false; 125 // Fallback to whether the intrinsic is speculatable. 126 if (auto IntrID = VPI.getFunctionalIntrinsicID()) 127 return Intrinsic::getAttributes(VPI.getContext(), *IntrID) 128 .hasFnAttr(Attribute::AttrKind::Speculatable); 129 if (auto Opc = VPI.getFunctionalOpcode()) 130 return isSafeToSpeculativelyExecuteWithOpcode(*Opc, &VPI); 131 return false; 132 } 133 134 //// } Helpers 135 136 namespace { 137 138 // Expansion pass state at function scope. 139 struct CachingVPExpander { 140 Function &F; 141 const TargetTransformInfo &TTI; 142 143 /// \returns A (fixed length) vector with ascending integer indices 144 /// (<0, 1, ..., NumElems-1>). 145 /// \p Builder 146 /// Used for instruction creation. 147 /// \p LaneTy 148 /// Integer element type of the result vector. 149 /// \p NumElems 150 /// Number of vector elements. 151 Value *createStepVector(IRBuilder<> &Builder, Type *LaneTy, 152 unsigned NumElems); 153 154 /// \returns A bitmask that is true where the lane position is less-than \p 155 /// EVLParam 156 /// 157 /// \p Builder 158 /// Used for instruction creation. 159 /// \p VLParam 160 /// The explicit vector length parameter to test against the lane 161 /// positions. 162 /// \p ElemCount 163 /// Static (potentially scalable) number of vector elements. 164 Value *convertEVLToMask(IRBuilder<> &Builder, Value *EVLParam, 165 ElementCount ElemCount); 166 167 Value *foldEVLIntoMask(VPIntrinsic &VPI); 168 169 /// "Remove" the %evl parameter of \p PI by setting it to the static vector 170 /// length of the operation. 171 void discardEVLParameter(VPIntrinsic &PI); 172 173 /// Lower this VP binary operator to a unpredicated binary operator. 174 Value *expandPredicationInBinaryOperator(IRBuilder<> &Builder, 175 VPIntrinsic &PI); 176 177 /// Lower this VP int call to a unpredicated int call. 178 Value *expandPredicationToIntCall(IRBuilder<> &Builder, VPIntrinsic &PI, 179 unsigned UnpredicatedIntrinsicID); 180 181 /// Lower this VP fp call to a unpredicated fp call. 182 Value *expandPredicationToFPCall(IRBuilder<> &Builder, VPIntrinsic &PI, 183 unsigned UnpredicatedIntrinsicID); 184 185 /// Lower this VP reduction to a call to an unpredicated reduction intrinsic. 186 Value *expandPredicationInReduction(IRBuilder<> &Builder, 187 VPReductionIntrinsic &PI); 188 189 /// Lower this VP cast operation to a non-VP intrinsic. 190 Value *expandPredicationToCastIntrinsic(IRBuilder<> &Builder, 191 VPIntrinsic &VPI); 192 193 /// Lower this VP memory operation to a non-VP intrinsic. 194 Value *expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder, 195 VPIntrinsic &VPI); 196 197 /// Lower this VP comparison to a call to an unpredicated comparison. 198 Value *expandPredicationInComparison(IRBuilder<> &Builder, 199 VPCmpIntrinsic &PI); 200 201 /// Query TTI and expand the vector predication in \p P accordingly. 202 Value *expandPredication(VPIntrinsic &PI); 203 204 /// Determine how and whether the VPIntrinsic \p VPI shall be expanded. This 205 /// overrides TTI with the cl::opts listed at the top of this file. 206 VPLegalization getVPLegalizationStrategy(const VPIntrinsic &VPI) const; 207 bool UsingTTIOverrides; 208 209 public: 210 CachingVPExpander(Function &F, const TargetTransformInfo &TTI) 211 : F(F), TTI(TTI), UsingTTIOverrides(anyExpandVPOverridesSet()) {} 212 213 bool expandVectorPredication(); 214 }; 215 216 //// CachingVPExpander { 217 218 Value *CachingVPExpander::createStepVector(IRBuilder<> &Builder, Type *LaneTy, 219 unsigned NumElems) { 220 // TODO add caching 221 SmallVector<Constant *, 16> ConstElems; 222 223 for (unsigned Idx = 0; Idx < NumElems; ++Idx) 224 ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false)); 225 226 return ConstantVector::get(ConstElems); 227 } 228 229 Value *CachingVPExpander::convertEVLToMask(IRBuilder<> &Builder, 230 Value *EVLParam, 231 ElementCount ElemCount) { 232 // TODO add caching 233 // Scalable vector %evl conversion. 234 if (ElemCount.isScalable()) { 235 auto *M = Builder.GetInsertBlock()->getModule(); 236 Type *BoolVecTy = VectorType::get(Builder.getInt1Ty(), ElemCount); 237 Function *ActiveMaskFunc = Intrinsic::getDeclaration( 238 M, Intrinsic::get_active_lane_mask, {BoolVecTy, EVLParam->getType()}); 239 // `get_active_lane_mask` performs an implicit less-than comparison. 240 Value *ConstZero = Builder.getInt32(0); 241 return Builder.CreateCall(ActiveMaskFunc, {ConstZero, EVLParam}); 242 } 243 244 // Fixed vector %evl conversion. 245 Type *LaneTy = EVLParam->getType(); 246 unsigned NumElems = ElemCount.getFixedValue(); 247 Value *VLSplat = Builder.CreateVectorSplat(NumElems, EVLParam); 248 Value *IdxVec = createStepVector(Builder, LaneTy, NumElems); 249 return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat); 250 } 251 252 Value * 253 CachingVPExpander::expandPredicationInBinaryOperator(IRBuilder<> &Builder, 254 VPIntrinsic &VPI) { 255 assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) && 256 "Implicitly dropping %evl in non-speculatable operator!"); 257 258 auto OC = static_cast<Instruction::BinaryOps>(*VPI.getFunctionalOpcode()); 259 assert(Instruction::isBinaryOp(OC)); 260 261 Value *Op0 = VPI.getOperand(0); 262 Value *Op1 = VPI.getOperand(1); 263 Value *Mask = VPI.getMaskParam(); 264 265 // Blend in safe operands. 266 if (Mask && !isAllTrueMask(Mask)) { 267 switch (OC) { 268 default: 269 // Can safely ignore the predicate. 270 break; 271 272 // Division operators need a safe divisor on masked-off lanes (1). 273 case Instruction::UDiv: 274 case Instruction::SDiv: 275 case Instruction::URem: 276 case Instruction::SRem: 277 // 2nd operand must not be zero. 278 Value *SafeDivisor = getSafeDivisor(VPI.getType()); 279 Op1 = Builder.CreateSelect(Mask, Op1, SafeDivisor); 280 } 281 } 282 283 Value *NewBinOp = Builder.CreateBinOp(OC, Op0, Op1, VPI.getName()); 284 285 replaceOperation(*NewBinOp, VPI); 286 return NewBinOp; 287 } 288 289 Value *CachingVPExpander::expandPredicationToIntCall( 290 IRBuilder<> &Builder, VPIntrinsic &VPI, unsigned UnpredicatedIntrinsicID) { 291 switch (UnpredicatedIntrinsicID) { 292 case Intrinsic::abs: 293 case Intrinsic::smax: 294 case Intrinsic::smin: 295 case Intrinsic::umax: 296 case Intrinsic::umin: { 297 Value *Op0 = VPI.getOperand(0); 298 Value *Op1 = VPI.getOperand(1); 299 Function *Fn = Intrinsic::getDeclaration( 300 VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()}); 301 Value *NewOp = Builder.CreateCall(Fn, {Op0, Op1}, VPI.getName()); 302 replaceOperation(*NewOp, VPI); 303 return NewOp; 304 } 305 case Intrinsic::bswap: 306 case Intrinsic::bitreverse: { 307 Value *Op = VPI.getOperand(0); 308 Function *Fn = Intrinsic::getDeclaration( 309 VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()}); 310 Value *NewOp = Builder.CreateCall(Fn, {Op}, VPI.getName()); 311 replaceOperation(*NewOp, VPI); 312 return NewOp; 313 } 314 } 315 return nullptr; 316 } 317 318 Value *CachingVPExpander::expandPredicationToFPCall( 319 IRBuilder<> &Builder, VPIntrinsic &VPI, unsigned UnpredicatedIntrinsicID) { 320 assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) && 321 "Implicitly dropping %evl in non-speculatable operator!"); 322 323 switch (UnpredicatedIntrinsicID) { 324 case Intrinsic::fabs: 325 case Intrinsic::sqrt: { 326 Value *Op0 = VPI.getOperand(0); 327 Function *Fn = Intrinsic::getDeclaration( 328 VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()}); 329 Value *NewOp = Builder.CreateCall(Fn, {Op0}, VPI.getName()); 330 replaceOperation(*NewOp, VPI); 331 return NewOp; 332 } 333 case Intrinsic::maxnum: 334 case Intrinsic::minnum: { 335 Value *Op0 = VPI.getOperand(0); 336 Value *Op1 = VPI.getOperand(1); 337 Function *Fn = Intrinsic::getDeclaration( 338 VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()}); 339 Value *NewOp = Builder.CreateCall(Fn, {Op0, Op1}, VPI.getName()); 340 replaceOperation(*NewOp, VPI); 341 return NewOp; 342 } 343 case Intrinsic::fma: 344 case Intrinsic::fmuladd: 345 case Intrinsic::experimental_constrained_fma: 346 case Intrinsic::experimental_constrained_fmuladd: { 347 Value *Op0 = VPI.getOperand(0); 348 Value *Op1 = VPI.getOperand(1); 349 Value *Op2 = VPI.getOperand(2); 350 Function *Fn = Intrinsic::getDeclaration( 351 VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()}); 352 Value *NewOp; 353 if (Intrinsic::isConstrainedFPIntrinsic(UnpredicatedIntrinsicID)) 354 NewOp = 355 Builder.CreateConstrainedFPCall(Fn, {Op0, Op1, Op2}, VPI.getName()); 356 else 357 NewOp = Builder.CreateCall(Fn, {Op0, Op1, Op2}, VPI.getName()); 358 replaceOperation(*NewOp, VPI); 359 return NewOp; 360 } 361 } 362 363 return nullptr; 364 } 365 366 static Value *getNeutralReductionElement(const VPReductionIntrinsic &VPI, 367 Type *EltTy) { 368 bool Negative = false; 369 unsigned EltBits = EltTy->getScalarSizeInBits(); 370 switch (VPI.getIntrinsicID()) { 371 default: 372 llvm_unreachable("Expecting a VP reduction intrinsic"); 373 case Intrinsic::vp_reduce_add: 374 case Intrinsic::vp_reduce_or: 375 case Intrinsic::vp_reduce_xor: 376 case Intrinsic::vp_reduce_umax: 377 return Constant::getNullValue(EltTy); 378 case Intrinsic::vp_reduce_mul: 379 return ConstantInt::get(EltTy, 1, /*IsSigned*/ false); 380 case Intrinsic::vp_reduce_and: 381 case Intrinsic::vp_reduce_umin: 382 return ConstantInt::getAllOnesValue(EltTy); 383 case Intrinsic::vp_reduce_smin: 384 return ConstantInt::get(EltTy->getContext(), 385 APInt::getSignedMaxValue(EltBits)); 386 case Intrinsic::vp_reduce_smax: 387 return ConstantInt::get(EltTy->getContext(), 388 APInt::getSignedMinValue(EltBits)); 389 case Intrinsic::vp_reduce_fmax: 390 Negative = true; 391 [[fallthrough]]; 392 case Intrinsic::vp_reduce_fmin: { 393 FastMathFlags Flags = VPI.getFastMathFlags(); 394 const fltSemantics &Semantics = EltTy->getFltSemantics(); 395 return !Flags.noNaNs() ? ConstantFP::getQNaN(EltTy, Negative) 396 : !Flags.noInfs() 397 ? ConstantFP::getInfinity(EltTy, Negative) 398 : ConstantFP::get(EltTy, 399 APFloat::getLargest(Semantics, Negative)); 400 } 401 case Intrinsic::vp_reduce_fadd: 402 return ConstantFP::getNegativeZero(EltTy); 403 case Intrinsic::vp_reduce_fmul: 404 return ConstantFP::get(EltTy, 1.0); 405 } 406 } 407 408 Value * 409 CachingVPExpander::expandPredicationInReduction(IRBuilder<> &Builder, 410 VPReductionIntrinsic &VPI) { 411 assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) && 412 "Implicitly dropping %evl in non-speculatable operator!"); 413 414 Value *Mask = VPI.getMaskParam(); 415 Value *RedOp = VPI.getOperand(VPI.getVectorParamPos()); 416 417 // Insert neutral element in masked-out positions 418 if (Mask && !isAllTrueMask(Mask)) { 419 auto *NeutralElt = getNeutralReductionElement(VPI, VPI.getType()); 420 auto *NeutralVector = Builder.CreateVectorSplat( 421 cast<VectorType>(RedOp->getType())->getElementCount(), NeutralElt); 422 RedOp = Builder.CreateSelect(Mask, RedOp, NeutralVector); 423 } 424 425 Value *Reduction; 426 Value *Start = VPI.getOperand(VPI.getStartParamPos()); 427 428 switch (VPI.getIntrinsicID()) { 429 default: 430 llvm_unreachable("Impossible reduction kind"); 431 case Intrinsic::vp_reduce_add: 432 Reduction = Builder.CreateAddReduce(RedOp); 433 Reduction = Builder.CreateAdd(Reduction, Start); 434 break; 435 case Intrinsic::vp_reduce_mul: 436 Reduction = Builder.CreateMulReduce(RedOp); 437 Reduction = Builder.CreateMul(Reduction, Start); 438 break; 439 case Intrinsic::vp_reduce_and: 440 Reduction = Builder.CreateAndReduce(RedOp); 441 Reduction = Builder.CreateAnd(Reduction, Start); 442 break; 443 case Intrinsic::vp_reduce_or: 444 Reduction = Builder.CreateOrReduce(RedOp); 445 Reduction = Builder.CreateOr(Reduction, Start); 446 break; 447 case Intrinsic::vp_reduce_xor: 448 Reduction = Builder.CreateXorReduce(RedOp); 449 Reduction = Builder.CreateXor(Reduction, Start); 450 break; 451 case Intrinsic::vp_reduce_smax: 452 Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ true); 453 Reduction = 454 Builder.CreateBinaryIntrinsic(Intrinsic::smax, Reduction, Start); 455 break; 456 case Intrinsic::vp_reduce_smin: 457 Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ true); 458 Reduction = 459 Builder.CreateBinaryIntrinsic(Intrinsic::smin, Reduction, Start); 460 break; 461 case Intrinsic::vp_reduce_umax: 462 Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ false); 463 Reduction = 464 Builder.CreateBinaryIntrinsic(Intrinsic::umax, Reduction, Start); 465 break; 466 case Intrinsic::vp_reduce_umin: 467 Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ false); 468 Reduction = 469 Builder.CreateBinaryIntrinsic(Intrinsic::umin, Reduction, Start); 470 break; 471 case Intrinsic::vp_reduce_fmax: 472 Reduction = Builder.CreateFPMaxReduce(RedOp); 473 transferDecorations(*Reduction, VPI); 474 Reduction = 475 Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, Reduction, Start); 476 break; 477 case Intrinsic::vp_reduce_fmin: 478 Reduction = Builder.CreateFPMinReduce(RedOp); 479 transferDecorations(*Reduction, VPI); 480 Reduction = 481 Builder.CreateBinaryIntrinsic(Intrinsic::minnum, Reduction, Start); 482 break; 483 case Intrinsic::vp_reduce_fadd: 484 Reduction = Builder.CreateFAddReduce(Start, RedOp); 485 break; 486 case Intrinsic::vp_reduce_fmul: 487 Reduction = Builder.CreateFMulReduce(Start, RedOp); 488 break; 489 } 490 491 replaceOperation(*Reduction, VPI); 492 return Reduction; 493 } 494 495 Value *CachingVPExpander::expandPredicationToCastIntrinsic(IRBuilder<> &Builder, 496 VPIntrinsic &VPI) { 497 Value *CastOp = nullptr; 498 switch (VPI.getIntrinsicID()) { 499 default: 500 llvm_unreachable("Not a VP cast intrinsic"); 501 case Intrinsic::vp_sext: 502 CastOp = 503 Builder.CreateSExt(VPI.getOperand(0), VPI.getType(), VPI.getName()); 504 break; 505 case Intrinsic::vp_zext: 506 CastOp = 507 Builder.CreateZExt(VPI.getOperand(0), VPI.getType(), VPI.getName()); 508 break; 509 case Intrinsic::vp_trunc: 510 CastOp = 511 Builder.CreateTrunc(VPI.getOperand(0), VPI.getType(), VPI.getName()); 512 break; 513 case Intrinsic::vp_inttoptr: 514 CastOp = 515 Builder.CreateIntToPtr(VPI.getOperand(0), VPI.getType(), VPI.getName()); 516 break; 517 case Intrinsic::vp_ptrtoint: 518 CastOp = 519 Builder.CreatePtrToInt(VPI.getOperand(0), VPI.getType(), VPI.getName()); 520 break; 521 case Intrinsic::vp_fptosi: 522 CastOp = 523 Builder.CreateFPToSI(VPI.getOperand(0), VPI.getType(), VPI.getName()); 524 break; 525 526 case Intrinsic::vp_fptoui: 527 CastOp = 528 Builder.CreateFPToUI(VPI.getOperand(0), VPI.getType(), VPI.getName()); 529 break; 530 case Intrinsic::vp_sitofp: 531 CastOp = 532 Builder.CreateSIToFP(VPI.getOperand(0), VPI.getType(), VPI.getName()); 533 break; 534 case Intrinsic::vp_uitofp: 535 CastOp = 536 Builder.CreateUIToFP(VPI.getOperand(0), VPI.getType(), VPI.getName()); 537 break; 538 case Intrinsic::vp_fptrunc: 539 CastOp = 540 Builder.CreateFPTrunc(VPI.getOperand(0), VPI.getType(), VPI.getName()); 541 break; 542 case Intrinsic::vp_fpext: 543 CastOp = 544 Builder.CreateFPExt(VPI.getOperand(0), VPI.getType(), VPI.getName()); 545 break; 546 } 547 replaceOperation(*CastOp, VPI); 548 return CastOp; 549 } 550 551 Value * 552 CachingVPExpander::expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder, 553 VPIntrinsic &VPI) { 554 assert(VPI.canIgnoreVectorLengthParam()); 555 556 const auto &DL = F.getParent()->getDataLayout(); 557 558 Value *MaskParam = VPI.getMaskParam(); 559 Value *PtrParam = VPI.getMemoryPointerParam(); 560 Value *DataParam = VPI.getMemoryDataParam(); 561 bool IsUnmasked = isAllTrueMask(MaskParam); 562 563 MaybeAlign AlignOpt = VPI.getPointerAlignment(); 564 565 Value *NewMemoryInst = nullptr; 566 switch (VPI.getIntrinsicID()) { 567 default: 568 llvm_unreachable("Not a VP memory intrinsic"); 569 case Intrinsic::vp_store: 570 if (IsUnmasked) { 571 StoreInst *NewStore = 572 Builder.CreateStore(DataParam, PtrParam, /*IsVolatile*/ false); 573 if (AlignOpt.has_value()) 574 NewStore->setAlignment(*AlignOpt); 575 NewMemoryInst = NewStore; 576 } else 577 NewMemoryInst = Builder.CreateMaskedStore( 578 DataParam, PtrParam, AlignOpt.valueOrOne(), MaskParam); 579 580 break; 581 case Intrinsic::vp_load: 582 if (IsUnmasked) { 583 LoadInst *NewLoad = 584 Builder.CreateLoad(VPI.getType(), PtrParam, /*IsVolatile*/ false); 585 if (AlignOpt.has_value()) 586 NewLoad->setAlignment(*AlignOpt); 587 NewMemoryInst = NewLoad; 588 } else 589 NewMemoryInst = Builder.CreateMaskedLoad( 590 VPI.getType(), PtrParam, AlignOpt.valueOrOne(), MaskParam); 591 592 break; 593 case Intrinsic::vp_scatter: { 594 auto *ElementType = 595 cast<VectorType>(DataParam->getType())->getElementType(); 596 NewMemoryInst = Builder.CreateMaskedScatter( 597 DataParam, PtrParam, 598 AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam); 599 break; 600 } 601 case Intrinsic::vp_gather: { 602 auto *ElementType = cast<VectorType>(VPI.getType())->getElementType(); 603 NewMemoryInst = Builder.CreateMaskedGather( 604 VPI.getType(), PtrParam, 605 AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam, nullptr, 606 VPI.getName()); 607 break; 608 } 609 } 610 611 assert(NewMemoryInst); 612 replaceOperation(*NewMemoryInst, VPI); 613 return NewMemoryInst; 614 } 615 616 Value *CachingVPExpander::expandPredicationInComparison(IRBuilder<> &Builder, 617 VPCmpIntrinsic &VPI) { 618 assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) && 619 "Implicitly dropping %evl in non-speculatable operator!"); 620 621 assert(*VPI.getFunctionalOpcode() == Instruction::ICmp || 622 *VPI.getFunctionalOpcode() == Instruction::FCmp); 623 624 Value *Op0 = VPI.getOperand(0); 625 Value *Op1 = VPI.getOperand(1); 626 auto Pred = VPI.getPredicate(); 627 628 auto *NewCmp = Builder.CreateCmp(Pred, Op0, Op1); 629 630 replaceOperation(*NewCmp, VPI); 631 return NewCmp; 632 } 633 634 void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) { 635 LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n"); 636 637 if (VPI.canIgnoreVectorLengthParam()) 638 return; 639 640 Value *EVLParam = VPI.getVectorLengthParam(); 641 if (!EVLParam) 642 return; 643 644 ElementCount StaticElemCount = VPI.getStaticVectorLength(); 645 Value *MaxEVL = nullptr; 646 Type *Int32Ty = Type::getInt32Ty(VPI.getContext()); 647 if (StaticElemCount.isScalable()) { 648 // TODO add caching 649 auto *M = VPI.getModule(); 650 Function *VScaleFunc = 651 Intrinsic::getDeclaration(M, Intrinsic::vscale, Int32Ty); 652 IRBuilder<> Builder(VPI.getParent(), VPI.getIterator()); 653 Value *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue()); 654 Value *VScale = Builder.CreateCall(VScaleFunc, {}, "vscale"); 655 MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size", 656 /*NUW*/ true, /*NSW*/ false); 657 } else { 658 MaxEVL = ConstantInt::get(Int32Ty, StaticElemCount.getFixedValue(), false); 659 } 660 VPI.setVectorLengthParam(MaxEVL); 661 } 662 663 Value *CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) { 664 LLVM_DEBUG(dbgs() << "Folding vlen for " << VPI << '\n'); 665 666 IRBuilder<> Builder(&VPI); 667 668 // Ineffective %evl parameter and so nothing to do here. 669 if (VPI.canIgnoreVectorLengthParam()) 670 return &VPI; 671 672 // Only VP intrinsics can have an %evl parameter. 673 Value *OldMaskParam = VPI.getMaskParam(); 674 Value *OldEVLParam = VPI.getVectorLengthParam(); 675 assert(OldMaskParam && "no mask param to fold the vl param into"); 676 assert(OldEVLParam && "no EVL param to fold away"); 677 678 LLVM_DEBUG(dbgs() << "OLD evl: " << *OldEVLParam << '\n'); 679 LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n'); 680 681 // Convert the %evl predication into vector mask predication. 682 ElementCount ElemCount = VPI.getStaticVectorLength(); 683 Value *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount); 684 Value *NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam); 685 VPI.setMaskParam(NewMaskParam); 686 687 // Drop the %evl parameter. 688 discardEVLParameter(VPI); 689 assert(VPI.canIgnoreVectorLengthParam() && 690 "transformation did not render the evl param ineffective!"); 691 692 // Reassess the modified instruction. 693 return &VPI; 694 } 695 696 Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) { 697 LLVM_DEBUG(dbgs() << "Lowering to unpredicated op: " << VPI << '\n'); 698 699 IRBuilder<> Builder(&VPI); 700 701 // Try lowering to a LLVM instruction first. 702 auto OC = VPI.getFunctionalOpcode(); 703 704 if (OC && Instruction::isBinaryOp(*OC)) 705 return expandPredicationInBinaryOperator(Builder, VPI); 706 707 if (auto *VPRI = dyn_cast<VPReductionIntrinsic>(&VPI)) 708 return expandPredicationInReduction(Builder, *VPRI); 709 710 if (auto *VPCmp = dyn_cast<VPCmpIntrinsic>(&VPI)) 711 return expandPredicationInComparison(Builder, *VPCmp); 712 713 if (VPCastIntrinsic::isVPCast(VPI.getIntrinsicID())) { 714 return expandPredicationToCastIntrinsic(Builder, VPI); 715 } 716 717 switch (VPI.getIntrinsicID()) { 718 default: 719 break; 720 case Intrinsic::vp_fneg: { 721 Value *NewNegOp = Builder.CreateFNeg(VPI.getOperand(0), VPI.getName()); 722 replaceOperation(*NewNegOp, VPI); 723 return NewNegOp; 724 } 725 case Intrinsic::vp_abs: 726 case Intrinsic::vp_smax: 727 case Intrinsic::vp_smin: 728 case Intrinsic::vp_umax: 729 case Intrinsic::vp_umin: 730 case Intrinsic::vp_bswap: 731 case Intrinsic::vp_bitreverse: 732 return expandPredicationToIntCall(Builder, VPI, 733 VPI.getFunctionalIntrinsicID().value()); 734 case Intrinsic::vp_fabs: 735 case Intrinsic::vp_sqrt: 736 case Intrinsic::vp_maxnum: 737 case Intrinsic::vp_minnum: 738 case Intrinsic::vp_maximum: 739 case Intrinsic::vp_minimum: 740 case Intrinsic::vp_fma: 741 case Intrinsic::vp_fmuladd: 742 return expandPredicationToFPCall(Builder, VPI, 743 VPI.getFunctionalIntrinsicID().value()); 744 case Intrinsic::vp_load: 745 case Intrinsic::vp_store: 746 case Intrinsic::vp_gather: 747 case Intrinsic::vp_scatter: 748 return expandPredicationInMemoryIntrinsic(Builder, VPI); 749 } 750 751 if (auto CID = VPI.getConstrainedIntrinsicID()) 752 if (Value *Call = expandPredicationToFPCall(Builder, VPI, *CID)) 753 return Call; 754 755 return &VPI; 756 } 757 758 //// } CachingVPExpander 759 760 struct TransformJob { 761 VPIntrinsic *PI; 762 TargetTransformInfo::VPLegalization Strategy; 763 TransformJob(VPIntrinsic *PI, TargetTransformInfo::VPLegalization InitStrat) 764 : PI(PI), Strategy(InitStrat) {} 765 766 bool isDone() const { return Strategy.shouldDoNothing(); } 767 }; 768 769 void sanitizeStrategy(VPIntrinsic &VPI, VPLegalization &LegalizeStrat) { 770 // Operations with speculatable lanes do not strictly need predication. 771 if (maySpeculateLanes(VPI)) { 772 // Converting a speculatable VP intrinsic means dropping %mask and %evl. 773 // No need to expand %evl into the %mask only to ignore that code. 774 if (LegalizeStrat.OpStrategy == VPLegalization::Convert) 775 LegalizeStrat.EVLParamStrategy = VPLegalization::Discard; 776 return; 777 } 778 779 // We have to preserve the predicating effect of %evl for this 780 // non-speculatable VP intrinsic. 781 // 1) Never discard %evl. 782 // 2) If this VP intrinsic will be expanded to non-VP code, make sure that 783 // %evl gets folded into %mask. 784 if ((LegalizeStrat.EVLParamStrategy == VPLegalization::Discard) || 785 (LegalizeStrat.OpStrategy == VPLegalization::Convert)) { 786 LegalizeStrat.EVLParamStrategy = VPLegalization::Convert; 787 } 788 } 789 790 VPLegalization 791 CachingVPExpander::getVPLegalizationStrategy(const VPIntrinsic &VPI) const { 792 auto VPStrat = TTI.getVPLegalizationStrategy(VPI); 793 if (LLVM_LIKELY(!UsingTTIOverrides)) { 794 // No overrides - we are in production. 795 return VPStrat; 796 } 797 798 // Overrides set - we are in testing, the following does not need to be 799 // efficient. 800 VPStrat.EVLParamStrategy = parseOverrideOption(EVLTransformOverride); 801 VPStrat.OpStrategy = parseOverrideOption(MaskTransformOverride); 802 return VPStrat; 803 } 804 805 /// Expand llvm.vp.* intrinsics as requested by \p TTI. 806 bool CachingVPExpander::expandVectorPredication() { 807 SmallVector<TransformJob, 16> Worklist; 808 809 // Collect all VPIntrinsics that need expansion and determine their expansion 810 // strategy. 811 for (auto &I : instructions(F)) { 812 auto *VPI = dyn_cast<VPIntrinsic>(&I); 813 if (!VPI) 814 continue; 815 auto VPStrat = getVPLegalizationStrategy(*VPI); 816 sanitizeStrategy(*VPI, VPStrat); 817 if (!VPStrat.shouldDoNothing()) 818 Worklist.emplace_back(VPI, VPStrat); 819 } 820 if (Worklist.empty()) 821 return false; 822 823 // Transform all VPIntrinsics on the worklist. 824 LLVM_DEBUG(dbgs() << "\n:::: Transforming " << Worklist.size() 825 << " instructions ::::\n"); 826 for (TransformJob Job : Worklist) { 827 // Transform the EVL parameter. 828 switch (Job.Strategy.EVLParamStrategy) { 829 case VPLegalization::Legal: 830 break; 831 case VPLegalization::Discard: 832 discardEVLParameter(*Job.PI); 833 break; 834 case VPLegalization::Convert: 835 if (foldEVLIntoMask(*Job.PI)) 836 ++NumFoldedVL; 837 break; 838 } 839 Job.Strategy.EVLParamStrategy = VPLegalization::Legal; 840 841 // Replace with a non-predicated operation. 842 switch (Job.Strategy.OpStrategy) { 843 case VPLegalization::Legal: 844 break; 845 case VPLegalization::Discard: 846 llvm_unreachable("Invalid strategy for operators."); 847 case VPLegalization::Convert: 848 expandPredication(*Job.PI); 849 ++NumLoweredVPOps; 850 break; 851 } 852 Job.Strategy.OpStrategy = VPLegalization::Legal; 853 854 assert(Job.isDone() && "incomplete transformation"); 855 } 856 857 return true; 858 } 859 class ExpandVectorPredication : public FunctionPass { 860 public: 861 static char ID; 862 ExpandVectorPredication() : FunctionPass(ID) { 863 initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry()); 864 } 865 866 bool runOnFunction(Function &F) override { 867 const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 868 CachingVPExpander VPExpander(F, *TTI); 869 return VPExpander.expandVectorPredication(); 870 } 871 872 void getAnalysisUsage(AnalysisUsage &AU) const override { 873 AU.addRequired<TargetTransformInfoWrapperPass>(); 874 AU.setPreservesCFG(); 875 } 876 }; 877 } // namespace 878 879 char ExpandVectorPredication::ID; 880 INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expandvp", 881 "Expand vector predication intrinsics", false, false) 882 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 883 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 884 INITIALIZE_PASS_END(ExpandVectorPredication, "expandvp", 885 "Expand vector predication intrinsics", false, false) 886 887 FunctionPass *llvm::createExpandVectorPredicationPass() { 888 return new ExpandVectorPredication(); 889 } 890 891 PreservedAnalyses 892 ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) { 893 const auto &TTI = AM.getResult<TargetIRAnalysis>(F); 894 CachingVPExpander VPExpander(F, TTI); 895 if (!VPExpander.expandVectorPredication()) 896 return PreservedAnalyses::all(); 897 PreservedAnalyses PA; 898 PA.preserveSet<CFGAnalyses>(); 899 return PA; 900 } 901