1 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines some vectorizer utilities. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_ANALYSIS_VECTORUTILS_H 14 #define LLVM_ANALYSIS_VECTORUTILS_H 15 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/SmallVector.h" 18 #include "llvm/Analysis/LoopAccessAnalysis.h" 19 #include "llvm/IR/Module.h" 20 #include "llvm/IR/VFABIDemangler.h" 21 #include "llvm/Support/CheckedArithmetic.h" 22 23 namespace llvm { 24 class TargetLibraryInfo; 25 26 /// The Vector Function Database. 27 /// 28 /// Helper class used to find the vector functions associated to a 29 /// scalar CallInst. 30 class VFDatabase { 31 /// The Module of the CallInst CI. 32 const Module *M; 33 /// The CallInst instance being queried for scalar to vector mappings. 34 const CallInst &CI; 35 /// List of vector functions descriptors associated to the call 36 /// instruction. 37 const SmallVector<VFInfo, 8> ScalarToVectorMappings; 38 39 /// Retrieve the scalar-to-vector mappings associated to the rule of 40 /// a vector Function ABI. 41 static void getVFABIMappings(const CallInst &CI, 42 SmallVectorImpl<VFInfo> &Mappings) { 43 if (!CI.getCalledFunction()) 44 return; 45 46 const StringRef ScalarName = CI.getCalledFunction()->getName(); 47 48 SmallVector<std::string, 8> ListOfStrings; 49 // The check for the vector-function-abi-variant attribute is done when 50 // retrieving the vector variant names here. 51 VFABI::getVectorVariantNames(CI, ListOfStrings); 52 if (ListOfStrings.empty()) 53 return; 54 for (const auto &MangledName : ListOfStrings) { 55 const std::optional<VFInfo> Shape = 56 VFABI::tryDemangleForVFABI(MangledName, CI.getFunctionType()); 57 // A match is found via scalar and vector names, and also by 58 // ensuring that the variant described in the attribute has a 59 // corresponding definition or declaration of the vector 60 // function in the Module M. 61 if (Shape && (Shape->ScalarName == ScalarName)) { 62 assert(CI.getModule()->getFunction(Shape->VectorName) && 63 "Vector function is missing."); 64 Mappings.push_back(*Shape); 65 } 66 } 67 } 68 69 public: 70 /// Retrieve all the VFInfo instances associated to the CallInst CI. 71 static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) { 72 SmallVector<VFInfo, 8> Ret; 73 74 // Get mappings from the Vector Function ABI variants. 75 getVFABIMappings(CI, Ret); 76 77 // Other non-VFABI variants should be retrieved here. 78 79 return Ret; 80 } 81 82 static bool hasMaskedVariant(const CallInst &CI, 83 std::optional<ElementCount> VF = std::nullopt) { 84 // Check whether we have at least one masked vector version of a scalar 85 // function. If no VF is specified then we check for any masked variant, 86 // otherwise we look for one that matches the supplied VF. 87 auto Mappings = VFDatabase::getMappings(CI); 88 for (VFInfo Info : Mappings) 89 if (!VF || Info.Shape.VF == *VF) 90 if (Info.isMasked()) 91 return true; 92 93 return false; 94 } 95 96 /// Constructor, requires a CallInst instance. 97 VFDatabase(CallInst &CI) 98 : M(CI.getModule()), CI(CI), 99 ScalarToVectorMappings(VFDatabase::getMappings(CI)) {} 100 101 /// \defgroup VFDatabase query interface. 102 /// 103 /// @{ 104 /// Retrieve the Function with VFShape \p Shape. 105 Function *getVectorizedFunction(const VFShape &Shape) const { 106 if (Shape == VFShape::getScalarShape(CI.getFunctionType())) 107 return CI.getCalledFunction(); 108 109 for (const auto &Info : ScalarToVectorMappings) 110 if (Info.Shape == Shape) 111 return M->getFunction(Info.VectorName); 112 113 return nullptr; 114 } 115 /// @} 116 }; 117 118 template <typename T> class ArrayRef; 119 class DemandedBits; 120 template <typename InstTy> class InterleaveGroup; 121 class IRBuilderBase; 122 class Loop; 123 class ScalarEvolution; 124 class TargetTransformInfo; 125 class Type; 126 class Value; 127 128 namespace Intrinsic { 129 typedef unsigned ID; 130 } 131 132 /// A helper function for converting Scalar types to vector types. If 133 /// the incoming type is void, we return void. If the EC represents a 134 /// scalar, we return the scalar type. 135 inline Type *ToVectorTy(Type *Scalar, ElementCount EC) { 136 if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar()) 137 return Scalar; 138 return VectorType::get(Scalar, EC); 139 } 140 141 inline Type *ToVectorTy(Type *Scalar, unsigned VF) { 142 return ToVectorTy(Scalar, ElementCount::getFixed(VF)); 143 } 144 145 /// Identify if the intrinsic is trivially vectorizable. 146 /// This method returns true if the intrinsic's argument types are all scalars 147 /// for the scalar form of the intrinsic and all vectors (or scalars handled by 148 /// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic. 149 bool isTriviallyVectorizable(Intrinsic::ID ID); 150 151 /// Identifies if the vector form of the intrinsic has a scalar operand. 152 bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, 153 unsigned ScalarOpdIdx); 154 155 /// Identifies if the vector form of the intrinsic is overloaded on the type of 156 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1. 157 bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx); 158 159 /// Returns intrinsic ID for call. 160 /// For the input call instruction it finds mapping intrinsic and returns 161 /// its intrinsic ID, in case it does not found it return not_intrinsic. 162 Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI, 163 const TargetLibraryInfo *TLI); 164 165 /// Given a vector and an element number, see if the scalar value is 166 /// already around as a register, for example if it were inserted then extracted 167 /// from the vector. 168 Value *findScalarElement(Value *V, unsigned EltNo); 169 170 /// If all non-negative \p Mask elements are the same value, return that value. 171 /// If all elements are negative (undefined) or \p Mask contains different 172 /// non-negative values, return -1. 173 int getSplatIndex(ArrayRef<int> Mask); 174 175 /// Get splat value if the input is a splat vector or return nullptr. 176 /// The value may be extracted from a splat constants vector or from 177 /// a sequence of instructions that broadcast a single value into a vector. 178 Value *getSplatValue(const Value *V); 179 180 /// Return true if each element of the vector value \p V is poisoned or equal to 181 /// every other non-poisoned element. If an index element is specified, either 182 /// every element of the vector is poisoned or the element at that index is not 183 /// poisoned and equal to every other non-poisoned element. 184 /// This may be more powerful than the related getSplatValue() because it is 185 /// not limited by finding a scalar source value to a splatted vector. 186 bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0); 187 188 /// Transform a shuffle mask's output demanded element mask into demanded 189 /// element masks for the 2 operands, returns false if the mask isn't valid. 190 /// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth]. 191 /// \p AllowUndefElts permits "-1" indices to be treated as undef. 192 bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask, 193 const APInt &DemandedElts, APInt &DemandedLHS, 194 APInt &DemandedRHS, bool AllowUndefElts = false); 195 196 /// Replace each shuffle mask index with the scaled sequential indices for an 197 /// equivalent mask of narrowed elements. Mask elements that are less than 0 198 /// (sentinel values) are repeated in the output mask. 199 /// 200 /// Example with Scale = 4: 201 /// <4 x i32> <3, 2, 0, -1> --> 202 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> 203 /// 204 /// This is the reverse process of widening shuffle mask elements, but it always 205 /// succeeds because the indexes can always be multiplied (scaled up) to map to 206 /// narrower vector elements. 207 void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask, 208 SmallVectorImpl<int> &ScaledMask); 209 210 /// Try to transform a shuffle mask by replacing elements with the scaled index 211 /// for an equivalent mask of widened elements. If all mask elements that would 212 /// map to a wider element of the new mask are the same negative number 213 /// (sentinel value), that element of the new mask is the same value. If any 214 /// element in a given slice is negative and some other element in that slice is 215 /// not the same value, return false (partial matches with sentinel values are 216 /// not allowed). 217 /// 218 /// Example with Scale = 4: 219 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> --> 220 /// <4 x i32> <3, 2, 0, -1> 221 /// 222 /// This is the reverse process of narrowing shuffle mask elements if it 223 /// succeeds. This transform is not always possible because indexes may not 224 /// divide evenly (scale down) to map to wider vector elements. 225 bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask, 226 SmallVectorImpl<int> &ScaledMask); 227 228 /// Attempt to narrow/widen the \p Mask shuffle mask to the \p NumDstElts target 229 /// width. Internally this will call narrowShuffleMaskElts/widenShuffleMaskElts. 230 /// This will assert unless NumDstElts is a multiple of Mask.size (or vice-versa). 231 /// Returns false on failure, and ScaledMask will be in an undefined state. 232 bool scaleShuffleMaskElts(unsigned NumDstElts, ArrayRef<int> Mask, 233 SmallVectorImpl<int> &ScaledMask); 234 235 /// Repetitively apply `widenShuffleMaskElts()` for as long as it succeeds, 236 /// to get the shuffle mask with widest possible elements. 237 void getShuffleMaskWithWidestElts(ArrayRef<int> Mask, 238 SmallVectorImpl<int> &ScaledMask); 239 240 /// Splits and processes shuffle mask depending on the number of input and 241 /// output registers. The function does 2 main things: 1) splits the 242 /// source/destination vectors into real registers; 2) do the mask analysis to 243 /// identify which real registers are permuted. Then the function processes 244 /// resulting registers mask using provided action items. If no input register 245 /// is defined, \p NoInputAction action is used. If only 1 input register is 246 /// used, \p SingleInputAction is used, otherwise \p ManyInputsAction is used to 247 /// process > 2 input registers and masks. 248 /// \param Mask Original shuffle mask. 249 /// \param NumOfSrcRegs Number of source registers. 250 /// \param NumOfDestRegs Number of destination registers. 251 /// \param NumOfUsedRegs Number of actually used destination registers. 252 void processShuffleMasks( 253 ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs, 254 unsigned NumOfUsedRegs, function_ref<void()> NoInputAction, 255 function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction, 256 function_ref<void(ArrayRef<int>, unsigned, unsigned)> ManyInputsAction); 257 258 /// Compute the demanded elements mask of horizontal binary operations. A 259 /// horizontal operation combines two adjacent elements in a vector operand. 260 /// This function returns a mask for the elements that correspond to the first 261 /// operand of this horizontal combination. For example, for two vectors 262 /// [X1, X2, X3, X4] and [Y1, Y2, Y3, Y4], the resulting mask can include the 263 /// elements X1, X3, Y1, and Y3. To get the other operands, simply shift the 264 /// result of this function to the left by 1. 265 /// 266 /// \param VectorBitWidth the total bit width of the vector 267 /// \param DemandedElts the demanded elements mask for the operation 268 /// \param DemandedLHS the demanded elements mask for the left operand 269 /// \param DemandedRHS the demanded elements mask for the right operand 270 void getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth, 271 const APInt &DemandedElts, 272 APInt &DemandedLHS, 273 APInt &DemandedRHS); 274 275 /// Compute a map of integer instructions to their minimum legal type 276 /// size. 277 /// 278 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int 279 /// type (e.g. i32) whenever arithmetic is performed on them. 280 /// 281 /// For targets with native i8 or i16 operations, usually InstCombine can shrink 282 /// the arithmetic type down again. However InstCombine refuses to create 283 /// illegal types, so for targets without i8 or i16 registers, the lengthening 284 /// and shrinking remains. 285 /// 286 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when 287 /// their scalar equivalents do not, so during vectorization it is important to 288 /// remove these lengthens and truncates when deciding the profitability of 289 /// vectorization. 290 /// 291 /// This function analyzes the given range of instructions and determines the 292 /// minimum type size each can be converted to. It attempts to remove or 293 /// minimize type size changes across each def-use chain, so for example in the 294 /// following code: 295 /// 296 /// %1 = load i8, i8* 297 /// %2 = add i8 %1, 2 298 /// %3 = load i16, i16* 299 /// %4 = zext i8 %2 to i32 300 /// %5 = zext i16 %3 to i32 301 /// %6 = add i32 %4, %5 302 /// %7 = trunc i32 %6 to i16 303 /// 304 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes 305 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}. 306 /// 307 /// If the optional TargetTransformInfo is provided, this function tries harder 308 /// to do less work by only looking at illegal types. 309 MapVector<Instruction*, uint64_t> 310 computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks, 311 DemandedBits &DB, 312 const TargetTransformInfo *TTI=nullptr); 313 314 /// Compute the union of two access-group lists. 315 /// 316 /// If the list contains just one access group, it is returned directly. If the 317 /// list is empty, returns nullptr. 318 MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2); 319 320 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2 321 /// are both in. If either instruction does not access memory at all, it is 322 /// considered to be in every list. 323 /// 324 /// If the list contains just one access group, it is returned directly. If the 325 /// list is empty, returns nullptr. 326 MDNode *intersectAccessGroups(const Instruction *Inst1, 327 const Instruction *Inst2); 328 329 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath, 330 /// MD_nontemporal, MD_access_group, MD_mmra]. 331 /// For K in Kinds, we get the MDNode for K from each of the 332 /// elements of VL, compute their "intersection" (i.e., the most generic 333 /// metadata value that covers all of the individual values), and set I's 334 /// metadata for M equal to the intersection value. 335 /// 336 /// This function always sets a (possibly null) value for each K in Kinds. 337 Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL); 338 339 /// Create a mask that filters the members of an interleave group where there 340 /// are gaps. 341 /// 342 /// For example, the mask for \p Group with interleave-factor 3 343 /// and \p VF 4, that has only its first member present is: 344 /// 345 /// <1,0,0,1,0,0,1,0,0,1,0,0> 346 /// 347 /// Note: The result is a mask of 0's and 1's, as opposed to the other 348 /// create[*]Mask() utilities which create a shuffle mask (mask that 349 /// consists of indices). 350 Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF, 351 const InterleaveGroup<Instruction> &Group); 352 353 /// Create a mask with replicated elements. 354 /// 355 /// This function creates a shuffle mask for replicating each of the \p VF 356 /// elements in a vector \p ReplicationFactor times. It can be used to 357 /// transform a mask of \p VF elements into a mask of 358 /// \p VF * \p ReplicationFactor elements used by a predicated 359 /// interleaved-group of loads/stores whose Interleaved-factor == 360 /// \p ReplicationFactor. 361 /// 362 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is: 363 /// 364 /// <0,0,0,1,1,1,2,2,2,3,3,3> 365 llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor, 366 unsigned VF); 367 368 /// Create an interleave shuffle mask. 369 /// 370 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of 371 /// vectorization factor \p VF into a single wide vector. The mask is of the 372 /// form: 373 /// 374 /// <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...> 375 /// 376 /// For example, the mask for VF = 4 and NumVecs = 2 is: 377 /// 378 /// <0, 4, 1, 5, 2, 6, 3, 7>. 379 llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs); 380 381 /// Create a stride shuffle mask. 382 /// 383 /// This function creates a shuffle mask whose elements begin at \p Start and 384 /// are incremented by \p Stride. The mask can be used to deinterleave an 385 /// interleaved vector into separate vectors of vectorization factor \p VF. The 386 /// mask is of the form: 387 /// 388 /// <Start, Start + Stride, ..., Start + Stride * (VF - 1)> 389 /// 390 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is: 391 /// 392 /// <0, 2, 4, 6> 393 llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride, 394 unsigned VF); 395 396 /// Create a sequential shuffle mask. 397 /// 398 /// This function creates shuffle mask whose elements are sequential and begin 399 /// at \p Start. The mask contains \p NumInts integers and is padded with \p 400 /// NumUndefs undef values. The mask is of the form: 401 /// 402 /// <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs> 403 /// 404 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is: 405 /// 406 /// <0, 1, 2, 3, undef, undef, undef, undef> 407 llvm::SmallVector<int, 16> 408 createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs); 409 410 /// Given a shuffle mask for a binary shuffle, create the equivalent shuffle 411 /// mask assuming both operands are identical. This assumes that the unary 412 /// shuffle will use elements from operand 0 (operand 1 will be unused). 413 llvm::SmallVector<int, 16> createUnaryMask(ArrayRef<int> Mask, 414 unsigned NumElts); 415 416 /// Concatenate a list of vectors. 417 /// 418 /// This function generates code that concatenate the vectors in \p Vecs into a 419 /// single large vector. The number of vectors should be greater than one, and 420 /// their element types should be the same. The number of elements in the 421 /// vectors should also be the same; however, if the last vector has fewer 422 /// elements, it will be padded with undefs. 423 Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs); 424 425 /// Given a mask vector of i1, Return true if all of the elements of this 426 /// predicate mask are known to be false or undef. That is, return true if all 427 /// lanes can be assumed inactive. 428 bool maskIsAllZeroOrUndef(Value *Mask); 429 430 /// Given a mask vector of i1, Return true if all of the elements of this 431 /// predicate mask are known to be true or undef. That is, return true if all 432 /// lanes can be assumed active. 433 bool maskIsAllOneOrUndef(Value *Mask); 434 435 /// Given a mask vector of i1, Return true if any of the elements of this 436 /// predicate mask are known to be true or undef. That is, return true if at 437 /// least one lane can be assumed active. 438 bool maskContainsAllOneOrUndef(Value *Mask); 439 440 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y) 441 /// for each lane which may be active. 442 APInt possiblyDemandedEltsInMask(Value *Mask); 443 444 /// The group of interleaved loads/stores sharing the same stride and 445 /// close to each other. 446 /// 447 /// Each member in this group has an index starting from 0, and the largest 448 /// index should be less than interleaved factor, which is equal to the absolute 449 /// value of the access's stride. 450 /// 451 /// E.g. An interleaved load group of factor 4: 452 /// for (unsigned i = 0; i < 1024; i+=4) { 453 /// a = A[i]; // Member of index 0 454 /// b = A[i+1]; // Member of index 1 455 /// d = A[i+3]; // Member of index 3 456 /// ... 457 /// } 458 /// 459 /// An interleaved store group of factor 4: 460 /// for (unsigned i = 0; i < 1024; i+=4) { 461 /// ... 462 /// A[i] = a; // Member of index 0 463 /// A[i+1] = b; // Member of index 1 464 /// A[i+2] = c; // Member of index 2 465 /// A[i+3] = d; // Member of index 3 466 /// } 467 /// 468 /// Note: the interleaved load group could have gaps (missing members), but 469 /// the interleaved store group doesn't allow gaps. 470 template <typename InstTy> class InterleaveGroup { 471 public: 472 InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment) 473 : Factor(Factor), Reverse(Reverse), Alignment(Alignment), 474 InsertPos(nullptr) {} 475 476 InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment) 477 : Alignment(Alignment), InsertPos(Instr) { 478 Factor = std::abs(Stride); 479 assert(Factor > 1 && "Invalid interleave factor"); 480 481 Reverse = Stride < 0; 482 Members[0] = Instr; 483 } 484 485 bool isReverse() const { return Reverse; } 486 uint32_t getFactor() const { return Factor; } 487 Align getAlign() const { return Alignment; } 488 uint32_t getNumMembers() const { return Members.size(); } 489 490 /// Try to insert a new member \p Instr with index \p Index and 491 /// alignment \p NewAlign. The index is related to the leader and it could be 492 /// negative if it is the new leader. 493 /// 494 /// \returns false if the instruction doesn't belong to the group. 495 bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) { 496 // Make sure the key fits in an int32_t. 497 std::optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey); 498 if (!MaybeKey) 499 return false; 500 int32_t Key = *MaybeKey; 501 502 // Skip if the key is used for either the tombstone or empty special values. 503 if (DenseMapInfo<int32_t>::getTombstoneKey() == Key || 504 DenseMapInfo<int32_t>::getEmptyKey() == Key) 505 return false; 506 507 // Skip if there is already a member with the same index. 508 if (Members.contains(Key)) 509 return false; 510 511 if (Key > LargestKey) { 512 // The largest index is always less than the interleave factor. 513 if (Index >= static_cast<int32_t>(Factor)) 514 return false; 515 516 LargestKey = Key; 517 } else if (Key < SmallestKey) { 518 519 // Make sure the largest index fits in an int32_t. 520 std::optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key); 521 if (!MaybeLargestIndex) 522 return false; 523 524 // The largest index is always less than the interleave factor. 525 if (*MaybeLargestIndex >= static_cast<int64_t>(Factor)) 526 return false; 527 528 SmallestKey = Key; 529 } 530 531 // It's always safe to select the minimum alignment. 532 Alignment = std::min(Alignment, NewAlign); 533 Members[Key] = Instr; 534 return true; 535 } 536 537 /// Get the member with the given index \p Index 538 /// 539 /// \returns nullptr if contains no such member. 540 InstTy *getMember(uint32_t Index) const { 541 int32_t Key = SmallestKey + Index; 542 return Members.lookup(Key); 543 } 544 545 /// Get the index for the given member. Unlike the key in the member 546 /// map, the index starts from 0. 547 uint32_t getIndex(const InstTy *Instr) const { 548 for (auto I : Members) { 549 if (I.second == Instr) 550 return I.first - SmallestKey; 551 } 552 553 llvm_unreachable("InterleaveGroup contains no such member"); 554 } 555 556 InstTy *getInsertPos() const { return InsertPos; } 557 void setInsertPos(InstTy *Inst) { InsertPos = Inst; } 558 559 /// Add metadata (e.g. alias info) from the instructions in this group to \p 560 /// NewInst. 561 /// 562 /// FIXME: this function currently does not add noalias metadata a'la 563 /// addNewMedata. To do that we need to compute the intersection of the 564 /// noalias info from all members. 565 void addMetadata(InstTy *NewInst) const; 566 567 /// Returns true if this Group requires a scalar iteration to handle gaps. 568 bool requiresScalarEpilogue() const { 569 // If the last member of the Group exists, then a scalar epilog is not 570 // needed for this group. 571 if (getMember(getFactor() - 1)) 572 return false; 573 574 // We have a group with gaps. It therefore can't be a reversed access, 575 // because such groups get invalidated (TODO). 576 assert(!isReverse() && "Group should have been invalidated"); 577 578 // This is a group of loads, with gaps, and without a last-member 579 return true; 580 } 581 582 private: 583 uint32_t Factor; // Interleave Factor. 584 bool Reverse; 585 Align Alignment; 586 DenseMap<int32_t, InstTy *> Members; 587 int32_t SmallestKey = 0; 588 int32_t LargestKey = 0; 589 590 // To avoid breaking dependences, vectorized instructions of an interleave 591 // group should be inserted at either the first load or the last store in 592 // program order. 593 // 594 // E.g. %even = load i32 // Insert Position 595 // %add = add i32 %even // Use of %even 596 // %odd = load i32 597 // 598 // store i32 %even 599 // %odd = add i32 // Def of %odd 600 // store i32 %odd // Insert Position 601 InstTy *InsertPos; 602 }; 603 604 /// Drive the analysis of interleaved memory accesses in the loop. 605 /// 606 /// Use this class to analyze interleaved accesses only when we can vectorize 607 /// a loop. Otherwise it's meaningless to do analysis as the vectorization 608 /// on interleaved accesses is unsafe. 609 /// 610 /// The analysis collects interleave groups and records the relationships 611 /// between the member and the group in a map. 612 class InterleavedAccessInfo { 613 public: 614 InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L, 615 DominatorTree *DT, LoopInfo *LI, 616 const LoopAccessInfo *LAI) 617 : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {} 618 619 ~InterleavedAccessInfo() { invalidateGroups(); } 620 621 /// Analyze the interleaved accesses and collect them in interleave 622 /// groups. Substitute symbolic strides using \p Strides. 623 /// Consider also predicated loads/stores in the analysis if 624 /// \p EnableMaskedInterleavedGroup is true. 625 void analyzeInterleaving(bool EnableMaskedInterleavedGroup); 626 627 /// Invalidate groups, e.g., in case all blocks in loop will be predicated 628 /// contrary to original assumption. Although we currently prevent group 629 /// formation for predicated accesses, we may be able to relax this limitation 630 /// in the future once we handle more complicated blocks. Returns true if any 631 /// groups were invalidated. 632 bool invalidateGroups() { 633 if (InterleaveGroups.empty()) { 634 assert( 635 !RequiresScalarEpilogue && 636 "RequiresScalarEpilog should not be set without interleave groups"); 637 return false; 638 } 639 640 InterleaveGroupMap.clear(); 641 for (auto *Ptr : InterleaveGroups) 642 delete Ptr; 643 InterleaveGroups.clear(); 644 RequiresScalarEpilogue = false; 645 return true; 646 } 647 648 /// Check if \p Instr belongs to any interleave group. 649 bool isInterleaved(Instruction *Instr) const { 650 return InterleaveGroupMap.contains(Instr); 651 } 652 653 /// Get the interleave group that \p Instr belongs to. 654 /// 655 /// \returns nullptr if doesn't have such group. 656 InterleaveGroup<Instruction> * 657 getInterleaveGroup(const Instruction *Instr) const { 658 return InterleaveGroupMap.lookup(Instr); 659 } 660 661 iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>> 662 getInterleaveGroups() { 663 return make_range(InterleaveGroups.begin(), InterleaveGroups.end()); 664 } 665 666 /// Returns true if an interleaved group that may access memory 667 /// out-of-bounds requires a scalar epilogue iteration for correctness. 668 bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; } 669 670 /// Invalidate groups that require a scalar epilogue (due to gaps). This can 671 /// happen when optimizing for size forbids a scalar epilogue, and the gap 672 /// cannot be filtered by masking the load/store. 673 void invalidateGroupsRequiringScalarEpilogue(); 674 675 /// Returns true if we have any interleave groups. 676 bool hasGroups() const { return !InterleaveGroups.empty(); } 677 678 private: 679 /// A wrapper around ScalarEvolution, used to add runtime SCEV checks. 680 /// Simplifies SCEV expressions in the context of existing SCEV assumptions. 681 /// The interleaved access analysis can also add new predicates (for example 682 /// by versioning strides of pointers). 683 PredicatedScalarEvolution &PSE; 684 685 Loop *TheLoop; 686 DominatorTree *DT; 687 LoopInfo *LI; 688 const LoopAccessInfo *LAI; 689 690 /// True if the loop may contain non-reversed interleaved groups with 691 /// out-of-bounds accesses. We ensure we don't speculatively access memory 692 /// out-of-bounds by executing at least one scalar epilogue iteration. 693 bool RequiresScalarEpilogue = false; 694 695 /// Holds the relationships between the members and the interleave group. 696 DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap; 697 698 SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups; 699 700 /// Holds dependences among the memory accesses in the loop. It maps a source 701 /// access to a set of dependent sink accesses. 702 DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences; 703 704 /// The descriptor for a strided memory access. 705 struct StrideDescriptor { 706 StrideDescriptor() = default; 707 StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size, 708 Align Alignment) 709 : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {} 710 711 // The access's stride. It is negative for a reverse access. 712 int64_t Stride = 0; 713 714 // The scalar expression of this access. 715 const SCEV *Scev = nullptr; 716 717 // The size of the memory object. 718 uint64_t Size = 0; 719 720 // The alignment of this access. 721 Align Alignment; 722 }; 723 724 /// A type for holding instructions and their stride descriptors. 725 using StrideEntry = std::pair<Instruction *, StrideDescriptor>; 726 727 /// Create a new interleave group with the given instruction \p Instr, 728 /// stride \p Stride and alignment \p Align. 729 /// 730 /// \returns the newly created interleave group. 731 InterleaveGroup<Instruction> * 732 createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) { 733 assert(!InterleaveGroupMap.count(Instr) && 734 "Already in an interleaved access group"); 735 InterleaveGroupMap[Instr] = 736 new InterleaveGroup<Instruction>(Instr, Stride, Alignment); 737 InterleaveGroups.insert(InterleaveGroupMap[Instr]); 738 return InterleaveGroupMap[Instr]; 739 } 740 741 /// Release the group and remove all the relationships. 742 void releaseGroup(InterleaveGroup<Instruction> *Group) { 743 InterleaveGroups.erase(Group); 744 releaseGroupWithoutRemovingFromSet(Group); 745 } 746 747 /// Do everything necessary to release the group, apart from removing it from 748 /// the InterleaveGroups set. 749 void releaseGroupWithoutRemovingFromSet(InterleaveGroup<Instruction> *Group) { 750 for (unsigned i = 0; i < Group->getFactor(); i++) 751 if (Instruction *Member = Group->getMember(i)) 752 InterleaveGroupMap.erase(Member); 753 754 delete Group; 755 } 756 757 /// Collect all the accesses with a constant stride in program order. 758 void collectConstStrideAccesses( 759 MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, 760 const DenseMap<Value *, const SCEV *> &Strides); 761 762 /// Returns true if \p Stride is allowed in an interleaved group. 763 static bool isStrided(int Stride); 764 765 /// Returns true if \p BB is a predicated block. 766 bool isPredicated(BasicBlock *BB) const { 767 return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); 768 } 769 770 /// Returns true if LoopAccessInfo can be used for dependence queries. 771 bool areDependencesValid() const { 772 return LAI && LAI->getDepChecker().getDependences(); 773 } 774 775 /// Returns true if memory accesses \p A and \p B can be reordered, if 776 /// necessary, when constructing interleaved groups. 777 /// 778 /// \p A must precede \p B in program order. We return false if reordering is 779 /// not necessary or is prevented because \p A and \p B may be dependent. 780 bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A, 781 StrideEntry *B) const { 782 // Code motion for interleaved accesses can potentially hoist strided loads 783 // and sink strided stores. The code below checks the legality of the 784 // following two conditions: 785 // 786 // 1. Potentially moving a strided load (B) before any store (A) that 787 // precedes B, or 788 // 789 // 2. Potentially moving a strided store (A) after any load or store (B) 790 // that A precedes. 791 // 792 // It's legal to reorder A and B if we know there isn't a dependence from A 793 // to B. Note that this determination is conservative since some 794 // dependences could potentially be reordered safely. 795 796 // A is potentially the source of a dependence. 797 auto *Src = A->first; 798 auto SrcDes = A->second; 799 800 // B is potentially the sink of a dependence. 801 auto *Sink = B->first; 802 auto SinkDes = B->second; 803 804 // Code motion for interleaved accesses can't violate WAR dependences. 805 // Thus, reordering is legal if the source isn't a write. 806 if (!Src->mayWriteToMemory()) 807 return true; 808 809 // At least one of the accesses must be strided. 810 if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride)) 811 return true; 812 813 // If dependence information is not available from LoopAccessInfo, 814 // conservatively assume the instructions can't be reordered. 815 if (!areDependencesValid()) 816 return false; 817 818 // If we know there is a dependence from source to sink, assume the 819 // instructions can't be reordered. Otherwise, reordering is legal. 820 return !Dependences.contains(Src) || !Dependences.lookup(Src).count(Sink); 821 } 822 823 /// Collect the dependences from LoopAccessInfo. 824 /// 825 /// We process the dependences once during the interleaved access analysis to 826 /// enable constant-time dependence queries. 827 void collectDependences() { 828 if (!areDependencesValid()) 829 return; 830 const auto &DepChecker = LAI->getDepChecker(); 831 auto *Deps = DepChecker.getDependences(); 832 for (auto Dep : *Deps) 833 Dependences[Dep.getSource(DepChecker)].insert( 834 Dep.getDestination(DepChecker)); 835 } 836 }; 837 838 } // llvm namespace 839 840 #endif 841